From 02dd92a159b9ce99fe725949ab955d02f9450a75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 3 Sep 2022 13:42:45 +0300 Subject: Improved the asyncpg event broker * It now recovers from a server disconnection * Replaced from_asyncpg_pool() with from_dsn() * Only copy the connection options from the async engine in from_async_sqla_engine() --- src/apscheduler/eventbrokers/asyncpg.py | 147 +++++++++++++++++++++++--------- 1 file changed, 108 insertions(+), 39 deletions(-) (limited to 'src') diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py index 5666392..4ee2339 100644 --- a/src/apscheduler/eventbrokers/asyncpg.py +++ b/src/apscheduler/eventbrokers/asyncpg.py @@ -1,13 +1,24 @@ from __future__ import annotations -from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable +from collections.abc import Awaitable, Mapping +from contextlib import AsyncExitStack +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, cast +import anyio +import asyncpg import attrs -from anyio import TASK_STATUS_IGNORED, CancelScope, sleep -from asyncpg import Connection -from asyncpg.pool import Pool - +import tenacity +from anyio import ( + TASK_STATUS_IGNORED, + EndOfStream, + create_memory_object_stream, + move_on_after, +) +from anyio.streams.memory import MemoryObjectSendStream +from asyncpg import Connection, InterfaceError + +from .. import RetrySettings from .._events import Event from .._exceptions import SerializationError from ..abc import Serializer @@ -27,41 +38,54 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): .. _asyncpg: https://pypi.org/project/asyncpg/ - :param connection_factory: a callable that creates an async context manager that - yields a new asyncpg connection + :param connection_factory: a callable that creates an asyncpg connection :param serializer: the serializer used to (de)serialize events for transport :param channel: the ``NOTIFY`` channel to use :param max_idle_time: maximum time to let the connection go idle, before sending a ``SELECT 1`` query to prevent a connection timeout """ - connection_factory: Callable[[], AsyncContextManager[Connection]] + connection_factory: Callable[[], Awaitable[Connection]] serializer: Serializer = attrs.field(kw_only=True, factory=JSONSerializer) channel: str = attrs.field(kw_only=True, default="apscheduler") - max_idle_time: float = attrs.field(kw_only=True, default=30) - _listen_cancel_scope: CancelScope = attrs.field(init=False) + retry_settings: RetrySettings = attrs.field(default=RetrySettings()) + max_idle_time: float = attrs.field(kw_only=True, default=10) + _send: MemoryObjectSendStream[str] = attrs.field(init=False) @classmethod - def from_asyncpg_pool(cls, pool: Pool, **kwargs) -> AsyncpgEventBroker: + def from_dsn( + cls, dsn: str, options: Mapping[str, Any] | None = None, **kwargs: Any + ) -> AsyncpgEventBroker: """ Create a new asyncpg event broker from an existing asyncpg connection pool. - :param pool: an asyncpg connection pool + :param dsn: data source name, passed as first positional argument to + :func:`asyncpg.connect` + :param options: keyword arguments passed to :func:`asyncpg.connect` :param kwargs: keyword arguments to pass to the initializer of this class :return: the newly created event broker """ - return cls(pool.acquire, **kwargs) + factory = partial(asyncpg.connect, dsn, **(options or {})) + return cls(factory, **kwargs) @classmethod def from_async_sqla_engine( - cls, engine: AsyncEngine, **kwargs + cls, + engine: AsyncEngine, + options: Mapping[str, Any] | None = None, + **kwargs: Any, ) -> AsyncpgEventBroker: """ Create a new asyncpg event broker from an SQLAlchemy engine. + The engine will only be used to create the appropriate options for + :func:`asyncpg.connect`. + :param engine: an asynchronous SQLAlchemy engine using asyncpg as the driver :type engine: ~sqlalchemy.ext.asyncio.AsyncEngine + :param options: extra keyword arguments passed to :func:`asyncpg.connect` (will + override any automatically generated arguments based on the engine) :param kwargs: keyword arguments to pass to the initializer of this class :return: the newly created event broker @@ -72,47 +96,94 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): f"{engine.dialect.driver})" ) - @asynccontextmanager - async def connection_factory() -> AsyncGenerator[Connection, None]: - conn = await engine.raw_connection() - try: - yield conn.connection._connection - finally: - conn.close() + connect_args = dict(engine.url.query) + for optname in ("host", "port", "database", "username", "password"): + value = getattr(engine.url, optname) + if value is not None: + if optname == "username": + optname = "user" + + connect_args[optname] = value + + if options: + connect_args |= options - return cls(connection_factory, **kwargs) + factory = partial(asyncpg.connect, **connect_args) + return cls(factory, **kwargs) + + def _retry(self) -> tenacity.AsyncRetrying: + def after_attempt(retry_state: tenacity.RetryCallState) -> None: + self._logger.warning( + f"{self.__class__.__name__}: connection failure " + f"(attempt {retry_state.attempt_number}): " + f"{retry_state.outcome.exception()}", + ) + + return tenacity.AsyncRetrying( + stop=self.retry_settings.stop, + wait=self.retry_settings.wait, + retry=tenacity.retry_if_exception_type((OSError, InterfaceError)), + after=after_attempt, + sleep=anyio.sleep, + reraise=True, + ) async def start(self) -> None: await super().start() - self._listen_cancel_scope = await self._task_group.start( - self._listen_notifications + self._send = cast( + MemoryObjectSendStream[str], + await self._task_group.start(self._listen_notifications), ) async def stop(self, *, force: bool = False) -> None: - self._listen_cancel_scope.cancel() + self._send.close() await super().stop(force=force) + self._logger.info("Stopped event broker") async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: - def callback(connection, pid, channel: str, payload: str) -> None: + def callback( + connection: Connection, pid: int, channel: str, payload: str + ) -> None: event = self.reconstitute_event_str(payload) if event is not None: self._task_group.start_soon(self.publish_local, event) task_started_sent = False - with CancelScope() as cancel_scope: - while True: - async with self.connection_factory() as conn: + send, receive = create_memory_object_stream(100, str) + while True: + async with AsyncExitStack() as exit_stack: + async for attempt in self._retry(): + with attempt: + conn = await self.connection_factory() + + exit_stack.push_async_callback(conn.close) + self._logger.info("Connection established") + try: await conn.add_listener(self.channel, callback) + exit_stack.push_async_callback( + conn.remove_listener, self.channel, callback + ) if not task_started_sent: - task_status.started(cancel_scope) + task_status.started(send) task_started_sent = True - try: - while True: - await sleep(self.max_idle_time) + while True: + notification: str | None = None + with move_on_after(self.max_idle_time): + try: + notification = await receive.receive() + except EndOfStream: + self._logger.info("Stream finished") + return + + if notification: + await conn.execute( + "SELECT pg_notify($1, $2)", self.channel, notification + ) + else: await conn.execute("SELECT 1") - finally: - await conn.remove_listener(self.channel, callback) + except InterfaceError as exc: + self._logger.error("Connection error: %s", exc) async def publish(self, event: Event) -> None: notification = self.generate_notification_str(event) @@ -121,6 +192,4 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): "Serialized event object exceeds 7999 bytes in size" ) - async with self.connection_factory() as conn: - await conn.execute("SELECT pg_notify($1, $2)", self.channel, notification) - return + await self._send.send(notification) -- cgit v1.2.1