summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio')
-rw-r--r--lib/sqlalchemy/ext/asyncio/__init__.py29
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py118
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py401
-rw-r--r--lib/sqlalchemy/ext/asyncio/events.py44
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py180
-rw-r--r--lib/sqlalchemy/ext/asyncio/scoping.py241
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py406
7 files changed, 956 insertions, 463 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
index 15b2cb015..dfe89a154 100644
--- a/lib/sqlalchemy/ext/asyncio/__init__.py
+++ b/lib/sqlalchemy/ext/asyncio/__init__.py
@@ -5,18 +5,17 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-from .engine import async_engine_from_config
-from .engine import AsyncConnection
-from .engine import AsyncEngine
-from .engine import AsyncTransaction
-from .engine import create_async_engine
-from .events import AsyncConnectionEvents
-from .events import AsyncSessionEvents
-from .result import AsyncMappingResult
-from .result import AsyncResult
-from .result import AsyncScalarResult
-from .scoping import async_scoped_session
-from .session import async_object_session
-from .session import async_session
-from .session import AsyncSession
-from .session import AsyncSessionTransaction
+from .engine import async_engine_from_config as async_engine_from_config
+from .engine import AsyncConnection as AsyncConnection
+from .engine import AsyncEngine as AsyncEngine
+from .engine import AsyncTransaction as AsyncTransaction
+from .engine import create_async_engine as create_async_engine
+from .result import AsyncMappingResult as AsyncMappingResult
+from .result import AsyncResult as AsyncResult
+from .result import AsyncScalarResult as AsyncScalarResult
+from .scoping import async_scoped_session as async_scoped_session
+from .session import async_object_session as async_object_session
+from .session import async_session as async_session
+from .session import async_sessionmaker as async_sessionmaker
+from .session import AsyncSession as AsyncSession
+from .session import AsyncSessionTransaction as AsyncSessionTransaction
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
index 3f77f5500..7fdd2d7e0 100644
--- a/lib/sqlalchemy/ext/asyncio/base.py
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -1,36 +1,103 @@
+# ext/asyncio/base.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import annotations
+
import abc
import functools
+from typing import Any
+from typing import ClassVar
+from typing import Dict
+from typing import Generic
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TypeVar
import weakref
from . import exc as async_exc
+from ... import util
+from ...util.typing import Literal
+
+_T = TypeVar("_T", bound=Any)
+
+
+_PT = TypeVar("_PT", bound=Any)
-class ReversibleProxy:
- # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
- _proxy_objects = {}
+SelfReversibleProxy = TypeVar(
+ "SelfReversibleProxy", bound="ReversibleProxy[Any]"
+)
+
+
+class ReversibleProxy(Generic[_PT]):
+ _proxy_objects: ClassVar[
+ Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
+ ] = {}
__slots__ = ("__weakref__",)
- def _assign_proxied(self, target):
+ @overload
+ def _assign_proxied(self, target: _PT) -> _PT:
+ ...
+
+ @overload
+ def _assign_proxied(self, target: None) -> None:
+ ...
+
+ def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
if target is not None:
- target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ target_ref: weakref.ref[_PT] = weakref.ref(
+ target, ReversibleProxy._target_gced
+ )
proxy_ref = weakref.ref(
self,
- functools.partial(ReversibleProxy._target_gced, target_ref),
+ functools.partial( # type: ignore
+ ReversibleProxy._target_gced, target_ref
+ ),
)
ReversibleProxy._proxy_objects[target_ref] = proxy_ref
return target
@classmethod
- def _target_gced(cls, ref, proxy_ref=None):
+ def _target_gced(
+ cls: Type[SelfReversibleProxy],
+ ref: weakref.ref[_PT],
+ proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None,
+ ) -> None:
cls._proxy_objects.pop(ref, None)
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT
+ ) -> SelfReversibleProxy:
raise NotImplementedError()
+ @overload
@classmethod
- def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy],
+ target: _PT,
+ regenerate: Literal[True] = ...,
+ ) -> SelfReversibleProxy:
+ ...
+
+ @overload
+ @classmethod
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+ ) -> Optional[SelfReversibleProxy]:
+ ...
+
+ @classmethod
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+ ) -> Optional[SelfReversibleProxy]:
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
except KeyError:
@@ -38,7 +105,7 @@ class ReversibleProxy:
else:
proxy = proxy_ref()
if proxy is not None:
- return proxy
+ return proxy # type: ignore
if regenerate:
return cls._regenerate_proxy_for_target(target)
@@ -46,43 +113,54 @@ class ReversibleProxy:
return None
+SelfStartableContext = TypeVar(
+ "SelfStartableContext", bound="StartableContext"
+)
+
+
class StartableContext(abc.ABC):
__slots__ = ()
@abc.abstractmethod
- async def start(self, is_ctxmanager=False):
- pass
+ async def start(
+ self: SelfStartableContext, is_ctxmanager: bool = False
+ ) -> Any:
+ raise NotImplementedError()
- def __await__(self):
+ def __await__(self) -> Any:
return self.start().__await__()
- async def __aenter__(self):
+ async def __aenter__(self: SelfStartableContext) -> Any:
return await self.start(is_ctxmanager=True)
@abc.abstractmethod
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
pass
- def _raise_for_not_started(self):
+ def _raise_for_not_started(self) -> NoReturn:
raise async_exc.AsyncContextNotStarted(
"%s context has not been started and object has not been awaited."
% (self.__class__.__name__)
)
-class ProxyComparable(ReversibleProxy):
+class ProxyComparable(ReversibleProxy[_PT]):
__slots__ = ()
- def __hash__(self):
+ @util.ro_non_memoized_property
+ def _proxied(self) -> _PT:
+ raise NotImplementedError()
+
+ def __hash__(self) -> int:
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and self._proxied == other._proxied
)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return (
not isinstance(other, self.__class__)
or self._proxied != other._proxied
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 3b54405c1..bb51a4d22 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -7,23 +7,56 @@
from __future__ import annotations
from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from . import exc as async_exc
from .base import ProxyComparable
from .base import StartableContext
from .result import _ensure_sync_result
from .result import AsyncResult
+from .result import AsyncScalarResult
from ... import exc
from ... import inspection
from ... import util
+from ...engine import Connection
from ...engine import create_engine as _create_engine
+from ...engine import Engine
from ...engine.base import NestedTransaction
-from ...future import Connection
-from ...future import Engine
+from ...engine.base import Transaction
from ...util.concurrency import greenlet_spawn
-
-
-def create_async_engine(*arg, **kw):
+from ...util.typing import Protocol
+
+if TYPE_CHECKING:
+ from ...engine import Connection
+ from ...engine import Engine
+ from ...engine.cursor import CursorResult
+ from ...engine.interfaces import _CoreAnyExecuteParams
+ from ...engine.interfaces import _CoreSingleExecuteParams
+ from ...engine.interfaces import _DBAPIAnyExecuteParams
+ from ...engine.interfaces import _ExecuteOptions
+ from ...engine.interfaces import _ExecuteOptionsParameter
+ from ...engine.interfaces import _IsolationLevel
+ from ...engine.interfaces import Dialect
+ from ...engine.result import ScalarResult
+ from ...engine.url import URL
+ from ...pool import Pool
+ from ...pool import PoolProxiedConnection
+ from ...sql.base import Executable
+
+
+class _SyncConnectionCallable(Protocol):
+ def __call__(self, connection: Connection, *arg: Any, **kw: Any) -> Any:
+ ...
+
+
+def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
"""Create a new async engine instance.
Arguments passed to :func:`_asyncio.create_async_engine` are mostly
@@ -43,11 +76,13 @@ def create_async_engine(*arg, **kw):
)
kw["future"] = True
kw["_is_async"] = True
- sync_engine = _create_engine(*arg, **kw)
+ sync_engine = _create_engine(url, **kw)
return AsyncEngine(sync_engine)
-def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+def async_engine_from_config(
+ configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any
+) -> AsyncEngine:
"""Create a new AsyncEngine instance using a configuration dictionary.
This function is analogous to the :func:`_sa.engine_from_config` function
@@ -73,6 +108,14 @@ def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
class AsyncConnectable:
__slots__ = "_slots_dispatch", "__weakref__"
+ @classmethod
+ def _no_async_engine_events(cls) -> NoReturn:
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes."
+ )
+
@util.create_proxy_methods(
Connection,
@@ -87,7 +130,9 @@ class AsyncConnectable:
"default_isolation_level",
],
)
-class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
+class AsyncConnection(
+ ProxyComparable[Connection], StartableContext, AsyncConnectable
+):
"""An asyncio proxy for a :class:`_engine.Connection`.
:class:`_asyncio.AsyncConnection` is acquired using the
@@ -115,12 +160,16 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"sync_connection",
)
- def __init__(self, async_engine, sync_connection=None):
+ def __init__(
+ self,
+ async_engine: AsyncEngine,
+ sync_connection: Optional[Connection] = None,
+ ):
self.engine = async_engine
self.sync_engine = async_engine.sync_engine
self.sync_connection = self._assign_proxied(sync_connection)
- sync_connection: Connection
+ sync_connection: Optional[Connection]
"""Reference to the sync-style :class:`_engine.Connection` this
:class:`_asyncio.AsyncConnection` proxies requests towards.
@@ -146,12 +195,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(
+ cls, target: Connection
+ ) -> AsyncConnection:
return AsyncConnection(
AsyncEngine._retrieve_proxy_for_target(target.engine), target
)
- async def start(self, is_ctxmanager=False):
+ async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
"""Start this :class:`_asyncio.AsyncConnection` object's context
outside of using a Python ``with:`` block.
@@ -164,7 +215,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
return self
@property
- def connection(self):
+ def connection(self) -> NoReturn:
"""Not implemented for async; call
:meth:`_asyncio.AsyncConnection.get_raw_connection`.
"""
@@ -174,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"Use the get_raw_connection() method."
)
- async def get_raw_connection(self):
+ async def get_raw_connection(self) -> PoolProxiedConnection:
"""Return the pooled DBAPI-level connection in use by this
:class:`_asyncio.AsyncConnection`.
@@ -187,16 +238,11 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
adapts the driver connection to the DBAPI protocol.
"""
- conn = self._sync_connection()
-
- return await greenlet_spawn(getattr, conn, "connection")
- @property
- def _proxied(self):
- return self.sync_connection
+ return await greenlet_spawn(getattr, self._proxied, "connection")
@property
- def info(self):
+ def info(self) -> Dict[str, Any]:
"""Return the :attr:`_engine.Connection.info` dictionary of the
underlying :class:`_engine.Connection`.
@@ -211,24 +257,28 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. versionadded:: 1.4.0b2
"""
- return self.sync_connection.info
+ return self._proxied.info
- def _sync_connection(self):
+ @util.ro_non_memoized_property
+ def _proxied(self) -> Connection:
if not self.sync_connection:
self._raise_for_not_started()
return self.sync_connection
- def begin(self):
+ def begin(self) -> AsyncTransaction:
"""Begin a transaction prior to autobegin occurring."""
- self._sync_connection()
+ assert self._proxied
return AsyncTransaction(self)
- def begin_nested(self):
+ def begin_nested(self) -> AsyncTransaction:
"""Begin a nested transaction and return a transaction handle."""
- self._sync_connection()
+ assert self._proxied
return AsyncTransaction(self, nested=True)
- async def invalidate(self, exception=None):
+ async def invalidate(
+ self, exception: Optional[BaseException] = None
+ ) -> None:
+
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
@@ -237,39 +287,27 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
- conn = self._sync_connection()
- return await greenlet_spawn(conn.invalidate, exception=exception)
-
- async def get_isolation_level(self):
- conn = self._sync_connection()
- return await greenlet_spawn(conn.get_isolation_level)
-
- async def set_isolation_level(self):
- conn = self._sync_connection()
- return await greenlet_spawn(conn.get_isolation_level)
-
- def in_transaction(self):
- """Return True if a transaction is in progress.
-
- .. versionadded:: 1.4.0b2
+ return await greenlet_spawn(
+ self._proxied.invalidate, exception=exception
+ )
- """
+ async def get_isolation_level(self) -> _IsolationLevel:
+ return await greenlet_spawn(self._proxied.get_isolation_level)
- conn = self._sync_connection()
+ def in_transaction(self) -> bool:
+ """Return True if a transaction is in progress."""
- return conn.in_transaction()
+ return self._proxied.in_transaction()
- def in_nested_transaction(self):
+ def in_nested_transaction(self) -> bool:
"""Return True if a transaction is in progress.
.. versionadded:: 1.4.0b2
"""
- conn = self._sync_connection()
-
- return conn.in_nested_transaction()
+ return self._proxied.in_nested_transaction()
- def get_transaction(self):
+ def get_transaction(self) -> Optional[AsyncTransaction]:
"""Return an :class:`.AsyncTransaction` representing the current
transaction, if any.
@@ -281,15 +319,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. versionadded:: 1.4.0b2
"""
- conn = self._sync_connection()
- trans = conn.get_transaction()
+ trans = self._proxied.get_transaction()
if trans is not None:
return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[AsyncTransaction]:
"""Return an :class:`.AsyncTransaction` representing the current
nested (savepoint) transaction, if any.
@@ -301,15 +338,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. versionadded:: 1.4.0b2
"""
- conn = self._sync_connection()
- trans = conn.get_nested_transaction()
+ trans = self._proxied.get_nested_transaction()
if trans is not None:
return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
- async def execution_options(self, **opt):
+ async def execution_options(self, **opt: Any) -> AsyncConnection:
r"""Set non-SQL options for the connection which take effect
during execution.
@@ -321,12 +357,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
- conn = self._sync_connection()
+ conn = self._proxied
c2 = await greenlet_spawn(conn.execution_options, **opt)
assert c2 is conn
return self
- async def commit(self):
+ async def commit(self) -> None:
"""Commit the transaction that is currently in progress.
This method commits the current transaction if one has been started.
@@ -338,10 +374,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
:meth:`_future.Connection.begin` method is called.
"""
- conn = self._sync_connection()
- await greenlet_spawn(conn.commit)
+ await greenlet_spawn(self._proxied.commit)
- async def rollback(self):
+ async def rollback(self) -> None:
"""Roll back the transaction that is currently in progress.
This method rolls back the current transaction if one has been started.
@@ -355,34 +390,30 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
- conn = self._sync_connection()
- await greenlet_spawn(conn.rollback)
+ await greenlet_spawn(self._proxied.rollback)
- async def close(self):
+ async def close(self) -> None:
"""Close this :class:`_asyncio.AsyncConnection`.
This has the effect of also rolling back the transaction if one
is in place.
"""
- conn = self._sync_connection()
- await greenlet_spawn(conn.close)
+ await greenlet_spawn(self._proxied.close)
async def exec_driver_sql(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: str,
+ parameters: Optional[_DBAPIAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult:
r"""Executes a driver-level SQL string and return buffered
:class:`_engine.Result`.
"""
- conn = self._sync_connection()
-
result = await greenlet_spawn(
- conn.exec_driver_sql,
+ self._proxied.exec_driver_sql,
statement,
parameters,
execution_options,
@@ -393,17 +424,15 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def stream(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncResult:
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object."""
- conn = self._sync_connection()
-
result = await greenlet_spawn(
- conn.execute,
+ self._proxied.execute,
statement,
parameters,
util.EMPTY_DICT.merge_with(
@@ -418,10 +447,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def execute(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult:
r"""Executes a SQL statement construct and return a buffered
:class:`_engine.Result`.
@@ -453,10 +482,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
:return: a :class:`_engine.Result` object.
"""
- conn = self._sync_connection()
-
result = await greenlet_spawn(
- conn.execute,
+ self._proxied.execute,
statement,
parameters,
execution_options,
@@ -466,10 +493,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def scalar(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
r"""Executes a SQL statement construct and returns a scalar object.
This method is shorthand for invoking the
@@ -485,10 +512,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def scalars(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[Any]:
r"""Executes a SQL statement construct and returns a scalar objects.
This method is shorthand for invoking the
@@ -505,10 +532,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def stream_scalars(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncScalarResult[Any]:
r"""Executes a SQL statement and returns a streaming scalar result
object.
@@ -524,7 +551,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
result = await self.stream(statement, parameters, execution_options)
return result.scalars()
- async def run_sync(self, fn, *arg, **kw):
+ async def run_sync(
+ self, fn: _SyncConnectionCallable, *arg: Any, **kw: Any
+ ) -> Any:
"""Invoke the given sync callable passing self as the first argument.
This method maintains the asyncio event loop all the way through
@@ -548,14 +577,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
:ref:`session_run_sync`
"""
- conn = self._sync_connection()
-
- return await greenlet_spawn(fn, conn, *arg, **kw)
+ return await greenlet_spawn(fn, self._proxied, *arg, **kw)
- def __await__(self):
+ def __await__(self) -> Generator[Any, None, AsyncConnection]:
return self.start().__await__()
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.close()
# START PROXY METHODS AsyncConnection
@@ -661,7 +688,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
],
attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
)
-class AsyncEngine(ProxyComparable, AsyncConnectable):
+class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Engine`.
:class:`_asyncio.AsyncEngine` is acquired using the
@@ -679,51 +706,60 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
# current transaction, info, etc. It should be possible to
# create a new AsyncEngine that matches this one given only the
# "sync" elements.
- __slots__ = ("sync_engine", "_proxied")
+ __slots__ = "sync_engine"
- _connection_cls = AsyncConnection
+ _connection_cls: Type[AsyncConnection] = AsyncConnection
- _option_cls: type
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncEngine` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
class _trans_ctx(StartableContext):
- def __init__(self, conn):
+ __slots__ = ("conn", "transaction")
+
+ conn: AsyncConnection
+ transaction: AsyncTransaction
+
+ def __init__(self, conn: AsyncConnection):
self.conn = conn
- async def start(self, is_ctxmanager=False):
+ async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
await self.conn.start(is_ctxmanager=is_ctxmanager)
self.transaction = self.conn.begin()
await self.transaction.__aenter__()
return self.conn
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(
+ self, type_: Any, value: Any, traceback: Any
+ ) -> None:
await self.transaction.__aexit__(type_, value, traceback)
await self.conn.close()
- def __init__(self, sync_engine):
+ def __init__(self, sync_engine: Engine):
if not sync_engine.dialect.is_async:
raise exc.InvalidRequestError(
"The asyncio extension requires an async driver to be used. "
f"The loaded {sync_engine.dialect.driver!r} is not async."
)
- self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
-
- sync_engine: Engine
- """Reference to the sync-style :class:`_engine.Engine` this
- :class:`_asyncio.AsyncEngine` proxies requests towards.
+ self.sync_engine = self._assign_proxied(sync_engine)
- This instance can be used as an event target.
-
- .. seealso::
-
- :ref:`asyncio_events`
- """
+ @util.ro_non_memoized_property
+ def _proxied(self) -> Engine:
+ return self.sync_engine
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
return AsyncEngine(target)
- def begin(self):
+ def begin(self) -> AsyncEngine._trans_ctx:
"""Return a context manager which when entered will deliver an
:class:`_asyncio.AsyncConnection` with an
:class:`_asyncio.AsyncTransaction` established.
@@ -741,7 +777,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
conn = self.connect()
return self._trans_ctx(conn)
- def connect(self):
+ def connect(self) -> AsyncConnection:
"""Return an :class:`_asyncio.AsyncConnection` object.
The :class:`_asyncio.AsyncConnection` will procure a database
@@ -759,7 +795,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
return self._connection_cls(self)
- async def raw_connection(self):
+ async def raw_connection(self) -> PoolProxiedConnection:
"""Return a "raw" DBAPI connection from the connection pool.
.. seealso::
@@ -769,7 +805,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
"""
return await greenlet_spawn(self.sync_engine.raw_connection)
- def execution_options(self, **opt):
+ def execution_options(self, **opt: Any) -> AsyncEngine:
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
:class:`_asyncio.AsyncConnection` objects with the given execution
options.
@@ -781,21 +817,31 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
return AsyncEngine(self.sync_engine.execution_options(**opt))
- async def dispose(self):
+ async def dispose(self, close: bool = True) -> None:
+
"""Dispose of the connection pool used by this
:class:`_asyncio.AsyncEngine`.
- This will close all connection pool connections that are
- **currently checked in**. See the documentation for the underlying
- :meth:`_future.Engine.dispose` method for further notes.
+ :param close: if left at its default of ``True``, has the
+ effect of fully closing all **currently checked in**
+ database connections. Connections that are still checked out
+ will **not** be closed, however they will no longer be associated
+ with this :class:`_engine.Engine`,
+ so when they are closed individually, eventually the
+ :class:`_pool.Pool` which they are associated with will
+ be garbage collected and they will be closed out fully, if
+ not already closed on checkin.
+
+ If set to ``False``, the previous connection pool is de-referenced,
+ and otherwise not touched in any way.
.. seealso::
- :meth:`_future.Engine.dispose`
+ :meth:`_engine.Engine.dispose`
"""
- return await greenlet_spawn(self.sync_engine.dispose)
+ return await greenlet_spawn(self.sync_engine.dispose, close=close)
# START PROXY METHODS AsyncEngine
@@ -973,18 +1019,24 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
# END PROXY METHODS AsyncEngine
-class AsyncTransaction(ProxyComparable, StartableContext):
+class AsyncTransaction(ProxyComparable[Transaction], StartableContext):
"""An asyncio proxy for a :class:`_engine.Transaction`."""
__slots__ = ("connection", "sync_transaction", "nested")
- def __init__(self, connection, nested=False):
- self.connection = connection # AsyncConnection
- self.sync_transaction = None # sqlalchemy.engine.Transaction
+ sync_transaction: Optional[Transaction]
+ connection: AsyncConnection
+ nested: bool
+
+ def __init__(self, connection: AsyncConnection, nested: bool = False):
+ self.connection = connection
+ self.sync_transaction = None
self.nested = nested
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(
+ cls, target: Transaction
+ ) -> AsyncTransaction:
sync_connection = target.connection
sync_transaction = target
nested = isinstance(target, NestedTransaction)
@@ -1000,25 +1052,22 @@ class AsyncTransaction(ProxyComparable, StartableContext):
obj.nested = nested
return obj
- def _sync_transaction(self):
+ @util.ro_non_memoized_property
+ def _proxied(self) -> Transaction:
if not self.sync_transaction:
self._raise_for_not_started()
return self.sync_transaction
@property
- def _proxied(self):
- return self.sync_transaction
+ def is_valid(self) -> bool:
+ return self._proxied.is_valid
@property
- def is_valid(self):
- return self._sync_transaction().is_valid
+ def is_active(self) -> bool:
+ return self._proxied.is_active
- @property
- def is_active(self):
- return self._sync_transaction().is_active
-
- async def close(self):
- """Close this :class:`.Transaction`.
+ async def close(self) -> None:
+ """Close this :class:`.AsyncTransaction`.
If this transaction is the base transaction in a begin/commit
nesting, the transaction will rollback(). Otherwise, the
@@ -1028,18 +1077,18 @@ class AsyncTransaction(ProxyComparable, StartableContext):
an enclosing transaction.
"""
- await greenlet_spawn(self._sync_transaction().close)
+ await greenlet_spawn(self._proxied.close)
- async def rollback(self):
- """Roll back this :class:`.Transaction`."""
- await greenlet_spawn(self._sync_transaction().rollback)
+ async def rollback(self) -> None:
+ """Roll back this :class:`.AsyncTransaction`."""
+ await greenlet_spawn(self._proxied.rollback)
- async def commit(self):
- """Commit this :class:`.Transaction`."""
+ async def commit(self) -> None:
+ """Commit this :class:`.AsyncTransaction`."""
- await greenlet_spawn(self._sync_transaction().commit)
+ await greenlet_spawn(self._proxied.commit)
- async def start(self, is_ctxmanager=False):
+ async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction:
"""Start this :class:`_asyncio.AsyncTransaction` object's context
outside of using a Python ``with:`` block.
@@ -1047,24 +1096,36 @@ class AsyncTransaction(ProxyComparable, StartableContext):
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
- self.connection._sync_connection().begin_nested
+ self.connection._proxied.begin_nested
if self.nested
- else self.connection._sync_connection().begin
+ else self.connection._proxied.begin
)
)
if is_ctxmanager:
self.sync_transaction.__enter__()
return self
- async def __aexit__(self, type_, value, traceback):
- await greenlet_spawn(
- self._sync_transaction().__exit__, type_, value, traceback
- )
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
+ await greenlet_spawn(self._proxied.__exit__, type_, value, traceback)
+
+
+@overload
+def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine:
+ ...
+
+
+@overload
+def _get_sync_engine_or_connection(
+ async_engine: AsyncConnection,
+) -> Connection:
+ ...
-def _get_sync_engine_or_connection(async_engine):
+def _get_sync_engine_or_connection(
+ async_engine: Union[AsyncEngine, AsyncConnection]
+) -> Union[Engine, Connection]:
if isinstance(async_engine, AsyncConnection):
- return async_engine.sync_connection
+ return async_engine._proxied
try:
return async_engine.sync_engine
@@ -1075,7 +1136,7 @@ def _get_sync_engine_or_connection(async_engine):
@inspection._inspects(AsyncConnection)
-def _no_insp_for_async_conn_yet(subject):
+def _no_insp_for_async_conn_yet(subject: AsyncConnection) -> NoReturn:
raise exc.NoInspectionAvailable(
"Inspection on an AsyncConnection is currently not supported. "
"Please use ``run_sync`` to pass a callable where it's possible "
@@ -1085,7 +1146,7 @@ def _no_insp_for_async_conn_yet(subject):
@inspection._inspects(AsyncEngine)
-def _no_insp_for_async_engine_xyet(subject):
+def _no_insp_for_async_engine_xyet(subject: AsyncEngine) -> NoReturn:
raise exc.NoInspectionAvailable(
"Inspection on an AsyncEngine is currently not supported. "
"Please obtain a connection then use ``conn.run_sync`` to pass a "
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
deleted file mode 100644
index c5d5e0126..000000000
--- a/lib/sqlalchemy/ext/asyncio/events.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# ext/asyncio/events.py
-# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: https://www.opensource.org/licenses/mit-license.php
-
-from .engine import AsyncConnectable
-from .session import AsyncSession
-from ...engine import events as engine_event
-from ...orm import events as orm_event
-
-
-class AsyncConnectionEvents(engine_event.ConnectionEvents):
- _target_class_doc = "SomeEngine"
- _dispatch_target = AsyncConnectable
-
- @classmethod
- def _no_async_engine_events(cls):
- raise NotImplementedError(
- "asynchronous events are not implemented at this time. Apply "
- "synchronous listeners to the AsyncEngine.sync_engine or "
- "AsyncConnection.sync_connection attributes."
- )
-
- @classmethod
- def _listen(cls, event_key, retval=False):
- cls._no_async_engine_events()
-
-
-class AsyncSessionEvents(orm_event.SessionEvents):
- _target_class_doc = "SomeSession"
- _dispatch_target = AsyncSession
-
- @classmethod
- def _no_async_engine_events(cls):
- raise NotImplementedError(
- "asynchronous events are not implemented at this time. Apply "
- "synchronous listeners to the AsyncSession.sync_session."
- )
-
- @classmethod
- def _listen(cls, event_key, retval=False):
- cls._no_async_engine_events()
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
index 39718735c..a9db822a6 100644
--- a/lib/sqlalchemy/ext/asyncio/result.py
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -4,25 +4,49 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
import operator
+from typing import Any
+from typing import AsyncIterator
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
from . import exc as async_exc
from ...engine.result import _NO_ROW
+from ...engine.result import _R
from ...engine.result import FilterResult
from ...engine.result import FrozenResult
from ...engine.result import MergedResult
+from ...engine.result import ResultMetaData
+from ...engine.row import Row
+from ...engine.row import RowMapping
from ...util.concurrency import greenlet_spawn
+if TYPE_CHECKING:
+ from ...engine import CursorResult
+ from ...engine import Result
+ from ...engine.result import _KeyIndexType
+ from ...engine.result import _UniqueFilterType
+ from ...engine.result import RMKeyView
-class AsyncCommon(FilterResult):
- async def close(self):
+
+class AsyncCommon(FilterResult[_R]):
+ _real_result: Result
+ _metadata: ResultMetaData
+
+ async def close(self) -> None:
"""Close this result."""
await greenlet_spawn(self._real_result.close)
-class AsyncResult(AsyncCommon):
+SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult")
+
+
+class AsyncResult(AsyncCommon[Row]):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -43,7 +67,7 @@ class AsyncResult(AsyncCommon):
"""
- def __init__(self, real_result):
+ def __init__(self, real_result: Result):
self._real_result = real_result
self._metadata = real_result._metadata
@@ -56,14 +80,16 @@ class AsyncResult(AsyncCommon):
"_row_getter", real_result.__dict__["_row_getter"]
)
- def keys(self):
+ def keys(self) -> RMKeyView:
"""Return the :meth:`_engine.Result.keys` collection from the
underlying :class:`_engine.Result`.
"""
return self._metadata.keys
- def unique(self, strategy=None):
+ def unique(
+ self: SelfAsyncResult, strategy: Optional[_UniqueFilterType] = None
+ ) -> SelfAsyncResult:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncResult`.
@@ -75,7 +101,9 @@ class AsyncResult(AsyncCommon):
self._unique_filter_state = (set(), strategy)
return self
- def columns(self, *col_expressions):
+ def columns(
+ self: SelfAsyncResult, *col_expressions: _KeyIndexType
+ ) -> SelfAsyncResult:
r"""Establish the columns that should be returned in each row.
Refer to :meth:`_engine.Result.columns` in the synchronous
@@ -85,7 +113,9 @@ class AsyncResult(AsyncCommon):
"""
return self._column_slices(col_expressions)
- async def partitions(self, size=None):
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[List[Row]]:
"""Iterate through sub-lists of rows of the size given.
An async iterator is returned::
@@ -111,7 +141,7 @@ class AsyncResult(AsyncCommon):
else:
break
- async def fetchone(self):
+ async def fetchone(self) -> Optional[Row]:
"""Fetch one row.
When all rows are exhausted, returns None.
@@ -131,9 +161,9 @@ class AsyncResult(AsyncCommon):
if row is _NO_ROW:
return None
else:
- return row # type: ignore[return-value]
+ return row
- async def fetchmany(self, size=None):
+ async def fetchmany(self, size: Optional[int] = None) -> List[Row]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
@@ -152,11 +182,9 @@ class AsyncResult(AsyncCommon):
"""
- return await greenlet_spawn(
- self._manyrow_getter, self, size # type: ignore
- )
+ return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self):
+ async def all(self) -> List[Row]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
@@ -166,19 +194,19 @@ class AsyncResult(AsyncCommon):
"""
- return await greenlet_spawn(self._allrows) # type: ignore
+ return await greenlet_spawn(self._allrows)
- def __aiter__(self):
+ def __aiter__(self) -> AsyncResult:
return self
- async def __anext__(self):
+ async def __anext__(self) -> Row:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self):
+ async def first(self) -> Optional[Row]:
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
@@ -201,7 +229,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self):
+ async def one_or_none(self) -> Optional[Row]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
@@ -223,7 +251,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
- async def scalar_one(self):
+ async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
@@ -238,7 +266,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, True, True)
- async def scalar_one_or_none(self):
+ async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
@@ -253,7 +281,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, True)
- async def one(self):
+ async def one(self) -> Row:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
@@ -284,7 +312,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
- async def scalar(self):
+ async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
Returns None if there are no rows to fetch.
@@ -300,7 +328,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, True)
- async def freeze(self):
+ async def freeze(self) -> FrozenResult:
"""Return a callable object that will produce copies of this
:class:`_asyncio.AsyncResult` when invoked.
@@ -323,7 +351,7 @@ class AsyncResult(AsyncCommon):
return await greenlet_spawn(FrozenResult, self)
- def merge(self, *others):
+ def merge(self, *others: AsyncResult) -> MergedResult:
"""Merge this :class:`_asyncio.AsyncResult` with other compatible result
objects.
@@ -337,9 +365,12 @@ class AsyncResult(AsyncCommon):
undefined.
"""
- return MergedResult(self._metadata, (self,) + others)
+ return MergedResult(
+ self._metadata,
+ (self._real_result,) + tuple(o._real_result for o in others),
+ )
- def scalars(self, index=0):
+ def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
@@ -355,7 +386,7 @@ class AsyncResult(AsyncCommon):
"""
return AsyncScalarResult(self._real_result, index)
- def mappings(self):
+ def mappings(self) -> AsyncMappingResult:
"""Apply a mappings filter to returned rows, returning an instance of
:class:`_asyncio.AsyncMappingResult`.
@@ -373,7 +404,12 @@ class AsyncResult(AsyncCommon):
return AsyncMappingResult(self._real_result)
-class AsyncScalarResult(AsyncCommon):
+SelfAsyncScalarResult = TypeVar(
+ "SelfAsyncScalarResult", bound="AsyncScalarResult[Any]"
+)
+
+
+class AsyncScalarResult(AsyncCommon[_R]):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
rather than :class:`_row.Row` values.
@@ -389,7 +425,7 @@ class AsyncScalarResult(AsyncCommon):
_generate_rows = False
- def __init__(self, real_result, index):
+ def __init__(self, real_result: Result, index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
@@ -401,7 +437,10 @@ class AsyncScalarResult(AsyncCommon):
self._unique_filter_state = real_result._unique_filter_state
- def unique(self, strategy=None):
+ def unique(
+ self: SelfAsyncScalarResult,
+ strategy: Optional[_UniqueFilterType] = None,
+ ) -> SelfAsyncScalarResult:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncScalarResult`.
@@ -411,7 +450,9 @@ class AsyncScalarResult(AsyncCommon):
self._unique_filter_state = (set(), strategy)
return self
- async def partitions(self, size=None):
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[List[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -429,12 +470,12 @@ class AsyncScalarResult(AsyncCommon):
else:
break
- async def fetchall(self):
+ async def fetchall(self) -> List[_R]:
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
return await greenlet_spawn(self._allrows)
- async def fetchmany(self, size=None):
+ async def fetchmany(self, size: Optional[int] = None) -> List[_R]:
"""Fetch many objects.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -444,7 +485,7 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self):
+ async def all(self) -> List[_R]:
"""Return all scalar values in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -454,17 +495,17 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._allrows)
- def __aiter__(self):
+ def __aiter__(self) -> AsyncScalarResult[_R]:
return self
- async def __anext__(self):
+ async def __anext__(self) -> _R:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self):
+ async def first(self) -> Optional[_R]:
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
@@ -474,7 +515,7 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self):
+ async def one_or_none(self) -> Optional[_R]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
@@ -484,7 +525,7 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
- async def one(self):
+ async def one(self) -> _R:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
@@ -495,7 +536,12 @@ class AsyncScalarResult(AsyncCommon):
return await greenlet_spawn(self._only_one_row, True, True, False)
-class AsyncMappingResult(AsyncCommon):
+SelfAsyncMappingResult = TypeVar(
+ "SelfAsyncMappingResult", bound="AsyncMappingResult"
+)
+
+
+class AsyncMappingResult(AsyncCommon[RowMapping]):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values
rather than :class:`_engine.Row` values.
@@ -513,14 +559,14 @@ class AsyncMappingResult(AsyncCommon):
_post_creational_filter = operator.attrgetter("_mapping")
- def __init__(self, result):
+ def __init__(self, result: Result):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
if result._source_supports_scalars:
self._metadata = self._metadata._reduce([0])
- def keys(self):
+ def keys(self) -> RMKeyView:
"""Return an iterable view which yields the string keys that would
be represented by each :class:`.Row`.
@@ -535,7 +581,10 @@ class AsyncMappingResult(AsyncCommon):
"""
return self._metadata.keys
- def unique(self, strategy=None):
+ def unique(
+ self: SelfAsyncMappingResult,
+ strategy: Optional[_UniqueFilterType] = None,
+ ) -> SelfAsyncMappingResult:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncMappingResult`.
@@ -545,11 +594,16 @@ class AsyncMappingResult(AsyncCommon):
self._unique_filter_state = (set(), strategy)
return self
- def columns(self, *col_expressions):
+ def columns(
+ self: SelfAsyncMappingResult, *col_expressions: _KeyIndexType
+ ) -> SelfAsyncMappingResult:
r"""Establish the columns that should be returned in each row."""
return self._column_slices(col_expressions)
- async def partitions(self, size=None):
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[List[RowMapping]]:
+
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -567,12 +621,12 @@ class AsyncMappingResult(AsyncCommon):
else:
break
- async def fetchall(self):
+ async def fetchall(self) -> List[RowMapping]:
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
return await greenlet_spawn(self._allrows)
- async def fetchone(self):
+ async def fetchone(self) -> Optional[RowMapping]:
"""Fetch one object.
Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
@@ -587,8 +641,8 @@ class AsyncMappingResult(AsyncCommon):
else:
return row
- async def fetchmany(self, size=None):
- """Fetch many objects.
+ async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+ """Fetch many rows.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
:class:`_result.RowMapping` values, rather than :class:`_result.Row`
@@ -598,8 +652,8 @@ class AsyncMappingResult(AsyncCommon):
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self):
- """Return all scalar values in a list.
+ async def all(self) -> List[RowMapping]:
+ """Return all rows in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
:class:`_result.RowMapping` values, rather than :class:`_result.Row`
@@ -609,17 +663,17 @@ class AsyncMappingResult(AsyncCommon):
return await greenlet_spawn(self._allrows)
- def __aiter__(self):
+ def __aiter__(self) -> AsyncMappingResult:
return self
- async def __anext__(self):
+ async def __anext__(self) -> RowMapping:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self):
+ async def first(self) -> Optional[RowMapping]:
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
@@ -630,7 +684,7 @@ class AsyncMappingResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self):
+ async def one_or_none(self) -> Optional[RowMapping]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
@@ -640,7 +694,7 @@ class AsyncMappingResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
- async def one(self):
+ async def one(self) -> RowMapping:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
@@ -651,11 +705,15 @@ class AsyncMappingResult(AsyncCommon):
return await greenlet_spawn(self._only_one_row, True, True, False)
-async def _ensure_sync_result(result, calling_method):
+_RT = TypeVar("_RT", bound="Result")
+
+
+async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
+ cursor_result: CursorResult
if not result._is_cursor:
- cursor_result = getattr(result, "raw", None)
+ cursor_result = getattr(result, "raw", None) # type: ignore
else:
- cursor_result = result
+ cursor_result = result # type: ignore
if cursor_result and cursor_result.context._is_server_side:
await greenlet_spawn(cursor_result.close)
raise async_exc.AsyncMethodRequired(
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
index 0503076aa..0d6ae92b4 100644
--- a/lib/sqlalchemy/ext/asyncio/scoping.py
+++ b/lib/sqlalchemy/ext/asyncio/scoping.py
@@ -8,12 +8,50 @@
from __future__ import annotations
from typing import Any
-
+from typing import Callable
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
+
+from .session import async_sessionmaker
from .session import AsyncSession
+from ... import exc as sa_exc
from ... import util
-from ...orm.scoping import ScopedSessionMixin
+from ...orm.session import Session
from ...util import create_proxy_methods
from ...util import ScopedRegistry
+from ...util import warn
+from ...util import warn_deprecated
+
+if TYPE_CHECKING:
+ from .engine import AsyncConnection
+ from .result import AsyncResult
+ from .result import AsyncScalarResult
+ from .session import AsyncSessionTransaction
+ from ...engine import Connection
+ from ...engine import Engine
+ from ...engine import Result
+ from ...engine import Row
+ from ...engine.interfaces import _CoreAnyExecuteParams
+ from ...engine.interfaces import _CoreSingleExecuteParams
+ from ...engine.interfaces import _ExecuteOptions
+ from ...engine.interfaces import _ExecuteOptionsParameter
+ from ...engine.result import ScalarResult
+ from ...orm._typing import _IdentityKeyType
+ from ...orm._typing import _O
+ from ...orm.interfaces import ORMOption
+ from ...orm.session import _BindArguments
+ from ...orm.session import _EntityBindKey
+ from ...orm.session import _PKIdentityArgument
+ from ...orm.session import _SessionBind
+ from ...sql.base import Executable
+ from ...sql.elements import ClauseElement
+ from ...sql.selectable import ForUpdateArg
@create_proxy_methods(
@@ -62,7 +100,7 @@ from ...util import ScopedRegistry
"info",
],
)
-class async_scoped_session(ScopedSessionMixin):
+class async_scoped_session:
"""Provides scoped management of :class:`.AsyncSession` objects.
See the section :ref:`asyncio_scoped_session` for usage details.
@@ -74,17 +112,23 @@ class async_scoped_session(ScopedSessionMixin):
_support_async = True
- def __init__(self, session_factory, scopefunc):
+ session_factory: async_sessionmaker
+ """The `session_factory` provided to `__init__` is stored in this
+ attribute and may be accessed at a later time. This can be useful when
+ a new non-scoped :class:`.AsyncSession` is needed."""
+
+ registry: ScopedRegistry[AsyncSession]
+
+ def __init__(
+ self,
+ session_factory: async_sessionmaker,
+ scopefunc: Callable[[], Any],
+ ):
"""Construct a new :class:`_asyncio.async_scoped_session`.
:param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
instances. This is usually, but not necessarily, an instance
- of :class:`_orm.sessionmaker` which itself was passed the
- :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
- parameter::
-
- async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
- AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+ of :class:`_asyncio.async_sessionmaker`.
:param scopefunc: function which defines
the current scope. A function such as ``asyncio.current_task``
@@ -96,10 +140,59 @@ class async_scoped_session(ScopedSessionMixin):
self.registry = ScopedRegistry(session_factory, scopefunc)
@property
- def _proxied(self):
+ def _proxied(self) -> AsyncSession:
return self.registry()
- async def remove(self):
+ def __call__(self, **kw: Any) -> AsyncSession:
+ r"""Return the current :class:`.AsyncSession`, creating it
+ using the :attr:`.scoped_session.session_factory` if not present.
+
+ :param \**kw: Keyword arguments will be passed to the
+ :attr:`.scoped_session.session_factory` callable, if an existing
+ :class:`.AsyncSession` is not present. If the
+ :class:`.AsyncSession` is present
+ and keyword arguments have been passed,
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if kw:
+ if self.registry.has():
+ raise sa_exc.InvalidRequestError(
+ "Scoped session is already present; "
+ "no new arguments may be specified."
+ )
+ else:
+ sess = self.session_factory(**kw)
+ self.registry.set(sess)
+ else:
+ sess = self.registry()
+ if not self._support_async and sess._is_asyncio:
+ warn_deprecated(
+ "Using `scoped_session` with asyncio is deprecated and "
+ "will raise an error in a future version. "
+ "Please use `async_scoped_session` instead.",
+ "1.4.23",
+ )
+ return sess
+
+ def configure(self, **kwargs: Any) -> None:
+ """reconfigure the :class:`.sessionmaker` used by this
+ :class:`.scoped_session`.
+
+ See :meth:`.sessionmaker.configure`.
+
+ """
+
+ if self.registry.has():
+ warn(
+ "At least one scoped session is already present. "
+ " configure() can not affect sessions that have "
+ "already been created."
+ )
+
+ self.session_factory.configure(**kwargs)
+
+ async def remove(self) -> None:
"""Dispose of the current :class:`.AsyncSession`, if present.
Different from scoped_session's remove method, this method would use
@@ -152,7 +245,9 @@ class async_scoped_session(ScopedSessionMixin):
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
- """
+
+
+ """ # noqa: E501
return self._proxied.__iter__()
@@ -199,7 +294,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.add_all(instances)
- def begin(self):
+ def begin(self) -> AsyncSessionTransaction:
r"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
.. container:: class_bases
@@ -228,7 +323,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.begin()
- def begin_nested(self):
+ def begin_nested(self) -> AsyncSessionTransaction:
r"""Return an :class:`_asyncio.AsyncSessionTransaction` object
which will begin a "nested" transaction, e.g. SAVEPOINT.
@@ -247,7 +342,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.begin_nested()
- async def close(self):
+ async def close(self) -> None:
r"""Close out the transactional resources and ORM objects used by this
:class:`_asyncio.AsyncSession`.
@@ -284,7 +379,7 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.close()
- async def commit(self):
+ async def commit(self) -> None:
r"""Commit the current transaction in progress.
.. container:: class_bases
@@ -296,7 +391,7 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.commit()
- async def connection(self, **kw):
+ async def connection(self, **kw: Any) -> AsyncConnection:
r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
this :class:`.Session` object's transactional state.
@@ -321,7 +416,7 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.connection(**kw)
- async def delete(self, instance):
+ async def delete(self, instance: object) -> None:
r"""Mark an instance as deleted.
.. container:: class_bases
@@ -345,12 +440,12 @@ class async_scoped_session(ScopedSessionMixin):
async def execute(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Result:
r"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -519,7 +614,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.expunge_all()
- async def flush(self, objects=None):
+ async def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
r"""Flush all the object changes to the database.
.. container:: class_bases
@@ -538,13 +633,15 @@ class async_scoped_session(ScopedSessionMixin):
async def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
r"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
@@ -568,9 +665,16 @@ class async_scoped_session(ScopedSessionMixin):
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
+ execution_options=execution_options,
)
- def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ def get_bind(
+ self,
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ **kw: Any,
+ ) -> Union[Engine, Connection]:
r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
is bound.
@@ -724,7 +828,7 @@ class async_scoped_session(ScopedSessionMixin):
instance, include_collections=include_collections
)
- async def invalidate(self):
+ async def invalidate(self) -> None:
r"""Close this Session, using connection invalidation.
.. container:: class_bases
@@ -738,7 +842,13 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.invalidate()
- async def merge(self, instance, load=True, options=None):
+ async def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
r"""Copy the state of a given instance into a corresponding instance
within this :class:`_asyncio.AsyncSession`.
@@ -757,8 +867,11 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.merge(instance, load=load, options=options)
async def refresh(
- self, instance, attribute_names=None, with_for_update=None
- ):
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
r"""Expire and refresh the attributes on the given instance.
.. container:: class_bases
@@ -785,7 +898,7 @@ class async_scoped_session(ScopedSessionMixin):
with_for_update=with_for_update,
)
- async def rollback(self):
+ async def rollback(self) -> None:
r"""Rollback the current transaction in progress.
.. container:: class_bases
@@ -799,12 +912,12 @@ class async_scoped_session(ScopedSessionMixin):
async def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
r"""Execute a statement and return a scalar result.
.. container:: class_bases
@@ -829,12 +942,12 @@ class async_scoped_session(ScopedSessionMixin):
async def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
r"""Execute a statement and return scalar results.
.. container:: class_bases
@@ -865,12 +978,12 @@ class async_scoped_session(ScopedSessionMixin):
async def stream(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult:
r"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -892,12 +1005,12 @@ class async_scoped_session(ScopedSessionMixin):
async def stream_scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
r"""Execute a statement and return a stream of scalar results.
.. container:: class_bases
@@ -1159,7 +1272,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.info
@classmethod
- async def close_all(self):
+ async def close_all(self) -> None:
r"""Close all :class:`_asyncio.AsyncSession` sessions.
.. container:: class_bases
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 769fe05bd..7d63b084c 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -7,17 +7,65 @@
from __future__ import annotations
from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from . import engine
-from . import result as _result
from .base import ReversibleProxy
from .base import StartableContext
from .result import _ensure_sync_result
+from .result import AsyncResult
+from .result import AsyncScalarResult
from ... import util
from ...orm import object_session
from ...orm import Session
+from ...orm import SessionTransaction
from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
+from ...util.typing import Protocol
+
+if TYPE_CHECKING:
+ from .engine import AsyncConnection
+ from .engine import AsyncEngine
+ from ...engine import Connection
+ from ...engine import Engine
+ from ...engine import Result
+ from ...engine import Row
+ from ...engine import ScalarResult
+ from ...engine import Transaction
+ from ...engine.interfaces import _CoreAnyExecuteParams
+ from ...engine.interfaces import _CoreSingleExecuteParams
+ from ...engine.interfaces import _ExecuteOptions
+ from ...engine.interfaces import _ExecuteOptionsParameter
+ from ...event import dispatcher
+ from ...orm._typing import _IdentityKeyType
+ from ...orm._typing import _O
+ from ...orm.identity import IdentityMap
+ from ...orm.interfaces import ORMOption
+ from ...orm.session import _BindArguments
+ from ...orm.session import _EntityBindKey
+ from ...orm.session import _PKIdentityArgument
+ from ...orm.session import _SessionBind
+ from ...orm.session import _SessionBindKey
+ from ...sql.base import Executable
+ from ...sql.elements import ClauseElement
+ from ...sql.selectable import ForUpdateArg
+
+_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
+
+
+class _SyncSessionCallable(Protocol):
+ def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
+ ...
+
_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
@@ -52,7 +100,7 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
"info",
],
)
-class AsyncSession(ReversibleProxy):
+class AsyncSession(ReversibleProxy[Session]):
"""Asyncio version of :class:`_orm.Session`.
The :class:`_asyncio.AsyncSession` is a proxy for a traditional
@@ -69,9 +117,15 @@ class AsyncSession(ReversibleProxy):
_is_asyncio = True
- dispatch = None
+ dispatch: dispatcher[Session]
- def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+ def __init__(
+ self,
+ bind: Optional[_AsyncSessionBind] = None,
+ binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None,
+ sync_session_class: Optional[Type[Session]] = None,
+ **kw: Any,
+ ):
r"""Construct a new :class:`_asyncio.AsyncSession`.
All parameters other than ``sync_session_class`` are passed to the
@@ -90,14 +144,15 @@ class AsyncSession(ReversibleProxy):
.. versionadded:: 1.4.24
"""
- kw["future"] = True
+ sync_bind = sync_binds = None
+
if bind:
self.bind = bind
- bind = engine._get_sync_engine_or_connection(bind)
+ sync_bind = engine._get_sync_engine_or_connection(bind)
if binds:
self.binds = binds
- binds = {
+ sync_binds = {
key: engine._get_sync_engine_or_connection(b)
for key, b in binds.items()
}
@@ -106,10 +161,10 @@ class AsyncSession(ReversibleProxy):
self.sync_session_class = sync_session_class
self.sync_session = self._proxied = self._assign_proxied(
- self.sync_session_class(bind=bind, binds=binds, **kw)
+ self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw)
)
- sync_session_class = Session
+ sync_session_class: Type[Session] = Session
"""The class or callable that provides the
underlying :class:`_orm.Session` instance for a particular
:class:`_asyncio.AsyncSession`.
@@ -138,9 +193,19 @@ class AsyncSession(ReversibleProxy):
"""
+ @classmethod
+ def _no_async_engine_events(cls) -> NoReturn:
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
+
async def refresh(
- self, instance, attribute_names=None, with_for_update=None
- ):
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
"""Expire and refresh the attributes on the given instance.
A query will be issued to the database and all attributes will be
@@ -155,14 +220,16 @@ class AsyncSession(ReversibleProxy):
"""
- return await greenlet_spawn(
+ await greenlet_spawn(
self.sync_session.refresh,
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
- async def run_sync(self, fn, *arg, **kw):
+ async def run_sync(
+ self, fn: _SyncSessionCallable, *arg: Any, **kw: Any
+ ) -> Any:
"""Invoke the given sync callable passing sync self as the first
argument.
@@ -191,12 +258,12 @@ class AsyncSession(ReversibleProxy):
async def execute(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Result:
"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -225,12 +292,12 @@ class AsyncSession(ReversibleProxy):
async def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
"""Execute a statement and return a scalar result.
.. seealso::
@@ -250,12 +317,12 @@ class AsyncSession(ReversibleProxy):
async def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
"""Execute a statement and return scalar results.
:return: a :class:`_result.ScalarResult` object
@@ -281,13 +348,16 @@ class AsyncSession(ReversibleProxy):
async def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
+
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
@@ -297,7 +367,8 @@ class AsyncSession(ReversibleProxy):
"""
- return await greenlet_spawn(
+
+ result_obj = await greenlet_spawn(
self.sync_session.get,
entity,
ident,
@@ -306,15 +377,17 @@ class AsyncSession(ReversibleProxy):
with_for_update=with_for_update,
identity_token=identity_token,
)
+ return result_obj
async def stream(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult:
+
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -335,16 +408,16 @@ class AsyncSession(ReversibleProxy):
bind_arguments=bind_arguments,
**kw,
)
- return _result.AsyncResult(result)
+ return AsyncResult(result)
async def stream_scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
"""Execute a statement and return a stream of scalar results.
:return: an :class:`_asyncio.AsyncScalarResult` object
@@ -368,7 +441,7 @@ class AsyncSession(ReversibleProxy):
)
return result.scalars()
- async def delete(self, instance):
+ async def delete(self, instance: object) -> None:
"""Mark an instance as deleted.
The database delete operation occurs upon ``flush()``.
@@ -381,9 +454,15 @@ class AsyncSession(ReversibleProxy):
:meth:`_orm.Session.delete` - main documentation for delete
"""
- return await greenlet_spawn(self.sync_session.delete, instance)
+ await greenlet_spawn(self.sync_session.delete, instance)
- async def merge(self, instance, load=True, options=None):
+ async def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
"""Copy the state of a given instance into a corresponding instance
within this :class:`_asyncio.AsyncSession`.
@@ -396,7 +475,7 @@ class AsyncSession(ReversibleProxy):
self.sync_session.merge, instance, load=load, options=options
)
- async def flush(self, objects=None):
+ async def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
"""Flush all the object changes to the database.
.. seealso::
@@ -406,7 +485,7 @@ class AsyncSession(ReversibleProxy):
"""
await greenlet_spawn(self.sync_session.flush, objects=objects)
- def get_transaction(self):
+ def get_transaction(self) -> Optional[AsyncSessionTransaction]:
"""Return the current root transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -421,7 +500,7 @@ class AsyncSession(ReversibleProxy):
else:
return None
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]:
"""Return the current nested transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -437,7 +516,13 @@ class AsyncSession(ReversibleProxy):
else:
return None
- def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ def get_bind(
+ self,
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ **kw: Any,
+ ) -> Union[Engine, Connection]:
"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
is bound.
@@ -515,7 +600,7 @@ class AsyncSession(ReversibleProxy):
mapper=mapper, clause=clause, bind=bind, **kw
)
- async def connection(self, **kw):
+ async def connection(self, **kw: Any) -> AsyncConnection:
r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
this :class:`.Session` object's transactional state.
@@ -539,7 +624,7 @@ class AsyncSession(ReversibleProxy):
sync_connection
)
- def begin(self):
+ def begin(self) -> AsyncSessionTransaction:
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
The underlying :class:`_orm.Session` will perform the
@@ -562,7 +647,7 @@ class AsyncSession(ReversibleProxy):
return AsyncSessionTransaction(self)
- def begin_nested(self):
+ def begin_nested(self) -> AsyncSessionTransaction:
"""Return an :class:`_asyncio.AsyncSessionTransaction` object
which will begin a "nested" transaction, e.g. SAVEPOINT.
@@ -575,15 +660,15 @@ class AsyncSession(ReversibleProxy):
return AsyncSessionTransaction(self, nested=True)
- async def rollback(self):
+ async def rollback(self) -> None:
"""Rollback the current transaction in progress."""
- return await greenlet_spawn(self.sync_session.rollback)
+ await greenlet_spawn(self.sync_session.rollback)
- async def commit(self):
+ async def commit(self) -> None:
"""Commit the current transaction in progress."""
- return await greenlet_spawn(self.sync_session.commit)
+ await greenlet_spawn(self.sync_session.commit)
- async def close(self):
+ async def close(self) -> None:
"""Close out the transactional resources and ORM objects used by this
:class:`_asyncio.AsyncSession`.
@@ -613,25 +698,25 @@ class AsyncSession(ReversibleProxy):
"""
return await greenlet_spawn(self.sync_session.close)
- async def invalidate(self):
+ async def invalidate(self) -> None:
"""Close this Session, using connection invalidation.
For a complete description, see :meth:`_orm.Session.invalidate`.
"""
- return await greenlet_spawn(self.sync_session.invalidate)
+ await greenlet_spawn(self.sync_session.invalidate)
@classmethod
- async def close_all(self):
+ async def close_all(self) -> None:
"""Close all :class:`_asyncio.AsyncSession` sessions."""
- return await greenlet_spawn(self.sync_session.close_all)
+ await greenlet_spawn(self.sync_session.close_all)
- async def __aenter__(self):
+ async def __aenter__(self) -> AsyncSession:
return self
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.close()
- def _maker_context_manager(self):
+ def _maker_context_manager(self) -> _AsyncSessionContextManager:
# TODO: can this use asynccontextmanager ??
return _AsyncSessionContextManager(self)
@@ -1142,21 +1227,159 @@ class AsyncSession(ReversibleProxy):
# END PROXY METHODS AsyncSession
+class async_sessionmaker:
+ """A configurable :class:`.AsyncSession` factory.
+
+ The :class:`.async_sessionmaker` factory works in the same way as the
+ :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession`
+ objects when called, creating them given
+ the configurational arguments established here.
+
+ e.g.::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ from sqlalchemy.ext.asyncio import async_sessionmaker
+
+ async def main():
+ # an AsyncEngine, which the AsyncSession will use for connection
+ # resources
+ engine = create_async_engine('postgresql+asycncpg://scott:tiger@localhost/')
+
+ AsyncSession = async_sessionmaker(engine)
+
+ async with async_session() as session:
+ session.add(some_object)
+ session.add(some_other_object)
+ await session.commit()
+
+ .. versionadded:: 2.0 :class:`.asyncio_sessionmaker` provides a
+ :class:`.sessionmaker` class that's dedicated to the
+ :class:`.AsyncSession` object, including pep-484 typing support.
+
+ .. seealso::
+
+ :ref:`asyncio_orm` - shows example use
+
+ :class:`.sessionmaker` - general overview of the
+ :class:`.sessionmaker` architecture
+
+
+ :ref:`session_getting` - introductory text on creating
+ sessions using :class:`.sessionmaker`.
+
+ """ # noqa E501
+
+ class_: Type[AsyncSession]
+
+ def __init__(
+ self,
+ bind: Optional[_AsyncSessionBind] = None,
+ class_: Type[AsyncSession] = AsyncSession,
+ autoflush: bool = True,
+ expire_on_commit: bool = True,
+ info: Optional[Dict[Any, Any]] = None,
+ **kw: Any,
+ ):
+ r"""Construct a new :class:`.async_sessionmaker`.
+
+ All arguments here except for ``class_`` correspond to arguments
+ accepted by :class:`.Session` directly. See the
+ :meth:`.AsyncSession.__init__` docstring for more details on
+ parameters.
+
+
+ """
+ kw["bind"] = bind
+ kw["autoflush"] = autoflush
+ kw["expire_on_commit"] = expire_on_commit
+ if info is not None:
+ kw["info"] = info
+ self.kw = kw
+ self.class_ = class_
+
+ def begin(self) -> _AsyncSessionContextManager:
+ """Produce a context manager that both provides a new
+ :class:`_orm.AsyncSession` as well as a transaction that commits.
+
+
+ e.g.::
+
+ async def main():
+ Session = async_sessionmaker(some_engine)
+
+ async with Session.begin() as session:
+ session.add(some_object)
+
+ # commits transaction, closes session
+
+
+ """
+
+ session = self()
+ return session._maker_context_manager()
+
+ def __call__(self, **local_kw: Any) -> AsyncSession:
+ """Produce a new :class:`.AsyncSession` object using the configuration
+ established in this :class:`.async_sessionmaker`.
+
+ In Python, the ``__call__`` method is invoked on an object when
+ it is "called" in the same way as a function::
+
+ AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False)
+ session = AsyncSession() # invokes sessionmaker.__call__()
+
+ """ # noqa E501
+ for k, v in self.kw.items():
+ if k == "info" and "info" in local_kw:
+ d = v.copy()
+ d.update(local_kw["info"])
+ local_kw["info"] = d
+ else:
+ local_kw.setdefault(k, v)
+ return self.class_(**local_kw)
+
+ def configure(self, **new_kw: Any) -> None:
+ """(Re)configure the arguments for this async_sessionmaker.
+
+ e.g.::
+
+ AsyncSession = async_sessionmaker(some_engine)
+
+ AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://'))
+ """ # noqa E501
+
+ self.kw.update(new_kw)
+
+ def __repr__(self) -> str:
+ return "%s(class_=%r, %s)" % (
+ self.__class__.__name__,
+ self.class_.__name__,
+ ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()),
+ )
+
+
class _AsyncSessionContextManager:
- def __init__(self, async_session):
+ __slots__ = ("async_session", "trans")
+
+ async_session: AsyncSession
+ trans: AsyncSessionTransaction
+
+ def __init__(self, async_session: AsyncSession):
self.async_session = async_session
- async def __aenter__(self):
+ async def __aenter__(self) -> AsyncSession:
self.trans = self.async_session.begin()
await self.trans.__aenter__()
return self.async_session
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.trans.__aexit__(type_, value, traceback)
await self.async_session.__aexit__(type_, value, traceback)
-class AsyncSessionTransaction(ReversibleProxy, StartableContext):
+class AsyncSessionTransaction(
+ ReversibleProxy[SessionTransaction], StartableContext
+):
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.
This object is provided so that a transaction-holding object
@@ -1174,36 +1397,41 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
__slots__ = ("session", "sync_transaction", "nested")
- def __init__(self, session, nested=False):
+ session: AsyncSession
+ sync_transaction: Optional[SessionTransaction]
+
+ def __init__(self, session: AsyncSession, nested: bool = False):
self.session = session
self.nested = nested
self.sync_transaction = None
@property
- def is_active(self):
+ def is_active(self) -> bool:
return (
self._sync_transaction() is not None
and self._sync_transaction().is_active
)
- def _sync_transaction(self):
+ def _sync_transaction(self) -> SessionTransaction:
if not self.sync_transaction:
self._raise_for_not_started()
return self.sync_transaction
- async def rollback(self):
+ async def rollback(self) -> None:
"""Roll back this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().rollback)
- async def commit(self):
+ async def commit(self) -> None:
"""Commit this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().commit)
- async def start(self, is_ctxmanager=False):
+ async def start(
+ self, is_ctxmanager: bool = False
+ ) -> AsyncSessionTransaction:
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
- self.session.sync_session.begin_nested
+ self.session.sync_session.begin_nested # type: ignore
if self.nested
else self.session.sync_session.begin
)
@@ -1212,13 +1440,13 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
self.sync_transaction.__enter__()
return self
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await greenlet_spawn(
self._sync_transaction().__exit__, type_, value, traceback
)
-def async_object_session(instance):
+def async_object_session(instance: object) -> Optional[AsyncSession]:
"""Return the :class:`_asyncio.AsyncSession` to which the given instance
belongs.
@@ -1247,7 +1475,7 @@ def async_object_session(instance):
return None
-def async_session(session: Session) -> AsyncSession:
+def async_session(session: Session) -> Optional[AsyncSession]:
"""Return the :class:`_asyncio.AsyncSession` which is proxying the given
:class:`_orm.Session` object, if any.
@@ -1260,4 +1488,4 @@ def async_session(session: Session) -> AsyncSession:
return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
-_instance_state._async_provider = async_session
+_instance_state._async_provider = async_session # type: ignore