diff options
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/engine.py')
-rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 461 |
1 files changed, 461 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py new file mode 100644 index 000000000..2d9198d16 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -0,0 +1,461 @@ +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Optional + +from . import exc as async_exc +from .base import StartableContext +from .result import AsyncResult +from ... import exc +from ... import util +from ...engine import Connection +from ...engine import create_engine as _create_engine +from ...engine import Engine +from ...engine import Result +from ...engine import Transaction +from ...engine.base import OptionEngineMixin +from ...sql import Executable +from ...util.concurrency import greenlet_spawn + + +def create_async_engine(*arg, **kw): + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_engine` are mostly + identical to those passed to the :func:`_sa.create_engine` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 1.4 + + """ + + if kw.get("server_side_cursors", False): + raise exc.AsyncMethodRequired( + "Can't set server_side_cursors for async engine globally; " + "use the connection.stream() method for an async " + "streaming result set" + ) + kw["future"] = True + sync_engine = _create_engine(*arg, **kw) + return AsyncEngine(sync_engine) + + +class AsyncConnection(StartableContext): + """An asyncio proxy for a :class:`_engine.Connection`. + + :class:`_asyncio.AsyncConnection` is acquired using the + :meth:`_asyncio.AsyncEngine.connect` + method of :class:`_asyncio.AsyncEngine`:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + async with engine.connect() as conn: + result = await conn.execute(select(table)) + + .. versionadded:: 1.4 + + """ # noqa + + __slots__ = ( + "sync_engine", + "sync_connection", + ) + + def __init__( + self, sync_engine: Engine, sync_connection: Optional[Connection] = None + ): + self.sync_engine = sync_engine + self.sync_connection = sync_connection + + async def start(self): + """Start this :class:`_asyncio.AsyncConnection` object's context + outside of using a Python ``with:`` block. + + """ + if self.sync_connection: + raise exc.InvalidRequestError("connection is already started") + self.sync_connection = await (greenlet_spawn(self.sync_engine.connect)) + return self + + def _sync_connection(self): + if not self.sync_connection: + self._raise_for_not_started() + return self.sync_connection + + def begin(self) -> "AsyncTransaction": + """Begin a transaction prior to autobegin occurring. + + """ + self._sync_connection() + return AsyncTransaction(self) + + def begin_nested(self) -> "AsyncTransaction": + """Begin a nested transaction and return a transaction handle. + + """ + self._sync_connection() + return AsyncTransaction(self, nested=True) + + async def commit(self): + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_future.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_future.Connection.begin` method is called. + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.commit) + + async def rollback(self): + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_future.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_future.Connection.begin` method is called. + + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.rollback) + + async def close(self): + """Close this :class:`_asyncio.AsyncConnection`. + + This has the effect of also rolling back the transaction if one + is in place. + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.close) + + async def exec_driver_sql( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> Result: + r"""Executes a driver-level SQL string and return buffered + :class:`_engine.Result`. + + """ + + conn = self._sync_connection() + + result = await greenlet_spawn( + conn.exec_driver_sql, statement, parameters, execution_options, + ) + if result.context._is_server_side: + raise async_exc.AsyncMethodRequired( + "Can't use the connection.exec_driver_sql() method with a " + "server-side cursor." + "Use the connection.stream() method for an async " + "streaming result set." + ) + + return result + + async def stream( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> AsyncResult: + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object.""" + + conn = self._sync_connection() + + result = await greenlet_spawn( + conn._execute_20, + statement, + parameters, + util.EMPTY_DICT.merge_with( + execution_options, {"stream_results": True} + ), + ) + if not result.context._is_server_side: + # TODO: real exception here + assert False, "server side result expected" + return AsyncResult(result) + + async def execute( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> Result: + r"""Executes a SQL statement construct and return a buffered + :class:`_engine.Result`. + + :param object: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.DDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_future.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + conn = self._sync_connection() + + result = await greenlet_spawn( + conn._execute_20, statement, parameters, execution_options, + ) + if result.context._is_server_side: + raise async_exc.AsyncMethodRequired( + "Can't use the connection.execute() method with a " + "server-side cursor." + "Use the connection.stream() method for an async " + "streaming result set." + ) + return result + + async def scalar( + self, + statement: Executable, + parameters: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + ) -> Any: + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_future.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + result = await self.execute(statement, parameters, execution_options) + return result.scalar() + + async def run_sync(self, fn: Callable, *arg, **kw) -> Any: + """"Invoke the given sync callable passing self as the first argument. + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + E.g.:: + + with async_engine.begin() as conn: + await conn.run_sync(metadata.create_all) + + """ + + conn = self._sync_connection() + + return await greenlet_spawn(fn, conn, *arg, **kw) + + def __await__(self): + return self.start().__await__() + + async def __aexit__(self, type_, value, traceback): + await self.close() + + +class AsyncEngine: + """An asyncio proxy for a :class:`_engine.Engine`. + + :class:`_asyncio.AsyncEngine` is acquired using the + :func:`_asyncio.create_async_engine` function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + .. versionadded:: 1.4 + + + """ # noqa + + __slots__ = ("sync_engine",) + + _connection_cls = AsyncConnection + + _option_cls: type + + class _trans_ctx(StartableContext): + def __init__(self, conn): + self.conn = conn + + async def start(self): + await self.conn.start() + self.transaction = self.conn.begin() + await self.transaction.__aenter__() + + return self.conn + + async def __aexit__(self, type_, value, traceback): + if type_ is not None: + await self.transaction.rollback() + else: + if self.transaction.is_active: + await self.transaction.commit() + await self.conn.close() + + def __init__(self, sync_engine: Engine): + self.sync_engine = sync_engine + + def begin(self): + """Return a context manager which when entered will deliver an + :class:`_asyncio.AsyncConnection` with an + :class:`_asyncio.AsyncTransaction` established. + + E.g.:: + + async with async_engine.begin() as conn: + await conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + await conn.execute(text("my_special_procedure(5)")) + + + """ + conn = self.connect() + return self._trans_ctx(conn) + + def connect(self) -> AsyncConnection: + """Return an :class:`_asyncio.AsyncConnection` object. + + The :class:`_asyncio.AsyncConnection` will procure a database + connection from the underlying connection pool when it is entered + as an async context manager:: + + async with async_engine.connect() as conn: + result = await conn.execute(select(user_table)) + + The :class:`_asyncio.AsyncConnection` may also be started outside of a + context manager by invoking its :meth:`_asyncio.AsyncConnection.start` + method. + + """ + + return self._connection_cls(self.sync_engine) + + async def raw_connection(self) -> Any: + """Return a "raw" DBAPI connection from the connection pool. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return await greenlet_spawn(self.sync_engine.raw_connection) + + +class AsyncOptionEngine(OptionEngineMixin, AsyncEngine): + pass + + +AsyncEngine._option_cls = AsyncOptionEngine + + +class AsyncTransaction(StartableContext): + """An asyncio proxy for a :class:`_engine.Transaction`.""" + + __slots__ = ("connection", "sync_transaction", "nested") + + def __init__(self, connection: AsyncConnection, nested: bool = False): + self.connection = connection + self.sync_transaction: Optional[Transaction] = None + self.nested = nested + + def _sync_transaction(self): + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + @property + def is_valid(self) -> bool: + return self._sync_transaction().is_valid + + @property + def is_active(self) -> bool: + return self._sync_transaction().is_active + + async def close(self): + """Close this :class:`.Transaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + await greenlet_spawn(self._sync_transaction().close) + + async def rollback(self): + """Roll back this :class:`.Transaction`. + + """ + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self): + """Commit this :class:`.Transaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start(self): + """Start this :class:`_asyncio.AsyncTransaction` object's context + outside of using a Python ``with:`` block. + + """ + + self.sync_transaction = await greenlet_spawn( + self.connection._sync_connection().begin_nested + if self.nested + else self.connection._sync_connection().begin + ) + return self + + async def __aexit__(self, type_, value, traceback): + if type_ is None and self.is_active: + try: + await self.commit() + except: + with util.safe_reraise(): + await self.rollback() + else: + await self.rollback() + + +def _get_sync_engine(async_engine): + try: + return async_engine.sync_engine + except AttributeError as e: + raise exc.ArgumentError( + "AsyncEngine expected, got %r" % async_engine + ) from e |