summaryrefslogtreecommitdiff
path: root/src/apscheduler/datastores/async_sqlalchemy.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/datastores/async_sqlalchemy.py')
-rw-r--r--src/apscheduler/datastores/async_sqlalchemy.py353
1 files changed, 240 insertions, 113 deletions
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py
index ca97330..89163fa 100644
--- a/src/apscheduler/datastores/async_sqlalchemy.py
+++ b/src/apscheduler/datastores/async_sqlalchemy.py
@@ -21,9 +21,18 @@ from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule
from ..enums import ConflictPolicy
from ..eventbrokers.async_local import LocalAsyncEventBroker
from ..events import (
- DataStoreEvent, JobAcquired, JobAdded, JobDeserializationFailed, ScheduleAdded,
- ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, TaskAdded, TaskRemoved,
- TaskUpdated)
+ DataStoreEvent,
+ JobAcquired,
+ JobAdded,
+ JobDeserializationFailed,
+ ScheduleAdded,
+ ScheduleDeserializationFailed,
+ ScheduleRemoved,
+ ScheduleUpdated,
+ TaskAdded,
+ TaskRemoved,
+ TaskUpdated,
+)
from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError
from ..marshalling import callable_to_ref
from ..structures import JobResult, Task
@@ -50,14 +59,20 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
# Construct the Tenacity retry controller
# OSError is raised by asyncpg if it can't connect
self._retrying = tenacity.AsyncRetrying(
- stop=self.retry_settings.stop, wait=self.retry_settings.wait,
+ stop=self.retry_settings.stop,
+ wait=self.retry_settings.wait,
retry=tenacity.retry_if_exception_type((InterfaceError, OSError)),
- after=self._after_attempt, sleep=anyio.sleep, reraise=True)
+ after=self._after_attempt,
+ sleep=anyio.sleep,
+ reraise=True,
+ )
async def __aenter__(self):
- asynclib = sniffio.current_async_library() or '(unknown)'
- if asynclib != 'asyncio':
- raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}')
+ asynclib = sniffio.current_async_library() or "(unknown)"
+ if asynclib != "asyncio":
+ raise RuntimeError(
+ f"This data store requires asyncio; currently running: {asynclib}"
+ )
# Verify that the schema is in place
async for attempt in self._retrying:
@@ -72,11 +87,14 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
result = await conn.execute(query)
version = result.scalar()
if version is None:
- await conn.execute(self.t_metadata.insert(values={'schema_version': 1}))
+ await conn.execute(
+ self.t_metadata.insert(values={"schema_version": 1})
+ )
elif version > 1:
raise RuntimeError(
- f'Unexpected schema version ({version}); '
- f'only version 1 is supported by this version of APScheduler')
+ f"Unexpected schema version ({version}); "
+ f"only version 1 is supported by this version of APScheduler"
+ )
await self._events.__aenter__()
return self
@@ -95,7 +113,8 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
schedules.append(Schedule.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
await self._events.publish(
- ScheduleDeserializationFailed(schedule_id=row['id'], exception=exc))
+ ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc)
+ )
return schedules
@@ -106,25 +125,33 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
jobs.append(Job.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
await self._events.publish(
- JobDeserializationFailed(job_id=row['id'], exception=exc))
+ JobDeserializationFailed(job_id=row["id"], exception=exc)
+ )
return jobs
async def add_task(self, task: Task) -> None:
- insert = self.t_tasks.insert().\
- values(id=task.id, func=callable_to_ref(task.func),
- max_running_jobs=task.max_running_jobs,
- misfire_grace_time=task.misfire_grace_time)
+ insert = self.t_tasks.insert().values(
+ id=task.id,
+ func=callable_to_ref(task.func),
+ max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time,
+ )
try:
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
await conn.execute(insert)
except IntegrityError:
- update = self.t_tasks.update().\
- values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs,
- misfire_grace_time=task.misfire_grace_time).\
- where(self.t_tasks.c.id == task.id)
+ update = (
+ self.t_tasks.update()
+ .values(
+ func=callable_to_ref(task.func),
+ max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time,
+ )
+ .where(self.t_tasks.c.id == task.id)
+ )
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
@@ -146,9 +173,15 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
await self._events.publish(TaskRemoved(task_id=task_id))
async def get_task(self, task_id: str) -> Task:
- query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
- self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
- where(self.t_tasks.c.id == task_id)
+ query = select(
+ [
+ self.t_tasks.c.id,
+ self.t_tasks.c.func,
+ self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state,
+ self.t_tasks.c.misfire_grace_time,
+ ]
+ ).where(self.t_tasks.c.id == task_id)
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
@@ -161,17 +194,27 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
raise TaskLookupError
async def get_tasks(self) -> list[Task]:
- query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
- self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
- order_by(self.t_tasks.c.id)
+ query = select(
+ [
+ self.t_tasks.c.id,
+ self.t_tasks.c.func,
+ self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state,
+ self.t_tasks.c.misfire_grace_time,
+ ]
+ ).order_by(self.t_tasks.c.id)
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
result = await conn.execute(query)
- tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result]
+ tasks = [
+ Task.unmarshal(self.serializer, row._asdict()) for row in result
+ ]
return tasks
- async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
event: DataStoreEvent
values = schedule.marshal(self.serializer)
insert = self.t_schedules.insert().values(**values)
@@ -184,30 +227,38 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
if conflict_policy is ConflictPolicy.exception:
raise ConflictingIdError(schedule.id) from None
elif conflict_policy is ConflictPolicy.replace:
- del values['id']
- update = self.t_schedules.update().\
- where(self.t_schedules.c.id == schedule.id).\
- values(**values)
+ del values["id"]
+ update = (
+ self.t_schedules.update()
+ .where(self.t_schedules.c.id == schedule.id)
+ .values(**values)
+ )
async for attempt in self._retrying:
async with attempt, self.engine.begin() as conn:
await conn.execute(update)
- event = ScheduleUpdated(schedule_id=schedule.id,
- next_fire_time=schedule.next_fire_time)
+ event = ScheduleUpdated(
+ schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
+ )
await self._events.publish(event)
else:
- event = ScheduleAdded(schedule_id=schedule.id,
- next_fire_time=schedule.next_fire_time)
+ event = ScheduleAdded(
+ schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
+ )
await self._events.publish(event)
async def remove_schedules(self, ids: Iterable[str]) -> None:
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
- delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids))
+ delete = self.t_schedules.delete().where(
+ self.t_schedules.c.id.in_(ids)
+ )
if self._supports_update_returning:
delete = delete.returning(self.t_schedules.c.id)
- removed_ids: Iterable[str] = [row[0] for row in await conn.execute(delete)]
+ removed_ids: Iterable[str] = [
+ row[0] for row in await conn.execute(delete)
+ ]
else:
# TODO: actually check which rows were deleted?
await conn.execute(delete)
@@ -233,31 +284,46 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
async with self.engine.begin() as conn:
now = datetime.now(timezone.utc)
acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- schedules_cte = select(self.t_schedules.c.id).\
- where(and_(self.t_schedules.c.next_fire_time.isnot(None),
- self.t_schedules.c.next_fire_time <= now,
- or_(self.t_schedules.c.acquired_until.is_(None),
- self.t_schedules.c.acquired_until < now))).\
- order_by(self.t_schedules.c.next_fire_time).\
- limit(limit).with_for_update(skip_locked=True).cte()
+ schedules_cte = (
+ select(self.t_schedules.c.id)
+ .where(
+ and_(
+ self.t_schedules.c.next_fire_time.isnot(None),
+ self.t_schedules.c.next_fire_time <= now,
+ or_(
+ self.t_schedules.c.acquired_until.is_(None),
+ self.t_schedules.c.acquired_until < now,
+ ),
+ )
+ )
+ .order_by(self.t_schedules.c.next_fire_time)
+ .limit(limit)
+ .with_for_update(skip_locked=True)
+ .cte()
+ )
subselect = select([schedules_cte.c.id])
- update = self.t_schedules.update().\
- where(self.t_schedules.c.id.in_(subselect)).\
- values(acquired_by=scheduler_id, acquired_until=acquired_until)
+ update = (
+ self.t_schedules.update()
+ .where(self.t_schedules.c.id.in_(subselect))
+ .values(acquired_by=scheduler_id, acquired_until=acquired_until)
+ )
if self._supports_update_returning:
update = update.returning(*self.t_schedules.columns)
result = await conn.execute(update)
else:
await conn.execute(update)
- query = self.t_schedules.select().\
- where(and_(self.t_schedules.c.acquired_by == scheduler_id))
+ query = self.t_schedules.select().where(
+ and_(self.t_schedules.c.acquired_by == scheduler_id)
+ )
result = conn.execute(query)
schedules = await self._deserialize_schedules(result)
return schedules
- async def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
@@ -267,52 +333,74 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
for schedule in schedules:
if schedule.next_fire_time is not None:
try:
- serialized_trigger = self.serializer.serialize(schedule.trigger)
+ serialized_trigger = self.serializer.serialize(
+ schedule.trigger
+ )
except SerializationError:
self._logger.exception(
- 'Error serializing trigger for schedule %r – '
- 'removing from data store', schedule.id)
+ "Error serializing trigger for schedule %r – "
+ "removing from data store",
+ schedule.id,
+ )
finished_schedule_ids.append(schedule.id)
continue
- update_args.append({
- 'p_id': schedule.id,
- 'p_trigger': serialized_trigger,
- 'p_next_fire_time': schedule.next_fire_time
- })
+ update_args.append(
+ {
+ "p_id": schedule.id,
+ "p_trigger": serialized_trigger,
+ "p_next_fire_time": schedule.next_fire_time,
+ }
+ )
else:
finished_schedule_ids.append(schedule.id)
# Update schedules that have a next fire time
if update_args:
- p_id: BindParameter = bindparam('p_id')
- p_trigger: BindParameter = bindparam('p_trigger')
- p_next_fire_time: BindParameter = bindparam('p_next_fire_time')
- update = self.t_schedules.update().\
- where(and_(self.t_schedules.c.id == p_id,
- self.t_schedules.c.acquired_by == scheduler_id)).\
- values(trigger=p_trigger, next_fire_time=p_next_fire_time,
- acquired_by=None, acquired_until=None)
- next_fire_times = {arg['p_id']: arg['p_next_fire_time']
- for arg in update_args}
+ p_id: BindParameter = bindparam("p_id")
+ p_trigger: BindParameter = bindparam("p_trigger")
+ p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
+ update = (
+ self.t_schedules.update()
+ .where(
+ and_(
+ self.t_schedules.c.id == p_id,
+ self.t_schedules.c.acquired_by == scheduler_id,
+ )
+ )
+ .values(
+ trigger=p_trigger,
+ next_fire_time=p_next_fire_time,
+ acquired_by=None,
+ acquired_until=None,
+ )
+ )
+ next_fire_times = {
+ arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
+ }
if self._supports_update_returning:
update = update.returning(self.t_schedules.c.id)
updated_ids = [
- row[0] for row in await conn.execute(update, update_args)]
+ row[0]
+ for row in await conn.execute(update, update_args)
+ ]
else:
# TODO: actually check which rows were updated?
await conn.execute(update, update_args)
updated_ids = list(next_fire_times)
for schedule_id in updated_ids:
- event = ScheduleUpdated(schedule_id=schedule_id,
- next_fire_time=next_fire_times[schedule_id])
+ event = ScheduleUpdated(
+ schedule_id=schedule_id,
+ next_fire_time=next_fire_times[schedule_id],
+ )
update_events.append(event)
# Remove schedules that have no next fire time or failed to serialize
if finished_schedule_ids:
- delete = self.t_schedules.delete().\
- where(self.t_schedules.c.id.in_(finished_schedule_ids))
+ delete = self.t_schedules.delete().where(
+ self.t_schedules.c.id.in_(finished_schedule_ids)
+ )
await conn.execute(delete)
for event in update_events:
@@ -322,10 +410,12 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
await self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
async def get_next_schedule_run_time(self) -> datetime | None:
- statenent = select(self.t_schedules.c.next_fire_time).\
- where(self.t_schedules.c.next_fire_time.isnot(None)).\
- order_by(self.t_schedules.c.next_fire_time).\
- limit(1)
+ statenent = (
+ select(self.t_schedules.c.next_fire_time)
+ .where(self.t_schedules.c.next_fire_time.isnot(None))
+ .order_by(self.t_schedules.c.next_fire_time)
+ .limit(1)
+ )
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
@@ -340,8 +430,12 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
async with self.engine.begin() as conn:
await conn.execute(insert)
- event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id,
- tags=job.tags)
+ event = JobAdded(
+ job_id=job.id,
+ task_id=job.task_id,
+ schedule_id=job.schedule_id,
+ tags=job.tags,
+ )
await self._events.publish(event)
async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
@@ -362,13 +456,19 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
async with self.engine.begin() as conn:
now = datetime.now(timezone.utc)
acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- query = self.t_jobs.select().\
- join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\
- where(or_(self.t_jobs.c.acquired_until.is_(None),
- self.t_jobs.c.acquired_until < now)).\
- order_by(self.t_jobs.c.created_at).\
- with_for_update(skip_locked=True).\
- limit(limit)
+ query = (
+ self.t_jobs.select()
+ .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id)
+ .where(
+ or_(
+ self.t_jobs.c.acquired_until.is_(None),
+ self.t_jobs.c.acquired_until < now,
+ )
+ )
+ .order_by(self.t_jobs.c.created_at)
+ .with_for_update(skip_locked=True)
+ .limit(limit)
+ )
result = await conn.execute(query)
if not result:
@@ -379,11 +479,16 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
task_ids: set[str] = {job.task_id for job in jobs}
# Retrieve the limits
- query = select([
- self.t_tasks.c.id,
- self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\
- where(self.t_tasks.c.max_running_jobs.isnot(None),
- self.t_tasks.c.id.in_(task_ids))
+ query = select(
+ [
+ self.t_tasks.c.id,
+ self.t_tasks.c.max_running_jobs
+ - self.t_tasks.c.running_jobs,
+ ]
+ ).where(
+ self.t_tasks.c.max_running_jobs.isnot(None),
+ self.t_tasks.c.id.in_(task_ids),
+ )
result = await conn.execute(query)
job_slots_left: dict[str, int] = dict(result.fetchall())
@@ -404,19 +509,29 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
if acquired_jobs:
# Mark the acquired jobs as acquired by this worker
acquired_job_ids = [job.id for job in acquired_jobs]
- update = self.t_jobs.update().\
- values(acquired_by=worker_id, acquired_until=acquired_until).\
- where(self.t_jobs.c.id.in_(acquired_job_ids))
+ update = (
+ self.t_jobs.update()
+ .values(
+ acquired_by=worker_id, acquired_until=acquired_until
+ )
+ .where(self.t_jobs.c.id.in_(acquired_job_ids))
+ )
await conn.execute(update)
# Increment the running job counters on each task
- p_id: BindParameter = bindparam('p_id')
- p_increment: BindParameter = bindparam('p_increment')
- params = [{'p_id': task_id, 'p_increment': increment}
- for task_id, increment in increments.items()]
- update = self.t_tasks.update().\
- values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\
- where(self.t_tasks.c.id == p_id)
+ p_id: BindParameter = bindparam("p_id")
+ p_increment: BindParameter = bindparam("p_increment")
+ params = [
+ {"p_id": task_id, "p_increment": increment}
+ for task_id, increment in increments.items()
+ ]
+ update = (
+ self.t_tasks.update()
+ .values(
+ running_jobs=self.t_tasks.c.running_jobs + p_increment
+ )
+ .where(self.t_tasks.c.id == p_id)
+ )
await conn.execute(update, params)
# Publish the appropriate events
@@ -425,7 +540,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
return acquired_jobs
- async def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
async for attempt in self._retrying:
with attempt:
async with self.engine.begin() as conn:
@@ -435,13 +552,17 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
await conn.execute(insert)
# Decrement the running jobs counter
- update = self.t_tasks.update().\
- values(running_jobs=self.t_tasks.c.running_jobs - 1).\
- where(self.t_tasks.c.id == task_id)
+ update = (
+ self.t_tasks.update()
+ .values(running_jobs=self.t_tasks.c.running_jobs - 1)
+ .where(self.t_tasks.c.id == task_id)
+ )
await conn.execute(update)
# Delete the job
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id)
+ delete = self.t_jobs.delete().where(
+ self.t_jobs.c.id == result.job_id
+ )
await conn.execute(delete)
async def get_job_result(self, job_id: UUID) -> JobResult | None:
@@ -449,13 +570,19 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
with attempt:
async with self.engine.begin() as conn:
# Retrieve the result
- query = self.t_job_results.select().\
- where(self.t_job_results.c.job_id == job_id)
+ query = self.t_job_results.select().where(
+ self.t_job_results.c.job_id == job_id
+ )
row = (await conn.execute(query)).fetchone()
# Delete the result
- delete = self.t_job_results.delete().\
- where(self.t_job_results.c.job_id == job_id)
+ delete = self.t_job_results.delete().where(
+ self.t_job_results.c.job_id == job_id
+ )
await conn.execute(delete)
- return JobResult.unmarshal(self.serializer, row._asdict()) if row else None
+ return (
+ JobResult.unmarshal(self.serializer, row._asdict())
+ if row
+ else None
+ )