From a0b7da22086b1704f52a3a34a3edb25f2bcd5d7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 29 Aug 2021 15:58:06 +0300 Subject: Switched to the src/ layout --- src/apscheduler/__init__.py | 10 + src/apscheduler/abc.py | 278 ++++++++++++++++ src/apscheduler/adapters.py | 67 ++++ src/apscheduler/datastores/__init__.py | 0 src/apscheduler/datastores/async_/__init__.py | 0 src/apscheduler/datastores/async_/sqlalchemy.py | 421 ++++++++++++++++++++++++ src/apscheduler/datastores/sync/__init__.py | 0 src/apscheduler/datastores/sync/memory.py | 232 +++++++++++++ src/apscheduler/datastores/sync/mongodb.py | 254 ++++++++++++++ src/apscheduler/datastores/sync/sqlalchemy.py | 325 ++++++++++++++++++ src/apscheduler/enums.py | 8 + src/apscheduler/events.py | 251 ++++++++++++++ src/apscheduler/exceptions.py | 59 ++++ src/apscheduler/marshalling.py | 123 +++++++ src/apscheduler/policies.py | 19 ++ src/apscheduler/schedulers/__init__.py | 0 src/apscheduler/schedulers/async_.py | 224 +++++++++++++ src/apscheduler/schedulers/sync.py | 209 ++++++++++++ src/apscheduler/serializers/__init__.py | 0 src/apscheduler/serializers/cbor.py | 36 ++ src/apscheduler/serializers/json.py | 45 +++ src/apscheduler/serializers/pickle.py | 15 + src/apscheduler/structures.py | 81 +++++ src/apscheduler/triggers/__init__.py | 0 src/apscheduler/triggers/calendarinterval.py | 145 ++++++++ src/apscheduler/triggers/combining.py | 141 ++++++++ src/apscheduler/triggers/cron/__init__.py | 216 ++++++++++++ src/apscheduler/triggers/cron/expressions.py | 209 ++++++++++++ src/apscheduler/triggers/cron/fields.py | 130 ++++++++ src/apscheduler/triggers/date.py | 45 +++ src/apscheduler/triggers/interval.py | 99 ++++++ src/apscheduler/util.py | 82 +++++ src/apscheduler/validators.py | 165 ++++++++++ src/apscheduler/workers/__init__.py | 0 src/apscheduler/workers/async_.py | 249 ++++++++++++++ src/apscheduler/workers/sync.py | 161 +++++++++ 36 files changed, 4299 insertions(+) create mode 100644 src/apscheduler/__init__.py create mode 100644 src/apscheduler/abc.py create mode 100644 src/apscheduler/adapters.py create mode 100644 src/apscheduler/datastores/__init__.py create mode 100644 src/apscheduler/datastores/async_/__init__.py create mode 100644 src/apscheduler/datastores/async_/sqlalchemy.py create mode 100644 src/apscheduler/datastores/sync/__init__.py create mode 100644 src/apscheduler/datastores/sync/memory.py create mode 100644 src/apscheduler/datastores/sync/mongodb.py create mode 100644 src/apscheduler/datastores/sync/sqlalchemy.py create mode 100644 src/apscheduler/enums.py create mode 100644 src/apscheduler/events.py create mode 100644 src/apscheduler/exceptions.py create mode 100644 src/apscheduler/marshalling.py create mode 100644 src/apscheduler/policies.py create mode 100644 src/apscheduler/schedulers/__init__.py create mode 100644 src/apscheduler/schedulers/async_.py create mode 100644 src/apscheduler/schedulers/sync.py create mode 100644 src/apscheduler/serializers/__init__.py create mode 100644 src/apscheduler/serializers/cbor.py create mode 100644 src/apscheduler/serializers/json.py create mode 100644 src/apscheduler/serializers/pickle.py create mode 100644 src/apscheduler/structures.py create mode 100644 src/apscheduler/triggers/__init__.py create mode 100644 src/apscheduler/triggers/calendarinterval.py create mode 100644 src/apscheduler/triggers/combining.py create mode 100644 src/apscheduler/triggers/cron/__init__.py create mode 100644 src/apscheduler/triggers/cron/expressions.py create mode 100644 src/apscheduler/triggers/cron/fields.py create mode 100644 src/apscheduler/triggers/date.py create mode 100644 src/apscheduler/triggers/interval.py create mode 100644 src/apscheduler/util.py create mode 100644 src/apscheduler/validators.py create mode 100644 src/apscheduler/workers/__init__.py create mode 100644 src/apscheduler/workers/async_.py create mode 100644 src/apscheduler/workers/sync.py (limited to 'src') diff --git a/src/apscheduler/__init__.py b/src/apscheduler/__init__.py new file mode 100644 index 0000000..d9b86bb --- /dev/null +++ b/src/apscheduler/__init__.py @@ -0,0 +1,10 @@ +from pkg_resources import DistributionNotFound, get_distribution + +try: + release = get_distribution('APScheduler').version.split('-')[0] +except DistributionNotFound: + release = '3.5.0' + +version_info = tuple(int(x) if x.isdigit() else x for x in release.split('.')) +version = __version__ = '.'.join(str(x) for x in version_info[:3]) +del get_distribution, DistributionNotFound diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py new file mode 100644 index 0000000..2b88a86 --- /dev/null +++ b/src/apscheduler/abc.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from base64 import b64decode, b64encode +from datetime import datetime +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, List, Optional, Set, Type +from uuid import UUID + +from .policies import ConflictPolicy +from .structures import Job, Schedule + +if TYPE_CHECKING: + from . import events + + +class Trigger(Iterator[datetime], metaclass=ABCMeta): + """Abstract base class that defines the interface that every trigger must implement.""" + + __slots__ = () + + @abstractmethod + def next(self) -> Optional[datetime]: + """ + Return the next datetime to fire on. + + If no such datetime can be calculated, ``None`` is returned. + :raises apscheduler.exceptions.MaxIterationsReached: + """ + + @abstractmethod + def __getstate__(self): + """Return the (JSON compatible) serializable state of the trigger.""" + + @abstractmethod + def __setstate__(self, state): + """Initialize an empty instance from an existing state.""" + + def __iter__(self): + return self + + def __next__(self) -> datetime: + dateval = self.next() + if dateval is None: + raise StopIteration + else: + return dateval + + +class Serializer(metaclass=ABCMeta): + __slots__ = () + + @abstractmethod + def serialize(self, obj) -> bytes: + pass + + def serialize_to_unicode(self, obj) -> str: + return b64encode(self.serialize(obj)).decode('ascii') + + @abstractmethod + def deserialize(self, serialized: bytes): + pass + + def deserialize_from_unicode(self, serialized: str): + return self.deserialize(b64decode(serialized)) + + +class EventSource(metaclass=ABCMeta): + @abstractmethod + def subscribe( + self, callback: Callable[[events.Event], Any], + event_types: Optional[Iterable[Type[events.Event]]] = None + ) -> events.SubscriptionToken: + """ + Subscribe to events from this event source. + + :param callback: callable to be called with the event object when an event is published + :param event_types: an iterable of concrete Event classes to subscribe to + """ + + @abstractmethod + def unsubscribe(self, token: events.SubscriptionToken) -> None: + """ + Cancel an event subscription. + + :param token: a token returned from :meth:`subscribe` + """ + + +class DataStore(EventSource): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + @abstractmethod + def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + """ + Get schedules from the data store. + + :param ids: a specific set of schedule IDs to return, or ``None`` to return all schedules + :return: the list of matching schedules, in unspecified order + """ + + @abstractmethod + def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + """ + Add or update the given schedule in the data store. + + :param schedule: schedule to be added + :param conflict_policy: policy that determines what to do if there is an existing schedule + with the same ID + """ + + @abstractmethod + def remove_schedules(self, ids: Iterable[str]) -> None: + """ + Remove schedules from the data store. + + :param ids: a specific set of schedule IDs to remove + """ + + @abstractmethod + def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + """ + Acquire unclaimed due schedules for processing. + + This method claims up to the requested number of schedules for the given scheduler and + returns them. + + :param scheduler_id: unique identifier of the scheduler + :param limit: maximum number of schedules to claim + :return: the list of claimed schedules + """ + + @abstractmethod + def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + """ + Release the claims on the given schedules and update them on the store. + + :param scheduler_id: unique identifier of the scheduler + :param schedules: the previously claimed schedules + """ + + @abstractmethod + def add_job(self, job: Job) -> None: + """ + Add a job to be executed by an eligible worker. + + :param job: the job object + """ + + @abstractmethod + def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + """ + Get the list of pending jobs. + + :param ids: a specific set of job IDs to return, or ``None`` to return all jobs + :return: the list of matching pending jobs, in the order they will be given to workers + """ + + @abstractmethod + def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + """ + Acquire unclaimed jobs for execution. + + This method claims up to the requested number of jobs for the given worker and returns + them. + + :param worker_id: unique identifier of the worker + :param limit: maximum number of jobs to claim and return + :return: the list of claimed jobs + """ + + @abstractmethod + def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + """ + Releases the claim on the given jobs + + :param worker_id: unique identifier of the worker + :param jobs: the previously claimed jobs + """ + + +class AsyncDataStore(EventSource): + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + @abstractmethod + async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + """ + Get schedules from the data store. + + :param ids: a specific set of schedule IDs to return, or ``None`` to return all schedules + :return: the list of matching schedules, in unspecified order + """ + + @abstractmethod + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + """ + Add or update the given schedule in the data store. + + :param schedule: schedule to be added + :param conflict_policy: policy that determines what to do if there is an existing schedule + with the same ID + """ + + @abstractmethod + async def remove_schedules(self, ids: Iterable[str]) -> None: + """ + Remove schedules from the data store. + + :param ids: a specific set of schedule IDs to remove + """ + + @abstractmethod + async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + """ + Acquire unclaimed due schedules for processing. + + This method claims up to the requested number of schedules for the given scheduler and + returns them. + + :param scheduler_id: unique identifier of the scheduler + :param limit: maximum number of schedules to claim + :return: the list of claimed schedules + """ + + @abstractmethod + async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + """ + Release the claims on the given schedules and update them on the store. + + :param scheduler_id: unique identifier of the scheduler + :param schedules: the previously claimed schedules + """ + + @abstractmethod + async def add_job(self, job: Job) -> None: + """ + Add a job to be executed by an eligible worker. + + :param job: the job object + """ + + @abstractmethod + async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + """ + Get the list of pending jobs. + + :param ids: a specific set of job IDs to return, or ``None`` to return all jobs + :return: the list of matching pending jobs, in the order they will be given to workers + """ + + @abstractmethod + async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + """ + Acquire unclaimed jobs for execution. + + This method claims up to the requested number of jobs for the given worker and returns + them. + + :param worker_id: unique identifier of the worker + :param limit: maximum number of jobs to claim and return + :return: the list of claimed jobs + """ + + @abstractmethod + async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + """ + Releases the claim on the given jobs + + :param worker_id: unique identifier of the worker + :param jobs: the previously claimed jobs + """ diff --git a/src/apscheduler/adapters.py b/src/apscheduler/adapters.py new file mode 100644 index 0000000..19b1e97 --- /dev/null +++ b/src/apscheduler/adapters.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Set, Type +from uuid import UUID + +from anyio import to_thread +from anyio.from_thread import BlockingPortal + +from . import events +from .abc import AsyncDataStore, DataStore +from .events import Event, SubscriptionToken +from .policies import ConflictPolicy +from .structures import Job, Schedule +from .util import reentrant + + +@reentrant +@dataclass +class AsyncDataStoreAdapter(AsyncDataStore): + original: DataStore + _portal: BlockingPortal = field(init=False) + + async def __aenter__(self) -> AsyncDataStoreAdapter: + self._portal = BlockingPortal() + await self._portal.__aenter__() + await to_thread.run_sync(self.original.__enter__) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb) + await self._portal.__aexit__(exc_type, exc_val, exc_tb) + + async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + return await to_thread.run_sync(self.original.get_schedules, ids) + + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + await to_thread.run_sync(self.original.add_schedule, schedule, conflict_policy) + + async def remove_schedules(self, ids: Iterable[str]) -> None: + await to_thread.run_sync(self.original.remove_schedules, ids) + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + return await to_thread.run_sync(self.original.acquire_schedules, scheduler_id, limit) + + async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + await to_thread.run_sync(self.original.release_schedules, scheduler_id, schedules) + + async def add_job(self, job: Job) -> None: + await to_thread.run_sync(self.original.add_job, job) + + async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + return await to_thread.run_sync(self.original.get_jobs, ids) + + async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit) + + async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + return await to_thread.run_sync(self.original.release_jobs, worker_id, jobs) + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + return self.original.subscribe(partial(self._portal.call, callback), event_types) + + def unsubscribe(self, token: events.SubscriptionToken) -> None: + self.original.unsubscribe(token) diff --git a/src/apscheduler/datastores/__init__.py b/src/apscheduler/datastores/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apscheduler/datastores/async_/__init__.py b/src/apscheduler/datastores/async_/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py new file mode 100644 index 0000000..a83d95d --- /dev/null +++ b/src/apscheduler/datastores/async_/sqlalchemy.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +import json +import logging +from contextlib import AsyncExitStack, closing +from datetime import datetime, timedelta, timezone +from json import JSONDecodeError +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from uuid import UUID + +import sniffio +from anyio import TASK_STATUS_IGNORED, create_task_group, sleep +from attr import asdict +from sqlalchemy import ( + Column, DateTime, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, func, or_, + select) +from sqlalchemy.engine import URL +from sqlalchemy.exc import CompileError, IntegrityError +from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncConnectable +from sqlalchemy.sql.ddl import DropTable + +from ... import events as events_module +from ...abc import AsyncDataStore, Job, Schedule, Serializer +from ...events import ( + AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded, + ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) +from ...exceptions import ConflictingIdError, SerializationError +from ...policies import ConflictPolicy +from ...serializers.pickle import PickleSerializer +from ...util import reentrant + +logger = logging.getLogger(__name__) + + +def default_json_handler(obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.timestamp() + elif isinstance(obj, UUID): + return obj.hex + elif isinstance(obj, frozenset): + return list(obj) + + raise TypeError(f'Cannot JSON encode type {type(obj)}') + + +def json_object_hook(obj: Dict[str, Any]) -> Any: + for key, value in obj.items(): + if key == 'timestamp': + obj[key] = datetime.fromtimestamp(value, timezone.utc) + elif key == 'job_id': + obj[key] = UUID(value) + elif key == 'tags': + obj[key] = frozenset(value) + + return obj + + +@reentrant +class SQLAlchemyDataStore(AsyncDataStore): + _metadata = MetaData() + + t_metadata = Table( + 'metadata', + _metadata, + Column('schema_version', Integer, nullable=False) + ) + t_schedules = Table( + 'schedules', + _metadata, + Column('id', Unicode, primary_key=True), + Column('task_id', Unicode, nullable=False), + Column('serialized_data', LargeBinary, nullable=False), + Column('next_fire_time', DateTime(timezone=True), index=True), + Column('acquired_by', Unicode), + Column('acquired_until', DateTime(timezone=True)) + ) + t_jobs = Table( + 'jobs', + _metadata, + Column('id', Unicode(32), primary_key=True), + Column('task_id', Unicode, nullable=False, index=True), + Column('serialized_data', LargeBinary, nullable=False), + Column('created_at', DateTime(timezone=True), nullable=False), + Column('acquired_by', Unicode), + Column('acquired_until', DateTime(timezone=True)) + ) + + def __init__(self, bind: AsyncConnectable, *, schema: Optional[str] = None, + serializer: Optional[Serializer] = None, + lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1, + max_idle_time: float = 60, start_from_scratch: bool = False, + notify_channel: Optional[str] = 'apscheduler'): + self.bind = bind + self.schema = schema + self.serializer = serializer or PickleSerializer() + self.lock_expiration_delay = lock_expiration_delay + self.max_poll_time = max_poll_time + self.max_idle_time = max_idle_time + self.start_from_scratch = start_from_scratch + self._logger = logging.getLogger(__name__) + self._exit_stack = AsyncExitStack() + self._events = AsyncEventHub() + + # Find out if the dialect supports RETURNING + statement = self.t_jobs.update().returning(self.t_schedules.c.id) + try: + statement.compile(bind=self.bind) + except CompileError: + self._supports_update_returning = False + else: + self._supports_update_returning = True + + self.notify_channel = notify_channel + if notify_channel: + if self.bind.dialect.name != 'postgresql' or self.bind.dialect.driver != 'asyncpg': + self.notify_channel = None + + @classmethod + def from_url(cls, url: Union[str, URL], **options) -> 'SQLAlchemyDataStore': + engine = create_async_engine(url, future=True) + return cls(engine, **options) + + async def __aenter__(self): + asynclib = sniffio.current_async_library() or '(unknown)' + if asynclib != 'asyncio': + raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}') + + # Verify that the schema is in place + async with self.bind.begin() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + await conn.execute(DropTable(table, if_exists=True)) + + await conn.run_sync(self._metadata.create_all) + query = select(self.t_metadata.c.schema_version) + result = await conn.execute(query) + version = result.scalar() + if version is None: + await conn.execute(self.t_metadata.insert(values={'schema_version': 1})) + elif version > 1: + raise RuntimeError(f'Unexpected schema version ({version}); ' + f'only version 1 is supported by this version of APScheduler') + + await self._exit_stack.enter_async_context(self._events) + + if self.notify_channel: + task_group = create_task_group() + await self._exit_stack.enter_async_context(task_group) + await task_group.start(self._listen_notifications) + self._exit_stack.callback(task_group.cancel_scope.cancel) + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + + async def _publish(self, conn: AsyncConnection, event: DataStoreEvent) -> None: + if self.notify_channel: + event_type = event.__class__.__name__ + event_data = json.dumps(asdict(event), ensure_ascii=False, + default=default_json_handler) + notification = event_type + ' ' + event_data + if len(notification) < 8000: + await conn.execute(func.pg_notify(self.notify_channel, notification)) + return + + self._logger.warning( + 'Could not send %s notification because it is too long (%d >= 8000)', + event_type, len(notification)) + + self._events.publish(event) + + async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: + def callback(connection, pid, channel: str, payload: str) -> None: + self._logger.debug('Received notification on channel %s: %s', channel, payload) + event_type, _, json_data = payload.partition(' ') + try: + event_data = json.loads(json_data, object_hook=json_object_hook) + except JSONDecodeError: + self._logger.exception('Failed decoding JSON payload of notification: %s', payload) + return + + event_class = getattr(events_module, event_type) + event = event_class(**event_data) + self._events.publish(event) + + task_started_sent = False + while True: + with closing(await self.bind.raw_connection()) as conn: + asyncpg_conn = conn.connection._connection + await asyncpg_conn.add_listener(self.notify_channel, callback) + if not task_started_sent: + task_status.started() + task_started_sent = True + + try: + while True: + await sleep(self.max_idle_time) + await asyncpg_conn.execute('SELECT 1') + finally: + await asyncpg_conn.remove_listener(self.notify_channel, callback) + + def _deserialize_jobs(self, serialized_jobs: Iterable[Tuple[UUID, bytes]]) -> List[Job]: + jobs: List[Job] = [] + for job_id, serialized_data in serialized_jobs: + try: + jobs.append(self.serializer.deserialize(serialized_data)) + except SerializationError as exc: + self._events.publish(JobDeserializationFailed(job_id=job_id, exception=exc)) + + return jobs + + def _deserialize_schedules( + self, serialized_schedules: Iterable[Tuple[str, bytes]]) -> List[Schedule]: + jobs: List[Schedule] = [] + for schedule_id, serialized_data in serialized_schedules: + try: + jobs.append(self.serializer.deserialize(serialized_data)) + except SerializationError as exc: + self._events.publish( + ScheduleDeserializationFailed(schedule_id=schedule_id, exception=exc)) + + return jobs + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: SubscriptionToken) -> None: + self._events.unsubscribe(token) + + async def clear(self) -> None: + async with self.bind.begin() as conn: + await conn.execute(self.t_schedules.delete()) + await conn.execute(self.t_jobs.delete()) + + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + serialized_data = self.serializer.serialize(schedule) + statement = self.t_schedules.insert().\ + values(id=schedule.id, task_id=schedule.task_id, serialized_data=serialized_data, + next_fire_time=schedule.next_fire_time) + try: + async with self.bind.begin() as conn: + await conn.execute(statement) + event = ScheduleAdded(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + await self._publish(conn, event) + except IntegrityError: + if conflict_policy is ConflictPolicy.exception: + raise ConflictingIdError(schedule.id) from None + elif conflict_policy is ConflictPolicy.replace: + statement = self.t_schedules.update().\ + where(self.t_schedules.c.id == schedule.id).\ + values(serialized_data=serialized_data, + next_fire_time=schedule.next_fire_time) + async with self.bind.begin() as conn: + await conn.execute(statement) + + event = ScheduleUpdated(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + await self._publish(conn, event) + + async def remove_schedules(self, ids: Iterable[str]) -> None: + async with self.bind.begin() as conn: + now = datetime.now(timezone.utc) + conditions = and_(self.t_schedules.c.id.in_(ids), + or_(self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now)) + statement = self.t_schedules.delete(conditions) + if self._supports_update_returning: + statement = statement.returning(self.t_schedules.c.id) + removed_ids = [row[0] for row in await conn.execute(statement)] + else: + await conn.execute(statement) + + for schedule_id in removed_ids: + await self._publish(conn, ScheduleRemoved(schedule_id=schedule_id)) + + async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + query = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ + order_by(self.t_schedules.c.id) + if ids: + query = query.where(self.t_schedules.c.id.in_(ids)) + + async with self.bind.begin() as conn: + result = await conn.execute(query) + + return self._deserialize_schedules(result) + + async def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + async with self.bind.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, timezone.utc) + schedules_cte = select(self.t_schedules.c.id).\ + where(and_(self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_(self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now))).\ + limit(limit).cte() + subselect = select([schedules_cte.c.id]) + statement = self.t_schedules.update().where(self.t_schedules.c.id.in_(subselect)).\ + values(acquired_by=scheduler_id, acquired_until=acquired_until) + if self._supports_update_returning: + statement = statement.returning(self.t_schedules.c.id, + self.t_schedules.c.serialized_data) + result = await conn.execute(statement) + else: + await conn.execute(statement) + statement = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ + where(and_(self.t_schedules.c.acquired_by == scheduler_id)) + result = await conn.execute(statement) + + return self._deserialize_schedules(result) + + async def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + update_events: List[ScheduleUpdated] = [] + finished_schedule_ids: List[str] = [] + async with self.bind.begin() as conn: + update_args: List[Dict[str, Any]] = [] + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_data = self.serializer.serialize(schedule) + except SerializationError: + self._logger.exception('Error serializing schedule %r – ' + 'removing from data store', schedule.id) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append({ + 'p_id': schedule.id, + 'p_serialized_data': serialized_data, + 'p_next_fire_time': schedule.next_fire_time + }) + else: + finished_schedule_ids.append(schedule.id) + + # Update schedules that have a next fire time + if update_args: + p_id = bindparam('p_id') + p_serialized = bindparam('p_serialized_data') + p_next_fire_time = bindparam('p_next_fire_time') + statement = self.t_schedules.update().\ + where(and_(self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id)).\ + values(serialized_data=p_serialized, next_fire_time=p_next_fire_time) + next_fire_times = {arg['p_id']: arg['p_next_fire_time'] for arg in update_args} + if self._supports_update_returning: + statement = statement.returning(self.t_schedules.c.id) + updated_ids = [row[0] for row in await conn.execute(statement, update_args)] + for schedule_id in updated_ids: + event = ScheduleUpdated(schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id]) + update_events.append(event) + + # Remove schedules that have no next fire time or failed to serialize + if finished_schedule_ids: + statement = self.t_schedules.delete().\ + where(and_(self.t_schedules.c.id.in_(finished_schedule_ids), + self.t_schedules.c.acquired_by == scheduler_id)) + await conn.execute(statement) + + for event in update_events: + await self._publish(conn, event) + + for schedule_id in finished_schedule_ids: + await self._publish(conn, ScheduleRemoved(schedule_id=schedule_id)) + + async def add_job(self, job: Job) -> None: + now = datetime.now(timezone.utc) + serialized_data = self.serializer.serialize(job) + statement = self.t_jobs.insert().values(id=job.id.hex, task_id=job.task_id, + created_at=now, serialized_data=serialized_data) + async with self.bind.begin() as conn: + await conn.execute(statement) + + event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + tags=job.tags) + await self._publish(conn, event) + + async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + order_by(self.t_jobs.c.id) + if ids: + job_ids = [job_id.hex for job_id in ids] + query = query.where(self.t_jobs.c.id.in_(job_ids)) + + async with self.bind.begin() as conn: + result = await conn.execute(query) + + return self._deserialize_jobs(result) + + async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + async with self.bind.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + where(or_(self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now)).\ + order_by(self.t_jobs.c.created_at).\ + limit(limit) + + serialized_jobs: Dict[str, bytes] = {row[0]: row[1] + for row in await conn.execute(query)} + if serialized_jobs: + query = self.t_jobs.update().\ + values(acquired_by=worker_id, acquired_until=acquired_until).\ + where(self.t_jobs.c.id.in_(serialized_jobs)) + await conn.execute(query) + + return self._deserialize_jobs(serialized_jobs.items()) + + async def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + job_ids = [job.id.hex for job in jobs] + statement = self.t_jobs.delete().\ + where(and_(self.t_jobs.c.acquired_by == worker_id, self.t_jobs.c.id.in_(job_ids))) + + async with self.bind.begin() as conn: + await conn.execute(statement) diff --git a/src/apscheduler/datastores/sync/__init__.py b/src/apscheduler/datastores/sync/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apscheduler/datastores/sync/memory.py b/src/apscheduler/datastores/sync/memory.py new file mode 100644 index 0000000..2837102 --- /dev/null +++ b/src/apscheduler/datastores/sync/memory.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +from bisect import bisect_left, insort_right +from collections import defaultdict +from datetime import MAXYEAR, datetime, timedelta, timezone +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type +from uuid import UUID + +from ... import events +from ...abc import DataStore, Job, Schedule +from ...events import ( + EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) +from ...exceptions import ConflictingIdError +from ...policies import ConflictPolicy +from ...util import reentrant + +max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) + + +class ScheduleState: + __slots__ = 'schedule', 'next_fire_time', 'acquired_by', 'acquired_until' + + def __init__(self, schedule: Schedule): + self.schedule = schedule + self.next_fire_time = self.schedule.next_fire_time + self.acquired_by: Optional[str] = None + self.acquired_until: Optional[datetime] = None + + def __lt__(self, other): + if self.next_fire_time is None: + return False + elif other.next_fire_time is None: + return self.next_fire_time is not None + else: + return self.next_fire_time < other.next_fire_time + + def __eq__(self, other): + if isinstance(other, ScheduleState): + return self.schedule == other.schedule + + return NotImplemented + + def __hash__(self): + return hash(self.schedule.id) + + def __repr__(self): + return (f'<{self.__class__.__name__} id={self.schedule.id!r} ' + f'task_id={self.schedule.task_id!r} next_fire_time={self.next_fire_time}>') + + +class JobState: + __slots__ = 'job', 'created_at', 'acquired_by', 'acquired_until' + + def __init__(self, job: Job): + self.job = job + self.created_at = datetime.now(timezone.utc) + self.acquired_by: Optional[str] = None + self.acquired_until: Optional[datetime] = None + + def __lt__(self, other): + return self.created_at < other.created_at + + def __eq__(self, other): + return self.job.id == other.job.id + + def __hash__(self): + return hash(self.job.id) + + def __repr__(self): + return f'<{self.__class__.__name__} id={self.job.id!r} task_id={self.job.task_id!r}>' + + +@reentrant +class MemoryDataStore(DataStore): + def __init__(self, lock_expiration_delay: float = 30): + self.lock_expiration_delay = lock_expiration_delay + self._events = EventHub() + self._schedules: List[ScheduleState] = [] + self._schedules_by_id: Dict[str, ScheduleState] = {} + self._schedules_by_task_id: Dict[str, Set[ScheduleState]] = defaultdict(set) + self._jobs: List[JobState] = [] + self._jobs_by_id: Dict[UUID, JobState] = {} + self._jobs_by_task_id: Dict[str, Set[JobState]] = defaultdict(set) + + def _find_schedule_index(self, state: ScheduleState) -> Optional[int]: + left_index = bisect_left(self._schedules, state) + right_index = bisect_left(self._schedules, state) + return self._schedules.index(state, left_index, right_index + 1) + + def _find_job_index(self, state: JobState) -> Optional[int]: + left_index = bisect_left(self._jobs, state) + right_index = bisect_left(self._jobs, state) + return self._jobs.index(state, left_index, right_index + 1) + + def __enter__(self): + self._events.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._events.__exit__(exc_type, exc_val, exc_tb) + + def subscribe(self, callback: Callable[[events.Event], Any], + event_types: Optional[Iterable[Type[events.Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: events.SubscriptionToken) -> None: + self._events.unsubscribe(token) + + def clear(self) -> None: + self._schedules.clear() + self._schedules_by_id.clear() + self._schedules_by_task_id.clear() + self._jobs.clear() + self._jobs_by_id.clear() + self._jobs_by_task_id.clear() + + def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + return [state.schedule for state in self._schedules + if ids is None or state.schedule.id in ids] + + def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + old_state = self._schedules_by_id.get(schedule.id) + if old_state is not None: + if conflict_policy is ConflictPolicy.do_nothing: + return + elif conflict_policy is ConflictPolicy.exception: + raise ConflictingIdError(schedule.id) + + index = self._find_schedule_index(old_state) + del self._schedules[index] + self._schedules_by_task_id[old_state.schedule.task_id].remove(old_state) + + state = ScheduleState(schedule) + self._schedules_by_id[schedule.id] = state + self._schedules_by_task_id[schedule.task_id].add(state) + insort_right(self._schedules, state) + + if old_state is not None: + event = ScheduleUpdated(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + else: + event = ScheduleAdded(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + + self._events.publish(event) + + def remove_schedules(self, ids: Iterable[str]) -> None: + for schedule_id in ids: + state = self._schedules_by_id.pop(schedule_id, None) + if state: + self._schedules.remove(state) + event = ScheduleRemoved(schedule_id=state.schedule.id) + self._events.publish(event) + + def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + now = datetime.now(timezone.utc) + schedules: List[Schedule] = [] + for state in self._schedules: + if state.acquired_by is not None and state.acquired_until >= now: + continue + elif state.next_fire_time is None or state.next_fire_time > now: + break + + schedules.append(state.schedule) + state.acquired_by = scheduler_id + state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + if len(schedules) == limit: + break + + return schedules + + def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + # Send update events for schedules that have a next time + finished_schedule_ids: List[str] = [] + for s in schedules: + if s.next_fire_time is not None: + # Remove the schedule + schedule_state = self._schedules_by_id.get(s.id) + index = self._find_schedule_index(schedule_state) + del self._schedules[index] + + # Readd the schedule to its new position + schedule_state.next_fire_time = s.next_fire_time + schedule_state.acquired_by = None + schedule_state.acquired_until = None + insort_right(self._schedules, schedule_state) + event = ScheduleUpdated(schedule_id=s.id, next_fire_time=s.next_fire_time) + self._events.publish(event) + else: + finished_schedule_ids.append(s.id) + + # Remove schedules that didn't get a new next fire time + self.remove_schedules(finished_schedule_ids) + + def add_job(self, job: Job) -> None: + state = JobState(job) + self._jobs.append(state) + self._jobs_by_id[job.id] = state + self._jobs_by_task_id[job.task_id].add(state) + + event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + tags=job.tags) + self._events.publish(event) + + def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + if ids is not None: + ids = frozenset(ids) + + return [state.job for state in self._jobs if ids is None or state.job.id in ids] + + def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + now = datetime.now(timezone.utc) + jobs: List[Job] = [] + for state in self._jobs: + if state.acquired_by is not None and state.acquired_until >= now: + continue + + jobs.append(state.job) + state.acquired_by = worker_id + state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + if len(jobs) == limit: + break + + return jobs + + def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + assert self._jobs + indexes = sorted((self._find_job_index(self._jobs_by_id[j.id]) for j in jobs), + reverse=True) + for i in indexes: + state = self._jobs.pop(i) + self._jobs_by_task_id[state.job.task_id].remove(state) diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py new file mode 100644 index 0000000..01267b1 --- /dev/null +++ b/src/apscheduler/datastores/sync/mongodb.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import logging +from contextlib import ExitStack +from datetime import datetime, timezone +from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Type +from uuid import UUID + +from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne +from pymongo.collection import Collection +from pymongo.errors import DuplicateKeyError + +from ... import events +from ...abc import DataStore, Job, Schedule, Serializer +from ...events import ( + DataStoreEvent, EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, + SubscriptionToken) +from ...exceptions import ConflictingIdError, DeserializationError, SerializationError +from ...policies import ConflictPolicy +from ...serializers.pickle import PickleSerializer +from ...util import reentrant + + +@reentrant +class MongoDBDataStore(DataStore): + def __init__(self, client: MongoClient, *, serializer: Optional[Serializer] = None, + database: str = 'apscheduler', schedules_collection: str = 'schedules', + jobs_collection: str = 'jobs', lock_expiration_delay: float = 30, + start_from_scratch: bool = False): + super().__init__() + if not client.delegate.codec_options.tz_aware: + raise ValueError('MongoDB client must have tz_aware set to True') + + self.client = client + self.serializer = serializer or PickleSerializer() + self.lock_expiration_delay = lock_expiration_delay + self.start_from_scratch = start_from_scratch + self._database = client[database] + self._schedules: Collection = self._database[schedules_collection] + self._jobs: Collection = self._database[jobs_collection] + self._logger = logging.getLogger(__name__) + self._exit_stack = ExitStack() + self._events = EventHub() + + @classmethod + def from_url(cls, uri: str, **options) -> 'MongoDBDataStore': + client = MongoClient(uri) + return cls(client, **options) + + def __enter__(self): + server_info = self.client.server_info() + if server_info['versionArray'] < [4, 0]: + raise RuntimeError(f"MongoDB server must be at least v4.0; current version = " + f"{server_info['version']}") + + self._exit_stack.__enter__() + self._exit_stack.enter_context(self._events) + + if self.start_from_scratch: + self._schedules.delete_many({}) + self._jobs.delete_many({}) + + self._schedules.create_index('next_fire_time') + self._jobs.create_index('task_id') + self._jobs.create_index('created_at') + self._jobs.create_index('tags') + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def subscribe(self, callback: Callable[[events.Event], Any], + event_types: Optional[Iterable[Type[events.Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: events.SubscriptionToken) -> None: + self._events.unsubscribe(token) + + def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + schedules: List[Schedule] = [] + filters = {'_id': {'$in': list(ids)}} if ids is not None else {} + cursor = self._schedules.find(filters, projection=['_id', 'serialized_data']).sort('_id') + for document in cursor: + try: + schedule = self.serializer.deserialize(document['serialized_data']) + except DeserializationError: + self._logger.warning('Failed to deserialize schedule %r', document['_id']) + continue + + schedules.append(schedule) + + return schedules + + def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + event: DataStoreEvent + serialized_data = self.serializer.serialize(schedule) + document = { + '_id': schedule.id, + 'task_id': schedule.task_id, + 'serialized_data': serialized_data, + 'next_fire_time': schedule.next_fire_time + } + try: + self._schedules.insert_one(document) + except DuplicateKeyError: + if conflict_policy is ConflictPolicy.exception: + raise ConflictingIdError(schedule.id) from None + elif conflict_policy is ConflictPolicy.replace: + self._schedules.replace_one({'_id': schedule.id}, document, True) + event = ScheduleUpdated( + schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + self._events.publish(event) + else: + event = ScheduleAdded(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + self._events.publish(event) + + def remove_schedules(self, ids: Iterable[str]) -> None: + with self.client.start_session() as s, s.start_transaction(): + filters = {'_id': {'$in': list(ids)}} if ids is not None else {} + cursor = self._schedules.find(filters, projection=['_id']) + ids = [doc['_id'] for doc in cursor] + if ids: + self._schedules.delete_many(filters) + + for schedule_id in ids: + self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + + def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + schedules: List[Schedule] = [] + with self.client.start_session() as s, s.start_transaction(): + cursor = self._schedules.find( + {'next_fire_time': {'$ne': None}, + '$or': [{'acquired_until': {'$exists': False}}, + {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] + }, + projection=['serialized_data'] + ).sort('next_fire_time').limit(limit) + for document in cursor: + schedule = self.serializer.deserialize(document['serialized_data']) + schedules.append(schedule) + + if schedules: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, now.tzinfo) + filters = {'_id': {'$in': [schedule.id for schedule in schedules]}} + update = {'$set': {'acquired_by': scheduler_id, + 'acquired_until': acquired_until}} + self._schedules.update_many(filters, update) + + return schedules + + def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + updated_schedules: List[Tuple[str, datetime]] = [] + finished_schedule_ids: List[str] = [] + with self.client.start_session() as s, s.start_transaction(): + # Update schedules that have a next fire time + requests = [] + for schedule in schedules: + filters = {'_id': schedule.id, 'acquired_by': scheduler_id} + if schedule.next_fire_time is not None: + try: + serialized_data = self.serializer.serialize(schedule) + except SerializationError: + self._logger.exception('Error serializing schedule %r – ' + 'removing from data store', schedule.id) + requests.append(DeleteOne(filters)) + finished_schedule_ids.append(schedule.id) + continue + + update = { + '$unset': { + 'acquired_by': True, + 'acquired_until': True, + }, + '$set': { + 'next_fire_time': schedule.next_fire_time, + 'serialized_data': serialized_data + } + } + requests.append(UpdateOne(filters, update)) + updated_schedules.append((schedule.id, schedule.next_fire_time)) + else: + requests.append(DeleteOne(filters)) + finished_schedule_ids.append(schedule.id) + + if requests: + self._schedules.bulk_write(requests, ordered=False) + for schedule_id, next_fire_time in updated_schedules: + event = ScheduleUpdated(schedule_id=schedule_id, next_fire_time=next_fire_time) + self._events.publish(event) + + for schedule_id in finished_schedule_ids: + self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + + def add_job(self, job: Job) -> None: + serialized_data = self.serializer.serialize(job) + document = { + '_id': job.id, + 'serialized_data': serialized_data, + 'task_id': job.task_id, + 'created_at': datetime.now(timezone.utc), + 'tags': list(job.tags) + } + self._jobs.insert_one(document) + event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + tags=job.tags) + self._events.publish(event) + + def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + jobs: List[Job] = [] + filters = {'_id': {'$in': list(ids)}} if ids is not None else {} + cursor = self._jobs.find(filters, projection=['_id', 'serialized_data']).sort('_id') + for document in cursor: + try: + job = self.serializer.deserialize(document['serialized_data']) + except DeserializationError: + self._logger.warning('Failed to deserialize job %r', document['_id']) + continue + + jobs.append(job) + + return jobs + + def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + jobs: List[Job] = [] + with self.client.start_session() as s, s.start_transaction(): + cursor = self._jobs.find( + {'$or': [{'acquired_until': {'$exists': False}}, + {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] + }, + projection=['serialized_data'], + sort=[('created_at', ASCENDING)], + limit=limit + ) + for document in cursor: + job = self.serializer.deserialize(document['serialized_data']) + jobs.append(job) + + if jobs: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, timezone.utc) + filters = {'_id': {'$in': [job.id for job in jobs]}} + update = {'$set': {'acquired_by': worker_id, + 'acquired_until': acquired_until}} + self._jobs.update_many(filters, update) + return jobs + + def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + filters = {'_id': {'$in': [job.id for job in jobs]}, 'acquired_by': worker_id} + self._jobs.delete_many(filters) diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py new file mode 100644 index 0000000..34589cd --- /dev/null +++ b/src/apscheduler/datastores/sync/sqlalchemy.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from uuid import UUID + +from sqlalchemy import ( + Column, DateTime, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, or_, select) +from sqlalchemy.engine import URL, Connectable +from sqlalchemy.exc import CompileError, IntegrityError +from sqlalchemy.future import create_engine +from sqlalchemy.sql.ddl import DropTable + +from ...abc import AsyncDataStore, Job, Schedule, Serializer +from ...events import ( + Event, EventHub, JobAdded, JobDeserializationFailed, ScheduleAdded, + ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) +from ...exceptions import ConflictingIdError, SerializationError +from ...policies import ConflictPolicy +from ...serializers.pickle import PickleSerializer +from ...util import reentrant + +logger = logging.getLogger(__name__) + + +@reentrant +class SQLAlchemyDataStore(AsyncDataStore): + _metadata = MetaData() + + t_metadata = Table( + 'metadata', + _metadata, + Column('schema_version', Integer, nullable=False) + ) + t_schedules = Table( + 'schedules', + _metadata, + Column('id', Unicode, primary_key=True), + Column('task_id', Unicode, nullable=False), + Column('serialized_data', LargeBinary, nullable=False), + Column('next_fire_time', DateTime(timezone=True), index=True), + Column('acquired_by', Unicode), + Column('acquired_until', DateTime(timezone=True)) + ) + t_jobs = Table( + 'jobs', + _metadata, + Column('id', Unicode(32), primary_key=True), + Column('task_id', Unicode, nullable=False, index=True), + Column('serialized_data', LargeBinary, nullable=False), + Column('created_at', DateTime(timezone=True), nullable=False), + Column('acquired_by', Unicode), + Column('acquired_until', DateTime(timezone=True)) + ) + + def __init__(self, bind: Connectable, *, schema: Optional[str] = None, + serializer: Optional[Serializer] = None, + lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1, + max_idle_time: float = 60, start_from_scratch: bool = False): + self.bind = bind + self.schema = schema + self.serializer = serializer or PickleSerializer() + self.lock_expiration_delay = lock_expiration_delay + self.max_poll_time = max_poll_time + self.max_idle_time = max_idle_time + self.start_from_scratch = start_from_scratch + self._logger = logging.getLogger(__name__) + self._events = EventHub() + + # Find out if the dialect supports RETURNING + statement = self.t_jobs.update().returning(self.t_schedules.c.id) + try: + statement.compile(bind=self.bind) + except CompileError: + self._supports_update_returning = False + else: + self._supports_update_returning = True + + @classmethod + def from_url(cls, url: Union[str, URL], **options) -> 'SQLAlchemyDataStore': + engine = create_engine(url) + return cls(engine, **options) + + def __enter__(self): + with self.bind.begin() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + conn.execute(DropTable(table, if_exists=True)) + + self._metadata.create_all(conn) + query = select(self.t_metadata.c.schema_version) + result = conn.execute(query) + version = result.scalar() + if version is None: + conn.execute(self.t_metadata.insert(values={'schema_version': 1})) + elif version > 1: + raise RuntimeError(f'Unexpected schema version ({version}); ' + f'only version 1 is supported by this version of APScheduler') + + self._events.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._events.__exit__(exc_type, exc_val, exc_tb) + + def _deserialize_jobs(self, serialized_jobs: Iterable[Tuple[UUID, bytes]]) -> List[Job]: + jobs: List[Job] = [] + for job_id, serialized_data in serialized_jobs: + try: + jobs.append(self.serializer.deserialize(serialized_data)) + except SerializationError as exc: + self._events.publish(JobDeserializationFailed(job_id=job_id, exception=exc)) + + return jobs + + def _deserialize_schedules( + self, serialized_schedules: Iterable[Tuple[str, bytes]]) -> List[Schedule]: + jobs: List[Schedule] = [] + for schedule_id, serialized_data in serialized_schedules: + try: + jobs.append(self.serializer.deserialize(serialized_data)) + except SerializationError as exc: + self._events.publish( + ScheduleDeserializationFailed(schedule_id=schedule_id, exception=exc)) + + return jobs + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: SubscriptionToken) -> None: + self._events.unsubscribe(token) + + def clear(self) -> None: + with self.bind.begin() as conn: + conn.execute(self.t_schedules.delete()) + conn.execute(self.t_jobs.delete()) + + def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: + serialized_data = self.serializer.serialize(schedule) + statement = self.t_schedules.insert().\ + values(id=schedule.id, task_id=schedule.task_id, serialized_data=serialized_data, + next_fire_time=schedule.next_fire_time) + try: + with self.bind.begin() as conn: + conn.execute(statement) + except IntegrityError: + if conflict_policy is ConflictPolicy.exception: + raise ConflictingIdError(schedule.id) from None + elif conflict_policy is ConflictPolicy.replace: + statement = self.t_schedules.update().\ + where(self.t_schedules.c.id == schedule.id).\ + values(serialized_data=serialized_data, + next_fire_time=schedule.next_fire_time) + with self.bind.begin() as conn: + conn.execute(statement) + + event = ScheduleUpdated(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + self._events.publish(event) + else: + event = ScheduleAdded(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + self._events.publish(event) + + def remove_schedules(self, ids: Iterable[str]) -> None: + with self.bind.begin() as conn: + now = datetime.now(timezone.utc) + conditions = and_(self.t_schedules.c.id.in_(ids), + or_(self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now)) + statement = self.t_schedules.delete(conditions) + if self._supports_update_returning: + statement = statement.returning(self.t_schedules.c.id) + removed_ids = [row[0] for row in conn.execute(statement)] + else: + conn.execute(statement) + + for schedule_id in removed_ids: + self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + + def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: + query = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ + order_by(self.t_schedules.c.id) + if ids: + query = query.where(self.t_schedules.c.id.in_(ids)) + + with self.bind.begin() as conn: + result = conn.execute(query) + + return self._deserialize_schedules(result) + + def acquire_schedules(self, scheduler_id: str, limit: int) -> List[Schedule]: + with self.bind.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, timezone.utc) + schedules_cte = select(self.t_schedules.c.id).\ + where(and_(self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_(self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now))).\ + limit(limit).cte() + subselect = select([schedules_cte.c.id]) + statement = self.t_schedules.update().where(self.t_schedules.c.id.in_(subselect)).\ + values(acquired_by=scheduler_id, acquired_until=acquired_until) + if self._supports_update_returning: + statement = statement.returning(self.t_schedules.c.id, + self.t_schedules.c.serialized_data) + result = conn.execute(statement) + else: + conn.execute(statement) + statement = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ + where(and_(self.t_schedules.c.acquired_by == scheduler_id)) + result = conn.execute(statement) + + return self._deserialize_schedules(result) + + def release_schedules(self, scheduler_id: str, schedules: List[Schedule]) -> None: + update_events: List[ScheduleUpdated] = [] + finished_schedule_ids: List[str] = [] + with self.bind.begin() as conn: + update_args: List[Dict[str, Any]] = [] + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_data = self.serializer.serialize(schedule) + except SerializationError: + self._logger.exception('Error serializing schedule %r – ' + 'removing from data store', schedule.id) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append({ + 'p_id': schedule.id, + 'p_serialized_data': serialized_data, + 'p_next_fire_time': schedule.next_fire_time + }) + else: + finished_schedule_ids.append(schedule.id) + + # Update schedules that have a next fire time + if update_args: + p_id = bindparam('p_id') + p_serialized = bindparam('p_serialized_data') + p_next_fire_time = bindparam('p_next_fire_time') + statement = self.t_schedules.update().\ + where(and_(self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id)).\ + values(serialized_data=p_serialized, next_fire_time=p_next_fire_time) + next_fire_times = {arg['p_id']: arg['p_next_fire_time'] for arg in update_args} + if self._supports_update_returning: + statement = statement.returning(self.t_schedules.c.id) + updated_ids = [row[0] for row in conn.execute(statement, update_args)] + for schedule_id in updated_ids: + event = ScheduleUpdated(schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id]) + update_events.append(event) + + # Remove schedules that have no next fire time or failed to serialize + if finished_schedule_ids: + statement = self.t_schedules.delete().\ + where(and_(self.t_schedules.c.id.in_(finished_schedule_ids), + self.t_schedules.c.acquired_by == scheduler_id)) + conn.execute(statement) + + for event in update_events: + self._events.publish(event) + + for schedule_id in finished_schedule_ids: + self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) + + def add_job(self, job: Job) -> None: + now = datetime.now(timezone.utc) + serialized_data = self.serializer.serialize(job) + statement = self.t_jobs.insert().values(id=job.id.hex, task_id=job.task_id, + created_at=now, serialized_data=serialized_data) + with self.bind.begin() as conn: + conn.execute(statement) + + event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + tags=job.tags) + self._events.publish(event) + + def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> List[Job]: + query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + order_by(self.t_jobs.c.id) + if ids: + job_ids = [job_id.hex for job_id in ids] + query = query.where(self.t_jobs.c.id.in_(job_ids)) + + with self.bind.begin() as conn: + result = conn.execute(query) + + return self._deserialize_jobs(result) + + def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: + with self.bind.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + where(or_(self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now)).\ + order_by(self.t_jobs.c.created_at).\ + limit(limit) + + serialized_jobs: Dict[str, bytes] = {row[0]: row[1] + for row in conn.execute(query)} + if serialized_jobs: + query = self.t_jobs.update().\ + values(acquired_by=worker_id, acquired_until=acquired_until).\ + where(self.t_jobs.c.id.in_(serialized_jobs)) + conn.execute(query) + + return self._deserialize_jobs(serialized_jobs.items()) + + def release_jobs(self, worker_id: str, jobs: List[Job]) -> None: + job_ids = [job.id.hex for job in jobs] + statement = self.t_jobs.delete().\ + where(and_(self.t_jobs.c.acquired_by == worker_id, self.t_jobs.c.id.in_(job_ids))) + + with self.bind.begin() as conn: + conn.execute(statement) diff --git a/src/apscheduler/enums.py b/src/apscheduler/enums.py new file mode 100644 index 0000000..fc449da --- /dev/null +++ b/src/apscheduler/enums.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class RunState(Enum): + starting = auto() + started = auto() + stopping = auto() + stopped = auto() diff --git a/src/apscheduler/events.py b/src/apscheduler/events.py new file mode 100644 index 0000000..0f09720 --- /dev/null +++ b/src/apscheduler/events.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import logging +from abc import abstractmethod +from asyncio import iscoroutinefunction +from concurrent.futures.thread import ThreadPoolExecutor +from dataclasses import dataclass, field +from datetime import datetime, timezone +from functools import partial +from inspect import isawaitable, isclass +from logging import Logger +from typing import Any, Callable, Dict, FrozenSet, Iterable, NewType, Optional, Set, Type, Union +from uuid import UUID + +import attr +from anyio import create_task_group +from anyio.abc import TaskGroup + +from . import abc + +SubscriptionToken = NewType('SubscriptionToken', object) + + +def timestamp_to_datetime(value: Union[datetime, float, None]) -> Optional[datetime]: + if isinstance(value, float): + return datetime.fromtimestamp(value, timezone.utc) + + return value + + +@attr.define(kw_only=True, frozen=True) +class Event: + timestamp: datetime = attr.field(factory=partial(datetime.now, timezone.utc), + converter=timestamp_to_datetime) + + +# +# Data store events +# + +@attr.define(kw_only=True, frozen=True) +class DataStoreEvent(Event): + pass + + +@attr.define(kw_only=True, frozen=True) +class ScheduleAdded(DataStoreEvent): + schedule_id: str + next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime) + + +@attr.define(kw_only=True, frozen=True) +class ScheduleUpdated(DataStoreEvent): + schedule_id: str + next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime) + + +@attr.define(kw_only=True, frozen=True) +class ScheduleRemoved(DataStoreEvent): + schedule_id: str + + +@attr.define(kw_only=True, frozen=True) +class JobAdded(DataStoreEvent): + job_id: UUID + task_id: str + schedule_id: Optional[str] + tags: FrozenSet[str] + + +@attr.define(kw_only=True, frozen=True) +class JobRemoved(DataStoreEvent): + job_id: UUID + + +@attr.define(kw_only=True, frozen=True) +class ScheduleDeserializationFailed(DataStoreEvent): + schedule_id: str + exception: BaseException + + +@attr.define(kw_only=True, frozen=True) +class JobDeserializationFailed(DataStoreEvent): + job_id: UUID + exception: BaseException + + +# +# Scheduler events +# + +@attr.define(kw_only=True, frozen=True) +class SchedulerEvent(Event): + pass + + +@attr.define(kw_only=True, frozen=True) +class SchedulerStarted(SchedulerEvent): + pass + + +@attr.define(kw_only=True, frozen=True) +class SchedulerStopped(SchedulerEvent): + exception: Optional[BaseException] = None + + +# +# Worker events +# + +@attr.define(kw_only=True, frozen=True) +class WorkerEvent(Event): + pass + + +@attr.define(kw_only=True, frozen=True) +class WorkerStarted(WorkerEvent): + pass + + +@attr.define(kw_only=True, frozen=True) +class WorkerStopped(WorkerEvent): + exception: Optional[BaseException] = None + + +@attr.define(kw_only=True, frozen=True) +class JobExecutionEvent(WorkerEvent): + job_id: UUID + task_id: str + schedule_id: Optional[str] + scheduled_fire_time: Optional[datetime] + start_deadline: Optional[datetime] + start_time: datetime + + +@attr.define(kw_only=True, frozen=True) +class JobStarted(JobExecutionEvent): + """Signals that a worker has started running a job.""" + + +@attr.define(kw_only=True, frozen=True) +class JobDeadlineMissed(JobExecutionEvent): + """Signals that a worker has skipped a job because its deadline was missed.""" + + +@attr.define(kw_only=True, frozen=True) +class JobCompleted(JobExecutionEvent): + """Signals that a worker has successfully run a job.""" + return_value: Any + + +@attr.define(kw_only=True, frozen=True) +class JobFailed(JobExecutionEvent): + """Signals that a worker encountered an exception while running a job.""" + exception: str + traceback: str + + +_all_event_types = [x for x in locals().values() if isclass(x) and issubclass(x, Event)] + + +# +# Event delivery +# + +@dataclass(eq=False, frozen=True) +class Subscription: + callback: Callable[[Event], Any] + event_types: Optional[Set[Type[Event]]] + + +@dataclass +class _BaseEventHub(abc.EventSource): + _logger: Logger = field(init=False, default_factory=lambda: logging.getLogger(__name__)) + _subscriptions: Dict[SubscriptionToken, Subscription] = field(init=False, default_factory=dict) + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + types = set(event_types) if event_types else None + token = SubscriptionToken(object()) + subscription = Subscription(callback, types) + self._subscriptions[token] = subscription + return token + + def unsubscribe(self, token: SubscriptionToken) -> None: + self._subscriptions.pop(token, None) + + @abstractmethod + def publish(self, event: Event) -> None: + """Publish an event to all subscribers.""" + + def relay_events_from(self, source: abc.EventSource) -> SubscriptionToken: + return source.subscribe(self.publish) + + +class EventHub(_BaseEventHub): + _executor: ThreadPoolExecutor + + def __enter__(self) -> EventHub: + self._executor = ThreadPoolExecutor(1) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._executor.shutdown(wait=exc_type is None) + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + if iscoroutinefunction(callback): + raise ValueError('Coroutine functions are not supported as callbacks on a synchronous ' + 'event source') + + return super().subscribe(callback, event_types) + + def publish(self, event: Event) -> None: + def deliver_event(func: Callable[[Event], Any]) -> None: + try: + func(event) + except BaseException: + self._logger.exception('Error delivering %s event', event.__class__.__name__) + + event_type = type(event) + for subscription in list(self._subscriptions.values()): + if subscription.event_types is None or event_type in subscription.event_types: + self._executor.submit(deliver_event, subscription.callback) + + +class AsyncEventHub(_BaseEventHub): + _task_group: TaskGroup + + async def __aenter__(self) -> AsyncEventHub: + self._task_group = create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + del self._task_group + + def publish(self, event: Event) -> None: + async def deliver_event(func: Callable[[Event], Any]) -> None: + try: + retval = func(event) + if isawaitable(retval): + await retval + except BaseException: + self._logger.exception('Error delivering %s event', event.__class__.__name__) + + event_type = type(event) + for subscription in self._subscriptions.values(): + if subscription.event_types is None or event_type in subscription.event_types: + self._task_group.start_soon(deliver_event, subscription.callback) diff --git a/src/apscheduler/exceptions.py b/src/apscheduler/exceptions.py new file mode 100644 index 0000000..ec04f2c --- /dev/null +++ b/src/apscheduler/exceptions.py @@ -0,0 +1,59 @@ + + +class JobLookupError(KeyError): + """Raised when the job store cannot find a job for update or removal.""" + + def __init__(self, job_id): + super().__init__(u'No job by the id of %s was found' % job_id) + + +class ConflictingIdError(KeyError): + """ + Raised when trying to add a schedule to a store that already contains a schedule by that ID, + and the conflict policy of ``exception`` is used. + """ + + def __init__(self, schedule_id): + super().__init__( + f'This data store already contains a schedule with the identifier {schedule_id!r}') + + +class TransientJobError(ValueError): + """ + Raised when an attempt to add transient (with no func_ref) job to a persistent job store is + detected. + """ + + def __init__(self, job_id): + super().__init__( + f'Job ({job_id}) cannot be added to this job store because a reference to the ' + f'callable could not be determined.') + + +class SerializationError(Exception): + """Raised when a serializer fails to serialize the given object.""" + + +class DeserializationError(Exception): + """Raised when a serializer fails to deserialize the given object.""" + + +class MaxIterationsReached(Exception): + """ + Raised when a trigger has reached its maximum number of allowed computation iterations when + trying to calculate the next fire time. + """ + + +class SchedulerAlreadyRunningError(Exception): + """Raised when attempting to start or configure the scheduler when it's already running.""" + + def __str__(self): + return 'Scheduler is already running' + + +class SchedulerNotRunningError(Exception): + """Raised when attempting to shutdown the scheduler when it's not running.""" + + def __str__(self): + return 'Scheduler is not running' diff --git a/src/apscheduler/marshalling.py b/src/apscheduler/marshalling.py new file mode 100644 index 0000000..4c5e660 --- /dev/null +++ b/src/apscheduler/marshalling.py @@ -0,0 +1,123 @@ +import sys +from datetime import date, datetime, tzinfo +from functools import partial +from typing import Any, Callable, Tuple, overload + +from .exceptions import DeserializationError, SerializationError + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo +else: + from backports.zoneinfo import ZoneInfo + + +def marshal_object(obj) -> Tuple[str, Any]: + return f'{obj.__class__.__module__}:{obj.__class__.__qualname__}', obj.__getstate__() + + +def unmarshal_object(ref: str, state): + cls = callable_from_ref(ref) + instance = cls.__new__(cls) + instance.__setstate__(state) + return instance + + +@overload +def marshal_date(value: None) -> None: + ... + + +@overload +def marshal_date(value: date) -> str: + ... + + +def marshal_date(value): + return value.isoformat() if value is not None else None + + +@overload +def unmarshal_date(value: None) -> None: + ... + + +@overload +def unmarshal_date(value: str) -> date: + ... + + +def unmarshal_date(value): + if value is None: + return None + elif len(value) == 10: + return date.fromisoformat(value) + else: + return datetime.fromisoformat(value) + + +def marshal_timezone(value: tzinfo) -> str: + if isinstance(value, ZoneInfo): + return value.key + elif hasattr(value, 'zone'): # pytz timezones + return value.zone + + raise SerializationError( + f'Unserializable time zone: {value!r}\n' + f'Only time zones from the zoneinfo or pytz modules can be serialized.') + + +def unmarshal_timezone(value: str) -> ZoneInfo: + return ZoneInfo(value) + + +def callable_to_ref(func: Callable) -> str: + """ + Return a reference to the given callable. + + :raises SerializationError: if the given object is not callable, is a partial(), lambda or + local function or does not have the ``__module__`` and ``__qualname__`` attributes + + """ + if isinstance(func, partial): + raise SerializationError('Cannot create a reference to a partial()') + + if not hasattr(func, '__module__'): + raise SerializationError('Callable has no __module__ attribute') + if not hasattr(func, '__qualname__'): + raise SerializationError('Callable has no __qualname__ attribute') + if '' in func.__qualname__: + raise SerializationError('Cannot create a reference to a lambda') + if '' in func.__qualname__: + raise SerializationError('Cannot create a reference to a nested function') + + return f'{func.__module__}:{func.__qualname__}' + + +def callable_from_ref(ref: str) -> Callable: + """ + Return the callable pointed to by ``ref``. + + :raises DeserializationError: if the reference could not be resolved or the looked up object is + not callable + + """ + if ':' not in ref: + raise ValueError(f'Invalid reference: {ref}') + + modulename, rest = ref.split(':', 1) + try: + obj = __import__(modulename, fromlist=[rest]) + except ImportError: + raise LookupError(f'Error resolving reference {ref!r}: could not import module') + + try: + for name in rest.split('.'): + obj = getattr(obj, name) + except Exception: + raise DeserializationError(f'Error resolving reference {ref!r}: error looking up object') + + if not callable(obj): + raise DeserializationError(f'{ref!r} points to an object of type ' + f'{obj.__class__.__qualname__} which is not callable') + + return obj diff --git a/src/apscheduler/policies.py b/src/apscheduler/policies.py new file mode 100644 index 0000000..91fbd31 --- /dev/null +++ b/src/apscheduler/policies.py @@ -0,0 +1,19 @@ +from enum import Enum, auto + + +class ConflictPolicy(Enum): + #: replace the existing schedule with a new one + replace = auto() + #: keep the existing schedule as-is and drop the new schedule + do_nothing = auto() + #: raise an exception if a conflict is detected + exception = auto() + + +class CoalescePolicy(Enum): + #: run once, with the earliest fire time + earliest = auto() + #: run once, with the latest fire time + latest = auto() + #: submit one job for every accumulated fire time + all = auto() diff --git a/src/apscheduler/schedulers/__init__.py b/src/apscheduler/schedulers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py new file mode 100644 index 0000000..6c7ef8c --- /dev/null +++ b/src/apscheduler/schedulers/async_.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import os +import platform +from contextlib import AsyncExitStack +from datetime import datetime, timedelta, timezone +from logging import Logger, getLogger +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Type, Union +from uuid import uuid4 + +import anyio +from anyio import TASK_STATUS_IGNORED, Event, create_task_group, get_cancelled_exc_class +from anyio.abc import TaskGroup + +from ..abc import AsyncDataStore, DataStore, EventSource, Job, Schedule, Trigger +from ..adapters import AsyncDataStoreAdapter +from ..datastores.sync.memory import MemoryDataStore +from ..enums import RunState +from ..events import ( + AsyncEventHub, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, + SubscriptionToken) +from ..marshalling import callable_to_ref +from ..policies import CoalescePolicy, ConflictPolicy +from ..structures import Task +from ..workers.async_ import AsyncWorker + + +class AsyncScheduler(EventSource): + """An asynchronous (AnyIO based) scheduler implementation.""" + + _state: RunState = RunState.stopped + _wakeup_event: anyio.Event + _worker: Optional[AsyncWorker] = None + _task_group: Optional[TaskGroup] = None + + def __init__(self, data_store: Union[DataStore, AsyncDataStore] = None, *, + identity: Optional[str] = None, logger: Optional[Logger] = None, + start_worker: bool = True): + self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' + self.logger = logger or getLogger(__name__) + self.start_worker = start_worker + self._tasks: Dict[str, Task] = {} + self._exit_stack = AsyncExitStack() + self._events = AsyncEventHub() + + data_store = data_store or MemoryDataStore() + if isinstance(data_store, DataStore): + self.data_store = AsyncDataStoreAdapter(data_store) + else: + self.data_store = data_store + + @property + def worker(self) -> Optional[AsyncWorker]: + return self._worker + + async def __aenter__(self): + self._state = RunState.starting + self._wakeup_event = anyio.Event() + await self._exit_stack.__aenter__() + await self._exit_stack.enter_async_context(self._events) + + # Initialize the data store + await self._exit_stack.enter_async_context(self.data_store) + relay_token = self._events.relay_events_from(self.data_store) + self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + + # Wake up the scheduler if the data store emits a significant schedule event + wakeup_token = self.data_store.subscribe( + lambda event: self._wakeup_event.set(), {ScheduleAdded, ScheduleUpdated}) + self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + + # Start the built-in worker, if configured to do so + if self.start_worker: + self._worker = AsyncWorker(self.data_store) + await self._exit_stack.enter_async_context(self._worker) + + # Start the worker and return when it has signalled readiness or raised an exception + self._task_group = create_task_group() + await self._exit_stack.enter_async_context(self._task_group) + await self._task_group.start(self.run) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._state = RunState.stopping + self._wakeup_event.set() + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + del self._task_group + del self._wakeup_event + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: SubscriptionToken) -> None: + self._events.unsubscribe(token) + + def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task: + task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id) + taskdef = self._tasks.get(task_id) + if not taskdef: + if isinstance(func_or_id, str): + raise LookupError('no task found with ID {!r}'.format(func_or_id)) + else: + taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id) + + return taskdef + + def define_task(self, func: Callable, task_id: Optional[str] = None, **kwargs): + if task_id is None: + task_id = callable_to_ref(func) + + task = Task(id=task_id, **kwargs) + if self._tasks.setdefault(task_id, task) is not task: + pass + + async def add_schedule( + self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None, + args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None, + coalesce: CoalescePolicy = CoalescePolicy.latest, + misfire_grace_time: Union[float, timedelta, None] = None, + tags: Optional[Iterable[str]] = None, + conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing + ) -> str: + id = id or str(uuid4()) + args = tuple(args or ()) + kwargs = dict(kwargs or {}) + tags = frozenset(tags or ()) + if isinstance(misfire_grace_time, (int, float)): + misfire_grace_time = timedelta(seconds=misfire_grace_time) + + taskdef = self._get_taskdef(task) + schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs, + coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags, + next_fire_time=trigger.next()) + await self.data_store.add_schedule(schedule, conflict_policy) + self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef, + trigger, schedule.next_fire_time) + return schedule.id + + async def remove_schedule(self, schedule_id: str) -> None: + await self.data_store.remove_schedules({schedule_id}) + + async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None: + if self._state is not RunState.starting: + raise RuntimeError(f'This function cannot be called while the scheduler is in the ' + f'{self._state} state') + + # Signal that the scheduler has started + self._state = RunState.started + task_status.started() + self._events.publish(SchedulerStarted()) + + try: + while self._state is RunState.started: + schedules = await self.data_store.acquire_schedules(self.identity, 100) + now = datetime.now(timezone.utc) + for schedule in schedules: + # Look up the task definition + try: + taskdef = self._get_taskdef(schedule.task_id) + except LookupError: + self.logger.error('Cannot locate task definition %r for schedule %r – ' + 'removing schedule', schedule.task_id, schedule.id) + schedule.next_fire_time = None + continue + + # Calculate a next fire time for the schedule, if possible + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + 'Error computing next fire time for schedule %r of task %r – ' + 'removing schedule', schedule.id, taskdef.id) + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + for fire_time in fire_times: + schedule.last_fire_time = fire_time + job = Job(taskdef.id, taskdef.func, schedule.args, schedule.kwargs, + schedule.id, fire_time, schedule.next_deadline, + schedule.tags) + await self.data_store.add_job(job) + + await self.data_store.release_schedules(self.identity, schedules) + + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() + except get_cancelled_exc_class(): + pass + except BaseException as exc: + self._state = RunState.stopped + self._events.publish(SchedulerStopped(exception=exc)) + raise + + self._state = RunState.stopped + self._events.publish(SchedulerStopped()) + + # async def stop(self, force: bool = False) -> None: + # self._running = False + # if self._worker: + # await self._worker.stop(force) + # + # if self._acquire_cancel_scope: + # self._acquire_cancel_scope.cancel() + # if force and self._task_group: + # self._task_group.cancel_scope.cancel() + # + # async def wait_until_stopped(self) -> None: + # if self._stop_event: + # await self._stop_event.wait() diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py new file mode 100644 index 0000000..6f330c9 --- /dev/null +++ b/src/apscheduler/schedulers/sync.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import os +import platform +import threading +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from contextlib import ExitStack +from datetime import datetime, timedelta, timezone +from logging import Logger, getLogger +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Type, Union +from uuid import uuid4 + +from ..abc import DataStore, EventSource, Trigger +from ..datastores.sync.memory import MemoryDataStore +from ..enums import RunState +from ..events import ( + Event, EventHub, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, + SubscriptionToken) +from ..marshalling import callable_to_ref +from ..policies import CoalescePolicy, ConflictPolicy +from ..structures import Job, Schedule, Task +from ..workers.sync import Worker + + +class Scheduler(EventSource): + """A synchronous scheduler implementation.""" + + _state: RunState = RunState.stopped + _wakeup_event: threading.Event + _worker: Optional[Worker] = None + + def __init__(self, data_store: Optional[DataStore] = None, *, identity: Optional[str] = None, + logger: Optional[Logger] = None, start_worker: bool = True): + self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' + self.logger = logger or getLogger(__name__) + self.start_worker = start_worker + self.data_store = data_store or MemoryDataStore() + self._tasks: Dict[str, Task] = {} + self._exit_stack = ExitStack() + self._executor = ThreadPoolExecutor(max_workers=1) + self._events = EventHub() + + @property + def state(self) -> RunState: + return self._state + + @property + def worker(self) -> Optional[Worker]: + return self._worker + + def __enter__(self) -> Scheduler: + self._state = RunState.starting + self._wakeup_event = threading.Event() + self._exit_stack.__enter__() + self._exit_stack.enter_context(self._events) + + # Initialize the data store + self._exit_stack.enter_context(self.data_store) + relay_token = self._events.relay_events_from(self.data_store) + self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + + # Wake up the scheduler if the data store emits a significant schedule event + wakeup_token = self.data_store.subscribe( + lambda event: self._wakeup_event.set(), {ScheduleAdded, ScheduleUpdated}) + self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + + # Start the built-in worker, if configured to do so + if self.start_worker: + self._worker = Worker(self.data_store) + self._exit_stack.enter_context(self._worker) + + # Start the worker and return when it has signalled readiness or raised an exception + start_future: Future[None] = Future() + token = self._events.subscribe(start_future.set_result) + run_future = self._executor.submit(self.run) + try: + wait([start_future, run_future], return_when=FIRST_COMPLETED) + finally: + self._events.unsubscribe(token) + + if run_future.done(): + run_future.result() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._state = RunState.stopping + self._wakeup_event.set() + self._executor.shutdown(wait=exc_type is None) + self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + del self._wakeup_event + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: SubscriptionToken) -> None: + self._events.unsubscribe(token) + + def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task: + task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id) + taskdef = self._tasks.get(task_id) + if not taskdef: + if isinstance(func_or_id, str): + raise LookupError('no task found with ID {!r}'.format(func_or_id)) + else: + taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id) + + return taskdef + + def add_schedule( + self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None, + args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None, + coalesce: CoalescePolicy = CoalescePolicy.latest, + misfire_grace_time: Union[float, timedelta, None] = None, + tags: Optional[Iterable[str]] = None, + conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing + ) -> str: + id = id or str(uuid4()) + args = tuple(args or ()) + kwargs = dict(kwargs or {}) + tags = frozenset(tags or ()) + if isinstance(misfire_grace_time, (int, float)): + misfire_grace_time = timedelta(seconds=misfire_grace_time) + + taskdef = self._get_taskdef(task) + schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs, + coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags, + next_fire_time=trigger.next()) + self.data_store.add_schedule(schedule, conflict_policy) + self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef, + trigger, schedule.next_fire_time) + return schedule.id + + def remove_schedule(self, schedule_id: str) -> None: + self.data_store.remove_schedules({schedule_id}) + + def run(self) -> None: + if self._state is not RunState.starting: + raise RuntimeError(f'This function cannot be called while the scheduler is in the ' + f'{self._state} state') + + # Signal that the scheduler has started + self._state = RunState.started + self._events.publish(SchedulerStarted()) + + try: + while self._state is RunState.started: + schedules = self.data_store.acquire_schedules(self.identity, 100) + now = datetime.now(timezone.utc) + for schedule in schedules: + # Look up the task definition + try: + taskdef = self._get_taskdef(schedule.task_id) + except LookupError: + self.logger.error('Cannot locate task definition %r for schedule %r – ' + 'putting schedule on hold', schedule.task_id, + schedule.id) + schedule.next_fire_time = None + continue + + # Calculate a next fire time for the schedule, if possible + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + 'Error computing next fire time for schedule %r of task %r – ' + 'removing schedule', schedule.id, taskdef.id) + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + for fire_time in fire_times: + schedule.last_fire_time = fire_time + job = Job(taskdef.id, taskdef.func, schedule.args, schedule.kwargs, + schedule.id, fire_time, schedule.next_deadline, + schedule.tags) + self.data_store.add_job(job) + + self.data_store.release_schedules(self.identity, schedules) + + self._wakeup_event.wait() + self._wakeup_event = threading.Event() + except BaseException as exc: + self._state = RunState.stopped + self._events.publish(SchedulerStopped(exception=exc)) + raise + + self._state = RunState.stopped + self._events.publish(SchedulerStopped()) + + # def stop(self) -> None: + # self.portal.call(self._scheduler.stop) + # + # def wait_until_stopped(self) -> None: + # self.portal.call(self._scheduler.wait_until_stopped) diff --git a/src/apscheduler/serializers/__init__.py b/src/apscheduler/serializers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apscheduler/serializers/cbor.py b/src/apscheduler/serializers/cbor.py new file mode 100644 index 0000000..9fdf17b --- /dev/null +++ b/src/apscheduler/serializers/cbor.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass, field +from typing import Any, Dict + +from cbor2 import CBOREncodeTypeError, CBORTag, dumps, loads + +from ..abc import Serializer +from ..marshalling import marshal_object, unmarshal_object + + +@dataclass +class CBORSerializer(Serializer): + type_tag: int = 4664 + dump_options: Dict[str, Any] = field(default_factory=dict) + load_options: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + self.dump_options.setdefault('default', self._default_hook) + self.load_options.setdefault('tag_hook', self._tag_hook) + + def _default_hook(self, encoder, value): + if hasattr(value, '__getstate__'): + marshalled = marshal_object(value) + encoder.encode(CBORTag(self.type_tag, marshalled)) + else: + raise CBOREncodeTypeError(f'cannot serialize type {value.__class__.__name__}') + + def _tag_hook(self, decoder, tag: CBORTag, shareable_index: int = None): + if tag.tag == self.type_tag: + cls_ref, state = tag.value + return unmarshal_object(cls_ref, state) + + def serialize(self, obj) -> bytes: + return dumps(obj, **self.dump_options) + + def deserialize(self, serialized: bytes): + return loads(serialized, **self.load_options) diff --git a/src/apscheduler/serializers/json.py b/src/apscheduler/serializers/json.py new file mode 100644 index 0000000..7c2513b --- /dev/null +++ b/src/apscheduler/serializers/json.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass, field +from json import dumps, loads +from typing import Any, Dict + +from ..abc import Serializer +from ..marshalling import marshal_object, unmarshal_object + + +@dataclass +class JSONSerializer(Serializer): + magic_key: str = '_apscheduler_json' + dump_options: Dict[str, Any] = field(default_factory=dict) + load_options: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + self.dump_options['default'] = self._default_hook + self.load_options['object_hook'] = self._object_hook + + @classmethod + def _default_hook(cls, obj): + if hasattr(obj, '__getstate__'): + cls_ref, state = marshal_object(obj) + return {cls.magic_key: [cls_ref, state]} + + raise TypeError(f'Object of type {obj.__class__.__name__!r} is not JSON serializable') + + @classmethod + def _object_hook(cls, obj_state: Dict[str, Any]): + if cls.magic_key in obj_state: + ref, *rest = obj_state[cls.magic_key] + return unmarshal_object(ref, *rest) + + return obj_state + + def serialize(self, obj) -> bytes: + return dumps(obj, ensure_ascii=False, **self.dump_options).encode('utf-8') + + def deserialize(self, serialized: bytes): + return loads(serialized, **self.load_options) + + def serialize_to_unicode(self, obj) -> str: + return dumps(obj, ensure_ascii=False, **self.dump_options) + + def deserialize_from_unicode(self, serialized: str): + return loads(serialized, **self.load_options) diff --git a/src/apscheduler/serializers/pickle.py b/src/apscheduler/serializers/pickle.py new file mode 100644 index 0000000..b429990 --- /dev/null +++ b/src/apscheduler/serializers/pickle.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from pickle import dumps, loads + +from ..abc import Serializer + + +@dataclass(frozen=True) +class PickleSerializer(Serializer): + protocol: int = 4 + + def serialize(self, obj) -> bytes: + return dumps(obj, self.protocol) + + def deserialize(self, serialized: bytes): + return loads(serialized) diff --git a/src/apscheduler/structures.py b/src/apscheduler/structures.py new file mode 100644 index 0000000..34a4485 --- /dev/null +++ b/src/apscheduler/structures.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Callable, Dict, FrozenSet, Optional +from uuid import UUID, uuid4 + +from . import abc +from .policies import CoalescePolicy + + +@dataclass(eq=False) +class Task: + id: str + func: Callable + max_instances: Optional[int] = None + metadata_arg: Optional[str] = None + stateful: bool = False + misfire_grace_time: Optional[timedelta] = None + + def __eq__(self, other) -> bool: + if isinstance(other, Task): + return self.id == other.id + + return NotImplemented + + def __hash__(self) -> int: + return hash(self.id) + + +@dataclass(eq=False) +class Schedule: + id: str + task_id: str + trigger: abc.Trigger + args: tuple + kwargs: Dict[str, Any] + coalesce: CoalescePolicy + misfire_grace_time: Optional[timedelta] + tags: FrozenSet[str] + next_fire_time: Optional[datetime] = None + last_fire_time: Optional[datetime] = field(init=False, default=None) + + @property + def next_deadline(self) -> Optional[datetime]: + if self.next_fire_time and self.misfire_grace_time: + return self.next_fire_time + self.misfire_grace_time + + return None + + def __eq__(self, other) -> bool: + if isinstance(other, Schedule): + return self.id == other.id + + return NotImplemented + + def __hash__(self) -> int: + return hash(self.id) + + +@dataclass(eq=False) +class Job: + id: UUID = field(init=False, default_factory=uuid4) + task_id: str + func: Callable + args: tuple + kwargs: Dict[str, Any] + schedule_id: Optional[str] = None + scheduled_fire_time: Optional[datetime] = None + start_deadline: Optional[datetime] = None + tags: FrozenSet[str] = field(default_factory=frozenset) + started_at: Optional[datetime] = field(init=False, default=None) + + def __eq__(self, other) -> bool: + if isinstance(other, Job): + return self.id == other.id + + return NotImplemented + + def __hash__(self) -> int: + return hash(self.id) diff --git a/src/apscheduler/triggers/__init__.py b/src/apscheduler/triggers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apscheduler/triggers/calendarinterval.py b/src/apscheduler/triggers/calendarinterval.py new file mode 100644 index 0000000..c6d63da --- /dev/null +++ b/src/apscheduler/triggers/calendarinterval.py @@ -0,0 +1,145 @@ +from datetime import date, datetime, time, timedelta, tzinfo +from typing import Optional, Union + +from ..abc import Trigger +from ..marshalling import marshal_date, marshal_timezone, unmarshal_date, unmarshal_timezone +from ..util import timezone_repr +from ..validators import as_date, as_timezone, require_state_version + + +class CalendarIntervalTrigger(Trigger): + """ + Runs the task on specified calendar-based intervals always at the same exact time of day. + + When calculating the next date, the ``years`` and ``months`` parameters are first added to the + previous date while keeping the day of the month constant. This is repeated until the resulting + date is valid. After that, the ``weeks`` and ``days`` parameters are added to that date. + Finally, the date is combined with the given time (hour, minute, second) to form the final + datetime. + + This means that if the ``days`` or ``weeks`` parameters are not used, the task will always be + executed on the same day of the month at the same wall clock time, assuming the date and time + are valid. + + If the resulting datetime is invalid due to a daylight saving forward shift, the date is + discarded and the process moves on to the next date. If instead the datetime is ambiguous due + to a backward DST shift, the earlier of the two resulting datetimes is used. + + If no previous run time is specified when requesting a new run time (like when starting for the + first time or resuming after being paused), ``start_date`` is used as a reference and the next + valid datetime equal to or later than the current time will be returned. Otherwise, the next + valid datetime starting from the previous run time is returned, even if it's in the past. + + .. warning:: Be wary of setting a start date near the end of the month (29. – 31.) if you have + ``months`` specified in your interval, as this will skip the months where those days do not + exist. Likewise, setting the start date on the leap day (February 29th) and having + ``years`` defined may cause some years to be skipped. + + Users are also discouraged from using a time inside the target timezone's DST switching + period (typically around 2 am) since a date could either be skipped or repeated due to the + specified wall clock time either occurring twice or not at all. + + :param years: number of years to wait + :param months: number of months to wait + :param weeks: number of weeks to wait + :param days: number of days to wait + :param hour: hour to run the task at + :param minute: minute to run the task at + :param second: second to run the task at + :param start_date: first date to trigger on (defaults to current date if omitted) + :param end_date: latest possible date to trigger on + :param timezone: time zone to use for calculating the next fire time + """ + + __slots__ = ('years', 'months', 'weeks', 'days', 'start_date', 'end_date', 'timezone', '_time', + '_last_fire_date') + + def __init__(self, *, years: int = 0, months: int = 0, weeks: int = 0, days: int = 0, + hour: int = 0, minute: int = 0, second: int = 0, + start_date: Union[date, str, None] = None, + end_date: Union[date, str, None] = None, + timezone: Union[str, tzinfo] = 'local'): + self.years = years + self.months = months + self.weeks = weeks + self.days = days + self.timezone = as_timezone(timezone) + self.start_date = as_date(start_date) or datetime.now(self.timezone).date() + self.end_date = as_date(end_date) + self._time = time(hour, minute, second, tzinfo=timezone) + self._last_fire_date: Optional[date] = None + + if self.years == self.months == self.weeks == self.days == 0: + raise ValueError('interval must be at least 1 day long') + + if self.start_date and self.end_date and self.start_date > self.end_date: + raise ValueError('end_date cannot be earlier than start_date') + + def next(self) -> Optional[datetime]: + previous_date: date = self._last_fire_date + while True: + if previous_date: + year, month = previous_date.year, previous_date.month + while True: + month += self.months + year += self.years + (month - 1) // 12 + month = (month - 1) % 12 + 1 + try: + next_date = date(year, month, previous_date.day) + except ValueError: + pass # Nonexistent date + else: + next_date += timedelta(self.days + self.weeks * 7) + break + else: + next_date = self.start_date + + # Don't return any date past end_date + if self.end_date and next_date > self.end_date: + return None + + # Combine the date with the designated time and normalize the result + timestamp = datetime.combine(next_date, self._time).timestamp() + next_time = datetime.fromtimestamp(timestamp, self.timezone) + + # Check if the time is off due to normalization and a forward DST shift + if next_time.time() != self._time: + previous_date = next_time.date() + else: + self._last_fire_date = next_date + return next_time + + def __getstate__(self): + return { + 'version': 1, + 'interval': [self.years, self.months, self.weeks, self.days], + 'time': [self._time.hour, self._time.minute, self._time.second], + 'start_date': marshal_date(self.start_date), + 'end_date': marshal_date(self.end_date), + 'timezone': marshal_timezone(self.timezone), + 'last_fire_date': marshal_date(self._last_fire_date) + } + + def __setstate__(self, state): + require_state_version(self, state, 1) + self.years, self.months, self.weeks, self.days = state['interval'] + self.start_date = unmarshal_date(state['start_date']) + self.end_date = unmarshal_date(state['end_date']) + self.timezone = unmarshal_timezone(state['timezone']) + self._time = time(*state['time'], tzinfo=self.timezone) + self._last_fire_date = unmarshal_date(state['last_fire_date']) + + def __repr__(self): + fields = [] + for field in 'years', 'months', 'weeks', 'days': + value = getattr(self, field) + if value > 0: + fields.append(f'{field}={value}') + + fields.append(f'time={self._time.isoformat()!r}') + fields.append(f"start_date='{self.start_date}'") + if self.end_date: + fields.append(f"end_date='{self.end_date}'") + + fields.append(f'timezone={timezone_repr(self.timezone)!r}') + return f'{self.__class__.__name__}({", ".join(fields)})' diff --git a/src/apscheduler/triggers/combining.py b/src/apscheduler/triggers/combining.py new file mode 100644 index 0000000..04363fd --- /dev/null +++ b/src/apscheduler/triggers/combining.py @@ -0,0 +1,141 @@ +from abc import abstractmethod +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Sequence, Union + +from ..abc import Trigger +from ..exceptions import MaxIterationsReached +from ..marshalling import marshal_object, unmarshal_object +from ..validators import as_list, as_positive_integer, as_timedelta, require_state_version + + +class BaseCombiningTrigger(Trigger): + __slots__ = 'triggers', '_next_fire_times' + + def __init__(self, triggers: Sequence[Trigger]): + self.triggers = as_list(triggers, Trigger, 'triggers') + self._next_fire_times: List[Optional[datetime]] = [] + + def __getstate__(self) -> Dict[str, Any]: + return { + 'version': 1, + 'triggers': [marshal_object(trigger) for trigger in self.triggers], + 'next_fire_times': self._next_fire_times + } + + @abstractmethod + def __setstate__(self, state: Dict[str, Any]) -> None: + self.triggers = [unmarshal_object(*trigger_state) for trigger_state in state['triggers']] + self._next_fire_times = state['next_fire_times'] + + +class AndTrigger(BaseCombiningTrigger): + """ + Fires on times produced by the enclosed triggers whenever the fire times are within the given + threshold. + + If the produced fire times are not within the given threshold of each other, the trigger(s) + that produced the earliest fire time will be asked for their next fire time and the iteration + is restarted. If instead all of the triggers agree on a fire time, all the triggers are asked + for their next fire times and the earliest of the previously produced fire times will be + returned. + + This trigger will be finished when any of the enclosed trigger has finished. + + :param triggers: triggers to combine + :param threshold: maximum time difference between the next fire times of the triggers in order + for the earliest of them to be returned from :meth:`next` (in seconds, or as timedelta) + :param max_iterations: maximum number of iterations of fire time calculations before giving up + """ + + __slots__ = 'threshold', 'max_iterations' + + def __init__(self, triggers: Sequence[Trigger], *, threshold: Union[float, timedelta] = 1, + max_iterations: Optional[int] = 10000): + super().__init__(triggers) + self.threshold = as_timedelta(threshold, 'threshold') + self.max_iterations = as_positive_integer(max_iterations, 'max_iterations') + + def next(self) -> Optional[datetime]: + if not self._next_fire_times: + # Fill out the fire times on the first run + self._next_fire_times = [t.next() for t in self.triggers] + + for _ in range(self.max_iterations): + # Find the earliest and latest fire times + earliest_fire_time: Optional[datetime] = None + latest_fire_time: Optional[datetime] = None + for fire_time in self._next_fire_times: + # If any of the fire times is None, this trigger is finished + if fire_time is None: + return None + + if earliest_fire_time is None or earliest_fire_time > fire_time: + earliest_fire_time = fire_time + + if latest_fire_time is None or latest_fire_time < fire_time: + latest_fire_time = fire_time + + # Replace all the fire times that were within the threshold + for i, trigger in enumerate(self.triggers): + if self._next_fire_times[i] - earliest_fire_time <= self.threshold: + self._next_fire_times[i] = self.triggers[i].next() + + # If all the fire times were within the threshold, return the earliest one + if latest_fire_time - earliest_fire_time <= self.threshold: + self._next_fire_times = [t.next() for t in self.triggers] + return earliest_fire_time + else: + raise MaxIterationsReached + + def __getstate__(self) -> Dict[str, Any]: + state = super().__getstate__() + state['threshold'] = self.threshold.total_seconds() + state['max_iterations'] = self.max_iterations + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + require_state_version(self, state, 1) + super().__setstate__(state) + self.threshold = timedelta(seconds=state['threshold']) + self.max_iterations = state['max_iterations'] + + def __repr__(self): + return f'{self.__class__.__name__}({self.triggers}, ' \ + f'threshold={self.threshold.total_seconds()}, max_iterations={self.max_iterations})' + + +class OrTrigger(BaseCombiningTrigger): + """ + Fires on every fire time of every trigger in chronological order. + If two or more triggers produce the same fire time, it will only be used once. + + This trigger will be finished when none of the enclosed triggers can produce any new fire + times. + + :param triggers: triggers to combine + """ + + __slots__ = () + + def next(self) -> Optional[datetime]: + # Fill out the fire times on the first run + if not self._next_fire_times: + self._next_fire_times = [t.next() for t in self.triggers] + + # Find out the earliest of the fire times + earliest_time: Optional[datetime] = min([fire_time for fire_time in self._next_fire_times + if fire_time is not None], default=None) + if earliest_time is not None: + # Generate new fire times for the trigger(s) that generated the earliest fire time + for i, fire_time in enumerate(self._next_fire_times): + if fire_time == earliest_time: + self._next_fire_times[i] = self.triggers[i].next() + + return earliest_time + + def __setstate__(self, state: Dict[str, Any]) -> None: + require_state_version(self, state, 1) + super().__setstate__(state) + + def __repr__(self): + return f'{self.__class__.__name__}({self.triggers})' diff --git a/src/apscheduler/triggers/cron/__init__.py b/src/apscheduler/triggers/cron/__init__.py new file mode 100644 index 0000000..682cd04 --- /dev/null +++ b/src/apscheduler/triggers/cron/__init__.py @@ -0,0 +1,216 @@ +from datetime import datetime, timedelta, tzinfo +from typing import ClassVar, List, Optional, Sequence, Tuple, Union + +from ...abc import Trigger +from ...marshalling import marshal_date, marshal_timezone, unmarshal_date, unmarshal_timezone +from ...util import timezone_repr +from ...validators import as_aware_datetime, as_timezone, require_state_version +from .fields import ( + DEFAULT_VALUES, BaseField, DayOfMonthField, DayOfWeekField, MonthField, WeekField) + + +class CronTrigger(Trigger): + """ + Triggers when current time matches all specified time constraints, similarly to how the UNIX + cron scheduler works. + + :param year: 4-digit year + :param month: month (1-12) + :param day: day of the (1-31) + :param week: ISO week (1-53) + :param day_of_week: number or name of weekday (0-7 or sun,mon,tue,wed,thu,fri,sat,sun) + :param hour: hour (0-23) + :param minute: minute (0-59) + :param second: second (0-59) + :param start_time: earliest possible date/time to trigger on (defaults to current time) + :param end_time: latest possible date/time to trigger on + :param timezone: time zone to use for the date/time calculations + (defaults to the local timezone) + + .. note:: The first weekday is always **monday**. + """ + + __slots__ = 'timezone', 'start_time', 'end_time', '_fields', '_last_fire_time' + + FIELDS_MAP: ClassVar[List] = [ + ('year', BaseField), + ('month', MonthField), + ('day', DayOfMonthField), + ('week', WeekField), + ('day_of_week', DayOfWeekField), + ('hour', BaseField), + ('minute', BaseField), + ('second', BaseField) + ] + + def __init__(self, *, year: Union[int, str, None] = None, month: Union[int, str, None] = None, + day: Union[int, str, None] = None, week: Union[int, str, None] = None, + day_of_week: Union[int, str, None] = None, hour: Union[int, str, None] = None, + minute: Union[int, str, None] = None, second: Union[int, str, None] = None, + start_time: Union[datetime, str, None] = None, + end_time: Union[datetime, str, None] = None, + timezone: Union[str, tzinfo, None] = None): + self.timezone = as_timezone(timezone) + self.start_time = (as_aware_datetime(start_time, self.timezone) + or datetime.now(self.timezone)) + self.end_time = as_aware_datetime(end_time, self.timezone) + self._set_fields([year, month, day, week, day_of_week, hour, minute, second]) + self._last_fire_time: Optional[datetime] = None + + def _set_fields(self, values: Sequence[Union[int, str, None]]) -> None: + self._fields = [] + assigned_values = {field_name: value + for (field_name, _), value in zip(self.FIELDS_MAP, values) + if value is not None} + for field_name, field_class in self.FIELDS_MAP: + exprs = assigned_values.pop(field_name, None) + if exprs is None: + exprs = '*' if assigned_values else DEFAULT_VALUES[field_name] + + field = field_class(field_name, exprs) + self._fields.append(field) + + @classmethod + def from_crontab(cls, expr: str, timezone: Union[str, tzinfo] = 'local') -> 'CronTrigger': + """ + Create a :class:`~CronTrigger` from a standard crontab expression. + + See https://en.wikipedia.org/wiki/Cron for more information on the format accepted here. + + :param expr: minute, hour, day of month, month, day of week + :param timezone: time zone to use for the date/time calculations + (defaults to local timezone if omitted) + + """ + values = expr.split() + if len(values) != 5: + raise ValueError(f'Wrong number of fields; got {len(values)}, expected 5') + + return cls(minute=values[0], hour=values[1], day=values[2], month=values[3], + day_of_week=values[4], timezone=timezone) + + def _increment_field_value(self, dateval: datetime, fieldnum: int) -> Tuple[datetime, int]: + """ + Increments the designated field and resets all less significant fields to their minimum + values. + + :return: a tuple containing the new date, and the number of the field that was actually + incremented + """ + + values = {} + i = 0 + while i < len(self._fields): + field = self._fields[i] + if not field.real: + if i == fieldnum: + fieldnum -= 1 + i -= 1 + else: + i += 1 + continue + + if i < fieldnum: + values[field.name] = field.get_value(dateval) + i += 1 + elif i > fieldnum: + values[field.name] = field.get_min(dateval) + i += 1 + else: + value = field.get_value(dateval) + maxval = field.get_max(dateval) + if value == maxval: + fieldnum -= 1 + i -= 1 + else: + values[field.name] = value + 1 + i += 1 + + difference = datetime(**values) - dateval.replace(tzinfo=None) + dateval = datetime.fromtimestamp(dateval.timestamp() + difference.total_seconds(), + self.timezone) + return dateval, fieldnum + # return datetime_normalize(dateval + difference), fieldnum + + def _set_field_value(self, dateval, fieldnum, new_value): + values = {} + for i, field in enumerate(self._fields): + if field.real: + if i < fieldnum: + values[field.name] = field.get_value(dateval) + elif i > fieldnum: + values[field.name] = field.get_min(dateval) + else: + values[field.name] = new_value + + return datetime(**values, tzinfo=self.timezone) + + def next(self) -> Optional[datetime]: + if self._last_fire_time: + start_time = self._last_fire_time + timedelta(microseconds=1) + else: + start_time = self.start_time + + fieldnum = 0 + next_time = datetime_ceil(start_time).astimezone(self.timezone) + while 0 <= fieldnum < len(self._fields): + field = self._fields[fieldnum] + curr_value = field.get_value(next_time) + next_value = field.get_next_value(next_time) + + if next_value is None: + # No valid value was found + next_time, fieldnum = self._increment_field_value(next_time, fieldnum - 1) + elif next_value > curr_value: + # A valid, but higher than the starting value, was found + if field.real: + next_time = self._set_field_value(next_time, fieldnum, next_value) + fieldnum += 1 + else: + next_time, fieldnum = self._increment_field_value(next_time, fieldnum) + else: + # A valid value was found, no changes necessary + fieldnum += 1 + + # Return if the date has rolled past the end date + if self.end_time and next_time > self.end_time: + return None + + if fieldnum >= 0: + self._last_fire_time = next_time + return next_time + + def __getstate__(self): + return { + 'version': 1, + 'timezone': marshal_timezone(self.timezone), + 'fields': [str(f) for f in self._fields], + 'start_time': marshal_date(self.start_time), + 'end_time': marshal_date(self.end_time), + 'last_fire_time': marshal_date(self._last_fire_time) + } + + def __setstate__(self, state): + require_state_version(self, state, 1) + self.timezone = unmarshal_timezone(state['timezone']) + self.start_time = unmarshal_date(state['start_time']) + self.end_time = unmarshal_date(state['end_time']) + self._last_fire_time = unmarshal_date(state['last_fire_time']) + self._set_fields(state['fields']) + + def __repr__(self): + fields = [f'{field.name}={str(field)!r}' for field in self._fields] + fields.append(f'start_time={self.start_time.isoformat()!r}') + if self.end_time: + fields.append(f'end_time={self.end_time.isoformat()!r}') + + fields.append(f'timezone={timezone_repr(self.timezone)!r}') + return f'CronTrigger({", ".join(fields)})' + + +def datetime_ceil(dateval: datetime) -> datetime: + """Round the given datetime object upwards.""" + if dateval.microsecond > 0: + return dateval + timedelta(seconds=1, microseconds=-dateval.microsecond) + + return dateval diff --git a/src/apscheduler/triggers/cron/expressions.py b/src/apscheduler/triggers/cron/expressions.py new file mode 100644 index 0000000..b2d098a --- /dev/null +++ b/src/apscheduler/triggers/cron/expressions.py @@ -0,0 +1,209 @@ +"""This module contains the expressions applicable for CronTrigger's fields.""" + +import re +from calendar import monthrange +from datetime import datetime +from typing import Optional, Union + +from ...validators import as_int + +WEEKDAYS = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'] +MONTHS = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'] + + +def get_weekday_index(weekday: str) -> int: + try: + return WEEKDAYS.index(weekday.lower()) + except ValueError: + raise ValueError(f'Invalid weekday name {weekday!r}') from None + + +class AllExpression: + __slots__ = 'step' + + value_re = re.compile(r'\*(?:/(?P\d+))?$') + + def __init__(self, step: Union[str, int, None] = None): + self.step = as_int(step) + if self.step == 0: + raise ValueError('Step must be higher than 0') + + def validate_range(self, field_name: str, min_value: int, max_value: int) -> None: + value_range = max_value - min_value + if self.step and self.step > value_range: + raise ValueError(f'the step value ({self.step}) is higher than the total range of the ' + f'expression ({value_range})') + + def get_next_value(self, dateval: datetime, field) -> Optional[int]: + start = field.get_value(dateval) + minval = field.get_min(dateval) + maxval = field.get_max(dateval) + start = max(start, minval) + + if not self.step: + nextval = start + else: + distance_to_next = (self.step - (start - minval)) % self.step + nextval = start + distance_to_next + + return nextval if nextval <= maxval else None + + def __str__(self): + return f'*/{self.step}' if self.step else '*' + + +class RangeExpression(AllExpression): + __slots__ = 'first', 'last' + + value_re = re.compile(r'(?P\d+)(?:-(?P\d+))?(?:/(?P\d+))?$') + + def __init__(self, first: Union[str, int], last: Union[str, int, None] = None, + step: Union[str, int, None] = None): + super().__init__(step) + self.first = as_int(first) + self.last = as_int(last) + + if self.last is None and self.step is None: + self.last = self.first + if self.last is not None and self.first > self.last: + raise ValueError('The minimum value in a range must not be higher than the maximum') + + def validate_range(self, field_name: str, min_value: int, max_value: int) -> None: + super().validate_range(field_name, min_value, max_value) + if self.first < min_value: + raise ValueError(f'the first value ({self.first}) is lower than the minimum value ' + f'({min_value})') + if self.last is not None and self.last > max_value: + raise ValueError(f'the last value ({self.last}) is higher than the maximum value ' + f'({max_value})') + value_range = (self.last or max_value) - self.first + if self.step and self.step > value_range: + raise ValueError(f'the step value ({self.step}) is higher than the total range of the ' + f'expression ({value_range})') + + def get_next_value(self, date, field): + startval = field.get_value(date) + minval = field.get_min(date) + maxval = field.get_max(date) + + # Apply range limits + minval = max(minval, self.first) + maxval = min(maxval, self.last) if self.last is not None else maxval + nextval = max(minval, startval) + + # Apply the step if defined + if self.step: + distance_to_next = (self.step - (nextval - minval)) % self.step + nextval += distance_to_next + + return nextval if nextval <= maxval else None + + def __str__(self): + if self.last != self.first and self.last is not None: + rangeval = f'{self.first}-{self.last}' + else: + rangeval = str(self.first) + + if self.step: + return f'{rangeval}/{self.step}' + + return rangeval + + +class MonthRangeExpression(RangeExpression): + __slots__ = () + + value_re = re.compile(r'(?P[a-z]+)(?:-(?P[a-z]+))?', re.IGNORECASE) + + def __init__(self, first, last=None): + try: + first_num = MONTHS.index(first.lower()) + 1 + except ValueError: + raise ValueError(f'Invalid month name {first!r}') from None + + if last: + try: + last_num = MONTHS.index(last.lower()) + 1 + except ValueError: + raise ValueError(f'Invalid month name {last!r}') from None + else: + last_num = None + + super().__init__(first_num, last_num) + + def __str__(self): + if self.last != self.first and self.last is not None: + return f'{MONTHS[self.first - 1]}-{MONTHS[self.last - 1]}' + + return MONTHS[self.first - 1] + + +class WeekdayRangeExpression(RangeExpression): + __slots__ = () + + value_re = re.compile(r'(?P[a-z]+)(?:-(?P[a-z]+))?', re.IGNORECASE) + + def __init__(self, first: str, last: Optional[str] = None): + first_num = get_weekday_index(first) + last_num = get_weekday_index(last) if last else None + super().__init__(first_num, last_num) + + def __str__(self): + if self.last != self.first and self.last is not None: + return f'{WEEKDAYS[self.first]}-{WEEKDAYS[self.last]}' + + return WEEKDAYS[self.first] + + +class WeekdayPositionExpression(AllExpression): + __slots__ = 'option_num', 'weekday' + + options = ['1st', '2nd', '3rd', '4th', '5th', 'last'] + value_re = re.compile(r'(?P%s) +(?P(?:\d+|\w+))' % + '|'.join(options), re.IGNORECASE) + + def __init__(self, option_name: str, weekday_name: str): + super().__init__(None) + self.option_num = self.options.index(option_name.lower()) + try: + self.weekday = WEEKDAYS.index(weekday_name.lower()) + except ValueError: + raise ValueError(f'Invalid weekday name {weekday_name!r}') from None + + def get_next_value(self, dateval: datetime, field) -> Optional[int]: + # Figure out the weekday of the month's first day and the number of days in that month + first_day_wday, last_day = monthrange(dateval.year, dateval.month) + + # Calculate which day of the month is the first of the target weekdays + first_hit_day = self.weekday - first_day_wday + 1 + if first_hit_day <= 0: + first_hit_day += 7 + + # Calculate what day of the month the target weekday would be + if self.option_num < 5: + target_day = first_hit_day + self.option_num * 7 + else: + target_day = first_hit_day + ((last_day - first_hit_day) // 7) * 7 + + if last_day >= target_day >= dateval.day: + return target_day + else: + return None + + def __str__(self): + return f'{self.options[self.option_num]} {WEEKDAYS[self.weekday]}' + + +class LastDayOfMonthExpression(AllExpression): + __slots__ = () + + value_re = re.compile(r'last', re.IGNORECASE) + + def __init__(self): + super().__init__(None) + + def get_next_value(self, dateval: datetime, field): + return monthrange(dateval.year, dateval.month)[1] + + def __str__(self): + return 'last' diff --git a/src/apscheduler/triggers/cron/fields.py b/src/apscheduler/triggers/cron/fields.py new file mode 100644 index 0000000..1e19bb9 --- /dev/null +++ b/src/apscheduler/triggers/cron/fields.py @@ -0,0 +1,130 @@ +"""Fields represent CronTrigger options which map to :class:`~datetime.datetime` fields.""" + +import re +from calendar import monthrange +from datetime import datetime +from typing import Any, ClassVar, List, Optional, Sequence, Union + +from .expressions import ( + WEEKDAYS, AllExpression, LastDayOfMonthExpression, MonthRangeExpression, RangeExpression, + WeekdayPositionExpression, WeekdayRangeExpression, get_weekday_index) + +MIN_VALUES = {'year': 1970, 'month': 1, 'day': 1, 'week': 1, 'day_of_week': 0, 'hour': 0, + 'minute': 0, 'second': 0} +MAX_VALUES = {'year': 9999, 'month': 12, 'day': 31, 'week': 53, 'day_of_week': 7, 'hour': 23, + 'minute': 59, 'second': 59} +DEFAULT_VALUES = {'year': '*', 'month': 1, 'day': 1, 'week': '*', 'day_of_week': '*', 'hour': 0, + 'minute': 0, 'second': 0} +SEPARATOR = re.compile(' *, *') + + +class BaseField: + __slots__ = 'name', 'expressions' + + real: ClassVar[bool] = True + compilers: ClassVar[Any] = (AllExpression, RangeExpression) + + def __init_subclass__(cls, real: bool = True, extra_compilers: Sequence = ()): + cls.real = real + if extra_compilers: + cls.compilers += extra_compilers + + def __init__(self, name: str, exprs: Union[int, str]): + self.name = name + self.expressions: List = [] + for expr in SEPARATOR.split(str(exprs).strip()): + self.append_expression(expr) + + def get_min(self, dateval: datetime) -> int: + return MIN_VALUES[self.name] + + def get_max(self, dateval: datetime) -> int: + return MAX_VALUES[self.name] + + def get_value(self, dateval: datetime) -> int: + return getattr(dateval, self.name) + + def get_next_value(self, dateval: datetime) -> Optional[int]: + smallest = None + for expr in self.expressions: + value = expr.get_next_value(dateval, self) + if smallest is None or (value is not None and value < smallest): + smallest = value + + return smallest + + def append_expression(self, expr: str) -> None: + for compiler in self.compilers: + match = compiler.value_re.match(expr) + if match: + compiled_expr = compiler(**match.groupdict()) + + try: + compiled_expr.validate_range(self.name, MIN_VALUES[self.name], + MAX_VALUES[self.name]) + except ValueError as exc: + raise ValueError(f'Error validating expression {expr!r}: {exc}') from exc + + self.expressions.append(compiled_expr) + return + + raise ValueError(f'Unrecognized expression {expr!r} for field {self.name!r}') + + def __str__(self): + expr_strings = (str(e) for e in self.expressions) + return ','.join(expr_strings) + + +class WeekField(BaseField, real=False): + __slots__ = () + + def get_value(self, dateval: datetime) -> int: + return dateval.isocalendar()[1] + + +class DayOfMonthField(BaseField, + extra_compilers=(WeekdayPositionExpression, LastDayOfMonthExpression)): + __slots__ = () + + def get_max(self, dateval: datetime) -> int: + return monthrange(dateval.year, dateval.month)[1] + + +class DayOfWeekField(BaseField, real=False, extra_compilers=(WeekdayRangeExpression,)): + __slots__ = () + + def append_expression(self, expr: str) -> None: + # Convert numeric weekday expressions into textual ones + match = RangeExpression.value_re.match(expr) + if match: + groups = match.groups() + first = int(groups[0]) - 1 + first = 6 if first < 0 else first + if groups[1]: + last = int(groups[1]) - 1 + last = 6 if last < 0 else last + else: + last = first + + expr = f'{WEEKDAYS[first]}-{WEEKDAYS[last]}' + + # For expressions like Sun-Tue or Sat-Mon, add two expressions that together cover the + # expected weekdays + match = WeekdayRangeExpression.value_re.match(expr) + if match and match.groups()[1]: + groups = match.groups() + first_index = get_weekday_index(groups[0]) + last_index = get_weekday_index(groups[1]) + if first_index > last_index: + super().append_expression(f'{WEEKDAYS[0]}-{groups[1]}') + super().append_expression(f'{groups[0]}-{WEEKDAYS[-1]}') + return + + super().append_expression(expr) + + def get_value(self, dateval: datetime) -> int: + return dateval.weekday() + + +class MonthField(BaseField, extra_compilers=(MonthRangeExpression,)): + __slots__ = () diff --git a/src/apscheduler/triggers/date.py b/src/apscheduler/triggers/date.py new file mode 100644 index 0000000..636d978 --- /dev/null +++ b/src/apscheduler/triggers/date.py @@ -0,0 +1,45 @@ +from datetime import datetime, tzinfo +from typing import Optional, Union + +from ..abc import Trigger +from ..marshalling import marshal_date, unmarshal_date +from ..validators import as_aware_datetime, as_timezone, require_state_version + + +class DateTrigger(Trigger): + """ + Triggers once on the given date/time. + + :param run_time: the date/time to run the job at + :param timezone: time zone to use to convert ``run_time`` into a timezone aware datetime, if it + isn't already + """ + + __slots__ = 'run_time', '_completed' + + def __init__(self, run_time: datetime, timezone: Union[tzinfo, str] = 'local'): + timezone = as_timezone(timezone) + self.run_time = as_aware_datetime(run_time, timezone) + self._completed = False + + def next(self) -> Optional[datetime]: + if not self._completed: + self._completed = True + return self.run_time + else: + return None + + def __getstate__(self): + return { + 'version': 1, + 'run_time': marshal_date(self.run_time), + 'completed': self._completed + } + + def __setstate__(self, state): + require_state_version(self, state, 1) + self.run_time = unmarshal_date(state['run_time']) + self._completed = state['completed'] + + def __repr__(self): + return f"{self.__class__.__name__}('{self.run_time}')" diff --git a/src/apscheduler/triggers/interval.py b/src/apscheduler/triggers/interval.py new file mode 100644 index 0000000..bd5029a --- /dev/null +++ b/src/apscheduler/triggers/interval.py @@ -0,0 +1,99 @@ +from datetime import datetime, timedelta, tzinfo +from typing import Optional, Union + +from ..abc import Trigger +from ..marshalling import marshal_date, unmarshal_date +from ..validators import as_aware_datetime, as_timezone, require_state_version + + +class IntervalTrigger(Trigger): + """ + Triggers on specified intervals. + + The first trigger time is on ``start_time`` which is the moment the trigger was created unless + specifically overridden. If ``end_time`` is specified, the last trigger time will be at or + before that time. If no ``end_time`` has been given, the trigger will produce new trigger times + as long as the resulting datetimes are valid datetimes in Python. + + :param weeks: number of weeks to wait + :param days: number of days to wait + :param hours: number of hours to wait + :param minutes: number of minutes to wait + :param seconds: number of seconds to wait + :param microseconds: number of microseconds to wait + :param start_time: first trigger date/time + :param end_time: latest possible date/time to trigger on + :param timezone: time zone used to make any passed naive datetimes timezone aware + """ + + __slots__ = ('weeks', 'days', 'hours', 'minutes', 'seconds', 'microseconds', 'start_time', + 'end_time', '_interval', '_last_fire_time') + + def __init__(self, *, weeks: float = 0, days: float = 0, hours: float = 0, minutes: float = 0, + seconds: float = 0, microseconds: float = 0, + start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, + timezone: Union[tzinfo, str] = 'local'): + self.weeks = weeks + self.days = days + self.hours = hours + self.minutes = minutes + self.seconds = seconds + self.microseconds = microseconds + timezone = as_timezone(timezone) + self.start_time = as_aware_datetime(start_time or datetime.now(), timezone) + self.end_time = as_aware_datetime(end_time, timezone) + self._interval = timedelta(weeks=self.weeks, days=self.days, hours=self.hours, + minutes=self.minutes, seconds=self.seconds, + microseconds=self.microseconds) + self._last_fire_time = None + + if self._interval.total_seconds() <= 0: + raise ValueError('The time interval must be positive') + + if self.end_time and self.end_time < self.start_time: + raise ValueError('end_time cannot be earlier than start_time') + + def next(self) -> Optional[datetime]: + if self._last_fire_time is None: + self._last_fire_time = self.start_time + else: + self._last_fire_time = self._last_fire_time + self._interval + + if self.end_time is None or self._last_fire_time <= self.end_time: + return self._last_fire_time + else: + return None + + def __getstate__(self): + return { + 'version': 1, + 'interval': [self.weeks, self.days, self.hours, self.minutes, self.seconds, + self.microseconds], + 'start_time': marshal_date(self.start_time), + 'end_time': marshal_date(self.end_time), + 'last_fire_time': marshal_date(self._last_fire_time) + } + + def __setstate__(self, state): + require_state_version(self, state, 1) + self.weeks, self.days, self.hours, self.minutes, self.seconds, self.microseconds = \ + state['interval'] + self.start_time = unmarshal_date(state['start_time']) + self.end_time = unmarshal_date(state['end_time']) + self._last_fire_time = unmarshal_date(state['last_fire_time']) + self._interval = timedelta(weeks=self.weeks, days=self.days, hours=self.hours, + minutes=self.minutes, seconds=self.seconds, + microseconds=self.microseconds) + + def __repr__(self): + fields = [] + for field in 'weeks', 'days', 'hours', 'minutes', 'seconds', 'microseconds': + value = getattr(self, field) + if value > 0: + fields.append(f'{field}={value}') + + fields.append(f"start_time='{self.start_time}'") + if self.end_time: + fields.append(f"end_time='{self.end_time}'") + + return f'{self.__class__.__name__}({", ".join(fields)})' diff --git a/src/apscheduler/util.py b/src/apscheduler/util.py new file mode 100644 index 0000000..867c15a --- /dev/null +++ b/src/apscheduler/util.py @@ -0,0 +1,82 @@ +"""This module contains several handy functions primarily meant for internal use.""" +import sys +from datetime import datetime, tzinfo +from typing import Callable, TypeVar + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo +else: + from backports.zoneinfo import ZoneInfo + +T_Type = TypeVar('T_Type', bound=type) + + +class _Undefined: + def __bool__(self): + return False + + def __repr__(self): + return '' + + +undefined = _Undefined() #: a unique object that only signifies that no value is defined + + +def timezone_repr(timezone: tzinfo) -> str: + if isinstance(timezone, ZoneInfo): + return timezone.key + else: + return repr(timezone) + + +def absolute_datetime_diff(dateval1: datetime, dateval2: datetime) -> float: + return dateval1.timestamp() - dateval2.timestamp() + + +def reentrant(cls: T_Type) -> T_Type: + """ + Modifies a class so that its ``__enter__`` / ``__exit__`` (or ``__aenter__`` / ``__aexit__``) + methods track the number of times it has been entered and exited and only actually invoke + the ``__enter__()`` method on the first entry and ``__exit__()`` on the last exit. + + """ + cls._loans = 0 + + def __enter__(self): + self._loans += 1 + if self._loans == 1: + previous_enter(self) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self._loans + self._loans -= 1 + if self._loans == 0: + return previous_exit(self, exc_type, exc_val, exc_tb) + + async def __aenter__(self): + self._loans += 1 + if self._loans == 1: + await previous_aenter(self) + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + assert self._loans + self._loans -= 1 + if self._loans == 0: + return await previous_aexit(self, exc_type, exc_val, exc_tb) + + previous_enter: Callable = getattr(cls, '__enter__', None) + previous_exit: Callable = getattr(cls, '__exit__', None) + previous_aenter: Callable = getattr(cls, '__aenter__', None) + previous_aexit: Callable = getattr(cls, '__aexit__', None) + if previous_enter and previous_exit: + cls.__enter__ = __enter__ + cls.__exit__ = __exit__ + elif previous_aenter and previous_aexit: + cls.__aenter__ = __aenter__ + cls.__aexit__ = __aexit__ + + return cls diff --git a/src/apscheduler/validators.py b/src/apscheduler/validators.py new file mode 100644 index 0000000..10f47fd --- /dev/null +++ b/src/apscheduler/validators.py @@ -0,0 +1,165 @@ +import sys +from datetime import date, datetime, timedelta, timezone, tzinfo +from typing import Any, Dict, Optional, Union + +from tzlocal import get_localzone + +from .abc import Trigger +from .exceptions import DeserializationError + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo +else: + from backports.zoneinfo import ZoneInfo + + +def as_int(value) -> Optional[int]: + """Convert the value into an integer.""" + if value is None: + return None + + return int(value) + + +def as_timezone(value: Union[str, tzinfo, None]) -> tzinfo: + """ + Convert the value into a tzinfo object. + + If ``value`` is ``None`` or ``'local'``, use the local timezone. + + :param value: the value to be converted + :return: a timezone object + + """ + if value is None or value == 'local': + return get_localzone() + elif isinstance(value, str): + return ZoneInfo(value) + elif isinstance(value, tzinfo): + if value is timezone.utc: + return ZoneInfo('UTC') + else: + return value + + raise TypeError(f'Expected tzinfo instance or timezone name, got ' + f'{value.__class__.__qualname__} instead') + + +def as_date(value: Union[date, str, None]) -> Optional[date]: + """ + Convert the value to a date. + + :param value: the value to convert to a date + :return: a date object, or ``None`` if ``None`` was given + + """ + if value is None: + return None + elif isinstance(value, int): + return date.fromordinal(value) + elif isinstance(value, str): + return date.fromisoformat(value) + elif isinstance(value, datetime): + return value.date() + elif isinstance(value, date): + return value + + raise TypeError(f'Expected string or date, got {value.__class__.__qualname__} instead') + + +def as_timestamp(value: Optional[datetime]) -> Optional[float]: + if value is None: + return None + + return value.timestamp() + + +def as_ordinal_date(value: Optional[date]) -> Optional[int]: + if value is None: + return None + + return value.toordinal() + + +def as_aware_datetime(value: Union[datetime, str, None], tz: tzinfo) -> Optional[datetime]: + """ + Convert the value to a timezone aware datetime. + + :param value: a datetime, an ISO 8601 representation of a datetime, or ``None`` + :param tz: timezone to use for making the datetime timezone aware + :return: a timezone aware datetime, or ``None`` if ``None`` was given + + """ + if value is None: + return None + + if isinstance(value, str): + if value.upper().endswith('Z'): + value = value[:-1] + '+00:00' + + value = datetime.fromisoformat(value) + + if isinstance(value, datetime): + if not value.tzinfo: + return value.replace(tzinfo=tz) + else: + return value + + raise TypeError(f'Expected string or datetime, got {value.__class__.__qualname__} instead') + + +def positive_number(instance, attribute, value) -> None: + if value <= 0: + raise ValueError(f'Expected positive number, got {value} instead') + + +def non_negative_number(instance, attribute, value) -> None: + if value < 0: + raise ValueError(f'Expected non-negative number, got {value} instead') + + +def as_positive_integer(value, name: str) -> int: + if isinstance(value, int): + if value > 0: + return value + else: + raise ValueError(f'{name} must be positive') + + raise TypeError(f'{name} must be an integer, got {value.__class__.__name__} instead') + + +def as_timedelta(value, name: str) -> timedelta: + if isinstance(value, (int, float)): + value = timedelta(seconds=value) + + if isinstance(value, timedelta): + if value.total_seconds() < 0: + raise ValueError(f'{name} cannot be negative') + else: + return value + + raise TypeError(f'{name} must be a timedelta or number of seconds, got ' + f'{value.__class__.__name__} instead') + + +def as_list(value, element_type: type, name: str) -> list: + value = list(value) + for i, element in enumerate(value): + if not isinstance(element, element_type): + raise TypeError(f'Element at index {i} of {name} is not of the expected type ' + f'({element_type.__name__}') + + return value + + +def require_state_version(trigger: Trigger, state: Dict[str, Any], max_version: int) -> None: + try: + if state['version'] > max_version: + raise DeserializationError( + f'{trigger.__class__.__name__} received a serialized state with version ' + f'{state["version"]}, but it only supports up to version {max_version}. ' + f'This can happen when an older version of APScheduler is being used with a data ' + f'store that was previously used with a newer APScheduler version.' + ) + except KeyError as exc: + raise DeserializationError('Missing "version" key in the serialized state') from exc diff --git a/src/apscheduler/workers/__init__.py b/src/apscheduler/workers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py new file mode 100644 index 0000000..6f84937 --- /dev/null +++ b/src/apscheduler/workers/async_.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import os +import platform +from contextlib import AsyncExitStack +from datetime import datetime, timezone +from inspect import isawaitable +from logging import Logger, getLogger +from traceback import format_tb +from typing import Any, Callable, Iterable, Optional, Set, Type, Union +from uuid import UUID + +import anyio +from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class +from anyio.abc import CancelScope, TaskGroup + +from ..abc import AsyncDataStore, DataStore, EventSource, Job +from ..adapters import AsyncDataStoreAdapter +from ..enums import RunState +from ..events import ( + AsyncEventHub, Event, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, + SubscriptionToken, WorkerStarted, WorkerStopped) + + +class AsyncWorker(EventSource): + """Runs jobs locally in a task group.""" + + _task_group: Optional[TaskGroup] = None + _stop_event: Optional[anyio.Event] = None + _state: RunState = RunState.stopped + _acquire_cancel_scope: Optional[CancelScope] = None + _datastore_subscription: SubscriptionToken + _wakeup_event: anyio.Event + + def __init__(self, data_store: Union[DataStore, AsyncDataStore], *, + max_concurrent_jobs: int = 100, identity: Optional[str] = None, + logger: Optional[Logger] = None): + self.max_concurrent_jobs = max_concurrent_jobs + self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' + self.logger = logger or getLogger(__name__) + self._acquired_jobs: Set[Job] = set() + self._exit_stack = AsyncExitStack() + self._events = AsyncEventHub() + self._running_jobs: Set[UUID] = set() + + if self.max_concurrent_jobs < 1: + raise ValueError('max_concurrent_jobs must be at least 1') + + if isinstance(data_store, DataStore): + self.data_store = AsyncDataStoreAdapter(data_store) + else: + self.data_store = data_store + + @property + def state(self) -> RunState: + return self._state + + async def __aenter__(self): + self._state = RunState.starting + self._wakeup_event = anyio.Event() + await self._exit_stack.__aenter__() + await self._exit_stack.enter_async_context(self._events) + + # Initialize the data store + await self._exit_stack.enter_async_context(self.data_store) + relay_token = self._events.relay_events_from(self.data_store) + self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + + # Wake up the worker if the data store emits a significant job event + wakeup_token = self.data_store.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded}) + self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + + # Start the actual worker + self._task_group = create_task_group() + await self._exit_stack.enter_async_context(self._task_group) + await self._task_group.start(self.run) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._state = RunState.stopping + self._wakeup_event.set() + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + del self._task_group + del self._wakeup_event + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[Type[Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: SubscriptionToken) -> None: + self._events.unsubscribe(token) + + async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None: + if self._state is not RunState.starting: + raise RuntimeError(f'This function cannot be called while the worker is in the ' + f'{self._state} state') + + # Signal that the worker has started + self._state = RunState.started + task_status.started() + self._events.publish(WorkerStarted()) + + try: + while self._state is RunState.started: + limit = self.max_concurrent_jobs - len(self._running_jobs) + with CancelScope() as self._acquire_cancel_scope: + try: + jobs = await self.data_store.acquire_jobs(self.identity, limit) + finally: + del self._acquire_cancel_scope + + for job in jobs: + self._running_jobs.add(job.id) + self._task_group.start_soon(self._run_job, job) + + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() + except get_cancelled_exc_class(): + pass + except BaseException as exc: + self._state = RunState.stopped + self._events.publish(WorkerStopped(exception=exc)) + raise + + self._state = RunState.stopped + self._events.publish(WorkerStopped()) + + # async def _run_job(self, job: Job) -> None: + # # Check if the job started before the deadline + # start_time = datetime.now(timezone.utc) + # if job.start_deadline is not None and start_time > job.start_deadline: + # event = JobDeadlineMissed( + # timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + # schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + # start_time=start_time, start_deadline=job.start_deadline) + # self._events.publish(event) + # return + # + # now = datetime.now(timezone.utc) + # if job.start_deadline is not None: + # if now.timestamp() > job.start_deadline.timestamp(): + # self.logger.info('Missed the deadline of job %r', job.id) + # event = JobDeadlineMissed( + # now, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + # scheduled_fire_time=job.scheduled_fire_time, start_time=now, + # start_deadline=job.start_deadline + # ) + # await self.publish(event) + # return + # + # # Set the job as running and publish a job update event + # self.logger.info('Started job %r', job.id) + # job.started_at = now + # event = JobUpdated( + # timestamp=now, job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id + # ) + # await self.publish(event) + # + # self._num_running_jobs += 1 + # try: + # return_value = await self._call_job_func(job.func, job.args, job.kwargs) + # except BaseException as exc: + # self.logger.exception('Job %r raised an exception', job.id) + # event = JobFailed( + # timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + # schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + # start_time=now, start_deadline=job.start_deadline, + # traceback=format_exc(), exception=exc + # ) + # else: + # self.logger.info('Job %r completed successfully', job.id) + # event = JobSuccessful( + # timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + # schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + # start_time=now, start_deadline=job.start_deadline, return_value=return_value + # ) + # + # self._num_running_jobs -= 1 + # await self.data_store.release_jobs(self.identity, [job]) + # await self.publish(event) + # + # async def _call_job_func(self, func: Callable, args: tuple, kwargs: Dict[str, Any]): + # if not self.run_sync_functions_in_event_loop and not iscoroutinefunction(func): + # wrapped = partial(func, *args, **kwargs) + # return await to_thread.run_sync(wrapped) + # + # return_value = func(*args, **kwargs) + # if isinstance(return_value, Coroutine): + # return_value = await return_value + # + # return return_value + + async def _run_job(self, job: Job) -> None: + event: Event + try: + # Check if the job started before the deadline + start_time = datetime.now(timezone.utc) + if job.start_deadline is not None and start_time > job.start_deadline: + event = JobDeadlineMissed( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline) + self._events.publish(event) + return + + event = JobStarted( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline) + self._events.publish(event) + try: + retval = job.func(*job.args, **job.kwargs) + if isawaitable(retval): + retval = await retval + except BaseException as exc: + if exc.__class__.__module__ == 'builtins': + exc_name = exc.__class__.__qualname__ + else: + exc_name = f'{exc.__class__.__module__}.{exc.__class__.__qualname__}' + + formatted_traceback = '\n'.join(format_tb(exc.__traceback__)) + event = JobFailed( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline, exception=exc_name, + traceback=formatted_traceback) + self._events.publish(event) + else: + event = JobCompleted( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline, return_value=retval) + self._events.publish(event) + finally: + self._running_jobs.remove(job.id) + await self.data_store.release_jobs(self.identity, [job]) + + # async def stop(self, force: bool = False) -> None: + # self._running = False + # if self._acquire_cancel_scope: + # self._acquire_cancel_scope.cancel() + # + # if force and self._task_group: + # self._task_group.cancel_scope.cancel() + # + # async def wait_until_stopped(self) -> None: + # if self._stop_event: + # await self._stop_event.wait() diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py new file mode 100644 index 0000000..bd3b060 --- /dev/null +++ b/src/apscheduler/workers/sync.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import os +import platform +import threading +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from contextlib import ExitStack +from datetime import datetime, timezone +from logging import Logger, getLogger +from traceback import format_tb +from typing import Any, Callable, Iterable, Optional, Set, Type +from uuid import UUID + +from .. import events +from ..abc import DataStore, EventSource +from ..enums import RunState +from ..events import ( + EventHub, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, SubscriptionToken, + WorkerStarted, WorkerStopped) +from ..structures import Job + + +class Worker(EventSource): + """Runs jobs locally in a thread pool.""" + + _state: RunState = RunState.stopped + _wakeup_event: threading.Event + + def __init__(self, data_store: DataStore, *, max_concurrent_jobs: int = 20, + identity: Optional[str] = None, logger: Optional[Logger] = None): + self.max_concurrent_jobs = max_concurrent_jobs + self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' + self.logger = logger or getLogger(__name__) + self._acquired_jobs: Set[Job] = set() + self._exit_stack = ExitStack() + self._executor = ThreadPoolExecutor(max_workers=max_concurrent_jobs + 1) + self._events = EventHub() + self._running_jobs: Set[UUID] = set() + + if self.max_concurrent_jobs < 1: + raise ValueError('max_concurrent_jobs must be at least 1') + + self.data_store = data_store + + @property + def state(self) -> RunState: + return self._state + + def __enter__(self) -> Worker: + self._state = RunState.starting + self._wakeup_event = threading.Event() + self._exit_stack.__enter__() + self._exit_stack.enter_context(self._events) + + # Initialize the data store + self._exit_stack.enter_context(self.data_store) + relay_token = self._events.relay_events_from(self.data_store) + self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + + # Wake up the worker if the data store emits a significant job event + wakeup_token = self.data_store.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded}) + self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + + # Start the worker and return when it has signalled readiness or raised an exception + start_future: Future[None] = Future() + token = self._events.subscribe(start_future.set_result) + run_future = self._executor.submit(self.run) + try: + wait([start_future, run_future], return_when=FIRST_COMPLETED) + finally: + self._events.unsubscribe(token) + + if run_future.done(): + run_future.result() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._state = RunState.stopping + self._wakeup_event.set() + self._executor.shutdown(wait=exc_type is None) + self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + del self._wakeup_event + + def subscribe(self, callback: Callable[[events.Event], Any], + event_types: Optional[Iterable[Type[events.Event]]] = None) -> SubscriptionToken: + return self._events.subscribe(callback, event_types) + + def unsubscribe(self, token: events.SubscriptionToken) -> None: + self._events.unsubscribe(token) + + def run(self) -> None: + if self._state is not RunState.starting: + raise RuntimeError(f'This function cannot be called while the worker is in the ' + f'{self._state} state') + + # Signal that the worker has started + self._state = RunState.started + self._events.publish(WorkerStarted()) + + try: + while self._state is RunState.started: + available_slots = self.max_concurrent_jobs - len(self._running_jobs) + if available_slots: + jobs = self.data_store.acquire_jobs(self.identity, available_slots) + for job in jobs: + self._running_jobs.add(job.id) + self._executor.submit(self._run_job, job) + + self._wakeup_event.wait() + self._wakeup_event = threading.Event() + except BaseException as exc: + self._state = RunState.stopped + self._events.publish(WorkerStopped(exception=exc)) + raise + + self._state = RunState.stopped + self._events.publish(WorkerStopped()) + + def _run_job(self, job: Job) -> None: + try: + # Check if the job started before the deadline + start_time = datetime.now(timezone.utc) + if job.start_deadline is not None and start_time > job.start_deadline: + event = JobDeadlineMissed( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline) + self._events.publish(event) + return + + event = JobStarted( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline) + self._events.publish(event) + try: + retval = job.func(*job.args, **job.kwargs) + except BaseException as exc: + if exc.__class__.__module__ == 'builtins': + exc_name = exc.__class__.__qualname__ + else: + exc_name = f'{exc.__class__.__module__}.{exc.__class__.__qualname__}' + + formatted_traceback = '\n'.join(format_tb(exc.__traceback__)) + event = JobFailed( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline, exception=exc_name, + traceback=formatted_traceback) + self._events.publish(event) + else: + event = JobCompleted( + timestamp=datetime.now(timezone.utc), job_id=job.id, task_id=job.task_id, + schedule_id=job.schedule_id, scheduled_fire_time=job.scheduled_fire_time, + start_time=start_time, start_deadline=job.start_deadline, return_value=retval) + self._events.publish(event) + finally: + self._running_jobs.remove(job.id) + self.data_store.release_jobs(self.identity, [job]) -- cgit v1.2.1