summaryrefslogtreecommitdiff
path: root/src/apscheduler/eventbrokers
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2022-06-09 12:40:55 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2022-07-19 00:48:51 +0300
commitf215c1ab45959095f6b499eb7b26356c5937ee8b (patch)
tree552687a0ed6e799b3da96eec5cd3fbb14d19f1b5 /src/apscheduler/eventbrokers
parente3158fdf59a7c92a9449a566a2b746a4024e582f (diff)
downloadapscheduler-f215c1ab45959095f6b499eb7b26356c5937ee8b.tar.gz
Added support for starting the sync scheduler (and worker) without the context manager
Diffstat (limited to 'src/apscheduler/eventbrokers')
-rw-r--r--src/apscheduler/eventbrokers/async_adapter.py68
-rw-r--r--src/apscheduler/eventbrokers/async_local.py16
-rw-r--r--src/apscheduler/eventbrokers/asyncpg.py46
-rw-r--r--src/apscheduler/eventbrokers/base.py16
-rw-r--r--src/apscheduler/eventbrokers/local.py16
-rw-r--r--src/apscheduler/eventbrokers/mqtt.py16
-rw-r--r--src/apscheduler/eventbrokers/redis.py13
7 files changed, 105 insertions, 86 deletions
diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py
index ebceaad..4c1cb6c 100644
--- a/src/apscheduler/eventbrokers/async_adapter.py
+++ b/src/apscheduler/eventbrokers/async_adapter.py
@@ -1,39 +1,65 @@
from __future__ import annotations
-from functools import partial
+from typing import Any, Callable, Iterable
import attrs
from anyio import to_thread
from anyio.from_thread import BlockingPortal
-from apscheduler.abc import EventBroker
-from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker
+from apscheduler.abc import AsyncEventBroker, EventBroker, Subscription
from apscheduler.events import Event
-from apscheduler.util import reentrant
-@reentrant
@attrs.define(eq=False)
-class AsyncEventBrokerAdapter(LocalAsyncEventBroker):
+class AsyncEventBrokerAdapter(AsyncEventBroker):
original: EventBroker
- portal: BlockingPortal
-
- async def __aenter__(self):
- await super().__aenter__()
- if not self.portal:
- self.portal = BlockingPortal()
- self._exit_stack.enter_async_context(self.portal)
+ async def start(self) -> None:
+ await to_thread.run_sync(self.original.start)
- await to_thread.run_sync(self.original.__enter__)
- self._exit_stack.push_async_exit(
- partial(to_thread.run_sync, self.original.__exit__)
- )
+ async def stop(self, *, force: bool = False) -> None:
+ await to_thread.run_sync(lambda: self.original.stop(force=force))
- # Relay events from the original broker to this one
- self._exit_stack.enter_context(
- self.original.subscribe(partial(self.portal.call, self.publish_local))
- )
+ async def publish_local(self, event: Event) -> None:
+ await to_thread.run_sync(self.original.publish_local, event)
async def publish(self, event: Event) -> None:
await to_thread.run_sync(self.original.publish, event)
+
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ one_shot: bool = False,
+ ) -> Subscription:
+ return self.original.subscribe(callback, event_types, one_shot=one_shot)
+
+
+@attrs.define(eq=False)
+class SyncEventBrokerAdapter(EventBroker):
+ original: AsyncEventBroker
+ portal: BlockingPortal
+
+ def start(self) -> None:
+ pass
+
+ def stop(self, *, force: bool = False) -> None:
+ pass
+
+ def publish_local(self, event: Event) -> None:
+ self.portal.call(self.original.publish_local, event)
+
+ def publish(self, event: Event) -> None:
+ self.portal.call(self.original.publish, event)
+
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ one_shot: bool = False,
+ ) -> Subscription:
+ return self.portal.call(
+ lambda: self.original.subscribe(callback, event_types, one_shot=one_shot)
+ )
diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py
index 98031bf..dba37f1 100644
--- a/src/apscheduler/eventbrokers/async_local.py
+++ b/src/apscheduler/eventbrokers/async_local.py
@@ -1,7 +1,6 @@
from __future__ import annotations
from asyncio import iscoroutine
-from contextlib import AsyncExitStack
from typing import Any, Callable
import attrs
@@ -10,26 +9,19 @@ from anyio.abc import TaskGroup
from ..abc import AsyncEventBroker
from ..events import Event
-from ..util import reentrant
from .base import BaseEventBroker
-@reentrant
@attrs.define(eq=False)
class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker):
_task_group: TaskGroup = attrs.field(init=False)
- _exit_stack: AsyncExitStack = attrs.field(init=False)
-
- async def __aenter__(self) -> LocalAsyncEventBroker:
- self._exit_stack = AsyncExitStack()
+ async def start(self) -> None:
self._task_group = create_task_group()
- await self._exit_stack.enter_async_context(self._task_group)
-
- return self
+ await self._task_group.__aenter__()
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
+ async def stop(self, *, force: bool = False) -> None:
+ await self._task_group.__aexit__(None, None, None)
del self._task_group
async def publish(self, event: Event) -> None:
diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py
index d97c229..c80033e 100644
--- a/src/apscheduler/eventbrokers/asyncpg.py
+++ b/src/apscheduler/eventbrokers/asyncpg.py
@@ -4,13 +4,12 @@ from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable
import attrs
-from anyio import TASK_STATUS_IGNORED, sleep
+from anyio import TASK_STATUS_IGNORED, CancelScope, sleep
from asyncpg import Connection
from asyncpg.pool import Pool
from ..events import Event
from ..exceptions import SerializationError
-from ..util import reentrant
from .async_local import LocalAsyncEventBroker
from .base import DistributedEventBrokerMixin
@@ -18,12 +17,12 @@ if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
-@reentrant
@attrs.define(eq=False)
class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
connection_factory: Callable[[], AsyncContextManager[Connection]]
channel: str = attrs.field(kw_only=True, default="apscheduler")
max_idle_time: float = attrs.field(kw_only=True, default=30)
+ _listen_cancel_scope: CancelScope = attrs.field(init=False)
@classmethod
def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker:
@@ -47,11 +46,15 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
return cls(connection_factory)
- async def __aenter__(self) -> AsyncpgEventBroker:
- await super().__aenter__()
- await self._task_group.start(self._listen_notifications)
- self._exit_stack.callback(self._task_group.cancel_scope.cancel)
- return self
+ async def start(self) -> None:
+ await super().start()
+ self._listen_cancel_scope = await self._task_group.start(
+ self._listen_notifications
+ )
+
+ async def stop(self, *, force: bool = False) -> None:
+ self._listen_cancel_scope.cancel()
+ await super().stop(force=force)
async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
def callback(connection, pid, channel: str, payload: str) -> None:
@@ -60,19 +63,20 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
self._task_group.start_soon(self.publish_local, event)
task_started_sent = False
- while True:
- async with self.connection_factory() as conn:
- await conn.add_listener(self.channel, callback)
- if not task_started_sent:
- task_status.started()
- task_started_sent = True
-
- try:
- while True:
- await sleep(self.max_idle_time)
- await conn.execute("SELECT 1")
- finally:
- await conn.remove_listener(self.channel, callback)
+ with CancelScope() as cancel_scope:
+ while True:
+ async with self.connection_factory() as conn:
+ await conn.add_listener(self.channel, callback)
+ if not task_started_sent:
+ task_status.started(cancel_scope)
+ task_started_sent = True
+
+ try:
+ while True:
+ await sleep(self.max_idle_time)
+ await conn.execute("SELECT 1")
+ finally:
+ await conn.remove_listener(self.channel, callback)
async def publish(self, event: Event) -> None:
notification = self.generate_notification_str(event)
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py
index 38c83c7..3aeee77 100644
--- a/src/apscheduler/eventbrokers/base.py
+++ b/src/apscheduler/eventbrokers/base.py
@@ -7,7 +7,7 @@ from typing import Any, Callable, Iterable
import attrs
from .. import events
-from ..abc import EventBroker, Serializer, Subscription
+from ..abc import EventSource, Serializer, Subscription
from ..events import Event
from ..exceptions import DeserializationError
@@ -25,7 +25,7 @@ class LocalSubscription(Subscription):
@attrs.define(eq=False)
-class BaseEventBroker(EventBroker):
+class BaseEventBroker(EventSource):
_logger: Logger = attrs.field(init=False)
_subscriptions: dict[object, LocalSubscription] = attrs.field(
init=False, factory=dict
@@ -70,7 +70,7 @@ class DistributedEventBrokerMixin:
self._logger.exception(
"Failed to deserialize an event of type %s",
event_type,
- serialized=serialized,
+ extra={"serialized": serialized},
)
return None
@@ -80,7 +80,7 @@ class DistributedEventBrokerMixin:
self._logger.error(
"Receive notification for a nonexistent event type: %s",
event_type,
- serialized=serialized,
+ extra={"serialized": serialized},
)
return None
@@ -94,7 +94,9 @@ class DistributedEventBrokerMixin:
try:
event_type_bytes, serialized = payload.split(b" ", 1)
except ValueError:
- self._logger.error("Received malformatted notification", payload=payload)
+ self._logger.error(
+ "Received malformatted notification", extra={"payload": payload}
+ )
return None
event_type = event_type_bytes.decode("ascii", errors="replace")
@@ -104,7 +106,9 @@ class DistributedEventBrokerMixin:
try:
event_type, b64_serialized = payload.split(" ", 1)
except ValueError:
- self._logger.error("Received malformatted notification", payload=payload)
+ self._logger.error(
+ "Received malformatted notification", extra={"payload": payload}
+ )
return None
return self._reconstitute_event(event_type, b64decode(b64_serialized))
diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py
index cf19526..a870c45 100644
--- a/src/apscheduler/eventbrokers/local.py
+++ b/src/apscheduler/eventbrokers/local.py
@@ -8,26 +8,22 @@ from typing import Any, Callable, Iterable
import attrs
-from ..abc import Subscription
+from ..abc import EventBroker, Subscription
from ..events import Event
-from ..util import reentrant
from .base import BaseEventBroker
-@reentrant
@attrs.define(eq=False)
-class LocalEventBroker(BaseEventBroker):
+class LocalEventBroker(EventBroker, BaseEventBroker):
_executor: ThreadPoolExecutor = attrs.field(init=False)
_exit_stack: ExitStack = attrs.field(init=False)
_subscriptions_lock: Lock = attrs.field(init=False, factory=Lock)
- def __enter__(self):
- self._exit_stack = ExitStack()
- self._executor = self._exit_stack.enter_context(ThreadPoolExecutor(1))
- return self
+ def start(self) -> None:
+ self._executor = ThreadPoolExecutor(1)
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
+ def stop(self, *, force: bool = False) -> None:
+ self._executor.shutdown(wait=not force)
del self._executor
def subscribe(
diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py
index 8b0a512..9b0f859 100644
--- a/src/apscheduler/eventbrokers/mqtt.py
+++ b/src/apscheduler/eventbrokers/mqtt.py
@@ -11,12 +11,10 @@ from paho.mqtt.reasoncodes import ReasonCodes
from ..abc import Serializer
from ..events import Event
from ..serializers.json import JSONSerializer
-from ..util import reentrant
from .base import DistributedEventBrokerMixin
from .local import LocalEventBroker
-@reentrant
@attrs.define(eq=False)
class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
client: Client
@@ -28,8 +26,8 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
publish_qos: int = attrs.field(kw_only=True, default=0)
_ready_future: Future[None] = attrs.field(init=False)
- def __enter__(self):
- super().__enter__()
+ def start(self) -> None:
+ super().start()
self._ready_future = Future()
self.client.enable_logger(self._logger)
self.client.on_connect = self._on_connect
@@ -38,11 +36,11 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
self.client.connect(self.host, self.port)
self.client.loop_start()
self._ready_future.result(10)
- self._exit_stack.push(
- lambda exc_type, *_: self.client.loop_stop(force=bool(exc_type))
- )
- self._exit_stack.callback(self.client.disconnect)
- return self
+
+ def stop(self, *, force: bool = False) -> None:
+ self.client.disconnect()
+ self.client.loop_stop(force=force)
+ super().stop()
def _on_connect(
self,
diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py
index 0f14e3e..a41ef37 100644
--- a/src/apscheduler/eventbrokers/redis.py
+++ b/src/apscheduler/eventbrokers/redis.py
@@ -9,12 +9,10 @@ from redis import ConnectionPool, Redis
from ..abc import Serializer
from ..events import Event
from ..serializers.json import JSONSerializer
-from ..util import reentrant
from .base import DistributedEventBrokerMixin
from .local import LocalEventBroker
-@reentrant
@attrs.define(eq=False)
class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
client: Redis
@@ -23,6 +21,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
message_poll_interval: float = attrs.field(kw_only=True, default=0.05)
_stopped: bool = attrs.field(init=False, default=True)
_ready_future: Future[None] = attrs.field(init=False)
+ _thread: Thread = attrs.field(init=False)
@classmethod
def from_url(cls, url: str, **kwargs) -> RedisEventBroker:
@@ -30,7 +29,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
client = Redis(connection_pool=pool)
return cls(client)
- def __enter__(self):
+ def start(self) -> None:
self._stopped = False
self._ready_future = Future()
self._thread = Thread(
@@ -38,14 +37,14 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
)
self._thread.start()
self._ready_future.result(10)
- return super().__enter__()
+ super().start()
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def stop(self, *, force: bool = False) -> None:
self._stopped = True
- if not exc_type:
+ if not force:
self._thread.join(5)
- super().__exit__(exc_type, exc_val, exc_tb)
+ super().stop(force=force)
def _listen_messages(self) -> None:
while not self._stopped: