from __future__ import annotations from collections import defaultdict from datetime import datetime, timedelta, timezone from typing import Any, Iterable from uuid import UUID import anyio import attrs import sniffio import tenacity from sqlalchemy import and_, bindparam, or_, select from sqlalchemy.engine import URL, Result from sqlalchemy.exc import IntegrityError, InterfaceError from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule from ..enums import ConflictPolicy from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import ( DataStoreEvent, JobAcquired, JobAdded, JobDeserializationFailed, ScheduleAdded, ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, TaskAdded, TaskRemoved, TaskUpdated, ) from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..structures import JobResult, Task from ..util import reentrant from .sqlalchemy import _BaseSQLAlchemyDataStore @reentrant @attrs.define(eq=False) class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): engine: AsyncEngine _events: AsyncEventBroker = attrs.field(factory=LocalAsyncEventBroker) @classmethod def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: engine = create_async_engine(url, future=True) return cls(engine, **options) def _retry(self) -> tenacity.AsyncRetrying: # OSError is raised by asyncpg if it can't connect return tenacity.AsyncRetrying( stop=self.retry_settings.stop, wait=self.retry_settings.wait, retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), after=self._after_attempt, sleep=anyio.sleep, reraise=True, ) async def __aenter__(self): asynclib = sniffio.current_async_library() or "(unknown)" if asynclib != "asyncio": raise RuntimeError( f"This data store requires asyncio; currently running: {asynclib}" ) # Verify that the schema is in place async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: if self.start_from_scratch: for table in self._metadata.sorted_tables: await conn.execute(DropTable(table, if_exists=True)) await conn.run_sync(self._metadata.create_all) query = select(self.t_metadata.c.schema_version) result = await conn.execute(query) version = result.scalar() if version is None: await conn.execute( self.t_metadata.insert(values={"schema_version": 1}) ) elif version > 1: raise RuntimeError( f"Unexpected schema version ({version}); " f"only version 1 is supported by this version of APScheduler" ) await self._events.__aenter__() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self._events.__aexit__(exc_type, exc_val, exc_tb) @property def events(self) -> EventSource: return self._events async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: try: schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: await self._events.publish( ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) ) return schedules async def _deserialize_jobs(self, result: Result) -> list[Job]: jobs: list[Job] = [] for row in result: try: jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: await self._events.publish( JobDeserializationFailed(job_id=row["id"], exception=exc) ) return jobs async def add_task(self, task: Task) -> None: insert = self.t_tasks.insert().values( id=task.id, func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) try: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: await conn.execute(insert) except IntegrityError: update = ( self.t_tasks.update() .values( func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) .where(self.t_tasks.c.id == task.id) ) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: await conn.execute(update) await self._events.publish(TaskUpdated(task_id=task.id)) else: await self._events.publish(TaskAdded(task_id=task.id)) async def remove_task(self, task_id: str) -> None: delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: result = await conn.execute(delete) if result.rowcount == 0: raise TaskLookupError(task_id) else: await self._events.publish(TaskRemoved(task_id=task_id)) async def get_task(self, task_id: str) -> Task: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).where(self.t_tasks.c.id == task_id) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: result = await conn.execute(query) row = result.one() if row: return Task.unmarshal(self.serializer, row._asdict()) else: raise TaskLookupError async def get_tasks(self) -> list[Task]: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).order_by(self.t_tasks.c.id) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: result = await conn.execute(query) tasks = [ Task.unmarshal(self.serializer, row._asdict()) for row in result ] return tasks async def add_schedule( self, schedule: Schedule, conflict_policy: ConflictPolicy ) -> None: event: DataStoreEvent values = schedule.marshal(self.serializer) insert = self.t_schedules.insert().values(**values) try: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: await conn.execute(insert) except IntegrityError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None elif conflict_policy is ConflictPolicy.replace: del values["id"] update = ( self.t_schedules.update() .where(self.t_schedules.c.id == schedule.id) .values(**values) ) async for attempt in self._retry(): async with attempt, self.engine.begin() as conn: await conn.execute(update) event = ScheduleUpdated( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) await self._events.publish(event) else: event = ScheduleAdded( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) await self._events.publish(event) async def remove_schedules(self, ids: Iterable[str]) -> None: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: delete = self.t_schedules.delete().where( self.t_schedules.c.id.in_(ids) ) if self._supports_update_returning: delete = delete.returning(self.t_schedules.c.id) removed_ids: Iterable[str] = [ row[0] for row in await conn.execute(delete) ] else: # TODO: actually check which rows were deleted? await conn.execute(delete) removed_ids = ids for schedule_id in removed_ids: await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: query = self.t_schedules.select().order_by(self.t_schedules.c.id) if ids: query = query.where(self.t_schedules.c.id.in_(ids)) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: result = await conn.execute(query) return await self._deserialize_schedules(result) async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) schedules_cte = ( select(self.t_schedules.c.id) .where( and_( self.t_schedules.c.next_fire_time.isnot(None), self.t_schedules.c.next_fire_time <= now, or_( self.t_schedules.c.acquired_until.is_(None), self.t_schedules.c.acquired_until < now, ), ) ) .order_by(self.t_schedules.c.next_fire_time) .limit(limit) .with_for_update(skip_locked=True) .cte() ) subselect = select([schedules_cte.c.id]) update = ( self.t_schedules.update() .where(self.t_schedules.c.id.in_(subselect)) .values(acquired_by=scheduler_id, acquired_until=acquired_until) ) if self._supports_update_returning: update = update.returning(*self.t_schedules.columns) result = await conn.execute(update) else: await conn.execute(update) query = self.t_schedules.select().where( and_(self.t_schedules.c.acquired_by == scheduler_id) ) result = conn.execute(query) schedules = await self._deserialize_schedules(result) return schedules async def release_schedules( self, scheduler_id: str, schedules: list[Schedule] ) -> None: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: update_events: list[ScheduleUpdated] = [] finished_schedule_ids: list[str] = [] update_args: list[dict[str, Any]] = [] for schedule in schedules: if schedule.next_fire_time is not None: try: serialized_trigger = self.serializer.serialize( schedule.trigger ) except SerializationError: self._logger.exception( "Error serializing trigger for schedule %r – " "removing from data store", schedule.id, ) finished_schedule_ids.append(schedule.id) continue update_args.append( { "p_id": schedule.id, "p_trigger": serialized_trigger, "p_next_fire_time": schedule.next_fire_time, } ) else: finished_schedule_ids.append(schedule.id) # Update schedules that have a next fire time if update_args: p_id: BindParameter = bindparam("p_id") p_trigger: BindParameter = bindparam("p_trigger") p_next_fire_time: BindParameter = bindparam("p_next_fire_time") update = ( self.t_schedules.update() .where( and_( self.t_schedules.c.id == p_id, self.t_schedules.c.acquired_by == scheduler_id, ) ) .values( trigger=p_trigger, next_fire_time=p_next_fire_time, acquired_by=None, acquired_until=None, ) ) next_fire_times = { arg["p_id"]: arg["p_next_fire_time"] for arg in update_args } if self._supports_update_returning: update = update.returning(self.t_schedules.c.id) updated_ids = [ row[0] for row in await conn.execute(update, update_args) ] else: # TODO: actually check which rows were updated? await conn.execute(update, update_args) updated_ids = list(next_fire_times) for schedule_id in updated_ids: event = ScheduleUpdated( schedule_id=schedule_id, next_fire_time=next_fire_times[schedule_id], ) update_events.append(event) # Remove schedules that have no next fire time or failed to serialize if finished_schedule_ids: delete = self.t_schedules.delete().where( self.t_schedules.c.id.in_(finished_schedule_ids) ) await conn.execute(delete) for event in update_events: await self._events.publish(event) for schedule_id in finished_schedule_ids: await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) async def get_next_schedule_run_time(self) -> datetime | None: statenent = ( select(self.t_schedules.c.next_fire_time) .where(self.t_schedules.c.next_fire_time.isnot(None)) .order_by(self.t_schedules.c.next_fire_time) .limit(1) ) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: result = await conn.execute(statenent) return result.scalar() async def add_job(self, job: Job) -> None: marshalled = job.marshal(self.serializer) insert = self.t_jobs.insert().values(**marshalled) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: await conn.execute(insert) event = JobAdded( job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, tags=job.tags, ) await self._events.publish(event) async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: query = self.t_jobs.select().order_by(self.t_jobs.c.id) if ids: job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: result = await conn.execute(query) return await self._deserialize_jobs(result) async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) query = ( self.t_jobs.select() .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) .where( or_( self.t_jobs.c.acquired_until.is_(None), self.t_jobs.c.acquired_until < now, ) ) .order_by(self.t_jobs.c.created_at) .with_for_update(skip_locked=True) .limit(limit) ) result = await conn.execute(query) if not result: return [] # Mark the jobs as acquired by this worker jobs = await self._deserialize_jobs(result) task_ids: set[str] = {job.task_id for job in jobs} # Retrieve the limits query = select( [ self.t_tasks.c.id, self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs, ] ).where( self.t_tasks.c.max_running_jobs.isnot(None), self.t_tasks.c.id.in_(task_ids), ) result = await conn.execute(query) job_slots_left: dict[str, int] = dict(result.fetchall()) # Filter out jobs that don't have free slots acquired_jobs: list[Job] = [] increments: dict[str, int] = defaultdict(lambda: 0) for job in jobs: # Don't acquire the job if there are no free slots left slots_left = job_slots_left.get(job.task_id) if slots_left == 0: continue elif slots_left is not None: job_slots_left[job.task_id] -= 1 acquired_jobs.append(job) increments[job.task_id] += 1 if acquired_jobs: # Mark the acquired jobs as acquired by this worker acquired_job_ids = [job.id for job in acquired_jobs] update = ( self.t_jobs.update() .values( acquired_by=worker_id, acquired_until=acquired_until ) .where(self.t_jobs.c.id.in_(acquired_job_ids)) ) await conn.execute(update) # Increment the running job counters on each task p_id: BindParameter = bindparam("p_id") p_increment: BindParameter = bindparam("p_increment") params = [ {"p_id": task_id, "p_increment": increment} for task_id, increment in increments.items() ] update = ( self.t_tasks.update() .values( running_jobs=self.t_tasks.c.running_jobs + p_increment ) .where(self.t_tasks.c.id == p_id) ) await conn.execute(update, params) # Publish the appropriate events for job in acquired_jobs: await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) return acquired_jobs async def release_job( self, worker_id: str, task_id: str, result: JobResult ) -> None: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: # Insert the job result marshalled = result.marshal(self.serializer) insert = self.t_job_results.insert().values(**marshalled) await conn.execute(insert) # Decrement the running jobs counter update = ( self.t_tasks.update() .values(running_jobs=self.t_tasks.c.running_jobs - 1) .where(self.t_tasks.c.id == task_id) ) await conn.execute(update) # Delete the job delete = self.t_jobs.delete().where( self.t_jobs.c.id == result.job_id ) await conn.execute(delete) async def get_job_result(self, job_id: UUID) -> JobResult | None: async for attempt in self._retry(): with attempt: async with self.engine.begin() as conn: # Retrieve the result query = self.t_job_results.select().where( self.t_job_results.c.job_id == job_id ) row = (await conn.execute(query)).fetchone() # Delete the result delete = self.t_job_results.delete().where( self.t_job_results.c.job_id == job_id ) await conn.execute(delete) return ( JobResult.unmarshal(self.serializer, row._asdict()) if row else None )