summaryrefslogtreecommitdiff
path: root/src/apscheduler/schedulers/async_.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/schedulers/async_.py')
-rw-r--r--src/apscheduler/schedulers/async_.py507
1 files changed, 328 insertions, 179 deletions
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index 44fee27..11a1c67 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -4,10 +4,10 @@ import os
import platform
import random
import sys
-from asyncio import CancelledError
from collections.abc import MutableMapping
from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
+from inspect import isclass
from logging import Logger, getLogger
from types import TracebackType
from typing import Any, Callable, Iterable, Mapping, cast
@@ -15,12 +15,19 @@ from uuid import UUID, uuid4
import anyio
import attrs
-from anyio import TASK_STATUS_IGNORED, create_task_group, move_on_after
+from anyio import (
+ TASK_STATUS_IGNORED,
+ CancelScope,
+ create_task_group,
+ get_cancelled_exc_class,
+ move_on_after,
+)
from anyio.abc import TaskGroup, TaskStatus
from attr.validators import instance_of
-from .._context import current_async_scheduler
-from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
+from .. import JobAdded
+from .._context import current_async_scheduler, current_job
+from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState, SchedulerRole
from .._events import (
Event,
JobReleased,
@@ -35,8 +42,8 @@ from .._exceptions import (
JobLookupError,
ScheduleLookupError,
)
-from .._structures import Job, JobResult, Schedule, Task
-from .._worker import Worker
+from .._structures import Job, JobInfo, JobResult, Schedule, Task
+from .._validators import non_negative_number
from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger
from ..datastores.memory import MemoryDataStore
from ..eventbrokers.local import LocalEventBroker
@@ -54,32 +61,37 @@ _microsecond_delta = timedelta(microseconds=1)
_zero_timedelta = timedelta()
-@attrs.define(eq=False)
+@attrs.define(eq=False, kw_only=True)
class AsyncScheduler:
- """An asynchronous (AnyIO based) scheduler implementation."""
+ """
+ An asynchronous (AnyIO based) scheduler implementation.
+
+ :param data_store: the data store for tasks, schedules and jobs
+ :param event_broker: the event broker to use for publishing an subscribing events
+ :param max_concurrent_jobs: Maximum number of jobs the worker will run at once
+ :param role: specifies what the scheduler should be doing when running
+ :param process_schedules: ``True`` to process due schedules in this scheduler
+ """
data_store: DataStore = attrs.field(
- validator=instance_of(DataStore), factory=MemoryDataStore
+ kw_only=False, validator=instance_of(DataStore), factory=MemoryDataStore
)
event_broker: EventBroker = attrs.field(
- validator=instance_of(EventBroker), factory=LocalEventBroker
- )
- identity: str = attrs.field(kw_only=True, default=None)
- process_jobs: bool = attrs.field(kw_only=True, default=True)
- job_executors: MutableMapping[str, JobExecutor] | None = attrs.field(
- kw_only=True, default=None
+ kw_only=False, validator=instance_of(EventBroker), factory=LocalEventBroker
)
- default_job_executor: str | None = attrs.field(kw_only=True, default=None)
- process_schedules: bool = attrs.field(kw_only=True, default=True)
- logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
+ identity: str = attrs.field(default=None)
+ role: SchedulerRole = attrs.field(default=SchedulerRole.both)
+ max_concurrent_jobs: int = attrs.field(validator=non_negative_number, default=100)
+ job_executors: MutableMapping[str, JobExecutor] | None = attrs.field(default=None)
+ default_job_executor: str | None = attrs.field(default=None)
+ logger: Logger | None = attrs.field(default=getLogger(__name__))
_state: RunState = attrs.field(init=False, default=RunState.stopped)
- _task_group: TaskGroup | None = attrs.field(init=False, default=None)
+ _services_task_group: TaskGroup | None = attrs.field(init=False, default=None)
_exit_stack: AsyncExitStack = attrs.field(init=False, factory=AsyncExitStack)
_services_initialized: bool = attrs.field(init=False, default=False)
- _wakeup_event: anyio.Event = attrs.field(init=False)
- _wakeup_deadline: datetime | None = attrs.field(init=False, default=None)
- _schedule_added_subscription: Subscription = attrs.field(init=False)
+ _scheduler_cancel_scope: CancelScope | None = attrs.field(init=False, default=None)
+ _running_jobs: set[Job] = attrs.field(init=False, factory=set)
def __attrs_post_init__(self) -> None:
if not self.identity:
@@ -103,10 +115,10 @@ class AsyncScheduler:
await self._exit_stack.__aenter__()
try:
await self._ensure_services_initialized(self._exit_stack)
- self._task_group = await self._exit_stack.enter_async_context(
+ self._services_task_group = await self._exit_stack.enter_async_context(
create_task_group()
)
- self._exit_stack.callback(setattr, self, "_task_group", None)
+ self._exit_stack.callback(setattr, self, "_services_task_group", None)
except BaseException as exc:
await self._exit_stack.__aexit__(type(exc), exc, exc.__traceback__)
raise
@@ -123,7 +135,10 @@ class AsyncScheduler:
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
async def _ensure_services_initialized(self, exit_stack: AsyncExitStack) -> None:
- """Ensure that the data store and event broker have been initialized."""
+ """
+ Initialize the data store and event broker if this hasn't already been done.
+
+ """
if not self._services_initialized:
self._services_initialized = True
exit_stack.callback(setattr, self, "_services_initialized", False)
@@ -132,6 +147,7 @@ class AsyncScheduler:
await self.data_store.start(exit_stack, self.event_broker)
def _check_initialized(self) -> None:
+ """Raise RuntimeError if the services have not been initialized yet."""
if not self._services_initialized:
raise RuntimeError(
"The scheduler has not been initialized yet. Use the scheduler as an "
@@ -139,16 +155,6 @@ class AsyncScheduler:
"than run_until_complete()."
)
- async def _schedule_added_or_modified(self, event: Event) -> None:
- event_ = cast("ScheduleAdded | ScheduleUpdated", event)
- if not self._wakeup_deadline or (
- event_.next_fire_time and event_.next_fire_time < self._wakeup_deadline
- ):
- self.logger.debug(
- "Detected a %s event – waking up the scheduler", type(event).__name__
- )
- self._wakeup_event.set()
-
@property
def state(self) -> RunState:
"""The current running state of the scheduler."""
@@ -157,7 +163,7 @@ class AsyncScheduler:
def subscribe(
self,
callback: Callable[[Event], Any],
- event_types: Iterable[type[Event]] | None = None,
+ event_types: type[Event] | Iterable[type[Event]] | None = None,
*,
one_shot: bool = False,
is_async: bool = True,
@@ -170,19 +176,43 @@ class AsyncScheduler:
:param callback: callable to be called with the event object when an event is
published
- :param event_types: an iterable of concrete Event classes to subscribe to
+ :param event_types: an event class or an iterable event classes to subscribe to
:param one_shot: if ``True``, automatically unsubscribe after the first matching
event
:param is_async: ``True`` if the (synchronous) callback should be called on the
event loop thread, ``False`` if it should be called in a worker thread.
- If the callback is a coroutine function, this flag is ignored.
+ If ``callback`` is a coroutine function, this flag is ignored.
"""
self._check_initialized()
+ if isclass(event_types):
+ event_types = {event_types}
+
return self.event_broker.subscribe(
callback, event_types, is_async=is_async, one_shot=one_shot
)
+ async def get_next_event(
+ self, event_types: type[Event] | Iterable[type[Event]]
+ ) -> Event:
+ """
+ Wait until the next event matching one of the given types arrives.
+
+ :param event_types: an event class or an iterable event classes to subscribe to
+
+ """
+ received_event: Event
+
+ def receive_event(ev: Event) -> None:
+ nonlocal received_event
+ received_event = ev
+ event.set()
+
+ event = anyio.Event()
+ with self.subscribe(receive_event, event_types, one_shot=True):
+ await event.wait()
+ return received_event
+
async def add_schedule(
self,
func_or_task_id: str | Callable,
@@ -433,9 +463,9 @@ class AsyncScheduler:
For that, see :meth:`wait_until_stopped`.
"""
- if self._state is RunState.started:
+ if self._state is RunState.started and self._scheduler_cancel_scope:
self._state = RunState.stopping
- self._wakeup_event.set()
+ self._scheduler_cancel_scope.cancel()
async def wait_until_stopped(self) -> None:
"""
@@ -446,22 +476,17 @@ class AsyncScheduler:
``SchedulerStopped`` event.
"""
- if self._state in (RunState.stopped, RunState.stopping):
- return
-
- event = anyio.Event()
- with self.event_broker.subscribe(
- lambda ev: event.set(), {SchedulerStopped}, one_shot=True
- ):
- await event.wait()
+ if self._state not in (RunState.stopped, RunState.stopping):
+ await self.get_next_event(SchedulerStopped)
async def start_in_background(self) -> None:
self._check_initialized()
- await self._task_group.start(self.run_until_stopped)
+ await self._services_task_group.start(self.run_until_stopped)
async def run_until_stopped(
self, *, task_status: TaskStatus = TASK_STATUS_IGNORED
) -> None:
+ """Run the scheduler until explicitly stopped."""
if self._state is not RunState.stopped:
raise RuntimeError(
f'Cannot start the scheduler when it is in the "{self._state}" '
@@ -470,150 +495,38 @@ class AsyncScheduler:
self._state = RunState.starting
async with AsyncExitStack() as exit_stack:
- self._wakeup_event = anyio.Event()
await self._ensure_services_initialized(exit_stack)
- # Wake up the scheduler if the data store emits a significant schedule event
- exit_stack.enter_context(
- self.event_broker.subscribe(
- self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated}
- )
- )
-
# Set this scheduler as the current scheduler
token = current_async_scheduler.set(self)
exit_stack.callback(current_async_scheduler.reset, token)
- # Start the built-in worker, if configured to do so
- if self.process_jobs:
- worker = Worker(job_executors=self.job_executors)
- await worker.start(
- exit_stack, self.data_store, self.event_broker, self.identity
- )
-
- # Signal that the scheduler has started
- self._state = RunState.started
- task_status.started()
- await self.event_broker.publish_local(SchedulerStarted())
-
exception: BaseException | None = None
try:
- while self._state is RunState.started:
- schedules = await self.data_store.acquire_schedules(
- self.identity, 100
- )
- now = datetime.now(timezone.utc)
- for schedule in schedules:
- # Calculate a next fire time for the schedule, if possible
- fire_times = [schedule.next_fire_time]
- calculate_next = schedule.trigger.next
- while True:
- try:
- fire_time = calculate_next()
- except Exception:
- self.logger.exception(
- "Error computing next fire time for schedule %r of "
- "task %r – removing schedule",
- schedule.id,
- schedule.task_id,
- )
- break
-
- # Stop if the calculated fire time is in the future
- if fire_time is None or fire_time > now:
- schedule.next_fire_time = fire_time
- break
-
- # Only keep all the fire times if coalesce policy = "all"
- if schedule.coalesce is CoalescePolicy.all:
- fire_times.append(fire_time)
- elif schedule.coalesce is CoalescePolicy.latest:
- fire_times[0] = fire_time
-
- # Add one or more jobs to the job queue
- max_jitter = (
- schedule.max_jitter.total_seconds()
- if schedule.max_jitter
- else 0
- )
- for i, fire_time in enumerate(fire_times):
- # Calculate a jitter if max_jitter > 0
- jitter = _zero_timedelta
- if max_jitter:
- if i + 1 < len(fire_times):
- next_fire_time = fire_times[i + 1]
- else:
- next_fire_time = schedule.next_fire_time
-
- if next_fire_time is not None:
- # Jitter must never be so high that it would cause a
- # fire time to equal or exceed the next fire time
- jitter_s = min(
- [
- max_jitter,
- (
- next_fire_time
- - fire_time
- - _microsecond_delta
- ).total_seconds(),
- ]
- )
- jitter = timedelta(
- seconds=random.uniform(0, jitter_s)
- )
- fire_time += jitter
-
- schedule.last_fire_time = fire_time
- job = Job(
- task_id=schedule.task_id,
- args=schedule.args,
- kwargs=schedule.kwargs,
- schedule_id=schedule.id,
- scheduled_fire_time=fire_time,
- jitter=jitter,
- start_deadline=schedule.next_deadline,
- tags=schedule.tags,
- )
- await self.data_store.add_job(job)
-
- # Update the schedules (and release the scheduler's claim on them)
- await self.data_store.release_schedules(self.identity, schedules)
-
- # If we received fewer schedules than the maximum amount, sleep
- # until the next schedule is due or the scheduler is explicitly
- # woken up
- wait_time = None
- if len(schedules) < 100:
- self._wakeup_deadline = (
- await self.data_store.get_next_schedule_run_time()
- )
- if self._wakeup_deadline:
- wait_time = (
- self._wakeup_deadline - datetime.now(timezone.utc)
- ).total_seconds()
- self.logger.debug(
- "Sleeping %.3f seconds until the next fire time (%s)",
- wait_time,
- self._wakeup_deadline,
- )
- else:
- self.logger.debug("Waiting for any due schedules to appear")
-
- with move_on_after(wait_time):
- await self._wakeup_event.wait()
- self._wakeup_event = anyio.Event()
- else:
- self.logger.debug(
- "Processing more schedules on the next iteration"
- )
+ async with create_task_group() as task_group:
+ self._scheduler_cancel_scope = task_group.cancel_scope
+ exit_stack.callback(setattr, self, "_scheduler_cancel_scope", None)
+
+ # Start processing due schedules, if configured to do so
+ if self.role in (SchedulerRole.scheduler, SchedulerRole.both):
+ await task_group.start(self._process_schedules)
+
+ # Start processing due jobs, if configured to do so
+ if self.role in (SchedulerRole.worker, SchedulerRole.both):
+ await task_group.start(self._process_jobs)
+
+ # Signal that the scheduler has started
+ self._state = RunState.started
+ self.logger.info("Scheduler started")
+ task_status.started()
+ await self.event_broker.publish_local(SchedulerStarted())
except BaseException as exc:
exception = exc
raise
finally:
self._state = RunState.stopped
- # CancelledError is a subclass of Exception in Python 3.7
- if not exception or isinstance(exception, CancelledError):
+ if not exception or isinstance(exception, get_cancelled_exc_class()):
self.logger.info("Scheduler stopped")
elif isinstance(exception, Exception):
self.logger.exception("Scheduler crashed")
@@ -626,3 +539,239 @@ class AsyncScheduler:
await self.event_broker.publish_local(
SchedulerStopped(exception=exception)
)
+
+ async def _process_schedules(self, *, task_status: TaskStatus) -> None:
+ wakeup_event = anyio.Event()
+ wakeup_deadline: datetime | None = None
+
+ async def schedule_added_or_modified(self, event: Event) -> None:
+ event_ = cast("ScheduleAdded | ScheduleUpdated", event)
+ if not wakeup_deadline or (
+ event_.next_fire_time and event_.next_fire_time < wakeup_deadline
+ ):
+ self.logger.debug(
+ "Detected a %s event – waking up the scheduler to process "
+ "schedules",
+ type(event).__name__,
+ )
+ wakeup_event.set()
+
+ subscription = self.event_broker.subscribe(
+ schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated}
+ )
+ with subscription:
+ # Signal that we are ready, and wait for the scheduler start event
+ task_status.started()
+ await self.get_next_event(SchedulerStarted)
+
+ while self._state is RunState.started:
+ schedules = await self.data_store.acquire_schedules(self.identity, 100)
+ now = datetime.now(timezone.utc)
+ for schedule in schedules:
+ # Calculate a next fire time for the schedule, if possible
+ fire_times = [schedule.next_fire_time]
+ calculate_next = schedule.trigger.next
+ while True:
+ try:
+ fire_time = calculate_next()
+ except Exception:
+ self.logger.exception(
+ "Error computing next fire time for schedule %r of "
+ "task %r – removing schedule",
+ schedule.id,
+ schedule.task_id,
+ )
+ break
+
+ # Stop if the calculated fire time is in the future
+ if fire_time is None or fire_time > now:
+ schedule.next_fire_time = fire_time
+ break
+
+ # Only keep all the fire times if coalesce policy = "all"
+ if schedule.coalesce is CoalescePolicy.all:
+ fire_times.append(fire_time)
+ elif schedule.coalesce is CoalescePolicy.latest:
+ fire_times[0] = fire_time
+
+ # Add one or more jobs to the job queue
+ max_jitter = (
+ schedule.max_jitter.total_seconds()
+ if schedule.max_jitter
+ else 0
+ )
+ for i, fire_time in enumerate(fire_times):
+ # Calculate a jitter if max_jitter > 0
+ jitter = _zero_timedelta
+ if max_jitter:
+ if i + 1 < len(fire_times):
+ next_fire_time = fire_times[i + 1]
+ else:
+ next_fire_time = schedule.next_fire_time
+
+ if next_fire_time is not None:
+ # Jitter must never be so high that it would cause a
+ # fire time to equal or exceed the next fire time
+ jitter_s = min(
+ [
+ max_jitter,
+ (
+ next_fire_time
+ - fire_time
+ - _microsecond_delta
+ ).total_seconds(),
+ ]
+ )
+ jitter = timedelta(seconds=random.uniform(0, jitter_s))
+ fire_time += jitter
+
+ schedule.last_fire_time = fire_time
+ job = Job(
+ task_id=schedule.task_id,
+ args=schedule.args,
+ kwargs=schedule.kwargs,
+ schedule_id=schedule.id,
+ scheduled_fire_time=fire_time,
+ jitter=jitter,
+ start_deadline=schedule.next_deadline,
+ tags=schedule.tags,
+ )
+ await self.data_store.add_job(job)
+
+ # Update the schedules (and release the scheduler's claim on them)
+ await self.data_store.release_schedules(self.identity, schedules)
+
+ # If we received fewer schedules than the maximum amount, sleep
+ # until the next schedule is due or the scheduler is explicitly
+ # woken up
+ wait_time = None
+ if len(schedules) < 100:
+ wakeup_deadline = await self.data_store.get_next_schedule_run_time()
+ if wakeup_deadline:
+ wait_time = (
+ wakeup_deadline - datetime.now(timezone.utc)
+ ).total_seconds()
+ self.logger.debug(
+ "Sleeping %.3f seconds until the next fire time (%s)",
+ wait_time,
+ wakeup_deadline,
+ )
+ else:
+ self.logger.debug("Waiting for any due schedules to appear")
+
+ with move_on_after(wait_time):
+ await wakeup_event.wait()
+ wakeup_event = anyio.Event()
+ else:
+ self.logger.debug("Processing more schedules on the next iteration")
+
+ async def _process_jobs(self, *, task_status: TaskStatus) -> None:
+ wakeup_event = anyio.Event()
+
+ async def job_added(event: Event) -> None:
+ if len(self._running_jobs) < self.max_concurrent_jobs:
+ wakeup_event.set()
+
+ async with AsyncExitStack() as exit_stack:
+ # Start the job executors
+ for job_executor in self.job_executors.values():
+ await job_executor.start(exit_stack)
+
+ task_group = await exit_stack.enter_async_context(create_task_group())
+
+ # Fetch new jobs every time
+ exit_stack.enter_context(self.event_broker.subscribe(job_added, {JobAdded}))
+
+ # Signal that we are ready, and wait for the scheduler start event
+ task_status.started()
+ await self.get_next_event(SchedulerStarted)
+
+ while self._state is RunState.started:
+ limit = self.max_concurrent_jobs - len(self._running_jobs)
+ if limit > 0:
+ jobs = await self.data_store.acquire_jobs(self.identity, limit)
+ for job in jobs:
+ task = await self.data_store.get_task(job.task_id)
+ self._running_jobs.add(job.id)
+ task_group.start_soon(
+ self._run_job, job, task.func, task.executor
+ )
+
+ await wakeup_event.wait()
+ wakeup_event = anyio.Event()
+
+ async def _run_job(self, job: Job, func: Callable, executor: str) -> None:
+ try:
+ # Check if the job started before the deadline
+ start_time = datetime.now(timezone.utc)
+ if job.start_deadline is not None and start_time > job.start_deadline:
+ result = JobResult.from_job(
+ job,
+ outcome=JobOutcome.missed_start_deadline,
+ finished_at=start_time,
+ )
+ await self.data_store.release_job(self.identity, job.task_id, result)
+ await self.event_broker.publish(
+ JobReleased.from_result(result, self.identity)
+ )
+ return
+
+ try:
+ job_executor = self.job_executors[executor]
+ except KeyError:
+ return
+
+ token = current_job.set(JobInfo.from_job(job))
+ try:
+ retval = await job_executor.run_job(func, job)
+ except get_cancelled_exc_class():
+ self.logger.info("Job %s was cancelled", job.id)
+ with CancelScope(shield=True):
+ result = JobResult.from_job(
+ job,
+ outcome=JobOutcome.cancelled,
+ )
+ await self.data_store.release_job(
+ self.identity, job.task_id, result
+ )
+ await self.event_broker.publish(
+ JobReleased.from_result(result, self.identity)
+ )
+ except BaseException as exc:
+ if isinstance(exc, Exception):
+ self.logger.exception("Job %s raised an exception", job.id)
+ else:
+ self.logger.error(
+ "Job %s was aborted due to %s", job.id, exc.__class__.__name__
+ )
+
+ result = JobResult.from_job(
+ job,
+ JobOutcome.error,
+ exception=exc,
+ )
+ await self.data_store.release_job(
+ self.identity,
+ job.task_id,
+ result,
+ )
+ await self.event_broker.publish(
+ JobReleased.from_result(result, self.identity)
+ )
+ if not isinstance(exc, Exception):
+ raise
+ else:
+ self.logger.info("Job %s completed successfully", job.id)
+ result = JobResult.from_job(
+ job,
+ JobOutcome.success,
+ return_value=retval,
+ )
+ await self.data_store.release_job(self.identity, job.task_id, result)
+ await self.event_broker.publish(
+ JobReleased.from_result(result, self.identity)
+ )
+ finally:
+ current_job.reset(token)
+ finally:
+ self._running_jobs.remove(job.id)