summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-01-10 13:44:14 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-01-13 22:10:13 -0500
commitf1e96cb0874927a475d0c111393b7861796dd758 (patch)
tree810f3c43c0d2c6336805ebcf13d86d5cf1226efa /test
parent7f92fdbd8ec479a61c53c11921ce0688ad4dd94b (diff)
downloadsqlalchemy-f1e96cb0874927a475d0c111393b7861796dd758.tar.gz
reinvent xdist hooks in terms of pytest fixtures
To allow the "connection" pytest fixture and others work correctly in conjunction with setup/teardown that expects to be external to the transaction, remove and prevent any usage of "xdist" style names that are hardcoded by pytest to run inside of fixtures, even function level ones. Instead use pytest autouse fixtures to implement our own r"setup|teardown_test(?:_class)?" methods so that we can ensure function-scoped fixtures are run within them. A new more explicit flow is set up within plugin_base and pytestplugin such that the order of setup/teardown steps, which there are now many, is fully documented and controllable. New granularity has been added to the test teardown phase to distinguish between "end of the test" when lock-holding structures on connections should be released to allow for table drops, vs. "end of the test plus its teardown steps" when we can perform final cleanup on connections and run assertions that everything is closed out. From there we can remove most of the defensive "tear down everything" logic inside of engines which for many years would frequently dispose of pools over and over again, creating for a broken and expensive connection flow. A quick test shows that running test/sql/ against a single Postgresql engine with the new approach uses 75% fewer new connections, creating 42 new connections total, vs. 164 new connections total with the previous system. As part of this, the new fixtures metadata/connection/future_connection have been integrated such that they can be combined together effectively. The fixture_session(), provide_metadata() fixtures have been improved, including that fixture_session() now strongly references sessions which are explicitly torn down before table drops occur afer a test. Major changes have been made to the ConnectionKiller such that it now features different "scopes" for testing engines and will limit its cleanup to those testing engines corresponding to end of test, end of test class, or end of test session. The system by which it tracks DBAPI connections has been reworked, is ultimately somewhat similar to how it worked before but is organized more clearly along with the proxy-tracking logic. A "testing_engine" fixture is also added that works as a pytest fixture rather than a standalone function. The connection cleanup logic should now be very robust, as we now can use the same global connection pools for the whole suite without ever disposing them, while also running a query for PostgreSQL locks remaining after every test and assert there are no open transactions leaking between tests at all. Additional steps are added that also accommodate for asyncio connections not explicitly closed, as is the case for legacy sync-style tests as well as the async tests themselves. As always, hundreds of tests are further refined to use the new fixtures where problems with loose connections were identified, largely as a result of the new PostgreSQL assertions, many more tests have moved from legacy patterns into the newest. An unfortunate discovery during the creation of this system is that autouse fixtures (as well as if they are set up by @pytest.mark.usefixtures) are not usable at our current scale with pytest 4.6.11 running under Python 2. It's unclear if this is due to the older version of pytest or how it implements itself for Python 2, as well as if the issue is CPU slowness or just large memory use, but collecting the full span of tests takes over a minute for a single process when any autouse fixtures are in place and on CI the jobs just time out after ten minutes. So at the moment this patch also reinvents a small version of "autouse" fixtures when py2k is running, which skips generating the real fixture and instead uses two global pytest fixtures (which don't seem to impact performance) to invoke the "autouse" fixtures ourselves outside of pytest. This will limit our ability to do more with fixtures until we can remove py2k support. py.test is still observed to be much slower in collection in the 4.6.11 version compared to modern 6.2 versions, so add support for new TOX_POSTGRESQL_PY2K and TOX_MYSQL_PY2K environment variables that will run the suite for fewer backends under Python 2. For Python 3 pin pytest to modern 6.2 versions where performance for collection has been improved greatly. Includes the following improvements: Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would be raised rather than :class:`.exc.TimeoutError`. Also repaired the :paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using the async engine, which previously would ignore the timeout and block rather than timing out immediately as is the behavior with regular :class:`.QueuePool`. For asyncio the connection pool will now also not interact at all with an asyncio connection whose ConnectionFairy is being garbage collected; a warning that the connection was not properly closed is emitted and the connection is discarded. Within the test suite the ConnectionKiller is now maintaining strong references to all DBAPI connections and ensuring they are released when tests end, including those whose ConnectionFairy proxies are GCed. Identified cx_Oracle.stmtcachesize as a major factor in Oracle test scalability issues, this can be reset on a per-test basis rather than setting it to zero across the board. the addition of this flag has resolved the long-standing oracle "two task" error problem. For SQL Server, changed the temp table style used by the "suite" tests to be the double-pound-sign, i.e. global, variety, which is much easier to test generically. There are already reflection tests that are more finely tuned to both styles of temp table within the mssql test suite. Additionally, added an extra step to the "dropfirst" mechanism for SQL Server that will remove all foreign key constraints first as some issues were observed when using this flag when multiple schemas had not been torn down. Identified and fixed two subtle failure modes in the engine, when commit/rollback fails in a begin() context manager, the connection is explicitly closed, and when "initialize()" fails on the first new connection of a dialect, the transactional state on that connection is still rolled back. Fixes: #5826 Fixes: #5827 Change-Id: Ib1d05cb8c7cf84f9a4bfd23df397dc23c9329bfe
Diffstat (limited to 'test')
-rw-r--r--test/aaa_profiling/test_compiler.py2
-rw-r--r--test/aaa_profiling/test_memusage.py4
-rw-r--r--test/aaa_profiling/test_misc.py2
-rw-r--r--test/aaa_profiling/test_orm.py6
-rw-r--r--test/aaa_profiling/test_pool.py2
-rw-r--r--test/base/test_events.py22
-rw-r--r--test/base/test_inspect.py2
-rw-r--r--test/base/test_tutorials.py4
-rw-r--r--test/dialect/mssql/test_compiler.py2
-rw-r--r--test/dialect/mssql/test_deprecations.py2
-rw-r--r--test/dialect/mssql/test_query.py3
-rw-r--r--test/dialect/mysql/test_compiler.py4
-rw-r--r--test/dialect/mysql/test_reflection.py2
-rw-r--r--test/dialect/oracle/test_compiler.py2
-rw-r--r--test/dialect/oracle/test_dialect.py4
-rw-r--r--test/dialect/oracle/test_reflection.py16
-rw-r--r--test/dialect/oracle/test_types.py4
-rw-r--r--test/dialect/postgresql/test_async_pg_py3k.py4
-rw-r--r--test/dialect/postgresql/test_compiler.py8
-rw-r--r--test/dialect/postgresql/test_dialect.py21
-rw-r--r--test/dialect/postgresql/test_query.py6
-rw-r--r--test/dialect/postgresql/test_reflection.py499
-rw-r--r--test/dialect/postgresql/test_types.py692
-rw-r--r--test/dialect/test_sqlite.py32
-rw-r--r--test/engine/test_ddlevents.py4
-rw-r--r--test/engine/test_deprecations.py16
-rw-r--r--test/engine/test_execute.py364
-rw-r--r--test/engine/test_logging.py16
-rw-r--r--test/engine/test_pool.py57
-rw-r--r--test/engine/test_processors.py10
-rw-r--r--test/engine/test_reconnect.py20
-rw-r--r--test/engine/test_reflection.py9
-rw-r--r--test/engine/test_transaction.py21
-rw-r--r--test/ext/asyncio/test_engine_py3k.py16
-rw-r--r--test/ext/declarative/test_inheritance.py4
-rw-r--r--test/ext/declarative/test_reflection.py127
-rw-r--r--test/ext/test_associationproxy.py16
-rw-r--r--test/ext/test_baked.py2
-rw-r--r--test/ext/test_compiler.py2
-rw-r--r--test/ext/test_extendedattr.py6
-rw-r--r--test/ext/test_horizontal_shard.py32
-rw-r--r--test/ext/test_hybrid.py2
-rw-r--r--test/ext/test_mutable.py15
-rw-r--r--test/ext/test_orderinglist.py16
-rw-r--r--test/orm/declarative/test_basic.py4
-rw-r--r--test/orm/declarative/test_concurrency.py2
-rw-r--r--test/orm/declarative/test_inheritance.py4
-rw-r--r--test/orm/declarative/test_mixin.py4
-rw-r--r--test/orm/declarative/test_reflection.py5
-rw-r--r--test/orm/inheritance/test_basic.py49
-rw-r--r--test/orm/test_attributes.py8
-rw-r--r--test/orm/test_bind.py74
-rw-r--r--test/orm/test_collection.py5
-rw-r--r--test/orm/test_compile.py2
-rw-r--r--test/orm/test_cycles.py3
-rw-r--r--test/orm/test_deprecations.py80
-rw-r--r--test/orm/test_eager_relations.py14
-rw-r--r--test/orm/test_events.py7
-rw-r--r--test/orm/test_froms.py106
-rw-r--r--test/orm/test_lazy_relations.py57
-rw-r--r--test/orm/test_load_on_fks.py30
-rw-r--r--test/orm/test_mapper.py6
-rw-r--r--test/orm/test_options.py4
-rw-r--r--test/orm/test_query.py37
-rw-r--r--test/orm/test_rel_fn.py2
-rw-r--r--test/orm/test_relationships.py6
-rw-r--r--test/orm/test_selectin_relations.py50
-rw-r--r--test/orm/test_session.py20
-rw-r--r--test/orm/test_subquery_relations.py50
-rw-r--r--test/orm/test_transaction.py681
-rw-r--r--test/orm/test_unitofwork.py8
-rw-r--r--test/orm/test_unitofworkv2.py3
-rw-r--r--test/requirements.py16
-rw-r--r--test/sql/test_case_statement.py4
-rw-r--r--test/sql/test_compare.py2
-rw-r--r--test/sql/test_compiler.py2
-rw-r--r--test/sql/test_defaults.py5
-rw-r--r--test/sql/test_deprecations.py6
-rw-r--r--test/sql/test_external_traversal.py14
-rw-r--r--test/sql/test_from_linter.py8
-rw-r--r--test/sql/test_functions.py10
-rw-r--r--test/sql/test_metadata.py2
-rw-r--r--test/sql/test_operators.py8
-rw-r--r--test/sql/test_resultset.py15
-rw-r--r--test/sql/test_sequences.py4
-rw-r--r--test/sql/test_types.py11
86 files changed, 1780 insertions, 1748 deletions
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py
index 0202768ae..968a74700 100644
--- a/test/aaa_profiling/test_compiler.py
+++ b/test/aaa_profiling/test_compiler.py
@@ -18,7 +18,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, metadata
metadata = MetaData()
diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py
index 75a4f51cf..a41a8b9f1 100644
--- a/test/aaa_profiling/test_memusage.py
+++ b/test/aaa_profiling/test_memusage.py
@@ -241,7 +241,7 @@ def assert_no_mappers():
class EnsureZeroed(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
_sessions.clear()
_mapper_registry.clear()
@@ -1032,7 +1032,7 @@ class MemUsageWBackendTest(EnsureZeroed):
t2_mapper = mapper(T2, t2)
t1_mapper.add_property("bar", relationship(t2_mapper))
- s1 = fixture_session()
+ s1 = Session(testing.db)
# this causes the path_registry to be invoked
s1.query(t1_mapper)._compile_context()
diff --git a/test/aaa_profiling/test_misc.py b/test/aaa_profiling/test_misc.py
index db6fd4b71..5b30a3968 100644
--- a/test/aaa_profiling/test_misc.py
+++ b/test/aaa_profiling/test_misc.py
@@ -19,7 +19,7 @@ from sqlalchemy.util import classproperty
class EnumTest(fixtures.TestBase):
__requires__ = ("cpython", "python_profiling_backend")
- def setup(self):
+ def setup_test(self):
class SomeEnum(object):
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py
index f163078d8..8116e5f21 100644
--- a/test/aaa_profiling/test_orm.py
+++ b/test/aaa_profiling/test_orm.py
@@ -29,15 +29,13 @@ class NoCache(object):
run_setup_bind = "each"
@classmethod
- def setup_class(cls):
- super(NoCache, cls).setup_class()
+ def setup_test_class(cls):
cls._cache = config.db._compiled_cache
config.db._compiled_cache = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
config.db._compiled_cache = cls._cache
- super(NoCache, cls).teardown_class()
class MergeTest(NoCache, fixtures.MappedTest):
diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py
index fd02f9139..da3c1c525 100644
--- a/test/aaa_profiling/test_pool.py
+++ b/test/aaa_profiling/test_pool.py
@@ -17,7 +17,7 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults):
def close(self):
pass
- def setup(self):
+ def setup_test(self):
# create a throwaway pool which
# has the effect of initializing
# class-level event listeners on Pool,
diff --git a/test/base/test_events.py b/test/base/test_events.py
index 19f68e9a3..68db5207c 100644
--- a/test/base/test_events.py
+++ b/test/base/test_events.py
@@ -16,7 +16,7 @@ from sqlalchemy.testing.util import gc_collect
class TearDownLocalEventsFixture(object):
- def tearDown(self):
+ def teardown_test(self):
classes = set()
for entry in event.base._registrars.values():
for evt_cls in entry:
@@ -30,7 +30,7 @@ class TearDownLocalEventsFixture(object):
class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test class- and instance-level event registration."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, x, y):
pass
@@ -438,7 +438,7 @@ class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test adaption of legacy args"""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
@event._legacy_signature("0.9", ["x", "y"])
def event_three(self, x, y, z, q):
@@ -608,7 +608,7 @@ class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
@@ -677,7 +677,7 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test default target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
@@ -734,7 +734,7 @@ class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _accept_with(cls, target):
@@ -771,7 +771,7 @@ class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test that ad-hoc subclasses are garbage collected."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def some_event(self, x, y):
pass
@@ -797,7 +797,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom listen functions which change the listener function
signature."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _listen(cls, event_key, add=False):
@@ -855,7 +855,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, arg):
pass
@@ -889,7 +889,7 @@ class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
@@ -1109,7 +1109,7 @@ class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
diff --git a/test/base/test_inspect.py b/test/base/test_inspect.py
index 15b98c848..252d0d977 100644
--- a/test/base/test_inspect.py
+++ b/test/base/test_inspect.py
@@ -13,7 +13,7 @@ class TestFixture(object):
class TestInspection(fixtures.TestBase):
- def tearDown(self):
+ def teardown_test(self):
for type_ in list(inspection._registrars):
if issubclass(type_, TestFixture):
del inspection._registrars[type_]
diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py
index 14e87ef69..6320ef052 100644
--- a/test/base/test_tutorials.py
+++ b/test/base/test_tutorials.py
@@ -48,11 +48,11 @@ class DocTest(fixtures.TestBase):
ddl.sort_tables_and_constraints = self.orig_sort
- def setup(self):
+ def setup_test(self):
self._setup_logger()
self._setup_create_table_patcher()
- def teardown(self):
+ def teardown_test(self):
self._teardown_create_table_patcher()
self._teardown_logger()
diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py
index 8119612e1..f0bb66aa9 100644
--- a/test/dialect/mssql/test_compiler.py
+++ b/test/dialect/mssql/test_compiler.py
@@ -1814,7 +1814,7 @@ class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL):
class SchemaTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
t = Table(
"sometable",
MetaData(),
diff --git a/test/dialect/mssql/test_deprecations.py b/test/dialect/mssql/test_deprecations.py
index c869182c5..27709beb0 100644
--- a/test/dialect/mssql/test_deprecations.py
+++ b/test/dialect/mssql/test_deprecations.py
@@ -31,7 +31,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL):
"""
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.t1 = table(
"t1",
diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py
index cdb37cc61..b806b9247 100644
--- a/test/dialect/mssql/test_query.py
+++ b/test/dialect/mssql/test_query.py
@@ -455,7 +455,7 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
return testing.db.execution_options(isolation_level="AUTOCOMMIT")
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as conn:
@@ -463,7 +463,6 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
except:
pass
- super(MatchTest, cls).setup_class()
@classmethod
def insert_data(cls, connection):
diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py
index 62292b9da..7fd24e8b5 100644
--- a/test/dialect/mysql/test_compiler.py
+++ b/test/dialect/mysql/test_compiler.py
@@ -991,7 +991,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = mysql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"foos",
MetaData(),
@@ -1062,7 +1062,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpCommon(testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py
index 40617e59c..795b2cbd3 100644
--- a/test/dialect/mysql/test_reflection.py
+++ b/test/dialect/mysql/test_reflection.py
@@ -1115,7 +1115,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
class RawReflectionTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
dialect = mysql.dialect()
self.parser = _reflection.MySQLTableDefinitionParser(
dialect, dialect.identifier_preparer
diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py
index 1b8b3fb89..f09346eb3 100644
--- a/test/dialect/oracle/test_compiler.py
+++ b/test/dialect/oracle/test_compiler.py
@@ -1355,7 +1355,7 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "oracle"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py
index df87fe89f..32234bf65 100644
--- a/test/dialect/oracle/test_dialect.py
+++ b/test/dialect/oracle/test_dialect.py
@@ -439,7 +439,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as c:
c.exec_driver_sql(
"""
@@ -471,7 +471,7 @@ end;
assert isinstance(result.out_parameters["x_out"], int)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.execute(text("DROP PROCEDURE foo"))
diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py
index 81e4e4ab5..0df4236e2 100644
--- a/test/dialect/oracle/test_reflection.py
+++ b/test/dialect/oracle/test_reflection.py
@@ -39,7 +39,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# currently assuming full DBA privs for the user.
# don't really know how else to go here unless
# we connect as the other user.
@@ -85,7 +85,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
conn.exec_driver_sql(stmt)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for stmt in (
"""
@@ -379,7 +379,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("create table my_table (id integer)")
conn.exec_driver_sql(
@@ -389,7 +389,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
"create table foo_table (id integer) tablespace SYSTEM"
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table my_temp_table")
conn.exec_driver_sql("drop table my_table")
@@ -421,7 +421,7 @@ class DontReflectIOTTest(fixtures.TestBase):
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
@@ -438,7 +438,7 @@ class DontReflectIOTTest(fixtures.TestBase):
""",
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table admin_docindex")
@@ -715,7 +715,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.testing import config
cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link")
@@ -734,7 +734,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop synonym test_table_syn")
conn.exec_driver_sql("drop table test_table")
diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py
index f008ea019..8ea7c0e04 100644
--- a/test/dialect/oracle/test_types.py
+++ b/test/dialect/oracle/test_types.py
@@ -1011,7 +1011,7 @@ class EuroNumericTest(fixtures.TestBase):
__only_on__ = "oracle+cx_oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
connect = testing.db.pool._creator
def _creator():
@@ -1023,7 +1023,7 @@ class EuroNumericTest(fixtures.TestBase):
self.engine = testing_engine(options={"creator": _creator})
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_were_getting_a_comma(self):
diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py
index fadf939b8..f6d48f3c6 100644
--- a/test/dialect/postgresql/test_async_pg_py3k.py
+++ b/test/dialect/postgresql/test_async_pg_py3k.py
@@ -27,7 +27,7 @@ class AsyncPgTest(fixtures.TestBase):
# TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3
# merges
- from sqlalchemy.testing import engines
+ from sqlalchemy.testing import util as testing_util
from sqlalchemy.sql import schema
metadata = schema.MetaData()
@@ -35,7 +35,7 @@ class AsyncPgTest(fixtures.TestBase):
try:
yield metadata
finally:
- engines.drop_all_tables(metadata, testing.db)
+ testing_util.drop_all_tables_from_metadata(metadata, testing.db)
@async_test
async def test_detect_stale_ddl_cache_raise_recover(
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 1763b210b..b3a0b9bbd 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1810,7 +1810,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table1 = table1 = table(
"mytable",
column("myid", Integer),
@@ -2222,7 +2222,7 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
@@ -2373,7 +2373,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
@@ -2464,7 +2464,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "postgresql"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index f760a309b..9c9d817bb 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -198,17 +198,17 @@ class ExecuteManyMode(object):
@config.fixture()
def connection(self):
- eng = engines.testing_engine(options=self.options)
+ opts = dict(self.options)
+ opts["use_reaper"] = False
+ eng = engines.testing_engine(options=opts)
conn = eng.connect()
trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
- eng.dispose()
+ yield conn
+ if trans.is_active:
+ trans.rollback()
+ conn.close()
+ eng.dispose()
@classmethod
def define_tables(cls, metadata):
@@ -510,8 +510,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
# assert result.closed
assert result.cursor is None
- @testing.provide_metadata
- def test_insert_returning_preexecute_pk(self, connection):
+ def test_insert_returning_preexecute_pk(self, metadata, connection):
counter = itertools.count(1)
t = Table(
@@ -525,7 +524,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
),
Column("data", Integer),
)
- self.metadata.create_all(connection)
+ metadata.create_all(connection)
result = connection.execute(
t.insert().return_defaults(),
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index 94af168ee..c51fd1943 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -40,10 +40,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
self.metadata.drop_all(conn)
@@ -890,7 +890,7 @@ class ExtractTest(fixtures.TablesTest):
def setup_bind(cls):
from sqlalchemy import event
- eng = engines.testing_engine()
+ eng = engines.testing_engine(options={"scope": "class"})
@event.listens_for(eng, "connect")
def connect(dbapi_conn, rec):
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py
index 754eff25a..6586a8308 100644
--- a/test/dialect/postgresql/test_reflection.py
+++ b/test/dialect/postgresql/test_reflection.py
@@ -80,26 +80,24 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
]:
sa.event.listen(metadata, "before_drop", sa.DDL(ddl))
- def test_foreign_table_is_reflected(self):
+ def test_foreign_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_foreigntable", metadata, autoload_with=testing.db)
+ table = Table("test_foreigntable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected foreign table didn't equal expected columns",
)
- def test_get_foreign_table_names(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- ft_names = inspector.get_foreign_table_names()
- eq_(ft_names, ["test_foreigntable"])
+ def test_get_foreign_table_names(self, connection):
+ inspector = inspect(connection)
+ ft_names = inspector.get_foreign_table_names()
+ eq_(ft_names, ["test_foreigntable"])
- def test_get_table_names_no_foreign(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- names = inspector.get_table_names()
- eq_(names, ["testtable"])
+ def test_get_table_names_no_foreign(self, connection):
+ inspector = inspect(connection)
+ names = inspector.get_table_names()
+ eq_(names, ["testtable"])
class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
@@ -133,22 +131,22 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
if testing.against("postgresql >= 11"):
Index("my_index", dv.c.q)
- def test_get_tablenames(self):
+ def test_get_tablenames(self, connection):
assert {"data_values", "data_values_4_10"}.issubset(
- inspect(testing.db).get_table_names()
+ inspect(connection).get_table_names()
)
- def test_reflect_cols(self):
- cols = inspect(testing.db).get_columns("data_values")
+ def test_reflect_cols(self, connection):
+ cols = inspect(connection).get_columns("data_values")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
- def test_reflect_cols_from_partition(self):
- cols = inspect(testing.db).get_columns("data_values_4_10")
+ def test_reflect_cols_from_partition(self, connection):
+ cols = inspect(connection).get_columns("data_values_4_10")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
@testing.only_on("postgresql >= 11")
- def test_reflect_index(self):
- idx = inspect(testing.db).get_indexes("data_values")
+ def test_reflect_index(self, connection):
+ idx = inspect(connection).get_indexes("data_values")
eq_(
idx,
[
@@ -162,8 +160,8 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
)
@testing.only_on("postgresql >= 11")
- def test_reflect_index_from_partition(self):
- idx = inspect(testing.db).get_indexes("data_values_4_10")
+ def test_reflect_index_from_partition(self, connection):
+ idx = inspect(connection).get_indexes("data_values_4_10")
# note the name appears to be generated by PG, currently
# 'data_values_4_10_q_idx'
eq_(
@@ -220,44 +218,43 @@ class MaterializedViewReflectionTest(
testtable, "before_drop", sa.DDL("DROP VIEW test_regview")
)
- def test_mview_is_reflected(self):
+ def test_mview_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
+ table = Table("test_mview", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected mview didn't equal expected columns",
)
- def test_mview_select(self):
+ def test_mview_select(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
- with testing.db.connect() as conn:
- eq_(conn.execute(table.select()).fetchall(), [(89, "d1")])
+ table = Table("test_mview", metadata, autoload_with=connection)
+ eq_(connection.execute(table.select()).fetchall(), [(89, "d1")])
- def test_get_view_names(self):
- insp = inspect(testing.db)
+ def test_get_view_names(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_plain(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
- def test_get_view_names_plain_string(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain_string(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names(include="plain")), set(["test_regview"]))
- def test_get_view_names_materialized(self):
- insp = inspect(testing.db)
+ def test_get_view_names_materialized(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("materialized",))),
set(["test_mview"]),
)
- def test_get_view_names_reflection_cache_ok(self):
- insp = inspect(testing.db)
+ def test_get_view_names_reflection_cache_ok(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
@@ -267,12 +264,12 @@ class MaterializedViewReflectionTest(
)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_empty(self):
- insp = inspect(testing.db)
+ def test_get_view_names_empty(self, connection):
+ insp = inspect(connection)
assert_raises(ValueError, insp.get_view_names, include=())
- def test_get_view_definition(self):
- insp = inspect(testing.db)
+ def test_get_view_definition(self, connection):
+ insp = inspect(connection)
eq_(
re.sub(
r"[\n\t ]+",
@@ -290,7 +287,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as con:
for ddl in [
'CREATE SCHEMA "SomeSchema"',
@@ -334,7 +331,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as con:
con.exec_driver_sql("DROP TABLE testtable")
con.exec_driver_sql("DROP TABLE test_schema.testtable")
@@ -350,9 +347,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
- def test_table_is_reflected(self):
+ def test_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["question", "answer"]),
@@ -360,9 +357,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
assert isinstance(table.c.answer.type, Integer)
- def test_domain_is_reflected(self):
+ def test_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"42",
@@ -372,28 +369,28 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
not table.columns.answer.nullable
), "Expected reflected column to not be nullable."
- def test_enum_domain_is_reflected(self):
+ def test_enum_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("enum_test", metadata, autoload_with=testing.db)
+ table = Table("enum_test", metadata, autoload_with=connection)
eq_(table.c.data.type.enums, ["test"])
- def test_array_domain_is_reflected(self):
+ def test_array_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("array_test", metadata, autoload_with=testing.db)
+ table = Table("array_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, ARRAY)
eq_(table.c.data.type.item_type.__class__, INTEGER)
- def test_quoted_remote_schema_domain_is_reflected(self):
+ def test_quoted_remote_schema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("quote_test", metadata, autoload_with=testing.db)
+ table = Table("quote_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, INTEGER)
- def test_table_is_reflected_test_schema(self):
+ def test_table_is_reflected_test_schema(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
@@ -403,12 +400,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
assert isinstance(table.c.anything.type, Integer)
- def test_schema_domain_is_reflected(self):
+ def test_schema_domain_is_reflected(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
@@ -420,9 +417,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_crosschema_domain_is_reflected(self):
+ def test_crosschema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("crosschema", metadata, autoload_with=testing.db)
+ table = Table("crosschema", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"0",
@@ -432,7 +429,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_unknown_types(self):
+ def test_unknown_types(self, connection):
from sqlalchemy.dialects.postgresql import base
ischema_names = base.PGDialect.ischema_names
@@ -440,13 +437,13 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
try:
m2 = MetaData()
assert_raises(
- exc.SAWarning, Table, "testtable", m2, autoload_with=testing.db
+ exc.SAWarning, Table, "testtable", m2, autoload_with=connection
)
@testing.emits_warning("Did not recognize type")
def warns():
m3 = MetaData()
- t3 = Table("testtable", m3, autoload_with=testing.db)
+ t3 = Table("testtable", m3, autoload_with=connection)
assert t3.c.answer.type.__class__ == sa.types.NullType
finally:
@@ -471,9 +468,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
subject = Table("subject", meta2, autoload_with=connection)
eq_(subject.primary_key.columns.keys(), ["p2", "p1"])
- @testing.provide_metadata
- def test_pg_weirdchar_reflection(self):
- meta1 = self.metadata
+ def test_pg_weirdchar_reflection(self, metadata, connection):
+ meta1 = metadata
subject = Table(
"subject", meta1, Column("id$", Integer, primary_key=True)
)
@@ -483,101 +479,91 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("ref", Integer, ForeignKey("subject.id$")),
)
- meta1.create_all(testing.db)
+ meta1.create_all(connection)
meta2 = MetaData()
- subject = Table("subject", meta2, autoload_with=testing.db)
- referer = Table("referer", meta2, autoload_with=testing.db)
+ subject = Table("subject", meta2, autoload_with=connection)
+ referer = Table("referer", meta2, autoload_with=connection)
self.assert_(
(subject.c["id$"] == referer.c.ref).compare(
subject.join(referer).onclause
)
)
- @testing.provide_metadata
- def test_reflect_default_over_128_chars(self):
+ def test_reflect_default_over_128_chars(self, metadata, connection):
Table(
"t",
- self.metadata,
+ metadata,
Column("x", String(200), server_default="abcd" * 40),
- ).create(testing.db)
+ ).create(connection)
m = MetaData()
- t = Table("t", m, autoload_with=testing.db)
+ t = Table("t", m, autoload_with=connection)
eq_(
t.c.x.server_default.arg.text,
"'%s'::character varying" % ("abcd" * 40),
)
- @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure")
- @testing.provide_metadata
- def test_renamed_sequence_reflection(self):
- metadata = self.metadata
+ def test_renamed_sequence_reflection(self, metadata, connection):
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db, implicit_returning=False)
+ t2 = Table("t", m2, autoload_with=connection, implicit_returning=False)
eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)")
- with testing.db.begin() as conn:
- r = conn.execute(t2.insert())
- eq_(r.inserted_primary_key, (1,))
+ r = connection.execute(t2.insert())
+ eq_(r.inserted_primary_key, (1,))
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t_id_seq rename to foobar_id_seq"
- )
+ connection.exec_driver_sql(
+ "alter table t_id_seq rename to foobar_id_seq"
+ )
m3 = MetaData()
- t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False)
+ t3 = Table("t", m3, autoload_with=connection, implicit_returning=False)
eq_(
t3.c.id.server_default.arg.text,
"nextval('foobar_id_seq'::regclass)",
)
- with testing.db.begin() as conn:
- r = conn.execute(t3.insert())
- eq_(r.inserted_primary_key, (2,))
+ r = connection.execute(t3.insert())
+ eq_(r.inserted_primary_key, (2,))
- @testing.provide_metadata
- def test_altered_type_autoincrement_pk_reflection(self):
- metadata = self.metadata
+ def test_altered_type_autoincrement_pk_reflection(
+ self, metadata, connection
+ ):
+ metadata = metadata
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t alter column id type varchar(50)"
- )
+ connection.exec_driver_sql(
+ "alter table t alter column id type varchar(50)"
+ )
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_(t2.c.id.autoincrement, False)
eq_(t2.c.x.autoincrement, False)
- @testing.provide_metadata
- def test_renamed_pk_reflection(self):
- metadata = self.metadata
+ def test_renamed_pk_reflection(self, metadata, connection):
+ metadata = metadata
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("alter table t rename id to t_id")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("alter table t rename id to t_id")
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_([c.name for c in t2.primary_key], ["t_id"])
- @testing.provide_metadata
- def test_has_temporary_table(self):
- assert not inspect(testing.db).has_table("some_temp_table")
+ def test_has_temporary_table(self, metadata, connection):
+ assert not inspect(connection).has_table("some_temp_table")
user_tmp = Table(
"some_temp_table",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
prefixes=["TEMPORARY"],
)
- user_tmp.create(testing.db)
- assert inspect(testing.db).has_table("some_temp_table")
+ user_tmp.create(connection)
+ assert inspect(connection).has_table("some_temp_table")
def test_cross_schema_reflection_one(self, metadata, connection):
@@ -898,19 +884,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
A_table.create(connection, checkfirst=True)
assert inspect(connection).has_table("A")
- def test_uppercase_lowercase_sequence(self):
+ def test_uppercase_lowercase_sequence(self, connection):
a_seq = Sequence("a")
A_seq = Sequence("A")
- a_seq.create(testing.db)
- assert testing.db.dialect.has_sequence(testing.db, "a")
- assert not testing.db.dialect.has_sequence(testing.db, "A")
- A_seq.create(testing.db, checkfirst=True)
- assert testing.db.dialect.has_sequence(testing.db, "A")
+ a_seq.create(connection)
+ assert connection.dialect.has_sequence(connection, "a")
+ assert not connection.dialect.has_sequence(connection, "A")
+ A_seq.create(connection, checkfirst=True)
+ assert connection.dialect.has_sequence(connection, "A")
- a_seq.drop(testing.db)
- A_seq.drop(testing.db)
+ a_seq.drop(connection)
+ A_seq.drop(connection)
def test_index_reflection(self, metadata, connection):
"""Reflecting expression-based indexes should warn"""
@@ -960,11 +946,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_index_reflection_partial(self, connection):
+ def test_index_reflection_partial(self, metadata, connection):
"""Reflect the filter defintion on partial indexes"""
- metadata = self.metadata
+ metadata = metadata
t1 = Table(
"table1",
@@ -978,7 +963,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
metadata.create_all(connection)
- ind = testing.db.dialect.get_indexes(connection, t1, None)
+ ind = connection.dialect.get_indexes(connection, t1, None)
partial_definitions = []
for ix in ind:
@@ -1073,15 +1058,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
compile_exprs(r3.expressions),
)
- @testing.provide_metadata
- def test_index_reflection_modified(self):
+ def test_index_reflection_modified(self, metadata, connection):
"""reflect indexes when a column name has changed - PG 9
does not update the name of the column in the index def.
[ticket:2141]
"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
@@ -1089,26 +1073,21 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
- conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
+ connection.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {"name": "idx1", "unique": False, "column_names": ["y"]}
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
+ ind = connection.dialect.get_indexes(connection, "t", None)
+ expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
- eq_(ind, expected)
+ eq_(ind, expected)
- @testing.fails_if("postgresql < 8.2", "reloptions not supported")
- @testing.provide_metadata
- def test_index_reflection_with_storage_options(self):
+ def test_index_reflection_with_storage_options(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
@@ -1116,70 +1095,63 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
- )
+ connection.exec_driver_sql(
+ "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
+ )
- ind = testing.db.dialect.get_indexes(conn, "t", None)
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {
- "postgresql_with": {"fillfactor": "50"}
- },
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_with": {"fillfactor": "50"}},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["with"],
- {"fillfactor": "50"},
- )
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["with"],
+ {"fillfactor": "50"},
+ )
- @testing.provide_metadata
- def test_index_reflection_with_access_method(self):
+ def test_index_reflection_with_access_method(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
-
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", ARRAY(Integer)),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {"postgresql_using": "gin"},
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["using"],
- "gin",
- )
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_using": "gin"},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["using"],
+ "gin",
+ )
@testing.skip_if("postgresql < 11.0", "indnkeyatts not supported")
def test_index_reflection_with_include(self, metadata, connection):
@@ -1199,7 +1171,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
# [{'column_names': ['x', 'name'],
# 'name': 'idx1', 'unique': False}]
- ind = testing.db.dialect.get_indexes(connection, "t", None)
+ ind = connection.dialect.get_indexes(connection, "t", None)
eq_(
ind,
[
@@ -1286,15 +1258,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
for fk in fks:
eq_(fk, fk_ref[fk["name"]])
- @testing.provide_metadata
- def test_inspect_enums_schema(self, connection):
+ def test_inspect_enums_schema(self, metadata, connection):
enum_type = postgresql.ENUM(
"sad",
"ok",
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
enum_type.create(connection)
inspector = inspect(connection)
@@ -1310,13 +1281,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums(self):
+ def test_inspect_enums(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
[
@@ -1329,17 +1299,16 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive(self):
+ def test_inspect_enums_case_sensitive(self, metadata, connection):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
for enum in "lower_case", "UpperCase", "Name.With.Dot":
@@ -1350,11 +1319,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"CapsTwo",
name=enum,
schema=schema,
- metadata=self.metadata,
+ metadata=metadata,
)
- self.metadata.create_all(testing.db)
- inspector = inspect(testing.db)
+ metadata.create_all(connection)
+ inspector = inspect(connection)
for schema in None, "test_schema", "TestSchema":
eq_(
sorted(
@@ -1382,17 +1351,18 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive_from_table(self):
+ def test_inspect_enums_case_sensitive_from_table(
+ self, metadata, connection
+ ):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
counter = itertools.count()
@@ -1403,19 +1373,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"CapsOne",
"CapsTwo",
name=enum,
- metadata=self.metadata,
+ metadata=metadata,
schema=schema,
)
Table(
"t%d" % next(counter),
- self.metadata,
+ metadata,
Column("q", enum_type),
)
- self.metadata.create_all(testing.db)
+ metadata.create_all(connection)
- inspector = inspect(testing.db)
+ inspector = inspect(connection)
counter = itertools.count()
for enum in "lower_case", "UpperCase", "Name.With.Dot":
for schema in None, "test_schema", "TestSchema":
@@ -1439,10 +1409,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_star(self):
+ def test_inspect_enums_star(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
schema_enum_type = postgresql.ENUM(
"sad",
@@ -1450,11 +1419,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
- enum_type.create(testing.db)
- schema_enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ schema_enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
@@ -1486,11 +1455,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty(self):
- enum_type = postgresql.ENUM(name="empty", metadata=self.metadata)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ def test_inspect_enum_empty(self, metadata, connection):
+ enum_type = postgresql.ENUM(name="empty", metadata=metadata)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
@@ -1504,13 +1472,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty_from_table(self):
+ def test_inspect_enum_empty_from_table(self, metadata, connection):
Table(
- "t", self.metadata, Column("x", postgresql.ENUM(name="empty"))
- ).create(testing.db)
+ "t", metadata, Column("x", postgresql.ENUM(name="empty"))
+ ).create(connection)
- t = Table("t", MetaData(), autoload_with=testing.db)
+ t = Table("t", MetaData(), autoload_with=connection)
eq_(t.c.x.type.enums, [])
def test_reflection_with_unique_constraint(self, metadata, connection):
@@ -1749,12 +1716,12 @@ class CustomTypeReflectionTest(fixtures.TestBase):
ischema_names = None
- def setup(self):
+ def setup_test(self):
ischema_names = postgresql.PGDialect.ischema_names
postgresql.PGDialect.ischema_names = ischema_names.copy()
self.ischema_names = ischema_names
- def teardown(self):
+ def teardown_test(self):
postgresql.PGDialect.ischema_names = self.ischema_names
self.ischema_names = None
@@ -1788,55 +1755,51 @@ class IntervalReflectionTest(fixtures.TestBase):
__only_on__ = "postgresql"
__backend__ = True
- def test_interval_types(self):
- for sym in [
- "YEAR",
- "MONTH",
- "DAY",
- "HOUR",
- "MINUTE",
- "SECOND",
- "YEAR TO MONTH",
- "DAY TO HOUR",
- "DAY TO MINUTE",
- "DAY TO SECOND",
- "HOUR TO MINUTE",
- "HOUR TO SECOND",
- "MINUTE TO SECOND",
- ]:
- self._test_interval_symbol(sym)
-
- @testing.provide_metadata
- def _test_interval_symbol(self, sym):
+ @testing.combinations(
+ ("YEAR",),
+ ("MONTH",),
+ ("DAY",),
+ ("HOUR",),
+ ("MINUTE",),
+ ("SECOND",),
+ ("YEAR TO MONTH",),
+ ("DAY TO HOUR",),
+ ("DAY TO MINUTE",),
+ ("DAY TO SECOND",),
+ ("HOUR TO MINUTE",),
+ ("HOUR TO SECOND",),
+ ("MINUTE TO SECOND",),
+ argnames="sym",
+ )
+ def test_interval_types(self, sym, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(fields=sym)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, sym.lower())
eq_(columns["data1"]["type"].precision, None)
- @testing.provide_metadata
- def test_interval_precision(self):
+ def test_interval_precision(self, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(precision=6)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, None)
@@ -1871,8 +1834,8 @@ class IdentityReflectionTest(fixtures.TablesTest):
Column("id4", SmallInteger, Identity()),
)
- def test_reflect_identity(self):
- insp = inspect(testing.db)
+ def test_reflect_identity(self, connection):
+ insp = inspect(connection)
default = dict(
always=False,
start=1,
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index e8a1876c7..6202f8f86 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -49,7 +49,6 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.sql import operators
from sqlalchemy.sql import sqltypes
-from sqlalchemy.testing import engines
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
@@ -156,8 +155,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql > 8.3"
- @testing.provide_metadata
- def test_create_table(self, connection):
+ def test_create_table(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"table",
@@ -177,8 +175,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[(1, "two"), (2, "three"), (3, "three")],
)
- @testing.combinations(None, "foo")
- def test_create_table_schema_translate_map(self, symbol_name):
+ @testing.combinations(None, "foo", argnames="symbol_name")
+ def test_create_table_schema_translate_map(self, connection, symbol_name):
# note we can't use the fixture here because it will not drop
# from the correct schema
metadata = MetaData()
@@ -199,35 +197,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
),
schema=symbol_name,
)
- with testing.db.begin() as conn:
- conn = conn.execution_options(
- schema_translate_map={symbol_name: testing.config.test_schema}
- )
- t1.create(conn)
- assert "schema_enum" in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.create(conn, checkfirst=True)
+ conn = connection.execution_options(
+ schema_translate_map={symbol_name: testing.config.test_schema}
+ )
+ t1.create(conn)
+ assert "schema_enum" in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.create(conn, checkfirst=True)
- conn.execute(t1.insert(), value="two")
- conn.execute(t1.insert(), value="three")
- conn.execute(t1.insert(), value="three")
- eq_(
- conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
- [(1, "two"), (2, "three"), (3, "three")],
- )
+ conn.execute(t1.insert(), value="two")
+ conn.execute(t1.insert(), value="three")
+ conn.execute(t1.insert(), value="three")
+ eq_(
+ conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
+ [(1, "two"), (2, "three"), (3, "three")],
+ )
- t1.drop(conn)
- assert "schema_enum" not in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.drop(conn, checkfirst=True)
+ t1.drop(conn)
+ assert "schema_enum" not in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.drop(conn, checkfirst=True)
def test_name_required(self, metadata, connection):
etype = Enum("four", "five", "six", metadata=metadata)
@@ -270,8 +263,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[util.u("réveillé"), util.u("drôle"), util.u("S’il")],
)
- @testing.provide_metadata
- def test_non_native_enum(self, connection):
+ def test_non_native_enum(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
@@ -290,10 +282,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
@@ -307,8 +299,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": "two"})
eq_(connection.scalar(select(t1.c.bar)), "two")
- @testing.provide_metadata
- def test_non_native_enum_w_unicode(self, connection):
+ def test_non_native_enum_w_unicode(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
@@ -326,10 +317,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
@@ -346,8 +337,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": util.u("Ü")})
eq_(connection.scalar(select(t1.c.bar)), util.u("Ü"))
- @testing.provide_metadata
- def test_disable_create(self):
+ def test_disable_create(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM(
@@ -357,13 +347,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
t1 = Table("e1", metadata, Column("c1", e1))
# table can be created separately
# without conflict
- e1.create(bind=testing.db)
- t1.create(testing.db)
- t1.drop(testing.db)
- e1.drop(bind=testing.db)
+ e1.create(bind=connection)
+ t1.create(connection)
+ t1.drop(connection)
+ e1.drop(bind=connection)
- @testing.provide_metadata
- def test_dont_keep_checking(self, connection):
+ def test_dont_keep_checking(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM("one", "two", "three", name="myenum")
@@ -560,11 +549,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
e["name"] for e in inspect(connection).get_enums()
]
- def test_non_native_dialect(self):
- engine = engines.testing_engine()
+ def test_non_native_dialect(self, metadata, testing_engine):
+ engine = testing_engine()
engine.connect()
engine.dialect.supports_native_enum = False
- metadata = MetaData()
t1 = Table(
"foo",
metadata,
@@ -583,21 +571,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
def go():
t1.create(engine)
- try:
- self.assert_sql(
- engine,
- go,
- [
- (
- "CREATE TABLE foo (bar "
- "VARCHAR(5), CONSTRAINT myenum CHECK "
- "(bar IN ('one', 'two', 'three')))",
- {},
- )
- ],
- )
- finally:
- metadata.drop_all(engine)
+ self.assert_sql(
+ engine,
+ go,
+ [
+ (
+ "CREATE TABLE foo (bar "
+ "VARCHAR(5), CONSTRAINT myenum CHECK "
+ "(bar IN ('one', 'two', 'three')))",
+ {},
+ )
+ ],
+ )
def test_standalone_enum(self, connection, metadata):
etype = Enum(
@@ -605,26 +590,26 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
etype.create(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
etype.drop(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
metadata.create_all(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
metadata.drop_all(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
- def test_no_support(self):
+ def test_no_support(self, testing_engine):
def server_version_info(self):
return (8, 2)
- e = engines.testing_engine()
+ e = testing_engine()
dialect = e.dialect
dialect._get_server_version_info = server_version_info
@@ -692,8 +677,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
eq_(t2.c.value2.type.name, "fourfivesixtype")
eq_(t2.c.value2.type.schema, "test_schema")
- @testing.provide_metadata
- def test_custom_subclass(self, connection):
+ def test_custom_subclass(self, metadata, connection):
class MyEnum(TypeDecorator):
impl = Enum("oneHI", "twoHI", "threeHI", name="myenum")
@@ -708,13 +692,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
return value
t1 = Table("table1", self.metadata, Column("data", MyEnum()))
- self.metadata.create_all(testing.db)
+ self.metadata.create_all(connection)
connection.execute(t1.insert(), {"data": "two"})
eq_(connection.scalar(select(t1.c.data)), "twoHITHERE")
- @testing.provide_metadata
- def test_generic_w_pg_variant(self, connection):
+ def test_generic_w_pg_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
@@ -752,8 +735,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
e["name"] for e in inspect(connection).get_enums()
]
- @testing.provide_metadata
- def test_generic_w_some_other_variant(self, connection):
+ def test_generic_w_some_other_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
@@ -809,26 +791,28 @@ class RegClassTest(fixtures.TestBase):
__only_on__ = "postgresql"
__backend__ = True
- @staticmethod
- def _scalar(expression):
- with testing.db.connect() as conn:
- return conn.scalar(select(expression))
+ @testing.fixture()
+ def scalar(self, connection):
+ def go(expression):
+ return connection.scalar(select(expression))
- def test_cast_name(self):
- eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+ return go
- def test_cast_path(self):
+ def test_cast_name(self, scalar):
+ eq_(scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+
+ def test_cast_path(self, scalar):
eq_(
- self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
+ scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
"pg_class",
)
- def test_cast_oid(self):
+ def test_cast_oid(self, scalar):
regclass = cast("pg_class", postgresql.REGCLASS)
- oid = self._scalar(cast(regclass, postgresql.OID))
+ oid = scalar(cast(regclass, postgresql.OID))
assert isinstance(oid, int)
eq_(
- self._scalar(
+ scalar(
cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS)
),
"pg_class",
@@ -1339,13 +1323,12 @@ class ArrayRoundTripTest(object):
Column("dimarr", ProcValue),
)
- def _fixture_456(self, table):
- with testing.db.begin() as conn:
- conn.execute(table.insert(), intarr=[4, 5, 6])
+ def _fixture_456(self, table, connection):
+ connection.execute(table.insert(), intarr=[4, 5, 6])
- def test_reflect_array_column(self):
+ def test_reflect_array_column(self, connection):
metadata2 = MetaData()
- tbl = Table("arrtable", metadata2, autoload_with=testing.db)
+ tbl = Table("arrtable", metadata2, autoload_with=connection)
assert isinstance(tbl.c.intarr.type, self.ARRAY)
assert isinstance(tbl.c.strarr.type, self.ARRAY)
assert isinstance(tbl.c.intarr.type.item_type, Integer)
@@ -1564,7 +1547,7 @@ class ArrayRoundTripTest(object):
def test_array_getitem_single_exec(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(connection.scalar(select(arrtable.c.intarr[2])), 5)
connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7}))
eq_(connection.scalar(select(arrtable.c.intarr[2])), 7)
@@ -1654,11 +1637,10 @@ class ArrayRoundTripTest(object):
set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]),
)
- def test_array_plus_native_enum_create(self):
- m = MetaData()
+ def test_array_plus_native_enum_create(self, metadata, connection):
t = Table(
"t",
- m,
+ metadata,
Column(
"data_1",
self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")),
@@ -1669,13 +1651,13 @@ class ArrayRoundTripTest(object):
),
)
- t.create(testing.db)
+ t.create(connection)
eq_(
- set(e["name"] for e in inspect(testing.db).get_enums()),
+ set(e["name"] for e in inspect(connection).get_enums()),
set(["my_enum_1", "my_enum_2"]),
)
- t.drop(testing.db)
- eq_(inspect(testing.db).get_enums(), [])
+ t.drop(connection)
+ eq_(inspect(connection).get_enums(), [])
class CoreArrayRoundTripTest(
@@ -1690,33 +1672,35 @@ class PGArrayRoundTripTest(
):
ARRAY = postgresql.ARRAY
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_undim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_undim_array_contains_typed_exec(self, struct, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(arrtable.c.intarr).where(
- arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(arrtable.c.intarr).where(
+ arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_dim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_dim_array_contains_typed_exec(self, struct, connection):
dim_arrtable = self.tables.dim_arrtable
- self._fixture_456(dim_arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(dim_arrtable.c.intarr).where(
- dim_arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(dim_arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(dim_arrtable.c.intarr).where(
+ dim_arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
def test_array_contained_by_exec(self, connection):
arrtable = self.tables.arrtable
@@ -1730,7 +1714,7 @@ class PGArrayRoundTripTest(
def test_undim_array_empty(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(
connection.scalar(
select(arrtable.c.intarr).where(arrtable.c.intarr.contains([]))
@@ -1782,8 +1766,9 @@ class ArrayEnum(fixtures.TestBase):
sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls"
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
- @testing.provide_metadata
- def test_raises_non_native_enums(self, array_cls, enum_cls):
+ def test_raises_non_native_enums(
+ self, metadata, connection, array_cls, enum_cls
+ ):
Table(
"my_table",
self.metadata,
@@ -1808,7 +1793,7 @@ class ArrayEnum(fixtures.TestBase):
"for ARRAY of non-native ENUM; please specify "
"create_constraint=False on this Enum datatype.",
self.metadata.create_all,
- testing.db,
+ connection,
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
@@ -1818,8 +1803,7 @@ class ArrayEnum(fixtures.TestBase):
(_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
argnames="array_cls",
)
- @testing.provide_metadata
- def test_array_of_enums(self, array_cls, enum_cls, connection):
+ def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
tbl = Table(
"enum_table",
self.metadata,
@@ -1875,8 +1859,7 @@ class ArrayJSON(fixtures.TestBase):
@testing.combinations(
sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls"
)
- @testing.provide_metadata
- def test_array_of_json(self, array_cls, json_cls, connection):
+ def test_array_of_json(self, array_cls, json_cls, metadata, connection):
tbl = Table(
"json_table",
self.metadata,
@@ -1982,19 +1965,38 @@ class HashableFlagORMTest(fixtures.TestBase):
},
],
),
+ (
+ "HSTORE",
+ postgresql.HSTORE(),
+ [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
+ testing.requires.hstore,
+ ),
+ (
+ "JSONB",
+ postgresql.JSONB(),
+ [
+ {"a": "1", "b": "2", "c": "3"},
+ {
+ "d": "4",
+ "e": {"e1": "5", "e2": "6"},
+ "f": {"f1": [9, 10, 11]},
+ },
+ ],
+ testing.requires.postgresql_jsonb,
+ ),
+ argnames="type_,data",
id_="iaa",
)
- @testing.provide_metadata
- def test_hashable_flag(self, type_, data):
- Base = declarative_base(metadata=self.metadata)
+ def test_hashable_flag(self, metadata, connection, type_, data):
+ Base = declarative_base(metadata=metadata)
class A(Base):
__tablename__ = "a1"
id = Column(Integer, primary_key=True)
data = Column(type_)
- Base.metadata.create_all(testing.db)
- s = Session(testing.db)
+ Base.metadata.create_all(connection)
+ s = Session(connection)
s.add_all([A(data=elem) for elem in data])
s.commit()
@@ -2006,27 +2008,6 @@ class HashableFlagORMTest(fixtures.TestBase):
list(enumerate(data, 1)),
)
- @testing.requires.hstore
- def test_hstore(self):
- self.test_hashable_flag(
- postgresql.HSTORE(),
- [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
- )
-
- @testing.requires.postgresql_jsonb
- def test_jsonb(self):
- self.test_hashable_flag(
- postgresql.JSONB(),
- [
- {"a": "1", "b": "2", "c": "3"},
- {
- "d": "4",
- "e": {"e1": "5", "e2": "6"},
- "f": {"f1": [9, 10, 11]},
- },
- ],
- )
-
class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
@@ -2108,14 +2089,14 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
return table
- def test_reflection(self, special_types_table):
+ def test_reflection(self, special_types_table, connection):
# cheat so that the "strict type check"
# works
special_types_table.c.year_interval.type = postgresql.INTERVAL()
special_types_table.c.month_interval.type = postgresql.INTERVAL()
m = MetaData()
- t = Table("sometable", m, autoload_with=testing.db)
+ t = Table("sometable", m, autoload_with=connection)
self.assert_tables_equal(special_types_table, t, strict_types=True)
assert t.c.plain_interval.type.precision is None
@@ -2210,7 +2191,7 @@ class UUIDTest(fixtures.TestBase):
class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -2494,17 +2475,16 @@ class HStoreRoundTripTest(fixtures.TablesTest):
Column("data", HSTORE),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- conn.execute(
- data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
- {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
- {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
- {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
- )
+ connection.execute(
+ data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
+ {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
+ {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
+ {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
+ )
def _assert_data(self, compare, conn):
data = conn.execute(
@@ -2514,26 +2494,32 @@ class HStoreRoundTripTest(fixtures.TablesTest):
).fetchall()
eq_([d for d, in data], compare)
- def _test_insert(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- )
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ def _test_insert(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ )
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _non_native_engine(self):
- if testing.requires.psycopg2_native_hstore.enabled:
- engine = engines.testing_engine(
- options=dict(use_native_hstore=False)
- )
+ @testing.fixture
+ def non_native_hstore_connection(self, testing_engine):
+ local_engine = testing.requires.psycopg2_native_hstore.enabled
+
+ if local_engine:
+ engine = testing_engine(options=dict(use_native_hstore=False))
else:
engine = testing.db
- engine.connect().close()
- return engine
- def test_reflect(self):
- insp = inspect(testing.db)
+ conn = engine.connect()
+ trans = conn.begin()
+ yield conn
+ try:
+ trans.rollback()
+ finally:
+ conn.close()
+
+ def test_reflect(self, connection):
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[2]["type"], HSTORE)
@@ -2548,106 +2534,88 @@ class HStoreRoundTripTest(fixtures.TablesTest):
eq_(connection.scalar(select(expr)), "3")
@testing.requires.psycopg2_native_hstore
- def test_insert_native(self):
- engine = testing.db
- self._test_insert(engine)
+ def test_insert_native(self, connection):
+ self._test_insert(connection)
- def test_insert_python(self):
- engine = self._non_native_engine()
- self._test_insert(engine)
+ def test_insert_python(self, non_native_hstore_connection):
+ self._test_insert(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_criterion_native(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_native(self, connection):
+ self._fixture_data(connection)
+ self._test_criterion(connection)
- def test_criterion_python(self):
- engine = self._non_native_engine()
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_python(self, non_native_hstore_connection):
+ self._fixture_data(non_native_hstore_connection)
+ self._test_criterion(non_native_hstore_connection)
- def _test_criterion(self, engine):
+ def _test_criterion(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"] == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+ result = connection.execute(
+ select(data_table.c.data).where(data_table.c.data["k1"] == "r3v1")
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
- def _test_fixed_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(["key1", "key2", "key3"]),
- array(["value1", "value2", "value3"]),
- )
- )
- eq_(
- conn.scalar(s),
- {"key1": "value1", "key2": "value2", "key3": "value3"},
+ def _test_fixed_round_trip(self, connection):
+ s = select(
+ hstore(
+ array(["key1", "key2", "key3"]),
+ array(["value1", "value2", "value3"]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
- def test_fixed_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_python(self, non_native_hstore_connection):
+ self._test_fixed_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_fixed_round_trip_native(self):
- engine = testing.db
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_native(self, connection):
+ self._test_fixed_round_trip(connection)
- def _test_unicode_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- )
- )
- eq_(
- conn.scalar(s),
- {
- util.u("réveillé"): util.u("réveillé"),
- util.u("drôle"): util.u("drôle"),
- util.u("S’il"): util.u("S’il"),
- },
+ def _test_unicode_round_trip(self, connection):
+ s = select(
+ hstore(
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {
+ util.u("réveillé"): util.u("réveillé"),
+ util.u("drôle"): util.u("drôle"),
+ util.u("S’il"): util.u("S’il"),
+ },
+ )
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_python(self, non_native_hstore_connection):
+ self._test_unicode_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_native(self):
- engine = testing.db
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_native(self, connection):
+ self._test_unicode_round_trip(connection)
- def test_escaped_quotes_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_python(
+ self, non_native_hstore_connection
+ ):
+ self._test_escaped_quotes_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_escaped_quotes_round_trip_native(self):
- engine = testing.db
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_native(self, connection):
+ self._test_escaped_quotes_round_trip(connection)
- def _test_escaped_quotes_round_trip(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
- )
- self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn)
+ def _test_escaped_quotes_round_trip(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
+ )
+ self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], connection)
- def test_orm_round_trip(self):
+ def test_orm_round_trip(self, connection):
from sqlalchemy import orm
class Data(object):
@@ -2656,13 +2624,14 @@ class HStoreRoundTripTest(fixtures.TablesTest):
self.data = data
orm.mapper(Data, self.tables.data_table)
- s = orm.Session(testing.db)
- d = Data(
- name="r1",
- data={"key1": "value1", "key2": "value2", "key3": "value3"},
- )
- s.add(d)
- eq_(s.query(Data.data, Data).all(), [(d.data, d)])
+
+ with orm.Session(connection) as s:
+ d = Data(
+ name="r1",
+ data={"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
+ s.add(d)
+ eq_(s.query(Data.data, Data).all(), [(d.data, d)])
class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
@@ -2671,7 +2640,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
# operator tests
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
table = Table(
"data_table",
MetaData(),
@@ -2852,10 +2821,10 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
def test_actual_type(self):
eq_(str(self._col_type()), self._col_str)
- def test_reflect(self):
+ def test_reflect(self, connection):
from sqlalchemy import inspect
- insp = inspect(testing.db)
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[0]["type"], self._col_type)
@@ -2986,8 +2955,8 @@ class _DateTimeTZRangeTests(object):
def tstzs(self):
if self._tstzs is None:
- with testing.db.begin() as conn:
- lower = conn.scalar(func.current_timestamp().select())
+ with testing.db.connect() as connection:
+ lower = connection.scalar(func.current_timestamp().select())
upper = lower + datetime.timedelta(1)
self._tstzs = (lower, upper)
return self._tstzs
@@ -3052,7 +3021,7 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip):
class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -3151,7 +3120,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
Column("nulldata", cls.data_type(none_as_null=True)),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
data = [
@@ -3162,8 +3131,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
{"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
{"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
]
- with engine.begin() as conn:
- conn.execute(data_table.insert(), data)
+ connection.execute(data_table.insert(), data)
return data
def _assert_data(self, compare, conn, column="data"):
@@ -3185,51 +3153,39 @@ class JSONRoundTripTest(fixtures.TablesTest):
).fetchall()
eq_([d for d, in data], [None])
- def _test_insert(self, conn):
- conn.execute(
+ def test_reflect(self, connection):
+ insp = inspect(connection)
+ cols = insp.get_columns("data_table")
+ assert isinstance(cols[2]["type"], self.data_type)
+
+ def test_insert(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
)
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _test_insert_nulls(self, conn):
- conn.execute(
+ def test_insert_nulls(self, connection):
+ connection.execute(
self.tables.data_table.insert(), {"name": "r1", "data": null()}
)
- self._assert_data([None], conn)
+ self._assert_data([None], connection)
- def _test_insert_none_as_null(self, conn):
- conn.execute(
+ def test_insert_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": None},
)
- self._assert_column_is_NULL(conn, column="nulldata")
+ self._assert_column_is_NULL(connection, column="nulldata")
- def _test_insert_nulljson_into_none_as_null(self, conn):
- conn.execute(
+ def test_insert_nulljson_into_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": JSON.NULL},
)
- self._assert_column_is_JSON_NULL(conn, column="nulldata")
-
- def test_reflect(self):
- insp = inspect(testing.db)
- cols = insp.get_columns("data_table")
- assert isinstance(cols[2]["type"], self.data_type)
-
- def test_insert(self, connection):
- self._test_insert(connection)
-
- def test_insert_nulls(self, connection):
- self._test_insert_nulls(connection)
+ self._assert_column_is_JSON_NULL(connection, column="nulldata")
- def test_insert_none_as_null(self, connection):
- self._test_insert_none_as_null(connection)
-
- def test_insert_nulljson_into_none_as_null(self, connection):
- self._test_insert_nulljson_into_none_as_null(connection)
-
- def test_custom_serialize_deserialize(self):
+ def test_custom_serialize_deserialize(self, testing_engine):
import json
def loads(value):
@@ -3242,7 +3198,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
value["x"] = "dumps_y"
return json.dumps(value)
- engine = engines.testing_engine(
+ engine = testing_engine(
options=dict(json_serializer=dumps, json_deserializer=loads)
)
@@ -3250,14 +3206,26 @@ class JSONRoundTripTest(fixtures.TablesTest):
with engine.begin() as conn:
eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
- def test_criterion(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion(self, connection):
+ self._fixture_data(connection)
+ data_table = self.tables.data_table
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext.cast(String) == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
def test_path_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
@@ -3271,8 +3239,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
"postgresql < 9.4", "Improvement in PostgreSQL behavior?"
)
def test_multi_index_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
@@ -3283,20 +3250,18 @@ class JSONRoundTripTest(fixtures.TablesTest):
eq_(result.scalar(), "r6")
def test_query_returned_as_text(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k1"].astext)
).first()
- if engine.dialect.returns_unicode_strings:
+ if connection.dialect.returns_unicode_strings:
assert isinstance(result[0], util.text_type)
else:
assert isinstance(result[0], util.string_types)
def test_query_returned_as_int(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k3"].astext.cast(Integer)).where(
@@ -3305,23 +3270,6 @@ class JSONRoundTripTest(fixtures.TablesTest):
).first()
assert isinstance(result[0], int)
- def _test_criterion(self, engine):
- data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext.cast(String) == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
def test_fixed_round_trip(self, connection):
s = select(
cast(
@@ -3352,42 +3300,41 @@ class JSONRoundTripTest(fixtures.TablesTest):
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
def test_literal(self, connection):
- exp = self._fixture_data(testing.db)
+ exp = self._fixture_data(connection)
result = connection.exec_driver_sql(
"select data from data_table order by name"
)
@@ -3395,11 +3342,10 @@ class JSONRoundTripTest(fixtures.TablesTest):
eq_(len(res), len(exp))
for row, expected in zip(res, exp):
eq_(row[0], expected["data"])
- result.close()
class JSONBTest(JSONTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index 4658b40a8..1926c6065 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -37,6 +37,7 @@ from sqlalchemy import UniqueConstraint
from sqlalchemy import util
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.sqlite import insert
+from sqlalchemy.dialects.sqlite import provision
from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect
from sqlalchemy.engine.url import make_url
from sqlalchemy.schema import CreateTable
@@ -46,6 +47,7 @@ from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import combinations
+from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
@@ -774,7 +776,7 @@ class AttachedDBTest(fixtures.TestBase):
def _fixture(self):
meta = self.metadata
- self.conn = testing.db.connect()
+ self.conn = self.engine.connect()
Table("created", meta, Column("foo", Integer), Column("bar", String))
Table("local_only", meta, Column("q", Integer), Column("p", Integer))
@@ -798,14 +800,20 @@ class AttachedDBTest(fixtures.TestBase):
meta.create_all(self.conn)
return ct
- def setup(self):
- self.conn = testing.db.connect()
+ def setup_test(self):
+ self.engine = engines.testing_engine(options={"use_reaper": False})
+
+ provision._sqlite_post_configure_engine(
+ self.engine.url, self.engine, config.ident
+ )
+ self.conn = self.engine.connect()
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with self.conn.begin():
self.metadata.drop_all(self.conn)
self.conn.close()
+ self.engine.dispose()
def test_no_tables(self):
insp = inspect(self.conn)
@@ -1495,7 +1503,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
__skip_if__ = (full_text_search_missing,)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata, cattable, matchtable
metadata = MetaData()
exec_sql(
@@ -1559,7 +1567,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_expression(self):
@@ -1681,7 +1689,7 @@ class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL):
class ReflectHeadlessFKsTest(fixtures.TestBase):
__only_on__ = "sqlite"
- def setup(self):
+ def setup_test(self):
exec_sql(testing.db, "CREATE TABLE a (id INTEGER PRIMARY KEY)")
# this syntax actually works on other DBs perhaps we'd want to add
# tests to test_reflection
@@ -1689,7 +1697,7 @@ class ReflectHeadlessFKsTest(fixtures.TestBase):
testing.db, "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)"
)
- def teardown(self):
+ def teardown_test(self):
exec_sql(testing.db, "drop table b")
exec_sql(testing.db, "drop table a")
@@ -1728,7 +1736,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
__only_on__ = "sqlite"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)")
@@ -1876,7 +1884,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for name in [
"implicit_referrer_comp_fake",
@@ -2370,7 +2378,7 @@ class SavepointTest(fixtures.TablesTest):
@classmethod
def setup_bind(cls):
- engine = engines.testing_engine(options={"use_reaper": False})
+ engine = engines.testing_engine(options={"scope": "class"})
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
@@ -2579,7 +2587,7 @@ class TypeReflectionTest(fixtures.TestBase):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "sqlite"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py
index 396b48aa4..baa766d48 100644
--- a/test/engine/test_ddlevents.py
+++ b/test/engine/test_ddlevents.py
@@ -21,7 +21,7 @@ from sqlalchemy.testing.schema import Table
class DDLEventTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.bind = engines.mock_engine()
self.metadata = MetaData()
self.table = Table("t", self.metadata, Column("id", Integer))
@@ -374,7 +374,7 @@ class DDLEventTest(fixtures.TestBase):
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py
index a18cf756b..0a2c9abe5 100644
--- a/test/engine/test_deprecations.py
+++ b/test/engine/test_deprecations.py
@@ -965,7 +965,7 @@ class TransactionTest(fixtures.TablesTest):
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -1021,18 +1021,18 @@ def MockDBAPI(): # noqa
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _queuepool_fixture(self, **kw):
@@ -1597,7 +1597,7 @@ class EngineEventsTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -1650,6 +1650,7 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(
engine, "before_cursor_execute", cursor_execute, retval=True
)
+
with testing.expect_deprecated(
r"The argument signature for the "
r"\"ConnectionEvents.before_execute\" event listener",
@@ -1676,11 +1677,12 @@ class EngineEventsTest(fixtures.TestBase):
r"The argument signature for the "
r"\"ConnectionEvents.after_execute\" event listener",
):
- e1.execute(select(1))
+ result = e1.execute(select(1))
+ result.close()
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 21d4e06e0..a1e4ea218 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -43,7 +43,6 @@ from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing.mock import call
from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.mock import patch
@@ -94,13 +93,13 @@ class ExecuteTest(fixtures.TablesTest):
).default_from()
)
- conn = testing.db.connect()
- result = (
- conn.execution_options(no_parameters=True)
- .exec_driver_sql(stmt)
- .scalar()
- )
- eq_(result, "%")
+ with testing.db.connect() as conn:
+ result = (
+ conn.execution_options(no_parameters=True)
+ .exec_driver_sql(stmt)
+ .scalar()
+ )
+ eq_(result, "%")
def test_raw_positional_invalid(self, connection):
assert_raises_message(
@@ -261,16 +260,15 @@ class ExecuteTest(fixtures.TablesTest):
(4, "sally"),
]
- @testing.engines.close_open_connections
def test_exception_wrapping_dbapi(self):
- conn = testing.db.connect()
- # engine does not have exec_driver_sql
- assert_raises_message(
- tsa.exc.DBAPIError,
- r"not_a_valid_statement",
- conn.exec_driver_sql,
- "not_a_valid_statement",
- )
+ with testing.db.connect() as conn:
+ # engine does not have exec_driver_sql
+ assert_raises_message(
+ tsa.exc.DBAPIError,
+ r"not_a_valid_statement",
+ conn.exec_driver_sql,
+ "not_a_valid_statement",
+ )
@testing.requires.sqlite
def test_exception_wrapping_non_dbapi_error(self):
@@ -864,12 +862,10 @@ class CompiledCacheTest(fixtures.TestBase):
["sqlite", "mysql", "postgresql"],
"uses blob value that is problematic for some DBAPIs",
)
- @testing.provide_metadata
- def test_cache_noleak_on_statement_values(self, connection):
+ def test_cache_noleak_on_statement_values(self, metadata, connection):
# This is a non regression test for an object reference leak caused
# by the compiled_cache.
- metadata = self.metadata
photo = Table(
"photo",
metadata,
@@ -1040,7 +1036,19 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
__requires__ = ("schemas",)
__backend__ = True
- def test_create_table(self):
+ @testing.fixture
+ def plain_tables(self, metadata):
+ t1 = Table(
+ "t1", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t2 = Table(
+ "t2", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t3 = Table("t3", metadata, Column("x", Integer), schema=None)
+
+ return t1, t2, t3
+
+ def test_create_table(self, plain_tables, connection):
map_ = {
None: config.test_schema,
"foo": config.test_schema,
@@ -1052,18 +1060,16 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- t1.create(conn)
- t2.create(conn)
- t3.create(conn)
+ t1.create(conn)
+ t2.create(conn)
+ t3.create(conn)
- t3.drop(conn)
- t2.drop(conn)
- t1.drop(conn)
+ t3.drop(conn)
+ t2.drop(conn)
+ t1.drop(conn)
asserter.assert_(
CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"),
@@ -1074,14 +1080,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DROP TABLE [SCHEMA__none].t1"),
)
- def _fixture(self):
- metadata = self.metadata
- Table("t1", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t2", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t3", metadata, Column("x", Integer), schema=None)
- metadata.create_all(testing.db)
-
- def test_ddl_hastable(self):
+ def test_ddl_hastable(self, plain_tables, connection):
map_ = {
None: config.test_schema,
@@ -1094,27 +1093,28 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
Table("t2", metadata, Column("x", Integer), schema="foo")
Table("t3", metadata, Column("x", Integer), schema="bar")
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.create_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+ metadata.create_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_true(insp.has_table("t1", schema=config.test_schema))
is_true(insp.has_table("t2", schema=config.test_schema))
is_true(insp.has_table("t3", schema=None))
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.drop_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+
+ # if this test fails, the tables won't get dropped. so need a
+ # more robust fixture for this
+ metadata.drop_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_false(insp.has_table("t1", schema=config.test_schema))
is_false(insp.has_table("t2", schema=config.test_schema))
is_false(insp.has_table("t3", schema=None))
- @testing.provide_metadata
- def test_option_on_execute(self):
- self._fixture()
+ def test_option_on_execute(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1127,61 +1127,54 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection
+ execution_options = {"schema_translate_map": map_}
+ conn._execute_20(
+ t1.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t2.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t3.insert(), {"x": 1}, execution_options=execution_options
+ )
- execution_options = {"schema_translate_map": map_}
- conn._execute_20(
- t1.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t2.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t3.insert(), {"x": 1}, execution_options=execution_options
- )
+ conn._execute_20(
+ t1.update().values(x=1).where(t1.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t2.update().values(x=2).where(t2.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t3.update().values(x=3).where(t3.c.x == 1),
+ execution_options=execution_options,
+ )
+ eq_(
conn._execute_20(
- t1.update().values(x=1).where(t1.c.x == 1),
- execution_options=execution_options,
- )
+ select(t1.c.x), execution_options=execution_options
+ ).scalar(),
+ 1,
+ )
+ eq_(
conn._execute_20(
- t2.update().values(x=2).where(t2.c.x == 1),
- execution_options=execution_options,
- )
+ select(t2.c.x), execution_options=execution_options
+ ).scalar(),
+ 2,
+ )
+ eq_(
conn._execute_20(
- t3.update().values(x=3).where(t3.c.x == 1),
- execution_options=execution_options,
- )
-
- eq_(
- conn._execute_20(
- select(t1.c.x), execution_options=execution_options
- ).scalar(),
- 1,
- )
- eq_(
- conn._execute_20(
- select(t2.c.x), execution_options=execution_options
- ).scalar(),
- 2,
- )
- eq_(
- conn._execute_20(
- select(t3.c.x), execution_options=execution_options
- ).scalar(),
- 3,
- )
+ select(t3.c.x), execution_options=execution_options
+ ).scalar(),
+ 3,
+ )
- conn._execute_20(
- t1.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t2.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t3.delete(), execution_options=execution_options
- )
+ conn._execute_20(t1.delete(), execution_options=execution_options)
+ conn._execute_20(t2.delete(), execution_options=execution_options)
+ conn._execute_20(t3.delete(), execution_options=execution_options)
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1207,9 +1200,9 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_crud(self):
- self._fixture()
+ def test_crud(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1222,26 +1215,24 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- conn.execute(t1.insert(), {"x": 1})
- conn.execute(t2.insert(), {"x": 1})
- conn.execute(t3.insert(), {"x": 1})
+ conn.execute(t1.insert(), {"x": 1})
+ conn.execute(t2.insert(), {"x": 1})
+ conn.execute(t3.insert(), {"x": 1})
- conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
- conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
- conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
+ conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
+ conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
+ conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
- eq_(conn.scalar(select(t1.c.x)), 1)
- eq_(conn.scalar(select(t2.c.x)), 2)
- eq_(conn.scalar(select(t3.c.x)), 3)
+ eq_(conn.scalar(select(t1.c.x)), 1)
+ eq_(conn.scalar(select(t2.c.x)), 2)
+ eq_(conn.scalar(select(t3.c.x)), 3)
- conn.execute(t1.delete())
- conn.execute(t2.delete())
- conn.execute(t3.delete())
+ conn.execute(t1.delete())
+ conn.execute(t2.delete())
+ conn.execute(t3.delete())
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1267,9 +1258,10 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_via_engine(self):
- self._fixture()
+ def test_via_engine(self, plain_tables, metadata):
+
+ with config.db.begin() as connection:
+ metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1282,25 +1274,25 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
with self.sql_execution_asserter(config.db) as asserter:
eng = config.db.execution_options(schema_translate_map=map_)
- conn = eng.connect()
- conn.execute(select(t2.c.x))
+ with eng.connect() as conn:
+ conn.execute(select(t2.c.x))
asserter.assert_(
CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2")
)
class ExecutionOptionsTest(fixtures.TestBase):
- def test_dialect_conn_options(self):
+ def test_dialect_conn_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c2 = conn.execution_options(foo="bar")
- eq_(
- engine.dialect.set_connection_execution_options.mock_calls,
- [call(c2, {"foo": "bar"})],
- )
+ with engine.connect() as conn:
+ c2 = conn.execution_options(foo="bar")
+ eq_(
+ engine.dialect.set_connection_execution_options.mock_calls,
+ [call(c2, {"foo": "bar"})],
+ )
- def test_dialect_engine_options(self):
+ def test_dialect_engine_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
@@ -1319,14 +1311,14 @@ class ExecutionOptionsTest(fixtures.TestBase):
[call(engine, {"foo": "bar"})],
)
- def test_propagate_engine_to_connection(self):
+ def test_propagate_engine_to_connection(self, testing_engine):
engine = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
- conn = engine.connect()
- eq_(conn._execution_options, {"foo": "bar"})
+ with engine.connect() as conn:
+ eq_(conn._execution_options, {"foo": "bar"})
- def test_propagate_option_engine_to_connection(self):
+ def test_propagate_option_engine_to_connection(self, testing_engine):
e1 = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
@@ -1336,27 +1328,30 @@ class ExecutionOptionsTest(fixtures.TestBase):
eq_(c1._execution_options, {"foo": "bar"})
eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"})
- def test_get_engine_execution_options(self):
+ c1.close()
+ c2.close()
+
+ def test_get_engine_execution_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
eq_(e2.get_execution_options(), {"foo": "bar"})
- def test_get_connection_execution_options(self):
+ def test_get_connection_execution_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c = conn.execution_options(foo="bar")
+ with engine.connect() as conn:
+ c = conn.execution_options(foo="bar")
- eq_(c.get_execution_options(), {"foo": "bar"})
+ eq_(c.get_execution_options(), {"foo": "bar"})
class EngineEventsTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -1376,7 +1371,7 @@ class EngineEventsTest(fixtures.TestBase):
):
break
- def test_per_engine_independence(self):
+ def test_per_engine_independence(self, testing_engine):
e1 = testing_engine(config.db_url)
e2 = testing_engine(config.db_url)
@@ -1400,7 +1395,7 @@ class EngineEventsTest(fixtures.TestBase):
conn.execute(s2)
eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2])
- def test_per_engine_plus_global(self):
+ def test_per_engine_plus_global(self, testing_engine):
canary = Mock()
event.listen(Engine, "before_execute", canary.be1)
e1 = testing_engine(config.db_url)
@@ -1409,8 +1404,6 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(e1, "before_execute", canary.be2)
event.listen(Engine, "before_execute", canary.be3)
- e1.connect()
- e2.connect()
with e1.connect() as conn:
conn.execute(select(1))
@@ -1424,7 +1417,7 @@ class EngineEventsTest(fixtures.TestBase):
eq_(canary.be2.call_count, 1)
eq_(canary.be3.call_count, 2)
- def test_per_connection_plus_engine(self):
+ def test_per_connection_plus_engine(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1442,9 +1435,14 @@ class EngineEventsTest(fixtures.TestBase):
eq_(canary.be1.call_count, 2)
eq_(canary.be2.call_count, 2)
- @testing.combinations((True, False), (True, True), (False, False))
+ @testing.combinations(
+ (True, False),
+ (True, True),
+ (False, False),
+ argnames="mock_out_on_connect, add_our_own_onconnect",
+ )
def test_insert_connect_is_definitely_first(
- self, mock_out_on_connect, add_our_own_onconnect
+ self, mock_out_on_connect, add_our_own_onconnect, testing_engine
):
"""test issue #5708.
@@ -1478,7 +1476,7 @@ class EngineEventsTest(fixtures.TestBase):
patcher = util.nullcontext()
with patcher:
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url)
initialize = e1.dialect.initialize
@@ -1559,10 +1557,11 @@ class EngineEventsTest(fixtures.TestBase):
conn.exec_driver_sql(select1(testing.db))
eq_(m1.mock_calls, [])
- def test_add_event_after_connect(self):
+ def test_add_event_after_connect(self, testing_engine):
# new feature as of #2978
+
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
conn = e1.connect()
@@ -1575,9 +1574,9 @@ class EngineEventsTest(fixtures.TestBase):
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 2)
- def test_force_conn_events_false(self):
+ def test_force_conn_events_false(self, testing_engine):
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
event.listen(e1, "before_execute", canary.be1)
@@ -1593,7 +1592,7 @@ class EngineEventsTest(fixtures.TestBase):
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 0)
- def test_cursor_events_ctx_execute_scalar(self):
+ def test_cursor_events_ctx_execute_scalar(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1620,7 +1619,7 @@ class EngineEventsTest(fixtures.TestBase):
[call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
)
- def test_cursor_events_execute(self):
+ def test_cursor_events_execute(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1653,9 +1652,15 @@ class EngineEventsTest(fixtures.TestBase):
),
((), {"z": 10}, [], {"z": 10}, testing.requires.legacy_engine),
(({"z": 10},), {}, [], {"z": 10}),
+ argnames="multiparams, params, expected_multiparams, expected_params",
)
def test_modify_parameters_from_event_one(
- self, multiparams, params, expected_multiparams, expected_params
+ self,
+ multiparams,
+ params,
+ expected_multiparams,
+ expected_params,
+ testing_engine,
):
# this is testing both the normalization added to parameters
# as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as
@@ -1704,7 +1709,9 @@ class EngineEventsTest(fixtures.TestBase):
[(15,), (19,)],
)
- def test_modify_parameters_from_event_three(self, connection):
+ def test_modify_parameters_from_event_three(
+ self, connection, testing_engine
+ ):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
@@ -1721,7 +1728,7 @@ class EngineEventsTest(fixtures.TestBase):
with e1.connect() as conn:
conn.execute(select(literal("1")))
- def test_argument_format_execute(self):
+ def test_argument_format_execute(self, testing_engine):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
@@ -1956,9 +1963,9 @@ class EngineEventsTest(fixtures.TestBase):
)
@testing.requires.ad_hoc_engines
- def test_dispose_event(self):
+ def test_dispose_event(self, testing_engine):
canary = Mock()
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
event.listen(eng, "engine_disposed", canary)
conn = eng.connect()
@@ -2102,13 +2109,13 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(engine, "commit", tracker("commit"))
event.listen(engine, "rollback", tracker("rollback"))
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
@@ -2145,13 +2152,13 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(engine, "commit", tracker("commit"), named=True)
event.listen(engine, "rollback", tracker("rollback"), named=True)
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
@@ -2310,7 +2317,7 @@ class HandleErrorTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -2742,7 +2749,7 @@ class HandleErrorTest(fixtures.TestBase):
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -3014,6 +3021,9 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
],
)
+ c.close()
+ c2.close()
+
class DialectEventTest(fixtures.TestBase):
@contextmanager
@@ -3370,7 +3380,7 @@ class SetInputSizesTest(fixtures.TablesTest):
)
@testing.fixture
- def input_sizes_fixture(self):
+ def input_sizes_fixture(self, testing_engine):
canary = mock.Mock()
def do_set_input_sizes(cursor, list_of_tuples, context):
diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py
index 29b8132aa..c56589248 100644
--- a/test/engine/test_logging.py
+++ b/test/engine/test_logging.py
@@ -30,7 +30,7 @@ class LogParamsTest(fixtures.TestBase):
__only_on__ = "sqlite"
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.eng = engines.testing_engine(options={"echo": True})
self.no_param_engine = engines.testing_engine(
options={"echo": True, "hide_parameters": True}
@@ -44,7 +44,7 @@ class LogParamsTest(fixtures.TestBase):
for log in [logging.getLogger("sqlalchemy.engine")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
exec_sql(self.eng, "drop table if exists foo")
for log in [logging.getLogger("sqlalchemy.engine")]:
log.removeHandler(self.buf)
@@ -413,14 +413,14 @@ class LogParamsTest(fixtures.TestBase):
class PoolLoggingTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.existing_level = logging.getLogger("sqlalchemy.pool").level
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.pool")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.pool")]:
log.removeHandler(self.buf)
logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level)
@@ -528,7 +528,7 @@ class LoggingNameTest(fixtures.TestBase):
kw.update({"echo": True})
return engines.testing_engine(options=kw)
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [
logging.getLogger("sqlalchemy.engine"),
@@ -536,7 +536,7 @@ class LoggingNameTest(fixtures.TestBase):
]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [
logging.getLogger("sqlalchemy.engine"),
logging.getLogger("sqlalchemy.pool"),
@@ -588,13 +588,13 @@ class LoggingNameTest(fixtures.TestBase):
class EchoTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.level = logging.getLogger("sqlalchemy.engine").level
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN)
self.buf = logging.handlers.BufferingHandler(100)
logging.getLogger("sqlalchemy.engine").addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
logging.getLogger("sqlalchemy.engine").removeHandler(self.buf)
logging.getLogger("sqlalchemy.engine").setLevel(self.level)
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index 550fedb8e..decdce3f9 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -17,7 +17,9 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_none
from sqlalchemy.testing import is_not
+from sqlalchemy.testing import is_not_none
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.engines import testing_engine
@@ -63,18 +65,18 @@ def MockDBAPI(): # noqa
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _with_teardown(self, connection):
@@ -364,10 +366,17 @@ class PoolEventsTest(PoolTestBase):
p = self._queuepool_fixture()
canary = []
+ @event.listens_for(p, "checkin")
def checkin(*arg, **kw):
canary.append("checkin")
- event.listen(p, "checkin", checkin)
+ @event.listens_for(p, "close_detached")
+ def close_detached(*arg, **kw):
+ canary.append("close_detached")
+
+ @event.listens_for(p, "detach")
+ def detach(*arg, **kw):
+ canary.append("detach")
return p, canary
@@ -629,15 +638,35 @@ class PoolEventsTest(PoolTestBase):
assert canary.call_args_list[0][0][0] is dbapi_con
assert canary.call_args_list[0][0][2] is exc
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_checkin_event_gc(self):
+ def test_checkin_event_gc(self, detach_gced):
p, canary = self._checkin_event_fixture()
+ if detach_gced:
+ p._is_asyncio = True
+
c1 = p.connect()
+
+ dbapi_connection = weakref.ref(c1.connection)
+
eq_(canary, [])
del c1
lazy_gc()
- eq_(canary, ["checkin"])
+
+ if detach_gced:
+ # "close_detached" is not called because for asyncio the
+ # connection is just lost.
+ eq_(canary, ["detach"])
+
+ else:
+ eq_(canary, ["checkin"])
+
+ gc_collect()
+ if detach_gced:
+ is_none(dbapi_connection())
+ else:
+ is_not_none(dbapi_connection())
def test_checkin_event_on_subsequently_recreated(self):
p, canary = self._checkin_event_fixture()
@@ -744,7 +773,7 @@ class PoolEventsTest(PoolTestBase):
eq_(conn.info["important_flag"], True)
conn.close()
- def teardown(self):
+ def teardown_test(self):
# TODO: need to get remove() functionality
# going
pool.Pool.dispatch._clear()
@@ -1490,12 +1519,16 @@ class QueuePoolTest(PoolTestBase):
self._assert_cleanup_on_pooled_reconnect(dbapi, p)
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_userspace_disconnectionerror_weakref_finalizer(self):
+ def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced):
dbapi, pool = self._queuepool_dbapi_fixture(
pool_size=1, max_overflow=2
)
+ if detach_gced:
+ pool._is_asyncio = True
+
@event.listens_for(pool, "checkout")
def handle_checkout_event(dbapi_con, con_record, con_proxy):
if getattr(dbapi_con, "boom") == "yes":
@@ -1514,8 +1547,12 @@ class QueuePoolTest(PoolTestBase):
del conn
gc_collect()
- # new connection was reset on return appropriately
- eq_(dbapi_conn.mock_calls, [call.rollback()])
+ if detach_gced:
+ # new connection was detached + abandoned on return
+ eq_(dbapi_conn.mock_calls, [])
+ else:
+ # new connection reset and returned to pool
+ eq_(dbapi_conn.mock_calls, [call.rollback()])
# old connection was just closed - did not get an
# erroneous reset on return
diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py
index 3810de06a..5a4220c82 100644
--- a/test/engine/test_processors.py
+++ b/test/engine/test_processors.py
@@ -25,7 +25,7 @@ class CBooleanProcessorTest(_BooleanProcessorTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
@@ -83,7 +83,7 @@ class _DateProcessorTest(fixtures.TestBase):
class PyDateProcessorTest(_DateProcessorTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import processors
cls.module = type(
@@ -100,7 +100,7 @@ class CDateProcessorTest(_DateProcessorTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
@@ -185,7 +185,7 @@ class _DistillArgsTest(fixtures.TestBase):
class PyDistillArgsTest(_DistillArgsTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.engine import util
cls.module = type(
@@ -202,7 +202,7 @@ class CDistillArgsTest(_DistillArgsTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cutils as util
cls.module = util
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index 5fe7f6cc2..7a64b2550 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -162,7 +162,7 @@ def MockDBAPI():
class PrePingMockTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
def _pool_fixture(self, pre_ping, pool_kw=None):
@@ -182,7 +182,7 @@ class PrePingMockTest(fixtures.TestBase):
)
return _pool
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_ping_not_on_first_connect(self):
@@ -357,7 +357,7 @@ class PrePingMockTest(fixtures.TestBase):
class MockReconnectTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
self.db = testing_engine(
@@ -373,7 +373,7 @@ class MockReconnectTest(fixtures.TestBase):
e, MockDisconnect
)
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_reconnect(self):
@@ -1004,10 +1004,10 @@ class RealReconnectTest(fixtures.TestBase):
__backend__ = True
__requires__ = "graceful_disconnects", "ad_hoc_engines"
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_reconnect(self):
@@ -1336,7 +1336,7 @@ class PrePingRealTest(fixtures.TestBase):
class InvalidateDuringResultTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
self.meta = MetaData()
table = Table(
@@ -1353,7 +1353,7 @@ class InvalidateDuringResultTest(fixtures.TestBase):
[{"id": i, "name": "row %d" % i} for i in range(1, 100)],
)
- def teardown(self):
+ def teardown_test(self):
with self.engine.begin() as conn:
self.meta.drop_all(conn)
self.engine.dispose()
@@ -1470,7 +1470,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine(
options=dict(future=self.future)
)
@@ -1483,7 +1483,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
)
self.meta.create_all(self.engine)
- def teardown(self):
+ def teardown_test(self):
self.meta.drop_all(self.engine)
self.engine.dispose()
diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py
index 658cdd79f..0a46ddeec 100644
--- a/test/engine/test_reflection.py
+++ b/test/engine/test_reflection.py
@@ -796,7 +796,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
assert f1 in b1.constraints
assert len(b1.constraints) == 2
- def test_override_keys(self, connection, metadata):
+ def test_override_keys(self, metadata, connection):
"""test that columns can be overridden with a 'key',
and that ForeignKey targeting during reflection still works."""
@@ -1375,7 +1375,7 @@ class CreateDropTest(fixtures.TablesTest):
run_create_tables = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
# TablesTest is used here without
# run_create_tables, so add an explicit drop of whatever is in
# metadata
@@ -1658,7 +1658,6 @@ class SchemaTest(fixtures.TestBase):
@testing.requires.schemas
@testing.requires.cross_schema_fk_reflection
@testing.requires.implicit_default_schema
- @testing.provide_metadata
def test_blank_schema_arg(self, connection, metadata):
Table(
@@ -1913,7 +1912,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
__backend__ = True
@testing.requires.denormalized_names
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
@@ -1926,7 +1925,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
)
@testing.requires.denormalized_names
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table weird_casing")
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py
index 79126fc5b..47504b60a 100644
--- a/test/engine/test_transaction.py
+++ b/test/engine/test_transaction.py
@@ -1,6 +1,5 @@
import sys
-from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
@@ -640,12 +639,12 @@ class AutoRollbackTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata
metadata = MetaData()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_rollback_deadlock(self):
@@ -871,11 +870,13 @@ class IsolationLevelTest(fixtures.TestBase):
def test_per_engine(self):
# new in 0.9
- eng = create_engine(
+ eng = testing_engine(
testing.db.url,
- execution_options={
- "isolation_level": self._non_default_isolation_level()
- },
+ options=dict(
+ execution_options={
+ "isolation_level": self._non_default_isolation_level()
+ }
+ ),
)
conn = eng.connect()
eq_(
@@ -884,7 +885,7 @@ class IsolationLevelTest(fixtures.TestBase):
)
def test_per_option_engine(self):
- eng = create_engine(testing.db.url).execution_options(
+ eng = testing_engine(testing.db.url).execution_options(
isolation_level=self._non_default_isolation_level()
)
@@ -895,14 +896,14 @@ class IsolationLevelTest(fixtures.TestBase):
)
def test_isolation_level_accessors_connection_default(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
eq_(conn.default_isolation_level, self._default_isolation_level())
with eng.connect() as conn:
eq_(conn.get_isolation_level(), self._default_isolation_level())
def test_isolation_level_accessors_connection_option_modified(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
c2 = conn.execution_options(
isolation_level=self._non_default_isolation_level()
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
index 7dae1411e..59a44f8e2 100644
--- a/test/ext/asyncio/test_engine_py3k.py
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -269,7 +269,7 @@ class AsyncEngineTest(EngineFixture):
await trans.rollback(),
@async_test
- async def test_pool_exhausted(self, async_engine):
+ async def test_pool_exhausted_some_timeout(self, async_engine):
engine = create_async_engine(
testing.db.url,
pool_size=1,
@@ -277,7 +277,19 @@ class AsyncEngineTest(EngineFixture):
pool_timeout=0.1,
)
async with engine.connect():
- with expect_raises(asyncio.TimeoutError):
+ with expect_raises(exc.TimeoutError):
+ await engine.connect()
+
+ @async_test
+ async def test_pool_exhausted_no_timeout(self, async_engine):
+ engine = create_async_engine(
+ testing.db.url,
+ pool_size=1,
+ max_overflow=0,
+ pool_timeout=0,
+ )
+ async with engine.connect():
+ with expect_raises(exc.TimeoutError):
await engine.connect()
@async_test
diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py
index 2b80b753e..e25e7cfc2 100644
--- a/test/ext/declarative/test_inheritance.py
+++ b/test/ext/declarative/test_inheritance.py
@@ -27,11 +27,11 @@ Base = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py
index d7fcbf9e8..c327de7d4 100644
--- a/test/ext/declarative/test_reflection.py
+++ b/test/ext/declarative/test_reflection.py
@@ -4,7 +4,6 @@ from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import decl_api as decl
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import exc as orm_exc
@@ -14,6 +13,7 @@ from sqlalchemy.orm.decl_base import _DeferredMapperConfig
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import gc_collect
@@ -22,20 +22,19 @@ from sqlalchemy.testing.util import gc_collect
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry()
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
class DeferredReflectBase(DeclarativeReflectionBase):
- def teardown(self):
- super(DeferredReflectBase, self).teardown()
+ def teardown_test(self):
+ super(DeferredReflectBase, self).teardown_test()
_DeferredMapperConfig._configs.clear()
@@ -101,22 +100,23 @@ class DeferredReflectionTest(DeferredReflectBase):
u1 = User(
name="u1", addresses=[Address(email="one"), Address(email="two")]
)
- sess = create_session(testing.db)
- sess.add(u1)
- sess.flush()
- sess.expunge_all()
- eq_(
- sess.query(User).all(),
- [
- User(
- name="u1",
- addresses=[Address(email="one"), Address(email="two")],
- )
- ],
- )
- a1 = sess.query(Address).filter(Address.email == "two").one()
- eq_(a1, Address(email="two"))
- eq_(a1.user, User(name="u1"))
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
+
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User).all(),
+ [
+ User(
+ name="u1",
+ addresses=[Address(email="one"), Address(email="two")],
+ )
+ ],
+ )
+ a1 = sess.query(Address).filter(Address.email == "two").one()
+ eq_(a1, Address(email="two"))
+ eq_(a1.user, User(name="u1"))
def test_exception_prepare_not_called(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -191,15 +191,25 @@ class DeferredReflectionTest(DeferredReflectBase):
return {"primary_key": cls.__table__.c.id}
DeferredReflection.prepare(testing.db)
- sess = Session(testing.db)
- sess.add_all(
- [User(name="G"), User(name="Q"), User(name="A"), User(name="C")]
- )
- sess.commit()
- eq_(
- sess.query(User).order_by(User.name).all(),
- [User(name="A"), User(name="C"), User(name="G"), User(name="Q")],
- )
+ with fixture_session() as sess:
+ sess.add_all(
+ [
+ User(name="G"),
+ User(name="Q"),
+ User(name="A"),
+ User(name="C"),
+ ]
+ )
+ sess.commit()
+ eq_(
+ sess.query(User).order_by(User.name).all(),
+ [
+ User(name="A"),
+ User(name="C"),
+ User(name="G"),
+ User(name="Q"),
+ ],
+ )
@testing.requires.predictable_gc
def test_cls_not_strong_ref(self):
@@ -255,14 +265,14 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase):
u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")])
- sess = Session(testing.db)
- sess.add(u1)
- sess.commit()
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
- eq_(
- sess.query(User).all(),
- [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
- )
+ eq_(
+ sess.query(User).all(),
+ [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
+ )
def test_string_resolution(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -296,27 +306,26 @@ class DeferredInhReflectBase(DeferredReflectBase):
Foo = Base.registry._class_registry["Foo"]
Bar = Base.registry._class_registry["Bar"]
- s = Session(testing.db)
-
- s.add_all(
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ]
- )
- s.commit()
-
- eq_(
- s.query(Foo).order_by(Foo.id).all(),
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ],
- )
+ with fixture_session() as s:
+ s.add_all(
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ]
+ )
+ s.commit()
+
+ eq_(
+ s.query(Foo).order_by(Foo.id).all(),
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ],
+ )
class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py
index b1f5cc956..31ae050c1 100644
--- a/test/ext/test_associationproxy.py
+++ b/test/ext/test_associationproxy.py
@@ -101,7 +101,7 @@ class AutoFlushTest(fixtures.TablesTest):
Column("name", String(50)),
)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture(self, collection_class, is_dict=False):
@@ -198,7 +198,7 @@ class AutoFlushTest(fixtures.TablesTest):
class _CollectionOperations(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
collection_class = self.collection_class
metadata = MetaData()
@@ -260,7 +260,7 @@ class _CollectionOperations(fixtures.TestBase):
self.session = fixture_session()
self.Parent, self.Child = Parent, Child
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
@@ -885,7 +885,7 @@ class CustomObjectTest(_CollectionOperations):
class ProxyFactoryTest(ListTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
@@ -1157,7 +1157,7 @@ class ScalarTest(fixtures.TestBase):
class LazyLoadTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
@@ -1197,7 +1197,7 @@ class LazyLoadTest(fixtures.TestBase):
self.Parent, self.Child = Parent, Child
self.table = parents_table
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
@@ -2294,7 +2294,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
class DictOfTupleUpdateTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
class B(object):
def __init__(self, key, elem):
self.key = key
@@ -2434,7 +2434,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest):
class AttributeAccessTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_resolve_aliased_class(self):
diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py
index 71fabc629..2d4e9848e 100644
--- a/test/ext/test_baked.py
+++ b/test/ext/test_baked.py
@@ -27,7 +27,7 @@ class BakedTest(_fixtures.FixtureTest):
run_inserts = "once"
run_deletes = None
- def setup(self):
+ def setup_test(self):
self.bakery = baked.bakery()
diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py
index 058c1dfd7..d011417d7 100644
--- a/test/ext/test_compiler.py
+++ b/test/ext/test_compiler.py
@@ -426,7 +426,7 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def teardown(self):
+ def teardown_test(self):
for cls in (Select, BindParameter):
deregister(cls)
diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py
index ad9bf0bc0..f3eceb0dc 100644
--- a/test/ext/test_extendedattr.py
+++ b/test/ext/test_extendedattr.py
@@ -32,7 +32,7 @@ def modifies_instrumentation_finders(fn, *args, **kw):
class _ExtBase(object):
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation._reinstall_default_lookups()
@@ -89,7 +89,7 @@ MyBaseClass, MyClass = None, None
class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global MyBaseClass, MyClass
class MyBaseClass(object):
@@ -143,7 +143,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
else:
del self._goofy_dict[key]
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_instance_dict(self):
diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py
index 038bdd83e..bb06d9648 100644
--- a/test/ext/test_horizontal_shard.py
+++ b/test/ext/test_horizontal_shard.py
@@ -19,7 +19,6 @@ from sqlalchemy import update
from sqlalchemy import util
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import deferred
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
@@ -42,7 +41,7 @@ class ShardTest(object):
schema = None
- def setUp(self):
+ def setup_test(self):
global db1, db2, db3, db4, weather_locations, weather_reports
db1, db2, db3, db4 = self._dbs = self._init_dbs()
@@ -88,7 +87,7 @@ class ShardTest(object):
@classmethod
def setup_session(cls):
- global create_session
+ global sharded_session
shard_lookup = {
"North America": "north_america",
"Asia": "asia",
@@ -128,10 +127,10 @@ class ShardTest(object):
else:
return ids
- create_session = sessionmaker(
+ sharded_session = sessionmaker(
class_=ShardedSession, autoflush=True, autocommit=False
)
- create_session.configure(
+ sharded_session.configure(
shards={
"north_america": db1,
"asia": db2,
@@ -180,7 +179,7 @@ class ShardTest(object):
tokyo.reports.append(Report(80.0, id_=1))
newyork.reports.append(Report(75, id_=1))
quito.reports.append(Report(85))
- sess = create_session(future=True)
+ sess = sharded_session(future=True)
for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
sess.add(c)
sess.flush()
@@ -671,11 +670,10 @@ class DistinctEngineShardTest(ShardTest, fixtures.TestBase):
self.dbs = [db1, db2, db3, db4]
return self.dbs
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- for db in self.dbs:
- db.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -702,10 +700,10 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase):
self.engine = e
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.engine.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -778,10 +776,13 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
self.postgresql_engine = e2
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.sqlite_engine.connect().invalidate()
+ # the tests in this suite don't cleanly close out the Session
+ # at the moment so use the reaper to close all connections
+ testing_reaper.checkin_all()
+
for i in [1, 3]:
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -789,6 +790,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
self.tables_test_metadata.drop_all(conn)
for i in [2, 4]:
conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,))
+ self.postgresql_engine.dispose()
class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
@@ -904,11 +906,11 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
return self.dbs
- def teardown(self):
+ def teardown_test(self):
for db in self.dbs:
db.connect().invalidate()
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
for i in range(1, 3):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py
index 048a8b52d..3bab7db93 100644
--- a/test/ext/test_hybrid.py
+++ b/test/ext/test_hybrid.py
@@ -697,7 +697,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import literal
symbols = ("usd", "gbp", "cad", "eur", "aud")
diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py
index eba2ac0cb..21244de73 100644
--- a/test/ext/test_mutable.py
+++ b/test/ext/test_mutable.py
@@ -90,11 +90,10 @@ class _MutableDictTestFixture(object):
def _type_fixture(cls):
return MutableDict
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableDictTestFixture, self).teardown()
class _MutableDictTestBase(_MutableDictTestFixture):
@@ -312,11 +311,10 @@ class _MutableListTestFixture(object):
def _type_fixture(cls):
return MutableList
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableListTestFixture, self).teardown()
class _MutableListTestBase(_MutableListTestFixture):
@@ -619,11 +617,10 @@ class _MutableSetTestFixture(object):
def _type_fixture(cls):
return MutableSet
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableSetTestFixture, self).teardown()
class _MutableSetTestBase(_MutableSetTestFixture):
@@ -1234,17 +1231,15 @@ class _CompositeTestBase(object):
Column("unrelated_data", String(50)),
)
- def setup(self):
+ def setup_test(self):
from sqlalchemy.ext import mutable
mutable._setup_composite_listener()
- super(_CompositeTestBase, self).setup()
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_CompositeTestBase, self).teardown()
@classmethod
def _type_fixture(cls):
diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py
index f23d6cb57..280fad6cf 100644
--- a/test/ext/test_orderinglist.py
+++ b/test/ext/test_orderinglist.py
@@ -8,7 +8,7 @@ from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.fixtures import create_session
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import picklers
@@ -60,7 +60,7 @@ def alpha_ordering(index, collection):
class OrderingListTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
global metadata, slides_table, bullets_table, Slide, Bullet
slides_table, bullets_table = None, None
Slide, Bullet = None, None
@@ -122,7 +122,7 @@ class OrderingListTest(fixtures.TestBase):
metadata.create_all(testing.db)
- def teardown(self):
+ def teardown_test(self):
metadata.drop_all(testing.db)
def test_append_no_reorder(self):
@@ -167,7 +167,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(s1.bullets[2].position == 3)
self.assert_(s1.bullets[3].position == 4)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -232,7 +232,7 @@ class OrderingListTest(fixtures.TestBase):
s1.bullets._reorder()
self.assert_(s1.bullets[4].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -289,7 +289,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(len(s1.bullets) == 6)
self.assert_(s1.bullets[5].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -338,7 +338,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(s1.bullets[li].position == li)
self.assert_(s1.bullets[li] == b[bi])
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -365,7 +365,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(len(s1.bullets) == 3)
self.assert_(s1.bullets[2].position == 2)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py
index 4c005d336..4d9162105 100644
--- a/test/orm/declarative/test_basic.py
+++ b/test/orm/declarative/test_basic.py
@@ -61,11 +61,11 @@ class DeclarativeTestBase(
):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
global Base
Base = declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/orm/declarative/test_concurrency.py b/test/orm/declarative/test_concurrency.py
index 5f12d8272..ecddc2e5f 100644
--- a/test/orm/declarative/test_concurrency.py
+++ b/test/orm/declarative/test_concurrency.py
@@ -17,7 +17,7 @@ from sqlalchemy.testing.fixtures import fixture_session
class ConcurrentUseDeclMappingTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@classmethod
diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py
index cc29cab7d..e09b1570e 100644
--- a/test/orm/declarative/test_inheritance.py
+++ b/test/orm/declarative/test_inheritance.py
@@ -27,11 +27,11 @@ Base = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py
index 631527daf..ad4832c35 100644
--- a/test/orm/declarative/test_mixin.py
+++ b/test/orm/declarative/test_mixin.py
@@ -38,13 +38,13 @@ mapper_registry = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base, mapper_registry
mapper_registry = registry(metadata=MetaData())
Base = mapper_registry.generate_base()
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
with testing.db.begin() as conn:
diff --git a/test/orm/declarative/test_reflection.py b/test/orm/declarative/test_reflection.py
index 241528c44..e7b2a7058 100644
--- a/test/orm/declarative/test_reflection.py
+++ b/test/orm/declarative/test_reflection.py
@@ -17,14 +17,13 @@ from sqlalchemy.testing.schema import Table
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry(metadata=MetaData())
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
index bdcdedc44..da07b4941 100644
--- a/test/orm/inheritance/test_basic.py
+++ b/test/orm/inheritance/test_basic.py
@@ -31,7 +31,6 @@ from sqlalchemy.orm import synonym
from sqlalchemy.orm.util import instance_str
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
@@ -1889,7 +1888,6 @@ class VersioningTest(fixtures.MappedTest):
@testing.emits_warning(r".*updated rowcount")
@testing.requires.sane_rowcount_w_returning
- @engines.close_open_connections
def test_save_update(self):
subtable, base, stuff = (
self.tables.subtable,
@@ -2927,7 +2925,7 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase):
)
return parent, child
- def tearDown(self):
+ def teardown_test(self):
clear_mappers()
def test_warning_on_sub(self):
@@ -3417,27 +3415,26 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest):
@classmethod
def insert_data(cls, connection):
Parent, A, B = cls.classes("Parent", "A", "B")
- s = fixture_session()
-
- p1 = Parent(id=1)
- p2 = Parent(id=2)
- s.add_all([p1, p2])
- s.flush()
+ with Session(connection) as s:
+ p1 = Parent(id=1)
+ p2 = Parent(id=2)
+ s.add_all([p1, p2])
+ s.flush()
- s.add_all(
- [
- A(id=1, parent_id=1),
- B(id=2, parent_id=1),
- A(id=3, parent_id=1),
- B(id=4, parent_id=1),
- ]
- )
- s.flush()
+ s.add_all(
+ [
+ A(id=1, parent_id=1),
+ B(id=2, parent_id=1),
+ A(id=3, parent_id=1),
+ B(id=4, parent_id=1),
+ ]
+ )
+ s.flush()
- s.query(A).filter(A.id.in_([3, 4])).update(
- {A.type: None}, synchronize_session=False
- )
- s.commit()
+ s.query(A).filter(A.id.in_([3, 4])).update(
+ {A.type: None}, synchronize_session=False
+ )
+ s.commit()
def test_pk_is_null(self):
Parent, A = self.classes("Parent", "A")
@@ -3527,10 +3524,12 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest):
ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes(
"ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB"
)
- s = fixture_session()
+ with Session(connection) as s:
- s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()])
- s.commit()
+ s.add_all(
+ [ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]
+ )
+ s.commit()
def test_single_invalid_ident(self):
ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA")
diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py
index 8820aa6a4..0a0a5d12b 100644
--- a/test/orm/test_attributes.py
+++ b/test/orm/test_attributes.py
@@ -209,7 +209,7 @@ class AttributeImplAPITest(fixtures.MappedTest):
class AttributesTest(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
global MyTest, MyTest2
class MyTest(object):
@@ -218,7 +218,7 @@ class AttributesTest(fixtures.ORMTest):
class MyTest2(object):
pass
- def teardown(self):
+ def teardown_test(self):
global MyTest, MyTest2
MyTest, MyTest2 = None, None
@@ -3690,7 +3690,7 @@ class EventPropagateTest(fixtures.TestBase):
class CollectionInitTest(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
@@ -3749,7 +3749,7 @@ class CollectionInitTest(fixtures.TestBase):
class TestUnlink(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py
index 2f54f7fff..014fa152e 100644
--- a/test/orm/test_bind.py
+++ b/test/orm/test_bind.py
@@ -428,39 +428,46 @@ class BindIntegrationTest(_fixtures.FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- c = testing.db.connect()
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.close()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u2")
- sess.add(u)
- sess.flush()
- sess.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
-
- with c.begin():
- c.exec_driver_sql("delete from users")
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- c = testing.db.connect()
-
- trans = c.begin()
- sess = Session(bind=c, autocommit=True)
- u = User(name="u3")
- sess.add(u)
- sess.flush()
- assert c.in_transaction()
- trans.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
+ with testing.db.connect() as c:
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.close()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u2")
+ sess.add(u)
+ sess.flush()
+ sess.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
+
+ with c.begin():
+ c.exec_driver_sql("delete from users")
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ with testing.db.connect() as c:
+ trans = c.begin()
+ sess = Session(bind=c, autocommit=True)
+ u = User(name="u3")
+ sess.add(u)
+ sess.flush()
+ assert c.in_transaction()
+ trans.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
class SessionBindTest(fixtures.MappedTest):
@@ -506,6 +513,7 @@ class SessionBindTest(fixtures.MappedTest):
finally:
if hasattr(bind, "close"):
bind.close()
+ sess.close()
def test_session_unbound(self):
Foo = self.classes.Foo
diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py
index 3d09bd446..2a0aafbbc 100644
--- a/test/orm/test_collection.py
+++ b/test/orm/test_collection.py
@@ -92,13 +92,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
return str((id(self), self.a, self.b, self.c))
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
instrumentation.register_class(cls.Entity)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation.unregister_class(cls.Entity)
- super(CollectionsTest, cls).teardown_class()
_entity_id = 1
diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py
index df652daf4..20d8ecc2d 100644
--- a/test/orm/test_compile.py
+++ b/test/orm/test_compile.py
@@ -19,7 +19,7 @@ from sqlalchemy.testing import fixtures
class CompileTest(fixtures.ORMTest):
"""test various mapper compilation scenarios"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_with_polymorphic(self):
diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py
index e1ef67fed..ed11b89c9 100644
--- a/test/orm/test_cycles.py
+++ b/test/orm/test_cycles.py
@@ -1743,8 +1743,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest):
id = Column(Integer, primary_key=True)
a_id = Column(ForeignKey("a.id", name="a_fk"))
- def setup(self):
- super(PostUpdateOnUpdateTest, self).setup()
+ def setup_test(self):
PostUpdateOnUpdateTest.counter = count()
PostUpdateOnUpdateTest.db_counter = count()
diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py
index 6d946cfe6..15063ebe9 100644
--- a/test/orm/test_deprecations.py
+++ b/test/orm/test_deprecations.py
@@ -2199,6 +2199,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture):
class AutocommitClosesOnFailTest(fixtures.MappedTest):
__requires__ = ("deferrable_fks",)
+ __only_on__ = ("postgresql+psycopg2",) # needs #5824 for asyncpg
@classmethod
def define_tables(cls, metadata):
@@ -4498,44 +4499,49 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
warnings += (join_aliased_dep,)
# load a user who has an order that contains item id 3 and address
# id 1 (order 3, owned by jack)
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_, isouter=True)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_, isouter=True)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
-
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .outerjoin("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .outerjoin("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join(
+ "orders", "items", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join(
+ "orders", "address", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .outerjoin("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .outerjoin("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
class AliasFromCorrectLeftTest(
diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py
index 4498fc1ff..7eedb37c9 100644
--- a/test/orm/test_eager_relations.py
+++ b/test/orm/test_eager_relations.py
@@ -559,15 +559,15 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
5,
),
]:
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opt).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opt).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
def test_disable_dynamic(self):
"""test no joined option on a dynamic."""
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index 1c918a88c..e85c23d6f 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -45,13 +45,14 @@ from test.orm import _fixtures
class _RemoveListeners(object):
- def teardown(self):
+ @testing.fixture(autouse=True)
+ def _remove_listeners(self):
+ yield
events.MapperEvents._clear()
events.InstanceEvents._clear()
events.SessionEvents._clear()
events.InstrumentationEvents._clear()
events.QueryEvents._clear()
- super(_RemoveListeners, self).teardown()
class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
@@ -1174,7 +1175,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
argnames="target, event_name, fn",
)(fn)
- def teardown(self):
+ def teardown_test(self):
A = self.classes.A
A._sa_class_manager.dispatch._clear()
diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py
index cc9596466..f622bff02 100644
--- a/test/orm/test_froms.py
+++ b/test/orm/test_froms.py
@@ -2395,16 +2395,19 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
]
adalias = addresses.alias()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
- assert q.all() == expected
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+
+ eq_(q.all(), expected)
# test with a straight statement
s = (
@@ -2417,52 +2420,57 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
.group_by(*[c for c in users.c])
.order_by(users.c.id)
)
- q = fixture_session().query(User)
- result = (
- q.add_columns(s.selected_columns.count, s.selected_columns.concat)
- .from_statement(s)
- .all()
- )
- assert result == expected
-
- sess.expunge_all()
- # test with select_entity_from()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .select_entity_from(users.outerjoin(addresses))
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = sess.query(User)
+ result = (
+ q.add_columns(
+ s.selected_columns.count, s.selected_columns.concat
+ )
+ .from_statement(s)
+ .all()
+ )
+ eq_(result, expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ # test with select_entity_from()
+ q = (
+ fixture_session()
+ .query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .select_entity_from(users.outerjoin(addresses))
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .outerjoin("addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ eq_(q.all(), expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin("addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+ eq_(q.all(), expected)
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- assert q.all() == expected
- sess.expunge_all()
+ eq_(q.all(), expected)
def test_expression_selectable_matches_mzero(self):
User, Address = self.classes.User, self.classes.Address
diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py
index 3061de309..43cf81e6d 100644
--- a/test/orm/test_lazy_relations.py
+++ b/test/orm/test_lazy_relations.py
@@ -717,24 +717,24 @@ class LazyTest(_fixtures.FixtureTest):
),
)
- sess = fixture_session()
+ with fixture_session() as sess:
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_uses_get_compatible_types(self):
"""test the use_get optimization with compatible
@@ -789,24 +789,23 @@ class LazyTest(_fixtures.FixtureTest):
properties=dict(user=relationship(mapper(User, users))),
)
- sess = fixture_session()
-
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ with fixture_session() as sess:
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_many_to_one(self):
users, Address, addresses, User = (
diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py
index 0e8ac97e3..42b5b3e45 100644
--- a/test/orm/test_load_on_fks.py
+++ b/test/orm/test_load_on_fks.py
@@ -9,14 +9,12 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import instance_state
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
-engine = testing.db
-
-
class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
@@ -36,27 +34,27 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
)
parent_id = Column(Integer, ForeignKey("parent.id"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
- def tearDown(self):
- Base.metadata.drop_all(engine)
+ def teardown_test(self):
+ Base.metadata.drop_all(testing.db)
def test_annoying_autoflush_one(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
p1.children = []
def test_annoying_autoflush_two(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
assert p1.children == []
def test_dont_load_if_no_keys(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
@@ -68,7 +66,9 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
@@ -91,10 +91,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
parent = relationship(Parent, backref=backref("children"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
global sess, p1, p2, c1, c2
- sess = Session(bind=engine)
+ sess = Session(bind=testing.db)
p1 = Parent()
p2 = Parent()
@@ -105,9 +105,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
sess.commit()
- def tearDown(self):
+ def teardown_test(self):
sess.rollback()
- Base.metadata.drop_all(engine)
+ Base.metadata.drop_all(testing.db)
def test_load_on_pending_allows_backref_event(self):
Child.parent.property.load_on_pending = True
diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py
index 013eb21e1..d182fd2c1 100644
--- a/test/orm/test_mapper.py
+++ b/test/orm/test_mapper.py
@@ -2560,7 +2560,7 @@ class MagicNamesTest(fixtures.MappedTest):
class DocumentTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.mapper = registry().map_imperatively
@@ -2624,14 +2624,14 @@ class DocumentTest(fixtures.TestBase):
class ORMLoggingTest(_fixtures.FixtureTest):
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.orm")]:
log.addHandler(self.buf)
self.mapper = registry().map_imperatively
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.orm")]:
log.removeHandler(self.buf)
diff --git a/test/orm/test_options.py b/test/orm/test_options.py
index b22b318e9..6f47c1238 100644
--- a/test/orm/test_options.py
+++ b/test/orm/test_options.py
@@ -1523,9 +1523,7 @@ class PickleTest(PathTest, QueryTest):
class LocalOptsTest(PathTest, QueryTest):
@classmethod
- def setup_class(cls):
- super(LocalOptsTest, cls).setup_class()
-
+ def setup_test_class(cls):
@strategy_options.loader_option()
def some_col_opt_only(loadopt, key, opts):
return loadopt.set_column_strategy(
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index fd8e849fb..7546ba162 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -6000,12 +6000,19 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
[User.orders_syn, Order.items_syn],
[User.orders_syn_2, Order.items_syn],
):
- q = fixture_session().query(User)
- for path in j:
- q = q.join(path)
- q = q.filter_by(id=3)
- result = q.all()
- assert [User(id=7, name="jack"), User(id=9, name="fred")] == result
+ with fixture_session() as sess:
+ q = sess.query(User)
+ for path in j:
+ q = q.join(path)
+ q = q.filter_by(id=3)
+ result = q.all()
+ eq_(
+ result,
+ [
+ User(id=7, name="jack"),
+ User(id=9, name="fred"),
+ ],
+ )
def test_with_parent(self):
Order, User = self.classes.Order, self.classes.User
@@ -6018,17 +6025,17 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
("name_syn", "orders_syn"),
("name_syn", "orders_syn_2"),
):
- sess = fixture_session()
- q = sess.query(User)
+ with fixture_session() as sess:
+ q = sess.query(User)
- u1 = q.filter_by(**{nameprop: "jack"}).one()
+ u1 = q.filter_by(**{nameprop: "jack"}).one()
- o = sess.query(Order).with_parent(u1, property=orderprop).all()
- assert [
- Order(description="order 1"),
- Order(description="order 3"),
- Order(description="order 5"),
- ] == o
+ o = sess.query(Order).with_parent(u1, property=orderprop).all()
+ assert [
+ Order(description="order 1"),
+ Order(description="order 3"),
+ Order(description="order 5"),
+ ] == o
def test_froms_aliased_col(self):
Address, User = self.classes.Address, self.classes.User
diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py
index 12c084b2d..ef1bf2e60 100644
--- a/test/orm/test_rel_fn.py
+++ b/test/orm/test_rel_fn.py
@@ -26,7 +26,7 @@ from sqlalchemy.testing import mock
class _JoinFixtures(object):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
m = MetaData()
cls.left = Table(
"lft",
diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py
index 5979f08ae..8d73cd40e 100644
--- a/test/orm/test_relationships.py
+++ b/test/orm/test_relationships.py
@@ -637,7 +637,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase):
"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture_one(
@@ -2474,7 +2474,7 @@ class JoinConditionErrorTest(fixtures.TestBase):
assert_raises(sa.exc.ArgumentError, configure_mappers)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@@ -4354,7 +4354,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest):
class SecondaryArgTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@testing.combinations((True,), (False,))
diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py
index 5535fe5d6..4895c7d3a 100644
--- a/test/orm/test_selectin_relations.py
+++ b/test/orm/test_selectin_relations.py
@@ -713,35 +713,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
diff --git a/test/orm/test_session.py b/test/orm/test_session.py
index 20c4752b8..3d4566af3 100644
--- a/test/orm/test_session.py
+++ b/test/orm/test_session.py
@@ -1273,7 +1273,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
run_inserts = None
- def setup(self):
+ def setup_test(self):
mapper(self.classes.User, self.tables.users)
def _assert_modified(self, u1):
@@ -1288,11 +1288,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
def _assert_no_cycle(self, u1):
assert sa.orm.attributes.instance_state(u1)._strong_obj is None
- def _persistent_fixture(self):
+ def _persistent_fixture(self, gc_collect=False):
User = self.classes.User
u1 = User()
u1.name = "ed"
- sess = fixture_session()
+ if gc_collect:
+ sess = Session(testing.db)
+ else:
+ sess = fixture_session()
sess.add(u1)
sess.flush()
return sess, u1
@@ -1389,14 +1392,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
@testing.requires.predictable_gc
def test_move_gc_session_persistent_dirty(self):
- sess, u1 = self._persistent_fixture()
+ sess, u1 = self._persistent_fixture(gc_collect=True)
u1.name = "edchanged"
self._assert_cycle(u1)
self._assert_modified(u1)
del sess
gc_collect()
self._assert_cycle(u1)
- s2 = fixture_session()
+ s2 = Session(testing.db)
s2.add(u1)
self._assert_cycle(u1)
self._assert_modified(u1)
@@ -1565,7 +1568,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
mapper(User, users)
- sess = fixture_session()
+ sess = Session(testing.db)
u1 = User(name="u1")
sess.add(u1)
@@ -1573,7 +1576,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
# can't add u1 to Session,
# already belongs to u2
- s2 = fixture_session()
+ s2 = Session(testing.db)
assert_raises_message(
sa.exc.InvalidRequestError,
r".*is already attached to session",
@@ -1725,11 +1728,10 @@ class DisposedStates(fixtures.MappedTest):
mapper(T, cls.tables.t1)
- def teardown(self):
+ def teardown_test(self):
from sqlalchemy.orm.session import _sessions
_sessions.clear()
- super(DisposedStates, self).teardown()
def _set_imap_in_disposal(self, sess, *objs):
"""remove selected objects from the given session, as though
diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py
index fe20442a3..150cee222 100644
--- a/test/orm/test_subquery_relations.py
+++ b/test/orm/test_subquery_relations.py
@@ -734,35 +734,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py
index 550cf6535..7f77b01c7 100644
--- a/test/orm/test_transaction.py
+++ b/test/orm/test_transaction.py
@@ -66,17 +66,18 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
users, User = self.tables.users, self.classes.User
mapper(User, users)
- conn = testing.db.connect()
- trans = conn.begin()
- sess = Session(bind=conn, autocommit=False, autoflush=True)
- sess.begin(subtransactions=True)
- u = User(name="ed")
- sess.add(u)
- sess.flush()
- sess.commit() # commit does nothing
- trans.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+
+ with testing.db.connect() as conn:
+ trans = conn.begin()
+ sess = Session(bind=conn, autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
+ u = User(name="ed")
+ sess.add(u)
+ sess.flush()
+ sess.commit() # commit does nothing
+ trans.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@engines.close_open_connections
def test_subtransaction_on_external_no_begin(self):
@@ -260,34 +261,33 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
users = self.tables.users
engine = Engine._future_facade(testing.db)
- session = Session(engine, autocommit=False)
-
- session.begin()
- session.connection().execute(users.insert().values(name="user1"))
- session.begin(subtransactions=True)
- session.begin_nested()
- session.connection().execute(users.insert().values(name="user2"))
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
- session.rollback()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(users.insert().values(name="user3"))
- session.commit()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
+ with Session(engine, autocommit=False) as session:
+ session.begin()
+ session.connection().execute(users.insert().values(name="user1"))
+ session.begin(subtransactions=True)
+ session.begin_nested()
+ session.connection().execute(users.insert().values(name="user2"))
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ session.rollback()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(users.insert().values(name="user3"))
+ session.commit()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting(self):
@@ -295,27 +295,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- s = Session(testing.db)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session() as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- s.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ s.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- s.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ s.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting_future(self):
@@ -323,27 +323,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- s = Session(testing.db, future=True)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session(future=True) as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- nt2.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ nt2.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- nt1.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ nt1.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.independent_connections
def test_transactions_isolated(self):
@@ -1049,23 +1049,25 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- session = Session(testing.db)
+ with fixture_session() as session:
- with expect_warnings(".*during handling of a previous exception.*"):
- session.begin_nested()
- savepoint = session.connection()._nested_transaction._savepoint
+ with expect_warnings(
+ ".*during handling of a previous exception.*"
+ ):
+ session.begin_nested()
+ savepoint = session.connection()._nested_transaction._savepoint
- # force the savepoint to disappear
- session.connection().dialect.do_release_savepoint(
- session.connection(), savepoint
- )
+ # force the savepoint to disappear
+ session.connection().dialect.do_release_savepoint(
+ session.connection(), savepoint
+ )
- # now do a broken flush
- session.add_all([User(id=1), User(id=1)])
+ # now do a broken flush
+ session.add_all([User(id=1), User(id=1)])
- assert_raises_message(
- sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
- )
+ assert_raises_message(
+ sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
+ )
class _LocalFixture(FixtureTest):
@@ -1170,39 +1172,40 @@ class SubtransactionRecipeTest(FixtureTest):
def test_recipe_heavy_nesting(self, subtransaction_recipe):
users = self.tables.users
- session = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(session):
- session.connection().execute(users.insert().values(name="user1"))
+ with fixture_session(future=self.future) as session:
with subtransaction_recipe(session):
- savepoint = session.begin_nested()
session.connection().execute(
- users.insert().values(name="user2")
+ users.insert().values(name="user1")
)
+ with subtransaction_recipe(session):
+ savepoint = session.begin_nested()
+ session.connection().execute(
+ users.insert().values(name="user2")
+ )
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ savepoint.rollback()
+
+ with subtransaction_recipe(session):
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(
+ users.insert().values(name="user3")
+ )
assert (
session.connection()
.exec_driver_sql("select count(1) from users")
.scalar()
== 2
)
- savepoint.rollback()
-
- with subtransaction_recipe(session):
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(
- users.insert().values(name="user3")
- )
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
@engines.close_open_connections
def test_recipe_subtransaction_on_external_subtrans(
@@ -1228,13 +1231,12 @@ class SubtransactionRecipeTest(FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.close()
- assert len(sess.query(User).all()) == 1
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.close()
+ assert len(sess.query(User).all()) == 1
def test_recipe_subtransaction_on_noautocommit(
self, subtransaction_recipe
@@ -1242,16 +1244,15 @@ class SubtransactionRecipeTest(FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- sess.begin()
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@testing.requires.savepoints
def test_recipe_mixed_transaction_control(self, subtransaction_recipe):
@@ -1259,30 +1260,28 @@ class SubtransactionRecipeTest(FixtureTest):
mapper(User, users)
- sess = Session(testing.db, future=self.future)
+ with fixture_session(future=self.future) as sess:
- sess.begin()
- sess.begin_nested()
+ sess.begin()
+ sess.begin_nested()
- with subtransaction_recipe(sess):
+ with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
+ sess.add(User(name="u1"))
- sess.commit()
- sess.commit()
+ sess.commit()
+ sess.commit()
- eq_(len(sess.query(User).all()), 1)
- sess.close()
+ eq_(len(sess.query(User).all()), 1)
+ sess.close()
- t1 = sess.begin()
- t2 = sess.begin_nested()
-
- sess.add(User(name="u2"))
+ t1 = sess.begin()
+ t2 = sess.begin_nested()
- t2.commit()
- assert sess._legacy_transaction() is t1
+ sess.add(User(name="u2"))
- sess.close()
+ t2.commit()
+ assert sess._legacy_transaction() is t1
def test_recipe_error_on_using_inactive_session_commands(
self, subtransaction_recipe
@@ -1290,56 +1289,55 @@ class SubtransactionRecipeTest(FixtureTest):
users, User = self.tables.users, self.classes.User
mapper(User, users)
- sess = Session(testing.db, future=self.future)
- sess.begin()
-
- try:
- with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
- sess.flush()
- raise Exception("force rollback")
- except:
- pass
-
- if self.recipe_rollsback_early:
- # that was a real rollback, so no transaction
- assert not sess.in_transaction()
- is_(sess.get_transaction(), None)
- else:
- assert sess.in_transaction()
-
- sess.close()
- assert not sess.in_transaction()
-
- def test_recipe_multi_nesting(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- assert sess.in_transaction()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
try:
with subtransaction_recipe(sess):
- assert sess._legacy_transaction()
+ sess.add(User(name="u1"))
+ sess.flush()
raise Exception("force rollback")
except:
pass
if self.recipe_rollsback_early:
+ # that was a real rollback, so no transaction
assert not sess.in_transaction()
+ is_(sess.get_transaction(), None)
else:
assert sess.in_transaction()
- assert not sess.in_transaction()
+ sess.close()
+ assert not sess.in_transaction()
+
+ def test_recipe_multi_nesting(self, subtransaction_recipe):
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ assert sess.in_transaction()
+
+ try:
+ with subtransaction_recipe(sess):
+ assert sess._legacy_transaction()
+ raise Exception("force rollback")
+ except:
+ pass
+
+ if self.recipe_rollsback_early:
+ assert not sess.in_transaction()
+ else:
+ assert sess.in_transaction()
+
+ assert not sess.in_transaction()
def test_recipe_deactive_status_check(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
- sess.begin()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
- with subtransaction_recipe(sess):
- sess.rollback()
+ with subtransaction_recipe(sess):
+ sess.rollback()
- assert not sess.in_transaction()
- sess.commit() # no error
+ assert not sess.in_transaction()
+ sess.commit() # no error
class FixtureDataTest(_LocalFixture):
@@ -1394,28 +1392,28 @@ class CleanSavepointTest(FixtureTest):
mapper(User, users)
- s = Session(bind=testing.db, future=future)
- u1 = User(name="u1")
- u2 = User(name="u2")
- s.add_all([u1, u2])
- s.commit()
- u1.name
- u2.name
- trans = s._transaction
- assert trans is not None
- s.begin_nested()
- update_fn(s, u2)
- eq_(u2.name, "u2modified")
- s.rollback()
+ with fixture_session(future=future) as s:
+ u1 = User(name="u1")
+ u2 = User(name="u2")
+ s.add_all([u1, u2])
+ s.commit()
+ u1.name
+ u2.name
+ trans = s._transaction
+ assert trans is not None
+ s.begin_nested()
+ update_fn(s, u2)
+ eq_(u2.name, "u2modified")
+ s.rollback()
- if future:
- assert s._transaction is None
- assert "name" not in u1.__dict__
- else:
- assert s._transaction is trans
- eq_(u1.__dict__["name"], "u1")
- assert "name" not in u2.__dict__
- eq_(u2.name, "u2")
+ if future:
+ assert s._transaction is None
+ assert "name" not in u1.__dict__
+ else:
+ assert s._transaction is trans
+ eq_(u1.__dict__["name"], "u1")
+ assert "name" not in u2.__dict__
+ eq_(u2.name, "u2")
@testing.requires.savepoints
def test_rollback_ignores_clean_on_savepoint(self):
@@ -2116,82 +2114,108 @@ class ContextManagerPlusFutureTest(FixtureTest):
eq_(sess.query(User).count(), 1)
def test_explicit_begin(self):
- s1 = Session(testing.db)
- with s1.begin() as trans:
- is_(trans, s1._legacy_transaction())
- s1.connection()
+ with fixture_session() as s1:
+ with s1.begin() as trans:
+ is_(trans, s1._legacy_transaction())
+ s1.connection()
- is_(s1._transaction, None)
+ is_(s1._transaction, None)
def test_no_double_begin_explicit(self):
- s1 = Session(testing.db)
- s1.begin()
- assert_raises_message(
- sa_exc.InvalidRequestError,
- "A transaction is already begun on this Session.",
- s1.begin,
- )
+ with fixture_session() as s1:
+ s1.begin()
+ assert_raises_message(
+ sa_exc.InvalidRequestError,
+ "A transaction is already begun on this Session.",
+ s1.begin,
+ )
@testing.requires.savepoints
def test_future_rollback_is_global(self):
users = self.tables.users
- s1 = Session(testing.db, future=True)
+ with fixture_session(future=True) as s1:
+ s1.begin()
- s1.begin()
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.begin_nested()
- s1.begin_nested()
-
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back the whole transaction
- s1.rollback()
- is_(s1._legacy_transaction(), None)
+ # rolls back the whole transaction
+ s1.rollback()
+ is_(s1._legacy_transaction(), None)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 0)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 0,
+ )
- s1.commit()
- is_(s1._legacy_transaction(), None)
+ s1.commit()
+ is_(s1._legacy_transaction(), None)
@testing.requires.savepoints
def test_old_rollback_is_local(self):
users = self.tables.users
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- t1 = s1.begin()
+ t1 = s1.begin()
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.begin_nested()
+ s1.begin_nested()
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back only the savepoint
- s1.rollback()
+ # rolls back only the savepoint
+ s1.rollback()
- is_(s1._legacy_transaction(), t1)
+ is_(s1._legacy_transaction(), t1)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
- s1.commit()
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
- is_not(s1._legacy_transaction(), None)
+ s1.commit()
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
+ is_not(s1._legacy_transaction(), None)
def test_session_as_ctx_manager_one(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2212,7 +2236,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
def test_session_as_ctx_manager_future_one(self):
users = self.tables.users
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2234,7 +2258,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2250,7 +2274,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
users = self.tables.users
try:
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2265,7 +2289,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
def test_begin_context_manager(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2296,12 +2320,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_begin_context_manager_rollback_trans(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2318,12 +2343,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_begin_context_manager_rollback_outer(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2340,6 +2366,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_trans(self):
users = self.tables.users
@@ -2363,6 +2390,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_outer(self):
users = self.tables.users
@@ -2386,36 +2414,37 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
class TransactionFlagsTest(fixtures.TestBase):
def test_in_transaction(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- n1 = s1.begin_nested()
+ n1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), n1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), n1)
- n1.rollback()
+ n1.rollback()
- is_(s1.get_nested_transaction(), None)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ eq_(s1.in_transaction(), True)
- s1.commit()
+ s1.commit()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
def test_in_transaction_subtransactions(self):
"""we'd like to do away with subtransactions for future sessions
@@ -2425,72 +2454,71 @@ class TransactionFlagsTest(fixtures.TestBase):
the external API works.
"""
- s1 = Session(testing.db)
-
- eq_(s1.in_transaction(), False)
+ with fixture_session() as s1:
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- subtrans = s1.begin(_subtrans=True)
- is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ subtrans = s1.begin(_subtrans=True)
+ is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
- is_(s1._transaction, subtrans)
+ is_(s1._transaction, subtrans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), True)
- is_(s1._transaction, trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1._transaction, trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1._transaction, None)
+ eq_(s1.in_transaction(), False)
+ is_(s1._transaction, None)
def test_in_transaction_nesting(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- sp1 = s1.begin_nested()
+ sp1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp2 = s1.begin_nested()
+ sp2 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp2)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp2)
- sp2.rollback()
+ sp2.rollback()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp1.rollback()
+ sp1.rollback()
- is_(s1.get_nested_transaction(), None)
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), False)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), False)
+ is_(s1.get_transaction(), trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
class NaturalPKRollbackTest(fixtures.MappedTest):
@@ -2674,8 +2702,11 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
class JoinIntoAnExternalTransactionFixture(object):
"""Test the "join into an external transaction" examples"""
- def setup(self):
- self.connection = testing.db.connect()
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
+ self.engine = testing.db
+ self.connection = self.engine.connect()
self.metadata = MetaData()
self.table = Table(
@@ -2686,6 +2717,17 @@ class JoinIntoAnExternalTransactionFixture(object):
self.setup_session()
+ def teardown_test(self):
+ self.teardown_session()
+
+ with self.connection.begin():
+ self._assert_count(0)
+
+ with self.connection.begin():
+ self.table.drop(self.connection)
+
+ self.connection.close()
+
def test_something(self):
A = self.A
@@ -2727,18 +2769,6 @@ class JoinIntoAnExternalTransactionFixture(object):
)
eq_(result, count)
- def teardown(self):
- self.teardown_session()
-
- with self.connection.begin():
- self._assert_count(0)
-
- with self.connection.begin():
- self.table.drop(self.connection)
-
- # return connection to the Engine
- self.connection.close()
-
class NewStyleJoinIntoAnExternalTransactionTest(
JoinIntoAnExternalTransactionFixture
@@ -2775,7 +2805,8 @@ class NewStyleJoinIntoAnExternalTransactionTest(
# rollback - everything that happened with the
# Session above (including calls to commit())
# is rolled back.
- self.trans.rollback()
+ if self.trans.is_active:
+ self.trans.rollback()
class FutureJoinIntoAnExternalTransactionTest(
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index 84373b2dc..2c35bec45 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -203,14 +203,6 @@ class UnicodeSchemaTest(fixtures.MappedTest):
cls.tables["t1"] = t1
cls.tables["t2"] = t2
- @classmethod
- def setup_class(cls):
- super(UnicodeSchemaTest, cls).setup_class()
-
- @classmethod
- def teardown_class(cls):
- super(UnicodeSchemaTest, cls).teardown_class()
-
def test_mapping(self):
t2, t1 = self.tables.t2, self.tables.t1
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 4e713627c..65089f773 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -771,14 +771,13 @@ class RudimentaryFlushTest(UOWTest):
class SingleCycleTest(UOWTest):
- def teardown(self):
+ def teardown_test(self):
engines.testing_reaper.rollback_all()
# mysql can't handle delete from nodes
# since it doesn't deal with the FKs correctly,
# so wipe out the parent_id first
with testing.db.begin() as conn:
conn.execute(self.tables.nodes.update().values(parent_id=None))
- super(SingleCycleTest, self).teardown()
def test_one_to_many_save(self):
Node, nodes = self.classes.Node, self.tables.nodes
diff --git a/test/requirements.py b/test/requirements.py
index d5a718372..3c9b39ac7 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -1624,15 +1624,15 @@ class DefaultRequirements(SuiteRequirements):
@property
def postgresql_utf8_server_encoding(self):
+ def go(config):
+ if not against(config, "postgresql"):
+ return False
- return only_if(
- lambda config: against(config, "postgresql")
- and config.db.connect(close_with_result=True)
- .exec_driver_sql("show server_encoding")
- .scalar()
- .lower()
- == "utf8"
- )
+ with config.db.connect() as conn:
+ enc = conn.exec_driver_sql("show server_encoding").scalar()
+ return enc.lower() == "utf8"
+
+ return only_if(go)
@property
def cxoracle6_or_greater(self):
diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py
index 4bef1df7f..b44971cec 100644
--- a/test/sql/test_case_statement.py
+++ b/test/sql/test_case_statement.py
@@ -26,7 +26,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
metadata = MetaData()
global info_table
info_table = Table(
@@ -52,7 +52,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
info_table.drop(conn)
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 70281d4e8..1ac3613f7 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -1203,7 +1203,7 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# TODO: we need to get dialects here somehow, perhaps in test_suite?
[
importlib.import_module("sqlalchemy.dialects.%s" % d)
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index fdffe04bf..4429753ec 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -4306,7 +4306,7 @@ class StringifySpecialTest(fixtures.TestBase):
class KwargPropagationTest(fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.sql.expression import ColumnClause, TableClause
class CatchCol(ColumnClause):
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index 2a2e70bc3..8be7eed1f 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -503,9 +503,8 @@ class DefaultRoundTripTest(fixtures.TablesTest):
Column("col11", MyType(), default="foo"),
)
- def teardown(self):
+ def teardown_test(self):
self.default_generator["x"] = 50
- super(DefaultRoundTripTest, self).teardown()
def test_standalone(self, connection):
t = self.tables.default_test
@@ -1226,7 +1225,7 @@ class SpecialTypePKTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
class MyInteger(TypeDecorator):
impl = Integer
diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py
index acc12a5fe..777565220 100644
--- a/test/sql/test_deprecations.py
+++ b/test/sql/test_deprecations.py
@@ -2488,12 +2488,12 @@ class LegacySequenceExecTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
@@ -2574,7 +2574,7 @@ class LegacySequenceExecTest(fixtures.TestBase):
class DDLDeprecatedBindTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
if inspect(conn).has_table("foo"):
conn.execute(schema.DropTable(table("foo")))
diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py
index 4edc9d025..a6001ba9d 100644
--- a/test/sql/test_external_traversal.py
+++ b/test/sql/test_external_traversal.py
@@ -47,7 +47,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
ability to copy and modify a ClauseElement in place."""
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global A, B
# establish two fictitious ClauseElements.
@@ -321,7 +321,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, t3
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1012,7 +1012,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table(
"table1",
@@ -1196,7 +1196,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1943,7 +1943,7 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global table1, table2, table3, table4
def _table(name):
@@ -2031,7 +2031,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -2128,7 +2128,7 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
# fixme: consolidate converage from elsewhere here and expand
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
index 6afe41aac..b0bcee18e 100644
--- a/test/sql/test_from_linter.py
+++ b/test/sql/test_from_linter.py
@@ -25,7 +25,7 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
- def setup(self):
+ def setup_test(self):
self.a = self.tables.table_a
self.b = self.tables.table_b
self.c = self.tables.table_c
@@ -267,8 +267,10 @@ class TestLinter(fixtures.TablesTest):
with self.bind.connect() as conn:
conn.execute(query)
- def test_no_linting(self):
- eng = engines.testing_engine(options={"enable_from_linting": False})
+ def test_no_linting(self, metadata, connection):
+ eng = engines.testing_engine(
+ options={"enable_from_linting": False, "use_reaper": False}
+ )
eng.pool = self.bind.pool # needed for SQLite
a, b = self.tables("table_a", "table_b")
query = select(a.c.col_a).where(b.c.col_b == 5)
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index 91076f9c3..32ea642d7 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -54,10 +54,10 @@ table1 = table(
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_compile(self):
@@ -938,7 +938,7 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
class ExecuteTest(fixtures.TestBase):
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
pass
def test_conn_execute(self, connection):
@@ -1113,10 +1113,10 @@ class ExecuteTest(fixtures.TestBase):
class RegisterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_GenericFunction_is_registered(self):
diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py
index 502f70ce7..9bf351f5c 100644
--- a/test/sql/test_metadata.py
+++ b/test/sql/test_metadata.py
@@ -4257,7 +4257,7 @@ class DialectKWArgTest(fixtures.TestBase):
with mock.patch("sqlalchemy.dialects.registry.load", load):
yield
- def teardown(self):
+ def teardown_test(self):
Index._kw_registry.clear()
def test_participating(self):
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index aaeed68dd..270e79ba1 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -608,7 +608,7 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
@@ -766,7 +766,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
@@ -2370,7 +2370,7 @@ class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
@@ -2403,7 +2403,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default_enhanced"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 136f10cf4..7ad12c620 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -661,8 +661,8 @@ class CursorResultTest(fixtures.TablesTest):
assert_raises(KeyError, lambda: row._mapping["Case_insensitive"])
assert_raises(KeyError, lambda: row._mapping["casesensitive"])
- def test_row_case_sensitive_unoptimized(self):
- with engines.testing_engine().connect() as ins_conn:
+ def test_row_case_sensitive_unoptimized(self, testing_engine):
+ with testing_engine().connect() as ins_conn:
row = ins_conn.execute(
select(
literal_column("1").label("case_insensitive"),
@@ -1234,8 +1234,7 @@ class CursorResultTest(fixtures.TablesTest):
eq_(proxy[0], "value")
eq_(proxy._mapping["key"], "value")
- @testing.provide_metadata
- def test_no_rowcount_on_selects_inserts(self):
+ def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine):
"""assert that rowcount is only called on deletes and updates.
This because cursor.rowcount may can be expensive on some dialects
@@ -1244,9 +1243,7 @@ class CursorResultTest(fixtures.TablesTest):
"""
- metadata = self.metadata
-
- engine = engines.testing_engine()
+ engine = testing_engine()
t = Table("t1", metadata, Column("data", String(10)))
metadata.create_all(engine)
@@ -2132,7 +2129,9 @@ class AlternateCursorResultTest(fixtures.TablesTest):
@classmethod
def setup_bind(cls):
- cls.engine = engine = engines.testing_engine("sqlite://")
+ cls.engine = engine = engines.testing_engine(
+ "sqlite://", options={"scope": "class"}
+ )
return engine
@classmethod
diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py
index 65325aa6f..5cfc2663f 100644
--- a/test/sql/test_sequences.py
+++ b/test/sql/test_sequences.py
@@ -100,12 +100,12 @@ class SequenceExecTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 0e1147800..64ace87df 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1375,7 +1375,7 @@ class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL):
class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
- def setup(self):
+ def setup_test(self):
class UTypeOne(types.UserDefinedType):
def get_col_spec(self):
return "UTYPEONE"
@@ -2504,7 +2504,7 @@ class BinaryTest(fixtures.TablesTest, AssertsExecutionResults):
class JSONTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -3445,7 +3445,12 @@ class BooleanTest(
@testing.requires.non_native_boolean_unconstrained
def test_constraint(self, connection):
assert_raises(
- (exc.IntegrityError, exc.ProgrammingError, exc.OperationalError),
+ (
+ exc.IntegrityError,
+ exc.ProgrammingError,
+ exc.OperationalError,
+ exc.InternalError, # older pymysql's do this
+ ),
connection.exec_driver_sql,
"insert into boolean_table (id, value) values(1, 5)",
)