summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-02-15 00:32:33 +0200
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-02-24 23:44:32 +0200
commitf4e8c3a0242b082fa1ca6ed5c78094f8de5ba439 (patch)
tree20573886078f60b59ba2811b3fc156c877ff386d
parent5bdf7bd34300ec764b1c5c979deec1e468b0b822 (diff)
downloadapscheduler-f4e8c3a0242b082fa1ca6ed5c78094f8de5ba439.tar.gz
Implemented data store sharing and proper async support
-rw-r--r--.github/workflows/codeqa-test.yml2
-rw-r--r--apscheduler/abc.py159
-rw-r--r--apscheduler/datastores/memory.py360
-rw-r--r--apscheduler/datastores/mongodb.py272
-rw-r--r--apscheduler/datastores/postgresql.py331
-rw-r--r--apscheduler/eventhubs/__init__.py0
-rw-r--r--apscheduler/eventhubs/local.py25
-rw-r--r--apscheduler/events.py155
-rw-r--r--apscheduler/exceptions.py23
-rw-r--r--apscheduler/policies.py19
-rw-r--r--apscheduler/schedulers/__init__.py12
-rw-r--r--apscheduler/schedulers/async_.py278
-rw-r--r--apscheduler/schedulers/sync.py119
-rw-r--r--apscheduler/util.py6
-rw-r--r--apscheduler/watchers.py22
-rw-r--r--apscheduler/workers/async_.py158
-rw-r--r--apscheduler/workers/local.py95
-rw-r--r--apscheduler/workers/sync.py25
-rw-r--r--docker-compose.yml10
-rw-r--r--docs/versionhistory.rst4
-rw-r--r--examples/executors/processpool.py2
-rw-r--r--examples/misc/reference.py2
-rw-r--r--examples/rpc/server.py6
-rw-r--r--examples/schedulers/async_.py4
-rw-r--r--examples/schedulers/sync.py4
-rw-r--r--setup.cfg4
-rw-r--r--tests/backends/test_memory.py138
-rw-r--r--tests/conftest.py89
-rw-r--r--tests/test_datastores.py245
-rw-r--r--tests/test_schedulers.py59
-rw-r--r--tests/test_workers.py152
-rw-r--r--tests/workers/test_local.py101
32 files changed, 2086 insertions, 795 deletions
diff --git a/.github/workflows/codeqa-test.yml b/.github/workflows/codeqa-test.yml
index 35b018f..8a87f9e 100644
--- a/.github/workflows/codeqa-test.yml
+++ b/.github/workflows/codeqa-test.yml
@@ -47,6 +47,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install the project and its dependencies
- run: pip install .[test,cbor]
+ run: pip install -e .[test,cbor]
- name: Test with pytest
run: pytest
diff --git a/apscheduler/abc.py b/apscheduler/abc.py
index 4bab22b..0acf7fd 100644
--- a/apscheduler/abc.py
+++ b/apscheduler/abc.py
@@ -3,11 +3,10 @@ from base64 import b64decode, b64encode
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import (
- Any, AsyncContextManager, Callable, Dict, FrozenSet, Iterable, Iterator, List, Mapping,
- Optional, Set)
-from uuid import uuid4
+ Any, Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional, Set, Type)
+from uuid import UUID, uuid4
-from apscheduler.events import Event
+from .policies import CoalescePolicy, ConflictPolicy
class Trigger(Iterator[datetime], metaclass=ABCMeta):
@@ -60,7 +59,7 @@ class Schedule:
trigger: Trigger = field(compare=False)
args: tuple = field(compare=False)
kwargs: Dict[str, Any] = field(compare=False)
- coalesce: bool = field(compare=False)
+ coalesce: CoalescePolicy = field(compare=False)
misfire_grace_time: Optional[timedelta] = field(compare=False)
tags: FrozenSet[str] = field(compare=False)
next_fire_time: Optional[datetime] = field(compare=False, default=None)
@@ -76,13 +75,13 @@ class Schedule:
@dataclass(unsafe_hash=True)
class Job:
- id: str = field(init=False, default_factory=lambda: str(uuid4()))
+ id: UUID = field(init=False, default_factory=uuid4)
task_id: str = field(compare=False)
func: Callable = field(compare=False)
args: tuple = field(compare=False)
kwargs: Dict[str, Any] = field(compare=False)
schedule_id: Optional[str] = field(compare=False, default=None)
- scheduled_start_time: Optional[datetime] = field(compare=False, default=None)
+ scheduled_fire_time: Optional[datetime] = field(compare=False, default=None)
start_deadline: Optional[datetime] = field(compare=False, default=None)
tags: Optional[FrozenSet[str]] = field(compare=False, default_factory=frozenset)
started_at: Optional[datetime] = field(init=False, compare=False, default=None)
@@ -106,22 +105,44 @@ class Serializer(metaclass=ABCMeta):
return self.deserialize(b64decode(serialized))
-class EventSource(metaclass=ABCMeta):
- __slots__ = ()
+@dataclass(frozen=True)
+class Event:
+ timestamp: datetime
+
+@dataclass
+class EventSource(metaclass=ABCMeta):
@abstractmethod
- async def subscribe(self, callback: Callable[[Event], Any]) -> None:
+ def subscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[Type[Event]]] = None) -> None:
"""
Subscribe to events from this event source.
:param callback: callable to be called with the event object when an event is published
- :return: an async iterable yielding event objects
+ :param event_types: an iterable of concrete Event classes to subscribe to
"""
+ @abstractmethod
+ def unsubscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[Type[Event]]] = None) -> None:
+ """
+ Cancel an event subscription
-class DataStore(EventSource):
- __slots__ = ()
+ :param callback:
+ :param event_types: an iterable of concrete Event classes to unsubscribe from
+ :return:
+ """
+ @abstractmethod
+ async def publish(self, event: Event) -> None:
+ """
+ Publish an event.
+
+ :param event: the event to publish
+ """
+
+
+class DataStore(EventSource):
async def __aenter__(self):
return self
@@ -138,112 +159,80 @@ class DataStore(EventSource):
"""
@abstractmethod
- async def add_or_replace_schedules(self, schedules: Iterable[Schedule]) -> None:
+ async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
"""
Add or update the given schedule in the data store.
- :param schedules: schedules to be added or updated
+ :param schedule: schedule to be added
+ :param conflict_policy: policy that determines what to do if there is an existing schedule
+ with the same ID
"""
@abstractmethod
- async def update_schedules(self, updates: Mapping[str, Dict[str, Any]]) -> Set[str]:
- """
- Update one or more existing schedules.
-
- :param updates: mapping of schedule ID to attribute names to be updated
- :return: the set of schedule IDs that were modified by this operation.
- """
-
- @abstractmethod
- async def remove_schedules(self, ids: Optional[Set[str]] = None) -> None:
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
"""
Remove schedules from the data store.
- :param ids: a specific set of schedule IDs to remove, or ``None`` in which case all
- schedules are removed
+ :param ids: a specific set of schedule IDs to remove
"""
@abstractmethod
- async def get_next_fire_time(self) -> Optional[datetime]:
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]:
"""
- Return the earliest fire time among all unclaimed schedules.
+ Acquire unclaimed due schedules for processing.
- If no running, unclaimed schedules exist, ``None`` is returned.
+ This method claims up to the requested number of schedules for the given scheduler and
+ returns them.
+
+ :param scheduler_id: unique identifier of the scheduler
+ :param limit: maximum number of schedules to claim
+ :return: the list of claimed schedules
"""
@abstractmethod
- async def acquire_due_schedules(
- self, scheduler_id: str,
- max_scheduled_time: datetime) -> AsyncContextManager[List[Schedule]]:
+ async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None:
"""
- Acquire an undefined amount of due schedules not claimed by any other scheduler.
+ Release the claims on the given schedules and update them on the store.
- This method claims due schedules for the given scheduler and returns them.
- When the scheduler has updated the objects, it calls :meth:`release_due_schedules` to
- release the claim on them.
+ :param scheduler_id: unique identifier of the scheduler
+ :param schedules: the previously claimed schedules
"""
- # @abstractmethod
- # async def release_due_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None:
- # """
- # Update the given schedules and release the claim on them held by this scheduler.
- #
- # This method should do the following:
- #
- # #. Remove any of the schedules in the store that have no next fire time
- # #. Update the schedules that do have a next fire time
- # #. Release any locks held on the schedules by this scheduler
- #
- # :param scheduler_id: identifier of the scheduler
- # :param schedules: schedules previously acquired using :meth:`acquire_due_schedules`
- # """
-
- # @abstractmethod
- # async def acquire_job(self, worker_id: str, tags: Set[str]) -> Job:
- # """
- # Claim and return the next matching job from the queue.
- #
- # :return: the acquired job
- # """
- #
- # @abstractmethod
- # async def release_job(self, job: Job) -> None:
- # """Remove the given job from the queue."""
-
-
-class Executor(EventSource):
@abstractmethod
- async def submit_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
"""
- Submit a task to be run in this executor.
-
- The executor may alter the ``id`` attribute of the job before returning.
+ Add a job to be executed by an eligible worker.
:param job: the job object
"""
@abstractmethod
- async def get_jobs(self) -> List[Job]:
+ async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]:
"""
- Get jobs currently queued or running in this executor.
+ Get the list of pending jobs.
- :return: list of jobs
+ :param ids: a specific set of job IDs to return, or ``None`` to return all jobs
+ :return: the list of matching pending jobs, in the order they will be given to workers
"""
+ @abstractmethod
+ async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
+ """
+ Acquire unclaimed jobs for execution.
-class EventHub(metaclass=ABCMeta):
- __slots__ = ()
-
- async def start(self) -> None:
- pass
+ This method claims up to the requested number of jobs for the given worker and returns
+ them.
- async def stop(self) -> None:
- pass
+ :param worker_id: unique identifier of the worker
+ :param limit: maximum number of jobs to claim and return
+ :return: the list of claimed jobs
+ """
@abstractmethod
- async def publish(self, event: Event) -> None:
- """Publish an event."""
+ async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None:
+ """
+ Releases the claim on the given jobs
- @abstractmethod
- async def subscribe(self, callback: Callable[[Event], Any]) -> None:
- """Add a callback to be called when a new event is published."""
+ :param worker_id: unique identifier of the worker
+ :param jobs: the previously claimed jobs
+ """
diff --git a/apscheduler/datastores/memory.py b/apscheduler/datastores/memory.py
index 3ad76c4..826981a 100644
--- a/apscheduler/datastores/memory.py
+++ b/apscheduler/datastores/memory.py
@@ -1,120 +1,274 @@
-from contextlib import asynccontextmanager
+from bisect import bisect_left, insort_right
+from collections import defaultdict
from dataclasses import dataclass, field
from datetime import MAXYEAR, datetime, timezone
-from typing import Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set
+from functools import partial
+from typing import Dict, Iterable, List, Optional, Set
+from uuid import UUID
-from apscheduler.abc import DataStore, Event, EventHub, Schedule
-from apscheduler.eventhubs.local import LocalEventHub
-from apscheduler.events import DataStoreEvent, SchedulesAdded, SchedulesRemoved, SchedulesUpdated
-from sortedcontainers import SortedSet # type: ignore
+from anyio import create_event, move_on_after
+from anyio.abc import Event
+
+from ..abc import DataStore, Job, Schedule
+from ..events import EventHub, ScheduleAdded, ScheduleEvent, ScheduleRemoved, ScheduleUpdated
+from ..exceptions import ConflictingIdError
+from ..policies import ConflictPolicy
max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)
-@dataclass(frozen=True, repr=False)
-class MemoryScheduleStore(DataStore):
- _event_hub: EventHub = field(init=False, default_factory=LocalEventHub)
- _schedules: Set[Schedule] = field(
- init=False,
- default_factory=lambda: SortedSet(key=lambda x: x.next_fire_time or max_datetime))
- _schedules_by_id: Dict[str, Schedule] = field(default_factory=dict)
+class ScheduleState:
+ __slots__ = 'schedule', 'next_fire_time', 'acquired_by', 'acquired_until'
+
+ def __init__(self, schedule: Schedule):
+ self.schedule = schedule
+ self.next_fire_time = self.schedule.next_fire_time
+ self.acquired_by: Optional[str] = None
+ self.acquired_until: Optional[datetime] = None
+
+ def __lt__(self, other):
+ if self.next_fire_time is None:
+ return False
+ elif other.next_fire_time is None:
+ return self.next_fire_time is not None
+ else:
+ return self.next_fire_time < other.next_fire_time
+
+ def __eq__(self, other):
+ return self.schedule == other.schedule
+
+ def __hash__(self):
+ return hash(self.schedule)
+
+ def __repr__(self):
+ return (f'<{self.__class__.__name__} id={self.schedule.id!r} '
+ f'next_fire_time={self.next_fire_time}>')
+
+
+class JobState:
+ __slots__ = 'job', 'created_at', 'acquired_by', 'acquired_until'
+
+ def __init__(self, job: Job):
+ self.job = job
+ self.created_at = datetime.now(timezone.utc)
+ self.acquired_by: Optional[str] = None
+ self.acquired_until: Optional[datetime] = None
+
+ def __lt__(self, other):
+ return self.created_at < other.created_at
+
+ def __eq__(self, other):
+ return self.job.id == other.job.id
+
+ def __hash__(self):
+ return hash(self.job.id)
+
+ def __repr__(self):
+ return f'<{self.__class__.__name__} id={self.job.id!r} task_id={self.job.task_id!r}'
+
+
+@dataclass(repr=False)
+class MemoryDataStore(DataStore, EventHub):
+ lock_expiration_delay: float = 30
+ _schedules: List[ScheduleState] = field(init=False, default_factory=list)
+ _schedules_by_id: Dict[str, ScheduleState] = field(init=False, default_factory=dict)
+ _schedules_by_task_id: Dict[str, Set[ScheduleState]] = field(
+ init=False, default_factory=partial(defaultdict, set))
+ _schedules_event: Optional[Event] = field(init=False, default=None)
+ _jobs: List[JobState] = field(init=False, default_factory=list)
+ _jobs_by_id: Dict[UUID, JobState] = field(init=False, default_factory=dict)
+ _jobs_by_task_id: Dict[str, Set[JobState]] = field(init=False,
+ default_factory=partial(defaultdict, set))
+ _jobs_event: Optional[Event] = field(init=False, default=None)
+ _loans: int = 0
+
+ def _find_schedule_index(self, state: ScheduleState) -> Optional[int]:
+ left_index = bisect_left(self._schedules, state)
+ right_index = bisect_left(self._schedules, state)
+ return self._schedules.index(state, left_index, right_index + 1)
+
+ def _find_job_index(self, state: JobState) -> Optional[int]:
+ left_index = bisect_left(self._jobs, state)
+ right_index = bisect_left(self._jobs, state)
+ return self._jobs.index(state, left_index, right_index + 1)
+
+ async def __aenter__(self):
+ self._loans += 1
+ if self._loans == 1:
+ self._schedules_event = create_event()
+ self._jobs_event = create_event()
+
+ return await super().__aenter__()
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ self._loans -= 1
+ if self._loans == 0:
+ self._schedules_event = None
+ self._jobs_event = None
+
+ return await super().__aexit__(exc_type, exc_val, exc_tb)
+
+ async def clear(self) -> None:
+ self._schedules.clear()
+ self._schedules_by_id.clear()
+ self._schedules_by_task_id.clear()
+ self._jobs.clear()
+ self._jobs_by_id.clear()
+ self._jobs_by_task_id.clear()
async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
- return [sched for sched in self._schedules if ids is None or sched.id in ids]
-
- async def add_or_replace_schedules(self, schedules: Iterable[Schedule]) -> None:
- added_schedule_ids = set()
- updated_schedule_ids = set()
- for schedule in schedules:
- if schedule.id in self._schedules_by_id:
- updated_schedule_ids.add(schedule.id)
- old_schedule = self._schedules_by_id[schedule.id]
- self._schedules.remove(old_schedule)
- else:
- added_schedule_ids.add(schedule.id)
-
- self._schedules_by_id[schedule.id] = schedule
- self._schedules.add(schedule)
-
- event: DataStoreEvent
- if added_schedule_ids:
- earliest_fire_time = min(
- (schedule.next_fire_time for schedule in schedules
- if schedule.id in added_schedule_ids and schedule.next_fire_time),
- default=None)
- event = SchedulesAdded(datetime.now(timezone.utc), added_schedule_ids,
- earliest_fire_time)
- await self._event_hub.publish(event)
-
- if updated_schedule_ids:
- earliest_fire_time = min(
- (schedule.next_fire_time for schedule in schedules
- if schedule.id in updated_schedule_ids and schedule.next_fire_time),
- default=None)
- event = SchedulesUpdated(datetime.now(timezone.utc), updated_schedule_ids,
- earliest_fire_time)
- await self._event_hub.publish(event)
-
- async def update_schedules(self, updates: Mapping[str, Dict[str, Any]]) -> Set[str]:
- updated_schedule_ids: Set[str] = set()
- earliest_fire_time: Optional[datetime] = None
- for schedule_id, changes in updates.items():
- try:
- schedule = self._schedules_by_id[schedule_id]
- except KeyError:
- continue
-
- if changes:
- updated_schedule_ids.add(schedule.id)
- for key, value in changes.items():
- setattr(schedule, key, value)
- if key == 'next_fire_time':
- if not earliest_fire_time:
- earliest_fire_time = schedule.next_fire_time
- elif (schedule.next_fire_time
- and schedule.next_fire_time < earliest_fire_time):
- earliest_fire_time = schedule.next_fire_time
-
- event = SchedulesUpdated(datetime.now(timezone.utc), updated_schedule_ids,
- earliest_fire_time)
- await self._event_hub.publish(event)
-
- return updated_schedule_ids
-
- async def remove_schedules(self, ids: Optional[Set[str]] = None) -> None:
- if ids is None:
- removed_schedule_ids = set(self._schedules_by_id)
- self._schedules_by_id.clear()
- self._schedules.clear()
+ return [state.schedule for state in self._schedules
+ if ids is None or state.schedule.id in ids]
+
+ async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ old_state = self._schedules_by_id.get(schedule.id)
+ if old_state is not None:
+ if conflict_policy is ConflictPolicy.do_nothing:
+ return
+ elif conflict_policy is ConflictPolicy.exception:
+ raise ConflictingIdError(schedule.id)
+
+ index = self._find_schedule_index(old_state)
+ del self._schedules[index]
+ self._schedules_by_task_id[old_state.schedule.task_id].remove(old_state)
+
+ state = ScheduleState(schedule)
+ self._schedules_by_id[schedule.id] = state
+ self._schedules_by_task_id[schedule.task_id].add(state)
+ insort_right(self._schedules, state)
+
+ if old_state is not None:
+ event = ScheduleUpdated(datetime.now(timezone.utc), schedule.id,
+ schedule.next_fire_time)
else:
- removed_schedule_ids = set()
- for schedule_id in ids:
- schedule = self._schedules_by_id.pop(schedule_id, None)
- if schedule:
- removed_schedule_ids.add(schedule.id)
- self._schedules.remove(schedule)
+ event = ScheduleAdded(datetime.now(timezone.utc), schedule.id,
+ schedule.next_fire_time)
+
+ await self.publish(event)
+
+ old_event, self._schedules_event = self._schedules_event, create_event()
+ await old_event.set()
+
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
+ events: List[ScheduleRemoved] = []
+ for schedule_id in ids:
+ state = self._schedules_by_id.pop(schedule_id, None)
+ if state:
+ self._schedules.remove(state)
+ events.append(ScheduleRemoved(datetime.now(timezone.utc), state.schedule.id))
+
+ for event in events:
+ await self.publish(event)
+
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]:
+ while True:
+ now = datetime.now(timezone.utc)
+ schedules: List[Schedule] = []
+ wait_time = None
+ for state in self._schedules:
+ if state.acquired_by is not None and state.acquired_until >= now:
+ continue
+ elif state.next_fire_time is None:
+ break
+ elif state.next_fire_time > now:
+ wait_time = state.next_fire_time.timestamp() - now.timestamp()
+ break
+
+ schedules.append(state.schedule)
+ state.acquired_by = scheduler_id
+ state.acquired_until = datetime.fromtimestamp(
+ now.timestamp() + self.lock_expiration_delay, now.tzinfo)
+ if len(schedules) == limit:
+ break
+
+ if schedules:
+ return schedules
+
+ # Wait until reaching the next known fire time, or a schedule is added or updated
+ async with move_on_after(wait_time):
+ await self._schedules_event.wait()
+
+ async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None:
+ # Send update events for schedules that have a next time
+ now = datetime.now(timezone.utc)
+ update_events: List[ScheduleEvent] = []
+ finished_schedule_ids: List[str] = []
+ for s in schedules:
+ if s.next_fire_time is not None:
+ # Remove the schedule
+ schedule_state = self._schedules_by_id.get(s.id)
+ index = self._find_schedule_index(schedule_state)
+ del self._schedules[index]
+
+ # Readd the schedule to its new position
+ schedule_state.next_fire_time = s.next_fire_time
+ schedule_state.acquired_by = None
+ schedule_state.acquired_until = None
+ insort_right(self._schedules, schedule_state)
+ update_events.append(ScheduleUpdated(now, s.id, s.next_fire_time))
+ else:
+ finished_schedule_ids.append(s.id)
+
+ # if updated_ids:
+ # next_fire_time = min(s.next_fire_time for s in schedules
+ # if s.next_fire_time is not None)
+ # event = ScheduleUpdated(datetime.now(timezone.utc), updated_ids, next_fire_time)
+ # await self.publish(event)
+ # old_event, self._schedules_event = self._schedules_event, create_event()
+ # await old_event.set()
+
+ for event in update_events:
+ await self.publish(event)
+
+ # Remove schedules that didn't get a new next fire time
+ await self.remove_schedules(finished_schedule_ids)
+
+ # async def get_next_fire_time(self) -> Optional[datetime]:
+ # for schedule in self._schedules:
+ # return schedule.next_fire_time
+ #
+ # return None
+
+ async def add_job(self, job: Job) -> None:
+ state = JobState(job)
+ self._jobs.append(state)
+ self._jobs_by_id[job.id] = state
+ self._jobs_by_task_id[job.task_id].add(state)
+
+ old_event, self._jobs_event = self._jobs_event, create_event()
+ await old_event.set()
- event = SchedulesRemoved(datetime.now(timezone.utc), removed_schedule_ids)
- await self._event_hub.publish(event)
+ async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]:
+ if ids is not None:
+ ids = frozenset(ids)
- @asynccontextmanager
- async def acquire_due_schedules(
- self, scheduler_id: str,
- max_scheduled_time: datetime) -> AsyncGenerator[List[Schedule], None]:
- pending: List[Schedule] = []
- for schedule in self._schedules:
- if not schedule.next_fire_time or schedule.next_fire_time > max_scheduled_time:
- break
+ return [state.job for state in self._jobs if ids is None or state.job.id in ids]
- pending.append(schedule)
+ async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
+ while True:
+ now = datetime.now(timezone.utc)
+ jobs: List[Job] = []
+ for state in self._jobs:
+ if state.acquired_by is not None and state.acquired_until >= now:
+ continue
- yield pending
+ jobs.append(state.job)
+ state.acquired_by = worker_id
+ state.acquired_until = datetime.fromtimestamp(
+ now.timestamp() + self.lock_expiration_delay, now.tzinfo)
+ if len(jobs) == limit:
+ break
- async def get_next_fire_time(self) -> Optional[datetime]:
- for schedule in self._schedules:
- return schedule.next_fire_time
+ if jobs:
+ return jobs
- return None
+ # Wait until a new job is added
+ await self._jobs_event.wait()
- async def subscribe(self, callback: Callable[[Event], Any]) -> None:
- await self._event_hub.subscribe(callback)
+ async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None:
+ assert self._jobs
+ indexes = sorted((self._find_job_index(self._jobs_by_id[j.id]) for j in jobs),
+ reverse=True)
+ for i in indexes:
+ state = self._jobs.pop(i)
+ self._jobs_by_task_id[state.job.task_id].remove(state)
diff --git a/apscheduler/datastores/mongodb.py b/apscheduler/datastores/mongodb.py
new file mode 100644
index 0000000..f547c90
--- /dev/null
+++ b/apscheduler/datastores/mongodb.py
@@ -0,0 +1,272 @@
+import asyncio
+import logging
+from datetime import datetime, timezone
+from typing import Iterable, List, Optional, Set, Tuple
+
+import sniffio
+from anyio import create_task_group, move_on_after
+from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCollection
+from pymongo import ASCENDING, DeleteOne, UpdateOne
+from pymongo.collection import Collection
+from pymongo.errors import DuplicateKeyError
+
+from ..abc import DataStore, Job, Schedule, Serializer
+from ..events import EventHub, ScheduleAdded, ScheduleRemoved, ScheduleUpdated
+from ..exceptions import ConflictingIdError, DeserializationError, SerializationError
+from ..policies import ConflictPolicy
+from ..serializers.pickle import PickleSerializer
+
+
+class MongoDBDataStore(DataStore, EventHub):
+ def __init__(self, client: AsyncIOMotorClient, *, serializer: Optional[Serializer] = None,
+ database: str = 'apscheduler', schedules_collection: str = 'schedules',
+ jobs_collection: str = 'jobs', lock_expiration_delay: float = 30,
+ max_poll_time: Optional[float] = 1, start_from_scratch: bool = False):
+ super().__init__()
+ if not client.delegate.codec_options.tz_aware:
+ raise ValueError('MongoDB client must have tz_aware set to True')
+
+ self.client = client
+ self.serializer = serializer or PickleSerializer()
+ self.lock_expiration_delay = lock_expiration_delay
+ self.max_poll_time = max_poll_time
+ self.start_from_scratch = start_from_scratch
+ self._database = client[database]
+ self._schedules: AsyncIOMotorCollection = self._database[schedules_collection]
+ self._jobs: AsyncIOMotorCollection = self._database[jobs_collection]
+ self._logger = logging.getLogger(__name__)
+ self._watch_collections = hasattr(self.client, '_topology_settings')
+ self._loans = 0
+
+ async def __aenter__(self):
+ asynclib = sniffio.current_async_library() or '(unknown)'
+ if asynclib != 'asyncio':
+ raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}')
+
+ if self._loans == 0:
+ self._schedules_event = asyncio.Event()
+ self._jobs_event = asyncio.Event()
+ await self._setup()
+
+ self._loans += 1
+ if self._loans == 1 and self._watch_collections:
+ self._task_group = create_task_group()
+ await self._task_group.__aenter__()
+ await self._task_group.spawn(self._watch_collection, self._schedules)
+ await self._task_group.spawn(self._watch_collection, self._jobs)
+
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ assert self._loans
+ self._loans -= 1
+ if self._loans == 0 and self._watch_collections:
+ await self._task_group.cancel_scope.cancel()
+ await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
+
+ async def _watch_collection(self, collection: Collection) -> None:
+ pipeline = [{'$match': {'operationType': {'$in': ['insert', 'update', 'replace']}}},
+ {'$project': ['next_fire_time']}]
+ async with collection.watch(pipeline, batch_size=1,
+ max_await_time_ms=self.max_poll_time) as stream:
+ await stream.next()
+
+ async def _setup(self) -> None:
+ server_info = await self.client.server_info()
+ if server_info['versionArray'] < [4, 0]:
+ raise RuntimeError(f"MongoDB server must be at least v4.0; current version = "
+ f"{server_info['version']}")
+
+ if self.start_from_scratch:
+ await self._schedules.delete_many({})
+ await self._jobs.delete_many({})
+
+ await self._schedules.create_index('next_fire_time')
+ await self._jobs.create_index('task_id')
+ await self._jobs.create_index('created_at')
+ await self._jobs.create_index('tags')
+
+ async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
+ schedules: List[Schedule] = []
+ filters = {'_id': {'$in': list(ids)}} if ids is not None else {}
+ cursor = self._schedules.find(filters, projection=['_id', 'serialized_data']).sort('_id')
+ for document in await cursor.to_list(None):
+ try:
+ schedule = self.serializer.deserialize(document['serialized_data'])
+ except DeserializationError:
+ self._logger.warning('Failed to deserialize schedule %r', document['_id'])
+ continue
+
+ schedules.append(schedule)
+
+ return schedules
+
+ async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ serialized_data = self.serializer.serialize(schedule)
+ document = {
+ '_id': schedule.id,
+ 'task_id': schedule.task_id,
+ 'serialized_data': serialized_data,
+ 'next_fire_time': schedule.next_fire_time
+ }
+ try:
+ await self._schedules.insert_one(document)
+ except DuplicateKeyError:
+ if conflict_policy is ConflictPolicy.exception:
+ raise ConflictingIdError(schedule.id) from None
+ elif conflict_policy is ConflictPolicy.replace:
+ await self._schedules.replace_one({'_id': schedule.id}, document, True)
+ await self.publish(
+ ScheduleUpdated(datetime.now(timezone.utc), schedule.id,
+ schedule.next_fire_time)
+ )
+ else:
+ await self.publish(
+ ScheduleAdded(datetime.now(timezone.utc), schedule.id, schedule.next_fire_time)
+ )
+
+ self._schedules_event.set()
+
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
+ async with await self.client.start_session() as s, s.start_transaction():
+ filters = {'_id': {'$in': list(ids)}} if ids is not None else {}
+ cursor = self._schedules.find(filters, projection=['_id'])
+ ids = [doc['_id'] for doc in await cursor.to_list(None)]
+ if ids:
+ await self._schedules.delete_many(filters)
+
+ now = datetime.now(timezone.utc)
+ for schedule_id in ids:
+ await self.publish(ScheduleRemoved(now, schedule_id))
+
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]:
+ while True:
+ schedules: List[Schedule] = []
+ async with await self.client.start_session() as s, s.start_transaction():
+ cursor = self._schedules.find(
+ {'next_fire_time': {'$ne': None},
+ '$or': [{'acquired_until': {'$exists': False}},
+ {'acquired_until': {'$lt': datetime.now(timezone.utc)}}]
+ },
+ projection=['serialized_data']
+ ).sort('next_fire_time')
+ for document in await cursor.to_list(length=limit):
+ schedule = self.serializer.deserialize(document['serialized_data'])
+ schedules.append(schedule)
+
+ if schedules:
+ now = datetime.now(timezone.utc)
+ acquired_until = datetime.fromtimestamp(
+ now.timestamp() + self.lock_expiration_delay, now.tzinfo)
+ filters = {'_id': {'$in': [schedule.id for schedule in schedules]}}
+ update = {'$set': {'acquired_by': scheduler_id,
+ 'acquired_until': acquired_until}}
+ await self._schedules.update_many(filters, update)
+ return schedules
+
+ async with move_on_after(self.max_poll_time):
+ await self._schedules_event.wait()
+
+ async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None:
+ updated_schedules: List[Tuple[str, datetime]] = []
+ finished_schedule_ids: List[str] = []
+ async with await self.client.start_session() as s, s.start_transaction():
+ # Update schedules that have a next fire time
+ requests = []
+ for schedule in schedules:
+ filters = {'_id': schedule.id, 'acquired_by': scheduler_id}
+ if schedule.next_fire_time is not None:
+ try:
+ serialized_data = self.serializer.serialize(schedule)
+ except SerializationError:
+ self._logger.exception('Error serializing schedule %r – '
+ 'removing from data store', schedule.id)
+ requests.append(DeleteOne(filters))
+ finished_schedule_ids.append(schedule.id)
+ continue
+
+ update = {
+ '$unset': {
+ 'acquired_by': True,
+ 'acquired_until': True,
+ },
+ '$set': {
+ 'next_fire_time': schedule.next_fire_time,
+ 'serialized_data': serialized_data
+ }
+ }
+ requests.append(UpdateOne(filters, update))
+ updated_schedules.append((schedule.id, schedule.next_fire_time))
+ else:
+ requests.append(DeleteOne(filters))
+ finished_schedule_ids.append(schedule.id)
+
+ if requests:
+ await self._schedules.bulk_write(requests, ordered=False)
+ now = datetime.now(timezone.utc)
+ for schedule_id, next_fire_time in updated_schedules:
+ await self.publish(ScheduleUpdated(now, schedule_id, next_fire_time))
+
+ now = datetime.now(timezone.utc)
+ for schedule_id in finished_schedule_ids:
+ await self.publish(ScheduleRemoved(now, schedule_id))
+
+ async def add_job(self, job: Job) -> None:
+ serialized_data = self.serializer.serialize(job)
+ document = {
+ '_id': job.id,
+ 'serialized_data': serialized_data,
+ 'task_id': job.task_id,
+ 'created_at': datetime.now(timezone.utc),
+ 'tags': list(job.tags)
+ }
+ await self._jobs.insert_one(document)
+ self._jobs_event.set()
+
+ async def get_jobs(self, ids: Optional[Iterable[str]] = None) -> List[Job]:
+ jobs: List[Job] = []
+ filters = {'_id': {'$in': list(ids)}} if ids is not None else {}
+ cursor = self._jobs.find(filters, projection=['_id', 'serialized_data']).sort('_id')
+ for document in await cursor.to_list(None):
+ try:
+ job = self.serializer.deserialize(document['serialized_data'])
+ except DeserializationError:
+ self._logger.warning('Failed to deserialize job %r', document['_id'])
+ continue
+
+ jobs.append(job)
+
+ return jobs
+
+ async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
+ jobs: List[Job] = []
+ while True:
+ async with await self.client.start_session() as s, s.start_transaction():
+ cursor = self._jobs.find(
+ {'$or': [{'acquired_until': {'$exists': False}},
+ {'acquired_until': {'$lt': datetime.now(timezone.utc)}}]
+ },
+ projection=['serialized_data'],
+ sort=[('created_at', ASCENDING)]
+ )
+ for document in await cursor.to_list(length=limit):
+ job = self.serializer.deserialize(document['serialized_data'])
+ jobs.append(job)
+
+ if jobs:
+ now = datetime.now(timezone.utc)
+ acquired_until = datetime.fromtimestamp(
+ now.timestamp() + self.lock_expiration_delay, timezone.utc)
+ filters = {'_id': {'$in': [job.id for job in jobs]}}
+ update = {'$set': {'acquired_by': worker_id,
+ 'acquired_until': acquired_until}}
+ await self._jobs.update_many(filters, update)
+ return jobs
+
+ async with move_on_after(self.max_poll_time):
+ await self._jobs_event.wait()
+ self._jobs_event.clear()
+
+ async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None:
+ filters = {'_id': {'$in': [job.id for job in jobs]}, 'acquired_by': worker_id}
+ await self._jobs.delete_many(filters)
diff --git a/apscheduler/datastores/postgresql.py b/apscheduler/datastores/postgresql.py
new file mode 100644
index 0000000..570badb
--- /dev/null
+++ b/apscheduler/datastores/postgresql.py
@@ -0,0 +1,331 @@
+import asyncio
+import logging
+from asyncio import current_task
+from datetime import datetime, timezone
+from typing import Iterable, List, Optional, Set
+from uuid import UUID
+
+import sniffio
+from anyio import create_task_group, move_on_after, sleep
+from anyio.abc import TaskGroup
+from asyncpg import UniqueViolationError
+from asyncpg.pool import Pool
+
+from ..abc import DataStore, Job, Schedule, Serializer
+from ..events import (
+ EventHub, JobAdded, ScheduleAdded, ScheduleEvent, ScheduleRemoved,
+ ScheduleUpdated)
+from ..exceptions import ConflictingIdError, SerializationError
+from ..policies import ConflictPolicy
+from ..serializers.pickle import PickleSerializer
+
+logger = logging.getLogger(__name__)
+
+
+class PostgresqlDataStore(DataStore, EventHub):
+ _task_group: TaskGroup
+ _schedules_event: Optional[asyncio.Event] = None
+ _jobs_event: Optional[asyncio.Event] = None
+
+ def __init__(self, pool: Pool, *, schema: str = 'public',
+ notify_channel: Optional[str] = 'apscheduler',
+ serializer: Optional[Serializer] = None,
+ lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1,
+ max_idle_time: float = 60, start_from_scratch: bool = False):
+ super().__init__()
+ self.pool = pool
+ self.schema = schema
+ self.notify_channel = notify_channel
+ self.serializer = serializer or PickleSerializer()
+ self.lock_expiration_delay = lock_expiration_delay
+ self.max_poll_time = max_poll_time
+ self.max_idle_time = max_idle_time
+ self.start_from_scratch = start_from_scratch
+ self._logger = logging.getLogger(__name__)
+ self._loans = 0
+
+ async def __aenter__(self):
+ asynclib = sniffio.current_async_library() or '(unknown)'
+ if asynclib != 'asyncio':
+ raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}')
+
+ if self._loans == 0:
+ self._schedules_event = asyncio.Event()
+ self._jobs_event = asyncio.Event()
+ await self._setup()
+
+ self._loans += 1
+ if self._loans == 1 and self.notify_channel:
+ print('entering postgresql data store in task', id(current_task()))
+ self._task_group = create_task_group()
+ await self._task_group.__aenter__()
+ await self._task_group.spawn(self._listen_notifications)
+
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ assert self._loans
+ self._loans -= 1
+ if self._loans == 0 and self.notify_channel:
+ print('exiting postgresql data store in task', id(current_task()))
+ await self._task_group.cancel_scope.cancel()
+ await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
+ del self._schedules_event
+ del self._jobs_event
+
+ async def _listen_notifications(self) -> None:
+ def callback(connection, pid, channel: str, payload: str) -> None:
+ self._logger.debug('Received notification on channel %s: %s', channel, payload)
+ if payload == 'schedule':
+ self._schedules_event.set()
+ elif payload == 'job':
+ self._jobs_event.set()
+
+ while True:
+ async with self.pool.acquire() as conn:
+ await conn.add_listener(self.notify_channel, callback)
+ try:
+ while True:
+ await sleep(self.max_idle_time)
+ await conn.execute('SELECT 1')
+ finally:
+ await conn.remove_listener(self.notify_channel, callback)
+
+ async def _setup(self) -> None:
+ async with self.pool.acquire() as conn, conn.transaction():
+ if self.start_from_scratch:
+ await conn.execute(f"DROP TABLE IF EXISTS {self.schema}.schedules")
+ await conn.execute(f"DROP TABLE IF EXISTS {self.schema}.jobs")
+ await conn.execute(f"DROP TABLE IF EXISTS {self.schema}.metadata")
+
+ await conn.execute(f"""
+ CREATE TABLE IF NOT EXISTS {self.schema}.metadata (
+ schema_version INTEGER NOT NULL
+ )
+ """)
+ version = await conn.fetchval(f"SELECT schema_version FROM {self.schema}.metadata")
+ if version is None:
+ await conn.execute(f"INSERT INTO {self.schema}.metadata VALUES (1)")
+ await conn.execute(f"""
+ CREATE TABLE {self.schema}.schedules (
+ id TEXT PRIMARY KEY,
+ task_id TEXT NOT NULL,
+ serialized_data BYTEA NOT NULL,
+ next_fire_time TIMESTAMP WITH TIME ZONE,
+ acquired_by TEXT,
+ acquired_until TIMESTAMP WITH TIME ZONE
+ ) WITH (fillfactor = 80)
+ """)
+ await conn.execute(f"CREATE INDEX ON {self.schema}.schedules (next_fire_time)")
+ await conn.execute(f"""
+ CREATE TABLE {self.schema}.jobs (
+ id UUID PRIMARY KEY,
+ serialized_data BYTEA NOT NULL,
+ task_id TEXT NOT NULL,
+ tags TEXT[] NOT NULL,
+ created_at TIMESTAMP WITH TIME ZONE NOT NULL,
+ acquired_by TEXT,
+ acquired_until TIMESTAMP WITH TIME ZONE
+ ) WITH (fillfactor = 80)
+ """)
+ await conn.execute(f"CREATE INDEX ON {self.schema}.jobs (task_id)")
+ await conn.execute(f"CREATE INDEX ON {self.schema}.jobs (tags)")
+ elif version > 1:
+ raise RuntimeError(f'Unexpected schema version ({version}); '
+ f'only version 1 is supported by this version of APScheduler')
+
+ async def clear(self) -> None:
+ async with self.pool.acquire() as conn, conn.transaction():
+ await conn.execute(f"TRUNCATE TABLE {self.schema}.schedules")
+ await conn.execute(f"TRUNCATE TABLE {self.schema}.jobs")
+
+ async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ event: Optional[ScheduleEvent] = None
+ serialized_data = self.serializer.serialize(schedule)
+ query = (f"INSERT INTO {self.schema}.schedules (id, serialized_data, task_id, "
+ f"next_fire_time) VALUES ($1, $2, $3, $4)")
+ async with self.pool.acquire() as conn:
+ try:
+ async with conn.transaction():
+ await conn.execute(query, schedule.id, serialized_data, schedule.task_id,
+ schedule.next_fire_time)
+ except UniqueViolationError:
+ if conflict_policy is ConflictPolicy.exception:
+ raise ConflictingIdError(schedule.id) from None
+ elif conflict_policy is ConflictPolicy.replace:
+ query = (f"UPDATE {self.schema}.schedules SET serialized_data = $2, "
+ f"task_id = $3, next_fire_time = $4 WHERE id = $1")
+ async with conn.transaction():
+ await conn.execute(query, schedule.id, serialized_data, schedule.task_id,
+ schedule.next_fire_time)
+ event = ScheduleUpdated(datetime.now(timezone.utc), schedule.id,
+ schedule.next_fire_time)
+ else:
+ event = ScheduleAdded(datetime.now(timezone.utc), schedule.id,
+ schedule.next_fire_time)
+
+ if event:
+ await self.publish(event)
+
+ if self.notify_channel:
+ await self.pool.execute(f"NOTIFY {self.notify_channel}, 'schedule'")
+
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
+ async with self.pool.acquire() as conn, conn.transaction():
+ now = datetime.now(timezone.utc)
+ query = (f"DELETE FROM {self.schema}.schedules "
+ f"WHERE id = any($1::text[]) "
+ f" AND (acquired_until IS NULL OR acquired_until < $2) "
+ f"RETURNING id")
+ removed_ids = [row[0] for row in await conn.fetch(query, list(ids), now)]
+
+ for schedule_id in removed_ids:
+ await self.publish(ScheduleRemoved(now, schedule_id))
+
+ async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
+ query = f"SELECT serialized_data FROM {self.schema}.schedules"
+ args = ()
+ if ids:
+ query += " WHERE id = any($1::text[])"
+ args = (ids,)
+
+ query += " ORDER BY id"
+ records = await self.pool.fetch(query, *args)
+ return [self.serializer.deserialize(r[0]) for r in records]
+
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]:
+ while True:
+ schedules: List[Schedule] = []
+ async with self.pool.acquire() as conn, conn.transaction():
+ acquired_until = datetime.fromtimestamp(
+ datetime.now(timezone.utc).timestamp() + self.lock_expiration_delay,
+ timezone.utc)
+ records = await conn.fetch(f"""
+ WITH schedule_ids AS (
+ SELECT id FROM {self.schema}.schedules
+ WHERE next_fire_time IS NOT NULL AND next_fire_time <= $1
+ AND (acquired_until IS NULL OR $1 > acquired_until)
+ ORDER BY next_fire_time
+ FOR NO KEY UPDATE SKIP LOCKED
+ FETCH FIRST $2 ROWS ONLY
+ )
+ UPDATE {self.schema}.schedules SET acquired_by = $3, acquired_until = $4
+ WHERE id IN (SELECT id FROM schedule_ids)
+ RETURNING serialized_data
+ """, datetime.now(timezone.utc), limit, scheduler_id, acquired_until)
+
+ for record in records:
+ schedule = self.serializer.deserialize(record['serialized_data'])
+ schedules.append(schedule)
+
+ if schedules:
+ return schedules
+
+ async with move_on_after(self.max_poll_time):
+ await self._schedules_event.wait()
+
+ async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None:
+ update_events: List[ScheduleUpdated] = []
+ finished_schedule_ids: List[str] = []
+ async with self.pool.acquire() as conn, conn.transaction():
+ update_args = []
+ now = datetime.now(timezone.utc)
+ for schedule in schedules:
+ if schedule.next_fire_time is not None:
+ try:
+ serialized_data = self.serializer.serialize(schedule)
+ except SerializationError:
+ self._logger.exception('Error serializing schedule %r – '
+ 'removing from data store', schedule.id)
+ finished_schedule_ids.append(schedule.id)
+ continue
+
+ update_args.append((serialized_data, schedule.next_fire_time, schedule.id))
+ update_events.append(
+ ScheduleUpdated(now, schedule.id, schedule.next_fire_time))
+ else:
+ finished_schedule_ids.append(schedule.id)
+
+ # Update schedules that have a next fire time
+ if update_args:
+ await conn.executemany(
+ f"UPDATE {self.schema}.schedules SET serialized_data = $1, "
+ f"next_fire_time = $2, acquired_by = NULL, acquired_until = NULL "
+ f"WHERE id = $3 AND acquired_by = {scheduler_id!r}", update_args)
+
+ # Remove schedules that have no next fire time or failed to serialize
+ if finished_schedule_ids:
+ await conn.execute(
+ f"DELETE FROM {self.schema}.schedules "
+ f"WHERE id = any($1::text[]) AND acquired_by = $2",
+ list(finished_schedule_ids), scheduler_id)
+
+ for event in update_events:
+ await self.publish(event)
+
+ if update_events and self.notify_channel:
+ await self.pool.execute(f"NOTIFY {self.notify_channel}, 'schedule'")
+
+ for schedule_id in finished_schedule_ids:
+ event = ScheduleRemoved(datetime.now(timezone.utc), schedule_id)
+ await self.publish(event)
+
+ async def add_job(self, job: Job) -> None:
+ now = datetime.now(timezone.utc)
+ query = (f"INSERT INTO {self.schema}.jobs (id, task_id, created_at, serialized_data, "
+ f"tags) VALUES($1, $2, $3, $4, $5)")
+ await self.pool.execute(query, job.id, job.task_id, now, self.serializer.serialize(job),
+ job.tags)
+ await self.publish(JobAdded(now, job.id, job.task_id, job.schedule_id))
+
+ if self.notify_channel:
+ await self.pool.execute(f"NOTIFY {self.notify_channel}, 'job'")
+
+ async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]:
+ query = f"SELECT serialized_data FROM {self.schema}.jobs"
+ args = ()
+ if ids:
+ query += " WHERE id = any($1::uuid[])"
+ args = (ids,)
+
+ query += " ORDER BY id"
+ records = await self.pool.fetch(query, *args)
+ return [self.serializer.deserialize(r[0]) for r in records]
+
+ async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
+ while True:
+ print('acquiring jobs')
+ jobs: List[Job] = []
+ async with self.pool.acquire() as conn, conn.transaction():
+ now = datetime.now(timezone.utc)
+ acquired_until = datetime.fromtimestamp(
+ now.timestamp() + self.lock_expiration_delay, timezone.utc)
+ records = await conn.fetch(f"""
+ WITH job_ids AS (
+ SELECT id FROM {self.schema}.jobs
+ WHERE acquired_until IS NULL OR acquired_until < $1
+ ORDER BY created_at
+ FOR NO KEY UPDATE SKIP LOCKED
+ FETCH FIRST $2 ROWS ONLY
+ )
+ UPDATE {self.schema}.jobs SET acquired_by = $3, acquired_until = $4
+ WHERE id IN (SELECT id FROM job_ids)
+ RETURNING serialized_data
+ """, now, limit, worker_id, acquired_until)
+
+ for record in records:
+ job = self.serializer.deserialize(record['serialized_data'])
+ jobs.append(job)
+
+ if jobs:
+ return jobs
+
+ async with move_on_after(self.max_poll_time):
+ await self._jobs_event.wait()
+ self._jobs_event.clear()
+
+ async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None:
+ job_ids = {j.id for j in jobs}
+ await self.pool.execute(
+ f"DELETE FROM {self.schema}.jobs WHERE acquired_by = $1 AND id = any($2::uuid[])",
+ worker_id, job_ids)
diff --git a/apscheduler/eventhubs/__init__.py b/apscheduler/eventhubs/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/apscheduler/eventhubs/__init__.py
+++ /dev/null
diff --git a/apscheduler/eventhubs/local.py b/apscheduler/eventhubs/local.py
deleted file mode 100644
index b429085..0000000
--- a/apscheduler/eventhubs/local.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import logging
-from dataclasses import dataclass, field
-from typing import Any, Callable, Coroutine, List
-
-from apscheduler.abc import EventHub
-from apscheduler.events import Event
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class LocalEventHub(EventHub):
- _callbacks: List[Callable[[Event], Any]] = field(init=False, default_factory=list)
-
- async def publish(self, event: Event) -> None:
- for callback in self._callbacks:
- try:
- retval = callback(event)
- if isinstance(retval, Coroutine):
- await retval
- except Exception:
- logger.exception('Error delivering event to callback %r', callback)
-
- async def subscribe(self, callback: Callable[[Event], Any]) -> None:
- self._callbacks.append(callback)
diff --git a/apscheduler/events.py b/apscheduler/events.py
index d99fd9f..7a67304 100644
--- a/apscheduler/events.py
+++ b/apscheduler/events.py
@@ -1,72 +1,169 @@
-from dataclasses import dataclass
+from collections import defaultdict
+from contextlib import ExitStack, suppress
+from dataclasses import dataclass, field
from datetime import datetime
-from typing import Any, Optional, Set
+from inspect import isawaitable, isclass
+from typing import Any, Callable, Dict, Iterable, List, Optional, Type
+from uuid import UUID
+from warnings import warn
+from anyio import run_sync_in_worker_thread, start_blocking_portal
+from anyio.abc import BlockingPortal
-@dataclass(frozen=True)
-class Event:
- timestamp: datetime
+from .abc import Event, EventSource
@dataclass(frozen=True)
-class ExecutorEvent(Event):
- job_id: str
+class JobEvent(Event):
+ job_id: UUID
task_id: str
schedule_id: Optional[str]
- scheduled_start_time: Optional[datetime]
@dataclass(frozen=True)
-class JobSubmissionFailed(ExecutorEvent):
- exception: BaseException
- formatted_traceback: str
+class JobAdded(JobEvent):
+ pass
@dataclass(frozen=True)
-class JobAdded(ExecutorEvent):
+class JobUpdated(JobEvent):
pass
@dataclass(frozen=True)
-class JobUpdated(ExecutorEvent):
+class JobExecutionEvent(JobEvent):
+ job_id: UUID
+ task_id: str
+ schedule_id: Optional[str]
+ scheduled_fire_time: Optional[datetime]
+ start_time: datetime
+ start_deadline: datetime
+
+
+@dataclass(frozen=True)
+class JobDeadlineMissed(JobExecutionEvent):
pass
@dataclass(frozen=True)
-class JobDeadlineMissed(ExecutorEvent):
- start_deadline: datetime
+class JobStarted(JobExecutionEvent):
+ pass
@dataclass(frozen=True)
-class JobSuccessful(ExecutorEvent):
- start_time: datetime
- start_deadline: Optional[datetime]
+class JobSuccessful(JobExecutionEvent):
return_value: Any
@dataclass(frozen=True)
-class JobFailed(ExecutorEvent):
- start_time: datetime
- start_deadline: Optional[datetime]
- formatted_traceback: str
+class JobFailed(JobExecutionEvent):
+ traceback: str
exception: Optional[BaseException] = None
@dataclass(frozen=True)
-class DataStoreEvent(Event):
- schedule_ids: Set[str]
+class ScheduleEvent(Event):
+ schedule_id: str
@dataclass(frozen=True)
-class SchedulesAdded(DataStoreEvent):
- earliest_next_fire_time: Optional[datetime]
+class ScheduleAdded(ScheduleEvent):
+ next_fire_time: Optional[datetime]
@dataclass(frozen=True)
-class SchedulesUpdated(DataStoreEvent):
- earliest_next_fire_time: Optional[datetime]
+class ScheduleUpdated(ScheduleEvent):
+ next_fire_time: Optional[datetime]
@dataclass(frozen=True)
-class SchedulesRemoved(DataStoreEvent):
+class ScheduleRemoved(ScheduleEvent):
pass
+
+
+_all_event_types = [x for x in locals().values() if isclass(x) and issubclass(x, Event)]
+
+
+#
+# Event delivery
+#
+
+@dataclass
+class EventHub(EventSource):
+ _subscribers: Dict[Type[Event], List[Callable[[Event], Any]]] = field(
+ init=False, default_factory=lambda: defaultdict(list))
+
+ def subscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[Type[Event]]] = None) -> None:
+ if event_types is None:
+ event_types = _all_event_types
+
+ for event_type in event_types:
+ existing_callbacks = self._subscribers[event_type]
+ if callback not in existing_callbacks:
+ existing_callbacks.append(callback)
+
+ def unsubscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[Type[Event]]] = None) -> None:
+ if event_types is None:
+ event_types = _all_event_types
+
+ for event_type in event_types:
+ existing_callbacks = self._subscribers.get(event_type, [])
+ with suppress(ValueError):
+ existing_callbacks.remove(callback)
+
+ async def publish(self, event: Event) -> None:
+ for callback in self._subscribers[type(event)]:
+ try:
+ retval = callback(event)
+ if isawaitable(retval):
+ await retval
+ except Exception as exc:
+ warn(f'Failed to deliver {event.__class__.__name__} event to callback '
+ f'{callback!r}: {exc.__class__.__name__}: {exc}')
+
+
+class SyncEventSource:
+ _subscribers: Dict[Type[Event], List[Callable[[Event], Any]]]
+
+ def __init__(self, async_event_source: EventSource, portal: Optional[BlockingPortal] = None):
+ self.portal = portal
+ self._async_event_source = async_event_source
+ self._async_event_source.subscribe(self._forward_async_event)
+ self._exit_stack = ExitStack()
+ self._subscribers = defaultdict(list)
+
+ def __enter__(self):
+ self._exit_stack.__enter__()
+ if not self.portal:
+ portal_cm = start_blocking_portal()
+ self.portal = self._exit_stack.enter_context(portal_cm)
+ self._async_event_source.subscribe(self._forward_async_event)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
+
+ async def _forward_async_event(self, event: Event) -> None:
+ for subscriber in self._subscribers.get(type(event), ()):
+ await run_sync_in_worker_thread(subscriber, event)
+
+ def subscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[Type[Event]]] = None) -> None:
+ if event_types is None:
+ event_types = _all_event_types
+
+ for event_type in event_types:
+ existing_callbacks = self._subscribers[event_type]
+ if callback not in existing_callbacks:
+ existing_callbacks.append(callback)
+
+ def unsubscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[Type[Event]]] = None) -> None:
+ if event_types is None:
+ event_types = _all_event_types
+
+ for event_type in event_types:
+ existing_callbacks = self._subscribers.get(event_type, [])
+ with suppress(ValueError):
+ existing_callbacks.remove(callback)
diff --git a/apscheduler/exceptions.py b/apscheduler/exceptions.py
index 48fe9c1..ec04f2c 100644
--- a/apscheduler/exceptions.py
+++ b/apscheduler/exceptions.py
@@ -8,11 +8,14 @@ class JobLookupError(KeyError):
class ConflictingIdError(KeyError):
- """Raised when the uniqueness of job IDs is being violated."""
+ """
+ Raised when trying to add a schedule to a store that already contains a schedule by that ID,
+ and the conflict policy of ``exception`` is used.
+ """
- def __init__(self, job_id):
+ def __init__(self, schedule_id):
super().__init__(
- u'Job identifier (%s) conflicts with an existing job' % job_id)
+ f'This data store already contains a schedule with the identifier {schedule_id!r}')
class TransientJobError(ValueError):
@@ -40,3 +43,17 @@ class MaxIterationsReached(Exception):
Raised when a trigger has reached its maximum number of allowed computation iterations when
trying to calculate the next fire time.
"""
+
+
+class SchedulerAlreadyRunningError(Exception):
+ """Raised when attempting to start or configure the scheduler when it's already running."""
+
+ def __str__(self):
+ return 'Scheduler is already running'
+
+
+class SchedulerNotRunningError(Exception):
+ """Raised when attempting to shutdown the scheduler when it's not running."""
+
+ def __str__(self):
+ return 'Scheduler is not running'
diff --git a/apscheduler/policies.py b/apscheduler/policies.py
new file mode 100644
index 0000000..91fbd31
--- /dev/null
+++ b/apscheduler/policies.py
@@ -0,0 +1,19 @@
+from enum import Enum, auto
+
+
+class ConflictPolicy(Enum):
+ #: replace the existing schedule with a new one
+ replace = auto()
+ #: keep the existing schedule as-is and drop the new schedule
+ do_nothing = auto()
+ #: raise an exception if a conflict is detected
+ exception = auto()
+
+
+class CoalescePolicy(Enum):
+ #: run once, with the earliest fire time
+ earliest = auto()
+ #: run once, with the latest fire time
+ latest = auto()
+ #: submit one job for every accumulated fire time
+ all = auto()
diff --git a/apscheduler/schedulers/__init__.py b/apscheduler/schedulers/__init__.py
index bd8a790..e69de29 100644
--- a/apscheduler/schedulers/__init__.py
+++ b/apscheduler/schedulers/__init__.py
@@ -1,12 +0,0 @@
-class SchedulerAlreadyRunningError(Exception):
- """Raised when attempting to start or configure the scheduler when it's already running."""
-
- def __str__(self):
- return 'Scheduler is already running'
-
-
-class SchedulerNotRunningError(Exception):
- """Raised when attempting to shutdown the scheduler when it's not running."""
-
- def __str__(self):
- return 'Scheduler is not running'
diff --git a/apscheduler/schedulers/async_.py b/apscheduler/schedulers/async_.py
index d6f8faf..6448aa5 100644
--- a/apscheduler/schedulers/async_.py
+++ b/apscheduler/schedulers/async_.py
@@ -1,53 +1,73 @@
-import logging
import os
import platform
import threading
from contextlib import AsyncExitStack
-from dataclasses import dataclass, field
-from datetime import datetime, timedelta, timezone, tzinfo
-from traceback import format_exc
-from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set, Union
+from datetime import datetime, timedelta, timezone
+from logging import Logger, getLogger
+from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Union
from uuid import uuid4
-import tzlocal
-from anyio import create_event, create_task_group, move_on_after
-from anyio.abc import Event
+from anyio import create_event, create_task_group, get_cancelled_exc_class, open_cancel_scope
+from anyio.abc import CancelScope, Event, TaskGroup
-from ..abc import DataStore, EventHub, EventSource, Executor, Job, Schedule, Task, Trigger
-from ..datastores.memory import MemoryScheduleStore
-from ..eventhubs.local import LocalEventHub
-from ..events import JobSubmissionFailed, SchedulesAdded, SchedulesUpdated
+from ..abc import DataStore, Job, Schedule, Task, Trigger
+from ..datastores.memory import MemoryDataStore
+from ..events import EventHub
from ..marshalling import callable_to_ref
-from ..validators import as_timezone
-from ..workers.local import LocalExecutor
-
-
-@dataclass
-class AsyncScheduler(EventSource):
- schedule_store: DataStore = field(default_factory=MemoryScheduleStore)
- worker: Executor = field(default_factory=LocalExecutor)
- timezone: tzinfo = field(default_factory=tzlocal.get_localzone)
- identity: str = f'{platform.node()}-{os.getpid()}-{threading.get_ident()}'
- logger: logging.Logger = field(default_factory=lambda: logging.getLogger(__name__))
- _event_hub: EventHub = field(init=False, default_factory=LocalEventHub)
- _tasks: Dict[str, Task] = field(init=False, default_factory=dict)
- _async_stack: AsyncExitStack = field(init=False, default_factory=AsyncExitStack)
- _next_fire_time: Optional[datetime] = field(init=False, default=None)
- _wakeup_event: Optional[Event] = field(init=False, default=None)
- _closed: bool = field(init=False, default=False)
-
- def __post_init__(self):
- self.timezone = as_timezone(self.timezone)
+from ..policies import CoalescePolicy, ConflictPolicy
+from ..workers.async_ import AsyncWorker
+
+
+class AsyncScheduler(EventHub):
+ _task_group: Optional[TaskGroup] = None
+ _stop_event: Optional[Event] = None
+ _running: bool = False
+ _worker: Optional[AsyncWorker] = None
+ _acquire_cancel_scope: Optional[CancelScope] = None
+
+ def __init__(self, data_store: Optional[DataStore] = None, *, identity: Optional[str] = None,
+ logger: Optional[Logger] = None, start_worker: bool = True):
+ super().__init__()
+ self.data_store = data_store or MemoryDataStore()
+ self.identity = identity or f'{platform.node()}-{os.getpid()}-{threading.get_ident()}'
+ self.logger = logger or getLogger(__name__)
+ self.start_worker = start_worker
+ self._tasks: Dict[str, Task] = {}
+ self._exit_stack = AsyncExitStack()
+
+ @property
+ def worker(self) -> Optional[AsyncWorker]:
+ return self._worker
async def __aenter__(self):
- await self._async_stack.__aenter__()
- task_group = create_task_group()
- await self._async_stack.enter_async_context(task_group)
- await task_group.spawn(self.run)
+ await self._exit_stack.__aenter__()
+
+ # Start the built-in worker, if configured to do so
+ if self.start_worker:
+ # The worker handles initializing the data store
+ self._worker = AsyncWorker(self.data_store)
+ await self._exit_stack.enter_async_context(self._worker)
+ else:
+ # Otherwise, initialize the data store ourselves
+ await self._exit_stack.enter_async_context(self.data_store)
+
+ # Start the actual scheduler
+ self._task_group = create_task_group()
+ await self._exit_stack.enter_async_context(self._task_group)
+ start_event = create_event()
+ await self._task_group.spawn(self.run, start_event)
+ await start_event.wait()
+
+ # await self.start(self._task_group)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._async_stack.__aexit__(exc_type, exc_val, exc_tb)
+ # Exit gracefully (wait for ongoing tasks to finish) if the context exited without errors
+ await self.stop(force=exc_type is not None)
+ if self._worker:
+ await self._worker.stop(force=exc_type is not None)
+
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task:
task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id)
@@ -71,8 +91,10 @@ class AsyncScheduler(EventSource):
async def add_schedule(
self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None,
args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None,
- coalesce: bool = True, misfire_grace_time: Union[float, timedelta, None] = None,
- tags: Optional[Iterable[str]] = None
+ coalesce: CoalescePolicy = CoalescePolicy.latest,
+ misfire_grace_time: Union[float, timedelta, None] = None,
+ tags: Optional[Iterable[str]] = None,
+ conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing
) -> str:
id = id or str(uuid4())
args = tuple(args or ())
@@ -85,36 +107,30 @@ class AsyncScheduler(EventSource):
schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs,
coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags,
next_fire_time=trigger.next())
- await self.schedule_store.add_or_replace_schedules([schedule])
- self.logger.info('Added new schedule for task %s; next run time at %s', taskdef,
- schedule.next_fire_time)
+ await self.data_store.add_schedule(schedule, conflict_policy)
+ self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef,
+ trigger, schedule.next_fire_time)
return schedule.id
async def remove_schedule(self, schedule_id: str) -> None:
- await self.schedule_store.remove_schedules({schedule_id})
-
- async def _handle_worker_event(self, event: Event) -> None:
- await self._event_hub.publish(event)
-
- async def _handle_datastore_event(self, event: Event) -> None:
- if isinstance(event, (SchedulesAdded, SchedulesUpdated)):
- # Wake up the scheduler if any schedule has an earlier next fire time than the one
- # we're currently waiting for
- if event.earliest_next_fire_time:
- if (self._next_fire_time is None
- or self._next_fire_time > event.earliest_next_fire_time):
- self.logger.debug('Job store reported an updated next fire time that requires '
- 'the scheduler to wake up: %s',
- event.earliest_next_fire_time)
- await self.wakeup()
-
- await self._event_hub.publish(event)
-
- async def _process_schedules(self):
- async with self.schedule_store.acquire_due_schedules(
- self.identity, datetime.now(timezone.utc)) as schedules:
- schedule_ids_to_remove: Set[str] = set()
- schedule_updates: Dict[str, Dict[str, Any]] = {}
+ await self.data_store.remove_schedules({schedule_id})
+
+ async def run(self, start_event: Optional[Event] = None) -> None:
+ self._stop_event = create_event()
+ self._running = True
+ if start_event:
+ await start_event.set()
+
+ while self._running:
+ async with open_cancel_scope() as self._acquire_cancel_scope:
+ try:
+ schedules = await self.data_store.acquire_schedules(self.identity, 100)
+ except get_cancelled_exc_class():
+ break
+ finally:
+ del self._acquire_cancel_scope
+
+ now = datetime.now(timezone.utc)
for schedule in schedules:
# Look up the task definition
try:
@@ -122,76 +138,66 @@ class AsyncScheduler(EventSource):
except LookupError:
self.logger.error('Cannot locate task definition %r for schedule %r – '
'removing schedule', schedule.task_id, schedule.id)
- schedule_ids_to_remove.add(schedule.id)
+ schedule.next_fire_time = None
continue
# Calculate a next fire time for the schedule, if possible
- try:
- next_fire_time = schedule.trigger.next()
- except Exception:
- self.logger.exception('Error computing next fire time for schedule %r of task '
- '%r – removing schedule', schedule.id, taskdef.id)
- next_fire_time = None
-
- # Queue a schedule update if a next fire time could be calculated.
- # Otherwise, queue the schedule for removal.
- if next_fire_time:
- schedule_updates[schedule.id] = {
- 'next_fire_time': next_fire_time,
- 'last_fire_time': schedule.next_fire_time,
- 'trigger': schedule.trigger
- }
- else:
- schedule_ids_to_remove.add(schedule.id)
-
- # Submit a new job to the executor
- job = Job(taskdef.id, taskdef.func, schedule.args, schedule.kwargs, schedule.id,
- schedule.last_fire_time, schedule.next_deadline, schedule.tags)
- try:
- await self.worker.submit_job(job)
- except Exception as exc:
- self.logger.exception('Error submitting job to worker')
- event = JobSubmissionFailed(datetime.now(self.timezone), job.id, job.task_id,
- job.schedule_id, job.scheduled_start_time,
- formatted_traceback=format_exc(), exception=exc)
- await self._event_hub.publish(event)
-
- # Removed finished schedules
- if schedule_ids_to_remove:
- await self.schedule_store.remove_schedules(schedule_ids_to_remove)
-
- # Update the next fire times
- if schedule_updates:
- await self.schedule_store.update_schedules(schedule_updates)
-
- return await self.schedule_store.get_next_fire_time()
-
- async def run(self):
- async with AsyncExitStack() as stack:
- await stack.enter_async_context(self.schedule_store)
- await stack.enter_async_context(self.worker)
- await self.worker.subscribe(self._handle_worker_event)
- await self.schedule_store.subscribe(self._handle_datastore_event)
- self._wakeup_event = create_event()
- while not self._closed:
- self._next_fire_time = await self._process_schedules()
-
- if self._next_fire_time:
- wait_time = (self._next_fire_time - datetime.now(timezone.utc)).total_seconds()
- else:
- wait_time = float('inf')
-
- self._wakeup_event = create_event()
- async with move_on_after(wait_time):
- await self._wakeup_event.wait()
-
- async def wakeup(self) -> None:
- await self._wakeup_event.set()
-
- async def shutdown(self, force: bool = False) -> None:
- if not self._closed:
- self._closed = True
- await self.wakeup()
-
- async def subscribe(self, callback: Callable[[Event], Any]) -> None:
- await self._event_hub.subscribe(callback)
+ fire_times = [schedule.next_fire_time]
+ calculate_next = schedule.trigger.next
+ while True:
+ try:
+ fire_time = calculate_next()
+ except Exception:
+ self.logger.exception(
+ 'Error computing next fire time for schedule %r of task %r – '
+ 'removing schedule', schedule.id, taskdef.id)
+ break
+
+ # Stop if the calculated fire time is in the future
+ if fire_time is None or fire_time > now:
+ schedule.next_fire_time = fire_time
+ break
+
+ # Only keep all the fire times if coalesce policy = "all"
+ if schedule.coalesce is CoalescePolicy.all:
+ fire_times.append(fire_time)
+ elif schedule.coalesce is CoalescePolicy.latest:
+ fire_times[0] = fire_time
+
+ # Add one or more jobs to the job queue
+ for fire_time in fire_times:
+ schedule.last_fire_time = fire_time
+ job = Job(taskdef.id, taskdef.func, schedule.args, schedule.kwargs,
+ schedule.id, fire_time, schedule.next_deadline,
+ schedule.tags)
+ print('Added job', job.id, 'for schedule', schedule)
+ await self.data_store.add_job(job)
+
+ self.logger.debug('Releasing %d schedules', len(schedules))
+ await self.data_store.release_schedules(self.identity, schedules)
+
+ await self._stop_event.set()
+ del self._stop_event
+
+ # async def start(self, task_group: TaskGroup, *, reset_datastore: bool = False) -> None:
+ # start_event = create_event()
+ # await task_group.spawn(self.run, start_event)
+ # await start_event.wait()
+ #
+ # if self.start_worker:
+ # self._worker = AsyncWorker(self.data_store)
+ # await self._worker.start(task_group)
+
+ async def stop(self, force: bool = False) -> None:
+ self._running = False
+ if self._worker:
+ await self._worker.stop(force)
+
+ if self._acquire_cancel_scope:
+ await self._acquire_cancel_scope.cancel()
+ if force and self._task_group:
+ await self._task_group.cancel_scope.cancel()
+
+ async def wait_until_stopped(self) -> None:
+ if self._stop_event:
+ await self._stop_event.wait()
diff --git a/apscheduler/schedulers/sync.py b/apscheduler/schedulers/sync.py
index 41a91ec..a049c9f 100644
--- a/apscheduler/schedulers/sync.py
+++ b/apscheduler/schedulers/sync.py
@@ -1,50 +1,109 @@
+from __future__ import annotations
+
+from concurrent.futures import Future
from functools import partial
+from logging import Logger
from typing import Any, Callable, Iterable, Mapping, Optional, Union
-from anyio import start_blocking_portal
from anyio.abc import BlockingPortal
-from apscheduler.abc import Trigger
-from apscheduler.events import Event
-from apscheduler.schedulers.async_ import AsyncScheduler
+from .async_ import AsyncScheduler
+from ..abc import DataStore, Trigger
+from ..events import SyncEventSource
+from ..workers.sync import SyncWorker
+
+
+class SyncScheduler(SyncEventSource):
+ """
+ A synchronous scheduler implementation.
+
+ Wraps an :class:`~.async_.AsyncScheduler` using a :class:`~anyio.abc.BlockingPortal`.
+
+ :ivar BlockingPortal portal: the blocking portal used by the scheduler
+ """
+
+ _task: Future
+ _worker: Optional[SyncWorker] = None
-class SyncScheduler:
- def __init__(self, *, portal: Optional[BlockingPortal] = None, **kwargs):
- self._scheduler = AsyncScheduler(**kwargs)
- self._portal = portal
- self._shutdown_portal = False
+ def __init__(self, data_store: Optional[DataStore] = None, *, identity: Optional[str] = None,
+ logger: Optional[Logger] = None, start_worker: bool = True,
+ portal: Optional[BlockingPortal] = None):
+ self._scheduler = AsyncScheduler(data_store, identity=identity, logger=logger,
+ start_worker=False)
+ self.start_worker = start_worker
+ super().__init__(self._scheduler, portal)
def __enter__(self):
- if not self._portal:
- self._portal = start_blocking_portal()
- self._shutdown_portal = True
+ super().__enter__()
+ self._exit_stack.enter_context(self.portal.wrap_async_context_manager(self._scheduler))
+
+ if self.start_worker:
+ self._worker = SyncWorker(self.data_store, portal=self.portal)
+ self._exit_stack.enter_context(self.worker)
- self._portal.call(self._scheduler.__aenter__)
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- cancel_remaining = exc_type is not None
- try:
- self._portal.call(self._scheduler.__aexit__, exc_type, exc_val, exc_tb)
- except BaseException:
- cancel_remaining = True
- raise
- finally:
- if self._shutdown_portal:
- self._portal.stop_from_external_thread(cancel_remaining=cancel_remaining)
- self._portal = None
+ # def __exit__(self, exc_type, exc_val, exc_tb):
+ # return self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
+
+ # def __enter__(self) -> SyncScheduler:
+ # self.start()
+ # return self
+ #
+ # def __exit__(self, exc_type, exc_val, exc_tb):
+ # self.stop(force=exc_type is not None)
+
+ # @property
+ # def worker(self) -> AsyncWorker:
+ # return self._scheduler.worker
+
+ @property
+ def data_store(self) -> DataStore:
+ return self._scheduler.data_store
+
+ @property
+ def worker(self) -> Optional[SyncWorker]:
+ return self._worker
def add_schedule(self, task: Union[str, Callable], trigger: Trigger, *,
args: Optional[Iterable] = None,
kwargs: Optional[Mapping[str, Any]] = None) -> str:
func = partial(self._scheduler.add_schedule, task, trigger, args=args, kwargs=kwargs)
- return self._portal.call(func)
+ return self.portal.call(func)
def remove_schedule(self, schedule_id: str) -> None:
- return self._portal.call(self._scheduler.remove_schedule, schedule_id)
+ self.portal.call(self._scheduler.remove_schedule, schedule_id)
+
+ def stop(self) -> None:
+ self.portal.call(self._scheduler.stop)
- def shutdown(self, force: bool = False) -> None:
- self._portal.call(self._scheduler.shutdown, force)
+ def wait_until_stopped(self) -> None:
+ self.portal.call(self._scheduler.wait_until_stopped)
- def subscribe(self, callback: Callable[[Event], Any]) -> None:
- self._portal.call(self._scheduler.subscribe, callback)
+ # def start(self):
+ # if not self._scheduler._running:
+ # if not self.portal:
+ # self.portal = start_blocking_portal()
+ # self._shutdown_portal = True
+ #
+ # self.portal.call(self._scheduler.data_store.initialize)
+ # start_event = self.portal.call(create_event)
+ # self._task = self.portal.spawn_task(self._scheduler.run, start_event)
+ # self.portal.call(start_event.wait)
+ # print('start_event wait finished')
+ #
+ # if self._scheduler.start_worker:
+ # self._worker = SyncWorker(self.data_store, portal=self.portal)
+ # self._worker.start()
+ #
+ # def stop(self, force: bool = False) -> None:
+ # if self._worker:
+ # self._worker.stop(force)
+ #
+ # if self._scheduler._running:
+ # self._scheduler._running = False
+ # try:
+ # self.portal.call(partial(self._scheduler.stop, force=force))
+ # finally:
+ # if self._shutdown_portal:
+ # self.portal.stop_from_external_thread(cancel_remaining=force)
diff --git a/apscheduler/util.py b/apscheduler/util.py
index 1b09d53..eee6da5 100644
--- a/apscheduler/util.py
+++ b/apscheduler/util.py
@@ -1,6 +1,6 @@
"""This module contains several handy functions primarily meant for internal use."""
import sys
-from datetime import tzinfo
+from datetime import datetime, tzinfo
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
@@ -24,3 +24,7 @@ def timezone_repr(timezone: tzinfo) -> str:
return timezone.key
else:
return repr(timezone)
+
+
+def absolute_datetime_diff(dateval1: datetime, dateval2: datetime) -> float:
+ return dateval1.timestamp() - dateval2.timestamp()
diff --git a/apscheduler/watchers.py b/apscheduler/watchers.py
new file mode 100644
index 0000000..d661d88
--- /dev/null
+++ b/apscheduler/watchers.py
@@ -0,0 +1,22 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+from anyio import create_event, move_on_after
+from anyio.abc import Event
+
+
+@dataclass
+class DelayWatcher:
+ max_wait_time: Optional[float]
+ _event: Optional[Event] = field(init=False, default=None)
+
+ async def wait(self) -> None:
+ async with move_on_after(self.max_wait_time):
+ self._event = create_event()
+ await self._event.wait()
+
+ async def notify(self) -> None:
+ if self._event:
+ event = self._event
+ del self._event
+ await event.set()
diff --git a/apscheduler/workers/async_.py b/apscheduler/workers/async_.py
new file mode 100644
index 0000000..7949993
--- /dev/null
+++ b/apscheduler/workers/async_.py
@@ -0,0 +1,158 @@
+import os
+import platform
+import threading
+from asyncio import current_task, iscoroutinefunction
+from collections.abc import Coroutine
+from contextlib import AsyncExitStack
+from datetime import datetime, timezone
+from functools import partial
+from logging import Logger, getLogger
+from traceback import format_exc
+from typing import Any, Callable, Dict, Optional, Set
+
+from anyio import create_event, create_task_group, open_cancel_scope, run_sync_in_worker_thread
+from anyio.abc import CancelScope, Event, TaskGroup
+
+from ..abc import DataStore, Job
+from ..events import EventHub, JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated
+
+
+class AsyncWorker(EventHub):
+ """Runs jobs locally in a task group."""
+
+ _task_group: Optional[TaskGroup] = None
+ _stop_event: Optional[Event] = None
+ _running: bool = False
+ _running_jobs: int = 0
+ _acquire_cancel_scope: Optional[CancelScope] = None
+
+ def __init__(self, data_store: DataStore, *, max_concurrent_jobs: int = 100,
+ identity: Optional[str] = None, logger: Optional[Logger] = None,
+ run_sync_functions_in_event_loop: bool = True):
+ super().__init__()
+ self.data_store = data_store
+ self.max_concurrent_jobs = max_concurrent_jobs
+ self.identity = identity or f'{platform.node()}-{os.getpid()}-{threading.get_ident()}'
+ self.logger = logger or getLogger(__name__)
+ self.run_sync_functions_in_event_loop = run_sync_functions_in_event_loop
+ self._acquired_jobs: Set[Job] = set()
+ self._exit_stack = AsyncExitStack()
+
+ if self.max_concurrent_jobs < 1:
+ raise ValueError('max_concurrent_jobs must be at least 1')
+
+ async def __aenter__(self):
+ print('entering worker in task', id(current_task()))
+ await self._exit_stack.__aenter__()
+
+ # Initialize the data store
+ await self._exit_stack.enter_async_context(self.data_store)
+
+ # Start the actual worker
+ self._task_group = create_task_group()
+ await self._exit_stack.enter_async_context(self._task_group)
+ start_event = create_event()
+ await self._task_group.spawn(self.run, start_event)
+ await start_event.wait()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ print('exiting worker in task', id(current_task()))
+ await self.stop(force=exc_type is not None)
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
+
+ async def run(self, start_event: Optional[Event] = None) -> None:
+ self._stop_event = create_event()
+ self._running = True
+ if start_event:
+ await start_event.set()
+
+ while self._running:
+ limit = self.max_concurrent_jobs - self._running_jobs
+ jobs = []
+ async with open_cancel_scope() as self._acquire_cancel_scope:
+ try:
+ jobs = await self.data_store.acquire_jobs(self.identity, limit)
+ finally:
+ del self._acquire_cancel_scope
+
+ for job in jobs:
+ await self._task_group.spawn(self._run_job, job)
+
+ await self._stop_event.set()
+ del self._stop_event
+ del self._task_group
+
+ async def _run_job(self, job: Job) -> None:
+ # Check if the job started before the deadline
+ now = datetime.now(timezone.utc)
+ if job.start_deadline is not None:
+ if now.timestamp() > job.start_deadline.timestamp():
+ self.logger.info('Missed the deadline of job %r', job.id)
+ event = JobDeadlineMissed(
+ now, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id,
+ scheduled_fire_time=job.scheduled_fire_time, start_time=now,
+ start_deadline=job.start_deadline
+ )
+ await self.publish(event)
+ return
+
+ # Set the job as running and publish a job update event
+ self.logger.info('Started job %r', job.id)
+ job.started_at = now
+ event = JobUpdated(
+ timestamp=now, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id
+ )
+ await self.publish(event)
+
+ self._running_jobs += 1
+ try:
+ return_value = await self._call_job_func(job.func, job.args, job.kwargs)
+ except BaseException as exc:
+ self.logger.exception('Job %r raised an exception', job.id)
+ event = JobFailed(
+ timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id,
+ schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time,
+ start_time=now, start_deadline=job.start_deadline,
+ traceback=format_exc(), exception=exc
+ )
+ else:
+ self.logger.info('Job %r completed successfully', job.id)
+ event = JobSuccessful(
+ timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id,
+ schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time,
+ start_time=now, start_deadline=job.start_deadline, return_value=return_value
+ )
+
+ self._running_jobs -= 1
+ await self.data_store.release_jobs(self.identity, [job])
+ await self.publish(event)
+
+ async def _call_job_func(self, func: Callable, args: tuple, kwargs: Dict[str, Any]):
+ if not self.run_sync_functions_in_event_loop and not iscoroutinefunction(func):
+ wrapped = partial(func, *args, **kwargs)
+ return await run_sync_in_worker_thread(wrapped)
+
+ return_value = func(*args, **kwargs)
+ if isinstance(return_value, Coroutine):
+ return_value = await return_value
+
+ return return_value
+
+ # async def start(self, task_group: TaskGroup) -> None:
+ # start_event = create_event()
+ # print('spawning task for AsyncWorker.run() in task', id(current_task()))
+ # await task_group.spawn(self.run, start_event)
+ # await start_event.wait()
+
+ async def stop(self, force: bool = False) -> None:
+ self._running = False
+ if self._acquire_cancel_scope:
+ await self._acquire_cancel_scope.cancel()
+
+ if force and self._task_group:
+ await self._task_group.cancel_scope.cancel()
+
+ async def wait_until_stopped(self) -> None:
+ if self._stop_event:
+ await self._stop_event.wait()
diff --git a/apscheduler/workers/local.py b/apscheduler/workers/local.py
deleted file mode 100644
index 9570612..0000000
--- a/apscheduler/workers/local.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from __future__ import annotations
-
-from collections import OrderedDict
-from collections.abc import Coroutine
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
-from traceback import format_exc
-from typing import Any, Callable, List
-
-from anyio import create_task_group
-from anyio.abc import TaskGroup
-from apscheduler.abc import EventHub, Executor, Job
-from apscheduler.eventhubs.local import LocalEventHub
-from apscheduler.events import (
- Event, JobAdded, JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated)
-
-
-@dataclass
-class LocalExecutor(Executor):
- """Runs jobs locally in a task group."""
-
- _event_hub: EventHub = field(init=False, default_factory=LocalEventHub)
- _task_group: TaskGroup = field(init=False)
- _queued_jobs: OrderedDict[Job, None] = field(init=False, default_factory=OrderedDict)
- _running_jobs: OrderedDict[Job, None] = field(init=False, default_factory=OrderedDict)
-
- async def __aenter__(self) -> LocalExecutor:
- self._task_group = create_task_group()
- await self._task_group.__aenter__()
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
- await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
-
- async def _run_job(self) -> None:
- job = self._queued_jobs.popitem(last=False)[0]
-
- # Check if the job started before the deadline
- if job.start_deadline:
- tz = job.scheduled_start_time.tzinfo
- start_time = datetime.now(tz)
- if start_time >= job.start_deadline:
- event = JobDeadlineMissed(
- start_time, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id,
- scheduled_start_time=job.scheduled_start_time,
- start_deadline=job.start_deadline
- )
- await self._event_hub.publish(event)
- return
- else:
- tz = timezone.utc
- start_time = datetime.now(tz)
-
- # Set the job as running and publish a job update event
- self._running_jobs[job] = None
- job.started_at = start_time
- event = JobUpdated(
- timestamp=datetime.now(tz), job_id=job.id, task_id=job.task_id,
- schedule_id=job.schedule_id, scheduled_start_time=job.scheduled_start_time
- )
- await self._event_hub.publish(event)
-
- try:
- return_value = job.func(*job.args, **job.kwargs)
- if isinstance(return_value, Coroutine):
- return_value = await return_value
- except Exception as exc:
- event = JobFailed(
- timestamp=datetime.now(tz), job_id=job.id, task_id=job.task_id,
- schedule_id=job.schedule_id, scheduled_start_time=job.scheduled_start_time,
- start_time=start_time, start_deadline=job.start_deadline,
- formatted_traceback=format_exc(), exception=exc)
- else:
- event = JobSuccessful(
- timestamp=datetime.now(tz), job_id=job.id, task_id=job.task_id,
- schedule_id=job.schedule_id, scheduled_start_time=job.scheduled_start_time,
- start_time=start_time, start_deadline=job.start_deadline, return_value=return_value
- )
-
- del self._running_jobs[job]
- await self._event_hub.publish(event)
-
- async def submit_job(self, job: Job) -> None:
- self._queued_jobs[job] = None
- await self._task_group.spawn(self._run_job)
-
- event = JobAdded(datetime.now(timezone.utc), job.id, job.task_id, job.schedule_id,
- job.scheduled_start_time)
- await self._event_hub.publish(event)
-
- async def get_jobs(self) -> List[Job]:
- return list(self._queued_jobs)
-
- async def subscribe(self, callback: Callable[[Event], Any]) -> None:
- await self._event_hub.subscribe(callback)
diff --git a/apscheduler/workers/sync.py b/apscheduler/workers/sync.py
new file mode 100644
index 0000000..919c914
--- /dev/null
+++ b/apscheduler/workers/sync.py
@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from typing import Optional
+
+from anyio.abc import BlockingPortal
+
+from .async_ import AsyncWorker
+from ..abc import DataStore
+from ..events import SyncEventSource
+
+
+class SyncWorker(SyncEventSource):
+ def __init__(self, data_store: DataStore, *, portal: Optional[BlockingPortal] = None,
+ max_concurrent_jobs: int = 100, identity: Optional[str] = None):
+ self._worker = AsyncWorker(data_store, max_concurrent_jobs=max_concurrent_jobs,
+ identity=identity, run_sync_functions_in_event_loop=False)
+ super().__init__(self._worker, portal)
+
+ def __enter__(self) -> SyncWorker:
+ super().__enter__()
+ self._exit_stack.enter_context(self.portal.wrap_async_context_manager(self._worker))
+ return self
+
+ def stop(self) -> None:
+ self.portal.call(self._worker.stop)
diff --git a/docker-compose.yml b/docker-compose.yml
index 001e048..ca5621f 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -1,5 +1,13 @@
version: "2"
services:
+ postgresql:
+ image: postgres
+ ports:
+ - 127.0.0.1:5432:5432
+ environment:
+ POSTGRES_DB: testdb
+ POSTGRES_PASSWORD: secret
+
redis:
image: redis
ports:
@@ -9,6 +17,8 @@ services:
image: mongo
ports:
- 127.0.0.1:27017:27017
+ environment:
+ MONGO_REPLICA_SET_NAME: foo
rethinkdb:
image: rethinkdb
diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 5700b3d..264a0ab 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -7,9 +7,13 @@ APScheduler, see the :doc:`migration section <migration>`.
UNRELEASED
----------
+* Added shareable data stores (can be used by multiple schedulers and workers)
+* Added support for running workers independently from schedulers
+* Added full async support via AnyIO_ (data store support varies)
* Dropped support for Python 2.X, 3.5 and 3.6
* Removed the Qt scheduler due to maintenance difficulties
+.. _AnyIO: https://github.com/agronholm/anyio
3.7.0
-----
diff --git a/examples/executors/processpool.py b/examples/executors/processpool.py
index df1e82b..37f58bb 100644
--- a/examples/executors/processpool.py
+++ b/examples/executors/processpool.py
@@ -19,6 +19,6 @@ if __name__ == '__main__':
print('Press Ctrl+{0} to exit'.format('Break' if os.name == 'nt' else 'C'))
try:
- scheduler.start()
+ scheduler.initialize()
except (KeyboardInterrupt, SystemExit):
pass
diff --git a/examples/misc/reference.py b/examples/misc/reference.py
index f6fc5dd..1479d2f 100644
--- a/examples/misc/reference.py
+++ b/examples/misc/reference.py
@@ -13,6 +13,6 @@ if __name__ == '__main__':
print('Press Ctrl+{0} to exit'.format('Break' if os.name == 'nt' else 'C'))
try:
- scheduler.start()
+ scheduler.initialize()
except (KeyboardInterrupt, SystemExit):
pass
diff --git a/examples/rpc/server.py b/examples/rpc/server.py
index 6fb8594..4c7c705 100644
--- a/examples/rpc/server.py
+++ b/examples/rpc/server.py
@@ -45,12 +45,12 @@ class SchedulerService(rpyc.Service):
if __name__ == '__main__':
scheduler = BackgroundScheduler()
- scheduler.start()
+ scheduler.initialize()
protocol_config = {'allow_public_attrs': True}
server = ThreadedServer(SchedulerService, port=12345, protocol_config=protocol_config)
try:
- server.start()
+ server.initialize()
except (KeyboardInterrupt, SystemExit):
pass
finally:
- scheduler.shutdown()
+ scheduler.stop()
diff --git a/examples/schedulers/async_.py b/examples/schedulers/async_.py
index 0f6677c..ab1784f 100644
--- a/examples/schedulers/async_.py
+++ b/examples/schedulers/async_.py
@@ -3,6 +3,7 @@ import logging
import anyio
from apscheduler.schedulers.async_ import AsyncScheduler
from apscheduler.triggers.interval import IntervalTrigger
+from apscheduler.workers.async_ import AsyncWorker
def say_hello():
@@ -10,8 +11,9 @@ def say_hello():
async def main():
- async with AsyncScheduler() as scheduler:
+ async with AsyncScheduler() as scheduler, AsyncWorker(scheduler.data_store):
await scheduler.add_schedule(say_hello, IntervalTrigger(seconds=1))
+ await scheduler.wait_until_stopped()
logging.basicConfig(level=logging.DEBUG)
try:
diff --git a/examples/schedulers/sync.py b/examples/schedulers/sync.py
index df78df9..f845727 100644
--- a/examples/schedulers/sync.py
+++ b/examples/schedulers/sync.py
@@ -2,6 +2,7 @@ import logging
from apscheduler.schedulers.sync import SyncScheduler
from apscheduler.triggers.interval import IntervalTrigger
+from apscheduler.workers.sync import SyncWorker
def say_hello():
@@ -10,7 +11,8 @@ def say_hello():
logging.basicConfig(level=logging.DEBUG)
try:
- with SyncScheduler() as scheduler:
+ with SyncScheduler() as scheduler, SyncWorker(scheduler.data_store, portal=scheduler.portal):
scheduler.add_schedule(say_hello, IntervalTrigger(seconds=1))
+ scheduler.wait_until_stopped()
except (KeyboardInterrupt, SystemExit):
pass
diff --git a/setup.cfg b/setup.cfg
index 86eee5b..647dcd7 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,9 +21,8 @@ license = MIT
packages = find:
python_requires = >= 3.7
install_requires =
- anyio ~= 2.0
+ anyio ~= 2.1
backports.zoneinfo; python_version < '3.9'
- sortedcontainers ~= 2.2
tzdata; platform_system == "Windows"
tzlocal >= 3.0b1, < 4.0
@@ -40,6 +39,7 @@ test =
curio
pytest >= 5.0
pytest-cov
+ pytest-freezegun
pytest-mock
pytz
trio
diff --git a/tests/backends/test_memory.py b/tests/backends/test_memory.py
deleted file mode 100644
index c8679f4..0000000
--- a/tests/backends/test_memory.py
+++ /dev/null
@@ -1,138 +0,0 @@
-from datetime import datetime, timedelta, timezone
-
-import pytest
-from apscheduler.abc import Schedule
-from apscheduler.datastores.memory import MemoryScheduleStore
-from apscheduler.events import SchedulesAdded, SchedulesRemoved, SchedulesUpdated
-from apscheduler.triggers.date import DateTrigger
-
-pytestmark = pytest.mark.anyio
-
-
-@pytest.fixture
-async def store():
- async with MemoryScheduleStore() as store_:
- yield store_
-
-
-@pytest.fixture
-async def events(store):
- events = []
- await store.subscribe(events.append)
- return events
-
-
-@pytest.fixture
-def schedules():
- trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
- schedule1 = Schedule(id='s1', task_id='bogus', trigger=trigger, args=(), kwargs={},
- coalesce=False, misfire_grace_time=None, tags=frozenset())
- schedule1.next_fire_time = trigger.next()
-
- trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc))
- schedule2 = Schedule(id='s2', task_id='bogus', trigger=trigger, args=(), kwargs={},
- coalesce=False, misfire_grace_time=None, tags=frozenset())
- schedule2.next_fire_time = trigger.next()
-
- trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc))
- schedule3 = Schedule(id='s3', task_id='bogus', trigger=trigger, args=(), kwargs={},
- coalesce=False, misfire_grace_time=None, tags=frozenset())
- return [schedule1, schedule2, schedule3]
-
-
-async def test_add_schedules(store, schedules, events):
- assert await store.get_next_fire_time() is None
- await store.add_or_replace_schedules(schedules)
- assert await store.get_next_fire_time() == datetime(2020, 9, 13, tzinfo=timezone.utc)
-
- assert await store.get_schedules() == schedules
- assert await store.get_schedules({'s1', 's2', 's3'}) == schedules
- assert await store.get_schedules({'s1'}) == [schedules[0]]
- assert await store.get_schedules({'s2'}) == [schedules[1]]
- assert await store.get_schedules({'s3'}) == [schedules[2]]
-
- assert len(events) == 1
- assert isinstance(events[0], SchedulesAdded)
- assert events[0].schedule_ids == {'s1', 's2', 's3'}
- assert events[0].earliest_next_fire_time == datetime(2020, 9, 13, tzinfo=timezone.utc)
-
-
-async def test_replace_schedules(store, schedules, events):
- await store.add_or_replace_schedules(schedules)
- events.clear()
-
- next_fire_time = schedules[2].trigger.next()
- schedule = Schedule(id='s3', task_id='foo', trigger=schedules[2].trigger, args=(), kwargs={},
- coalesce=False, misfire_grace_time=None, tags=frozenset())
- schedule.next_fire_time = next_fire_time
- await store.add_or_replace_schedules([schedule])
-
- schedules = await store.get_schedules([schedule.id])
- assert schedules[0].task_id == 'foo'
- assert schedules[0].next_fire_time == next_fire_time
- assert schedules[0].args == ()
- assert schedules[0].kwargs == {}
- assert not schedules[0].coalesce
- assert schedules[0].misfire_grace_time is None
- assert schedules[0].tags == frozenset()
-
- assert len(events) == 1
- assert isinstance(events[0], SchedulesUpdated)
- assert events[0].schedule_ids == {'s3'}
- assert events[0].earliest_next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc)
-
-
-async def test_update_schedules(store, schedules, events):
- await store.add_or_replace_schedules(schedules)
- events.clear()
-
- next_fire_time = datetime(2020, 9, 12, tzinfo=timezone.utc)
- updated_ids = await store.update_schedules({
- 's2': {'task_id': 'foo', 'next_fire_time': next_fire_time, 'args': (1,),
- 'kwargs': {'a': 'x'}, 'coalesce': True, 'misfire_grace_time': timedelta(seconds=5),
- 'tags': frozenset(['a', 'b'])},
- 'nonexistent': {}
- })
- assert updated_ids == {'s2'}
-
- schedules = await store.get_schedules({'s2'})
- assert len(schedules) == 1
- assert schedules[0].task_id == 'foo'
- assert schedules[0].args == (1,)
- assert schedules[0].kwargs == {'a': 'x'}
- assert schedules[0].coalesce
- assert schedules[0].next_fire_time == next_fire_time
- assert schedules[0].misfire_grace_time == timedelta(seconds=5)
- assert schedules[0].tags == frozenset(['a', 'b'])
-
- assert len(events) == 1
- assert isinstance(events[0], SchedulesUpdated)
- assert events[0].schedule_ids == {'s2'}
- assert events[0].earliest_next_fire_time == next_fire_time
-
-
-@pytest.mark.parametrize('ids', [None, {'s1', 's2'}], ids=['all', 'by_id'])
-async def test_remove_schedules(store, schedules, events, ids):
- await store.add_or_replace_schedules(schedules)
- events.clear()
-
- await store.remove_schedules(ids)
- assert len(events) == 1
- assert isinstance(events[0], SchedulesRemoved)
- if ids:
- assert events[0].schedule_ids == {'s1', 's2'}
- assert await store.get_schedules() == [schedules[2]]
- else:
- assert events[0].schedule_ids == {'s1', 's2', 's3'}
- assert await store.get_schedules() == []
-
-
-async def test_acquire_due_schedules(store, schedules, events):
- await store.add_or_replace_schedules(schedules)
- events.clear()
-
- now = datetime(2020, 9, 14, tzinfo=timezone.utc)
- async with store.acquire_due_schedules('dummy-id', now) as schedules:
- assert len(schedules) == 2
- assert schedules[0].id == 's1'
- assert schedules[1].id == 's2'
diff --git a/tests/conftest.py b/tests/conftest.py
index 62e552f..9d0f64c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,8 +1,15 @@
import sys
-from datetime import datetime
-from unittest.mock import Mock
+from contextlib import AsyncExitStack, ExitStack
+from functools import partial
import pytest
+from anyio import start_blocking_portal
+from asyncpg import create_pool
+from motor.motor_asyncio import AsyncIOMotorClient
+
+from apscheduler.datastores.memory import MemoryDataStore
+from apscheduler.datastores.mongodb import MongoDBDataStore
+from apscheduler.datastores.postgresql import PostgresqlDataStore
from apscheduler.serializers.cbor import CBORSerializer
from apscheduler.serializers.json import JSONSerializer
from apscheduler.serializers.pickle import PickleSerializer
@@ -12,42 +19,70 @@ if sys.version_info >= (3, 9):
else:
from backports.zoneinfo import ZoneInfo
+store_params = [
+ pytest.param(MemoryDataStore, id='memory'),
+ pytest.param(PostgresqlDataStore, id='postgresql'),
+ pytest.param(MongoDBDataStore, id='mongodb')
+]
+
@pytest.fixture(scope='session')
def timezone():
return ZoneInfo('Europe/Berlin')
+@pytest.fixture(params=[None, PickleSerializer, CBORSerializer, JSONSerializer],
+ ids=['none', 'pickle', 'cbor', 'json'])
+def serializer(request):
+ return request.param() if request.param else None
+
+
@pytest.fixture
-def freeze_time(monkeypatch, timezone):
- class TimeFreezer:
- def __init__(self, initial):
- self.current = initial
- self.increment = None
+def anyio_backend():
+ return 'asyncio'
- def get(self, tzinfo=None):
- now = self.current.astimezone(tzinfo) if tzinfo else self.current.replace(tzinfo=None)
- if self.increment:
- self.current += self.increment
- return now
- def set(self, new_time):
- self.current = new_time
+@pytest.fixture(params=store_params)
+async def store(request):
+ async with AsyncExitStack() as stack:
+ if request.param is PostgresqlDataStore:
+ pool = await create_pool('postgresql://postgres:secret@localhost/testdb',
+ min_size=1, max_size=2)
+ await stack.enter_async_context(pool)
+ store = PostgresqlDataStore(pool, start_from_scratch=True)
+ elif request.param is MongoDBDataStore:
+ client = AsyncIOMotorClient(tz_aware=True)
+ stack.push(lambda *args: client.close())
+ store = MongoDBDataStore(client, start_from_scratch=True)
+ else:
+ store = MemoryDataStore()
- def next(self,):
- return self.current + self.increment
+ await stack.enter_async_context(store)
+ yield store
- def set_increment(self, delta):
- self.increment = delta
- freezer = TimeFreezer(timezone.localize(datetime(2011, 4, 3, 18, 40)))
- fake_datetime = Mock(datetime, now=freezer.get)
- monkeypatch.setattr('apscheduler.triggers.interval.datetime', fake_datetime)
- monkeypatch.setattr('apscheduler.triggers.date.datetime', fake_datetime)
- return freezer
+@pytest.fixture
+def portal():
+ with start_blocking_portal() as portal:
+ yield portal
-@pytest.fixture(params=[None, PickleSerializer, CBORSerializer, JSONSerializer],
- ids=['none', 'pickle', 'cbor', 'json'])
-def serializer(request):
- return request.param() if request.param else None
+@pytest.fixture(params=store_params)
+def sync_store(request, portal):
+ with ExitStack() as stack:
+ if request.param is PostgresqlDataStore:
+ pool = portal.call(
+ partial(create_pool, 'postgresql://postgres:secret@localhost/testdb',
+ min_size=1, max_size=2)
+ )
+ stack.enter_context(portal.wrap_async_context_manager(pool))
+ store = PostgresqlDataStore(pool, start_from_scratch=True)
+ elif request.param is MongoDBDataStore:
+ client = portal.call(partial(AsyncIOMotorClient, tz_aware=True))
+ stack.push(lambda *args: portal.call(client.close))
+ store = MongoDBDataStore(client, start_from_scratch=True)
+ else:
+ store = MemoryDataStore()
+
+ stack.enter_context(portal.wrap_async_context_manager(store))
+ yield store
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
new file mode 100644
index 0000000..0eb16d7
--- /dev/null
+++ b/tests/test_datastores.py
@@ -0,0 +1,245 @@
+from datetime import datetime, timezone
+
+import pytest
+from anyio import move_on_after
+
+from apscheduler.abc import Job, Schedule
+from apscheduler.events import ScheduleAdded, ScheduleRemoved, ScheduleUpdated
+from apscheduler.policies import CoalescePolicy, ConflictPolicy
+from apscheduler.triggers.date import DateTrigger
+
+pytestmark = [
+ pytest.mark.anyio,
+ pytest.mark.parametrize('anyio_backend', ['asyncio'])
+]
+
+
+@pytest.fixture
+async def events(store):
+ events = []
+ store.subscribe(events.append)
+ return events
+
+
+@pytest.fixture
+def schedules():
+ trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
+ schedule1 = Schedule(id='s1', task_id='bogus', trigger=trigger, args=(), kwargs={},
+ coalesce=CoalescePolicy.latest, misfire_grace_time=None, tags=frozenset())
+ schedule1.next_fire_time = trigger.next()
+
+ trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc))
+ schedule2 = Schedule(id='s2', task_id='bogus', trigger=trigger, args=(), kwargs={},
+ coalesce=CoalescePolicy.latest, misfire_grace_time=None, tags=frozenset())
+ schedule2.next_fire_time = trigger.next()
+
+ trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc))
+ schedule3 = Schedule(id='s3', task_id='bogus', trigger=trigger, args=(), kwargs={},
+ coalesce=CoalescePolicy.latest, misfire_grace_time=None, tags=frozenset())
+ return [schedule1, schedule2, schedule3]
+
+
+@pytest.fixture
+def jobs():
+ job1 = Job('task1', print, ('hello',), {'arg2': 'world'}, 'schedule1',
+ datetime(2020, 10, 10, tzinfo=timezone.utc),
+ datetime(2020, 10, 10, 1, tzinfo=timezone.utc), frozenset())
+ job2 = Job('task2', print, ('hello',), {'arg2': 'world'}, None,
+ None, None, frozenset())
+ return [job1, job2]
+
+
+async def test_add_schedules(store, schedules, events):
+ for schedule in schedules:
+ await store.add_schedule(schedule, ConflictPolicy.exception)
+
+ assert await store.get_schedules() == schedules
+ assert await store.get_schedules({'s1', 's2', 's3'}) == schedules
+ assert await store.get_schedules({'s1'}) == [schedules[0]]
+ assert await store.get_schedules({'s2'}) == [schedules[1]]
+ assert await store.get_schedules({'s3'}) == [schedules[2]]
+
+ assert len(events) == 3
+ for event, schedule in zip(events, schedules):
+ assert isinstance(event, ScheduleAdded)
+ assert event.schedule_id == schedule.id
+ assert event.next_fire_time == schedule.next_fire_time
+
+
+async def test_replace_schedules(store, schedules, events):
+ for schedule in schedules:
+ await store.add_schedule(schedule, ConflictPolicy.exception)
+
+ events.clear()
+ next_fire_time = schedules[2].trigger.next()
+ schedule = Schedule(id='s3', task_id='foo', trigger=schedules[2].trigger, args=(),
+ kwargs={}, coalesce=CoalescePolicy.earliest, misfire_grace_time=None,
+ tags=frozenset())
+ schedule.next_fire_time = next_fire_time
+ await store.add_schedule(schedule, ConflictPolicy.replace)
+
+ schedules = await store.get_schedules({schedule.id})
+ assert schedules[0].task_id == 'foo'
+ assert schedules[0].next_fire_time == next_fire_time
+ assert schedules[0].args == ()
+ assert schedules[0].kwargs == {}
+ assert schedules[0].coalesce is CoalescePolicy.earliest
+ assert schedules[0].misfire_grace_time is None
+ assert schedules[0].tags == frozenset()
+
+ assert len(events) == 1
+ assert isinstance(events[0], ScheduleUpdated)
+ assert events[0].schedule_id == 's3'
+ assert events[0].next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc)
+
+
+async def test_remove_schedules(store, schedules, events):
+ for schedule in schedules:
+ await store.add_schedule(schedule, ConflictPolicy.exception)
+
+ events.clear()
+ await store.remove_schedules(['s1', 's2'])
+
+ assert len(events) == 2
+ assert isinstance(events[0], ScheduleRemoved)
+ assert events[0].schedule_id == 's1'
+ assert isinstance(events[1], ScheduleRemoved)
+ assert events[1].schedule_id == 's2'
+
+ assert await store.get_schedules() == [schedules[2]]
+
+
+@pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc))
+async def test_acquire_release_schedules(store, schedules, events):
+ for schedule in schedules:
+ await store.add_schedule(schedule, ConflictPolicy.exception)
+
+ events.clear()
+
+ # The first scheduler gets the first due schedule
+ schedules1 = await store.acquire_schedules('dummy-id1', 1)
+ assert len(schedules1) == 1
+ assert schedules1[0].id == 's1'
+
+ # The second scheduler gets the second due schedule
+ schedules2 = await store.acquire_schedules('dummy-id2', 1)
+ assert len(schedules2) == 1
+ assert schedules2[0].id == 's2'
+
+ # The third scheduler gets nothing
+ async with move_on_after(0.2):
+ await store.acquire_schedules('dummy-id3', 1)
+ pytest.fail('The call should have timed out')
+
+ # The schedules here have run their course, and releasing them should delete them
+ schedules1[0].next_fire_time = None
+ schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc)
+ await store.release_schedules('dummy-id1', schedules1)
+ await store.release_schedules('dummy-id2', schedules2)
+
+ # Check that the first schedule is gone
+ schedules = await store.get_schedules()
+ assert len(schedules) == 2
+ assert schedules[0].id == 's2'
+ assert schedules[1].id == 's3'
+
+ # Check for the appropriate update and delete events
+ assert len(events) == 2
+ assert isinstance(events[0], ScheduleRemoved)
+ assert isinstance(events[1], ScheduleUpdated)
+ assert events[0].schedule_id == 's1'
+ assert events[1].schedule_id == 's2'
+ assert events[1].next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc)
+
+
+async def test_acquire_schedules_lock_timeout(store, schedules, events, freezer):
+ """
+ Test that a scheduler can acquire schedules that were acquired by another scheduler but not
+ released within the lock timeout period.
+
+ """
+ # First, one scheduler acquires the first available schedule
+ await store.add_schedule(schedules[0], ConflictPolicy.exception)
+ acquired = await store.acquire_schedules('dummy-id1', 1)
+ assert len(acquired) == 1
+ assert acquired[0].id == 's1'
+
+ # Try to acquire the schedule just at the threshold (now == acquired_until).
+ # This should not yield any schedules.
+ freezer.tick(30)
+ async with move_on_after(0.2):
+ await store.acquire_schedules('dummy-id2', 1)
+ pytest.fail('The call should have timed out')
+
+ # Right after that, the schedule should be available
+ freezer.tick(1)
+ acquired = await store.acquire_schedules('dummy-id2', 1)
+ assert len(acquired) == 1
+ assert acquired[0].id == 's1'
+
+
+async def test_acquire_release_jobs(store, jobs, events):
+ for job in jobs:
+ await store.add_job(job)
+
+ events.clear()
+
+ # The first worker gets the first job in the queue
+ jobs1 = await store.acquire_jobs('dummy-id1', 1)
+ assert len(jobs1) == 1
+ assert jobs1[0].id == jobs[0].id
+
+ # The second worker gets the second job
+ jobs2 = await store.acquire_jobs('dummy-id2', 1)
+ assert len(jobs2) == 1
+ assert jobs2[0].id == jobs[1].id
+
+ # The third worker gets nothing
+ async with move_on_after(0.2):
+ await store.acquire_jobs('dummy-id3', 1)
+ pytest.fail('The call should have timed out')
+
+ # All the jobs should still be returned
+ visible_jobs = await store.get_jobs()
+ assert len(visible_jobs) == 2
+
+ await store.release_jobs('dummy-id1', jobs1)
+ await store.release_jobs('dummy-id2', jobs2)
+
+ # All the jobs should be gone
+ visible_jobs = await store.get_jobs()
+ assert len(visible_jobs) == 0
+
+ # Check for the appropriate update and delete events
+ # assert len(events) == 2
+ # assert isinstance(events[0], Job)
+ # assert isinstance(events[1], SchedulesUpdated)
+ # assert events[0].schedule_ids == {'s1'}
+ # assert events[1].schedule_ids == {'s2'}
+ # assert events[1].next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc)
+
+
+async def test_acquire_jobs_lock_timeout(store, jobs, events, freezer):
+ """
+ Test that a worker can acquire jobs that were acquired by another scheduler but not
+ released within the lock timeout period.
+
+ """
+ # First, one worker acquires the first available job
+ await store.add_job(jobs[0])
+ acquired = await store.acquire_jobs('dummy-id1', 1)
+ assert len(acquired) == 1
+ assert acquired[0].id == jobs[0].id
+
+ # Try to acquire the job just at the threshold (now == acquired_until).
+ # This should not yield any jobs.
+ freezer.tick(30)
+ async with move_on_after(0.2):
+ await store.acquire_jobs('dummy-id2', 1)
+ pytest.fail('The call should have timed out')
+
+ # Right after that, the job should be available
+ freezer.tick(1)
+ acquired = await store.acquire_jobs('dummy-id2', 1)
+ assert len(acquired) == 1
+ assert acquired[0].id == jobs[0].id
diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py
new file mode 100644
index 0000000..b155f6c
--- /dev/null
+++ b/tests/test_schedulers.py
@@ -0,0 +1,59 @@
+import logging
+from datetime import datetime, timezone
+
+import pytest
+
+from apscheduler.events import JobSuccessful
+from apscheduler.schedulers.async_ import AsyncScheduler
+from apscheduler.schedulers.sync import SyncScheduler
+from apscheduler.triggers.date import DateTrigger
+
+pytestmark = pytest.mark.anyio
+
+
+async def dummy_async_job():
+ return 'returnvalue'
+
+
+def dummy_sync_job():
+ return 'returnvalue'
+
+
+class TestAsyncScheduler:
+ async def test_schedule_job(self, caplog, store):
+ async def listener(event):
+ events.append(event)
+ if isinstance(event, JobSuccessful):
+ await scheduler.stop()
+
+ caplog.set_level(logging.DEBUG)
+ trigger = DateTrigger(datetime.now(timezone.utc))
+ events = []
+ async with AsyncScheduler(store) as scheduler:
+ scheduler.worker.subscribe(listener)
+ await scheduler.add_schedule(dummy_async_job, trigger)
+ await scheduler.wait_until_stopped()
+
+ assert len(events) == 2
+ assert isinstance(events[1], JobSuccessful)
+ assert events[1].return_value == 'returnvalue'
+
+
+class TestSyncScheduler:
+ @pytest.mark.parametrize('anyio_backend', ['asyncio'])
+ def test_schedule_job(self, caplog, anyio_backend, sync_store, portal):
+ def listener(event):
+ events.append(event)
+ if isinstance(event, JobSuccessful):
+ scheduler.stop()
+
+ caplog.set_level(logging.DEBUG)
+ events = []
+ with SyncScheduler(sync_store, portal=portal) as scheduler:
+ scheduler.worker.subscribe(listener)
+ scheduler.add_schedule(dummy_sync_job, DateTrigger(datetime.now(timezone.utc)))
+ scheduler.wait_until_stopped()
+
+ assert len(events) == 2
+ assert isinstance(events[1], JobSuccessful)
+ assert events[1].return_value == 'returnvalue'
diff --git a/tests/test_workers.py b/tests/test_workers.py
new file mode 100644
index 0000000..89f88e3
--- /dev/null
+++ b/tests/test_workers.py
@@ -0,0 +1,152 @@
+import threading
+from datetime import datetime
+
+import pytest
+from anyio import create_event, fail_after
+
+from apscheduler.abc import Job
+from apscheduler.events import JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated
+from apscheduler.workers.async_ import AsyncWorker
+from apscheduler.workers.sync import SyncWorker
+
+pytestmark = pytest.mark.anyio
+
+
+def sync_func(*args, fail: bool, **kwargs):
+ if fail:
+ raise Exception('failing as requested')
+ else:
+ return args, kwargs
+
+
+async def async_func(*args, fail: bool, **kwargs):
+ if fail:
+ raise Exception('failing as requested')
+ else:
+ return args, kwargs
+
+
+def fail_func():
+ pytest.fail('This function should never be run')
+
+
+class TestAsyncWorker:
+ @pytest.mark.parametrize('target_func', [sync_func, async_func], ids=['sync', 'async'])
+ @pytest.mark.parametrize('fail', [False, True], ids=['success', 'fail'])
+ @pytest.mark.parametrize('anyio_backend', ['asyncio'])
+ async def test_run_job_nonscheduled_success(self, target_func, fail, store):
+ async def listener(worker_event):
+ worker_events.append(worker_event)
+ if len(worker_events) == 2:
+ await event.set()
+
+ worker_events = []
+ event = create_event()
+ job = Job('task_id', func=target_func, args=(1, 2), kwargs={'x': 'foo', 'fail': fail})
+ async with AsyncWorker(store) as worker:
+ worker.subscribe(listener)
+ await store.add_job(job)
+ async with fail_after(2):
+ await event.wait()
+
+ assert len(worker_events) == 2
+
+ assert isinstance(worker_events[0], JobUpdated)
+ assert worker_events[0].job_id == job.id
+ assert worker_events[0].task_id == 'task_id'
+ assert worker_events[0].schedule_id is None
+
+ assert worker_events[1].job_id == job.id
+ assert worker_events[1].task_id == 'task_id'
+ assert worker_events[1].schedule_id is None
+ if fail:
+ assert isinstance(worker_events[1], JobFailed)
+ assert type(worker_events[1].exception) is Exception
+ assert isinstance(worker_events[1].traceback, str)
+ else:
+ assert isinstance(worker_events[1], JobSuccessful)
+ assert worker_events[1].return_value == ((1, 2), {'x': 'foo'})
+
+ @pytest.mark.parametrize('anyio_backend', ['asyncio'])
+ async def test_run_deadline_missed(self, store):
+ async def listener(worker_event):
+ worker_events.append(worker_event)
+ await event.set()
+
+ scheduled_start_time = datetime(2020, 9, 14)
+ worker_events = []
+ event = create_event()
+ job = Job('task_id', fail_func, args=(), kwargs={}, schedule_id='foo',
+ scheduled_fire_time=scheduled_start_time,
+ start_deadline=datetime(2020, 9, 14, 1))
+ async with AsyncWorker(store) as worker:
+ worker.subscribe(listener)
+ await store.add_job(job)
+ async with fail_after(5):
+ await event.wait()
+
+ assert len(worker_events) == 1
+ assert isinstance(worker_events[0], JobDeadlineMissed)
+ assert worker_events[0].job_id == job.id
+ assert worker_events[0].task_id == 'task_id'
+ assert worker_events[0].schedule_id == 'foo'
+ assert worker_events[0].scheduled_fire_time == scheduled_start_time
+
+
+class TestSyncWorker:
+ @pytest.mark.parametrize('target_func', [sync_func, async_func], ids=['sync', 'async'])
+ @pytest.mark.parametrize('fail', [False, True], ids=['success', 'fail'])
+ def test_run_job_nonscheduled(self, anyio_backend, target_func, fail, sync_store, portal):
+ def listener(worker_event):
+ print('received event:', worker_event)
+ worker_events.append(worker_event)
+ if len(worker_events) == 2:
+ event.set()
+
+ worker_events = []
+ event = threading.Event()
+ job = Job('task_id', func=target_func, args=(1, 2), kwargs={'x': 'foo', 'fail': fail})
+ with SyncWorker(sync_store, portal=portal) as worker:
+ worker.subscribe(listener)
+ portal.call(sync_store.add_job, job)
+ event.wait(2)
+
+ assert len(worker_events) == 2
+
+ assert isinstance(worker_events[0], JobUpdated)
+ assert worker_events[0].job_id == job.id
+ assert worker_events[0].task_id == 'task_id'
+ assert worker_events[0].schedule_id is None
+
+ assert worker_events[1].job_id == job.id
+ assert worker_events[1].task_id == 'task_id'
+ assert worker_events[1].schedule_id is None
+ if fail:
+ assert isinstance(worker_events[1], JobFailed)
+ assert type(worker_events[1].exception) is Exception
+ assert isinstance(worker_events[1].traceback, str)
+ else:
+ assert isinstance(worker_events[1], JobSuccessful)
+ assert worker_events[1].return_value == ((1, 2), {'x': 'foo'})
+
+ def test_run_deadline_missed(self, anyio_backend, sync_store, portal):
+ def listener(worker_event):
+ worker_events.append(worker_event)
+ event.set()
+
+ scheduled_start_time = datetime(2020, 9, 14)
+ worker_events = []
+ event = threading.Event()
+ job = Job('task_id', fail_func, args=(), kwargs={}, schedule_id='foo',
+ scheduled_fire_time=scheduled_start_time,
+ start_deadline=datetime(2020, 9, 14, 1))
+ with SyncWorker(sync_store, portal=portal) as worker:
+ worker.subscribe(listener)
+ portal.call(sync_store.add_job, job)
+ event.wait(5)
+
+ assert len(worker_events) == 1
+ assert isinstance(worker_events[0], JobDeadlineMissed)
+ assert worker_events[0].job_id == job.id
+ assert worker_events[0].task_id == 'task_id'
+ assert worker_events[0].schedule_id == 'foo'
diff --git a/tests/workers/test_local.py b/tests/workers/test_local.py
deleted file mode 100644
index d72d559..0000000
--- a/tests/workers/test_local.py
+++ /dev/null
@@ -1,101 +0,0 @@
-from datetime import datetime
-
-import pytest
-from anyio import fail_after, sleep
-from apscheduler.abc import Job
-from apscheduler.events import JobAdded, JobDeadlineMissed, JobFailed, JobSuccessful, JobUpdated
-from apscheduler.workers.local import LocalExecutor
-
-pytestmark = pytest.mark.anyio
-
-
-@pytest.mark.parametrize('sync', [True, False], ids=['sync', 'async'])
-@pytest.mark.parametrize('fail', [False, True], ids=['success', 'fail'])
-async def test_run_job_nonscheduled_success(sync, fail):
- def sync_func(*args, **kwargs):
- nonlocal received_args, received_kwargs
- received_args = args
- received_kwargs = kwargs
- if fail:
- raise Exception('failing as requested')
- else:
- return 'success'
-
- async def async_func(*args, **kwargs):
- nonlocal received_args, received_kwargs
- received_args = args
- received_kwargs = kwargs
- if fail:
- raise Exception('failing as requested')
- else:
- return 'success'
-
- received_args = received_kwargs = None
- events = []
- async with LocalExecutor() as worker:
- await worker.subscribe(events.append)
-
- job = Job('task_id', sync_func if sync else async_func, args=(1, 2), kwargs={'x': 'foo'})
- await worker.submit_job(job)
-
- async with fail_after(1):
- while len(events) < 3:
- await sleep(0)
-
- assert received_args == (1, 2)
- assert received_kwargs == {'x': 'foo'}
-
- assert isinstance(events[0], JobAdded)
- assert events[0].job_id == job.id
- assert events[0].task_id == 'task_id'
- assert events[0].schedule_id is None
- assert events[0].scheduled_start_time is None
-
- assert isinstance(events[1], JobUpdated)
- assert events[1].job_id == job.id
- assert events[1].task_id == 'task_id'
- assert events[1].schedule_id is None
- assert events[1].scheduled_start_time is None
-
- assert events[2].job_id == job.id
- assert events[2].task_id == 'task_id'
- assert events[2].schedule_id is None
- assert events[2].scheduled_start_time is None
- if fail:
- assert isinstance(events[2], JobFailed)
- assert type(events[2].exception) is Exception
- assert isinstance(events[2].formatted_traceback, str)
- else:
- assert isinstance(events[2], JobSuccessful)
- assert events[2].return_value == 'success'
-
-
-async def test_run_deadline_missed():
- def func():
- pytest.fail('This function should never be run')
-
- scheduled_start_time = datetime(2020, 9, 14)
- events = []
- async with LocalExecutor() as worker:
- await worker.subscribe(events.append)
-
- job = Job('task_id', func, args=(), kwargs={}, schedule_id='foo',
- scheduled_start_time=scheduled_start_time,
- start_deadline=datetime(2020, 9, 14, 1))
- await worker.submit_job(job)
-
- async with fail_after(1):
- while len(events) < 2:
- await sleep(0)
-
- assert isinstance(events[0], JobAdded)
- assert events[0].job_id == job.id
- assert events[0].task_id == 'task_id'
- assert events[0].schedule_id == 'foo'
- assert events[0].scheduled_start_time == scheduled_start_time
-
- assert isinstance(events[1], JobDeadlineMissed)
- assert events[1].job_id == job.id
- assert events[1].task_id == 'task_id'
- assert events[1].schedule_id == 'foo'
- assert events[1].scheduled_start_time == scheduled_start_time