diff options
-rw-r--r-- | src/apscheduler/context.py | 20 | ||||
-rw-r--r-- | src/apscheduler/schedulers/async_.py | 9 | ||||
-rw-r--r-- | src/apscheduler/schedulers/sync.py | 9 | ||||
-rw-r--r-- | src/apscheduler/structures.py | 17 | ||||
-rw-r--r-- | src/apscheduler/workers/async_.py | 10 | ||||
-rw-r--r-- | src/apscheduler/workers/sync.py | 15 | ||||
-rw-r--r-- | tests/test_schedulers.py | 54 |
7 files changed, 126 insertions, 8 deletions
diff --git a/src/apscheduler/context.py b/src/apscheduler/context.py new file mode 100644 index 0000000..aa9977e --- /dev/null +++ b/src/apscheduler/context.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .schedulers.async_ import AsyncScheduler + from .schedulers.sync import Scheduler + from .structures import JobInfo + from .workers.async_ import AsyncWorker + from .workers.sync import Worker + +#: The currently running (local) scheduler +current_scheduler: ContextVar[Scheduler | AsyncScheduler | None] = ContextVar('current_scheduler', + default=None) +#: The worker running the current job +current_worker: ContextVar[Worker | AsyncWorker | None] = ContextVar('current_worker', + default=None) +#: Metadata about the current job +job_info: ContextVar[JobInfo] = ContextVar('job_info') diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index a08c090..a10e772 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -14,6 +14,7 @@ from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_clas from anyio.abc import TaskGroup from ..abc import AsyncDataStore, DataStore, EventSource, Job, Schedule, Trigger +from ..context import current_scheduler from ..datastores.async_adapter import AsyncDataStoreAdapter from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState @@ -80,8 +81,12 @@ class AsyncScheduler: # Start the built-in worker, if configured to do so if self.start_worker: - self._worker = AsyncWorker(self.data_store) - await self._exit_stack.enter_async_context(self._worker) + token = current_scheduler.set(self) + try: + self._worker = AsyncWorker(self.data_store) + await self._exit_stack.enter_async_context(self._worker) + finally: + current_scheduler.reset(token) # Start the worker and return when it has signalled readiness or raised an exception self._task_group = create_task_group() diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index 905efac..2af4f8e 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Iterable, Mapping, Optional from uuid import UUID, uuid4 from ..abc import DataStore, EventSource, Trigger +from ..context import current_scheduler from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState from ..eventbrokers.local import LocalEventBroker @@ -74,8 +75,12 @@ class Scheduler: # Start the built-in worker, if configured to do so if self.start_worker: - self._worker = Worker(self.data_store) - self._exit_stack.enter_context(self._worker) + token = current_scheduler.set(self) + try: + self._worker = Worker(self.data_store) + self._exit_stack.enter_context(self._worker) + finally: + current_scheduler.reset(token) # Start the scheduler and return when it has signalled readiness or raised an exception start_future: Future[Event] = Future() diff --git a/src/apscheduler/structures.py b/src/apscheduler/structures.py index a1e653b..d49e44e 100644 --- a/src/apscheduler/structures.py +++ b/src/apscheduler/structures.py @@ -124,6 +124,23 @@ class Job: return cls(**marshalled) +@attr.define(kw_only=True) +class JobInfo: + job_id: UUID + task_id: str + schedule_id: Optional[str] + scheduled_fire_time: Optional[datetime] + jitter: timedelta + start_deadline: Optional[datetime] + tags: frozenset[str] + + @classmethod + def from_job(cls, job: Job) -> JobInfo: + return cls(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + scheduled_fire_time=job.scheduled_fire_time, jitter=job.jitter, + start_deadline=job.start_deadline, tags=job.tags) + + @attr.define(kw_only=True, frozen=True) class JobResult: job_id: UUID diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py index dfa81dc..e037a05 100644 --- a/src/apscheduler/workers/async_.py +++ b/src/apscheduler/workers/async_.py @@ -14,11 +14,12 @@ from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_clas from anyio.abc import CancelScope from ..abc import AsyncDataStore, DataStore, EventSource, Job +from ..context import current_worker, job_info from ..datastores.async_adapter import AsyncDataStoreAdapter from ..enums import JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import JobAdded, WorkerStarted, WorkerStopped -from ..structures import JobResult +from ..structures import JobInfo, JobResult class AsyncWorker: @@ -88,6 +89,9 @@ class AsyncWorker: raise RuntimeError(f'This function cannot be called while the worker is in the ' f'{self._state} state') + # Set the current worker + token = current_worker.set(self) + # Signal that the worker has started self._state = RunState.started task_status.started() @@ -114,6 +118,7 @@ class AsyncWorker: raise + current_worker.reset(token) self._state = RunState.stopped await self._events.publish(WorkerStopped()) @@ -126,6 +131,7 @@ class AsyncWorker: await self.data_store.release_job(self.identity, job.task_id, result) return + token = job_info.set(JobInfo.from_job(job)) try: retval = func(*job.args, **job.kwargs) if isawaitable(retval): @@ -142,5 +148,7 @@ class AsyncWorker: else: result = JobResult(job_id=job.id, outcome=JobOutcome.success, return_value=retval) await self.data_store.release_job(self.identity, job.task_id, result) + finally: + job_info.reset(token) finally: self._running_jobs.remove(job.id) diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 48922da..0afa8c7 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -5,16 +5,18 @@ import platform import threading from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import ExitStack +from contextvars import copy_context from datetime import datetime, timezone from logging import Logger, getLogger from typing import Callable, Optional from uuid import UUID from ..abc import DataStore, EventSource +from ..context import current_worker, job_info from ..enums import JobOutcome, RunState from ..eventbrokers.local import LocalEventBroker from ..events import JobAdded, WorkerStarted, WorkerStopped -from ..structures import Job, JobResult +from ..structures import Job, JobInfo, JobResult class Worker: @@ -66,7 +68,7 @@ class Worker: start_future: Future[None] = Future() with self._events.subscribe(start_future.set_result, one_shot=True): self._executor = ThreadPoolExecutor(1) - run_future = self._executor.submit(self.run) + run_future = self._executor.submit(copy_context().run, self.run) wait([start_future, run_future], return_when=FIRST_COMPLETED) if run_future.done(): @@ -86,6 +88,9 @@ class Worker: raise RuntimeError(f'This function cannot be called while the worker is in the ' f'{self._state} state') + # Set the current worker + token = current_worker.set(self) + # Signal that the worker has started self._state = RunState.started self._events.publish(WorkerStarted()) @@ -99,7 +104,7 @@ class Worker: for job in jobs: task = self.data_store.get_task(job.task_id) self._running_jobs.add(job.id) - executor.submit(self._run_job, job, task.func) + executor.submit(copy_context().run, self._run_job, job, task.func) self._wakeup_event.wait() self._wakeup_event = threading.Event() @@ -110,6 +115,7 @@ class Worker: raise executor.shutdown() + current_worker.reset(token) self._state = RunState.stopped self._events.publish(WorkerStopped()) @@ -122,6 +128,7 @@ class Worker: self.data_store.release_job(self.identity, job.task_id, result) return + token = job_info.set(JobInfo.from_job(job)) try: retval = func(*job.args, **job.kwargs) except BaseException as exc: @@ -132,5 +139,7 @@ class Worker: else: result = JobResult(job_id=job.id, outcome=JobOutcome.success, return_value=retval) self.data_store.release_job(self.identity, job.task_id, result) + finally: + job_info.reset(token) finally: self._running_jobs.remove(job.id) diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 240b7bf..2605535 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -10,12 +10,14 @@ import pytest from anyio import fail_after from pytest_mock import MockerFixture +from apscheduler.context import current_scheduler, current_worker, job_info from apscheduler.enums import JobOutcome from apscheduler.events import ( Event, JobAdded, ScheduleAdded, ScheduleRemoved, SchedulerStarted, SchedulerStopped, TaskAdded) from apscheduler.exceptions import JobLookupError from apscheduler.schedulers.async_ import AsyncScheduler from apscheduler.schedulers.sync import Scheduler +from apscheduler.structures import Job, Task from apscheduler.triggers.date import DateTrigger from apscheduler.triggers.interval import IntervalTrigger @@ -164,6 +166,32 @@ class TestAsyncScheduler: with pytest.raises(RuntimeError, match='failing as requested'): await scheduler.run_job(dummy_async_job, kwargs={'fail': True}) + async def test_contextvars(self) -> None: + def check_contextvars() -> None: + assert current_scheduler.get() is scheduler + assert current_worker.get() is scheduler.worker + info = job_info.get() + assert info.task_id == 'task_id' + assert info.schedule_id == 'foo' + assert info.scheduled_fire_time == scheduled_fire_time + assert info.jitter == timedelta(seconds=2.16) + assert info.start_deadline == start_deadline + assert info.tags == {'foo', 'bar'} + + scheduled_fire_time = datetime.now(timezone.utc) + start_deadline = datetime.now(timezone.utc) + timedelta(seconds=10) + async with AsyncScheduler() as scheduler: + await scheduler.data_store.add_task(Task(id='task_id', func=check_contextvars)) + job = Job(task_id='task_id', schedule_id='foo', + scheduled_fire_time=scheduled_fire_time, jitter=timedelta(seconds=2.16), + start_deadline=start_deadline, tags={'foo', 'bar'}) + await scheduler.data_store.add_job(job) + result = await scheduler.get_job_result(job.id) + if result.outcome is JobOutcome.error: + raise result.exception + else: + assert result.outcome is JobOutcome.success + class TestSyncScheduler: def test_schedule_job(self): @@ -280,3 +308,29 @@ class TestSyncScheduler: with Scheduler() as scheduler: with pytest.raises(RuntimeError, match='failing as requested'): scheduler.run_job(dummy_sync_job, kwargs={'fail': True}) + + def test_contextvars(self) -> None: + def check_contextvars() -> None: + assert current_scheduler.get() is scheduler + assert current_worker.get() is scheduler.worker + info = job_info.get() + assert info.task_id == 'task_id' + assert info.schedule_id == 'foo' + assert info.scheduled_fire_time == scheduled_fire_time + assert info.jitter == timedelta(seconds=2.16) + assert info.start_deadline == start_deadline + assert info.tags == {'foo', 'bar'} + + scheduled_fire_time = datetime.now(timezone.utc) + start_deadline = datetime.now(timezone.utc) + timedelta(seconds=10) + with Scheduler() as scheduler: + scheduler.data_store.add_task(Task(id='task_id', func=check_contextvars)) + job = Job(task_id='task_id', schedule_id='foo', + scheduled_fire_time=scheduled_fire_time, jitter=timedelta(seconds=2.16), + start_deadline=start_deadline, tags={'foo', 'bar'}) + scheduler.data_store.add_job(job) + result = scheduler.get_job_result(job.id) + if result.outcome is JobOutcome.error: + raise result.exception + else: + assert result.outcome is JobOutcome.success |