diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-06-09 12:40:55 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-07-19 00:48:51 +0300 |
commit | f215c1ab45959095f6b499eb7b26356c5937ee8b (patch) | |
tree | 552687a0ed6e799b3da96eec5cd3fbb14d19f1b5 /src/apscheduler/eventbrokers | |
parent | e3158fdf59a7c92a9449a566a2b746a4024e582f (diff) | |
download | apscheduler-f215c1ab45959095f6b499eb7b26356c5937ee8b.tar.gz |
Added support for starting the sync scheduler (and worker) without the context manager
Diffstat (limited to 'src/apscheduler/eventbrokers')
-rw-r--r-- | src/apscheduler/eventbrokers/async_adapter.py | 68 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/async_local.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/asyncpg.py | 46 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/base.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/local.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/mqtt.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/redis.py | 13 |
7 files changed, 105 insertions, 86 deletions
diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py index ebceaad..4c1cb6c 100644 --- a/src/apscheduler/eventbrokers/async_adapter.py +++ b/src/apscheduler/eventbrokers/async_adapter.py @@ -1,39 +1,65 @@ from __future__ import annotations -from functools import partial +from typing import Any, Callable, Iterable import attrs from anyio import to_thread from anyio.from_thread import BlockingPortal -from apscheduler.abc import EventBroker -from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker +from apscheduler.abc import AsyncEventBroker, EventBroker, Subscription from apscheduler.events import Event -from apscheduler.util import reentrant -@reentrant @attrs.define(eq=False) -class AsyncEventBrokerAdapter(LocalAsyncEventBroker): +class AsyncEventBrokerAdapter(AsyncEventBroker): original: EventBroker - portal: BlockingPortal - - async def __aenter__(self): - await super().__aenter__() - if not self.portal: - self.portal = BlockingPortal() - self._exit_stack.enter_async_context(self.portal) + async def start(self) -> None: + await to_thread.run_sync(self.original.start) - await to_thread.run_sync(self.original.__enter__) - self._exit_stack.push_async_exit( - partial(to_thread.run_sync, self.original.__exit__) - ) + async def stop(self, *, force: bool = False) -> None: + await to_thread.run_sync(lambda: self.original.stop(force=force)) - # Relay events from the original broker to this one - self._exit_stack.enter_context( - self.original.subscribe(partial(self.portal.call, self.publish_local)) - ) + async def publish_local(self, event: Event) -> None: + await to_thread.run_sync(self.original.publish_local, event) async def publish(self, event: Event) -> None: await to_thread.run_sync(self.original.publish, event) + + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + return self.original.subscribe(callback, event_types, one_shot=one_shot) + + +@attrs.define(eq=False) +class SyncEventBrokerAdapter(EventBroker): + original: AsyncEventBroker + portal: BlockingPortal + + def start(self) -> None: + pass + + def stop(self, *, force: bool = False) -> None: + pass + + def publish_local(self, event: Event) -> None: + self.portal.call(self.original.publish_local, event) + + def publish(self, event: Event) -> None: + self.portal.call(self.original.publish, event) + + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + return self.portal.call( + lambda: self.original.subscribe(callback, event_types, one_shot=one_shot) + ) diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py index 98031bf..dba37f1 100644 --- a/src/apscheduler/eventbrokers/async_local.py +++ b/src/apscheduler/eventbrokers/async_local.py @@ -1,7 +1,6 @@ from __future__ import annotations from asyncio import iscoroutine -from contextlib import AsyncExitStack from typing import Any, Callable import attrs @@ -10,26 +9,19 @@ from anyio.abc import TaskGroup from ..abc import AsyncEventBroker from ..events import Event -from ..util import reentrant from .base import BaseEventBroker -@reentrant @attrs.define(eq=False) class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker): _task_group: TaskGroup = attrs.field(init=False) - _exit_stack: AsyncExitStack = attrs.field(init=False) - - async def __aenter__(self) -> LocalAsyncEventBroker: - self._exit_stack = AsyncExitStack() + async def start(self) -> None: self._task_group = create_task_group() - await self._exit_stack.enter_async_context(self._task_group) - - return self + await self._task_group.__aenter__() - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + async def stop(self, *, force: bool = False) -> None: + await self._task_group.__aexit__(None, None, None) del self._task_group async def publish(self, event: Event) -> None: diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py index d97c229..c80033e 100644 --- a/src/apscheduler/eventbrokers/asyncpg.py +++ b/src/apscheduler/eventbrokers/asyncpg.py @@ -4,13 +4,12 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable import attrs -from anyio import TASK_STATUS_IGNORED, sleep +from anyio import TASK_STATUS_IGNORED, CancelScope, sleep from asyncpg import Connection from asyncpg.pool import Pool from ..events import Event from ..exceptions import SerializationError -from ..util import reentrant from .async_local import LocalAsyncEventBroker from .base import DistributedEventBrokerMixin @@ -18,12 +17,12 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncEngine -@reentrant @attrs.define(eq=False) class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): connection_factory: Callable[[], AsyncContextManager[Connection]] channel: str = attrs.field(kw_only=True, default="apscheduler") max_idle_time: float = attrs.field(kw_only=True, default=30) + _listen_cancel_scope: CancelScope = attrs.field(init=False) @classmethod def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker: @@ -47,11 +46,15 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): return cls(connection_factory) - async def __aenter__(self) -> AsyncpgEventBroker: - await super().__aenter__() - await self._task_group.start(self._listen_notifications) - self._exit_stack.callback(self._task_group.cancel_scope.cancel) - return self + async def start(self) -> None: + await super().start() + self._listen_cancel_scope = await self._task_group.start( + self._listen_notifications + ) + + async def stop(self, *, force: bool = False) -> None: + self._listen_cancel_scope.cancel() + await super().stop(force=force) async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: def callback(connection, pid, channel: str, payload: str) -> None: @@ -60,19 +63,20 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): self._task_group.start_soon(self.publish_local, event) task_started_sent = False - while True: - async with self.connection_factory() as conn: - await conn.add_listener(self.channel, callback) - if not task_started_sent: - task_status.started() - task_started_sent = True - - try: - while True: - await sleep(self.max_idle_time) - await conn.execute("SELECT 1") - finally: - await conn.remove_listener(self.channel, callback) + with CancelScope() as cancel_scope: + while True: + async with self.connection_factory() as conn: + await conn.add_listener(self.channel, callback) + if not task_started_sent: + task_status.started(cancel_scope) + task_started_sent = True + + try: + while True: + await sleep(self.max_idle_time) + await conn.execute("SELECT 1") + finally: + await conn.remove_listener(self.channel, callback) async def publish(self, event: Event) -> None: notification = self.generate_notification_str(event) diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index 38c83c7..3aeee77 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Iterable import attrs from .. import events -from ..abc import EventBroker, Serializer, Subscription +from ..abc import EventSource, Serializer, Subscription from ..events import Event from ..exceptions import DeserializationError @@ -25,7 +25,7 @@ class LocalSubscription(Subscription): @attrs.define(eq=False) -class BaseEventBroker(EventBroker): +class BaseEventBroker(EventSource): _logger: Logger = attrs.field(init=False) _subscriptions: dict[object, LocalSubscription] = attrs.field( init=False, factory=dict @@ -70,7 +70,7 @@ class DistributedEventBrokerMixin: self._logger.exception( "Failed to deserialize an event of type %s", event_type, - serialized=serialized, + extra={"serialized": serialized}, ) return None @@ -80,7 +80,7 @@ class DistributedEventBrokerMixin: self._logger.error( "Receive notification for a nonexistent event type: %s", event_type, - serialized=serialized, + extra={"serialized": serialized}, ) return None @@ -94,7 +94,9 @@ class DistributedEventBrokerMixin: try: event_type_bytes, serialized = payload.split(b" ", 1) except ValueError: - self._logger.error("Received malformatted notification", payload=payload) + self._logger.error( + "Received malformatted notification", extra={"payload": payload} + ) return None event_type = event_type_bytes.decode("ascii", errors="replace") @@ -104,7 +106,9 @@ class DistributedEventBrokerMixin: try: event_type, b64_serialized = payload.split(" ", 1) except ValueError: - self._logger.error("Received malformatted notification", payload=payload) + self._logger.error( + "Received malformatted notification", extra={"payload": payload} + ) return None return self._reconstitute_event(event_type, b64decode(b64_serialized)) diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py index cf19526..a870c45 100644 --- a/src/apscheduler/eventbrokers/local.py +++ b/src/apscheduler/eventbrokers/local.py @@ -8,26 +8,22 @@ from typing import Any, Callable, Iterable import attrs -from ..abc import Subscription +from ..abc import EventBroker, Subscription from ..events import Event -from ..util import reentrant from .base import BaseEventBroker -@reentrant @attrs.define(eq=False) -class LocalEventBroker(BaseEventBroker): +class LocalEventBroker(EventBroker, BaseEventBroker): _executor: ThreadPoolExecutor = attrs.field(init=False) _exit_stack: ExitStack = attrs.field(init=False) _subscriptions_lock: Lock = attrs.field(init=False, factory=Lock) - def __enter__(self): - self._exit_stack = ExitStack() - self._executor = self._exit_stack.enter_context(ThreadPoolExecutor(1)) - return self + def start(self) -> None: + self._executor = ThreadPoolExecutor(1) - def __exit__(self, exc_type, exc_val, exc_tb): - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + def stop(self, *, force: bool = False) -> None: + self._executor.shutdown(wait=not force) del self._executor def subscribe( diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py index 8b0a512..9b0f859 100644 --- a/src/apscheduler/eventbrokers/mqtt.py +++ b/src/apscheduler/eventbrokers/mqtt.py @@ -11,12 +11,10 @@ from paho.mqtt.reasoncodes import ReasonCodes from ..abc import Serializer from ..events import Event from ..serializers.json import JSONSerializer -from ..util import reentrant from .base import DistributedEventBrokerMixin from .local import LocalEventBroker -@reentrant @attrs.define(eq=False) class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client: Client @@ -28,8 +26,8 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): publish_qos: int = attrs.field(kw_only=True, default=0) _ready_future: Future[None] = attrs.field(init=False) - def __enter__(self): - super().__enter__() + def start(self) -> None: + super().start() self._ready_future = Future() self.client.enable_logger(self._logger) self.client.on_connect = self._on_connect @@ -38,11 +36,11 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): self.client.connect(self.host, self.port) self.client.loop_start() self._ready_future.result(10) - self._exit_stack.push( - lambda exc_type, *_: self.client.loop_stop(force=bool(exc_type)) - ) - self._exit_stack.callback(self.client.disconnect) - return self + + def stop(self, *, force: bool = False) -> None: + self.client.disconnect() + self.client.loop_stop(force=force) + super().stop() def _on_connect( self, diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py index 0f14e3e..a41ef37 100644 --- a/src/apscheduler/eventbrokers/redis.py +++ b/src/apscheduler/eventbrokers/redis.py @@ -9,12 +9,10 @@ from redis import ConnectionPool, Redis from ..abc import Serializer from ..events import Event from ..serializers.json import JSONSerializer -from ..util import reentrant from .base import DistributedEventBrokerMixin from .local import LocalEventBroker -@reentrant @attrs.define(eq=False) class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client: Redis @@ -23,6 +21,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): message_poll_interval: float = attrs.field(kw_only=True, default=0.05) _stopped: bool = attrs.field(init=False, default=True) _ready_future: Future[None] = attrs.field(init=False) + _thread: Thread = attrs.field(init=False) @classmethod def from_url(cls, url: str, **kwargs) -> RedisEventBroker: @@ -30,7 +29,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client = Redis(connection_pool=pool) return cls(client) - def __enter__(self): + def start(self) -> None: self._stopped = False self._ready_future = Future() self._thread = Thread( @@ -38,14 +37,14 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): ) self._thread.start() self._ready_future.result(10) - return super().__enter__() + super().start() - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self, *, force: bool = False) -> None: self._stopped = True - if not exc_type: + if not force: self._thread.join(5) - super().__exit__(exc_type, exc_val, exc_tb) + super().stop(force=force) def _listen_messages(self) -> None: while not self._stopped: |