diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-02-15 00:32:33 +0200 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-02-24 23:44:32 +0200 |
commit | f4e8c3a0242b082fa1ca6ed5c78094f8de5ba439 (patch) | |
tree | 20573886078f60b59ba2811b3fc156c877ff386d | |
parent | 5bdf7bd34300ec764b1c5c979deec1e468b0b822 (diff) | |
download | apscheduler-f4e8c3a0242b082fa1ca6ed5c78094f8de5ba439.tar.gz |
Implemented data store sharing and proper async support
32 files changed, 2086 insertions, 795 deletions
diff --git a/.github/workflows/codeqa-test.yml b/.github/workflows/codeqa-test.yml index 35b018f..8a87f9e 100644 --- a/.github/workflows/codeqa-test.yml +++ b/.github/workflows/codeqa-test.yml @@ -47,6 +47,6 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install the project and its dependencies - run: pip install .[test,cbor] + run: pip install -e .[test,cbor] - name: Test with pytest run: pytest diff --git a/apscheduler/abc.py b/apscheduler/abc.py index 4bab22b..0acf7fd 100644 --- a/apscheduler/abc.py +++ b/apscheduler/abc.py @@ -3,11 +3,10 @@ from base64 import b64decode, b64encode from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import ( - Any, AsyncContextManager, Callable, Dict, FrozenSet, Iterable, Iterator, List, Mapping, - Optional, Set) -from uuid import uuid4 + Any, Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional, Set, Type) +from uuid import UUID, uuid4 -from apscheduler.events import Event +from .policies import CoalescePolicy, ConflictPolicy class Trigger(Iterator[datetime], metaclass=ABCMeta): @@ -60,7 +59,7 @@ class Schedule: trigger: Trigger = field(compare=False) args: tuple = field(compare=False) kwargs: Dict[str, Any] = field(compare=False) - coalesce: bool = field(compare=False) + coalesce: CoalescePolicy = field(compare=False) misfire_grace_time: Optional[timedelta] = field(compare=False) tags: FrozenSet[str] = field(compare=False) next_fire_time: Optional[datetime] = field(compare=False, default=None) @@ -76,13 +75,13 @@ class Schedule: @dataclass(unsafe_hash=True) class Job: - id: str = field(init=False, default_factory=lambda: str(uuid4())) + id: UUID = field(init=False, default_factory=uuid4) task_id: str = field(compare=False) func: Callable = field(compare=False) args: tuple = field(compare=False) kwargs: Dict[str, Any] = field(compare=False) schedule_id: Optional[str] = field(compare=False, default=None) - scheduled_start_time: Optional[datetime] = field(compare=False, default=None) + scheduled_fire_time: Optional[datetime] = field(compare=False, default=None) start_deadline: Optional[datetime] = field(compare=False, default=None) tags: Optional[FrozenSet[str]] = field(compare=False, default_factory=frozenset) started_at: Optional[datetime] = field(init=False, compare=False, default=None) @@ -106,22 +105,44 @@ class Serializer(metaclass=ABCMeta): return self.deserialize(b64decode(serialized)) -class EventSource(metaclass=ABCMeta): - __slots__ = () +@dataclass(frozen=True) +class Event: + timestamp: datetime + +@dataclass +class EventSource(metaclass=ABCMeta): @abstractmethod - async def subscribe(self, callback: Callable[[Event], Any]) -> None: + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> None: """ Subscribe to events from this event source. :param callback: callable to be called with the event object when an event is published - :return: an async iterable yielding event objects + :param event_types: an iterable of concrete Event classes to subscribe to """ + @abstractmethod + def unsubscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> None: + """ + Cancel an event subscription -class DataStore(EventSource): - __slots__ = () + :param callback: + :param event_types: an iterable of concrete Event classes to unsubscribe from + :return: + """ + @abstractmethod + async def publish(self, event: Event) -> None: + """ + Publish an event. + + :param event: the event to publish + """ + + +class DataStore(EventSource): async def __aenter__(self): return self @@ -138,112 +159,80 @@ class DataStore(EventSource): """ @abstractmethod - async def add_or_replace_schedules(self, schedules: Iterable[Schedule]) -> None: + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: """ Add or update the given schedule in the data store. - :param schedules: schedules to be added or updated + :param schedule: schedule to be added + :param conflict_policy: policy that determines what to do if there is an existing schedule + with the same ID """ @abstractmethod - async def update_schedules(self, updates: Mapping[str, Dict[str, Any]]) -> Set[str]: - """ - Update one or more existing schedules. - - :param updates: mapping of schedule ID to attribute names to be updated - :return: the set of schedule IDs that were modified by this operation. - """ - - @abstractmethod - async def remove_schedules(self, ids: Optional[Set[str]] = None) -> None: + async def remove_schedules(self, ids: Iterable[str]) -> None: """ Remove schedules from the data store. - :param ids: a specific set of schedule IDs to remove, or ``None`` in which case all - schedules are removed + :param ids: a specific set of schedule IDs to remove """ @abstractmethod - async def get_next_fire_time(self) -> Optional[datetime]: + async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: """ - Return the earliest fire time among all unclaimed schedules. + Acquire unclaimed due schedules for processing. - If no running, unclaimed schedules exist, ``None`` is returned. + This method claims up to the requested number of schedules for the given scheduler and + returns them. + + :param scheduler_id: unique identifier of the scheduler + :param limit: maximum number of schedules to claim + :return: the list of claimed schedules """ @abstractmethod - async def acquire_due_schedules( - self, scheduler_id: str, - max_scheduled_time: datetime) -> AsyncContextManager[List[Schedule]]: + async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: """ - Acquire an undefined amount of due schedules not claimed by any other scheduler. + Release the claims on the given schedules and update them on the store. - This method claims due schedules for the given scheduler and returns them. - When the scheduler has updated the objects, it calls :meth:`release_due_schedules` to - release the claim on them. + :param scheduler_id: unique identifier of the scheduler + :param schedules: the previously claimed schedules """ - # @abstractmethod - # async def release_due_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: - # """ - # Update the given schedules and release the claim on them held by this scheduler. - # - # This method should do the following: - # - # #. Remove any of the schedules in the store that have no next fire time - # #. Update the schedules that do have a next fire time - # #. Release any locks held on the schedules by this scheduler - # - # :param scheduler_id: identifier of the scheduler - # :param schedules: schedules previously acquired using :meth:`acquire_due_schedules` - # """ - - # @abstractmethod - # async def acquire_job(self, worker_id: str, tags: Set[str]) -> Job: - # """ - # Claim and return the next matching job from the queue. - # - # :return: the acquired job - # """ - # - # @abstractmethod - # async def release_job(self, job: Job) -> None: - # """Remove the given job from the queue.""" - - -class Executor(EventSource): @abstractmethod - async def submit_job(self, job: Job) -> None: + async def add_job(self, job: Job) -> None: """ - Submit a task to be run in this executor. - - The executor may alter the ``id`` attribute of the job before returning. + Add a job to be executed by an eligible worker. :param job: the job object """ @abstractmethod - async def get_jobs(self) -> List[Job]: + async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: """ - Get jobs currently queued or running in this executor. + Get the list of pending jobs. - :return: list of jobs + :param ids: a specific set of job IDs to return, or ``None`` to return all jobs + :return: the list of matching pending jobs, in the order they will be given to workers """ + @abstractmethod + async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + """ + Acquire unclaimed jobs for execution. -class EventHub(metaclass=ABCMeta): - __slots__ = () - - async def start(self) -> None: - pass + This method claims up to the requested number of jobs for the given worker and returns + them. - async def stop(self) -> None: - pass + :param worker_id: unique identifier of the worker + :param limit: maximum number of jobs to claim and return + :return: the list of claimed jobs + """ @abstractmethod - async def publish(self, event: Event) -> None: - """Publish an event.""" + async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + """ + Releases the claim on the given jobs - @abstractmethod - async def subscribe(self, callback: Callable[[Event], Any]) -> None: - """Add a callback to be called when a new event is published.""" + :param worker_id: unique identifier of the worker + :param jobs: the previously claimed jobs + """ diff --git a/apscheduler/datastores/memory.py b/apscheduler/datastores/memory.py index 3ad76c4..826981a 100644 --- a/apscheduler/datastores/memory.py +++ b/apscheduler/datastores/memory.py @@ -1,120 +1,274 @@ -from contextlib import asynccontextmanager +from bisect import bisect_left, insort_right +from collections import defaultdict from dataclasses import dataclass, field from datetime import MAXYEAR, datetime, timezone -from typing import Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set +from functools import partial +from typing import Dict, Iterable, List, Optional, Set +from uuid import UUID -from apscheduler.abc import DataStore, Event, EventHub, Schedule -from apscheduler.eventhubs.local import LocalEventHub -from apscheduler.events import DataStoreEvent, SchedulesAdded, SchedulesRemoved, SchedulesUpdated -from sortedcontainers import SortedSet # type: ignore +from anyio import create_event, move_on_after +from anyio.abc import Event + +from ..abc import DataStore, Job, Schedule +from ..events import EventHub, ScheduleAdded, ScheduleEvent, ScheduleRemoved, ScheduleUpdated +from ..exceptions import ConflictingIdError +from ..policies import ConflictPolicy max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) -@dataclass(frozen=True, repr=False) -class MemoryScheduleStore(DataStore): - _event_hub: EventHub = field(init=False, default_factory=LocalEventHub) - _schedules: Set[Schedule] = field( - init=False, - default_factory=lambda: SortedSet(key=lambda x: x.next_fire_time or max_datetime)) - _schedules_by_id: Dict[str, Schedule] = field(default_factory=dict) +class ScheduleState: + __slots__ = 'schedule', 'next_fire_time', 'acquired_by', 'acquired_until' + + def __init__(self, schedule: Schedule): + self.schedule = schedule + self.next_fire_time = self.schedule.next_fire_time + self.acquired_by: Optional[str] = None + self.acquired_until: Optional[datetime] = None + + def __lt__(self, other): + if self.next_fire_time is None: + return False + elif other.next_fire_time is None: + return self.next_fire_time is not None + else: + return self.next_fire_time < other.next_fire_time + + def __eq__(self, other): + return self.schedule == other.schedule + + def __hash__(self): + return hash(self.schedule) + + def __repr__(self): + return (f'<{self.__class__.__name__} id={self.schedule.id!r} ' + f'next_fire_time={self.next_fire_time}>') + + +class JobState: + __slots__ = 'job', 'created_at', 'acquired_by', 'acquired_until' + + def __init__(self, job: Job): + self.job = job + self.created_at = datetime.now(timezone.utc) + self.acquired_by: Optional[str] = None + self.acquired_until: Optional[datetime] = None + + def __lt__(self, other): + return self.created_at < other.created_at + + def __eq__(self, other): + return self.job.id == other.job.id + + def __hash__(self): + return hash(self.job.id) + + def __repr__(self): + return f'<{self.__class__.__name__} id={self.job.id!r} task_id={self.job.task_id!r}' + + +@dataclass(repr=False) +class MemoryDataStore(DataStore, EventHub): + lock_expiration_delay: float = 30 + _schedules: List[ScheduleState] = field(init=False, default_factory=list) + _schedules_by_id: Dict[str, ScheduleState] = field(init=False, default_factory=dict) + _schedules_by_task_id: Dict[str, Set[ScheduleState]] = field( + init=False, default_factory=partial(defaultdict, set)) + _schedules_event: Optional[Event] = field(init=False, default=None) + _jobs: List[JobState] = field(init=False, default_factory=list) + _jobs_by_id: Dict[UUID, JobState] = field(init=False, default_factory=dict) + _jobs_by_task_id: Dict[str, Set[JobState]] = field(init=False, + default_factory=partial(defaultdict, set)) + _jobs_event: Optional[Event] = field(init=False, default=None) + _loans: int = 0 + + def _find_schedule_index(self, state: ScheduleState) -> Optional[int]: + left_index = bisect_left(self._schedules, state) + right_index = bisect_left(self._schedules, state) + return self._schedules.index(state, left_index, right_index + 1) + + def _find_job_index(self, state: JobState) -> Optional[int]: + left_index = bisect_left(self._jobs, state) + right_index = bisect_left(self._jobs, state) + return self._jobs.index(state, left_index, right_index + 1) + + async def __aenter__(self): + self._loans += 1 + if self._loans == 1: + self._schedules_event = create_event() + self._jobs_event = create_event() + + return await super().__aenter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._loans -= 1 + if self._loans == 0: + self._schedules_event = None + self._jobs_event = None + + return await super().__aexit__(exc_type, exc_val, exc_tb) + + async def clear(self) -> None: + self._schedules.clear() + self._schedules_by_id.clear() + self._schedules_by_task_id.clear() + self._jobs.clear() + self._jobs_by_id.clear() + self._jobs_by_task_id.clear() async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: - return [sched for sched in self._schedules if ids is None or sched.id in ids] - - async def add_or_replace_schedules(self, schedules: Iterable[Schedule]) -> None: - added_schedule_ids = set() - updated_schedule_ids = set() - for schedule in schedules: - if schedule.id in self._schedules_by_id: - updated_schedule_ids.add(schedule.id) - old_schedule = self._schedules_by_id[schedule.id] - self._schedules.remove(old_schedule) - else: - added_schedule_ids.add(schedule.id) - - self._schedules_by_id[schedule.id] = schedule - self._schedules.add(schedule) - - event: DataStoreEvent - if added_schedule_ids: - earliest_fire_time = min( - (schedule.next_fire_time for schedule in schedules - if schedule.id in added_schedule_ids and schedule.next_fire_time), - default=None) - event = SchedulesAdded(datetime.now(timezone.utc), added_schedule_ids, - earliest_fire_time) - await self._event_hub.publish(event) - - if updated_schedule_ids: - earliest_fire_time = min( - (schedule.next_fire_time for schedule in schedules - if schedule.id in updated_schedule_ids and schedule.next_fire_time), - default=None) - event = SchedulesUpdated(datetime.now(timezone.utc), updated_schedule_ids, - earliest_fire_time) - await self._event_hub.publish(event) - - async def update_schedules(self, updates: Mapping[str, Dict[str, Any]]) -> Set[str]: - updated_schedule_ids: Set[str] = set() - earliest_fire_time: Optional[datetime] = None - for schedule_id, changes in updates.items(): - try: - schedule = self._schedules_by_id[schedule_id] - except KeyError: - continue - - if changes: - updated_schedule_ids.add(schedule.id) - for key, value in changes.items(): - setattr(schedule, key, value) - if key == 'next_fire_time': - if not earliest_fire_time: - earliest_fire_time = schedule.next_fire_time - elif (schedule.next_fire_time - and schedule.next_fire_time < earliest_fire_time): - earliest_fire_time = schedule.next_fire_time - - event = SchedulesUpdated(datetime.now(timezone.utc), updated_schedule_ids, - earliest_fire_time) - await self._event_hub.publish(event) - - return updated_schedule_ids - - async def remove_schedules(self, ids: Optional[Set[str]] = None) -> None: - if ids is None: - removed_schedule_ids = set(self._schedules_by_id) - self._schedules_by_id.clear() - self._schedules.clear() + return [state.schedule for state in self._schedules + if ids is None or state.schedule.id in ids] + + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + old_state = self._schedules_by_id.get(schedule.id) + if old_state is not None: + if conflict_policy is ConflictPolicy.do_nothing: + return + elif conflict_policy is ConflictPolicy.exception: + raise ConflictingIdError(schedule.id) + + index = self._find_schedule_index(old_state) + del self._schedules[index] + self._schedules_by_task_id[old_state.schedule.task_id].remove(old_state) + + state = ScheduleState(schedule) + self._schedules_by_id[schedule.id] = state + self._schedules_by_task_id[schedule.task_id].add(state) + insort_right(self._schedules, state) + + if old_state is not None: + event = ScheduleUpdated(datetime.now(timezone.utc), schedule.id, + schedule.next_fire_time) else: - removed_schedule_ids = set() - for schedule_id in ids: - schedule = self._schedules_by_id.pop(schedule_id, None) - if schedule: - removed_schedule_ids.add(schedule.id) - self._schedules.remove(schedule) + event = ScheduleAdded(datetime.now(timezone.utc), schedule.id, + schedule.next_fire_time) + + await self.publish(event) + + old_event, self._schedules_event = self._schedules_event, create_event() + await old_event.set() + + async def remove_schedules(self, ids: Iterable[str]) -> None: + events: List[ScheduleRemoved] = [] + for schedule_id in ids: + state = self._schedules_by_id.pop(schedule_id, None) + if state: + self._schedules.remove(state) + events.append(ScheduleRemoved(datetime.now(timezone.utc), state.schedule.id)) + + for event in events: + await self.publish(event) + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + while True: + now = datetime.now(timezone.utc) + schedules: List[Schedule] = [] + wait_time = None + for state in self._schedules: + if state.acquired_by is not None and state.acquired_until >= now: + continue + elif state.next_fire_time is None: + break + elif state.next_fire_time > now: + wait_time = state.next_fire_time.timestamp() - now.timestamp() + break + + schedules.append(state.schedule) + state.acquired_by = scheduler_id + state.acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, now.tzinfo) + if len(schedules) == limit: + break + + if schedules: + return schedules + + # Wait until reaching the next known fire time, or a schedule is added or updated + async with move_on_after(wait_time): + await self._schedules_event.wait() + + async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + # Send update events for schedules that have a next time + now = datetime.now(timezone.utc) + update_events: List[ScheduleEvent] = [] + finished_schedule_ids: List[str] = [] + for s in schedules: + if s.next_fire_time is not None: + # Remove the schedule + schedule_state = self._schedules_by_id.get(s.id) + index = self._find_schedule_index(schedule_state) + del self._schedules[index] + + # Readd the schedule to its new position + schedule_state.next_fire_time = s.next_fire_time + schedule_state.acquired_by = None + schedule_state.acquired_until = None + insort_right(self._schedules, schedule_state) + update_events.append(ScheduleUpdated(now, s.id, s.next_fire_time)) + else: + finished_schedule_ids.append(s.id) + + # if updated_ids: + # next_fire_time = min(s.next_fire_time for s in schedules + # if s.next_fire_time is not None) + # event = ScheduleUpdated(datetime.now(timezone.utc), updated_ids, next_fire_time) + # await self.publish(event) + # old_event, self._schedules_event = self._schedules_event, create_event() + # await old_event.set() + + for event in update_events: + await self.publish(event) + + # Remove schedules that didn't get a new next fire time + await self.remove_schedules(finished_schedule_ids) + + # async def get_next_fire_time(self) -> Optional[datetime]: + # for schedule in self._schedules: + # return schedule.next_fire_time + # + # return None + + async def add_job(self, job: Job) -> None: + state = JobState(job) + self._jobs.append(state) + self._jobs_by_id[job.id] = state + self._jobs_by_task_id[job.task_id].add(state) + + old_event, self._jobs_event = self._jobs_event, create_event() + await old_event.set() - event = SchedulesRemoved(datetime.now(timezone.utc), removed_schedule_ids) - await self._event_hub.publish(event) + async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + if ids is not None: + ids = frozenset(ids) - @asynccontextmanager - async def acquire_due_schedules( - self, scheduler_id: str, - max_scheduled_time: datetime) -> AsyncGenerator[List[Schedule], None]: - pending: List[Schedule] = [] - for schedule in self._schedules: - if not schedule.next_fire_time or schedule.next_fire_time > max_scheduled_time: - break + return [state.job for state in self._jobs if ids is None or state.job.id in ids] - pending.append(schedule) + async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + while True: + now = datetime.now(timezone.utc) + jobs: List[Job] = [] + for state in self._jobs: + if state.acquired_by is not None and state.acquired_until >= now: + continue - yield pending + jobs.append(state.job) + state.acquired_by = worker_id + state.acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, now.tzinfo) + if len(jobs) == limit: + break - async def get_next_fire_time(self) -> Optional[datetime]: - for schedule in self._schedules: - return schedule.next_fire_time + if jobs: + return jobs - return None + # Wait until a new job is added + await self._jobs_event.wait() - async def subscribe(self, callback: Callable[[Event], Any]) -> None: - await self._event_hub.subscribe(callback) + async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + assert self._jobs + indexes = sorted((self._find_job_index(self._jobs_by_id[j.id]) for j in jobs), + reverse=True) + for i in indexes: + state = self._jobs.pop(i) + self._jobs_by_task_id[state.job.task_id].remove(state) diff --git a/apscheduler/datastores/mongodb.py b/apscheduler/datastores/mongodb.py new file mode 100644 index 0000000..f547c90 --- /dev/null +++ b/apscheduler/datastores/mongodb.py @@ -0,0 +1,272 @@ +import asyncio +import logging +from datetime import datetime, timezone +from typing import Iterable, List, Optional, Set, Tuple + +import sniffio +from anyio import create_task_group, move_on_after +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCollection +from pymongo import ASCENDING, DeleteOne, UpdateOne +from pymongo.collection import Collection +from pymongo.errors import DuplicateKeyError + +from ..abc import DataStore, Job, Schedule, Serializer +from ..events import EventHub, ScheduleAdded, ScheduleRemoved, ScheduleUpdated +from ..exceptions import ConflictingIdError, DeserializationError, SerializationError +from ..policies import ConflictPolicy +from ..serializers.pickle import PickleSerializer + + +class MongoDBDataStore(DataStore, EventHub): + def __init__(self, client: AsyncIOMotorClient, *, serializer: Optional[Serializer] = None, + database: str = 'apscheduler', schedules_collection: str = 'schedules', + jobs_collection: str = 'jobs', lock_expiration_delay: float = 30, + max_poll_time: Optional[float] = 1, start_from_scratch: bool = False): + super().__init__() + if not client.delegate.codec_options.tz_aware: + raise ValueError('MongoDB client must have tz_aware set to True') + + self.client = client + self.serializer = serializer or PickleSerializer() + self.lock_expiration_delay = lock_expiration_delay + self.max_poll_time = max_poll_time + self.start_from_scratch = start_from_scratch + self._database = client[database] + self._schedules: AsyncIOMotorCollection = self._database[schedules_collection] + self._jobs: AsyncIOMotorCollection = self._database[jobs_collection] + self._logger = logging.getLogger(__name__) + self._watch_collections = hasattr(self.client, '_topology_settings') + self._loans = 0 + + async def __aenter__(self): + asynclib = sniffio.current_async_library() or '(unknown)' + if asynclib != 'asyncio': + raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}') + + if self._loans == 0: + self._schedules_event = asyncio.Event() + self._jobs_event = asyncio.Event() + await self._setup() + + self._loans += 1 + if self._loans == 1 and self._watch_collections: + self._task_group = create_task_group() + await self._task_group.__aenter__() + await self._task_group.spawn(self._watch_collection, self._schedules) + await self._task_group.spawn(self._watch_collection, self._jobs) + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + assert self._loans + self._loans -= 1 + if self._loans == 0 and self._watch_collections: + await self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def _watch_collection(self, collection: Collection) -> None: + pipeline = [{'$match': {'operationType': {'$in': ['insert', 'update', 'replace']}}}, + {'$project': ['next_fire_time']}] + async with collection.watch(pipeline, batch_size=1, + max_await_time_ms=self.max_poll_time) as stream: + await stream.next() + + async def _setup(self) -> None: + server_info = await self.client.server_info() + if server_info['versionArray'] < [4, 0]: + raise RuntimeError(f"MongoDB server must be at least v4.0; current version = " + f"{server_info['version']}") + + if self.start_from_scratch: + await self._schedules.delete_many({}) + await self._jobs.delete_many({}) + + await self._schedules.create_index('next_fire_time') + await self._jobs.create_index('task_id') + await self._jobs.create_index('created_at') + await self._jobs.create_index('tags') + + async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + schedules: List[Schedule] = [] + filters = {'_id': {'$in': list(ids)}} if ids is not None else {} + cursor = self._schedules.find(filters, projection=['_id', 'serialized_data']).sort('_id') + for document in await cursor.to_list(None): + try: + schedule = self.serializer.deserialize(document['serialized_data']) + except DeserializationError: + self._logger.warning('Failed to deserialize schedule %r', document['_id']) + continue + + schedules.append(schedule) + + return schedules + + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + serialized_data = self.serializer.serialize(schedule) + document = { + '_id': schedule.id, + 'task_id': schedule.task_id, + 'serialized_data': serialized_data, + 'next_fire_time': schedule.next_fire_time + } + try: + await self._schedules.insert_one(document) + except DuplicateKeyError: + if conflict_policy is ConflictPolicy.exception: + raise ConflictingIdError(schedule.id) from None + elif conflict_policy is ConflictPolicy.replace: + await self._schedules.replace_one({'_id': schedule.id}, document, True) + await self.publish( + ScheduleUpdated(datetime.now(timezone.utc), schedule.id, + schedule.next_fire_time) + ) + else: + await self.publish( + ScheduleAdded(datetime.now(timezone.utc), schedule.id, schedule.next_fire_time) + ) + + self._schedules_event.set() + + async def remove_schedules(self, ids: Iterable[str]) -> None: + async with await self.client.start_session() as s, s.start_transaction(): + filters = {'_id': {'$in': list(ids)}} if ids is not None else {} + cursor = self._schedules.find(filters, projection=['_id']) + ids = [doc['_id'] for doc in await cursor.to_list(None)] + if ids: + await self._schedules.delete_many(filters) + + now = datetime.now(timezone.utc) + for schedule_id in ids: + await self.publish(ScheduleRemoved(now, schedule_id)) + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + while True: + schedules: List[Schedule] = [] + async with await self.client.start_session() as s, s.start_transaction(): + cursor = self._schedules.find( + {'next_fire_time': {'$ne': None}, + '$or': [{'acquired_until': {'$exists': False}}, + {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] + }, + projection=['serialized_data'] + ).sort('next_fire_time') + for document in await cursor.to_list(length=limit): + schedule = self.serializer.deserialize(document['serialized_data']) + schedules.append(schedule) + + if schedules: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, now.tzinfo) + filters = {'_id': {'$in': [schedule.id for schedule in schedules]}} + update = {'$set': {'acquired_by': scheduler_id, + 'acquired_until': acquired_until}} + await self._schedules.update_many(filters, update) + return schedules + + async with move_on_after(self.max_poll_time): + await self._schedules_event.wait() + + async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + updated_schedules: List[Tuple[str, datetime]] = [] + finished_schedule_ids: List[str] = [] + async with await self.client.start_session() as s, s.start_transaction(): + # Update schedules that have a next fire time + requests = [] + for schedule in schedules: + filters = {'_id': schedule.id, 'acquired_by': scheduler_id} + if schedule.next_fire_time is not None: + try: + serialized_data = self.serializer.serialize(schedule) + except SerializationError: + self._logger.exception('Error serializing schedule %r – ' + 'removing from data store', schedule.id) + requests.append(DeleteOne(filters)) + finished_schedule_ids.append(schedule.id) + continue + + update = { + '$unset': { + 'acquired_by': True, + 'acquired_until': True, + }, + '$set': { + 'next_fire_time': schedule.next_fire_time, + 'serialized_data': serialized_data + } + } + requests.append(UpdateOne(filters, update)) + updated_schedules.append((schedule.id, schedule.next_fire_time)) + else: + requests.append(DeleteOne(filters)) + finished_schedule_ids.append(schedule.id) + + if requests: + await self._schedules.bulk_write(requests, ordered=False) + now = datetime.now(timezone.utc) + for schedule_id, next_fire_time in updated_schedules: + await self.publish(ScheduleUpdated(now, schedule_id, next_fire_time)) + + now = datetime.now(timezone.utc) + for schedule_id in finished_schedule_ids: + await self.publish(ScheduleRemoved(now, schedule_id)) + + async def add_job(self, job: Job) -> None: + serialized_data = self.serializer.serialize(job) + document = { + '_id': job.id, + 'serialized_data': serialized_data, + 'task_id': job.task_id, + 'created_at': datetime.now(timezone.utc), + 'tags': list(job.tags) + } + await self._jobs.insert_one(document) + self._jobs_event.set() + + async def get_jobs(self, ids: Optional[Iterable[str]] = None) -> List[Job]: + jobs: List[Job] = [] + filters = {'_id': {'$in': list(ids)}} if ids is not None else {} + cursor = self._jobs.find(filters, projection=['_id', 'serialized_data']).sort('_id') + for document in await cursor.to_list(None): + try: + job = self.serializer.deserialize(document['serialized_data']) + except DeserializationError: + self._logger.warning('Failed to deserialize job %r', document['_id']) + continue + + jobs.append(job) + + return jobs + + async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + jobs: List[Job] = [] + while True: + async with await self.client.start_session() as s, s.start_transaction(): + cursor = self._jobs.find( + {'$or': [{'acquired_until': {'$exists': False}}, + {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] + }, + projection=['serialized_data'], + sort=[('created_at', ASCENDING)] + ) + for document in await cursor.to_list(length=limit): + job = self.serializer.deserialize(document['serialized_data']) + jobs.append(job) + + if jobs: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, timezone.utc) + filters = {'_id': {'$in': [job.id for job in jobs]}} + update = {'$set': {'acquired_by': worker_id, + 'acquired_until': acquired_until}} + await self._jobs.update_many(filters, update) + return jobs + + async with move_on_after(self.max_poll_time): + await self._jobs_event.wait() + self._jobs_event.clear() + + async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + filters = {'_id': {'$in': [job.id for job in jobs]}, 'acquired_by': worker_id} + await self._jobs.delete_many(filters) diff --git a/apscheduler/datastores/postgresql.py b/apscheduler/datastores/postgresql.py new file mode 100644 index 0000000..570badb --- /dev/null +++ b/apscheduler/datastores/postgresql.py @@ -0,0 +1,331 @@ +import asyncio +import logging +from asyncio import current_task +from datetime import datetime, timezone +from typing import Iterable, List, Optional, Set +from uuid import UUID + +import sniffio +from anyio import create_task_group, move_on_after, sleep +from anyio.abc import TaskGroup +from asyncpg import UniqueViolationError +from asyncpg.pool import Pool + +from ..abc import DataStore, Job, Schedule, Serializer +from ..events import ( + EventHub, JobAdded, ScheduleAdded, ScheduleEvent, ScheduleRemoved, + ScheduleUpdated) +from ..exceptions import ConflictingIdError, SerializationError +from ..policies import ConflictPolicy +from ..serializers.pickle import PickleSerializer + +logger = logging.getLogger(__name__) + + +class PostgresqlDataStore(DataStore, EventHub): + _task_group: TaskGroup + _schedules_event: Optional[asyncio.Event] = None + _jobs_event: Optional[asyncio.Event] = None + + def __init__(self, pool: Pool, *, schema: str = 'public', + notify_channel: Optional[str] = 'apscheduler', + serializer: Optional[Serializer] = None, + lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1, + max_idle_time: float = 60, start_from_scratch: bool = False): + super().__init__() + self.pool = pool + self.schema = schema + self.notify_channel = notify_channel + self.serializer = serializer or PickleSerializer() + self.lock_expiration_delay = lock_expiration_delay + self.max_poll_time = max_poll_time + self.max_idle_time = max_idle_time + self.start_from_scratch = start_from_scratch + self._logger = logging.getLogger(__name__) + self._loans = 0 + + async def __aenter__(self): + asynclib = sniffio.current_async_library() or '(unknown)' + if asynclib != 'asyncio': + raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}') + + if self._loans == 0: + self._schedules_event = asyncio.Event() + self._jobs_event = asyncio.Event() + await self._setup() + + self._loans += 1 + if self._loans == 1 and self.notify_channel: + print('entering postgresql data store in task', id(current_task())) + self._task_group = create_task_group() + await self._task_group.__aenter__() + await self._task_group.spawn(self._listen_notifications) + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + assert self._loans + self._loans -= 1 + if self._loans == 0 and self.notify_channel: + print('exiting postgresql data store in task', id(current_task())) + await self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + del self._schedules_event + del self._jobs_event + + async def _listen_notifications(self) -> None: + def callback(connection, pid, channel: str, payload: str) -> None: + self._logger.debug('Received notification on channel %s: %s', channel, payload) + if payload == 'schedule': + self._schedules_event.set() + elif payload == 'job': + self._jobs_event.set() + + while True: + async with self.pool.acquire() as conn: + await conn.add_listener(self.notify_channel, callback) + try: + while True: + await sleep(self.max_idle_time) + await conn.execute('SELECT 1') + finally: + await conn.remove_listener(self.notify_channel, callback) + + async def _setup(self) -> None: + async with self.pool.acquire() as conn, conn.transaction(): + if self.start_from_scratch: + await conn.execute(f"DROP TABLE IF EXISTS {self.schema}.schedules") + await conn.execute(f"DROP TABLE IF EXISTS {self.schema}.jobs") + await conn.execute(f"DROP TABLE IF EXISTS {self.schema}.metadata") + + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.schema}.metadata ( + schema_version INTEGER NOT NULL + ) + """) + version = await conn.fetchval(f"SELECT schema_version FROM {self.schema}.metadata") + if version is None: + await conn.execute(f"INSERT INTO {self.schema}.metadata VALUES (1)") + await conn.execute(f""" + CREATE TABLE {self.schema}.schedules ( + id TEXT PRIMARY KEY, + task_id TEXT NOT NULL, + serialized_data BYTEA NOT NULL, + next_fire_time TIMESTAMP WITH TIME ZONE, + acquired_by TEXT, + acquired_until TIMESTAMP WITH TIME ZONE + ) WITH (fillfactor = 80) + """) + await conn.execute(f"CREATE INDEX ON {self.schema}.schedules (next_fire_time)") + await conn.execute(f""" + CREATE TABLE {self.schema}.jobs ( + id UUID PRIMARY KEY, + serialized_data BYTEA NOT NULL, + task_id TEXT NOT NULL, + tags TEXT[] NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL, + acquired_by TEXT, + acquired_until TIMESTAMP WITH TIME ZONE + ) WITH (fillfactor = 80) + """) + await conn.execute(f"CREATE INDEX ON {self.schema}.jobs (task_id)") + await conn.execute(f"CREATE INDEX ON {self.schema}.jobs (tags)") + elif version > 1: + raise RuntimeError(f'Unexpected schema version ({version}); ' + f'only version 1 is supported by this version of APScheduler') + + async def clear(self) -> None: + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(f"TRUNCATE TABLE {self.schema}.schedules") + await conn.execute(f"TRUNCATE TABLE {self.schema}.jobs") + + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + event: Optional[ScheduleEvent] = None + serialized_data = self.serializer.serialize(schedule) + query = (f"INSERT INTO {self.schema}.schedules (id, serialized_data, task_id, " + f"next_fire_time) VALUES ($1, $2, $3, $4)") + async with self.pool.acquire() as conn: + try: + async with conn.transaction(): + await conn.execute(query, schedule.id, serialized_data, schedule.task_id, + schedule.next_fire_time) + except UniqueViolationError: + if conflict_policy is ConflictPolicy.exception: + raise ConflictingIdError(schedule.id) from None + elif conflict_policy is ConflictPolicy.replace: + query = (f"UPDATE {self.schema}.schedules SET serialized_data = $2, " + f"task_id = $3, next_fire_time = $4 WHERE id = $1") + async with conn.transaction(): + await conn.execute(query, schedule.id, serialized_data, schedule.task_id, + schedule.next_fire_time) + event = ScheduleUpdated(datetime.now(timezone.utc), schedule.id, + schedule.next_fire_time) + else: + event = ScheduleAdded(datetime.now(timezone.utc), schedule.id, + schedule.next_fire_time) + + if event: + await self.publish(event) + + if self.notify_channel: + await self.pool.execute(f"NOTIFY {self.notify_channel}, 'schedule'") + + async def remove_schedules(self, ids: Iterable[str]) -> None: + async with self.pool.acquire() as conn, conn.transaction(): + now = datetime.now(timezone.utc) + query = (f"DELETE FROM {self.schema}.schedules " + f"WHERE id = any($1::text[]) " + f" AND (acquired_until IS NULL OR acquired_until < $2) " + f"RETURNING id") + removed_ids = [row[0] for row in await conn.fetch(query, list(ids), now)] + + for schedule_id in removed_ids: + await self.publish(ScheduleRemoved(now, schedule_id)) + + async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + query = f"SELECT serialized_data FROM {self.schema}.schedules" + args = () + if ids: + query += " WHERE id = any($1::text[])" + args = (ids,) + + query += " ORDER BY id" + records = await self.pool.fetch(query, *args) + return [self.serializer.deserialize(r[0]) for r in records] + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + while True: + schedules: List[Schedule] = [] + async with self.pool.acquire() as conn, conn.transaction(): + acquired_until = datetime.fromtimestamp( + datetime.now(timezone.utc).timestamp() + self.lock_expiration_delay, + timezone.utc) + records = await conn.fetch(f""" + WITH schedule_ids AS ( + SELECT id FROM {self.schema}.schedules + WHERE next_fire_time IS NOT NULL AND next_fire_time <= $1 + AND (acquired_until IS NULL OR $1 > acquired_until) + ORDER BY next_fire_time + FOR NO KEY UPDATE SKIP LOCKED + FETCH FIRST $2 ROWS ONLY + ) + UPDATE {self.schema}.schedules SET acquired_by = $3, acquired_until = $4 + WHERE id IN (SELECT id FROM schedule_ids) + RETURNING serialized_data + """, datetime.now(timezone.utc), limit, scheduler_id, acquired_until) + + for record in records: + schedule = self.serializer.deserialize(record['serialized_data']) + schedules.append(schedule) + + if schedules: + return schedules + + async with move_on_after(self.max_poll_time): + await self._schedules_event.wait() + + async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + update_events: List[ScheduleUpdated] = [] + finished_schedule_ids: List[str] = [] + async with self.pool.acquire() as conn, conn.transaction(): + update_args = [] + now = datetime.now(timezone.utc) + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_data = self.serializer.serialize(schedule) + except SerializationError: + self._logger.exception('Error serializing schedule %r – ' + 'removing from data store', schedule.id) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append((serialized_data, schedule.next_fire_time, schedule.id)) + update_events.append( + ScheduleUpdated(now, schedule.id, schedule.next_fire_time)) + else: + finished_schedule_ids.append(schedule.id) + + # Update schedules that have a next fire time + if update_args: + await conn.executemany( + f"UPDATE {self.schema}.schedules SET serialized_data = $1, " + f"next_fire_time = $2, acquired_by = NULL, acquired_until = NULL " + f"WHERE id = $3 AND acquired_by = {scheduler_id!r}", update_args) + + # Remove schedules that have no next fire time or failed to serialize + if finished_schedule_ids: + await conn.execute( + f"DELETE FROM {self.schema}.schedules " + f"WHERE id = any($1::text[]) AND acquired_by = $2", + list(finished_schedule_ids), scheduler_id) + + for event in update_events: + await self.publish(event) + + if update_events and self.notify_channel: + await self.pool.execute(f"NOTIFY {self.notify_channel}, 'schedule'") + + for schedule_id in finished_schedule_ids: + event = ScheduleRemoved(datetime.now(timezone.utc), schedule_id) + await self.publish(event) + + async def add_job(self, job: Job) -> None: + now = datetime.now(timezone.utc) + query = (f"INSERT INTO {self.schema}.jobs (id, task_id, created_at, serialized_data, " + f"tags) VALUES($1, $2, $3, $4, $5)") + await self.pool.execute(query, job.id, job.task_id, now, self.serializer.serialize(job), + job.tags) + await self.publish(JobAdded(now, job.id, job.task_id, job.schedule_id)) + + if self.notify_channel: + await self.pool.execute(f"NOTIFY {self.notify_channel}, 'job'") + + async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + query = f"SELECT serialized_data FROM {self.schema}.jobs" + args = () + if ids: + query += " WHERE id = any($1::uuid[])" + args = (ids,) + + query += " ORDER BY id" + records = await self.pool.fetch(query, *args) + return [self.serializer.deserialize(r[0]) for r in records] + + async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + while True: + print('acquiring jobs') + jobs: List[Job] = [] + async with self.pool.acquire() as conn, conn.transaction(): + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, timezone.utc) + records = await conn.fetch(f""" + WITH job_ids AS ( + SELECT id FROM {self.schema}.jobs + WHERE acquired_until IS NULL OR acquired_until < $1 + ORDER BY created_at + FOR NO KEY UPDATE SKIP LOCKED + FETCH FIRST $2 ROWS ONLY + ) + UPDATE {self.schema}.jobs SET acquired_by = $3, acquired_until = $4 + WHERE id IN (SELECT id FROM job_ids) + RETURNING serialized_data + """, now, limit, worker_id, acquired_until) + + for record in records: + job = self.serializer.deserialize(record['serialized_data']) + jobs.append(job) + + if jobs: + return jobs + + async with move_on_after(self.max_poll_time): + await self._jobs_event.wait() + self._jobs_event.clear() + + async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + job_ids = {j.id for j in jobs} + await self.pool.execute( + f"DELETE FROM {self.schema}.jobs WHERE acquired_by = $1 AND id = any($2::uuid[])", + worker_id, job_ids) diff --git a/apscheduler/eventhubs/__init__.py b/apscheduler/eventhubs/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/apscheduler/eventhubs/__init__.py +++ /dev/null diff --git a/apscheduler/eventhubs/local.py b/apscheduler/eventhubs/local.py deleted file mode 100644 index b429085..0000000 --- a/apscheduler/eventhubs/local.py +++ /dev/null @@ -1,25 +0,0 @@ -import logging -from dataclasses import dataclass, field -from typing import Any, Callable, Coroutine, List - -from apscheduler.abc import EventHub -from apscheduler.events import Event - -logger = logging.getLogger(__name__) - - -@dataclass -class LocalEventHub(EventHub): - _callbacks: List[Callable[[Event], Any]] = field(init=False, default_factory=list) - - async def publish(self, event: Event) -> None: - for callback in self._callbacks: - try: - retval = callback(event) - if isinstance(retval, Coroutine): - await retval - except Exception: - logger.exception('Error delivering event to callback %r', callback) - - async def subscribe(self, callback: Callable[[Event], Any]) -> None: - self._callbacks.append(callback) diff --git a/apscheduler/events.py b/apscheduler/events.py index d99fd9f..7a67304 100644 --- a/apscheduler/events.py +++ b/apscheduler/events.py @@ -1,72 +1,169 @@ -from dataclasses import dataclass +from collections import defaultdict +from contextlib import ExitStack, suppress +from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Optional, Set +from inspect import isawaitable, isclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Type +from uuid import UUID +from warnings import warn +from anyio import run_sync_in_worker_thread, start_blocking_portal +from anyio.abc import BlockingPortal -@dataclass(frozen=True) -class Event: - timestamp: datetime +from .abc import Event, EventSource @dataclass(frozen=True) -class ExecutorEvent(Event): - job_id: str +class JobEvent(Event): + job_id: UUID task_id: str schedule_id: Optional[str] - scheduled_start_time: Optional[datetime] @dataclass(frozen=True) -class JobSubmissionFailed(ExecutorEvent): - exception: BaseException - formatted_traceback: str +class JobAdded(JobEvent): + pass @dataclass(frozen=True) -class JobAdded(ExecutorEvent): +class JobUpdated(JobEvent): pass @dataclass(frozen=True) -class JobUpdated(ExecutorEvent): +class JobExecutionEvent(JobEvent): + job_id: UUID + task_id: str + schedule_id: Optional[str] + scheduled_fire_time: Optional[datetime] + start_time: datetime + start_deadline: datetime + + +@dataclass(frozen=True) +class JobDeadlineMissed(JobExecutionEvent): pass @dataclass(frozen=True) -class JobDeadlineMissed(ExecutorEvent): - start_deadline: datetime +class JobStarted(JobExecutionEvent): + pass @dataclass(frozen=True) -class JobSuccessful(ExecutorEvent): - start_time: datetime - start_deadline: Optional[datetime] +class JobSuccessful(JobExecutionEvent): return_value: Any @dataclass(frozen=True) -class JobFailed(ExecutorEvent): - start_time: datetime - start_deadline: Optional[datetime] - formatted_traceback: str +class JobFailed(JobExecutionEvent): + traceback: str exception: Optional[BaseException] = None @dataclass(frozen=True) -class DataStoreEvent(Event): - schedule_ids: Set[str] +class ScheduleEvent(Event): + schedule_id: str @dataclass(frozen=True) -class SchedulesAdded(DataStoreEvent): - earliest_next_fire_time: Optional[datetime] +class ScheduleAdded(ScheduleEvent): + next_fire_time: Optional[datetime] @dataclass(frozen=True) -class SchedulesUpdated(DataStoreEvent): - earliest_next_fire_time: Optional[datetime] +class ScheduleUpdated(ScheduleEvent): + next_fire_time: Optional[datetime] @dataclass(frozen=True) -class SchedulesRemoved(DataStoreEvent): +class ScheduleRemoved(ScheduleEvent): pass + + +_all_event_types = [x for x in locals().values() if isclass(x) and issubclass(x, Event)] + + +# +# Event delivery +# + +@dataclass +class EventHub(EventSource): + _subscribers: Dict[Type[Event], List[Callable[[Event], Any]]] = field( + init=False, default_factory=lambda: defaultdict(list)) + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> None: + if event_types is None: + event_types = _all_event_types + + for event_type in event_types: + existing_callbacks = self._subscribers[event_type] + if callback not in existing_callbacks: + existing_callbacks.append(callback) + + def unsubscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> None: + if event_types is None: + event_types = _all_event_types + + for event_type in event_types: + existing_callbacks = self._subscribers.get(event_type, []) + with suppress(ValueError): + existing_callbacks.remove(callback) + + async def publish(self, event: Event) -> None: + for callback in self._subscribers[type(event)]: + try: + retval = callback(event) + if isawaitable(retval): + await retval + except Exception as exc: + warn(f'Failed to deliver {event.__class__.__name__} event to callback ' + f'{callback!r}: {exc.__class__.__name__}: {exc}') + + +class SyncEventSource: + _subscribers: Dict[Type[Event], List[Callable[[Event], Any]]] + + def __init__(self, async_event_source: EventSource, portal: Optional[BlockingPortal] = None): + self.portal = portal + self._async_event_source = async_event_source + self._async_event_source.subscribe(self._forward_async_event) + self._exit_stack = ExitStack() + self._subscribers = defaultdict(list) + + def __enter__(self): + self._exit_stack.__enter__() + if not self.portal: + portal_cm = start_blocking_portal() + self.portal = self._exit_stack.enter_context(portal_cm) + self._async_event_source.subscribe(self._forward_async_event) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + + async def _forward_async_event(self, event: Event) -> None: + for subscriber in self._subscribers.get(type(event), ()): + await run_sync_in_worker_thread(subscriber, event) + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> None: + if event_types is None: + event_types = _all_event_types + + for event_type in event_types: + existing_callbacks = self._subscribers[event_type] + if callback not in existing_callbacks: + existing_callbacks.append(callback) + + def unsubscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> None: + if event_types is None: + event_types = _all_event_types + + for event_type in event_types: + existing_callbacks = self._subscribers.get(event_type, []) + with suppress(ValueError): + existing_callbacks.remove(callback) diff --git a/apscheduler/exceptions.py b/apscheduler/exceptions.py index 48fe9c1..ec04f2c 100644 --- a/apscheduler/exceptions.py +++ b/apscheduler/exceptions.py @@ -8,11 +8,14 @@ class JobLookupError(KeyError): class ConflictingIdError(KeyError): - """Raised when the uniqueness of job IDs is being violated.""" + """ + Raised when trying to add a schedule to a store that already contains a schedule by that ID, + and the conflict policy of ``exception`` is used. + """ - def __init__(self, job_id): + def __init__(self, schedule_id): super().__init__( - u'Job identifier (%s) conflicts with an existing job' % job_id) + f'This data store already contains a schedule with the identifier {schedule_id!r}') class TransientJobError(ValueError): @@ -40,3 +43,17 @@ class MaxIterationsReached(Exception): Raised when a trigger has reached its maximum number of allowed computation iterations when trying to calculate the next fire time. """ + + +class SchedulerAlreadyRunningError(Exception): + """Raised when attempting to start or configure the scheduler when it's already running.""" + + def __str__(self): + return 'Scheduler is already running' + + +class SchedulerNotRunningError(Exception): + """Raised when attempting to shutdown the scheduler when it's not running.""" + + def __str__(self): + return 'Scheduler is not running' diff --git a/apscheduler/policies.py b/apscheduler/policies.py new file mode 100644 index 0000000..91fbd31 --- /dev/null +++ b/apscheduler/policies.py @@ -0,0 +1,19 @@ +from enum import Enum, auto + + +class ConflictPolicy(Enum): + #: replace the existing schedule with a new one + replace = auto() + #: keep the existing schedule as-is and drop the new schedule + do_nothing = auto() + #: raise an exception if a conflict is detected + exception = auto() + + +class CoalescePolicy(Enum): + #: run once, with the earliest fire time + earliest = auto() + #: run once, with the latest fire time + latest = auto() + #: submit one job for every accumulated fire time + all = auto() diff --git a/apscheduler/schedulers/__init__.py b/apscheduler/schedulers/__init__.py index bd8a790..e69de29 100644 --- a/apscheduler/schedulers/__init__.py +++ b/apscheduler/schedulers/__init__.py @@ -1,12 +0,0 @@ -class SchedulerAlreadyRunningError(Exception): - """Raised when attempting to start or configure the scheduler when it's already running.""" - - def __str__(self): - return 'Scheduler is already running' - - -class SchedulerNotRunningError(Exception): - """Raised when attempting to shutdown the scheduler when it's not running.""" - - def __str__(self): - return 'Scheduler is not running' diff --git a/apscheduler/schedulers/async_.py b/apscheduler/schedulers/async_.py index d6f8faf..6448aa5 100644 --- a/apscheduler/schedulers/async_.py +++ b/apscheduler/schedulers/async_.py @@ -1,53 +1,73 @@ -import logging import os import platform import threading from contextlib import AsyncExitStack -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone, tzinfo -from traceback import format_exc -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set, Union +from datetime import datetime, timedelta, timezone +from logging import Logger, getLogger +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Union from uuid import uuid4 -import tzlocal -from anyio import create_event, create_task_group, move_on_after -from anyio.abc import Event +from anyio import create_event, create_task_group, get_cancelled_exc_class, open_cancel_scope +from anyio.abc import CancelScope, Event, TaskGroup -from ..abc import DataStore, EventHub, EventSource, Executor, Job, Schedule, Task, Trigger -from ..datastores.memory import MemoryScheduleStore -from ..eventhubs.local import LocalEventHub -from ..events import JobSubmissionFailed, SchedulesAdded, SchedulesUpdated +from ..abc import DataStore, Job, Schedule, Task, Trigger +from ..datastores.memory import MemoryDataStore +from ..events import EventHub from ..marshalling import callable_to_ref -from ..validators import as_timezone -from ..workers.local import LocalExecutor - - -@dataclass -class AsyncScheduler(EventSource): - schedule_store: DataStore = field(default_factory=MemoryScheduleStore) - worker: Executor = field(default_factory=LocalExecutor) - timezone: tzinfo = field(default_factory=tzlocal.get_localzone) - identity: str = f'{platform.node()}-{os.getpid()}-{threading.get_ident()}' - logger: logging.Logger = field(default_factory=lambda: logging.getLogger(__name__)) - _event_hub: EventHub = field(init=False, default_factory=LocalEventHub) - _tasks: Dict[str, Task] = field(init=False, default_factory=dict) - _async_stack: AsyncExitStack = field(init=False, default_factory=AsyncExitStack) - _next_fire_time: Optional[datetime] = field(init=False, default=None) - _wakeup_event: Optional[Event] = field(init=False, default=None) - _closed: bool = field(init=False, default=False) - - def __post_init__(self): - self.timezone = as_timezone(self.timezone) +from ..policies import CoalescePolicy, ConflictPolicy +from ..workers.async_ import AsyncWorker + + +class AsyncScheduler(EventHub): + _task_group: Optional[TaskGroup] = None + _stop_event: Optional[Event] = None + _running: bool = False + _worker: Optional[AsyncWorker] = None + _acquire_cancel_scope: Optional[CancelScope] = None + + def __init__(self, data_store: Optional[DataStore] = None, *, identity: Optional[str] = None, + logger: Optional[Logger] = None, start_worker: bool = True): + super().__init__() + self.data_store = data_store or MemoryDataStore() + self.identity = identity or f'{platform.node()}-{os.getpid()}-{threading.get_ident()}' + self.logger = logger or getLogger(__name__) + self.start_worker = start_worker + self._tasks: Dict[str, Task] = {} + self._exit_stack = AsyncExitStack() + + @property + def worker(self) -> Optional[AsyncWorker]: + return self._worker async def __aenter__(self): - await self._async_stack.__aenter__() - task_group = create_task_group() - await self._async_stack.enter_async_context(task_group) - await task_group.spawn(self.run) + await self._exit_stack.__aenter__() + + # Start the built-in worker, if configured to do so + if self.start_worker: + # The worker handles initializing the data store + self._worker = AsyncWorker(self.data_store) + await self._exit_stack.enter_async_context(self._worker) + else: + # Otherwise, initialize the data store ourselves + await self._exit_stack.enter_async_context(self.data_store) + + # Start the actual scheduler + self._task_group = create_task_group() + await self._exit_stack.enter_async_context(self._task_group) + start_event = create_event() + await self._task_group.spawn(self.run, start_event) + await start_event.wait() + + # await self.start(self._task_group) return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._async_stack.__aexit__(exc_type, exc_val, exc_tb) + # Exit gracefully (wait for ongoing tasks to finish) if the context exited without errors + await self.stop(force=exc_type is not None) + if self._worker: + await self._worker.stop(force=exc_type is not None) + + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task: task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id) @@ -71,8 +91,10 @@ class AsyncScheduler(EventSource): async def add_schedule( self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None, args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None, - coalesce: bool = True, misfire_grace_time: Union[float, timedelta, None] = None, - tags: Optional[Iterable[str]] = None + coalesce: CoalescePolicy = CoalescePolicy.latest, + misfire_grace_time: Union[float, timedelta, None] = None, + tags: Optional[Iterable[str]] = None, + conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing ) -> str: id = id or str(uuid4()) args = tuple(args or ()) @@ -85,36 +107,30 @@ class AsyncScheduler(EventSource): schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs, coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags, next_fire_time=trigger.next()) - await self.schedule_store.add_or_replace_schedules([schedule]) - self.logger.info('Added new schedule for task %s; next run time at %s', taskdef, - schedule.next_fire_time) + await self.data_store.add_schedule(schedule, conflict_policy) + self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef, + trigger, schedule.next_fire_time) return schedule.id async def remove_schedule(self, schedule_id: str) -> None: - await self.schedule_store.remove_schedules({schedule_id}) - - async def _handle_worker_event(self, event: Event) -> None: - await self._event_hub.publish(event) - - async def _handle_datastore_event(self, event: Event) -> None: - if isinstance(event, (SchedulesAdded, SchedulesUpdated)): - # Wake up the scheduler if any schedule has an earlier next fire time than the one - # we're currently waiting for - if event.earliest_next_fire_time: - if (self._next_fire_time is None - or self._next_fire_time > event.earliest_next_fire_time): - self.logger.debug('Job store reported an updated next fire time that requires ' - 'the scheduler to wake up: %s', - event.earliest_next_fire_time) - await self.wakeup() - - await self._event_hub.publish(event) - - async def _process_schedules(self): - async with self.schedule_store.acquire_due_schedules( - self.identity, datetime.now(timezone.utc)) as schedules: - schedule_ids_to_remove: Set[str] = set() - schedule_updates: Dict[str, Dict[str, Any]] = {} + await self.data_store.remove_schedules({schedule_id}) + + async def run(self, start_event: Optional[Event] = None) -> None: + self._stop_event = create_event() + self._running = True + if start_event: + await start_event.set() + + while self._running: + async with open_cancel_scope() as self._acquire_cancel_scope: + try: + schedules = await self.data_store.acquire_schedules(self.identity, 100) + except get_cancelled_exc_class(): + break + finally: + del self._acquire_cancel_scope + + now = datetime.now(timezone.utc) for schedule in schedules: # Look up the task definition try: @@ -122,76 +138,66 @@ class AsyncScheduler(EventSource): except LookupError: self.logger.error('Cannot locate task definition %r for schedule %r – ' 'removing schedule', schedule.task_id, schedule.id) - schedule_ids_to_remove.add(schedule.id) + schedule.next_fire_time = None continue # Calculate a next fire time for the schedule, if possible - try: - next_fire_time = schedule.trigger.next() - except Exception: - self.logger.exception('Error computing next fire time for schedule %r of task ' - '%r – removing schedule', schedule.id, taskdef.id) - next_fire_time = None - - # Queue a schedule update if a next fire time could be calculated. - # Otherwise, queue the schedule for removal. - if next_fire_time: - schedule_updates[schedule.id] = { - 'next_fire_time': next_fire_time, - 'last_fire_time': schedule.next_fire_time, - 'trigger': schedule.trigger - } - else: - schedule_ids_to_remove.add(schedule.id) - - # Submit a new job to the executor - job = Job(taskdef.id, taskdef.func, schedule.args, schedule.kwargs, schedule.id, - schedule.last_fire_time, schedule.next_deadline, schedule.tags) - try: - await self.worker.submit_job(job) - except Exception as exc: - self.logger.exception('Error submitting job to worker') - event = JobSubmissionFailed(datetime.now(self.timezone), job.id, job.task_id, - job.schedule_id, job.scheduled_start_time, - formatted_traceback=format_exc(), exception=exc) - await self._event_hub.publish(event) - - # Removed finished schedules - if schedule_ids_to_remove: - await self.schedule_store.remove_schedules(schedule_ids_to_remove) - - # Update the next fire times - if schedule_updates: - await self.schedule_store.update_schedules(schedule_updates) - - return await self.schedule_store.get_next_fire_time() - - async def run(self): - async with AsyncExitStack() as stack: - await stack.enter_async_context(self.schedule_store) - await stack.enter_async_context(self.worker) - await self.worker.subscribe(self._handle_worker_event) - await self.schedule_store.subscribe(self._handle_datastore_event) - self._wakeup_event = create_event() - while not self._closed: - self._next_fire_time = await self._process_schedules() - - if self._next_fire_time: - wait_time = (self._next_fire_time - datetime.now(timezone.utc)).total_seconds() - else: - wait_time = float('inf') - - self._wakeup_event = create_event() - async with move_on_after(wait_time): - await self._wakeup_event.wait() - - async def wakeup(self) -> None: - await self._wakeup_event.set() - - async def shutdown(self, force: bool = False) -> None: - if not self._closed: - self._closed = True - await self.wakeup() - - async def subscribe(self, callback: Callable[[Event], Any]) -> None: - await self._event_hub.subscribe(callback) + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + 'Error computing next fire time for schedule %r of task %r – ' + 'removing schedule', schedule.id, taskdef.id) + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + for fire_time in fire_times: + schedule.last_fire_time = fire_time + job = Job(taskdef.id, taskdef.func, schedule.args, schedule.kwargs, + schedule.id, fire_time, schedule.next_deadline, + schedule.tags) + print('Added job', job.id, 'for schedule', schedule) + await self.data_store.add_job(job) + + self.logger.debug('Releasing %d schedules', len(schedules)) + await self.data_store.release_schedules(self.identity, schedules) + + await self._stop_event.set() + del self._stop_event + + # async def start(self, task_group: TaskGroup, *, reset_datastore: bool = False) -> None: + # start_event = create_event() + # await task_group.spawn(self.run, start_event) + # await start_event.wait() + # + # if self.start_worker: + # self._worker = AsyncWorker(self.data_store) + # await self._worker.start(task_group) + + async def stop(self, force: bool = False) -> None: + self._running = False + if self._worker: + await self._worker.stop(force) + + if self._acquire_cancel_scope: + await self._acquire_cancel_scope.cancel() + if force and self._task_group: + await self._task_group.cancel_scope.cancel() + + async def wait_until_stopped(self) -> None: + if self._stop_event: + await self._stop_event.wait() diff --git a/apscheduler/schedulers/sync.py b/apscheduler/schedulers/sync.py index 41a91ec..a049c9f 100644 --- a/apscheduler/schedulers/sync.py +++ b/apscheduler/schedulers/sync.py @@ -1,50 +1,109 @@ +from __future__ import annotations + +from concurrent.futures import Future from functools import partial +from logging import Logger from typing import Any, Callable, Iterable, Mapping, Optional, Union -from anyio import start_blocking_portal from anyio.abc import BlockingPortal -from apscheduler.abc import Trigger -from apscheduler.events import Event -from apscheduler.schedulers.async_ import AsyncScheduler +from .async_ import AsyncScheduler +from ..abc import DataStore, Trigger +from ..events import SyncEventSource +from ..workers.sync import SyncWorker + + +class SyncScheduler(SyncEventSource): + """ + A synchronous scheduler implementation. + + Wraps an :class:`~.async_.AsyncScheduler` using a :class:`~anyio.abc.BlockingPortal`. + + :ivar BlockingPortal portal: the blocking portal used by the scheduler + """ + + _task: Future + _worker: Optional[SyncWorker] = None -class SyncScheduler: - def __init__(self, *, portal: Optional[BlockingPortal] = None, **kwargs): - self._scheduler = AsyncScheduler(**kwargs) - self._portal = portal - self._shutdown_portal = False + def __init__(self, data_store: Optional[DataStore] = None, *, identity: Optional[str] = None, + logger: Optional[Logger] = None, start_worker: bool = True, + portal: Optional[BlockingPortal] = None): + self._scheduler = AsyncScheduler(data_store, identity=identity, logger=logger, + start_worker=False) + self.start_worker = start_worker + super().__init__(self._scheduler, portal) def __enter__(self): - if not self._portal: - self._portal = start_blocking_portal() - self._shutdown_portal = True + super().__enter__() + self._exit_stack.enter_context(self.portal.wrap_async_context_manager(self._scheduler)) + + if self.start_worker: + self._worker = SyncWorker(self.data_store, portal=self.portal) + self._exit_stack.enter_context(self.worker) - self._portal.call(self._scheduler.__aenter__) return self - def __exit__(self, exc_type, exc_val, exc_tb): - cancel_remaining = exc_type is not None - try: - self._portal.call(self._scheduler.__aexit__, exc_type, exc_val, exc_tb) - except BaseException: - cancel_remaining = True - raise - finally: - if self._shutdown_portal: - self._portal.stop_from_external_thread(cancel_remaining=cancel_remaining) - self._portal = None + # def __exit__(self, exc_type, exc_val, exc_tb): + # return self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + + # def __enter__(self) -> SyncScheduler: + # self.start() + # return self + # + # def __exit__(self, exc_type, exc_val, exc_tb): + # self.stop(force=exc_type is not None) + + # @property + # def worker(self) -> AsyncWorker: + # return self._scheduler.worker + + @property + def data_store(self) -> DataStore: + return self._scheduler.data_store + + @property + def worker(self) -> Optional[SyncWorker]: + return self._worker def add_schedule(self, task: Union[str, Callable], trigger: Trigger, *, args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None) -> str: func = partial(self._scheduler.add_schedule, task, trigger, args=args, kwargs=kwargs) - return self._portal.call(func) + return self.portal.call(func) def remove_schedule(self, schedule_id: str) -> None: - return self._portal.call(self._scheduler.remove_schedule, schedule_id) + self.portal.call(self._scheduler.remove_schedule, schedule_id) + + def stop(self) -> None: + self.portal.call(self._scheduler.stop) - def shutdown(self, force: bool = False) -> None: - self._portal.call(self._scheduler.shutdown, force) + def wait_until_stopped(self) -> None: + self.portal.call(self._scheduler.wait_until_stopped) - def subscribe(self, callback: Callable[[Event], Any]) -> None: - self._portal.call(self._scheduler.subscribe, callback) + # def start(self): + # if not self._scheduler._running: + # if not self.portal: + # self.portal = start_blocking_portal() + # self._shutdown_portal = True + # + # self.portal.call(self._scheduler.data_store.initialize) + # start_event = self.portal.call(create_event) + # self._task = self.portal.spawn_task(self._scheduler.run, start_event) + # self.portal.call(start_event.wait) + # print('start_event wait finished') + # + # if self._scheduler.start_worker: + # self._worker = SyncWorker(self.data_store, portal=self.portal) + # self._worker.start() + # + # def stop(self, force: bool = False) -> None: + # if self._worker: + # self._worker.stop(force) + # + # if self._scheduler._running: + # self._scheduler._running = False + # try: + # self.portal.call(partial(self._scheduler.stop, force=force)) + # finally: + # if self._shutdown_portal: + # self.portal.stop_from_external_thread(cancel_remaining=force) diff --git a/apscheduler/util.py b/apscheduler/util.py index 1b09d53..eee6da5 100644 --- a/apscheduler/util.py +++ b/apscheduler/util.py @@ -1,6 +1,6 @@ """This module contains several handy functions primarily meant for internal use.""" import sys -from datetime import tzinfo +from datetime import datetime, tzinfo if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -24,3 +24,7 @@ def timezone_repr(timezone: tzinfo) -> str: return timezone.key else: return repr(timezone) + + +def absolute_datetime_diff(dateval1: datetime, dateval2: datetime) -> float: + return dateval1.timestamp() - dateval2.timestamp() diff --git a/apscheduler/watchers.py b/apscheduler/watchers.py new file mode 100644 index 0000000..d661d88 --- /dev/null +++ b/apscheduler/watchers.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field +from typing import Optional + +from anyio import create_event, move_on_after +from anyio.abc import Event + + +@dataclass +class DelayWatcher: + max_wait_time: Optional[float] + _event: Optional[Event] = field(init=False, default=None) + + async def wait(self) -> None: + async with move_on_after(self.max_wait_time): + self._event = create_event() + await self._event.wait() + + async def notify(self) -> None: + if self._event: + event = self._event + del self._event + await event.set() diff --git a/apscheduler/workers/async_.py b/apscheduler/workers/async_.py new file mode 100644 index 0000000..7949993 --- /dev/null +++ b/apscheduler/workers/async_.py @@ -0,0 +1,158 @@ +import os +import platform +import threading +from asyncio import current_task, iscoroutinefunction +from collections.abc import Coroutine +from contextlib import AsyncExitStack +from datetime import datetime, timezone +from functools import partial +from logging import Logger, getLogger +from traceback import format_exc +from typing import Any, Callable, Dict, Optional, Set + +from anyio import create_event, create_task_group, open_cancel_scope, run_sync_in_worker_thread +from anyio.abc import CancelScope, Event, TaskGroup + +from ..abc import DataStore, Job +from ..events import EventHub, JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated + + +class AsyncWorker(EventHub): + """Runs jobs locally in a task group.""" + + _task_group: Optional[TaskGroup] = None + _stop_event: Optional[Event] = None + _running: bool = False + _running_jobs: int = 0 + _acquire_cancel_scope: Optional[CancelScope] = None + + def __init__(self, data_store: DataStore, *, max_concurrent_jobs: int = 100, + identity: Optional[str] = None, logger: Optional[Logger] = None, + run_sync_functions_in_event_loop: bool = True): + super().__init__() + self.data_store = data_store + self.max_concurrent_jobs = max_concurrent_jobs + self.identity = identity or f'{platform.node()}-{os.getpid()}-{threading.get_ident()}' + self.logger = logger or getLogger(__name__) + self.run_sync_functions_in_event_loop = run_sync_functions_in_event_loop + self._acquired_jobs: Set[Job] = set() + self._exit_stack = AsyncExitStack() + + if self.max_concurrent_jobs < 1: + raise ValueError('max_concurrent_jobs must be at least 1') + + async def __aenter__(self): + print('entering worker in task', id(current_task())) + await self._exit_stack.__aenter__() + + # Initialize the data store + await self._exit_stack.enter_async_context(self.data_store) + + # Start the actual worker + self._task_group = create_task_group() + await self._exit_stack.enter_async_context(self._task_group) + start_event = create_event() + await self._task_group.spawn(self.run, start_event) + await start_event.wait() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + print('exiting worker in task', id(current_task())) + await self.stop(force=exc_type is not None) + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + + async def run(self, start_event: Optional[Event] = None) -> None: + self._stop_event = create_event() + self._running = True + if start_event: + await start_event.set() + + while self._running: + limit = self.max_concurrent_jobs - self._running_jobs + jobs = [] + async with open_cancel_scope() as self._acquire_cancel_scope: + try: + jobs = await self.data_store.acquire_jobs(self.identity, limit) + finally: + del self._acquire_cancel_scope + + for job in jobs: + await self._task_group.spawn(self._run_job, job) + + await self._stop_event.set() + del self._stop_event + del self._task_group + + async def _run_job(self, job: Job) -> None: + # Check if the job started before the deadline + now = datetime.now(timezone.utc) + if job.start_deadline is not None: + if now.timestamp() > job.start_deadline.timestamp(): + self.logger.info('Missed the deadline of job %r', job.id) + event = JobDeadlineMissed( + now, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + scheduled_fire_time=job.scheduled_fire_time, start_time=now, + start_deadline=job.start_deadline + ) + await self.publish(event) + return + + # Set the job as running and publish a job update event + self.logger.info('Started job %r', job.id) + job.started_at = now + event = JobUpdated( + timestamp=now, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id + ) + await self.publish(event) + + self._running_jobs += 1 + try: + return_value = await self._call_job_func(job.func, job.args, job.kwargs) + except BaseException as exc: + self.logger.exception('Job %r raised an exception', job.id) + event = JobFailed( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=now, start_deadline=job.start_deadline, + traceback=format_exc(), exception=exc + ) + else: + self.logger.info('Job %r completed successfully', job.id) + event = JobSuccessful( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=now, start_deadline=job.start_deadline, return_value=return_value + ) + + self._running_jobs -= 1 + await self.data_store.release_jobs(self.identity, [job]) + await self.publish(event) + + async def _call_job_func(self, func: Callable, args: tuple, kwargs: Dict[str, Any]): + if not self.run_sync_functions_in_event_loop and not iscoroutinefunction(func): + wrapped = partial(func, *args, **kwargs) + return await run_sync_in_worker_thread(wrapped) + + return_value = func(*args, **kwargs) + if isinstance(return_value, Coroutine): + return_value = await return_value + + return return_value + + # async def start(self, task_group: TaskGroup) -> None: + # start_event = create_event() + # print('spawning task for AsyncWorker.run() in task', id(current_task())) + # await task_group.spawn(self.run, start_event) + # await start_event.wait() + + async def stop(self, force: bool = False) -> None: + self._running = False + if self._acquire_cancel_scope: + await self._acquire_cancel_scope.cancel() + + if force and self._task_group: + await self._task_group.cancel_scope.cancel() + + async def wait_until_stopped(self) -> None: + if self._stop_event: + await self._stop_event.wait() diff --git a/apscheduler/workers/local.py b/apscheduler/workers/local.py deleted file mode 100644 index 9570612..0000000 --- a/apscheduler/workers/local.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations - -from collections import OrderedDict -from collections.abc import Coroutine -from dataclasses import dataclass, field -from datetime import datetime, timezone -from traceback import format_exc -from typing import Any, Callable, List - -from anyio import create_task_group -from anyio.abc import TaskGroup -from apscheduler.abc import EventHub, Executor, Job -from apscheduler.eventhubs.local import LocalEventHub -from apscheduler.events import ( - Event, JobAdded, JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated) - - -@dataclass -class LocalExecutor(Executor): - """Runs jobs locally in a task group.""" - - _event_hub: EventHub = field(init=False, default_factory=LocalEventHub) - _task_group: TaskGroup = field(init=False) - _queued_jobs: OrderedDict[Job, None] = field(init=False, default_factory=OrderedDict) - _running_jobs: OrderedDict[Job, None] = field(init=False, default_factory=OrderedDict) - - async def __aenter__(self) -> LocalExecutor: - self._task_group = create_task_group() - await self._task_group.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - - async def _run_job(self) -> None: - job = self._queued_jobs.popitem(last=False)[0] - - # Check if the job started before the deadline - if job.start_deadline: - tz = job.scheduled_start_time.tzinfo - start_time = datetime.now(tz) - if start_time >= job.start_deadline: - event = JobDeadlineMissed( - start_time, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, - scheduled_start_time=job.scheduled_start_time, - start_deadline=job.start_deadline - ) - await self._event_hub.publish(event) - return - else: - tz = timezone.utc - start_time = datetime.now(tz) - - # Set the job as running and publish a job update event - self._running_jobs[job] = None - job.started_at = start_time - event = JobUpdated( - timestamp=datetime.now(tz), job_id=job.id, task_id=job.task_id, - schedule_id=job.schedule_id, scheduled_start_time=job.scheduled_start_time - ) - await self._event_hub.publish(event) - - try: - return_value = job.func(*job.args, **job.kwargs) - if isinstance(return_value, Coroutine): - return_value = await return_value - except Exception as exc: - event = JobFailed( - timestamp=datetime.now(tz), job_id=job.id, task_id=job.task_id, - schedule_id=job.schedule_id, scheduled_start_time=job.scheduled_start_time, - start_time=start_time, start_deadline=job.start_deadline, - formatted_traceback=format_exc(), exception=exc) - else: - event = JobSuccessful( - timestamp=datetime.now(tz), job_id=job.id, task_id=job.task_id, - schedule_id=job.schedule_id, scheduled_start_time=job.scheduled_start_time, - start_time=start_time, start_deadline=job.start_deadline, return_value=return_value - ) - - del self._running_jobs[job] - await self._event_hub.publish(event) - - async def submit_job(self, job: Job) -> None: - self._queued_jobs[job] = None - await self._task_group.spawn(self._run_job) - - event = JobAdded(datetime.now(timezone.utc), job.id, job.task_id, job.schedule_id, - job.scheduled_start_time) - await self._event_hub.publish(event) - - async def get_jobs(self) -> List[Job]: - return list(self._queued_jobs) - - async def subscribe(self, callback: Callable[[Event], Any]) -> None: - await self._event_hub.subscribe(callback) diff --git a/apscheduler/workers/sync.py b/apscheduler/workers/sync.py new file mode 100644 index 0000000..919c914 --- /dev/null +++ b/apscheduler/workers/sync.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Optional + +from anyio.abc import BlockingPortal + +from .async_ import AsyncWorker +from ..abc import DataStore +from ..events import SyncEventSource + + +class SyncWorker(SyncEventSource): + def __init__(self, data_store: DataStore, *, portal: Optional[BlockingPortal] = None, + max_concurrent_jobs: int = 100, identity: Optional[str] = None): + self._worker = AsyncWorker(data_store, max_concurrent_jobs=max_concurrent_jobs, + identity=identity, run_sync_functions_in_event_loop=False) + super().__init__(self._worker, portal) + + def __enter__(self) -> SyncWorker: + super().__enter__() + self._exit_stack.enter_context(self.portal.wrap_async_context_manager(self._worker)) + return self + + def stop(self) -> None: + self.portal.call(self._worker.stop) diff --git a/docker-compose.yml b/docker-compose.yml index 001e048..ca5621f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,13 @@ version: "2" services: + postgresql: + image: postgres + ports: + - 127.0.0.1:5432:5432 + environment: + POSTGRES_DB: testdb + POSTGRES_PASSWORD: secret + redis: image: redis ports: @@ -9,6 +17,8 @@ services: image: mongo ports: - 127.0.0.1:27017:27017 + environment: + MONGO_REPLICA_SET_NAME: foo rethinkdb: image: rethinkdb diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 5700b3d..264a0ab 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -7,9 +7,13 @@ APScheduler, see the :doc:`migration section <migration>`. UNRELEASED ---------- +* Added shareable data stores (can be used by multiple schedulers and workers) +* Added support for running workers independently from schedulers +* Added full async support via AnyIO_ (data store support varies) * Dropped support for Python 2.X, 3.5 and 3.6 * Removed the Qt scheduler due to maintenance difficulties +.. _AnyIO: https://github.com/agronholm/anyio 3.7.0 ----- diff --git a/examples/executors/processpool.py b/examples/executors/processpool.py index df1e82b..37f58bb 100644 --- a/examples/executors/processpool.py +++ b/examples/executors/processpool.py @@ -19,6 +19,6 @@ if __name__ == '__main__': print('Press Ctrl+{0} to exit'.format('Break' if os.name == 'nt' else 'C')) try: - scheduler.start() + scheduler.initialize() except (KeyboardInterrupt, SystemExit): pass diff --git a/examples/misc/reference.py b/examples/misc/reference.py index f6fc5dd..1479d2f 100644 --- a/examples/misc/reference.py +++ b/examples/misc/reference.py @@ -13,6 +13,6 @@ if __name__ == '__main__': print('Press Ctrl+{0} to exit'.format('Break' if os.name == 'nt' else 'C')) try: - scheduler.start() + scheduler.initialize() except (KeyboardInterrupt, SystemExit): pass diff --git a/examples/rpc/server.py b/examples/rpc/server.py index 6fb8594..4c7c705 100644 --- a/examples/rpc/server.py +++ b/examples/rpc/server.py @@ -45,12 +45,12 @@ class SchedulerService(rpyc.Service): if __name__ == '__main__': scheduler = BackgroundScheduler() - scheduler.start() + scheduler.initialize() protocol_config = {'allow_public_attrs': True} server = ThreadedServer(SchedulerService, port=12345, protocol_config=protocol_config) try: - server.start() + server.initialize() except (KeyboardInterrupt, SystemExit): pass finally: - scheduler.shutdown() + scheduler.stop() diff --git a/examples/schedulers/async_.py b/examples/schedulers/async_.py index 0f6677c..ab1784f 100644 --- a/examples/schedulers/async_.py +++ b/examples/schedulers/async_.py @@ -3,6 +3,7 @@ import logging import anyio from apscheduler.schedulers.async_ import AsyncScheduler from apscheduler.triggers.interval import IntervalTrigger +from apscheduler.workers.async_ import AsyncWorker def say_hello(): @@ -10,8 +11,9 @@ def say_hello(): async def main(): - async with AsyncScheduler() as scheduler: + async with AsyncScheduler() as scheduler, AsyncWorker(scheduler.data_store): await scheduler.add_schedule(say_hello, IntervalTrigger(seconds=1)) + await scheduler.wait_until_stopped() logging.basicConfig(level=logging.DEBUG) try: diff --git a/examples/schedulers/sync.py b/examples/schedulers/sync.py index df78df9..f845727 100644 --- a/examples/schedulers/sync.py +++ b/examples/schedulers/sync.py @@ -2,6 +2,7 @@ import logging from apscheduler.schedulers.sync import SyncScheduler from apscheduler.triggers.interval import IntervalTrigger +from apscheduler.workers.sync import SyncWorker def say_hello(): @@ -10,7 +11,8 @@ def say_hello(): logging.basicConfig(level=logging.DEBUG) try: - with SyncScheduler() as scheduler: + with SyncScheduler() as scheduler, SyncWorker(scheduler.data_store, portal=scheduler.portal): scheduler.add_schedule(say_hello, IntervalTrigger(seconds=1)) + scheduler.wait_until_stopped() except (KeyboardInterrupt, SystemExit): pass @@ -21,9 +21,8 @@ license = MIT packages = find: python_requires = >= 3.7 install_requires = - anyio ~= 2.0 + anyio ~= 2.1 backports.zoneinfo; python_version < '3.9' - sortedcontainers ~= 2.2 tzdata; platform_system == "Windows" tzlocal >= 3.0b1, < 4.0 @@ -40,6 +39,7 @@ test = curio pytest >= 5.0 pytest-cov + pytest-freezegun pytest-mock pytz trio diff --git a/tests/backends/test_memory.py b/tests/backends/test_memory.py deleted file mode 100644 index c8679f4..0000000 --- a/tests/backends/test_memory.py +++ /dev/null @@ -1,138 +0,0 @@ -from datetime import datetime, timedelta, timezone - -import pytest -from apscheduler.abc import Schedule -from apscheduler.datastores.memory import MemoryScheduleStore -from apscheduler.events import SchedulesAdded, SchedulesRemoved, SchedulesUpdated -from apscheduler.triggers.date import DateTrigger - -pytestmark = pytest.mark.anyio - - -@pytest.fixture -async def store(): - async with MemoryScheduleStore() as store_: - yield store_ - - -@pytest.fixture -async def events(store): - events = [] - await store.subscribe(events.append) - return events - - -@pytest.fixture -def schedules(): - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule1 = Schedule(id='s1', task_id='bogus', trigger=trigger, args=(), kwargs={}, - coalesce=False, misfire_grace_time=None, tags=frozenset()) - schedule1.next_fire_time = trigger.next() - - trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc)) - schedule2 = Schedule(id='s2', task_id='bogus', trigger=trigger, args=(), kwargs={}, - coalesce=False, misfire_grace_time=None, tags=frozenset()) - schedule2.next_fire_time = trigger.next() - - trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc)) - schedule3 = Schedule(id='s3', task_id='bogus', trigger=trigger, args=(), kwargs={}, - coalesce=False, misfire_grace_time=None, tags=frozenset()) - return [schedule1, schedule2, schedule3] - - -async def test_add_schedules(store, schedules, events): - assert await store.get_next_fire_time() is None - await store.add_or_replace_schedules(schedules) - assert await store.get_next_fire_time() == datetime(2020, 9, 13, tzinfo=timezone.utc) - - assert await store.get_schedules() == schedules - assert await store.get_schedules({'s1', 's2', 's3'}) == schedules - assert await store.get_schedules({'s1'}) == [schedules[0]] - assert await store.get_schedules({'s2'}) == [schedules[1]] - assert await store.get_schedules({'s3'}) == [schedules[2]] - - assert len(events) == 1 - assert isinstance(events[0], SchedulesAdded) - assert events[0].schedule_ids == {'s1', 's2', 's3'} - assert events[0].earliest_next_fire_time == datetime(2020, 9, 13, tzinfo=timezone.utc) - - -async def test_replace_schedules(store, schedules, events): - await store.add_or_replace_schedules(schedules) - events.clear() - - next_fire_time = schedules[2].trigger.next() - schedule = Schedule(id='s3', task_id='foo', trigger=schedules[2].trigger, args=(), kwargs={}, - coalesce=False, misfire_grace_time=None, tags=frozenset()) - schedule.next_fire_time = next_fire_time - await store.add_or_replace_schedules([schedule]) - - schedules = await store.get_schedules([schedule.id]) - assert schedules[0].task_id == 'foo' - assert schedules[0].next_fire_time == next_fire_time - assert schedules[0].args == () - assert schedules[0].kwargs == {} - assert not schedules[0].coalesce - assert schedules[0].misfire_grace_time is None - assert schedules[0].tags == frozenset() - - assert len(events) == 1 - assert isinstance(events[0], SchedulesUpdated) - assert events[0].schedule_ids == {'s3'} - assert events[0].earliest_next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc) - - -async def test_update_schedules(store, schedules, events): - await store.add_or_replace_schedules(schedules) - events.clear() - - next_fire_time = datetime(2020, 9, 12, tzinfo=timezone.utc) - updated_ids = await store.update_schedules({ - 's2': {'task_id': 'foo', 'next_fire_time': next_fire_time, 'args': (1,), - 'kwargs': {'a': 'x'}, 'coalesce': True, 'misfire_grace_time': timedelta(seconds=5), - 'tags': frozenset(['a', 'b'])}, - 'nonexistent': {} - }) - assert updated_ids == {'s2'} - - schedules = await store.get_schedules({'s2'}) - assert len(schedules) == 1 - assert schedules[0].task_id == 'foo' - assert schedules[0].args == (1,) - assert schedules[0].kwargs == {'a': 'x'} - assert schedules[0].coalesce - assert schedules[0].next_fire_time == next_fire_time - assert schedules[0].misfire_grace_time == timedelta(seconds=5) - assert schedules[0].tags == frozenset(['a', 'b']) - - assert len(events) == 1 - assert isinstance(events[0], SchedulesUpdated) - assert events[0].schedule_ids == {'s2'} - assert events[0].earliest_next_fire_time == next_fire_time - - -@pytest.mark.parametrize('ids', [None, {'s1', 's2'}], ids=['all', 'by_id']) -async def test_remove_schedules(store, schedules, events, ids): - await store.add_or_replace_schedules(schedules) - events.clear() - - await store.remove_schedules(ids) - assert len(events) == 1 - assert isinstance(events[0], SchedulesRemoved) - if ids: - assert events[0].schedule_ids == {'s1', 's2'} - assert await store.get_schedules() == [schedules[2]] - else: - assert events[0].schedule_ids == {'s1', 's2', 's3'} - assert await store.get_schedules() == [] - - -async def test_acquire_due_schedules(store, schedules, events): - await store.add_or_replace_schedules(schedules) - events.clear() - - now = datetime(2020, 9, 14, tzinfo=timezone.utc) - async with store.acquire_due_schedules('dummy-id', now) as schedules: - assert len(schedules) == 2 - assert schedules[0].id == 's1' - assert schedules[1].id == 's2' diff --git a/tests/conftest.py b/tests/conftest.py index 62e552f..9d0f64c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,15 @@ import sys -from datetime import datetime -from unittest.mock import Mock +from contextlib import AsyncExitStack, ExitStack +from functools import partial import pytest +from anyio import start_blocking_portal +from asyncpg import create_pool +from motor.motor_asyncio import AsyncIOMotorClient + +from apscheduler.datastores.memory import MemoryDataStore +from apscheduler.datastores.mongodb import MongoDBDataStore +from apscheduler.datastores.postgresql import PostgresqlDataStore from apscheduler.serializers.cbor import CBORSerializer from apscheduler.serializers.json import JSONSerializer from apscheduler.serializers.pickle import PickleSerializer @@ -12,42 +19,70 @@ if sys.version_info >= (3, 9): else: from backports.zoneinfo import ZoneInfo +store_params = [ + pytest.param(MemoryDataStore, id='memory'), + pytest.param(PostgresqlDataStore, id='postgresql'), + pytest.param(MongoDBDataStore, id='mongodb') +] + @pytest.fixture(scope='session') def timezone(): return ZoneInfo('Europe/Berlin') +@pytest.fixture(params=[None, PickleSerializer, CBORSerializer, JSONSerializer], + ids=['none', 'pickle', 'cbor', 'json']) +def serializer(request): + return request.param() if request.param else None + + @pytest.fixture -def freeze_time(monkeypatch, timezone): - class TimeFreezer: - def __init__(self, initial): - self.current = initial - self.increment = None +def anyio_backend(): + return 'asyncio' - def get(self, tzinfo=None): - now = self.current.astimezone(tzinfo) if tzinfo else self.current.replace(tzinfo=None) - if self.increment: - self.current += self.increment - return now - def set(self, new_time): - self.current = new_time +@pytest.fixture(params=store_params) +async def store(request): + async with AsyncExitStack() as stack: + if request.param is PostgresqlDataStore: + pool = await create_pool('postgresql://postgres:secret@localhost/testdb', + min_size=1, max_size=2) + await stack.enter_async_context(pool) + store = PostgresqlDataStore(pool, start_from_scratch=True) + elif request.param is MongoDBDataStore: + client = AsyncIOMotorClient(tz_aware=True) + stack.push(lambda *args: client.close()) + store = MongoDBDataStore(client, start_from_scratch=True) + else: + store = MemoryDataStore() - def next(self,): - return self.current + self.increment + await stack.enter_async_context(store) + yield store - def set_increment(self, delta): - self.increment = delta - freezer = TimeFreezer(timezone.localize(datetime(2011, 4, 3, 18, 40))) - fake_datetime = Mock(datetime, now=freezer.get) - monkeypatch.setattr('apscheduler.triggers.interval.datetime', fake_datetime) - monkeypatch.setattr('apscheduler.triggers.date.datetime', fake_datetime) - return freezer +@pytest.fixture +def portal(): + with start_blocking_portal() as portal: + yield portal -@pytest.fixture(params=[None, PickleSerializer, CBORSerializer, JSONSerializer], - ids=['none', 'pickle', 'cbor', 'json']) -def serializer(request): - return request.param() if request.param else None +@pytest.fixture(params=store_params) +def sync_store(request, portal): + with ExitStack() as stack: + if request.param is PostgresqlDataStore: + pool = portal.call( + partial(create_pool, 'postgresql://postgres:secret@localhost/testdb', + min_size=1, max_size=2) + ) + stack.enter_context(portal.wrap_async_context_manager(pool)) + store = PostgresqlDataStore(pool, start_from_scratch=True) + elif request.param is MongoDBDataStore: + client = portal.call(partial(AsyncIOMotorClient, tz_aware=True)) + stack.push(lambda *args: portal.call(client.close)) + store = MongoDBDataStore(client, start_from_scratch=True) + else: + store = MemoryDataStore() + + stack.enter_context(portal.wrap_async_context_manager(store)) + yield store diff --git a/tests/test_datastores.py b/tests/test_datastores.py new file mode 100644 index 0000000..0eb16d7 --- /dev/null +++ b/tests/test_datastores.py @@ -0,0 +1,245 @@ +from datetime import datetime, timezone + +import pytest +from anyio import move_on_after + +from apscheduler.abc import Job, Schedule +from apscheduler.events import ScheduleAdded, ScheduleRemoved, ScheduleUpdated +from apscheduler.policies import CoalescePolicy, ConflictPolicy +from apscheduler.triggers.date import DateTrigger + +pytestmark = [ + pytest.mark.anyio, + pytest.mark.parametrize('anyio_backend', ['asyncio']) +] + + +@pytest.fixture +async def events(store): + events = [] + store.subscribe(events.append) + return events + + +@pytest.fixture +def schedules(): + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule1 = Schedule(id='s1', task_id='bogus', trigger=trigger, args=(), kwargs={}, + coalesce=CoalescePolicy.latest, misfire_grace_time=None, tags=frozenset()) + schedule1.next_fire_time = trigger.next() + + trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc)) + schedule2 = Schedule(id='s2', task_id='bogus', trigger=trigger, args=(), kwargs={}, + coalesce=CoalescePolicy.latest, misfire_grace_time=None, tags=frozenset()) + schedule2.next_fire_time = trigger.next() + + trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc)) + schedule3 = Schedule(id='s3', task_id='bogus', trigger=trigger, args=(), kwargs={}, + coalesce=CoalescePolicy.latest, misfire_grace_time=None, tags=frozenset()) + return [schedule1, schedule2, schedule3] + + +@pytest.fixture +def jobs(): + job1 = Job('task1', print, ('hello',), {'arg2': 'world'}, 'schedule1', + datetime(2020, 10, 10, tzinfo=timezone.utc), + datetime(2020, 10, 10, 1, tzinfo=timezone.utc), frozenset()) + job2 = Job('task2', print, ('hello',), {'arg2': 'world'}, None, + None, None, frozenset()) + return [job1, job2] + + +async def test_add_schedules(store, schedules, events): + for schedule in schedules: + await store.add_schedule(schedule, ConflictPolicy.exception) + + assert await store.get_schedules() == schedules + assert await store.get_schedules({'s1', 's2', 's3'}) == schedules + assert await store.get_schedules({'s1'}) == [schedules[0]] + assert await store.get_schedules({'s2'}) == [schedules[1]] + assert await store.get_schedules({'s3'}) == [schedules[2]] + + assert len(events) == 3 + for event, schedule in zip(events, schedules): + assert isinstance(event, ScheduleAdded) + assert event.schedule_id == schedule.id + assert event.next_fire_time == schedule.next_fire_time + + +async def test_replace_schedules(store, schedules, events): + for schedule in schedules: + await store.add_schedule(schedule, ConflictPolicy.exception) + + events.clear() + next_fire_time = schedules[2].trigger.next() + schedule = Schedule(id='s3', task_id='foo', trigger=schedules[2].trigger, args=(), + kwargs={}, coalesce=CoalescePolicy.earliest, misfire_grace_time=None, + tags=frozenset()) + schedule.next_fire_time = next_fire_time + await store.add_schedule(schedule, ConflictPolicy.replace) + + schedules = await store.get_schedules({schedule.id}) + assert schedules[0].task_id == 'foo' + assert schedules[0].next_fire_time == next_fire_time + assert schedules[0].args == () + assert schedules[0].kwargs == {} + assert schedules[0].coalesce is CoalescePolicy.earliest + assert schedules[0].misfire_grace_time is None + assert schedules[0].tags == frozenset() + + assert len(events) == 1 + assert isinstance(events[0], ScheduleUpdated) + assert events[0].schedule_id == 's3' + assert events[0].next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc) + + +async def test_remove_schedules(store, schedules, events): + for schedule in schedules: + await store.add_schedule(schedule, ConflictPolicy.exception) + + events.clear() + await store.remove_schedules(['s1', 's2']) + + assert len(events) == 2 + assert isinstance(events[0], ScheduleRemoved) + assert events[0].schedule_id == 's1' + assert isinstance(events[1], ScheduleRemoved) + assert events[1].schedule_id == 's2' + + assert await store.get_schedules() == [schedules[2]] + + +@pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc)) +async def test_acquire_release_schedules(store, schedules, events): + for schedule in schedules: + await store.add_schedule(schedule, ConflictPolicy.exception) + + events.clear() + + # The first scheduler gets the first due schedule + schedules1 = await store.acquire_schedules('dummy-id1', 1) + assert len(schedules1) == 1 + assert schedules1[0].id == 's1' + + # The second scheduler gets the second due schedule + schedules2 = await store.acquire_schedules('dummy-id2', 1) + assert len(schedules2) == 1 + assert schedules2[0].id == 's2' + + # The third scheduler gets nothing + async with move_on_after(0.2): + await store.acquire_schedules('dummy-id3', 1) + pytest.fail('The call should have timed out') + + # The schedules here have run their course, and releasing them should delete them + schedules1[0].next_fire_time = None + schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc) + await store.release_schedules('dummy-id1', schedules1) + await store.release_schedules('dummy-id2', schedules2) + + # Check that the first schedule is gone + schedules = await store.get_schedules() + assert len(schedules) == 2 + assert schedules[0].id == 's2' + assert schedules[1].id == 's3' + + # Check for the appropriate update and delete events + assert len(events) == 2 + assert isinstance(events[0], ScheduleRemoved) + assert isinstance(events[1], ScheduleUpdated) + assert events[0].schedule_id == 's1' + assert events[1].schedule_id == 's2' + assert events[1].next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc) + + +async def test_acquire_schedules_lock_timeout(store, schedules, events, freezer): + """ + Test that a scheduler can acquire schedules that were acquired by another scheduler but not + released within the lock timeout period. + + """ + # First, one scheduler acquires the first available schedule + await store.add_schedule(schedules[0], ConflictPolicy.exception) + acquired = await store.acquire_schedules('dummy-id1', 1) + assert len(acquired) == 1 + assert acquired[0].id == 's1' + + # Try to acquire the schedule just at the threshold (now == acquired_until). + # This should not yield any schedules. + freezer.tick(30) + async with move_on_after(0.2): + await store.acquire_schedules('dummy-id2', 1) + pytest.fail('The call should have timed out') + + # Right after that, the schedule should be available + freezer.tick(1) + acquired = await store.acquire_schedules('dummy-id2', 1) + assert len(acquired) == 1 + assert acquired[0].id == 's1' + + +async def test_acquire_release_jobs(store, jobs, events): + for job in jobs: + await store.add_job(job) + + events.clear() + + # The first worker gets the first job in the queue + jobs1 = await store.acquire_jobs('dummy-id1', 1) + assert len(jobs1) == 1 + assert jobs1[0].id == jobs[0].id + + # The second worker gets the second job + jobs2 = await store.acquire_jobs('dummy-id2', 1) + assert len(jobs2) == 1 + assert jobs2[0].id == jobs[1].id + + # The third worker gets nothing + async with move_on_after(0.2): + await store.acquire_jobs('dummy-id3', 1) + pytest.fail('The call should have timed out') + + # All the jobs should still be returned + visible_jobs = await store.get_jobs() + assert len(visible_jobs) == 2 + + await store.release_jobs('dummy-id1', jobs1) + await store.release_jobs('dummy-id2', jobs2) + + # All the jobs should be gone + visible_jobs = await store.get_jobs() + assert len(visible_jobs) == 0 + + # Check for the appropriate update and delete events + # assert len(events) == 2 + # assert isinstance(events[0], Job) + # assert isinstance(events[1], SchedulesUpdated) + # assert events[0].schedule_ids == {'s1'} + # assert events[1].schedule_ids == {'s2'} + # assert events[1].next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc) + + +async def test_acquire_jobs_lock_timeout(store, jobs, events, freezer): + """ + Test that a worker can acquire jobs that were acquired by another scheduler but not + released within the lock timeout period. + + """ + # First, one worker acquires the first available job + await store.add_job(jobs[0]) + acquired = await store.acquire_jobs('dummy-id1', 1) + assert len(acquired) == 1 + assert acquired[0].id == jobs[0].id + + # Try to acquire the job just at the threshold (now == acquired_until). + # This should not yield any jobs. + freezer.tick(30) + async with move_on_after(0.2): + await store.acquire_jobs('dummy-id2', 1) + pytest.fail('The call should have timed out') + + # Right after that, the job should be available + freezer.tick(1) + acquired = await store.acquire_jobs('dummy-id2', 1) + assert len(acquired) == 1 + assert acquired[0].id == jobs[0].id diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py new file mode 100644 index 0000000..b155f6c --- /dev/null +++ b/tests/test_schedulers.py @@ -0,0 +1,59 @@ +import logging +from datetime import datetime, timezone + +import pytest + +from apscheduler.events import JobSuccessful +from apscheduler.schedulers.async_ import AsyncScheduler +from apscheduler.schedulers.sync import SyncScheduler +from apscheduler.triggers.date import DateTrigger + +pytestmark = pytest.mark.anyio + + +async def dummy_async_job(): + return 'returnvalue' + + +def dummy_sync_job(): + return 'returnvalue' + + +class TestAsyncScheduler: + async def test_schedule_job(self, caplog, store): + async def listener(event): + events.append(event) + if isinstance(event, JobSuccessful): + await scheduler.stop() + + caplog.set_level(logging.DEBUG) + trigger = DateTrigger(datetime.now(timezone.utc)) + events = [] + async with AsyncScheduler(store) as scheduler: + scheduler.worker.subscribe(listener) + await scheduler.add_schedule(dummy_async_job, trigger) + await scheduler.wait_until_stopped() + + assert len(events) == 2 + assert isinstance(events[1], JobSuccessful) + assert events[1].return_value == 'returnvalue' + + +class TestSyncScheduler: + @pytest.mark.parametrize('anyio_backend', ['asyncio']) + def test_schedule_job(self, caplog, anyio_backend, sync_store, portal): + def listener(event): + events.append(event) + if isinstance(event, JobSuccessful): + scheduler.stop() + + caplog.set_level(logging.DEBUG) + events = [] + with SyncScheduler(sync_store, portal=portal) as scheduler: + scheduler.worker.subscribe(listener) + scheduler.add_schedule(dummy_sync_job, DateTrigger(datetime.now(timezone.utc))) + scheduler.wait_until_stopped() + + assert len(events) == 2 + assert isinstance(events[1], JobSuccessful) + assert events[1].return_value == 'returnvalue' diff --git a/tests/test_workers.py b/tests/test_workers.py new file mode 100644 index 0000000..89f88e3 --- /dev/null +++ b/tests/test_workers.py @@ -0,0 +1,152 @@ +import threading +from datetime import datetime + +import pytest +from anyio import create_event, fail_after + +from apscheduler.abc import Job +from apscheduler.events import JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated +from apscheduler.workers.async_ import AsyncWorker +from apscheduler.workers.sync import SyncWorker + +pytestmark = pytest.mark.anyio + + +def sync_func(*args, fail: bool, **kwargs): + if fail: + raise Exception('failing as requested') + else: + return args, kwargs + + +async def async_func(*args, fail: bool, **kwargs): + if fail: + raise Exception('failing as requested') + else: + return args, kwargs + + +def fail_func(): + pytest.fail('This function should never be run') + + +class TestAsyncWorker: + @pytest.mark.parametrize('target_func', [sync_func, async_func], ids=['sync', 'async']) + @pytest.mark.parametrize('fail', [False, True], ids=['success', 'fail']) + @pytest.mark.parametrize('anyio_backend', ['asyncio']) + async def test_run_job_nonscheduled_success(self, target_func, fail, store): + async def listener(worker_event): + worker_events.append(worker_event) + if len(worker_events) == 2: + await event.set() + + worker_events = [] + event = create_event() + job = Job('task_id', func=target_func, args=(1, 2), kwargs={'x': 'foo', 'fail': fail}) + async with AsyncWorker(store) as worker: + worker.subscribe(listener) + await store.add_job(job) + async with fail_after(2): + await event.wait() + + assert len(worker_events) == 2 + + assert isinstance(worker_events[0], JobUpdated) + assert worker_events[0].job_id == job.id + assert worker_events[0].task_id == 'task_id' + assert worker_events[0].schedule_id is None + + assert worker_events[1].job_id == job.id + assert worker_events[1].task_id == 'task_id' + assert worker_events[1].schedule_id is None + if fail: + assert isinstance(worker_events[1], JobFailed) + assert type(worker_events[1].exception) is Exception + assert isinstance(worker_events[1].traceback, str) + else: + assert isinstance(worker_events[1], JobSuccessful) + assert worker_events[1].return_value == ((1, 2), {'x': 'foo'}) + + @pytest.mark.parametrize('anyio_backend', ['asyncio']) + async def test_run_deadline_missed(self, store): + async def listener(worker_event): + worker_events.append(worker_event) + await event.set() + + scheduled_start_time = datetime(2020, 9, 14) + worker_events = [] + event = create_event() + job = Job('task_id', fail_func, args=(), kwargs={}, schedule_id='foo', + scheduled_fire_time=scheduled_start_time, + start_deadline=datetime(2020, 9, 14, 1)) + async with AsyncWorker(store) as worker: + worker.subscribe(listener) + await store.add_job(job) + async with fail_after(5): + await event.wait() + + assert len(worker_events) == 1 + assert isinstance(worker_events[0], JobDeadlineMissed) + assert worker_events[0].job_id == job.id + assert worker_events[0].task_id == 'task_id' + assert worker_events[0].schedule_id == 'foo' + assert worker_events[0].scheduled_fire_time == scheduled_start_time + + +class TestSyncWorker: + @pytest.mark.parametrize('target_func', [sync_func, async_func], ids=['sync', 'async']) + @pytest.mark.parametrize('fail', [False, True], ids=['success', 'fail']) + def test_run_job_nonscheduled(self, anyio_backend, target_func, fail, sync_store, portal): + def listener(worker_event): + print('received event:', worker_event) + worker_events.append(worker_event) + if len(worker_events) == 2: + event.set() + + worker_events = [] + event = threading.Event() + job = Job('task_id', func=target_func, args=(1, 2), kwargs={'x': 'foo', 'fail': fail}) + with SyncWorker(sync_store, portal=portal) as worker: + worker.subscribe(listener) + portal.call(sync_store.add_job, job) + event.wait(2) + + assert len(worker_events) == 2 + + assert isinstance(worker_events[0], JobUpdated) + assert worker_events[0].job_id == job.id + assert worker_events[0].task_id == 'task_id' + assert worker_events[0].schedule_id is None + + assert worker_events[1].job_id == job.id + assert worker_events[1].task_id == 'task_id' + assert worker_events[1].schedule_id is None + if fail: + assert isinstance(worker_events[1], JobFailed) + assert type(worker_events[1].exception) is Exception + assert isinstance(worker_events[1].traceback, str) + else: + assert isinstance(worker_events[1], JobSuccessful) + assert worker_events[1].return_value == ((1, 2), {'x': 'foo'}) + + def test_run_deadline_missed(self, anyio_backend, sync_store, portal): + def listener(worker_event): + worker_events.append(worker_event) + event.set() + + scheduled_start_time = datetime(2020, 9, 14) + worker_events = [] + event = threading.Event() + job = Job('task_id', fail_func, args=(), kwargs={}, schedule_id='foo', + scheduled_fire_time=scheduled_start_time, + start_deadline=datetime(2020, 9, 14, 1)) + with SyncWorker(sync_store, portal=portal) as worker: + worker.subscribe(listener) + portal.call(sync_store.add_job, job) + event.wait(5) + + assert len(worker_events) == 1 + assert isinstance(worker_events[0], JobDeadlineMissed) + assert worker_events[0].job_id == job.id + assert worker_events[0].task_id == 'task_id' + assert worker_events[0].schedule_id == 'foo' diff --git a/tests/workers/test_local.py b/tests/workers/test_local.py deleted file mode 100644 index d72d559..0000000 --- a/tests/workers/test_local.py +++ /dev/null @@ -1,101 +0,0 @@ -from datetime import datetime - -import pytest -from anyio import fail_after, sleep -from apscheduler.abc import Job -from apscheduler.events import JobAdded, JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated -from apscheduler.workers.local import LocalExecutor - -pytestmark = pytest.mark.anyio - - -@pytest.mark.parametrize('sync', [True, False], ids=['sync', 'async']) -@pytest.mark.parametrize('fail', [False, True], ids=['success', 'fail']) -async def test_run_job_nonscheduled_success(sync, fail): - def sync_func(*args, **kwargs): - nonlocal received_args, received_kwargs - received_args = args - received_kwargs = kwargs - if fail: - raise Exception('failing as requested') - else: - return 'success' - - async def async_func(*args, **kwargs): - nonlocal received_args, received_kwargs - received_args = args - received_kwargs = kwargs - if fail: - raise Exception('failing as requested') - else: - return 'success' - - received_args = received_kwargs = None - events = [] - async with LocalExecutor() as worker: - await worker.subscribe(events.append) - - job = Job('task_id', sync_func if sync else async_func, args=(1, 2), kwargs={'x': 'foo'}) - await worker.submit_job(job) - - async with fail_after(1): - while len(events) < 3: - await sleep(0) - - assert received_args == (1, 2) - assert received_kwargs == {'x': 'foo'} - - assert isinstance(events[0], JobAdded) - assert events[0].job_id == job.id - assert events[0].task_id == 'task_id' - assert events[0].schedule_id is None - assert events[0].scheduled_start_time is None - - assert isinstance(events[1], JobUpdated) - assert events[1].job_id == job.id - assert events[1].task_id == 'task_id' - assert events[1].schedule_id is None - assert events[1].scheduled_start_time is None - - assert events[2].job_id == job.id - assert events[2].task_id == 'task_id' - assert events[2].schedule_id is None - assert events[2].scheduled_start_time is None - if fail: - assert isinstance(events[2], JobFailed) - assert type(events[2].exception) is Exception - assert isinstance(events[2].formatted_traceback, str) - else: - assert isinstance(events[2], JobSuccessful) - assert events[2].return_value == 'success' - - -async def test_run_deadline_missed(): - def func(): - pytest.fail('This function should never be run') - - scheduled_start_time = datetime(2020, 9, 14) - events = [] - async with LocalExecutor() as worker: - await worker.subscribe(events.append) - - job = Job('task_id', func, args=(), kwargs={}, schedule_id='foo', - scheduled_start_time=scheduled_start_time, - start_deadline=datetime(2020, 9, 14, 1)) - await worker.submit_job(job) - - async with fail_after(1): - while len(events) < 2: - await sleep(0) - - assert isinstance(events[0], JobAdded) - assert events[0].job_id == job.id - assert events[0].task_id == 'task_id' - assert events[0].schedule_id == 'foo' - assert events[0].scheduled_start_time == scheduled_start_time - - assert isinstance(events[1], JobDeadlineMissed) - assert events[1].job_id == job.id - assert events[1].task_id == 'task_id' - assert events[1].schedule_id == 'foo' - assert events[1].scheduled_start_time == scheduled_start_time |