diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-09-12 22:09:05 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-09-21 02:40:02 +0300 |
commit | c5727432736b55b7d76753307f14efdb962c2edf (patch) | |
tree | 005bd129694b56bd601d65c4cdf43828cfcd4381 /src/apscheduler/datastores/sqlalchemy.py | |
parent | 26c4db062145fcb4f623ecfda96c42ce2414e8e1 (diff) | |
download | apscheduler-c5727432736b55b7d76753307f14efdb962c2edf.tar.gz |
Major refactoring
- Made SyncScheduler a synchronous wrapper for AsyncScheduler
- Removed workers as a user interface
- Removed synchronous interfaces for data stores and event brokers and refactored existing implementations to use the async interface
- Added the current_async_scheduler contextvar
- Added job executors
Diffstat (limited to 'src/apscheduler/datastores/sqlalchemy.py')
-rw-r--r-- | src/apscheduler/datastores/sqlalchemy.py | 865 |
1 files changed, 475 insertions, 390 deletions
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index ebb076f..00d06aa 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -1,16 +1,20 @@ from __future__ import annotations +import sys from collections import defaultdict +from collections.abc import AsyncGenerator, Sequence +from contextlib import AsyncExitStack, asynccontextmanager from datetime import datetime, timedelta, timezone -from logging import Logger, getLogger from typing import Any, Iterable from uuid import UUID +import anyio import attrs +import sniffio import tenacity +from anyio import to_thread from sqlalchemy import ( JSON, - TIMESTAMP, BigInteger, Column, Enum, @@ -26,14 +30,15 @@ from sqlalchemy import ( select, ) from sqlalchemy.engine import URL, Dialect, Result -from sqlalchemy.exc import CompileError, IntegrityError, OperationalError -from sqlalchemy.future import Engine, create_engine +from sqlalchemy.exc import CompileError, IntegrityError, InterfaceError +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine +from sqlalchemy.future import Connection, Engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome from .._events import ( - Event, + DataStoreEvent, JobAcquired, JobAdded, JobDeserializationFailed, @@ -46,11 +51,15 @@ from .._events import ( TaskUpdated, ) from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError -from .._structures import Job, JobResult, RetrySettings, Schedule, Task -from ..abc import EventBroker, Serializer +from .._structures import Job, JobResult, Schedule, Task +from ..abc import EventBroker from ..marshalling import callable_to_ref -from ..serializers.pickle import PickleSerializer -from .base import BaseDataStore +from .base import BaseExternalDataStore + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self class EmulatedUUID(TypeDecorator): @@ -86,17 +95,30 @@ class EmulatedInterval(TypeDecorator): return timedelta(seconds=value) if value is not None else None -@attrs.define(kw_only=True, eq=False) -class _BaseSQLAlchemyDataStore: +@attrs.define(eq=False) +class SQLAlchemyDataStore(BaseExternalDataStore): + """ + Uses a relational database to store data. + + When started, this data store creates the appropriate tables on the given database + if they're not already present. + + Operations are retried (in accordance to ``retry_settings``) when an operation + raises :exc:`sqlalchemy.OperationalError`. + + This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL + (asyncmy driver). + + :param engine: an asynchronous SQLAlchemy engine + :param schema: a database schema name to use, if not the default + """ + + engine: Engine | AsyncEngine schema: str | None = attrs.field(default=None) - serializer: Serializer = attrs.field(factory=PickleSerializer) - lock_expiration_delay: float = attrs.field(default=30) max_poll_time: float | None = attrs.field(default=1) max_idle_time: float = attrs.field(default=60) - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) - start_from_scratch: bool = attrs.field(default=False) - _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__)) + _is_async: bool = attrs.field(init=False) def __attrs_post_init__(self) -> None: # Generate the table definitions @@ -107,6 +129,7 @@ class _BaseSQLAlchemyDataStore: self.t_schedules = self._metadata.tables[prefix + "schedules"] self.t_jobs = self._metadata.tables[prefix + "jobs"] self.t_job_results = self._metadata.tables[prefix + "job_results"] + self._is_async = isinstance(self.engine, AsyncEngine) # Find out if the dialect supports UPDATE...RETURNING update = self.t_jobs.update().returning(self.t_jobs.c.id) @@ -117,18 +140,76 @@ class _BaseSQLAlchemyDataStore: else: self._supports_update_returning = True - def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None: - self._logger.warning( - "Temporary data store error (attempt %d): %s", - retry_state.attempt_number, - retry_state.outcome.exception(), + @classmethod + def from_url(cls: type[Self], url: str | URL, **options) -> Self: + """ + Create a new asynchronous SQLAlchemy data store. + + :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` + (must use an async dialect like ``asyncpg`` or ``asyncmy``) + :param kwargs: keyword arguments to pass to the initializer of this class + :return: the newly created data store + + """ + engine = create_async_engine(url, future=True) + return cls(engine, **options) + + def _retry(self) -> tenacity.AsyncRetrying: + def after_attempt(self, retry_state: tenacity.RetryCallState) -> None: + self._logger.warning( + "Temporary data store error (attempt %d): %s", + retry_state.attempt_number, + retry_state.outcome.exception(), + ) + + # OSError is raised by asyncpg if it can't connect + return tenacity.AsyncRetrying( + stop=self.retry_settings.stop, + wait=self.retry_settings.wait, + retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), + after=after_attempt, + sleep=anyio.sleep, + reraise=True, ) + @asynccontextmanager + async def _begin_transaction(self) -> AsyncGenerator[Connection | AsyncConnection]: + if isinstance(self.engine, AsyncEngine): + async with self.engine.begin() as conn: + yield conn + else: + cm = self.engine.begin() + conn = await to_thread.run_sync(cm.__enter__) + try: + yield conn + except BaseException as exc: + await to_thread.run_sync(cm.__exit__, type(exc), exc, exc.__traceback__) + raise + else: + await to_thread.run_sync(cm.__exit__, None, None, None) + + async def _create_metadata(self, conn: Connection | AsyncConnection) -> None: + if isinstance(conn, AsyncConnection): + await conn.run_sync(self._metadata.create_all) + else: + await to_thread.run_sync(self._metadata.create_all, conn) + + async def _execute( + self, + conn: Connection | AsyncConnection, + statement, + parameters: Sequence | None = None, + ): + if isinstance(conn, AsyncConnection): + return await conn.execute(statement, parameters) + else: + return await to_thread.run_sync(conn.execute, statement, parameters) + def get_table_definitions(self) -> MetaData: if self.engine.dialect.name == "postgresql": from sqlalchemy.dialects import postgresql - timestamp_type = TIMESTAMP(timezone=True) + timestamp_type = postgresql.TIMESTAMP(timezone=True) job_id_type = postgresql.UUID(as_uuid=True) interval_type = postgresql.INTERVAL(precision=6) tags_type = postgresql.ARRAY(Unicode) @@ -145,6 +226,7 @@ class _BaseSQLAlchemyDataStore: metadata, Column("id", Unicode(500), primary_key=True), Column("func", Unicode(500), nullable=False), + Column("executor", Unicode(500), nullable=False), Column("state", LargeBinary), Column("max_running_jobs", Integer), Column("misfire_grace_time", interval_type), @@ -197,186 +279,160 @@ class _BaseSQLAlchemyDataStore: ) return metadata - def _deserialize_schedules(self, result: Result) -> list[Schedule]: + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: + await super().start(exit_stack, event_broker) + asynclib = sniffio.current_async_library() or "(unknown)" + if asynclib != "asyncio": + raise RuntimeError( + f"This data store requires asyncio; currently running: {asynclib}" + ) + + # Verify that the schema is in place + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + await self._execute(conn, DropTable(table, if_exists=True)) + + await self._create_metadata(conn) + query = select(self.t_metadata.c.schema_version) + result = await self._execute(conn, query) + version = result.scalar() + if version is None: + await self._execute( + conn, self.t_metadata.insert(values={"schema_version": 1}) + ) + elif version > 1: + raise RuntimeError( + f"Unexpected schema version ({version}); " + f"only version 1 is supported by this version of " + f"APScheduler" + ) + + async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: try: schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._event_broker.publish( ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) ) return schedules - def _deserialize_jobs(self, result: Result) -> list[Job]: + async def _deserialize_jobs(self, result: Result) -> list[Job]: jobs: list[Job] = [] for row in result: try: jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._event_broker.publish( JobDeserializationFailed(job_id=row["id"], exception=exc) ) return jobs - -@attrs.define(eq=False) -class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): - """ - Uses a relational database to store data. - - When started, this data store creates the appropriate tables on the given database - if they're not already present. - - Operations are retried (in accordance to ``retry_settings``) when an operation - raises :exc:`sqlalchemy.OperationalError`. - - This store has been tested to work with PostgreSQL (psycopg2 driver), MySQL - (pymysql driver) and SQLite. - - :param engine: a (synchronous) SQLAlchemy engine - :param schema: a database schema name to use, if not the default - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) - """ - - engine: Engine - - @classmethod - def from_url(cls, url: str | URL, **kwargs) -> SQLAlchemyDataStore: - """ - Create a new SQLAlchemy data store. - - :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` - :param kwargs: keyword arguments to pass to the initializer of this class - :return: the newly created data store - - """ - engine = create_engine(url) - return cls(engine, **kwargs) - - def _retry(self) -> tenacity.Retrying: - return tenacity.Retrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type(OperationalError), - after=self._after_attempt, - reraise=True, - ) - - def start(self, event_broker: EventBroker) -> None: - super().start(event_broker) - - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - conn.execute(DropTable(table, if_exists=True)) - - self._metadata.create_all(conn) - query = select(self.t_metadata.c.schema_version) - result = conn.execute(query) - version = result.scalar() - if version is None: - conn.execute(self.t_metadata.insert(values={"schema_version": 1})) - elif version > 1: - raise RuntimeError( - f"Unexpected schema version ({version}); " - f"only version 1 is supported by this version of APScheduler" - ) - - def add_task(self, task: Task) -> None: + async def add_task(self, task: Task) -> None: insert = self.t_tasks.insert().values( id=task.id, func=callable_to_ref(task.func), + executor=task.executor, max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) try: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) except IntegrityError: update = ( self.t_tasks.update() .values( func=callable_to_ref(task.func), + executor=task.executor, max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) .where(self.t_tasks.c.id == task.id) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(update) - self._events.publish(TaskUpdated(task_id=task.id)) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, update) + + await self._event_broker.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._event_broker.publish(TaskAdded(task_id=task.id)) - def remove_task(self, task_id: str) -> None: + async def remove_task(self, task_id: str) -> None: delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - self._events.publish(TaskRemoved(task_id=task_id)) - - def get_task(self, task_id: str) -> Task: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + await self._event_broker.publish(TaskRemoved(task_id=task_id)) + + async def get_task(self, task_id: str) -> Task: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, + self.t_tasks.c.executor, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).where(self.t_tasks.c.id == task_id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - row = result.first() + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + row = result.first() if row: return Task.unmarshal(self.serializer, row._asdict()) else: raise TaskLookupError(task_id) - def get_tasks(self) -> list[Task]: + async def get_tasks(self) -> list[Task]: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, + self.t_tasks.c.executor, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).order_by(self.t_tasks.c.id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - tasks = [ - Task.unmarshal(self.serializer, row._asdict()) for row in result - ] - return tasks - - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: - event: Event + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + tasks = [ + Task.unmarshal(self.serializer, row._asdict()) for row in result + ] + return tasks + + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: + event: DataStoreEvent values = schedule.marshal(self.serializer) insert = self.t_schedules.insert().values(**values) try: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) except IntegrityError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None @@ -387,185 +443,197 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): .where(self.t_schedules.c.id == schedule.id) .values(**values) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(update) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, update) event = ScheduleUpdated( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) else: event = ScheduleAdded( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) - - def remove_schedules(self, ids: Iterable[str]) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [ - row[0] for row in conn.execute(delete) - ] - else: - # TODO: actually check which rows were deleted? - conn.execute(delete) - removed_ids = ids + await self._event_broker.publish(event) + + async def remove_schedules(self, ids: Iterable[str]) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(ids) + ) + if self._supports_update_returning: + delete = delete.returning(self.t_schedules.c.id) + removed_ids: Iterable[str] = [ + row[0] for row in await self._execute(conn, delete) + ] + else: + # TODO: actually check which rows were deleted? + await self._execute(conn, delete) + removed_ids = ids for schedule_id in removed_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: + async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: query = self.t_schedules.select().order_by(self.t_schedules.c.id) if ids: query = query.where(self.t_schedules.c.id.in_(ids)) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_schedules(result) - - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = ( - select(self.t_schedules.c.id) - .where( - and_( - self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_( - self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now, - ), + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + return await self._deserialize_schedules(result) + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + schedules_cte = ( + select(self.t_schedules.c.id) + .where( + and_( + self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_( + self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now, + ), + ) ) + .order_by(self.t_schedules.c.next_fire_time) + .limit(limit) + .with_for_update(skip_locked=True) + .cte() ) - .order_by(self.t_schedules.c.next_fire_time) - .limit(limit) - .with_for_update(skip_locked=True) - .cte() - ) - subselect = select([schedules_cte.c.id]) - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id.in_(subselect)) - .values(acquired_by=scheduler_id, acquired_until=acquired_until) - ) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = conn.execute(update) - else: - conn.execute(update) - query = self.t_schedules.select().where( - and_(self.t_schedules.c.acquired_by == scheduler_id) + subselect = select([schedules_cte.c.id]) + update = ( + self.t_schedules.update() + .where(self.t_schedules.c.id.in_(subselect)) + .values(acquired_by=scheduler_id, acquired_until=acquired_until) ) - result = conn.execute(query) + if self._supports_update_returning: + update = update.returning(*self.t_schedules.columns) + result = await self._execute(conn, update) + else: + await self._execute(conn, update) + query = self.t_schedules.select().where( + and_(self.t_schedules.c.acquired_by == scheduler_id) + ) + result = await self._execute(conn, query) - schedules = self._deserialize_schedules(result) + schedules = await self._deserialize_schedules(result) return schedules - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize( - schedule.trigger - ) - except SerializationError: - self._logger.exception( - "Error serializing trigger for schedule %r – " - "removing from data store", - schedule.id, + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + update_events: list[ScheduleUpdated] = [] + finished_schedule_ids: list[str] = [] + update_args: list[dict[str, Any]] = [] + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_trigger = self.serializer.serialize( + schedule.trigger + ) + except SerializationError: + self._logger.exception( + "Error serializing trigger for schedule %r – " + "removing from data store", + schedule.id, + ) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append( + { + "p_id": schedule.id, + "p_trigger": serialized_trigger, + "p_next_fire_time": schedule.next_fire_time, + } ) + else: finished_schedule_ids.append(schedule.id) - continue - - update_args.append( - { - "p_id": schedule.id, - "p_trigger": serialized_trigger, - "p_next_fire_time": schedule.next_fire_time, - } - ) - else: - finished_schedule_ids.append(schedule.id) - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam("p_id") - p_trigger: BindParameter = bindparam("p_trigger") - p_next_fire_time: BindParameter = bindparam("p_next_fire_time") - update = ( - self.t_schedules.update() - .where( - and_( - self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id, + # Update schedules that have a next fire time + if update_args: + p_id: BindParameter = bindparam("p_id") + p_trigger: BindParameter = bindparam("p_trigger") + p_next_fire_time: BindParameter = bindparam("p_next_fire_time") + update = ( + self.t_schedules.update() + .where( + and_( + self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id, + ) + ) + .values( + trigger=p_trigger, + next_fire_time=p_next_fire_time, + acquired_by=None, + acquired_until=None, ) ) - .values( - trigger=p_trigger, - next_fire_time=p_next_fire_time, - acquired_by=None, - acquired_until=None, - ) - ) - next_fire_times = { - arg["p_id"]: arg["p_next_fire_time"] for arg in update_args - } - # TODO: actually check which rows were updated? - conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated( - schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id], - ) - update_events.append(event) + next_fire_times = { + arg["p_id"]: arg["p_next_fire_time"] for arg in update_args + } + # TODO: actually check which rows were updated? + await self._execute(conn, update, update_args) + updated_ids = list(next_fire_times) + + for schedule_id in updated_ids: + event = ScheduleUpdated( + schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id], + ) + update_events.append(event) - # Remove schedules that have no next fire time or failed to serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(finished_schedule_ids) - ) - conn.execute(delete) + # Remove schedules that have no next fire time or failed to + # serialize + if finished_schedule_ids: + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(finished_schedule_ids) + ) + await self._execute(conn, delete) for event in update_events: - self._events.publish(event) + await self._event_broker.publish(event) for schedule_id in finished_schedule_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_next_schedule_run_time(self) -> datetime | None: - query = ( + async def get_next_schedule_run_time(self) -> datetime | None: + statenent = ( select(self.t_schedules.c.next_fire_time) .where(self.t_schedules.c.next_fire_time.isnot(None)) .order_by(self.t_schedules.c.next_fire_time) .limit(1) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return result.scalar() + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, statenent) + return result.scalar() - def add_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: marshalled = job.marshal(self.serializer) insert = self.t_jobs.insert().values(**marshalled) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) event = JobAdded( job_id=job.id, @@ -573,139 +641,156 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): schedule_id=job.schedule_id, tags=job.tags, ) - self._events.publish(event) + await self._event_broker.publish(event) - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: + async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: query = self.t_jobs.select().order_by(self.t_jobs.c.id) if ids: job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_jobs(result) - - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = ( - self.t_jobs.select() - .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) - .where( - or_( - self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now, + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + return await self._deserialize_jobs(result) + + async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + query = ( + self.t_jobs.select() + .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) + .where( + or_( + self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now, + ) ) + .order_by(self.t_jobs.c.created_at) + .with_for_update(skip_locked=True) + .limit(limit) ) - .order_by(self.t_jobs.c.created_at) - .with_for_update(skip_locked=True) - .limit(limit) - ) - - result = conn.execute(query) - if not result: - return [] - # Mark the jobs as acquired by this worker - jobs = self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs, - ] - ).where( - self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids), - ) - result = conn.execute(query) - job_slots_left = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = ( - self.t_jobs.update() - .values(acquired_by=worker_id, acquired_until=acquired_until) - .where(self.t_jobs.c.id.in_(acquired_job_ids)) + result = await self._execute(conn, query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = await self._deserialize_jobs(result) + task_ids: set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select( + [ + self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs + - self.t_tasks.c.running_jobs, + ] + ).where( + self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids), ) - conn.execute(update) - - # Increment the running job counters on each task - p_id: BindParameter = bindparam("p_id") - p_increment: BindParameter = bindparam("p_increment") - params = [ - {"p_id": task_id, "p_increment": increment} - for task_id, increment in increments.items() - ] - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs + p_increment) - .where(self.t_tasks.c.id == p_id) - ) - conn.execute(update, params) + result = await self._execute(conn, query) + job_slots_left: dict[str, int] = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: list[Job] = [] + increments: dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id for job in acquired_jobs] + update = ( + self.t_jobs.update() + .values( + acquired_by=worker_id, acquired_until=acquired_until + ) + .where(self.t_jobs.c.id.in_(acquired_job_ids)) + ) + await self._execute(conn, update) + + # Increment the running job counters on each task + p_id: BindParameter = bindparam("p_id") + p_increment: BindParameter = bindparam("p_increment") + params = [ + {"p_id": task_id, "p_increment": increment} + for task_id, increment in increments.items() + ] + update = ( + self.t_tasks.update() + .values( + running_jobs=self.t_tasks.c.running_jobs + p_increment + ) + .where(self.t_tasks.c.id == p_id) + ) + await self._execute(conn, update, params) # Publish the appropriate events for job in acquired_jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + await self._event_broker.publish( + JobAcquired(job_id=job.id, worker_id=worker_id) + ) return acquired_jobs - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - # Insert the job result - if result.expires_at > result.finished_at: - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - conn.execute(insert) + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + # Record the job result + if result.expires_at > result.finished_at: + marshalled = result.marshal(self.serializer) + insert = self.t_job_results.insert().values(**marshalled) + await self._execute(conn, insert) + + # Decrement the number of running jobs for this task + update = ( + self.t_tasks.update() + .values(running_jobs=self.t_tasks.c.running_jobs - 1) + .where(self.t_tasks.c.id == task_id) + ) + await self._execute(conn, update) - # Decrement the running jobs counter - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs - 1) - .where(self.t_tasks.c.id == task_id) - ) - conn.execute(update) - - # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) - conn.execute(delete) - - def get_job_result(self, job_id: UUID) -> JobResult | None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().where( - self.t_job_results.c.job_id == job_id - ) - row = conn.execute(query).first() + # Delete the job + delete = self.t_jobs.delete().where( + self.t_jobs.c.id == result.job_id + ) + await self._execute(conn, delete) + + async def get_job_result(self, job_id: UUID) -> JobResult | None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + # Retrieve the result + query = self.t_job_results.select().where( + self.t_job_results.c.job_id == job_id + ) + row = (await self._execute(conn, query)).first() - # Delete the result - delete = self.t_job_results.delete().where( - self.t_job_results.c.job_id == job_id - ) - conn.execute(delete) + # Delete the result + delete = self.t_job_results.delete().where( + self.t_job_results.c.job_id == job_id + ) + await self._execute(conn, delete) - return ( - JobResult.unmarshal(self.serializer, row._asdict()) if row else None - ) + return ( + JobResult.unmarshal(self.serializer, row._asdict()) + if row + else None + ) |