summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/engine.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-06-02 12:23:31 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-06-02 16:09:14 -0400
commit97d922663a0350c6ce026ecfbde8010ca1bc0c37 (patch)
tree438b4341441b33cee08b8f01022cd2ff383277f2 /lib/sqlalchemy/ext/asyncio/engine.py
parentf51c56b8dca0569269a69bd85c25fcfed39a3c9e (diff)
downloadsqlalchemy-97d922663a0350c6ce026ecfbde8010ca1bc0c37.tar.gz
Implement proxy back reference system for asyncio
Implemented a new registry architecture that allows the ``Async`` version of an object, like ``AsyncSession``, ``AsyncConnection``, etc., to be locatable given the proxied "sync" object, i.e. ``Session``, ``Connection``. Previously, to the degree such lookup functions were used, an ``Async`` object would be re-created each time, which was less than ideal as the identity and state of the "async" object would not be preserved across calls. From there, new helper functions :func:`_asyncio.async_object_session`, :func:`_asyncio.async_session` as well as a new :class:`_orm.InstanceState` attribute :attr:`_orm.InstanceState.asyncio_session` have been added, which are used to retrieve the original :class:`_asyncio.AsyncSession` associated with an ORM mapped object, a :class:`_orm.Session` associated with an :class:`_asyncio.AsyncSession`, and an :class:`_asyncio.AsyncSession` associated with an :class:`_orm.InstanceState`, respectively. This patch also implements new methods :meth:`_asyncio.AsyncSession.in_nested_transaction`, :meth:`_asyncio.AsyncSession.get_transaction`, :meth:`_asyncio.AsyncSession.get_nested_transaction`. Fixes: #6319 Change-Id: Ia452a7e7ce9bad3ff8846c7dea8d45c839ac9fac
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/engine.py')
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py56
1 files changed, 38 insertions, 18 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 9cd3cb2f8..8e5c01919 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -11,6 +11,7 @@ from .result import AsyncResult
from ... import exc
from ... import util
from ...engine import create_engine as _create_engine
+from ...engine.base import NestedTransaction
from ...future import Connection
from ...future import Engine
from ...util.concurrency import greenlet_spawn
@@ -86,7 +87,13 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
def __init__(self, async_engine, sync_connection=None):
self.engine = async_engine
self.sync_engine = async_engine.sync_engine
- self.sync_connection = sync_connection
+ self.sync_connection = self._assign_proxied(sync_connection)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncConnection(
+ AsyncEngine._retrieve_proxy_for_target(target.engine), target
+ )
async def start(self, is_ctxmanager=False):
"""Start this :class:`_asyncio.AsyncConnection` object's context
@@ -95,7 +102,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
if self.sync_connection:
raise exc.InvalidRequestError("connection is already started")
- self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
+ self.sync_connection = self._assign_proxied(
+ await (greenlet_spawn(self.sync_engine.connect))
+ )
return self
@property
@@ -216,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
trans = conn.get_transaction()
if trans is not None:
- return AsyncTransaction._from_existing_transaction(self, trans)
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
@@ -236,9 +245,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
trans = conn.get_nested_transaction()
if trans is not None:
- return AsyncTransaction._from_existing_transaction(
- self, trans, True
- )
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
@@ -522,7 +529,11 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
"The asyncio extension requires an async driver to be used. "
f"The loaded {sync_engine.dialect.driver!r} is not async."
)
- self.sync_engine = self._proxied = sync_engine
+ self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncEngine(target)
def begin(self):
"""Return a context manager which when entered will deliver an
@@ -605,17 +616,24 @@ class AsyncTransaction(ProxyComparable, StartableContext):
__slots__ = ("connection", "sync_transaction", "nested")
def __init__(self, connection, nested=False):
- self.connection = connection
- self.sync_transaction = None
+ self.connection = connection # AsyncConnection
+ self.sync_transaction = None # sqlalchemy.engine.Transaction
self.nested = nested
@classmethod
- def _from_existing_transaction(
- cls, connection, sync_transaction, nested=False
- ):
+ def _regenerate_proxy_for_target(cls, target):
+ sync_connection = target.connection
+ sync_transaction = target
+ nested = isinstance(target, NestedTransaction)
+
+ async_connection = AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+ assert async_connection is not None
+
obj = cls.__new__(cls)
- obj.connection = connection
- obj.sync_transaction = sync_transaction
+ obj.connection = async_connection
+ obj.sync_transaction = obj._assign_proxied(sync_transaction)
obj.nested = nested
return obj
@@ -664,10 +682,12 @@ class AsyncTransaction(ProxyComparable, StartableContext):
"""
- self.sync_transaction = await greenlet_spawn(
- self.connection._sync_connection().begin_nested
- if self.nested
- else self.connection._sync_connection().begin
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.connection._sync_connection().begin_nested
+ if self.nested
+ else self.connection._sync_connection().begin
+ )
)
if is_ctxmanager:
self.sync_transaction.__enter__()