diff options
Diffstat (limited to 'src/apscheduler/datastores/async_sqlalchemy.py')
-rw-r--r-- | src/apscheduler/datastores/async_sqlalchemy.py | 353 |
1 files changed, 240 insertions, 113 deletions
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py index ca97330..89163fa 100644 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ b/src/apscheduler/datastores/async_sqlalchemy.py @@ -21,9 +21,18 @@ from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule from ..enums import ConflictPolicy from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import ( - DataStoreEvent, JobAcquired, JobAdded, JobDeserializationFailed, ScheduleAdded, - ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, TaskAdded, TaskRemoved, - TaskUpdated) + DataStoreEvent, + JobAcquired, + JobAdded, + JobDeserializationFailed, + ScheduleAdded, + ScheduleDeserializationFailed, + ScheduleRemoved, + ScheduleUpdated, + TaskAdded, + TaskRemoved, + TaskUpdated, +) from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..structures import JobResult, Task @@ -50,14 +59,20 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): # Construct the Tenacity retry controller # OSError is raised by asyncpg if it can't connect self._retrying = tenacity.AsyncRetrying( - stop=self.retry_settings.stop, wait=self.retry_settings.wait, + stop=self.retry_settings.stop, + wait=self.retry_settings.wait, retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), - after=self._after_attempt, sleep=anyio.sleep, reraise=True) + after=self._after_attempt, + sleep=anyio.sleep, + reraise=True, + ) async def __aenter__(self): - asynclib = sniffio.current_async_library() or '(unknown)' - if asynclib != 'asyncio': - raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}') + asynclib = sniffio.current_async_library() or "(unknown)" + if asynclib != "asyncio": + raise RuntimeError( + f"This data store requires asyncio; currently running: {asynclib}" + ) # Verify that the schema is in place async for attempt in self._retrying: @@ -72,11 +87,14 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): result = await conn.execute(query) version = result.scalar() if version is None: - await conn.execute(self.t_metadata.insert(values={'schema_version': 1})) + await conn.execute( + self.t_metadata.insert(values={"schema_version": 1}) + ) elif version > 1: raise RuntimeError( - f'Unexpected schema version ({version}); ' - f'only version 1 is supported by this version of APScheduler') + f"Unexpected schema version ({version}); " + f"only version 1 is supported by this version of APScheduler" + ) await self._events.__aenter__() return self @@ -95,7 +113,8 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: await self._events.publish( - ScheduleDeserializationFailed(schedule_id=row['id'], exception=exc)) + ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) + ) return schedules @@ -106,25 +125,33 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: await self._events.publish( - JobDeserializationFailed(job_id=row['id'], exception=exc)) + JobDeserializationFailed(job_id=row["id"], exception=exc) + ) return jobs async def add_task(self, task: Task) -> None: - insert = self.t_tasks.insert().\ - values(id=task.id, func=callable_to_ref(task.func), - max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time) + insert = self.t_tasks.insert().values( + id=task.id, + func=callable_to_ref(task.func), + max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time, + ) try: async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: await conn.execute(insert) except IntegrityError: - update = self.t_tasks.update().\ - values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time).\ - where(self.t_tasks.c.id == task.id) + update = ( + self.t_tasks.update() + .values( + func=callable_to_ref(task.func), + max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time, + ) + .where(self.t_tasks.c.id == task.id) + ) async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: @@ -146,9 +173,15 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): await self._events.publish(TaskRemoved(task_id=task_id)) async def get_task(self, task_id: str) -> Task: - query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ - where(self.t_tasks.c.id == task_id) + query = select( + [ + self.t_tasks.c.id, + self.t_tasks.c.func, + self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, + self.t_tasks.c.misfire_grace_time, + ] + ).where(self.t_tasks.c.id == task_id) async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: @@ -161,17 +194,27 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): raise TaskLookupError async def get_tasks(self) -> list[Task]: - query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ - order_by(self.t_tasks.c.id) + query = select( + [ + self.t_tasks.c.id, + self.t_tasks.c.func, + self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, + self.t_tasks.c.misfire_grace_time, + ] + ).order_by(self.t_tasks.c.id) async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: result = await conn.execute(query) - tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] + tasks = [ + Task.unmarshal(self.serializer, row._asdict()) for row in result + ] return tasks - async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: event: DataStoreEvent values = schedule.marshal(self.serializer) insert = self.t_schedules.insert().values(**values) @@ -184,30 +227,38 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None elif conflict_policy is ConflictPolicy.replace: - del values['id'] - update = self.t_schedules.update().\ - where(self.t_schedules.c.id == schedule.id).\ - values(**values) + del values["id"] + update = ( + self.t_schedules.update() + .where(self.t_schedules.c.id == schedule.id) + .values(**values) + ) async for attempt in self._retrying: async with attempt, self.engine.begin() as conn: await conn.execute(update) - event = ScheduleUpdated(schedule_id=schedule.id, - next_fire_time=schedule.next_fire_time) + event = ScheduleUpdated( + schedule_id=schedule.id, next_fire_time=schedule.next_fire_time + ) await self._events.publish(event) else: - event = ScheduleAdded(schedule_id=schedule.id, - next_fire_time=schedule.next_fire_time) + event = ScheduleAdded( + schedule_id=schedule.id, next_fire_time=schedule.next_fire_time + ) await self._events.publish(event) async def remove_schedules(self, ids: Iterable[str]) -> None: async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: - delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(ids) + ) if self._supports_update_returning: delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [row[0] for row in await conn.execute(delete)] + removed_ids: Iterable[str] = [ + row[0] for row in await conn.execute(delete) + ] else: # TODO: actually check which rows were deleted? await conn.execute(delete) @@ -233,31 +284,46 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async with self.engine.begin() as conn: now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = select(self.t_schedules.c.id).\ - where(and_(self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_(self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now))).\ - order_by(self.t_schedules.c.next_fire_time).\ - limit(limit).with_for_update(skip_locked=True).cte() + schedules_cte = ( + select(self.t_schedules.c.id) + .where( + and_( + self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_( + self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now, + ), + ) + ) + .order_by(self.t_schedules.c.next_fire_time) + .limit(limit) + .with_for_update(skip_locked=True) + .cte() + ) subselect = select([schedules_cte.c.id]) - update = self.t_schedules.update().\ - where(self.t_schedules.c.id.in_(subselect)).\ - values(acquired_by=scheduler_id, acquired_until=acquired_until) + update = ( + self.t_schedules.update() + .where(self.t_schedules.c.id.in_(subselect)) + .values(acquired_by=scheduler_id, acquired_until=acquired_until) + ) if self._supports_update_returning: update = update.returning(*self.t_schedules.columns) result = await conn.execute(update) else: await conn.execute(update) - query = self.t_schedules.select().\ - where(and_(self.t_schedules.c.acquired_by == scheduler_id)) + query = self.t_schedules.select().where( + and_(self.t_schedules.c.acquired_by == scheduler_id) + ) result = conn.execute(query) schedules = await self._deserialize_schedules(result) return schedules - async def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: @@ -267,52 +333,74 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): for schedule in schedules: if schedule.next_fire_time is not None: try: - serialized_trigger = self.serializer.serialize(schedule.trigger) + serialized_trigger = self.serializer.serialize( + schedule.trigger + ) except SerializationError: self._logger.exception( - 'Error serializing trigger for schedule %r – ' - 'removing from data store', schedule.id) + "Error serializing trigger for schedule %r – " + "removing from data store", + schedule.id, + ) finished_schedule_ids.append(schedule.id) continue - update_args.append({ - 'p_id': schedule.id, - 'p_trigger': serialized_trigger, - 'p_next_fire_time': schedule.next_fire_time - }) + update_args.append( + { + "p_id": schedule.id, + "p_trigger": serialized_trigger, + "p_next_fire_time": schedule.next_fire_time, + } + ) else: finished_schedule_ids.append(schedule.id) # Update schedules that have a next fire time if update_args: - p_id: BindParameter = bindparam('p_id') - p_trigger: BindParameter = bindparam('p_trigger') - p_next_fire_time: BindParameter = bindparam('p_next_fire_time') - update = self.t_schedules.update().\ - where(and_(self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id)).\ - values(trigger=p_trigger, next_fire_time=p_next_fire_time, - acquired_by=None, acquired_until=None) - next_fire_times = {arg['p_id']: arg['p_next_fire_time'] - for arg in update_args} + p_id: BindParameter = bindparam("p_id") + p_trigger: BindParameter = bindparam("p_trigger") + p_next_fire_time: BindParameter = bindparam("p_next_fire_time") + update = ( + self.t_schedules.update() + .where( + and_( + self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id, + ) + ) + .values( + trigger=p_trigger, + next_fire_time=p_next_fire_time, + acquired_by=None, + acquired_until=None, + ) + ) + next_fire_times = { + arg["p_id"]: arg["p_next_fire_time"] for arg in update_args + } if self._supports_update_returning: update = update.returning(self.t_schedules.c.id) updated_ids = [ - row[0] for row in await conn.execute(update, update_args)] + row[0] + for row in await conn.execute(update, update_args) + ] else: # TODO: actually check which rows were updated? await conn.execute(update, update_args) updated_ids = list(next_fire_times) for schedule_id in updated_ids: - event = ScheduleUpdated(schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id]) + event = ScheduleUpdated( + schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id], + ) update_events.append(event) # Remove schedules that have no next fire time or failed to serialize if finished_schedule_ids: - delete = self.t_schedules.delete().\ - where(self.t_schedules.c.id.in_(finished_schedule_ids)) + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(finished_schedule_ids) + ) await conn.execute(delete) for event in update_events: @@ -322,10 +410,12 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) async def get_next_schedule_run_time(self) -> datetime | None: - statenent = select(self.t_schedules.c.next_fire_time).\ - where(self.t_schedules.c.next_fire_time.isnot(None)).\ - order_by(self.t_schedules.c.next_fire_time).\ - limit(1) + statenent = ( + select(self.t_schedules.c.next_fire_time) + .where(self.t_schedules.c.next_fire_time.isnot(None)) + .order_by(self.t_schedules.c.next_fire_time) + .limit(1) + ) async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: @@ -340,8 +430,12 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async with self.engine.begin() as conn: await conn.execute(insert) - event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, - tags=job.tags) + event = JobAdded( + job_id=job.id, + task_id=job.task_id, + schedule_id=job.schedule_id, + tags=job.tags, + ) await self._events.publish(event) async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: @@ -362,13 +456,19 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async with self.engine.begin() as conn: now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = self.t_jobs.select().\ - join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ - where(or_(self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now)).\ - order_by(self.t_jobs.c.created_at).\ - with_for_update(skip_locked=True).\ - limit(limit) + query = ( + self.t_jobs.select() + .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) + .where( + or_( + self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now, + ) + ) + .order_by(self.t_jobs.c.created_at) + .with_for_update(skip_locked=True) + .limit(limit) + ) result = await conn.execute(query) if not result: @@ -379,11 +479,16 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): task_ids: set[str] = {job.task_id for job in jobs} # Retrieve the limits - query = select([ - self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ - where(self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids)) + query = select( + [ + self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs + - self.t_tasks.c.running_jobs, + ] + ).where( + self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids), + ) result = await conn.execute(query) job_slots_left: dict[str, int] = dict(result.fetchall()) @@ -404,19 +509,29 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): if acquired_jobs: # Mark the acquired jobs as acquired by this worker acquired_job_ids = [job.id for job in acquired_jobs] - update = self.t_jobs.update().\ - values(acquired_by=worker_id, acquired_until=acquired_until).\ - where(self.t_jobs.c.id.in_(acquired_job_ids)) + update = ( + self.t_jobs.update() + .values( + acquired_by=worker_id, acquired_until=acquired_until + ) + .where(self.t_jobs.c.id.in_(acquired_job_ids)) + ) await conn.execute(update) # Increment the running job counters on each task - p_id: BindParameter = bindparam('p_id') - p_increment: BindParameter = bindparam('p_increment') - params = [{'p_id': task_id, 'p_increment': increment} - for task_id, increment in increments.items()] - update = self.t_tasks.update().\ - values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ - where(self.t_tasks.c.id == p_id) + p_id: BindParameter = bindparam("p_id") + p_increment: BindParameter = bindparam("p_increment") + params = [ + {"p_id": task_id, "p_increment": increment} + for task_id, increment in increments.items() + ] + update = ( + self.t_tasks.update() + .values( + running_jobs=self.t_tasks.c.running_jobs + p_increment + ) + .where(self.t_tasks.c.id == p_id) + ) await conn.execute(update, params) # Publish the appropriate events @@ -425,7 +540,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): return acquired_jobs - async def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: async for attempt in self._retrying: with attempt: async with self.engine.begin() as conn: @@ -435,13 +552,17 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): await conn.execute(insert) # Decrement the running jobs counter - update = self.t_tasks.update().\ - values(running_jobs=self.t_tasks.c.running_jobs - 1).\ - where(self.t_tasks.c.id == task_id) + update = ( + self.t_tasks.update() + .values(running_jobs=self.t_tasks.c.running_jobs - 1) + .where(self.t_tasks.c.id == task_id) + ) await conn.execute(update) # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) + delete = self.t_jobs.delete().where( + self.t_jobs.c.id == result.job_id + ) await conn.execute(delete) async def get_job_result(self, job_id: UUID) -> JobResult | None: @@ -449,13 +570,19 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): with attempt: async with self.engine.begin() as conn: # Retrieve the result - query = self.t_job_results.select().\ - where(self.t_job_results.c.job_id == job_id) + query = self.t_job_results.select().where( + self.t_job_results.c.job_id == job_id + ) row = (await conn.execute(query)).fetchone() # Delete the result - delete = self.t_job_results.delete().\ - where(self.t_job_results.c.job_id == job_id) + delete = self.t_job_results.delete().where( + self.t_job_results.c.job_id == job_id + ) await conn.execute(delete) - return JobResult.unmarshal(self.serializer, row._asdict()) if row else None + return ( + JobResult.unmarshal(self.serializer, row._asdict()) + if row + else None + ) |