diff options
Diffstat (limited to 'src/apscheduler/eventbrokers/mqtt.py')
-rw-r--r-- | src/apscheduler/eventbrokers/mqtt.py | 35 |
1 files changed, 17 insertions, 18 deletions
diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py index 10ac605..567418d 100644 --- a/src/apscheduler/eventbrokers/mqtt.py +++ b/src/apscheduler/eventbrokers/mqtt.py @@ -2,22 +2,22 @@ from __future__ import annotations import sys from concurrent.futures import Future +from contextlib import AsyncExitStack from typing import Any import attrs +from anyio import to_thread +from anyio.from_thread import BlockingPortal from paho.mqtt.client import Client, MQTTMessage from paho.mqtt.properties import Properties from paho.mqtt.reasoncodes import ReasonCodes from .._events import Event -from ..abc import Serializer -from ..serializers.json import JSONSerializer -from .base import DistributedEventBrokerMixin -from .local import LocalEventBroker +from .base import BaseExternalEventBroker @attrs.define(eq=False) -class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): +class MQTTEventBroker(BaseExternalEventBroker): """ An event broker that uses an MQTT (v3.1 or v5) broker to broadcast events. @@ -26,7 +26,6 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): .. _paho-mqtt: https://pypi.org/project/paho-mqtt/ :param client: a paho-mqtt client - :param serializer: the serializer used to (de)serialize events for transport :param host: host name or IP address to connect to :param port: TCP port number to connect to :param topic: topic on which to send the messages @@ -35,16 +34,17 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): """ client: Client = attrs.field(factory=Client) - serializer: Serializer = attrs.field(factory=JSONSerializer) host: str = attrs.field(kw_only=True, default="localhost") port: int = attrs.field(kw_only=True, default=1883) topic: str = attrs.field(kw_only=True, default="apscheduler") subscribe_qos: int = attrs.field(kw_only=True, default=0) publish_qos: int = attrs.field(kw_only=True, default=0) + _portal: BlockingPortal = attrs.field(init=False) _ready_future: Future[None] = attrs.field(init=False) - def start(self) -> None: - super().start() + async def start(self, exit_stack: AsyncExitStack) -> None: + await super().start(exit_stack) + self._portal = await exit_stack.enter_async_context(BlockingPortal()) self._ready_future = Future() self.client.on_connect = self._on_connect self.client.on_connect_fail = self._on_connect_fail @@ -53,12 +53,9 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): self.client.on_subscribe = self._on_subscribe self.client.connect(self.host, self.port) self.client.loop_start() - self._ready_future.result(10) - - def stop(self, *, force: bool = False) -> None: - self.client.disconnect() - self.client.loop_stop(force=force) - super().stop() + exit_stack.push_async_callback(to_thread.run_sync, self.client.loop_stop) + await to_thread.run_sync(self._ready_future.result, 10) + exit_stack.push_async_callback(to_thread.run_sync, self.client.disconnect) def _on_connect( self, @@ -97,8 +94,10 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): def _on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None: event = self.reconstitute_event(msg.payload) if event is not None: - self.publish_local(event) + self._portal.call(self.publish_local, event) - def publish(self, event: Event) -> None: + async def publish(self, event: Event) -> None: notification = self.generate_notification(event) - self.client.publish(self.topic, notification, qos=self.publish_qos) + await to_thread.run_sync( + lambda: self.client.publish(self.topic, notification, qos=self.publish_qos) + ) |