diff options
-rw-r--r-- | src/apscheduler/datastores/async_adapter.py | 15 | ||||
-rw-r--r-- | tests/test_datastores.py | 31 | ||||
-rw-r--r-- | tests/test_eventbrokers.py | 30 |
3 files changed, 61 insertions, 15 deletions
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py index 99accd4..6ae51b9 100644 --- a/src/apscheduler/datastores/async_adapter.py +++ b/src/apscheduler/datastores/async_adapter.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from datetime import datetime from typing import Iterable from uuid import UUID @@ -31,12 +32,18 @@ class AsyncDataStoreAdapter(BaseAsyncDataStore): else: sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal) - await to_thread.run_sync(lambda: self.original.start(sync_event_broker)) + try: + await to_thread.run_sync(lambda: self.original.start(sync_event_broker)) + except BaseException: + await self._portal.__aexit__(*sys.exc_info()) + raise async def stop(self, *, force: bool = False) -> None: - await to_thread.run_sync(lambda: self.original.stop(force=force)) - await self._portal.__aexit__(None, None, None) - await super().stop(force=force) + try: + await to_thread.run_sync(lambda: self.original.stop(force=force)) + finally: + await self._portal.__aexit__(None, None, None) + await super().stop(force=force) async def add_task(self, task: Task) -> None: await to_thread.run_sync(self.original.add_task, task) diff --git a/tests/test_datastores.py b/tests/test_datastores.py index ea5416e..0b594cf 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -10,6 +10,7 @@ from typing import Any, AsyncGenerator, cast import anyio import pytest from _pytest.fixtures import SubRequest +from anyio import CancelScope from freezegun.api import FrozenDateTimeFactory from pytest_lazyfixture import lazy_fixture @@ -614,13 +615,17 @@ class TestAsyncDataStores: ), ] ) - async def datastore( + async def raw_datastore( self, request: SubRequest, event_broker: AsyncEventBroker + ) -> AsyncDataStore: + return cast(AsyncDataStore, request.param) + + async def datastore( + self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker ) -> AsyncGenerator[AsyncDataStore, Any]: - datastore = cast(AsyncDataStore, request.param) - await datastore.start(event_broker) - yield datastore - await datastore.stop() + await raw_datastore.start(event_broker) + yield raw_datastore + await raw_datastore.stop() async def test_add_replace_task(self, datastore: AsyncDataStore) -> None: import math @@ -1012,3 +1017,19 @@ class TestAsyncDataStores: ) acquired_jobs = await datastore.acquire_jobs("worker1", 3) assert [job.id for job in acquired_jobs] == [jobs[2].id] + + async def test_cancel_start( + self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker + ) -> None: + with CancelScope() as scope: + scope.cancel() + await raw_datastore.start(event_broker) + await raw_datastore.stop() + + async def test_cancel_stop( + self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker + ) -> None: + with CancelScope() as scope: + await raw_datastore.start(event_broker) + scope.cancel() + await raw_datastore.stop() diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py index ffa4be7..43bfddd 100644 --- a/tests/test_eventbrokers.py +++ b/tests/test_eventbrokers.py @@ -4,12 +4,12 @@ from collections.abc import AsyncGenerator, Generator from concurrent.futures import Future from datetime import datetime, timezone from queue import Empty, Queue -from typing import Any +from typing import Any, cast import pytest from _pytest.fixtures import SubRequest from _pytest.logging import LogCaptureFixture -from anyio import create_memory_object_stream, fail_after +from anyio import CancelScope, create_memory_object_stream, fail_after from pytest_lazyfixture import lazy_fixture from apscheduler.abc import AsyncEventBroker, EventBroker, Serializer @@ -90,10 +90,16 @@ def broker(request: SubRequest) -> Generator[EventBroker, Any, None]: ), ] ) -async def async_broker(request: SubRequest) -> AsyncGenerator[AsyncEventBroker, Any]: - await request.param.start() - yield request.param - await request.param.stop() +async def raw_async_broker(request: SubRequest) -> AsyncEventBroker: + return cast(AsyncEventBroker, request.param) + + +async def async_broker( + raw_async_broker: AsyncEventBroker, +) -> AsyncGenerator[AsyncEventBroker, Any]: + await raw_async_broker.start() + yield raw_async_broker + await raw_async_broker.stop() class TestEventBroker: @@ -252,3 +258,15 @@ class TestAsyncEventBroker: received_event = await receive.receive() assert received_event.timestamp == timestamp assert "Error delivering Event" in caplog.text + + async def test_cancel_start(self, raw_async_broker: AsyncEventBroker) -> None: + with CancelScope() as scope: + scope.cancel() + await raw_async_broker.start() + await raw_async_broker.stop() + + async def test_cancel_stop(self, raw_async_broker: AsyncEventBroker) -> None: + with CancelScope() as scope: + await raw_async_broker.start() + scope.cancel() + await raw_async_broker.stop() |