summaryrefslogtreecommitdiff
path: root/src/apscheduler/datastores
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/datastores')
-rw-r--r--src/apscheduler/datastores/async_adapter.py101
-rw-r--r--src/apscheduler/datastores/async_sqlalchemy.py602
-rw-r--r--src/apscheduler/datastores/base.py58
-rw-r--r--src/apscheduler/datastores/memory.py62
-rw-r--r--src/apscheduler/datastores/mongodb.py109
-rw-r--r--src/apscheduler/datastores/sqlalchemy.py865
6 files changed, 586 insertions, 1211 deletions
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py
deleted file mode 100644
index d16ae56..0000000
--- a/src/apscheduler/datastores/async_adapter.py
+++ /dev/null
@@ -1,101 +0,0 @@
-from __future__ import annotations
-
-import sys
-from datetime import datetime
-from typing import Iterable
-from uuid import UUID
-
-import attrs
-from anyio import to_thread
-from anyio.from_thread import BlockingPortal
-
-from .._enums import ConflictPolicy
-from .._structures import Job, JobResult, Schedule, Task
-from ..abc import AsyncEventBroker, DataStore
-from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter, SyncEventBrokerAdapter
-from .base import BaseAsyncDataStore
-
-
-@attrs.define(eq=False)
-class AsyncDataStoreAdapter(BaseAsyncDataStore):
- original: DataStore
- _portal: BlockingPortal = attrs.field(init=False)
-
- async def start(self, event_broker: AsyncEventBroker) -> None:
- await super().start(event_broker)
-
- self._portal = BlockingPortal()
- await self._portal.__aenter__()
-
- if isinstance(event_broker, AsyncEventBrokerAdapter):
- sync_event_broker = event_broker.original
- else:
- sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal)
-
- try:
- await to_thread.run_sync(lambda: self.original.start(sync_event_broker))
- except BaseException:
- await self._portal.__aexit__(*sys.exc_info())
- raise
-
- async def stop(self, *, force: bool = False) -> None:
- try:
- await to_thread.run_sync(lambda: self.original.stop(force=force))
- finally:
- await self._portal.__aexit__(None, None, None)
- await super().stop(force=force)
-
- async def add_task(self, task: Task) -> None:
- await to_thread.run_sync(self.original.add_task, task)
-
- async def remove_task(self, task_id: str) -> None:
- await to_thread.run_sync(self.original.remove_task, task_id)
-
- async def get_task(self, task_id: str) -> Task:
- return await to_thread.run_sync(self.original.get_task, task_id)
-
- async def get_tasks(self) -> list[Task]:
- return await to_thread.run_sync(self.original.get_tasks)
-
- async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
- return await to_thread.run_sync(self.original.get_schedules, ids)
-
- async def add_schedule(
- self, schedule: Schedule, conflict_policy: ConflictPolicy
- ) -> None:
- await to_thread.run_sync(self.original.add_schedule, schedule, conflict_policy)
-
- async def remove_schedules(self, ids: Iterable[str]) -> None:
- await to_thread.run_sync(self.original.remove_schedules, ids)
-
- async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
- return await to_thread.run_sync(
- self.original.acquire_schedules, scheduler_id, limit
- )
-
- async def release_schedules(
- self, scheduler_id: str, schedules: list[Schedule]
- ) -> None:
- await to_thread.run_sync(
- self.original.release_schedules, scheduler_id, schedules
- )
-
- async def get_next_schedule_run_time(self) -> datetime | None:
- return await to_thread.run_sync(self.original.get_next_schedule_run_time)
-
- async def add_job(self, job: Job) -> None:
- await to_thread.run_sync(self.original.add_job, job)
-
- async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
- return await to_thread.run_sync(self.original.get_jobs, ids)
-
- async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
- return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit)
-
- async def release_job(
- self, worker_id: str, task_id: str, result: JobResult
- ) -> None:
- await to_thread.run_sync(self.original.release_job, worker_id, task_id, result)
-
- async def get_job_result(self, job_id: UUID) -> JobResult | None:
- return await to_thread.run_sync(self.original.get_job_result, job_id)
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py
deleted file mode 100644
index 0c165b8..0000000
--- a/src/apscheduler/datastores/async_sqlalchemy.py
+++ /dev/null
@@ -1,602 +0,0 @@
-from __future__ import annotations
-
-from collections import defaultdict
-from datetime import datetime, timedelta, timezone
-from typing import Any, Iterable
-from uuid import UUID
-
-import anyio
-import attrs
-import sniffio
-import tenacity
-from sqlalchemy import and_, bindparam, or_, select
-from sqlalchemy.engine import URL, Result
-from sqlalchemy.exc import IntegrityError, InterfaceError
-from sqlalchemy.ext.asyncio import create_async_engine
-from sqlalchemy.ext.asyncio.engine import AsyncEngine
-from sqlalchemy.sql.ddl import DropTable
-from sqlalchemy.sql.elements import BindParameter
-
-from .._enums import ConflictPolicy
-from .._events import (
- DataStoreEvent,
- JobAcquired,
- JobAdded,
- JobDeserializationFailed,
- ScheduleAdded,
- ScheduleDeserializationFailed,
- ScheduleRemoved,
- ScheduleUpdated,
- TaskAdded,
- TaskRemoved,
- TaskUpdated,
-)
-from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError
-from .._structures import Job, JobResult, Schedule, Task
-from ..abc import AsyncEventBroker
-from ..marshalling import callable_to_ref
-from .base import BaseAsyncDataStore
-from .sqlalchemy import _BaseSQLAlchemyDataStore
-
-
-@attrs.define(eq=False)
-class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore):
- """
- Uses a relational database to store data.
-
- When started, this data store creates the appropriate tables on the given database
- if they're not already present.
-
- Operations are retried (in accordance to ``retry_settings``) when an operation
- raises :exc:`sqlalchemy.OperationalError`.
-
- This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL
- (asyncmy driver).
-
- :param engine: an asynchronous SQLAlchemy engine
- :param schema: a database schema name to use, if not the default
- :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
- for storage
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
- :param retry_settings: Tenacity settings for retrying operations in case of a
- database connecitivty problem
- :param start_from_scratch: erase all existing data during startup (useful for test
- suites)
- """
-
- engine: AsyncEngine
-
- @classmethod
- def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore:
- """
- Create a new asynchronous SQLAlchemy data store.
-
- :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine`
- (must use an async dialect like ``asyncpg`` or ``asyncmy``)
- :param kwargs: keyword arguments to pass to the initializer of this class
- :return: the newly created data store
-
- """
- engine = create_async_engine(url, future=True)
- return cls(engine, **options)
-
- def _retry(self) -> tenacity.AsyncRetrying:
- # OSError is raised by asyncpg if it can't connect
- return tenacity.AsyncRetrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type((InterfaceError, OSError)),
- after=self._after_attempt,
- sleep=anyio.sleep,
- reraise=True,
- )
-
- async def start(self, event_broker: AsyncEventBroker) -> None:
- await super().start(event_broker)
-
- asynclib = sniffio.current_async_library() or "(unknown)"
- if asynclib != "asyncio":
- raise RuntimeError(
- f"This data store requires asyncio; currently running: {asynclib}"
- )
-
- # Verify that the schema is in place
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- if self.start_from_scratch:
- for table in self._metadata.sorted_tables:
- await conn.execute(DropTable(table, if_exists=True))
-
- await conn.run_sync(self._metadata.create_all)
- query = select(self.t_metadata.c.schema_version)
- result = await conn.execute(query)
- version = result.scalar()
- if version is None:
- await conn.execute(
- self.t_metadata.insert(values={"schema_version": 1})
- )
- elif version > 1:
- raise RuntimeError(
- f"Unexpected schema version ({version}); "
- f"only version 1 is supported by this version of "
- f"APScheduler"
- )
-
- async def _deserialize_schedules(self, result: Result) -> list[Schedule]:
- schedules: list[Schedule] = []
- for row in result:
- try:
- schedules.append(Schedule.unmarshal(self.serializer, row._asdict()))
- except SerializationError as exc:
- await self._events.publish(
- ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc)
- )
-
- return schedules
-
- async def _deserialize_jobs(self, result: Result) -> list[Job]:
- jobs: list[Job] = []
- for row in result:
- try:
- jobs.append(Job.unmarshal(self.serializer, row._asdict()))
- except SerializationError as exc:
- await self._events.publish(
- JobDeserializationFailed(job_id=row["id"], exception=exc)
- )
-
- return jobs
-
- async def add_task(self, task: Task) -> None:
- insert = self.t_tasks.insert().values(
- id=task.id,
- func=callable_to_ref(task.func),
- max_running_jobs=task.max_running_jobs,
- misfire_grace_time=task.misfire_grace_time,
- )
- try:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(insert)
- except IntegrityError:
- update = (
- self.t_tasks.update()
- .values(
- func=callable_to_ref(task.func),
- max_running_jobs=task.max_running_jobs,
- misfire_grace_time=task.misfire_grace_time,
- )
- .where(self.t_tasks.c.id == task.id)
- )
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(update)
-
- await self._events.publish(TaskUpdated(task_id=task.id))
- else:
- await self._events.publish(TaskAdded(task_id=task.id))
-
- async def remove_task(self, task_id: str) -> None:
- delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(delete)
- if result.rowcount == 0:
- raise TaskLookupError(task_id)
- else:
- await self._events.publish(TaskRemoved(task_id=task_id))
-
- async def get_task(self, task_id: str) -> Task:
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.func,
- self.t_tasks.c.max_running_jobs,
- self.t_tasks.c.state,
- self.t_tasks.c.misfire_grace_time,
- ]
- ).where(self.t_tasks.c.id == task_id)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- row = result.first()
-
- if row:
- return Task.unmarshal(self.serializer, row._asdict())
- else:
- raise TaskLookupError(task_id)
-
- async def get_tasks(self) -> list[Task]:
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.func,
- self.t_tasks.c.max_running_jobs,
- self.t_tasks.c.state,
- self.t_tasks.c.misfire_grace_time,
- ]
- ).order_by(self.t_tasks.c.id)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- tasks = [
- Task.unmarshal(self.serializer, row._asdict()) for row in result
- ]
- return tasks
-
- async def add_schedule(
- self, schedule: Schedule, conflict_policy: ConflictPolicy
- ) -> None:
- event: DataStoreEvent
- values = schedule.marshal(self.serializer)
- insert = self.t_schedules.insert().values(**values)
- try:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(insert)
- except IntegrityError:
- if conflict_policy is ConflictPolicy.exception:
- raise ConflictingIdError(schedule.id) from None
- elif conflict_policy is ConflictPolicy.replace:
- del values["id"]
- update = (
- self.t_schedules.update()
- .where(self.t_schedules.c.id == schedule.id)
- .values(**values)
- )
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(update)
-
- event = ScheduleUpdated(
- schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
- )
- await self._events.publish(event)
- else:
- event = ScheduleAdded(
- schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
- )
- await self._events.publish(event)
-
- async def remove_schedules(self, ids: Iterable[str]) -> None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- delete = self.t_schedules.delete().where(
- self.t_schedules.c.id.in_(ids)
- )
- if self._supports_update_returning:
- delete = delete.returning(self.t_schedules.c.id)
- removed_ids: Iterable[str] = [
- row[0] for row in await conn.execute(delete)
- ]
- else:
- # TODO: actually check which rows were deleted?
- await conn.execute(delete)
- removed_ids = ids
-
- for schedule_id in removed_ids:
- await self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
-
- async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
- query = self.t_schedules.select().order_by(self.t_schedules.c.id)
- if ids:
- query = query.where(self.t_schedules.c.id.in_(ids))
-
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- return await self._deserialize_schedules(result)
-
- async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- schedules_cte = (
- select(self.t_schedules.c.id)
- .where(
- and_(
- self.t_schedules.c.next_fire_time.isnot(None),
- self.t_schedules.c.next_fire_time <= now,
- or_(
- self.t_schedules.c.acquired_until.is_(None),
- self.t_schedules.c.acquired_until < now,
- ),
- )
- )
- .order_by(self.t_schedules.c.next_fire_time)
- .limit(limit)
- .with_for_update(skip_locked=True)
- .cte()
- )
- subselect = select([schedules_cte.c.id])
- update = (
- self.t_schedules.update()
- .where(self.t_schedules.c.id.in_(subselect))
- .values(acquired_by=scheduler_id, acquired_until=acquired_until)
- )
- if self._supports_update_returning:
- update = update.returning(*self.t_schedules.columns)
- result = await conn.execute(update)
- else:
- await conn.execute(update)
- query = self.t_schedules.select().where(
- and_(self.t_schedules.c.acquired_by == scheduler_id)
- )
- result = await conn.execute(query)
-
- schedules = await self._deserialize_schedules(result)
-
- return schedules
-
- async def release_schedules(
- self, scheduler_id: str, schedules: list[Schedule]
- ) -> None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- update_events: list[ScheduleUpdated] = []
- finished_schedule_ids: list[str] = []
- update_args: list[dict[str, Any]] = []
- for schedule in schedules:
- if schedule.next_fire_time is not None:
- try:
- serialized_trigger = self.serializer.serialize(
- schedule.trigger
- )
- except SerializationError:
- self._logger.exception(
- "Error serializing trigger for schedule %r – "
- "removing from data store",
- schedule.id,
- )
- finished_schedule_ids.append(schedule.id)
- continue
-
- update_args.append(
- {
- "p_id": schedule.id,
- "p_trigger": serialized_trigger,
- "p_next_fire_time": schedule.next_fire_time,
- }
- )
- else:
- finished_schedule_ids.append(schedule.id)
-
- # Update schedules that have a next fire time
- if update_args:
- p_id: BindParameter = bindparam("p_id")
- p_trigger: BindParameter = bindparam("p_trigger")
- p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
- update = (
- self.t_schedules.update()
- .where(
- and_(
- self.t_schedules.c.id == p_id,
- self.t_schedules.c.acquired_by == scheduler_id,
- )
- )
- .values(
- trigger=p_trigger,
- next_fire_time=p_next_fire_time,
- acquired_by=None,
- acquired_until=None,
- )
- )
- next_fire_times = {
- arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
- }
- # TODO: actually check which rows were updated?
- await conn.execute(update, update_args)
- updated_ids = list(next_fire_times)
-
- for schedule_id in updated_ids:
- event = ScheduleUpdated(
- schedule_id=schedule_id,
- next_fire_time=next_fire_times[schedule_id],
- )
- update_events.append(event)
-
- # Remove schedules that have no next fire time or failed to
- # serialize
- if finished_schedule_ids:
- delete = self.t_schedules.delete().where(
- self.t_schedules.c.id.in_(finished_schedule_ids)
- )
- await conn.execute(delete)
-
- for event in update_events:
- await self._events.publish(event)
-
- for schedule_id in finished_schedule_ids:
- await self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
-
- async def get_next_schedule_run_time(self) -> datetime | None:
- statenent = (
- select(self.t_schedules.c.next_fire_time)
- .where(self.t_schedules.c.next_fire_time.isnot(None))
- .order_by(self.t_schedules.c.next_fire_time)
- .limit(1)
- )
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(statenent)
- return result.scalar()
-
- async def add_job(self, job: Job) -> None:
- marshalled = job.marshal(self.serializer)
- insert = self.t_jobs.insert().values(**marshalled)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(insert)
-
- event = JobAdded(
- job_id=job.id,
- task_id=job.task_id,
- schedule_id=job.schedule_id,
- tags=job.tags,
- )
- await self._events.publish(event)
-
- async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
- query = self.t_jobs.select().order_by(self.t_jobs.c.id)
- if ids:
- job_ids = [job_id for job_id in ids]
- query = query.where(self.t_jobs.c.id.in_(job_ids))
-
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- return await self._deserialize_jobs(result)
-
- async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- query = (
- self.t_jobs.select()
- .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id)
- .where(
- or_(
- self.t_jobs.c.acquired_until.is_(None),
- self.t_jobs.c.acquired_until < now,
- )
- )
- .order_by(self.t_jobs.c.created_at)
- .with_for_update(skip_locked=True)
- .limit(limit)
- )
-
- result = await conn.execute(query)
- if not result:
- return []
-
- # Mark the jobs as acquired by this worker
- jobs = await self._deserialize_jobs(result)
- task_ids: set[str] = {job.task_id for job in jobs}
-
- # Retrieve the limits
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.max_running_jobs
- - self.t_tasks.c.running_jobs,
- ]
- ).where(
- self.t_tasks.c.max_running_jobs.isnot(None),
- self.t_tasks.c.id.in_(task_ids),
- )
- result = await conn.execute(query)
- job_slots_left: dict[str, int] = dict(result.fetchall())
-
- # Filter out jobs that don't have free slots
- acquired_jobs: list[Job] = []
- increments: dict[str, int] = defaultdict(lambda: 0)
- for job in jobs:
- # Don't acquire the job if there are no free slots left
- slots_left = job_slots_left.get(job.task_id)
- if slots_left == 0:
- continue
- elif slots_left is not None:
- job_slots_left[job.task_id] -= 1
-
- acquired_jobs.append(job)
- increments[job.task_id] += 1
-
- if acquired_jobs:
- # Mark the acquired jobs as acquired by this worker
- acquired_job_ids = [job.id for job in acquired_jobs]
- update = (
- self.t_jobs.update()
- .values(
- acquired_by=worker_id, acquired_until=acquired_until
- )
- .where(self.t_jobs.c.id.in_(acquired_job_ids))
- )
- await conn.execute(update)
-
- # Increment the running job counters on each task
- p_id: BindParameter = bindparam("p_id")
- p_increment: BindParameter = bindparam("p_increment")
- params = [
- {"p_id": task_id, "p_increment": increment}
- for task_id, increment in increments.items()
- ]
- update = (
- self.t_tasks.update()
- .values(
- running_jobs=self.t_tasks.c.running_jobs + p_increment
- )
- .where(self.t_tasks.c.id == p_id)
- )
- await conn.execute(update, params)
-
- # Publish the appropriate events
- for job in acquired_jobs:
- await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
-
- return acquired_jobs
-
- async def release_job(
- self, worker_id: str, task_id: str, result: JobResult
- ) -> None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- # Record the job result
- if result.expires_at > result.finished_at:
- marshalled = result.marshal(self.serializer)
- insert = self.t_job_results.insert().values(**marshalled)
- await conn.execute(insert)
-
- # Decrement the number of running jobs for this task
- update = (
- self.t_tasks.update()
- .values(running_jobs=self.t_tasks.c.running_jobs - 1)
- .where(self.t_tasks.c.id == task_id)
- )
- await conn.execute(update)
-
- # Delete the job
- delete = self.t_jobs.delete().where(
- self.t_jobs.c.id == result.job_id
- )
- await conn.execute(delete)
-
- async def get_job_result(self, job_id: UUID) -> JobResult | None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- # Retrieve the result
- query = self.t_job_results.select().where(
- self.t_job_results.c.job_id == job_id
- )
- row = (await conn.execute(query)).first()
-
- # Delete the result
- delete = self.t_job_results.delete().where(
- self.t_job_results.c.job_id == job_id
- )
- await conn.execute(delete)
-
- return (
- JobResult.unmarshal(self.serializer, row._asdict())
- if row
- else None
- )
diff --git a/src/apscheduler/datastores/base.py b/src/apscheduler/datastores/base.py
index c05d28c..5c7ef7d 100644
--- a/src/apscheduler/datastores/base.py
+++ b/src/apscheduler/datastores/base.py
@@ -1,37 +1,47 @@
from __future__ import annotations
-from apscheduler.abc import (
- AsyncDataStore,
- AsyncEventBroker,
- DataStore,
- EventBroker,
- EventSource,
-)
+from contextlib import AsyncExitStack
+from logging import Logger, getLogger
+import attrs
+from .._retry import RetryMixin
+from ..abc import DataStore, EventBroker, Serializer
+from ..serializers.pickle import PickleSerializer
+
+
+@attrs.define(kw_only=True)
class BaseDataStore(DataStore):
- _events: EventBroker
+ """
+ Base class for data stores.
- def start(self, event_broker: EventBroker) -> None:
- self._events = event_broker
+ :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
+ or worker can keep a lock on a schedule or task
+ """
- def stop(self, *, force: bool = False) -> None:
- del self._events
+ lock_expiration_delay: float = 30
+ _event_broker: EventBroker = attrs.field(init=False)
+ _logger: Logger = attrs.field(init=False)
- @property
- def events(self) -> EventSource:
- return self._events
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
+ self._event_broker = event_broker
+ def __attrs_post_init__(self):
+ self._logger = getLogger(self.__class__.__name__)
-class BaseAsyncDataStore(AsyncDataStore):
- _events: AsyncEventBroker
- async def start(self, event_broker: AsyncEventBroker) -> None:
- self._events = event_broker
+@attrs.define(kw_only=True)
+class BaseExternalDataStore(BaseDataStore, RetryMixin):
+ """
+ Base class for data stores using an external service such as a database.
- async def stop(self, *, force: bool = False) -> None:
- del self._events
+ :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
+ for storage
+ :param start_from_scratch: erase all existing data during startup (useful for test
+ suites)
+ """
- @property
- def events(self) -> EventSource:
- return self._events
+ serializer: Serializer = attrs.field(factory=PickleSerializer)
+ start_from_scratch: bool = attrs.field(default=False)
diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py
index a9ff3cb..fd7e90e 100644
--- a/src/apscheduler/datastores/memory.py
+++ b/src/apscheduler/datastores/memory.py
@@ -85,13 +85,9 @@ class MemoryDataStore(BaseDataStore):
"""
Stores scheduler data in memory, without serializing it.
- Can be shared between multiple schedulers and workers within the same event loop.
-
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
+ Can be shared between multiple schedulers within the same event loop.
"""
- lock_expiration_delay: float = 30
_tasks: dict[str, TaskState] = attrs.Factory(dict)
_schedules: list[ScheduleState] = attrs.Factory(list)
_schedules_by_id: dict[str, ScheduleState] = attrs.Factory(dict)
@@ -115,41 +111,43 @@ class MemoryDataStore(BaseDataStore):
right_index = bisect_left(self._jobs, state)
return self._jobs.index(state, left_index, right_index + 1)
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
+ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
return [
state.schedule
for state in self._schedules
if ids is None or state.schedule.id in ids
]
- def add_task(self, task: Task) -> None:
+ async def add_task(self, task: Task) -> None:
task_exists = task.id in self._tasks
self._tasks[task.id] = TaskState(task)
if task_exists:
- self._events.publish(TaskUpdated(task_id=task.id))
+ await self._event_broker.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._event_broker.publish(TaskAdded(task_id=task.id))
- def remove_task(self, task_id: str) -> None:
+ async def remove_task(self, task_id: str) -> None:
try:
del self._tasks[task_id]
except KeyError:
raise TaskLookupError(task_id) from None
- self._events.publish(TaskRemoved(task_id=task_id))
+ await self._event_broker.publish(TaskRemoved(task_id=task_id))
- def get_task(self, task_id: str) -> Task:
+ async def get_task(self, task_id: str) -> Task:
try:
return self._tasks[task_id].task
except KeyError:
raise TaskLookupError(task_id) from None
- def get_tasks(self) -> list[Task]:
+ async def get_tasks(self) -> list[Task]:
return sorted(
(state.task for state in self._tasks.values()), key=lambda task: task.id
)
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
old_state = self._schedules_by_id.get(schedule.id)
if old_state is not None:
if conflict_policy is ConflictPolicy.do_nothing:
@@ -175,17 +173,17 @@ class MemoryDataStore(BaseDataStore):
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def remove_schedules(self, ids: Iterable[str]) -> None:
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
for schedule_id in ids:
state = self._schedules_by_id.pop(schedule_id, None)
if state:
self._schedules.remove(state)
event = ScheduleRemoved(schedule_id=state.schedule.id)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
now = datetime.now(timezone.utc)
schedules: list[Schedule] = []
for state in self._schedules:
@@ -206,7 +204,9 @@ class MemoryDataStore(BaseDataStore):
return schedules
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
# Send update events for schedules that have a next time
finished_schedule_ids: list[str] = []
for s in schedules:
@@ -224,17 +224,17 @@ class MemoryDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=s.id, next_fire_time=s.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
else:
finished_schedule_ids.append(s.id)
# Remove schedules that didn't get a new next fire time
- self.remove_schedules(finished_schedule_ids)
+ await self.remove_schedules(finished_schedule_ids)
- def get_next_schedule_run_time(self) -> datetime | None:
+ async def get_next_schedule_run_time(self) -> datetime | None:
return self._schedules[0].next_fire_time if self._schedules else None
- def add_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
state = JobState(job)
self._jobs.append(state)
self._jobs_by_id[job.id] = state
@@ -246,15 +246,15 @@ class MemoryDataStore(BaseDataStore):
schedule_id=job.schedule_id,
tags=job.tags,
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
+ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
if ids is not None:
ids = frozenset(ids)
return [state.job for state in self._jobs if ids is None or state.job.id in ids]
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
now = datetime.now(timezone.utc)
jobs: list[Job] = []
for _index, job_state in enumerate(self._jobs):
@@ -290,11 +290,15 @@ class MemoryDataStore(BaseDataStore):
# Publish the appropriate events
for job in jobs:
- self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
+ await self._event_broker.publish(
+ JobAcquired(job_id=job.id, worker_id=worker_id)
+ )
return jobs
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
# Record the job result
if result.expires_at > result.finished_at:
self._job_results[result.job_id] = result
@@ -310,5 +314,5 @@ class MemoryDataStore(BaseDataStore):
index = self._find_job_index(job_state)
del self._jobs[index]
- def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async def get_job_result(self, job_id: UUID) -> JobResult | None:
return self._job_results.pop(job_id, None)
diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py
index 13299bb..4899f54 100644
--- a/src/apscheduler/datastores/mongodb.py
+++ b/src/apscheduler/datastores/mongodb.py
@@ -2,14 +2,13 @@ from __future__ import annotations
import operator
from collections import defaultdict
+from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
-from logging import Logger, getLogger
from typing import Any, Callable, ClassVar, Iterable
from uuid import UUID
import attrs
import pymongo
-import tenacity
from attrs.validators import instance_of
from bson import CodecOptions, UuidRepresentation
from bson.codec_options import TypeEncoder, TypeRegistry
@@ -35,10 +34,9 @@ from .._exceptions import (
SerializationError,
TaskLookupError,
)
-from .._structures import Job, JobResult, RetrySettings, Schedule, Task
-from ..abc import EventBroker, Serializer
-from ..serializers.pickle import PickleSerializer
-from .base import BaseDataStore
+from .._structures import Job, JobResult, Schedule, Task
+from ..abc import EventBroker
+from .base import BaseExternalDataStore
class CustomEncoder(TypeEncoder):
@@ -55,7 +53,7 @@ class CustomEncoder(TypeEncoder):
@attrs.define(eq=False)
-class MongoDBDataStore(BaseDataStore):
+class MongoDBDataStore(BaseExternalDataStore):
"""
Uses a MongoDB server to store data.
@@ -66,23 +64,11 @@ class MongoDBDataStore(BaseDataStore):
raises :exc:`pymongo.errors.ConnectionFailure`.
:param client: a PyMongo client
- :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
- for storage
:param database: name of the database to use
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
- :param retry_settings: Tenacity settings for retrying operations in case of a
- database connecitivty problem
- :param start_from_scratch: erase all existing data during startup (useful for test
- suites)
"""
client: MongoClient = attrs.field(validator=instance_of(MongoClient))
- serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True)
database: str = attrs.field(default="apscheduler", kw_only=True)
- lock_expiration_delay: float = attrs.field(default=30, kw_only=True)
- retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True)
- start_from_scratch: bool = attrs.field(default=False, kw_only=True)
_task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)]
_schedule_attrs: ClassVar[list[str]] = [
@@ -90,8 +76,8 @@ class MongoDBDataStore(BaseDataStore):
]
_job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)]
- _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__))
_local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict)
+ _temporary_failure_exceptions = (ConnectionFailure,)
def __attrs_post_init__(self) -> None:
type_registry = TypeRegistry(
@@ -118,25 +104,10 @@ class MongoDBDataStore(BaseDataStore):
client = MongoClient(uri)
return cls(client, **options)
- def _retry(self) -> tenacity.Retrying:
- return tenacity.Retrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type(ConnectionFailure),
- after=self._after_attempt,
- reraise=True,
- )
-
- def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None:
- self._logger.warning(
- "Temporary data store error (attempt %d): %s",
- retry_state.attempt_number,
- retry_state.outcome.exception(),
- )
-
- def start(self, event_broker: EventBroker) -> None:
- super().start(event_broker)
-
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
+ await super().start(exit_stack, event_broker)
server_info = self.client.server_info()
if server_info["versionArray"] < [4, 0]:
raise RuntimeError(
@@ -144,7 +115,7 @@ class MongoDBDataStore(BaseDataStore):
f"{server_info['version']}"
)
- for attempt in self._retry():
+ async for attempt in self._retry():
with attempt, self.client.start_session() as session:
if self.start_from_scratch:
self._tasks.delete_many({}, session=session)
@@ -159,7 +130,7 @@ class MongoDBDataStore(BaseDataStore):
self._jobs_results.create_index("finished_at", session=session)
self._jobs_results.create_index("expires_at", session=session)
- def add_task(self, task: Task) -> None:
+ async def add_task(self, task: Task) -> None:
for attempt in self._retry():
with attempt:
previous = self._tasks.find_one_and_update(
@@ -173,20 +144,20 @@ class MongoDBDataStore(BaseDataStore):
self._local_tasks[task.id] = task
if previous:
- self._events.publish(TaskUpdated(task_id=task.id))
+ await self._event_broker.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._event_broker.publish(TaskAdded(task_id=task.id))
- def remove_task(self, task_id: str) -> None:
+ async def remove_task(self, task_id: str) -> None:
for attempt in self._retry():
with attempt:
if not self._tasks.find_one_and_delete({"_id": task_id}):
raise TaskLookupError(task_id)
del self._local_tasks[task_id]
- self._events.publish(TaskRemoved(task_id=task_id))
+ await self._event_broker.publish(TaskRemoved(task_id=task_id))
- def get_task(self, task_id: str) -> Task:
+ async def get_task(self, task_id: str) -> Task:
try:
return self._local_tasks[task_id]
except KeyError:
@@ -205,7 +176,7 @@ class MongoDBDataStore(BaseDataStore):
)
return task
- def get_tasks(self) -> list[Task]:
+ async def get_tasks(self) -> list[Task]:
for attempt in self._retry():
with attempt:
tasks: list[Task] = []
@@ -217,7 +188,7 @@ class MongoDBDataStore(BaseDataStore):
return tasks
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
+ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt:
@@ -237,7 +208,9 @@ class MongoDBDataStore(BaseDataStore):
return schedules
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
event: DataStoreEvent
document = schedule.marshal(self.serializer)
document["_id"] = document.pop("id")
@@ -258,14 +231,14 @@ class MongoDBDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
else:
event = ScheduleAdded(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def remove_schedules(self, ids: Iterable[str]) -> None:
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt, self.client.start_session() as session:
@@ -277,9 +250,9 @@ class MongoDBDataStore(BaseDataStore):
self._schedules.delete_many(filters, session=session)
for schedule_id in ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
schedules: list[Schedule] = []
@@ -318,7 +291,9 @@ class MongoDBDataStore(BaseDataStore):
return schedules
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
updated_schedules: list[tuple[str, datetime]] = []
finished_schedule_ids: list[str] = []
@@ -365,12 +340,12 @@ class MongoDBDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=schedule_id, next_fire_time=next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
for schedule_id in finished_schedule_ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def get_next_schedule_run_time(self) -> datetime | None:
+ async def get_next_schedule_run_time(self) -> datetime | None:
for attempt in self._retry():
with attempt:
document = self._schedules.find_one(
@@ -384,7 +359,7 @@ class MongoDBDataStore(BaseDataStore):
else:
return None
- def add_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
document = job.marshal(self.serializer)
document["_id"] = document.pop("id")
for attempt in self._retry():
@@ -397,9 +372,9 @@ class MongoDBDataStore(BaseDataStore):
schedule_id=job.schedule_id,
tags=job.tags,
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
+ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt:
@@ -419,7 +394,7 @@ class MongoDBDataStore(BaseDataStore):
return jobs
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
cursor = self._jobs.find(
@@ -488,11 +463,15 @@ class MongoDBDataStore(BaseDataStore):
# Publish the appropriate events
for job in acquired_jobs:
- self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
+ await self._event_broker.publish(
+ JobAcquired(job_id=job.id, worker_id=worker_id)
+ )
return acquired_jobs
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
# Record the job result
@@ -509,7 +488,7 @@ class MongoDBDataStore(BaseDataStore):
# Delete the job
self._jobs.delete_one({"_id": result.job_id}, session=session)
- def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async def get_job_result(self, job_id: UUID) -> JobResult | None:
for attempt in self._retry():
with attempt:
document = self._jobs_results.find_one_and_delete({"_id": job_id})
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py
index ebb076f..00d06aa 100644
--- a/src/apscheduler/datastores/sqlalchemy.py
+++ b/src/apscheduler/datastores/sqlalchemy.py
@@ -1,16 +1,20 @@
from __future__ import annotations
+import sys
from collections import defaultdict
+from collections.abc import AsyncGenerator, Sequence
+from contextlib import AsyncExitStack, asynccontextmanager
from datetime import datetime, timedelta, timezone
-from logging import Logger, getLogger
from typing import Any, Iterable
from uuid import UUID
+import anyio
import attrs
+import sniffio
import tenacity
+from anyio import to_thread
from sqlalchemy import (
JSON,
- TIMESTAMP,
BigInteger,
Column,
Enum,
@@ -26,14 +30,15 @@ from sqlalchemy import (
select,
)
from sqlalchemy.engine import URL, Dialect, Result
-from sqlalchemy.exc import CompileError, IntegrityError, OperationalError
-from sqlalchemy.future import Engine, create_engine
+from sqlalchemy.exc import CompileError, IntegrityError, InterfaceError
+from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
+from sqlalchemy.future import Connection, Engine
from sqlalchemy.sql.ddl import DropTable
from sqlalchemy.sql.elements import BindParameter, literal
from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome
from .._events import (
- Event,
+ DataStoreEvent,
JobAcquired,
JobAdded,
JobDeserializationFailed,
@@ -46,11 +51,15 @@ from .._events import (
TaskUpdated,
)
from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError
-from .._structures import Job, JobResult, RetrySettings, Schedule, Task
-from ..abc import EventBroker, Serializer
+from .._structures import Job, JobResult, Schedule, Task
+from ..abc import EventBroker
from ..marshalling import callable_to_ref
-from ..serializers.pickle import PickleSerializer
-from .base import BaseDataStore
+from .base import BaseExternalDataStore
+
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
class EmulatedUUID(TypeDecorator):
@@ -86,17 +95,30 @@ class EmulatedInterval(TypeDecorator):
return timedelta(seconds=value) if value is not None else None
-@attrs.define(kw_only=True, eq=False)
-class _BaseSQLAlchemyDataStore:
+@attrs.define(eq=False)
+class SQLAlchemyDataStore(BaseExternalDataStore):
+ """
+ Uses a relational database to store data.
+
+ When started, this data store creates the appropriate tables on the given database
+ if they're not already present.
+
+ Operations are retried (in accordance to ``retry_settings``) when an operation
+ raises :exc:`sqlalchemy.OperationalError`.
+
+ This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL
+ (asyncmy driver).
+
+ :param engine: an asynchronous SQLAlchemy engine
+ :param schema: a database schema name to use, if not the default
+ """
+
+ engine: Engine | AsyncEngine
schema: str | None = attrs.field(default=None)
- serializer: Serializer = attrs.field(factory=PickleSerializer)
- lock_expiration_delay: float = attrs.field(default=30)
max_poll_time: float | None = attrs.field(default=1)
max_idle_time: float = attrs.field(default=60)
- retry_settings: RetrySettings = attrs.field(default=RetrySettings())
- start_from_scratch: bool = attrs.field(default=False)
- _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__))
+ _is_async: bool = attrs.field(init=False)
def __attrs_post_init__(self) -> None:
# Generate the table definitions
@@ -107,6 +129,7 @@ class _BaseSQLAlchemyDataStore:
self.t_schedules = self._metadata.tables[prefix + "schedules"]
self.t_jobs = self._metadata.tables[prefix + "jobs"]
self.t_job_results = self._metadata.tables[prefix + "job_results"]
+ self._is_async = isinstance(self.engine, AsyncEngine)
# Find out if the dialect supports UPDATE...RETURNING
update = self.t_jobs.update().returning(self.t_jobs.c.id)
@@ -117,18 +140,76 @@ class _BaseSQLAlchemyDataStore:
else:
self._supports_update_returning = True
- def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None:
- self._logger.warning(
- "Temporary data store error (attempt %d): %s",
- retry_state.attempt_number,
- retry_state.outcome.exception(),
+ @classmethod
+ def from_url(cls: type[Self], url: str | URL, **options) -> Self:
+ """
+ Create a new asynchronous SQLAlchemy data store.
+
+ :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine`
+ (must use an async dialect like ``asyncpg`` or ``asyncmy``)
+ :param kwargs: keyword arguments to pass to the initializer of this class
+ :return: the newly created data store
+
+ """
+ engine = create_async_engine(url, future=True)
+ return cls(engine, **options)
+
+ def _retry(self) -> tenacity.AsyncRetrying:
+ def after_attempt(self, retry_state: tenacity.RetryCallState) -> None:
+ self._logger.warning(
+ "Temporary data store error (attempt %d): %s",
+ retry_state.attempt_number,
+ retry_state.outcome.exception(),
+ )
+
+ # OSError is raised by asyncpg if it can't connect
+ return tenacity.AsyncRetrying(
+ stop=self.retry_settings.stop,
+ wait=self.retry_settings.wait,
+ retry=tenacity.retry_if_exception_type((InterfaceError, OSError)),
+ after=after_attempt,
+ sleep=anyio.sleep,
+ reraise=True,
)
+ @asynccontextmanager
+ async def _begin_transaction(self) -> AsyncGenerator[Connection | AsyncConnection]:
+ if isinstance(self.engine, AsyncEngine):
+ async with self.engine.begin() as conn:
+ yield conn
+ else:
+ cm = self.engine.begin()
+ conn = await to_thread.run_sync(cm.__enter__)
+ try:
+ yield conn
+ except BaseException as exc:
+ await to_thread.run_sync(cm.__exit__, type(exc), exc, exc.__traceback__)
+ raise
+ else:
+ await to_thread.run_sync(cm.__exit__, None, None, None)
+
+ async def _create_metadata(self, conn: Connection | AsyncConnection) -> None:
+ if isinstance(conn, AsyncConnection):
+ await conn.run_sync(self._metadata.create_all)
+ else:
+ await to_thread.run_sync(self._metadata.create_all, conn)
+
+ async def _execute(
+ self,
+ conn: Connection | AsyncConnection,
+ statement,
+ parameters: Sequence | None = None,
+ ):
+ if isinstance(conn, AsyncConnection):
+ return await conn.execute(statement, parameters)
+ else:
+ return await to_thread.run_sync(conn.execute, statement, parameters)
+
def get_table_definitions(self) -> MetaData:
if self.engine.dialect.name == "postgresql":
from sqlalchemy.dialects import postgresql
- timestamp_type = TIMESTAMP(timezone=True)
+ timestamp_type = postgresql.TIMESTAMP(timezone=True)
job_id_type = postgresql.UUID(as_uuid=True)
interval_type = postgresql.INTERVAL(precision=6)
tags_type = postgresql.ARRAY(Unicode)
@@ -145,6 +226,7 @@ class _BaseSQLAlchemyDataStore:
metadata,
Column("id", Unicode(500), primary_key=True),
Column("func", Unicode(500), nullable=False),
+ Column("executor", Unicode(500), nullable=False),
Column("state", LargeBinary),
Column("max_running_jobs", Integer),
Column("misfire_grace_time", interval_type),
@@ -197,186 +279,160 @@ class _BaseSQLAlchemyDataStore:
)
return metadata
- def _deserialize_schedules(self, result: Result) -> list[Schedule]:
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
+ await super().start(exit_stack, event_broker)
+ asynclib = sniffio.current_async_library() or "(unknown)"
+ if asynclib != "asyncio":
+ raise RuntimeError(
+ f"This data store requires asyncio; currently running: {asynclib}"
+ )
+
+ # Verify that the schema is in place
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ if self.start_from_scratch:
+ for table in self._metadata.sorted_tables:
+ await self._execute(conn, DropTable(table, if_exists=True))
+
+ await self._create_metadata(conn)
+ query = select(self.t_metadata.c.schema_version)
+ result = await self._execute(conn, query)
+ version = result.scalar()
+ if version is None:
+ await self._execute(
+ conn, self.t_metadata.insert(values={"schema_version": 1})
+ )
+ elif version > 1:
+ raise RuntimeError(
+ f"Unexpected schema version ({version}); "
+ f"only version 1 is supported by this version of "
+ f"APScheduler"
+ )
+
+ async def _deserialize_schedules(self, result: Result) -> list[Schedule]:
schedules: list[Schedule] = []
for row in result:
try:
schedules.append(Schedule.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
- self._events.publish(
+ await self._event_broker.publish(
ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc)
)
return schedules
- def _deserialize_jobs(self, result: Result) -> list[Job]:
+ async def _deserialize_jobs(self, result: Result) -> list[Job]:
jobs: list[Job] = []
for row in result:
try:
jobs.append(Job.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
- self._events.publish(
+ await self._event_broker.publish(
JobDeserializationFailed(job_id=row["id"], exception=exc)
)
return jobs
-
-@attrs.define(eq=False)
-class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore):
- """
- Uses a relational database to store data.
-
- When started, this data store creates the appropriate tables on the given database
- if they're not already present.
-
- Operations are retried (in accordance to ``retry_settings``) when an operation
- raises :exc:`sqlalchemy.OperationalError`.
-
- This store has been tested to work with PostgreSQL (psycopg2 driver), MySQL
- (pymysql driver) and SQLite.
-
- :param engine: a (synchronous) SQLAlchemy engine
- :param schema: a database schema name to use, if not the default
- :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
- for storage
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
- :param retry_settings: Tenacity settings for retrying operations in case of a
- database connecitivty problem
- :param start_from_scratch: erase all existing data during startup (useful for test
- suites)
- """
-
- engine: Engine
-
- @classmethod
- def from_url(cls, url: str | URL, **kwargs) -> SQLAlchemyDataStore:
- """
- Create a new SQLAlchemy data store.
-
- :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine`
- :param kwargs: keyword arguments to pass to the initializer of this class
- :return: the newly created data store
-
- """
- engine = create_engine(url)
- return cls(engine, **kwargs)
-
- def _retry(self) -> tenacity.Retrying:
- return tenacity.Retrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type(OperationalError),
- after=self._after_attempt,
- reraise=True,
- )
-
- def start(self, event_broker: EventBroker) -> None:
- super().start(event_broker)
-
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- if self.start_from_scratch:
- for table in self._metadata.sorted_tables:
- conn.execute(DropTable(table, if_exists=True))
-
- self._metadata.create_all(conn)
- query = select(self.t_metadata.c.schema_version)
- result = conn.execute(query)
- version = result.scalar()
- if version is None:
- conn.execute(self.t_metadata.insert(values={"schema_version": 1}))
- elif version > 1:
- raise RuntimeError(
- f"Unexpected schema version ({version}); "
- f"only version 1 is supported by this version of APScheduler"
- )
-
- def add_task(self, task: Task) -> None:
+ async def add_task(self, task: Task) -> None:
insert = self.t_tasks.insert().values(
id=task.id,
func=callable_to_ref(task.func),
+ executor=task.executor,
max_running_jobs=task.max_running_jobs,
misfire_grace_time=task.misfire_grace_time,
)
try:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(insert)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, insert)
except IntegrityError:
update = (
self.t_tasks.update()
.values(
func=callable_to_ref(task.func),
+ executor=task.executor,
max_running_jobs=task.max_running_jobs,
misfire_grace_time=task.misfire_grace_time,
)
.where(self.t_tasks.c.id == task.id)
)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(update)
- self._events.publish(TaskUpdated(task_id=task.id))
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, update)
+
+ await self._event_broker.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._event_broker.publish(TaskAdded(task_id=task.id))
- def remove_task(self, task_id: str) -> None:
+ async def remove_task(self, task_id: str) -> None:
delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(delete)
- if result.rowcount == 0:
- raise TaskLookupError(task_id)
- else:
- self._events.publish(TaskRemoved(task_id=task_id))
-
- def get_task(self, task_id: str) -> Task:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, delete)
+ if result.rowcount == 0:
+ raise TaskLookupError(task_id)
+ else:
+ await self._event_broker.publish(TaskRemoved(task_id=task_id))
+
+ async def get_task(self, task_id: str) -> Task:
query = select(
[
self.t_tasks.c.id,
self.t_tasks.c.func,
+ self.t_tasks.c.executor,
self.t_tasks.c.max_running_jobs,
self.t_tasks.c.state,
self.t_tasks.c.misfire_grace_time,
]
).where(self.t_tasks.c.id == task_id)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- row = result.first()
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ row = result.first()
if row:
return Task.unmarshal(self.serializer, row._asdict())
else:
raise TaskLookupError(task_id)
- def get_tasks(self) -> list[Task]:
+ async def get_tasks(self) -> list[Task]:
query = select(
[
self.t_tasks.c.id,
self.t_tasks.c.func,
+ self.t_tasks.c.executor,
self.t_tasks.c.max_running_jobs,
self.t_tasks.c.state,
self.t_tasks.c.misfire_grace_time,
]
).order_by(self.t_tasks.c.id)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- tasks = [
- Task.unmarshal(self.serializer, row._asdict()) for row in result
- ]
- return tasks
-
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
- event: Event
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ tasks = [
+ Task.unmarshal(self.serializer, row._asdict()) for row in result
+ ]
+ return tasks
+
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
+ event: DataStoreEvent
values = schedule.marshal(self.serializer)
insert = self.t_schedules.insert().values(**values)
try:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(insert)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, insert)
except IntegrityError:
if conflict_policy is ConflictPolicy.exception:
raise ConflictingIdError(schedule.id) from None
@@ -387,185 +443,197 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore):
.where(self.t_schedules.c.id == schedule.id)
.values(**values)
)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(update)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, update)
event = ScheduleUpdated(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
else:
event = ScheduleAdded(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
-
- def remove_schedules(self, ids: Iterable[str]) -> None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids))
- if self._supports_update_returning:
- delete = delete.returning(self.t_schedules.c.id)
- removed_ids: Iterable[str] = [
- row[0] for row in conn.execute(delete)
- ]
- else:
- # TODO: actually check which rows were deleted?
- conn.execute(delete)
- removed_ids = ids
+ await self._event_broker.publish(event)
+
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ delete = self.t_schedules.delete().where(
+ self.t_schedules.c.id.in_(ids)
+ )
+ if self._supports_update_returning:
+ delete = delete.returning(self.t_schedules.c.id)
+ removed_ids: Iterable[str] = [
+ row[0] for row in await self._execute(conn, delete)
+ ]
+ else:
+ # TODO: actually check which rows were deleted?
+ await self._execute(conn, delete)
+ removed_ids = ids
for schedule_id in removed_ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
+ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
query = self.t_schedules.select().order_by(self.t_schedules.c.id)
if ids:
query = query.where(self.t_schedules.c.id.in_(ids))
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- return self._deserialize_schedules(result)
-
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- schedules_cte = (
- select(self.t_schedules.c.id)
- .where(
- and_(
- self.t_schedules.c.next_fire_time.isnot(None),
- self.t_schedules.c.next_fire_time <= now,
- or_(
- self.t_schedules.c.acquired_until.is_(None),
- self.t_schedules.c.acquired_until < now,
- ),
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ return await self._deserialize_schedules(result)
+
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ now = datetime.now(timezone.utc)
+ acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+ schedules_cte = (
+ select(self.t_schedules.c.id)
+ .where(
+ and_(
+ self.t_schedules.c.next_fire_time.isnot(None),
+ self.t_schedules.c.next_fire_time <= now,
+ or_(
+ self.t_schedules.c.acquired_until.is_(None),
+ self.t_schedules.c.acquired_until < now,
+ ),
+ )
)
+ .order_by(self.t_schedules.c.next_fire_time)
+ .limit(limit)
+ .with_for_update(skip_locked=True)
+ .cte()
)
- .order_by(self.t_schedules.c.next_fire_time)
- .limit(limit)
- .with_for_update(skip_locked=True)
- .cte()
- )
- subselect = select([schedules_cte.c.id])
- update = (
- self.t_schedules.update()
- .where(self.t_schedules.c.id.in_(subselect))
- .values(acquired_by=scheduler_id, acquired_until=acquired_until)
- )
- if self._supports_update_returning:
- update = update.returning(*self.t_schedules.columns)
- result = conn.execute(update)
- else:
- conn.execute(update)
- query = self.t_schedules.select().where(
- and_(self.t_schedules.c.acquired_by == scheduler_id)
+ subselect = select([schedules_cte.c.id])
+ update = (
+ self.t_schedules.update()
+ .where(self.t_schedules.c.id.in_(subselect))
+ .values(acquired_by=scheduler_id, acquired_until=acquired_until)
)
- result = conn.execute(query)
+ if self._supports_update_returning:
+ update = update.returning(*self.t_schedules.columns)
+ result = await self._execute(conn, update)
+ else:
+ await self._execute(conn, update)
+ query = self.t_schedules.select().where(
+ and_(self.t_schedules.c.acquired_by == scheduler_id)
+ )
+ result = await self._execute(conn, query)
- schedules = self._deserialize_schedules(result)
+ schedules = await self._deserialize_schedules(result)
return schedules
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- update_events: list[ScheduleUpdated] = []
- finished_schedule_ids: list[str] = []
- update_args: list[dict[str, Any]] = []
- for schedule in schedules:
- if schedule.next_fire_time is not None:
- try:
- serialized_trigger = self.serializer.serialize(
- schedule.trigger
- )
- except SerializationError:
- self._logger.exception(
- "Error serializing trigger for schedule %r – "
- "removing from data store",
- schedule.id,
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ update_events: list[ScheduleUpdated] = []
+ finished_schedule_ids: list[str] = []
+ update_args: list[dict[str, Any]] = []
+ for schedule in schedules:
+ if schedule.next_fire_time is not None:
+ try:
+ serialized_trigger = self.serializer.serialize(
+ schedule.trigger
+ )
+ except SerializationError:
+ self._logger.exception(
+ "Error serializing trigger for schedule %r – "
+ "removing from data store",
+ schedule.id,
+ )
+ finished_schedule_ids.append(schedule.id)
+ continue
+
+ update_args.append(
+ {
+ "p_id": schedule.id,
+ "p_trigger": serialized_trigger,
+ "p_next_fire_time": schedule.next_fire_time,
+ }
)
+ else:
finished_schedule_ids.append(schedule.id)
- continue
-
- update_args.append(
- {
- "p_id": schedule.id,
- "p_trigger": serialized_trigger,
- "p_next_fire_time": schedule.next_fire_time,
- }
- )
- else:
- finished_schedule_ids.append(schedule.id)
- # Update schedules that have a next fire time
- if update_args:
- p_id: BindParameter = bindparam("p_id")
- p_trigger: BindParameter = bindparam("p_trigger")
- p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
- update = (
- self.t_schedules.update()
- .where(
- and_(
- self.t_schedules.c.id == p_id,
- self.t_schedules.c.acquired_by == scheduler_id,
+ # Update schedules that have a next fire time
+ if update_args:
+ p_id: BindParameter = bindparam("p_id")
+ p_trigger: BindParameter = bindparam("p_trigger")
+ p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
+ update = (
+ self.t_schedules.update()
+ .where(
+ and_(
+ self.t_schedules.c.id == p_id,
+ self.t_schedules.c.acquired_by == scheduler_id,
+ )
+ )
+ .values(
+ trigger=p_trigger,
+ next_fire_time=p_next_fire_time,
+ acquired_by=None,
+ acquired_until=None,
)
)
- .values(
- trigger=p_trigger,
- next_fire_time=p_next_fire_time,
- acquired_by=None,
- acquired_until=None,
- )
- )
- next_fire_times = {
- arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
- }
- # TODO: actually check which rows were updated?
- conn.execute(update, update_args)
- updated_ids = list(next_fire_times)
-
- for schedule_id in updated_ids:
- event = ScheduleUpdated(
- schedule_id=schedule_id,
- next_fire_time=next_fire_times[schedule_id],
- )
- update_events.append(event)
+ next_fire_times = {
+ arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
+ }
+ # TODO: actually check which rows were updated?
+ await self._execute(conn, update, update_args)
+ updated_ids = list(next_fire_times)
+
+ for schedule_id in updated_ids:
+ event = ScheduleUpdated(
+ schedule_id=schedule_id,
+ next_fire_time=next_fire_times[schedule_id],
+ )
+ update_events.append(event)
- # Remove schedules that have no next fire time or failed to serialize
- if finished_schedule_ids:
- delete = self.t_schedules.delete().where(
- self.t_schedules.c.id.in_(finished_schedule_ids)
- )
- conn.execute(delete)
+ # Remove schedules that have no next fire time or failed to
+ # serialize
+ if finished_schedule_ids:
+ delete = self.t_schedules.delete().where(
+ self.t_schedules.c.id.in_(finished_schedule_ids)
+ )
+ await self._execute(conn, delete)
for event in update_events:
- self._events.publish(event)
+ await self._event_broker.publish(event)
for schedule_id in finished_schedule_ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def get_next_schedule_run_time(self) -> datetime | None:
- query = (
+ async def get_next_schedule_run_time(self) -> datetime | None:
+ statenent = (
select(self.t_schedules.c.next_fire_time)
.where(self.t_schedules.c.next_fire_time.isnot(None))
.order_by(self.t_schedules.c.next_fire_time)
.limit(1)
)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- return result.scalar()
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, statenent)
+ return result.scalar()
- def add_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
marshalled = job.marshal(self.serializer)
insert = self.t_jobs.insert().values(**marshalled)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(insert)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, insert)
event = JobAdded(
job_id=job.id,
@@ -573,139 +641,156 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore):
schedule_id=job.schedule_id,
tags=job.tags,
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
+ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
query = self.t_jobs.select().order_by(self.t_jobs.c.id)
if ids:
job_ids = [job_id for job_id in ids]
query = query.where(self.t_jobs.c.id.in_(job_ids))
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- return self._deserialize_jobs(result)
-
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- query = (
- self.t_jobs.select()
- .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id)
- .where(
- or_(
- self.t_jobs.c.acquired_until.is_(None),
- self.t_jobs.c.acquired_until < now,
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ return await self._deserialize_jobs(result)
+
+ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ now = datetime.now(timezone.utc)
+ acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+ query = (
+ self.t_jobs.select()
+ .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id)
+ .where(
+ or_(
+ self.t_jobs.c.acquired_until.is_(None),
+ self.t_jobs.c.acquired_until < now,
+ )
)
+ .order_by(self.t_jobs.c.created_at)
+ .with_for_update(skip_locked=True)
+ .limit(limit)
)
- .order_by(self.t_jobs.c.created_at)
- .with_for_update(skip_locked=True)
- .limit(limit)
- )
-
- result = conn.execute(query)
- if not result:
- return []
- # Mark the jobs as acquired by this worker
- jobs = self._deserialize_jobs(result)
- task_ids: set[str] = {job.task_id for job in jobs}
-
- # Retrieve the limits
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs,
- ]
- ).where(
- self.t_tasks.c.max_running_jobs.isnot(None),
- self.t_tasks.c.id.in_(task_ids),
- )
- result = conn.execute(query)
- job_slots_left = dict(result.fetchall())
-
- # Filter out jobs that don't have free slots
- acquired_jobs: list[Job] = []
- increments: dict[str, int] = defaultdict(lambda: 0)
- for job in jobs:
- # Don't acquire the job if there are no free slots left
- slots_left = job_slots_left.get(job.task_id)
- if slots_left == 0:
- continue
- elif slots_left is not None:
- job_slots_left[job.task_id] -= 1
-
- acquired_jobs.append(job)
- increments[job.task_id] += 1
-
- if acquired_jobs:
- # Mark the acquired jobs as acquired by this worker
- acquired_job_ids = [job.id for job in acquired_jobs]
- update = (
- self.t_jobs.update()
- .values(acquired_by=worker_id, acquired_until=acquired_until)
- .where(self.t_jobs.c.id.in_(acquired_job_ids))
+ result = await self._execute(conn, query)
+ if not result:
+ return []
+
+ # Mark the jobs as acquired by this worker
+ jobs = await self._deserialize_jobs(result)
+ task_ids: set[str] = {job.task_id for job in jobs}
+
+ # Retrieve the limits
+ query = select(
+ [
+ self.t_tasks.c.id,
+ self.t_tasks.c.max_running_jobs
+ - self.t_tasks.c.running_jobs,
+ ]
+ ).where(
+ self.t_tasks.c.max_running_jobs.isnot(None),
+ self.t_tasks.c.id.in_(task_ids),
)
- conn.execute(update)
-
- # Increment the running job counters on each task
- p_id: BindParameter = bindparam("p_id")
- p_increment: BindParameter = bindparam("p_increment")
- params = [
- {"p_id": task_id, "p_increment": increment}
- for task_id, increment in increments.items()
- ]
- update = (
- self.t_tasks.update()
- .values(running_jobs=self.t_tasks.c.running_jobs + p_increment)
- .where(self.t_tasks.c.id == p_id)
- )
- conn.execute(update, params)
+ result = await self._execute(conn, query)
+ job_slots_left: dict[str, int] = dict(result.fetchall())
+
+ # Filter out jobs that don't have free slots
+ acquired_jobs: list[Job] = []
+ increments: dict[str, int] = defaultdict(lambda: 0)
+ for job in jobs:
+ # Don't acquire the job if there are no free slots left
+ slots_left = job_slots_left.get(job.task_id)
+ if slots_left == 0:
+ continue
+ elif slots_left is not None:
+ job_slots_left[job.task_id] -= 1
+
+ acquired_jobs.append(job)
+ increments[job.task_id] += 1
+
+ if acquired_jobs:
+ # Mark the acquired jobs as acquired by this worker
+ acquired_job_ids = [job.id for job in acquired_jobs]
+ update = (
+ self.t_jobs.update()
+ .values(
+ acquired_by=worker_id, acquired_until=acquired_until
+ )
+ .where(self.t_jobs.c.id.in_(acquired_job_ids))
+ )
+ await self._execute(conn, update)
+
+ # Increment the running job counters on each task
+ p_id: BindParameter = bindparam("p_id")
+ p_increment: BindParameter = bindparam("p_increment")
+ params = [
+ {"p_id": task_id, "p_increment": increment}
+ for task_id, increment in increments.items()
+ ]
+ update = (
+ self.t_tasks.update()
+ .values(
+ running_jobs=self.t_tasks.c.running_jobs + p_increment
+ )
+ .where(self.t_tasks.c.id == p_id)
+ )
+ await self._execute(conn, update, params)
# Publish the appropriate events
for job in acquired_jobs:
- self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
+ await self._event_broker.publish(
+ JobAcquired(job_id=job.id, worker_id=worker_id)
+ )
return acquired_jobs
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- # Insert the job result
- if result.expires_at > result.finished_at:
- marshalled = result.marshal(self.serializer)
- insert = self.t_job_results.insert().values(**marshalled)
- conn.execute(insert)
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ # Record the job result
+ if result.expires_at > result.finished_at:
+ marshalled = result.marshal(self.serializer)
+ insert = self.t_job_results.insert().values(**marshalled)
+ await self._execute(conn, insert)
+
+ # Decrement the number of running jobs for this task
+ update = (
+ self.t_tasks.update()
+ .values(running_jobs=self.t_tasks.c.running_jobs - 1)
+ .where(self.t_tasks.c.id == task_id)
+ )
+ await self._execute(conn, update)
- # Decrement the running jobs counter
- update = (
- self.t_tasks.update()
- .values(running_jobs=self.t_tasks.c.running_jobs - 1)
- .where(self.t_tasks.c.id == task_id)
- )
- conn.execute(update)
-
- # Delete the job
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id)
- conn.execute(delete)
-
- def get_job_result(self, job_id: UUID) -> JobResult | None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- # Retrieve the result
- query = self.t_job_results.select().where(
- self.t_job_results.c.job_id == job_id
- )
- row = conn.execute(query).first()
+ # Delete the job
+ delete = self.t_jobs.delete().where(
+ self.t_jobs.c.id == result.job_id
+ )
+ await self._execute(conn, delete)
+
+ async def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ # Retrieve the result
+ query = self.t_job_results.select().where(
+ self.t_job_results.c.job_id == job_id
+ )
+ row = (await self._execute(conn, query)).first()
- # Delete the result
- delete = self.t_job_results.delete().where(
- self.t_job_results.c.job_id == job_id
- )
- conn.execute(delete)
+ # Delete the result
+ delete = self.t_job_results.delete().where(
+ self.t_job_results.c.job_id == job_id
+ )
+ await self._execute(conn, delete)
- return (
- JobResult.unmarshal(self.serializer, row._asdict()) if row else None
- )
+ return (
+ JobResult.unmarshal(self.serializer, row._asdict())
+ if row
+ else None
+ )