summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2022-06-09 12:40:55 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2022-07-19 00:48:51 +0300
commitf215c1ab45959095f6b499eb7b26356c5937ee8b (patch)
tree552687a0ed6e799b3da96eec5cd3fbb14d19f1b5 /src
parente3158fdf59a7c92a9449a566a2b746a4024e582f (diff)
downloadapscheduler-f215c1ab45959095f6b499eb7b26356c5937ee8b.tar.gz
Added support for starting the sync scheduler (and worker) without the context manager
Diffstat (limited to 'src')
-rw-r--r--src/apscheduler/abc.py43
-rw-r--r--src/apscheduler/converters.py11
-rw-r--r--src/apscheduler/datastores/async_adapter.py42
-rw-r--r--src/apscheduler/datastores/async_sqlalchemy.py24
-rw-r--r--src/apscheduler/datastores/base.py37
-rw-r--r--src/apscheduler/datastores/memory.py20
-rw-r--r--src/apscheduler/datastores/mongodb.py29
-rw-r--r--src/apscheduler/datastores/sqlalchemy.py24
-rw-r--r--src/apscheduler/eventbrokers/async_adapter.py68
-rw-r--r--src/apscheduler/eventbrokers/async_local.py16
-rw-r--r--src/apscheduler/eventbrokers/asyncpg.py46
-rw-r--r--src/apscheduler/eventbrokers/base.py16
-rw-r--r--src/apscheduler/eventbrokers/local.py16
-rw-r--r--src/apscheduler/eventbrokers/mqtt.py16
-rw-r--r--src/apscheduler/eventbrokers/redis.py13
-rw-r--r--src/apscheduler/schedulers/async_.py338
-rw-r--r--src/apscheduler/schedulers/sync.py395
-rw-r--r--src/apscheduler/util.py53
-rw-r--r--src/apscheduler/workers/async_.py171
-rw-r--r--src/apscheduler/workers/sync.py202
20 files changed, 827 insertions, 753 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py
index 00bb281..c270bc6 100644
--- a/src/apscheduler/abc.py
+++ b/src/apscheduler/abc.py
@@ -108,15 +108,16 @@ class EventSource(metaclass=ABCMeta):
class EventBroker(EventSource):
"""
- Interface for objects that can be used to publish notifications to interested subscribers.
-
- Can be used as a context manager.
+ Interface for objects that can be used to publish notifications to interested
+ subscribers.
"""
- def __enter__(self):
- return self
+ @abstractmethod
+ def start(self) -> None:
+ pass
- def __exit__(self, exc_type, exc_val, exc_tb):
+ @abstractmethod
+ def stop(self, *, force: bool = False) -> None:
pass
@abstractmethod
@@ -129,16 +130,14 @@ class EventBroker(EventSource):
class AsyncEventBroker(EventSource):
- """
- Asynchronous version of :class:`EventBroker`.
-
- Can be used as an asynchronous context manager.
- """
+ """Asynchronous version of :class:`EventBroker`."""
- async def __aenter__(self):
- return self
+ @abstractmethod
+ async def start(self) -> None:
+ pass
- async def __aexit__(self, exc_type, exc_val, exc_tb):
+ @abstractmethod
+ async def stop(self, *, force: bool = False) -> None:
pass
@abstractmethod
@@ -151,10 +150,12 @@ class AsyncEventBroker(EventSource):
class DataStore:
- def __enter__(self):
- return self
+ @abstractmethod
+ def start(self, event_broker: EventBroker) -> None:
+ pass
- def __exit__(self, exc_type, exc_val, exc_tb):
+ @abstractmethod
+ def stop(self, *, force: bool = False) -> None:
pass
@property
@@ -309,10 +310,12 @@ class DataStore:
class AsyncDataStore:
- async def __aenter__(self):
- return self
+ @abstractmethod
+ async def start(self, event_broker: AsyncEventBroker) -> None:
+ pass
- async def __aexit__(self, exc_type, exc_val, exc_tb):
+ @abstractmethod
+ async def stop(self, *, force: bool = False) -> None:
pass
@property
diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py
index 7502327..3518d44 100644
--- a/src/apscheduler/converters.py
+++ b/src/apscheduler/converters.py
@@ -48,6 +48,17 @@ def as_enum(enum_class: type[TEnum]) -> Callable[[TEnum | str], TEnum]:
return converter
+def as_async_eventbroker(
+ value: abc.EventBroker | abc.AsyncEventBroker,
+) -> abc.AsyncEventBroker:
+ if isinstance(value, abc.EventBroker):
+ from apscheduler.eventbrokers.async_adapter import AsyncEventBrokerAdapter
+
+ return AsyncEventBrokerAdapter(value)
+
+ return value
+
+
def as_async_datastore(value: abc.DataStore | abc.AsyncDataStore) -> abc.AsyncDataStore:
if isinstance(value, abc.DataStore):
from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py
index 3b2342e..99accd4 100644
--- a/src/apscheduler/datastores/async_adapter.py
+++ b/src/apscheduler/datastores/async_adapter.py
@@ -1,8 +1,6 @@
from __future__ import annotations
-from contextlib import AsyncExitStack
from datetime import datetime
-from functools import partial
from typing import Iterable
from uuid import UUID
@@ -10,43 +8,35 @@ import attrs
from anyio import to_thread
from anyio.from_thread import BlockingPortal
-from ..abc import AsyncDataStore, AsyncEventBroker, DataStore, EventSource
+from ..abc import AsyncEventBroker, DataStore
from ..enums import ConflictPolicy
-from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter
+from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter, SyncEventBrokerAdapter
from ..structures import Job, JobResult, Schedule, Task
-from ..util import reentrant
+from .base import BaseAsyncDataStore
-@reentrant
@attrs.define(eq=False)
-class AsyncDataStoreAdapter(AsyncDataStore):
+class AsyncDataStoreAdapter(BaseAsyncDataStore):
original: DataStore
_portal: BlockingPortal = attrs.field(init=False)
- _events: AsyncEventBroker = attrs.field(init=False)
- _exit_stack: AsyncExitStack = attrs.field(init=False)
- @property
- def events(self) -> EventSource:
- return self._events
-
- async def __aenter__(self) -> AsyncDataStoreAdapter:
- self._exit_stack = AsyncExitStack()
+ async def start(self, event_broker: AsyncEventBroker) -> None:
+ await super().start(event_broker)
self._portal = BlockingPortal()
- await self._exit_stack.enter_async_context(self._portal)
-
- self._events = AsyncEventBrokerAdapter(self.original.events, self._portal)
- await self._exit_stack.enter_async_context(self._events)
+ await self._portal.__aenter__()
- await to_thread.run_sync(self.original.__enter__)
- self._exit_stack.push_async_exit(
- partial(to_thread.run_sync, self.original.__exit__)
- )
+ if isinstance(event_broker, AsyncEventBrokerAdapter):
+ sync_event_broker = event_broker.original
+ else:
+ sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal)
- return self
+ await to_thread.run_sync(lambda: self.original.start(sync_event_broker))
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
+ async def stop(self, *, force: bool = False) -> None:
+ await to_thread.run_sync(lambda: self.original.stop(force=force))
+ await self._portal.__aexit__(None, None, None)
+ await super().stop(force=force)
async def add_task(self, task: Task) -> None:
await to_thread.run_sync(self.original.add_task, task)
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py
index aa56c6f..edf9b3e 100644
--- a/src/apscheduler/datastores/async_sqlalchemy.py
+++ b/src/apscheduler/datastores/async_sqlalchemy.py
@@ -17,9 +17,8 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.ddl import DropTable
from sqlalchemy.sql.elements import BindParameter
-from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule
+from ..abc import AsyncEventBroker, Job, Schedule
from ..enums import ConflictPolicy
-from ..eventbrokers.async_local import LocalAsyncEventBroker
from ..events import (
DataStoreEvent,
JobAcquired,
@@ -36,17 +35,14 @@ from ..events import (
from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError
from ..marshalling import callable_to_ref
from ..structures import JobResult, Task
-from ..util import reentrant
+from .base import BaseAsyncDataStore
from .sqlalchemy import _BaseSQLAlchemyDataStore
-@reentrant
@attrs.define(eq=False)
-class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
+class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore):
engine: AsyncEngine
- _events: AsyncEventBroker = attrs.field(factory=LocalAsyncEventBroker)
-
@classmethod
def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore:
engine = create_async_engine(url, future=True)
@@ -63,7 +59,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
reraise=True,
)
- async def __aenter__(self):
+ async def start(self, event_broker: AsyncEventBroker) -> None:
+ await super().start(event_broker)
+
asynclib = sniffio.current_async_library() or "(unknown)"
if asynclib != "asyncio":
raise RuntimeError(
@@ -92,16 +90,6 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
f"only version 1 is supported by this version of APScheduler"
)
- await self._events.__aenter__()
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._events.__aexit__(exc_type, exc_val, exc_tb)
-
- @property
- def events(self) -> EventSource:
- return self._events
-
async def _deserialize_schedules(self, result: Result) -> list[Schedule]:
schedules: list[Schedule] = []
for row in result:
diff --git a/src/apscheduler/datastores/base.py b/src/apscheduler/datastores/base.py
new file mode 100644
index 0000000..c05d28c
--- /dev/null
+++ b/src/apscheduler/datastores/base.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from apscheduler.abc import (
+ AsyncDataStore,
+ AsyncEventBroker,
+ DataStore,
+ EventBroker,
+ EventSource,
+)
+
+
+class BaseDataStore(DataStore):
+ _events: EventBroker
+
+ def start(self, event_broker: EventBroker) -> None:
+ self._events = event_broker
+
+ def stop(self, *, force: bool = False) -> None:
+ del self._events
+
+ @property
+ def events(self) -> EventSource:
+ return self._events
+
+
+class BaseAsyncDataStore(AsyncDataStore):
+ _events: AsyncEventBroker
+
+ async def start(self, event_broker: AsyncEventBroker) -> None:
+ self._events = event_broker
+
+ async def stop(self, *, force: bool = False) -> None:
+ del self._events
+
+ @property
+ def events(self) -> EventSource:
+ return self._events
diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py
index c2d7420..516d141 100644
--- a/src/apscheduler/datastores/memory.py
+++ b/src/apscheduler/datastores/memory.py
@@ -9,9 +9,8 @@ from uuid import UUID
import attrs
-from ..abc import DataStore, EventBroker, EventSource, Job, Schedule
+from ..abc import Job, Schedule
from ..enums import ConflictPolicy
-from ..eventbrokers.local import LocalEventBroker
from ..events import (
JobAcquired,
JobAdded,
@@ -25,7 +24,7 @@ from ..events import (
)
from ..exceptions import ConflictingIdError, TaskLookupError
from ..structures import JobResult, Task
-from ..util import reentrant
+from .base import BaseDataStore
max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)
@@ -83,11 +82,9 @@ class JobState:
return hash(self.job.id)
-@reentrant
@attrs.define(eq=False)
-class MemoryDataStore(DataStore):
+class MemoryDataStore(BaseDataStore):
lock_expiration_delay: float = 30
- _events: EventBroker = attrs.Factory(LocalEventBroker)
_tasks: dict[str, TaskState] = attrs.Factory(dict)
_schedules: list[ScheduleState] = attrs.Factory(list)
_schedules_by_id: dict[str, ScheduleState] = attrs.Factory(dict)
@@ -111,17 +108,6 @@ class MemoryDataStore(DataStore):
right_index = bisect_left(self._jobs, state)
return self._jobs.index(state, left_index, right_index + 1)
- def __enter__(self):
- self._events.__enter__()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._events.__exit__(exc_type, exc_val, exc_tb)
-
- @property
- def events(self) -> EventSource:
- return self._events
-
def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
return [
state.schedule
diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py
index 2669d03..59a345a 100644
--- a/src/apscheduler/datastores/mongodb.py
+++ b/src/apscheduler/datastores/mongodb.py
@@ -2,7 +2,6 @@ from __future__ import annotations
import operator
from collections import defaultdict
-from contextlib import ExitStack
from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
from typing import Any, Callable, ClassVar, Iterable
@@ -18,7 +17,7 @@ from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne
from pymongo.collection import Collection
from pymongo.errors import ConnectionFailure, DuplicateKeyError
-from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer
+from ..abc import EventBroker, Job, Schedule, Serializer
from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome
from ..eventbrokers.local import LocalEventBroker
from ..events import (
@@ -41,7 +40,7 @@ from ..exceptions import (
)
from ..serializers.pickle import PickleSerializer
from ..structures import JobResult, RetrySettings, Task
-from ..util import reentrant
+from .base import BaseDataStore
class CustomEncoder(TypeEncoder):
@@ -62,14 +61,14 @@ def ensure_uuid_presentation(client: MongoClient) -> None:
pass
-@reentrant
@attrs.define(eq=False)
-class MongoDBDataStore(DataStore):
+class MongoDBDataStore(BaseDataStore):
client: MongoClient = attrs.field(validator=instance_of(MongoClient))
serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True)
+ event_broker = attrs.field(factory=LocalEventBroker, kw_only=True)
database: str = attrs.field(default="apscheduler", kw_only=True)
lock_expiration_delay: float = attrs.field(default=30, kw_only=True)
- retry_settings: RetrySettings = attrs.field(default=RetrySettings())
+ retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True)
start_from_scratch: bool = attrs.field(default=False, kw_only=True)
_task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)]
@@ -79,8 +78,6 @@ class MongoDBDataStore(DataStore):
_job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)]
_logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__))
- _exit_stack: ExitStack = attrs.field(init=False, factory=ExitStack)
- _events: EventBroker = attrs.field(init=False, factory=LocalEventBroker)
_local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict)
def __attrs_post_init__(self) -> None:
@@ -108,10 +105,6 @@ class MongoDBDataStore(DataStore):
client = MongoClient(uri)
return cls(client, **options)
- @property
- def events(self) -> EventSource:
- return self._events
-
def _retry(self) -> tenacity.Retrying:
return tenacity.Retrying(
stop=self.retry_settings.stop,
@@ -128,7 +121,9 @@ class MongoDBDataStore(DataStore):
retry_state.outcome.exception(),
)
- def __enter__(self):
+ def start(self, event_broker: EventBroker) -> None:
+ super().start(event_broker)
+
server_info = self.client.server_info()
if server_info["versionArray"] < [4, 0]:
raise RuntimeError(
@@ -136,9 +131,6 @@ class MongoDBDataStore(DataStore):
f"{server_info['version']}"
)
- self._exit_stack.__enter__()
- self._exit_stack.enter_context(self._events)
-
for attempt in self._retry():
with attempt, self.client.start_session() as session:
if self.start_from_scratch:
@@ -153,11 +145,6 @@ class MongoDBDataStore(DataStore):
self._jobs.create_index("tags", session=session)
self._jobs_results.create_index("finished_at", session=session)
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
-
def add_task(self, task: Task) -> None:
for attempt in self._retry():
with attempt:
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py
index cca223e..9ee03fa 100644
--- a/src/apscheduler/datastores/sqlalchemy.py
+++ b/src/apscheduler/datastores/sqlalchemy.py
@@ -31,9 +31,8 @@ from sqlalchemy.future import Engine, create_engine
from sqlalchemy.sql.ddl import DropTable
from sqlalchemy.sql.elements import BindParameter, literal
-from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer
+from ..abc import EventBroker, Job, Schedule, Serializer
from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome
-from ..eventbrokers.local import LocalEventBroker
from ..events import (
Event,
JobAcquired,
@@ -52,7 +51,7 @@ from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError
from ..marshalling import callable_to_ref
from ..serializers.pickle import PickleSerializer
from ..structures import JobResult, RetrySettings, Task
-from ..util import reentrant
+from .base import BaseDataStore
class EmulatedUUID(TypeDecorator):
@@ -219,13 +218,10 @@ class _BaseSQLAlchemyDataStore:
return jobs
-@reentrant
@attrs.define(eq=False)
-class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore):
+class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore):
engine: Engine
- _events: EventBroker = attrs.field(init=False, factory=LocalEventBroker)
-
@classmethod
def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore:
engine = create_engine(url)
@@ -240,7 +236,9 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore):
reraise=True,
)
- def __enter__(self):
+ def start(self, event_broker: EventBroker) -> None:
+ super().start(event_broker)
+
for attempt in self._retry():
with attempt, self.engine.begin() as conn:
if self.start_from_scratch:
@@ -259,16 +257,6 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore):
f"only version 1 is supported by this version of APScheduler"
)
- self._events.__enter__()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._events.__exit__(exc_type, exc_val, exc_tb)
-
- @property
- def events(self) -> EventSource:
- return self._events
-
def add_task(self, task: Task) -> None:
insert = self.t_tasks.insert().values(
id=task.id,
diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py
index ebceaad..4c1cb6c 100644
--- a/src/apscheduler/eventbrokers/async_adapter.py
+++ b/src/apscheduler/eventbrokers/async_adapter.py
@@ -1,39 +1,65 @@
from __future__ import annotations
-from functools import partial
+from typing import Any, Callable, Iterable
import attrs
from anyio import to_thread
from anyio.from_thread import BlockingPortal
-from apscheduler.abc import EventBroker
-from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker
+from apscheduler.abc import AsyncEventBroker, EventBroker, Subscription
from apscheduler.events import Event
-from apscheduler.util import reentrant
-@reentrant
@attrs.define(eq=False)
-class AsyncEventBrokerAdapter(LocalAsyncEventBroker):
+class AsyncEventBrokerAdapter(AsyncEventBroker):
original: EventBroker
- portal: BlockingPortal
-
- async def __aenter__(self):
- await super().__aenter__()
- if not self.portal:
- self.portal = BlockingPortal()
- self._exit_stack.enter_async_context(self.portal)
+ async def start(self) -> None:
+ await to_thread.run_sync(self.original.start)
- await to_thread.run_sync(self.original.__enter__)
- self._exit_stack.push_async_exit(
- partial(to_thread.run_sync, self.original.__exit__)
- )
+ async def stop(self, *, force: bool = False) -> None:
+ await to_thread.run_sync(lambda: self.original.stop(force=force))
- # Relay events from the original broker to this one
- self._exit_stack.enter_context(
- self.original.subscribe(partial(self.portal.call, self.publish_local))
- )
+ async def publish_local(self, event: Event) -> None:
+ await to_thread.run_sync(self.original.publish_local, event)
async def publish(self, event: Event) -> None:
await to_thread.run_sync(self.original.publish, event)
+
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ one_shot: bool = False,
+ ) -> Subscription:
+ return self.original.subscribe(callback, event_types, one_shot=one_shot)
+
+
+@attrs.define(eq=False)
+class SyncEventBrokerAdapter(EventBroker):
+ original: AsyncEventBroker
+ portal: BlockingPortal
+
+ def start(self) -> None:
+ pass
+
+ def stop(self, *, force: bool = False) -> None:
+ pass
+
+ def publish_local(self, event: Event) -> None:
+ self.portal.call(self.original.publish_local, event)
+
+ def publish(self, event: Event) -> None:
+ self.portal.call(self.original.publish, event)
+
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ one_shot: bool = False,
+ ) -> Subscription:
+ return self.portal.call(
+ lambda: self.original.subscribe(callback, event_types, one_shot=one_shot)
+ )
diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py
index 98031bf..dba37f1 100644
--- a/src/apscheduler/eventbrokers/async_local.py
+++ b/src/apscheduler/eventbrokers/async_local.py
@@ -1,7 +1,6 @@
from __future__ import annotations
from asyncio import iscoroutine
-from contextlib import AsyncExitStack
from typing import Any, Callable
import attrs
@@ -10,26 +9,19 @@ from anyio.abc import TaskGroup
from ..abc import AsyncEventBroker
from ..events import Event
-from ..util import reentrant
from .base import BaseEventBroker
-@reentrant
@attrs.define(eq=False)
class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker):
_task_group: TaskGroup = attrs.field(init=False)
- _exit_stack: AsyncExitStack = attrs.field(init=False)
-
- async def __aenter__(self) -> LocalAsyncEventBroker:
- self._exit_stack = AsyncExitStack()
+ async def start(self) -> None:
self._task_group = create_task_group()
- await self._exit_stack.enter_async_context(self._task_group)
-
- return self
+ await self._task_group.__aenter__()
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
+ async def stop(self, *, force: bool = False) -> None:
+ await self._task_group.__aexit__(None, None, None)
del self._task_group
async def publish(self, event: Event) -> None:
diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py
index d97c229..c80033e 100644
--- a/src/apscheduler/eventbrokers/asyncpg.py
+++ b/src/apscheduler/eventbrokers/asyncpg.py
@@ -4,13 +4,12 @@ from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable
import attrs
-from anyio import TASK_STATUS_IGNORED, sleep
+from anyio import TASK_STATUS_IGNORED, CancelScope, sleep
from asyncpg import Connection
from asyncpg.pool import Pool
from ..events import Event
from ..exceptions import SerializationError
-from ..util import reentrant
from .async_local import LocalAsyncEventBroker
from .base import DistributedEventBrokerMixin
@@ -18,12 +17,12 @@ if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
-@reentrant
@attrs.define(eq=False)
class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
connection_factory: Callable[[], AsyncContextManager[Connection]]
channel: str = attrs.field(kw_only=True, default="apscheduler")
max_idle_time: float = attrs.field(kw_only=True, default=30)
+ _listen_cancel_scope: CancelScope = attrs.field(init=False)
@classmethod
def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker:
@@ -47,11 +46,15 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
return cls(connection_factory)
- async def __aenter__(self) -> AsyncpgEventBroker:
- await super().__aenter__()
- await self._task_group.start(self._listen_notifications)
- self._exit_stack.callback(self._task_group.cancel_scope.cancel)
- return self
+ async def start(self) -> None:
+ await super().start()
+ self._listen_cancel_scope = await self._task_group.start(
+ self._listen_notifications
+ )
+
+ async def stop(self, *, force: bool = False) -> None:
+ self._listen_cancel_scope.cancel()
+ await super().stop(force=force)
async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
def callback(connection, pid, channel: str, payload: str) -> None:
@@ -60,19 +63,20 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
self._task_group.start_soon(self.publish_local, event)
task_started_sent = False
- while True:
- async with self.connection_factory() as conn:
- await conn.add_listener(self.channel, callback)
- if not task_started_sent:
- task_status.started()
- task_started_sent = True
-
- try:
- while True:
- await sleep(self.max_idle_time)
- await conn.execute("SELECT 1")
- finally:
- await conn.remove_listener(self.channel, callback)
+ with CancelScope() as cancel_scope:
+ while True:
+ async with self.connection_factory() as conn:
+ await conn.add_listener(self.channel, callback)
+ if not task_started_sent:
+ task_status.started(cancel_scope)
+ task_started_sent = True
+
+ try:
+ while True:
+ await sleep(self.max_idle_time)
+ await conn.execute("SELECT 1")
+ finally:
+ await conn.remove_listener(self.channel, callback)
async def publish(self, event: Event) -> None:
notification = self.generate_notification_str(event)
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py
index 38c83c7..3aeee77 100644
--- a/src/apscheduler/eventbrokers/base.py
+++ b/src/apscheduler/eventbrokers/base.py
@@ -7,7 +7,7 @@ from typing import Any, Callable, Iterable
import attrs
from .. import events
-from ..abc import EventBroker, Serializer, Subscription
+from ..abc import EventSource, Serializer, Subscription
from ..events import Event
from ..exceptions import DeserializationError
@@ -25,7 +25,7 @@ class LocalSubscription(Subscription):
@attrs.define(eq=False)
-class BaseEventBroker(EventBroker):
+class BaseEventBroker(EventSource):
_logger: Logger = attrs.field(init=False)
_subscriptions: dict[object, LocalSubscription] = attrs.field(
init=False, factory=dict
@@ -70,7 +70,7 @@ class DistributedEventBrokerMixin:
self._logger.exception(
"Failed to deserialize an event of type %s",
event_type,
- serialized=serialized,
+ extra={"serialized": serialized},
)
return None
@@ -80,7 +80,7 @@ class DistributedEventBrokerMixin:
self._logger.error(
"Receive notification for a nonexistent event type: %s",
event_type,
- serialized=serialized,
+ extra={"serialized": serialized},
)
return None
@@ -94,7 +94,9 @@ class DistributedEventBrokerMixin:
try:
event_type_bytes, serialized = payload.split(b" ", 1)
except ValueError:
- self._logger.error("Received malformatted notification", payload=payload)
+ self._logger.error(
+ "Received malformatted notification", extra={"payload": payload}
+ )
return None
event_type = event_type_bytes.decode("ascii", errors="replace")
@@ -104,7 +106,9 @@ class DistributedEventBrokerMixin:
try:
event_type, b64_serialized = payload.split(" ", 1)
except ValueError:
- self._logger.error("Received malformatted notification", payload=payload)
+ self._logger.error(
+ "Received malformatted notification", extra={"payload": payload}
+ )
return None
return self._reconstitute_event(event_type, b64decode(b64_serialized))
diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py
index cf19526..a870c45 100644
--- a/src/apscheduler/eventbrokers/local.py
+++ b/src/apscheduler/eventbrokers/local.py
@@ -8,26 +8,22 @@ from typing import Any, Callable, Iterable
import attrs
-from ..abc import Subscription
+from ..abc import EventBroker, Subscription
from ..events import Event
-from ..util import reentrant
from .base import BaseEventBroker
-@reentrant
@attrs.define(eq=False)
-class LocalEventBroker(BaseEventBroker):
+class LocalEventBroker(EventBroker, BaseEventBroker):
_executor: ThreadPoolExecutor = attrs.field(init=False)
_exit_stack: ExitStack = attrs.field(init=False)
_subscriptions_lock: Lock = attrs.field(init=False, factory=Lock)
- def __enter__(self):
- self._exit_stack = ExitStack()
- self._executor = self._exit_stack.enter_context(ThreadPoolExecutor(1))
- return self
+ def start(self) -> None:
+ self._executor = ThreadPoolExecutor(1)
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
+ def stop(self, *, force: bool = False) -> None:
+ self._executor.shutdown(wait=not force)
del self._executor
def subscribe(
diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py
index 8b0a512..9b0f859 100644
--- a/src/apscheduler/eventbrokers/mqtt.py
+++ b/src/apscheduler/eventbrokers/mqtt.py
@@ -11,12 +11,10 @@ from paho.mqtt.reasoncodes import ReasonCodes
from ..abc import Serializer
from ..events import Event
from ..serializers.json import JSONSerializer
-from ..util import reentrant
from .base import DistributedEventBrokerMixin
from .local import LocalEventBroker
-@reentrant
@attrs.define(eq=False)
class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
client: Client
@@ -28,8 +26,8 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
publish_qos: int = attrs.field(kw_only=True, default=0)
_ready_future: Future[None] = attrs.field(init=False)
- def __enter__(self):
- super().__enter__()
+ def start(self) -> None:
+ super().start()
self._ready_future = Future()
self.client.enable_logger(self._logger)
self.client.on_connect = self._on_connect
@@ -38,11 +36,11 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
self.client.connect(self.host, self.port)
self.client.loop_start()
self._ready_future.result(10)
- self._exit_stack.push(
- lambda exc_type, *_: self.client.loop_stop(force=bool(exc_type))
- )
- self._exit_stack.callback(self.client.disconnect)
- return self
+
+ def stop(self, *, force: bool = False) -> None:
+ self.client.disconnect()
+ self.client.loop_stop(force=force)
+ super().stop()
def _on_connect(
self,
diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py
index 0f14e3e..a41ef37 100644
--- a/src/apscheduler/eventbrokers/redis.py
+++ b/src/apscheduler/eventbrokers/redis.py
@@ -9,12 +9,10 @@ from redis import ConnectionPool, Redis
from ..abc import Serializer
from ..events import Event
from ..serializers.json import JSONSerializer
-from ..util import reentrant
from .base import DistributedEventBrokerMixin
from .local import LocalEventBroker
-@reentrant
@attrs.define(eq=False)
class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
client: Redis
@@ -23,6 +21,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
message_poll_interval: float = attrs.field(kw_only=True, default=0.05)
_stopped: bool = attrs.field(init=False, default=True)
_ready_future: Future[None] = attrs.field(init=False)
+ _thread: Thread = attrs.field(init=False)
@classmethod
def from_url(cls, url: str, **kwargs) -> RedisEventBroker:
@@ -30,7 +29,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
client = Redis(connection_pool=pool)
return cls(client)
- def __enter__(self):
+ def start(self) -> None:
self._stopped = False
self._ready_future = Future()
self._thread = Thread(
@@ -38,14 +37,14 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
)
self._thread.start()
self._ready_future.result(10)
- return super().__enter__()
+ super().start()
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def stop(self, *, force: bool = False) -> None:
self._stopped = True
- if not exc_type:
+ if not force:
self._thread.join(5)
- super().__exit__(exc_type, exc_val, exc_tb)
+ super().stop(force=force)
def _listen_messages(self) -> None:
while not self._stopped:
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index 72c5994..889596d 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -17,10 +17,11 @@ from anyio import (
get_cancelled_exc_class,
move_on_after,
)
+from anyio.abc import TaskGroup
-from ..abc import AsyncDataStore, EventSource, Job, Schedule, Trigger
+from ..abc import AsyncDataStore, AsyncEventBroker, Job, Schedule, Subscription, Trigger
from ..context import current_scheduler
-from ..converters import as_async_datastore
+from ..converters import as_async_datastore, as_async_eventbroker
from ..datastores.memory import MemoryDataStore
from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
from ..eventbrokers.async_local import LocalAsyncEventBroker
@@ -48,6 +49,9 @@ class AsyncScheduler:
data_store: AsyncDataStore = attrs.field(
converter=as_async_datastore, factory=MemoryDataStore
)
+ event_broker: AsyncEventBroker = attrs.field(
+ converter=as_async_eventbroker, factory=LocalAsyncEventBroker
+ )
identity: str = attrs.field(kw_only=True, default=None)
start_worker: bool = attrs.field(kw_only=True, default=True)
logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
@@ -55,65 +59,24 @@ class AsyncScheduler:
_state: RunState = attrs.field(init=False, default=RunState.stopped)
_wakeup_event: anyio.Event = attrs.field(init=False)
_wakeup_deadline: datetime | None = attrs.field(init=False, default=None)
- _worker: AsyncWorker | None = attrs.field(init=False, default=None)
- _events: LocalAsyncEventBroker = attrs.field(
- init=False, factory=LocalAsyncEventBroker
- )
- _exit_stack: AsyncExitStack = attrs.field(init=False)
+ _task_group: TaskGroup | None = attrs.field(init=False, default=None)
+ _schedule_added_subscription: Subscription = attrs.field(init=False)
def __attrs_post_init__(self) -> None:
if not self.identity:
self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
- @property
- def events(self) -> EventSource:
- return self._events
-
- @property
- def worker(self) -> AsyncWorker | None:
- return self._worker
-
async def __aenter__(self):
- self._state = RunState.starting
- self._wakeup_event = anyio.Event()
- self._exit_stack = AsyncExitStack()
- await self._exit_stack.__aenter__()
- await self._exit_stack.enter_async_context(self._events)
-
- # Initialize the data store and start relaying events to the scheduler's event broker
- await self._exit_stack.enter_async_context(self.data_store)
- self._exit_stack.enter_context(
- self.data_store.events.subscribe(self._events.publish)
- )
-
- # Wake up the scheduler if the data store emits a significant schedule event
- self._exit_stack.enter_context(
- self.data_store.events.subscribe(
- self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated}
- )
- )
-
- # Start the built-in worker, if configured to do so
- if self.start_worker:
- token = current_scheduler.set(self)
- try:
- self._worker = AsyncWorker(self.data_store)
- await self._exit_stack.enter_async_context(self._worker)
- finally:
- current_scheduler.reset(token)
-
- # Start the worker and return when it has signalled readiness or raised an exception
- task_group = create_task_group()
- await self._exit_stack.enter_async_context(task_group)
- await task_group.start(self.run)
+ self._task_group = create_task_group()
+ await self._task_group.__aenter__()
+ await self._task_group.start(self.run_until_stopped)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self._state = RunState.stopping
self._wakeup_event.set()
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
- self._state = RunState.stopped
- del self._wakeup_event
+ await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
+ self._task_group = None
def _schedule_added_or_modified(self, event: Event) -> None:
event_ = cast("ScheduleAdded | ScheduleUpdated", event)
@@ -281,130 +244,169 @@ class AsyncScheduler:
else:
raise RuntimeError(f"Unknown job outcome: {result.outcome}")
- async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None:
- if self._state is not RunState.starting:
+ async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None:
+ if self._state is not RunState.stopped:
raise RuntimeError(
- f"This function cannot be called while the scheduler is in the "
- f"{self._state} state"
+ f'Cannot start the scheduler when it is in the "{self._state}" '
+ f"state"
)
- # Signal that the scheduler has started
- self._state = RunState.started
- task_status.started()
- await self._events.publish(SchedulerStarted())
-
- exception: BaseException | None = None
- try:
- while self._state is RunState.started:
- schedules = await self.data_store.acquire_schedules(self.identity, 100)
- now = datetime.now(timezone.utc)
- for schedule in schedules:
- # Calculate a next fire time for the schedule, if possible
- fire_times = [schedule.next_fire_time]
- calculate_next = schedule.trigger.next
- while True:
- try:
- fire_time = calculate_next()
- except Exception:
- self.logger.exception(
- "Error computing next fire time for schedule %r of task %r – "
- "removing schedule",
- schedule.id,
- schedule.task_id,
- )
- break
-
- # Stop if the calculated fire time is in the future
- if fire_time is None or fire_time > now:
- schedule.next_fire_time = fire_time
- break
-
- # Only keep all the fire times if coalesce policy = "all"
- if schedule.coalesce is CoalescePolicy.all:
- fire_times.append(fire_time)
- elif schedule.coalesce is CoalescePolicy.latest:
- fire_times[0] = fire_time
-
- # Add one or more jobs to the job queue
- max_jitter = (
- schedule.max_jitter.total_seconds()
- if schedule.max_jitter
- else 0
+ self._state = RunState.starting
+ async with AsyncExitStack() as exit_stack:
+ self._wakeup_event = anyio.Event()
+
+ # Initialize the event broker
+ await self.event_broker.start()
+ exit_stack.push_async_exit(
+ lambda *exc_info: self.event_broker.stop(force=exc_info[0] is not None)
+ )
+
+ # Initialize the data store
+ await self.data_store.start(self.event_broker)
+ exit_stack.push_async_exit(
+ lambda *exc_info: self.data_store.stop(force=exc_info[0] is not None)
+ )
+
+ # Wake up the scheduler if the data store emits a significant schedule event
+ self._schedule_added_subscription = self.event_broker.subscribe(
+ self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated}
+ )
+
+ # Start the built-in worker, if configured to do so
+ if self.start_worker:
+ token = current_scheduler.set(self)
+ exit_stack.callback(current_scheduler.reset, token)
+ worker = AsyncWorker(
+ self.data_store, self.event_broker, is_internal=True
+ )
+ await exit_stack.enter_async_context(worker)
+
+ # Signal that the scheduler has started
+ self._state = RunState.started
+ task_status.started()
+ await self.event_broker.publish_local(SchedulerStarted())
+
+ exception: BaseException | None = None
+ try:
+ while self._state is RunState.started:
+ schedules = await self.data_store.acquire_schedules(
+ self.identity, 100
)
- for i, fire_time in enumerate(fire_times):
- # Calculate a jitter if max_jitter > 0
- jitter = _zero_timedelta
- if max_jitter:
- if i + 1 < len(fire_times):
- next_fire_time = fire_times[i + 1]
- else:
- next_fire_time = schedule.next_fire_time
-
- if next_fire_time is not None:
- # Jitter must never be so high that it would cause a fire time to
- # equal or exceed the next fire time
- jitter_s = min(
- [
- max_jitter,
- (
- next_fire_time
- - fire_time
- - _microsecond_delta
- ).total_seconds(),
- ]
+ now = datetime.now(timezone.utc)
+ for schedule in schedules:
+ # Calculate a next fire time for the schedule, if possible
+ fire_times = [schedule.next_fire_time]
+ calculate_next = schedule.trigger.next
+ while True:
+ try:
+ fire_time = calculate_next()
+ except Exception:
+ self.logger.exception(
+ "Error computing next fire time for schedule %r of "
+ "task %r – removing schedule",
+ schedule.id,
+ schedule.task_id,
)
- jitter = timedelta(seconds=random.uniform(0, jitter_s))
- fire_time += jitter
-
- schedule.last_fire_time = fire_time
- job = Job(
- task_id=schedule.task_id,
- args=schedule.args,
- kwargs=schedule.kwargs,
- schedule_id=schedule.id,
- scheduled_fire_time=fire_time,
- jitter=jitter,
- start_deadline=schedule.next_deadline,
- tags=schedule.tags,
+ break
+
+ # Stop if the calculated fire time is in the future
+ if fire_time is None or fire_time > now:
+ schedule.next_fire_time = fire_time
+ break
+
+ # Only keep all the fire times if coalesce policy = "all"
+ if schedule.coalesce is CoalescePolicy.all:
+ fire_times.append(fire_time)
+ elif schedule.coalesce is CoalescePolicy.latest:
+ fire_times[0] = fire_time
+
+ # Add one or more jobs to the job queue
+ max_jitter = (
+ schedule.max_jitter.total_seconds()
+ if schedule.max_jitter
+ else 0
)
- await self.data_store.add_job(job)
-
- # Update the schedules (and release the scheduler's claim on them)
- await self.data_store.release_schedules(self.identity, schedules)
+ for i, fire_time in enumerate(fire_times):
+ # Calculate a jitter if max_jitter > 0
+ jitter = _zero_timedelta
+ if max_jitter:
+ if i + 1 < len(fire_times):
+ next_fire_time = fire_times[i + 1]
+ else:
+ next_fire_time = schedule.next_fire_time
+
+ if next_fire_time is not None:
+ # Jitter must never be so high that it would cause a
+ # fire time to equal or exceed the next fire time
+ jitter_s = min(
+ [
+ max_jitter,
+ (
+ next_fire_time
+ - fire_time
+ - _microsecond_delta
+ ).total_seconds(),
+ ]
+ )
+ jitter = timedelta(
+ seconds=random.uniform(0, jitter_s)
+ )
+ fire_time += jitter
+
+ schedule.last_fire_time = fire_time
+ job = Job(
+ task_id=schedule.task_id,
+ args=schedule.args,
+ kwargs=schedule.kwargs,
+ schedule_id=schedule.id,
+ scheduled_fire_time=fire_time,
+ jitter=jitter,
+ start_deadline=schedule.next_deadline,
+ tags=schedule.tags,
+ )
+ await self.data_store.add_job(job)
+
+ # Update the schedules (and release the scheduler's claim on them)
+ await self.data_store.release_schedules(self.identity, schedules)
+
+ # If we received fewer schedules than the maximum amount, sleep
+ # until the next schedule is due or the scheduler is explicitly
+ # woken up
+ wait_time = None
+ if len(schedules) < 100:
+ self._wakeup_deadline = (
+ await self.data_store.get_next_schedule_run_time()
+ )
+ if self._wakeup_deadline:
+ wait_time = (
+ self._wakeup_deadline - datetime.now(timezone.utc)
+ ).total_seconds()
+ self.logger.debug(
+ "Sleeping %.3f seconds until the next fire time (%s)",
+ wait_time,
+ self._wakeup_deadline,
+ )
+ else:
+ self.logger.debug("Waiting for any due schedules to appear")
- # If we received fewer schedules than the maximum amount, sleep until the next
- # schedule is due or the scheduler is explicitly woken up
- wait_time = None
- if len(schedules) < 100:
- self._wakeup_deadline = (
- await self.data_store.get_next_schedule_run_time()
- )
- if self._wakeup_deadline:
- wait_time = (
- self._wakeup_deadline - datetime.now(timezone.utc)
- ).total_seconds()
+ with move_on_after(wait_time):
+ await self._wakeup_event.wait()
+ self._wakeup_event = anyio.Event()
+ else:
self.logger.debug(
- "Sleeping %.3f seconds until the next fire time (%s)",
- wait_time,
- self._wakeup_deadline,
+ "Processing more schedules on the next iteration"
)
- else:
- self.logger.debug("Waiting for any due schedules to appear")
-
- with move_on_after(wait_time):
- await self._wakeup_event.wait()
- self._wakeup_event = anyio.Event()
- else:
- self.logger.debug("Processing more schedules on the next iteration")
- except get_cancelled_exc_class():
- pass
- except BaseException as exc:
- self.logger.exception("Scheduler crashed")
- exception = exc
- raise
- else:
- self.logger.info("Scheduler stopped")
- finally:
- self._state = RunState.stopped
- with move_on_after(3, shield=True):
- await self._events.publish(SchedulerStopped(exception=exception))
+ except get_cancelled_exc_class():
+ pass
+ except BaseException as exc:
+ self.logger.exception("Scheduler crashed")
+ exception = exc
+ raise
+ else:
+ self.logger.info("Scheduler stopped")
+ finally:
+ self._state = RunState.stopped
+ with move_on_after(3, shield=True):
+ await self.event_broker.publish_local(
+ SchedulerStopped(exception=exception)
+ )
diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py
index 4c7095e..5ea0c8a 100644
--- a/src/apscheduler/schedulers/sync.py
+++ b/src/apscheduler/schedulers/sync.py
@@ -5,16 +5,17 @@ import os
import platform
import random
import threading
-from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
+from concurrent.futures import Future
from contextlib import ExitStack
from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
+from types import TracebackType
from typing import Any, Callable, Iterable, Mapping, cast
from uuid import UUID, uuid4
import attrs
-from ..abc import DataStore, EventSource, Trigger
+from ..abc import DataStore, EventBroker, Trigger
from ..context import current_scheduler
from ..datastores.memory import MemoryDataStore
from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
@@ -41,39 +42,60 @@ class Scheduler:
"""A synchronous scheduler implementation."""
data_store: DataStore = attrs.field(factory=MemoryDataStore)
+ event_broker: EventBroker = attrs.field(factory=LocalEventBroker)
identity: str = attrs.field(kw_only=True, default=None)
start_worker: bool = attrs.field(kw_only=True, default=True)
logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
_state: RunState = attrs.field(init=False, default=RunState.stopped)
- _wakeup_event: threading.Event = attrs.field(init=False)
+ _thread: threading.Thread | None = attrs.field(init=False, default=None)
+ _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event)
_wakeup_deadline: datetime | None = attrs.field(init=False, default=None)
- _worker: Worker | None = attrs.field(init=False, default=None)
- _events: LocalEventBroker = attrs.field(init=False, factory=LocalEventBroker)
+ _services_initialized: bool = attrs.field(init=False, default=False)
+ _exit_stack: ExitStack = attrs.field(init=False, factory=ExitStack)
+ _lock: threading.RLock = attrs.field(init=False, factory=threading.RLock)
def __attrs_post_init__(self) -> None:
if not self.identity:
self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
@property
- def events(self) -> EventSource:
- return self._events
-
- @property
def state(self) -> RunState:
return self._state
- @property
- def worker(self) -> Worker | None:
- return self._worker
-
def __enter__(self) -> Scheduler:
self.start_in_background()
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: type[BaseException],
+ exc_val: BaseException,
+ exc_tb: TracebackType,
+ ) -> None:
self.stop()
+ def _ensure_services_ready(self) -> None:
+ """Ensure that the data store and event broker have been initialized."""
+ with self._lock:
+ if not self._services_initialized:
+ self._services_initialized = True
+ self.event_broker.start()
+ self._exit_stack.push(
+ lambda *exc_info: self.event_broker.stop(
+ force=exc_info[0] is not None
+ )
+ )
+
+ # Initialize the data store
+ self.data_store.start(self.event_broker)
+ self._exit_stack.push(
+ lambda *exc_info: self.data_store.stop(
+ force=exc_info[0] is not None
+ )
+ )
+ atexit.register(self.stop)
+
def _schedule_added_or_modified(self, event: Event) -> None:
event_ = cast("ScheduleAdded | ScheduleUpdated", event)
if not self._wakeup_deadline or (
@@ -98,6 +120,7 @@ class Scheduler:
tags: Iterable[str] | None = None,
conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing,
) -> str:
+ self._ensure_services_ready()
id = id or str(uuid4())
args = tuple(args or ())
kwargs = dict(kwargs or {})
@@ -133,10 +156,12 @@ class Scheduler:
return schedule.id
def get_schedule(self, id: str) -> Schedule:
+ self._ensure_services_ready()
schedules = self.data_store.get_schedules({id})
return schedules[0]
def remove_schedule(self, schedule_id: str) -> None:
+ self._ensure_services_ready()
self.data_store.remove_schedules({schedule_id})
def add_job(
@@ -157,6 +182,7 @@ class Scheduler:
:return: the ID of the newly created job
"""
+ self._ensure_services_ready()
if callable(func_or_task_id):
task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
self.data_store.add_task(task)
@@ -177,18 +203,19 @@ class Scheduler:
Retrieve the result of a job.
:param job_id: the ID of the job
- :param wait: if ``True``, wait until the job has ended (one way or another), ``False`` to
- raise an exception if the result is not yet available
+ :param wait: if ``True``, wait until the job has ended (one way or another),
+ ``False`` to raise an exception if the result is not yet available
:raises JobLookupError: if the job does not exist in the data store
"""
+ self._ensure_services_ready()
wait_event = threading.Event()
def listener(event: JobReleased) -> None:
if event.job_id == job_id:
wait_event.set()
- with self.data_store.events.subscribe(listener, {JobReleased}):
+ with self.data_store.events.subscribe(listener, {JobReleased}, one_shot=True):
result = self.data_store.get_job_result(job_id)
if result:
return result
@@ -215,6 +242,7 @@ class Scheduler:
:returns: the return value of the target function
"""
+ self._ensure_services_ready()
job_complete_event = threading.Event()
def listener(event: JobReleased) -> None:
@@ -239,181 +267,200 @@ class Scheduler:
raise RuntimeError(f"Unknown job outcome: {result.outcome}")
def start_in_background(self) -> None:
- if self._state is not RunState.stopped:
- raise RuntimeError("Scheduler already running")
+ start_future: Future[None] = Future()
+ self._thread = threading.Thread(
+ target=self._run, args=[start_future], daemon=True
+ )
+ self._thread.start()
+ try:
+ start_future.result()
+ except BaseException:
+ self._thread = None
+ raise
- thread = threading.Thread(target=self.run_until_stopped, daemon=True)
- thread.start()
atexit.register(self.stop)
def stop(self) -> None:
- if self._state is RunState.started:
- f = Future()
- with self._events.subscribe(f.set_result, one_shot=True):
+ atexit.unregister(self.stop)
+ with self._lock:
+ if self._state is RunState.started:
self._state = RunState.stopping
self._wakeup_event.set()
- f.result()
- def run_until_stopped(self):
- self._state = RunState.starting
- self._wakeup_event = threading.Event()
- with ExitStack() as exit_stack:
- exit_stack.enter_context(self._events)
+ if self._thread and threading.current_thread() != self._thread:
+ self._thread.join()
+ self._thread = None
- # Initialize the data store and start relaying events to the scheduler's event broker
- exit_stack.enter_context(self.data_store)
- exit_stack.enter_context(
- self.data_store.events.subscribe(self._events.publish)
- )
+ self._exit_stack.close()
- # Wake up the scheduler if the data store emits a significant schedule event
- exit_stack.enter_context(
- self.data_store.events.subscribe(
- self._schedule_added_or_modified, {ScheduleAdded, ScheduleUpdated}
- )
- )
+ def run_until_stopped(self) -> None:
+ self._run(None)
- # Start the built-in worker, if configured to do so
- if self.start_worker:
- token = current_scheduler.set(self)
- try:
- self._worker = Worker(self.data_store)
- exit_stack.enter_context(self._worker)
- finally:
- current_scheduler.reset(token)
-
- # Start the scheduler and return when it has signalled readiness or raised an exception
- start_future: Future[Event] = Future()
- with self._events.subscribe(start_future.set_result, one_shot=True):
- executor = ThreadPoolExecutor(1)
- exit_stack.push(
- lambda exc_type, *args: executor.shutdown(wait=exc_type is None)
- )
- run_future = executor.submit(self._run)
- wait([start_future, run_future], return_when=FIRST_COMPLETED)
-
- self._state = RunState.stopped
- del self._wakeup_event
-
- if run_future.done():
- run_future.result()
+ def _run(self, start_future: Future[None] | None) -> None:
+ with ExitStack() as exit_stack:
+ try:
+ with self._lock:
+ if self._state is not RunState.stopped:
+ raise RuntimeError(
+ f'Cannot start the scheduler when it is in the "{self._state}" '
+ f"state"
+ )
- def _run(self) -> None:
- assert self._state is RunState.starting
+ self._state = RunState.starting
- # Signal that the scheduler has started
- self._state = RunState.started
- self._events.publish(SchedulerStarted())
+ self._ensure_services_ready()
- try:
- while self._state is RunState.started:
- schedules = self.data_store.acquire_schedules(self.identity, 100)
- self.logger.debug(
- "Processing %d schedules retrieved from the data store",
- len(schedules),
+ # Wake up the scheduler if the data store emits a significant schedule event
+ exit_stack.enter_context(
+ self.data_store.events.subscribe(
+ self._schedule_added_or_modified,
+ {ScheduleAdded, ScheduleUpdated},
+ )
)
- now = datetime.now(timezone.utc)
- for schedule in schedules:
- # Calculate a next fire time for the schedule, if possible
- fire_times = [schedule.next_fire_time]
- calculate_next = schedule.trigger.next
- while True:
- try:
- fire_time = calculate_next()
- except Exception:
- self.logger.exception(
- "Error computing next fire time for schedule %r of task %r – "
- "removing schedule",
- schedule.id,
- schedule.task_id,
- )
- break
-
- # Stop if the calculated fire time is in the future
- if fire_time is None or fire_time > now:
- schedule.next_fire_time = fire_time
- break
-
- # Only keep all the fire times if coalesce policy = "all"
- if schedule.coalesce is CoalescePolicy.all:
- fire_times.append(fire_time)
- elif schedule.coalesce is CoalescePolicy.latest:
- fire_times[0] = fire_time
-
- # Add one or more jobs to the job queue
- max_jitter = (
- schedule.max_jitter.total_seconds()
- if schedule.max_jitter
- else 0
+
+ # Start the built-in worker, if configured to do so
+ if self.start_worker:
+ token = current_scheduler.set(self)
+ exit_stack.callback(current_scheduler.reset, token)
+ worker = Worker(
+ self.data_store, self.event_broker, is_internal=True
+ )
+ exit_stack.enter_context(worker)
+
+ # Signal that the scheduler has started
+ self._state = RunState.started
+ self.event_broker.publish_local(SchedulerStarted())
+ except BaseException as exc:
+ if start_future:
+ start_future.set_exception(exc)
+ return
+ else:
+ raise
+ else:
+ if start_future:
+ start_future.set_result(None)
+
+ try:
+ while self._state is RunState.started:
+ schedules = self.data_store.acquire_schedules(self.identity, 100)
+ self.logger.debug(
+ "Processing %d schedules retrieved from the data store",
+ len(schedules),
)
- for i, fire_time in enumerate(fire_times):
- # Calculate a jitter if max_jitter > 0
- jitter = _zero_timedelta
- if max_jitter:
- if i + 1 < len(fire_times):
- next_fire_time = fire_times[i + 1]
- else:
- next_fire_time = schedule.next_fire_time
-
- if next_fire_time is not None:
- # Jitter must never be so high that it would cause a fire time to
- # equal or exceed the next fire time
- jitter_s = min(
- [
- max_jitter,
- (
- next_fire_time
- - fire_time
- - _microsecond_delta
- ).total_seconds(),
- ]
+ now = datetime.now(timezone.utc)
+ for schedule in schedules:
+ # Calculate a next fire time for the schedule, if possible
+ fire_times = [schedule.next_fire_time]
+ calculate_next = schedule.trigger.next
+ while True:
+ try:
+ fire_time = calculate_next()
+ except Exception:
+ self.logger.exception(
+ "Error computing next fire time for schedule %r of task %r – "
+ "removing schedule",
+ schedule.id,
+ schedule.task_id,
)
- jitter = timedelta(seconds=random.uniform(0, jitter_s))
- fire_time += jitter
-
- schedule.last_fire_time = fire_time
- job = Job(
- task_id=schedule.task_id,
- args=schedule.args,
- kwargs=schedule.kwargs,
- schedule_id=schedule.id,
- scheduled_fire_time=fire_time,
- jitter=jitter,
- start_deadline=schedule.next_deadline,
- tags=schedule.tags,
+ break
+
+ # Stop if the calculated fire time is in the future
+ if fire_time is None or fire_time > now:
+ schedule.next_fire_time = fire_time
+ break
+
+ # Only keep all the fire times if coalesce policy = "all"
+ if schedule.coalesce is CoalescePolicy.all:
+ fire_times.append(fire_time)
+ elif schedule.coalesce is CoalescePolicy.latest:
+ fire_times[0] = fire_time
+
+ # Add one or more jobs to the job queue
+ max_jitter = (
+ schedule.max_jitter.total_seconds()
+ if schedule.max_jitter
+ else 0
)
- self.data_store.add_job(job)
-
- # Update the schedules (and release the scheduler's claim on them)
- self.data_store.release_schedules(self.identity, schedules)
-
- # If we received fewer schedules than the maximum amount, sleep until the next
- # schedule is due or the scheduler is explicitly woken up
- wait_time = None
- if len(schedules) < 100:
- self._wakeup_deadline = self.data_store.get_next_schedule_run_time()
- if self._wakeup_deadline:
- wait_time = (
- self._wakeup_deadline - datetime.now(timezone.utc)
- ).total_seconds()
- self.logger.debug(
- "Sleeping %.3f seconds until the next fire time (%s)",
- wait_time,
- self._wakeup_deadline,
+ for i, fire_time in enumerate(fire_times):
+ # Calculate a jitter if max_jitter > 0
+ jitter = _zero_timedelta
+ if max_jitter:
+ if i + 1 < len(fire_times):
+ next_fire_time = fire_times[i + 1]
+ else:
+ next_fire_time = schedule.next_fire_time
+
+ if next_fire_time is not None:
+ # Jitter must never be so high that it would cause
+ # a fire time to equal or exceed the next fire time
+ jitter_s = min(
+ [
+ max_jitter,
+ (
+ next_fire_time
+ - fire_time
+ - _microsecond_delta
+ ).total_seconds(),
+ ]
+ )
+ jitter = timedelta(
+ seconds=random.uniform(0, jitter_s)
+ )
+ fire_time += jitter
+
+ schedule.last_fire_time = fire_time
+ job = Job(
+ task_id=schedule.task_id,
+ args=schedule.args,
+ kwargs=schedule.kwargs,
+ schedule_id=schedule.id,
+ scheduled_fire_time=fire_time,
+ jitter=jitter,
+ start_deadline=schedule.next_deadline,
+ tags=schedule.tags,
+ )
+ self.data_store.add_job(job)
+
+ # Update the schedules (and release the scheduler's claim on them)
+ self.data_store.release_schedules(self.identity, schedules)
+
+ # If we received fewer schedules than the maximum amount, sleep
+ # until the next schedule is due or the scheduler is explicitly
+ # woken up
+ wait_time = None
+ if len(schedules) < 100:
+ self._wakeup_deadline = (
+ self.data_store.get_next_schedule_run_time()
)
- else:
- self.logger.debug("Waiting for any due schedules to appear")
+ if self._wakeup_deadline:
+ wait_time = (
+ self._wakeup_deadline - datetime.now(timezone.utc)
+ ).total_seconds()
+ self.logger.debug(
+ "Sleeping %.3f seconds until the next fire time (%s)",
+ wait_time,
+ self._wakeup_deadline,
+ )
+ else:
+ self.logger.debug("Waiting for any due schedules to appear")
- if self._wakeup_event.wait(wait_time):
- self._wakeup_event = threading.Event()
+ if self._wakeup_event.wait(wait_time):
+ self._wakeup_event = threading.Event()
+ else:
+ self.logger.debug(
+ "Processing more schedules on the next iteration"
+ )
+ except BaseException as exc:
+ self._state = RunState.stopped
+ if isinstance(exc, Exception):
+ self.logger.exception("Scheduler crashed")
else:
- self.logger.debug("Processing more schedules on the next iteration")
- except BaseException as exc:
- self._state = RunState.stopped
- self.logger.exception("Scheduler crashed")
- self._events.publish(SchedulerStopped(exception=exc))
- raise
+ self.logger.info(
+ f"Scheduler stopped due to {exc.__class__.__name__}"
+ )
- self._state = RunState.stopped
- self.logger.info("Scheduler stopped")
- self._events.publish(SchedulerStopped())
+ self.event_broker.publish_local(SchedulerStopped(exception=exc))
+ else:
+ self._state = RunState.stopped
+ self.logger.info("Scheduler stopped")
+ self.event_broker.publish_local(SchedulerStopped())
diff --git a/src/apscheduler/util.py b/src/apscheduler/util.py
index 7ff4858..8c69203 100644
--- a/src/apscheduler/util.py
+++ b/src/apscheduler/util.py
@@ -2,9 +2,8 @@
from __future__ import annotations
import sys
-from collections import defaultdict
from datetime import datetime, tzinfo
-from typing import Callable, TypeVar
+from typing import TypeVar
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
@@ -36,53 +35,3 @@ def timezone_repr(timezone: tzinfo) -> str:
def absolute_datetime_diff(dateval1: datetime, dateval2: datetime) -> float:
return dateval1.timestamp() - dateval2.timestamp()
-
-
-def reentrant(cls: type[T]) -> type[T]:
- """
- Modifies a class so that its ``__enter__`` / ``__exit__`` (or ``__aenter__`` / ``__aexit__``)
- methods track the number of times it has been entered and exited and only actually invoke
- the ``__enter__()`` method on the first entry and ``__exit__()`` on the last exit.
-
- """
-
- def __enter__(self: T) -> T:
- loans[self] += 1
- if loans[self] == 1:
- previous_enter(self)
-
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
- assert loans[self]
- loans[self] -= 1
- if loans[self] == 0:
- return previous_exit(self, exc_type, exc_val, exc_tb)
-
- async def __aenter__(self: T) -> T:
- loans[self] += 1
- if loans[self] == 1:
- await previous_aenter(self)
-
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- assert loans[self]
- loans[self] -= 1
- if loans[self] == 0:
- del loans[self]
- return await previous_aexit(self, exc_type, exc_val, exc_tb)
-
- loans: dict[T, int] = defaultdict(lambda: 0)
- previous_enter: Callable = getattr(cls, "__enter__", None)
- previous_exit: Callable = getattr(cls, "__exit__", None)
- previous_aenter: Callable = getattr(cls, "__aenter__", None)
- previous_aexit: Callable = getattr(cls, "__aexit__", None)
- if previous_enter and previous_exit:
- cls.__enter__ = __enter__
- cls.__exit__ = __exit__
- elif previous_aenter and previous_aexit:
- cls.__aenter__ = __aenter__
- cls.__aexit__ = __aexit__
-
- return cls
diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py
index e4ad44d..65870e5 100644
--- a/src/apscheduler/workers/async_.py
+++ b/src/apscheduler/workers/async_.py
@@ -6,6 +6,7 @@ from contextlib import AsyncExitStack
from datetime import datetime, timezone
from inspect import isawaitable
from logging import Logger, getLogger
+from types import TracebackType
from typing import Callable
from uuid import UUID
@@ -17,11 +18,11 @@ from anyio import (
get_cancelled_exc_class,
move_on_after,
)
-from anyio.abc import CancelScope
+from anyio.abc import CancelScope, TaskGroup
-from ..abc import AsyncDataStore, EventSource, Job
+from ..abc import AsyncDataStore, AsyncEventBroker, Job
from ..context import current_worker, job_info
-from ..converters import as_async_datastore
+from ..converters import as_async_datastore, as_async_eventbroker
from ..enums import JobOutcome, RunState
from ..eventbrokers.async_local import LocalAsyncEventBroker
from ..events import JobAdded, WorkerStarted, WorkerStopped
@@ -34,105 +35,129 @@ class AsyncWorker:
"""Runs jobs locally in a task group."""
data_store: AsyncDataStore = attrs.field(converter=as_async_datastore)
+ event_broker: AsyncEventBroker = attrs.field(
+ converter=as_async_eventbroker, factory=LocalAsyncEventBroker
+ )
max_concurrent_jobs: int = attrs.field(
kw_only=True, validator=positive_integer, default=100
)
identity: str = attrs.field(kw_only=True, default=None)
logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
+ # True if a scheduler owns this worker
+ _is_internal: bool = attrs.field(kw_only=True, default=False)
_state: RunState = attrs.field(init=False, default=RunState.stopped)
- _wakeup_event: anyio.Event = attrs.field(init=False, factory=anyio.Event)
+ _wakeup_event: anyio.Event = attrs.field(init=False)
+ _task_group: TaskGroup = attrs.field(init=False)
_acquired_jobs: set[Job] = attrs.field(init=False, factory=set)
- _events: LocalAsyncEventBroker = attrs.field(
- init=False, factory=LocalAsyncEventBroker
- )
_running_jobs: set[UUID] = attrs.field(init=False, factory=set)
- _exit_stack: AsyncExitStack = attrs.field(init=False)
def __attrs_post_init__(self) -> None:
if not self.identity:
self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
@property
- def events(self) -> EventSource:
- return self._events
-
- @property
def state(self) -> RunState:
return self._state
- async def __aenter__(self):
- self._state = RunState.starting
- self._wakeup_event = anyio.Event()
- self._exit_stack = AsyncExitStack()
- await self._exit_stack.__aenter__()
- await self._exit_stack.enter_async_context(self._events)
-
- # Initialize the data store and start relaying events to the worker's event broker
- await self._exit_stack.enter_async_context(self.data_store)
- self._exit_stack.enter_context(
- self.data_store.events.subscribe(self._events.publish)
- )
-
- # Wake up the worker if the data store emits a significant job event
- self._exit_stack.enter_context(
- self.data_store.events.subscribe(
- lambda event: self._wakeup_event.set(), {JobAdded}
- )
- )
-
- # Start the actual worker
- task_group = create_task_group()
- await self._exit_stack.enter_async_context(task_group)
- await task_group.start(self.run)
+ async def __aenter__(self) -> AsyncWorker:
+ self._task_group = create_task_group()
+ await self._task_group.__aenter__()
+ await self._task_group.start(self.run_until_stopped)
return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException],
+ exc_val: BaseException,
+ exc_tb: TracebackType,
+ ) -> None:
self._state = RunState.stopping
self._wakeup_event.set()
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
+ await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
+ del self._task_group
del self._wakeup_event
- async def run(self, *, task_status=TASK_STATUS_IGNORED) -> None:
- if self._state is not RunState.starting:
+ async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None:
+ if self._state is not RunState.stopped:
raise RuntimeError(
- f"This function cannot be called while the worker is in the "
- f"{self._state} state"
+ f'Cannot start the worker when it is in the "{self._state}" ' f"state"
)
- # Set the current worker
- token = current_worker.set(self)
+ self._state = RunState.starting
+ self._wakeup_event = anyio.Event()
+ async with AsyncExitStack() as exit_stack:
+ if not self._is_internal:
+ # Initialize the event broker
+ await self.event_broker.start()
+ exit_stack.push_async_exit(
+ lambda *exc_info: self.event_broker.stop(
+ force=exc_info[0] is not None
+ )
+ )
- # Signal that the worker has started
- self._state = RunState.started
- task_status.started()
- await self._events.publish(WorkerStarted())
+ # Initialize the data store
+ await self.data_store.start(self.event_broker)
+ exit_stack.push_async_exit(
+ lambda *exc_info: self.data_store.stop(
+ force=exc_info[0] is not None
+ )
+ )
- exception: BaseException | None = None
- try:
- async with create_task_group() as tg:
- while self._state is RunState.started:
- limit = self.max_concurrent_jobs - len(self._running_jobs)
- jobs = await self.data_store.acquire_jobs(self.identity, limit)
- for job in jobs:
- task = await self.data_store.get_task(job.task_id)
- self._running_jobs.add(job.id)
- tg.start_soon(self._run_job, job, task.func)
-
- await self._wakeup_event.wait()
- self._wakeup_event = anyio.Event()
- except get_cancelled_exc_class():
- pass
- except BaseException as exc:
- self.logger.exception("Worker crashed")
- exception = exc
- else:
- self.logger.info("Worker stopped")
- finally:
- current_worker.reset(token)
- self._state = RunState.stopped
- with move_on_after(3, shield=True):
- await self._events.publish(WorkerStopped(exception=exception))
+ # Set the current worker
+ token = current_worker.set(self)
+ exit_stack.callback(current_worker.reset, token)
+
+ # Wake up the worker if the data store emits a significant job event
+ self.event_broker.subscribe(
+ lambda event: self._wakeup_event.set(), {JobAdded}
+ )
+
+ # Signal that the worker has started
+ self._state = RunState.started
+ task_status.started()
+ exception: BaseException | None = None
+ try:
+ await self.event_broker.publish_local(WorkerStarted())
+
+ async with create_task_group() as tg:
+ while self._state is RunState.started:
+ limit = self.max_concurrent_jobs - len(self._running_jobs)
+ jobs = await self.data_store.acquire_jobs(self.identity, limit)
+ for job in jobs:
+ task = await self.data_store.get_task(job.task_id)
+ self._running_jobs.add(job.id)
+ tg.start_soon(self._run_job, job, task.func)
+
+ await self._wakeup_event.wait()
+ self._wakeup_event = anyio.Event()
+ except get_cancelled_exc_class():
+ pass
+ except BaseException as exc:
+ self.logger.exception("Worker crashed")
+ exception = exc
+ else:
+ self.logger.info("Worker stopped")
+ finally:
+ self._state = RunState.stopped
+ with move_on_after(3, shield=True):
+ await self.event_broker.publish_local(
+ WorkerStopped(exception=exception)
+ )
+
+ async def stop(self, *, force: bool = False) -> None:
+ if self._state in (RunState.starting, RunState.started):
+ self._state = RunState.stopping
+ event = anyio.Event()
+ self.event_broker.subscribe(
+ lambda ev: event.set(), {WorkerStopped}, one_shot=True
+ )
+ if force:
+ self._task_group.cancel_scope.cancel()
+ else:
+ self._wakeup_event.set()
+
+ await event.wait()
async def _run_job(self, job: Job, func: Callable) -> None:
try:
diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py
index 2da0045..5b5e6a9 100644
--- a/src/apscheduler/workers/sync.py
+++ b/src/apscheduler/workers/sync.py
@@ -1,19 +1,21 @@
from __future__ import annotations
+import atexit
import os
import platform
import threading
-from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
+from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import ExitStack
from contextvars import copy_context
from datetime import datetime, timezone
from logging import Logger, getLogger
+from types import TracebackType
from typing import Callable
from uuid import UUID
import attrs
-from ..abc import DataStore, EventSource
+from ..abc import DataStore, EventBroker
from ..context import current_worker, job_info
from ..enums import JobOutcome, RunState
from ..eventbrokers.local import LocalEventBroker
@@ -27,111 +29,151 @@ class Worker:
"""Runs jobs locally in a thread pool."""
data_store: DataStore
+ event_broker: EventBroker = attrs.field(factory=LocalEventBroker)
max_concurrent_jobs: int = attrs.field(
kw_only=True, validator=positive_integer, default=20
)
identity: str = attrs.field(kw_only=True, default=None)
logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
+ # True if a scheduler owns this worker
+ _is_internal: bool = attrs.field(kw_only=True, default=False)
_state: RunState = attrs.field(init=False, default=RunState.stopped)
- _wakeup_event: threading.Event = attrs.field(init=False)
+ _thread: threading.Thread | None = attrs.field(init=False, default=None)
+ _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event)
+ _executor: ThreadPoolExecutor = attrs.field(init=False)
_acquired_jobs: set[Job] = attrs.field(init=False, factory=set)
- _events: LocalEventBroker = attrs.field(init=False, factory=LocalEventBroker)
_running_jobs: set[UUID] = attrs.field(init=False, factory=set)
- _exit_stack: ExitStack = attrs.field(init=False)
- _executor: ThreadPoolExecutor = attrs.field(init=False)
def __attrs_post_init__(self) -> None:
if not self.identity:
self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
@property
- def events(self) -> EventSource:
- return self._events
-
- @property
def state(self) -> RunState:
return self._state
def __enter__(self) -> Worker:
- self._state = RunState.starting
- self._wakeup_event = threading.Event()
- self._exit_stack = ExitStack()
- self._exit_stack.__enter__()
- self._exit_stack.enter_context(self._events)
-
- # Initialize the data store and start relaying events to the worker's event broker
- self._exit_stack.enter_context(self.data_store)
- self._exit_stack.enter_context(
- self.data_store.events.subscribe(self._events.publish)
- )
+ self.start_in_background()
+ return self
- # Wake up the worker if the data store emits a significant job event
- self._exit_stack.enter_context(
- self.data_store.events.subscribe(
- lambda event: self._wakeup_event.set(), {JobAdded}
- )
- )
+ def __exit__(
+ self,
+ exc_type: type[BaseException],
+ exc_val: BaseException,
+ exc_tb: TracebackType,
+ ) -> None:
+ self.stop()
- # Start the worker and return when it has signalled readiness or raised an exception
+ def start_in_background(self) -> None:
start_future: Future[None] = Future()
- with self._events.subscribe(start_future.set_result, one_shot=True):
- self._executor = ThreadPoolExecutor(1)
- run_future = self._executor.submit(copy_context().run, self.run)
- wait([start_future, run_future], return_when=FIRST_COMPLETED)
+ self._thread = threading.Thread(
+ target=copy_context().run, args=[self._run, start_future], daemon=True
+ )
+ self._thread.start()
+ try:
+ start_future.result()
+ except BaseException:
+ self._thread = None
+ raise
- if run_future.done():
- run_future.result()
+ atexit.register(self.stop)
- return self
+ def stop(self) -> None:
+ atexit.unregister(self.stop)
+ if self._state is RunState.started:
+ self._state = RunState.stopping
+ self._wakeup_event.set()
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._state = RunState.stopping
- self._wakeup_event.set()
- self._executor.shutdown(wait=exc_type is None)
- self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
- del self._wakeup_event
-
- def run(self) -> None:
- if self._state is not RunState.starting:
- raise RuntimeError(
- f"This function cannot be called while the worker is in the "
- f"{self._state} state"
- )
-
- # Set the current worker
- token = current_worker.set(self)
-
- # Signal that the worker has started
- self._state = RunState.started
- self._events.publish(WorkerStarted())
-
- executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs)
- exception: BaseException | None = None
- try:
- while self._state is RunState.started:
- available_slots = self.max_concurrent_jobs - len(self._running_jobs)
- if available_slots:
- jobs = self.data_store.acquire_jobs(self.identity, available_slots)
- for job in jobs:
- task = self.data_store.get_task(job.task_id)
- self._running_jobs.add(job.id)
- executor.submit(
- copy_context().run, self._run_job, job, task.func
+ if threading.current_thread() != self._thread:
+ self._thread.join()
+ self._thread = None
+
+ def run_until_stopped(self) -> None:
+ self._run(None)
+
+ def _run(self, start_future: Future[None] | None) -> None:
+ with ExitStack() as exit_stack:
+ try:
+ if self._state is not RunState.stopped:
+ raise RuntimeError(
+ f'Cannot start the worker when it is in the "{self._state}" '
+ f"state"
+ )
+
+ if not self._is_internal:
+ # Initialize the event broker
+ self.event_broker.start()
+ exit_stack.push(
+ lambda *exc_info: self.event_broker.stop(
+ force=exc_info[0] is not None
)
+ )
- self._wakeup_event.wait()
- self._wakeup_event = threading.Event()
- except BaseException as exc:
- self.logger.exception("Worker crashed")
- exception = exc
- else:
- self.logger.info("Worker stopped")
- finally:
- current_worker.reset(token)
- self._state = RunState.stopped
- self._events.publish(WorkerStopped(exception=exception))
- executor.shutdown()
+ # Initialize the data store
+ self.data_store.start(self.event_broker)
+ exit_stack.push(
+ lambda *exc_info: self.data_store.stop(
+ force=exc_info[0] is not None
+ )
+ )
+
+ # Set the current worker
+ token = current_worker.set(self)
+ exit_stack.callback(current_worker.reset, token)
+
+ # Wake up the worker if the data store emits a significant job event
+ exit_stack.enter_context(
+ self.event_broker.subscribe(
+ lambda event: self._wakeup_event.set(), {JobAdded}
+ )
+ )
+
+ # Initialize the thread pool
+ executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs)
+ exit_stack.enter_context(executor)
+
+ # Signal that the worker has started
+ self._state = RunState.started
+ self.event_broker.publish_local(WorkerStarted())
+ except BaseException as exc:
+ if start_future:
+ start_future.set_exception(exc)
+ return
+ else:
+ raise
+ else:
+ if start_future:
+ start_future.set_result(None)
+
+ try:
+ while self._state is RunState.started:
+ available_slots = self.max_concurrent_jobs - len(self._running_jobs)
+ if available_slots:
+ jobs = self.data_store.acquire_jobs(
+ self.identity, available_slots
+ )
+ for job in jobs:
+ task = self.data_store.get_task(job.task_id)
+ self._running_jobs.add(job.id)
+ executor.submit(
+ copy_context().run, self._run_job, job, task.func
+ )
+
+ self._wakeup_event.wait()
+ self._wakeup_event = threading.Event()
+ except BaseException as exc:
+ self._state = RunState.stopped
+ if isinstance(exc, Exception):
+ self.logger.exception("Worker crashed")
+ else:
+ self.logger.info(f"Worker stopped due to {exc.__class__.__name__}")
+
+ self.event_broker.publish_local(WorkerStopped(exception=exc))
+ else:
+ self._state = RunState.stopped
+ self.logger.info("Worker stopped")
+ self.event_broker.publish_local(WorkerStopped())
def _run_job(self, job: Job, func: Callable) -> None:
try: