// Copyright 2016 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "media/midi/midi_manager_winrt.h" #pragma warning(disable : 4467) #include // Required by #include #include #include #include #include #include #include #include #include #include #include #include "base/bind.h" #include "base/scoped_generic.h" #include "base/single_thread_task_runner.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/threading/thread_checker.h" #include "base/threading/thread_task_runner_handle.h" #include "base/timer/timer.h" #include "base/win/scoped_comptr.h" #include "media/midi/midi_scheduler.h" namespace midi { namespace { namespace WRL = Microsoft::WRL; using namespace ABI::Windows::Devices::Enumeration; using namespace ABI::Windows::Devices::Midi; using namespace ABI::Windows::Foundation; using namespace ABI::Windows::Storage::Streams; using base::win::ScopedComPtr; using mojom::PortState; using mojom::Result; // Helpers for printing HRESULTs. struct PrintHr { PrintHr(HRESULT hr) : hr(hr) {} HRESULT hr; }; std::ostream& operator<<(std::ostream& os, const PrintHr& phr) { std::ios_base::fmtflags ff = os.flags(); os << _com_error(phr.hr).ErrorMessage() << " (0x" << std::hex << std::uppercase << std::setfill('0') << std::setw(8) << phr.hr << ")"; os.flags(ff); return os; } // Provides access to functions in combase.dll which may not be available on // Windows 7. Loads functions dynamically at runtime to prevent library // dependencies. class CombaseFunctions { public: CombaseFunctions() = default; ~CombaseFunctions() { if (combase_dll_) ::FreeLibrary(combase_dll_); } bool LoadFunctions() { combase_dll_ = ::LoadLibrary(L"combase.dll"); if (!combase_dll_) return false; get_factory_func_ = reinterpret_cast( ::GetProcAddress(combase_dll_, "RoGetActivationFactory")); if (!get_factory_func_) return false; create_string_func_ = reinterpret_cast( ::GetProcAddress(combase_dll_, "WindowsCreateString")); if (!create_string_func_) return false; delete_string_func_ = reinterpret_cast( ::GetProcAddress(combase_dll_, "WindowsDeleteString")); if (!delete_string_func_) return false; get_string_raw_buffer_func_ = reinterpret_cast( ::GetProcAddress(combase_dll_, "WindowsGetStringRawBuffer")); if (!get_string_raw_buffer_func_) return false; return true; } HRESULT RoGetActivationFactory(HSTRING class_id, const IID& iid, void** out_factory) { DCHECK(get_factory_func_); return get_factory_func_(class_id, iid, out_factory); } HRESULT WindowsCreateString(const base::char16* src, uint32_t len, HSTRING* out_hstr) { DCHECK(create_string_func_); return create_string_func_(src, len, out_hstr); } HRESULT WindowsDeleteString(HSTRING hstr) { DCHECK(delete_string_func_); return delete_string_func_(hstr); } const base::char16* WindowsGetStringRawBuffer(HSTRING hstr, uint32_t* out_len) { DCHECK(get_string_raw_buffer_func_); return get_string_raw_buffer_func_(hstr, out_len); } private: HMODULE combase_dll_ = nullptr; decltype(&::RoGetActivationFactory) get_factory_func_ = nullptr; decltype(&::WindowsCreateString) create_string_func_ = nullptr; decltype(&::WindowsDeleteString) delete_string_func_ = nullptr; decltype(&::WindowsGetStringRawBuffer) get_string_raw_buffer_func_ = nullptr; }; CombaseFunctions* GetCombaseFunctions() { static CombaseFunctions* functions = new CombaseFunctions(); return functions; } // Scoped HSTRING class to maintain lifetime of HSTRINGs allocated with // WindowsCreateString(). class ScopedHStringTraits { public: static HSTRING InvalidValue() { return nullptr; } static void Free(HSTRING hstr) { GetCombaseFunctions()->WindowsDeleteString(hstr); } }; class ScopedHString : public base::ScopedGeneric { public: explicit ScopedHString(const base::char16* str) : ScopedGeneric(nullptr) { HSTRING hstr; HRESULT hr = GetCombaseFunctions()->WindowsCreateString( str, static_cast(wcslen(str)), &hstr); if (FAILED(hr)) VLOG(1) << "WindowsCreateString failed: " << PrintHr(hr); else reset(hstr); } }; // Factory functions that activate and create WinRT components. The caller takes // ownership of the returning ComPtr. template ScopedComPtr WrlStaticsFactory() { ScopedComPtr com_ptr; ScopedHString class_id_hstring(runtime_class_id); if (!class_id_hstring.is_valid()) { com_ptr = nullptr; return com_ptr; } HRESULT hr = GetCombaseFunctions()->RoGetActivationFactory( class_id_hstring.get(), IID_PPV_ARGS(&com_ptr)); if (FAILED(hr)) { VLOG(1) << "RoGetActivationFactory failed: " << PrintHr(hr); com_ptr = nullptr; } return com_ptr; } std::string HStringToString(HSTRING hstr) { // Note: empty HSTRINGs are represent as nullptr, and instantiating // std::string with nullptr (in base::WideToUTF8) is undefined behavior. const base::char16* buffer = GetCombaseFunctions()->WindowsGetStringRawBuffer(hstr, nullptr); if (buffer) return base::WideToUTF8(buffer); return std::string(); } template std::string GetIdString(T* obj) { HSTRING result; HRESULT hr = obj->get_Id(&result); if (FAILED(hr)) { VLOG(1) << "get_Id failed: " << PrintHr(hr); return std::string(); } return HStringToString(result); } template std::string GetDeviceIdString(T* obj) { HSTRING result; HRESULT hr = obj->get_DeviceId(&result); if (FAILED(hr)) { VLOG(1) << "get_DeviceId failed: " << PrintHr(hr); return std::string(); } return HStringToString(result); } std::string GetNameString(IDeviceInformation* info) { HSTRING result; HRESULT hr = info->get_Name(&result); if (FAILED(hr)) { VLOG(1) << "get_Name failed: " << PrintHr(hr); return std::string(); } return HStringToString(result); } HRESULT GetPointerToBufferData(IBuffer* buffer, uint8_t** out) { ScopedComPtr buffer_byte_access; HRESULT hr = buffer->QueryInterface(IID_PPV_ARGS(&buffer_byte_access)); if (FAILED(hr)) { VLOG(1) << "QueryInterface failed: " << PrintHr(hr); return hr; } // Lifetime of the pointing buffer is controlled by the buffer object. hr = buffer_byte_access->Buffer(out); if (FAILED(hr)) { VLOG(1) << "Buffer failed: " << PrintHr(hr); return hr; } return S_OK; } // Checks if given DeviceInformation represent a Microsoft GS Wavetable Synth // instance. bool IsMicrosoftSynthesizer(IDeviceInformation* info) { auto midi_synthesizer_statics = WrlStaticsFactory(); boolean result = FALSE; HRESULT hr = midi_synthesizer_statics->IsSynthesizer(info, &result); VLOG_IF(1, FAILED(hr)) << "IsSynthesizer failed: " << PrintHr(hr); return result != FALSE; } void GetDevPropString(DEVINST handle, const DEVPROPKEY* devprop_key, std::string* out) { DEVPROPTYPE devprop_type; unsigned long buffer_size = 0; // Retrieve |buffer_size| and allocate buffer later for receiving data. CONFIGRET cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type, nullptr, &buffer_size, 0); if (cr != CR_BUFFER_SMALL) { // Here we print error codes in hex instead of using PrintHr() with // HRESULT_FROM_WIN32() and CM_MapCrToWin32Err(), since only a minor set of // CONFIGRET values are mapped to Win32 errors. Same for following VLOG()s. VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr; return; } if (devprop_type != DEVPROP_TYPE_STRING) { VLOG(1) << "CM_Get_DevNode_Property returns wrong data type, " << "expected DEVPROP_TYPE_STRING"; return; } std::unique_ptr buffer(new uint8_t[buffer_size]); // Receive property data. cr = CM_Get_DevNode_Property(handle, devprop_key, &devprop_type, buffer.get(), &buffer_size, 0); if (cr != CR_SUCCESS) VLOG(1) << "CM_Get_DevNode_Property failed: CONFIGRET 0x" << std::hex << cr; else *out = base::WideToUTF8(reinterpret_cast(buffer.get())); } // Retrieves manufacturer (provider) and version information of underlying // device driver through PnP Configuration Manager, given device (interface) ID // provided by WinRT. |out_manufacturer| and |out_driver_version| won't be // modified if retrieval fails. // // Device instance ID is extracted from device (interface) ID provided by WinRT // APIs, for example from the following interface ID: // \\?\SWD#MMDEVAPI#MIDII_60F39FCA.P_0002#{504be32c-ccf6-4d2c-b73f-6f8b3747e22b} // we extract the device instance ID: SWD\MMDEVAPI\MIDII_60F39FCA.P_0002 // // However the extracted device instance ID represent a "software device" // provided by Microsoft, which is an interface on top of the hardware for each // input/output port. Therefore we further locate its parent device, which is // the actual hardware device, for driver information. void GetDriverInfoFromDeviceId(const std::string& dev_id, std::string* out_manufacturer, std::string* out_driver_version) { base::string16 dev_instance_id = base::UTF8ToWide(dev_id.substr(4, dev_id.size() - 43)); base::ReplaceChars(dev_instance_id, L"#", L"\\", &dev_instance_id); DEVINST dev_instance_handle; CONFIGRET cr = CM_Locate_DevNode(&dev_instance_handle, &dev_instance_id[0], CM_LOCATE_DEVNODE_NORMAL); if (cr != CR_SUCCESS) { VLOG(1) << "CM_Locate_DevNode failed: CONFIGRET 0x" << std::hex << cr; return; } DEVINST parent_handle; cr = CM_Get_Parent(&parent_handle, dev_instance_handle, 0); if (cr != CR_SUCCESS) { VLOG(1) << "CM_Get_Parent failed: CONFIGRET 0x" << std::hex << cr; return; } GetDevPropString(parent_handle, &DEVPKEY_Device_DriverProvider, out_manufacturer); GetDevPropString(parent_handle, &DEVPKEY_Device_DriverVersion, out_driver_version); } // Tokens with value = 0 are considered invalid (as in ). const int64_t kInvalidTokenValue = 0; template struct MidiPort { MidiPort() = default; uint32_t index; ScopedComPtr handle; EventRegistrationToken token_MessageReceived; private: DISALLOW_COPY_AND_ASSIGN(MidiPort); }; } // namespace template class MidiManagerWinrt::MidiPortManager { public: // MidiPortManager instances should be constructed on the COM thread. MidiPortManager(MidiManagerWinrt* midi_manager) : midi_manager_(midi_manager), task_runner_(base::ThreadTaskRunnerHandle::Get()) {} virtual ~MidiPortManager() { DCHECK(thread_checker_.CalledOnValidThread()); } bool StartWatcher() { DCHECK(thread_checker_.CalledOnValidThread()); HRESULT hr; midi_port_statics_ = WrlStaticsFactory(); if (!midi_port_statics_) return false; HSTRING device_selector = nullptr; hr = midi_port_statics_->GetDeviceSelector(&device_selector); if (FAILED(hr)) { VLOG(1) << "GetDeviceSelector failed: " << PrintHr(hr); return false; } auto dev_info_statics = WrlStaticsFactory< IDeviceInformationStatics, RuntimeClass_Windows_Devices_Enumeration_DeviceInformation>(); if (!dev_info_statics) return false; hr = dev_info_statics->CreateWatcherAqsFilter(device_selector, watcher_.GetAddressOf()); if (FAILED(hr)) { VLOG(1) << "CreateWatcherAqsFilter failed: " << PrintHr(hr); return false; } // Register callbacks to WinRT that post state-modifying jobs back to COM // thread. |weak_ptr| and |task_runner| are captured by lambda callbacks for // posting jobs. Note that WinRT callback arguments should not be passed // outside the callback since the pointers may be unavailable afterwards. base::WeakPtr weak_ptr = GetWeakPtrFromFactory(); scoped_refptr task_runner = task_runner_; hr = watcher_->add_Added( WRL::Callback>( [weak_ptr, task_runner](IDeviceWatcher* watcher, IDeviceInformation* info) { if (!info) { VLOG(1) << "DeviceWatcher.Added callback provides null " "pointer, ignoring"; return S_OK; } // Disable Microsoft GS Wavetable Synth due to security reasons. // http://crbug.com/499279 if (IsMicrosoftSynthesizer(info)) return S_OK; std::string dev_id = GetIdString(info), dev_name = GetNameString(info); task_runner->PostTask( FROM_HERE, base::Bind(&MidiPortManager::OnAdded, weak_ptr, dev_id, dev_name)); return S_OK; }) .Get(), &token_Added_); if (FAILED(hr)) { VLOG(1) << "add_Added failed: " << PrintHr(hr); return false; } hr = watcher_->add_EnumerationCompleted( WRL::Callback>( [weak_ptr, task_runner](IDeviceWatcher* watcher, IInspectable* insp) { task_runner->PostTask( FROM_HERE, base::Bind(&MidiPortManager::OnEnumerationCompleted, weak_ptr)); return S_OK; }) .Get(), &token_EnumerationCompleted_); if (FAILED(hr)) { VLOG(1) << "add_EnumerationCompleted failed: " << PrintHr(hr); return false; } hr = watcher_->add_Removed( WRL::Callback< ITypedEventHandler>( [weak_ptr, task_runner](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) { if (!update) { VLOG(1) << "DeviceWatcher.Removed callback provides null " "pointer, ignoring"; return S_OK; } std::string dev_id = GetIdString(update); task_runner->PostTask( FROM_HERE, base::Bind(&MidiPortManager::OnRemoved, weak_ptr, dev_id)); return S_OK; }) .Get(), &token_Removed_); if (FAILED(hr)) { VLOG(1) << "add_Removed failed: " << PrintHr(hr); return false; } hr = watcher_->add_Stopped( WRL::Callback>( [](IDeviceWatcher* watcher, IInspectable* insp) { // Placeholder, does nothing for now. return S_OK; }) .Get(), &token_Stopped_); if (FAILED(hr)) { VLOG(1) << "add_Stopped failed: " << PrintHr(hr); return false; } hr = watcher_->add_Updated( WRL::Callback< ITypedEventHandler>( [](IDeviceWatcher* watcher, IDeviceInformationUpdate* update) { // TODO(shaochuan): Check for fields to be updated here. return S_OK; }) .Get(), &token_Updated_); if (FAILED(hr)) { VLOG(1) << "add_Updated failed: " << PrintHr(hr); return false; } hr = watcher_->Start(); if (FAILED(hr)) { VLOG(1) << "Start failed: " << PrintHr(hr); return false; } is_initialized_ = true; return true; } void StopWatcher() { DCHECK(thread_checker_.CalledOnValidThread()); HRESULT hr; for (const auto& entry : ports_) RemovePortEventHandlers(entry.second.get()); if (token_Added_.value != kInvalidTokenValue) { hr = watcher_->remove_Added(token_Added_); VLOG_IF(1, FAILED(hr)) << "remove_Added failed: " << PrintHr(hr); token_Added_.value = kInvalidTokenValue; } if (token_EnumerationCompleted_.value != kInvalidTokenValue) { hr = watcher_->remove_EnumerationCompleted(token_EnumerationCompleted_); VLOG_IF(1, FAILED(hr)) << "remove_EnumerationCompleted failed: " << PrintHr(hr); token_EnumerationCompleted_.value = kInvalidTokenValue; } if (token_Removed_.value != kInvalidTokenValue) { hr = watcher_->remove_Removed(token_Removed_); VLOG_IF(1, FAILED(hr)) << "remove_Removed failed: " << PrintHr(hr); token_Removed_.value = kInvalidTokenValue; } if (token_Stopped_.value != kInvalidTokenValue) { hr = watcher_->remove_Stopped(token_Stopped_); VLOG_IF(1, FAILED(hr)) << "remove_Stopped failed: " << PrintHr(hr); token_Stopped_.value = kInvalidTokenValue; } if (token_Updated_.value != kInvalidTokenValue) { hr = watcher_->remove_Updated(token_Updated_); VLOG_IF(1, FAILED(hr)) << "remove_Updated failed: " << PrintHr(hr); token_Updated_.value = kInvalidTokenValue; } if (is_initialized_) { hr = watcher_->Stop(); VLOG_IF(1, FAILED(hr)) << "Stop failed: " << PrintHr(hr); is_initialized_ = false; } } MidiPort* GetPortByDeviceId(std::string dev_id) { DCHECK(thread_checker_.CalledOnValidThread()); CHECK(is_initialized_); auto it = ports_.find(dev_id); if (it == ports_.end()) return nullptr; return it->second.get(); } MidiPort* GetPortByIndex(uint32_t port_index) { DCHECK(thread_checker_.CalledOnValidThread()); CHECK(is_initialized_); return GetPortByDeviceId(port_ids_[port_index]); } protected: // Points to the MidiManagerWinrt instance, which is expected to outlive the // MidiPortManager instance. MidiManagerWinrt* midi_manager_; // Task runner of the COM thread. scoped_refptr task_runner_; // Ensures all methods are called on the COM thread. base::ThreadChecker thread_checker_; private: // DeviceWatcher callbacks: void OnAdded(std::string dev_id, std::string dev_name) { DCHECK(thread_checker_.CalledOnValidThread()); CHECK(is_initialized_); port_names_[dev_id] = dev_name; ScopedHString dev_id_hstring(base::UTF8ToWide(dev_id).c_str()); if (!dev_id_hstring.is_valid()) return; IAsyncOperation* async_op; HRESULT hr = midi_port_statics_->FromIdAsync(dev_id_hstring.get(), &async_op); if (FAILED(hr)) { VLOG(1) << "FromIdAsync failed: " << PrintHr(hr); return; } base::WeakPtr weak_ptr = GetWeakPtrFromFactory(); scoped_refptr task_runner = task_runner_; hr = async_op->put_Completed( WRL::Callback>( [weak_ptr, task_runner](IAsyncOperation* async_op, AsyncStatus status) { // A reference to |async_op| is kept in |async_ops_|, safe to pass // outside. task_runner->PostTask( FROM_HERE, base::Bind(&MidiPortManager::OnCompletedGetPortFromIdAsync, weak_ptr, async_op)); return S_OK; }) .Get()); if (FAILED(hr)) { VLOG(1) << "put_Completed failed: " << PrintHr(hr); return; } // Keep a reference to incompleted |async_op| for releasing later. async_ops_.insert(async_op); } void OnEnumerationCompleted() { DCHECK(thread_checker_.CalledOnValidThread()); CHECK(is_initialized_); if (async_ops_.empty()) midi_manager_->OnPortManagerReady(); else enumeration_completed_not_ready_ = true; } void OnRemoved(std::string dev_id) { DCHECK(thread_checker_.CalledOnValidThread()); CHECK(is_initialized_); // Note: in case Microsoft GS Wavetable Synth triggers this event for some // reason, it will be ignored here with log emitted. MidiPort* port = GetPortByDeviceId(dev_id); if (!port) { VLOG(1) << "Removing non-existent port " << dev_id; return; } SetPortState(port->index, PortState::DISCONNECTED); RemovePortEventHandlers(port); port->handle = nullptr; } void OnCompletedGetPortFromIdAsync(IAsyncOperation* async_op) { DCHECK(thread_checker_.CalledOnValidThread()); CHECK(is_initialized_); InterfaceType* handle = nullptr; HRESULT hr = async_op->GetResults(&handle); if (FAILED(hr)) { VLOG(1) << "GetResults failed: " << PrintHr(hr); return; } // Manually release COM interface to completed |async_op|. auto it = async_ops_.find(async_op); CHECK(it != async_ops_.end()); (*it)->Release(); async_ops_.erase(it); if (!handle) { VLOG(1) << "Midi{In,Out}Port.FromIdAsync callback provides null pointer, " "ignoring"; return; } EventRegistrationToken token = {kInvalidTokenValue}; if (!RegisterOnMessageReceived(handle, &token)) return; std::string dev_id = GetDeviceIdString(handle); MidiPort* port = GetPortByDeviceId(dev_id); if (port == nullptr) { std::string manufacturer = "Unknown", driver_version = "Unknown"; GetDriverInfoFromDeviceId(dev_id, &manufacturer, &driver_version); AddPort(MidiPortInfo(dev_id, manufacturer, port_names_[dev_id], driver_version, PortState::OPENED)); port = new MidiPort; port->index = static_cast(port_ids_.size()); ports_[dev_id].reset(port); port_ids_.push_back(dev_id); } else { SetPortState(port->index, PortState::CONNECTED); } port->handle = handle; port->token_MessageReceived = token; if (enumeration_completed_not_ready_ && async_ops_.empty()) { midi_manager_->OnPortManagerReady(); enumeration_completed_not_ready_ = false; } } // Overrided by MidiInPortManager to listen to input ports. virtual bool RegisterOnMessageReceived(InterfaceType* handle, EventRegistrationToken* p_token) { return true; } // Overrided by MidiInPortManager to remove MessageReceived event handler. virtual void RemovePortEventHandlers(MidiPort* port) {} // Calls midi_manager_->Add{Input,Output}Port. virtual void AddPort(MidiPortInfo info) = 0; // Calls midi_manager_->Set{Input,Output}PortState. virtual void SetPortState(uint32_t port_index, PortState state) = 0; // WeakPtrFactory has to be declared in derived class, use this method to // retrieve upcasted WeakPtr for posting tasks. virtual base::WeakPtr GetWeakPtrFromFactory() = 0; // Midi{In,Out}PortStatics instance. ScopedComPtr midi_port_statics_; // DeviceWatcher instance and event registration tokens for unsubscribing // events in destructor. ScopedComPtr watcher_; EventRegistrationToken token_Added_ = {kInvalidTokenValue}, token_EnumerationCompleted_ = {kInvalidTokenValue}, token_Removed_ = {kInvalidTokenValue}, token_Stopped_ = {kInvalidTokenValue}, token_Updated_ = {kInvalidTokenValue}; // All manipulations to these fields should be done on COM thread. std::unordered_map>> ports_; std::vector port_ids_; std::unordered_map port_names_; // Keeps AsyncOperation references before the operation completes. Note that // raw pointers are used here and the COM interfaces should be released // manually. std::unordered_set*> async_ops_; // Set when device enumeration is completed but OnPortManagerReady() is not // called since some ports are not yet ready (i.e. |async_ops_| is not empty). // In such cases, OnPortManagerReady() will be called in // OnCompletedGetPortFromIdAsync() when the last pending port is ready. bool enumeration_completed_not_ready_ = false; // Set if the instance is initialized without error. Should be checked in all // methods on COM thread except StartWatcher(). bool is_initialized_ = false; }; class MidiManagerWinrt::MidiInPortManager final : public MidiPortManager { public: MidiInPortManager(MidiManagerWinrt* midi_manager) : MidiPortManager(midi_manager), weak_factory_(this) {} private: // MidiPortManager overrides: bool RegisterOnMessageReceived(IMidiInPort* handle, EventRegistrationToken* p_token) override { DCHECK(thread_checker_.CalledOnValidThread()); base::WeakPtr weak_ptr = weak_factory_.GetWeakPtr(); scoped_refptr task_runner = task_runner_; HRESULT hr = handle->add_MessageReceived( WRL::Callback< ITypedEventHandler>( [weak_ptr, task_runner](IMidiInPort* handle, IMidiMessageReceivedEventArgs* args) { const base::TimeTicks now = base::TimeTicks::Now(); std::string dev_id = GetDeviceIdString(handle); ScopedComPtr message; HRESULT hr = args->get_Message(message.GetAddressOf()); if (FAILED(hr)) { VLOG(1) << "get_Message failed: " << PrintHr(hr); return hr; } ScopedComPtr buffer; hr = message->get_RawData(buffer.GetAddressOf()); if (FAILED(hr)) { VLOG(1) << "get_RawData failed: " << PrintHr(hr); return hr; } uint8_t* p_buffer_data = nullptr; hr = GetPointerToBufferData(buffer.Get(), &p_buffer_data); if (FAILED(hr)) return hr; uint32_t data_length = 0; hr = buffer->get_Length(&data_length); if (FAILED(hr)) { VLOG(1) << "get_Length failed: " << PrintHr(hr); return hr; } std::vector data(p_buffer_data, p_buffer_data + data_length); task_runner->PostTask( FROM_HERE, base::Bind(&MidiInPortManager::OnMessageReceived, weak_ptr, dev_id, data, now)); return S_OK; }) .Get(), p_token); if (FAILED(hr)) { VLOG(1) << "add_MessageReceived failed: " << PrintHr(hr); return false; } return true; } void RemovePortEventHandlers(MidiPort* port) override { if (!(port->handle && port->token_MessageReceived.value != kInvalidTokenValue)) return; HRESULT hr = port->handle->remove_MessageReceived(port->token_MessageReceived); VLOG_IF(1, FAILED(hr)) << "remove_MessageReceived failed: " << PrintHr(hr); port->token_MessageReceived.value = kInvalidTokenValue; } void AddPort(MidiPortInfo info) final { midi_manager_->AddInputPort(info); } void SetPortState(uint32_t port_index, PortState state) final { midi_manager_->SetInputPortState(port_index, state); } base::WeakPtr GetWeakPtrFromFactory() final { DCHECK(thread_checker_.CalledOnValidThread()); return weak_factory_.GetWeakPtr(); } // Callback on receiving MIDI input message. void OnMessageReceived(std::string dev_id, std::vector data, base::TimeTicks time) { DCHECK(thread_checker_.CalledOnValidThread()); MidiPort* port = GetPortByDeviceId(dev_id); CHECK(port); midi_manager_->ReceiveMidiData(port->index, &data[0], data.size(), time); } // Last member to ensure destructed first. base::WeakPtrFactory weak_factory_; DISALLOW_COPY_AND_ASSIGN(MidiInPortManager); }; class MidiManagerWinrt::MidiOutPortManager final : public MidiPortManager { public: MidiOutPortManager(MidiManagerWinrt* midi_manager) : MidiPortManager(midi_manager), weak_factory_(this) {} private: // MidiPortManager overrides: void AddPort(MidiPortInfo info) final { midi_manager_->AddOutputPort(info); } void SetPortState(uint32_t port_index, PortState state) final { midi_manager_->SetOutputPortState(port_index, state); } base::WeakPtr GetWeakPtrFromFactory() final { DCHECK(thread_checker_.CalledOnValidThread()); return weak_factory_.GetWeakPtr(); } // Last member to ensure destructed first. base::WeakPtrFactory weak_factory_; DISALLOW_COPY_AND_ASSIGN(MidiOutPortManager); }; MidiManagerWinrt::MidiManagerWinrt(MidiService* service) : MidiManager(service), com_thread_("Windows MIDI COM Thread") {} MidiManagerWinrt::~MidiManagerWinrt() { base::AutoLock auto_lock(lazy_init_member_lock_); CHECK(!com_thread_checker_); CHECK(!port_manager_in_); CHECK(!port_manager_out_); CHECK(!scheduler_); } void MidiManagerWinrt::StartInitialization() { com_thread_.init_com_with_mta(true); com_thread_.Start(); com_thread_.task_runner()->PostTask( FROM_HERE, base::Bind(&MidiManagerWinrt::InitializeOnComThread, base::Unretained(this))); } void MidiManagerWinrt::Finalize() { com_thread_.task_runner()->PostTask( FROM_HERE, base::Bind(&MidiManagerWinrt::FinalizeOnComThread, base::Unretained(this))); // Blocks until FinalizeOnComThread() returns. Delayed MIDI send data tasks // will be ignored. com_thread_.Stop(); } void MidiManagerWinrt::DispatchSendMidiData(MidiManagerClient* client, uint32_t port_index, const std::vector& data, double timestamp) { CHECK(scheduler_); scheduler_->PostSendDataTask( client, data.size(), timestamp, base::Bind(&MidiManagerWinrt::SendOnComThread, base::Unretained(this), port_index, data)); } void MidiManagerWinrt::InitializeOnComThread() { base::AutoLock auto_lock(lazy_init_member_lock_); com_thread_checker_.reset(new base::ThreadChecker); if (!GetCombaseFunctions()->LoadFunctions()) { VLOG(1) << "Failed loading functions from combase.dll: " << PrintHr(HRESULT_FROM_WIN32(GetLastError())); CompleteInitialization(Result::INITIALIZATION_ERROR); return; } port_manager_in_.reset(new MidiInPortManager(this)); port_manager_out_.reset(new MidiOutPortManager(this)); scheduler_.reset(new MidiScheduler(this)); if (!(port_manager_in_->StartWatcher() && port_manager_out_->StartWatcher())) { port_manager_in_->StopWatcher(); port_manager_out_->StopWatcher(); CompleteInitialization(Result::INITIALIZATION_ERROR); } } void MidiManagerWinrt::FinalizeOnComThread() { base::AutoLock auto_lock(lazy_init_member_lock_); DCHECK(com_thread_checker_->CalledOnValidThread()); scheduler_.reset(); if (port_manager_in_) { port_manager_in_->StopWatcher(); port_manager_in_.reset(); } if (port_manager_out_) { port_manager_out_->StopWatcher(); port_manager_out_.reset(); } com_thread_checker_.reset(); } void MidiManagerWinrt::SendOnComThread(uint32_t port_index, const std::vector& data) { DCHECK(com_thread_checker_->CalledOnValidThread()); MidiPort* port = port_manager_out_->GetPortByIndex(port_index); if (!(port && port->handle)) { VLOG(1) << "Port not available: " << port_index; return; } auto buffer_factory = WrlStaticsFactory(); if (!buffer_factory) return; ScopedComPtr buffer; HRESULT hr = buffer_factory->Create(static_cast(data.size()), buffer.GetAddressOf()); if (FAILED(hr)) { VLOG(1) << "Create failed: " << PrintHr(hr); return; } hr = buffer->put_Length(static_cast(data.size())); if (FAILED(hr)) { VLOG(1) << "put_Length failed: " << PrintHr(hr); return; } uint8_t* p_buffer_data = nullptr; hr = GetPointerToBufferData(buffer.Get(), &p_buffer_data); if (FAILED(hr)) return; std::copy(data.begin(), data.end(), p_buffer_data); hr = port->handle->SendBuffer(buffer.Get()); if (FAILED(hr)) { VLOG(1) << "SendBuffer failed: " << PrintHr(hr); return; } } void MidiManagerWinrt::OnPortManagerReady() { DCHECK(com_thread_checker_->CalledOnValidThread()); DCHECK(port_manager_ready_count_ < 2); if (++port_manager_ready_count_ == 2) CompleteInitialization(Result::OK); } } // namespace midi