diff options
Diffstat (limited to 'tests/test_op.py')
-rw-r--r-- | tests/test_op.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tests/test_op.py b/tests/test_op.py index 8ae22a0..35adeaf 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -1,5 +1,8 @@ """Test against the builders in the op.* module.""" +from unittest.mock import MagicMock +from unittest.mock import patch + from sqlalchemy import Boolean from sqlalchemy import CheckConstraint from sqlalchemy import Column @@ -30,6 +33,7 @@ from alembic.testing import eq_ from alembic.testing import expect_warnings from alembic.testing import is_not_ from alembic.testing import mock +from alembic.testing.assertions import expect_raises_message from alembic.testing.fixtures import op_fixture from alembic.testing.fixtures import TestBase from alembic.util import sqla_compat @@ -1156,6 +1160,46 @@ class OpTest(TestBase): ("after_drop", "tb_test"), ] + @config.requirements.sqlalchemy_14 + def test_run_async_error(self): + op_fixture() + + async def go(conn): + pass + + with expect_raises_message( + NotImplementedError, "SQLAlchemy 1.4.18. required" + ): + with patch.object(sqla_compat, "sqla_14_18", False): + op.run_async(go) + with expect_raises_message( + NotImplementedError, "Cannot call run_async in SQL mode" + ): + with patch.object(op._proxy, "get_bind", lambda: None): + op.run_async(go) + with expect_raises_message( + ValueError, "Cannot call run_async with a sync engine" + ): + op.run_async(go) + + @config.requirements.sqlalchemy_14 + def test_run_async_ok(self): + from sqlalchemy.ext.asyncio import AsyncConnection + + op_fixture() + conn = op.get_bind() + mock_conn = MagicMock() + mock_fn = MagicMock() + with patch.object(conn.dialect, "is_async", True), patch.object( + AsyncConnection, "_retrieve_proxy_for_target", mock_conn + ), patch("sqlalchemy.util.await_only") as mock_await: + res = op.run_async(mock_fn, 99, foo=42) + + eq_(res, mock_await.return_value) + mock_conn.assert_called_once_with(conn) + mock_await.assert_called_once_with(mock_fn.return_value) + mock_fn.assert_called_once_with(mock_conn.return_value, 99, foo=42) + class SQLModeOpTest(TestBase): def test_auto_literals(self): |