diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-06-09 12:40:55 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2022-07-19 00:48:51 +0300 |
commit | f215c1ab45959095f6b499eb7b26356c5937ee8b (patch) | |
tree | 552687a0ed6e799b3da96eec5cd3fbb14d19f1b5 /src | |
parent | e3158fdf59a7c92a9449a566a2b746a4024e582f (diff) | |
download | apscheduler-f215c1ab45959095f6b499eb7b26356c5937ee8b.tar.gz |
Added support for starting the sync scheduler (and worker) without the context manager
Diffstat (limited to 'src')
-rw-r--r-- | src/apscheduler/abc.py | 43 | ||||
-rw-r--r-- | src/apscheduler/converters.py | 11 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_adapter.py | 42 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_sqlalchemy.py | 24 | ||||
-rw-r--r-- | src/apscheduler/datastores/base.py | 37 | ||||
-rw-r--r-- | src/apscheduler/datastores/memory.py | 20 | ||||
-rw-r--r-- | src/apscheduler/datastores/mongodb.py | 29 | ||||
-rw-r--r-- | src/apscheduler/datastores/sqlalchemy.py | 24 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/async_adapter.py | 68 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/async_local.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/asyncpg.py | 46 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/base.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/local.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/mqtt.py | 16 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/redis.py | 13 | ||||
-rw-r--r-- | src/apscheduler/schedulers/async_.py | 338 | ||||
-rw-r--r-- | src/apscheduler/schedulers/sync.py | 395 | ||||
-rw-r--r-- | src/apscheduler/util.py | 53 | ||||
-rw-r--r-- | src/apscheduler/workers/async_.py | 171 | ||||
-rw-r--r-- | src/apscheduler/workers/sync.py | 202 |
20 files changed, 827 insertions, 753 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index 00bb281..c270bc6 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -108,15 +108,16 @@ class EventSource(metaclass=ABCMeta): class EventBroker(EventSource): """ - Interface for objects that can be used to publish notifications to interested subscribers. - - Can be used as a context manager. + Interface for objects that can be used to publish notifications to interested + subscribers. """ - def __enter__(self): - return self + @abstractmethod + def start(self) -> None: + pass - def __exit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + def stop(self, *, force: bool = False) -> None: pass @abstractmethod @@ -129,16 +130,14 @@ class EventBroker(EventSource): class AsyncEventBroker(EventSource): - """ - Asynchronous version of :class:`EventBroker`. - - Can be used as an asynchronous context manager. - """ + """Asynchronous version of :class:`EventBroker`.""" - async def __aenter__(self): - return self + @abstractmethod + async def start(self) -> None: + pass - async def __aexit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + async def stop(self, *, force: bool = False) -> None: pass @abstractmethod @@ -151,10 +150,12 @@ class AsyncEventBroker(EventSource): class DataStore: - def __enter__(self): - return self + @abstractmethod + def start(self, event_broker: EventBroker) -> None: + pass - def __exit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + def stop(self, *, force: bool = False) -> None: pass @property @@ -309,10 +310,12 @@ class DataStore: class AsyncDataStore: - async def __aenter__(self): - return self + @abstractmethod + async def start(self, event_broker: AsyncEventBroker) -> None: + pass - async def __aexit__(self, exc_type, exc_val, exc_tb): + @abstractmethod + async def stop(self, *, force: bool = False) -> None: pass @property diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py index 7502327..3518d44 100644 --- a/src/apscheduler/converters.py +++ b/src/apscheduler/converters.py @@ -48,6 +48,17 @@ def as_enum(enum_class: type[TEnum]) -> Callable[[TEnum | str], TEnum]: return converter +def as_async_eventbroker( + value: abc.EventBroker | abc.AsyncEventBroker, +) -> abc.AsyncEventBroker: + if isinstance(value, abc.EventBroker): + from apscheduler.eventbrokers.async_adapter import AsyncEventBrokerAdapter + + return AsyncEventBrokerAdapter(value) + + return value + + def as_async_datastore(value: abc.DataStore | abc.AsyncDataStore) -> abc.AsyncDataStore: if isinstance(value, abc.DataStore): from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py index 3b2342e..99accd4 100644 --- a/src/apscheduler/datastores/async_adapter.py +++ b/src/apscheduler/datastores/async_adapter.py @@ -1,8 +1,6 @@ from __future__ import annotations -from contextlib import AsyncExitStack from datetime import datetime -from functools import partial from typing import Iterable from uuid import UUID @@ -10,43 +8,35 @@ import attrs from anyio import to_thread from anyio.from_thread import BlockingPortal -from ..abc import AsyncDataStore, AsyncEventBroker, DataStore, EventSource +from ..abc import AsyncEventBroker, DataStore from ..enums import ConflictPolicy -from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter +from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter, SyncEventBrokerAdapter from ..structures import Job, JobResult, Schedule, Task -from ..util import reentrant +from .base import BaseAsyncDataStore -@reentrant @attrs.define(eq=False) -class AsyncDataStoreAdapter(AsyncDataStore): +class AsyncDataStoreAdapter(BaseAsyncDataStore): original: DataStore _portal: BlockingPortal = attrs.field(init=False) - _events: AsyncEventBroker = attrs.field(init=False) - _exit_stack: AsyncExitStack = attrs.field(init=False) - @property - def events(self) -> EventSource: - return self._events - - async def __aenter__(self) -> AsyncDataStoreAdapter: - self._exit_stack = AsyncExitStack() + async def start(self, event_broker: AsyncEventBroker) -> None: + await super().start(event_broker) self._portal = BlockingPortal() - await self._exit_stack.enter_async_context(self._portal) - - self._events = AsyncEventBrokerAdapter(self.original.events, self._portal) - await self._exit_stack.enter_async_context(self._events) + await self._portal.__aenter__() - await to_thread.run_sync(self.original.__enter__) - self._exit_stack.push_async_exit( - partial(to_thread.run_sync, self.original.__exit__) - ) + if isinstance(event_broker, AsyncEventBrokerAdapter): + sync_event_broker = event_broker.original + else: + sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal) - return self + await to_thread.run_sync(lambda: self.original.start(sync_event_broker)) - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + async def stop(self, *, force: bool = False) -> None: + await to_thread.run_sync(lambda: self.original.stop(force=force)) + await self._portal.__aexit__(None, None, None) + await super().stop(force=force) async def add_task(self, task: Task) -> None: await to_thread.run_sync(self.original.add_task, task) diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py index aa56c6f..edf9b3e 100644 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ b/src/apscheduler/datastores/async_sqlalchemy.py @@ -17,9 +17,8 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter -from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule +from ..abc import AsyncEventBroker, Job, Schedule from ..enums import ConflictPolicy -from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import ( DataStoreEvent, JobAcquired, @@ -36,17 +35,14 @@ from ..events import ( from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..structures import JobResult, Task -from ..util import reentrant +from .base import BaseAsyncDataStore from .sqlalchemy import _BaseSQLAlchemyDataStore -@reentrant @attrs.define(eq=False) -class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): +class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore): engine: AsyncEngine - _events: AsyncEventBroker = attrs.field(factory=LocalAsyncEventBroker) - @classmethod def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: engine = create_async_engine(url, future=True) @@ -63,7 +59,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): reraise=True, ) - async def __aenter__(self): + async def start(self, event_broker: AsyncEventBroker) -> None: + await super().start(event_broker) + asynclib = sniffio.current_async_library() or "(unknown)" if asynclib != "asyncio": raise RuntimeError( @@ -92,16 +90,6 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): f"only version 1 is supported by this version of APScheduler" ) - await self._events.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._events.__aexit__(exc_type, exc_val, exc_tb) - - @property - def events(self) -> EventSource: - return self._events - async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: diff --git a/src/apscheduler/datastores/base.py b/src/apscheduler/datastores/base.py new file mode 100644 index 0000000..c05d28c --- /dev/null +++ b/src/apscheduler/datastores/base.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from apscheduler.abc import ( + AsyncDataStore, + AsyncEventBroker, + DataStore, + EventBroker, + EventSource, +) + + +class BaseDataStore(DataStore): + _events: EventBroker + + def start(self, event_broker: EventBroker) -> None: + self._events = event_broker + + def stop(self, *, force: bool = False) -> None: + del self._events + + @property + def events(self) -> EventSource: + return self._events + + +class BaseAsyncDataStore(AsyncDataStore): + _events: AsyncEventBroker + + async def start(self, event_broker: AsyncEventBroker) -> None: + self._events = event_broker + + async def stop(self, *, force: bool = False) -> None: + del self._events + + @property + def events(self) -> EventSource: + return self._events diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py index c2d7420..516d141 100644 --- a/src/apscheduler/datastores/memory.py +++ b/src/apscheduler/datastores/memory.py @@ -9,9 +9,8 @@ from uuid import UUID import attrs -from ..abc import DataStore, EventBroker, EventSource, Job, Schedule +from ..abc import Job, Schedule from ..enums import ConflictPolicy -from ..eventbrokers.local import LocalEventBroker from ..events import ( JobAcquired, JobAdded, @@ -25,7 +24,7 @@ from ..events import ( ) from ..exceptions import ConflictingIdError, TaskLookupError from ..structures import JobResult, Task -from ..util import reentrant +from .base import BaseDataStore max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) @@ -83,11 +82,9 @@ class JobState: return hash(self.job.id) -@reentrant @attrs.define(eq=False) -class MemoryDataStore(DataStore): +class MemoryDataStore(BaseDataStore): lock_expiration_delay: float = 30 - _events: EventBroker = attrs.Factory(LocalEventBroker) _tasks: dict[str, TaskState] = attrs.Factory(dict) _schedules: list[ScheduleState] = attrs.Factory(list) _schedules_by_id: dict[str, ScheduleState] = attrs.Factory(dict) @@ -111,17 +108,6 @@ class MemoryDataStore(DataStore): right_index = bisect_left(self._jobs, state) return self._jobs.index(state, left_index, right_index + 1) - def __enter__(self): - self._events.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._events.__exit__(exc_type, exc_val, exc_tb) - - @property - def events(self) -> EventSource: - return self._events - def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: return [ state.schedule diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index 2669d03..59a345a 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator from collections import defaultdict -from contextlib import ExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger from typing import Any, Callable, ClassVar, Iterable @@ -18,7 +17,7 @@ from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.errors import ConnectionFailure, DuplicateKeyError -from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer +from ..abc import EventBroker, Job, Schedule, Serializer from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome from ..eventbrokers.local import LocalEventBroker from ..events import ( @@ -41,7 +40,7 @@ from ..exceptions import ( ) from ..serializers.pickle import PickleSerializer from ..structures import JobResult, RetrySettings, Task -from ..util import reentrant +from .base import BaseDataStore class CustomEncoder(TypeEncoder): @@ -62,14 +61,14 @@ def ensure_uuid_presentation(client: MongoClient) -> None: pass -@reentrant @attrs.define(eq=False) -class MongoDBDataStore(DataStore): +class MongoDBDataStore(BaseDataStore): client: MongoClient = attrs.field(validator=instance_of(MongoClient)) serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True) + event_broker = attrs.field(factory=LocalEventBroker, kw_only=True) database: str = attrs.field(default="apscheduler", kw_only=True) lock_expiration_delay: float = attrs.field(default=30, kw_only=True) - retry_settings: RetrySettings = attrs.field(default=RetrySettings()) + retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True) start_from_scratch: bool = attrs.field(default=False, kw_only=True) _task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)] @@ -79,8 +78,6 @@ class MongoDBDataStore(DataStore): _job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)] _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__)) - _exit_stack: ExitStack = attrs.field(init=False, factory=ExitStack) - _events: EventBroker = attrs.field(init=False, factory=LocalEventBroker) _local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict) def __attrs_post_init__(self) -> None: @@ -108,10 +105,6 @@ class MongoDBDataStore(DataStore): client = MongoClient(uri) return cls(client, **options) - @property - def events(self) -> EventSource: - return self._events - def _retry(self) -> tenacity.Retrying: return tenacity.Retrying( stop=self.retry_settings.stop, @@ -128,7 +121,9 @@ class MongoDBDataStore(DataStore): retry_state.outcome.exception(), ) - def __enter__(self): + def start(self, event_broker: EventBroker) -> None: + super().start(event_broker) + server_info = self.client.server_info() if server_info["versionArray"] < [4, 0]: raise RuntimeError( @@ -136,9 +131,6 @@ class MongoDBDataStore(DataStore): f"{server_info['version']}" ) - self._exit_stack.__enter__() - self._exit_stack.enter_context(self._events) - for attempt in self._retry(): with attempt, self.client.start_session() as session: if self.start_from_scratch: @@ -153,11 +145,6 @@ class MongoDBDataStore(DataStore): self._jobs.create_index("tags", session=session) self._jobs_results.create_index("finished_at", session=session) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) - def add_task(self, task: Task) -> None: for attempt in self._retry(): with attempt: diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index cca223e..9ee03fa 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -31,9 +31,8 @@ from sqlalchemy.future import Engine, create_engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal -from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer +from ..abc import EventBroker, Job, Schedule, Serializer from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome -from ..eventbrokers.local import LocalEventBroker from ..events import ( Event, JobAcquired, @@ -52,7 +51,7 @@ from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..serializers.pickle import PickleSerializer from ..structures import JobResult, RetrySettings, Task -from ..util import reentrant +from .base import BaseDataStore class EmulatedUUID(TypeDecorator): @@ -219,13 +218,10 @@ class _BaseSQLAlchemyDataStore: return jobs -@reentrant @attrs.define(eq=False) -class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): +class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore): engine: Engine - _events: EventBroker = attrs.field(init=False, factory=LocalEventBroker) - @classmethod def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: engine = create_engine(url) @@ -240,7 +236,9 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): reraise=True, ) - def __enter__(self): + def start(self, event_broker: EventBroker) -> None: + super().start(event_broker) + for attempt in self._retry(): with attempt, self.engine.begin() as conn: if self.start_from_scratch: @@ -259,16 +257,6 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): f"only version 1 is supported by this version of APScheduler" ) - self._events.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._events.__exit__(exc_type, exc_val, exc_tb) - - @property - def events(self) -> EventSource: - return self._events - def add_task(self, task: Task) -> None: insert = self.t_tasks.insert().values( id=task.id, diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py index ebceaad..4c1cb6c 100644 --- a/src/apscheduler/eventbrokers/async_adapter.py +++ b/src/apscheduler/eventbrokers/async_adapter.py @@ -1,39 +1,65 @@ from __future__ import annotations -from functools import partial +from typing import Any, Callable, Iterable import attrs from anyio import to_thread from anyio.from_thread import BlockingPortal -from apscheduler.abc import EventBroker -from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker +from apscheduler.abc import AsyncEventBroker, EventBroker, Subscription from apscheduler.events import Event -from apscheduler.util import reentrant -@reentrant @attrs.define(eq=False) -class AsyncEventBrokerAdapter(LocalAsyncEventBroker): +class AsyncEventBrokerAdapter(AsyncEventBroker): original: EventBroker - portal: BlockingPortal - - async def __aenter__(self): - await super().__aenter__() - if not self.portal: - self.portal = BlockingPortal() - self._exit_stack.enter_async_context(self.portal) + async def start(self) -> None: + await to_thread.run_sync(self.original.start) - await to_thread.run_sync(self.original.__enter__) - self._exit_stack.push_async_exit( - partial(to_thread.run_sync, self.original.__exit__) - ) + async def stop(self, *, force: bool = False) -> None: + await to_thread.run_sync(lambda: self.original.stop(force=force)) - # Relay events from the original broker to this one - self._exit_stack.enter_context( - self.original.subscribe(partial(self.portal.call, self.publish_local)) - ) + async def publish_local(self, event: Event) -> None: + await to_thread.run_sync(self.original.publish_local, event) async def publish(self, event: Event) -> None: await to_thread.run_sync(self.original.publish, event) + + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + return self.original.subscribe(callback, event_types, one_shot=one_shot) + + +@attrs.define(eq=False) +class SyncEventBrokerAdapter(EventBroker): + original: AsyncEventBroker + portal: BlockingPortal + + def start(self) -> None: + pass + + def stop(self, *, force: bool = False) -> None: + pass + + def publish_local(self, event: Event) -> None: + self.portal.call(self.original.publish_local, event) + + def publish(self, event: Event) -> None: + self.portal.call(self.original.publish, event) + + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + return self.portal.call( + lambda: self.original.subscribe(callback, event_types, one_shot=one_shot) + ) diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py index 98031bf..dba37f1 100644 --- a/src/apscheduler/eventbrokers/async_local.py +++ b/src/apscheduler/eventbrokers/async_local.py @@ -1,7 +1,6 @@ from __future__ import annotations from asyncio import iscoroutine -from contextlib import AsyncExitStack from typing import Any, Callable import attrs @@ -10,26 +9,19 @@ from anyio.abc import TaskGroup from ..abc import AsyncEventBroker from ..events import Event -from ..util import reentrant from .base import BaseEventBroker -@reentrant @attrs.define(eq=False) class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker): _task_group: TaskGroup = attrs.field(init=False) - _exit_stack: AsyncExitStack = attrs.field(init=False) - - async def __aenter__(self) -> LocalAsyncEventBroker: - self._exit_stack = AsyncExitStack() + async def start(self) -> None: self._task_group = create_task_group() - await self._exit_stack.enter_async_context(self._task_group) - - return self + await self._task_group.__aenter__() - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + async def stop(self, *, force: bool = False) -> None: + await self._task_group.__aexit__(None, None, None) del self._task_group async def publish(self, event: Event) -> None: diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py index d97c229..c80033e 100644 --- a/src/apscheduler/eventbrokers/asyncpg.py +++ b/src/apscheduler/eventbrokers/asyncpg.py @@ -4,13 +4,12 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable import attrs -from anyio import TASK_STATUS_IGNORED, sleep +from anyio import TASK_STATUS_IGNORED, CancelScope, sleep from asyncpg import Connection from asyncpg.pool import Pool from ..events import Event from ..exceptions import SerializationError -from ..util import reentrant from .async_local import LocalAsyncEventBroker from .base import DistributedEventBrokerMixin @@ -18,12 +17,12 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncEngine -@reentrant @attrs.define(eq=False) class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): connection_factory: Callable[[], AsyncContextManager[Connection]] channel: str = attrs.field(kw_only=True, default="apscheduler") max_idle_time: float = attrs.field(kw_only=True, default=30) + _listen_cancel_scope: CancelScope = attrs.field(init=False) @classmethod def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker: @@ -47,11 +46,15 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): return cls(connection_factory) - async def __aenter__(self) -> AsyncpgEventBroker: - await super().__aenter__() - await self._task_group.start(self._listen_notifications) - self._exit_stack.callback(self._task_group.cancel_scope.cancel) - return self + async def start(self) -> None: + await super().start() + self._listen_cancel_scope = await self._task_group.start( + self._listen_notifications + ) + + async def stop(self, *, force: bool = False) -> None: + self._listen_cancel_scope.cancel() + await super().stop(force=force) async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: def callback(connection, pid, channel: str, payload: str) -> None: @@ -60,19 +63,20 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): self._task_group.start_soon(self.publish_local, event) task_started_sent = False - while True: - async with self.connection_factory() as conn: - await conn.add_listener(self.channel, callback) - if not task_started_sent: - task_status.started() - task_started_sent = True - - try: - while True: - await sleep(self.max_idle_time) - await conn.execute("SELECT 1") - finally: - await conn.remove_listener(self.channel, callback) + with CancelScope() as cancel_scope: + while True: + async with self.connection_factory() as conn: + await conn.add_listener(self.channel, callback) + if not task_started_sent: + task_status.started(cancel_scope) + task_started_sent = True + + try: + while True: + await sleep(self.max_idle_time) + await conn.execute("SELECT 1") + finally: + await conn.remove_listener(self.channel, callback) async def publish(self, event: Event) -> None: notification = self.generate_notification_str(event) diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index 38c83c7..3aeee77 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Iterable import attrs from .. import events -from ..abc import EventBroker, Serializer, Subscription +from ..abc import EventSource, Serializer, Subscription from ..events import Event from ..exceptions import DeserializationError @@ -25,7 +25,7 @@ class LocalSubscription(Subscription): @attrs.define(eq=False) -class BaseEventBroker(EventBroker): +class BaseEventBroker(EventSource): _logger: Logger = attrs.field(init=False) _subscriptions: dict[object, LocalSubscription] = attrs.field( init=False, factory=dict @@ -70,7 +70,7 @@ class DistributedEventBrokerMixin: self._logger.exception( "Failed to deserialize an event of type %s", event_type, - serialized=serialized, + extra={"serialized": serialized}, ) return None @@ -80,7 +80,7 @@ class DistributedEventBrokerMixin: self._logger.error( "Receive notification for a nonexistent event type: %s", event_type, - serialized=serialized, + extra={"serialized": serialized}, ) return None @@ -94,7 +94,9 @@ class DistributedEventBrokerMixin: try: event_type_bytes, serialized = payload.split(b" ", 1) except ValueError: - self._logger.error("Received malformatted notification", payload=payload) + self._logger.error( + "Received malformatted notification", extra={"payload": payload} + ) return None event_type = event_type_bytes.decode("ascii", errors="replace") @@ -104,7 +106,9 @@ class DistributedEventBrokerMixin: try: event_type, b64_serialized = payload.split(" ", 1) except ValueError: - self._logger.error("Received malformatted notification", payload=payload) + self._logger.error( + "Received malformatted notification", extra={"payload": payload} + ) return None return self._reconstitute_event(event_type, b64decode(b64_serialized)) diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py index cf19526..a870c45 100644 --- a/src/apscheduler/eventbrokers/local.py +++ b/src/apscheduler/eventbrokers/local.py @@ -8,26 +8,22 @@ from typing import Any, Callable, Iterable import attrs -from ..abc import Subscription +from ..abc import EventBroker, Subscription from ..events import Event -from ..util import reentrant from .base import BaseEventBroker -@reentrant @attrs.define(eq=False) -class LocalEventBroker(BaseEventBroker): +class LocalEventBroker(EventBroker, BaseEventBroker): _executor: ThreadPoolExecutor = attrs.field(init=False) _exit_stack: ExitStack = attrs.field(init=False) _subscriptions_lock: Lock = attrs.field(init=False, factory=Lock) - def __enter__(self): - self._exit_stack = ExitStack() - self._executor = self._exit_stack.enter_context(ThreadPoolExecutor(1)) - return self + def start(self) -> None: + self._executor = ThreadPoolExecutor(1) - def __exit__(self, exc_type, exc_val, exc_tb): - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + def stop(self, *, force: bool = False) -> None: + self._executor.shutdown(wait=not force) del self._executor def subscribe( diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py index 8b0a512..9b0f859 100644 --- a/src/apscheduler/eventbrokers/mqtt.py +++ b/src/apscheduler/eventbrokers/mqtt.py @@ -11,12 +11,10 @@ from paho.mqtt.reasoncodes import ReasonCodes from ..abc import Serializer from ..events import Event from ..serializers.json import JSONSerializer -from ..util import reentrant from .base import DistributedEventBrokerMixin from .local import LocalEventBroker -@reentrant @attrs.define(eq=False) class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client: Client @@ -28,8 +26,8 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): publish_qos: int = attrs.field(kw_only=True, default=0) _ready_future: Future[None] = attrs.field(init=False) - def __enter__(self): - super().__enter__() + def start(self) -> None: + super().start() self._ready_future = Future() self.client.enable_logger(self._logger) self.client.on_connect = self._on_connect @@ -38,11 +36,11 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): self.client.connect(self.host, self.port) self.client.loop_start() self._ready_future.result(10) - self._exit_stack.push( - lambda exc_type, *_: self.client.loop_stop(force=bool(exc_type)) - ) - self._exit_stack.callback(self.client.disconnect) - return self + + def stop(self, *, force: bool = False) -> None: + self.client.disconnect() + self.client.loop_stop(force=force) + super().stop() def _on_connect( self, diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py index 0f14e3e..a41ef37 100644 --- a/src/apscheduler/eventbrokers/redis.py +++ b/src/apscheduler/eventbrokers/redis.py @@ -9,12 +9,10 @@ from redis import ConnectionPool, Redis from ..abc import Serializer from ..events import Event from ..serializers.json import JSONSerializer -from ..util import reentrant from .base import DistributedEventBrokerMixin from .local import LocalEventBroker -@reentrant @attrs.define(eq=False) class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client: Redis @@ -23,6 +21,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): message_poll_interval: float = attrs.field(kw_only=True, default=0.05) _stopped: bool = attrs.field(init=False, default=True) _ready_future: Future[None] = attrs.field(init=False) + _thread: Thread = attrs.field(init=False) @classmethod def from_url(cls, url: str, **kwargs) -> RedisEventBroker: @@ -30,7 +29,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): client = Redis(connection_pool=pool) return cls(client) - def __enter__(self): + def start(self) -> None: self._stopped = False self._ready_future = Future() self._thread = Thread( @@ -38,14 +37,14 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): ) self._thread.start() self._ready_future.result(10) - return super().__enter__() + super().start() - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self, *, force: bool = False) -> None: self._stopped = True - if not exc_type: + if not force: self._thread.join(5) - super().__exit__(exc_type, exc_val, exc_tb) + super().stop(force=force) def _listen_messages(self) -> None: while not self._stopped: diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index 72c5994..889596d 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -17,10 +17,11 @@ from anyio import ( get_cancelled_exc_class, move_on_after, ) +from anyio.abc import TaskGroup -from ..abc import AsyncDataStore, EventSource, Job, Schedule, Trigger +from ..abc import AsyncDataStore, AsyncEventBroker, Job, Schedule, Subscription, Trigger from ..context import current_scheduler -from ..converters import as_async_datastore +from ..converters import as_async_datastore, as_async_eventbroker from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker @@ -48,6 +49,9 @@ class AsyncScheduler: data_store: AsyncDataStore = attrs.field( converter=as_async_datastore, factory=MemoryDataStore ) + event_broker: AsyncEventBroker = attrs.field( + converter=as_async_eventbroker, factory=LocalAsyncEventBroker + ) identity: str = attrs.field(kw_only=True, default=None) start_worker: bool = attrs.field(kw_only=True, default=True) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) @@ -55,65 +59,24 @@ class AsyncScheduler: _state: RunState = attrs.field(init=False, default=RunState.stopped) _wakeup_event: anyio.Event = attrs.field(init=False) _wakeup_deadline: datetime | None = attrs.field(init=False, default=None) - _worker: AsyncWorker | None = attrs.field(init=False, default=None) - _events: LocalAsyncEventBroker = attrs.field( - init=False, factory=LocalAsyncEventBroker - ) - _exit_stack: AsyncExitStack = attrs.field(init=False) + _task_group: TaskGroup | None = attrs.field(init=False, default=None) + _schedule_added_subscription: Subscription = attrs.field(init=False) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" - @property - def events(self) -> EventSource: - return self._events - - @property - def worker(self) -> AsyncWorker | None: - return self._worker - async def __aenter__(self): - self._state = RunState.starting - self._wakeup_event = anyio.Event() - self._exit_stack = AsyncExitStack() - await self._exit_stack.__aenter__() - await self._exit_stack.enter_async_context(self._events) - - # Initialize the data store and start relaying events to the scheduler's event broker - await self._exit_stack.enter_async_context(self.data_store) - self._exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) - - # Wake up the scheduler if the data store emits a significant schedule event - self._exit_stack.enter_context( - self.data_store.events.subscribe( - self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} - ) - ) - - # Start the built-in worker, if configured to do so - if self.start_worker: - token = current_scheduler.set(self) - try: - self._worker = AsyncWorker(self.data_store) - await self._exit_stack.enter_async_context(self._worker) - finally: - current_scheduler.reset(token) - - # Start the worker and return when it has signalled readiness or raised an exception - task_group = create_task_group() - await self._exit_stack.enter_async_context(task_group) - await task_group.start(self.run) + self._task_group = create_task_group() + await self._task_group.__aenter__() + await self._task_group.start(self.run_until_stopped) return self async def __aexit__(self, exc_type, exc_val, exc_tb): self._state = RunState.stopping self._wakeup_event.set() - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - self._state = RunState.stopped - del self._wakeup_event + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + self._task_group = None def _schedule_added_or_modified(self, event: Event) -> None: event_ = cast("ScheduleAdded | ScheduleUpdated", event) @@ -281,130 +244,169 @@ class AsyncScheduler: else: raise RuntimeError(f"Unknown job outcome: {result.outcome}") - async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None: - if self._state is not RunState.starting: + async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None: + if self._state is not RunState.stopped: raise RuntimeError( - f"This function cannot be called while the scheduler is in the " - f"{self._state} state" + f'Cannot start the scheduler when it is in the "{self._state}" ' + f"state" ) - # Signal that the scheduler has started - self._state = RunState.started - task_status.started() - await self._events.publish(SchedulerStarted()) - - exception: BaseException | None = None - try: - while self._state is RunState.started: - schedules = await self.data_store.acquire_schedules(self.identity, 100) - now = datetime.now(timezone.utc) - for schedule in schedules: - # Calculate a next fire time for the schedule, if possible - fire_times = [schedule.next_fire_time] - calculate_next = schedule.trigger.next - while True: - try: - fire_time = calculate_next() - except Exception: - self.logger.exception( - "Error computing next fire time for schedule %r of task %r – " - "removing schedule", - schedule.id, - schedule.task_id, - ) - break - - # Stop if the calculated fire time is in the future - if fire_time is None or fire_time > now: - schedule.next_fire_time = fire_time - break - - # Only keep all the fire times if coalesce policy = "all" - if schedule.coalesce is CoalescePolicy.all: - fire_times.append(fire_time) - elif schedule.coalesce is CoalescePolicy.latest: - fire_times[0] = fire_time - - # Add one or more jobs to the job queue - max_jitter = ( - schedule.max_jitter.total_seconds() - if schedule.max_jitter - else 0 + self._state = RunState.starting + async with AsyncExitStack() as exit_stack: + self._wakeup_event = anyio.Event() + + # Initialize the event broker + await self.event_broker.start() + exit_stack.push_async_exit( + lambda *exc_info: self.event_broker.stop(force=exc_info[0] is not None) + ) + + # Initialize the data store + await self.data_store.start(self.event_broker) + exit_stack.push_async_exit( + lambda *exc_info: self.data_store.stop(force=exc_info[0] is not None) + ) + + # Wake up the scheduler if the data store emits a significant schedule event + self._schedule_added_subscription = self.event_broker.subscribe( + self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} + ) + + # Start the built-in worker, if configured to do so + if self.start_worker: + token = current_scheduler.set(self) + exit_stack.callback(current_scheduler.reset, token) + worker = AsyncWorker( + self.data_store, self.event_broker, is_internal=True + ) + await exit_stack.enter_async_context(worker) + + # Signal that the scheduler has started + self._state = RunState.started + task_status.started() + await self.event_broker.publish_local(SchedulerStarted()) + + exception: BaseException | None = None + try: + while self._state is RunState.started: + schedules = await self.data_store.acquire_schedules( + self.identity, 100 ) - for i, fire_time in enumerate(fire_times): - # Calculate a jitter if max_jitter > 0 - jitter = _zero_timedelta - if max_jitter: - if i + 1 < len(fire_times): - next_fire_time = fire_times[i + 1] - else: - next_fire_time = schedule.next_fire_time - - if next_fire_time is not None: - # Jitter must never be so high that it would cause a fire time to - # equal or exceed the next fire time - jitter_s = min( - [ - max_jitter, - ( - next_fire_time - - fire_time - - _microsecond_delta - ).total_seconds(), - ] + now = datetime.now(timezone.utc) + for schedule in schedules: + # Calculate a next fire time for the schedule, if possible + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + "Error computing next fire time for schedule %r of " + "task %r – removing schedule", + schedule.id, + schedule.task_id, ) - jitter = timedelta(seconds=random.uniform(0, jitter_s)) - fire_time += jitter - - schedule.last_fire_time = fire_time - job = Job( - task_id=schedule.task_id, - args=schedule.args, - kwargs=schedule.kwargs, - schedule_id=schedule.id, - scheduled_fire_time=fire_time, - jitter=jitter, - start_deadline=schedule.next_deadline, - tags=schedule.tags, + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + max_jitter = ( + schedule.max_jitter.total_seconds() + if schedule.max_jitter + else 0 ) - await self.data_store.add_job(job) - - # Update the schedules (and release the scheduler's claim on them) - await self.data_store.release_schedules(self.identity, schedules) + for i, fire_time in enumerate(fire_times): + # Calculate a jitter if max_jitter > 0 + jitter = _zero_timedelta + if max_jitter: + if i + 1 < len(fire_times): + next_fire_time = fire_times[i + 1] + else: + next_fire_time = schedule.next_fire_time + + if next_fire_time is not None: + # Jitter must never be so high that it would cause a + # fire time to equal or exceed the next fire time + jitter_s = min( + [ + max_jitter, + ( + next_fire_time + - fire_time + - _microsecond_delta + ).total_seconds(), + ] + ) + jitter = timedelta( + seconds=random.uniform(0, jitter_s) + ) + fire_time += jitter + + schedule.last_fire_time = fire_time + job = Job( + task_id=schedule.task_id, + args=schedule.args, + kwargs=schedule.kwargs, + schedule_id=schedule.id, + scheduled_fire_time=fire_time, + jitter=jitter, + start_deadline=schedule.next_deadline, + tags=schedule.tags, + ) + await self.data_store.add_job(job) + + # Update the schedules (and release the scheduler's claim on them) + await self.data_store.release_schedules(self.identity, schedules) + + # If we received fewer schedules than the maximum amount, sleep + # until the next schedule is due or the scheduler is explicitly + # woken up + wait_time = None + if len(schedules) < 100: + self._wakeup_deadline = ( + await self.data_store.get_next_schedule_run_time() + ) + if self._wakeup_deadline: + wait_time = ( + self._wakeup_deadline - datetime.now(timezone.utc) + ).total_seconds() + self.logger.debug( + "Sleeping %.3f seconds until the next fire time (%s)", + wait_time, + self._wakeup_deadline, + ) + else: + self.logger.debug("Waiting for any due schedules to appear") - # If we received fewer schedules than the maximum amount, sleep until the next - # schedule is due or the scheduler is explicitly woken up - wait_time = None - if len(schedules) < 100: - self._wakeup_deadline = ( - await self.data_store.get_next_schedule_run_time() - ) - if self._wakeup_deadline: - wait_time = ( - self._wakeup_deadline - datetime.now(timezone.utc) - ).total_seconds() + with move_on_after(wait_time): + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() + else: self.logger.debug( - "Sleeping %.3f seconds until the next fire time (%s)", - wait_time, - self._wakeup_deadline, + "Processing more schedules on the next iteration" ) - else: - self.logger.debug("Waiting for any due schedules to appear") - - with move_on_after(wait_time): - await self._wakeup_event.wait() - self._wakeup_event = anyio.Event() - else: - self.logger.debug("Processing more schedules on the next iteration") - except get_cancelled_exc_class(): - pass - except BaseException as exc: - self.logger.exception("Scheduler crashed") - exception = exc - raise - else: - self.logger.info("Scheduler stopped") - finally: - self._state = RunState.stopped - with move_on_after(3, shield=True): - await self._events.publish(SchedulerStopped(exception=exception)) + except get_cancelled_exc_class(): + pass + except BaseException as exc: + self.logger.exception("Scheduler crashed") + exception = exc + raise + else: + self.logger.info("Scheduler stopped") + finally: + self._state = RunState.stopped + with move_on_after(3, shield=True): + await self.event_broker.publish_local( + SchedulerStopped(exception=exception) + ) diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index 4c7095e..5ea0c8a 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -5,16 +5,17 @@ import os import platform import random import threading -from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from concurrent.futures import Future from contextlib import ExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger +from types import TracebackType from typing import Any, Callable, Iterable, Mapping, cast from uuid import UUID, uuid4 import attrs -from ..abc import DataStore, EventSource, Trigger +from ..abc import DataStore, EventBroker, Trigger from ..context import current_scheduler from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState @@ -41,39 +42,60 @@ class Scheduler: """A synchronous scheduler implementation.""" data_store: DataStore = attrs.field(factory=MemoryDataStore) + event_broker: EventBroker = attrs.field(factory=LocalEventBroker) identity: str = attrs.field(kw_only=True, default=None) start_worker: bool = attrs.field(kw_only=True, default=True) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: threading.Event = attrs.field(init=False) + _thread: threading.Thread | None = attrs.field(init=False, default=None) + _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event) _wakeup_deadline: datetime | None = attrs.field(init=False, default=None) - _worker: Worker | None = attrs.field(init=False, default=None) - _events: LocalEventBroker = attrs.field(init=False, factory=LocalEventBroker) + _services_initialized: bool = attrs.field(init=False, default=False) + _exit_stack: ExitStack = attrs.field(init=False, factory=ExitStack) + _lock: threading.RLock = attrs.field(init=False, factory=threading.RLock) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" @property - def events(self) -> EventSource: - return self._events - - @property def state(self) -> RunState: return self._state - @property - def worker(self) -> Worker | None: - return self._worker - def __enter__(self) -> Scheduler: self.start_in_background() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: self.stop() + def _ensure_services_ready(self) -> None: + """Ensure that the data store and event broker have been initialized.""" + with self._lock: + if not self._services_initialized: + self._services_initialized = True + self.event_broker.start() + self._exit_stack.push( + lambda *exc_info: self.event_broker.stop( + force=exc_info[0] is not None + ) + ) + + # Initialize the data store + self.data_store.start(self.event_broker) + self._exit_stack.push( + lambda *exc_info: self.data_store.stop( + force=exc_info[0] is not None + ) + ) + atexit.register(self.stop) + def _schedule_added_or_modified(self, event: Event) -> None: event_ = cast("ScheduleAdded | ScheduleUpdated", event) if not self._wakeup_deadline or ( @@ -98,6 +120,7 @@ class Scheduler: tags: Iterable[str] | None = None, conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing, ) -> str: + self._ensure_services_ready() id = id or str(uuid4()) args = tuple(args or ()) kwargs = dict(kwargs or {}) @@ -133,10 +156,12 @@ class Scheduler: return schedule.id def get_schedule(self, id: str) -> Schedule: + self._ensure_services_ready() schedules = self.data_store.get_schedules({id}) return schedules[0] def remove_schedule(self, schedule_id: str) -> None: + self._ensure_services_ready() self.data_store.remove_schedules({schedule_id}) def add_job( @@ -157,6 +182,7 @@ class Scheduler: :return: the ID of the newly created job """ + self._ensure_services_ready() if callable(func_or_task_id): task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) self.data_store.add_task(task) @@ -177,18 +203,19 @@ class Scheduler: Retrieve the result of a job. :param job_id: the ID of the job - :param wait: if ``True``, wait until the job has ended (one way or another), ``False`` to - raise an exception if the result is not yet available + :param wait: if ``True``, wait until the job has ended (one way or another), + ``False`` to raise an exception if the result is not yet available :raises JobLookupError: if the job does not exist in the data store """ + self._ensure_services_ready() wait_event = threading.Event() def listener(event: JobReleased) -> None: if event.job_id == job_id: wait_event.set() - with self.data_store.events.subscribe(listener, {JobReleased}): + with self.data_store.events.subscribe(listener, {JobReleased}, one_shot=True): result = self.data_store.get_job_result(job_id) if result: return result @@ -215,6 +242,7 @@ class Scheduler: :returns: the return value of the target function """ + self._ensure_services_ready() job_complete_event = threading.Event() def listener(event: JobReleased) -> None: @@ -239,181 +267,200 @@ class Scheduler: raise RuntimeError(f"Unknown job outcome: {result.outcome}") def start_in_background(self) -> None: - if self._state is not RunState.stopped: - raise RuntimeError("Scheduler already running") + start_future: Future[None] = Future() + self._thread = threading.Thread( + target=self._run, args=[start_future], daemon=True + ) + self._thread.start() + try: + start_future.result() + except BaseException: + self._thread = None + raise - thread = threading.Thread(target=self.run_until_stopped, daemon=True) - thread.start() atexit.register(self.stop) def stop(self) -> None: - if self._state is RunState.started: - f = Future() - with self._events.subscribe(f.set_result, one_shot=True): + atexit.unregister(self.stop) + with self._lock: + if self._state is RunState.started: self._state = RunState.stopping self._wakeup_event.set() - f.result() - def run_until_stopped(self): - self._state = RunState.starting - self._wakeup_event = threading.Event() - with ExitStack() as exit_stack: - exit_stack.enter_context(self._events) + if self._thread and threading.current_thread() != self._thread: + self._thread.join() + self._thread = None - # Initialize the data store and start relaying events to the scheduler's event broker - exit_stack.enter_context(self.data_store) - exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) + self._exit_stack.close() - # Wake up the scheduler if the data store emits a significant schedule event - exit_stack.enter_context( - self.data_store.events.subscribe( - self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated} - ) - ) + def run_until_stopped(self) -> None: + self._run(None) - # Start the built-in worker, if configured to do so - if self.start_worker: - token = current_scheduler.set(self) - try: - self._worker = Worker(self.data_store) - exit_stack.enter_context(self._worker) - finally: - current_scheduler.reset(token) - - # Start the scheduler and return when it has signalled readiness or raised an exception - start_future: Future[Event] = Future() - with self._events.subscribe(start_future.set_result, one_shot=True): - executor = ThreadPoolExecutor(1) - exit_stack.push( - lambda exc_type, *args: executor.shutdown(wait=exc_type is None) - ) - run_future = executor.submit(self._run) - wait([start_future, run_future], return_when=FIRST_COMPLETED) - - self._state = RunState.stopped - del self._wakeup_event - - if run_future.done(): - run_future.result() + def _run(self, start_future: Future[None] | None) -> None: + with ExitStack() as exit_stack: + try: + with self._lock: + if self._state is not RunState.stopped: + raise RuntimeError( + f'Cannot start the scheduler when it is in the "{self._state}" ' + f"state" + ) - def _run(self) -> None: - assert self._state is RunState.starting + self._state = RunState.starting - # Signal that the scheduler has started - self._state = RunState.started - self._events.publish(SchedulerStarted()) + self._ensure_services_ready() - try: - while self._state is RunState.started: - schedules = self.data_store.acquire_schedules(self.identity, 100) - self.logger.debug( - "Processing %d schedules retrieved from the data store", - len(schedules), + # Wake up the scheduler if the data store emits a significant schedule event + exit_stack.enter_context( + self.data_store.events.subscribe( + self._schedule_added_or_modified, + {ScheduleAdded, ScheduleUpdated}, + ) ) - now = datetime.now(timezone.utc) - for schedule in schedules: - # Calculate a next fire time for the schedule, if possible - fire_times = [schedule.next_fire_time] - calculate_next = schedule.trigger.next - while True: - try: - fire_time = calculate_next() - except Exception: - self.logger.exception( - "Error computing next fire time for schedule %r of task %r – " - "removing schedule", - schedule.id, - schedule.task_id, - ) - break - - # Stop if the calculated fire time is in the future - if fire_time is None or fire_time > now: - schedule.next_fire_time = fire_time - break - - # Only keep all the fire times if coalesce policy = "all" - if schedule.coalesce is CoalescePolicy.all: - fire_times.append(fire_time) - elif schedule.coalesce is CoalescePolicy.latest: - fire_times[0] = fire_time - - # Add one or more jobs to the job queue - max_jitter = ( - schedule.max_jitter.total_seconds() - if schedule.max_jitter - else 0 + + # Start the built-in worker, if configured to do so + if self.start_worker: + token = current_scheduler.set(self) + exit_stack.callback(current_scheduler.reset, token) + worker = Worker( + self.data_store, self.event_broker, is_internal=True + ) + exit_stack.enter_context(worker) + + # Signal that the scheduler has started + self._state = RunState.started + self.event_broker.publish_local(SchedulerStarted()) + except BaseException as exc: + if start_future: + start_future.set_exception(exc) + return + else: + raise + else: + if start_future: + start_future.set_result(None) + + try: + while self._state is RunState.started: + schedules = self.data_store.acquire_schedules(self.identity, 100) + self.logger.debug( + "Processing %d schedules retrieved from the data store", + len(schedules), ) - for i, fire_time in enumerate(fire_times): - # Calculate a jitter if max_jitter > 0 - jitter = _zero_timedelta - if max_jitter: - if i + 1 < len(fire_times): - next_fire_time = fire_times[i + 1] - else: - next_fire_time = schedule.next_fire_time - - if next_fire_time is not None: - # Jitter must never be so high that it would cause a fire time to - # equal or exceed the next fire time - jitter_s = min( - [ - max_jitter, - ( - next_fire_time - - fire_time - - _microsecond_delta - ).total_seconds(), - ] + now = datetime.now(timezone.utc) + for schedule in schedules: + # Calculate a next fire time for the schedule, if possible + fire_times = [schedule.next_fire_time] + calculate_next = schedule.trigger.next + while True: + try: + fire_time = calculate_next() + except Exception: + self.logger.exception( + "Error computing next fire time for schedule %r of task %r – " + "removing schedule", + schedule.id, + schedule.task_id, ) - jitter = timedelta(seconds=random.uniform(0, jitter_s)) - fire_time += jitter - - schedule.last_fire_time = fire_time - job = Job( - task_id=schedule.task_id, - args=schedule.args, - kwargs=schedule.kwargs, - schedule_id=schedule.id, - scheduled_fire_time=fire_time, - jitter=jitter, - start_deadline=schedule.next_deadline, - tags=schedule.tags, + break + + # Stop if the calculated fire time is in the future + if fire_time is None or fire_time > now: + schedule.next_fire_time = fire_time + break + + # Only keep all the fire times if coalesce policy = "all" + if schedule.coalesce is CoalescePolicy.all: + fire_times.append(fire_time) + elif schedule.coalesce is CoalescePolicy.latest: + fire_times[0] = fire_time + + # Add one or more jobs to the job queue + max_jitter = ( + schedule.max_jitter.total_seconds() + if schedule.max_jitter + else 0 ) - self.data_store.add_job(job) - - # Update the schedules (and release the scheduler's claim on them) - self.data_store.release_schedules(self.identity, schedules) - - # If we received fewer schedules than the maximum amount, sleep until the next - # schedule is due or the scheduler is explicitly woken up - wait_time = None - if len(schedules) < 100: - self._wakeup_deadline = self.data_store.get_next_schedule_run_time() - if self._wakeup_deadline: - wait_time = ( - self._wakeup_deadline - datetime.now(timezone.utc) - ).total_seconds() - self.logger.debug( - "Sleeping %.3f seconds until the next fire time (%s)", - wait_time, - self._wakeup_deadline, + for i, fire_time in enumerate(fire_times): + # Calculate a jitter if max_jitter > 0 + jitter = _zero_timedelta + if max_jitter: + if i + 1 < len(fire_times): + next_fire_time = fire_times[i + 1] + else: + next_fire_time = schedule.next_fire_time + + if next_fire_time is not None: + # Jitter must never be so high that it would cause + # a fire time to equal or exceed the next fire time + jitter_s = min( + [ + max_jitter, + ( + next_fire_time + - fire_time + - _microsecond_delta + ).total_seconds(), + ] + ) + jitter = timedelta( + seconds=random.uniform(0, jitter_s) + ) + fire_time += jitter + + schedule.last_fire_time = fire_time + job = Job( + task_id=schedule.task_id, + args=schedule.args, + kwargs=schedule.kwargs, + schedule_id=schedule.id, + scheduled_fire_time=fire_time, + jitter=jitter, + start_deadline=schedule.next_deadline, + tags=schedule.tags, + ) + self.data_store.add_job(job) + + # Update the schedules (and release the scheduler's claim on them) + self.data_store.release_schedules(self.identity, schedules) + + # If we received fewer schedules than the maximum amount, sleep + # until the next schedule is due or the scheduler is explicitly + # woken up + wait_time = None + if len(schedules) < 100: + self._wakeup_deadline = ( + self.data_store.get_next_schedule_run_time() ) - else: - self.logger.debug("Waiting for any due schedules to appear") + if self._wakeup_deadline: + wait_time = ( + self._wakeup_deadline - datetime.now(timezone.utc) + ).total_seconds() + self.logger.debug( + "Sleeping %.3f seconds until the next fire time (%s)", + wait_time, + self._wakeup_deadline, + ) + else: + self.logger.debug("Waiting for any due schedules to appear") - if self._wakeup_event.wait(wait_time): - self._wakeup_event = threading.Event() + if self._wakeup_event.wait(wait_time): + self._wakeup_event = threading.Event() + else: + self.logger.debug( + "Processing more schedules on the next iteration" + ) + except BaseException as exc: + self._state = RunState.stopped + if isinstance(exc, Exception): + self.logger.exception("Scheduler crashed") else: - self.logger.debug("Processing more schedules on the next iteration") - except BaseException as exc: - self._state = RunState.stopped - self.logger.exception("Scheduler crashed") - self._events.publish(SchedulerStopped(exception=exc)) - raise + self.logger.info( + f"Scheduler stopped due to {exc.__class__.__name__}" + ) - self._state = RunState.stopped - self.logger.info("Scheduler stopped") - self._events.publish(SchedulerStopped()) + self.event_broker.publish_local(SchedulerStopped(exception=exc)) + else: + self._state = RunState.stopped + self.logger.info("Scheduler stopped") + self.event_broker.publish_local(SchedulerStopped()) diff --git a/src/apscheduler/util.py b/src/apscheduler/util.py index 7ff4858..8c69203 100644 --- a/src/apscheduler/util.py +++ b/src/apscheduler/util.py @@ -2,9 +2,8 @@ from __future__ import annotations import sys -from collections import defaultdict from datetime import datetime, tzinfo -from typing import Callable, TypeVar +from typing import TypeVar if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -36,53 +35,3 @@ def timezone_repr(timezone: tzinfo) -> str: def absolute_datetime_diff(dateval1: datetime, dateval2: datetime) -> float: return dateval1.timestamp() - dateval2.timestamp() - - -def reentrant(cls: type[T]) -> type[T]: - """ - Modifies a class so that its ``__enter__`` / ``__exit__`` (or ``__aenter__`` / ``__aexit__``) - methods track the number of times it has been entered and exited and only actually invoke - the ``__enter__()`` method on the first entry and ``__exit__()`` on the last exit. - - """ - - def __enter__(self: T) -> T: - loans[self] += 1 - if loans[self] == 1: - previous_enter(self) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - assert loans[self] - loans[self] -= 1 - if loans[self] == 0: - return previous_exit(self, exc_type, exc_val, exc_tb) - - async def __aenter__(self: T) -> T: - loans[self] += 1 - if loans[self] == 1: - await previous_aenter(self) - - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - assert loans[self] - loans[self] -= 1 - if loans[self] == 0: - del loans[self] - return await previous_aexit(self, exc_type, exc_val, exc_tb) - - loans: dict[T, int] = defaultdict(lambda: 0) - previous_enter: Callable = getattr(cls, "__enter__", None) - previous_exit: Callable = getattr(cls, "__exit__", None) - previous_aenter: Callable = getattr(cls, "__aenter__", None) - previous_aexit: Callable = getattr(cls, "__aexit__", None) - if previous_enter and previous_exit: - cls.__enter__ = __enter__ - cls.__exit__ = __exit__ - elif previous_aenter and previous_aexit: - cls.__aenter__ = __aenter__ - cls.__aexit__ = __aexit__ - - return cls diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py index e4ad44d..65870e5 100644 --- a/src/apscheduler/workers/async_.py +++ b/src/apscheduler/workers/async_.py @@ -6,6 +6,7 @@ from contextlib import AsyncExitStack from datetime import datetime, timezone from inspect import isawaitable from logging import Logger, getLogger +from types import TracebackType from typing import Callable from uuid import UUID @@ -17,11 +18,11 @@ from anyio import ( get_cancelled_exc_class, move_on_after, ) -from anyio.abc import CancelScope +from anyio.abc import CancelScope, TaskGroup -from ..abc import AsyncDataStore, EventSource, Job +from ..abc import AsyncDataStore, AsyncEventBroker, Job from ..context import current_worker, job_info -from ..converters import as_async_datastore +from ..converters import as_async_datastore, as_async_eventbroker from ..enums import JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import JobAdded, WorkerStarted, WorkerStopped @@ -34,105 +35,129 @@ class AsyncWorker: """Runs jobs locally in a task group.""" data_store: AsyncDataStore = attrs.field(converter=as_async_datastore) + event_broker: AsyncEventBroker = attrs.field( + converter=as_async_eventbroker, factory=LocalAsyncEventBroker + ) max_concurrent_jobs: int = attrs.field( kw_only=True, validator=positive_integer, default=100 ) identity: str = attrs.field(kw_only=True, default=None) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) + # True if a scheduler owns this worker + _is_internal: bool = attrs.field(kw_only=True, default=False) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: anyio.Event = attrs.field(init=False, factory=anyio.Event) + _wakeup_event: anyio.Event = attrs.field(init=False) + _task_group: TaskGroup = attrs.field(init=False) _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) - _events: LocalAsyncEventBroker = attrs.field( - init=False, factory=LocalAsyncEventBroker - ) _running_jobs: set[UUID] = attrs.field(init=False, factory=set) - _exit_stack: AsyncExitStack = attrs.field(init=False) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" @property - def events(self) -> EventSource: - return self._events - - @property def state(self) -> RunState: return self._state - async def __aenter__(self): - self._state = RunState.starting - self._wakeup_event = anyio.Event() - self._exit_stack = AsyncExitStack() - await self._exit_stack.__aenter__() - await self._exit_stack.enter_async_context(self._events) - - # Initialize the data store and start relaying events to the worker's event broker - await self._exit_stack.enter_async_context(self.data_store) - self._exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) - - # Wake up the worker if the data store emits a significant job event - self._exit_stack.enter_context( - self.data_store.events.subscribe( - lambda event: self._wakeup_event.set(), {JobAdded} - ) - ) - - # Start the actual worker - task_group = create_task_group() - await self._exit_stack.enter_async_context(task_group) - await task_group.start(self.run) + async def __aenter__(self) -> AsyncWorker: + self._task_group = create_task_group() + await self._task_group.__aenter__() + await self._task_group.start(self.run_until_stopped) return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: self._state = RunState.stopping self._wakeup_event.set() - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + del self._task_group del self._wakeup_event - async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None: - if self._state is not RunState.starting: + async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None: + if self._state is not RunState.stopped: raise RuntimeError( - f"This function cannot be called while the worker is in the " - f"{self._state} state" + f'Cannot start the worker when it is in the "{self._state}" ' f"state" ) - # Set the current worker - token = current_worker.set(self) + self._state = RunState.starting + self._wakeup_event = anyio.Event() + async with AsyncExitStack() as exit_stack: + if not self._is_internal: + # Initialize the event broker + await self.event_broker.start() + exit_stack.push_async_exit( + lambda *exc_info: self.event_broker.stop( + force=exc_info[0] is not None + ) + ) - # Signal that the worker has started - self._state = RunState.started - task_status.started() - await self._events.publish(WorkerStarted()) + # Initialize the data store + await self.data_store.start(self.event_broker) + exit_stack.push_async_exit( + lambda *exc_info: self.data_store.stop( + force=exc_info[0] is not None + ) + ) - exception: BaseException | None = None - try: - async with create_task_group() as tg: - while self._state is RunState.started: - limit = self.max_concurrent_jobs - len(self._running_jobs) - jobs = await self.data_store.acquire_jobs(self.identity, limit) - for job in jobs: - task = await self.data_store.get_task(job.task_id) - self._running_jobs.add(job.id) - tg.start_soon(self._run_job, job, task.func) - - await self._wakeup_event.wait() - self._wakeup_event = anyio.Event() - except get_cancelled_exc_class(): - pass - except BaseException as exc: - self.logger.exception("Worker crashed") - exception = exc - else: - self.logger.info("Worker stopped") - finally: - current_worker.reset(token) - self._state = RunState.stopped - with move_on_after(3, shield=True): - await self._events.publish(WorkerStopped(exception=exception)) + # Set the current worker + token = current_worker.set(self) + exit_stack.callback(current_worker.reset, token) + + # Wake up the worker if the data store emits a significant job event + self.event_broker.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded} + ) + + # Signal that the worker has started + self._state = RunState.started + task_status.started() + exception: BaseException | None = None + try: + await self.event_broker.publish_local(WorkerStarted()) + + async with create_task_group() as tg: + while self._state is RunState.started: + limit = self.max_concurrent_jobs - len(self._running_jobs) + jobs = await self.data_store.acquire_jobs(self.identity, limit) + for job in jobs: + task = await self.data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + tg.start_soon(self._run_job, job, task.func) + + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() + except get_cancelled_exc_class(): + pass + except BaseException as exc: + self.logger.exception("Worker crashed") + exception = exc + else: + self.logger.info("Worker stopped") + finally: + self._state = RunState.stopped + with move_on_after(3, shield=True): + await self.event_broker.publish_local( + WorkerStopped(exception=exception) + ) + + async def stop(self, *, force: bool = False) -> None: + if self._state in (RunState.starting, RunState.started): + self._state = RunState.stopping + event = anyio.Event() + self.event_broker.subscribe( + lambda ev: event.set(), {WorkerStopped}, one_shot=True + ) + if force: + self._task_group.cancel_scope.cancel() + else: + self._wakeup_event.set() + + await event.wait() async def _run_job(self, job: Job, func: Callable) -> None: try: diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 2da0045..5b5e6a9 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -1,19 +1,21 @@ from __future__ import annotations +import atexit import os import platform import threading -from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from concurrent.futures import Future, ThreadPoolExecutor from contextlib import ExitStack from contextvars import copy_context from datetime import datetime, timezone from logging import Logger, getLogger +from types import TracebackType from typing import Callable from uuid import UUID import attrs -from ..abc import DataStore, EventSource +from ..abc import DataStore, EventBroker from ..context import current_worker, job_info from ..enums import JobOutcome, RunState from ..eventbrokers.local import LocalEventBroker @@ -27,111 +29,151 @@ class Worker: """Runs jobs locally in a thread pool.""" data_store: DataStore + event_broker: EventBroker = attrs.field(factory=LocalEventBroker) max_concurrent_jobs: int = attrs.field( kw_only=True, validator=positive_integer, default=20 ) identity: str = attrs.field(kw_only=True, default=None) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) + # True if a scheduler owns this worker + _is_internal: bool = attrs.field(kw_only=True, default=False) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: threading.Event = attrs.field(init=False) + _thread: threading.Thread | None = attrs.field(init=False, default=None) + _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event) + _executor: ThreadPoolExecutor = attrs.field(init=False) _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) - _events: LocalEventBroker = attrs.field(init=False, factory=LocalEventBroker) _running_jobs: set[UUID] = attrs.field(init=False, factory=set) - _exit_stack: ExitStack = attrs.field(init=False) - _executor: ThreadPoolExecutor = attrs.field(init=False) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" @property - def events(self) -> EventSource: - return self._events - - @property def state(self) -> RunState: return self._state def __enter__(self) -> Worker: - self._state = RunState.starting - self._wakeup_event = threading.Event() - self._exit_stack = ExitStack() - self._exit_stack.__enter__() - self._exit_stack.enter_context(self._events) - - # Initialize the data store and start relaying events to the worker's event broker - self._exit_stack.enter_context(self.data_store) - self._exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) + self.start_in_background() + return self - # Wake up the worker if the data store emits a significant job event - self._exit_stack.enter_context( - self.data_store.events.subscribe( - lambda event: self._wakeup_event.set(), {JobAdded} - ) - ) + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + self.stop() - # Start the worker and return when it has signalled readiness or raised an exception + def start_in_background(self) -> None: start_future: Future[None] = Future() - with self._events.subscribe(start_future.set_result, one_shot=True): - self._executor = ThreadPoolExecutor(1) - run_future = self._executor.submit(copy_context().run, self.run) - wait([start_future, run_future], return_when=FIRST_COMPLETED) + self._thread = threading.Thread( + target=copy_context().run, args=[self._run, start_future], daemon=True + ) + self._thread.start() + try: + start_future.result() + except BaseException: + self._thread = None + raise - if run_future.done(): - run_future.result() + atexit.register(self.stop) - return self + def stop(self) -> None: + atexit.unregister(self.stop) + if self._state is RunState.started: + self._state = RunState.stopping + self._wakeup_event.set() - def __exit__(self, exc_type, exc_val, exc_tb): - self._state = RunState.stopping - self._wakeup_event.set() - self._executor.shutdown(wait=exc_type is None) - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) - del self._wakeup_event - - def run(self) -> None: - if self._state is not RunState.starting: - raise RuntimeError( - f"This function cannot be called while the worker is in the " - f"{self._state} state" - ) - - # Set the current worker - token = current_worker.set(self) - - # Signal that the worker has started - self._state = RunState.started - self._events.publish(WorkerStarted()) - - executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) - exception: BaseException | None = None - try: - while self._state is RunState.started: - available_slots = self.max_concurrent_jobs - len(self._running_jobs) - if available_slots: - jobs = self.data_store.acquire_jobs(self.identity, available_slots) - for job in jobs: - task = self.data_store.get_task(job.task_id) - self._running_jobs.add(job.id) - executor.submit( - copy_context().run, self._run_job, job, task.func + if threading.current_thread() != self._thread: + self._thread.join() + self._thread = None + + def run_until_stopped(self) -> None: + self._run(None) + + def _run(self, start_future: Future[None] | None) -> None: + with ExitStack() as exit_stack: + try: + if self._state is not RunState.stopped: + raise RuntimeError( + f'Cannot start the worker when it is in the "{self._state}" ' + f"state" + ) + + if not self._is_internal: + # Initialize the event broker + self.event_broker.start() + exit_stack.push( + lambda *exc_info: self.event_broker.stop( + force=exc_info[0] is not None ) + ) - self._wakeup_event.wait() - self._wakeup_event = threading.Event() - except BaseException as exc: - self.logger.exception("Worker crashed") - exception = exc - else: - self.logger.info("Worker stopped") - finally: - current_worker.reset(token) - self._state = RunState.stopped - self._events.publish(WorkerStopped(exception=exception)) - executor.shutdown() + # Initialize the data store + self.data_store.start(self.event_broker) + exit_stack.push( + lambda *exc_info: self.data_store.stop( + force=exc_info[0] is not None + ) + ) + + # Set the current worker + token = current_worker.set(self) + exit_stack.callback(current_worker.reset, token) + + # Wake up the worker if the data store emits a significant job event + exit_stack.enter_context( + self.event_broker.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded} + ) + ) + + # Initialize the thread pool + executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) + exit_stack.enter_context(executor) + + # Signal that the worker has started + self._state = RunState.started + self.event_broker.publish_local(WorkerStarted()) + except BaseException as exc: + if start_future: + start_future.set_exception(exc) + return + else: + raise + else: + if start_future: + start_future.set_result(None) + + try: + while self._state is RunState.started: + available_slots = self.max_concurrent_jobs - len(self._running_jobs) + if available_slots: + jobs = self.data_store.acquire_jobs( + self.identity, available_slots + ) + for job in jobs: + task = self.data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + executor.submit( + copy_context().run, self._run_job, job, task.func + ) + + self._wakeup_event.wait() + self._wakeup_event = threading.Event() + except BaseException as exc: + self._state = RunState.stopped + if isinstance(exc, Exception): + self.logger.exception("Worker crashed") + else: + self.logger.info(f"Worker stopped due to {exc.__class__.__name__}") + + self.event_broker.publish_local(WorkerStopped(exception=exc)) + else: + self._state = RunState.stopped + self.logger.info("Worker stopped") + self.event_broker.publish_local(WorkerStopped()) def _run_job(self, job: Job, func: Callable) -> None: try: |