diff options
Diffstat (limited to 'src/apscheduler/eventbrokers/base.py')
-rw-r--r-- | src/apscheduler/eventbrokers/base.py | 64 |
1 files changed, 57 insertions, 7 deletions
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index 1373d54..5a3e718 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -1,15 +1,21 @@ from __future__ import annotations from base64 import b64decode, b64encode +from contextlib import AsyncExitStack +from inspect import iscoroutine from logging import Logger, getLogger from typing import Any, Callable, Iterable import attrs +from anyio import create_task_group, to_thread +from anyio.abc import TaskGroup from .. import _events from .._events import Event from .._exceptions import DeserializationError -from ..abc import EventSource, Serializer, Subscription +from .._retry import RetryMixin +from ..abc import EventBroker, Serializer, Subscription +from ..serializers.json import JSONSerializer @attrs.define(eq=False, frozen=True) @@ -17,6 +23,7 @@ class LocalSubscription(Subscription): callback: Callable[[Event], Any] event_types: set[type[Event]] | None one_shot: bool + is_async: bool token: object _source: BaseEventBroker @@ -24,36 +31,79 @@ class LocalSubscription(Subscription): self._source.unsubscribe(self.token) -@attrs.define(eq=False) -class BaseEventBroker(EventSource): +@attrs.define(kw_only=True) +class BaseEventBroker(EventBroker): _logger: Logger = attrs.field(init=False) _subscriptions: dict[object, LocalSubscription] = attrs.field( init=False, factory=dict ) + _task_group: TaskGroup = attrs.field(init=False) def __attrs_post_init__(self) -> None: self._logger = getLogger(self.__class__.__module__) + async def start(self, exit_stack: AsyncExitStack) -> None: + self._task_group = await exit_stack.enter_async_context(create_task_group()) + def subscribe( self, callback: Callable[[Event], Any], event_types: Iterable[type[Event]] | None = None, *, + is_async: bool = True, one_shot: bool = False, ) -> Subscription: types = set(event_types) if event_types else None token = object() - subscription = LocalSubscription(callback, types, one_shot, token, self) + subscription = LocalSubscription( + callback, types, one_shot, is_async, token, self + ) self._subscriptions[token] = subscription return subscription def unsubscribe(self, token: object) -> None: self._subscriptions.pop(token, None) + async def publish_local(self, event: Event) -> None: + event_type = type(event) + one_shot_tokens: list[object] = [] + for _token, subscription in self._subscriptions.items(): + if ( + subscription.event_types is None + or event_type in subscription.event_types + ): + self._task_group.start_soon(self._deliver_event, subscription, event) + if subscription.one_shot: + one_shot_tokens.append(subscription.token) + + for token in one_shot_tokens: + self.unsubscribe(token) + + async def _deliver_event( + self, subscription: LocalSubscription, event: Event + ) -> None: + try: + if subscription.is_async: + retval = subscription.callback(event) + if iscoroutine(retval): + await retval + else: + await to_thread.run_sync(subscription.callback, event) + except Exception: + self._logger.exception( + "Error delivering %s event", event.__class__.__name__ + ) + + +@attrs.define(kw_only=True) +class BaseExternalEventBroker(BaseEventBroker, RetryMixin): + """ + Base class for event brokers that use an external service. + + :param serializer: the serializer used to (de)serialize events for transport + """ -class DistributedEventBrokerMixin: - serializer: Serializer - _logger: Logger + serializer: Serializer = attrs.field(factory=JSONSerializer) def generate_notification(self, event: Event) -> bytes: serialized = self.serializer.serialize(event.marshal(self.serializer)) |