summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/engine.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/engine.py')
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py461
1 files changed, 461 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
new file mode 100644
index 000000000..2d9198d16
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -0,0 +1,461 @@
+from typing import Any
+from typing import Callable
+from typing import Mapping
+from typing import Optional
+
+from . import exc as async_exc
+from .base import StartableContext
+from .result import AsyncResult
+from ... import exc
+from ... import util
+from ...engine import Connection
+from ...engine import create_engine as _create_engine
+from ...engine import Engine
+from ...engine import Result
+from ...engine import Transaction
+from ...engine.base import OptionEngineMixin
+from ...sql import Executable
+from ...util.concurrency import greenlet_spawn
+
+
+def create_async_engine(*arg, **kw):
+ """Create a new async engine instance.
+
+ Arguments passed to :func:`_asyncio.create_async_engine` are mostly
+ identical to those passed to the :func:`_sa.create_engine` function.
+ The specified dialect must be an asyncio-compatible dialect
+ such as :ref:`dialect-postgresql-asyncpg`.
+
+ .. versionadded:: 1.4
+
+ """
+
+ if kw.get("server_side_cursors", False):
+ raise exc.AsyncMethodRequired(
+ "Can't set server_side_cursors for async engine globally; "
+ "use the connection.stream() method for an async "
+ "streaming result set"
+ )
+ kw["future"] = True
+ sync_engine = _create_engine(*arg, **kw)
+ return AsyncEngine(sync_engine)
+
+
+class AsyncConnection(StartableContext):
+ """An asyncio proxy for a :class:`_engine.Connection`.
+
+ :class:`_asyncio.AsyncConnection` is acquired using the
+ :meth:`_asyncio.AsyncEngine.connect`
+ method of :class:`_asyncio.AsyncEngine`::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ async with engine.connect() as conn:
+ result = await conn.execute(select(table))
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ __slots__ = (
+ "sync_engine",
+ "sync_connection",
+ )
+
+ def __init__(
+ self, sync_engine: Engine, sync_connection: Optional[Connection] = None
+ ):
+ self.sync_engine = sync_engine
+ self.sync_connection = sync_connection
+
+ async def start(self):
+ """Start this :class:`_asyncio.AsyncConnection` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+ if self.sync_connection:
+ raise exc.InvalidRequestError("connection is already started")
+ self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
+ return self
+
+ def _sync_connection(self):
+ if not self.sync_connection:
+ self._raise_for_not_started()
+ return self.sync_connection
+
+ def begin(self) -> "AsyncTransaction":
+ """Begin a transaction prior to autobegin occurring.
+
+ """
+ self._sync_connection()
+ return AsyncTransaction(self)
+
+ def begin_nested(self) -> "AsyncTransaction":
+ """Begin a nested transaction and return a transaction handle.
+
+ """
+ self._sync_connection()
+ return AsyncTransaction(self, nested=True)
+
+ async def commit(self):
+ """Commit the transaction that is currently in progress.
+
+ This method commits the current transaction if one has been started.
+ If no transaction was started, the method has no effect, assuming
+ the connection is in a non-invalidated state.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.commit)
+
+ async def rollback(self):
+ """Roll back the transaction that is currently in progress.
+
+ This method rolls back the current transaction if one has been started.
+ If no transaction was started, the method has no effect. If a
+ transaction was started and the connection is in an invalidated state,
+ the transaction is cleared using this method.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.rollback)
+
+ async def close(self):
+ """Close this :class:`_asyncio.AsyncConnection`.
+
+ This has the effect of also rolling back the transaction if one
+ is in place.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.close)
+
+ async def exec_driver_sql(
+ self,
+ statement: Executable,
+ parameters: Optional[Mapping] = None,
+ execution_options: Mapping = util.EMPTY_DICT,
+ ) -> Result:
+ r"""Executes a driver-level SQL string and return buffered
+ :class:`_engine.Result`.
+
+ """
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn.exec_driver_sql, statement, parameters, execution_options,
+ )
+ if result.context._is_server_side:
+ raise async_exc.AsyncMethodRequired(
+ "Can't use the connection.exec_driver_sql() method with a "
+ "server-side cursor."
+ "Use the connection.stream() method for an async "
+ "streaming result set."
+ )
+
+ return result
+
+ async def stream(
+ self,
+ statement: Executable,
+ parameters: Optional[Mapping] = None,
+ execution_options: Mapping = util.EMPTY_DICT,
+ ) -> AsyncResult:
+ """Execute a statement and return a streaming
+ :class:`_asyncio.AsyncResult` object."""
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20,
+ statement,
+ parameters,
+ util.EMPTY_DICT.merge_with(
+ execution_options, {"stream_results": True}
+ ),
+ )
+ if not result.context._is_server_side:
+ # TODO: real exception here
+ assert False, "server side result expected"
+ return AsyncResult(result)
+
+ async def execute(
+ self,
+ statement: Executable,
+ parameters: Optional[Mapping] = None,
+ execution_options: Mapping = util.EMPTY_DICT,
+ ) -> Result:
+ r"""Executes a SQL statement construct and return a buffered
+ :class:`_engine.Result`.
+
+ :param object: The statement to be executed. This is always
+ an object that is in both the :class:`_expression.ClauseElement` and
+ :class:`_expression.Executable` hierarchies, including:
+
+ * :class:`_expression.Select`
+ * :class:`_expression.Insert`, :class:`_expression.Update`,
+ :class:`_expression.Delete`
+ * :class:`_expression.TextClause` and
+ :class:`_expression.TextualSelect`
+ * :class:`_schema.DDL` and objects which inherit from
+ :class:`_schema.DDLElement`
+
+ :param parameters: parameters which will be bound into the statement.
+ This may be either a dictionary of parameter names to values,
+ or a mutable sequence (e.g. a list) of dictionaries. When a
+ list of dictionaries is passed, the underlying statement execution
+ will make use of the DBAPI ``cursor.executemany()`` method.
+ When a single dictionary is passed, the DBAPI ``cursor.execute()``
+ method will be used.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_future.Connection.execution_options`.
+
+ :return: a :class:`_engine.Result` object.
+
+ """
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20, statement, parameters, execution_options,
+ )
+ if result.context._is_server_side:
+ raise async_exc.AsyncMethodRequired(
+ "Can't use the connection.execute() method with a "
+ "server-side cursor."
+ "Use the connection.stream() method for an async "
+ "streaming result set."
+ )
+ return result
+
+ async def scalar(
+ self,
+ statement: Executable,
+ parameters: Optional[Mapping] = None,
+ execution_options: Mapping = util.EMPTY_DICT,
+ ) -> Any:
+ r"""Executes a SQL statement construct and returns a scalar object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalar` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a scalar Python value representing the first column of the
+ first row returned.
+
+ """
+ result = await self.execute(statement, parameters, execution_options)
+ return result.scalar()
+
+ async def run_sync(self, fn: Callable, *arg, **kw) -> Any:
+ """"Invoke the given sync callable passing self as the first argument.
+
+ This method maintains the asyncio event loop all the way through
+ to the database connection by running the given callable in a
+ specially instrumented greenlet.
+
+ E.g.::
+
+ with async_engine.begin() as conn:
+ await conn.run_sync(metadata.create_all)
+
+ """
+
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(fn, conn, *arg, **kw)
+
+ def __await__(self):
+ return self.start().__await__()
+
+ async def __aexit__(self, type_, value, traceback):
+ await self.close()
+
+
+class AsyncEngine:
+ """An asyncio proxy for a :class:`_engine.Engine`.
+
+ :class:`_asyncio.AsyncEngine` is acquired using the
+ :func:`_asyncio.create_async_engine` function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ .. versionadded:: 1.4
+
+
+ """ # noqa
+
+ __slots__ = ("sync_engine",)
+
+ _connection_cls = AsyncConnection
+
+ _option_cls: type
+
+ class _trans_ctx(StartableContext):
+ def __init__(self, conn):
+ self.conn = conn
+
+ async def start(self):
+ await self.conn.start()
+ self.transaction = self.conn.begin()
+ await self.transaction.__aenter__()
+
+ return self.conn
+
+ async def __aexit__(self, type_, value, traceback):
+ if type_ is not None:
+ await self.transaction.rollback()
+ else:
+ if self.transaction.is_active:
+ await self.transaction.commit()
+ await self.conn.close()
+
+ def __init__(self, sync_engine: Engine):
+ self.sync_engine = sync_engine
+
+ def begin(self):
+ """Return a context manager which when entered will deliver an
+ :class:`_asyncio.AsyncConnection` with an
+ :class:`_asyncio.AsyncTransaction` established.
+
+ E.g.::
+
+ async with async_engine.begin() as conn:
+ await conn.execute(
+ text("insert into table (x, y, z) values (1, 2, 3)")
+ )
+ await conn.execute(text("my_special_procedure(5)"))
+
+
+ """
+ conn = self.connect()
+ return self._trans_ctx(conn)
+
+ def connect(self) -> AsyncConnection:
+ """Return an :class:`_asyncio.AsyncConnection` object.
+
+ The :class:`_asyncio.AsyncConnection` will procure a database
+ connection from the underlying connection pool when it is entered
+ as an async context manager::
+
+ async with async_engine.connect() as conn:
+ result = await conn.execute(select(user_table))
+
+ The :class:`_asyncio.AsyncConnection` may also be started outside of a
+ context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
+ method.
+
+ """
+
+ return self._connection_cls(self.sync_engine)
+
+ async def raw_connection(self) -> Any:
+ """Return a "raw" DBAPI connection from the connection pool.
+
+ .. seealso::
+
+ :ref:`dbapi_connections`
+
+ """
+ return await greenlet_spawn(self.sync_engine.raw_connection)
+
+
+class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
+ pass
+
+
+AsyncEngine._option_cls = AsyncOptionEngine
+
+
+class AsyncTransaction(StartableContext):
+ """An asyncio proxy for a :class:`_engine.Transaction`."""
+
+ __slots__ = ("connection", "sync_transaction", "nested")
+
+ def __init__(self, connection: AsyncConnection, nested: bool = False):
+ self.connection = connection
+ self.sync_transaction: Optional[Transaction] = None
+ self.nested = nested
+
+ def _sync_transaction(self):
+ if not self.sync_transaction:
+ self._raise_for_not_started()
+ return self.sync_transaction
+
+ @property
+ def is_valid(self) -> bool:
+ return self._sync_transaction().is_valid
+
+ @property
+ def is_active(self) -> bool:
+ return self._sync_transaction().is_active
+
+ async def close(self):
+ """Close this :class:`.Transaction`.
+
+ If this transaction is the base transaction in a begin/commit
+ nesting, the transaction will rollback(). Otherwise, the
+ method returns.
+
+ This is used to cancel a Transaction without affecting the scope of
+ an enclosing transaction.
+
+ """
+ await greenlet_spawn(self._sync_transaction().close)
+
+ async def rollback(self):
+ """Roll back this :class:`.Transaction`.
+
+ """
+ await greenlet_spawn(self._sync_transaction().rollback)
+
+ async def commit(self):
+ """Commit this :class:`.Transaction`."""
+
+ await greenlet_spawn(self._sync_transaction().commit)
+
+ async def start(self):
+ """Start this :class:`_asyncio.AsyncTransaction` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+
+ self.sync_transaction = await greenlet_spawn(
+ self.connection._sync_connection().begin_nested
+ if self.nested
+ else self.connection._sync_connection().begin
+ )
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ if type_ is None and self.is_active:
+ try:
+ await self.commit()
+ except:
+ with util.safe_reraise():
+ await self.rollback()
+ else:
+ await self.rollback()
+
+
+def _get_sync_engine(async_engine):
+ try:
+ return async_engine.sync_engine
+ except AttributeError as e:
+ raise exc.ArgumentError(
+ "AsyncEngine expected, got %r" % async_engine
+ ) from e