summaryrefslogtreecommitdiff
path: root/alembic
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2023-04-29 23:25:21 +0200
committerMike Bayer <mike_mp@zzzcomputing.com>2023-05-12 12:06:14 -0400
commit6d2df6c92c981d220bd1481ebdc7d86274ccf845 (patch)
treebe009a39f769dd4cff23d1fd917f651adeccee72 /alembic
parent92e54a0e1c96cecd99397cb1aee9c3bb28f780c6 (diff)
downloadalembic-6d2df6c92c981d220bd1481ebdc7d86274ccf845.tar.gz
Added ``op.run_async``.
Added :meth:`.Operations.run_async` to the operation module to allow running async functions in the ``upgrade`` or ``downgrade`` migration function when running alembic using an async dialect. This function will receive as first argument an class:`~sqlalchemy.ext.asyncio.AsyncConnection` sharing the transaction used in the migration context. also restore the .execute() method to BatchOperations Fixes: #1231 Change-Id: I3c3237d570be3c9bd9834e4c61bb3231bfb82765
Diffstat (limited to 'alembic')
-rw-r--r--alembic/op.pyi29
-rw-r--r--alembic/operations/base.py45
-rw-r--r--alembic/operations/ops.py1
-rw-r--r--alembic/util/sqla_compat.py2
4 files changed, 75 insertions, 2 deletions
diff --git a/alembic/op.pyi b/alembic/op.pyi
index aa3ad2d..4395f77 100644
--- a/alembic/op.pyi
+++ b/alembic/op.pyi
@@ -4,6 +4,7 @@ from __future__ import annotations
from contextlib import contextmanager
from typing import Any
+from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import Iterator
@@ -15,6 +16,7 @@ from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from sqlalchemy.sql.expression import TableClause
@@ -38,6 +40,8 @@ if TYPE_CHECKING:
from .operations.ops import MigrateOperation
from .runtime.migration import MigrationContext
from .util.sqla_compat import _literal_bindparam
+
+_T = TypeVar("_T")
### end imports ###
def add_column(
@@ -1238,3 +1242,28 @@ def rename_table(
:class:`~sqlalchemy.sql.elements.quoted_name`.
"""
+
+def run_async(
+ async_function: Callable[..., Awaitable[_T]], *args: Any, **kw_args: Any
+) -> _T:
+ """Invoke the given asynchronous callable, passing an asynchronous
+ :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first
+ argument.
+
+ This method allows calling async functions from within the
+ synchronous ``upgrade()`` or ``downgrade()`` alembic migration
+ method.
+
+ The async connection passed to the callable shares the same
+ transaction as the connection running in the migration context.
+
+ Any additional arg or kw_arg passed to this function are passed
+ to the provided async function.
+
+ .. versionadded: 1.11
+
+ .. note::
+
+ This method can be called only when alembic is called using
+ an async dialect.
+ """
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index 6e45a11..b4190dc 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -4,6 +4,7 @@ from contextlib import contextmanager
import re
import textwrap
from typing import Any
+from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import Iterator
@@ -14,6 +15,7 @@ from typing import Sequence # noqa
from typing import Tuple
from typing import Type # noqa
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from sqlalchemy.sql.elements import conv
@@ -28,8 +30,6 @@ from ..util.compat import inspect_getfullargspec
from ..util.sqla_compat import _literal_bindparam
-NoneType = type(None)
-
if TYPE_CHECKING:
from typing import Literal
@@ -51,6 +51,7 @@ if TYPE_CHECKING:
from ..ddl import DefaultImpl
from ..runtime.migration import MigrationContext
__all__ = ("Operations", "BatchOperations")
+_T = TypeVar("_T")
class AbstractOperations(util.ModuleClsProxy):
@@ -483,6 +484,46 @@ class AbstractOperations(util.ModuleClsProxy):
"""
return self.migration_context.impl.bind # type: ignore[return-value]
+ def run_async(
+ self,
+ async_function: Callable[..., Awaitable[_T]],
+ *args: Any,
+ **kw_args: Any,
+ ) -> _T:
+ """Invoke the given asynchronous callable, passing an asynchronous
+ :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first
+ argument.
+
+ This method allows calling async functions from within the
+ synchronous ``upgrade()`` or ``downgrade()`` alembic migration
+ method.
+
+ The async connection passed to the callable shares the same
+ transaction as the connection running in the migration context.
+
+ Any additional arg or kw_arg passed to this function are passed
+ to the provided async function.
+
+ .. versionadded: 1.11
+
+ .. note::
+
+ This method can be called only when alembic is called using
+ an async dialect.
+ """
+ if not sqla_compat.sqla_14_18:
+ raise NotImplementedError("SQLAlchemy 1.4.18+ required")
+ sync_conn = self.get_bind()
+ if sync_conn is None:
+ raise NotImplementedError("Cannot call run_async in SQL mode")
+ if not sync_conn.dialect.is_async:
+ raise ValueError("Cannot call run_async with a sync engine")
+ from sqlalchemy.ext.asyncio import AsyncConnection
+ from sqlalchemy.util import await_only
+
+ async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
+ return await_only(async_function(async_conn, *args, **kw_args))
+
class Operations(AbstractOperations):
"""Define high level migration operations.
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index 99d21d9..3a002c1 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -2375,6 +2375,7 @@ class BulkInsertOp(MigrateOperation):
@Operations.register_operation("execute")
+@BatchOperations.register_operation("execute")
class ExecuteSQLOp(MigrateOperation):
"""Represent an execute SQL operation."""
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index 0070337..37e1ee1 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -61,6 +61,8 @@ _vers = tuple(
)
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
+# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
+sqla_14_18 = _vers >= (1, 4, 18)
sqla_14_26 = _vers >= (1, 4, 26)
sqla_2 = _vers >= (2,)
sqlalchemy_version = __version__