diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-06-09 12:40:55 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-07-19 00:48:51 +0300 |
commit | f215c1ab45959095f6b499eb7b26356c5937ee8b (patch) | |
tree | 552687a0ed6e799b3da96eec5cd3fbb14d19f1b5 | |
parent | e3158fdf59a7c92a9449a566a2b746a4024e582f (diff) | |
download | apscheduler-f215c1ab45959095f6b499eb7b26356c5937ee8b.tar.gz |
Added support for starting the sync scheduler (and worker) without the context manager
25 files changed, 1649 insertions, 1174 deletions
diff --git a/pyproject.toml b/pyproject.toml index 00f51e1..004c1dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] requires = [ "setuptools >= 61", - "setuptools_scm[toml] >= 6.4" + "setuptools_scm >= 6.4" ] build-backend = "setuptools.build_meta" @@ -26,7 +26,7 @@ license = {text = "MIT"} urls = {Homepage = "https://github.com/agronholm/apscheduler"} requires-python = ">= 3.7" dependencies = [ - "anyio ~= 3.0", + "anyio ~= 3.6", "attrs >= 21.3", "tenacity ~= 8.0", "tzlocal >= 3.0", diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index 00bb281..c270bc6 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -108,15 +108,16 @@ class EventSource(metaclass=ABCMeta): class EventBroker(EventSource): """ - Interface for objects that can be used to publish notifications to interested subscribers. - - Can be used as a context manager. + Interface for objects that can be used to publish notifications to interested + subscribers. """ - def __enter__(self): - return self + @abstractmethod + def start(self) -> None: + pass - def __exit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + def stop(self, *, force: bool = False) -> None: pass @abstractmethod @@ -129,16 +130,14 @@ class EventBroker(EventSource): class AsyncEventBroker(EventSource): - """ - Asynchronous version of :class:`EventBroker`. - - Can be used as an asynchronous context manager. - """ + """Asynchronous version of :class:`EventBroker`.""" - async def __aenter__(self): - return self + @abstractmethod + async def start(self) -> None: + pass - async def __aexit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + async def stop(self, *, force: bool = False) -> None: pass @abstractmethod @@ -151,10 +150,12 @@ class AsyncEventBroker(EventSource): class DataStore: - def __enter__(self): - return self + @abstractmethod + def start(self, event_broker: EventBroker) -> None: + pass - def __exit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + def stop(self, *, force: bool = False) -> None: pass @property @@ -309,10 +310,12 @@ class DataStore: class AsyncDataStore: - async def __aenter__(self): - return self + @abstractmethod + async def start(self, event_broker: AsyncEventBroker) -> None: + pass - async def __aexit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + async def stop(self, *, force: bool = False) -> None: pass @property diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py index 7502327..3518d44 100644 --- a/src/apscheduler/converters.py +++ b/src/apscheduler/converters.py @@ -48,6 +48,17 @@ def as_enum(enum_class: type[TEnum]) -> Callable[[TEnum | str], TEnum]: return converter +def as_async_eventbroker( + value: abc.EventBroker | abc.AsyncEventBroker, +) -> abc.AsyncEventBroker: + if isinstance(value, abc.EventBroker): + from apscheduler.eventbrokers.async_adapter import AsyncEventBrokerAdapter + + return AsyncEventBrokerAdapter(value) + + return value + + def as_async_datastore(value: abc.DataStore | abc.AsyncDataStore) -> abc.AsyncDataStore: if isinstance(value, abc.DataStore): from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py index 3b2342e..99accd4 100644 --- a/src/apscheduler/datastores/async_adapter.py +++ b/src/apscheduler/datastores/async_adapter.py @@ -1,8 +1,6 @@ from __future__ import annotations -from contextlib import AsyncExitStack from datetime import datetime -from functools import partial from typing import Iterable from uuid import UUID @@ -10,43 +8,35 @@ import attrs from anyio import to_thread from anyio.from_thread import BlockingPortal -from ..abc import AsyncDataStore, AsyncEventBroker, DataStore, EventSource +from ..abc import AsyncEventBroker, DataStore from ..enums import ConflictPolicy -from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter +from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter, SyncEventBrokerAdapter from ..structures import Job, JobResult, Schedule, Task -from ..util import reentrant +from .base import BaseAsyncDataStore -@reentrant @attrs.define(eq=False) -class AsyncDataStoreAdapter(AsyncDataStore): +class AsyncDataStoreAdapter(BaseAsyncDataStore): original: DataStore _portal: BlockingPortal = attrs.field(init=False) - _events: AsyncEventBroker = attrs.field(init=False) - _exit_stack: AsyncExitStack = attrs.field(init=False) - @property - def events(self) -> EventSource: - return self._events - - async def __aenter__(self) -> AsyncDataStoreAdapter: - self._exit_stack = AsyncExitStack() + async def start(self, event_broker: AsyncEventBroker) -> None: + await super().start(event_broker) self._portal = BlockingPortal() - await self._exit_stack.enter_async_context(self._portal) - - self._events = AsyncEventBrokerAdapter(self.original.events, self._portal) - await self._exit_stack.enter_async_context(self._events) + await self._portal.__aenter__() - await to_thread.run_sync(self.original.__enter__) - self._exit_stack.push_async_exit( - partial(to_thread.run_sync, self.original.__exit__) - ) + if isinstance(event_broker, AsyncEventBrokerAdapter): + sync_event_broker = event_broker.original + else: + sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal) - return self + await to_thread.run_sync(lambda: self.original.start(sync_event_broker)) - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + async def stop(self, *, force: bool = False) -> None: + await to_thread.run_sync(lambda: self.original.stop(force=force)) + await self._portal.__aexit__(None, None, None) + await super().stop(force=force) async def add_task(self, task: Task) -> None: await to_thread.run_sync(self.original.add_task, task) diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py index aa56c6f..edf9b3e 100644 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ b/src/apscheduler/datastores/async_sqlalchemy.py @@ -17,9 +17,8 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter -from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule +from ..abc import AsyncEventBroker, Job, Schedule from ..enums import ConflictPolicy -from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import ( DataStoreEvent, JobAcquired, @@ -36,17 +35,14 @@ from ..events import ( from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..structures import JobResult, Task -from ..util import reentrant +from .base import BaseAsyncDataStore from .sqlalchemy import _BaseSQLAlchemyDataStore -@reentrant @attrs.define(eq=False) -class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): +class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore): engine: AsyncEngine - _events: AsyncEventBroker = attrs.field(factory=LocalAsyncEventBroker) - @classmethod def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: engine = create_async_engine(url, future=True) @@ -63,7 +59,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): reraise=True, ) - async def __aenter__(self): + async def start(self, event_broker: AsyncEventBroker) -> None: + await super().start(event_broker) + asynclib = sniffio.current_async_library() or "(unknown)" if asynclib != "asyncio": raise RuntimeError( @@ -92,16 +90,6 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): f"only version 1 is supported by this version of APScheduler" ) - await self._events.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._events.__aexit__(exc_type, exc_val, exc_tb) - - @property - def events(self) -> EventSource: - return self._events - async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: diff --git a/src/apscheduler/datastores/base.py b/src/apscheduler/datastores/base.py new file mode 100644 index 0000000..c05d28c --- /dev/null +++ b/src/apscheduler/datastores/base.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from apscheduler.abc import ( + AsyncDataStore, + AsyncEventBroker, + DataStore, + EventBroker, + EventSource, +) + + +class BaseDataStore(DataStore): + _events: EventBroker + + def start(self, event_broker: EventBroker) -> None: + self._events = event_broker + + def stop(self, *, force: bool = False) -> None: + del self._events + + @property + def events(self) -> EventSource: + return self._events + + +class BaseAsyncDataStore(AsyncDataStore): + _events: AsyncEventBroker + + async def start(self, event_broker: AsyncEventBroker) -> None: + self._events = event_broker + + async def stop(self, *, force: bool = False) -> None: + del self._events + + @property + def events(self) -> EventSource: + return self._events diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py index c2d7420..516d141 100644 --- a/src/apscheduler/datastores/memory.py +++ b/src/apscheduler/datastores/memory.py @@ -9,9 +9,8 @@ from uuid import UUID import attrs -from ..abc import DataStore, EventBroker, EventSource, Job, Schedule +from ..abc import Job, Schedule from ..enums import ConflictPolicy -from ..eventbrokers.local import LocalEventBroker from ..events import ( JobAcquired, JobAdded, @@ -25,7 +24,7 @@ from ..events import ( ) from ..exceptions import ConflictingIdError, TaskLookupError from ..structures import JobResult, Task -from ..util import reentrant +from .base import BaseDataStore max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) @@ -83,11 +82,9 @@ class JobState: return hash(self.job.id) -@reentrant @attrs.define(eq=False) -class MemoryDataStore(DataStore): +class MemoryDataStore(BaseDataStore): lock_expiration_delay: float = 30 - _events: EventBroker = attrs.Factory(LocalEventBroker) _tasks: dict[str, TaskState] = attrs.Factory(dict) _schedules: list[ScheduleState] = attrs.Factory(list) _schedules_by_id: dict[str, ScheduleState] = attrs.Factory(dict) @@ -111,17 +108,6 @@ class MemoryDataStore(DataStore): right_index = bisect_left(self._jobs, state) return self._jobs.index(state, left_index, right_index + 1) - def __enter__(self): - self._events.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._events.__exit__(exc_type, exc_val, exc_tb) - - @property - def events(self) -> EventSource: - return self._events - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: return [ state.schedule diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index 2669d03..59a345a 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator from collections import defaultdict -from contextlib import ExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger from typing import Any, Callable, ClassVar, Iterable @@ -18,7 +17,7 @@ from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.errors import ConnectionFailure, DuplicateKeyError -from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer +from ..abc import EventBroker, Job, Schedule, Serializer from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome from ..eventbrokers.local import LocalEventBroker from ..events import ( @@ -41,7 +40,7 @@ from ..exceptions import ( ) from ..serializers.pickle import PickleSerializer from ..structures import JobResult, RetrySettings, Task -from ..util import reentrant +from .base import BaseDataStore class CustomEncoder(TypeEncoder): @@ -62,14 +61,14 @@ def ensure_uuid_presentation(client: MongoClient) -> None: pass -@reentrant @attrs.define(eq=False) -class MongoDBDataStore(DataStore): +class MongoDBDataStore(BaseDataStore): client: MongoClient = attrs.field(validator=instance_of(MongoClient)) serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True) + event_broker = attrs.field(factory=LocalEventBroker, kw_only=True) database: str = attrs.field(default="apscheduler", kw_only=True) lock_expiration_delay: float = attrs.field(default=30, kw_only=True) - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) + retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True) start_from_scratch: bool = attrs.field(default=False, kw_only=True) _task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)] @@ -79,8 +78,6 @@ class MongoDBDataStore(DataStore): _job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)] _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__)) - _exit_stack: ExitStack = attrs.field(init=False, factory=ExitStack) - _events: EventBroker = attrs.field(init=False, factory=LocalEventBroker) _local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict) def __attrs_post_init__(self) -> None: @@ -108,10 +105,6 @@ class MongoDBDataStore(DataStore): client = MongoClient(uri) return cls(client, **options) - @property - def events(self) -> EventSource: - return self._events - def _retry(self) -> tenacity.Retrying: return tenacity.Retrying( stop=self.retry_settings.stop, @@ -128,7 +121,9 @@ class MongoDBDataStore(DataStore): retry_state.outcome.exception(), ) - def __enter__(self): + def start(self, event_broker: EventBroker) -> None: + super().start(event_broker) + server_info = self.client.server_info() if server_info["versionArray"] < [4, 0]: raise RuntimeError( @@ -136,9 +131,6 @@ class MongoDBDataStore(DataStore): f"{server_info['version']}" ) - self._exit_stack.__enter__() - self._exit_stack.enter_context(self._events) - for attempt in self._retry(): with attempt, self.client.start_session() as session: if self.start_from_scratch: @@ -153,11 +145,6 @@ class MongoDBDataStore(DataStore): self._jobs.create_index("tags", session=session) self._jobs_results.create_index("finished_at", session=session) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) - def add_task(self, task: Task) -> None: for attempt in self._retry(): with attempt: diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index cca223e..9ee03fa 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -31,9 +31,8 @@ from sqlalchemy.future import Engine, create_engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal -from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer +from ..abc import EventBroker, Job, Schedule, Serializer from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome -from ..eventbrokers.local import LocalEventBroker from ..events import ( Event, JobAcquired, @@ -52,7 +51,7 @@ from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..serializers.pickle import PickleSerializer from ..structures import JobResult, RetrySettings, Task -from ..util import reentrant +from .base import BaseDataStore class EmulatedUUID(TypeDecorator): @@ -219,13 +218,10 @@ class _BaseSQLAlchemyDataStore: return jobs -@reentrant @attrs.define(eq=False) -class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): +class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): engine: Engine - _events: EventBroker = attrs.field(init=False, factory=LocalEventBroker) - @classmethod def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: engine = create_engine(url) @@ -240,7 +236,9 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): reraise=True, ) - def __enter__(self): + def start(self, event_broker: EventBroker) -> None: + super().start(event_broker) + for attempt in self._retry(): with attempt, self.engine.begin() as conn: if self.start_from_scratch: @@ -259,16 +257,6 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): f"only version 1 is supported by this version of APScheduler" ) - self._events.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._events.__exit__(exc_type, exc_val, exc_tb) - - @property - def events(self) -> EventSource: - return self._events - def add_task(self, task: Task) -> None: insert = self.t_tasks.insert().values( id=task.id, diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py index ebceaad..4c1cb6c 100644 --- a/src/apscheduler/eventbrokers/async_adapter.py +++ b/src/apscheduler/eventbrokers/async_adapter.py @@ -1,39 +1,65 @@ from __future__ import annotations -from functools import partial +from typing import Any, Callable, Iterable import attrs from anyio import to_thread from anyio.from_thread import BlockingPortal -from apscheduler.abc import EventBroker -from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker +from apscheduler.abc import AsyncEventBroker, EventBroker, Subscription from apscheduler.events import Event -from apscheduler.util import reentrant -@reentrant @attrs.define(eq=False) -class AsyncEventBrokerAdapter(LocalAsyncEventBroker): +class AsyncEventBrokerAdapter(AsyncEventBroker): original: EventBroker - portal: BlockingPortal - - async def __aenter__(self): - await super().__aenter__() - if not self.portal: - self.portal = BlockingPortal() - self._exit_stack.enter_async_context(self.portal) + async def start(self) -> None: + await to_thread.run_sync(self.original.start) - await to_thread.run_sync(self.original.__enter__) - self._exit_stack.push_async_exit( - partial(to_thread.run_sync, self.original.__exit__) - ) + async def stop(self, *, force: bool = False) -> None: + await to_thread.run_sync(lambda: self.original.stop(force=force)) - # Relay events from the original broker to this one - self._exit_stack.enter_context( - self.original.subscribe(partial(self.portal.call, self.publish_local)) - ) + async def publish_local(self, event: Event) -> None: + await to_thread.run_sync(self.original.publish_local, event) async def publish(self, event: Event) -> None: await to_thread.run_sync(self.original.publish, event) + + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + return self.original.subscribe(callback, event_types, one_shot=one_shot) + + +@attrs.define(eq=False) +class SyncEventBrokerAdapter(EventBroker): + original: AsyncEventBroker + portal: BlockingPortal + + def start(self) -> None: + pass + + def stop(self, *, force: bool = False) -> None: + pass + + def publish_local(self, event: Event) -> None: + self.portal.call(self.original.publish_local, event) + + def publish(self, event: Event) -> None: + self.portal.call(self.original.publish, event) + + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + return self.portal.call( + lambda: self.original.subscribe(callback, event_types, one_shot=one_shot) + ) diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py index 98031bf..dba37f1 100644 --- a/src/apscheduler/eventbrokers/async_local.py +++ b/src/apscheduler/eventbrokers/async_local.py @@ -1,7 +1,6 @@ from __future__ import annotations from asyncio import iscoroutine -from contextlib import AsyncExitStack from typing import Any, Callable import attrs @@ -10,26 +9,19 @@ from anyio.abc import TaskGroup from ..abc import AsyncEventBroker from ..events import Event -from ..util import reentrant from .base import BaseEventBroker -@reentrant @attrs.define(eq=False) class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker): _task_group: TaskGroup = attrs.field(init=False) - _exit_stack: AsyncExitStack = attrs.field(init=False) - - async def __aenter__(self) -> LocalAsyncEventBroker: - self._exit_stack = AsyncExitStack() + async def start(self) -> None: self._task_group = create_task_group() - await self._exit_stack.enter_async_context(self._task_group) - - return self + await self._task_group.__aenter__() - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + async def stop(self, *, force: bool = False) -> None: + await self._task_group.__aexit__(None, None, None) del self._task_group async def publish(self, event: Event) -> None: diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py index d97c229..c80033e 100644 --- a/src/apscheduler/eventbrokers/asyncpg.py +++ b/src/apscheduler/eventbrokers/asyncpg.py @@ -4,13 +4,12 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable import attrs -from anyio import TASK_STATUS_IGNORED, sleep +from anyio import TASK_STATUS_IGNORED, CancelScope, sleep from asyncpg import Connection from asyncpg.pool import Pool from ..events import Event from ..exceptions import SerializationError -from ..util import reentrant from .async_local import LocalAsyncEventBroker from .base import DistributedEventBrokerMixin @@ -18,12 +17,12 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncEngine -@reentrant @attrs.define(eq=False) class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): connection_factory: Callable[[], AsyncContextManager[Connection]] channel: str = attrs.field(kw_only=True, default="apscheduler") max_idle_time: float = attrs.field(kw_only=True, default=30) + _listen_cancel_scope: CancelScope = attrs.field(init=False) @classmethod def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker: @@ -47,11 +46,15 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): return cls(connection_factory) - async def __aenter__(self) -> AsyncpgEventBroker: - await super().__aenter__() - await self._task_group.start(self._listen_notifications) - self._exit_stack.callback(self._task_group.cancel_scope.cancel) - return self + async def start(self) -> None: + await super().start() + self._listen_cancel_scope = await self._task_group.start( + self._listen_notifications + ) + + async def stop(self, *, force: bool = False) -> None: + self._listen_cancel_scope.cancel() + await super().stop(force=force) async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: def callback(connection, pid, channel: str, payload: str) -> None: @@ -60,19 +63,20 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): self._task_group.start_soon(self.publish_local, event) task_started_sent = False - while True: - async with self.connection_factory() as conn: - await conn.add_listener(self.channel, callback) - if not task_started_sent: - task_status.started() - task_started_sent = True - - try: - while True: - await sleep(self.max_idle_time) - await conn.execute("SELECT 1") - finally: - await conn.remove_listener(self.channel, callback) + with CancelScope() as cancel_scope: + while True: + async with self.connection_factory() as conn: + await conn.add_listener(self.channel, callback) + if not task_started_sent: + task_status.started(cancel_scope) + task_started_sent = True + + try: + while True: + await sleep(self.max_idle_time) + await conn.execute("SELECT 1") + finally: + await conn.remove_listener(self.channel, callback) async def publish(self, event: Event) -> None: notification = self.generate_notification_str(event) diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index 38c83c7..3aeee77 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Iterable import attrs from .. import events -from ..abc import EventBroker, Serializer, Subscription +from ..abc import EventSource, Serializer, Subscription from ..events import Event from ..exceptions import DeserializationError @@ -25,7 +25,7 @@ class LocalSubscription(Subscription): @attrs.define(eq=False) -class BaseEventBroker(EventBroker): +class BaseEventBroker(EventSource): _logger: Logger = attrs.field(init=False) _subscriptions: dict[object, LocalSubscription] = attrs.field( init=False, factory=dict @@ -70,7 +70,7 @@ class DistributedEventBrokerMixin: self._logger.exception( "Failed to deserialize an event of type %s", event_type, - serialized=serialized, + extra={"serialized": serialized}, ) return None @@ -80,7 +80,7 @@ class DistributedEventBrokerMixin: self._logger.error( "Receive notification for a nonexistent event type: %s", event_type, - serialized=serialized, + extra={"serialized": serialized}, ) return None @@ -94,7 +94,9 @@ class DistributedEventBrokerMixin: try: event_type_bytes, serialized = payload.split(b" ", 1) except ValueError: - self._logger.error("Received malformatted notification", payload=payload) + self._logger.error( + "Received malformatted notification", extra={"payload": payload} + ) return None event_type = event_type_bytes.decode("ascii", errors="replace") @@ -104,7 +106,9 @@ class DistributedEventBrokerMixin: try: event_type, b64_serialized = payload.split(" ", 1) except ValueError: - self._logger.error("Received malformatted notification", payload=payload) + self._logger.error( + "Received malformatted notification", extra={"payload": payload} + ) return None return self._reconstitute_event(event_type, b64decode(b64_serialized)) diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py index cf19526..a870c45 100644 --- a/src/apscheduler/eventbrokers/local.py +++ b/src/apscheduler/eventbrokers/local.py @@ -8,26 +8,22 @@ from typing import Any, Callable, Iterable import attrs -from ..abc import Subscription +from ..abc import EventBroker, Subscription from ..events import Event -from ..util import reentrant from .base import BaseEventBroker -@reentrant @attrs.define(eq=False) -class LocalEventBroker(BaseEventBroker): +class LocalEventBroker(EventBroker, BaseEventBroker): _executor: ThreadPoolExecutor = attrs.field(init=False) _exit_stack: ExitStack = attrs.field(init=False) _subscriptions_lock: Lock = attrs.field(init=False, factory=Lock) - def __enter__(self): - self._exit_stack = ExitStack() - self._executor = self._exit_stack.enter_context(ThreadPoolExecutor(1)) - return self + def start(self) -> None: + self._executor = ThreadPoolExecutor(1) - def __exit__(self, exc_type, exc_val, exc_tb): - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + def stop(self, *, force: bool = False) -> None: + self._executor.shutdown(wait=not force) del self._executor def subscribe( diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py index 8b0a512..9b0f859 100644 --- a/src/apscheduler/eventbrokers/mqtt.py +++ b/src/apscheduler/eventbrokers/mqtt.py @@ -11,12 +11,10 @@ from paho.mqtt.reasoncodes import ReasonCodes from ..abc import Serializer from ..events import Event from ..serializers.json import JSONSerializer -from ..util import reentrant from .base import DistributedEventBrokerMixin from .local import LocalEventBroker -@reentrant @attrs.define(eq=False) class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client: Client @@ -28,8 +26,8 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): publish_qos: int = attrs.field(kw_only=True, default=0) _ready_future: Future[None] = attrs.field(init=False) - def __enter__(self): - super().__enter__() + def start(self) -> None: + super().start() self._ready_future = Future() self.client.enable_logger(self._logger) self.client.on_connect = self._on_connect @@ -38,11 +36,11 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): self.client.connect(self.host, self.port) self.client.loop_start() self._ready_future.result(10) - self._exit_stack.push( - lambda exc_type, *_: self.client.loop_stop(force=bool(exc_type)) - ) - self._exit_stack.callback(self.client.disconnect) - return self + + def stop(self, *, force: bool = False) -> None: + self.client.disconnect() + self.client.loop_stop(force=force) + super().stop() def _on_connect( self, diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py index 0f14e3e..a41ef37 100644 --- a/src/apscheduler/eventbrokers/redis.py +++ b/src/apscheduler/eventbrokers/redis.py @@ -9,12 +9,10 @@ from redis import ConnectionPool, Redis from ..abc import Serializer from ..events import Event from ..serializers.json import JSONSerializer -from ..util import reentrant from .base import DistributedEventBrokerMixin from .local import LocalEventBroker -@reentrant @attrs.define(eq=False) class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client: Redis @@ -23,6 +21,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): message_poll_interval: float = attrs.field(kw_only=True, default=0.05) _stopped: bool = attrs.field(init=False, default=True) _ready_future: Future[None] = attrs.field(init=False) + _thread: Thread = attrs.field(init=False) @classmethod def from_url(cls, url: str, **kwargs) -> RedisEventBroker: @@ -30,7 +29,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client = Redis(connection_pool=pool) return cls(client) - def __enter__(self): + def start(self) -> None: self._stopped = False self._ready_future = Future() self._thread = Thread( @@ -38,14 +37,14 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): ) self._thread.start() self._ready_future.result(10) - return super().__enter__() + super().start() - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self, *, force: bool = False) -> None: self._stopped = True - if not exc_type: + if not force: self._thread.join(5) - super().__exit__(exc_type, exc_val, exc_tb) + super().stop(force=force) def _listen_messages(self) -> None: while not self._stopped: diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index 72c5994..889596d 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -17,10 +17,11 @@ from anyio import ( get_cancelled_exc_class, move_on_after, ) +from anyio.abc import TaskGroup -from ..abc import AsyncDataStore, EventSource, Job, Schedule, Trigger +from ..abc import AsyncDataStore, AsyncEventBroker, Job, Schedule, Subscription, Trigger from ..context import current_scheduler -from ..converters import as_async_datastore +from ..converters import as_async_datastore, as_async_eventbroker from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker @@ -48,6 +49,9 @@ class AsyncScheduler: data_store: AsyncDataStore = attrs.field( converter=as_async_datastore, factory=MemoryDataStore ) + event_broker: AsyncEventBroker = attrs.field( + converter=as_async_eventbroker, factory=LocalAsyncEventBroker + ) identity: str = attrs.field(kw_only=True, default=None) start_worker: bool = attrs.field(kw_only=True, default=True) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) @@ -55,65 +59,24 @@ class AsyncScheduler: _state: RunState = attrs.field(init=False, default=RunState.stopped) _wakeup_event: anyio.Event = attrs.field(init=False) _wakeup_deadline: datetime | None = attrs.field(init=False, default=None) - _worker: AsyncWorker | None = attrs.field(init=False, default=None) - _events: LocalAsyncEventBroker = attrs.field( - init=False, factory=LocalAsyncEventBroker - ) - _exit_stack: AsyncExitStack = attrs.field(init=False) + _task_group: TaskGroup | None = attrs.field(init=False, default=None) + _schedule_added_subscription: Subscription = attrs.field(init=False) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" - @property - def events(self) -> EventSource: - return self._events - - @property - def worker(self) -> AsyncWorker | None: - return self._worker - async def __aenter__(self): - self._state = RunState.starting - self._wakeup_event = anyio.Event() - self._exit_stack = AsyncExitStack() - await self._exit_stack.__aenter__() - await self._exit_stack.enter_async_context(self._events) - - # Initialize the data store and start relaying events to the scheduler's event broker - await self._exit_stack.enter_async_context(self.data_store) - self._exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) - - # Wake up the scheduler if the data store emits a significant schedule event - self._exit_stack.enter_context( - self.data_store.events.subscribe( - self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} - ) - ) - - # Start the built-in worker, if configured to do so - if self.start_worker: - token = current_scheduler.set(self) - try: - self._worker = AsyncWorker(self.data_store) - await self._exit_stack.enter_async_context(self._worker) - finally: - current_scheduler.reset(token) - - # Start the worker and return when it has signalled readiness or raised an exception - task_group = create_task_group() - await self._exit_stack.enter_async_context(task_group) - await task_group.start(self.run) + self._task_group = create_task_group() + await self._task_group.__aenter__() + await self._task_group.start(self.run_until_stopped) return self async def __aexit__(self, exc_type, exc_val, exc_tb): self._state = RunState.stopping self._wakeup_event.set() - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - self._state = RunState.stopped - del self._wakeup_event + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + self._task_group = None def _schedule_added_or_modified(self, event: Event) -> None: event_ = cast("ScheduleAdded | ScheduleUpdated", event) @@ -281,130 +244,169 @@ class AsyncScheduler: else: raise RuntimeError(f"Unknown job outcome: {result.outcome}") - async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None: - if self._state is not RunState.starting: + async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None: + if self._state is not RunState.stopped: raise RuntimeError( - f"This function cannot be called while the scheduler is in the " - f"{self._state} state" + f'Cannot start the scheduler when it is in the "{self._state}" ' + f"state" ) - # Signal that the scheduler has started - self._state = RunState.started - task_status.started() - await self._events.publish(SchedulerStarted()) - - exception: BaseException | None = None - try: - while self._state is RunState.started: - schedules = await self.data_store.acquire_schedules(self.identity, 100) - now = datetime.now(timezone.utc) - for schedule in schedules: - # Calculate a next fire time for the schedule, if possible - fire_times = [schedule.next_fire_time] - calculate_next = schedule.trigger.next - while True: - try: - fire_time = calculate_next() - except Exception: - self.logger.exception( - "Error computing next fire time for schedule %r of task %r – " - "removing schedule", - schedule.id, - schedule.task_id, - ) - break - - # Stop if the calculated fire time is in the future - if fire_time is None or fire_time > now: - schedule.next_fire_time = fire_time - break - - # Only keep all the fire times if coalesce policy = "all" - if schedule.coalesce is CoalescePolicy.all: - fire_times.append(fire_time) - elif schedule.coalesce is CoalescePolicy.latest: - fire_times[0] = fire_time - - # Add one or more jobs to the job queue - max_jitter = ( - schedule.max_jitter.total_seconds() - if schedule.max_jitter - else 0 + self._state = RunState.starting + async with AsyncExitStack() as exit_stack: + self._wakeup_event = anyio.Event() + + # Initialize the event broker + await self.event_broker.start() + exit_stack.push_async_exit( + lambda *exc_info: self.event_broker.stop(force=exc_info[0] is not None) + ) + + # Initialize the data store + await self.data_store.start(self.event_broker) + exit_stack.push_async_exit( + lambda *exc_info: self.data_store.stop(force=exc_info[0] is not None) + ) + + # Wake up the scheduler if the data store emits a significant schedule event + self._schedule_added_subscription = self.event_broker.subscribe( + self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} + ) + + # Start the built-in worker, if configured to do so + if self.start_worker: + token = current_scheduler.set(self) + exit_stack.callback(current_scheduler.reset, token) + worker = AsyncWorker( + self.data_store, self.event_broker, is_internal=True + ) + await exit_stack.enter_async_context(worker) + + # Signal that the scheduler has started + self._state = RunState.started + task_status.started() + await self.event_broker.publish_local(SchedulerStarted()) + + exception: BaseException | None = None + try: + while self._state is RunState.started: + schedules = await self.data_store.acquire_schedules( + self.identity, 100 ) - for i, fire_time in enumerate(fire_times): - # Calculate a jitter if max_jitter > 0 - jitter = _zero_timedelta - if max_jitter: - if i + 1 < len(fire_times): - next_fire_time = fire_times[i + 1] - else: - next_fire_time = schedule.next_fire_time - - if next_fire_time is not None: - # Jitter must never be so high that it would cause a fire time to - # equal or exceed the next fire time - jitter_s = min( - [ - max_jitter, - ( - next_fire_time - - fire_time - - _microsecond_delta - ).total_seconds(), - ] + now = datetime.now(timezone.utc) + for schedule in schedules: + # Calculate a next fire time for the schedule, if possible + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + "Error computing next fire time for schedule %r of " + "task %r – removing schedule", + schedule.id, + schedule.task_id, ) - jitter = timedelta(seconds=random.uniform(0, jitter_s)) - fire_time += jitter - - schedule.last_fire_time = fire_time - job = Job( - task_id=schedule.task_id, - args=schedule.args, - kwargs=schedule.kwargs, - schedule_id=schedule.id, - scheduled_fire_time=fire_time, - jitter=jitter, - start_deadline=schedule.next_deadline, - tags=schedule.tags, + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + max_jitter = ( + schedule.max_jitter.total_seconds() + if schedule.max_jitter + else 0 ) - await self.data_store.add_job(job) - - # Update the schedules (and release the scheduler's claim on them) - await self.data_store.release_schedules(self.identity, schedules) + for i, fire_time in enumerate(fire_times): + # Calculate a jitter if max_jitter > 0 + jitter = _zero_timedelta + if max_jitter: + if i + 1 < len(fire_times): + next_fire_time = fire_times[i + 1] + else: + next_fire_time = schedule.next_fire_time + + if next_fire_time is not None: + # Jitter must never be so high that it would cause a + # fire time to equal or exceed the next fire time + jitter_s = min( + [ + max_jitter, + ( + next_fire_time + - fire_time + - _microsecond_delta + ).total_seconds(), + ] + ) + jitter = timedelta( + seconds=random.uniform(0, jitter_s) + ) + fire_time += jitter + + schedule.last_fire_time = fire_time + job = Job( + task_id=schedule.task_id, + args=schedule.args, + kwargs=schedule.kwargs, + schedule_id=schedule.id, + scheduled_fire_time=fire_time, + jitter=jitter, + start_deadline=schedule.next_deadline, + tags=schedule.tags, + ) + await self.data_store.add_job(job) + + # Update the schedules (and release the scheduler's claim on them) + await self.data_store.release_schedules(self.identity, schedules) + + # If we received fewer schedules than the maximum amount, sleep + # until the next schedule is due or the scheduler is explicitly + # woken up + wait_time = None + if len(schedules) < 100: + self._wakeup_deadline = ( + await self.data_store.get_next_schedule_run_time() + ) + if self._wakeup_deadline: + wait_time = ( + self._wakeup_deadline - datetime.now(timezone.utc) + ).total_seconds() + self.logger.debug( + "Sleeping %.3f seconds until the next fire time (%s)", + wait_time, + self._wakeup_deadline, + ) + else: + self.logger.debug("Waiting for any due schedules to appear") - # If we received fewer schedules than the maximum amount, sleep until the next - # schedule is due or the scheduler is explicitly woken up - wait_time = None - if len(schedules) < 100: - self._wakeup_deadline = ( - await self.data_store.get_next_schedule_run_time() - ) - if self._wakeup_deadline: - wait_time = ( - self._wakeup_deadline - datetime.now(timezone.utc) - ).total_seconds() + with move_on_after(wait_time): + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() + else: self.logger.debug( - "Sleeping %.3f seconds until the next fire time (%s)", - wait_time, - self._wakeup_deadline, + "Processing more schedules on the next iteration" ) - else: - self.logger.debug("Waiting for any due schedules to appear") - - with move_on_after(wait_time): - await self._wakeup_event.wait() - self._wakeup_event = anyio.Event() - else: - self.logger.debug("Processing more schedules on the next iteration") - except get_cancelled_exc_class(): - pass - except BaseException as exc: - self.logger.exception("Scheduler crashed") - exception = exc - raise - else: - self.logger.info("Scheduler stopped") - finally: - self._state = RunState.stopped - with move_on_after(3, shield=True): - await self._events.publish(SchedulerStopped(exception=exception)) + except get_cancelled_exc_class(): + pass + except BaseException as exc: + self.logger.exception("Scheduler crashed") + exception = exc + raise + else: + self.logger.info("Scheduler stopped") + finally: + self._state = RunState.stopped + with move_on_after(3, shield=True): + await self.event_broker.publish_local( + SchedulerStopped(exception=exception) + ) diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index 4c7095e..5ea0c8a 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -5,16 +5,17 @@ import os import platform import random import threading -from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from concurrent.futures import Future from contextlib import ExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger +from types import TracebackType from typing import Any, Callable, Iterable, Mapping, cast from uuid import UUID, uuid4 import attrs -from ..abc import DataStore, EventSource, Trigger +from ..abc import DataStore, EventBroker, Trigger from ..context import current_scheduler from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState @@ -41,39 +42,60 @@ class Scheduler: """A synchronous scheduler implementation.""" data_store: DataStore = attrs.field(factory=MemoryDataStore) + event_broker: EventBroker = attrs.field(factory=LocalEventBroker) identity: str = attrs.field(kw_only=True, default=None) start_worker: bool = attrs.field(kw_only=True, default=True) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: threading.Event = attrs.field(init=False) + _thread: threading.Thread | None = attrs.field(init=False, default=None) + _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event) _wakeup_deadline: datetime | None = attrs.field(init=False, default=None) - _worker: Worker | None = attrs.field(init=False, default=None) - _events: LocalEventBroker = attrs.field(init=False, factory=LocalEventBroker) + _services_initialized: bool = attrs.field(init=False, default=False) + _exit_stack: ExitStack = attrs.field(init=False, factory=ExitStack) + _lock: threading.RLock = attrs.field(init=False, factory=threading.RLock) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" @property - def events(self) -> EventSource: - return self._events - - @property def state(self) -> RunState: return self._state - @property - def worker(self) -> Worker | None: - return self._worker - def __enter__(self) -> Scheduler: self.start_in_background() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: self.stop() + def _ensure_services_ready(self) -> None: + """Ensure that the data store and event broker have been initialized.""" + with self._lock: + if not self._services_initialized: + self._services_initialized = True + self.event_broker.start() + self._exit_stack.push( + lambda *exc_info: self.event_broker.stop( + force=exc_info[0] is not None + ) + ) + + # Initialize the data store + self.data_store.start(self.event_broker) + self._exit_stack.push( + lambda *exc_info: self.data_store.stop( + force=exc_info[0] is not None + ) + ) + atexit.register(self.stop) + def _schedule_added_or_modified(self, event: Event) -> None: event_ = cast("ScheduleAdded | ScheduleUpdated", event) if not self._wakeup_deadline or ( @@ -98,6 +120,7 @@ class Scheduler: tags: Iterable[str] | None = None, conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing, ) -> str: + self._ensure_services_ready() id = id or str(uuid4()) args = tuple(args or ()) kwargs = dict(kwargs or {}) @@ -133,10 +156,12 @@ class Scheduler: return schedule.id def get_schedule(self, id: str) -> Schedule: + self._ensure_services_ready() schedules = self.data_store.get_schedules({id}) return schedules[0] def remove_schedule(self, schedule_id: str) -> None: + self._ensure_services_ready() self.data_store.remove_schedules({schedule_id}) def add_job( @@ -157,6 +182,7 @@ class Scheduler: :return: the ID of the newly created job """ + self._ensure_services_ready() if callable(func_or_task_id): task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) self.data_store.add_task(task) @@ -177,18 +203,19 @@ class Scheduler: Retrieve the result of a job. :param job_id: the ID of the job - :param wait: if ``True``, wait until the job has ended (one way or another), ``False`` to - raise an exception if the result is not yet available + :param wait: if ``True``, wait until the job has ended (one way or another), + ``False`` to raise an exception if the result is not yet available :raises JobLookupError: if the job does not exist in the data store """ + self._ensure_services_ready() wait_event = threading.Event() def listener(event: JobReleased) -> None: if event.job_id == job_id: wait_event.set() - with self.data_store.events.subscribe(listener, {JobReleased}): + with self.data_store.events.subscribe(listener, {JobReleased}, one_shot=True): result = self.data_store.get_job_result(job_id) if result: return result @@ -215,6 +242,7 @@ class Scheduler: :returns: the return value of the target function """ + self._ensure_services_ready() job_complete_event = threading.Event() def listener(event: JobReleased) -> None: @@ -239,181 +267,200 @@ class Scheduler: raise RuntimeError(f"Unknown job outcome: {result.outcome}") def start_in_background(self) -> None: - if self._state is not RunState.stopped: - raise RuntimeError("Scheduler already running") + start_future: Future[None] = Future() + self._thread = threading.Thread( + target=self._run, args=[start_future], daemon=True + ) + self._thread.start() + try: + start_future.result() + except BaseException: + self._thread = None + raise - thread = threading.Thread(target=self.run_until_stopped, daemon=True) - thread.start() atexit.register(self.stop) def stop(self) -> None: - if self._state is RunState.started: - f = Future() - with self._events.subscribe(f.set_result, one_shot=True): + atexit.unregister(self.stop) + with self._lock: + if self._state is RunState.started: self._state = RunState.stopping self._wakeup_event.set() - f.result() - def run_until_stopped(self): - self._state = RunState.starting - self._wakeup_event = threading.Event() - with ExitStack() as exit_stack: - exit_stack.enter_context(self._events) + if self._thread and threading.current_thread() != self._thread: + self._thread.join() + self._thread = None - # Initialize the data store and start relaying events to the scheduler's event broker - exit_stack.enter_context(self.data_store) - exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) + self._exit_stack.close() - # Wake up the scheduler if the data store emits a significant schedule event - exit_stack.enter_context( - self.data_store.events.subscribe( - self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} - ) - ) + def run_until_stopped(self) -> None: + self._run(None) - # Start the built-in worker, if configured to do so - if self.start_worker: - token = current_scheduler.set(self) - try: - self._worker = Worker(self.data_store) - exit_stack.enter_context(self._worker) - finally: - current_scheduler.reset(token) - - # Start the scheduler and return when it has signalled readiness or raised an exception - start_future: Future[Event] = Future() - with self._events.subscribe(start_future.set_result, one_shot=True): - executor = ThreadPoolExecutor(1) - exit_stack.push( - lambda exc_type, *args: executor.shutdown(wait=exc_type is None) - ) - run_future = executor.submit(self._run) - wait([start_future, run_future], return_when=FIRST_COMPLETED) - - self._state = RunState.stopped - del self._wakeup_event - - if run_future.done(): - run_future.result() + def _run(self, start_future: Future[None] | None) -> None: + with ExitStack() as exit_stack: + try: + with self._lock: + if self._state is not RunState.stopped: + raise RuntimeError( + f'Cannot start the scheduler when it is in the "{self._state}" ' + f"state" + ) - def _run(self) -> None: - assert self._state is RunState.starting + self._state = RunState.starting - # Signal that the scheduler has started - self._state = RunState.started - self._events.publish(SchedulerStarted()) + self._ensure_services_ready() - try: - while self._state is RunState.started: - schedules = self.data_store.acquire_schedules(self.identity, 100) - self.logger.debug( - "Processing %d schedules retrieved from the data store", - len(schedules), + # Wake up the scheduler if the data store emits a significant schedule event + exit_stack.enter_context( + self.data_store.events.subscribe( + self._schedule_added_or_modified, + {ScheduleAdded, ScheduleUpdated}, + ) ) - now = datetime.now(timezone.utc) - for schedule in schedules: - # Calculate a next fire time for the schedule, if possible - fire_times = [schedule.next_fire_time] - calculate_next = schedule.trigger.next - while True: - try: - fire_time = calculate_next() - except Exception: - self.logger.exception( - "Error computing next fire time for schedule %r of task %r – " - "removing schedule", - schedule.id, - schedule.task_id, - ) - break - - # Stop if the calculated fire time is in the future - if fire_time is None or fire_time > now: - schedule.next_fire_time = fire_time - break - - # Only keep all the fire times if coalesce policy = "all" - if schedule.coalesce is CoalescePolicy.all: - fire_times.append(fire_time) - elif schedule.coalesce is CoalescePolicy.latest: - fire_times[0] = fire_time - - # Add one or more jobs to the job queue - max_jitter = ( - schedule.max_jitter.total_seconds() - if schedule.max_jitter - else 0 + + # Start the built-in worker, if configured to do so + if self.start_worker: + token = current_scheduler.set(self) + exit_stack.callback(current_scheduler.reset, token) + worker = Worker( + self.data_store, self.event_broker, is_internal=True + ) + exit_stack.enter_context(worker) + + # Signal that the scheduler has started + self._state = RunState.started + self.event_broker.publish_local(SchedulerStarted()) + except BaseException as exc: + if start_future: + start_future.set_exception(exc) + return + else: + raise + else: + if start_future: + start_future.set_result(None) + + try: + while self._state is RunState.started: + schedules = self.data_store.acquire_schedules(self.identity, 100) + self.logger.debug( + "Processing %d schedules retrieved from the data store", + len(schedules), ) - for i, fire_time in enumerate(fire_times): - # Calculate a jitter if max_jitter > 0 - jitter = _zero_timedelta - if max_jitter: - if i + 1 < len(fire_times): - next_fire_time = fire_times[i + 1] - else: - next_fire_time = schedule.next_fire_time - - if next_fire_time is not None: - # Jitter must never be so high that it would cause a fire time to - # equal or exceed the next fire time - jitter_s = min( - [ - max_jitter, - ( - next_fire_time - - fire_time - - _microsecond_delta - ).total_seconds(), - ] + now = datetime.now(timezone.utc) + for schedule in schedules: + # Calculate a next fire time for the schedule, if possible + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + "Error computing next fire time for schedule %r of task %r – " + "removing schedule", + schedule.id, + schedule.task_id, ) - jitter = timedelta(seconds=random.uniform(0, jitter_s)) - fire_time += jitter - - schedule.last_fire_time = fire_time - job = Job( - task_id=schedule.task_id, - args=schedule.args, - kwargs=schedule.kwargs, - schedule_id=schedule.id, - scheduled_fire_time=fire_time, - jitter=jitter, - start_deadline=schedule.next_deadline, - tags=schedule.tags, + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + max_jitter = ( + schedule.max_jitter.total_seconds() + if schedule.max_jitter + else 0 ) - self.data_store.add_job(job) - - # Update the schedules (and release the scheduler's claim on them) - self.data_store.release_schedules(self.identity, schedules) - - # If we received fewer schedules than the maximum amount, sleep until the next - # schedule is due or the scheduler is explicitly woken up - wait_time = None - if len(schedules) < 100: - self._wakeup_deadline = self.data_store.get_next_schedule_run_time() - if self._wakeup_deadline: - wait_time = ( - self._wakeup_deadline - datetime.now(timezone.utc) - ).total_seconds() - self.logger.debug( - "Sleeping %.3f seconds until the next fire time (%s)", - wait_time, - self._wakeup_deadline, + for i, fire_time in enumerate(fire_times): + # Calculate a jitter if max_jitter > 0 + jitter = _zero_timedelta + if max_jitter: + if i + 1 < len(fire_times): + next_fire_time = fire_times[i + 1] + else: + next_fire_time = schedule.next_fire_time + + if next_fire_time is not None: + # Jitter must never be so high that it would cause + # a fire time to equal or exceed the next fire time + jitter_s = min( + [ + max_jitter, + ( + next_fire_time + - fire_time + - _microsecond_delta + ).total_seconds(), + ] + ) + jitter = timedelta( + seconds=random.uniform(0, jitter_s) + ) + fire_time += jitter + + schedule.last_fire_time = fire_time + job = Job( + task_id=schedule.task_id, + args=schedule.args, + kwargs=schedule.kwargs, + schedule_id=schedule.id, + scheduled_fire_time=fire_time, + jitter=jitter, + start_deadline=schedule.next_deadline, + tags=schedule.tags, + ) + self.data_store.add_job(job) + + # Update the schedules (and release the scheduler's claim on them) + self.data_store.release_schedules(self.identity, schedules) + + # If we received fewer schedules than the maximum amount, sleep + # until the next schedule is due or the scheduler is explicitly + # woken up + wait_time = None + if len(schedules) < 100: + self._wakeup_deadline = ( + self.data_store.get_next_schedule_run_time() ) - else: - self.logger.debug("Waiting for any due schedules to appear") + if self._wakeup_deadline: + wait_time = ( + self._wakeup_deadline - datetime.now(timezone.utc) + ).total_seconds() + self.logger.debug( + "Sleeping %.3f seconds until the next fire time (%s)", + wait_time, + self._wakeup_deadline, + ) + else: + self.logger.debug("Waiting for any due schedules to appear") - if self._wakeup_event.wait(wait_time): - self._wakeup_event = threading.Event() + if self._wakeup_event.wait(wait_time): + self._wakeup_event = threading.Event() + else: + self.logger.debug( + "Processing more schedules on the next iteration" + ) + except BaseException as exc: + self._state = RunState.stopped + if isinstance(exc, Exception): + self.logger.exception("Scheduler crashed") else: - self.logger.debug("Processing more schedules on the next iteration") - except BaseException as exc: - self._state = RunState.stopped - self.logger.exception("Scheduler crashed") - self._events.publish(SchedulerStopped(exception=exc)) - raise + self.logger.info( + f"Scheduler stopped due to {exc.__class__.__name__}" + ) - self._state = RunState.stopped - self.logger.info("Scheduler stopped") - self._events.publish(SchedulerStopped()) + self.event_broker.publish_local(SchedulerStopped(exception=exc)) + else: + self._state = RunState.stopped + self.logger.info("Scheduler stopped") + self.event_broker.publish_local(SchedulerStopped()) diff --git a/src/apscheduler/util.py b/src/apscheduler/util.py index 7ff4858..8c69203 100644 --- a/src/apscheduler/util.py +++ b/src/apscheduler/util.py @@ -2,9 +2,8 @@ from __future__ import annotations import sys -from collections import defaultdict from datetime import datetime, tzinfo -from typing import Callable, TypeVar +from typing import TypeVar if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -36,53 +35,3 @@ def timezone_repr(timezone: tzinfo) -> str: def absolute_datetime_diff(dateval1: datetime, dateval2: datetime) -> float: return dateval1.timestamp() - dateval2.timestamp() - - -def reentrant(cls: type[T]) -> type[T]: - """ - Modifies a class so that its ``__enter__`` / ``__exit__`` (or ``__aenter__`` / ``__aexit__``) - methods track the number of times it has been entered and exited and only actually invoke - the ``__enter__()`` method on the first entry and ``__exit__()`` on the last exit. - - """ - - def __enter__(self: T) -> T: - loans[self] += 1 - if loans[self] == 1: - previous_enter(self) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - assert loans[self] - loans[self] -= 1 - if loans[self] == 0: - return previous_exit(self, exc_type, exc_val, exc_tb) - - async def __aenter__(self: T) -> T: - loans[self] += 1 - if loans[self] == 1: - await previous_aenter(self) - - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - assert loans[self] - loans[self] -= 1 - if loans[self] == 0: - del loans[self] - return await previous_aexit(self, exc_type, exc_val, exc_tb) - - loans: dict[T, int] = defaultdict(lambda: 0) - previous_enter: Callable = getattr(cls, "__enter__", None) - previous_exit: Callable = getattr(cls, "__exit__", None) - previous_aenter: Callable = getattr(cls, "__aenter__", None) - previous_aexit: Callable = getattr(cls, "__aexit__", None) - if previous_enter and previous_exit: - cls.__enter__ = __enter__ - cls.__exit__ = __exit__ - elif previous_aenter and previous_aexit: - cls.__aenter__ = __aenter__ - cls.__aexit__ = __aexit__ - - return cls diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py index e4ad44d..65870e5 100644 --- a/src/apscheduler/workers/async_.py +++ b/src/apscheduler/workers/async_.py @@ -6,6 +6,7 @@ from contextlib import AsyncExitStack from datetime import datetime, timezone from inspect import isawaitable from logging import Logger, getLogger +from types import TracebackType from typing import Callable from uuid import UUID @@ -17,11 +18,11 @@ from anyio import ( get_cancelled_exc_class, move_on_after, ) -from anyio.abc import CancelScope +from anyio.abc import CancelScope, TaskGroup -from ..abc import AsyncDataStore, EventSource, Job +from ..abc import AsyncDataStore, AsyncEventBroker, Job from ..context import current_worker, job_info -from ..converters import as_async_datastore +from ..converters import as_async_datastore, as_async_eventbroker from ..enums import JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import JobAdded, WorkerStarted, WorkerStopped @@ -34,105 +35,129 @@ class AsyncWorker: """Runs jobs locally in a task group.""" data_store: AsyncDataStore = attrs.field(converter=as_async_datastore) + event_broker: AsyncEventBroker = attrs.field( + converter=as_async_eventbroker, factory=LocalAsyncEventBroker + ) max_concurrent_jobs: int = attrs.field( kw_only=True, validator=positive_integer, default=100 ) identity: str = attrs.field(kw_only=True, default=None) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) + # True if a scheduler owns this worker + _is_internal: bool = attrs.field(kw_only=True, default=False) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: anyio.Event = attrs.field(init=False, factory=anyio.Event) + _wakeup_event: anyio.Event = attrs.field(init=False) + _task_group: TaskGroup = attrs.field(init=False) _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) - _events: LocalAsyncEventBroker = attrs.field( - init=False, factory=LocalAsyncEventBroker - ) _running_jobs: set[UUID] = attrs.field(init=False, factory=set) - _exit_stack: AsyncExitStack = attrs.field(init=False) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" @property - def events(self) -> EventSource: - return self._events - - @property def state(self) -> RunState: return self._state - async def __aenter__(self): - self._state = RunState.starting - self._wakeup_event = anyio.Event() - self._exit_stack = AsyncExitStack() - await self._exit_stack.__aenter__() - await self._exit_stack.enter_async_context(self._events) - - # Initialize the data store and start relaying events to the worker's event broker - await self._exit_stack.enter_async_context(self.data_store) - self._exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) - - # Wake up the worker if the data store emits a significant job event - self._exit_stack.enter_context( - self.data_store.events.subscribe( - lambda event: self._wakeup_event.set(), {JobAdded} - ) - ) - - # Start the actual worker - task_group = create_task_group() - await self._exit_stack.enter_async_context(task_group) - await task_group.start(self.run) + async def __aenter__(self) -> AsyncWorker: + self._task_group = create_task_group() + await self._task_group.__aenter__() + await self._task_group.start(self.run_until_stopped) return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: self._state = RunState.stopping self._wakeup_event.set() - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + del self._task_group del self._wakeup_event - async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None: - if self._state is not RunState.starting: + async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None: + if self._state is not RunState.stopped: raise RuntimeError( - f"This function cannot be called while the worker is in the " - f"{self._state} state" + f'Cannot start the worker when it is in the "{self._state}" ' f"state" ) - # Set the current worker - token = current_worker.set(self) + self._state = RunState.starting + self._wakeup_event = anyio.Event() + async with AsyncExitStack() as exit_stack: + if not self._is_internal: + # Initialize the event broker + await self.event_broker.start() + exit_stack.push_async_exit( + lambda *exc_info: self.event_broker.stop( + force=exc_info[0] is not None + ) + ) - # Signal that the worker has started - self._state = RunState.started - task_status.started() - await self._events.publish(WorkerStarted()) + # Initialize the data store + await self.data_store.start(self.event_broker) + exit_stack.push_async_exit( + lambda *exc_info: self.data_store.stop( + force=exc_info[0] is not None + ) + ) - exception: BaseException | None = None - try: - async with create_task_group() as tg: - while self._state is RunState.started: - limit = self.max_concurrent_jobs - len(self._running_jobs) - jobs = await self.data_store.acquire_jobs(self.identity, limit) - for job in jobs: - task = await self.data_store.get_task(job.task_id) - self._running_jobs.add(job.id) - tg.start_soon(self._run_job, job, task.func) - - await self._wakeup_event.wait() - self._wakeup_event = anyio.Event() - except get_cancelled_exc_class(): - pass - except BaseException as exc: - self.logger.exception("Worker crashed") - exception = exc - else: - self.logger.info("Worker stopped") - finally: - current_worker.reset(token) - self._state = RunState.stopped - with move_on_after(3, shield=True): - await self._events.publish(WorkerStopped(exception=exception)) + # Set the current worker + token = current_worker.set(self) + exit_stack.callback(current_worker.reset, token) + + # Wake up the worker if the data store emits a significant job event + self.event_broker.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded} + ) + + # Signal that the worker has started + self._state = RunState.started + task_status.started() + exception: BaseException | None = None + try: + await self.event_broker.publish_local(WorkerStarted()) + + async with create_task_group() as tg: + while self._state is RunState.started: + limit = self.max_concurrent_jobs - len(self._running_jobs) + jobs = await self.data_store.acquire_jobs(self.identity, limit) + for job in jobs: + task = await self.data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + tg.start_soon(self._run_job, job, task.func) + + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() + except get_cancelled_exc_class(): + pass + except BaseException as exc: + self.logger.exception("Worker crashed") + exception = exc + else: + self.logger.info("Worker stopped") + finally: + self._state = RunState.stopped + with move_on_after(3, shield=True): + await self.event_broker.publish_local( + WorkerStopped(exception=exception) + ) + + async def stop(self, *, force: bool = False) -> None: + if self._state in (RunState.starting, RunState.started): + self._state = RunState.stopping + event = anyio.Event() + self.event_broker.subscribe( + lambda ev: event.set(), {WorkerStopped}, one_shot=True + ) + if force: + self._task_group.cancel_scope.cancel() + else: + self._wakeup_event.set() + + await event.wait() async def _run_job(self, job: Job, func: Callable) -> None: try: diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 2da0045..5b5e6a9 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -1,19 +1,21 @@ from __future__ import annotations +import atexit import os import platform import threading -from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from concurrent.futures import Future, ThreadPoolExecutor from contextlib import ExitStack from contextvars import copy_context from datetime import datetime, timezone from logging import Logger, getLogger +from types import TracebackType from typing import Callable from uuid import UUID import attrs -from ..abc import DataStore, EventSource +from ..abc import DataStore, EventBroker from ..context import current_worker, job_info from ..enums import JobOutcome, RunState from ..eventbrokers.local import LocalEventBroker @@ -27,111 +29,151 @@ class Worker: """Runs jobs locally in a thread pool.""" data_store: DataStore + event_broker: EventBroker = attrs.field(factory=LocalEventBroker) max_concurrent_jobs: int = attrs.field( kw_only=True, validator=positive_integer, default=20 ) identity: str = attrs.field(kw_only=True, default=None) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) + # True if a scheduler owns this worker + _is_internal: bool = attrs.field(kw_only=True, default=False) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: threading.Event = attrs.field(init=False) + _thread: threading.Thread | None = attrs.field(init=False, default=None) + _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event) + _executor: ThreadPoolExecutor = attrs.field(init=False) _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) - _events: LocalEventBroker = attrs.field(init=False, factory=LocalEventBroker) _running_jobs: set[UUID] = attrs.field(init=False, factory=set) - _exit_stack: ExitStack = attrs.field(init=False) - _executor: ThreadPoolExecutor = attrs.field(init=False) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" @property - def events(self) -> EventSource: - return self._events - - @property def state(self) -> RunState: return self._state def __enter__(self) -> Worker: - self._state = RunState.starting - self._wakeup_event = threading.Event() - self._exit_stack = ExitStack() - self._exit_stack.__enter__() - self._exit_stack.enter_context(self._events) - - # Initialize the data store and start relaying events to the worker's event broker - self._exit_stack.enter_context(self.data_store) - self._exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) + self.start_in_background() + return self - # Wake up the worker if the data store emits a significant job event - self._exit_stack.enter_context( - self.data_store.events.subscribe( - lambda event: self._wakeup_event.set(), {JobAdded} - ) - ) + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + self.stop() - # Start the worker and return when it has signalled readiness or raised an exception + def start_in_background(self) -> None: start_future: Future[None] = Future() - with self._events.subscribe(start_future.set_result, one_shot=True): - self._executor = ThreadPoolExecutor(1) - run_future = self._executor.submit(copy_context().run, self.run) - wait([start_future, run_future], return_when=FIRST_COMPLETED) + self._thread = threading.Thread( + target=copy_context().run, args=[self._run, start_future], daemon=True + ) + self._thread.start() + try: + start_future.result() + except BaseException: + self._thread = None + raise - if run_future.done(): - run_future.result() + atexit.register(self.stop) - return self + def stop(self) -> None: + atexit.unregister(self.stop) + if self._state is RunState.started: + self._state = RunState.stopping + self._wakeup_event.set() - def __exit__(self, exc_type, exc_val, exc_tb): - self._state = RunState.stopping - self._wakeup_event.set() - self._executor.shutdown(wait=exc_type is None) - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) - del self._wakeup_event - - def run(self) -> None: - if self._state is not RunState.starting: - raise RuntimeError( - f"This function cannot be called while the worker is in the " - f"{self._state} state" - ) - - # Set the current worker - token = current_worker.set(self) - - # Signal that the worker has started - self._state = RunState.started - self._events.publish(WorkerStarted()) - - executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) - exception: BaseException | None = None - try: - while self._state is RunState.started: - available_slots = self.max_concurrent_jobs - len(self._running_jobs) - if available_slots: - jobs = self.data_store.acquire_jobs(self.identity, available_slots) - for job in jobs: - task = self.data_store.get_task(job.task_id) - self._running_jobs.add(job.id) - executor.submit( - copy_context().run, self._run_job, job, task.func + if threading.current_thread() != self._thread: + self._thread.join() + self._thread = None + + def run_until_stopped(self) -> None: + self._run(None) + + def _run(self, start_future: Future[None] | None) -> None: + with ExitStack() as exit_stack: + try: + if self._state is not RunState.stopped: + raise RuntimeError( + f'Cannot start the worker when it is in the "{self._state}" ' + f"state" + ) + + if not self._is_internal: + # Initialize the event broker + self.event_broker.start() + exit_stack.push( + lambda *exc_info: self.event_broker.stop( + force=exc_info[0] is not None ) + ) - self._wakeup_event.wait() - self._wakeup_event = threading.Event() - except BaseException as exc: - self.logger.exception("Worker crashed") - exception = exc - else: - self.logger.info("Worker stopped") - finally: - current_worker.reset(token) - self._state = RunState.stopped - self._events.publish(WorkerStopped(exception=exception)) - executor.shutdown() + # Initialize the data store + self.data_store.start(self.event_broker) + exit_stack.push( + lambda *exc_info: self.data_store.stop( + force=exc_info[0] is not None + ) + ) + + # Set the current worker + token = current_worker.set(self) + exit_stack.callback(current_worker.reset, token) + + # Wake up the worker if the data store emits a significant job event + exit_stack.enter_context( + self.event_broker.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded} + ) + ) + + # Initialize the thread pool + executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) + exit_stack.enter_context(executor) + + # Signal that the worker has started + self._state = RunState.started + self.event_broker.publish_local(WorkerStarted()) + except BaseException as exc: + if start_future: + start_future.set_exception(exc) + return + else: + raise + else: + if start_future: + start_future.set_result(None) + + try: + while self._state is RunState.started: + available_slots = self.max_concurrent_jobs - len(self._running_jobs) + if available_slots: + jobs = self.data_store.acquire_jobs( + self.identity, available_slots + ) + for job in jobs: + task = self.data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + executor.submit( + copy_context().run, self._run_job, job, task.func + ) + + self._wakeup_event.wait() + self._wakeup_event = threading.Event() + except BaseException as exc: + self._state = RunState.stopped + if isinstance(exc, Exception): + self.logger.exception("Worker crashed") + else: + self.logger.info(f"Worker stopped due to {exc.__class__.__name__}") + + self.event_broker.publish_local(WorkerStopped(exception=exc)) + else: + self._state = RunState.stopped + self.logger.info("Worker stopped") + self.event_broker.publish_local(WorkerStopped()) def _run_job(self, job: Job, func: Callable) -> None: try: diff --git a/tests/test_datastores.py b/tests/test_datastores.py index d6376ea..ea5416e 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -1,19 +1,31 @@ from __future__ import annotations -from contextlib import asynccontextmanager +import threading +from collections.abc import Generator +from contextlib import asynccontextmanager, contextmanager from datetime import datetime, timezone from tempfile import TemporaryDirectory -from typing import AsyncGenerator +from typing import Any, AsyncGenerator, cast import anyio import pytest +from _pytest.fixtures import SubRequest from freezegun.api import FrozenDateTimeFactory from pytest_lazyfixture import lazy_fixture -from apscheduler.abc import AsyncDataStore, DataStore, Job, Schedule +from apscheduler.abc import ( + AsyncDataStore, + AsyncEventBroker, + DataStore, + EventBroker, + Job, + Schedule, +) from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter from apscheduler.datastores.memory import MemoryDataStore from apscheduler.enums import CoalescePolicy, ConflictPolicy, JobOutcome +from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker +from apscheduler.eventbrokers.local import LocalEventBroker from apscheduler.events import ( Event, ScheduleAdded, @@ -32,6 +44,12 @@ def memory_store() -> DataStore: @pytest.fixture +def adapted_memory_store() -> AsyncDataStore: + store = MemoryDataStore() + return AsyncDataStoreAdapter(store) + + +@pytest.fixture def mongodb_store() -> DataStore: from pymongo import MongoClient @@ -96,39 +114,6 @@ async def asyncpg_store() -> AsyncDataStore: await engine.dispose() -@pytest.fixture( - params=[ - pytest.param(lazy_fixture("memory_store"), id="memory"), - pytest.param(lazy_fixture("sqlite_store"), id="sqlite"), - pytest.param( - lazy_fixture("mongodb_store"), - id="mongodb", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("psycopg2_store"), - id="psycopg2", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("asyncpg_store"), - id="asyncpg", - marks=[pytest.mark.external_service], - ), - pytest.param( - lazy_fixture("mysql_store"), - id="mysql", - marks=[pytest.mark.external_service], - ), - ] -) -async def datastore(request): - if isinstance(request.param, DataStore): - return AsyncDataStoreAdapter(request.param) - else: - return request.param - - @pytest.fixture def schedules() -> list[Schedule]: trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) @@ -144,32 +129,504 @@ def schedules() -> list[Schedule]: return [schedule1, schedule2, schedule3] -@asynccontextmanager -async def capture_events( - datastore: AsyncDataStore, limit: int, event_types: set[type[Event]] | None = None -) -> AsyncGenerator[list[Event], None]: - def listener(event: Event) -> None: - events.append(event) - if len(events) == limit: - limit_event.set() - subscription.unsubscribe() +class TestDataStores: + @contextmanager + def capture_events( + self, + datastore: DataStore, + limit: int, + event_types: set[type[Event]] | None = None, + ) -> Generator[list[Event], None, None]: + def listener(event: Event) -> None: + events.append(event) + if len(events) == limit: + limit_event.set() + subscription.unsubscribe() + + events: list[Event] = [] + limit_event = threading.Event() + subscription = datastore.events.subscribe(listener, event_types) + yield events + if limit: + limit_event.wait(2) + + @pytest.fixture + def event_broker(self) -> Generator[EventBroker, Any, None]: + broker = LocalEventBroker() + broker.start() + yield broker + broker.stop() + + @pytest.fixture( + params=[ + pytest.param(lazy_fixture("memory_store"), id="memory"), + pytest.param(lazy_fixture("sqlite_store"), id="sqlite"), + pytest.param( + lazy_fixture("mongodb_store"), + id="mongodb", + marks=[pytest.mark.external_service], + ), + pytest.param( + lazy_fixture("psycopg2_store"), + id="psycopg2", + marks=[pytest.mark.external_service], + ), + pytest.param( + lazy_fixture("mysql_store"), + id="mysql", + marks=[pytest.mark.external_service], + ), + ] + ) + def datastore( + self, request: SubRequest, event_broker: EventBroker + ) -> Generator[DataStore, Any, None]: + datastore = cast(DataStore, request.param) + datastore.start(event_broker) + yield datastore + datastore.stop() + + def test_add_replace_task(self, datastore: DataStore) -> None: + import math - events: list[Event] = [] - limit_event = anyio.Event() - subscription = datastore.events.subscribe(listener, event_types) - yield events - if limit: - with anyio.fail_after(3): - await limit_event.wait() + event_types = {TaskAdded, TaskUpdated} + with self.capture_events(datastore, 3, event_types) as events: + datastore.add_task(Task(id="test_task", func=print)) + datastore.add_task(Task(id="test_task2", func=math.ceil)) + datastore.add_task(Task(id="test_task", func=repr)) + + tasks = datastore.get_tasks() + assert len(tasks) == 2 + assert tasks[0].id == "test_task" + assert tasks[0].func is repr + assert tasks[1].id == "test_task2" + assert tasks[1].func is math.ceil + + received_event = events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == "test_task" + + received_event = events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == "test_task2" + + received_event = events.pop(0) + assert isinstance(received_event, TaskUpdated) + assert received_event.task_id == "test_task" + + assert not events + + def test_add_schedules( + self, datastore: DataStore, schedules: list[Schedule] + ) -> None: + with self.capture_events(datastore, 3, {ScheduleAdded}) as events: + for schedule in schedules: + datastore.add_schedule(schedule, ConflictPolicy.exception) + + assert datastore.get_schedules() == schedules + assert datastore.get_schedules({"s1", "s2", "s3"}) == schedules + assert datastore.get_schedules({"s1"}) == [schedules[0]] + assert datastore.get_schedules({"s2"}) == [schedules[1]] + assert datastore.get_schedules({"s3"}) == [schedules[2]] + + for event, schedule in zip(events, schedules): + assert event.schedule_id == schedule.id + assert event.next_fire_time == schedule.next_fire_time + + def test_replace_schedules( + self, datastore: DataStore, schedules: list[Schedule] + ) -> None: + with self.capture_events(datastore, 1, {ScheduleUpdated}) as events: + for schedule in schedules: + datastore.add_schedule(schedule, ConflictPolicy.exception) + + next_fire_time = schedules[2].trigger.next() + schedule = Schedule( + id="s3", + task_id="foo", + trigger=schedules[2].trigger, + args=(), + kwargs={}, + coalesce=CoalescePolicy.earliest, + misfire_grace_time=None, + tags=frozenset(), + ) + schedule.next_fire_time = next_fire_time + datastore.add_schedule(schedule, ConflictPolicy.replace) + + schedules = datastore.get_schedules({schedule.id}) + assert schedules[0].task_id == "foo" + assert schedules[0].next_fire_time == next_fire_time + assert schedules[0].args == () + assert schedules[0].kwargs == {} + assert schedules[0].coalesce is CoalescePolicy.earliest + assert schedules[0].misfire_grace_time is None + assert schedules[0].tags == frozenset() + + received_event = events.pop(0) + assert received_event.schedule_id == "s3" + assert received_event.next_fire_time == datetime( + 2020, 9, 15, tzinfo=timezone.utc + ) + assert not events + + def test_remove_schedules( + self, datastore: DataStore, schedules: list[Schedule] + ) -> None: + with self.capture_events(datastore, 2, {ScheduleRemoved}) as events: + for schedule in schedules: + datastore.add_schedule(schedule, ConflictPolicy.exception) + + datastore.remove_schedules(["s1", "s2"]) + assert datastore.get_schedules() == [schedules[2]] + + received_event = events.pop(0) + assert received_event.schedule_id == "s1" + + received_event = events.pop(0) + assert received_event.schedule_id == "s2" + + assert not events + + @pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc)) + def test_acquire_release_schedules( + self, datastore: DataStore, schedules: list[Schedule] + ) -> None: + event_types = {ScheduleRemoved, ScheduleUpdated} + with self.capture_events(datastore, 2, event_types) as events: + for schedule in schedules: + datastore.add_schedule(schedule, ConflictPolicy.exception) + + # The first scheduler gets the first due schedule + schedules1 = datastore.acquire_schedules("dummy-id1", 1) + assert len(schedules1) == 1 + assert schedules1[0].id == "s1" + + # The second scheduler gets the second due schedule + schedules2 = datastore.acquire_schedules("dummy-id2", 1) + assert len(schedules2) == 1 + assert schedules2[0].id == "s2" + + # The third scheduler gets nothing + schedules3 = datastore.acquire_schedules("dummy-id3", 1) + assert not schedules3 + + # Update the schedules and check that the job store actually deletes the first + # one and updates the second one + schedules1[0].next_fire_time = None + schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc) + + # Release all the schedules + datastore.release_schedules("dummy-id1", schedules1) + datastore.release_schedules("dummy-id2", schedules2) + + # Check that the first schedule is gone + schedules = datastore.get_schedules() + assert len(schedules) == 2 + assert schedules[0].id == "s2" + assert schedules[1].id == "s3" + + # Check for the appropriate update and delete events + received_event = events.pop(0) + assert isinstance(received_event, ScheduleRemoved) + assert received_event.schedule_id == "s1" + + received_event = events.pop(0) + assert isinstance(received_event, ScheduleUpdated) + assert received_event.schedule_id == "s2" + assert received_event.next_fire_time == datetime( + 2020, 9, 15, tzinfo=timezone.utc + ) + + assert not events + + def test_release_schedule_two_identical_fire_times( + self, datastore: DataStore + ) -> None: + """Regression test for #616.""" + for i in range(1, 3): + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) + schedule.next_fire_time = trigger.next() + datastore.add_schedule(schedule, ConflictPolicy.exception) + + schedules = datastore.acquire_schedules("foo", 3) + schedules[0].next_fire_time = None + datastore.release_schedules("foo", schedules) + + remaining = datastore.get_schedules({s.id for s in schedules}) + assert len(remaining) == 1 + assert remaining[0].id == schedules[1].id + + def test_release_two_schedules_at_once(self, datastore: DataStore) -> None: + """Regression test for #621.""" + for i in range(2): + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) + schedule.next_fire_time = trigger.next() + datastore.add_schedule(schedule, ConflictPolicy.exception) + + schedules = datastore.acquire_schedules("foo", 3) + datastore.release_schedules("foo", schedules) + + remaining = datastore.get_schedules({s.id for s in schedules}) + assert len(remaining) == 2 + + def test_acquire_schedules_lock_timeout( + self, datastore: DataStore, schedules: list[Schedule], freezer + ) -> None: + """ + Test that a scheduler can acquire schedules that were acquired by another scheduler but + not released within the lock timeout period. + + """ + datastore.add_schedule(schedules[0], ConflictPolicy.exception) + + # First, one scheduler acquires the first available schedule + acquired1 = datastore.acquire_schedules("dummy-id1", 1) + assert len(acquired1) == 1 + assert acquired1[0].id == "s1" + + # Try to acquire the schedule just at the threshold (now == acquired_until). + # This should not yield any schedules. + freezer.tick(30) + acquired2 = datastore.acquire_schedules("dummy-id2", 1) + assert not acquired2 + + # Right after that, the schedule should be available + freezer.tick(1) + acquired3 = datastore.acquire_schedules("dummy-id2", 1) + assert len(acquired3) == 1 + assert acquired3[0].id == "s1" + + def test_acquire_multiple_workers(self, datastore: DataStore) -> None: + datastore.add_task(Task(id="task1", func=asynccontextmanager)) + jobs = [Job(task_id="task1") for _ in range(2)] + for job in jobs: + datastore.add_job(job) + + # The first worker gets the first job in the queue + jobs1 = datastore.acquire_jobs("worker1", 1) + assert len(jobs1) == 1 + assert jobs1[0].id == jobs[0].id + + # The second worker gets the second job + jobs2 = datastore.acquire_jobs("worker2", 1) + assert len(jobs2) == 1 + assert jobs2[0].id == jobs[1].id + + # The third worker gets nothing + jobs3 = datastore.acquire_jobs("worker3", 1) + assert not jobs3 + + def test_job_release_success(self, datastore: DataStore) -> None: + datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + datastore.add_job(job) + + acquired = datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult( + job_id=acquired[0].id, + outcome=JobOutcome.success, + return_value="foo", + ), + ) + result = datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.success + assert result.exception is None + assert result.return_value == "foo" + + # Check that the job and its result are gone + assert not datastore.get_jobs({acquired[0].id}) + assert not datastore.get_job_result(acquired[0].id) + + def test_job_release_failure(self, datastore: DataStore) -> None: + datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + datastore.add_job(job) + + acquired = datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult( + job_id=acquired[0].id, + outcome=JobOutcome.error, + exception=ValueError("foo"), + ), + ) + result = datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.error + assert isinstance(result.exception, ValueError) + assert result.exception.args == ("foo",) + assert result.return_value is None + + # Check that the job and its result are gone + assert not datastore.get_jobs({acquired[0].id}) + assert not datastore.get_job_result(acquired[0].id) + + def test_job_release_missed_deadline(self, datastore: DataStore): + datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + datastore.add_job(job) + + acquired = datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.missed_start_deadline), + ) + result = datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.missed_start_deadline + assert result.exception is None + assert result.return_value is None + + # Check that the job and its result are gone + assert not datastore.get_jobs({acquired[0].id}) + assert not datastore.get_job_result(acquired[0].id) + + def test_job_release_cancelled(self, datastore: DataStore) -> None: + datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + datastore.add_job(job) + + acquired = datastore.acquire_jobs("worker1", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + datastore.release_job( + "worker1", + acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.cancelled), + ) + result = datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.cancelled + assert result.exception is None + assert result.return_value is None + + # Check that the job and its result are gone + assert not datastore.get_jobs({acquired[0].id}) + assert not datastore.get_job_result(acquired[0].id) + + def test_acquire_jobs_lock_timeout( + self, datastore: DataStore, freezer: FrozenDateTimeFactory + ) -> None: + """ + Test that a worker can acquire jobs that were acquired by another scheduler but not + released within the lock timeout period. + + """ + datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + datastore.add_job(job) + + # First, one worker acquires the first available job + acquired = datastore.acquire_jobs("worker1", 1) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + # Try to acquire the job just at the threshold (now == acquired_until). + # This should not yield any jobs. + freezer.tick(30) + assert not datastore.acquire_jobs("worker2", 1) + + # Right after that, the job should be available + freezer.tick(1) + acquired = datastore.acquire_jobs("worker2", 1) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + def test_acquire_jobs_max_number_exceeded(self, datastore: DataStore) -> None: + datastore.add_task( + Task(id="task1", func=asynccontextmanager, max_running_jobs=2) + ) + jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")] + for job in jobs: + datastore.add_job(job) + + # Check that only 2 jobs are returned from acquire_jobs() even though the limit wqas 3 + acquired_jobs = datastore.acquire_jobs("worker1", 3) + assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] + + # Release one job, and the worker should be able to acquire the third job + datastore.release_job( + "worker1", + acquired_jobs[0].task_id, + JobResult( + job_id=acquired_jobs[0].id, + outcome=JobOutcome.success, + return_value=None, + ), + ) + acquired_jobs = datastore.acquire_jobs("worker1", 3) + assert [job.id for job in acquired_jobs] == [jobs[2].id] @pytest.mark.anyio -class TestAsyncStores: +class TestAsyncDataStores: + @asynccontextmanager + async def capture_events( + self, + datastore: AsyncDataStore, + limit: int, + event_types: set[type[Event]] | None = None, + ) -> AsyncGenerator[list[Event], None]: + def listener(event: Event) -> None: + events.append(event) + if len(events) == limit: + limit_event.set() + subscription.unsubscribe() + + events: list[Event] = [] + limit_event = anyio.Event() + subscription = datastore.events.subscribe(listener, event_types) + yield events + if limit: + with anyio.fail_after(3): + await limit_event.wait() + + @pytest.fixture + async def event_broker(self) -> AsyncGenerator[AsyncEventBroker, Any]: + broker = LocalAsyncEventBroker() + await broker.start() + yield broker + await broker.stop() + + @pytest.fixture( + params=[ + pytest.param(lazy_fixture("adapted_memory_store"), id="memory"), + pytest.param( + lazy_fixture("asyncpg_store"), + id="asyncpg", + marks=[pytest.mark.external_service], + ), + ] + ) + async def datastore( + self, request: SubRequest, event_broker: AsyncEventBroker + ) -> AsyncGenerator[AsyncDataStore, Any]: + datastore = cast(AsyncDataStore, request.param) + await datastore.start(event_broker) + yield datastore + await datastore.stop() + async def test_add_replace_task(self, datastore: AsyncDataStore) -> None: import math event_types = {TaskAdded, TaskUpdated} - async with datastore, capture_events(datastore, 3, event_types) as events: + async with self.capture_events(datastore, 3, event_types) as events: await datastore.add_task(Task(id="test_task", func=print)) await datastore.add_task(Task(id="test_task2", func=math.ceil)) await datastore.add_task(Task(id="test_task", func=repr)) @@ -198,7 +655,7 @@ class TestAsyncStores: async def test_add_schedules( self, datastore: AsyncDataStore, schedules: list[Schedule] ) -> None: - async with datastore, capture_events(datastore, 3, {ScheduleAdded}) as events: + async with self.capture_events(datastore, 3, {ScheduleAdded}) as events: for schedule in schedules: await datastore.add_schedule(schedule, ConflictPolicy.exception) @@ -215,7 +672,7 @@ class TestAsyncStores: async def test_replace_schedules( self, datastore: AsyncDataStore, schedules: list[Schedule] ) -> None: - async with datastore, capture_events(datastore, 1, {ScheduleUpdated}) as events: + async with self.capture_events(datastore, 1, {ScheduleUpdated}) as events: for schedule in schedules: await datastore.add_schedule(schedule, ConflictPolicy.exception) @@ -252,7 +709,7 @@ class TestAsyncStores: async def test_remove_schedules( self, datastore: AsyncDataStore, schedules: list[Schedule] ) -> None: - async with datastore, capture_events(datastore, 2, {ScheduleRemoved}) as events: + async with self.capture_events(datastore, 2, {ScheduleRemoved}) as events: for schedule in schedules: await datastore.add_schedule(schedule, ConflictPolicy.exception) @@ -272,7 +729,7 @@ class TestAsyncStores: self, datastore: AsyncDataStore, schedules: list[Schedule] ) -> None: event_types = {ScheduleRemoved, ScheduleUpdated} - async with datastore, capture_events(datastore, 2, event_types) as events: + async with self.capture_events(datastore, 2, event_types) as events: for schedule in schedules: await datastore.add_schedule(schedule, ConflictPolicy.exception) @@ -323,16 +780,15 @@ class TestAsyncStores: self, datastore: AsyncDataStore ) -> None: """Regression test for #616.""" - async with datastore: - for i in range(1, 3): - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) - schedule.next_fire_time = trigger.next() - await datastore.add_schedule(schedule, ConflictPolicy.exception) + for i in range(1, 3): + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) + schedule.next_fire_time = trigger.next() + await datastore.add_schedule(schedule, ConflictPolicy.exception) - schedules = await datastore.acquire_schedules("foo", 3) - schedules[0].next_fire_time = None - await datastore.release_schedules("foo", schedules) + schedules = await datastore.acquire_schedules("foo", 3) + schedules[0].next_fire_time = None + await datastore.release_schedules("foo", schedules) remaining = await datastore.get_schedules({s.id for s in schedules}) assert len(remaining) == 1 @@ -342,15 +798,14 @@ class TestAsyncStores: self, datastore: AsyncDataStore ) -> None: """Regression test for #621.""" - async with datastore: - for i in range(2): - trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) - schedule.next_fire_time = trigger.next() - await datastore.add_schedule(schedule, ConflictPolicy.exception) + for i in range(2): + trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) + schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger) + schedule.next_fire_time = trigger.next() + await datastore.add_schedule(schedule, ConflictPolicy.exception) - schedules = await datastore.acquire_schedules("foo", 3) - await datastore.release_schedules("foo", schedules) + schedules = await datastore.acquire_schedules("foo", 3) + await datastore.release_schedules("foo", schedules) remaining = await datastore.get_schedules({s.id for s in schedules}) assert len(remaining) == 2 @@ -363,153 +818,145 @@ class TestAsyncStores: not released within the lock timeout period. """ - async with datastore: - await datastore.add_schedule(schedules[0], ConflictPolicy.exception) - - # First, one scheduler acquires the first available schedule - acquired1 = await datastore.acquire_schedules("dummy-id1", 1) - assert len(acquired1) == 1 - assert acquired1[0].id == "s1" - - # Try to acquire the schedule just at the threshold (now == acquired_until). - # This should not yield any schedules. - freezer.tick(30) - acquired2 = await datastore.acquire_schedules("dummy-id2", 1) - assert not acquired2 - - # Right after that, the schedule should be available - freezer.tick(1) - acquired3 = await datastore.acquire_schedules("dummy-id2", 1) - assert len(acquired3) == 1 - assert acquired3[0].id == "s1" + await datastore.add_schedule(schedules[0], ConflictPolicy.exception) - async def test_acquire_multiple_workers(self, datastore: AsyncDataStore) -> None: - async with datastore: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - jobs = [Job(task_id="task1") for _ in range(2)] - for job in jobs: - await datastore.add_job(job) - - # The first worker gets the first job in the queue - jobs1 = await datastore.acquire_jobs("worker1", 1) - assert len(jobs1) == 1 - assert jobs1[0].id == jobs[0].id - - # The second worker gets the second job - jobs2 = await datastore.acquire_jobs("worker2", 1) - assert len(jobs2) == 1 - assert jobs2[0].id == jobs[1].id - - # The third worker gets nothing - jobs3 = await datastore.acquire_jobs("worker3", 1) - assert not jobs3 - - async def test_job_release_success(self, datastore: AsyncDataStore) -> None: - async with datastore: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1") - await datastore.add_job(job) + # First, one scheduler acquires the first available schedule + acquired1 = await datastore.acquire_schedules("dummy-id1", 1) + assert len(acquired1) == 1 + assert acquired1[0].id == "s1" - acquired = await datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - await datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult( - job_id=acquired[0].id, - outcome=JobOutcome.success, - return_value="foo", - ), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.success - assert result.exception is None - assert result.return_value == "foo" + # Try to acquire the schedule just at the threshold (now == acquired_until). + # This should not yield any schedules. + freezer.tick(30) + acquired2 = await datastore.acquire_schedules("dummy-id2", 1) + assert not acquired2 - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) + # Right after that, the schedule should be available + freezer.tick(1) + acquired3 = await datastore.acquire_schedules("dummy-id2", 1) + assert len(acquired3) == 1 + assert acquired3[0].id == "s1" - async def test_job_release_failure(self, datastore: AsyncDataStore) -> None: - async with datastore: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1") + async def test_acquire_multiple_workers(self, datastore: AsyncDataStore) -> None: + await datastore.add_task(Task(id="task1", func=asynccontextmanager)) + jobs = [Job(task_id="task1") for _ in range(2)] + for job in jobs: await datastore.add_job(job) - acquired = await datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - await datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult( - job_id=acquired[0].id, - outcome=JobOutcome.error, - exception=ValueError("foo"), - ), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.error - assert isinstance(result.exception, ValueError) - assert result.exception.args == ("foo",) - assert result.return_value is None + # The first worker gets the first job in the queue + jobs1 = await datastore.acquire_jobs("worker1", 1) + assert len(jobs1) == 1 + assert jobs1[0].id == jobs[0].id - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) + # The second worker gets the second job + jobs2 = await datastore.acquire_jobs("worker2", 1) + assert len(jobs2) == 1 + assert jobs2[0].id == jobs[1].id - async def test_job_release_missed_deadline(self, datastore: AsyncDataStore): - async with datastore: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1") - await datastore.add_job(job) + # The third worker gets nothing + jobs3 = await datastore.acquire_jobs("worker3", 1) + assert not jobs3 + + async def test_job_release_success(self, datastore: AsyncDataStore) -> None: + await datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + await datastore.add_job(job) + + acquired = await datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + await datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult( + job_id=acquired[0].id, + outcome=JobOutcome.success, + return_value="foo", + ), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.success + assert result.exception is None + assert result.return_value == "foo" - acquired = await datastore.acquire_jobs("worker_id", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) - await datastore.release_job( - "worker_id", - acquired[0].task_id, - JobResult( - job_id=acquired[0].id, outcome=JobOutcome.missed_start_deadline - ), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.missed_start_deadline - assert result.exception is None - assert result.return_value is None + async def test_job_release_failure(self, datastore: AsyncDataStore) -> None: + await datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + await datastore.add_job(job) + + acquired = await datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + await datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult( + job_id=acquired[0].id, + outcome=JobOutcome.error, + exception=ValueError("foo"), + ), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.error + assert isinstance(result.exception, ValueError) + assert result.exception.args == ("foo",) + assert result.return_value is None - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) - async def test_job_release_cancelled(self, datastore: AsyncDataStore) -> None: - async with datastore: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1") - await datastore.add_job(job) + async def test_job_release_missed_deadline(self, datastore: AsyncDataStore): + await datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + await datastore.add_job(job) + + acquired = await datastore.acquire_jobs("worker_id", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + await datastore.release_job( + "worker_id", + acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.missed_start_deadline), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.missed_start_deadline + assert result.exception is None + assert result.return_value is None - acquired = await datastore.acquire_jobs("worker1", 2) - assert len(acquired) == 1 - assert acquired[0].id == job.id + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) - await datastore.release_job( - "worker1", - acquired[0].task_id, - JobResult(job_id=acquired[0].id, outcome=JobOutcome.cancelled), - ) - result = await datastore.get_job_result(acquired[0].id) - assert result.outcome is JobOutcome.cancelled - assert result.exception is None - assert result.return_value is None + async def test_job_release_cancelled(self, datastore: AsyncDataStore) -> None: + await datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + await datastore.add_job(job) + + acquired = await datastore.acquire_jobs("worker1", 2) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + await datastore.release_job( + "worker1", + acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.cancelled), + ) + result = await datastore.get_job_result(acquired[0].id) + assert result.outcome is JobOutcome.cancelled + assert result.exception is None + assert result.return_value is None - # Check that the job and its result are gone - assert not await datastore.get_jobs({acquired[0].id}) - assert not await datastore.get_job_result(acquired[0].id) + # Check that the job and its result are gone + assert not await datastore.get_jobs({acquired[0].id}) + assert not await datastore.get_job_result(acquired[0].id) async def test_acquire_jobs_lock_timeout( self, datastore: AsyncDataStore, freezer: FrozenDateTimeFactory @@ -519,51 +966,49 @@ class TestAsyncStores: released within the lock timeout period. """ - async with datastore: - await datastore.add_task(Task(id="task1", func=asynccontextmanager)) - job = Job(task_id="task1") - await datastore.add_job(job) - - # First, one worker acquires the first available job - acquired = await datastore.acquire_jobs("worker1", 1) - assert len(acquired) == 1 - assert acquired[0].id == job.id - - # Try to acquire the job just at the threshold (now == acquired_until). - # This should not yield any jobs. - freezer.tick(30) - assert not await datastore.acquire_jobs("worker2", 1) - - # Right after that, the job should be available - freezer.tick(1) - acquired = await datastore.acquire_jobs("worker2", 1) - assert len(acquired) == 1 - assert acquired[0].id == job.id + await datastore.add_task(Task(id="task1", func=asynccontextmanager)) + job = Job(task_id="task1") + await datastore.add_job(job) + + # First, one worker acquires the first available job + acquired = await datastore.acquire_jobs("worker1", 1) + assert len(acquired) == 1 + assert acquired[0].id == job.id + + # Try to acquire the job just at the threshold (now == acquired_until). + # This should not yield any jobs. + freezer.tick(30) + assert not await datastore.acquire_jobs("worker2", 1) + + # Right after that, the job should be available + freezer.tick(1) + acquired = await datastore.acquire_jobs("worker2", 1) + assert len(acquired) == 1 + assert acquired[0].id == job.id async def test_acquire_jobs_max_number_exceeded( self, datastore: AsyncDataStore ) -> None: - async with datastore: - await datastore.add_task( - Task(id="task1", func=asynccontextmanager, max_running_jobs=2) - ) - jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")] - for job in jobs: - await datastore.add_job(job) - - # Check that only 2 jobs are returned from acquire_jobs() even though the limit wqas 3 - acquired_jobs = await datastore.acquire_jobs("worker1", 3) - assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] - - # Release one job, and the worker should be able to acquire the third job - await datastore.release_job( - "worker1", - acquired_jobs[0].task_id, - JobResult( - job_id=acquired_jobs[0].id, - outcome=JobOutcome.success, - return_value=None, - ), - ) - acquired_jobs = await datastore.acquire_jobs("worker1", 3) - assert [job.id for job in acquired_jobs] == [jobs[2].id] + await datastore.add_task( + Task(id="task1", func=asynccontextmanager, max_running_jobs=2) + ) + jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")] + for job in jobs: + await datastore.add_job(job) + + # Check that only 2 jobs are returned from acquire_jobs() even though the limit wqas 3 + acquired_jobs = await datastore.acquire_jobs("worker1", 3) + assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] + + # Release one job, and the worker should be able to acquire the third job + await datastore.release_job( + "worker1", + acquired_jobs[0].task_id, + JobResult( + job_id=acquired_jobs[0].id, + outcome=JobOutcome.success, + return_value=None, + ), + ) + acquired_jobs = await datastore.acquire_jobs("worker1", 3) + assert [job.id for job in acquired_jobs] == [jobs[2].id] diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py index 7ec90ab..ffa4be7 100644 --- a/tests/test_eventbrokers.py +++ b/tests/test_eventbrokers.py @@ -1,9 +1,10 @@ from __future__ import annotations +from collections.abc import AsyncGenerator, Generator from concurrent.futures import Future from datetime import datetime, timezone from queue import Empty, Queue -from typing import Callable +from typing import Any import pytest from _pytest.fixtures import SubRequest @@ -73,8 +74,10 @@ async def asyncpg_broker(serializer: Serializer) -> AsyncEventBroker: ), ] ) -def broker(request: SubRequest) -> Callable[[], EventBroker]: - return request.param +def broker(request: SubRequest) -> Generator[EventBroker, Any, None]: + request.param.start() + yield request.param + request.param.stop() @pytest.fixture( @@ -87,23 +90,24 @@ def broker(request: SubRequest) -> Callable[[], EventBroker]: ), ] ) -def async_broker(request: SubRequest) -> Callable[[], AsyncEventBroker]: - return request.param +async def async_broker(request: SubRequest) -> AsyncGenerator[AsyncEventBroker, Any]: + await request.param.start() + yield request.param + await request.param.stop() class TestEventBroker: def test_publish_subscribe(self, broker: EventBroker) -> None: queue: Queue[Event] = Queue() - with broker: - broker.subscribe(queue.put_nowait) - broker.subscribe(queue.put_nowait) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - broker.publish(event) - event1 = queue.get(timeout=3) - event2 = queue.get(timeout=1) + broker.subscribe(queue.put_nowait) + broker.subscribe(queue.put_nowait) + event = ScheduleAdded( + schedule_id="schedule1", + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), + ) + broker.publish(event) + event1 = queue.get(timeout=3) + event2 = queue.get(timeout=1) assert event1 == event2 assert isinstance(event1, ScheduleAdded) @@ -115,43 +119,39 @@ class TestEventBroker: def test_subscribe_one_shot(self, broker: EventBroker) -> None: queue: Queue[Event] = Queue() - with broker: - broker.subscribe(queue.put_nowait, one_shot=True) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - broker.publish(event) - event = ScheduleAdded( - schedule_id="schedule2", - next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc), - ) - broker.publish(event) - received_event = queue.get(timeout=3) - with pytest.raises(Empty): - queue.get(timeout=0.1) + broker.subscribe(queue.put_nowait, one_shot=True) + event = ScheduleAdded( + schedule_id="schedule1", + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), + ) + broker.publish(event) + event = ScheduleAdded( + schedule_id="schedule2", + next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc), + ) + broker.publish(event) + received_event = queue.get(timeout=3) + with pytest.raises(Empty): + queue.get(timeout=0.1) assert isinstance(received_event, ScheduleAdded) assert received_event.schedule_id == "schedule1" def test_unsubscribe(self, broker: EventBroker, caplog) -> None: queue: Queue[Event] = Queue() - with broker: - subscription = broker.subscribe(queue.put_nowait) - broker.publish(Event()) - queue.get(timeout=3) + subscription = broker.subscribe(queue.put_nowait) + broker.publish(Event()) + queue.get(timeout=3) - subscription.unsubscribe() - broker.publish(Event()) - with pytest.raises(Empty): - queue.get(timeout=0.1) + subscription.unsubscribe() + broker.publish(Event()) + with pytest.raises(Empty): + queue.get(timeout=0.1) def test_publish_no_subscribers( self, broker: EventBroker, caplog: LogCaptureFixture ) -> None: - with broker: - broker.publish(Event()) - + broker.publish(Event()) assert not caplog.text def test_publish_exception( @@ -162,33 +162,31 @@ class TestEventBroker: timestamp = datetime.now(timezone.utc) event_future: Future[Event] = Future() - with broker: - broker.subscribe(bad_subscriber) - broker.subscribe(event_future.set_result) - broker.publish(Event(timestamp=timestamp)) + broker.subscribe(bad_subscriber) + broker.subscribe(event_future.set_result) + broker.publish(Event(timestamp=timestamp)) - event = event_future.result(3) - assert isinstance(event, Event) - assert event.timestamp == timestamp - assert "Error delivering Event" in caplog.text + event = event_future.result(3) + assert isinstance(event, Event) + assert event.timestamp == timestamp + assert "Error delivering Event" in caplog.text @pytest.mark.anyio class TestAsyncEventBroker: async def test_publish_subscribe(self, async_broker: AsyncEventBroker) -> None: send, receive = create_memory_object_stream(2) - async with async_broker: - async_broker.subscribe(send.send) - async_broker.subscribe(send.send_nowait) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - await async_broker.publish(event) - - with fail_after(3): - event1 = await receive.receive() - event2 = await receive.receive() + async_broker.subscribe(send.send) + async_broker.subscribe(send.send_nowait) + event = ScheduleAdded( + schedule_id="schedule1", + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), + ) + await async_broker.publish(event) + + with fail_after(3): + event1 = await receive.receive() + event2 = await receive.receive() assert event1 == event2 assert isinstance(event1, ScheduleAdded) @@ -200,47 +198,43 @@ class TestAsyncEventBroker: async def test_subscribe_one_shot(self, async_broker: AsyncEventBroker) -> None: send, receive = create_memory_object_stream(2) - async with async_broker: - async_broker.subscribe(send.send, one_shot=True) - event = ScheduleAdded( - schedule_id="schedule1", - next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), - ) - await async_broker.publish(event) - event = ScheduleAdded( - schedule_id="schedule2", - next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc), - ) - await async_broker.publish(event) - - with fail_after(3): - received_event = await receive.receive() - - with pytest.raises(TimeoutError), fail_after(0.1): - await receive.receive() + async_broker.subscribe(send.send, one_shot=True) + event = ScheduleAdded( + schedule_id="schedule1", + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc), + ) + await async_broker.publish(event) + event = ScheduleAdded( + schedule_id="schedule2", + next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc), + ) + await async_broker.publish(event) + + with fail_after(3): + received_event = await receive.receive() + + with pytest.raises(TimeoutError), fail_after(0.1): + await receive.receive() assert isinstance(received_event, ScheduleAdded) assert received_event.schedule_id == "schedule1" async def test_unsubscribe(self, async_broker: AsyncEventBroker) -> None: send, receive = create_memory_object_stream() - async with async_broker: - subscription = async_broker.subscribe(send.send) - await async_broker.publish(Event()) - with fail_after(3): - await receive.receive() + subscription = async_broker.subscribe(send.send) + await async_broker.publish(Event()) + with fail_after(3): + await receive.receive() - subscription.unsubscribe() - await async_broker.publish(Event()) - with pytest.raises(TimeoutError), fail_after(0.1): - await receive.receive() + subscription.unsubscribe() + await async_broker.publish(Event()) + with pytest.raises(TimeoutError), fail_after(0.1): + await receive.receive() async def test_publish_no_subscribers( self, async_broker: AsyncEventBroker, caplog: LogCaptureFixture ) -> None: - async with async_broker: - await async_broker.publish(Event()) - + await async_broker.publish(Event()) assert not caplog.text async def test_publish_exception( @@ -251,11 +245,10 @@ class TestAsyncEventBroker: timestamp = datetime.now(timezone.utc) send, receive = create_memory_object_stream() - async with async_broker: - async_broker.subscribe(bad_subscriber) - async_broker.subscribe(send.send) - await async_broker.publish(Event(timestamp=timestamp)) + async_broker.subscribe(bad_subscriber) + async_broker.subscribe(send.send) + await async_broker.publish(Event(timestamp=timestamp)) - received_event = await receive.receive() - assert received_event.timestamp == timestamp - assert "Error delivering Event" in caplog.text + received_event = await receive.receive() + assert received_event.timestamp == timestamp + assert "Error delivering Event" in caplog.text diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index b8c64a2..e33a91e 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -18,7 +18,6 @@ from apscheduler.events import ( JobAdded, ScheduleAdded, ScheduleRemoved, - SchedulerStarted, SchedulerStopped, TaskAdded, ) @@ -28,6 +27,7 @@ from apscheduler.schedulers.sync import Scheduler from apscheduler.structures import Job, Task from apscheduler.triggers.date import DateTrigger from apscheduler.triggers.interval import IntervalTrigger +from apscheduler.workers.async_ import AsyncWorker if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -57,23 +57,18 @@ class TestAsyncScheduler: async def test_schedule_job(self) -> None: def listener(received_event: Event) -> None: received_events.append(received_event) - if len(received_events) == 5: + if isinstance(received_event, ScheduleRemoved): event.set() received_events: list[Event] = [] event = anyio.Event() - scheduler = AsyncScheduler(start_worker=False) - scheduler.events.subscribe(listener) trigger = DateTrigger(datetime.now(timezone.utc)) - async with scheduler: + async with AsyncScheduler(start_worker=False) as scheduler: + scheduler.event_broker.subscribe(listener) await scheduler.add_schedule(dummy_async_job, trigger, id="foo") with fail_after(3): await event.wait() - # The scheduler was first started - received_event = received_events.pop(0) - assert isinstance(received_event, SchedulerStarted) - # Then the task was added received_event = received_events.pop(0) assert isinstance(received_event, TaskAdded) @@ -129,7 +124,7 @@ class TestAsyncScheduler: async with AsyncScheduler(start_worker=False) as scheduler: trigger = IntervalTrigger(seconds=3, start_time=orig_start_time) job_added_event = anyio.Event() - scheduler.events.subscribe(job_added_listener, {JobAdded}) + scheduler.event_broker.subscribe(job_added_listener, {JobAdded}) schedule_id = await scheduler.add_schedule( dummy_async_job, trigger, max_jitter=max_jitter ) @@ -142,7 +137,8 @@ class TestAsyncScheduler: fake_uniform.assert_called_once_with(0, expected_upper_bound) - # Check that the job was created with the proper amount of jitter in its scheduled time + # Check that the job was created with the proper amount of jitter in its + # scheduled time jobs = await scheduler.data_store.get_jobs({job_id}) assert jobs[0].jitter == timedelta(seconds=jitter) assert jobs[0].scheduled_fire_time == orig_start_time + timedelta( @@ -188,7 +184,7 @@ class TestAsyncScheduler: async def test_contextvars(self) -> None: def check_contextvars() -> None: assert current_scheduler.get() is scheduler - assert current_worker.get() is scheduler.worker + assert isinstance(current_worker.get(), AsyncWorker) info = job_info.get() assert info.task_id == "task_id" assert info.schedule_id == "foo" @@ -223,23 +219,18 @@ class TestSyncScheduler: def test_schedule_job(self): def listener(received_event: Event) -> None: received_events.append(received_event) - if len(received_events) == 5: + if isinstance(received_event, ScheduleRemoved): event.set() received_events: list[Event] = [] event = threading.Event() - scheduler = Scheduler(start_worker=False) - scheduler.events.subscribe(listener) trigger = DateTrigger(datetime.now(timezone.utc)) - with scheduler: + with Scheduler(start_worker=False) as scheduler: + scheduler.event_broker.subscribe(listener) scheduler.add_schedule(dummy_sync_job, trigger, id="foo") event.wait(3) - # The scheduler was first started - received_event = received_events.pop(0) - assert isinstance(received_event, SchedulerStarted) - - # Then the task was added + # First, a task was added received_event = received_events.pop(0) assert isinstance(received_event, TaskAdded) assert received_event.task_id == "test_schedulers:dummy_sync_job" @@ -264,9 +255,6 @@ class TestSyncScheduler: received_event = received_events.pop(0) assert isinstance(received_event, SchedulerStopped) - # There should be no more events on the list - assert not received_events - @pytest.mark.parametrize( "max_jitter, expected_upper_bound", [pytest.param(2, 2, id="within"), pytest.param(4, 2.999999, id="exceed")], @@ -293,7 +281,7 @@ class TestSyncScheduler: with Scheduler(start_worker=False) as scheduler: trigger = IntervalTrigger(seconds=3, start_time=orig_start_time) job_added_event = threading.Event() - scheduler.events.subscribe(job_added_listener, {JobAdded}) + scheduler.event_broker.subscribe(job_added_listener, {JobAdded}) schedule_id = scheduler.add_schedule( dummy_async_job, trigger, max_jitter=max_jitter ) @@ -350,7 +338,7 @@ class TestSyncScheduler: def test_contextvars(self) -> None: def check_contextvars() -> None: assert current_scheduler.get() is scheduler - assert current_worker.get() is scheduler.worker + assert current_worker.get() is not None info = job_info.get() assert info.task_id == "task_id" assert info.schedule_id == "foo" diff --git a/tests/test_workers.py b/tests/test_workers.py index 6e5568f..e1c7421 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -17,7 +17,6 @@ from apscheduler.events import ( JobAdded, JobReleased, TaskAdded, - WorkerStarted, WorkerStopped, ) from apscheduler.structures import Task @@ -55,26 +54,20 @@ class TestAsyncWorker: ) -> None: def listener(received_event: Event): received_events.append(received_event) - if len(received_events) == 5: + if isinstance(received_event, JobReleased): event.set() received_events: list[Event] = [] event = anyio.Event() - data_store = MemoryDataStore() - worker = AsyncWorker(data_store) - worker.events.subscribe(listener) - async with worker: + async with AsyncWorker(MemoryDataStore()) as worker: + worker.event_broker.subscribe(listener) await worker.data_store.add_task(Task(id="task_id", func=target_func)) job = Job(task_id="task_id", args=(1, 2), kwargs={"x": "foo", "fail": fail}) await worker.data_store.add_job(job) with fail_after(3): await event.wait() - # The worker was first started - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStarted) - - # Then the task was added + # First, a task was added received_event = received_events.pop(0) assert isinstance(received_event, TaskAdded) assert received_event.task_id == "task_id" @@ -112,16 +105,14 @@ class TestAsyncWorker: async def test_run_deadline_missed(self) -> None: def listener(received_event: Event): received_events.append(received_event) - if len(received_events) == 5: + if isinstance(received_event, JobReleased): event.set() scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) received_events: list[Event] = [] event = anyio.Event() - data_store = MemoryDataStore() - worker = AsyncWorker(data_store) - worker.events.subscribe(listener) - async with worker: + async with AsyncWorker(MemoryDataStore()) as worker: + worker.event_broker.subscribe(listener) await worker.data_store.add_task(Task(id="task_id", func=fail_func)) job = Job( task_id="task_id", @@ -133,11 +124,7 @@ class TestAsyncWorker: with fail_after(3): await event.wait() - # The worker was first started - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStarted) - - # Then the task was added + # First, a task was added received_event = received_events.pop(0) assert isinstance(received_event, TaskAdded) assert received_event.task_id == "task_id" @@ -175,25 +162,19 @@ class TestSyncWorker: def test_run_job_nonscheduled(self, fail: bool) -> None: def listener(received_event: Event): received_events.append(received_event) - if len(received_events) == 5: + if isinstance(received_event, JobReleased): event.set() received_events: list[Event] = [] event = threading.Event() - data_store = MemoryDataStore() - worker = Worker(data_store) - worker.events.subscribe(listener) - with worker: + with Worker(MemoryDataStore()) as worker: + worker.event_broker.subscribe(listener) worker.data_store.add_task(Task(id="task_id", func=sync_func)) job = Job(task_id="task_id", args=(1, 2), kwargs={"x": "foo", "fail": fail}) worker.data_store.add_job(job) - event.wait(5) - - # The worker was first started - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStarted) + event.wait(3) - # Then the task was added + # First, a task was added received_event = received_events.pop(0) assert isinstance(received_event, TaskAdded) assert received_event.task_id == "task_id" @@ -229,18 +210,16 @@ class TestSyncWorker: assert not received_events def test_run_deadline_missed(self) -> None: - def listener(worker_event: Event): - received_events.append(worker_event) - if len(received_events) == 5: + def listener(received_event: Event): + received_events.append(received_event) + if isinstance(received_event, JobReleased): event.set() scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) received_events: list[Event] = [] event = threading.Event() - data_store = MemoryDataStore() - worker = Worker(data_store) - worker.events.subscribe(listener) - with worker: + with Worker(MemoryDataStore()) as worker: + worker.event_broker.subscribe(listener) worker.data_store.add_task(Task(id="task_id", func=fail_func)) job = Job( task_id="task_id", @@ -251,11 +230,7 @@ class TestSyncWorker: worker.data_store.add_job(job) event.wait(3) - # The worker was first started - received_event = received_events.pop(0) - assert isinstance(received_event, WorkerStarted) - - # Then the task was added + # First, a task was added received_event = received_events.pop(0) assert isinstance(received_event, TaskAdded) assert received_event.task_id == "task_id" |