summaryrefslogtreecommitdiff
path: root/src/apscheduler/datastores
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-12 01:47:31 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-12 01:47:31 +0300
commita58fca290e0831d377d496a69101e5e3dc4c604e (patch)
tree8beb7504e7113ff1f01fb610513bb72745fa91ba /src/apscheduler/datastores
parent59ea7376985ef2c8b8b6b6d6df6b1b3be958480c (diff)
downloadapscheduler-a58fca290e0831d377d496a69101e5e3dc4c604e.tar.gz
Refactored event brokers to use exit stacks
Diffstat (limited to 'src/apscheduler/datastores')
-rw-r--r--src/apscheduler/datastores/async_adapter.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py
index a5cf6a2..736685c 100644
--- a/src/apscheduler/datastores/async_adapter.py
+++ b/src/apscheduler/datastores/async_adapter.py
@@ -1,6 +1,8 @@
from __future__ import annotations
+from contextlib import AsyncExitStack
from datetime import datetime
+from functools import partial
from typing import Iterable, Optional
from uuid import UUID
@@ -21,23 +23,28 @@ class AsyncDataStoreAdapter(AsyncDataStore):
original: DataStore
_portal: BlockingPortal = attr.field(init=False)
_events: AsyncEventBroker = attr.field(init=False)
+ _exit_stack: AsyncExitStack = attr.field(init=False)
@property
def events(self) -> EventSource:
return self._events
async def __aenter__(self) -> AsyncDataStoreAdapter:
+ self._exit_stack = AsyncExitStack()
+
self._portal = BlockingPortal()
- await self._portal.__aenter__()
+ await self._exit_stack.enter_async_context(self._portal)
+
self._events = AsyncEventBrokerAdapter(self.original.events, self._portal)
- await self._events.__aenter__()
+ await self._exit_stack.enter_async_context(self._events)
+
await to_thread.run_sync(self.original.__enter__)
+ self._exit_stack.push_async_exit(partial(to_thread.run_sync, self.original.__exit__))
+
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
- await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb)
- await self._events.__aexit__(exc_type, exc_val, exc_tb)
- await self._portal.__aexit__(exc_type, exc_val, exc_tb)
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
async def add_task(self, task: Task) -> None:
await to_thread.run_sync(self.original.add_task, task)