diff options
Diffstat (limited to 'src/apscheduler/workers/sync.py')
-rw-r--r-- | src/apscheduler/workers/sync.py | 202 |
1 files changed, 122 insertions, 80 deletions
diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 2da0045..5b5e6a9 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -1,19 +1,21 @@ from __future__ import annotations +import atexit import os import platform import threading -from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from concurrent.futures import Future, ThreadPoolExecutor from contextlib import ExitStack from contextvars import copy_context from datetime import datetime, timezone from logging import Logger, getLogger +from types import TracebackType from typing import Callable from uuid import UUID import attrs -from ..abc import DataStore, EventSource +from ..abc import DataStore, EventBroker from ..context import current_worker, job_info from ..enums import JobOutcome, RunState from ..eventbrokers.local import LocalEventBroker @@ -27,111 +29,151 @@ class Worker: """Runs jobs locally in a thread pool.""" data_store: DataStore + event_broker: EventBroker = attrs.field(factory=LocalEventBroker) max_concurrent_jobs: int = attrs.field( kw_only=True, validator=positive_integer, default=20 ) identity: str = attrs.field(kw_only=True, default=None) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) + # True if a scheduler owns this worker + _is_internal: bool = attrs.field(kw_only=True, default=False) _state: RunState = attrs.field(init=False, default=RunState.stopped) - _wakeup_event: threading.Event = attrs.field(init=False) + _thread: threading.Thread | None = attrs.field(init=False, default=None) + _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event) + _executor: ThreadPoolExecutor = attrs.field(init=False) _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) - _events: LocalEventBroker = attrs.field(init=False, factory=LocalEventBroker) _running_jobs: set[UUID] = attrs.field(init=False, factory=set) - _exit_stack: ExitStack = attrs.field(init=False) - _executor: ThreadPoolExecutor = attrs.field(init=False) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" @property - def events(self) -> EventSource: - return self._events - - @property def state(self) -> RunState: return self._state def __enter__(self) -> Worker: - self._state = RunState.starting - self._wakeup_event = threading.Event() - self._exit_stack = ExitStack() - self._exit_stack.__enter__() - self._exit_stack.enter_context(self._events) - - # Initialize the data store and start relaying events to the worker's event broker - self._exit_stack.enter_context(self.data_store) - self._exit_stack.enter_context( - self.data_store.events.subscribe(self._events.publish) - ) + self.start_in_background() + return self - # Wake up the worker if the data store emits a significant job event - self._exit_stack.enter_context( - self.data_store.events.subscribe( - lambda event: self._wakeup_event.set(), {JobAdded} - ) - ) + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + self.stop() - # Start the worker and return when it has signalled readiness or raised an exception + def start_in_background(self) -> None: start_future: Future[None] = Future() - with self._events.subscribe(start_future.set_result, one_shot=True): - self._executor = ThreadPoolExecutor(1) - run_future = self._executor.submit(copy_context().run, self.run) - wait([start_future, run_future], return_when=FIRST_COMPLETED) + self._thread = threading.Thread( + target=copy_context().run, args=[self._run, start_future], daemon=True + ) + self._thread.start() + try: + start_future.result() + except BaseException: + self._thread = None + raise - if run_future.done(): - run_future.result() + atexit.register(self.stop) - return self + def stop(self) -> None: + atexit.unregister(self.stop) + if self._state is RunState.started: + self._state = RunState.stopping + self._wakeup_event.set() - def __exit__(self, exc_type, exc_val, exc_tb): - self._state = RunState.stopping - self._wakeup_event.set() - self._executor.shutdown(wait=exc_type is None) - self._exit_stack.__exit__(exc_type, exc_val, exc_tb) - del self._wakeup_event - - def run(self) -> None: - if self._state is not RunState.starting: - raise RuntimeError( - f"This function cannot be called while the worker is in the " - f"{self._state} state" - ) - - # Set the current worker - token = current_worker.set(self) - - # Signal that the worker has started - self._state = RunState.started - self._events.publish(WorkerStarted()) - - executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) - exception: BaseException | None = None - try: - while self._state is RunState.started: - available_slots = self.max_concurrent_jobs - len(self._running_jobs) - if available_slots: - jobs = self.data_store.acquire_jobs(self.identity, available_slots) - for job in jobs: - task = self.data_store.get_task(job.task_id) - self._running_jobs.add(job.id) - executor.submit( - copy_context().run, self._run_job, job, task.func + if threading.current_thread() != self._thread: + self._thread.join() + self._thread = None + + def run_until_stopped(self) -> None: + self._run(None) + + def _run(self, start_future: Future[None] | None) -> None: + with ExitStack() as exit_stack: + try: + if self._state is not RunState.stopped: + raise RuntimeError( + f'Cannot start the worker when it is in the "{self._state}" ' + f"state" + ) + + if not self._is_internal: + # Initialize the event broker + self.event_broker.start() + exit_stack.push( + lambda *exc_info: self.event_broker.stop( + force=exc_info[0] is not None ) + ) - self._wakeup_event.wait() - self._wakeup_event = threading.Event() - except BaseException as exc: - self.logger.exception("Worker crashed") - exception = exc - else: - self.logger.info("Worker stopped") - finally: - current_worker.reset(token) - self._state = RunState.stopped - self._events.publish(WorkerStopped(exception=exception)) - executor.shutdown() + # Initialize the data store + self.data_store.start(self.event_broker) + exit_stack.push( + lambda *exc_info: self.data_store.stop( + force=exc_info[0] is not None + ) + ) + + # Set the current worker + token = current_worker.set(self) + exit_stack.callback(current_worker.reset, token) + + # Wake up the worker if the data store emits a significant job event + exit_stack.enter_context( + self.event_broker.subscribe( + lambda event: self._wakeup_event.set(), {JobAdded} + ) + ) + + # Initialize the thread pool + executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) + exit_stack.enter_context(executor) + + # Signal that the worker has started + self._state = RunState.started + self.event_broker.publish_local(WorkerStarted()) + except BaseException as exc: + if start_future: + start_future.set_exception(exc) + return + else: + raise + else: + if start_future: + start_future.set_result(None) + + try: + while self._state is RunState.started: + available_slots = self.max_concurrent_jobs - len(self._running_jobs) + if available_slots: + jobs = self.data_store.acquire_jobs( + self.identity, available_slots + ) + for job in jobs: + task = self.data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + executor.submit( + copy_context().run, self._run_job, job, task.func + ) + + self._wakeup_event.wait() + self._wakeup_event = threading.Event() + except BaseException as exc: + self._state = RunState.stopped + if isinstance(exc, Exception): + self.logger.exception("Worker crashed") + else: + self.logger.info(f"Worker stopped due to {exc.__class__.__name__}") + + self.event_broker.publish_local(WorkerStopped(exception=exc)) + else: + self._state = RunState.stopped + self.logger.info("Worker stopped") + self.event_broker.publish_local(WorkerStopped()) def _run_job(self, job: Job, func: Callable) -> None: try: |