summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2022-09-03 13:42:45 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2022-09-03 21:34:40 +0300
commit02dd92a159b9ce99fe725949ab955d02f9450a75 (patch)
treefaf39e3efb939786d88832de89ea9b0e4515c50b /src
parent38e257e4b30f6affadebf60b6c7a0ee0282d9fe3 (diff)
downloadapscheduler-02dd92a159b9ce99fe725949ab955d02f9450a75.tar.gz
Improved the asyncpg event broker
* It now recovers from a server disconnection * Replaced from_asyncpg_pool() with from_dsn() * Only copy the connection options from the async engine in from_async_sqla_engine()
Diffstat (limited to 'src')
-rw-r--r--src/apscheduler/eventbrokers/asyncpg.py147
1 files changed, 108 insertions, 39 deletions
diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py
index 5666392..4ee2339 100644
--- a/src/apscheduler/eventbrokers/asyncpg.py
+++ b/src/apscheduler/eventbrokers/asyncpg.py
@@ -1,13 +1,24 @@
from __future__ import annotations
-from contextlib import asynccontextmanager
-from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable
+from collections.abc import Awaitable, Mapping
+from contextlib import AsyncExitStack
+from functools import partial
+from typing import TYPE_CHECKING, Any, Callable, cast
+import anyio
+import asyncpg
import attrs
-from anyio import TASK_STATUS_IGNORED, CancelScope, sleep
-from asyncpg import Connection
-from asyncpg.pool import Pool
-
+import tenacity
+from anyio import (
+ TASK_STATUS_IGNORED,
+ EndOfStream,
+ create_memory_object_stream,
+ move_on_after,
+)
+from anyio.streams.memory import MemoryObjectSendStream
+from asyncpg import Connection, InterfaceError
+
+from .. import RetrySettings
from .._events import Event
from .._exceptions import SerializationError
from ..abc import Serializer
@@ -27,41 +38,54 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
.. _asyncpg: https://pypi.org/project/asyncpg/
- :param connection_factory: a callable that creates an async context manager that
- yields a new asyncpg connection
+ :param connection_factory: a callable that creates an asyncpg connection
:param serializer: the serializer used to (de)serialize events for transport
:param channel: the ``NOTIFY`` channel to use
:param max_idle_time: maximum time to let the connection go idle, before sending a
``SELECT 1`` query to prevent a connection timeout
"""
- connection_factory: Callable[[], AsyncContextManager[Connection]]
+ connection_factory: Callable[[], Awaitable[Connection]]
serializer: Serializer = attrs.field(kw_only=True, factory=JSONSerializer)
channel: str = attrs.field(kw_only=True, default="apscheduler")
- max_idle_time: float = attrs.field(kw_only=True, default=30)
- _listen_cancel_scope: CancelScope = attrs.field(init=False)
+ retry_settings: RetrySettings = attrs.field(default=RetrySettings())
+ max_idle_time: float = attrs.field(kw_only=True, default=10)
+ _send: MemoryObjectSendStream[str] = attrs.field(init=False)
@classmethod
- def from_asyncpg_pool(cls, pool: Pool, **kwargs) -> AsyncpgEventBroker:
+ def from_dsn(
+ cls, dsn: str, options: Mapping[str, Any] | None = None, **kwargs: Any
+ ) -> AsyncpgEventBroker:
"""
Create a new asyncpg event broker from an existing asyncpg connection pool.
- :param pool: an asyncpg connection pool
+ :param dsn: data source name, passed as first positional argument to
+ :func:`asyncpg.connect`
+ :param options: keyword arguments passed to :func:`asyncpg.connect`
:param kwargs: keyword arguments to pass to the initializer of this class
:return: the newly created event broker
"""
- return cls(pool.acquire, **kwargs)
+ factory = partial(asyncpg.connect, dsn, **(options or {}))
+ return cls(factory, **kwargs)
@classmethod
def from_async_sqla_engine(
- cls, engine: AsyncEngine, **kwargs
+ cls,
+ engine: AsyncEngine,
+ options: Mapping[str, Any] | None = None,
+ **kwargs: Any,
) -> AsyncpgEventBroker:
"""
Create a new asyncpg event broker from an SQLAlchemy engine.
+ The engine will only be used to create the appropriate options for
+ :func:`asyncpg.connect`.
+
:param engine: an asynchronous SQLAlchemy engine using asyncpg as the driver
:type engine: ~sqlalchemy.ext.asyncio.AsyncEngine
+ :param options: extra keyword arguments passed to :func:`asyncpg.connect` (will
+ override any automatically generated arguments based on the engine)
:param kwargs: keyword arguments to pass to the initializer of this class
:return: the newly created event broker
@@ -72,47 +96,94 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
f"{engine.dialect.driver})"
)
- @asynccontextmanager
- async def connection_factory() -> AsyncGenerator[Connection, None]:
- conn = await engine.raw_connection()
- try:
- yield conn.connection._connection
- finally:
- conn.close()
+ connect_args = dict(engine.url.query)
+ for optname in ("host", "port", "database", "username", "password"):
+ value = getattr(engine.url, optname)
+ if value is not None:
+ if optname == "username":
+ optname = "user"
+
+ connect_args[optname] = value
+
+ if options:
+ connect_args |= options
- return cls(connection_factory, **kwargs)
+ factory = partial(asyncpg.connect, **connect_args)
+ return cls(factory, **kwargs)
+
+ def _retry(self) -> tenacity.AsyncRetrying:
+ def after_attempt(retry_state: tenacity.RetryCallState) -> None:
+ self._logger.warning(
+ f"{self.__class__.__name__}: connection failure "
+ f"(attempt {retry_state.attempt_number}): "
+ f"{retry_state.outcome.exception()}",
+ )
+
+ return tenacity.AsyncRetrying(
+ stop=self.retry_settings.stop,
+ wait=self.retry_settings.wait,
+ retry=tenacity.retry_if_exception_type((OSError, InterfaceError)),
+ after=after_attempt,
+ sleep=anyio.sleep,
+ reraise=True,
+ )
async def start(self) -> None:
await super().start()
- self._listen_cancel_scope = await self._task_group.start(
- self._listen_notifications
+ self._send = cast(
+ MemoryObjectSendStream[str],
+ await self._task_group.start(self._listen_notifications),
)
async def stop(self, *, force: bool = False) -> None:
- self._listen_cancel_scope.cancel()
+ self._send.close()
await super().stop(force=force)
+ self._logger.info("Stopped event broker")
async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
- def callback(connection, pid, channel: str, payload: str) -> None:
+ def callback(
+ connection: Connection, pid: int, channel: str, payload: str
+ ) -> None:
event = self.reconstitute_event_str(payload)
if event is not None:
self._task_group.start_soon(self.publish_local, event)
task_started_sent = False
- with CancelScope() as cancel_scope:
- while True:
- async with self.connection_factory() as conn:
+ send, receive = create_memory_object_stream(100, str)
+ while True:
+ async with AsyncExitStack() as exit_stack:
+ async for attempt in self._retry():
+ with attempt:
+ conn = await self.connection_factory()
+
+ exit_stack.push_async_callback(conn.close)
+ self._logger.info("Connection established")
+ try:
await conn.add_listener(self.channel, callback)
+ exit_stack.push_async_callback(
+ conn.remove_listener, self.channel, callback
+ )
if not task_started_sent:
- task_status.started(cancel_scope)
+ task_status.started(send)
task_started_sent = True
- try:
- while True:
- await sleep(self.max_idle_time)
+ while True:
+ notification: str | None = None
+ with move_on_after(self.max_idle_time):
+ try:
+ notification = await receive.receive()
+ except EndOfStream:
+ self._logger.info("Stream finished")
+ return
+
+ if notification:
+ await conn.execute(
+ "SELECT pg_notify($1, $2)", self.channel, notification
+ )
+ else:
await conn.execute("SELECT 1")
- finally:
- await conn.remove_listener(self.channel, callback)
+ except InterfaceError as exc:
+ self._logger.error("Connection error: %s", exc)
async def publish(self, event: Event) -> None:
notification = self.generate_notification_str(event)
@@ -121,6 +192,4 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
"Serialized event object exceeds 7999 bytes in size"
)
- async with self.connection_factory() as conn:
- await conn.execute("SELECT pg_notify($1, $2)", self.channel, notification)
- return
+ await self._send.send(notification)