summaryrefslogtreecommitdiff
path: root/src/apscheduler/eventbrokers/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/eventbrokers/base.py')
-rw-r--r--src/apscheduler/eventbrokers/base.py64
1 files changed, 57 insertions, 7 deletions
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py
index 1373d54..5a3e718 100644
--- a/src/apscheduler/eventbrokers/base.py
+++ b/src/apscheduler/eventbrokers/base.py
@@ -1,15 +1,21 @@
from __future__ import annotations
from base64 import b64decode, b64encode
+from contextlib import AsyncExitStack
+from inspect import iscoroutine
from logging import Logger, getLogger
from typing import Any, Callable, Iterable
import attrs
+from anyio import create_task_group, to_thread
+from anyio.abc import TaskGroup
from .. import _events
from .._events import Event
from .._exceptions import DeserializationError
-from ..abc import EventSource, Serializer, Subscription
+from .._retry import RetryMixin
+from ..abc import EventBroker, Serializer, Subscription
+from ..serializers.json import JSONSerializer
@attrs.define(eq=False, frozen=True)
@@ -17,6 +23,7 @@ class LocalSubscription(Subscription):
callback: Callable[[Event], Any]
event_types: set[type[Event]] | None
one_shot: bool
+ is_async: bool
token: object
_source: BaseEventBroker
@@ -24,36 +31,79 @@ class LocalSubscription(Subscription):
self._source.unsubscribe(self.token)
-@attrs.define(eq=False)
-class BaseEventBroker(EventSource):
+@attrs.define(kw_only=True)
+class BaseEventBroker(EventBroker):
_logger: Logger = attrs.field(init=False)
_subscriptions: dict[object, LocalSubscription] = attrs.field(
init=False, factory=dict
)
+ _task_group: TaskGroup = attrs.field(init=False)
def __attrs_post_init__(self) -> None:
self._logger = getLogger(self.__class__.__module__)
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ self._task_group = await exit_stack.enter_async_context(create_task_group())
+
def subscribe(
self,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
*,
+ is_async: bool = True,
one_shot: bool = False,
) -> Subscription:
types = set(event_types) if event_types else None
token = object()
- subscription = LocalSubscription(callback, types, one_shot, token, self)
+ subscription = LocalSubscription(
+ callback, types, one_shot, is_async, token, self
+ )
self._subscriptions[token] = subscription
return subscription
def unsubscribe(self, token: object) -> None:
self._subscriptions.pop(token, None)
+ async def publish_local(self, event: Event) -> None:
+ event_type = type(event)
+ one_shot_tokens: list[object] = []
+ for _token, subscription in self._subscriptions.items():
+ if (
+ subscription.event_types is None
+ or event_type in subscription.event_types
+ ):
+ self._task_group.start_soon(self._deliver_event, subscription, event)
+ if subscription.one_shot:
+ one_shot_tokens.append(subscription.token)
+
+ for token in one_shot_tokens:
+ self.unsubscribe(token)
+
+ async def _deliver_event(
+ self, subscription: LocalSubscription, event: Event
+ ) -> None:
+ try:
+ if subscription.is_async:
+ retval = subscription.callback(event)
+ if iscoroutine(retval):
+ await retval
+ else:
+ await to_thread.run_sync(subscription.callback, event)
+ except Exception:
+ self._logger.exception(
+ "Error delivering %s event", event.__class__.__name__
+ )
+
+
+@attrs.define(kw_only=True)
+class BaseExternalEventBroker(BaseEventBroker, RetryMixin):
+ """
+ Base class for event brokers that use an external service.
+
+ :param serializer: the serializer used to (de)serialize events for transport
+ """
-class DistributedEventBrokerMixin:
- serializer: Serializer
- _logger: Logger
+ serializer: Serializer = attrs.field(factory=JSONSerializer)
def generate_notification(self, event: Event) -> bytes:
serialized = self.serializer.serialize(event.marshal(self.serializer))