diff options
38 files changed, 2144 insertions, 4534 deletions
diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index efe69fd..559ee36 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -6,6 +6,30 @@ APScheduler, see the :doc:`migration section <migration>`. **UNRELEASED** +- **BREAKING** Workers can no longer be run independently. Instead, you can run a + scheduler that only starts a worker but does not process schedules by passing + ``process_schedules=False`` to the scheduler +- **BREAKING** The synchronous interfaces for event brokers and data stores have been + removed. Synchronous libraries can still be used to implement these services through + the use of ``anyio.to_thread.run_sync()``. +- **BREAKING** The ``current_worker`` context variable has been removed +- **BREAKING** The ``current_scheduler`` context variable is now specified to only + contain the currently running instance of a **synchronous** scheduler + (``apscheduler.schedulers.sync.Scheduler``). The asynchronous scheduler instance can + be fetched from the new ``current_async_scheduler`` context variable, and will always + be available when a scheduler is running in the current context, while + ``current_scheduler`` is only available when the synchronous wrapper is being run. +- **BREAKING** Changed the initialization of data stores and event brokers to use a + single ``start()`` method that accepts an ``AsyncExitStack`` (and, depending on the + interface, other arguments too) +- **BREAKING** Added a concept of "job executors". This determines how the task function + is executed once picked up by a worker. Several data structures and scheduler methods + have a new field/parameter for this, ``job_executor``. This addition requires database + schema changes too. +- Added the ability to run jobs in worker processes, courtesy of the ``processpool`` + executor +- The synchronous scheduler now runs an asyncio event loop in a thread, acting as a + façade for ``AsyncScheduler` - Fixed the ``schema`` parameter in ``SQLAlchemyDataStore`` not being applied **4.0.0a2** diff --git a/examples/separate_worker/async_scheduler.py b/examples/separate_worker/async_scheduler.py index 6ffdbcd..2ac53c5 100644 --- a/examples/separate_worker/async_scheduler.py +++ b/examples/separate_worker/async_scheduler.py @@ -19,7 +19,7 @@ import logging from example_tasks import tick from sqlalchemy.ext.asyncio import create_async_engine -from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore +from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker from apscheduler.schedulers.async_ import AsyncScheduler from apscheduler.triggers.interval import IntervalTrigger @@ -29,15 +29,15 @@ async def main(): engine = create_async_engine( "postgresql+asyncpg://postgres:secret@localhost/testdb" ) - data_store = AsyncSQLAlchemyDataStore(engine) + data_store = SQLAlchemyDataStore(engine) event_broker = AsyncpgEventBroker.from_async_sqla_engine(engine) # Uncomment the next two lines to use the Redis event broker instead - # from apscheduler.eventbrokers.async_redis import AsyncRedisEventBroker - # event_broker = AsyncRedisEventBroker.from_url("redis://localhost") + # from apscheduler.eventbrokers.redis import RedisEventBroker + # event_broker = RedisEventBroker.from_url("redis://localhost") async with AsyncScheduler( - data_store, event_broker, start_worker=False + data_store, event_broker, process_jobs=False ) as scheduler: await scheduler.add_schedule(tick, IntervalTrigger(seconds=1), id="tick") await scheduler.run_until_stopped() diff --git a/examples/separate_worker/async_worker.py b/examples/separate_worker/async_worker.py index 700720e..51c51e9 100644 --- a/examples/separate_worker/async_worker.py +++ b/examples/separate_worker/async_worker.py @@ -18,24 +18,24 @@ import logging from sqlalchemy.ext.asyncio import create_async_engine -from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore +from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker -from apscheduler.workers.async_ import AsyncWorker +from apscheduler.schedulers.async_ import AsyncScheduler async def main(): engine = create_async_engine( "postgresql+asyncpg://postgres:secret@localhost/testdb" ) - data_store = AsyncSQLAlchemyDataStore(engine) + data_store = SQLAlchemyDataStore(engine) event_broker = AsyncpgEventBroker.from_async_sqla_engine(engine) # Uncomment the next two lines to use the Redis event broker instead - # from apscheduler.eventbrokers.async_redis import AsyncRedisEventBroker - # event_broker = AsyncRedisEventBroker.from_url("redis://localhost") + # from apscheduler.eventbrokers.redis import RedisEventBroker + # event_broker = RedisEventBroker.from_url("redis://localhost") - worker = AsyncWorker(data_store, event_broker) - await worker.run_until_stopped() + scheduler = AsyncScheduler(data_store, event_broker, process_schedules=False) + await scheduler.run_until_stopped() logging.basicConfig(level=logging.INFO) diff --git a/examples/separate_worker/sync_worker.py b/examples/separate_worker/sync_worker.py index e57be64..4329d02 100644 --- a/examples/separate_worker/sync_worker.py +++ b/examples/separate_worker/sync_worker.py @@ -1,7 +1,7 @@ """ -Example demonstrating the separation of scheduler and worker. -This script runs the worker part. You need to be running both this and the scheduler -script simultaneously in order for the scheduled task to be run. +Example demonstrating a scheduler that only runs jobs but does not process schedules. +You need to be running both this and the scheduler script simultaneously in order for +the scheduled task to be run. Requires the "postgresql" and "redis" services to be running. To install prerequisites: pip install sqlalchemy psycopg2 redis @@ -19,7 +19,7 @@ from sqlalchemy.future import create_engine from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore from apscheduler.eventbrokers.redis import RedisEventBroker -from apscheduler.workers.sync import Worker +from apscheduler.schedulers.sync import Scheduler logging.basicConfig(level=logging.INFO) engine = create_engine("postgresql+psycopg2://postgres:secret@localhost/testdb") @@ -30,5 +30,5 @@ event_broker = RedisEventBroker.from_url("redis://localhost") # from apscheduler.eventbrokers.mqtt import MQTTEventBroker # event_broker = MQTTEventBroker() -worker = Worker(data_store, event_broker) -worker.run_until_stopped() +with Scheduler(data_store, event_broker, process_schedules=False) as scheduler: + scheduler.run_until_stopped() diff --git a/src/apscheduler/__init__.py b/src/apscheduler/__init__.py index 6b5828d..844a056 100644 --- a/src/apscheduler/__init__.py +++ b/src/apscheduler/__init__.py @@ -41,14 +41,14 @@ __all__ = [ "WorkerEvent", "WorkerStarted", "WorkerStopped", + "current_async_scheduler", "current_scheduler", - "current_worker", "current_job", ] from typing import Any -from ._context import current_job, current_scheduler, current_worker +from ._context import current_async_scheduler, current_job, current_scheduler from ._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState from ._events import ( DataStoreEvent, @@ -84,7 +84,8 @@ from ._exceptions import ( SerializationError, TaskLookupError, ) -from ._structures import Job, JobInfo, JobResult, RetrySettings, Schedule, Task +from ._retry import RetrySettings +from ._structures import Job, JobInfo, JobResult, Schedule, Task # Re-export imports, so they look like they live directly in this package value: Any diff --git a/src/apscheduler/_context.py b/src/apscheduler/_context.py index cc5aff2..5edc310 100644 --- a/src/apscheduler/_context.py +++ b/src/apscheduler/_context.py @@ -7,16 +7,13 @@ if TYPE_CHECKING: from ._structures import JobInfo from .schedulers.async_ import AsyncScheduler from .schedulers.sync import Scheduler - from .workers.async_ import AsyncWorker - from .workers.sync import Worker #: The currently running (local) scheduler -current_scheduler: ContextVar[Scheduler | AsyncScheduler | None] = ContextVar( +current_scheduler: ContextVar[Scheduler | None] = ContextVar( "current_scheduler", default=None ) -#: The worker running the current job -current_worker: ContextVar[Worker | AsyncWorker | None] = ContextVar( - "current_worker", default=None +current_async_scheduler: ContextVar[AsyncScheduler | None] = ContextVar( + "current_async_scheduler", default=None ) #: Metadata about the current job current_job: ContextVar[JobInfo] = ContextVar("job_info") diff --git a/src/apscheduler/_converters.py b/src/apscheduler/_converters.py index 3518d44..9e299dd 100644 --- a/src/apscheduler/_converters.py +++ b/src/apscheduler/_converters.py @@ -6,8 +6,6 @@ from enum import Enum from typing import TypeVar from uuid import UUID -from . import abc - TEnum = TypeVar("TEnum", bound=Enum) @@ -46,23 +44,3 @@ def as_enum(enum_class: type[TEnum]) -> Callable[[TEnum | str], TEnum]: return value return converter - - -def as_async_eventbroker( - value: abc.EventBroker | abc.AsyncEventBroker, -) -> abc.AsyncEventBroker: - if isinstance(value, abc.EventBroker): - from apscheduler.eventbrokers.async_adapter import AsyncEventBrokerAdapter - - return AsyncEventBrokerAdapter(value) - - return value - - -def as_async_datastore(value: abc.DataStore | abc.AsyncDataStore) -> abc.AsyncDataStore: - if isinstance(value, abc.DataStore): - from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter - - return AsyncDataStoreAdapter(value) - - return value diff --git a/src/apscheduler/_retry.py b/src/apscheduler/_retry.py new file mode 100644 index 0000000..19cf51b --- /dev/null +++ b/src/apscheduler/_retry.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import attrs +from attr.validators import instance_of +from tenacity import ( + AsyncRetrying, + RetryCallState, + retry_if_exception_type, + stop_after_delay, + wait_exponential, +) +from tenacity.stop import stop_base +from tenacity.wait import wait_base + + +@attrs.define(kw_only=True, frozen=True) +class RetrySettings: + """ + Settings for retrying an operation with Tenacity. + + :param stop: defines when to stop trying + :param wait: defines how long to wait between attempts + """ + + stop: stop_base = attrs.field( + validator=instance_of(stop_base), + default=stop_after_delay(60), + ) + wait: wait_base = attrs.field( + validator=instance_of(wait_base), + default=wait_exponential(min=0.5, max=20), + ) + + +@attrs.define(kw_only=True, slots=False) +class RetryMixin: + """ + Mixin that provides support for retrying operations. + + :param retry_settings: Tenacity settings for retrying operations in case of a + database connecitivty problem + """ + + retry_settings: RetrySettings = attrs.field(default=RetrySettings()) + + @property + def _temporary_failure_exceptions(self) -> tuple[type[Exception]]: + """ + Tuple of exception classes which indicate that the operation should be retried. + + """ + return () + + def _retry(self) -> AsyncRetrying: + def after_attempt(self, retry_state: RetryCallState) -> None: + self._logger.warning( + "Temporary data store error (attempt %d): %s", + retry_state.attempt_number, + retry_state.outcome.exception(), + ) + + return AsyncRetrying( + stop=self.retry_settings.stop, + wait=self.retry_settings.wait, + retry=retry_if_exception_type(self._temporary_failure_exceptions), + after=after_attempt, + reraise=True, + ) diff --git a/src/apscheduler/_structures.py b/src/apscheduler/_structures.py index 0a959e4..eb26a3a 100644 --- a/src/apscheduler/_structures.py +++ b/src/apscheduler/_structures.py @@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 import attrs -import tenacity.stop -import tenacity.wait -from attrs.validators import instance_of from ._converters import as_enum, as_timedelta from ._enums import CoalescePolicy, JobOutcome @@ -43,6 +40,7 @@ class Task: id: str func: Callable = attrs.field(eq=False, order=False) + executor: str = attrs.field(eq=False) max_running_jobs: int | None = attrs.field(eq=False, order=False, default=None) misfire_grace_time: timedelta | None = attrs.field( eq=False, order=False, default=None @@ -339,22 +337,3 @@ class JobResult: ) return cls(**marshalled) - - -@attrs.define(kw_only=True, frozen=True) -class RetrySettings: - """ - Settings for retrying an operation with Tenacity. - - :param stop: defines when to stop trying - :param wait: defines how long to wait between attempts - """ - - stop: tenacity.stop.stop_base = attrs.field( - validator=instance_of(tenacity.stop.stop_base), - default=tenacity.stop_after_delay(60), - ) - wait: tenacity.wait.wait_base = attrs.field( - validator=instance_of(tenacity.wait.wait_base), - default=tenacity.wait_exponential(min=0.5, max=20), - ) diff --git a/src/apscheduler/_worker.py b/src/apscheduler/_worker.py new file mode 100644 index 0000000..8c95d16 --- /dev/null +++ b/src/apscheduler/_worker.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from collections.abc import Mapping +from contextlib import AsyncExitStack +from datetime import datetime, timezone +from logging import Logger, getLogger +from typing import Callable +from uuid import UUID + +import anyio +import attrs +from anyio import create_task_group, get_cancelled_exc_class, move_on_after +from anyio.abc import CancelScope + +from ._context import current_job +from ._enums import JobOutcome, RunState +from ._events import JobAdded, JobReleased, WorkerStarted, WorkerStopped +from ._structures import Job, JobInfo, JobResult +from ._validators import positive_integer +from .abc import DataStore, EventBroker, JobExecutor + + +@attrs.define(eq=False, kw_only=True) +class Worker: + """ + Runs jobs locally in a task group. + + :param max_concurrent_jobs: Maximum number of jobs the worker will run at once + """ + + job_executors: Mapping[str, JobExecutor] = attrs.field(kw_only=True) + max_concurrent_jobs: int = attrs.field( + kw_only=True, validator=positive_integer, default=100 + ) + logger: Logger = attrs.field(kw_only=True, default=getLogger(__name__)) + + _data_store: DataStore = attrs.field(init=False) + _event_broker: EventBroker = attrs.field(init=False) + _identity: str = attrs.field(init=False) + _state: RunState = attrs.field(init=False, default=RunState.stopped) + _wakeup_event: anyio.Event = attrs.field(init=False) + _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) + _running_jobs: set[UUID] = attrs.field(init=False, factory=set) + + async def start( + self, + exit_stack: AsyncExitStack, + data_store: DataStore, + event_broker: EventBroker, + identity: str, + ) -> None: + self._data_store = data_store + self._event_broker = event_broker + self._identity = identity + self._state = RunState.started + self._wakeup_event = anyio.Event() + + # Start the job executors + for job_executor in self.job_executors.values(): + await job_executor.start(exit_stack) + + # Start the worker in a background task + task_group = await exit_stack.enter_async_context(create_task_group()) + task_group.start_soon(self._run) + + # Stop the worker when the exit stack unwinds + exit_stack.callback(lambda: self._wakeup_event.set()) + exit_stack.callback(setattr, self, "_state", RunState.stopped) + + # Wake up the worker if the data store emits a significant job event + exit_stack.enter_context( + self._event_broker.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded} + ) + ) + + # Signal that the worker has started + await self._event_broker.publish_local(WorkerStarted()) + + async def _run(self) -> None: + """Run the worker until it is explicitly stopped.""" + exception: BaseException | None = None + try: + async with create_task_group() as tg: + while self._state is RunState.started: + limit = self.max_concurrent_jobs - len(self._running_jobs) + jobs = await self._data_store.acquire_jobs(self._identity, limit) + for job in jobs: + task = await self._data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + tg.start_soon(self._run_job, job, task.func, task.executor) + + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() + except get_cancelled_exc_class(): + pass + except BaseException as exc: + exception = exc + raise + finally: + if not exception: + self.logger.info("Worker stopped") + elif isinstance(exception, Exception): + self.logger.exception("Worker crashed") + elif exception: + self.logger.info( + f"Worker stopped due to {exception.__class__.__name__}" + ) + + with move_on_after(3, shield=True): + await self._event_broker.publish_local( + WorkerStopped(exception=exception) + ) + + async def _run_job(self, job: Job, func: Callable, executor: str) -> None: + try: + # Check if the job started before the deadline + start_time = datetime.now(timezone.utc) + if job.start_deadline is not None and start_time > job.start_deadline: + result = JobResult.from_job( + job, + outcome=JobOutcome.missed_start_deadline, + finished_at=start_time, + ) + await self._data_store.release_job(self._identity, job.task_id, result) + await self._event_broker.publish( + JobReleased.from_result(result, self._identity) + ) + return + + try: + job_executor = self.job_executors[executor] + except KeyError: + return + + token = current_job.set(JobInfo.from_job(job)) + try: + retval = await job_executor.run_job(func, job) + except get_cancelled_exc_class(): + self.logger.info("Job %s was cancelled", job.id) + with CancelScope(shield=True): + result = JobResult.from_job( + job, + outcome=JobOutcome.cancelled, + ) + await self._data_store.release_job( + self._identity, job.task_id, result + ) + await self._event_broker.publish( + JobReleased.from_result(result, self._identity) + ) + except BaseException as exc: + if isinstance(exc, Exception): + self.logger.exception("Job %s raised an exception", job.id) + else: + self.logger.error( + "Job %s was aborted due to %s", job.id, exc.__class__.__name__ + ) + + result = JobResult.from_job( + job, + JobOutcome.error, + exception=exc, + ) + await self._data_store.release_job( + self._identity, + job.task_id, + result, + ) + await self._event_broker.publish( + JobReleased.from_result(result, self._identity) + ) + if not isinstance(exc, Exception): + raise + else: + self.logger.info("Job %s completed successfully", job.id) + result = JobResult.from_job( + job, + JobOutcome.success, + return_value=retval, + ) + await self._data_store.release_job(self._identity, job.task_id, result) + await self._event_broker.publish( + JobReleased.from_result(result, self._identity) + ) + finally: + current_job.reset(token) + finally: + self._running_jobs.remove(job.id) diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index 74920e5..bc30ddc 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod +from contextlib import AsyncExitStack from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator from uuid import UUID @@ -91,63 +92,20 @@ class Subscription(metaclass=ABCMeta): """ -class EventSource(metaclass=ABCMeta): - """ - Interface for objects that can deliver notifications to interested subscribers. - """ - - @abstractmethod - def subscribe( - self, - callback: Callable[[Event], Any], - event_types: Iterable[type[Event]] | None = None, - *, - one_shot: bool = False, - ) -> Subscription: - """ - Subscribe to events from this event source. - - :param callback: callable to be called with the event object when an event is - published - :param event_types: an iterable of concrete Event classes to subscribe to - :param one_shot: if ``True``, automatically unsubscribe after the first matching - event - """ - - -class EventBroker(EventSource): +class EventBroker(metaclass=ABCMeta): """ Interface for objects that can be used to publish notifications to interested subscribers. """ @abstractmethod - def start(self) -> None: - pass - - @abstractmethod - def stop(self, *, force: bool = False) -> None: - pass - - @abstractmethod - def publish(self, event: Event) -> None: - """Publish an event.""" - - @abstractmethod - def publish_local(self, event: Event) -> None: - """Publish an event, but only to local subscribers.""" - - -class AsyncEventBroker(EventSource): - """Asynchronous version of :class:`EventBroker`. Expected to work on asyncio.""" - - @abstractmethod - async def start(self) -> None: - pass + async def start(self, exit_stack: AsyncExitStack) -> None: + """ + Start the event broker. - @abstractmethod - async def stop(self, *, force: bool = False) -> None: - pass + :param exit_stack: an asynchronous exit stack which will be processed when the + scheduler is shut down + """ @abstractmethod async def publish(self, event: Event) -> None: @@ -157,185 +115,45 @@ class AsyncEventBroker(EventSource): async def publish_local(self, event: Event) -> None: """Publish an event, but only to local subscribers.""" - -class DataStore(metaclass=ABCMeta): - @abstractmethod - def start(self, event_broker: EventBroker) -> None: - pass - - @abstractmethod - def stop(self, *, force: bool = False) -> None: - pass - - @property - @abstractmethod - def events(self) -> EventSource: - pass - - @abstractmethod - def add_task(self, task: Task) -> None: - """ - Add the given task to the store. - - If a task with the same ID already exists, it replaces the old one but does NOT - affect task accounting (# of running jobs). - - :param task: the task to be added - """ - - @abstractmethod - def remove_task(self, task_id: str) -> None: - """ - Remove the task with the given ID. - - :param task_id: ID of the task to be removed - :raises TaskLookupError: if no matching task was found - """ - - @abstractmethod - def get_task(self, task_id: str) -> Task: - """ - Get an existing task definition. - - :param task_id: ID of the task to be returned - :return: the matching task - :raises TaskLookupError: if no matching task was found - """ - - @abstractmethod - def get_tasks(self) -> list[Task]: - """ - Get all the tasks in this store. - - :return: a list of tasks, sorted by ID - """ - - @abstractmethod - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: - """ - Get schedules from the data store. - - :param ids: a specific set of schedule IDs to return, or ``None`` to return all - schedules - :return: the list of matching schedules, in unspecified order - """ - - @abstractmethod - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: - """ - Add or update the given schedule in the data store. - - :param schedule: schedule to be added - :param conflict_policy: policy that determines what to do if there is an - existing schedule with the same ID - """ - - @abstractmethod - def remove_schedules(self, ids: Iterable[str]) -> None: - """ - Remove schedules from the data store. - - :param ids: a specific set of schedule IDs to remove - """ - - @abstractmethod - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - """ - Acquire unclaimed due schedules for processing. - - This method claims up to the requested number of schedules for the given - scheduler and returns them. - - :param scheduler_id: unique identifier of the scheduler - :param limit: maximum number of schedules to claim - :return: the list of claimed schedules - """ - - @abstractmethod - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: - """ - Release the claims on the given schedules and update them on the store. - - :param scheduler_id: unique identifier of the scheduler - :param schedules: the previously claimed schedules - """ - - @abstractmethod - def get_next_schedule_run_time(self) -> datetime | None: - """ - Return the earliest upcoming run time of all the schedules in the store, or - ``None`` if there are no active schedules. - """ - @abstractmethod - def add_job(self, job: Job) -> None: - """ - Add a job to be executed by an eligible worker. - - :param job: the job object - """ - - @abstractmethod - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: - """ - Get the list of pending jobs. - - :param ids: a specific set of job IDs to return, or ``None`` to return all jobs - :return: the list of matching pending jobs, in the order they will be given to - workers - """ - - @abstractmethod - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + is_async: bool = True, + one_shot: bool = False, + ) -> Subscription: """ - Acquire unclaimed jobs for execution. - - This method claims up to the requested number of jobs for the given worker and - returns them. + Subscribe to events from this event broker. - :param worker_id: unique identifier of the worker - :param limit: maximum number of jobs to claim and return - :return: the list of claimed jobs + :param callback: callable to be called with the event object when an event is + published + :param event_types: an iterable of concrete Event classes to subscribe to + :param is_async: ``True`` if the (synchronous) callback should be called on the + event loop thread, ``False`` if it should be called in a worker thread. + If the callback is a coroutine function, this flag is ignored. + :param one_shot: if ``True``, automatically unsubscribe after the first matching + event """ - @abstractmethod - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: - """ - Release the claim on the given job and record the result. - :param worker_id: unique identifier of the worker - :param task_id: the job's task ID - :param result: the result of the job - """ +class DataStore(metaclass=ABCMeta): + """Asynchronous version of :class:`DataStore`. Expected to work on asyncio.""" @abstractmethod - def get_job_result(self, job_id: UUID) -> JobResult | None: + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: """ - Retrieve the result of a job. - - The result is removed from the store after retrieval. + Start the event broker. - :param job_id: the identifier of the job - :return: the result, or ``None`` if the result was not found + :param exit_stack: an asynchronous exit stack which will be processed when the + scheduler is shut down + :param event_broker: the event broker shared between the scheduler, worker (if + any) and this data store """ - -class AsyncDataStore(metaclass=ABCMeta): - """Asynchronous version of :class:`DataStore`. Expected to work on asyncio.""" - - @abstractmethod - async def start(self, event_broker: AsyncEventBroker) -> None: - pass - - @abstractmethod - async def stop(self, *, force: bool = False) -> None: - pass - - @property - @abstractmethod - def events(self) -> EventSource: - pass - @abstractmethod async def add_task(self, task: Task) -> None: """ @@ -488,3 +306,23 @@ class AsyncDataStore(metaclass=ABCMeta): :param job_id: the identifier of the job :return: the result, or ``None`` if the result was not found """ + + +class JobExecutor(metaclass=ABCMeta): + async def start(self, exit_stack: AsyncExitStack) -> None: + """ + Start the job executor. + + :param exit_stack: an asynchronous exit stack which will be processed when the + scheduler is shut down + """ + + @abstractmethod + async def run_job(self, func: Callable[..., Any], job: Job) -> Any: + """ + + :param func: + :param job: + :return: the return value of ``func`` (potentially awaiting on the returned + aawaitable, if any) + """ diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py deleted file mode 100644 index d16ae56..0000000 --- a/src/apscheduler/datastores/async_adapter.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import sys -from datetime import datetime -from typing import Iterable -from uuid import UUID - -import attrs -from anyio import to_thread -from anyio.from_thread import BlockingPortal - -from .._enums import ConflictPolicy -from .._structures import Job, JobResult, Schedule, Task -from ..abc import AsyncEventBroker, DataStore -from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter, SyncEventBrokerAdapter -from .base import BaseAsyncDataStore - - -@attrs.define(eq=False) -class AsyncDataStoreAdapter(BaseAsyncDataStore): - original: DataStore - _portal: BlockingPortal = attrs.field(init=False) - - async def start(self, event_broker: AsyncEventBroker) -> None: - await super().start(event_broker) - - self._portal = BlockingPortal() - await self._portal.__aenter__() - - if isinstance(event_broker, AsyncEventBrokerAdapter): - sync_event_broker = event_broker.original - else: - sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal) - - try: - await to_thread.run_sync(lambda: self.original.start(sync_event_broker)) - except BaseException: - await self._portal.__aexit__(*sys.exc_info()) - raise - - async def stop(self, *, force: bool = False) -> None: - try: - await to_thread.run_sync(lambda: self.original.stop(force=force)) - finally: - await self._portal.__aexit__(None, None, None) - await super().stop(force=force) - - async def add_task(self, task: Task) -> None: - await to_thread.run_sync(self.original.add_task, task) - - async def remove_task(self, task_id: str) -> None: - await to_thread.run_sync(self.original.remove_task, task_id) - - async def get_task(self, task_id: str) -> Task: - return await to_thread.run_sync(self.original.get_task, task_id) - - async def get_tasks(self) -> list[Task]: - return await to_thread.run_sync(self.original.get_tasks) - - async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: - return await to_thread.run_sync(self.original.get_schedules, ids) - - async def add_schedule( - self, schedule: Schedule, conflict_policy: ConflictPolicy - ) -> None: - await to_thread.run_sync(self.original.add_schedule, schedule, conflict_policy) - - async def remove_schedules(self, ids: Iterable[str]) -> None: - await to_thread.run_sync(self.original.remove_schedules, ids) - - async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - return await to_thread.run_sync( - self.original.acquire_schedules, scheduler_id, limit - ) - - async def release_schedules( - self, scheduler_id: str, schedules: list[Schedule] - ) -> None: - await to_thread.run_sync( - self.original.release_schedules, scheduler_id, schedules - ) - - async def get_next_schedule_run_time(self) -> datetime | None: - return await to_thread.run_sync(self.original.get_next_schedule_run_time) - - async def add_job(self, job: Job) -> None: - await to_thread.run_sync(self.original.add_job, job) - - async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: - return await to_thread.run_sync(self.original.get_jobs, ids) - - async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit) - - async def release_job( - self, worker_id: str, task_id: str, result: JobResult - ) -> None: - await to_thread.run_sync(self.original.release_job, worker_id, task_id, result) - - async def get_job_result(self, job_id: UUID) -> JobResult | None: - return await to_thread.run_sync(self.original.get_job_result, job_id) diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py deleted file mode 100644 index 0c165b8..0000000 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ /dev/null @@ -1,602 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from datetime import datetime, timedelta, timezone -from typing import Any, Iterable -from uuid import UUID - -import anyio -import attrs -import sniffio -import tenacity -from sqlalchemy import and_, bindparam, or_, select -from sqlalchemy.engine import URL, Result -from sqlalchemy.exc import IntegrityError, InterfaceError -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.ext.asyncio.engine import AsyncEngine -from sqlalchemy.sql.ddl import DropTable -from sqlalchemy.sql.elements import BindParameter - -from .._enums import ConflictPolicy -from .._events import ( - DataStoreEvent, - JobAcquired, - JobAdded, - JobDeserializationFailed, - ScheduleAdded, - ScheduleDeserializationFailed, - ScheduleRemoved, - ScheduleUpdated, - TaskAdded, - TaskRemoved, - TaskUpdated, -) -from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError -from .._structures import Job, JobResult, Schedule, Task -from ..abc import AsyncEventBroker -from ..marshalling import callable_to_ref -from .base import BaseAsyncDataStore -from .sqlalchemy import _BaseSQLAlchemyDataStore - - -@attrs.define(eq=False) -class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore): - """ - Uses a relational database to store data. - - When started, this data store creates the appropriate tables on the given database - if they're not already present. - - Operations are retried (in accordance to ``retry_settings``) when an operation - raises :exc:`sqlalchemy.OperationalError`. - - This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL - (asyncmy driver). - - :param engine: an asynchronous SQLAlchemy engine - :param schema: a database schema name to use, if not the default - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) - """ - - engine: AsyncEngine - - @classmethod - def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: - """ - Create a new asynchronous SQLAlchemy data store. - - :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` - (must use an async dialect like ``asyncpg`` or ``asyncmy``) - :param kwargs: keyword arguments to pass to the initializer of this class - :return: the newly created data store - - """ - engine = create_async_engine(url, future=True) - return cls(engine, **options) - - def _retry(self) -> tenacity.AsyncRetrying: - # OSError is raised by asyncpg if it can't connect - return tenacity.AsyncRetrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), - after=self._after_attempt, - sleep=anyio.sleep, - reraise=True, - ) - - async def start(self, event_broker: AsyncEventBroker) -> None: - await super().start(event_broker) - - asynclib = sniffio.current_async_library() or "(unknown)" - if asynclib != "asyncio": - raise RuntimeError( - f"This data store requires asyncio; currently running: {asynclib}" - ) - - # Verify that the schema is in place - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - await conn.execute(DropTable(table, if_exists=True)) - - await conn.run_sync(self._metadata.create_all) - query = select(self.t_metadata.c.schema_version) - result = await conn.execute(query) - version = result.scalar() - if version is None: - await conn.execute( - self.t_metadata.insert(values={"schema_version": 1}) - ) - elif version > 1: - raise RuntimeError( - f"Unexpected schema version ({version}); " - f"only version 1 is supported by this version of " - f"APScheduler" - ) - - async def _deserialize_schedules(self, result: Result) -> list[Schedule]: - schedules: list[Schedule] = [] - for row in result: - try: - schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) - except SerializationError as exc: - await self._events.publish( - ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) - ) - - return schedules - - async def _deserialize_jobs(self, result: Result) -> list[Job]: - jobs: list[Job] = [] - for row in result: - try: - jobs.append(Job.unmarshal(self.serializer, row._asdict())) - except SerializationError as exc: - await self._events.publish( - JobDeserializationFailed(job_id=row["id"], exception=exc) - ) - - return jobs - - async def add_task(self, task: Task) -> None: - insert = self.t_tasks.insert().values( - id=task.id, - func=callable_to_ref(task.func), - max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time, - ) - try: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - except IntegrityError: - update = ( - self.t_tasks.update() - .values( - func=callable_to_ref(task.func), - max_running_jobs=task.max_running_jobs, - misfire_grace_time=task.misfire_grace_time, - ) - .where(self.t_tasks.c.id == task.id) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(update) - - await self._events.publish(TaskUpdated(task_id=task.id)) - else: - await self._events.publish(TaskAdded(task_id=task.id)) - - async def remove_task(self, task_id: str) -> None: - delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - await self._events.publish(TaskRemoved(task_id=task_id)) - - async def get_task(self, task_id: str) -> Task: - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.func, - self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, - self.t_tasks.c.misfire_grace_time, - ] - ).where(self.t_tasks.c.id == task_id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - row = result.first() - - if row: - return Task.unmarshal(self.serializer, row._asdict()) - else: - raise TaskLookupError(task_id) - - async def get_tasks(self) -> list[Task]: - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.func, - self.t_tasks.c.max_running_jobs, - self.t_tasks.c.state, - self.t_tasks.c.misfire_grace_time, - ] - ).order_by(self.t_tasks.c.id) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - tasks = [ - Task.unmarshal(self.serializer, row._asdict()) for row in result - ] - return tasks - - async def add_schedule( - self, schedule: Schedule, conflict_policy: ConflictPolicy - ) -> None: - event: DataStoreEvent - values = schedule.marshal(self.serializer) - insert = self.t_schedules.insert().values(**values) - try: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - except IntegrityError: - if conflict_policy is ConflictPolicy.exception: - raise ConflictingIdError(schedule.id) from None - elif conflict_policy is ConflictPolicy.replace: - del values["id"] - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id == schedule.id) - .values(**values) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(update) - - event = ScheduleUpdated( - schedule_id=schedule.id, next_fire_time=schedule.next_fire_time - ) - await self._events.publish(event) - else: - event = ScheduleAdded( - schedule_id=schedule.id, next_fire_time=schedule.next_fire_time - ) - await self._events.publish(event) - - async def remove_schedules(self, ids: Iterable[str]) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(ids) - ) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [ - row[0] for row in await conn.execute(delete) - ] - else: - # TODO: actually check which rows were deleted? - await conn.execute(delete) - removed_ids = ids - - for schedule_id in removed_ids: - await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) - - async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: - query = self.t_schedules.select().order_by(self.t_schedules.c.id) - if ids: - query = query.where(self.t_schedules.c.id.in_(ids)) - - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_schedules(result) - - async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = ( - select(self.t_schedules.c.id) - .where( - and_( - self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_( - self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now, - ), - ) - ) - .order_by(self.t_schedules.c.next_fire_time) - .limit(limit) - .with_for_update(skip_locked=True) - .cte() - ) - subselect = select([schedules_cte.c.id]) - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id.in_(subselect)) - .values(acquired_by=scheduler_id, acquired_until=acquired_until) - ) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = await conn.execute(update) - else: - await conn.execute(update) - query = self.t_schedules.select().where( - and_(self.t_schedules.c.acquired_by == scheduler_id) - ) - result = await conn.execute(query) - - schedules = await self._deserialize_schedules(result) - - return schedules - - async def release_schedules( - self, scheduler_id: str, schedules: list[Schedule] - ) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize( - schedule.trigger - ) - except SerializationError: - self._logger.exception( - "Error serializing trigger for schedule %r – " - "removing from data store", - schedule.id, - ) - finished_schedule_ids.append(schedule.id) - continue - - update_args.append( - { - "p_id": schedule.id, - "p_trigger": serialized_trigger, - "p_next_fire_time": schedule.next_fire_time, - } - ) - else: - finished_schedule_ids.append(schedule.id) - - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam("p_id") - p_trigger: BindParameter = bindparam("p_trigger") - p_next_fire_time: BindParameter = bindparam("p_next_fire_time") - update = ( - self.t_schedules.update() - .where( - and_( - self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id, - ) - ) - .values( - trigger=p_trigger, - next_fire_time=p_next_fire_time, - acquired_by=None, - acquired_until=None, - ) - ) - next_fire_times = { - arg["p_id"]: arg["p_next_fire_time"] for arg in update_args - } - # TODO: actually check which rows were updated? - await conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated( - schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id], - ) - update_events.append(event) - - # Remove schedules that have no next fire time or failed to - # serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(finished_schedule_ids) - ) - await conn.execute(delete) - - for event in update_events: - await self._events.publish(event) - - for schedule_id in finished_schedule_ids: - await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) - - async def get_next_schedule_run_time(self) -> datetime | None: - statenent = ( - select(self.t_schedules.c.next_fire_time) - .where(self.t_schedules.c.next_fire_time.isnot(None)) - .order_by(self.t_schedules.c.next_fire_time) - .limit(1) - ) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(statenent) - return result.scalar() - - async def add_job(self, job: Job) -> None: - marshalled = job.marshal(self.serializer) - insert = self.t_jobs.insert().values(**marshalled) - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - await conn.execute(insert) - - event = JobAdded( - job_id=job.id, - task_id=job.task_id, - schedule_id=job.schedule_id, - tags=job.tags, - ) - await self._events.publish(event) - - async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: - query = self.t_jobs.select().order_by(self.t_jobs.c.id) - if ids: - job_ids = [job_id for job_id in ids] - query = query.where(self.t_jobs.c.id.in_(job_ids)) - - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_jobs(result) - - async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = ( - self.t_jobs.select() - .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) - .where( - or_( - self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now, - ) - ) - .order_by(self.t_jobs.c.created_at) - .with_for_update(skip_locked=True) - .limit(limit) - ) - - result = await conn.execute(query) - if not result: - return [] - - # Mark the jobs as acquired by this worker - jobs = await self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - - self.t_tasks.c.running_jobs, - ] - ).where( - self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids), - ) - result = await conn.execute(query) - job_slots_left: dict[str, int] = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = ( - self.t_jobs.update() - .values( - acquired_by=worker_id, acquired_until=acquired_until - ) - .where(self.t_jobs.c.id.in_(acquired_job_ids)) - ) - await conn.execute(update) - - # Increment the running job counters on each task - p_id: BindParameter = bindparam("p_id") - p_increment: BindParameter = bindparam("p_increment") - params = [ - {"p_id": task_id, "p_increment": increment} - for task_id, increment in increments.items() - ] - update = ( - self.t_tasks.update() - .values( - running_jobs=self.t_tasks.c.running_jobs + p_increment - ) - .where(self.t_tasks.c.id == p_id) - ) - await conn.execute(update, params) - - # Publish the appropriate events - for job in acquired_jobs: - await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) - - return acquired_jobs - - async def release_job( - self, worker_id: str, task_id: str, result: JobResult - ) -> None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - # Record the job result - if result.expires_at > result.finished_at: - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - await conn.execute(insert) - - # Decrement the number of running jobs for this task - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs - 1) - .where(self.t_tasks.c.id == task_id) - ) - await conn.execute(update) - - # Delete the job - delete = self.t_jobs.delete().where( - self.t_jobs.c.id == result.job_id - ) - await conn.execute(delete) - - async def get_job_result(self, job_id: UUID) -> JobResult | None: - async for attempt in self._retry(): - with attempt: - async with self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().where( - self.t_job_results.c.job_id == job_id - ) - row = (await conn.execute(query)).first() - - # Delete the result - delete = self.t_job_results.delete().where( - self.t_job_results.c.job_id == job_id - ) - await conn.execute(delete) - - return ( - JobResult.unmarshal(self.serializer, row._asdict()) - if row - else None - ) diff --git a/src/apscheduler/datastores/base.py b/src/apscheduler/datastores/base.py index c05d28c..5c7ef7d 100644 --- a/src/apscheduler/datastores/base.py +++ b/src/apscheduler/datastores/base.py @@ -1,37 +1,47 @@ from __future__ import annotations -from apscheduler.abc import ( - AsyncDataStore, - AsyncEventBroker, - DataStore, - EventBroker, - EventSource, -) +from contextlib import AsyncExitStack +from logging import Logger, getLogger +import attrs +from .._retry import RetryMixin +from ..abc import DataStore, EventBroker, Serializer +from ..serializers.pickle import PickleSerializer + + +@attrs.define(kw_only=True) class BaseDataStore(DataStore): - _events: EventBroker + """ + Base class for data stores. - def start(self, event_broker: EventBroker) -> None: - self._events = event_broker + :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler + or worker can keep a lock on a schedule or task + """ - def stop(self, *, force: bool = False) -> None: - del self._events + lock_expiration_delay: float = 30 + _event_broker: EventBroker = attrs.field(init=False) + _logger: Logger = attrs.field(init=False) - @property - def events(self) -> EventSource: - return self._events + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: + self._event_broker = event_broker + def __attrs_post_init__(self): + self._logger = getLogger(self.__class__.__name__) -class BaseAsyncDataStore(AsyncDataStore): - _events: AsyncEventBroker - async def start(self, event_broker: AsyncEventBroker) -> None: - self._events = event_broker +@attrs.define(kw_only=True) +class BaseExternalDataStore(BaseDataStore, RetryMixin): + """ + Base class for data stores using an external service such as a database. - async def stop(self, *, force: bool = False) -> None: - del self._events + :param serializer: the serializer used to (de)serialize tasks, schedules and jobs + for storage + :param start_from_scratch: erase all existing data during startup (useful for test + suites) + """ - @property - def events(self) -> EventSource: - return self._events + serializer: Serializer = attrs.field(factory=PickleSerializer) + start_from_scratch: bool = attrs.field(default=False) diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py index a9ff3cb..fd7e90e 100644 --- a/src/apscheduler/datastores/memory.py +++ b/src/apscheduler/datastores/memory.py @@ -85,13 +85,9 @@ class MemoryDataStore(BaseDataStore): """ Stores scheduler data in memory, without serializing it. - Can be shared between multiple schedulers and workers within the same event loop. - - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task + Can be shared between multiple schedulers within the same event loop. """ - lock_expiration_delay: float = 30 _tasks: dict[str, TaskState] = attrs.Factory(dict) _schedules: list[ScheduleState] = attrs.Factory(list) _schedules_by_id: dict[str, ScheduleState] = attrs.Factory(dict) @@ -115,41 +111,43 @@ class MemoryDataStore(BaseDataStore): right_index = bisect_left(self._jobs, state) return self._jobs.index(state, left_index, right_index + 1) - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: + async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: return [ state.schedule for state in self._schedules if ids is None or state.schedule.id in ids ] - def add_task(self, task: Task) -> None: + async def add_task(self, task: Task) -> None: task_exists = task.id in self._tasks self._tasks[task.id] = TaskState(task) if task_exists: - self._events.publish(TaskUpdated(task_id=task.id)) + await self._event_broker.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._event_broker.publish(TaskAdded(task_id=task.id)) - def remove_task(self, task_id: str) -> None: + async def remove_task(self, task_id: str) -> None: try: del self._tasks[task_id] except KeyError: raise TaskLookupError(task_id) from None - self._events.publish(TaskRemoved(task_id=task_id)) + await self._event_broker.publish(TaskRemoved(task_id=task_id)) - def get_task(self, task_id: str) -> Task: + async def get_task(self, task_id: str) -> Task: try: return self._tasks[task_id].task except KeyError: raise TaskLookupError(task_id) from None - def get_tasks(self) -> list[Task]: + async def get_tasks(self) -> list[Task]: return sorted( (state.task for state in self._tasks.values()), key=lambda task: task.id ) - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: old_state = self._schedules_by_id.get(schedule.id) if old_state is not None: if conflict_policy is ConflictPolicy.do_nothing: @@ -175,17 +173,17 @@ class MemoryDataStore(BaseDataStore): schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) - def remove_schedules(self, ids: Iterable[str]) -> None: + async def remove_schedules(self, ids: Iterable[str]) -> None: for schedule_id in ids: state = self._schedules_by_id.pop(schedule_id, None) if state: self._schedules.remove(state) event = ScheduleRemoved(schedule_id=state.schedule.id) - self._events.publish(event) + await self._event_broker.publish(event) - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: + async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: now = datetime.now(timezone.utc) schedules: list[Schedule] = [] for state in self._schedules: @@ -206,7 +204,9 @@ class MemoryDataStore(BaseDataStore): return schedules - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: # Send update events for schedules that have a next time finished_schedule_ids: list[str] = [] for s in schedules: @@ -224,17 +224,17 @@ class MemoryDataStore(BaseDataStore): event = ScheduleUpdated( schedule_id=s.id, next_fire_time=s.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) else: finished_schedule_ids.append(s.id) # Remove schedules that didn't get a new next fire time - self.remove_schedules(finished_schedule_ids) + await self.remove_schedules(finished_schedule_ids) - def get_next_schedule_run_time(self) -> datetime | None: + async def get_next_schedule_run_time(self) -> datetime | None: return self._schedules[0].next_fire_time if self._schedules else None - def add_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: state = JobState(job) self._jobs.append(state) self._jobs_by_id[job.id] = state @@ -246,15 +246,15 @@ class MemoryDataStore(BaseDataStore): schedule_id=job.schedule_id, tags=job.tags, ) - self._events.publish(event) + await self._event_broker.publish(event) - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: + async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: if ids is not None: ids = frozenset(ids) return [state.job for state in self._jobs if ids is None or state.job.id in ids] - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: now = datetime.now(timezone.utc) jobs: list[Job] = [] for _index, job_state in enumerate(self._jobs): @@ -290,11 +290,15 @@ class MemoryDataStore(BaseDataStore): # Publish the appropriate events for job in jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + await self._event_broker.publish( + JobAcquired(job_id=job.id, worker_id=worker_id) + ) return jobs - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: # Record the job result if result.expires_at > result.finished_at: self._job_results[result.job_id] = result @@ -310,5 +314,5 @@ class MemoryDataStore(BaseDataStore): index = self._find_job_index(job_state) del self._jobs[index] - def get_job_result(self, job_id: UUID) -> JobResult | None: + async def get_job_result(self, job_id: UUID) -> JobResult | None: return self._job_results.pop(job_id, None) diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index 13299bb..4899f54 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -2,14 +2,13 @@ from __future__ import annotations import operator from collections import defaultdict +from contextlib import AsyncExitStack from datetime import datetime, timedelta, timezone -from logging import Logger, getLogger from typing import Any, Callable, ClassVar, Iterable from uuid import UUID import attrs import pymongo -import tenacity from attrs.validators import instance_of from bson import CodecOptions, UuidRepresentation from bson.codec_options import TypeEncoder, TypeRegistry @@ -35,10 +34,9 @@ from .._exceptions import ( SerializationError, TaskLookupError, ) -from .._structures import Job, JobResult, RetrySettings, Schedule, Task -from ..abc import EventBroker, Serializer -from ..serializers.pickle import PickleSerializer -from .base import BaseDataStore +from .._structures import Job, JobResult, Schedule, Task +from ..abc import EventBroker +from .base import BaseExternalDataStore class CustomEncoder(TypeEncoder): @@ -55,7 +53,7 @@ class CustomEncoder(TypeEncoder): @attrs.define(eq=False) -class MongoDBDataStore(BaseDataStore): +class MongoDBDataStore(BaseExternalDataStore): """ Uses a MongoDB server to store data. @@ -66,23 +64,11 @@ class MongoDBDataStore(BaseDataStore): raises :exc:`pymongo.errors.ConnectionFailure`. :param client: a PyMongo client - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage :param database: name of the database to use - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) """ client: MongoClient = attrs.field(validator=instance_of(MongoClient)) - serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True) database: str = attrs.field(default="apscheduler", kw_only=True) - lock_expiration_delay: float = attrs.field(default=30, kw_only=True) - retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True) - start_from_scratch: bool = attrs.field(default=False, kw_only=True) _task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)] _schedule_attrs: ClassVar[list[str]] = [ @@ -90,8 +76,8 @@ class MongoDBDataStore(BaseDataStore): ] _job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)] - _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__)) _local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict) + _temporary_failure_exceptions = (ConnectionFailure,) def __attrs_post_init__(self) -> None: type_registry = TypeRegistry( @@ -118,25 +104,10 @@ class MongoDBDataStore(BaseDataStore): client = MongoClient(uri) return cls(client, **options) - def _retry(self) -> tenacity.Retrying: - return tenacity.Retrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type(ConnectionFailure), - after=self._after_attempt, - reraise=True, - ) - - def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None: - self._logger.warning( - "Temporary data store error (attempt %d): %s", - retry_state.attempt_number, - retry_state.outcome.exception(), - ) - - def start(self, event_broker: EventBroker) -> None: - super().start(event_broker) - + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: + await super().start(exit_stack, event_broker) server_info = self.client.server_info() if server_info["versionArray"] < [4, 0]: raise RuntimeError( @@ -144,7 +115,7 @@ class MongoDBDataStore(BaseDataStore): f"{server_info['version']}" ) - for attempt in self._retry(): + async for attempt in self._retry(): with attempt, self.client.start_session() as session: if self.start_from_scratch: self._tasks.delete_many({}, session=session) @@ -159,7 +130,7 @@ class MongoDBDataStore(BaseDataStore): self._jobs_results.create_index("finished_at", session=session) self._jobs_results.create_index("expires_at", session=session) - def add_task(self, task: Task) -> None: + async def add_task(self, task: Task) -> None: for attempt in self._retry(): with attempt: previous = self._tasks.find_one_and_update( @@ -173,20 +144,20 @@ class MongoDBDataStore(BaseDataStore): self._local_tasks[task.id] = task if previous: - self._events.publish(TaskUpdated(task_id=task.id)) + await self._event_broker.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._event_broker.publish(TaskAdded(task_id=task.id)) - def remove_task(self, task_id: str) -> None: + async def remove_task(self, task_id: str) -> None: for attempt in self._retry(): with attempt: if not self._tasks.find_one_and_delete({"_id": task_id}): raise TaskLookupError(task_id) del self._local_tasks[task_id] - self._events.publish(TaskRemoved(task_id=task_id)) + await self._event_broker.publish(TaskRemoved(task_id=task_id)) - def get_task(self, task_id: str) -> Task: + async def get_task(self, task_id: str) -> Task: try: return self._local_tasks[task_id] except KeyError: @@ -205,7 +176,7 @@ class MongoDBDataStore(BaseDataStore): ) return task - def get_tasks(self) -> list[Task]: + async def get_tasks(self) -> list[Task]: for attempt in self._retry(): with attempt: tasks: list[Task] = [] @@ -217,7 +188,7 @@ class MongoDBDataStore(BaseDataStore): return tasks - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: + async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: filters = {"_id": {"$in": list(ids)}} if ids is not None else {} for attempt in self._retry(): with attempt: @@ -237,7 +208,9 @@ class MongoDBDataStore(BaseDataStore): return schedules - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: event: DataStoreEvent document = schedule.marshal(self.serializer) document["_id"] = document.pop("id") @@ -258,14 +231,14 @@ class MongoDBDataStore(BaseDataStore): event = ScheduleUpdated( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) else: event = ScheduleAdded( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) - def remove_schedules(self, ids: Iterable[str]) -> None: + async def remove_schedules(self, ids: Iterable[str]) -> None: filters = {"_id": {"$in": list(ids)}} if ids is not None else {} for attempt in self._retry(): with attempt, self.client.start_session() as session: @@ -277,9 +250,9 @@ class MongoDBDataStore(BaseDataStore): self._schedules.delete_many(filters, session=session) for schedule_id in ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: + async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: for attempt in self._retry(): with attempt, self.client.start_session() as session: schedules: list[Schedule] = [] @@ -318,7 +291,9 @@ class MongoDBDataStore(BaseDataStore): return schedules - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: updated_schedules: list[tuple[str, datetime]] = [] finished_schedule_ids: list[str] = [] @@ -365,12 +340,12 @@ class MongoDBDataStore(BaseDataStore): event = ScheduleUpdated( schedule_id=schedule_id, next_fire_time=next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) for schedule_id in finished_schedule_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_next_schedule_run_time(self) -> datetime | None: + async def get_next_schedule_run_time(self) -> datetime | None: for attempt in self._retry(): with attempt: document = self._schedules.find_one( @@ -384,7 +359,7 @@ class MongoDBDataStore(BaseDataStore): else: return None - def add_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: document = job.marshal(self.serializer) document["_id"] = document.pop("id") for attempt in self._retry(): @@ -397,9 +372,9 @@ class MongoDBDataStore(BaseDataStore): schedule_id=job.schedule_id, tags=job.tags, ) - self._events.publish(event) + await self._event_broker.publish(event) - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: + async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: filters = {"_id": {"$in": list(ids)}} if ids is not None else {} for attempt in self._retry(): with attempt: @@ -419,7 +394,7 @@ class MongoDBDataStore(BaseDataStore): return jobs - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: for attempt in self._retry(): with attempt, self.client.start_session() as session: cursor = self._jobs.find( @@ -488,11 +463,15 @@ class MongoDBDataStore(BaseDataStore): # Publish the appropriate events for job in acquired_jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + await self._event_broker.publish( + JobAcquired(job_id=job.id, worker_id=worker_id) + ) return acquired_jobs - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: for attempt in self._retry(): with attempt, self.client.start_session() as session: # Record the job result @@ -509,7 +488,7 @@ class MongoDBDataStore(BaseDataStore): # Delete the job self._jobs.delete_one({"_id": result.job_id}, session=session) - def get_job_result(self, job_id: UUID) -> JobResult | None: + async def get_job_result(self, job_id: UUID) -> JobResult | None: for attempt in self._retry(): with attempt: document = self._jobs_results.find_one_and_delete({"_id": job_id}) diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index ebb076f..00d06aa 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -1,16 +1,20 @@ from __future__ import annotations +import sys from collections import defaultdict +from collections.abc import AsyncGenerator, Sequence +from contextlib import AsyncExitStack, asynccontextmanager from datetime import datetime, timedelta, timezone -from logging import Logger, getLogger from typing import Any, Iterable from uuid import UUID +import anyio import attrs +import sniffio import tenacity +from anyio import to_thread from sqlalchemy import ( JSON, - TIMESTAMP, BigInteger, Column, Enum, @@ -26,14 +30,15 @@ from sqlalchemy import ( select, ) from sqlalchemy.engine import URL, Dialect, Result -from sqlalchemy.exc import CompileError, IntegrityError, OperationalError -from sqlalchemy.future import Engine, create_engine +from sqlalchemy.exc import CompileError, IntegrityError, InterfaceError +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine +from sqlalchemy.future import Connection, Engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome from .._events import ( - Event, + DataStoreEvent, JobAcquired, JobAdded, JobDeserializationFailed, @@ -46,11 +51,15 @@ from .._events import ( TaskUpdated, ) from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError -from .._structures import Job, JobResult, RetrySettings, Schedule, Task -from ..abc import EventBroker, Serializer +from .._structures import Job, JobResult, Schedule, Task +from ..abc import EventBroker from ..marshalling import callable_to_ref -from ..serializers.pickle import PickleSerializer -from .base import BaseDataStore +from .base import BaseExternalDataStore + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self class EmulatedUUID(TypeDecorator): @@ -86,17 +95,30 @@ class EmulatedInterval(TypeDecorator): return timedelta(seconds=value) if value is not None else None -@attrs.define(kw_only=True, eq=False) -class _BaseSQLAlchemyDataStore: +@attrs.define(eq=False) +class SQLAlchemyDataStore(BaseExternalDataStore): + """ + Uses a relational database to store data. + + When started, this data store creates the appropriate tables on the given database + if they're not already present. + + Operations are retried (in accordance to ``retry_settings``) when an operation + raises :exc:`sqlalchemy.OperationalError`. + + This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL + (asyncmy driver). + + :param engine: an asynchronous SQLAlchemy engine + :param schema: a database schema name to use, if not the default + """ + + engine: Engine | AsyncEngine schema: str | None = attrs.field(default=None) - serializer: Serializer = attrs.field(factory=PickleSerializer) - lock_expiration_delay: float = attrs.field(default=30) max_poll_time: float | None = attrs.field(default=1) max_idle_time: float = attrs.field(default=60) - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) - start_from_scratch: bool = attrs.field(default=False) - _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__)) + _is_async: bool = attrs.field(init=False) def __attrs_post_init__(self) -> None: # Generate the table definitions @@ -107,6 +129,7 @@ class _BaseSQLAlchemyDataStore: self.t_schedules = self._metadata.tables[prefix + "schedules"] self.t_jobs = self._metadata.tables[prefix + "jobs"] self.t_job_results = self._metadata.tables[prefix + "job_results"] + self._is_async = isinstance(self.engine, AsyncEngine) # Find out if the dialect supports UPDATE...RETURNING update = self.t_jobs.update().returning(self.t_jobs.c.id) @@ -117,18 +140,76 @@ class _BaseSQLAlchemyDataStore: else: self._supports_update_returning = True - def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None: - self._logger.warning( - "Temporary data store error (attempt %d): %s", - retry_state.attempt_number, - retry_state.outcome.exception(), + @classmethod + def from_url(cls: type[Self], url: str | URL, **options) -> Self: + """ + Create a new asynchronous SQLAlchemy data store. + + :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` + (must use an async dialect like ``asyncpg`` or ``asyncmy``) + :param kwargs: keyword arguments to pass to the initializer of this class + :return: the newly created data store + + """ + engine = create_async_engine(url, future=True) + return cls(engine, **options) + + def _retry(self) -> tenacity.AsyncRetrying: + def after_attempt(self, retry_state: tenacity.RetryCallState) -> None: + self._logger.warning( + "Temporary data store error (attempt %d): %s", + retry_state.attempt_number, + retry_state.outcome.exception(), + ) + + # OSError is raised by asyncpg if it can't connect + return tenacity.AsyncRetrying( + stop=self.retry_settings.stop, + wait=self.retry_settings.wait, + retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), + after=after_attempt, + sleep=anyio.sleep, + reraise=True, ) + @asynccontextmanager + async def _begin_transaction(self) -> AsyncGenerator[Connection | AsyncConnection]: + if isinstance(self.engine, AsyncEngine): + async with self.engine.begin() as conn: + yield conn + else: + cm = self.engine.begin() + conn = await to_thread.run_sync(cm.__enter__) + try: + yield conn + except BaseException as exc: + await to_thread.run_sync(cm.__exit__, type(exc), exc, exc.__traceback__) + raise + else: + await to_thread.run_sync(cm.__exit__, None, None, None) + + async def _create_metadata(self, conn: Connection | AsyncConnection) -> None: + if isinstance(conn, AsyncConnection): + await conn.run_sync(self._metadata.create_all) + else: + await to_thread.run_sync(self._metadata.create_all, conn) + + async def _execute( + self, + conn: Connection | AsyncConnection, + statement, + parameters: Sequence | None = None, + ): + if isinstance(conn, AsyncConnection): + return await conn.execute(statement, parameters) + else: + return await to_thread.run_sync(conn.execute, statement, parameters) + def get_table_definitions(self) -> MetaData: if self.engine.dialect.name == "postgresql": from sqlalchemy.dialects import postgresql - timestamp_type = TIMESTAMP(timezone=True) + timestamp_type = postgresql.TIMESTAMP(timezone=True) job_id_type = postgresql.UUID(as_uuid=True) interval_type = postgresql.INTERVAL(precision=6) tags_type = postgresql.ARRAY(Unicode) @@ -145,6 +226,7 @@ class _BaseSQLAlchemyDataStore: metadata, Column("id", Unicode(500), primary_key=True), Column("func", Unicode(500), nullable=False), + Column("executor", Unicode(500), nullable=False), Column("state", LargeBinary), Column("max_running_jobs", Integer), Column("misfire_grace_time", interval_type), @@ -197,186 +279,160 @@ class _BaseSQLAlchemyDataStore: ) return metadata - def _deserialize_schedules(self, result: Result) -> list[Schedule]: + async def start( + self, exit_stack: AsyncExitStack, event_broker: EventBroker + ) -> None: + await super().start(exit_stack, event_broker) + asynclib = sniffio.current_async_library() or "(unknown)" + if asynclib != "asyncio": + raise RuntimeError( + f"This data store requires asyncio; currently running: {asynclib}" + ) + + # Verify that the schema is in place + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + await self._execute(conn, DropTable(table, if_exists=True)) + + await self._create_metadata(conn) + query = select(self.t_metadata.c.schema_version) + result = await self._execute(conn, query) + version = result.scalar() + if version is None: + await self._execute( + conn, self.t_metadata.insert(values={"schema_version": 1}) + ) + elif version > 1: + raise RuntimeError( + f"Unexpected schema version ({version}); " + f"only version 1 is supported by this version of " + f"APScheduler" + ) + + async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: try: schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._event_broker.publish( ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc) ) return schedules - def _deserialize_jobs(self, result: Result) -> list[Job]: + async def _deserialize_jobs(self, result: Result) -> list[Job]: jobs: list[Job] = [] for row in result: try: jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._event_broker.publish( JobDeserializationFailed(job_id=row["id"], exception=exc) ) return jobs - -@attrs.define(eq=False) -class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): - """ - Uses a relational database to store data. - - When started, this data store creates the appropriate tables on the given database - if they're not already present. - - Operations are retried (in accordance to ``retry_settings``) when an operation - raises :exc:`sqlalchemy.OperationalError`. - - This store has been tested to work with PostgreSQL (psycopg2 driver), MySQL - (pymysql driver) and SQLite. - - :param engine: a (synchronous) SQLAlchemy engine - :param schema: a database schema name to use, if not the default - :param serializer: the serializer used to (de)serialize tasks, schedules and jobs - for storage - :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler - or worker can keep a lock on a schedule or task - :param retry_settings: Tenacity settings for retrying operations in case of a - database connecitivty problem - :param start_from_scratch: erase all existing data during startup (useful for test - suites) - """ - - engine: Engine - - @classmethod - def from_url(cls, url: str | URL, **kwargs) -> SQLAlchemyDataStore: - """ - Create a new SQLAlchemy data store. - - :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine` - :param kwargs: keyword arguments to pass to the initializer of this class - :return: the newly created data store - - """ - engine = create_engine(url) - return cls(engine, **kwargs) - - def _retry(self) -> tenacity.Retrying: - return tenacity.Retrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type(OperationalError), - after=self._after_attempt, - reraise=True, - ) - - def start(self, event_broker: EventBroker) -> None: - super().start(event_broker) - - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - conn.execute(DropTable(table, if_exists=True)) - - self._metadata.create_all(conn) - query = select(self.t_metadata.c.schema_version) - result = conn.execute(query) - version = result.scalar() - if version is None: - conn.execute(self.t_metadata.insert(values={"schema_version": 1})) - elif version > 1: - raise RuntimeError( - f"Unexpected schema version ({version}); " - f"only version 1 is supported by this version of APScheduler" - ) - - def add_task(self, task: Task) -> None: + async def add_task(self, task: Task) -> None: insert = self.t_tasks.insert().values( id=task.id, func=callable_to_ref(task.func), + executor=task.executor, max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) try: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) except IntegrityError: update = ( self.t_tasks.update() .values( func=callable_to_ref(task.func), + executor=task.executor, max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time, ) .where(self.t_tasks.c.id == task.id) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(update) - self._events.publish(TaskUpdated(task_id=task.id)) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, update) + + await self._event_broker.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._event_broker.publish(TaskAdded(task_id=task.id)) - def remove_task(self, task_id: str) -> None: + async def remove_task(self, task_id: str) -> None: delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - self._events.publish(TaskRemoved(task_id=task_id)) - - def get_task(self, task_id: str) -> Task: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + await self._event_broker.publish(TaskRemoved(task_id=task_id)) + + async def get_task(self, task_id: str) -> Task: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, + self.t_tasks.c.executor, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).where(self.t_tasks.c.id == task_id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - row = result.first() + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + row = result.first() if row: return Task.unmarshal(self.serializer, row._asdict()) else: raise TaskLookupError(task_id) - def get_tasks(self) -> list[Task]: + async def get_tasks(self) -> list[Task]: query = select( [ self.t_tasks.c.id, self.t_tasks.c.func, + self.t_tasks.c.executor, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time, ] ).order_by(self.t_tasks.c.id) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - tasks = [ - Task.unmarshal(self.serializer, row._asdict()) for row in result - ] - return tasks - - def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: - event: Event + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + tasks = [ + Task.unmarshal(self.serializer, row._asdict()) for row in result + ] + return tasks + + async def add_schedule( + self, schedule: Schedule, conflict_policy: ConflictPolicy + ) -> None: + event: DataStoreEvent values = schedule.marshal(self.serializer) insert = self.t_schedules.insert().values(**values) try: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) except IntegrityError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None @@ -387,185 +443,197 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): .where(self.t_schedules.c.id == schedule.id) .values(**values) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(update) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, update) event = ScheduleUpdated( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) + await self._event_broker.publish(event) else: event = ScheduleAdded( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time ) - self._events.publish(event) - - def remove_schedules(self, ids: Iterable[str]) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [ - row[0] for row in conn.execute(delete) - ] - else: - # TODO: actually check which rows were deleted? - conn.execute(delete) - removed_ids = ids + await self._event_broker.publish(event) + + async def remove_schedules(self, ids: Iterable[str]) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(ids) + ) + if self._supports_update_returning: + delete = delete.returning(self.t_schedules.c.id) + removed_ids: Iterable[str] = [ + row[0] for row in await self._execute(conn, delete) + ] + else: + # TODO: actually check which rows were deleted? + await self._execute(conn, delete) + removed_ids = ids for schedule_id in removed_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: + async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: query = self.t_schedules.select().order_by(self.t_schedules.c.id) if ids: query = query.where(self.t_schedules.c.id.in_(ids)) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_schedules(result) - - def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = ( - select(self.t_schedules.c.id) - .where( - and_( - self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_( - self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now, - ), + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + return await self._deserialize_schedules(result) + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + schedules_cte = ( + select(self.t_schedules.c.id) + .where( + and_( + self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_( + self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now, + ), + ) ) + .order_by(self.t_schedules.c.next_fire_time) + .limit(limit) + .with_for_update(skip_locked=True) + .cte() ) - .order_by(self.t_schedules.c.next_fire_time) - .limit(limit) - .with_for_update(skip_locked=True) - .cte() - ) - subselect = select([schedules_cte.c.id]) - update = ( - self.t_schedules.update() - .where(self.t_schedules.c.id.in_(subselect)) - .values(acquired_by=scheduler_id, acquired_until=acquired_until) - ) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = conn.execute(update) - else: - conn.execute(update) - query = self.t_schedules.select().where( - and_(self.t_schedules.c.acquired_by == scheduler_id) + subselect = select([schedules_cte.c.id]) + update = ( + self.t_schedules.update() + .where(self.t_schedules.c.id.in_(subselect)) + .values(acquired_by=scheduler_id, acquired_until=acquired_until) ) - result = conn.execute(query) + if self._supports_update_returning: + update = update.returning(*self.t_schedules.columns) + result = await self._execute(conn, update) + else: + await self._execute(conn, update) + query = self.t_schedules.select().where( + and_(self.t_schedules.c.acquired_by == scheduler_id) + ) + result = await self._execute(conn, query) - schedules = self._deserialize_schedules(result) + schedules = await self._deserialize_schedules(result) return schedules - def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize( - schedule.trigger - ) - except SerializationError: - self._logger.exception( - "Error serializing trigger for schedule %r – " - "removing from data store", - schedule.id, + async def release_schedules( + self, scheduler_id: str, schedules: list[Schedule] + ) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + update_events: list[ScheduleUpdated] = [] + finished_schedule_ids: list[str] = [] + update_args: list[dict[str, Any]] = [] + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_trigger = self.serializer.serialize( + schedule.trigger + ) + except SerializationError: + self._logger.exception( + "Error serializing trigger for schedule %r – " + "removing from data store", + schedule.id, + ) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append( + { + "p_id": schedule.id, + "p_trigger": serialized_trigger, + "p_next_fire_time": schedule.next_fire_time, + } ) + else: finished_schedule_ids.append(schedule.id) - continue - - update_args.append( - { - "p_id": schedule.id, - "p_trigger": serialized_trigger, - "p_next_fire_time": schedule.next_fire_time, - } - ) - else: - finished_schedule_ids.append(schedule.id) - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam("p_id") - p_trigger: BindParameter = bindparam("p_trigger") - p_next_fire_time: BindParameter = bindparam("p_next_fire_time") - update = ( - self.t_schedules.update() - .where( - and_( - self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id, + # Update schedules that have a next fire time + if update_args: + p_id: BindParameter = bindparam("p_id") + p_trigger: BindParameter = bindparam("p_trigger") + p_next_fire_time: BindParameter = bindparam("p_next_fire_time") + update = ( + self.t_schedules.update() + .where( + and_( + self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id, + ) + ) + .values( + trigger=p_trigger, + next_fire_time=p_next_fire_time, + acquired_by=None, + acquired_until=None, ) ) - .values( - trigger=p_trigger, - next_fire_time=p_next_fire_time, - acquired_by=None, - acquired_until=None, - ) - ) - next_fire_times = { - arg["p_id"]: arg["p_next_fire_time"] for arg in update_args - } - # TODO: actually check which rows were updated? - conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated( - schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id], - ) - update_events.append(event) + next_fire_times = { + arg["p_id"]: arg["p_next_fire_time"] for arg in update_args + } + # TODO: actually check which rows were updated? + await self._execute(conn, update, update_args) + updated_ids = list(next_fire_times) + + for schedule_id in updated_ids: + event = ScheduleUpdated( + schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id], + ) + update_events.append(event) - # Remove schedules that have no next fire time or failed to serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().where( - self.t_schedules.c.id.in_(finished_schedule_ids) - ) - conn.execute(delete) + # Remove schedules that have no next fire time or failed to + # serialize + if finished_schedule_ids: + delete = self.t_schedules.delete().where( + self.t_schedules.c.id.in_(finished_schedule_ids) + ) + await self._execute(conn, delete) for event in update_events: - self._events.publish(event) + await self._event_broker.publish(event) for schedule_id in finished_schedule_ids: - self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id)) - def get_next_schedule_run_time(self) -> datetime | None: - query = ( + async def get_next_schedule_run_time(self) -> datetime | None: + statenent = ( select(self.t_schedules.c.next_fire_time) .where(self.t_schedules.c.next_fire_time.isnot(None)) .order_by(self.t_schedules.c.next_fire_time) .limit(1) ) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return result.scalar() + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, statenent) + return result.scalar() - def add_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: marshalled = job.marshal(self.serializer) insert = self.t_jobs.insert().values(**marshalled) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - conn.execute(insert) + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + await self._execute(conn, insert) event = JobAdded( job_id=job.id, @@ -573,139 +641,156 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): schedule_id=job.schedule_id, tags=job.tags, ) - self._events.publish(event) + await self._event_broker.publish(event) - def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: + async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: query = self.t_jobs.select().order_by(self.t_jobs.c.id) if ids: job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_jobs(result) - - def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = ( - self.t_jobs.select() - .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) - .where( - or_( - self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now, + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + result = await self._execute(conn, query) + return await self._deserialize_jobs(result) + + async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + query = ( + self.t_jobs.select() + .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id) + .where( + or_( + self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now, + ) ) + .order_by(self.t_jobs.c.created_at) + .with_for_update(skip_locked=True) + .limit(limit) ) - .order_by(self.t_jobs.c.created_at) - .with_for_update(skip_locked=True) - .limit(limit) - ) - - result = conn.execute(query) - if not result: - return [] - # Mark the jobs as acquired by this worker - jobs = self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select( - [ - self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs, - ] - ).where( - self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids), - ) - result = conn.execute(query) - job_slots_left = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = ( - self.t_jobs.update() - .values(acquired_by=worker_id, acquired_until=acquired_until) - .where(self.t_jobs.c.id.in_(acquired_job_ids)) + result = await self._execute(conn, query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = await self._deserialize_jobs(result) + task_ids: set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select( + [ + self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs + - self.t_tasks.c.running_jobs, + ] + ).where( + self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids), ) - conn.execute(update) - - # Increment the running job counters on each task - p_id: BindParameter = bindparam("p_id") - p_increment: BindParameter = bindparam("p_increment") - params = [ - {"p_id": task_id, "p_increment": increment} - for task_id, increment in increments.items() - ] - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs + p_increment) - .where(self.t_tasks.c.id == p_id) - ) - conn.execute(update, params) + result = await self._execute(conn, query) + job_slots_left: dict[str, int] = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: list[Job] = [] + increments: dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id for job in acquired_jobs] + update = ( + self.t_jobs.update() + .values( + acquired_by=worker_id, acquired_until=acquired_until + ) + .where(self.t_jobs.c.id.in_(acquired_job_ids)) + ) + await self._execute(conn, update) + + # Increment the running job counters on each task + p_id: BindParameter = bindparam("p_id") + p_increment: BindParameter = bindparam("p_increment") + params = [ + {"p_id": task_id, "p_increment": increment} + for task_id, increment in increments.items() + ] + update = ( + self.t_tasks.update() + .values( + running_jobs=self.t_tasks.c.running_jobs + p_increment + ) + .where(self.t_tasks.c.id == p_id) + ) + await self._execute(conn, update, params) # Publish the appropriate events for job in acquired_jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + await self._event_broker.publish( + JobAcquired(job_id=job.id, worker_id=worker_id) + ) return acquired_jobs - def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - # Insert the job result - if result.expires_at > result.finished_at: - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - conn.execute(insert) + async def release_job( + self, worker_id: str, task_id: str, result: JobResult + ) -> None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + # Record the job result + if result.expires_at > result.finished_at: + marshalled = result.marshal(self.serializer) + insert = self.t_job_results.insert().values(**marshalled) + await self._execute(conn, insert) + + # Decrement the number of running jobs for this task + update = ( + self.t_tasks.update() + .values(running_jobs=self.t_tasks.c.running_jobs - 1) + .where(self.t_tasks.c.id == task_id) + ) + await self._execute(conn, update) - # Decrement the running jobs counter - update = ( - self.t_tasks.update() - .values(running_jobs=self.t_tasks.c.running_jobs - 1) - .where(self.t_tasks.c.id == task_id) - ) - conn.execute(update) - - # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) - conn.execute(delete) - - def get_job_result(self, job_id: UUID) -> JobResult | None: - for attempt in self._retry(): - with attempt, self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().where( - self.t_job_results.c.job_id == job_id - ) - row = conn.execute(query).first() + # Delete the job + delete = self.t_jobs.delete().where( + self.t_jobs.c.id == result.job_id + ) + await self._execute(conn, delete) + + async def get_job_result(self, job_id: UUID) -> JobResult | None: + async for attempt in self._retry(): + with attempt: + async with self._begin_transaction() as conn: + # Retrieve the result + query = self.t_job_results.select().where( + self.t_job_results.c.job_id == job_id + ) + row = (await self._execute(conn, query)).first() - # Delete the result - delete = self.t_job_results.delete().where( - self.t_job_results.c.job_id == job_id - ) - conn.execute(delete) + # Delete the result + delete = self.t_job_results.delete().where( + self.t_job_results.c.job_id == job_id + ) + await self._execute(conn, delete) - return ( - JobResult.unmarshal(self.serializer, row._asdict()) if row else None - ) + return ( + JobResult.unmarshal(self.serializer, row._asdict()) + if row + else None + ) diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py deleted file mode 100644 index 4ff08a5..0000000 --- a/src/apscheduler/eventbrokers/async_adapter.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable, Iterable - -import attrs -from anyio import to_thread -from anyio.from_thread import BlockingPortal - -from .._events import Event -from ..abc import AsyncEventBroker, EventBroker, Subscription - - -@attrs.define(eq=False) -class AsyncEventBrokerAdapter(AsyncEventBroker): - original: EventBroker - - async def start(self) -> None: - await to_thread.run_sync(self.original.start) - - async def stop(self, *, force: bool = False) -> None: - await to_thread.run_sync(lambda: self.original.stop(force=force)) - - async def publish_local(self, event: Event) -> None: - await to_thread.run_sync(self.original.publish_local, event) - - async def publish(self, event: Event) -> None: - await to_thread.run_sync(self.original.publish, event) - - def subscribe( - self, - callback: Callable[[Event], Any], - event_types: Iterable[type[Event]] | None = None, - *, - one_shot: bool = False, - ) -> Subscription: - return self.original.subscribe(callback, event_types, one_shot=one_shot) - - -@attrs.define(eq=False) -class SyncEventBrokerAdapter(EventBroker): - original: AsyncEventBroker - portal: BlockingPortal - - def start(self) -> None: - pass - - def stop(self, *, force: bool = False) -> None: - pass - - def publish_local(self, event: Event) -> None: - self.portal.call(self.original.publish_local, event) - - def publish(self, event: Event) -> None: - self.portal.call(self.original.publish, event) - - def subscribe( - self, - callback: Callable[[Event], Any], - event_types: Iterable[type[Event]] | None = None, - *, - one_shot: bool = False, - ) -> Subscription: - return self.portal.call( - lambda: self.original.subscribe(callback, event_types, one_shot=one_shot) - ) diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py deleted file mode 100644 index f69bfc7..0000000 --- a/src/apscheduler/eventbrokers/async_local.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from asyncio import iscoroutine -from typing import Any, Callable - -import attrs -from anyio import create_task_group -from anyio.abc import TaskGroup - -from .._events import Event -from ..abc import AsyncEventBroker -from .base import BaseEventBroker - - -@attrs.define(eq=False) -class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker): - """ - Asynchronous, local event broker. - - This event broker only broadcasts within the process it runs in, and is therefore - not suitable for multi-node or multiprocess use cases. - - Does not serialize events. - """ - - _task_group: TaskGroup = attrs.field(init=False) - - async def start(self) -> None: - self._task_group = create_task_group() - await self._task_group.__aenter__() - - async def stop(self, *, force: bool = False) -> None: - await self._task_group.__aexit__(None, None, None) - del self._task_group - - async def publish(self, event: Event) -> None: - await self.publish_local(event) - - async def publish_local(self, event: Event) -> None: - event_type = type(event) - one_shot_tokens: list[object] = [] - for _token, subscription in self._subscriptions.items(): - if ( - subscription.event_types is None - or event_type in subscription.event_types - ): - self._task_group.start_soon( - self._deliver_event, subscription.callback, event - ) - if subscription.one_shot: - one_shot_tokens.append(subscription.token) - - for token in one_shot_tokens: - super().unsubscribe(token) - - async def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None: - try: - retval = func(event) - if iscoroutine(retval): - await retval - except BaseException: - self._logger.exception( - "Error delivering %s event", event.__class__.__name__ - ) diff --git a/src/apscheduler/eventbrokers/async_redis.py b/src/apscheduler/eventbrokers/async_redis.py deleted file mode 100644 index 5e71621..0000000 --- a/src/apscheduler/eventbrokers/async_redis.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -from asyncio import CancelledError - -import anyio -import attrs -import tenacity -from redis import ConnectionError -from redis.asyncio import Redis, RedisCluster -from redis.asyncio.client import PubSub -from redis.asyncio.connection import ConnectionPool - -from .. import RetrySettings -from .._events import Event -from ..abc import Serializer -from ..serializers.json import JSONSerializer -from .async_local import LocalAsyncEventBroker -from .base import DistributedEventBrokerMixin - - -@attrs.define(eq=False) -class AsyncRedisEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): - """ - An event broker that uses a Redis server to broadcast events. - - Requires the redis_ library to be installed. - - .. _redis: https://pypi.org/project/redis/ - - :param client: an asynchronous Redis client - :param serializer: the serializer used to (de)serialize events for transport - :param channel: channel on which to send the messages - :param retry_settings: Tenacity settings for retrying operations in case of a - broker connectivity problem - :param stop_check_interval: interval (in seconds) on which the channel listener - should check if it should stop (higher values mean slower reaction time but less - CPU use) - """ - - client: Redis | RedisCluster - serializer: Serializer = attrs.field(factory=JSONSerializer) - channel: str = attrs.field(kw_only=True, default="apscheduler") - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) - stop_check_interval: float = attrs.field(kw_only=True, default=1) - _stopped: bool = attrs.field(init=False, default=True) - - @classmethod - def from_url(cls, url: str, **kwargs) -> AsyncRedisEventBroker: - """ - Create a new event broker from a URL. - - :param url: a Redis URL (```redis://...```) - :param kwargs: keyword arguments to pass to the initializer of this class - :return: the newly created event broker - - """ - pool = ConnectionPool.from_url(url) - client = Redis(connection_pool=pool) - return cls(client, **kwargs) - - def _retry(self) -> tenacity.AsyncRetrying: - def after_attempt(retry_state: tenacity.RetryCallState) -> None: - self._logger.warning( - f"{self.__class__.__name__}: connection failure " - f"(attempt {retry_state.attempt_number}): " - f"{retry_state.outcome.exception()}", - ) - - return tenacity.AsyncRetrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type(ConnectionError), - after=after_attempt, - sleep=anyio.sleep, - reraise=True, - ) - - async def start(self) -> None: - await super().start() - pubsub = self.client.pubsub() - try: - await pubsub.subscribe(self.channel) - except BaseException: - await self.stop(force=True) - raise - - self._stopped = False - self._task_group.start_soon( - self._listen_messages, pubsub, name="Redis subscriber" - ) - - async def stop(self, *, force: bool = False) -> None: - self._stopped = True - await super().stop(force=force) - - async def _listen_messages(self, pubsub: PubSub) -> None: - while not self._stopped: - try: - async for attempt in self._retry(): - with attempt: - msg = await pubsub.get_message( - ignore_subscribe_messages=True, - timeout=self.stop_check_interval, - ) - - if msg and isinstance(msg["data"], bytes): - event = self.reconstitute_event(msg["data"]) - if event is not None: - await self.publish_local(event) - except Exception as exc: - # CancelledError is a subclass of Exception in Python 3.7 - if not isinstance(exc, CancelledError): - self._logger.exception( - f"{self.__class__.__name__} listener crashed" - ) - - await pubsub.close() - raise - - async def publish(self, event: Event) -> None: - notification = self.generate_notification(event) - async for attempt in self._retry(): - with attempt: - await self.client.publish(self.channel, notification) diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py index 7e3045e..59f7292 100644 --- a/src/apscheduler/eventbrokers/asyncpg.py +++ b/src/apscheduler/eventbrokers/asyncpg.py @@ -5,10 +5,8 @@ from contextlib import AsyncExitStack from functools import partial from typing import TYPE_CHECKING, Any, Callable, cast -import anyio import asyncpg import attrs -import tenacity from anyio import ( TASK_STATUS_IGNORED, EndOfStream, @@ -18,20 +16,16 @@ from anyio import ( from anyio.streams.memory import MemoryObjectSendStream from asyncpg import Connection, InterfaceError -from .. import RetrySettings from .._events import Event from .._exceptions import SerializationError -from ..abc import Serializer -from ..serializers.json import JSONSerializer -from .async_local import LocalAsyncEventBroker -from .base import DistributedEventBrokerMixin +from .base import BaseExternalEventBroker if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncEngine @attrs.define(eq=False) -class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): +class AsyncpgEventBroker(BaseExternalEventBroker): """ An asynchronous, asyncpg_ based event broker that uses a PostgreSQL server to broadcast events using its ``NOTIFY`` mechanism. @@ -39,16 +33,13 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): .. _asyncpg: https://pypi.org/project/asyncpg/ :param connection_factory: a callable that creates an asyncpg connection - :param serializer: the serializer used to (de)serialize events for transport :param channel: the ``NOTIFY`` channel to use :param max_idle_time: maximum time to let the connection go idle, before sending a ``SELECT 1`` query to prevent a connection timeout """ connection_factory: Callable[[], Awaitable[Connection]] - serializer: Serializer = attrs.field(kw_only=True, factory=JSONSerializer) channel: str = attrs.field(kw_only=True, default="apscheduler") - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) max_idle_time: float = attrs.field(kw_only=True, default=10) _send: MemoryObjectSendStream[str] = attrs.field(init=False) @@ -111,38 +102,17 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): factory = partial(asyncpg.connect, **connect_args) return cls(factory, **kwargs) - def _retry(self) -> tenacity.AsyncRetrying: - def after_attempt(retry_state: tenacity.RetryCallState) -> None: - self._logger.warning( - f"{self.__class__.__name__}: connection failure " - f"(attempt {retry_state.attempt_number}): " - f"{retry_state.outcome.exception()}", - ) + @property + def _temporary_failure_exceptions(self) -> tuple[type[Exception]]: + return OSError, InterfaceError - return tenacity.AsyncRetrying( - stop=self.retry_settings.stop, - wait=self.retry_settings.wait, - retry=tenacity.retry_if_exception_type((OSError, InterfaceError)), - after=after_attempt, - sleep=anyio.sleep, - reraise=True, + async def start(self, exit_stack: AsyncExitStack) -> None: + await super().start(exit_stack) + self._send = cast( + MemoryObjectSendStream[str], + await self._task_group.start(self._listen_notifications), ) - - async def start(self) -> None: - await super().start() - try: - self._send = cast( - MemoryObjectSendStream[str], - await self._task_group.start(self._listen_notifications), - ) - except BaseException: - await super().stop(force=True) - raise - - async def stop(self, *, force: bool = False) -> None: - self._send.close() - await super().stop(force=force) - self._logger.info("Stopped event broker") + await exit_stack.enter_async_context(self._send) async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: conn: Connection diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index 1373d54..5a3e718 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -1,15 +1,21 @@ from __future__ import annotations from base64 import b64decode, b64encode +from contextlib import AsyncExitStack +from inspect import iscoroutine from logging import Logger, getLogger from typing import Any, Callable, Iterable import attrs +from anyio import create_task_group, to_thread +from anyio.abc import TaskGroup from .. import _events from .._events import Event from .._exceptions import DeserializationError -from ..abc import EventSource, Serializer, Subscription +from .._retry import RetryMixin +from ..abc import EventBroker, Serializer, Subscription +from ..serializers.json import JSONSerializer @attrs.define(eq=False, frozen=True) @@ -17,6 +23,7 @@ class LocalSubscription(Subscription): callback: Callable[[Event], Any] event_types: set[type[Event]] | None one_shot: bool + is_async: bool token: object _source: BaseEventBroker @@ -24,36 +31,79 @@ class LocalSubscription(Subscription): self._source.unsubscribe(self.token) -@attrs.define(eq=False) -class BaseEventBroker(EventSource): +@attrs.define(kw_only=True) +class BaseEventBroker(EventBroker): _logger: Logger = attrs.field(init=False) _subscriptions: dict[object, LocalSubscription] = attrs.field( init=False, factory=dict ) + _task_group: TaskGroup = attrs.field(init=False) def __attrs_post_init__(self) -> None: self._logger = getLogger(self.__class__.__module__) + async def start(self, exit_stack: AsyncExitStack) -> None: + self._task_group = await exit_stack.enter_async_context(create_task_group()) + def subscribe( self, callback: Callable[[Event], Any], event_types: Iterable[type[Event]] | None = None, *, + is_async: bool = True, one_shot: bool = False, ) -> Subscription: types = set(event_types) if event_types else None token = object() - subscription = LocalSubscription(callback, types, one_shot, token, self) + subscription = LocalSubscription( + callback, types, one_shot, is_async, token, self + ) self._subscriptions[token] = subscription return subscription def unsubscribe(self, token: object) -> None: self._subscriptions.pop(token, None) + async def publish_local(self, event: Event) -> None: + event_type = type(event) + one_shot_tokens: list[object] = [] + for _token, subscription in self._subscriptions.items(): + if ( + subscription.event_types is None + or event_type in subscription.event_types + ): + self._task_group.start_soon(self._deliver_event, subscription, event) + if subscription.one_shot: + one_shot_tokens.append(subscription.token) + + for token in one_shot_tokens: + self.unsubscribe(token) + + async def _deliver_event( + self, subscription: LocalSubscription, event: Event + ) -> None: + try: + if subscription.is_async: + retval = subscription.callback(event) + if iscoroutine(retval): + await retval + else: + await to_thread.run_sync(subscription.callback, event) + except Exception: + self._logger.exception( + "Error delivering %s event", event.__class__.__name__ + ) + + +@attrs.define(kw_only=True) +class BaseExternalEventBroker(BaseEventBroker, RetryMixin): + """ + Base class for event brokers that use an external service. + + :param serializer: the serializer used to (de)serialize events for transport + """ -class DistributedEventBrokerMixin: - serializer: Serializer - _logger: Logger + serializer: Serializer = attrs.field(factory=JSONSerializer) def generate_notification(self, event: Event) -> bytes: serialized = self.serializer.serialize(event.marshal(self.serializer)) diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py index 25ff2dd..27a3cfd 100644 --- a/src/apscheduler/eventbrokers/local.py +++ b/src/apscheduler/eventbrokers/local.py @@ -1,22 +1,15 @@ from __future__ import annotations -from asyncio import iscoroutinefunction -from concurrent.futures import ThreadPoolExecutor -from contextlib import ExitStack -from threading import Lock -from typing import Any, Callable, Iterable - import attrs from .._events import Event -from ..abc import EventBroker, Subscription from .base import BaseEventBroker @attrs.define(eq=False) -class LocalEventBroker(EventBroker, BaseEventBroker): +class LocalEventBroker(BaseEventBroker): """ - Synchronous, local event broker. + Asynchronous, local event broker. This event broker only broadcasts within the process it runs in, and is therefore not suitable for multi-node or multiprocess use cases. @@ -24,62 +17,5 @@ class LocalEventBroker(EventBroker, BaseEventBroker): Does not serialize events. """ - _executor: ThreadPoolExecutor = attrs.field(init=False) - _exit_stack: ExitStack = attrs.field(init=False) - _subscriptions_lock: Lock = attrs.field(init=False, factory=Lock) - - def start(self) -> None: - self._executor = ThreadPoolExecutor(1) - - def stop(self, *, force: bool = False) -> None: - self._executor.shutdown(wait=not force) - del self._executor - - def subscribe( - self, - callback: Callable[[Event], Any], - event_types: Iterable[type[Event]] | None = None, - *, - one_shot: bool = False, - ) -> Subscription: - if iscoroutinefunction(callback): - raise ValueError( - "Coroutine functions are not supported as callbacks on a synchronous " - "event source" - ) - - with self._subscriptions_lock: - return super().subscribe(callback, event_types, one_shot=one_shot) - - def unsubscribe(self, token: object) -> None: - with self._subscriptions_lock: - super().unsubscribe(token) - - def publish(self, event: Event) -> None: - self.publish_local(event) - - def publish_local(self, event: Event) -> None: - event_type = type(event) - with self._subscriptions_lock: - one_shot_tokens: list[object] = [] - for _token, subscription in self._subscriptions.items(): - if ( - subscription.event_types is None - or event_type in subscription.event_types - ): - self._executor.submit( - self._deliver_event, subscription.callback, event - ) - if subscription.one_shot: - one_shot_tokens.append(subscription.token) - - for token in one_shot_tokens: - super().unsubscribe(token) - - def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None: - try: - func(event) - except BaseException: - self._logger.exception( - "Error delivering %s event", event.__class__.__name__ - ) + async def publish(self, event: Event) -> None: + await self.publish_local(event) diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py index 10ac605..567418d 100644 --- a/src/apscheduler/eventbrokers/mqtt.py +++ b/src/apscheduler/eventbrokers/mqtt.py @@ -2,22 +2,22 @@ from __future__ import annotations import sys from concurrent.futures import Future +from contextlib import AsyncExitStack from typing import Any import attrs +from anyio import to_thread +from anyio.from_thread import BlockingPortal from paho.mqtt.client import Client, MQTTMessage from paho.mqtt.properties import Properties from paho.mqtt.reasoncodes import ReasonCodes from .._events import Event -from ..abc import Serializer -from ..serializers.json import JSONSerializer -from .base import DistributedEventBrokerMixin -from .local import LocalEventBroker +from .base import BaseExternalEventBroker @attrs.define(eq=False) -class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): +class MQTTEventBroker(BaseExternalEventBroker): """ An event broker that uses an MQTT (v3.1 or v5) broker to broadcast events. @@ -26,7 +26,6 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): .. _paho-mqtt: https://pypi.org/project/paho-mqtt/ :param client: a paho-mqtt client - :param serializer: the serializer used to (de)serialize events for transport :param host: host name or IP address to connect to :param port: TCP port number to connect to :param topic: topic on which to send the messages @@ -35,16 +34,17 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): """ client: Client = attrs.field(factory=Client) - serializer: Serializer = attrs.field(factory=JSONSerializer) host: str = attrs.field(kw_only=True, default="localhost") port: int = attrs.field(kw_only=True, default=1883) topic: str = attrs.field(kw_only=True, default="apscheduler") subscribe_qos: int = attrs.field(kw_only=True, default=0) publish_qos: int = attrs.field(kw_only=True, default=0) + _portal: BlockingPortal = attrs.field(init=False) _ready_future: Future[None] = attrs.field(init=False) - def start(self) -> None: - super().start() + async def start(self, exit_stack: AsyncExitStack) -> None: + await super().start(exit_stack) + self._portal = await exit_stack.enter_async_context(BlockingPortal()) self._ready_future = Future() self.client.on_connect = self._on_connect self.client.on_connect_fail = self._on_connect_fail @@ -53,12 +53,9 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): self.client.on_subscribe = self._on_subscribe self.client.connect(self.host, self.port) self.client.loop_start() - self._ready_future.result(10) - - def stop(self, *, force: bool = False) -> None: - self.client.disconnect() - self.client.loop_stop(force=force) - super().stop() + exit_stack.push_async_callback(to_thread.run_sync, self.client.loop_stop) + await to_thread.run_sync(self._ready_future.result, 10) + exit_stack.push_async_callback(to_thread.run_sync, self.client.disconnect) def _on_connect( self, @@ -97,8 +94,10 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): def _on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None: event = self.reconstitute_event(msg.payload) if event is not None: - self.publish_local(event) + self._portal.call(self.publish_local, event) - def publish(self, event: Event) -> None: + async def publish(self, event: Event) -> None: notification = self.generate_notification(event) - self.client.publish(self.topic, notification, qos=self.publish_qos) + await to_thread.run_sync( + lambda: self.client.publish(self.topic, notification, qos=self.publish_qos) + ) diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py index 6683276..10d2343 100644 --- a/src/apscheduler/eventbrokers/redis.py +++ b/src/apscheduler/eventbrokers/redis.py @@ -1,22 +1,22 @@ from __future__ import annotations -from threading import Thread +from asyncio import CancelledError +from contextlib import AsyncExitStack +import anyio import attrs import tenacity -from redis import ConnectionError, ConnectionPool, Redis, RedisCluster -from redis.client import PubSub +from redis import ConnectionError +from redis.asyncio import Redis, RedisCluster +from redis.asyncio.client import PubSub +from redis.asyncio.connection import ConnectionPool -from .. import RetrySettings from .._events import Event -from ..abc import Serializer -from ..serializers.json import JSONSerializer -from .base import DistributedEventBrokerMixin -from .local import LocalEventBroker +from .base import BaseExternalEventBroker @attrs.define(eq=False) -class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): +class RedisEventBroker(BaseExternalEventBroker): """ An event broker that uses a Redis server to broadcast events. @@ -24,8 +24,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): .. _redis: https://pypi.org/project/redis/ - :param client: a (synchronous) Redis client - :param serializer: the serializer used to (de)serialize events for transport + :param client: an asynchronous Redis client :param channel: channel on which to send the messages :param stop_check_interval: interval (in seconds) on which the channel listener should check if it should stop (higher values mean slower reaction time but less @@ -33,12 +32,9 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): """ client: Redis | RedisCluster - serializer: Serializer = attrs.field(factory=JSONSerializer) channel: str = attrs.field(kw_only=True, default="apscheduler") stop_check_interval: float = attrs.field(kw_only=True, default=1) - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) _stopped: bool = attrs.field(init=False, default=True) - _thread: Thread = attrs.field(init=False) @classmethod def from_url(cls, url: str, **kwargs) -> RedisEventBroker: @@ -54,7 +50,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client = Redis(connection_pool=pool) return cls(client, **kwargs) - def _retry(self) -> tenacity.Retrying: + def _retry(self) -> tenacity.AsyncRetrying: def after_attempt(retry_state: tenacity.RetryCallState) -> None: self._logger.warning( f"{self.__class__.__name__}: connection failure " @@ -62,40 +58,32 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): f"{retry_state.outcome.exception()}", ) - return tenacity.Retrying( + return tenacity.AsyncRetrying( stop=self.retry_settings.stop, wait=self.retry_settings.wait, retry=tenacity.retry_if_exception_type(ConnectionError), after=after_attempt, + sleep=anyio.sleep, reraise=True, ) - def start(self) -> None: + async def start(self, exit_stack: AsyncExitStack) -> None: + await super().start(exit_stack) pubsub = self.client.pubsub() - pubsub.subscribe(self.channel) + await pubsub.subscribe(self.channel) + self._stopped = False - super().start() - self._thread = Thread( - target=self._listen_messages, - args=[pubsub], - daemon=True, - name="Redis subscriber", + exit_stack.callback(setattr, self, "_stopped", True) + self._task_group.start_soon( + self._listen_messages, pubsub, name="Redis subscriber" ) - self._thread.start() - - def stop(self, *, force: bool = False) -> None: - self._stopped = True - if not force: - self._thread.join(5) - super().stop(force=force) - - def _listen_messages(self, pubsub: PubSub) -> None: + async def _listen_messages(self, pubsub: PubSub) -> None: while not self._stopped: try: - for attempt in self._retry(): + async for attempt in self._retry(): with attempt: - msg = pubsub.get_message( + msg = await pubsub.get_message( ignore_subscribe_messages=True, timeout=self.stop_check_interval, ) @@ -103,16 +91,19 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): if msg and isinstance(msg["data"], bytes): event = self.reconstitute_event(msg["data"]) if event is not None: - self.publish_local(event) - except Exception: - self._logger.exception(f"{self.__class__.__name__} listener crashed") - pubsub.close() + await self.publish_local(event) + except Exception as exc: + # CancelledError is a subclass of Exception in Python 3.7 + if not isinstance(exc, CancelledError): + self._logger.exception( + f"{self.__class__.__name__} listener crashed" + ) + + await pubsub.close() raise - pubsub.close() - - def publish(self, event: Event) -> None: + async def publish(self, event: Event) -> None: notification = self.generate_notification(event) - for attempt in self._retry(): + async for attempt in self._retry(): with attempt: - self.client.publish(self.channel, notification) + await self.client.publish(self.channel, notification) diff --git a/src/apscheduler/workers/__init__.py b/src/apscheduler/executors/__init__.py index e69de29..e69de29 100644 --- a/src/apscheduler/workers/__init__.py +++ b/src/apscheduler/executors/__init__.py diff --git a/src/apscheduler/executors/async_.py b/src/apscheduler/executors/async_.py new file mode 100644 index 0000000..c7924ad --- /dev/null +++ b/src/apscheduler/executors/async_.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from collections.abc import Callable +from inspect import isawaitable +from typing import Any + +from .._structures import Job +from ..abc import JobExecutor + + +class AsyncJobExecutor(JobExecutor): + """ + Executes functions directly on the event loop thread. + + If the function returns a coroutine object (or another kind of awaitable), that is + awaited on and its return value is used as the job's return value. + """ + + async def run_job(self, func: Callable[..., Any], job: Job) -> Any: + retval = func(*job.args, **job.kwargs) + if isawaitable(retval): + retval = await retval + + return retval diff --git a/src/apscheduler/executors/subprocess.py b/src/apscheduler/executors/subprocess.py new file mode 100644 index 0000000..e766e71 --- /dev/null +++ b/src/apscheduler/executors/subprocess.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from collections.abc import Callable +from contextlib import AsyncExitStack +from functools import partial +from typing import Any + +import attrs +from anyio import CapacityLimiter, to_process + +from .._structures import Job +from ..abc import JobExecutor + + +@attrs.define(eq=False, kw_only=True) +class ProcessPoolJobExecutor(JobExecutor): + """ + Executes functions in a process pool. + + :param max_workers: the maximum number of worker processes to keep + """ + + max_workers: int = 40 + _limiter: CapacityLimiter = attrs.field(init=False) + + async def start(self, exit_stack: AsyncExitStack) -> None: + self._limiter = CapacityLimiter(self.max_workers) + + async def run_job(self, func: Callable[..., Any], job: Job) -> Any: + wrapped = partial(func, *job.args, **job.kwargs) + return await to_process.run_sync( + wrapped, cancellable=True, limiter=self._limiter + ) diff --git a/src/apscheduler/executors/thread.py b/src/apscheduler/executors/thread.py new file mode 100644 index 0000000..9774093 --- /dev/null +++ b/src/apscheduler/executors/thread.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from collections.abc import Callable +from contextlib import AsyncExitStack +from functools import partial +from typing import Any + +import attrs +from anyio import CapacityLimiter, to_thread + +from .._structures import Job +from ..abc import JobExecutor + + +@attrs.define(eq=False, kw_only=True) +class ThreadPoolJobExecutor(JobExecutor): + """ + Executes functions in a thread pool. + + :param max_workers: the maximum number of worker threads to keep + """ + + max_workers: int = 40 + _limiter: CapacityLimiter = attrs.field(init=False) + + async def start(self, exit_stack: AsyncExitStack) -> None: + self._limiter = CapacityLimiter(self.max_workers) + + async def run_job(self, func: Callable[..., Any], job: Job) -> Any: + wrapped = partial(func, *job.args, **job.kwargs) + return await to_thread.run_sync(wrapped, limiter=self._limiter) diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index 75972b5..44fee27 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -5,6 +5,7 @@ import platform import random import sys from asyncio import CancelledError +from collections.abc import MutableMapping from contextlib import AsyncExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger @@ -16,9 +17,9 @@ import anyio import attrs from anyio import TASK_STATUS_IGNORED, create_task_group, move_on_after from anyio.abc import TaskGroup, TaskStatus +from attr.validators import instance_of -from .._context import current_scheduler -from .._converters import as_async_datastore, as_async_eventbroker +from .._context import current_async_scheduler from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState from .._events import ( Event, @@ -35,11 +36,14 @@ from .._exceptions import ( ScheduleLookupError, ) from .._structures import Job, JobResult, Schedule, Task -from ..abc import AsyncDataStore, AsyncEventBroker, Subscription, Trigger +from .._worker import Worker +from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger from ..datastores.memory import MemoryDataStore -from ..eventbrokers.async_local import LocalAsyncEventBroker +from ..eventbrokers.local import LocalEventBroker +from ..executors.async_ import AsyncJobExecutor +from ..executors.subprocess import ProcessPoolJobExecutor +from ..executors.thread import ThreadPoolJobExecutor from ..marshalling import callable_to_ref -from ..workers.async_ import AsyncWorker if sys.version_info >= (3, 11): from typing import Self @@ -54,19 +58,24 @@ _zero_timedelta = timedelta() class AsyncScheduler: """An asynchronous (AnyIO based) scheduler implementation.""" - data_store: AsyncDataStore = attrs.field( - converter=as_async_datastore, factory=MemoryDataStore + data_store: DataStore = attrs.field( + validator=instance_of(DataStore), factory=MemoryDataStore ) - event_broker: AsyncEventBroker = attrs.field( - converter=as_async_eventbroker, factory=LocalAsyncEventBroker + event_broker: EventBroker = attrs.field( + validator=instance_of(EventBroker), factory=LocalEventBroker ) identity: str = attrs.field(kw_only=True, default=None) - start_worker: bool = attrs.field(kw_only=True, default=True) + process_jobs: bool = attrs.field(kw_only=True, default=True) + job_executors: MutableMapping[str, JobExecutor] | None = attrs.field( + kw_only=True, default=None + ) + default_job_executor: str | None = attrs.field(kw_only=True, default=None) + process_schedules: bool = attrs.field(kw_only=True, default=True) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) _state: RunState = attrs.field(init=False, default=RunState.stopped) _task_group: TaskGroup | None = attrs.field(init=False, default=None) - _exit_stack: AsyncExitStack | None = attrs.field(init=False, default=None) + _exit_stack: AsyncExitStack = attrs.field(init=False, factory=AsyncExitStack) _services_initialized: bool = attrs.field(init=False, default=False) _wakeup_event: anyio.Event = attrs.field(init=False) _wakeup_deadline: datetime | None = attrs.field(init=False, default=None) @@ -76,13 +85,32 @@ class AsyncScheduler: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" + if not self.job_executors: + self.job_executors = { + "async": AsyncJobExecutor(), + "threadpool": ThreadPoolJobExecutor(), + "processpool": ProcessPoolJobExecutor(), + } + + if not self.default_job_executor: + self.default_job_executor = next(iter(self.job_executors)) + elif self.default_job_executor not in self.job_executors: + raise ValueError( + "default_job_executor must be one of the given job executors" + ) + async def __aenter__(self: Self) -> Self: - self._exit_stack = AsyncExitStack() await self._exit_stack.__aenter__() - await self._ensure_services_ready(self._exit_stack) - self._task_group = await self._exit_stack.enter_async_context( - create_task_group() - ) + try: + await self._ensure_services_initialized(self._exit_stack) + self._task_group = await self._exit_stack.enter_async_context( + create_task_group() + ) + self._exit_stack.callback(setattr, self, "_task_group", None) + except BaseException as exc: + await self._exit_stack.__aexit__(type(exc), exc, exc.__traceback__) + raise + return self async def __aexit__( @@ -93,27 +121,25 @@ class AsyncScheduler: ) -> None: await self.stop() await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - self._task_group = None - async def _ensure_services_ready(self, exit_stack: AsyncExitStack) -> None: + async def _ensure_services_initialized(self, exit_stack: AsyncExitStack) -> None: """Ensure that the data store and event broker have been initialized.""" if not self._services_initialized: self._services_initialized = True exit_stack.callback(setattr, self, "_services_initialized", False) - # Initialize the event broker - await self.event_broker.start() - exit_stack.push_async_exit( - lambda *exc_info: self.event_broker.stop(force=exc_info[0] is not None) - ) + await self.event_broker.start(exit_stack) + await self.data_store.start(exit_stack, self.event_broker) - # Initialize the data store - await self.data_store.start(self.event_broker) - exit_stack.push_async_exit( - lambda *exc_info: self.data_store.stop(force=exc_info[0] is not None) + def _check_initialized(self) -> None: + if not self._services_initialized: + raise RuntimeError( + "The scheduler has not been initialized yet. Use the scheduler as an " + "async context manager (async with ...) in order to call methods other " + "than run_until_complete()." ) - def _schedule_added_or_modified(self, event: Event) -> None: + async def _schedule_added_or_modified(self, event: Event) -> None: event_ = cast("ScheduleAdded | ScheduleUpdated", event) if not self._wakeup_deadline or ( event_.next_fire_time and event_.next_fire_time < self._wakeup_deadline @@ -128,6 +154,35 @@ class AsyncScheduler: """The current running state of the scheduler.""" return self._state + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + is_async: bool = True, + ) -> Subscription: + """ + Subscribe to events. + + To unsubscribe, call the :meth:`Subscription.unsubscribe` method on the returned + object. + + :param callback: callable to be called with the event object when an event is + published + :param event_types: an iterable of concrete Event classes to subscribe to + :param one_shot: if ``True``, automatically unsubscribe after the first matching + event + :param is_async: ``True`` if the (synchronous) callback should be called on the + event loop thread, ``False`` if it should be called in a worker thread. + If the callback is a coroutine function, this flag is ignored. + + """ + self._check_initialized() + return self.event_broker.subscribe( + callback, event_types, is_async=is_async, one_shot=one_shot + ) + async def add_schedule( self, func_or_task_id: str | Callable, @@ -136,6 +191,7 @@ class AsyncScheduler: id: str | None = None, args: Iterable | None = None, kwargs: Mapping[str, Any] | None = None, + job_executor: str | None = None, coalesce: CoalescePolicy = CoalescePolicy.latest, misfire_grace_time: float | timedelta | None = None, max_jitter: float | timedelta | None = None, @@ -152,6 +208,7 @@ class AsyncScheduler: based ID will be assigned) :param args: positional arguments to be passed to the task function :param kwargs: keyword arguments to be passed to the task function + :param job_executor: name of the job executor to run the task with :param coalesce: determines what to do when processing the schedule if multiple fire times have become due for this schedule since the last processing :param misfire_grace_time: maximum number of seconds the scheduled job's actual @@ -165,6 +222,7 @@ class AsyncScheduler: :return: the ID of the newly added schedule """ + self._check_initialized() id = id or str(uuid4()) args = tuple(args or ()) kwargs = dict(kwargs or {}) @@ -173,7 +231,11 @@ class AsyncScheduler: misfire_grace_time = timedelta(seconds=misfire_grace_time) if callable(func_or_task_id): - task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) + task = Task( + id=callable_to_ref(func_or_task_id), + func=func_or_task_id, + executor=job_executor or self.default_job_executor, + ) await self.data_store.add_task(task) else: task = await self.data_store.get_task(func_or_task_id) @@ -207,6 +269,7 @@ class AsyncScheduler: :raises ScheduleLookupError: if the schedule could not be found """ + self._check_initialized() schedules = await self.data_store.get_schedules({id}) if schedules: return schedules[0] @@ -220,6 +283,7 @@ class AsyncScheduler: :return: a list of schedules, in an unspecified order """ + self._check_initialized() return await self.data_store.get_schedules() async def remove_schedule(self, id: str) -> None: @@ -229,6 +293,7 @@ class AsyncScheduler: :param id: the unique identifier of the schedule """ + self._check_initialized() await self.data_store.remove_schedules({id}) async def add_job( @@ -237,6 +302,7 @@ class AsyncScheduler: *, args: Iterable | None = None, kwargs: Mapping[str, Any] | None = None, + job_executor: str | None = None, tags: Iterable[str] | None = None, result_expiration_time: timedelta | float = 0, ) -> UUID: @@ -244,8 +310,10 @@ class AsyncScheduler: Add a job to the data store. :param func_or_task_id: + :param job_executor: name of the job executor to run the task with :param args: positional arguments to call the target callable with :param kwargs: keyword arguments to call the target callable with + :param job_executor: name of the job executor to run the task with :param tags: strings that can be used to categorize and filter the job :param result_expiration_time: the minimum time (as seconds, or timedelta) to keep the result of the job available for fetching (the result won't be @@ -253,8 +321,13 @@ class AsyncScheduler: :return: the ID of the newly created job """ + self._check_initialized() if callable(func_or_task_id): - task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) + task = Task( + id=callable_to_ref(func_or_task_id), + func=func_or_task_id, + executor=job_executor or self.default_job_executor, + ) await self.data_store.add_task(task) else: task = await self.data_store.get_task(func_or_task_id) @@ -280,13 +353,14 @@ class AsyncScheduler: the data store """ + self._check_initialized() wait_event = anyio.Event() def listener(event: JobReleased) -> None: if event.job_id == job_id: wait_event.set() - with self.data_store.events.subscribe(listener, {JobReleased}): + with self.event_broker.subscribe(listener, {JobReleased}): result = await self.data_store.get_job_result(job_id) if result: return result @@ -303,6 +377,7 @@ class AsyncScheduler: *, args: Iterable | None = None, kwargs: Mapping[str, Any] | None = None, + job_executor: str | None = None, tags: Iterable[str] | None = (), ) -> Any: """ @@ -314,10 +389,12 @@ class AsyncScheduler: definition :param args: positional arguments to be passed to the task function :param kwargs: keyword arguments to be passed to the task function + :param job_executor: name of the job executor to run the task with :param tags: strings that can be used to categorize and filter the job :returns: the return value of the task function """ + self._check_initialized() job_complete_event = anyio.Event() def listener(event: JobReleased) -> None: @@ -325,11 +402,12 @@ class AsyncScheduler: job_complete_event.set() job_id: UUID | None = None - with self.data_store.events.subscribe(listener, {JobReleased}): + with self.event_broker.subscribe(listener, {JobReleased}): job_id = await self.add_job( func_or_task_id, args=args, kwargs=kwargs, + job_executor=job_executor, tags=tags, result_expiration_time=timedelta(minutes=15), ) @@ -378,12 +456,7 @@ class AsyncScheduler: await event.wait() async def start_in_background(self) -> None: - if self._task_group is None: - raise RuntimeError( - "The scheduler must be used as an async context manager (async with " - "...) in order to be startable in the background" - ) - + self._check_initialized() await self._task_group.start(self.run_until_stopped) async def run_until_stopped( @@ -398,7 +471,7 @@ class AsyncScheduler: self._state = RunState.starting async with AsyncExitStack() as exit_stack: self._wakeup_event = anyio.Event() - await self._ensure_services_ready(exit_stack) + await self._ensure_services_initialized(exit_stack) # Wake up the scheduler if the data store emits a significant schedule event exit_stack.enter_context( @@ -407,14 +480,16 @@ class AsyncScheduler: ) ) + # Set this scheduler as the current scheduler + token = current_async_scheduler.set(self) + exit_stack.callback(current_async_scheduler.reset, token) + # Start the built-in worker, if configured to do so - if self.start_worker: - token = current_scheduler.set(self) - exit_stack.callback(current_scheduler.reset, token) - worker = AsyncWorker( - self.data_store, self.event_broker, is_internal=True + if self.process_jobs: + worker = Worker(job_executors=self.job_executors) + await worker.start( + exit_stack, self.data_store, self.event_broker, self.identity ) - await exit_stack.enter_async_context(worker) # Signal that the scheduler has started self._state = RunState.started diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index d98161a..3a812a4 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -1,77 +1,100 @@ from __future__ import annotations import atexit -import os -import platform -import random +import logging import sys import threading -from concurrent.futures import Future +from collections.abc import MutableMapping from contextlib import ExitStack -from datetime import datetime, timedelta, timezone -from logging import Logger, getLogger +from datetime import timedelta +from functools import partial +from logging import Logger from types import TracebackType -from typing import Any, Callable, Iterable, Mapping, cast -from uuid import UUID, uuid4 - -import attrs - -from .._context import current_scheduler -from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState -from .._events import ( - Event, - JobReleased, - ScheduleAdded, - SchedulerStarted, - SchedulerStopped, - ScheduleUpdated, -) -from .._exceptions import ( - JobCancelled, - JobDeadlineMissed, - JobLookupError, - ScheduleLookupError, -) -from .._structures import Job, JobResult, Schedule, Task -from ..abc import DataStore, EventBroker, Trigger -from ..datastores.memory import MemoryDataStore -from ..eventbrokers.local import LocalEventBroker -from ..marshalling import callable_to_ref -from ..workers.sync import Worker +from typing import Any, Callable, Iterable, Mapping +from uuid import UUID + +from anyio import start_blocking_portal +from anyio.from_thread import BlockingPortal + +from .. import Event, current_scheduler +from .._enums import CoalescePolicy, ConflictPolicy, RunState +from .._structures import JobResult, Schedule +from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger +from .async_ import AsyncScheduler if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self -_microsecond_delta = timedelta(microseconds=1) -_zero_timedelta = timedelta() - -@attrs.define(eq=False) class Scheduler: """A synchronous scheduler implementation.""" - data_store: DataStore = attrs.field(factory=MemoryDataStore) - event_broker: EventBroker = attrs.field(factory=LocalEventBroker) - identity: str = attrs.field(kw_only=True, default=None) - start_worker: bool = attrs.field(kw_only=True, default=True) - logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) + def __init__( + self, + data_store: DataStore | None = None, + event_broker: EventBroker | None = None, + *, + identity: str | None = None, + process_schedules: bool = True, + start_worker: bool = True, + job_executors: Mapping[str, JobExecutor] | None = None, + default_job_executor: str | None = None, + logger: Logger | None = None, + ): + kwargs: dict[str, Any] = {} + if data_store is not None: + kwargs["data_store"] = data_store + if event_broker is not None: + kwargs["event_broker"] = event_broker + + if not default_job_executor and not job_executors: + default_job_executor = "threadpool" + + self._async_scheduler = AsyncScheduler( + identity=identity, + process_schedules=process_schedules, + process_jobs=start_worker, + job_executors=job_executors, + default_job_executor=default_job_executor, + logger=logger or logging.getLogger(__name__), + **kwargs, + ) + self._exit_stack = ExitStack() + self._portal: BlockingPortal | None = None + self._lock = threading.RLock() - _state: RunState = attrs.field(init=False, default=RunState.stopped) - _thread: threading.Thread | None = attrs.field(init=False, default=None) - _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event) - _wakeup_deadline: datetime | None = attrs.field(init=False, default=None) - _services_initialized: bool = attrs.field(init=False, default=False) - _exit_stack: ExitStack | None = attrs.field(init=False, default=None) - _lock: threading.RLock = attrs.field(init=False, factory=threading.RLock) + @property + def data_store(self) -> DataStore: + return self._async_scheduler.data_store - def __attrs_post_init__(self) -> None: - if not self.identity: - self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" + @property + def event_broker(self) -> EventBroker: + return self._async_scheduler.event_broker + + @property + def identity(self) -> str: + return self._async_scheduler.identity + + @property + def process_schedules(self) -> bool: + return self._async_scheduler.process_schedules + + @property + def start_worker(self) -> bool: + return self._async_scheduler.process_jobs + + @property + def job_executors(self) -> MutableMapping[str, JobExecutor]: + return self._async_scheduler.job_executors + + @property + def state(self) -> RunState: + """The current running state of the scheduler.""" + return self._async_scheduler.state def __enter__(self: Self) -> Self: - self._exit_stack = ExitStack() self._ensure_services_ready(self._exit_stack) return self @@ -81,13 +104,12 @@ class Scheduler: exc_val: BaseException, exc_tb: TracebackType, ) -> None: - self.stop() self._exit_stack.__exit__(exc_type, exc_val, exc_tb) def _ensure_services_ready(self, exit_stack: ExitStack | None = None) -> None: - """Ensure that the data store and event broker have been initialized.""" + """Ensure that the underlying asynchronous scheduler has been initialized.""" with self._lock: - if not self._services_initialized: + if self._portal is None: if exit_stack is None: if self._exit_stack is None: self._exit_stack = exit_stack = ExitStack() @@ -95,43 +117,39 @@ class Scheduler: else: exit_stack = self._exit_stack - self._services_initialized = True - exit_stack.callback(setattr, self, "_services_initialized", False) + # Set this scheduler as the current synchronous scheduler + token = current_scheduler.set(self) + exit_stack.callback(current_scheduler.reset, token) - self.event_broker.start() - exit_stack.push( - lambda *exc_info: self.event_broker.stop( - force=exc_info[0] is not None - ) + self._portal = exit_stack.enter_context(start_blocking_portal()) + exit_stack.callback(setattr, self, "_portal", None) + exit_stack.enter_context( + self._portal.wrap_async_context_manager(self._async_scheduler) ) - # Initialize the data store - self.data_store.start(self.event_broker) - exit_stack.push( - lambda *exc_info: self.data_store.stop( - force=exc_info[0] is not None - ) - ) + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + """ + Subscribe to events. - def _schedule_added_or_modified(self, event: Event) -> None: - event_ = cast("ScheduleAdded | ScheduleUpdated", event) - if not self._wakeup_deadline or ( - event_.next_fire_time and event_.next_fire_time < self._wakeup_deadline - ): - self.logger.debug( - "Detected a %s event – waking up the scheduler", type(event).__name__ - ) - self._wakeup_event.set() + To unsubscribe, call the :meth:`Subscription.unsubscribe` method on the returned + object. - def _join_thread(self) -> None: - if self._thread: - self._thread.join() - self._thread = None + :param callback: callable to be called with the event object when an event is + published + :param event_types: an iterable of concrete Event classes to subscribe to + :param one_shot: if ``True``, automatically unsubscribe after the first matching + event - @property - def state(self) -> RunState: - """The current running state of the scheduler.""" - return self._state + """ + return self.data_store.event_broker.subscribe( + callback, event_types, is_async=False, one_shot=one_shot + ) def add_schedule( self, @@ -141,104 +159,42 @@ class Scheduler: id: str | None = None, args: Iterable | None = None, kwargs: Mapping[str, Any] | None = None, + job_executor: str | None = None, coalesce: CoalescePolicy = CoalescePolicy.latest, misfire_grace_time: float | timedelta | None = None, max_jitter: float | timedelta | None = None, tags: Iterable[str] | None = None, conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing, ) -> str: - """ - Schedule a task to be run one or more times in the future. - - :param func_or_task_id: either a callable or an ID of an existing task - definition - :param trigger: determines the times when the task should be run - :param id: an explicit identifier for the schedule (if omitted, a random, UUID - based ID will be assigned) - :param args: positional arguments to be passed to the task function - :param kwargs: keyword arguments to be passed to the task function - :param coalesce: determines what to do when processing the schedule if multiple - fire times have become due for this schedule since the last processing - :param misfire_grace_time: maximum number of seconds the scheduled job's actual - run time is allowed to be late, compared to the scheduled run time - :param max_jitter: maximum number of seconds to randomly add to the scheduled - time for each job created from this schedule - :param tags: strings that can be used to categorize and filter the schedule and - its derivative jobs - :param conflict_policy: determines what to do if a schedule with the same ID - already exists in the data store - :return: the ID of the newly added schedule - - """ self._ensure_services_ready() - id = id or str(uuid4()) - args = tuple(args or ()) - kwargs = dict(kwargs or {}) - tags = frozenset(tags or ()) - if isinstance(misfire_grace_time, (int, float)): - misfire_grace_time = timedelta(seconds=misfire_grace_time) - - if callable(func_or_task_id): - task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) - self.data_store.add_task(task) - else: - task = self.data_store.get_task(func_or_task_id) - - schedule = Schedule( - id=id, - task_id=task.id, - trigger=trigger, - args=args, - kwargs=kwargs, - coalesce=coalesce, - misfire_grace_time=misfire_grace_time, - max_jitter=max_jitter, - tags=tags, - ) - schedule.next_fire_time = trigger.next() - self.data_store.add_schedule(schedule, conflict_policy) - self.logger.info( - "Added new schedule (task=%r, trigger=%r); next run time at %s", - task, - trigger, - schedule.next_fire_time, + return self._portal.call( + partial( + self._async_scheduler.add_schedule, + func_or_task_id, + trigger, + id=id, + args=args, + kwargs=kwargs, + job_executor=job_executor, + coalesce=coalesce, + misfire_grace_time=misfire_grace_time, + max_jitter=max_jitter, + tags=tags, + conflict_policy=conflict_policy, + ) ) - return schedule.id def get_schedule(self, id: str) -> Schedule: - """ - Retrieve a schedule from the data store. - - :param id: the unique identifier of the schedule - :raises ScheduleLookupError: if the schedule could not be found - - """ self._ensure_services_ready() - schedules = self.data_store.get_schedules({id}) - if schedules: - return schedules[0] - else: - raise ScheduleLookupError(id) + return self._portal.call(self._async_scheduler.get_schedule, id) def get_schedules(self) -> list[Schedule]: - """ - Retrieve all schedules from the data store. - - :return: a list of schedules, in an unspecified order - - """ self._ensure_services_ready() - return self.data_store.get_schedules() + return self._portal.call(self._async_scheduler.get_schedules) def remove_schedule(self, id: str) -> None: - """ - Remove the given schedule from the data store. - - :param id: the unique identifier of the schedule - - """ self._ensure_services_ready() - self.data_store.remove_schedules({id}) + self._portal.call(self._async_scheduler.remove_schedule, id) def add_job( self, @@ -246,68 +202,28 @@ class Scheduler: *, args: Iterable | None = None, kwargs: Mapping[str, Any] | None = None, + job_executor: str | None = None, tags: Iterable[str] | None = None, result_expiration_time: timedelta | float = 0, ) -> UUID: - """ - Add a job to the data store. - - :param func_or_task_id: either a callable or an ID of an existing task - definition - :param args: positional arguments to be passed to the task function - :param kwargs: keyword arguments to be passed to the task function - :param tags: strings that can be used to categorize and filter the job - :param result_expiration_time: the minimum time (as seconds, or timedelta) to - keep the result of the job available for fetching (the result won't be - saved at all if that time is 0) - :return: the ID of the newly created job - - """ self._ensure_services_ready() - if callable(func_or_task_id): - task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) - self.data_store.add_task(task) - else: - task = self.data_store.get_task(func_or_task_id) - - job = Job( - task_id=task.id, - args=args or (), - kwargs=kwargs or {}, - tags=tags or frozenset(), - result_expiration_time=result_expiration_time, + return self._portal.call( + partial( + self._async_scheduler.add_job, + func_or_task_id, + args=args, + kwargs=kwargs, + job_executor=job_executor, + tags=tags, + result_expiration_time=result_expiration_time, + ) ) - self.data_store.add_job(job) - return job.id def get_job_result(self, job_id: UUID, *, wait: bool = True) -> JobResult: - """ - Retrieve the result of a job. - - :param job_id: the ID of the job - :param wait: if ``True``, wait until the job has ended (one way or another), - ``False`` to raise an exception if the result is not yet available - :raises JobLookupError: if ``wait=False`` and the job result does not exist in - the data store - - """ self._ensure_services_ready() - wait_event = threading.Event() - - def listener(event: JobReleased) -> None: - if event.job_id == job_id: - wait_event.set() - - with self.data_store.events.subscribe(listener, {JobReleased}, one_shot=True): - result = self.data_store.get_job_result(job_id) - if result: - return result - elif not wait: - raise JobLookupError(job_id) - - wait_event.wait() - - return self.data_store.get_job_result(job_id) + return self._portal.call( + partial(self._async_scheduler.get_job_result, job_id, wait=wait) + ) def run_job( self, @@ -315,50 +231,20 @@ class Scheduler: *, args: Iterable | None = None, kwargs: Mapping[str, Any] | None = None, + job_executor: str | None = None, tags: Iterable[str] | None = (), ) -> Any: - """ - Convenience method to add a job and then return its result. - - If the job raised an exception, that exception will be reraised here. - - :param func_or_task_id: either a callable or an ID of an existing task - definition - :param args: positional arguments to be passed to the task function - :param kwargs: keyword arguments to be passed to the task function - :param tags: strings that can be used to categorize and filter the job - :returns: the return value of the task function - - """ self._ensure_services_ready() - job_complete_event = threading.Event() - - def listener(event: JobReleased) -> None: - if event.job_id == job_id: - job_complete_event.set() - - job_id: UUID | None = None - with self.data_store.events.subscribe(listener, {JobReleased}): - job_id = self.add_job( + return self._portal.call( + partial( + self._async_scheduler.run_job, func_or_task_id, args=args, kwargs=kwargs, + job_executor=job_executor, tags=tags, - result_expiration_time=timedelta(minutes=15), ) - job_complete_event.wait() - - result = self.get_job_result(job_id) - if result.outcome is JobOutcome.success: - return result.return_value - elif result.outcome is JobOutcome.error: - raise result.exception - elif result.outcome is JobOutcome.missed_start_deadline: - raise JobDeadlineMissed - elif result.outcome is JobOutcome.cancelled: - raise JobCancelled - else: - raise RuntimeError(f"Unknown job outcome: {result.outcome}") + ) def start_in_background(self) -> None: """ @@ -370,241 +256,19 @@ class Scheduler: :raises RuntimeError: if the scheduler is not in the ``stopped`` state """ - with self._lock: - if self._state is not RunState.stopped: - raise RuntimeError( - f'Cannot start the scheduler when it is in the "{self._state}" ' - f"state" - ) - - self._state = RunState.starting - - start_future: Future[None] = Future() - self._thread = threading.Thread( - target=self._run, args=[start_future], daemon=True - ) - self._thread.start() - try: - start_future.result() - except BaseException: - self._thread = None - raise - - self._exit_stack.callback(self._join_thread) - self._exit_stack.callback(self.stop) + self._ensure_services_ready() + self._portal.call(self._async_scheduler.start_in_background) def stop(self) -> None: - """ - Signal the scheduler that it should stop processing schedules. - - This method does not wait for the scheduler to actually stop. - For that, see :meth:`wait_until_stopped`. - - """ - with self._lock: - if self._state is RunState.started: - self._state = RunState.stopping - self._wakeup_event.set() + if self._portal is not None: + self._portal.call(self._async_scheduler.stop) def wait_until_stopped(self) -> None: - """ - Wait until the scheduler is in the "stopped" or "stopping" state. - - If the scheduler is already stopped or in the process of stopping, this method - returns immediately. Otherwise, it waits until the scheduler posts the - ``SchedulerStopped`` event. - - """ - with self._lock: - if self._state in (RunState.stopped, RunState.stopping): - return - - event = threading.Event() - sub = self.event_broker.subscribe( - lambda ev: event.set(), {SchedulerStopped}, one_shot=True - ) - - with sub: - event.wait() + if self._portal is not None: + self._portal.call(self._async_scheduler.wait_until_stopped) def run_until_stopped(self) -> None: - """ - Run the scheduler (and its internal worker) until it is explicitly stopped. - - This method will only return if :meth:`stop` is called. - - """ - with self._lock: - if self._state is not RunState.stopped: - raise RuntimeError( - f'Cannot start the scheduler when it is in the "{self._state}" ' - f"state" - ) - - self._state = RunState.starting - - self._run(None) - - def _run(self, start_future: Future[None] | None) -> None: - assert self._state is RunState.starting - with self._exit_stack.pop_all() as exit_stack: - try: - self._ensure_services_ready(exit_stack) - - # Wake up the scheduler if the data store emits a significant schedule - # event - exit_stack.enter_context( - self.data_store.events.subscribe( - self._schedule_added_or_modified, - {ScheduleAdded, ScheduleUpdated}, - ) - ) - - # Start the built-in worker, if configured to do so - if self.start_worker: - token = current_scheduler.set(self) - exit_stack.callback(current_scheduler.reset, token) - worker = Worker( - self.data_store, self.event_broker, is_internal=True - ) - exit_stack.enter_context(worker) - - # Signal that the scheduler has started - self._state = RunState.started - self.event_broker.publish_local(SchedulerStarted()) - except BaseException as exc: - if start_future: - start_future.set_exception(exc) - return - else: - raise - else: - if start_future: - start_future.set_result(None) - - exception: BaseException | None = None - try: - while self._state is RunState.started: - schedules = self.data_store.acquire_schedules(self.identity, 100) - self.logger.debug( - "Processing %d schedules retrieved from the data store", - len(schedules), - ) - now = datetime.now(timezone.utc) - for schedule in schedules: - # Calculate a next fire time for the schedule, if possible - fire_times = [schedule.next_fire_time] - calculate_next = schedule.trigger.next - while True: - try: - fire_time = calculate_next() - except Exception: - self.logger.exception( - "Error computing next fire time for schedule %r of " - "task %r – removing schedule", - schedule.id, - schedule.task_id, - ) - break - - # Stop if the calculated fire time is in the future - if fire_time is None or fire_time > now: - schedule.next_fire_time = fire_time - break - - # Only keep all the fire times if coalesce policy = "all" - if schedule.coalesce is CoalescePolicy.all: - fire_times.append(fire_time) - elif schedule.coalesce is CoalescePolicy.latest: - fire_times[0] = fire_time - - # Add one or more jobs to the job queue - max_jitter = ( - schedule.max_jitter.total_seconds() - if schedule.max_jitter - else 0 - ) - for i, fire_time in enumerate(fire_times): - # Calculate a jitter if max_jitter > 0 - jitter = _zero_timedelta - if max_jitter: - if i + 1 < len(fire_times): - next_fire_time = fire_times[i + 1] - else: - next_fire_time = schedule.next_fire_time - - if next_fire_time is not None: - # Jitter must never be so high that it would cause - # a fire time to equal or exceed the next fire time - jitter_s = min( - [ - max_jitter, - ( - next_fire_time - - fire_time - - _microsecond_delta - ).total_seconds(), - ] - ) - jitter = timedelta( - seconds=random.uniform(0, jitter_s) - ) - fire_time += jitter - - schedule.last_fire_time = fire_time - job = Job( - task_id=schedule.task_id, - args=schedule.args, - kwargs=schedule.kwargs, - schedule_id=schedule.id, - scheduled_fire_time=fire_time, - jitter=jitter, - start_deadline=schedule.next_deadline, - tags=schedule.tags, - ) - self.data_store.add_job(job) - - # Update the schedules (and release the scheduler's claim on them) - self.data_store.release_schedules(self.identity, schedules) - - # If we received fewer schedules than the maximum amount, sleep - # until the next schedule is due or the scheduler is explicitly - # woken up - wait_time = None - if len(schedules) < 100: - self._wakeup_deadline = ( - self.data_store.get_next_schedule_run_time() - ) - if self._wakeup_deadline: - wait_time = ( - self._wakeup_deadline - datetime.now(timezone.utc) - ).total_seconds() - self.logger.debug( - "Sleeping %.3f seconds until the next fire time (%s)", - wait_time, - self._wakeup_deadline, - ) - else: - self.logger.debug("Waiting for any due schedules to appear") - - if self._wakeup_event.wait(wait_time): - self._wakeup_event = threading.Event() - else: - self.logger.debug( - "Processing more schedules on the next iteration" - ) - except BaseException as exc: - exception = exc - raise - finally: - self._state = RunState.stopped - if isinstance(exception, Exception): - self.logger.exception("Scheduler crashed") - elif exception: - self.logger.info( - f"Scheduler stopped due to {exception.__class__.__name__}" - ) - else: - self.logger.info("Scheduler stopped") - - self.event_broker.publish_local(SchedulerStopped(exception=exception)) + with ExitStack() as exit_stack: + # Run the async scheduler + self._ensure_services_ready(exit_stack) + self._portal.call(self._async_scheduler.run_until_stopped) diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py deleted file mode 100644 index 4ef1bff..0000000 --- a/src/apscheduler/workers/async_.py +++ /dev/null @@ -1,251 +0,0 @@ -from __future__ import annotations - -import os -import platform -from asyncio import CancelledError -from contextlib import AsyncExitStack -from datetime import datetime, timezone -from inspect import isawaitable -from logging import Logger, getLogger -from types import TracebackType -from typing import Callable -from uuid import UUID - -import anyio -import attrs -from anyio import ( - TASK_STATUS_IGNORED, - create_task_group, - get_cancelled_exc_class, - move_on_after, -) -from anyio.abc import CancelScope, TaskGroup - -from .._context import current_job, current_worker -from .._converters import as_async_datastore, as_async_eventbroker -from .._enums import JobOutcome, RunState -from .._events import JobAdded, JobReleased, WorkerStarted, WorkerStopped -from .._structures import Job, JobInfo, JobResult -from .._validators import positive_integer -from ..abc import AsyncDataStore, AsyncEventBroker -from ..eventbrokers.async_local import LocalAsyncEventBroker - - -@attrs.define(eq=False) -class AsyncWorker: - """Runs jobs locally in a task group.""" - - data_store: AsyncDataStore = attrs.field(converter=as_async_datastore) - event_broker: AsyncEventBroker = attrs.field( - converter=as_async_eventbroker, factory=LocalAsyncEventBroker - ) - max_concurrent_jobs: int = attrs.field( - kw_only=True, validator=positive_integer, default=100 - ) - identity: str = attrs.field(kw_only=True, default=None) - logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) - # True if a scheduler owns this worker - _is_internal: bool = attrs.field(kw_only=True, default=False) - - _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: anyio.Event = attrs.field(init=False) - _task_group: TaskGroup = attrs.field(init=False) - _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) - _running_jobs: set[UUID] = attrs.field(init=False, factory=set) - - def __attrs_post_init__(self) -> None: - if not self.identity: - self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" - - async def __aenter__(self) -> AsyncWorker: - self._task_group = create_task_group() - await self._task_group.__aenter__() - await self._task_group.start(self.run_until_stopped) - return self - - async def __aexit__( - self, - exc_type: type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, - ) -> None: - self._state = RunState.stopping - self._wakeup_event.set() - await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - del self._task_group - del self._wakeup_event - - @property - def state(self) -> RunState: - """The current running state of the worker.""" - return self._state - - async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None: - """Run the worker until it is explicitly stopped.""" - if self._state is not RunState.stopped: - raise RuntimeError( - f'Cannot start the worker when it is in the "{self._state}" ' f"state" - ) - - self._state = RunState.starting - self._wakeup_event = anyio.Event() - async with AsyncExitStack() as exit_stack: - if not self._is_internal: - # Initialize the event broker - await self.event_broker.start() - exit_stack.push_async_exit( - lambda *exc_info: self.event_broker.stop( - force=exc_info[0] is not None - ) - ) - - # Initialize the data store - await self.data_store.start(self.event_broker) - exit_stack.push_async_exit( - lambda *exc_info: self.data_store.stop( - force=exc_info[0] is not None - ) - ) - - # Set the current worker - token = current_worker.set(self) - exit_stack.callback(current_worker.reset, token) - - # Wake up the worker if the data store emits a significant job event - self.event_broker.subscribe( - lambda event: self._wakeup_event.set(), {JobAdded} - ) - - # Signal that the worker has started - self._state = RunState.started - task_status.started() - exception: BaseException | None = None - try: - await self.event_broker.publish_local(WorkerStarted()) - - async with create_task_group() as tg: - while self._state is RunState.started: - limit = self.max_concurrent_jobs - len(self._running_jobs) - jobs = await self.data_store.acquire_jobs(self.identity, limit) - for job in jobs: - task = await self.data_store.get_task(job.task_id) - self._running_jobs.add(job.id) - tg.start_soon(self._run_job, job, task.func) - - await self._wakeup_event.wait() - self._wakeup_event = anyio.Event() - except get_cancelled_exc_class(): - pass - except BaseException as exc: - exception = exc - raise - finally: - self._state = RunState.stopped - - # CancelledError is a subclass of Exception in Python 3.7 - if not exception or isinstance(exception, CancelledError): - self.logger.info("Worker stopped") - elif isinstance(exception, Exception): - self.logger.exception("Worker crashed") - elif exception: - self.logger.info( - f"Worker stopped due to {exception.__class__.__name__}" - ) - - with move_on_after(3, shield=True): - await self.event_broker.publish_local( - WorkerStopped(exception=exception) - ) - - async def stop(self, *, force: bool = False) -> None: - """ - Signal the worker that it should stop running jobs. - - This method does not wait for the worker to actually stop. - - """ - if self._state in (RunState.starting, RunState.started): - self._state = RunState.stopping - event = anyio.Event() - self.event_broker.subscribe( - lambda ev: event.set(), {WorkerStopped}, one_shot=True - ) - if force: - self._task_group.cancel_scope.cancel() - else: - self._wakeup_event.set() - - await event.wait() - - async def _run_job(self, job: Job, func: Callable) -> None: - try: - # Check if the job started before the deadline - start_time = datetime.now(timezone.utc) - if job.start_deadline is not None and start_time > job.start_deadline: - result = JobResult.from_job( - job, - outcome=JobOutcome.missed_start_deadline, - finished_at=start_time, - ) - await self.data_store.release_job(self.identity, job.task_id, result) - await self.event_broker.publish( - JobReleased.from_result(result, self.identity) - ) - return - - token = current_job.set(JobInfo.from_job(job)) - try: - retval = func(*job.args, **job.kwargs) - if isawaitable(retval): - retval = await retval - except get_cancelled_exc_class(): - self.logger.info("Job %s was cancelled", job.id) - with CancelScope(shield=True): - result = JobResult.from_job( - job, - outcome=JobOutcome.cancelled, - ) - await self.data_store.release_job( - self.identity, job.task_id, result - ) - await self.event_broker.publish( - JobReleased.from_result(result, self.identity) - ) - except BaseException as exc: - if isinstance(exc, Exception): - self.logger.exception("Job %s raised an exception", job.id) - else: - self.logger.error( - "Job %s was aborted due to %s", job.id, exc.__class__.__name__ - ) - - result = JobResult.from_job( - job, - JobOutcome.error, - exception=exc, - ) - await self.data_store.release_job( - self.identity, - job.task_id, - result, - ) - await self.event_broker.publish( - JobReleased.from_result(result, self.identity) - ) - if not isinstance(exc, Exception): - raise - else: - self.logger.info("Job %s completed successfully", job.id) - result = JobResult.from_job( - job, - JobOutcome.success, - return_value=retval, - ) - await self.data_store.release_job(self.identity, job.task_id, result) - await self.event_broker.publish( - JobReleased.from_result(result, self.identity) - ) - finally: - current_job.reset(token) - finally: - self._running_jobs.remove(job.id) diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py deleted file mode 100644 index 41279d9..0000000 --- a/src/apscheduler/workers/sync.py +++ /dev/null @@ -1,257 +0,0 @@ -from __future__ import annotations - -import atexit -import os -import platform -import threading -from concurrent.futures import Future, ThreadPoolExecutor -from contextlib import ExitStack -from contextvars import copy_context -from datetime import datetime, timezone -from logging import Logger, getLogger -from types import TracebackType -from typing import Callable -from uuid import UUID - -import attrs - -from .. import JobReleased -from .._context import current_job, current_worker -from .._enums import JobOutcome, RunState -from .._events import JobAdded, WorkerStarted, WorkerStopped -from .._structures import Job, JobInfo, JobResult -from .._validators import positive_integer -from ..abc import DataStore, EventBroker -from ..eventbrokers.local import LocalEventBroker - - -@attrs.define(eq=False) -class Worker: - """Runs jobs locally in a thread pool.""" - - data_store: DataStore - event_broker: EventBroker = attrs.field(factory=LocalEventBroker) - max_concurrent_jobs: int = attrs.field( - kw_only=True, validator=positive_integer, default=20 - ) - identity: str = attrs.field(kw_only=True, default=None) - logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) - # True if a scheduler owns this worker - _is_internal: bool = attrs.field(kw_only=True, default=False) - - _state: RunState = attrs.field(init=False, default=RunState.stopped) - _thread: threading.Thread | None = attrs.field(init=False, default=None) - _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event) - _executor: ThreadPoolExecutor = attrs.field(init=False) - _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) - _running_jobs: set[UUID] = attrs.field(init=False, factory=set) - - def __attrs_post_init__(self) -> None: - if not self.identity: - self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" - - def __enter__(self) -> Worker: - self.start_in_background() - return self - - def __exit__( - self, - exc_type: type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, - ) -> None: - self.stop() - - @property - def state(self) -> RunState: - """The current running state of the worker.""" - return self._state - - def start_in_background(self) -> None: - """ - Launch the worker in a new thread. - - This method registers an :mod:`atexit` hook to shut down the worker and wait - for the thread to finish. - - """ - start_future: Future[None] = Future() - self._thread = threading.Thread( - target=copy_context().run, args=[self._run, start_future], daemon=True - ) - self._thread.start() - try: - start_future.result() - except BaseException: - self._thread = None - raise - - atexit.register(self.stop) - - def stop(self) -> None: - """ - Signal the worker that it should stop running jobs. - - This method does not wait for the worker to actually stop. - - """ - atexit.unregister(self.stop) - if self._state is RunState.started: - self._state = RunState.stopping - self._wakeup_event.set() - - if threading.current_thread() != self._thread: - self._thread.join() - self._thread = None - - def run_until_stopped(self) -> None: - """ - Run the worker until it is explicitly stopped. - - This method will only return if :meth:`stop` is called. - - """ - self._run(None) - - def _run(self, start_future: Future[None] | None) -> None: - with ExitStack() as exit_stack: - try: - if self._state is not RunState.stopped: - raise RuntimeError( - f'Cannot start the worker when it is in the "{self._state}" ' - f"state" - ) - - if not self._is_internal: - # Initialize the event broker - self.event_broker.start() - exit_stack.push( - lambda *exc_info: self.event_broker.stop( - force=exc_info[0] is not None - ) - ) - - # Initialize the data store - self.data_store.start(self.event_broker) - exit_stack.push( - lambda *exc_info: self.data_store.stop( - force=exc_info[0] is not None - ) - ) - - # Set the current worker - token = current_worker.set(self) - exit_stack.callback(current_worker.reset, token) - - # Wake up the worker if the data store emits a significant job event - exit_stack.enter_context( - self.event_broker.subscribe( - lambda event: self._wakeup_event.set(), {JobAdded} - ) - ) - - # Initialize the thread pool - executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) - exit_stack.enter_context(executor) - - # Signal that the worker has started - self._state = RunState.started - self.event_broker.publish_local(WorkerStarted()) - except BaseException as exc: - if start_future: - start_future.set_exception(exc) - return - else: - raise - else: - if start_future: - start_future.set_result(None) - - exception: BaseException | None = None - try: - while self._state is RunState.started: - available_slots = self.max_concurrent_jobs - len(self._running_jobs) - if available_slots: - jobs = self.data_store.acquire_jobs( - self.identity, available_slots - ) - for job in jobs: - task = self.data_store.get_task(job.task_id) - self._running_jobs.add(job.id) - executor.submit( - copy_context().run, self._run_job, job, task.func - ) - - self._wakeup_event.wait() - self._wakeup_event = threading.Event() - except BaseException as exc: - exception = exc - raise - finally: - self._state = RunState.stopped - if not exception: - self.logger.info("Worker stopped") - elif isinstance(exception, Exception): - self.logger.exception("Worker crashed") - elif exception: - self.logger.info( - f"Worker stopped due to {exception.__class__.__name__}" - ) - - self.event_broker.publish_local(WorkerStopped(exception=exception)) - - def _run_job(self, job: Job, func: Callable) -> None: - try: - # Check if the job started before the deadline - start_time = datetime.now(timezone.utc) - if job.start_deadline is not None and start_time > job.start_deadline: - result = JobResult.from_job( - job, JobOutcome.missed_start_deadline, finished_at=start_time - ) - self.event_broker.publish( - JobReleased.from_result(result, self.identity) - ) - self.data_store.release_job(self.identity, job.task_id, result) - return - - token = current_job.set(JobInfo.from_job(job)) - try: - retval = func(*job.args, **job.kwargs) - except BaseException as exc: - if isinstance(exc, Exception): - self.logger.exception("Job %s raised an exception", job.id) - else: - self.logger.error( - "Job %s was aborted due to %s", job.id, exc.__class__.__name__ - ) - - result = JobResult.from_job( - job, - JobOutcome.error, - exception=exc, - ) - self.data_store.release_job( - self.identity, - job.task_id, - result, - ) - self.event_broker.publish( - JobReleased.from_result(result, self.identity) - ) - if not isinstance(exc, Exception): - raise - else: - self.logger.info("Job %s completed successfully", job.id) - result = JobResult.from_job( - job, - JobOutcome.success, - return_value=retval, - ) - self.data_store.release_job(self.identity, job.task_id, result) - self.event_broker.publish( - JobReleased.from_result(result, self.identity) - ) - finally: - current_job.reset(token) - finally: - self._running_jobs.remove(job.id) diff --git a/tests/conftest.py b/tests/conftest.py index 31ea9b0..7497b52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,16 @@ from __future__ import annotations import sys +from contextlib import AsyncExitStack +from tempfile import TemporaryDirectory +from typing import Any, AsyncGenerator, cast import pytest +from _pytest.fixtures import SubRequest +from pytest_lazyfixture import lazy_fixture -from apscheduler.abc import Serializer +from apscheduler.abc import DataStore, EventBroker, Serializer +from apscheduler.datastores.memory import MemoryDataStore from apscheduler.serializers.cbor import CBORSerializer from apscheduler.serializers.json import JSONSerializer from apscheduler.serializers.pickle import PickleSerializer @@ -34,3 +40,189 @@ def serializer(request) -> Serializer | None: @pytest.fixture def anyio_backend() -> str: return "asyncio" + + +@pytest.fixture +def local_broker() -> EventBroker: + from apscheduler.eventbrokers.local import LocalEventBroker + + return LocalEventBroker() + + +@pytest.fixture +async def redis_broker(serializer: Serializer) -> EventBroker: + from apscheduler.eventbrokers.redis import RedisEventBroker + + broker = RedisEventBroker.from_url( + "redis://localhost:6379", serializer=serializer, stop_check_interval=0.05 + ) + await broker.client.flushdb() + return broker + + +@pytest.fixture +def mqtt_broker(serializer: Serializer) -> EventBroker: + from paho.mqtt.client import Client + + from apscheduler.eventbrokers.mqtt import MQTTEventBroker + + return MQTTEventBroker(Client(), serializer=serializer) + + +@pytest.fixture +async def asyncpg_broker(serializer: Serializer) -> EventBroker: + from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker + + broker = AsyncpgEventBroker.from_dsn( + "postgres://postgres:secret@localhost:5432/testdb", serializer=serializer + ) + return broker + + +@pytest.fixture( + params=[ + pytest.param(lazy_fixture("local_broker"), id="local"), + pytest.param( + lazy_fixture("asyncpg_broker"), + id="asyncpg", + marks=[pytest.mark.external_service], + ), + pytest.param( + lazy_fixture("redis_broker"), + id="redis", + marks=[pytest.mark.external_service], + ), + pytest.param( + lazy_fixture("mqtt_broker"), id="mqtt", marks=[pytest.mark.external_service] + ), + ] +) +async def raw_event_broker(request: SubRequest) -> EventBroker: + return cast(EventBroker, request.param) + + +@pytest.fixture +async def event_broker( + raw_event_broker: EventBroker, +) -> AsyncGenerator[EventBroker, Any]: + async with AsyncExitStack() as exit_stack: + await raw_event_broker.start(exit_stack) + yield raw_event_broker + + +@pytest.fixture +def memory_store() -> DataStore: + yield MemoryDataStore() + + +@pytest.fixture +def mongodb_store() -> DataStore: + from pymongo import MongoClient + + from apscheduler.datastores.mongodb import MongoDBDataStore + + with MongoClient(tz_aware=True, serverSelectionTimeoutMS=1000) as client: + yield MongoDBDataStore(client, start_from_scratch=True) + + +@pytest.fixture +def sqlite_store() -> DataStore: + from sqlalchemy.future import create_engine + + from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + + with TemporaryDirectory("sqlite_") as tempdir: + engine = create_engine(f"sqlite:///{tempdir}/test.db") + try: + yield SQLAlchemyDataStore(engine) + finally: + engine.dispose() + + +@pytest.fixture +def psycopg2_store() -> DataStore: + from sqlalchemy.future import create_engine + + from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + + engine = create_engine("postgresql+psycopg2://postgres:secret@localhost/testdb") + try: + yield SQLAlchemyDataStore(engine, schema="alter", start_from_scratch=True) + finally: + engine.dispose() + + +@pytest.fixture +def mysql_store() -> DataStore: + from sqlalchemy.future import create_engine + + from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + + engine = create_engine("mysql+pymysql://root:secret@localhost/testdb") + try: + yield SQLAlchemyDataStore(engine, start_from_scratch=True) + finally: + engine.dispose() + + +@pytest.fixture +async def asyncpg_store() -> DataStore: + from sqlalchemy.ext.asyncio import create_async_engine + + from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + + engine = create_async_engine( + "postgresql+asyncpg://postgres:secret@localhost/testdb", future=True + ) + try: + yield SQLAlchemyDataStore(engine, start_from_scratch=True) + finally: + await engine.dispose() + + +@pytest.fixture +async def asyncmy_store() -> DataStore: + from sqlalchemy.ext.asyncio import create_async_engine + + from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + + engine = create_async_engine( + "mysql+asyncmy://root:secret@localhost/testdb?charset=utf8mb4", future=True + ) + try: + yield SQLAlchemyDataStore(engine, start_from_scratch=True) + finally: + await engine.dispose() + + +@pytest.fixture( + params=[ + pytest.param( + lazy_fixture("asyncpg_store"), + id="asyncpg", + marks=[pytest.mark.external_service], + ), + pytest.param( + lazy_fixture("asyncmy_store"), + id="asyncmy", + marks=[pytest.mark.external_service], + ), + pytest.param( + lazy_fixture("mongodb_store"), + id="mongodb", + marks=[pytest.mark.external_service], + ), + ] +) +async def raw_datastore(request: SubRequest) -> DataStore: + return cast(DataStore, request.param) + + +@pytest.fixture +async def datastore( + raw_datastore: DataStore, local_broker: EventBroker +) -> AsyncGenerator[DataStore, Any]: + async with AsyncExitStack() as exit_stack: + await local_broker.start(exit_stack) + await raw_datastore.start(exit_stack, local_broker) + yield raw_datastore diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 9c9a0cf..618369a 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -1,18 +1,13 @@ from __future__ import annotations -import threading -from collections.abc import Generator -from contextlib import asynccontextmanager, contextmanager +from contextlib import AsyncExitStack, asynccontextmanager from datetime import datetime, timedelta, timezone -from tempfile import TemporaryDirectory -from typing import Any, AsyncGenerator, cast +from typing import AsyncGenerator import anyio import pytest -from _pytest.fixtures import SubRequest from anyio import CancelScope from freezegun.api import FrozenDateTimeFactory -from pytest_lazyfixture import lazy_fixture from apscheduler import ( CoalescePolicy, @@ -30,1055 +25,483 @@ from apscheduler import ( TaskLookupError, TaskUpdated, ) -from apscheduler.abc import AsyncDataStore, AsyncEventBroker, DataStore, EventBroker -from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter -from apscheduler.datastores.memory import MemoryDataStore -from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker -from apscheduler.eventbrokers.local import LocalEventBroker +from apscheduler.abc import DataStore, EventBroker from apscheduler.triggers.date import DateTrigger +pytestmark = pytest.mark.anyio + @pytest.fixture -def memory_store() -> DataStore: - yield MemoryDataStore() +def schedules() -> list[Schedule]: + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule1 = Schedule(id="s1", task_id="task1", trigger=trigger) + schedule1.next_fire_time = trigger.next() + trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc)) + schedule2 = Schedule(id="s2", task_id="task2", trigger=trigger) + schedule2.next_fire_time = trigger.next() -@pytest.fixture -def adapted_memory_store() -> AsyncDataStore: - store = MemoryDataStore() - return AsyncDataStoreAdapter(store) + trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc)) + schedule3 = Schedule(id="s3", task_id="task1", trigger=trigger) + return [schedule1, schedule2, schedule3] -@pytest.fixture -def mongodb_store() -> DataStore: - from pymongo import MongoClient +@asynccontextmanager +async def capture_events( + datastore: DataStore, + limit: int, + event_types: set[type[Event]] | None = None, +) -> AsyncGenerator[list[Event], None]: + def listener(event: Event) -> None: + events.append(event) + if len(events) == limit: + limit_event.set() + subscription.unsubscribe() + + events: list[Event] = [] + limit_event = anyio.Event() + subscription = datastore._event_broker.subscribe(listener, event_types) + yield events + if limit: + with anyio.fail_after(3): + await limit_event.wait() + + +async def test_add_replace_task(datastore: DataStore) -> None: + import math + + event_types = {TaskAdded, TaskUpdated} + async with capture_events(datastore, 3, event_types) as events: + await datastore.add_task(Task(id="test_task", func=print, executor="async")) + await datastore.add_task( + Task(id="test_task2", func=math.ceil, executor="async") + ) + await datastore.add_task(Task(id="test_task", func=repr, executor="async")) - from apscheduler.datastores.mongodb import MongoDBDataStore + tasks = await datastore.get_tasks() + assert len(tasks) == 2 + assert tasks[0].id == "test_task" + assert tasks[0].func is repr + assert tasks[1].id == "test_task2" + assert tasks[1].func is math.ceil - with MongoClient(tz_aware=True, serverSelectionTimeoutMS=1000) as client: - yield MongoDBDataStore(client, start_from_scratch=True) + received_event = events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == "test_task" + received_event = events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == "test_task2" -@pytest.fixture -def sqlite_store() -> DataStore: - from sqlalchemy.future import create_engine + received_event = events.pop(0) + assert isinstance(received_event, TaskUpdated) + assert received_event.task_id == "test_task" - from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + assert not events - with TemporaryDirectory("sqlite_") as tempdir: - engine = create_engine(f"sqlite:///{tempdir}/test.db") - try: - yield SQLAlchemyDataStore(engine) - finally: - engine.dispose() +async def test_add_schedules(datastore: DataStore, schedules: list[Schedule]) -> None: + async with capture_events(datastore, 3, {ScheduleAdded}) as events: + for schedule in schedules: + await datastore.add_schedule(schedule, ConflictPolicy.exception) -@pytest.fixture -def psycopg2_store() -> DataStore: - from sqlalchemy.future import create_engine + assert await datastore.get_schedules() == schedules + assert await datastore.get_schedules({"s1", "s2", "s3"}) == schedules + assert await datastore.get_schedules({"s1"}) == [schedules[0]] + assert await datastore.get_schedules({"s2"}) == [schedules[1]] + assert await datastore.get_schedules({"s3"}) == [schedules[2]] - from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + for event, schedule in zip(events, schedules): + assert event.schedule_id == schedule.id + assert event.next_fire_time == schedule.next_fire_time - engine = create_engine("postgresql+psycopg2://postgres:secret@localhost/testdb") - try: - yield SQLAlchemyDataStore(engine, schema="alter", start_from_scratch=True) - finally: - engine.dispose() +async def test_replace_schedules( + datastore: DataStore, schedules: list[Schedule] +) -> None: + async with capture_events(datastore, 1, {ScheduleUpdated}) as events: + for schedule in schedules: + await datastore.add_schedule(schedule, ConflictPolicy.exception) -@pytest.fixture -def mysql_store() -> DataStore: - from sqlalchemy.future import create_engine + next_fire_time = schedules[2].trigger.next() + schedule = Schedule( + id="s3", + task_id="foo", + trigger=schedules[2].trigger, + args=(), + kwargs={}, + coalesce=CoalescePolicy.earliest, + misfire_grace_time=None, + tags=frozenset(), + ) + schedule.next_fire_time = next_fire_time + await datastore.add_schedule(schedule, ConflictPolicy.replace) + + schedules = await datastore.get_schedules({schedule.id}) + assert schedules[0].task_id == "foo" + assert schedules[0].next_fire_time == next_fire_time + assert schedules[0].args == () + assert schedules[0].kwargs == {} + assert schedules[0].coalesce is CoalescePolicy.earliest + assert schedules[0].misfire_grace_time is None + assert schedules[0].tags == frozenset() + + received_event = events.pop(0) + assert received_event.schedule_id == "s3" + assert received_event.next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc) + assert not events + + +async def test_remove_schedules( + datastore: DataStore, schedules: list[Schedule] +) -> None: + async with capture_events(datastore, 2, {ScheduleRemoved}) as events: + for schedule in schedules: + await datastore.add_schedule(schedule, ConflictPolicy.exception) - from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore + await datastore.remove_schedules(["s1", "s2"]) + assert await datastore.get_schedules() == [schedules[2]] - engine = create_engine("mysql+pymysql://root:secret@localhost/testdb") - try: - yield SQLAlchemyDataStore(engine, start_from_scratch=True) - finally: - engine.dispose() + received_event = events.pop(0) + assert received_event.schedule_id == "s1" + received_event = events.pop(0) + assert received_event.schedule_id == "s2" -@pytest.fixture -async def asyncpg_store() -> AsyncDataStore: - from sqlalchemy.ext.asyncio import create_async_engine + assert not events - from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore - engine = create_async_engine( - "postgresql+asyncpg://postgres:secret@localhost/testdb", future=True +@pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc)) +async def test_acquire_release_schedules( + datastore: DataStore, schedules: list[Schedule] +) -> None: + event_types = {ScheduleRemoved, ScheduleUpdated} + async with capture_events(datastore, 2, event_types) as events: + for schedule in schedules: + await datastore.add_schedule(schedule, ConflictPolicy.exception) + + # The first scheduler gets the first due schedule + schedules1 = await datastore.acquire_schedules("dummy-id1", 1) + assert len(schedules1) == 1 + assert schedules1[0].id == "s1" + + # The second scheduler gets the second due schedule + schedules2 = await datastore.acquire_schedules("dummy-id2", 1) + assert len(schedules2) == 1 + assert schedules2[0].id == "s2" + + # The third scheduler gets nothing + schedules3 = await datastore.acquire_schedules("dummy-id3", 1) + assert not schedules3 + + # Update the schedules and check that the job store actually deletes the + # first one and updates the second one + schedules1[0].next_fire_time = None + schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc) + + # Release all the schedules + await datastore.release_schedules("dummy-id1", schedules1) + await datastore.release_schedules("dummy-id2", schedules2) + + # Check that the first schedule is gone + schedules = await datastore.get_schedules() + assert len(schedules) == 2 + assert schedules[0].id == "s2" + assert schedules[1].id == "s3" + + # Check for the appropriate update and delete events + received_event = events.pop(0) + assert isinstance(received_event, ScheduleRemoved) + assert received_event.schedule_id == "s1" + + received_event = events.pop(0) + assert isinstance(received_event, ScheduleUpdated) + assert received_event.schedule_id == "s2" + assert received_event.next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc) + + assert not events + + +async def test_release_schedule_two_identical_fire_times(datastore: DataStore) -> None: + """Regression test for #616.""" + for i in range(1, 3): + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) + schedule.next_fire_time = trigger.next() + await datastore.add_schedule(schedule, ConflictPolicy.exception) + + schedules = await datastore.acquire_schedules("foo", 3) + schedules[0].next_fire_time = None + await datastore.release_schedules("foo", schedules) + + remaining = await datastore.get_schedules({s.id for s in schedules}) + assert len(remaining) == 1 + assert remaining[0].id == schedules[1].id + + +async def test_release_two_schedules_at_once(datastore: DataStore) -> None: + """Regression test for #621.""" + for i in range(2): + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) + schedule.next_fire_time = trigger.next() + await datastore.add_schedule(schedule, ConflictPolicy.exception) + + schedules = await datastore.acquire_schedules("foo", 3) + await datastore.release_schedules("foo", schedules) + + remaining = await datastore.get_schedules({s.id for s in schedules}) + assert len(remaining) == 2 + + +async def test_acquire_schedules_lock_timeout( + datastore: DataStore, schedules: list[Schedule], freezer +) -> None: + """ + Test that a scheduler can acquire schedules that were acquired by another + scheduler but not released within the lock timeout period. + + """ + await datastore.add_schedule(schedules[0], ConflictPolicy.exception) + + # First, one scheduler acquires the first available schedule + acquired1 = await datastore.acquire_schedules("dummy-id1", 1) + assert len(acquired1) == 1 + assert acquired1[0].id == "s1" + + # Try to acquire the schedule just at the threshold (now == acquired_until). + # This should not yield any schedules. + freezer.tick(30) + acquired2 = await datastore.acquire_schedules("dummy-id2", 1) + assert not acquired2 + + # Right after that, the schedule should be available + freezer.tick(1) + acquired3 = await datastore.acquire_schedules("dummy-id2", 1) + assert len(acquired3) == 1 + assert acquired3[0].id == "s1" + + +async def test_acquire_multiple_workers(datastore: DataStore) -> None: + await datastore.add_task( + Task(id="task1", func=asynccontextmanager, executor="async") ) - try: - yield AsyncSQLAlchemyDataStore(engine, start_from_scratch=True) - finally: - await engine.dispose() + jobs = [Job(task_id="task1") for _ in range(2)] + for job in jobs: + await datastore.add_job(job) + # The first worker gets the first job in the queue + jobs1 = await datastore.acquire_jobs("worker1", 1) + assert len(jobs1) == 1 + assert jobs1[0].id == jobs[0].id -@pytest.fixture -async def asyncmy_store() -> AsyncDataStore: - from sqlalchemy.ext.asyncio import create_async_engine + # The second worker gets the second job + jobs2 = await datastore.acquire_jobs("worker2", 1) + assert len(jobs2) == 1 + assert jobs2[0].id == jobs[1].id + + # The third worker gets nothing + jobs3 = await datastore.acquire_jobs("worker3", 1) + assert not jobs3 - from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore - engine = create_async_engine( - "mysql+asyncmy://root:secret@localhost/testdb?charset=utf8mb4", future=True +async def test_job_release_success(datastore: DataStore) -> None: + await datastore.add_task( + Task(id="task1", func=asynccontextmanager, executor="async") ) - try: - yield AsyncSQLAlchemyDataStore(engine, start_from_scratch=True) - finally: - await engine.dispose() + job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) + await datastore.add_job(job) + + acquired = await datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + await datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult.from_job( + acquired[0], + JobOutcome.success, + return_value="foo", + ), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.success + assert result.exception is None + assert result.return_value == "foo" + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) -@pytest.fixture -def schedules() -> list[Schedule]: - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule1 = Schedule(id="s1", task_id="task1", trigger=trigger) - schedule1.next_fire_time = trigger.next() - trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc)) - schedule2 = Schedule(id="s2", task_id="task2", trigger=trigger) - schedule2.next_fire_time = trigger.next() +async def test_job_release_failure(datastore: DataStore) -> None: + await datastore.add_task( + Task(id="task1", executor="async", func=asynccontextmanager) + ) + job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) + await datastore.add_job(job) + + acquired = await datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + await datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult.from_job( + acquired[0], + JobOutcome.error, + exception=ValueError("foo"), + ), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.error + assert isinstance(result.exception, ValueError) + assert result.exception.args == ("foo",) + assert result.return_value is None - trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc)) - schedule3 = Schedule(id="s3", task_id="task1", trigger=trigger) - return [schedule1, schedule2, schedule3] + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) -class TestDataStores: - @contextmanager - def capture_events( - self, - datastore: DataStore, - limit: int, - event_types: set[type[Event]] | None = None, - ) -> Generator[list[Event], None, None]: - def listener(event: Event) -> None: - events.append(event) - if len(events) == limit: - limit_event.set() - subscription.unsubscribe() - - events: list[Event] = [] - limit_event = threading.Event() - subscription = datastore.events.subscribe(listener, event_types) - yield events - if limit: - limit_event.wait(2) - - @pytest.fixture - def event_broker(self) -> Generator[EventBroker, Any, None]: - broker = LocalEventBroker() - broker.start() - yield broker - broker.stop() - - @pytest.fixture( - params=[ - pytest.param(lazy_fixture("memory_store"), id="memory"), - pytest.param(lazy_fixture("sqlite_store"), id="sqlite"), - pytest.param( - lazy_fixture("mongodb_store"), - id="mongodb", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("psycopg2_store"), - id="psycopg2", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("mysql_store"), - id="mysql", - marks=[pytest.mark.external_service], - ), - ] +async def test_job_release_missed_deadline(datastore: DataStore): + await datastore.add_task( + Task(id="task1", func=asynccontextmanager, executor="async") ) - def datastore( - self, request: SubRequest, event_broker: EventBroker - ) -> Generator[DataStore, Any, None]: - datastore = cast(DataStore, request.param) - datastore.start(event_broker) - yield datastore - datastore.stop() - - def test_add_replace_task(self, datastore: DataStore) -> None: - import math - - event_types = {TaskAdded, TaskUpdated} - with self.capture_events(datastore, 3, event_types) as events: - datastore.add_task(Task(id="test_task", func=print)) - datastore.add_task(Task(id="test_task2", func=math.ceil)) - datastore.add_task(Task(id="test_task", func=repr)) - - tasks = datastore.get_tasks() - assert len(tasks) == 2 - assert tasks[0].id == "test_task" - assert tasks[0].func is repr - assert tasks[1].id == "test_task2" - assert tasks[1].func is math.ceil - - received_event = events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "test_task" - - received_event = events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "test_task2" - - received_event = events.pop(0) - assert isinstance(received_event, TaskUpdated) - assert received_event.task_id == "test_task" - - assert not events - - def test_add_schedules( - self, datastore: DataStore, schedules: list[Schedule] - ) -> None: - with self.capture_events(datastore, 3, {ScheduleAdded}) as events: - for schedule in schedules: - datastore.add_schedule(schedule, ConflictPolicy.exception) - - assert datastore.get_schedules() == schedules - assert datastore.get_schedules({"s1", "s2", "s3"}) == schedules - assert datastore.get_schedules({"s1"}) == [schedules[0]] - assert datastore.get_schedules({"s2"}) == [schedules[1]] - assert datastore.get_schedules({"s3"}) == [schedules[2]] - - for event, schedule in zip(events, schedules): - assert event.schedule_id == schedule.id - assert event.next_fire_time == schedule.next_fire_time - - def test_replace_schedules( - self, datastore: DataStore, schedules: list[Schedule] - ) -> None: - with self.capture_events(datastore, 1, {ScheduleUpdated}) as events: - for schedule in schedules: - datastore.add_schedule(schedule, ConflictPolicy.exception) - - next_fire_time = schedules[2].trigger.next() - schedule = Schedule( - id="s3", - task_id="foo", - trigger=schedules[2].trigger, - args=(), - kwargs={}, - coalesce=CoalescePolicy.earliest, - misfire_grace_time=None, - tags=frozenset(), - ) - schedule.next_fire_time = next_fire_time - datastore.add_schedule(schedule, ConflictPolicy.replace) - - schedules = datastore.get_schedules({schedule.id}) - assert schedules[0].task_id == "foo" - assert schedules[0].next_fire_time == next_fire_time - assert schedules[0].args == () - assert schedules[0].kwargs == {} - assert schedules[0].coalesce is CoalescePolicy.earliest - assert schedules[0].misfire_grace_time is None - assert schedules[0].tags == frozenset() - - received_event = events.pop(0) - assert received_event.schedule_id == "s3" - assert received_event.next_fire_time == datetime( - 2020, 9, 15, tzinfo=timezone.utc - ) - assert not events - - def test_remove_schedules( - self, datastore: DataStore, schedules: list[Schedule] - ) -> None: - with self.capture_events(datastore, 2, {ScheduleRemoved}) as events: - for schedule in schedules: - datastore.add_schedule(schedule, ConflictPolicy.exception) - - datastore.remove_schedules(["s1", "s2"]) - assert datastore.get_schedules() == [schedules[2]] - - received_event = events.pop(0) - assert received_event.schedule_id == "s1" - - received_event = events.pop(0) - assert received_event.schedule_id == "s2" - - assert not events - - @pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc)) - def test_acquire_release_schedules( - self, datastore: DataStore, schedules: list[Schedule] - ) -> None: - event_types = {ScheduleRemoved, ScheduleUpdated} - with self.capture_events(datastore, 2, event_types) as events: - for schedule in schedules: - datastore.add_schedule(schedule, ConflictPolicy.exception) - - # The first scheduler gets the first due schedule - schedules1 = datastore.acquire_schedules("dummy-id1", 1) - assert len(schedules1) == 1 - assert schedules1[0].id == "s1" - - # The second scheduler gets the second due schedule - schedules2 = datastore.acquire_schedules("dummy-id2", 1) - assert len(schedules2) == 1 - assert schedules2[0].id == "s2" - - # The third scheduler gets nothing - schedules3 = datastore.acquire_schedules("dummy-id3", 1) - assert not schedules3 - - # Update the schedules and check that the job store actually deletes the - # first one and updates the second one - schedules1[0].next_fire_time = None - schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc) - - # Release all the schedules - datastore.release_schedules("dummy-id1", schedules1) - datastore.release_schedules("dummy-id2", schedules2) - - # Check that the first schedule is gone - schedules = datastore.get_schedules() - assert len(schedules) == 2 - assert schedules[0].id == "s2" - assert schedules[1].id == "s3" - - # Check for the appropriate update and delete events - received_event = events.pop(0) - assert isinstance(received_event, ScheduleRemoved) - assert received_event.schedule_id == "s1" - - received_event = events.pop(0) - assert isinstance(received_event, ScheduleUpdated) - assert received_event.schedule_id == "s2" - assert received_event.next_fire_time == datetime( - 2020, 9, 15, tzinfo=timezone.utc - ) + job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) + await datastore.add_job(job) + + acquired = await datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + await datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult.from_job( + acquired[0], + JobOutcome.missed_start_deadline, + ), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.missed_start_deadline + assert result.exception is None + assert result.return_value is None - assert not events - - def test_release_schedule_two_identical_fire_times( - self, datastore: DataStore - ) -> None: - """Regression test for #616.""" - for i in range(1, 3): - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) - schedule.next_fire_time = trigger.next() - datastore.add_schedule(schedule, ConflictPolicy.exception) - - schedules = datastore.acquire_schedules("foo", 3) - schedules[0].next_fire_time = None - datastore.release_schedules("foo", schedules) - - remaining = datastore.get_schedules({s.id for s in schedules}) - assert len(remaining) == 1 - assert remaining[0].id == schedules[1].id - - def test_release_two_schedules_at_once(self, datastore: DataStore) -> None: - """Regression test for #621.""" - for i in range(2): - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) - schedule.next_fire_time = trigger.next() - datastore.add_schedule(schedule, ConflictPolicy.exception) - - schedules = datastore.acquire_schedules("foo", 3) - datastore.release_schedules("foo", schedules) - - remaining = datastore.get_schedules({s.id for s in schedules}) - assert len(remaining) == 2 - - def test_acquire_schedules_lock_timeout( - self, datastore: DataStore, schedules: list[Schedule], freezer - ) -> None: - """ - Test that a scheduler can acquire schedules that were acquired by another - scheduler but not released within the lock timeout period. - - """ - datastore.add_schedule(schedules[0], ConflictPolicy.exception) - - # First, one scheduler acquires the first available schedule - acquired1 = datastore.acquire_schedules("dummy-id1", 1) - assert len(acquired1) == 1 - assert acquired1[0].id == "s1" - - # Try to acquire the schedule just at the threshold (now == acquired_until). - # This should not yield any schedules. - freezer.tick(30) - acquired2 = datastore.acquire_schedules("dummy-id2", 1) - assert not acquired2 - - # Right after that, the schedule should be available - freezer.tick(1) - acquired3 = datastore.acquire_schedules("dummy-id2", 1) - assert len(acquired3) == 1 - assert acquired3[0].id == "s1" - - def test_acquire_multiple_workers(self, datastore: DataStore) -> None: - datastore.add_task(Task(id="task1", func=asynccontextmanager)) - jobs = [Job(task_id="task1") for _ in range(2)] - for job in jobs: - datastore.add_job(job) - - # The first worker gets the first job in the queue - jobs1 = datastore.acquire_jobs("worker1", 1) - assert len(jobs1) == 1 - assert jobs1[0].id == jobs[0].id - - # The second worker gets the second job - jobs2 = datastore.acquire_jobs("worker2", 1) - assert len(jobs2) == 1 - assert jobs2[0].id == jobs[1].id - - # The third worker gets nothing - jobs3 = datastore.acquire_jobs("worker3", 1) - assert not jobs3 - - def test_job_release_success(self, datastore: DataStore) -> None: - datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - datastore.add_job(job) - - acquired = datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult.from_job( - acquired[0], - JobOutcome.success, - return_value="foo", - ), - ) - result = datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.success - assert result.exception is None - assert result.return_value == "foo" - - # Check that the job and its result are gone - assert not datastore.get_jobs({acquired[0].id}) - assert not datastore.get_job_result(acquired[0].id) - - def test_job_release_failure(self, datastore: DataStore) -> None: - datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - datastore.add_job(job) - - acquired = datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult.from_job( - acquired[0], - JobOutcome.error, - exception=ValueError("foo"), - ), - ) - result = datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.error - assert isinstance(result.exception, ValueError) - assert result.exception.args == ("foo",) - assert result.return_value is None - - # Check that the job and its result are gone - assert not datastore.get_jobs({acquired[0].id}) - assert not datastore.get_job_result(acquired[0].id) - - def test_job_release_missed_deadline(self, datastore: DataStore): - datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - datastore.add_job(job) - - acquired = datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult.from_job( - acquired[0], - JobOutcome.missed_start_deadline, - ), - ) - result = datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.missed_start_deadline - assert result.exception is None - assert result.return_value is None - - # Check that the job and its result are gone - assert not datastore.get_jobs({acquired[0].id}) - assert not datastore.get_job_result(acquired[0].id) - - def test_job_release_cancelled(self, datastore: DataStore) -> None: - datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - datastore.add_job(job) - - acquired = datastore.acquire_jobs("worker1", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - datastore.release_job( - "worker1", - acquired[0].task_id, - JobResult.from_job( - acquired[0], - JobOutcome.cancelled, - ), - ) - result = datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.cancelled - assert result.exception is None - assert result.return_value is None - - # Check that the job and its result are gone - assert not datastore.get_jobs({acquired[0].id}) - assert not datastore.get_job_result(acquired[0].id) - - def test_acquire_jobs_lock_timeout( - self, datastore: DataStore, freezer: FrozenDateTimeFactory - ) -> None: - """ - Test that a worker can acquire jobs that were acquired by another scheduler but - not released within the lock timeout period. - - """ - datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1") - datastore.add_job(job) - - # First, one worker acquires the first available job - acquired = datastore.acquire_jobs("worker1", 1) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - # Try to acquire the job just at the threshold (now == acquired_until). - # This should not yield any jobs. - freezer.tick(30) - assert not datastore.acquire_jobs("worker2", 1) - - # Right after that, the job should be available - freezer.tick(1) - acquired = datastore.acquire_jobs("worker2", 1) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - def test_acquire_jobs_max_number_exceeded(self, datastore: DataStore) -> None: - datastore.add_task( - Task(id="task1", func=asynccontextmanager, max_running_jobs=2) - ) - jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")] - for job in jobs: - datastore.add_job(job) - - # Check that only 2 jobs are returned from acquire_jobs() even though the limit - # wqas 3 - acquired_jobs = datastore.acquire_jobs("worker1", 3) - assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] - - # Release one job, and the worker should be able to acquire the third job - datastore.release_job( - "worker1", - acquired_jobs[0].task_id, - JobResult.from_job( - acquired_jobs[0], - JobOutcome.success, - return_value=None, - ), - ) - acquired_jobs = datastore.acquire_jobs("worker1", 3) - assert [job.id for job in acquired_jobs] == [jobs[2].id] - - def test_add_get_task(self, datastore: DataStore) -> None: - with pytest.raises(TaskLookupError): - datastore.get_task("dummyid") - - datastore.add_task(Task(id="dummyid", func=asynccontextmanager)) - task = datastore.get_task("dummyid") - assert task.id == "dummyid" - assert task.func is asynccontextmanager - - -@pytest.mark.anyio -class TestAsyncDataStores: - @asynccontextmanager - async def capture_events( - self, - datastore: AsyncDataStore, - limit: int, - event_types: set[type[Event]] | None = None, - ) -> AsyncGenerator[list[Event], None]: - def listener(event: Event) -> None: - events.append(event) - if len(events) == limit: - limit_event.set() - subscription.unsubscribe() - - events: list[Event] = [] - limit_event = anyio.Event() - subscription = datastore.events.subscribe(listener, event_types) - yield events - if limit: - with anyio.fail_after(3): - await limit_event.wait() - - @pytest.fixture - async def event_broker(self) -> AsyncGenerator[AsyncEventBroker, Any]: - broker = LocalAsyncEventBroker() - await broker.start() - yield broker - await broker.stop() - - @pytest.fixture( - params=[ - pytest.param(lazy_fixture("adapted_memory_store"), id="memory"), - pytest.param( - lazy_fixture("asyncpg_store"), - id="asyncpg", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("asyncmy_store"), - id="asyncmy", - marks=[pytest.mark.external_service], - ), - ] + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) + + +async def test_job_release_cancelled(datastore: DataStore) -> None: + await datastore.add_task( + Task(id="task1", func=asynccontextmanager, executor="async") ) - async def raw_datastore( - self, request: SubRequest, event_broker: AsyncEventBroker - ) -> AsyncDataStore: - return cast(AsyncDataStore, request.param) - - @pytest.fixture - async def datastore( - self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker - ) -> AsyncGenerator[AsyncDataStore, Any]: - await raw_datastore.start(event_broker) - yield raw_datastore - await raw_datastore.stop() - - async def test_add_replace_task(self, datastore: AsyncDataStore) -> None: - import math - - event_types = {TaskAdded, TaskUpdated} - async with self.capture_events(datastore, 3, event_types) as events: - await datastore.add_task(Task(id="test_task", func=print)) - await datastore.add_task(Task(id="test_task2", func=math.ceil)) - await datastore.add_task(Task(id="test_task", func=repr)) - - tasks = await datastore.get_tasks() - assert len(tasks) == 2 - assert tasks[0].id == "test_task" - assert tasks[0].func is repr - assert tasks[1].id == "test_task2" - assert tasks[1].func is math.ceil - - received_event = events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "test_task" - - received_event = events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "test_task2" - - received_event = events.pop(0) - assert isinstance(received_event, TaskUpdated) - assert received_event.task_id == "test_task" - - assert not events - - async def test_add_schedules( - self, datastore: AsyncDataStore, schedules: list[Schedule] - ) -> None: - async with self.capture_events(datastore, 3, {ScheduleAdded}) as events: - for schedule in schedules: - await datastore.add_schedule(schedule, ConflictPolicy.exception) - - assert await datastore.get_schedules() == schedules - assert await datastore.get_schedules({"s1", "s2", "s3"}) == schedules - assert await datastore.get_schedules({"s1"}) == [schedules[0]] - assert await datastore.get_schedules({"s2"}) == [schedules[1]] - assert await datastore.get_schedules({"s3"}) == [schedules[2]] - - for event, schedule in zip(events, schedules): - assert event.schedule_id == schedule.id - assert event.next_fire_time == schedule.next_fire_time - - async def test_replace_schedules( - self, datastore: AsyncDataStore, schedules: list[Schedule] - ) -> None: - async with self.capture_events(datastore, 1, {ScheduleUpdated}) as events: - for schedule in schedules: - await datastore.add_schedule(schedule, ConflictPolicy.exception) - - next_fire_time = schedules[2].trigger.next() - schedule = Schedule( - id="s3", - task_id="foo", - trigger=schedules[2].trigger, - args=(), - kwargs={}, - coalesce=CoalescePolicy.earliest, - misfire_grace_time=None, - tags=frozenset(), - ) - schedule.next_fire_time = next_fire_time - await datastore.add_schedule(schedule, ConflictPolicy.replace) - - schedules = await datastore.get_schedules({schedule.id}) - assert schedules[0].task_id == "foo" - assert schedules[0].next_fire_time == next_fire_time - assert schedules[0].args == () - assert schedules[0].kwargs == {} - assert schedules[0].coalesce is CoalescePolicy.earliest - assert schedules[0].misfire_grace_time is None - assert schedules[0].tags == frozenset() - - received_event = events.pop(0) - assert received_event.schedule_id == "s3" - assert received_event.next_fire_time == datetime( - 2020, 9, 15, tzinfo=timezone.utc - ) - assert not events - - async def test_remove_schedules( - self, datastore: AsyncDataStore, schedules: list[Schedule] - ) -> None: - async with self.capture_events(datastore, 2, {ScheduleRemoved}) as events: - for schedule in schedules: - await datastore.add_schedule(schedule, ConflictPolicy.exception) - - await datastore.remove_schedules(["s1", "s2"]) - assert await datastore.get_schedules() == [schedules[2]] - - received_event = events.pop(0) - assert received_event.schedule_id == "s1" - - received_event = events.pop(0) - assert received_event.schedule_id == "s2" - - assert not events - - @pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc)) - async def test_acquire_release_schedules( - self, datastore: AsyncDataStore, schedules: list[Schedule] - ) -> None: - event_types = {ScheduleRemoved, ScheduleUpdated} - async with self.capture_events(datastore, 2, event_types) as events: - for schedule in schedules: - await datastore.add_schedule(schedule, ConflictPolicy.exception) - - # The first scheduler gets the first due schedule - schedules1 = await datastore.acquire_schedules("dummy-id1", 1) - assert len(schedules1) == 1 - assert schedules1[0].id == "s1" - - # The second scheduler gets the second due schedule - schedules2 = await datastore.acquire_schedules("dummy-id2", 1) - assert len(schedules2) == 1 - assert schedules2[0].id == "s2" - - # The third scheduler gets nothing - schedules3 = await datastore.acquire_schedules("dummy-id3", 1) - assert not schedules3 - - # Update the schedules and check that the job store actually deletes the - # first one and updates the second one - schedules1[0].next_fire_time = None - schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc) - - # Release all the schedules - await datastore.release_schedules("dummy-id1", schedules1) - await datastore.release_schedules("dummy-id2", schedules2) - - # Check that the first schedule is gone - schedules = await datastore.get_schedules() - assert len(schedules) == 2 - assert schedules[0].id == "s2" - assert schedules[1].id == "s3" - - # Check for the appropriate update and delete events - received_event = events.pop(0) - assert isinstance(received_event, ScheduleRemoved) - assert received_event.schedule_id == "s1" - - received_event = events.pop(0) - assert isinstance(received_event, ScheduleUpdated) - assert received_event.schedule_id == "s2" - assert received_event.next_fire_time == datetime( - 2020, 9, 15, tzinfo=timezone.utc - ) + job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) + await datastore.add_job(job) - assert not events + acquired = await datastore.acquire_jobs("worker1", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id - async def test_release_schedule_two_identical_fire_times( - self, datastore: AsyncDataStore - ) -> None: - """Regression test for #616.""" - for i in range(1, 3): - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) - schedule.next_fire_time = trigger.next() - await datastore.add_schedule(schedule, ConflictPolicy.exception) + await datastore.release_job( + "worker1", + acquired[0].task_id, + JobResult.from_job(acquired[0], JobOutcome.cancelled), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.cancelled + assert result.exception is None + assert result.return_value is None + + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) + + +async def test_acquire_jobs_lock_timeout( + datastore: DataStore, freezer: FrozenDateTimeFactory +) -> None: + """ + Test that a worker can acquire jobs that were acquired by another scheduler but + not released within the lock timeout period. + + """ + await datastore.add_task( + Task(id="task1", func=asynccontextmanager, executor="async") + ) + job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) + await datastore.add_job(job) - schedules = await datastore.acquire_schedules("foo", 3) - schedules[0].next_fire_time = None - await datastore.release_schedules("foo", schedules) - - remaining = await datastore.get_schedules({s.id for s in schedules}) - assert len(remaining) == 1 - assert remaining[0].id == schedules[1].id - - async def test_release_two_schedules_at_once( - self, datastore: AsyncDataStore - ) -> None: - """Regression test for #621.""" - for i in range(2): - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) - schedule.next_fire_time = trigger.next() - await datastore.add_schedule(schedule, ConflictPolicy.exception) + # First, one worker acquires the first available job + acquired = await datastore.acquire_jobs("worker1", 1) + assert len(acquired) == 1 + assert acquired[0].id == job.id - schedules = await datastore.acquire_schedules("foo", 3) - await datastore.release_schedules("foo", schedules) - - remaining = await datastore.get_schedules({s.id for s in schedules}) - assert len(remaining) == 2 - - async def test_acquire_schedules_lock_timeout( - self, datastore: AsyncDataStore, schedules: list[Schedule], freezer - ) -> None: - """ - Test that a scheduler can acquire schedules that were acquired by another - scheduler but not released within the lock timeout period. - - """ - await datastore.add_schedule(schedules[0], ConflictPolicy.exception) - - # First, one scheduler acquires the first available schedule - acquired1 = await datastore.acquire_schedules("dummy-id1", 1) - assert len(acquired1) == 1 - assert acquired1[0].id == "s1" - - # Try to acquire the schedule just at the threshold (now == acquired_until). - # This should not yield any schedules. - freezer.tick(30) - acquired2 = await datastore.acquire_schedules("dummy-id2", 1) - assert not acquired2 - - # Right after that, the schedule should be available - freezer.tick(1) - acquired3 = await datastore.acquire_schedules("dummy-id2", 1) - assert len(acquired3) == 1 - assert acquired3[0].id == "s1" - - async def test_acquire_multiple_workers(self, datastore: AsyncDataStore) -> None: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - jobs = [Job(task_id="task1") for _ in range(2)] - for job in jobs: - await datastore.add_job(job) - - # The first worker gets the first job in the queue - jobs1 = await datastore.acquire_jobs("worker1", 1) - assert len(jobs1) == 1 - assert jobs1[0].id == jobs[0].id - - # The second worker gets the second job - jobs2 = await datastore.acquire_jobs("worker2", 1) - assert len(jobs2) == 1 - assert jobs2[0].id == jobs[1].id - - # The third worker gets nothing - jobs3 = await datastore.acquire_jobs("worker3", 1) - assert not jobs3 - - async def test_job_release_success(self, datastore: AsyncDataStore) -> None: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - await datastore.add_job(job) + # Try to acquire the job just at the threshold (now == acquired_until). + # This should not yield any jobs. + freezer.tick(30) + assert not await datastore.acquire_jobs("worker2", 1) - acquired = await datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - await datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult.from_job( - acquired[0], - JobOutcome.success, - return_value="foo", - ), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.success - assert result.exception is None - assert result.return_value == "foo" - - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) - - async def test_job_release_failure(self, datastore: AsyncDataStore) -> None: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - await datastore.add_job(job) + # Right after that, the job should be available + freezer.tick(1) + acquired = await datastore.acquire_jobs("worker2", 1) + assert len(acquired) == 1 + assert acquired[0].id == job.id - acquired = await datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - await datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult.from_job( - acquired[0], - JobOutcome.error, - exception=ValueError("foo"), - ), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.error - assert isinstance(result.exception, ValueError) - assert result.exception.args == ("foo",) - assert result.return_value is None - - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) - - async def test_job_release_missed_deadline(self, datastore: AsyncDataStore): - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - await datastore.add_job(job) - acquired = await datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - await datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult.from_job( - acquired[0], - JobOutcome.missed_start_deadline, - ), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.missed_start_deadline - assert result.exception is None - assert result.return_value is None - - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) - - async def test_job_release_cancelled(self, datastore: AsyncDataStore) -> None: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) +async def test_acquire_jobs_max_number_exceeded(datastore: DataStore) -> None: + await datastore.add_task( + Task(id="task1", func=asynccontextmanager, executor="async", max_running_jobs=2) + ) + jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")] + for job in jobs: await datastore.add_job(job) - acquired = await datastore.acquire_jobs("worker1", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id + # Check that only 2 jobs are returned from acquire_jobs() even though the limit + # wqas 3 + acquired_jobs = await datastore.acquire_jobs("worker1", 3) + assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] + + # Release one job, and the worker should be able to acquire the third job + await datastore.release_job( + "worker1", + acquired_jobs[0].task_id, + JobResult.from_job( + acquired_jobs[0], + JobOutcome.success, + return_value=None, + ), + ) + acquired_jobs = await datastore.acquire_jobs("worker1", 3) + assert [job.id for job in acquired_jobs] == [jobs[2].id] - await datastore.release_job( - "worker1", - acquired[0].task_id, - JobResult.from_job(acquired[0], JobOutcome.cancelled), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.cancelled - assert result.exception is None - assert result.return_value is None - - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) - - async def test_acquire_jobs_lock_timeout( - self, datastore: AsyncDataStore, freezer: FrozenDateTimeFactory - ) -> None: - """ - Test that a worker can acquire jobs that were acquired by another scheduler but - not released within the lock timeout period. - - """ - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1)) - await datastore.add_job(job) - # First, one worker acquires the first available job - acquired = await datastore.acquire_jobs("worker1", 1) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - # Try to acquire the job just at the threshold (now == acquired_until). - # This should not yield any jobs. - freezer.tick(30) - assert not await datastore.acquire_jobs("worker2", 1) - - # Right after that, the job should be available - freezer.tick(1) - acquired = await datastore.acquire_jobs("worker2", 1) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - async def test_acquire_jobs_max_number_exceeded( - self, datastore: AsyncDataStore - ) -> None: - await datastore.add_task( - Task(id="task1", func=asynccontextmanager, max_running_jobs=2) - ) - jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")] - for job in jobs: - await datastore.add_job(job) - - # Check that only 2 jobs are returned from acquire_jobs() even though the limit - # wqas 3 - acquired_jobs = await datastore.acquire_jobs("worker1", 3) - assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] - - # Release one job, and the worker should be able to acquire the third job - await datastore.release_job( - "worker1", - acquired_jobs[0].task_id, - JobResult.from_job( - acquired_jobs[0], - JobOutcome.success, - return_value=None, - ), - ) - acquired_jobs = await datastore.acquire_jobs("worker1", 3) - assert [job.id for job in acquired_jobs] == [jobs[2].id] - - async def test_add_get_task(self, datastore: DataStore) -> None: - with pytest.raises(TaskLookupError): - await datastore.get_task("dummyid") - - await datastore.add_task(Task(id="dummyid", func=asynccontextmanager)) - task = await datastore.get_task("dummyid") - assert task.id == "dummyid" - assert task.func is asynccontextmanager - - async def test_cancel_start( - self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker - ) -> None: - with CancelScope() as scope: - scope.cancel() - await raw_datastore.start(event_broker) - await raw_datastore.stop() - - async def test_cancel_stop( - self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker - ) -> None: - with CancelScope() as scope: - await raw_datastore.start(event_broker) +async def test_add_get_task(datastore: DataStore) -> None: + with pytest.raises(TaskLookupError): + await datastore.get_task("dummyid") + + await datastore.add_task( + Task(id="dummyid", func=asynccontextmanager, executor="async") + ) + task = await datastore.get_task("dummyid") + assert task.id == "dummyid" + assert task.func is asynccontextmanager + + +async def test_cancel_start( + raw_datastore: DataStore, local_broker: EventBroker +) -> None: + with CancelScope() as scope: + scope.cancel() + async with AsyncExitStack() as exit_stack: + await raw_datastore.start(exit_stack, local_broker) + + +async def test_cancel_stop(raw_datastore: DataStore, local_broker: EventBroker) -> None: + with CancelScope() as scope: + async with AsyncExitStack() as exit_stack: + await raw_datastore.start(exit_stack, local_broker) scope.cancel() - await raw_datastore.stop() diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py index 09942f5..9b98f60 100644 --- a/tests/test_eventbrokers.py +++ b/tests/test_eventbrokers.py @@ -1,286 +1,107 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, Generator -from concurrent.futures import Future +from contextlib import AsyncExitStack from datetime import datetime, timezone -from queue import Empty, Queue -from typing import Any, cast import pytest -from _pytest.fixtures import SubRequest from _pytest.logging import LogCaptureFixture from anyio import CancelScope, create_memory_object_stream, fail_after -from pytest_lazyfixture import lazy_fixture from apscheduler import Event, ScheduleAdded -from apscheduler.abc import AsyncEventBroker, EventBroker, Serializer +from apscheduler.abc import EventBroker +pytestmark = pytest.mark.anyio -@pytest.fixture -def local_broker() -> EventBroker: - from apscheduler.eventbrokers.local import LocalEventBroker - return LocalEventBroker() - - -@pytest.fixture -def local_async_broker() -> AsyncEventBroker: - from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker - - return LocalAsyncEventBroker() - - -@pytest.fixture -def redis_broker(serializer: Serializer) -> EventBroker: - from apscheduler.eventbrokers.redis import RedisEventBroker - - broker = RedisEventBroker.from_url( - "redis://localhost:6379", serializer=serializer, stop_check_interval=0.05 +async def test_publish_subscribe(event_broker: EventBroker) -> None: + send, receive = create_memory_object_stream(2) + event_broker.subscribe(send.send) + event_broker.subscribe(send.send_nowait) + event = ScheduleAdded( + schedule_id="schedule1", + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), ) - return broker - - -@pytest.fixture -async def async_redis_broker(serializer: Serializer) -> AsyncEventBroker: - from apscheduler.eventbrokers.async_redis import AsyncRedisEventBroker - - broker = AsyncRedisEventBroker.from_url( - "redis://localhost:6379", serializer=serializer, stop_check_interval=0.05 + await event_broker.publish(event) + + with fail_after(3): + event1 = await receive.receive() + event2 = await receive.receive() + + assert event1 == event2 + assert isinstance(event1, ScheduleAdded) + assert isinstance(event1.timestamp, datetime) + assert event1.schedule_id == "schedule1" + assert event1.next_fire_time == datetime( + 2021, 9, 11, 12, 31, 56, 254867, timezone.utc ) - return broker -@pytest.fixture -def mqtt_broker(serializer: Serializer) -> EventBroker: - from paho.mqtt.client import Client - - from apscheduler.eventbrokers.mqtt import MQTTEventBroker - - return MQTTEventBroker(Client(), serializer=serializer) - - -@pytest.fixture -async def asyncpg_broker(serializer: Serializer) -> AsyncEventBroker: - from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker - - broker = AsyncpgEventBroker.from_dsn( - "postgres://postgres:secret@localhost:5432/testdb", serializer=serializer +async def test_subscribe_one_shot(event_broker: EventBroker) -> None: + send, receive = create_memory_object_stream(2) + event_broker.subscribe(send.send, one_shot=True) + event = ScheduleAdded( + schedule_id="schedule1", + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), ) - yield broker - - -@pytest.fixture( - params=[ - pytest.param(lazy_fixture("local_broker"), id="local"), - pytest.param( - lazy_fixture("redis_broker"), - id="redis", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("mqtt_broker"), id="mqtt", marks=[pytest.mark.external_service] - ), - ] -) -def broker(request: SubRequest) -> Generator[EventBroker, Any, None]: - request.param.start() - yield request.param - request.param.stop() - - -@pytest.fixture( - params=[ - pytest.param(lazy_fixture("local_async_broker"), id="local"), - pytest.param( - lazy_fixture("asyncpg_broker"), - id="asyncpg", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("async_redis_broker"), - id="async_redis", - marks=[pytest.mark.external_service], - ), - ] -) -async def raw_async_broker(request: SubRequest) -> AsyncEventBroker: - return cast(AsyncEventBroker, request.param) - - -@pytest.fixture -async def async_broker( - raw_async_broker: AsyncEventBroker, -) -> AsyncGenerator[AsyncEventBroker, Any]: - await raw_async_broker.start() - yield raw_async_broker - await raw_async_broker.stop() - - -class TestEventBroker: - def test_publish_subscribe(self, broker: EventBroker) -> None: - queue: Queue[Event] = Queue() - broker.subscribe(queue.put_nowait) - broker.subscribe(queue.put_nowait) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - broker.publish(event) - event1 = queue.get(timeout=3) - event2 = queue.get(timeout=1) - - assert event1 == event2 - assert isinstance(event1, ScheduleAdded) - assert isinstance(event1.timestamp, datetime) - assert event1.schedule_id == "schedule1" - assert event1.next_fire_time == datetime( - 2021, 9, 11, 12, 31, 56, 254867, timezone.utc - ) - - def test_subscribe_one_shot(self, broker: EventBroker) -> None: - queue: Queue[Event] = Queue() - broker.subscribe(queue.put_nowait, one_shot=True) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - broker.publish(event) - event = ScheduleAdded( - schedule_id="schedule2", - next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc), - ) - broker.publish(event) - received_event = queue.get(timeout=3) - with pytest.raises(Empty): - queue.get(timeout=0.1) - - assert isinstance(received_event, ScheduleAdded) - assert received_event.schedule_id == "schedule1" - - def test_unsubscribe(self, broker: EventBroker, caplog) -> None: - queue: Queue[Event] = Queue() - subscription = broker.subscribe(queue.put_nowait) - broker.publish(Event()) - queue.get(timeout=3) - - subscription.unsubscribe() - broker.publish(Event()) - with pytest.raises(Empty): - queue.get(timeout=0.1) - - def test_publish_no_subscribers( - self, broker: EventBroker, caplog: LogCaptureFixture - ) -> None: - broker.publish(Event()) - assert not caplog.text - - def test_publish_exception( - self, broker: EventBroker, caplog: LogCaptureFixture - ) -> None: - def bad_subscriber(event: Event) -> None: - raise Exception("foo") - - timestamp = datetime.now(timezone.utc) - event_future: Future[Event] = Future() - broker.subscribe(bad_subscriber) - broker.subscribe(event_future.set_result) - broker.publish(Event(timestamp=timestamp)) - - event = event_future.result(3) - assert isinstance(event, Event) - assert event.timestamp == timestamp - assert "Error delivering Event" in caplog.text + await event_broker.publish(event) + event = ScheduleAdded( + schedule_id="schedule2", + next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc), + ) + await event_broker.publish(event) + with fail_after(3): + received_event = await receive.receive() -@pytest.mark.anyio -class TestAsyncEventBroker: - async def test_publish_subscribe(self, async_broker: AsyncEventBroker) -> None: - send, receive = create_memory_object_stream(2) - async_broker.subscribe(send.send) - async_broker.subscribe(send.send_nowait) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - await async_broker.publish(event) + with pytest.raises(TimeoutError), fail_after(0.1): + await receive.receive() - with fail_after(3): - event1 = await receive.receive() - event2 = await receive.receive() + assert isinstance(received_event, ScheduleAdded) + assert received_event.schedule_id == "schedule1" - assert event1 == event2 - assert isinstance(event1, ScheduleAdded) - assert isinstance(event1.timestamp, datetime) - assert event1.schedule_id == "schedule1" - assert event1.next_fire_time == datetime( - 2021, 9, 11, 12, 31, 56, 254867, timezone.utc - ) - async def test_subscribe_one_shot(self, async_broker: AsyncEventBroker) -> None: - send, receive = create_memory_object_stream(2) - async_broker.subscribe(send.send, one_shot=True) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - await async_broker.publish(event) - event = ScheduleAdded( - schedule_id="schedule2", - next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc), - ) - await async_broker.publish(event) +async def test_unsubscribe(event_broker: EventBroker) -> None: + send, receive = create_memory_object_stream() + subscription = event_broker.subscribe(send.send) + await event_broker.publish(Event()) + with fail_after(3): + await receive.receive() - with fail_after(3): - received_event = await receive.receive() + subscription.unsubscribe() + await event_broker.publish(Event()) + with pytest.raises(TimeoutError), fail_after(0.1): + await receive.receive() - with pytest.raises(TimeoutError), fail_after(0.1): - await receive.receive() - assert isinstance(received_event, ScheduleAdded) - assert received_event.schedule_id == "schedule1" +async def test_publish_no_subscribers(event_broker, caplog: LogCaptureFixture) -> None: + await event_broker.publish(Event()) + assert not caplog.text - async def test_unsubscribe(self, async_broker: AsyncEventBroker) -> None: - send, receive = create_memory_object_stream() - subscription = async_broker.subscribe(send.send) - await async_broker.publish(Event()) - with fail_after(3): - await receive.receive() - subscription.unsubscribe() - await async_broker.publish(Event()) - with pytest.raises(TimeoutError), fail_after(0.1): - await receive.receive() +async def test_publish_exception(event_broker, caplog: LogCaptureFixture) -> None: + def bad_subscriber(event: Event) -> None: + raise Exception("foo") - async def test_publish_no_subscribers( - self, async_broker: AsyncEventBroker, caplog: LogCaptureFixture - ) -> None: - await async_broker.publish(Event()) - assert not caplog.text + timestamp = datetime.now(timezone.utc) + send, receive = create_memory_object_stream() + event_broker.subscribe(bad_subscriber) + event_broker.subscribe(send.send) + await event_broker.publish(Event(timestamp=timestamp)) - async def test_publish_exception( - self, async_broker: AsyncEventBroker, caplog: LogCaptureFixture - ) -> None: - def bad_subscriber(event: Event) -> None: - raise Exception("foo") + received_event = await receive.receive() + assert received_event.timestamp == timestamp + assert "Error delivering Event" in caplog.text - timestamp = datetime.now(timezone.utc) - send, receive = create_memory_object_stream() - async_broker.subscribe(bad_subscriber) - async_broker.subscribe(send.send) - await async_broker.publish(Event(timestamp=timestamp)) - received_event = await receive.receive() - assert received_event.timestamp == timestamp - assert "Error delivering Event" in caplog.text +async def test_cancel_start(raw_event_broker: EventBroker) -> None: + with CancelScope() as scope: + scope.cancel() + async with AsyncExitStack() as exit_stack: + await raw_event_broker.start(exit_stack) - async def test_cancel_start(self, raw_async_broker: AsyncEventBroker) -> None: - with CancelScope() as scope: - scope.cancel() - await raw_async_broker.start() - await raw_async_broker.stop() - async def test_cancel_stop(self, raw_async_broker: AsyncEventBroker) -> None: - with CancelScope() as scope: - await raw_async_broker.start() +async def test_cancel_stop(raw_event_broker: EventBroker) -> None: + with CancelScope() as scope: + async with AsyncExitStack() as exit_stack: + await raw_event_broker.start(exit_stack) scope.cancel() - await raw_async_broker.stop() diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 32c1c29..30c8dc9 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -26,15 +26,14 @@ from apscheduler import ( SchedulerStopped, Task, TaskAdded, + current_async_scheduler, current_job, current_scheduler, - current_worker, ) from apscheduler.schedulers.async_ import AsyncScheduler from apscheduler.schedulers.sync import Scheduler from apscheduler.triggers.date import DateTrigger from apscheduler.triggers.interval import IntervalTrigger -from apscheduler.workers.async_ import AsyncWorker if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -70,7 +69,7 @@ class TestAsyncScheduler: received_events: list[Event] = [] event = anyio.Event() trigger = DateTrigger(datetime.now(timezone.utc)) - async with AsyncScheduler(start_worker=False) as scheduler: + async with AsyncScheduler(process_jobs=False) as scheduler: scheduler.event_broker.subscribe(listener) await scheduler.add_schedule(dummy_async_job, trigger, id="foo") await scheduler.start_in_background() @@ -111,7 +110,7 @@ class TestAsyncScheduler: assert not received_events async def test_add_get_schedule(self) -> None: - async with AsyncScheduler(start_worker=False) as scheduler: + async with AsyncScheduler(process_jobs=False) as scheduler: with pytest.raises(ScheduleLookupError): await scheduler.get_schedule("dummyid") @@ -121,7 +120,7 @@ class TestAsyncScheduler: assert isinstance(schedule, Schedule) async def test_add_get_schedules(self) -> None: - async with AsyncScheduler(start_worker=False) as scheduler: + async with AsyncScheduler(process_jobs=False) as scheduler: assert await scheduler.get_schedules() == [] schedule1_id = await scheduler.add_schedule( @@ -161,7 +160,7 @@ class TestAsyncScheduler: orig_start_time = datetime.now(timezone) - timedelta(seconds=1) fake_uniform = mocker.patch("random.uniform") fake_uniform.configure_mock(side_effect=lambda a, b: jitter) - async with AsyncScheduler(start_worker=False) as scheduler: + async with AsyncScheduler(process_jobs=False) as scheduler: trigger = IntervalTrigger(seconds=3, start_time=orig_start_time) job_added_event = anyio.Event() scheduler.event_broker.subscribe(job_added_listener, {JobAdded}) @@ -263,8 +262,7 @@ class TestAsyncScheduler: async def test_contextvars(self) -> None: def check_contextvars() -> None: - assert current_scheduler.get() is scheduler - assert isinstance(current_worker.get(), AsyncWorker) + assert current_async_scheduler.get() is scheduler info = current_job.get() assert info.task_id == "task_id" assert info.schedule_id == "foo" @@ -277,7 +275,7 @@ class TestAsyncScheduler: start_deadline = datetime.now(timezone.utc) + timedelta(seconds=10) async with AsyncScheduler() as scheduler: await scheduler.data_store.add_task( - Task(id="task_id", func=check_contextvars) + Task(id="task_id", func=check_contextvars, executor="async") ) job = Job( task_id="task_id", @@ -300,10 +298,7 @@ class TestAsyncScheduler: async def test_wait_until_stopped(self) -> None: async with AsyncScheduler() as scheduler: - trigger = DateTrigger( - datetime.now(timezone.utc) + timedelta(milliseconds=100) - ) - await scheduler.add_schedule(scheduler.stop, trigger) + await scheduler.add_job(scheduler.stop) await scheduler.wait_until_stopped() # This should be a no-op @@ -422,7 +417,7 @@ class TestSyncScheduler: # Check that the job was created with the proper amount of jitter in its # scheduled time - jobs = scheduler.data_store.get_jobs({job_id}) + jobs = scheduler._portal.call(scheduler.data_store.get_jobs, {job_id}) assert jobs[0].jitter == timedelta(seconds=jitter) assert jobs[0].scheduled_fire_time == orig_start_time + timedelta( seconds=jitter @@ -495,7 +490,6 @@ class TestSyncScheduler: def test_contextvars(self) -> None: def check_contextvars() -> None: assert current_scheduler.get() is scheduler - assert current_worker.get() is not None info = current_job.get() assert info.task_id == "task_id" assert info.schedule_id == "foo" @@ -507,7 +501,10 @@ class TestSyncScheduler: scheduled_fire_time = datetime.now(timezone.utc) start_deadline = datetime.now(timezone.utc) + timedelta(seconds=10) with Scheduler() as scheduler: - scheduler.data_store.add_task(Task(id="task_id", func=check_contextvars)) + scheduler._portal.call( + scheduler.data_store.add_task, + Task(id="task_id", func=check_contextvars, executor="threadpool"), + ) job = Job( task_id="task_id", schedule_id="foo", @@ -517,7 +514,7 @@ class TestSyncScheduler: tags={"foo", "bar"}, result_expiration_time=timedelta(seconds=10), ) - scheduler.data_store.add_job(job) + scheduler._portal.call(scheduler.data_store.add_job, job) scheduler.start_in_background() result = scheduler.get_job_result(job.id) if result.outcome is JobOutcome.error: @@ -527,10 +524,7 @@ class TestSyncScheduler: def test_wait_until_stopped(self) -> None: with Scheduler() as scheduler: - trigger = DateTrigger( - datetime.now(timezone.utc) + timedelta(milliseconds=100) - ) - scheduler.add_schedule(scheduler.stop, trigger) + scheduler.add_job(scheduler.stop) scheduler.start_in_background() scheduler.wait_until_stopped() diff --git a/tests/test_workers.py b/tests/test_workers.py deleted file mode 100644 index aecc63b..0000000 --- a/tests/test_workers.py +++ /dev/null @@ -1,281 +0,0 @@ -from __future__ import annotations - -import threading -from datetime import datetime, timezone -from typing import Callable - -import anyio -import pytest -from anyio import fail_after - -from apscheduler import ( - Event, - Job, - JobAcquired, - JobAdded, - JobOutcome, - JobReleased, - Task, - TaskAdded, - WorkerStopped, -) -from apscheduler.datastores.memory import MemoryDataStore -from apscheduler.workers.async_ import AsyncWorker -from apscheduler.workers.sync import Worker - -pytestmark = pytest.mark.anyio - - -def sync_func(*args, fail: bool, **kwargs): - if fail: - raise Exception("failing as requested") - else: - return args, kwargs - - -async def async_func(*args, fail: bool, **kwargs): - if fail: - raise Exception("failing as requested") - else: - return args, kwargs - - -def fail_func(): - pytest.fail("This function should never be run") - - -class TestAsyncWorker: - @pytest.mark.parametrize( - "target_func", [sync_func, async_func], ids=["sync", "async"] - ) - @pytest.mark.parametrize("fail", [False, True], ids=["success", "fail"]) - async def test_run_job_nonscheduled_success( - self, target_func: Callable, fail: bool - ) -> None: - def listener(received_event: Event): - received_events.append(received_event) - if isinstance(received_event, JobReleased): - event.set() - - received_events: list[Event] = [] - event = anyio.Event() - async with AsyncWorker(MemoryDataStore()) as worker: - worker.event_broker.subscribe(listener) - await worker.data_store.add_task(Task(id="task_id", func=target_func)) - job = Job(task_id="task_id", args=(1, 2), kwargs={"x": "foo", "fail": fail}) - await worker.data_store.add_job(job) - with fail_after(3): - await event.wait() - - # First, a task was added - received_event = received_events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "task_id" - - # Then a job was added - received_event = received_events.pop(0) - assert isinstance(received_event, JobAdded) - assert received_event.job_id == job.id - assert received_event.task_id == "task_id" - assert received_event.schedule_id is None - - # Then the job was started - received_event = received_events.pop(0) - assert isinstance(received_event, JobAcquired) - assert received_event.job_id == job.id - assert received_event.worker_id == worker.identity - - received_event = received_events.pop(0) - if fail: - # Then the job failed - assert isinstance(received_event, JobReleased) - assert received_event.outcome is JobOutcome.error - assert received_event.exception_type == "Exception" - assert received_event.exception_message == "failing as requested" - assert isinstance(received_event.exception_traceback, list) - assert all( - isinstance(line, str) for line in received_event.exception_traceback - ) - else: - # Then the job finished successfully - assert isinstance(received_event, JobReleased) - assert received_event.outcome is JobOutcome.success - assert received_event.exception_type is None - assert received_event.exception_message is None - assert received_event.exception_traceback is None - - # Finally, the worker was stopped - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStopped) - - # There should be no more events on the list - assert not received_events - - async def test_run_deadline_missed(self) -> None: - def listener(received_event: Event): - received_events.append(received_event) - if isinstance(received_event, JobReleased): - event.set() - - scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) - received_events: list[Event] = [] - event = anyio.Event() - async with AsyncWorker(MemoryDataStore()) as worker: - worker.event_broker.subscribe(listener) - await worker.data_store.add_task(Task(id="task_id", func=fail_func)) - job = Job( - task_id="task_id", - schedule_id="foo", - scheduled_fire_time=scheduled_start_time, - start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc), - ) - await worker.data_store.add_job(job) - with fail_after(3): - await event.wait() - - # First, a task was added - received_event = received_events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "task_id" - - # Then a job was added - received_event = received_events.pop(0) - assert isinstance(received_event, JobAdded) - assert received_event.job_id == job.id - assert received_event.task_id == "task_id" - assert received_event.schedule_id == "foo" - - # The worker acquired the job - received_event = received_events.pop(0) - assert isinstance(received_event, JobAcquired) - assert received_event.job_id == job.id - assert received_event.worker_id == worker.identity - - # The worker determined that the deadline has been missed - received_event = received_events.pop(0) - assert isinstance(received_event, JobReleased) - assert received_event.outcome is JobOutcome.missed_start_deadline - assert received_event.job_id == job.id - assert received_event.worker_id == worker.identity - - # Finally, the worker was stopped - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStopped) - - # There should be no more events on the list - assert not received_events - - -class TestSyncWorker: - @pytest.mark.parametrize("fail", [False, True], ids=["success", "fail"]) - def test_run_job_nonscheduled(self, fail: bool) -> None: - def listener(received_event: Event): - received_events.append(received_event) - if isinstance(received_event, JobReleased): - event.set() - - received_events: list[Event] = [] - event = threading.Event() - with Worker(MemoryDataStore()) as worker: - worker.event_broker.subscribe(listener) - worker.data_store.add_task(Task(id="task_id", func=sync_func)) - job = Job(task_id="task_id", args=(1, 2), kwargs={"x": "foo", "fail": fail}) - worker.data_store.add_job(job) - event.wait(3) - - # First, a task was added - received_event = received_events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "task_id" - - # Then a job was added - received_event = received_events.pop(0) - assert isinstance(received_event, JobAdded) - assert received_event.job_id == job.id - assert received_event.task_id == "task_id" - assert received_event.schedule_id is None - - # Then the job was started - received_event = received_events.pop(0) - assert isinstance(received_event, JobAcquired) - assert received_event.job_id == job.id - assert received_event.worker_id == worker.identity - - received_event = received_events.pop(0) - if fail: - # Then the job failed - assert isinstance(received_event, JobReleased) - assert received_event.outcome is JobOutcome.error - assert received_event.exception_type == "Exception" - assert received_event.exception_message == "failing as requested" - assert isinstance(received_event.exception_traceback, list) - assert all( - isinstance(line, str) for line in received_event.exception_traceback - ) - else: - # Then the job finished successfully - assert isinstance(received_event, JobReleased) - assert received_event.outcome is JobOutcome.success - assert received_event.exception_type is None - assert received_event.exception_message is None - assert received_event.exception_traceback is None - - # Finally, the worker was stopped - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStopped) - - # There should be no more events on the list - assert not received_events - - def test_run_deadline_missed(self) -> None: - def listener(received_event: Event): - received_events.append(received_event) - if isinstance(received_event, JobReleased): - event.set() - - scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) - received_events: list[Event] = [] - event = threading.Event() - with Worker(MemoryDataStore()) as worker: - worker.event_broker.subscribe(listener) - worker.data_store.add_task(Task(id="task_id", func=fail_func)) - job = Job( - task_id="task_id", - schedule_id="foo", - scheduled_fire_time=scheduled_start_time, - start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc), - ) - worker.data_store.add_job(job) - event.wait(3) - - # First, a task was added - received_event = received_events.pop(0) - assert isinstance(received_event, TaskAdded) - assert received_event.task_id == "task_id" - - # Then a job was added - received_event = received_events.pop(0) - assert isinstance(received_event, JobAdded) - assert received_event.job_id == job.id - assert received_event.task_id == "task_id" - assert received_event.schedule_id == "foo" - - # The worker acquired the job - received_event = received_events.pop(0) - assert isinstance(received_event, JobAcquired) - assert received_event.job_id == job.id - assert received_event.worker_id == worker.identity - - # The worker determined that the deadline has been missed - received_event = received_events.pop(0) - assert isinstance(received_event, JobReleased) - assert received_event.outcome is JobOutcome.missed_start_deadline - assert received_event.job_id == job.id - assert received_event.worker_id == worker.identity - - # Finally, the worker was stopped - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStopped) - - # There should be no more events on the list - assert not received_events |