diff options
Diffstat (limited to 'src/apscheduler/datastores/async_sqlalchemy.py')
-rw-r--r-- | src/apscheduler/datastores/async_sqlalchemy.py | 602 |
1 files changed, 0 insertions, 602 deletions
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py deleted file mode 100644 index 0c165b8..0000000 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ /dev/null @@ -1,602 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from datetime import datetime, timedelta, timezone -from typing import Any, Iterable -from uuid import UUID - -import anyio -import attrs -import sniffio -import tenacity -from sqlalchemy import and_, bindparam, or_, select -from sqlalchemy.engine import URL, Result -from sqlalchemy.exc import IntegrityError, InterfaceError -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.ext.asyncio.engine import AsyncEngine -from sqlalchemy.sql.ddl import DropTable -from sqlalchemy.sql.elements import BindParameter - -from .._enums import ConflictPolicy -from .._events import ( - DataStoreEvent, - JobAcquired, - JobAdded, - JobDeserializationFailed, - ScheduleAdded, - ScheduleDeserializationFailed, - ScheduleRemoved, - ScheduleUpdated, - TaskAdded, - TaskRemoved, - TaskUpdated, -) -from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError -from .._structures import Job, JobResult, Schedule, Task -from ..abc import AsyncEventBroker -from ..marshalling import callable_to_ref -from .base import BaseAsyncDataStore -from .sqlalchemy import _BaseSQLAlchemyDataStore - - -@attrs.define(eq=False) -class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore): - """ - Uses a relational database to store data. - - When started, this data store creates the appropriate tables on the given database - if they're not already present. - - Operations are retried (in accordance to ``retry_settings``) when an operation - raises :exc:`sqlalchemy.OperationalError`. - - This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL - (asyncmy driver). - - :param engine: an asynchronous SQLAlchemy engine - :param schema: a database schema name to use, if not the default - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) - """ - - engine: AsyncEngine - - @classmethod - def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: - """ - Create a new asynchronous SQLAlchemy data store. - - :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` - (must use an async dialect like ``asyncpg`` or ``asyncmy``) - :param kwargs: keyword arguments to pass to the initializer of this class - :return: the newly created data store - - """ - engine = create_async_engine(url, future=True) - return cls(engine, **options) - - def _retry(self) -> tenacity.AsyncRetrying: - # OSError is raised by asyncpg if it can't connect - return tenacity.AsyncRetrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), - after=self._after_attempt, - sleep=anyio.sleep, - reraise=True, - ) - - async def start(self, event_broker: AsyncEventBroker) -> None: - await super().start(event_broker) - - asynclib = sniffio.current_async_library() or "(unknown)" - if asynclib != "asyncio": - raise RuntimeError( - f"This data store requires asyncio; currently running: {asynclib}" - ) - - # Verify that the schema is in place - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - await conn.execute(DropTable(table, if_exists=True)) - - await conn.run_sync(self._metadata.create_all) - query = select(self.t_metadata.c.schema_version) - result = await conn.execute(query) - version = result.scalar() - if version is None: - await conn.execute( - self.t_metadata.insert(values={"schema_version": 1}) - ) - elif version > 1: - raise RuntimeError( - f"Unexpected schema version ({version}); " - f"only version 1 is supported by this version of " - f"APScheduler" - ) - - async def _deserialize_schedules(self, result: Result) -> list[Schedule]: - schedules: list[Schedule] = [] - for row in result: - try: - schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) - except SerializationError as exc: - await self._events.publish( - ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) - ) - - return schedules - - async def _deserialize_jobs(self, result: Result) -> list[Job]: - jobs: list[Job] = [] - for row in result: - try: - jobs.append(Job.unmarshal(self.serializer, row._asdict())) - except SerializationError as exc: - await self._events.publish( - JobDeserializationFailed(job_id=row["id"], exception=exc) - ) - - return jobs - - async def add_task(self, task: Task) -> None: - insert = self.t_tasks.insert().values( - id=task.id, - func=callable_to_ref(task.func), - max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time, - ) - try: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - except IntegrityError: - update = ( - self.t_tasks.update() - .values( - func=callable_to_ref(task.func), - max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time, - ) - .where(self.t_tasks.c.id == task.id) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(update) - - await self._events.publish(TaskUpdated(task_id=task.id)) - else: - await self._events.publish(TaskAdded(task_id=task.id)) - - async def remove_task(self, task_id: str) -> None: - delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - await self._events.publish(TaskRemoved(task_id=task_id)) - - async def get_task(self, task_id: str) -> Task: - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.func, - self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, - self.t_tasks.c.misfire_grace_time, - ] - ).where(self.t_tasks.c.id == task_id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - row = result.first() - - if row: - return Task.unmarshal(self.serializer, row._asdict()) - else: - raise TaskLookupError(task_id) - - async def get_tasks(self) -> list[Task]: - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.func, - self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, - self.t_tasks.c.misfire_grace_time, - ] - ).order_by(self.t_tasks.c.id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - tasks = [ - Task.unmarshal(self.serializer, row._asdict()) for row in result - ] - return tasks - - async def add_schedule( - self, schedule: Schedule, conflict_policy: ConflictPolicy - ) -> None: - event: DataStoreEvent - values = schedule.marshal(self.serializer) - insert = self.t_schedules.insert().values(**values) - try: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - except IntegrityError: - if conflict_policy is ConflictPolicy.exception: - raise ConflictingIdError(schedule.id) from None - elif conflict_policy is ConflictPolicy.replace: - del values["id"] - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id == schedule.id) - .values(**values) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(update) - - event = ScheduleUpdated( - schedule_id=schedule.id, next_fire_time=schedule.next_fire_time - ) - await self._events.publish(event) - else: - event = ScheduleAdded( - schedule_id=schedule.id, next_fire_time=schedule.next_fire_time - ) - await self._events.publish(event) - - async def remove_schedules(self, ids: Iterable[str]) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(ids) - ) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [ - row[0] for row in await conn.execute(delete) - ] - else: - # TODO: actually check which rows were deleted? - await conn.execute(delete) - removed_ids = ids - - for schedule_id in removed_ids: - await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) - - async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: - query = self.t_schedules.select().order_by(self.t_schedules.c.id) - if ids: - query = query.where(self.t_schedules.c.id.in_(ids)) - - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_schedules(result) - - async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = ( - select(self.t_schedules.c.id) - .where( - and_( - self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_( - self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now, - ), - ) - ) - .order_by(self.t_schedules.c.next_fire_time) - .limit(limit) - .with_for_update(skip_locked=True) - .cte() - ) - subselect = select([schedules_cte.c.id]) - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id.in_(subselect)) - .values(acquired_by=scheduler_id, acquired_until=acquired_until) - ) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = await conn.execute(update) - else: - await conn.execute(update) - query = self.t_schedules.select().where( - and_(self.t_schedules.c.acquired_by == scheduler_id) - ) - result = await conn.execute(query) - - schedules = await self._deserialize_schedules(result) - - return schedules - - async def release_schedules( - self, scheduler_id: str, schedules: list[Schedule] - ) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize( - schedule.trigger - ) - except SerializationError: - self._logger.exception( - "Error serializing trigger for schedule %r – " - "removing from data store", - schedule.id, - ) - finished_schedule_ids.append(schedule.id) - continue - - update_args.append( - { - "p_id": schedule.id, - "p_trigger": serialized_trigger, - "p_next_fire_time": schedule.next_fire_time, - } - ) - else: - finished_schedule_ids.append(schedule.id) - - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam("p_id") - p_trigger: BindParameter = bindparam("p_trigger") - p_next_fire_time: BindParameter = bindparam("p_next_fire_time") - update = ( - self.t_schedules.update() - .where( - and_( - self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id, - ) - ) - .values( - trigger=p_trigger, - next_fire_time=p_next_fire_time, - acquired_by=None, - acquired_until=None, - ) - ) - next_fire_times = { - arg["p_id"]: arg["p_next_fire_time"] for arg in update_args - } - # TODO: actually check which rows were updated? - await conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated( - schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id], - ) - update_events.append(event) - - # Remove schedules that have no next fire time or failed to - # serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(finished_schedule_ids) - ) - await conn.execute(delete) - - for event in update_events: - await self._events.publish(event) - - for schedule_id in finished_schedule_ids: - await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) - - async def get_next_schedule_run_time(self) -> datetime | None: - statenent = ( - select(self.t_schedules.c.next_fire_time) - .where(self.t_schedules.c.next_fire_time.isnot(None)) - .order_by(self.t_schedules.c.next_fire_time) - .limit(1) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(statenent) - return result.scalar() - - async def add_job(self, job: Job) -> None: - marshalled = job.marshal(self.serializer) - insert = self.t_jobs.insert().values(**marshalled) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - - event = JobAdded( - job_id=job.id, - task_id=job.task_id, - schedule_id=job.schedule_id, - tags=job.tags, - ) - await self._events.publish(event) - - async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: - query = self.t_jobs.select().order_by(self.t_jobs.c.id) - if ids: - job_ids = [job_id for job_id in ids] - query = query.where(self.t_jobs.c.id.in_(job_ids)) - - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_jobs(result) - - async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = ( - self.t_jobs.select() - .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) - .where( - or_( - self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now, - ) - ) - .order_by(self.t_jobs.c.created_at) - .with_for_update(skip_locked=True) - .limit(limit) - ) - - result = await conn.execute(query) - if not result: - return [] - - # Mark the jobs as acquired by this worker - jobs = await self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - - self.t_tasks.c.running_jobs, - ] - ).where( - self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids), - ) - result = await conn.execute(query) - job_slots_left: dict[str, int] = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = ( - self.t_jobs.update() - .values( - acquired_by=worker_id, acquired_until=acquired_until - ) - .where(self.t_jobs.c.id.in_(acquired_job_ids)) - ) - await conn.execute(update) - - # Increment the running job counters on each task - p_id: BindParameter = bindparam("p_id") - p_increment: BindParameter = bindparam("p_increment") - params = [ - {"p_id": task_id, "p_increment": increment} - for task_id, increment in increments.items() - ] - update = ( - self.t_tasks.update() - .values( - running_jobs=self.t_tasks.c.running_jobs + p_increment - ) - .where(self.t_tasks.c.id == p_id) - ) - await conn.execute(update, params) - - # Publish the appropriate events - for job in acquired_jobs: - await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) - - return acquired_jobs - - async def release_job( - self, worker_id: str, task_id: str, result: JobResult - ) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - # Record the job result - if result.expires_at > result.finished_at: - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - await conn.execute(insert) - - # Decrement the number of running jobs for this task - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs - 1) - .where(self.t_tasks.c.id == task_id) - ) - await conn.execute(update) - - # Delete the job - delete = self.t_jobs.delete().where( - self.t_jobs.c.id == result.job_id - ) - await conn.execute(delete) - - async def get_job_result(self, job_id: UUID) -> JobResult | None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().where( - self.t_job_results.c.job_id == job_id - ) - row = (await conn.execute(query)).first() - - # Delete the result - delete = self.t_job_results.delete().where( - self.t_job_results.c.job_id == job_id - ) - await conn.execute(delete) - - return ( - JobResult.unmarshal(self.serializer, row._asdict()) - if row - else None - ) |