diff options
Diffstat (limited to 'src/apscheduler/schedulers/async_.py')
-rw-r--r-- | src/apscheduler/schedulers/async_.py | 507 |
1 files changed, 328 insertions, 179 deletions
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index 44fee27..11a1c67 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -4,10 +4,10 @@ import os import platform import random import sys -from asyncio import CancelledError from collections.abc import MutableMapping from contextlib import AsyncExitStack from datetime import datetime, timedelta, timezone +from inspect import isclass from logging import Logger, getLogger from types import TracebackType from typing import Any, Callable, Iterable, Mapping, cast @@ -15,12 +15,19 @@ from uuid import UUID, uuid4 import anyio import attrs -from anyio import TASK_STATUS_IGNORED, create_task_group, move_on_after +from anyio import ( + TASK_STATUS_IGNORED, + CancelScope, + create_task_group, + get_cancelled_exc_class, + move_on_after, +) from anyio.abc import TaskGroup, TaskStatus from attr.validators import instance_of -from .._context import current_async_scheduler -from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState +from .. import JobAdded +from .._context import current_async_scheduler, current_job +from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState, SchedulerRole from .._events import ( Event, JobReleased, @@ -35,8 +42,8 @@ from .._exceptions import ( JobLookupError, ScheduleLookupError, ) -from .._structures import Job, JobResult, Schedule, Task -from .._worker import Worker +from .._structures import Job, JobInfo, JobResult, Schedule, Task +from .._validators import non_negative_number from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger from ..datastores.memory import MemoryDataStore from ..eventbrokers.local import LocalEventBroker @@ -54,32 +61,37 @@ _microsecond_delta = timedelta(microseconds=1) _zero_timedelta = timedelta() -@attrs.define(eq=False) +@attrs.define(eq=False, kw_only=True) class AsyncScheduler: - """An asynchronous (AnyIO based) scheduler implementation.""" + """ + An asynchronous (AnyIO based) scheduler implementation. + + :param data_store: the data store for tasks, schedules and jobs + :param event_broker: the event broker to use for publishing an subscribing events + :param max_concurrent_jobs: Maximum number of jobs the worker will run at once + :param role: specifies what the scheduler should be doing when running + :param process_schedules: ``True`` to process due schedules in this scheduler + """ data_store: DataStore = attrs.field( - validator=instance_of(DataStore), factory=MemoryDataStore + kw_only=False, validator=instance_of(DataStore), factory=MemoryDataStore ) event_broker: EventBroker = attrs.field( - validator=instance_of(EventBroker), factory=LocalEventBroker - ) - identity: str = attrs.field(kw_only=True, default=None) - process_jobs: bool = attrs.field(kw_only=True, default=True) - job_executors: MutableMapping[str, JobExecutor] | None = attrs.field( - kw_only=True, default=None + kw_only=False, validator=instance_of(EventBroker), factory=LocalEventBroker ) - default_job_executor: str | None = attrs.field(kw_only=True, default=None) - process_schedules: bool = attrs.field(kw_only=True, default=True) - logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) + identity: str = attrs.field(default=None) + role: SchedulerRole = attrs.field(default=SchedulerRole.both) + max_concurrent_jobs: int = attrs.field(validator=non_negative_number, default=100) + job_executors: MutableMapping[str, JobExecutor] | None = attrs.field(default=None) + default_job_executor: str | None = attrs.field(default=None) + logger: Logger | None = attrs.field(default=getLogger(__name__)) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _task_group: TaskGroup | None = attrs.field(init=False, default=None) + _services_task_group: TaskGroup | None = attrs.field(init=False, default=None) _exit_stack: AsyncExitStack = attrs.field(init=False, factory=AsyncExitStack) _services_initialized: bool = attrs.field(init=False, default=False) - _wakeup_event: anyio.Event = attrs.field(init=False) - _wakeup_deadline: datetime | None = attrs.field(init=False, default=None) - _schedule_added_subscription: Subscription = attrs.field(init=False) + _scheduler_cancel_scope: CancelScope | None = attrs.field(init=False, default=None) + _running_jobs: set[Job] = attrs.field(init=False, factory=set) def __attrs_post_init__(self) -> None: if not self.identity: @@ -103,10 +115,10 @@ class AsyncScheduler: await self._exit_stack.__aenter__() try: await self._ensure_services_initialized(self._exit_stack) - self._task_group = await self._exit_stack.enter_async_context( + self._services_task_group = await self._exit_stack.enter_async_context( create_task_group() ) - self._exit_stack.callback(setattr, self, "_task_group", None) + self._exit_stack.callback(setattr, self, "_services_task_group", None) except BaseException as exc: await self._exit_stack.__aexit__(type(exc), exc, exc.__traceback__) raise @@ -123,7 +135,10 @@ class AsyncScheduler: await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) async def _ensure_services_initialized(self, exit_stack: AsyncExitStack) -> None: - """Ensure that the data store and event broker have been initialized.""" + """ + Initialize the data store and event broker if this hasn't already been done. + + """ if not self._services_initialized: self._services_initialized = True exit_stack.callback(setattr, self, "_services_initialized", False) @@ -132,6 +147,7 @@ class AsyncScheduler: await self.data_store.start(exit_stack, self.event_broker) def _check_initialized(self) -> None: + """Raise RuntimeError if the services have not been initialized yet.""" if not self._services_initialized: raise RuntimeError( "The scheduler has not been initialized yet. Use the scheduler as an " @@ -139,16 +155,6 @@ class AsyncScheduler: "than run_until_complete()." ) - async def _schedule_added_or_modified(self, event: Event) -> None: - event_ = cast("ScheduleAdded | ScheduleUpdated", event) - if not self._wakeup_deadline or ( - event_.next_fire_time and event_.next_fire_time < self._wakeup_deadline - ): - self.logger.debug( - "Detected a %s event – waking up the scheduler", type(event).__name__ - ) - self._wakeup_event.set() - @property def state(self) -> RunState: """The current running state of the scheduler.""" @@ -157,7 +163,7 @@ class AsyncScheduler: def subscribe( self, callback: Callable[[Event], Any], - event_types: Iterable[type[Event]] | None = None, + event_types: type[Event] | Iterable[type[Event]] | None = None, *, one_shot: bool = False, is_async: bool = True, @@ -170,19 +176,43 @@ class AsyncScheduler: :param callback: callable to be called with the event object when an event is published - :param event_types: an iterable of concrete Event classes to subscribe to + :param event_types: an event class or an iterable event classes to subscribe to :param one_shot: if ``True``, automatically unsubscribe after the first matching event :param is_async: ``True`` if the (synchronous) callback should be called on the event loop thread, ``False`` if it should be called in a worker thread. - If the callback is a coroutine function, this flag is ignored. + If ``callback`` is a coroutine function, this flag is ignored. """ self._check_initialized() + if isclass(event_types): + event_types = {event_types} + return self.event_broker.subscribe( callback, event_types, is_async=is_async, one_shot=one_shot ) + async def get_next_event( + self, event_types: type[Event] | Iterable[type[Event]] + ) -> Event: + """ + Wait until the next event matching one of the given types arrives. + + :param event_types: an event class or an iterable event classes to subscribe to + + """ + received_event: Event + + def receive_event(ev: Event) -> None: + nonlocal received_event + received_event = ev + event.set() + + event = anyio.Event() + with self.subscribe(receive_event, event_types, one_shot=True): + await event.wait() + return received_event + async def add_schedule( self, func_or_task_id: str | Callable, @@ -433,9 +463,9 @@ class AsyncScheduler: For that, see :meth:`wait_until_stopped`. """ - if self._state is RunState.started: + if self._state is RunState.started and self._scheduler_cancel_scope: self._state = RunState.stopping - self._wakeup_event.set() + self._scheduler_cancel_scope.cancel() async def wait_until_stopped(self) -> None: """ @@ -446,22 +476,17 @@ class AsyncScheduler: ``SchedulerStopped`` event. """ - if self._state in (RunState.stopped, RunState.stopping): - return - - event = anyio.Event() - with self.event_broker.subscribe( - lambda ev: event.set(), {SchedulerStopped}, one_shot=True - ): - await event.wait() + if self._state not in (RunState.stopped, RunState.stopping): + await self.get_next_event(SchedulerStopped) async def start_in_background(self) -> None: self._check_initialized() - await self._task_group.start(self.run_until_stopped) + await self._services_task_group.start(self.run_until_stopped) async def run_until_stopped( self, *, task_status: TaskStatus = TASK_STATUS_IGNORED ) -> None: + """Run the scheduler until explicitly stopped.""" if self._state is not RunState.stopped: raise RuntimeError( f'Cannot start the scheduler when it is in the "{self._state}" ' @@ -470,150 +495,38 @@ class AsyncScheduler: self._state = RunState.starting async with AsyncExitStack() as exit_stack: - self._wakeup_event = anyio.Event() await self._ensure_services_initialized(exit_stack) - # Wake up the scheduler if the data store emits a significant schedule event - exit_stack.enter_context( - self.event_broker.subscribe( - self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} - ) - ) - # Set this scheduler as the current scheduler token = current_async_scheduler.set(self) exit_stack.callback(current_async_scheduler.reset, token) - # Start the built-in worker, if configured to do so - if self.process_jobs: - worker = Worker(job_executors=self.job_executors) - await worker.start( - exit_stack, self.data_store, self.event_broker, self.identity - ) - - # Signal that the scheduler has started - self._state = RunState.started - task_status.started() - await self.event_broker.publish_local(SchedulerStarted()) - exception: BaseException | None = None try: - while self._state is RunState.started: - schedules = await self.data_store.acquire_schedules( - self.identity, 100 - ) - now = datetime.now(timezone.utc) - for schedule in schedules: - # Calculate a next fire time for the schedule, if possible - fire_times = [schedule.next_fire_time] - calculate_next = schedule.trigger.next - while True: - try: - fire_time = calculate_next() - except Exception: - self.logger.exception( - "Error computing next fire time for schedule %r of " - "task %r – removing schedule", - schedule.id, - schedule.task_id, - ) - break - - # Stop if the calculated fire time is in the future - if fire_time is None or fire_time > now: - schedule.next_fire_time = fire_time - break - - # Only keep all the fire times if coalesce policy = "all" - if schedule.coalesce is CoalescePolicy.all: - fire_times.append(fire_time) - elif schedule.coalesce is CoalescePolicy.latest: - fire_times[0] = fire_time - - # Add one or more jobs to the job queue - max_jitter = ( - schedule.max_jitter.total_seconds() - if schedule.max_jitter - else 0 - ) - for i, fire_time in enumerate(fire_times): - # Calculate a jitter if max_jitter > 0 - jitter = _zero_timedelta - if max_jitter: - if i + 1 < len(fire_times): - next_fire_time = fire_times[i + 1] - else: - next_fire_time = schedule.next_fire_time - - if next_fire_time is not None: - # Jitter must never be so high that it would cause a - # fire time to equal or exceed the next fire time - jitter_s = min( - [ - max_jitter, - ( - next_fire_time - - fire_time - - _microsecond_delta - ).total_seconds(), - ] - ) - jitter = timedelta( - seconds=random.uniform(0, jitter_s) - ) - fire_time += jitter - - schedule.last_fire_time = fire_time - job = Job( - task_id=schedule.task_id, - args=schedule.args, - kwargs=schedule.kwargs, - schedule_id=schedule.id, - scheduled_fire_time=fire_time, - jitter=jitter, - start_deadline=schedule.next_deadline, - tags=schedule.tags, - ) - await self.data_store.add_job(job) - - # Update the schedules (and release the scheduler's claim on them) - await self.data_store.release_schedules(self.identity, schedules) - - # If we received fewer schedules than the maximum amount, sleep - # until the next schedule is due or the scheduler is explicitly - # woken up - wait_time = None - if len(schedules) < 100: - self._wakeup_deadline = ( - await self.data_store.get_next_schedule_run_time() - ) - if self._wakeup_deadline: - wait_time = ( - self._wakeup_deadline - datetime.now(timezone.utc) - ).total_seconds() - self.logger.debug( - "Sleeping %.3f seconds until the next fire time (%s)", - wait_time, - self._wakeup_deadline, - ) - else: - self.logger.debug("Waiting for any due schedules to appear") - - with move_on_after(wait_time): - await self._wakeup_event.wait() - self._wakeup_event = anyio.Event() - else: - self.logger.debug( - "Processing more schedules on the next iteration" - ) + async with create_task_group() as task_group: + self._scheduler_cancel_scope = task_group.cancel_scope + exit_stack.callback(setattr, self, "_scheduler_cancel_scope", None) + + # Start processing due schedules, if configured to do so + if self.role in (SchedulerRole.scheduler, SchedulerRole.both): + await task_group.start(self._process_schedules) + + # Start processing due jobs, if configured to do so + if self.role in (SchedulerRole.worker, SchedulerRole.both): + await task_group.start(self._process_jobs) + + # Signal that the scheduler has started + self._state = RunState.started + self.logger.info("Scheduler started") + task_status.started() + await self.event_broker.publish_local(SchedulerStarted()) except BaseException as exc: exception = exc raise finally: self._state = RunState.stopped - # CancelledError is a subclass of Exception in Python 3.7 - if not exception or isinstance(exception, CancelledError): + if not exception or isinstance(exception, get_cancelled_exc_class()): self.logger.info("Scheduler stopped") elif isinstance(exception, Exception): self.logger.exception("Scheduler crashed") @@ -626,3 +539,239 @@ class AsyncScheduler: await self.event_broker.publish_local( SchedulerStopped(exception=exception) ) + + async def _process_schedules(self, *, task_status: TaskStatus) -> None: + wakeup_event = anyio.Event() + wakeup_deadline: datetime | None = None + + async def schedule_added_or_modified(self, event: Event) -> None: + event_ = cast("ScheduleAdded | ScheduleUpdated", event) + if not wakeup_deadline or ( + event_.next_fire_time and event_.next_fire_time < wakeup_deadline + ): + self.logger.debug( + "Detected a %s event – waking up the scheduler to process " + "schedules", + type(event).__name__, + ) + wakeup_event.set() + + subscription = self.event_broker.subscribe( + schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} + ) + with subscription: + # Signal that we are ready, and wait for the scheduler start event + task_status.started() + await self.get_next_event(SchedulerStarted) + + while self._state is RunState.started: + schedules = await self.data_store.acquire_schedules(self.identity, 100) + now = datetime.now(timezone.utc) + for schedule in schedules: + # Calculate a next fire time for the schedule, if possible + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + "Error computing next fire time for schedule %r of " + "task %r – removing schedule", + schedule.id, + schedule.task_id, + ) + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + max_jitter = ( + schedule.max_jitter.total_seconds() + if schedule.max_jitter + else 0 + ) + for i, fire_time in enumerate(fire_times): + # Calculate a jitter if max_jitter > 0 + jitter = _zero_timedelta + if max_jitter: + if i + 1 < len(fire_times): + next_fire_time = fire_times[i + 1] + else: + next_fire_time = schedule.next_fire_time + + if next_fire_time is not None: + # Jitter must never be so high that it would cause a + # fire time to equal or exceed the next fire time + jitter_s = min( + [ + max_jitter, + ( + next_fire_time + - fire_time + - _microsecond_delta + ).total_seconds(), + ] + ) + jitter = timedelta(seconds=random.uniform(0, jitter_s)) + fire_time += jitter + + schedule.last_fire_time = fire_time + job = Job( + task_id=schedule.task_id, + args=schedule.args, + kwargs=schedule.kwargs, + schedule_id=schedule.id, + scheduled_fire_time=fire_time, + jitter=jitter, + start_deadline=schedule.next_deadline, + tags=schedule.tags, + ) + await self.data_store.add_job(job) + + # Update the schedules (and release the scheduler's claim on them) + await self.data_store.release_schedules(self.identity, schedules) + + # If we received fewer schedules than the maximum amount, sleep + # until the next schedule is due or the scheduler is explicitly + # woken up + wait_time = None + if len(schedules) < 100: + wakeup_deadline = await self.data_store.get_next_schedule_run_time() + if wakeup_deadline: + wait_time = ( + wakeup_deadline - datetime.now(timezone.utc) + ).total_seconds() + self.logger.debug( + "Sleeping %.3f seconds until the next fire time (%s)", + wait_time, + wakeup_deadline, + ) + else: + self.logger.debug("Waiting for any due schedules to appear") + + with move_on_after(wait_time): + await wakeup_event.wait() + wakeup_event = anyio.Event() + else: + self.logger.debug("Processing more schedules on the next iteration") + + async def _process_jobs(self, *, task_status: TaskStatus) -> None: + wakeup_event = anyio.Event() + + async def job_added(event: Event) -> None: + if len(self._running_jobs) < self.max_concurrent_jobs: + wakeup_event.set() + + async with AsyncExitStack() as exit_stack: + # Start the job executors + for job_executor in self.job_executors.values(): + await job_executor.start(exit_stack) + + task_group = await exit_stack.enter_async_context(create_task_group()) + + # Fetch new jobs every time + exit_stack.enter_context(self.event_broker.subscribe(job_added, {JobAdded})) + + # Signal that we are ready, and wait for the scheduler start event + task_status.started() + await self.get_next_event(SchedulerStarted) + + while self._state is RunState.started: + limit = self.max_concurrent_jobs - len(self._running_jobs) + if limit > 0: + jobs = await self.data_store.acquire_jobs(self.identity, limit) + for job in jobs: + task = await self.data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + task_group.start_soon( + self._run_job, job, task.func, task.executor + ) + + await wakeup_event.wait() + wakeup_event = anyio.Event() + + async def _run_job(self, job: Job, func: Callable, executor: str) -> None: + try: + # Check if the job started before the deadline + start_time = datetime.now(timezone.utc) + if job.start_deadline is not None and start_time > job.start_deadline: + result = JobResult.from_job( + job, + outcome=JobOutcome.missed_start_deadline, + finished_at=start_time, + ) + await self.data_store.release_job(self.identity, job.task_id, result) + await self.event_broker.publish( + JobReleased.from_result(result, self.identity) + ) + return + + try: + job_executor = self.job_executors[executor] + except KeyError: + return + + token = current_job.set(JobInfo.from_job(job)) + try: + retval = await job_executor.run_job(func, job) + except get_cancelled_exc_class(): + self.logger.info("Job %s was cancelled", job.id) + with CancelScope(shield=True): + result = JobResult.from_job( + job, + outcome=JobOutcome.cancelled, + ) + await self.data_store.release_job( + self.identity, job.task_id, result + ) + await self.event_broker.publish( + JobReleased.from_result(result, self.identity) + ) + except BaseException as exc: + if isinstance(exc, Exception): + self.logger.exception("Job %s raised an exception", job.id) + else: + self.logger.error( + "Job %s was aborted due to %s", job.id, exc.__class__.__name__ + ) + + result = JobResult.from_job( + job, + JobOutcome.error, + exception=exc, + ) + await self.data_store.release_job( + self.identity, + job.task_id, + result, + ) + await self.event_broker.publish( + JobReleased.from_result(result, self.identity) + ) + if not isinstance(exc, Exception): + raise + else: + self.logger.info("Job %s completed successfully", job.id) + result = JobResult.from_job( + job, + JobOutcome.success, + return_value=retval, + ) + await self.data_store.release_job(self.identity, job.task_id, result) + await self.event_broker.publish( + JobReleased.from_result(result, self.identity) + ) + finally: + current_job.reset(token) + finally: + self._running_jobs.remove(job.id) |