diff options
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/engine.py')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 56 |
1 files changed, 38 insertions, 18 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 9cd3cb2f8..8e5c01919 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -11,6 +11,7 @@ from .result import AsyncResult from ... import exc from ... import util from ...engine import create_engine as _create_engine +from ...engine.base import NestedTransaction from ...future import Connection from ...future import Engine from ...util.concurrency import greenlet_spawn @@ -86,7 +87,13 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): def __init__(self, async_engine, sync_connection=None): self.engine = async_engine self.sync_engine = async_engine.sync_engine - self.sync_connection = sync_connection + self.sync_connection = self._assign_proxied(sync_connection) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncConnection( + AsyncEngine._retrieve_proxy_for_target(target.engine), target + ) async def start(self, is_ctxmanager=False): """Start this :class:`_asyncio.AsyncConnection` object's context @@ -95,7 +102,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ if self.sync_connection: raise exc.InvalidRequestError("connection is already started") - self.sync_connection = await (greenlet_spawn(self.sync_engine.connect)) + self.sync_connection = self._assign_proxied( + await (greenlet_spawn(self.sync_engine.connect)) + ) return self @property @@ -216,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): trans = conn.get_transaction() if trans is not None: - return AsyncTransaction._from_existing_transaction(self, trans) + return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None @@ -236,9 +245,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): trans = conn.get_nested_transaction() if trans is not None: - return AsyncTransaction._from_existing_transaction( - self, trans, True - ) + return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None @@ -522,7 +529,11 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): "The asyncio extension requires an async driver to be used. " f"The loaded {sync_engine.dialect.driver!r} is not async." ) - self.sync_engine = self._proxied = sync_engine + self.sync_engine = self._proxied = self._assign_proxied(sync_engine) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncEngine(target) def begin(self): """Return a context manager which when entered will deliver an @@ -605,17 +616,24 @@ class AsyncTransaction(ProxyComparable, StartableContext): __slots__ = ("connection", "sync_transaction", "nested") def __init__(self, connection, nested=False): - self.connection = connection - self.sync_transaction = None + self.connection = connection # AsyncConnection + self.sync_transaction = None # sqlalchemy.engine.Transaction self.nested = nested @classmethod - def _from_existing_transaction( - cls, connection, sync_transaction, nested=False - ): + def _regenerate_proxy_for_target(cls, target): + sync_connection = target.connection + sync_transaction = target + nested = isinstance(target, NestedTransaction) + + async_connection = AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + assert async_connection is not None + obj = cls.__new__(cls) - obj.connection = connection - obj.sync_transaction = sync_transaction + obj.connection = async_connection + obj.sync_transaction = obj._assign_proxied(sync_transaction) obj.nested = nested return obj @@ -664,10 +682,12 @@ class AsyncTransaction(ProxyComparable, StartableContext): """ - self.sync_transaction = await greenlet_spawn( - self.connection._sync_connection().begin_nested - if self.nested - else self.connection._sync_connection().begin + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.connection._sync_connection().begin_nested + if self.nested + else self.connection._sync_connection().begin + ) ) if is_ctxmanager: self.sync_transaction.__enter__() |