diff options
Diffstat (limited to 'src/apscheduler/datastores')
-rw-r--r-- | src/apscheduler/datastores/async_adapter.py | 101 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_sqlalchemy.py | 602 | ||||
-rw-r--r-- | src/apscheduler/datastores/base.py | 58 | ||||
-rw-r--r-- | src/apscheduler/datastores/memory.py | 62 | ||||
-rw-r--r-- | src/apscheduler/datastores/mongodb.py | 109 | ||||
-rw-r--r-- | src/apscheduler/datastores/sqlalchemy.py | 865 |
6 files changed, 586 insertions, 1211 deletions
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py deleted file mode 100644 index d16ae56..0000000 --- a/src/apscheduler/datastores/async_adapter.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import sys -from datetime import datetime -from typing import Iterable -from uuid import UUID - -import attrs -from anyio import to_thread -from anyio.from_thread import BlockingPortal - -from .._enums import ConflictPolicy -from .._structures import Job, JobResult, Schedule, Task -from ..abc import AsyncEventBroker, DataStore -from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter, SyncEventBrokerAdapter -from .base import BaseAsyncDataStore - - -@attrs.define(eq=False) -class AsyncDataStoreAdapter(BaseAsyncDataStore): - original: DataStore - _portal: BlockingPortal = attrs.field(init=False) - - async def start(self, event_broker: AsyncEventBroker) -> None: - await super().start(event_broker) - - self._portal = BlockingPortal() - await self._portal.__aenter__() - - if isinstance(event_broker, AsyncEventBrokerAdapter): - sync_event_broker = event_broker.original - else: - sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal) - - try: - await to_thread.run_sync(lambda: self.original.start(sync_event_broker)) - except BaseException: - await self._portal.__aexit__(*sys.exc_info()) - raise - - async def stop(self, *, force: bool = False) -> None: - try: - await to_thread.run_sync(lambda: self.original.stop(force=force)) - finally: - await self._portal.__aexit__(None, None, None) - await super().stop(force=force) - - async def add_task(self, task: Task) -> None: - await to_thread.run_sync(self.original.add_task, task) - - async def remove_task(self, task_id: str) -> None: - await to_thread.run_sync(self.original.remove_task, task_id) - - async def get_task(self, task_id: str) -> Task: - return await to_thread.run_sync(self.original.get_task, task_id) - - async def get_tasks(self) -> list[Task]: - return await to_thread.run_sync(self.original.get_tasks) - - async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: - return await to_thread.run_sync(self.original.get_schedules, ids) - - async def add_schedule( - self, schedule: Schedule, conflict_policy: ConflictPolicy - ) -> None: - await to_thread.run_sync(self.original.add_schedule, schedule, conflict_policy) - - async def remove_schedules(self, ids: Iterable[str]) -> None: - await to_thread.run_sync(self.original.remove_schedules, ids) - - async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - return await to_thread.run_sync( - self.original.acquire_schedules, scheduler_id, limit - ) - - async def release_schedules( - self, scheduler_id: str, schedules: list[Schedule] - ) -> None: - await to_thread.run_sync( - self.original.release_schedules, scheduler_id, schedules - ) - - async def get_next_schedule_run_time(self) -> datetime | None: - return await to_thread.run_sync(self.original.get_next_schedule_run_time) - - async def add_job(self, job: Job) -> None: - await to_thread.run_sync(self.original.add_job, job) - - async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: - return await to_thread.run_sync(self.original.get_jobs, ids) - - async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit) - - async def release_job( - self, worker_id: str, task_id: str, result: JobResult - ) -> None: - await to_thread.run_sync(self.original.release_job, worker_id, task_id, result) - - async def get_job_result(self, job_id: UUID) -> JobResult | None: - return await to_thread.run_sync(self.original.get_job_result, job_id) diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py deleted file mode 100644 index 0c165b8..0000000 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ /dev/null @@ -1,602 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from datetime import datetime, timedelta, timezone -from typing import Any, Iterable -from uuid import UUID - -import anyio -import attrs -import sniffio -import tenacity -from sqlalchemy import and_, bindparam, or_, select -from sqlalchemy.engine import URL, Result -from sqlalchemy.exc import IntegrityError, InterfaceError -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.ext.asyncio.engine import AsyncEngine -from sqlalchemy.sql.ddl import DropTable -from sqlalchemy.sql.elements import BindParameter - -from .._enums import ConflictPolicy -from .._events import ( - DataStoreEvent, - JobAcquired, - JobAdded, - JobDeserializationFailed, - ScheduleAdded, - ScheduleDeserializationFailed, - ScheduleRemoved, - ScheduleUpdated, - TaskAdded, - TaskRemoved, - TaskUpdated, -) -from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError -from .._structures import Job, JobResult, Schedule, Task -from ..abc import AsyncEventBroker -from ..marshalling import callable_to_ref -from .base import BaseAsyncDataStore -from .sqlalchemy import _BaseSQLAlchemyDataStore - - -@attrs.define(eq=False) -class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore): - """ - Uses a relational database to store data. - - When started, this data store creates the appropriate tables on the given database - if they're not already present. - - Operations are retried (in accordance to ``retry_settings``) when an operation - raises :exc:`sqlalchemy.OperationalError`. - - This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL - (asyncmy driver). - - :param engine: an asynchronous SQLAlchemy engine - :param schema: a database schema name to use, if not the default - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) - """ - - engine: AsyncEngine - - @classmethod - def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: - """ - Create a new asynchronous SQLAlchemy data store. - - :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` - (must use an async dialect like ``asyncpg`` or ``asyncmy``) - :param kwargs: keyword arguments to pass to the initializer of this class - :return: the newly created data store - - """ - engine = create_async_engine(url, future=True) - return cls(engine, **options) - - def _retry(self) -> tenacity.AsyncRetrying: - # OSError is raised by asyncpg if it can't connect - return tenacity.AsyncRetrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), - after=self._after_attempt, - sleep=anyio.sleep, - reraise=True, - ) - - async def start(self, event_broker: AsyncEventBroker) -> None: - await super().start(event_broker) - - asynclib = sniffio.current_async_library() or "(unknown)" - if asynclib != "asyncio": - raise RuntimeError( - f"This data store requires asyncio; currently running: {asynclib}" - ) - - # Verify that the schema is in place - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - await conn.execute(DropTable(table, if_exists=True)) - - await conn.run_sync(self._metadata.create_all) - query = select(self.t_metadata.c.schema_version) - result = await conn.execute(query) - version = result.scalar() - if version is None: - await conn.execute( - self.t_metadata.insert(values={"schema_version": 1}) - ) - elif version > 1: - raise RuntimeError( - f"Unexpected schema version ({version}); " - f"only version 1 is supported by this version of " - f"APScheduler" - ) - - async def _deserialize_schedules(self, result: Result) -> list[Schedule]: - schedules: list[Schedule] = [] - for row in result: - try: - schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) - except SerializationError as exc: - await self._events.publish( - ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) - ) - - return schedules - - async def _deserialize_jobs(self, result: Result) -> list[Job]: - jobs: list[Job] = [] - for row in result: - try: - jobs.append(Job.unmarshal(self.serializer, row._asdict())) - except SerializationError as exc: - await self._events.publish( - JobDeserializationFailed(job_id=row["id"], exception=exc) - ) - - return jobs - - async def add_task(self, task: Task) -> None: - insert = self.t_tasks.insert().values( - id=task.id, - func=callable_to_ref(task.func), - max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time, - ) - try: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - except IntegrityError: - update = ( - self.t_tasks.update() - .values( - func=callable_to_ref(task.func), - max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time, - ) - .where(self.t_tasks.c.id == task.id) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(update) - - await self._events.publish(TaskUpdated(task_id=task.id)) - else: - await self._events.publish(TaskAdded(task_id=task.id)) - - async def remove_task(self, task_id: str) -> None: - delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - await self._events.publish(TaskRemoved(task_id=task_id)) - - async def get_task(self, task_id: str) -> Task: - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.func, - self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, - self.t_tasks.c.misfire_grace_time, - ] - ).where(self.t_tasks.c.id == task_id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - row = result.first() - - if row: - return Task.unmarshal(self.serializer, row._asdict()) - else: - raise TaskLookupError(task_id) - - async def get_tasks(self) -> list[Task]: - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.func, - self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, - self.t_tasks.c.misfire_grace_time, - ] - ).order_by(self.t_tasks.c.id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - tasks = [ - Task.unmarshal(self.serializer, row._asdict()) for row in result - ] - return tasks - - async def add_schedule( - self, schedule: Schedule, conflict_policy: ConflictPolicy - ) -> None: - event: DataStoreEvent - values = schedule.marshal(self.serializer) - insert = self.t_schedules.insert().values(**values) - try: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - except IntegrityError: - if conflict_policy is ConflictPolicy.exception: - raise ConflictingIdError(schedule.id) from None - elif conflict_policy is ConflictPolicy.replace: - del values["id"] - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id == schedule.id) - .values(**values) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(update) - - event = ScheduleUpdated( - schedule_id=schedule.id, next_fire_time=schedule.next_fire_time - ) - await self._events.publish(event) - else: - event = ScheduleAdded( - schedule_id=schedule.id, next_fire_time=schedule.next_fire_time - ) - await self._events.publish(event) - - async def remove_schedules(self, ids: Iterable[str]) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(ids) - ) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [ - row[0] for row in await conn.execute(delete) - ] - else: - # TODO: actually check which rows were deleted? - await conn.execute(delete) - removed_ids = ids - - for schedule_id in removed_ids: - await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) - - async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: - query = self.t_schedules.select().order_by(self.t_schedules.c.id) - if ids: - query = query.where(self.t_schedules.c.id.in_(ids)) - - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_schedules(result) - - async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = ( - select(self.t_schedules.c.id) - .where( - and_( - self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_( - self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now, - ), - ) - ) - .order_by(self.t_schedules.c.next_fire_time) - .limit(limit) - .with_for_update(skip_locked=True) - .cte() - ) - subselect = select([schedules_cte.c.id]) - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id.in_(subselect)) - .values(acquired_by=scheduler_id, acquired_until=acquired_until) - ) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = await conn.execute(update) - else: - await conn.execute(update) - query = self.t_schedules.select().where( - and_(self.t_schedules.c.acquired_by == scheduler_id) - ) - result = await conn.execute(query) - - schedules = await self._deserialize_schedules(result) - - return schedules - - async def release_schedules( - self, scheduler_id: str, schedules: list[Schedule] - ) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize( - schedule.trigger - ) - except SerializationError: - self._logger.exception( - "Error serializing trigger for schedule %r – " - "removing from data store", - schedule.id, - ) - finished_schedule_ids.append(schedule.id) - continue - - update_args.append( - { - "p_id": schedule.id, - "p_trigger": serialized_trigger, - "p_next_fire_time": schedule.next_fire_time, - } - ) - else: - finished_schedule_ids.append(schedule.id) - - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam("p_id") - p_trigger: BindParameter = bindparam("p_trigger") - p_next_fire_time: BindParameter = bindparam("p_next_fire_time") - update = ( - self.t_schedules.update() - .where( - and_( - self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id, - ) - ) - .values( - trigger=p_trigger, - next_fire_time=p_next_fire_time, - acquired_by=None, - acquired_until=None, - ) - ) - next_fire_times = { - arg["p_id"]: arg["p_next_fire_time"] for arg in update_args - } - # TODO: actually check which rows were updated? - await conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated( - schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id], - ) - update_events.append(event) - - # Remove schedules that have no next fire time or failed to - # serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(finished_schedule_ids) - ) - await conn.execute(delete) - - for event in update_events: - await self._events.publish(event) - - for schedule_id in finished_schedule_ids: - await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) - - async def get_next_schedule_run_time(self) -> datetime | None: - statenent = ( - select(self.t_schedules.c.next_fire_time) - .where(self.t_schedules.c.next_fire_time.isnot(None)) - .order_by(self.t_schedules.c.next_fire_time) - .limit(1) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(statenent) - return result.scalar() - - async def add_job(self, job: Job) -> None: - marshalled = job.marshal(self.serializer) - insert = self.t_jobs.insert().values(**marshalled) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - - event = JobAdded( - job_id=job.id, - task_id=job.task_id, - schedule_id=job.schedule_id, - tags=job.tags, - ) - await self._events.publish(event) - - async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: - query = self.t_jobs.select().order_by(self.t_jobs.c.id) - if ids: - job_ids = [job_id for job_id in ids] - query = query.where(self.t_jobs.c.id.in_(job_ids)) - - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_jobs(result) - - async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = ( - self.t_jobs.select() - .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) - .where( - or_( - self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now, - ) - ) - .order_by(self.t_jobs.c.created_at) - .with_for_update(skip_locked=True) - .limit(limit) - ) - - result = await conn.execute(query) - if not result: - return [] - - # Mark the jobs as acquired by this worker - jobs = await self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - - self.t_tasks.c.running_jobs, - ] - ).where( - self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids), - ) - result = await conn.execute(query) - job_slots_left: dict[str, int] = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = ( - self.t_jobs.update() - .values( - acquired_by=worker_id, acquired_until=acquired_until - ) - .where(self.t_jobs.c.id.in_(acquired_job_ids)) - ) - await conn.execute(update) - - # Increment the running job counters on each task - p_id: BindParameter = bindparam("p_id") - p_increment: BindParameter = bindparam("p_increment") - params = [ - {"p_id": task_id, "p_increment": increment} - for task_id, increment in increments.items() - ] - update = ( - self.t_tasks.update() - .values( - running_jobs=self.t_tasks.c.running_jobs + p_increment - ) - .where(self.t_tasks.c.id == p_id) - ) - await conn.execute(update, params) - - # Publish the appropriate events - for job in acquired_jobs: - await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) - - return acquired_jobs - - async def release_job( - self, worker_id: str, task_id: str, result: JobResult - ) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - # Record the job result - if result.expires_at > result.finished_at: - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - await conn.execute(insert) - - # Decrement the number of running jobs for this task - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs - 1) - .where(self.t_tasks.c.id == task_id) - ) - await conn.execute(update) - - # Delete the job - delete = self.t_jobs.delete().where( - self.t_jobs.c.id == result.job_id - ) - await conn.execute(delete) - - async def get_job_result(self, job_id: UUID) -> JobResult | None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().where( - self.t_job_results.c.job_id == job_id - ) - row = (await conn.execute(query)).first() - - # Delete the result - delete = self.t_job_results.delete().where( - self.t_job_results.c.job_id == job_id - ) - await conn.execute(delete) - - return ( - JobResult.unmarshal(self.serializer, row._asdict()) - if row - else None - ) diff --git a/src/apscheduler/datastores/base.py b/src/apscheduler/datastores/base.py index c05d28c..5c7ef7d 100644 --- a/src/apscheduler/datastores/base.py +++ b/src/apscheduler/datastores/base.py @@ -1,37 +1,47 @@ from __future__ import annotations -from apscheduler.abc import ( - AsyncDataStore, - AsyncEventBroker, - DataStore, - EventBroker, - EventSource, -) +from contextlib import AsyncExitStack +from logging import Logger, getLogger +import attrs +from .._retry import RetryMixin +from ..abc import DataStore, EventBroker, Serializer +from ..serializers.pickle import PickleSerializer + + +@attrs.define(kw_only=True) class BaseDataStore(DataStore): - _events: EventBroker + """ + Base class for data stores. - def start(self, event_broker: EventBroker) -> None: - self._events = event_broker + :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler + or worker can keep a lock on a schedule or task + """ - def stop(self, *, force: bool = False) -> None: - del self._events + lock_expiration_delay: float = 30 + _event_broker: EventBroker = attrs.field(init=False) + _logger: Logger = attrs.field(init=False) - @property - def events(self) -> EventSource: - return self._events + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: + self._event_broker = event_broker + def __attrs_post_init__(self): + self._logger = getLogger(self.__class__.__name__) -class BaseAsyncDataStore(AsyncDataStore): - _events: AsyncEventBroker - async def start(self, event_broker: AsyncEventBroker) -> None: - self._events = event_broker +@attrs.define(kw_only=True) +class BaseExternalDataStore(BaseDataStore, RetryMixin): + """ + Base class for data stores using an external service such as a database. - async def stop(self, *, force: bool = False) -> None: - del self._events + :param serializer: the serializer used to (de)serialize tasks, schedules and jobs + for storage + :param start_from_scratch: erase all existing data during startup (useful for test + suites) + """ - @property - def events(self) -> EventSource: - return self._events + serializer: Serializer = attrs.field(factory=PickleSerializer) + start_from_scratch: bool = attrs.field(default=False) diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py index a9ff3cb..fd7e90e 100644 --- a/src/apscheduler/datastores/memory.py +++ b/src/apscheduler/datastores/memory.py @@ -85,13 +85,9 @@ class MemoryDataStore(BaseDataStore): """ Stores scheduler data in memory, without serializing it. - Can be shared between multiple schedulers and workers within the same event loop. - - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task + Can be shared between multiple schedulers within the same event loop. """ - lock_expiration_delay: float = 30 _tasks: dict[str, TaskState] = attrs.Factory(dict) _schedules: list[ScheduleState] = attrs.Factory(list) _schedules_by_id: dict[str, ScheduleState] = attrs.Factory(dict) @@ -115,41 +111,43 @@ class MemoryDataStore(BaseDataStore): right_index = bisect_left(self._jobs, state) return self._jobs.index(state, left_index, right_index + 1) - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: + async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: return [ state.schedule for state in self._schedules if ids is None or state.schedule.id in ids ] - def add_task(self, task: Task) -> None: + async def add_task(self, task: Task) -> None: task_exists = task.id in self._tasks self._tasks[task.id] = TaskState(task) if task_exists: - self._events.publish(TaskUpdated(task_id=task.id)) + await self._event_broker.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._event_broker.publish(TaskAdded(task_id=task.id)) - def remove_task(self, task_id: str) -> None: + async def remove_task(self, task_id: str) -> None: try: del self._tasks[task_id] except KeyError: raise TaskLookupError(task_id) from None - self._events.publish(TaskRemoved(task_id=task_id)) + await self._event_broker.publish(TaskRemoved(task_id=task_id)) - def get_task(self, task_id: str) -> Task: + async def get_task(self, task_id: str) -> Task: try: return self._tasks[task_id].task except KeyError: raise TaskLookupError(task_id) from None - def get_tasks(self) -> list[Task]: + async def get_tasks(self) -> list[Task]: return sorted( (state.task for state in self._tasks.values()), key=lambda task: task.id ) - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: old_state = self._schedules_by_id.get(schedule.id) if old_state is not None: if conflict_policy is ConflictPolicy.do_nothing: @@ -175,17 +173,17 @@ class MemoryDataStore(BaseDataStore): schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) - def remove_schedules(self, ids: Iterable[str]) -> None: + async def remove_schedules(self, ids: Iterable[str]) -> None: for schedule_id in ids: state = self._schedules_by_id.pop(schedule_id, None) if state: self._schedules.remove(state) event = ScheduleRemoved(schedule_id=state.schedule.id) - self._events.publish(event) + await self._event_broker.publish(event) - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: + async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: now = datetime.now(timezone.utc) schedules: list[Schedule] = [] for state in self._schedules: @@ -206,7 +204,9 @@ class MemoryDataStore(BaseDataStore): return schedules - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: # Send update events for schedules that have a next time finished_schedule_ids: list[str] = [] for s in schedules: @@ -224,17 +224,17 @@ class MemoryDataStore(BaseDataStore): event = ScheduleUpdated( schedule_id=s.id, next_fire_time=s.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) else: finished_schedule_ids.append(s.id) # Remove schedules that didn't get a new next fire time - self.remove_schedules(finished_schedule_ids) + await self.remove_schedules(finished_schedule_ids) - def get_next_schedule_run_time(self) -> datetime | None: + async def get_next_schedule_run_time(self) -> datetime | None: return self._schedules[0].next_fire_time if self._schedules else None - def add_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: state = JobState(job) self._jobs.append(state) self._jobs_by_id[job.id] = state @@ -246,15 +246,15 @@ class MemoryDataStore(BaseDataStore): schedule_id=job.schedule_id, tags=job.tags, ) - self._events.publish(event) + await self._event_broker.publish(event) - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: + async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: if ids is not None: ids = frozenset(ids) return [state.job for state in self._jobs if ids is None or state.job.id in ids] - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: now = datetime.now(timezone.utc) jobs: list[Job] = [] for _index, job_state in enumerate(self._jobs): @@ -290,11 +290,15 @@ class MemoryDataStore(BaseDataStore): # Publish the appropriate events for job in jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + await self._event_broker.publish( + JobAcquired(job_id=job.id, worker_id=worker_id) + ) return jobs - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: # Record the job result if result.expires_at > result.finished_at: self._job_results[result.job_id] = result @@ -310,5 +314,5 @@ class MemoryDataStore(BaseDataStore): index = self._find_job_index(job_state) del self._jobs[index] - def get_job_result(self, job_id: UUID) -> JobResult | None: + async def get_job_result(self, job_id: UUID) -> JobResult | None: return self._job_results.pop(job_id, None) diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index 13299bb..4899f54 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -2,14 +2,13 @@ from __future__ import annotations import operator from collections import defaultdict +from contextlib import AsyncExitStack from datetime import datetime, timedelta, timezone -from logging import Logger, getLogger from typing import Any, Callable, ClassVar, Iterable from uuid import UUID import attrs import pymongo -import tenacity from attrs.validators import instance_of from bson import CodecOptions, UuidRepresentation from bson.codec_options import TypeEncoder, TypeRegistry @@ -35,10 +34,9 @@ from .._exceptions import ( SerializationError, TaskLookupError, ) -from .._structures import Job, JobResult, RetrySettings, Schedule, Task -from ..abc import EventBroker, Serializer -from ..serializers.pickle import PickleSerializer -from .base import BaseDataStore +from .._structures import Job, JobResult, Schedule, Task +from ..abc import EventBroker +from .base import BaseExternalDataStore class CustomEncoder(TypeEncoder): @@ -55,7 +53,7 @@ class CustomEncoder(TypeEncoder): @attrs.define(eq=False) -class MongoDBDataStore(BaseDataStore): +class MongoDBDataStore(BaseExternalDataStore): """ Uses a MongoDB server to store data. @@ -66,23 +64,11 @@ class MongoDBDataStore(BaseDataStore): raises :exc:`pymongo.errors.ConnectionFailure`. :param client: a PyMongo client - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage :param database: name of the database to use - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) """ client: MongoClient = attrs.field(validator=instance_of(MongoClient)) - serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True) database: str = attrs.field(default="apscheduler", kw_only=True) - lock_expiration_delay: float = attrs.field(default=30, kw_only=True) - retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True) - start_from_scratch: bool = attrs.field(default=False, kw_only=True) _task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)] _schedule_attrs: ClassVar[list[str]] = [ @@ -90,8 +76,8 @@ class MongoDBDataStore(BaseDataStore): ] _job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)] - _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__)) _local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict) + _temporary_failure_exceptions = (ConnectionFailure,) def __attrs_post_init__(self) -> None: type_registry = TypeRegistry( @@ -118,25 +104,10 @@ class MongoDBDataStore(BaseDataStore): client = MongoClient(uri) return cls(client, **options) - def _retry(self) -> tenacity.Retrying: - return tenacity.Retrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type(ConnectionFailure), - after=self._after_attempt, - reraise=True, - ) - - def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None: - self._logger.warning( - "Temporary data store error (attempt %d): %s", - retry_state.attempt_number, - retry_state.outcome.exception(), - ) - - def start(self, event_broker: EventBroker) -> None: - super().start(event_broker) - + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: + await super().start(exit_stack, event_broker) server_info = self.client.server_info() if server_info["versionArray"] < [4, 0]: raise RuntimeError( @@ -144,7 +115,7 @@ class MongoDBDataStore(BaseDataStore): f"{server_info['version']}" ) - for attempt in self._retry(): + async for attempt in self._retry(): with attempt, self.client.start_session() as session: if self.start_from_scratch: self._tasks.delete_many({}, session=session) @@ -159,7 +130,7 @@ class MongoDBDataStore(BaseDataStore): self._jobs_results.create_index("finished_at", session=session) self._jobs_results.create_index("expires_at", session=session) - def add_task(self, task: Task) -> None: + async def add_task(self, task: Task) -> None: for attempt in self._retry(): with attempt: previous = self._tasks.find_one_and_update( @@ -173,20 +144,20 @@ class MongoDBDataStore(BaseDataStore): self._local_tasks[task.id] = task if previous: - self._events.publish(TaskUpdated(task_id=task.id)) + await self._event_broker.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._event_broker.publish(TaskAdded(task_id=task.id)) - def remove_task(self, task_id: str) -> None: + async def remove_task(self, task_id: str) -> None: for attempt in self._retry(): with attempt: if not self._tasks.find_one_and_delete({"_id": task_id}): raise TaskLookupError(task_id) del self._local_tasks[task_id] - self._events.publish(TaskRemoved(task_id=task_id)) + await self._event_broker.publish(TaskRemoved(task_id=task_id)) - def get_task(self, task_id: str) -> Task: + async def get_task(self, task_id: str) -> Task: try: return self._local_tasks[task_id] except KeyError: @@ -205,7 +176,7 @@ class MongoDBDataStore(BaseDataStore): ) return task - def get_tasks(self) -> list[Task]: + async def get_tasks(self) -> list[Task]: for attempt in self._retry(): with attempt: tasks: list[Task] = [] @@ -217,7 +188,7 @@ class MongoDBDataStore(BaseDataStore): return tasks - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: + async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: filters = {"_id": {"$in": list(ids)}} if ids is not None else {} for attempt in self._retry(): with attempt: @@ -237,7 +208,9 @@ class MongoDBDataStore(BaseDataStore): return schedules - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: event: DataStoreEvent document = schedule.marshal(self.serializer) document["_id"] = document.pop("id") @@ -258,14 +231,14 @@ class MongoDBDataStore(BaseDataStore): event = ScheduleUpdated( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) else: event = ScheduleAdded( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) - def remove_schedules(self, ids: Iterable[str]) -> None: + async def remove_schedules(self, ids: Iterable[str]) -> None: filters = {"_id": {"$in": list(ids)}} if ids is not None else {} for attempt in self._retry(): with attempt, self.client.start_session() as session: @@ -277,9 +250,9 @@ class MongoDBDataStore(BaseDataStore): self._schedules.delete_many(filters, session=session) for schedule_id in ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: + async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: for attempt in self._retry(): with attempt, self.client.start_session() as session: schedules: list[Schedule] = [] @@ -318,7 +291,9 @@ class MongoDBDataStore(BaseDataStore): return schedules - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: updated_schedules: list[tuple[str, datetime]] = [] finished_schedule_ids: list[str] = [] @@ -365,12 +340,12 @@ class MongoDBDataStore(BaseDataStore): event = ScheduleUpdated( schedule_id=schedule_id, next_fire_time=next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) for schedule_id in finished_schedule_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_next_schedule_run_time(self) -> datetime | None: + async def get_next_schedule_run_time(self) -> datetime | None: for attempt in self._retry(): with attempt: document = self._schedules.find_one( @@ -384,7 +359,7 @@ class MongoDBDataStore(BaseDataStore): else: return None - def add_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: document = job.marshal(self.serializer) document["_id"] = document.pop("id") for attempt in self._retry(): @@ -397,9 +372,9 @@ class MongoDBDataStore(BaseDataStore): schedule_id=job.schedule_id, tags=job.tags, ) - self._events.publish(event) + await self._event_broker.publish(event) - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: + async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: filters = {"_id": {"$in": list(ids)}} if ids is not None else {} for attempt in self._retry(): with attempt: @@ -419,7 +394,7 @@ class MongoDBDataStore(BaseDataStore): return jobs - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: for attempt in self._retry(): with attempt, self.client.start_session() as session: cursor = self._jobs.find( @@ -488,11 +463,15 @@ class MongoDBDataStore(BaseDataStore): # Publish the appropriate events for job in acquired_jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + await self._event_broker.publish( + JobAcquired(job_id=job.id, worker_id=worker_id) + ) return acquired_jobs - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: for attempt in self._retry(): with attempt, self.client.start_session() as session: # Record the job result @@ -509,7 +488,7 @@ class MongoDBDataStore(BaseDataStore): # Delete the job self._jobs.delete_one({"_id": result.job_id}, session=session) - def get_job_result(self, job_id: UUID) -> JobResult | None: + async def get_job_result(self, job_id: UUID) -> JobResult | None: for attempt in self._retry(): with attempt: document = self._jobs_results.find_one_and_delete({"_id": job_id}) diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index ebb076f..00d06aa 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -1,16 +1,20 @@ from __future__ import annotations +import sys from collections import defaultdict +from collections.abc import AsyncGenerator, Sequence +from contextlib import AsyncExitStack, asynccontextmanager from datetime import datetime, timedelta, timezone -from logging import Logger, getLogger from typing import Any, Iterable from uuid import UUID +import anyio import attrs +import sniffio import tenacity +from anyio import to_thread from sqlalchemy import ( JSON, - TIMESTAMP, BigInteger, Column, Enum, @@ -26,14 +30,15 @@ from sqlalchemy import ( select, ) from sqlalchemy.engine import URL, Dialect, Result -from sqlalchemy.exc import CompileError, IntegrityError, OperationalError -from sqlalchemy.future import Engine, create_engine +from sqlalchemy.exc import CompileError, IntegrityError, InterfaceError +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine +from sqlalchemy.future import Connection, Engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome from .._events import ( - Event, + DataStoreEvent, JobAcquired, JobAdded, JobDeserializationFailed, @@ -46,11 +51,15 @@ from .._events import ( TaskUpdated, ) from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError -from .._structures import Job, JobResult, RetrySettings, Schedule, Task -from ..abc import EventBroker, Serializer +from .._structures import Job, JobResult, Schedule, Task +from ..abc import EventBroker from ..marshalling import callable_to_ref -from ..serializers.pickle import PickleSerializer -from .base import BaseDataStore +from .base import BaseExternalDataStore + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self class EmulatedUUID(TypeDecorator): @@ -86,17 +95,30 @@ class EmulatedInterval(TypeDecorator): return timedelta(seconds=value) if value is not None else None -@attrs.define(kw_only=True, eq=False) -class _BaseSQLAlchemyDataStore: +@attrs.define(eq=False) +class SQLAlchemyDataStore(BaseExternalDataStore): + """ + Uses a relational database to store data. + + When started, this data store creates the appropriate tables on the given database + if they're not already present. + + Operations are retried (in accordance to ``retry_settings``) when an operation + raises :exc:`sqlalchemy.OperationalError`. + + This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL + (asyncmy driver). + + :param engine: an asynchronous SQLAlchemy engine + :param schema: a database schema name to use, if not the default + """ + + engine: Engine | AsyncEngine schema: str | None = attrs.field(default=None) - serializer: Serializer = attrs.field(factory=PickleSerializer) - lock_expiration_delay: float = attrs.field(default=30) max_poll_time: float | None = attrs.field(default=1) max_idle_time: float = attrs.field(default=60) - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) - start_from_scratch: bool = attrs.field(default=False) - _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__)) + _is_async: bool = attrs.field(init=False) def __attrs_post_init__(self) -> None: # Generate the table definitions @@ -107,6 +129,7 @@ class _BaseSQLAlchemyDataStore: self.t_schedules = self._metadata.tables[prefix + "schedules"] self.t_jobs = self._metadata.tables[prefix + "jobs"] self.t_job_results = self._metadata.tables[prefix + "job_results"] + self._is_async = isinstance(self.engine, AsyncEngine) # Find out if the dialect supports UPDATE...RETURNING update = self.t_jobs.update().returning(self.t_jobs.c.id) @@ -117,18 +140,76 @@ class _BaseSQLAlchemyDataStore: else: self._supports_update_returning = True - def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None: - self._logger.warning( - "Temporary data store error (attempt %d): %s", - retry_state.attempt_number, - retry_state.outcome.exception(), + @classmethod + def from_url(cls: type[Self], url: str | URL, **options) -> Self: + """ + Create a new asynchronous SQLAlchemy data store. + + :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` + (must use an async dialect like ``asyncpg`` or ``asyncmy``) + :param kwargs: keyword arguments to pass to the initializer of this class + :return: the newly created data store + + """ + engine = create_async_engine(url, future=True) + return cls(engine, **options) + + def _retry(self) -> tenacity.AsyncRetrying: + def after_attempt(self, retry_state: tenacity.RetryCallState) -> None: + self._logger.warning( + "Temporary data store error (attempt %d): %s", + retry_state.attempt_number, + retry_state.outcome.exception(), + ) + + # OSError is raised by asyncpg if it can't connect + return tenacity.AsyncRetrying( + stop=self.retry_settings.stop, + wait=self.retry_settings.wait, + retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), + after=after_attempt, + sleep=anyio.sleep, + reraise=True, ) + @asynccontextmanager + async def _begin_transaction(self) -> AsyncGenerator[Connection | AsyncConnection]: + if isinstance(self.engine, AsyncEngine): + async with self.engine.begin() as conn: + yield conn + else: + cm = self.engine.begin() + conn = await to_thread.run_sync(cm.__enter__) + try: + yield conn + except BaseException as exc: + await to_thread.run_sync(cm.__exit__, type(exc), exc, exc.__traceback__) + raise + else: + await to_thread.run_sync(cm.__exit__, None, None, None) + + async def _create_metadata(self, conn: Connection | AsyncConnection) -> None: + if isinstance(conn, AsyncConnection): + await conn.run_sync(self._metadata.create_all) + else: + await to_thread.run_sync(self._metadata.create_all, conn) + + async def _execute( + self, + conn: Connection | AsyncConnection, + statement, + parameters: Sequence | None = None, + ): + if isinstance(conn, AsyncConnection): + return await conn.execute(statement, parameters) + else: + return await to_thread.run_sync(conn.execute, statement, parameters) + def get_table_definitions(self) -> MetaData: if self.engine.dialect.name == "postgresql": from sqlalchemy.dialects import postgresql - timestamp_type = TIMESTAMP(timezone=True) + timestamp_type = postgresql.TIMESTAMP(timezone=True) job_id_type = postgresql.UUID(as_uuid=True) interval_type = postgresql.INTERVAL(precision=6) tags_type = postgresql.ARRAY(Unicode) @@ -145,6 +226,7 @@ class _BaseSQLAlchemyDataStore: metadata, Column("id", Unicode(500), primary_key=True), Column("func", Unicode(500), nullable=False), + Column("executor", Unicode(500), nullable=False), Column("state", LargeBinary), Column("max_running_jobs", Integer), Column("misfire_grace_time", interval_type), @@ -197,186 +279,160 @@ class _BaseSQLAlchemyDataStore: ) return metadata - def _deserialize_schedules(self, result: Result) -> list[Schedule]: + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: + await super().start(exit_stack, event_broker) + asynclib = sniffio.current_async_library() or "(unknown)" + if asynclib != "asyncio": + raise RuntimeError( + f"This data store requires asyncio; currently running: {asynclib}" + ) + + # Verify that the schema is in place + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + await self._execute(conn, DropTable(table, if_exists=True)) + + await self._create_metadata(conn) + query = select(self.t_metadata.c.schema_version) + result = await self._execute(conn, query) + version = result.scalar() + if version is None: + await self._execute( + conn, self.t_metadata.insert(values={"schema_version": 1}) + ) + elif version > 1: + raise RuntimeError( + f"Unexpected schema version ({version}); " + f"only version 1 is supported by this version of " + f"APScheduler" + ) + + async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: try: schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._event_broker.publish( ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) ) return schedules - def _deserialize_jobs(self, result: Result) -> list[Job]: + async def _deserialize_jobs(self, result: Result) -> list[Job]: jobs: list[Job] = [] for row in result: try: jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._event_broker.publish( JobDeserializationFailed(job_id=row["id"], exception=exc) ) return jobs - -@attrs.define(eq=False) -class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): - """ - Uses a relational database to store data. - - When started, this data store creates the appropriate tables on the given database - if they're not already present. - - Operations are retried (in accordance to ``retry_settings``) when an operation - raises :exc:`sqlalchemy.OperationalError`. - - This store has been tested to work with PostgreSQL (psycopg2 driver), MySQL - (pymysql driver) and SQLite. - - :param engine: a (synchronous) SQLAlchemy engine - :param schema: a database schema name to use, if not the default - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) - """ - - engine: Engine - - @classmethod - def from_url(cls, url: str | URL, **kwargs) -> SQLAlchemyDataStore: - """ - Create a new SQLAlchemy data store. - - :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` - :param kwargs: keyword arguments to pass to the initializer of this class - :return: the newly created data store - - """ - engine = create_engine(url) - return cls(engine, **kwargs) - - def _retry(self) -> tenacity.Retrying: - return tenacity.Retrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type(OperationalError), - after=self._after_attempt, - reraise=True, - ) - - def start(self, event_broker: EventBroker) -> None: - super().start(event_broker) - - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - conn.execute(DropTable(table, if_exists=True)) - - self._metadata.create_all(conn) - query = select(self.t_metadata.c.schema_version) - result = conn.execute(query) - version = result.scalar() - if version is None: - conn.execute(self.t_metadata.insert(values={"schema_version": 1})) - elif version > 1: - raise RuntimeError( - f"Unexpected schema version ({version}); " - f"only version 1 is supported by this version of APScheduler" - ) - - def add_task(self, task: Task) -> None: + async def add_task(self, task: Task) -> None: insert = self.t_tasks.insert().values( id=task.id, func=callable_to_ref(task.func), + executor=task.executor, max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) try: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) except IntegrityError: update = ( self.t_tasks.update() .values( func=callable_to_ref(task.func), + executor=task.executor, max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) .where(self.t_tasks.c.id == task.id) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(update) - self._events.publish(TaskUpdated(task_id=task.id)) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, update) + + await self._event_broker.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._event_broker.publish(TaskAdded(task_id=task.id)) - def remove_task(self, task_id: str) -> None: + async def remove_task(self, task_id: str) -> None: delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - self._events.publish(TaskRemoved(task_id=task_id)) - - def get_task(self, task_id: str) -> Task: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + await self._event_broker.publish(TaskRemoved(task_id=task_id)) + + async def get_task(self, task_id: str) -> Task: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, + self.t_tasks.c.executor, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).where(self.t_tasks.c.id == task_id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - row = result.first() + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + row = result.first() if row: return Task.unmarshal(self.serializer, row._asdict()) else: raise TaskLookupError(task_id) - def get_tasks(self) -> list[Task]: + async def get_tasks(self) -> list[Task]: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, + self.t_tasks.c.executor, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).order_by(self.t_tasks.c.id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - tasks = [ - Task.unmarshal(self.serializer, row._asdict()) for row in result - ] - return tasks - - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: - event: Event + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + tasks = [ + Task.unmarshal(self.serializer, row._asdict()) for row in result + ] + return tasks + + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: + event: DataStoreEvent values = schedule.marshal(self.serializer) insert = self.t_schedules.insert().values(**values) try: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) except IntegrityError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None @@ -387,185 +443,197 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): .where(self.t_schedules.c.id == schedule.id) .values(**values) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(update) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, update) event = ScheduleUpdated( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) else: event = ScheduleAdded( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) - - def remove_schedules(self, ids: Iterable[str]) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [ - row[0] for row in conn.execute(delete) - ] - else: - # TODO: actually check which rows were deleted? - conn.execute(delete) - removed_ids = ids + await self._event_broker.publish(event) + + async def remove_schedules(self, ids: Iterable[str]) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(ids) + ) + if self._supports_update_returning: + delete = delete.returning(self.t_schedules.c.id) + removed_ids: Iterable[str] = [ + row[0] for row in await self._execute(conn, delete) + ] + else: + # TODO: actually check which rows were deleted? + await self._execute(conn, delete) + removed_ids = ids for schedule_id in removed_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: + async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: query = self.t_schedules.select().order_by(self.t_schedules.c.id) if ids: query = query.where(self.t_schedules.c.id.in_(ids)) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_schedules(result) - - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = ( - select(self.t_schedules.c.id) - .where( - and_( - self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_( - self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now, - ), + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + return await self._deserialize_schedules(result) + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + schedules_cte = ( + select(self.t_schedules.c.id) + .where( + and_( + self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_( + self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now, + ), + ) ) + .order_by(self.t_schedules.c.next_fire_time) + .limit(limit) + .with_for_update(skip_locked=True) + .cte() ) - .order_by(self.t_schedules.c.next_fire_time) - .limit(limit) - .with_for_update(skip_locked=True) - .cte() - ) - subselect = select([schedules_cte.c.id]) - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id.in_(subselect)) - .values(acquired_by=scheduler_id, acquired_until=acquired_until) - ) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = conn.execute(update) - else: - conn.execute(update) - query = self.t_schedules.select().where( - and_(self.t_schedules.c.acquired_by == scheduler_id) + subselect = select([schedules_cte.c.id]) + update = ( + self.t_schedules.update() + .where(self.t_schedules.c.id.in_(subselect)) + .values(acquired_by=scheduler_id, acquired_until=acquired_until) ) - result = conn.execute(query) + if self._supports_update_returning: + update = update.returning(*self.t_schedules.columns) + result = await self._execute(conn, update) + else: + await self._execute(conn, update) + query = self.t_schedules.select().where( + and_(self.t_schedules.c.acquired_by == scheduler_id) + ) + result = await self._execute(conn, query) - schedules = self._deserialize_schedules(result) + schedules = await self._deserialize_schedules(result) return schedules - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize( - schedule.trigger - ) - except SerializationError: - self._logger.exception( - "Error serializing trigger for schedule %r – " - "removing from data store", - schedule.id, + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + update_events: list[ScheduleUpdated] = [] + finished_schedule_ids: list[str] = [] + update_args: list[dict[str, Any]] = [] + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_trigger = self.serializer.serialize( + schedule.trigger + ) + except SerializationError: + self._logger.exception( + "Error serializing trigger for schedule %r – " + "removing from data store", + schedule.id, + ) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append( + { + "p_id": schedule.id, + "p_trigger": serialized_trigger, + "p_next_fire_time": schedule.next_fire_time, + } ) + else: finished_schedule_ids.append(schedule.id) - continue - - update_args.append( - { - "p_id": schedule.id, - "p_trigger": serialized_trigger, - "p_next_fire_time": schedule.next_fire_time, - } - ) - else: - finished_schedule_ids.append(schedule.id) - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam("p_id") - p_trigger: BindParameter = bindparam("p_trigger") - p_next_fire_time: BindParameter = bindparam("p_next_fire_time") - update = ( - self.t_schedules.update() - .where( - and_( - self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id, + # Update schedules that have a next fire time + if update_args: + p_id: BindParameter = bindparam("p_id") + p_trigger: BindParameter = bindparam("p_trigger") + p_next_fire_time: BindParameter = bindparam("p_next_fire_time") + update = ( + self.t_schedules.update() + .where( + and_( + self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id, + ) + ) + .values( + trigger=p_trigger, + next_fire_time=p_next_fire_time, + acquired_by=None, + acquired_until=None, ) ) - .values( - trigger=p_trigger, - next_fire_time=p_next_fire_time, - acquired_by=None, - acquired_until=None, - ) - ) - next_fire_times = { - arg["p_id"]: arg["p_next_fire_time"] for arg in update_args - } - # TODO: actually check which rows were updated? - conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated( - schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id], - ) - update_events.append(event) + next_fire_times = { + arg["p_id"]: arg["p_next_fire_time"] for arg in update_args + } + # TODO: actually check which rows were updated? + await self._execute(conn, update, update_args) + updated_ids = list(next_fire_times) + + for schedule_id in updated_ids: + event = ScheduleUpdated( + schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id], + ) + update_events.append(event) - # Remove schedules that have no next fire time or failed to serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(finished_schedule_ids) - ) - conn.execute(delete) + # Remove schedules that have no next fire time or failed to + # serialize + if finished_schedule_ids: + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(finished_schedule_ids) + ) + await self._execute(conn, delete) for event in update_events: - self._events.publish(event) + await self._event_broker.publish(event) for schedule_id in finished_schedule_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_next_schedule_run_time(self) -> datetime | None: - query = ( + async def get_next_schedule_run_time(self) -> datetime | None: + statenent = ( select(self.t_schedules.c.next_fire_time) .where(self.t_schedules.c.next_fire_time.isnot(None)) .order_by(self.t_schedules.c.next_fire_time) .limit(1) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return result.scalar() + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, statenent) + return result.scalar() - def add_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: marshalled = job.marshal(self.serializer) insert = self.t_jobs.insert().values(**marshalled) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) event = JobAdded( job_id=job.id, @@ -573,139 +641,156 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): schedule_id=job.schedule_id, tags=job.tags, ) - self._events.publish(event) + await self._event_broker.publish(event) - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: + async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: query = self.t_jobs.select().order_by(self.t_jobs.c.id) if ids: job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_jobs(result) - - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = ( - self.t_jobs.select() - .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) - .where( - or_( - self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now, + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + return await self._deserialize_jobs(result) + + async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + query = ( + self.t_jobs.select() + .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) + .where( + or_( + self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now, + ) ) + .order_by(self.t_jobs.c.created_at) + .with_for_update(skip_locked=True) + .limit(limit) ) - .order_by(self.t_jobs.c.created_at) - .with_for_update(skip_locked=True) - .limit(limit) - ) - - result = conn.execute(query) - if not result: - return [] - # Mark the jobs as acquired by this worker - jobs = self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs, - ] - ).where( - self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids), - ) - result = conn.execute(query) - job_slots_left = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = ( - self.t_jobs.update() - .values(acquired_by=worker_id, acquired_until=acquired_until) - .where(self.t_jobs.c.id.in_(acquired_job_ids)) + result = await self._execute(conn, query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = await self._deserialize_jobs(result) + task_ids: set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select( + [ + self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs + - self.t_tasks.c.running_jobs, + ] + ).where( + self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids), ) - conn.execute(update) - - # Increment the running job counters on each task - p_id: BindParameter = bindparam("p_id") - p_increment: BindParameter = bindparam("p_increment") - params = [ - {"p_id": task_id, "p_increment": increment} - for task_id, increment in increments.items() - ] - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs + p_increment) - .where(self.t_tasks.c.id == p_id) - ) - conn.execute(update, params) + result = await self._execute(conn, query) + job_slots_left: dict[str, int] = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: list[Job] = [] + increments: dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id for job in acquired_jobs] + update = ( + self.t_jobs.update() + .values( + acquired_by=worker_id, acquired_until=acquired_until + ) + .where(self.t_jobs.c.id.in_(acquired_job_ids)) + ) + await self._execute(conn, update) + + # Increment the running job counters on each task + p_id: BindParameter = bindparam("p_id") + p_increment: BindParameter = bindparam("p_increment") + params = [ + {"p_id": task_id, "p_increment": increment} + for task_id, increment in increments.items() + ] + update = ( + self.t_tasks.update() + .values( + running_jobs=self.t_tasks.c.running_jobs + p_increment + ) + .where(self.t_tasks.c.id == p_id) + ) + await self._execute(conn, update, params) # Publish the appropriate events for job in acquired_jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + await self._event_broker.publish( + JobAcquired(job_id=job.id, worker_id=worker_id) + ) return acquired_jobs - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - # Insert the job result - if result.expires_at > result.finished_at: - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - conn.execute(insert) + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + # Record the job result + if result.expires_at > result.finished_at: + marshalled = result.marshal(self.serializer) + insert = self.t_job_results.insert().values(**marshalled) + await self._execute(conn, insert) + + # Decrement the number of running jobs for this task + update = ( + self.t_tasks.update() + .values(running_jobs=self.t_tasks.c.running_jobs - 1) + .where(self.t_tasks.c.id == task_id) + ) + await self._execute(conn, update) - # Decrement the running jobs counter - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs - 1) - .where(self.t_tasks.c.id == task_id) - ) - conn.execute(update) - - # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) - conn.execute(delete) - - def get_job_result(self, job_id: UUID) -> JobResult | None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().where( - self.t_job_results.c.job_id == job_id - ) - row = conn.execute(query).first() + # Delete the job + delete = self.t_jobs.delete().where( + self.t_jobs.c.id == result.job_id + ) + await self._execute(conn, delete) + + async def get_job_result(self, job_id: UUID) -> JobResult | None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + # Retrieve the result + query = self.t_job_results.select().where( + self.t_job_results.c.job_id == job_id + ) + row = (await self._execute(conn, query)).first() - # Delete the result - delete = self.t_job_results.delete().where( - self.t_job_results.c.job_id == job_id - ) - conn.execute(delete) + # Delete the result + delete = self.t_job_results.delete().where( + self.t_job_results.c.job_id == job_id + ) + await self._execute(conn, delete) - return ( - JobResult.unmarshal(self.serializer, row._asdict()) if row else None - ) + return ( + JobResult.unmarshal(self.serializer, row._asdict()) + if row + else None + ) |