from __future__ import annotations import os import platform from contextlib import AsyncExitStack from datetime import datetime, timezone from inspect import isawaitable from logging import Logger, getLogger from types import TracebackType from typing import Callable from uuid import UUID import anyio import attrs from anyio import ( TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class, move_on_after, ) from anyio.abc import CancelScope, TaskGroup from .._converters import as_async_datastore, as_async_eventbroker from ..abc import AsyncDataStore, AsyncEventBroker, Job from ..context import current_worker, job_info from ..enums import JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import JobAdded, WorkerStarted, WorkerStopped from ..structures import JobInfo, JobResult from ..validators import positive_integer @attrs.define(eq=False) class AsyncWorker: """Runs jobs locally in a task group.""" data_store: AsyncDataStore = attrs.field(converter=as_async_datastore) event_broker: AsyncEventBroker = attrs.field( converter=as_async_eventbroker, factory=LocalAsyncEventBroker ) max_concurrent_jobs: int = attrs.field( kw_only=True, validator=positive_integer, default=100 ) identity: str = attrs.field(kw_only=True, default=None) logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__)) # True if a scheduler owns this worker _is_internal: bool = attrs.field(kw_only=True, default=False) _state: RunState = attrs.field(init=False, default=RunState.stopped) _wakeup_event: anyio.Event = attrs.field(init=False) _task_group: TaskGroup = attrs.field(init=False) _acquired_jobs: set[Job] = attrs.field(init=False, factory=set) _running_jobs: set[UUID] = attrs.field(init=False, factory=set) def __attrs_post_init__(self) -> None: if not self.identity: self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}" async def __aenter__(self) -> AsyncWorker: self._task_group = create_task_group() await self._task_group.__aenter__() await self._task_group.start(self.run_until_stopped) return self async def __aexit__( self, exc_type: type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> None: self._state = RunState.stopping self._wakeup_event.set() await self._task_group.__aexit__(exc_type, exc_val, exc_tb) del self._task_group del self._wakeup_event @property def state(self) -> RunState: """The current running state of the worker.""" return self._state async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None: """Run the worker until it is explicitly stopped.""" if self._state is not RunState.stopped: raise RuntimeError( f'Cannot start the worker when it is in the "{self._state}" ' f"state" ) self._state = RunState.starting self._wakeup_event = anyio.Event() async with AsyncExitStack() as exit_stack: if not self._is_internal: # Initialize the event broker await self.event_broker.start() exit_stack.push_async_exit( lambda *exc_info: self.event_broker.stop( force=exc_info[0] is not None ) ) # Initialize the data store await self.data_store.start(self.event_broker) exit_stack.push_async_exit( lambda *exc_info: self.data_store.stop( force=exc_info[0] is not None ) ) # Set the current worker token = current_worker.set(self) exit_stack.callback(current_worker.reset, token) # Wake up the worker if the data store emits a significant job event self.event_broker.subscribe( lambda event: self._wakeup_event.set(), {JobAdded} ) # Signal that the worker has started self._state = RunState.started task_status.started() exception: BaseException | None = None try: await self.event_broker.publish_local(WorkerStarted()) async with create_task_group() as tg: while self._state is RunState.started: limit = self.max_concurrent_jobs - len(self._running_jobs) jobs = await self.data_store.acquire_jobs(self.identity, limit) for job in jobs: task = await self.data_store.get_task(job.task_id) self._running_jobs.add(job.id) tg.start_soon(self._run_job, job, task.func) await self._wakeup_event.wait() self._wakeup_event = anyio.Event() except get_cancelled_exc_class(): pass except BaseException as exc: self.logger.exception("Worker crashed") exception = exc else: self.logger.info("Worker stopped") finally: self._state = RunState.stopped with move_on_after(3, shield=True): await self.event_broker.publish_local( WorkerStopped(exception=exception) ) async def stop(self, *, force: bool = False) -> None: """ Signal the worker that it should stop running jobs. This method does not wait for the worker to actually stop. """ if self._state in (RunState.starting, RunState.started): self._state = RunState.stopping event = anyio.Event() self.event_broker.subscribe( lambda ev: event.set(), {WorkerStopped}, one_shot=True ) if force: self._task_group.cancel_scope.cancel() else: self._wakeup_event.set() await event.wait() async def _run_job(self, job: Job, func: Callable) -> None: try: # Check if the job started before the deadline start_time = datetime.now(timezone.utc) if job.start_deadline is not None and start_time > job.start_deadline: result = JobResult( job_id=job.id, outcome=JobOutcome.missed_start_deadline ) await self.data_store.release_job(self.identity, job.task_id, result) return token = job_info.set(JobInfo.from_job(job)) try: retval = func(*job.args, **job.kwargs) if isawaitable(retval): retval = await retval except get_cancelled_exc_class(): with CancelScope(shield=True): result = JobResult(job_id=job.id, outcome=JobOutcome.cancelled) await self.data_store.release_job( self.identity, job.task_id, result ) except BaseException as exc: result = JobResult( job_id=job.id, outcome=JobOutcome.error, exception=exc ) await self.data_store.release_job(self.identity, job.task_id, result) if not isinstance(exc, Exception): raise else: result = JobResult( job_id=job.id, outcome=JobOutcome.success, return_value=retval ) await self.data_store.release_job(self.identity, job.task_id, result) finally: job_info.reset(token) finally: self._running_jobs.remove(job.id)