summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_autogen_render.py28
-rw-r--r--tests/test_op.py44
2 files changed, 72 insertions, 0 deletions
diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py
index 14a0001..c4f474e 100644
--- a/tests/test_autogen_render.py
+++ b/tests/test_autogen_render.py
@@ -1636,6 +1636,34 @@ class AutogenRenderTest(TestBase):
"sa.ARRAY(sa.DateTime(timezone=True))",
)
+ @config.combinations(None, "default", "my_module.")
+ def test_render_create_table_with_user_module_type(self, mod):
+ class MyType(UserDefinedType):
+ def get_col_spec(self):
+ return "MYTYPE"
+
+ type_ = MyType()
+ uo = ops.UpgradeOps(
+ ops=[ops.CreateTableOp("sometable", [Column("x", type_)])]
+ )
+ if mod != "default":
+ kw = {"user_module_prefix": mod}
+ else:
+ kw = {}
+ if mod and mod != "default":
+ prefix = mod
+ else:
+ prefix = f"{__name__}."
+
+ eq_(
+ autogenerate.render_python_code(uo, **kw),
+ "# ### commands auto generated by Alembic - please adjust! ###\n"
+ " op.create_table('sometable',\n"
+ f" sa.Column('x', {prefix}MyType(), nullable=True)\n"
+ " )\n"
+ " # ### end Alembic commands ###",
+ )
+
def test_render_array_no_context(self):
uo = ops.UpgradeOps(
ops=[
diff --git a/tests/test_op.py b/tests/test_op.py
index 8ae22a0..35adeaf 100644
--- a/tests/test_op.py
+++ b/tests/test_op.py
@@ -1,5 +1,8 @@
"""Test against the builders in the op.* module."""
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
@@ -30,6 +33,7 @@ from alembic.testing import eq_
from alembic.testing import expect_warnings
from alembic.testing import is_not_
from alembic.testing import mock
+from alembic.testing.assertions import expect_raises_message
from alembic.testing.fixtures import op_fixture
from alembic.testing.fixtures import TestBase
from alembic.util import sqla_compat
@@ -1156,6 +1160,46 @@ class OpTest(TestBase):
("after_drop", "tb_test"),
]
+ @config.requirements.sqlalchemy_14
+ def test_run_async_error(self):
+ op_fixture()
+
+ async def go(conn):
+ pass
+
+ with expect_raises_message(
+ NotImplementedError, "SQLAlchemy 1.4.18. required"
+ ):
+ with patch.object(sqla_compat, "sqla_14_18", False):
+ op.run_async(go)
+ with expect_raises_message(
+ NotImplementedError, "Cannot call run_async in SQL mode"
+ ):
+ with patch.object(op._proxy, "get_bind", lambda: None):
+ op.run_async(go)
+ with expect_raises_message(
+ ValueError, "Cannot call run_async with a sync engine"
+ ):
+ op.run_async(go)
+
+ @config.requirements.sqlalchemy_14
+ def test_run_async_ok(self):
+ from sqlalchemy.ext.asyncio import AsyncConnection
+
+ op_fixture()
+ conn = op.get_bind()
+ mock_conn = MagicMock()
+ mock_fn = MagicMock()
+ with patch.object(conn.dialect, "is_async", True), patch.object(
+ AsyncConnection, "_retrieve_proxy_for_target", mock_conn
+ ), patch("sqlalchemy.util.await_only") as mock_await:
+ res = op.run_async(mock_fn, 99, foo=42)
+
+ eq_(res, mock_await.return_value)
+ mock_conn.assert_called_once_with(conn)
+ mock_await.assert_called_once_with(mock_fn.return_value)
+ mock_fn.assert_called_once_with(mock_conn.return_value, 99, foo=42)
+
class SQLModeOpTest(TestBase):
def test_auto_literals(self):