diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_autogen_render.py | 28 | ||||
-rw-r--r-- | tests/test_op.py | 44 |
2 files changed, 72 insertions, 0 deletions
diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index 14a0001..c4f474e 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -1636,6 +1636,34 @@ class AutogenRenderTest(TestBase): "sa.ARRAY(sa.DateTime(timezone=True))", ) + @config.combinations(None, "default", "my_module.") + def test_render_create_table_with_user_module_type(self, mod): + class MyType(UserDefinedType): + def get_col_spec(self): + return "MYTYPE" + + type_ = MyType() + uo = ops.UpgradeOps( + ops=[ops.CreateTableOp("sometable", [Column("x", type_)])] + ) + if mod != "default": + kw = {"user_module_prefix": mod} + else: + kw = {} + if mod and mod != "default": + prefix = mod + else: + prefix = f"{__name__}." + + eq_( + autogenerate.render_python_code(uo, **kw), + "# ### commands auto generated by Alembic - please adjust! ###\n" + " op.create_table('sometable',\n" + f" sa.Column('x', {prefix}MyType(), nullable=True)\n" + " )\n" + " # ### end Alembic commands ###", + ) + def test_render_array_no_context(self): uo = ops.UpgradeOps( ops=[ diff --git a/tests/test_op.py b/tests/test_op.py index 8ae22a0..35adeaf 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -1,5 +1,8 @@ """Test against the builders in the op.* module.""" +from unittest.mock import MagicMock +from unittest.mock import patch + from sqlalchemy import Boolean from sqlalchemy import CheckConstraint from sqlalchemy import Column @@ -30,6 +33,7 @@ from alembic.testing import eq_ from alembic.testing import expect_warnings from alembic.testing import is_not_ from alembic.testing import mock +from alembic.testing.assertions import expect_raises_message from alembic.testing.fixtures import op_fixture from alembic.testing.fixtures import TestBase from alembic.util import sqla_compat @@ -1156,6 +1160,46 @@ class OpTest(TestBase): ("after_drop", "tb_test"), ] + @config.requirements.sqlalchemy_14 + def test_run_async_error(self): + op_fixture() + + async def go(conn): + pass + + with expect_raises_message( + NotImplementedError, "SQLAlchemy 1.4.18. required" + ): + with patch.object(sqla_compat, "sqla_14_18", False): + op.run_async(go) + with expect_raises_message( + NotImplementedError, "Cannot call run_async in SQL mode" + ): + with patch.object(op._proxy, "get_bind", lambda: None): + op.run_async(go) + with expect_raises_message( + ValueError, "Cannot call run_async with a sync engine" + ): + op.run_async(go) + + @config.requirements.sqlalchemy_14 + def test_run_async_ok(self): + from sqlalchemy.ext.asyncio import AsyncConnection + + op_fixture() + conn = op.get_bind() + mock_conn = MagicMock() + mock_fn = MagicMock() + with patch.object(conn.dialect, "is_async", True), patch.object( + AsyncConnection, "_retrieve_proxy_for_target", mock_conn + ), patch("sqlalchemy.util.await_only") as mock_await: + res = op.run_async(mock_fn, 99, foo=42) + + eq_(res, mock_await.return_value) + mock_conn.assert_called_once_with(conn) + mock_await.assert_called_once_with(mock_fn.return_value) + mock_fn.assert_called_once_with(mock_conn.return_value, 99, foo=42) + class SQLModeOpTest(TestBase): def test_auto_literals(self): |