diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-12 01:47:31 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-12 01:47:31 +0300 |
commit | a58fca290e0831d377d496a69101e5e3dc4c604e (patch) | |
tree | 8beb7504e7113ff1f01fb610513bb72745fa91ba /src/apscheduler/datastores | |
parent | 59ea7376985ef2c8b8b6b6d6df6b1b3be958480c (diff) | |
download | apscheduler-a58fca290e0831d377d496a69101e5e3dc4c604e.tar.gz |
Refactored event brokers to use exit stacks
Diffstat (limited to 'src/apscheduler/datastores')
-rw-r--r-- | src/apscheduler/datastores/async_adapter.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py index a5cf6a2..736685c 100644 --- a/src/apscheduler/datastores/async_adapter.py +++ b/src/apscheduler/datastores/async_adapter.py @@ -1,6 +1,8 @@ from __future__ import annotations +from contextlib import AsyncExitStack from datetime import datetime +from functools import partial from typing import Iterable, Optional from uuid import UUID @@ -21,23 +23,28 @@ class AsyncDataStoreAdapter(AsyncDataStore): original: DataStore _portal: BlockingPortal = attr.field(init=False) _events: AsyncEventBroker = attr.field(init=False) + _exit_stack: AsyncExitStack = attr.field(init=False) @property def events(self) -> EventSource: return self._events async def __aenter__(self) -> AsyncDataStoreAdapter: + self._exit_stack = AsyncExitStack() + self._portal = BlockingPortal() - await self._portal.__aenter__() + await self._exit_stack.enter_async_context(self._portal) + self._events = AsyncEventBrokerAdapter(self.original.events, self._portal) - await self._events.__aenter__() + await self._exit_stack.enter_async_context(self._events) + await to_thread.run_sync(self.original.__enter__) + self._exit_stack.push_async_exit(partial(to_thread.run_sync, self.original.__exit__)) + return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb) - await self._events.__aexit__(exc_type, exc_val, exc_tb) - await self._portal.__aexit__(exc_type, exc_val, exc_tb) + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) async def add_task(self, task: Task) -> None: await to_thread.run_sync(self.original.add_task, task) |