summaryrefslogtreecommitdiff
path: root/tests/test_batch.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-11-05 11:24:23 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-11-05 11:24:23 -0400
commitd5a368ca7dbfe8501632cbacc69f04ccbfde48ae (patch)
tree2184681308d023f30554f0a67e681ab709c766cd /tests/test_batch.py
parentda31584344894c66f35afb507122f4e5b660fc00 (diff)
downloadalembic-d5a368ca7dbfe8501632cbacc69f04ccbfde48ae.tar.gz
sqlalchemy 2.0 test updates
- disable branched connection tests for 2.x - dont use future flag for 2.x - adjust batch tests for autobegin, inconsistent SQLite transactional DDL behaviors Change-Id: I70caf6afecc83f880dc92fa6cbc29e2043c43bb9
Diffstat (limited to 'tests/test_batch.py')
-rw-r--r--tests/test_batch.py56
1 files changed, 31 insertions, 25 deletions
diff --git a/tests/test_batch.py b/tests/test_batch.py
index 2753bdc..700056a 100644
--- a/tests/test_batch.py
+++ b/tests/test_batch.py
@@ -39,6 +39,7 @@ from alembic.testing import mock
from alembic.testing import TestBase
from alembic.testing.fixtures import op_fixture
from alembic.util import exc as alembic_exc
+from alembic.util.sqla_compat import _safe_commit_connection_transaction
from alembic.util.sqla_compat import _select
from alembic.util.sqla_compat import has_computed
from alembic.util.sqla_compat import has_identity
@@ -1282,6 +1283,20 @@ class BatchRoundTripTest(TestBase):
context = MigrationContext.configure(self.conn)
self.op = Operations(context)
+ def tearDown(self):
+ # why commit? because SQLite has inconsistent treatment
+ # of transactional DDL. A test that runs CREATE TABLE and then
+ # ALTER TABLE to change the name of that table, will end up
+ # committing the CREATE TABLE but not the ALTER. As batch mode
+ # does this with a temp table name that's not even in the
+ # metadata collection, we don't have an explicit drop for it
+ # (though we could do that too). calling commit means the
+ # ALTER will go through and the drop_all() will then catch it.
+ _safe_commit_connection_transaction(self.conn)
+ with self.conn.begin():
+ self.metadata.drop_all(self.conn)
+ self.conn.close()
+
@contextmanager
def _sqlite_referential_integrity(self):
self.conn.exec_driver_sql("PRAGMA foreign_keys=ON")
@@ -1385,7 +1400,7 @@ class BatchRoundTripTest(TestBase):
type_=Integer,
existing_type=Boolean(create_constraint=True, name="ck1"),
)
- insp = inspect(config.db)
+ insp = inspect(self.conn)
eq_(
[
@@ -1440,7 +1455,7 @@ class BatchRoundTripTest(TestBase):
batch_op.drop_column(
"x", existing_type=Boolean(create_constraint=True, name="ck1")
)
- insp = inspect(config.db)
+ insp = inspect(self.conn)
assert "x" not in (c["name"] for c in insp.get_columns("hasbool"))
@@ -1450,7 +1465,7 @@ class BatchRoundTripTest(TestBase):
batch_op.alter_column(
"x", type_=Boolean(create_constraint=True, name="ck1")
)
- insp = inspect(config.db)
+ insp = inspect(self.conn)
if exclusions.against(config, "sqlite"):
eq_(
@@ -1471,14 +1486,6 @@ class BatchRoundTripTest(TestBase):
[Integer],
)
- def tearDown(self):
- in_t = getattr(self.conn, "in_transaction", lambda: False)
- if in_t():
- self.conn.rollback()
- with self.conn.begin():
- self.metadata.drop_all(self.conn)
- self.conn.close()
-
def _assert_data(self, data, tablename="foo"):
res = self.conn.execute(text("select * from %s" % tablename))
if sqla_14:
@@ -1492,7 +1499,7 @@ class BatchRoundTripTest(TestBase):
batch_op.alter_column("data", type_=String(30))
batch_op.create_index("ix_data", ["data"])
- insp = inspect(config.db)
+ insp = inspect(self.conn)
eq_(
set(
(ix["name"], tuple(ix["column_names"]))
@@ -1734,7 +1741,7 @@ class BatchRoundTripTest(TestBase):
)
def _assert_table_comment(self, tname, comment):
- insp = inspect(config.db)
+ insp = inspect(self.conn)
tcomment = insp.get_table_comment(tname)
eq_(tcomment, {"text": comment})
@@ -1794,7 +1801,7 @@ class BatchRoundTripTest(TestBase):
self._assert_table_comment("foo", None)
def _assert_column_comment(self, tname, cname, comment):
- insp = inspect(config.db)
+ insp = inspect(self.conn)
cols = {col["name"]: col for col in insp.get_columns(tname)}
eq_(cols[cname]["comment"], comment)
@@ -2037,7 +2044,7 @@ class BatchRoundTripTest(TestBase):
]
)
eq_(
- [col["name"] for col in inspect(config.db).get_columns("foo")],
+ [col["name"] for col in inspect(self.conn).get_columns("foo")],
["id", "data", "x", "data2"],
)
@@ -2063,7 +2070,7 @@ class BatchRoundTripTest(TestBase):
]
)
eq_(
- [col["name"] for col in inspect(config.db).get_columns("foo")],
+ [col["name"] for col in inspect(self.conn).get_columns("foo")],
["id", "data", "x", "data2"],
)
@@ -2084,7 +2091,7 @@ class BatchRoundTripTest(TestBase):
tablename="nopk",
)
eq_(
- [col["name"] for col in inspect(config.db).get_columns("foo")],
+ [col["name"] for col in inspect(self.conn).get_columns("foo")],
["id", "data", "x"],
)
@@ -2104,7 +2111,7 @@ class BatchRoundTripTest(TestBase):
]
)
eq_(
- [col["name"] for col in inspect(config.db).get_columns("foo")],
+ [col["name"] for col in inspect(self.conn).get_columns("foo")],
["id", "data2", "data", "x"],
)
@@ -2124,7 +2131,7 @@ class BatchRoundTripTest(TestBase):
]
)
eq_(
- [col["name"] for col in inspect(config.db).get_columns("foo")],
+ [col["name"] for col in inspect(self.conn).get_columns("foo")],
["id", "data", "data2", "x"],
)
@@ -2158,12 +2165,12 @@ class BatchRoundTripTest(TestBase):
]
)
eq_(
- [col["name"] for col in inspect(config.db).get_columns("foo")],
+ [col["name"] for col in inspect(self.conn).get_columns("foo")],
["id", "data", "x", "data2"],
)
def test_create_drop_index(self):
- insp = inspect(config.db)
+ insp = inspect(self.conn)
eq_(insp.get_indexes("foo"), [])
with self.op.batch_alter_table("foo", recreate="always") as batch_op:
@@ -2178,8 +2185,7 @@ class BatchRoundTripTest(TestBase):
{"id": 5, "data": "d5", "x": 9},
]
)
-
- insp = inspect(config.db)
+ insp = inspect(self.conn)
eq_(
[
dict(
@@ -2195,7 +2201,7 @@ class BatchRoundTripTest(TestBase):
with self.op.batch_alter_table("foo", recreate="always") as batch_op:
batch_op.drop_index("ix_data")
- insp = inspect(config.db)
+ insp = inspect(self.conn)
eq_(insp.get_indexes("foo"), [])
@@ -2316,7 +2322,7 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
) as batch_op:
batch_op.add_column(Column("data", Integer))
- insp = inspect(config.db)
+ insp = inspect(self.conn)
eq_(
[