summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/aaa_profiling/test_compiler.py2
-rw-r--r--test/aaa_profiling/test_memusage.py4
-rw-r--r--test/aaa_profiling/test_misc.py2
-rw-r--r--test/aaa_profiling/test_orm.py6
-rw-r--r--test/aaa_profiling/test_pool.py2
-rw-r--r--test/base/test_events.py22
-rw-r--r--test/base/test_inspect.py2
-rw-r--r--test/base/test_tutorials.py4
-rw-r--r--test/dialect/mssql/test_compiler.py2
-rw-r--r--test/dialect/mssql/test_deprecations.py2
-rw-r--r--test/dialect/mssql/test_query.py3
-rw-r--r--test/dialect/mysql/test_compiler.py4
-rw-r--r--test/dialect/mysql/test_reflection.py2
-rw-r--r--test/dialect/oracle/test_compiler.py2
-rw-r--r--test/dialect/oracle/test_dialect.py4
-rw-r--r--test/dialect/oracle/test_reflection.py16
-rw-r--r--test/dialect/oracle/test_types.py4
-rw-r--r--test/dialect/postgresql/test_async_pg_py3k.py4
-rw-r--r--test/dialect/postgresql/test_compiler.py8
-rw-r--r--test/dialect/postgresql/test_dialect.py21
-rw-r--r--test/dialect/postgresql/test_query.py6
-rw-r--r--test/dialect/postgresql/test_reflection.py499
-rw-r--r--test/dialect/postgresql/test_types.py692
-rw-r--r--test/dialect/test_sqlite.py32
-rw-r--r--test/engine/test_ddlevents.py4
-rw-r--r--test/engine/test_deprecations.py16
-rw-r--r--test/engine/test_execute.py364
-rw-r--r--test/engine/test_logging.py16
-rw-r--r--test/engine/test_pool.py57
-rw-r--r--test/engine/test_processors.py10
-rw-r--r--test/engine/test_reconnect.py20
-rw-r--r--test/engine/test_reflection.py9
-rw-r--r--test/engine/test_transaction.py21
-rw-r--r--test/ext/asyncio/test_engine_py3k.py16
-rw-r--r--test/ext/declarative/test_inheritance.py4
-rw-r--r--test/ext/declarative/test_reflection.py127
-rw-r--r--test/ext/test_associationproxy.py16
-rw-r--r--test/ext/test_baked.py2
-rw-r--r--test/ext/test_compiler.py2
-rw-r--r--test/ext/test_extendedattr.py6
-rw-r--r--test/ext/test_horizontal_shard.py32
-rw-r--r--test/ext/test_hybrid.py2
-rw-r--r--test/ext/test_mutable.py15
-rw-r--r--test/ext/test_orderinglist.py16
-rw-r--r--test/orm/declarative/test_basic.py4
-rw-r--r--test/orm/declarative/test_concurrency.py2
-rw-r--r--test/orm/declarative/test_inheritance.py4
-rw-r--r--test/orm/declarative/test_mixin.py4
-rw-r--r--test/orm/declarative/test_reflection.py5
-rw-r--r--test/orm/inheritance/test_basic.py49
-rw-r--r--test/orm/test_attributes.py8
-rw-r--r--test/orm/test_bind.py74
-rw-r--r--test/orm/test_collection.py5
-rw-r--r--test/orm/test_compile.py2
-rw-r--r--test/orm/test_cycles.py3
-rw-r--r--test/orm/test_deprecations.py80
-rw-r--r--test/orm/test_eager_relations.py14
-rw-r--r--test/orm/test_events.py7
-rw-r--r--test/orm/test_froms.py106
-rw-r--r--test/orm/test_lazy_relations.py57
-rw-r--r--test/orm/test_load_on_fks.py30
-rw-r--r--test/orm/test_mapper.py6
-rw-r--r--test/orm/test_options.py4
-rw-r--r--test/orm/test_query.py37
-rw-r--r--test/orm/test_rel_fn.py2
-rw-r--r--test/orm/test_relationships.py6
-rw-r--r--test/orm/test_selectin_relations.py50
-rw-r--r--test/orm/test_session.py20
-rw-r--r--test/orm/test_subquery_relations.py50
-rw-r--r--test/orm/test_transaction.py681
-rw-r--r--test/orm/test_unitofwork.py8
-rw-r--r--test/orm/test_unitofworkv2.py3
-rw-r--r--test/requirements.py16
-rw-r--r--test/sql/test_case_statement.py4
-rw-r--r--test/sql/test_compare.py2
-rw-r--r--test/sql/test_compiler.py2
-rw-r--r--test/sql/test_defaults.py5
-rw-r--r--test/sql/test_deprecations.py6
-rw-r--r--test/sql/test_external_traversal.py14
-rw-r--r--test/sql/test_from_linter.py8
-rw-r--r--test/sql/test_functions.py10
-rw-r--r--test/sql/test_metadata.py2
-rw-r--r--test/sql/test_operators.py8
-rw-r--r--test/sql/test_resultset.py15
-rw-r--r--test/sql/test_sequences.py4
-rw-r--r--test/sql/test_types.py11
86 files changed, 1780 insertions, 1748 deletions
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py
index 0202768ae..968a74700 100644
--- a/test/aaa_profiling/test_compiler.py
+++ b/test/aaa_profiling/test_compiler.py
@@ -18,7 +18,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, metadata
metadata = MetaData()
diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py
index 75a4f51cf..a41a8b9f1 100644
--- a/test/aaa_profiling/test_memusage.py
+++ b/test/aaa_profiling/test_memusage.py
@@ -241,7 +241,7 @@ def assert_no_mappers():
class EnsureZeroed(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
_sessions.clear()
_mapper_registry.clear()
@@ -1032,7 +1032,7 @@ class MemUsageWBackendTest(EnsureZeroed):
t2_mapper = mapper(T2, t2)
t1_mapper.add_property("bar", relationship(t2_mapper))
- s1 = fixture_session()
+ s1 = Session(testing.db)
# this causes the path_registry to be invoked
s1.query(t1_mapper)._compile_context()
diff --git a/test/aaa_profiling/test_misc.py b/test/aaa_profiling/test_misc.py
index db6fd4b71..5b30a3968 100644
--- a/test/aaa_profiling/test_misc.py
+++ b/test/aaa_profiling/test_misc.py
@@ -19,7 +19,7 @@ from sqlalchemy.util import classproperty
class EnumTest(fixtures.TestBase):
__requires__ = ("cpython", "python_profiling_backend")
- def setup(self):
+ def setup_test(self):
class SomeEnum(object):
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py
index f163078d8..8116e5f21 100644
--- a/test/aaa_profiling/test_orm.py
+++ b/test/aaa_profiling/test_orm.py
@@ -29,15 +29,13 @@ class NoCache(object):
run_setup_bind = "each"
@classmethod
- def setup_class(cls):
- super(NoCache, cls).setup_class()
+ def setup_test_class(cls):
cls._cache = config.db._compiled_cache
config.db._compiled_cache = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
config.db._compiled_cache = cls._cache
- super(NoCache, cls).teardown_class()
class MergeTest(NoCache, fixtures.MappedTest):
diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py
index fd02f9139..da3c1c525 100644
--- a/test/aaa_profiling/test_pool.py
+++ b/test/aaa_profiling/test_pool.py
@@ -17,7 +17,7 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults):
def close(self):
pass
- def setup(self):
+ def setup_test(self):
# create a throwaway pool which
# has the effect of initializing
# class-level event listeners on Pool,
diff --git a/test/base/test_events.py b/test/base/test_events.py
index 19f68e9a3..68db5207c 100644
--- a/test/base/test_events.py
+++ b/test/base/test_events.py
@@ -16,7 +16,7 @@ from sqlalchemy.testing.util import gc_collect
class TearDownLocalEventsFixture(object):
- def tearDown(self):
+ def teardown_test(self):
classes = set()
for entry in event.base._registrars.values():
for evt_cls in entry:
@@ -30,7 +30,7 @@ class TearDownLocalEventsFixture(object):
class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test class- and instance-level event registration."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, x, y):
pass
@@ -438,7 +438,7 @@ class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test adaption of legacy args"""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
@event._legacy_signature("0.9", ["x", "y"])
def event_three(self, x, y, z, q):
@@ -608,7 +608,7 @@ class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
@@ -677,7 +677,7 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test default target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
@@ -734,7 +734,7 @@ class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _accept_with(cls, target):
@@ -771,7 +771,7 @@ class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test that ad-hoc subclasses are garbage collected."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def some_event(self, x, y):
pass
@@ -797,7 +797,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom listen functions which change the listener function
signature."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _listen(cls, event_key, add=False):
@@ -855,7 +855,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, arg):
pass
@@ -889,7 +889,7 @@ class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
@@ -1109,7 +1109,7 @@ class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
diff --git a/test/base/test_inspect.py b/test/base/test_inspect.py
index 15b98c848..252d0d977 100644
--- a/test/base/test_inspect.py
+++ b/test/base/test_inspect.py
@@ -13,7 +13,7 @@ class TestFixture(object):
class TestInspection(fixtures.TestBase):
- def tearDown(self):
+ def teardown_test(self):
for type_ in list(inspection._registrars):
if issubclass(type_, TestFixture):
del inspection._registrars[type_]
diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py
index 14e87ef69..6320ef052 100644
--- a/test/base/test_tutorials.py
+++ b/test/base/test_tutorials.py
@@ -48,11 +48,11 @@ class DocTest(fixtures.TestBase):
ddl.sort_tables_and_constraints = self.orig_sort
- def setup(self):
+ def setup_test(self):
self._setup_logger()
self._setup_create_table_patcher()
- def teardown(self):
+ def teardown_test(self):
self._teardown_create_table_patcher()
self._teardown_logger()
diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py
index 8119612e1..f0bb66aa9 100644
--- a/test/dialect/mssql/test_compiler.py
+++ b/test/dialect/mssql/test_compiler.py
@@ -1814,7 +1814,7 @@ class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL):
class SchemaTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
t = Table(
"sometable",
MetaData(),
diff --git a/test/dialect/mssql/test_deprecations.py b/test/dialect/mssql/test_deprecations.py
index c869182c5..27709beb0 100644
--- a/test/dialect/mssql/test_deprecations.py
+++ b/test/dialect/mssql/test_deprecations.py
@@ -31,7 +31,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL):
"""
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.t1 = table(
"t1",
diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py
index cdb37cc61..b806b9247 100644
--- a/test/dialect/mssql/test_query.py
+++ b/test/dialect/mssql/test_query.py
@@ -455,7 +455,7 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
return testing.db.execution_options(isolation_level="AUTOCOMMIT")
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as conn:
@@ -463,7 +463,6 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
except:
pass
- super(MatchTest, cls).setup_class()
@classmethod
def insert_data(cls, connection):
diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py
index 62292b9da..7fd24e8b5 100644
--- a/test/dialect/mysql/test_compiler.py
+++ b/test/dialect/mysql/test_compiler.py
@@ -991,7 +991,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = mysql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"foos",
MetaData(),
@@ -1062,7 +1062,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpCommon(testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py
index 40617e59c..795b2cbd3 100644
--- a/test/dialect/mysql/test_reflection.py
+++ b/test/dialect/mysql/test_reflection.py
@@ -1115,7 +1115,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
class RawReflectionTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
dialect = mysql.dialect()
self.parser = _reflection.MySQLTableDefinitionParser(
dialect, dialect.identifier_preparer
diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py
index 1b8b3fb89..f09346eb3 100644
--- a/test/dialect/oracle/test_compiler.py
+++ b/test/dialect/oracle/test_compiler.py
@@ -1355,7 +1355,7 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "oracle"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py
index df87fe89f..32234bf65 100644
--- a/test/dialect/oracle/test_dialect.py
+++ b/test/dialect/oracle/test_dialect.py
@@ -439,7 +439,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as c:
c.exec_driver_sql(
"""
@@ -471,7 +471,7 @@ end;
assert isinstance(result.out_parameters["x_out"], int)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.execute(text("DROP PROCEDURE foo"))
diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py
index 81e4e4ab5..0df4236e2 100644
--- a/test/dialect/oracle/test_reflection.py
+++ b/test/dialect/oracle/test_reflection.py
@@ -39,7 +39,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# currently assuming full DBA privs for the user.
# don't really know how else to go here unless
# we connect as the other user.
@@ -85,7 +85,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
conn.exec_driver_sql(stmt)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for stmt in (
"""
@@ -379,7 +379,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("create table my_table (id integer)")
conn.exec_driver_sql(
@@ -389,7 +389,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
"create table foo_table (id integer) tablespace SYSTEM"
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table my_temp_table")
conn.exec_driver_sql("drop table my_table")
@@ -421,7 +421,7 @@ class DontReflectIOTTest(fixtures.TestBase):
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
@@ -438,7 +438,7 @@ class DontReflectIOTTest(fixtures.TestBase):
""",
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table admin_docindex")
@@ -715,7 +715,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.testing import config
cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link")
@@ -734,7 +734,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop synonym test_table_syn")
conn.exec_driver_sql("drop table test_table")
diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py
index f008ea019..8ea7c0e04 100644
--- a/test/dialect/oracle/test_types.py
+++ b/test/dialect/oracle/test_types.py
@@ -1011,7 +1011,7 @@ class EuroNumericTest(fixtures.TestBase):
__only_on__ = "oracle+cx_oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
connect = testing.db.pool._creator
def _creator():
@@ -1023,7 +1023,7 @@ class EuroNumericTest(fixtures.TestBase):
self.engine = testing_engine(options={"creator": _creator})
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_were_getting_a_comma(self):
diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py
index fadf939b8..f6d48f3c6 100644
--- a/test/dialect/postgresql/test_async_pg_py3k.py
+++ b/test/dialect/postgresql/test_async_pg_py3k.py
@@ -27,7 +27,7 @@ class AsyncPgTest(fixtures.TestBase):
# TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3
# merges
- from sqlalchemy.testing import engines
+ from sqlalchemy.testing import util as testing_util
from sqlalchemy.sql import schema
metadata = schema.MetaData()
@@ -35,7 +35,7 @@ class AsyncPgTest(fixtures.TestBase):
try:
yield metadata
finally:
- engines.drop_all_tables(metadata, testing.db)
+ testing_util.drop_all_tables_from_metadata(metadata, testing.db)
@async_test
async def test_detect_stale_ddl_cache_raise_recover(
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 1763b210b..b3a0b9bbd 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1810,7 +1810,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table1 = table1 = table(
"mytable",
column("myid", Integer),
@@ -2222,7 +2222,7 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
@@ -2373,7 +2373,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
@@ -2464,7 +2464,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "postgresql"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index f760a309b..9c9d817bb 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -198,17 +198,17 @@ class ExecuteManyMode(object):
@config.fixture()
def connection(self):
- eng = engines.testing_engine(options=self.options)
+ opts = dict(self.options)
+ opts["use_reaper"] = False
+ eng = engines.testing_engine(options=opts)
conn = eng.connect()
trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
- eng.dispose()
+ yield conn
+ if trans.is_active:
+ trans.rollback()
+ conn.close()
+ eng.dispose()
@classmethod
def define_tables(cls, metadata):
@@ -510,8 +510,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
# assert result.closed
assert result.cursor is None
- @testing.provide_metadata
- def test_insert_returning_preexecute_pk(self, connection):
+ def test_insert_returning_preexecute_pk(self, metadata, connection):
counter = itertools.count(1)
t = Table(
@@ -525,7 +524,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
),
Column("data", Integer),
)
- self.metadata.create_all(connection)
+ metadata.create_all(connection)
result = connection.execute(
t.insert().return_defaults(),
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index 94af168ee..c51fd1943 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -40,10 +40,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
self.metadata.drop_all(conn)
@@ -890,7 +890,7 @@ class ExtractTest(fixtures.TablesTest):
def setup_bind(cls):
from sqlalchemy import event
- eng = engines.testing_engine()
+ eng = engines.testing_engine(options={"scope": "class"})
@event.listens_for(eng, "connect")
def connect(dbapi_conn, rec):
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py
index 754eff25a..6586a8308 100644
--- a/test/dialect/postgresql/test_reflection.py
+++ b/test/dialect/postgresql/test_reflection.py
@@ -80,26 +80,24 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
]:
sa.event.listen(metadata, "before_drop", sa.DDL(ddl))
- def test_foreign_table_is_reflected(self):
+ def test_foreign_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_foreigntable", metadata, autoload_with=testing.db)
+ table = Table("test_foreigntable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected foreign table didn't equal expected columns",
)
- def test_get_foreign_table_names(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- ft_names = inspector.get_foreign_table_names()
- eq_(ft_names, ["test_foreigntable"])
+ def test_get_foreign_table_names(self, connection):
+ inspector = inspect(connection)
+ ft_names = inspector.get_foreign_table_names()
+ eq_(ft_names, ["test_foreigntable"])
- def test_get_table_names_no_foreign(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- names = inspector.get_table_names()
- eq_(names, ["testtable"])
+ def test_get_table_names_no_foreign(self, connection):
+ inspector = inspect(connection)
+ names = inspector.get_table_names()
+ eq_(names, ["testtable"])
class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
@@ -133,22 +131,22 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
if testing.against("postgresql >= 11"):
Index("my_index", dv.c.q)
- def test_get_tablenames(self):
+ def test_get_tablenames(self, connection):
assert {"data_values", "data_values_4_10"}.issubset(
- inspect(testing.db).get_table_names()
+ inspect(connection).get_table_names()
)
- def test_reflect_cols(self):
- cols = inspect(testing.db).get_columns("data_values")
+ def test_reflect_cols(self, connection):
+ cols = inspect(connection).get_columns("data_values")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
- def test_reflect_cols_from_partition(self):
- cols = inspect(testing.db).get_columns("data_values_4_10")
+ def test_reflect_cols_from_partition(self, connection):
+ cols = inspect(connection).get_columns("data_values_4_10")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
@testing.only_on("postgresql >= 11")
- def test_reflect_index(self):
- idx = inspect(testing.db).get_indexes("data_values")
+ def test_reflect_index(self, connection):
+ idx = inspect(connection).get_indexes("data_values")
eq_(
idx,
[
@@ -162,8 +160,8 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
)
@testing.only_on("postgresql >= 11")
- def test_reflect_index_from_partition(self):
- idx = inspect(testing.db).get_indexes("data_values_4_10")
+ def test_reflect_index_from_partition(self, connection):
+ idx = inspect(connection).get_indexes("data_values_4_10")
# note the name appears to be generated by PG, currently
# 'data_values_4_10_q_idx'
eq_(
@@ -220,44 +218,43 @@ class MaterializedViewReflectionTest(
testtable, "before_drop", sa.DDL("DROP VIEW test_regview")
)
- def test_mview_is_reflected(self):
+ def test_mview_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
+ table = Table("test_mview", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected mview didn't equal expected columns",
)
- def test_mview_select(self):
+ def test_mview_select(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
- with testing.db.connect() as conn:
- eq_(conn.execute(table.select()).fetchall(), [(89, "d1")])
+ table = Table("test_mview", metadata, autoload_with=connection)
+ eq_(connection.execute(table.select()).fetchall(), [(89, "d1")])
- def test_get_view_names(self):
- insp = inspect(testing.db)
+ def test_get_view_names(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_plain(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
- def test_get_view_names_plain_string(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain_string(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names(include="plain")), set(["test_regview"]))
- def test_get_view_names_materialized(self):
- insp = inspect(testing.db)
+ def test_get_view_names_materialized(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("materialized",))),
set(["test_mview"]),
)
- def test_get_view_names_reflection_cache_ok(self):
- insp = inspect(testing.db)
+ def test_get_view_names_reflection_cache_ok(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
@@ -267,12 +264,12 @@ class MaterializedViewReflectionTest(
)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_empty(self):
- insp = inspect(testing.db)
+ def test_get_view_names_empty(self, connection):
+ insp = inspect(connection)
assert_raises(ValueError, insp.get_view_names, include=())
- def test_get_view_definition(self):
- insp = inspect(testing.db)
+ def test_get_view_definition(self, connection):
+ insp = inspect(connection)
eq_(
re.sub(
r"[\n\t ]+",
@@ -290,7 +287,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as con:
for ddl in [
'CREATE SCHEMA "SomeSchema"',
@@ -334,7 +331,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as con:
con.exec_driver_sql("DROP TABLE testtable")
con.exec_driver_sql("DROP TABLE test_schema.testtable")
@@ -350,9 +347,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
- def test_table_is_reflected(self):
+ def test_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["question", "answer"]),
@@ -360,9 +357,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
assert isinstance(table.c.answer.type, Integer)
- def test_domain_is_reflected(self):
+ def test_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"42",
@@ -372,28 +369,28 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
not table.columns.answer.nullable
), "Expected reflected column to not be nullable."
- def test_enum_domain_is_reflected(self):
+ def test_enum_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("enum_test", metadata, autoload_with=testing.db)
+ table = Table("enum_test", metadata, autoload_with=connection)
eq_(table.c.data.type.enums, ["test"])
- def test_array_domain_is_reflected(self):
+ def test_array_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("array_test", metadata, autoload_with=testing.db)
+ table = Table("array_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, ARRAY)
eq_(table.c.data.type.item_type.__class__, INTEGER)
- def test_quoted_remote_schema_domain_is_reflected(self):
+ def test_quoted_remote_schema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("quote_test", metadata, autoload_with=testing.db)
+ table = Table("quote_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, INTEGER)
- def test_table_is_reflected_test_schema(self):
+ def test_table_is_reflected_test_schema(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
@@ -403,12 +400,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
assert isinstance(table.c.anything.type, Integer)
- def test_schema_domain_is_reflected(self):
+ def test_schema_domain_is_reflected(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
@@ -420,9 +417,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_crosschema_domain_is_reflected(self):
+ def test_crosschema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("crosschema", metadata, autoload_with=testing.db)
+ table = Table("crosschema", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"0",
@@ -432,7 +429,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_unknown_types(self):
+ def test_unknown_types(self, connection):
from sqlalchemy.dialects.postgresql import base
ischema_names = base.PGDialect.ischema_names
@@ -440,13 +437,13 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
try:
m2 = MetaData()
assert_raises(
- exc.SAWarning, Table, "testtable", m2, autoload_with=testing.db
+ exc.SAWarning, Table, "testtable", m2, autoload_with=connection
)
@testing.emits_warning("Did not recognize type")
def warns():
m3 = MetaData()
- t3 = Table("testtable", m3, autoload_with=testing.db)
+ t3 = Table("testtable", m3, autoload_with=connection)
assert t3.c.answer.type.__class__ == sa.types.NullType
finally:
@@ -471,9 +468,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
subject = Table("subject", meta2, autoload_with=connection)
eq_(subject.primary_key.columns.keys(), ["p2", "p1"])
- @testing.provide_metadata
- def test_pg_weirdchar_reflection(self):
- meta1 = self.metadata
+ def test_pg_weirdchar_reflection(self, metadata, connection):
+ meta1 = metadata
subject = Table(
"subject", meta1, Column("id$", Integer, primary_key=True)
)
@@ -483,101 +479,91 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("ref", Integer, ForeignKey("subject.id$")),
)
- meta1.create_all(testing.db)
+ meta1.create_all(connection)
meta2 = MetaData()
- subject = Table("subject", meta2, autoload_with=testing.db)
- referer = Table("referer", meta2, autoload_with=testing.db)
+ subject = Table("subject", meta2, autoload_with=connection)
+ referer = Table("referer", meta2, autoload_with=connection)
self.assert_(
(subject.c["id$"] == referer.c.ref).compare(
subject.join(referer).onclause
)
)
- @testing.provide_metadata
- def test_reflect_default_over_128_chars(self):
+ def test_reflect_default_over_128_chars(self, metadata, connection):
Table(
"t",
- self.metadata,
+ metadata,
Column("x", String(200), server_default="abcd" * 40),
- ).create(testing.db)
+ ).create(connection)
m = MetaData()
- t = Table("t", m, autoload_with=testing.db)
+ t = Table("t", m, autoload_with=connection)
eq_(
t.c.x.server_default.arg.text,
"'%s'::character varying" % ("abcd" * 40),
)
- @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure")
- @testing.provide_metadata
- def test_renamed_sequence_reflection(self):
- metadata = self.metadata
+ def test_renamed_sequence_reflection(self, metadata, connection):
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db, implicit_returning=False)
+ t2 = Table("t", m2, autoload_with=connection, implicit_returning=False)
eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)")
- with testing.db.begin() as conn:
- r = conn.execute(t2.insert())
- eq_(r.inserted_primary_key, (1,))
+ r = connection.execute(t2.insert())
+ eq_(r.inserted_primary_key, (1,))
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t_id_seq rename to foobar_id_seq"
- )
+ connection.exec_driver_sql(
+ "alter table t_id_seq rename to foobar_id_seq"
+ )
m3 = MetaData()
- t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False)
+ t3 = Table("t", m3, autoload_with=connection, implicit_returning=False)
eq_(
t3.c.id.server_default.arg.text,
"nextval('foobar_id_seq'::regclass)",
)
- with testing.db.begin() as conn:
- r = conn.execute(t3.insert())
- eq_(r.inserted_primary_key, (2,))
+ r = connection.execute(t3.insert())
+ eq_(r.inserted_primary_key, (2,))
- @testing.provide_metadata
- def test_altered_type_autoincrement_pk_reflection(self):
- metadata = self.metadata
+ def test_altered_type_autoincrement_pk_reflection(
+ self, metadata, connection
+ ):
+ metadata = metadata
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t alter column id type varchar(50)"
- )
+ connection.exec_driver_sql(
+ "alter table t alter column id type varchar(50)"
+ )
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_(t2.c.id.autoincrement, False)
eq_(t2.c.x.autoincrement, False)
- @testing.provide_metadata
- def test_renamed_pk_reflection(self):
- metadata = self.metadata
+ def test_renamed_pk_reflection(self, metadata, connection):
+ metadata = metadata
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("alter table t rename id to t_id")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("alter table t rename id to t_id")
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_([c.name for c in t2.primary_key], ["t_id"])
- @testing.provide_metadata
- def test_has_temporary_table(self):
- assert not inspect(testing.db).has_table("some_temp_table")
+ def test_has_temporary_table(self, metadata, connection):
+ assert not inspect(connection).has_table("some_temp_table")
user_tmp = Table(
"some_temp_table",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
prefixes=["TEMPORARY"],
)
- user_tmp.create(testing.db)
- assert inspect(testing.db).has_table("some_temp_table")
+ user_tmp.create(connection)
+ assert inspect(connection).has_table("some_temp_table")
def test_cross_schema_reflection_one(self, metadata, connection):
@@ -898,19 +884,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
A_table.create(connection, checkfirst=True)
assert inspect(connection).has_table("A")
- def test_uppercase_lowercase_sequence(self):
+ def test_uppercase_lowercase_sequence(self, connection):
a_seq = Sequence("a")
A_seq = Sequence("A")
- a_seq.create(testing.db)
- assert testing.db.dialect.has_sequence(testing.db, "a")
- assert not testing.db.dialect.has_sequence(testing.db, "A")
- A_seq.create(testing.db, checkfirst=True)
- assert testing.db.dialect.has_sequence(testing.db, "A")
+ a_seq.create(connection)
+ assert connection.dialect.has_sequence(connection, "a")
+ assert not connection.dialect.has_sequence(connection, "A")
+ A_seq.create(connection, checkfirst=True)
+ assert connection.dialect.has_sequence(connection, "A")
- a_seq.drop(testing.db)
- A_seq.drop(testing.db)
+ a_seq.drop(connection)
+ A_seq.drop(connection)
def test_index_reflection(self, metadata, connection):
"""Reflecting expression-based indexes should warn"""
@@ -960,11 +946,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_index_reflection_partial(self, connection):
+ def test_index_reflection_partial(self, metadata, connection):
"""Reflect the filter defintion on partial indexes"""
- metadata = self.metadata
+ metadata = metadata
t1 = Table(
"table1",
@@ -978,7 +963,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
metadata.create_all(connection)
- ind = testing.db.dialect.get_indexes(connection, t1, None)
+ ind = connection.dialect.get_indexes(connection, t1, None)
partial_definitions = []
for ix in ind:
@@ -1073,15 +1058,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
compile_exprs(r3.expressions),
)
- @testing.provide_metadata
- def test_index_reflection_modified(self):
+ def test_index_reflection_modified(self, metadata, connection):
"""reflect indexes when a column name has changed - PG 9
does not update the name of the column in the index def.
[ticket:2141]
"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
@@ -1089,26 +1073,21 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
- conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
+ connection.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {"name": "idx1", "unique": False, "column_names": ["y"]}
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
+ ind = connection.dialect.get_indexes(connection, "t", None)
+ expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
- eq_(ind, expected)
+ eq_(ind, expected)
- @testing.fails_if("postgresql < 8.2", "reloptions not supported")
- @testing.provide_metadata
- def test_index_reflection_with_storage_options(self):
+ def test_index_reflection_with_storage_options(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
@@ -1116,70 +1095,63 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
- )
+ connection.exec_driver_sql(
+ "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
+ )
- ind = testing.db.dialect.get_indexes(conn, "t", None)
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {
- "postgresql_with": {"fillfactor": "50"}
- },
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_with": {"fillfactor": "50"}},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["with"],
- {"fillfactor": "50"},
- )
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["with"],
+ {"fillfactor": "50"},
+ )
- @testing.provide_metadata
- def test_index_reflection_with_access_method(self):
+ def test_index_reflection_with_access_method(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
-
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", ARRAY(Integer)),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {"postgresql_using": "gin"},
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["using"],
- "gin",
- )
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_using": "gin"},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["using"],
+ "gin",
+ )
@testing.skip_if("postgresql < 11.0", "indnkeyatts not supported")
def test_index_reflection_with_include(self, metadata, connection):
@@ -1199,7 +1171,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
# [{'column_names': ['x', 'name'],
# 'name': 'idx1', 'unique': False}]
- ind = testing.db.dialect.get_indexes(connection, "t", None)
+ ind = connection.dialect.get_indexes(connection, "t", None)
eq_(
ind,
[
@@ -1286,15 +1258,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
for fk in fks:
eq_(fk, fk_ref[fk["name"]])
- @testing.provide_metadata
- def test_inspect_enums_schema(self, connection):
+ def test_inspect_enums_schema(self, metadata, connection):
enum_type = postgresql.ENUM(
"sad",
"ok",
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
enum_type.create(connection)
inspector = inspect(connection)
@@ -1310,13 +1281,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums(self):
+ def test_inspect_enums(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
[
@@ -1329,17 +1299,16 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive(self):
+ def test_inspect_enums_case_sensitive(self, metadata, connection):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
for enum in "lower_case", "UpperCase", "Name.With.Dot":
@@ -1350,11 +1319,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"CapsTwo",
name=enum,
schema=schema,
- metadata=self.metadata,
+ metadata=metadata,
)
- self.metadata.create_all(testing.db)
- inspector = inspect(testing.db)
+ metadata.create_all(connection)
+ inspector = inspect(connection)
for schema in None, "test_schema", "TestSchema":
eq_(
sorted(
@@ -1382,17 +1351,18 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive_from_table(self):
+ def test_inspect_enums_case_sensitive_from_table(
+ self, metadata, connection
+ ):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
counter = itertools.count()
@@ -1403,19 +1373,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"CapsOne",
"CapsTwo",
name=enum,
- metadata=self.metadata,
+ metadata=metadata,
schema=schema,
)
Table(
"t%d" % next(counter),
- self.metadata,
+ metadata,
Column("q", enum_type),
)
- self.metadata.create_all(testing.db)
+ metadata.create_all(connection)
- inspector = inspect(testing.db)
+ inspector = inspect(connection)
counter = itertools.count()
for enum in "lower_case", "UpperCase", "Name.With.Dot":
for schema in None, "test_schema", "TestSchema":
@@ -1439,10 +1409,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_star(self):
+ def test_inspect_enums_star(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
schema_enum_type = postgresql.ENUM(
"sad",
@@ -1450,11 +1419,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
- enum_type.create(testing.db)
- schema_enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ schema_enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
@@ -1486,11 +1455,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty(self):
- enum_type = postgresql.ENUM(name="empty", metadata=self.metadata)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ def test_inspect_enum_empty(self, metadata, connection):
+ enum_type = postgresql.ENUM(name="empty", metadata=metadata)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
@@ -1504,13 +1472,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty_from_table(self):
+ def test_inspect_enum_empty_from_table(self, metadata, connection):
Table(
- "t", self.metadata, Column("x", postgresql.ENUM(name="empty"))
- ).create(testing.db)
+ "t", metadata, Column("x", postgresql.ENUM(name="empty"))
+ ).create(connection)
- t = Table("t", MetaData(), autoload_with=testing.db)
+ t = Table("t", MetaData(), autoload_with=connection)
eq_(t.c.x.type.enums, [])
def test_reflection_with_unique_constraint(self, metadata, connection):
@@ -1749,12 +1716,12 @@ class CustomTypeReflectionTest(fixtures.TestBase):
ischema_names = None
- def setup(self):
+ def setup_test(self):
ischema_names = postgresql.PGDialect.ischema_names
postgresql.PGDialect.ischema_names = ischema_names.copy()
self.ischema_names = ischema_names
- def teardown(self):
+ def teardown_test(self):
postgresql.PGDialect.ischema_names = self.ischema_names
self.ischema_names = None
@@ -1788,55 +1755,51 @@ class IntervalReflectionTest(fixtures.TestBase):
__only_on__ = "postgresql"
__backend__ = True
- def test_interval_types(self):
- for sym in [
- "YEAR",
- "MONTH",
- "DAY",
- "HOUR",
- "MINUTE",
- "SECOND",
- "YEAR TO MONTH",
- "DAY TO HOUR",
- "DAY TO MINUTE",
- "DAY TO SECOND",
- "HOUR TO MINUTE",
- "HOUR TO SECOND",
- "MINUTE TO SECOND",
- ]:
- self._test_interval_symbol(sym)
-
- @testing.provide_metadata
- def _test_interval_symbol(self, sym):
+ @testing.combinations(
+ ("YEAR",),
+ ("MONTH",),
+ ("DAY",),
+ ("HOUR",),
+ ("MINUTE",),
+ ("SECOND",),
+ ("YEAR TO MONTH",),
+ ("DAY TO HOUR",),
+ ("DAY TO MINUTE",),
+ ("DAY TO SECOND",),
+ ("HOUR TO MINUTE",),
+ ("HOUR TO SECOND",),
+ ("MINUTE TO SECOND",),
+ argnames="sym",
+ )
+ def test_interval_types(self, sym, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(fields=sym)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, sym.lower())
eq_(columns["data1"]["type"].precision, None)
- @testing.provide_metadata
- def test_interval_precision(self):
+ def test_interval_precision(self, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(precision=6)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, None)
@@ -1871,8 +1834,8 @@ class IdentityReflectionTest(fixtures.TablesTest):
Column("id4", SmallInteger, Identity()),
)
- def test_reflect_identity(self):
- insp = inspect(testing.db)
+ def test_reflect_identity(self, connection):
+ insp = inspect(connection)
default = dict(
always=False,
start=1,
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index e8a1876c7..6202f8f86 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -49,7 +49,6 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.sql import operators
from sqlalchemy.sql import sqltypes
-from sqlalchemy.testing import engines
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
@@ -156,8 +155,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql > 8.3"
- @testing.provide_metadata
- def test_create_table(self, connection):
+ def test_create_table(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"table",
@@ -177,8 +175,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[(1, "two"), (2, "three"), (3, "three")],
)
- @testing.combinations(None, "foo")
- def test_create_table_schema_translate_map(self, symbol_name):
+ @testing.combinations(None, "foo", argnames="symbol_name")
+ def test_create_table_schema_translate_map(self, connection, symbol_name):
# note we can't use the fixture here because it will not drop
# from the correct schema
metadata = MetaData()
@@ -199,35 +197,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
),
schema=symbol_name,
)
- with testing.db.begin() as conn:
- conn = conn.execution_options(
- schema_translate_map={symbol_name: testing.config.test_schema}
- )
- t1.create(conn)
- assert "schema_enum" in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.create(conn, checkfirst=True)
+ conn = connection.execution_options(
+ schema_translate_map={symbol_name: testing.config.test_schema}
+ )
+ t1.create(conn)
+ assert "schema_enum" in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.create(conn, checkfirst=True)
- conn.execute(t1.insert(), value="two")
- conn.execute(t1.insert(), value="three")
- conn.execute(t1.insert(), value="three")
- eq_(
- conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
- [(1, "two"), (2, "three"), (3, "three")],
- )
+ conn.execute(t1.insert(), value="two")
+ conn.execute(t1.insert(), value="three")
+ conn.execute(t1.insert(), value="three")
+ eq_(
+ conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
+ [(1, "two"), (2, "three"), (3, "three")],
+ )
- t1.drop(conn)
- assert "schema_enum" not in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.drop(conn, checkfirst=True)
+ t1.drop(conn)
+ assert "schema_enum" not in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.drop(conn, checkfirst=True)
def test_name_required(self, metadata, connection):
etype = Enum("four", "five", "six", metadata=metadata)
@@ -270,8 +263,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[util.u("réveillé"), util.u("drôle"), util.u("S’il")],
)
- @testing.provide_metadata
- def test_non_native_enum(self, connection):
+ def test_non_native_enum(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
@@ -290,10 +282,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
@@ -307,8 +299,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": "two"})
eq_(connection.scalar(select(t1.c.bar)), "two")
- @testing.provide_metadata
- def test_non_native_enum_w_unicode(self, connection):
+ def test_non_native_enum_w_unicode(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
@@ -326,10 +317,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
@@ -346,8 +337,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": util.u("Ü")})
eq_(connection.scalar(select(t1.c.bar)), util.u("Ü"))
- @testing.provide_metadata
- def test_disable_create(self):
+ def test_disable_create(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM(
@@ -357,13 +347,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
t1 = Table("e1", metadata, Column("c1", e1))
# table can be created separately
# without conflict
- e1.create(bind=testing.db)
- t1.create(testing.db)
- t1.drop(testing.db)
- e1.drop(bind=testing.db)
+ e1.create(bind=connection)
+ t1.create(connection)
+ t1.drop(connection)
+ e1.drop(bind=connection)
- @testing.provide_metadata
- def test_dont_keep_checking(self, connection):
+ def test_dont_keep_checking(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM("one", "two", "three", name="myenum")
@@ -560,11 +549,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
e["name"] for e in inspect(connection).get_enums()
]
- def test_non_native_dialect(self):
- engine = engines.testing_engine()
+ def test_non_native_dialect(self, metadata, testing_engine):
+ engine = testing_engine()
engine.connect()
engine.dialect.supports_native_enum = False
- metadata = MetaData()
t1 = Table(
"foo",
metadata,
@@ -583,21 +571,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
def go():
t1.create(engine)
- try:
- self.assert_sql(
- engine,
- go,
- [
- (
- "CREATE TABLE foo (bar "
- "VARCHAR(5), CONSTRAINT myenum CHECK "
- "(bar IN ('one', 'two', 'three')))",
- {},
- )
- ],
- )
- finally:
- metadata.drop_all(engine)
+ self.assert_sql(
+ engine,
+ go,
+ [
+ (
+ "CREATE TABLE foo (bar "
+ "VARCHAR(5), CONSTRAINT myenum CHECK "
+ "(bar IN ('one', 'two', 'three')))",
+ {},
+ )
+ ],
+ )
def test_standalone_enum(self, connection, metadata):
etype = Enum(
@@ -605,26 +590,26 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
etype.create(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
etype.drop(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
metadata.create_all(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
metadata.drop_all(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
- def test_no_support(self):
+ def test_no_support(self, testing_engine):
def server_version_info(self):
return (8, 2)
- e = engines.testing_engine()
+ e = testing_engine()
dialect = e.dialect
dialect._get_server_version_info = server_version_info
@@ -692,8 +677,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
eq_(t2.c.value2.type.name, "fourfivesixtype")
eq_(t2.c.value2.type.schema, "test_schema")
- @testing.provide_metadata
- def test_custom_subclass(self, connection):
+ def test_custom_subclass(self, metadata, connection):
class MyEnum(TypeDecorator):
impl = Enum("oneHI", "twoHI", "threeHI", name="myenum")
@@ -708,13 +692,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
return value
t1 = Table("table1", self.metadata, Column("data", MyEnum()))
- self.metadata.create_all(testing.db)
+ self.metadata.create_all(connection)
connection.execute(t1.insert(), {"data": "two"})
eq_(connection.scalar(select(t1.c.data)), "twoHITHERE")
- @testing.provide_metadata
- def test_generic_w_pg_variant(self, connection):
+ def test_generic_w_pg_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
@@ -752,8 +735,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
e["name"] for e in inspect(connection).get_enums()
]
- @testing.provide_metadata
- def test_generic_w_some_other_variant(self, connection):
+ def test_generic_w_some_other_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
@@ -809,26 +791,28 @@ class RegClassTest(fixtures.TestBase):
__only_on__ = "postgresql"
__backend__ = True
- @staticmethod
- def _scalar(expression):
- with testing.db.connect() as conn:
- return conn.scalar(select(expression))
+ @testing.fixture()
+ def scalar(self, connection):
+ def go(expression):
+ return connection.scalar(select(expression))
- def test_cast_name(self):
- eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+ return go
- def test_cast_path(self):
+ def test_cast_name(self, scalar):
+ eq_(scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+
+ def test_cast_path(self, scalar):
eq_(
- self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
+ scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
"pg_class",
)
- def test_cast_oid(self):
+ def test_cast_oid(self, scalar):
regclass = cast("pg_class", postgresql.REGCLASS)
- oid = self._scalar(cast(regclass, postgresql.OID))
+ oid = scalar(cast(regclass, postgresql.OID))
assert isinstance(oid, int)
eq_(
- self._scalar(
+ scalar(
cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS)
),
"pg_class",
@@ -1339,13 +1323,12 @@ class ArrayRoundTripTest(object):
Column("dimarr", ProcValue),
)
- def _fixture_456(self, table):
- with testing.db.begin() as conn:
- conn.execute(table.insert(), intarr=[4, 5, 6])
+ def _fixture_456(self, table, connection):
+ connection.execute(table.insert(), intarr=[4, 5, 6])
- def test_reflect_array_column(self):
+ def test_reflect_array_column(self, connection):
metadata2 = MetaData()
- tbl = Table("arrtable", metadata2, autoload_with=testing.db)
+ tbl = Table("arrtable", metadata2, autoload_with=connection)
assert isinstance(tbl.c.intarr.type, self.ARRAY)
assert isinstance(tbl.c.strarr.type, self.ARRAY)
assert isinstance(tbl.c.intarr.type.item_type, Integer)
@@ -1564,7 +1547,7 @@ class ArrayRoundTripTest(object):
def test_array_getitem_single_exec(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(connection.scalar(select(arrtable.c.intarr[2])), 5)
connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7}))
eq_(connection.scalar(select(arrtable.c.intarr[2])), 7)
@@ -1654,11 +1637,10 @@ class ArrayRoundTripTest(object):
set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]),
)
- def test_array_plus_native_enum_create(self):
- m = MetaData()
+ def test_array_plus_native_enum_create(self, metadata, connection):
t = Table(
"t",
- m,
+ metadata,
Column(
"data_1",
self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")),
@@ -1669,13 +1651,13 @@ class ArrayRoundTripTest(object):
),
)
- t.create(testing.db)
+ t.create(connection)
eq_(
- set(e["name"] for e in inspect(testing.db).get_enums()),
+ set(e["name"] for e in inspect(connection).get_enums()),
set(["my_enum_1", "my_enum_2"]),
)
- t.drop(testing.db)
- eq_(inspect(testing.db).get_enums(), [])
+ t.drop(connection)
+ eq_(inspect(connection).get_enums(), [])
class CoreArrayRoundTripTest(
@@ -1690,33 +1672,35 @@ class PGArrayRoundTripTest(
):
ARRAY = postgresql.ARRAY
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_undim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_undim_array_contains_typed_exec(self, struct, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(arrtable.c.intarr).where(
- arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(arrtable.c.intarr).where(
+ arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_dim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_dim_array_contains_typed_exec(self, struct, connection):
dim_arrtable = self.tables.dim_arrtable
- self._fixture_456(dim_arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(dim_arrtable.c.intarr).where(
- dim_arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(dim_arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(dim_arrtable.c.intarr).where(
+ dim_arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
def test_array_contained_by_exec(self, connection):
arrtable = self.tables.arrtable
@@ -1730,7 +1714,7 @@ class PGArrayRoundTripTest(
def test_undim_array_empty(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(
connection.scalar(
select(arrtable.c.intarr).where(arrtable.c.intarr.contains([]))
@@ -1782,8 +1766,9 @@ class ArrayEnum(fixtures.TestBase):
sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls"
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
- @testing.provide_metadata
- def test_raises_non_native_enums(self, array_cls, enum_cls):
+ def test_raises_non_native_enums(
+ self, metadata, connection, array_cls, enum_cls
+ ):
Table(
"my_table",
self.metadata,
@@ -1808,7 +1793,7 @@ class ArrayEnum(fixtures.TestBase):
"for ARRAY of non-native ENUM; please specify "
"create_constraint=False on this Enum datatype.",
self.metadata.create_all,
- testing.db,
+ connection,
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
@@ -1818,8 +1803,7 @@ class ArrayEnum(fixtures.TestBase):
(_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
argnames="array_cls",
)
- @testing.provide_metadata
- def test_array_of_enums(self, array_cls, enum_cls, connection):
+ def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
tbl = Table(
"enum_table",
self.metadata,
@@ -1875,8 +1859,7 @@ class ArrayJSON(fixtures.TestBase):
@testing.combinations(
sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls"
)
- @testing.provide_metadata
- def test_array_of_json(self, array_cls, json_cls, connection):
+ def test_array_of_json(self, array_cls, json_cls, metadata, connection):
tbl = Table(
"json_table",
self.metadata,
@@ -1982,19 +1965,38 @@ class HashableFlagORMTest(fixtures.TestBase):
},
],
),
+ (
+ "HSTORE",
+ postgresql.HSTORE(),
+ [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
+ testing.requires.hstore,
+ ),
+ (
+ "JSONB",
+ postgresql.JSONB(),
+ [
+ {"a": "1", "b": "2", "c": "3"},
+ {
+ "d": "4",
+ "e": {"e1": "5", "e2": "6"},
+ "f": {"f1": [9, 10, 11]},
+ },
+ ],
+ testing.requires.postgresql_jsonb,
+ ),
+ argnames="type_,data",
id_="iaa",
)
- @testing.provide_metadata
- def test_hashable_flag(self, type_, data):
- Base = declarative_base(metadata=self.metadata)
+ def test_hashable_flag(self, metadata, connection, type_, data):
+ Base = declarative_base(metadata=metadata)
class A(Base):
__tablename__ = "a1"
id = Column(Integer, primary_key=True)
data = Column(type_)
- Base.metadata.create_all(testing.db)
- s = Session(testing.db)
+ Base.metadata.create_all(connection)
+ s = Session(connection)
s.add_all([A(data=elem) for elem in data])
s.commit()
@@ -2006,27 +2008,6 @@ class HashableFlagORMTest(fixtures.TestBase):
list(enumerate(data, 1)),
)
- @testing.requires.hstore
- def test_hstore(self):
- self.test_hashable_flag(
- postgresql.HSTORE(),
- [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
- )
-
- @testing.requires.postgresql_jsonb
- def test_jsonb(self):
- self.test_hashable_flag(
- postgresql.JSONB(),
- [
- {"a": "1", "b": "2", "c": "3"},
- {
- "d": "4",
- "e": {"e1": "5", "e2": "6"},
- "f": {"f1": [9, 10, 11]},
- },
- ],
- )
-
class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
@@ -2108,14 +2089,14 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
return table
- def test_reflection(self, special_types_table):
+ def test_reflection(self, special_types_table, connection):
# cheat so that the "strict type check"
# works
special_types_table.c.year_interval.type = postgresql.INTERVAL()
special_types_table.c.month_interval.type = postgresql.INTERVAL()
m = MetaData()
- t = Table("sometable", m, autoload_with=testing.db)
+ t = Table("sometable", m, autoload_with=connection)
self.assert_tables_equal(special_types_table, t, strict_types=True)
assert t.c.plain_interval.type.precision is None
@@ -2210,7 +2191,7 @@ class UUIDTest(fixtures.TestBase):
class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -2494,17 +2475,16 @@ class HStoreRoundTripTest(fixtures.TablesTest):
Column("data", HSTORE),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- conn.execute(
- data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
- {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
- {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
- {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
- )
+ connection.execute(
+ data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
+ {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
+ {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
+ {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
+ )
def _assert_data(self, compare, conn):
data = conn.execute(
@@ -2514,26 +2494,32 @@ class HStoreRoundTripTest(fixtures.TablesTest):
).fetchall()
eq_([d for d, in data], compare)
- def _test_insert(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- )
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ def _test_insert(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ )
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _non_native_engine(self):
- if testing.requires.psycopg2_native_hstore.enabled:
- engine = engines.testing_engine(
- options=dict(use_native_hstore=False)
- )
+ @testing.fixture
+ def non_native_hstore_connection(self, testing_engine):
+ local_engine = testing.requires.psycopg2_native_hstore.enabled
+
+ if local_engine:
+ engine = testing_engine(options=dict(use_native_hstore=False))
else:
engine = testing.db
- engine.connect().close()
- return engine
- def test_reflect(self):
- insp = inspect(testing.db)
+ conn = engine.connect()
+ trans = conn.begin()
+ yield conn
+ try:
+ trans.rollback()
+ finally:
+ conn.close()
+
+ def test_reflect(self, connection):
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[2]["type"], HSTORE)
@@ -2548,106 +2534,88 @@ class HStoreRoundTripTest(fixtures.TablesTest):
eq_(connection.scalar(select(expr)), "3")
@testing.requires.psycopg2_native_hstore
- def test_insert_native(self):
- engine = testing.db
- self._test_insert(engine)
+ def test_insert_native(self, connection):
+ self._test_insert(connection)
- def test_insert_python(self):
- engine = self._non_native_engine()
- self._test_insert(engine)
+ def test_insert_python(self, non_native_hstore_connection):
+ self._test_insert(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_criterion_native(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_native(self, connection):
+ self._fixture_data(connection)
+ self._test_criterion(connection)
- def test_criterion_python(self):
- engine = self._non_native_engine()
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_python(self, non_native_hstore_connection):
+ self._fixture_data(non_native_hstore_connection)
+ self._test_criterion(non_native_hstore_connection)
- def _test_criterion(self, engine):
+ def _test_criterion(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"] == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+ result = connection.execute(
+ select(data_table.c.data).where(data_table.c.data["k1"] == "r3v1")
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
- def _test_fixed_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(["key1", "key2", "key3"]),
- array(["value1", "value2", "value3"]),
- )
- )
- eq_(
- conn.scalar(s),
- {"key1": "value1", "key2": "value2", "key3": "value3"},
+ def _test_fixed_round_trip(self, connection):
+ s = select(
+ hstore(
+ array(["key1", "key2", "key3"]),
+ array(["value1", "value2", "value3"]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
- def test_fixed_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_python(self, non_native_hstore_connection):
+ self._test_fixed_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_fixed_round_trip_native(self):
- engine = testing.db
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_native(self, connection):
+ self._test_fixed_round_trip(connection)
- def _test_unicode_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- )
- )
- eq_(
- conn.scalar(s),
- {
- util.u("réveillé"): util.u("réveillé"),
- util.u("drôle"): util.u("drôle"),
- util.u("S’il"): util.u("S’il"),
- },
+ def _test_unicode_round_trip(self, connection):
+ s = select(
+ hstore(
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {
+ util.u("réveillé"): util.u("réveillé"),
+ util.u("drôle"): util.u("drôle"),
+ util.u("S’il"): util.u("S’il"),
+ },
+ )
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_python(self, non_native_hstore_connection):
+ self._test_unicode_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_native(self):
- engine = testing.db
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_native(self, connection):
+ self._test_unicode_round_trip(connection)
- def test_escaped_quotes_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_python(
+ self, non_native_hstore_connection
+ ):
+ self._test_escaped_quotes_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_escaped_quotes_round_trip_native(self):
- engine = testing.db
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_native(self, connection):
+ self._test_escaped_quotes_round_trip(connection)
- def _test_escaped_quotes_round_trip(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
- )
- self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn)
+ def _test_escaped_quotes_round_trip(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
+ )
+ self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], connection)
- def test_orm_round_trip(self):
+ def test_orm_round_trip(self, connection):
from sqlalchemy import orm
class Data(object):
@@ -2656,13 +2624,14 @@ class HStoreRoundTripTest(fixtures.TablesTest):
self.data = data
orm.mapper(Data, self.tables.data_table)
- s = orm.Session(testing.db)
- d = Data(
- name="r1",
- data={"key1": "value1", "key2": "value2", "key3": "value3"},
- )
- s.add(d)
- eq_(s.query(Data.data, Data).all(), [(d.data, d)])
+
+ with orm.Session(connection) as s:
+ d = Data(
+ name="r1",
+ data={"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
+ s.add(d)
+ eq_(s.query(Data.data, Data).all(), [(d.data, d)])
class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
@@ -2671,7 +2640,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
# operator tests
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
table = Table(
"data_table",
MetaData(),
@@ -2852,10 +2821,10 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
def test_actual_type(self):
eq_(str(self._col_type()), self._col_str)
- def test_reflect(self):
+ def test_reflect(self, connection):
from sqlalchemy import inspect
- insp = inspect(testing.db)
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[0]["type"], self._col_type)
@@ -2986,8 +2955,8 @@ class _DateTimeTZRangeTests(object):
def tstzs(self):
if self._tstzs is None:
- with testing.db.begin() as conn:
- lower = conn.scalar(func.current_timestamp().select())
+ with testing.db.connect() as connection:
+ lower = connection.scalar(func.current_timestamp().select())
upper = lower + datetime.timedelta(1)
self._tstzs = (lower, upper)
return self._tstzs
@@ -3052,7 +3021,7 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip):
class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -3151,7 +3120,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
Column("nulldata", cls.data_type(none_as_null=True)),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
data = [
@@ -3162,8 +3131,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
{"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
{"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
]
- with engine.begin() as conn:
- conn.execute(data_table.insert(), data)
+ connection.execute(data_table.insert(), data)
return data
def _assert_data(self, compare, conn, column="data"):
@@ -3185,51 +3153,39 @@ class JSONRoundTripTest(fixtures.TablesTest):
).fetchall()
eq_([d for d, in data], [None])
- def _test_insert(self, conn):
- conn.execute(
+ def test_reflect(self, connection):
+ insp = inspect(connection)
+ cols = insp.get_columns("data_table")
+ assert isinstance(cols[2]["type"], self.data_type)
+
+ def test_insert(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
)
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _test_insert_nulls(self, conn):
- conn.execute(
+ def test_insert_nulls(self, connection):
+ connection.execute(
self.tables.data_table.insert(), {"name": "r1", "data": null()}
)
- self._assert_data([None], conn)
+ self._assert_data([None], connection)
- def _test_insert_none_as_null(self, conn):
- conn.execute(
+ def test_insert_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": None},
)
- self._assert_column_is_NULL(conn, column="nulldata")
+ self._assert_column_is_NULL(connection, column="nulldata")
- def _test_insert_nulljson_into_none_as_null(self, conn):
- conn.execute(
+ def test_insert_nulljson_into_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": JSON.NULL},
)
- self._assert_column_is_JSON_NULL(conn, column="nulldata")
-
- def test_reflect(self):
- insp = inspect(testing.db)
- cols = insp.get_columns("data_table")
- assert isinstance(cols[2]["type"], self.data_type)
-
- def test_insert(self, connection):
- self._test_insert(connection)
-
- def test_insert_nulls(self, connection):
- self._test_insert_nulls(connection)
+ self._assert_column_is_JSON_NULL(connection, column="nulldata")
- def test_insert_none_as_null(self, connection):
- self._test_insert_none_as_null(connection)
-
- def test_insert_nulljson_into_none_as_null(self, connection):
- self._test_insert_nulljson_into_none_as_null(connection)
-
- def test_custom_serialize_deserialize(self):
+ def test_custom_serialize_deserialize(self, testing_engine):
import json
def loads(value):
@@ -3242,7 +3198,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
value["x"] = "dumps_y"
return json.dumps(value)
- engine = engines.testing_engine(
+ engine = testing_engine(
options=dict(json_serializer=dumps, json_deserializer=loads)
)
@@ -3250,14 +3206,26 @@ class JSONRoundTripTest(fixtures.TablesTest):
with engine.begin() as conn:
eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
- def test_criterion(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion(self, connection):
+ self._fixture_data(connection)
+ data_table = self.tables.data_table
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext.cast(String) == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
def test_path_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
@@ -3271,8 +3239,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
"postgresql < 9.4", "Improvement in PostgreSQL behavior?"
)
def test_multi_index_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
@@ -3283,20 +3250,18 @@ class JSONRoundTripTest(fixtures.TablesTest):
eq_(result.scalar(), "r6")
def test_query_returned_as_text(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k1"].astext)
).first()
- if engine.dialect.returns_unicode_strings:
+ if connection.dialect.returns_unicode_strings:
assert isinstance(result[0], util.text_type)
else:
assert isinstance(result[0], util.string_types)
def test_query_returned_as_int(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k3"].astext.cast(Integer)).where(
@@ -3305,23 +3270,6 @@ class JSONRoundTripTest(fixtures.TablesTest):
).first()
assert isinstance(result[0], int)
- def _test_criterion(self, engine):
- data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext.cast(String) == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
def test_fixed_round_trip(self, connection):
s = select(
cast(
@@ -3352,42 +3300,41 @@ class JSONRoundTripTest(fixtures.TablesTest):
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
def test_literal(self, connection):
- exp = self._fixture_data(testing.db)
+ exp = self._fixture_data(connection)
result = connection.exec_driver_sql(
"select data from data_table order by name"
)
@@ -3395,11 +3342,10 @@ class JSONRoundTripTest(fixtures.TablesTest):
eq_(len(res), len(exp))
for row, expected in zip(res, exp):
eq_(row[0], expected["data"])
- result.close()
class JSONBTest(JSONTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index 4658b40a8..1926c6065 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -37,6 +37,7 @@ from sqlalchemy import UniqueConstraint
from sqlalchemy import util
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.sqlite import insert
+from sqlalchemy.dialects.sqlite import provision
from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect
from sqlalchemy.engine.url import make_url
from sqlalchemy.schema import CreateTable
@@ -46,6 +47,7 @@ from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import combinations
+from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
@@ -774,7 +776,7 @@ class AttachedDBTest(fixtures.TestBase):
def _fixture(self):
meta = self.metadata
- self.conn = testing.db.connect()
+ self.conn = self.engine.connect()
Table("created", meta, Column("foo", Integer), Column("bar", String))
Table("local_only", meta, Column("q", Integer), Column("p", Integer))
@@ -798,14 +800,20 @@ class AttachedDBTest(fixtures.TestBase):
meta.create_all(self.conn)
return ct
- def setup(self):
- self.conn = testing.db.connect()
+ def setup_test(self):
+ self.engine = engines.testing_engine(options={"use_reaper": False})
+
+ provision._sqlite_post_configure_engine(
+ self.engine.url, self.engine, config.ident
+ )
+ self.conn = self.engine.connect()
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with self.conn.begin():
self.metadata.drop_all(self.conn)
self.conn.close()
+ self.engine.dispose()
def test_no_tables(self):
insp = inspect(self.conn)
@@ -1495,7 +1503,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
__skip_if__ = (full_text_search_missing,)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata, cattable, matchtable
metadata = MetaData()
exec_sql(
@@ -1559,7 +1567,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_expression(self):
@@ -1681,7 +1689,7 @@ class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL):
class ReflectHeadlessFKsTest(fixtures.TestBase):
__only_on__ = "sqlite"
- def setup(self):
+ def setup_test(self):
exec_sql(testing.db, "CREATE TABLE a (id INTEGER PRIMARY KEY)")
# this syntax actually works on other DBs perhaps we'd want to add
# tests to test_reflection
@@ -1689,7 +1697,7 @@ class ReflectHeadlessFKsTest(fixtures.TestBase):
testing.db, "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)"
)
- def teardown(self):
+ def teardown_test(self):
exec_sql(testing.db, "drop table b")
exec_sql(testing.db, "drop table a")
@@ -1728,7 +1736,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
__only_on__ = "sqlite"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)")
@@ -1876,7 +1884,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for name in [
"implicit_referrer_comp_fake",
@@ -2370,7 +2378,7 @@ class SavepointTest(fixtures.TablesTest):
@classmethod
def setup_bind(cls):
- engine = engines.testing_engine(options={"use_reaper": False})
+ engine = engines.testing_engine(options={"scope": "class"})
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
@@ -2579,7 +2587,7 @@ class TypeReflectionTest(fixtures.TestBase):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "sqlite"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py
index 396b48aa4..baa766d48 100644
--- a/test/engine/test_ddlevents.py
+++ b/test/engine/test_ddlevents.py
@@ -21,7 +21,7 @@ from sqlalchemy.testing.schema import Table
class DDLEventTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.bind = engines.mock_engine()
self.metadata = MetaData()
self.table = Table("t", self.metadata, Column("id", Integer))
@@ -374,7 +374,7 @@ class DDLEventTest(fixtures.TestBase):
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py
index a18cf756b..0a2c9abe5 100644
--- a/test/engine/test_deprecations.py
+++ b/test/engine/test_deprecations.py
@@ -965,7 +965,7 @@ class TransactionTest(fixtures.TablesTest):
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -1021,18 +1021,18 @@ def MockDBAPI(): # noqa
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _queuepool_fixture(self, **kw):
@@ -1597,7 +1597,7 @@ class EngineEventsTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -1650,6 +1650,7 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(
engine, "before_cursor_execute", cursor_execute, retval=True
)
+
with testing.expect_deprecated(
r"The argument signature for the "
r"\"ConnectionEvents.before_execute\" event listener",
@@ -1676,11 +1677,12 @@ class EngineEventsTest(fixtures.TestBase):
r"The argument signature for the "
r"\"ConnectionEvents.after_execute\" event listener",
):
- e1.execute(select(1))
+ result = e1.execute(select(1))
+ result.close()
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 21d4e06e0..a1e4ea218 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -43,7 +43,6 @@ from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing.mock import call
from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.mock import patch
@@ -94,13 +93,13 @@ class ExecuteTest(fixtures.TablesTest):
).default_from()
)
- conn = testing.db.connect()
- result = (
- conn.execution_options(no_parameters=True)
- .exec_driver_sql(stmt)
- .scalar()
- )
- eq_(result, "%")
+ with testing.db.connect() as conn:
+ result = (
+ conn.execution_options(no_parameters=True)
+ .exec_driver_sql(stmt)
+ .scalar()
+ )
+ eq_(result, "%")
def test_raw_positional_invalid(self, connection):
assert_raises_message(
@@ -261,16 +260,15 @@ class ExecuteTest(fixtures.TablesTest):
(4, "sally"),
]
- @testing.engines.close_open_connections
def test_exception_wrapping_dbapi(self):
- conn = testing.db.connect()
- # engine does not have exec_driver_sql
- assert_raises_message(
- tsa.exc.DBAPIError,
- r"not_a_valid_statement",
- conn.exec_driver_sql,
- "not_a_valid_statement",
- )
+ with testing.db.connect() as conn:
+ # engine does not have exec_driver_sql
+ assert_raises_message(
+ tsa.exc.DBAPIError,
+ r"not_a_valid_statement",
+ conn.exec_driver_sql,
+ "not_a_valid_statement",
+ )
@testing.requires.sqlite
def test_exception_wrapping_non_dbapi_error(self):
@@ -864,12 +862,10 @@ class CompiledCacheTest(fixtures.TestBase):
["sqlite", "mysql", "postgresql"],
"uses blob value that is problematic for some DBAPIs",
)
- @testing.provide_metadata
- def test_cache_noleak_on_statement_values(self, connection):
+ def test_cache_noleak_on_statement_values(self, metadata, connection):
# This is a non regression test for an object reference leak caused
# by the compiled_cache.
- metadata = self.metadata
photo = Table(
"photo",
metadata,
@@ -1040,7 +1036,19 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
__requires__ = ("schemas",)
__backend__ = True
- def test_create_table(self):
+ @testing.fixture
+ def plain_tables(self, metadata):
+ t1 = Table(
+ "t1", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t2 = Table(
+ "t2", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t3 = Table("t3", metadata, Column("x", Integer), schema=None)
+
+ return t1, t2, t3
+
+ def test_create_table(self, plain_tables, connection):
map_ = {
None: config.test_schema,
"foo": config.test_schema,
@@ -1052,18 +1060,16 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- t1.create(conn)
- t2.create(conn)
- t3.create(conn)
+ t1.create(conn)
+ t2.create(conn)
+ t3.create(conn)
- t3.drop(conn)
- t2.drop(conn)
- t1.drop(conn)
+ t3.drop(conn)
+ t2.drop(conn)
+ t1.drop(conn)
asserter.assert_(
CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"),
@@ -1074,14 +1080,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DROP TABLE [SCHEMA__none].t1"),
)
- def _fixture(self):
- metadata = self.metadata
- Table("t1", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t2", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t3", metadata, Column("x", Integer), schema=None)
- metadata.create_all(testing.db)
-
- def test_ddl_hastable(self):
+ def test_ddl_hastable(self, plain_tables, connection):
map_ = {
None: config.test_schema,
@@ -1094,27 +1093,28 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
Table("t2", metadata, Column("x", Integer), schema="foo")
Table("t3", metadata, Column("x", Integer), schema="bar")
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.create_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+ metadata.create_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_true(insp.has_table("t1", schema=config.test_schema))
is_true(insp.has_table("t2", schema=config.test_schema))
is_true(insp.has_table("t3", schema=None))
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.drop_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+
+ # if this test fails, the tables won't get dropped. so need a
+ # more robust fixture for this
+ metadata.drop_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_false(insp.has_table("t1", schema=config.test_schema))
is_false(insp.has_table("t2", schema=config.test_schema))
is_false(insp.has_table("t3", schema=None))
- @testing.provide_metadata
- def test_option_on_execute(self):
- self._fixture()
+ def test_option_on_execute(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1127,61 +1127,54 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection
+ execution_options = {"schema_translate_map": map_}
+ conn._execute_20(
+ t1.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t2.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t3.insert(), {"x": 1}, execution_options=execution_options
+ )
- execution_options = {"schema_translate_map": map_}
- conn._execute_20(
- t1.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t2.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t3.insert(), {"x": 1}, execution_options=execution_options
- )
+ conn._execute_20(
+ t1.update().values(x=1).where(t1.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t2.update().values(x=2).where(t2.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t3.update().values(x=3).where(t3.c.x == 1),
+ execution_options=execution_options,
+ )
+ eq_(
conn._execute_20(
- t1.update().values(x=1).where(t1.c.x == 1),
- execution_options=execution_options,
- )
+ select(t1.c.x), execution_options=execution_options
+ ).scalar(),
+ 1,
+ )
+ eq_(
conn._execute_20(
- t2.update().values(x=2).where(t2.c.x == 1),
- execution_options=execution_options,
- )
+ select(t2.c.x), execution_options=execution_options
+ ).scalar(),
+ 2,
+ )
+ eq_(
conn._execute_20(
- t3.update().values(x=3).where(t3.c.x == 1),
- execution_options=execution_options,
- )
-
- eq_(
- conn._execute_20(
- select(t1.c.x), execution_options=execution_options
- ).scalar(),
- 1,
- )
- eq_(
- conn._execute_20(
- select(t2.c.x), execution_options=execution_options
- ).scalar(),
- 2,
- )
- eq_(
- conn._execute_20(
- select(t3.c.x), execution_options=execution_options
- ).scalar(),
- 3,
- )
+ select(t3.c.x), execution_options=execution_options
+ ).scalar(),
+ 3,
+ )
- conn._execute_20(
- t1.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t2.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t3.delete(), execution_options=execution_options
- )
+ conn._execute_20(t1.delete(), execution_options=execution_options)
+ conn._execute_20(t2.delete(), execution_options=execution_options)
+ conn._execute_20(t3.delete(), execution_options=execution_options)
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1207,9 +1200,9 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_crud(self):
- self._fixture()
+ def test_crud(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1222,26 +1215,24 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- conn.execute(t1.insert(), {"x": 1})
- conn.execute(t2.insert(), {"x": 1})
- conn.execute(t3.insert(), {"x": 1})
+ conn.execute(t1.insert(), {"x": 1})
+ conn.execute(t2.insert(), {"x": 1})
+ conn.execute(t3.insert(), {"x": 1})
- conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
- conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
- conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
+ conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
+ conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
+ conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
- eq_(conn.scalar(select(t1.c.x)), 1)
- eq_(conn.scalar(select(t2.c.x)), 2)
- eq_(conn.scalar(select(t3.c.x)), 3)
+ eq_(conn.scalar(select(t1.c.x)), 1)
+ eq_(conn.scalar(select(t2.c.x)), 2)
+ eq_(conn.scalar(select(t3.c.x)), 3)
- conn.execute(t1.delete())
- conn.execute(t2.delete())
- conn.execute(t3.delete())
+ conn.execute(t1.delete())
+ conn.execute(t2.delete())
+ conn.execute(t3.delete())
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1267,9 +1258,10 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_via_engine(self):
- self._fixture()
+ def test_via_engine(self, plain_tables, metadata):
+
+ with config.db.begin() as connection:
+ metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1282,25 +1274,25 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
with self.sql_execution_asserter(config.db) as asserter:
eng = config.db.execution_options(schema_translate_map=map_)
- conn = eng.connect()
- conn.execute(select(t2.c.x))
+ with eng.connect() as conn:
+ conn.execute(select(t2.c.x))
asserter.assert_(
CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2")
)
class ExecutionOptionsTest(fixtures.TestBase):
- def test_dialect_conn_options(self):
+ def test_dialect_conn_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c2 = conn.execution_options(foo="bar")
- eq_(
- engine.dialect.set_connection_execution_options.mock_calls,
- [call(c2, {"foo": "bar"})],
- )
+ with engine.connect() as conn:
+ c2 = conn.execution_options(foo="bar")
+ eq_(
+ engine.dialect.set_connection_execution_options.mock_calls,
+ [call(c2, {"foo": "bar"})],
+ )
- def test_dialect_engine_options(self):
+ def test_dialect_engine_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
@@ -1319,14 +1311,14 @@ class ExecutionOptionsTest(fixtures.TestBase):
[call(engine, {"foo": "bar"})],
)
- def test_propagate_engine_to_connection(self):
+ def test_propagate_engine_to_connection(self, testing_engine):
engine = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
- conn = engine.connect()
- eq_(conn._execution_options, {"foo": "bar"})
+ with engine.connect() as conn:
+ eq_(conn._execution_options, {"foo": "bar"})
- def test_propagate_option_engine_to_connection(self):
+ def test_propagate_option_engine_to_connection(self, testing_engine):
e1 = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
@@ -1336,27 +1328,30 @@ class ExecutionOptionsTest(fixtures.TestBase):
eq_(c1._execution_options, {"foo": "bar"})
eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"})
- def test_get_engine_execution_options(self):
+ c1.close()
+ c2.close()
+
+ def test_get_engine_execution_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
eq_(e2.get_execution_options(), {"foo": "bar"})
- def test_get_connection_execution_options(self):
+ def test_get_connection_execution_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c = conn.execution_options(foo="bar")
+ with engine.connect() as conn:
+ c = conn.execution_options(foo="bar")
- eq_(c.get_execution_options(), {"foo": "bar"})
+ eq_(c.get_execution_options(), {"foo": "bar"})
class EngineEventsTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -1376,7 +1371,7 @@ class EngineEventsTest(fixtures.TestBase):
):
break
- def test_per_engine_independence(self):
+ def test_per_engine_independence(self, testing_engine):
e1 = testing_engine(config.db_url)
e2 = testing_engine(config.db_url)
@@ -1400,7 +1395,7 @@ class EngineEventsTest(fixtures.TestBase):
conn.execute(s2)
eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2])
- def test_per_engine_plus_global(self):
+ def test_per_engine_plus_global(self, testing_engine):
canary = Mock()
event.listen(Engine, "before_execute", canary.be1)
e1 = testing_engine(config.db_url)
@@ -1409,8 +1404,6 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(e1, "before_execute", canary.be2)
event.listen(Engine, "before_execute", canary.be3)
- e1.connect()
- e2.connect()
with e1.connect() as conn:
conn.execute(select(1))
@@ -1424,7 +1417,7 @@ class EngineEventsTest(fixtures.TestBase):
eq_(canary.be2.call_count, 1)
eq_(canary.be3.call_count, 2)
- def test_per_connection_plus_engine(self):
+ def test_per_connection_plus_engine(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1442,9 +1435,14 @@ class EngineEventsTest(fixtures.TestBase):
eq_(canary.be1.call_count, 2)
eq_(canary.be2.call_count, 2)
- @testing.combinations((True, False), (True, True), (False, False))
+ @testing.combinations(
+ (True, False),
+ (True, True),
+ (False, False),
+ argnames="mock_out_on_connect, add_our_own_onconnect",
+ )
def test_insert_connect_is_definitely_first(
- self, mock_out_on_connect, add_our_own_onconnect
+ self, mock_out_on_connect, add_our_own_onconnect, testing_engine
):
"""test issue #5708.
@@ -1478,7 +1476,7 @@ class EngineEventsTest(fixtures.TestBase):
patcher = util.nullcontext()
with patcher:
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url)
initialize = e1.dialect.initialize
@@ -1559,10 +1557,11 @@ class EngineEventsTest(fixtures.TestBase):
conn.exec_driver_sql(select1(testing.db))
eq_(m1.mock_calls, [])
- def test_add_event_after_connect(self):
+ def test_add_event_after_connect(self, testing_engine):
# new feature as of #2978
+
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
conn = e1.connect()
@@ -1575,9 +1574,9 @@ class EngineEventsTest(fixtures.TestBase):
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 2)
- def test_force_conn_events_false(self):
+ def test_force_conn_events_false(self, testing_engine):
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
event.listen(e1, "before_execute", canary.be1)
@@ -1593,7 +1592,7 @@ class EngineEventsTest(fixtures.TestBase):
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 0)
- def test_cursor_events_ctx_execute_scalar(self):
+ def test_cursor_events_ctx_execute_scalar(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1620,7 +1619,7 @@ class EngineEventsTest(fixtures.TestBase):
[call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
)
- def test_cursor_events_execute(self):
+ def test_cursor_events_execute(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1653,9 +1652,15 @@ class EngineEventsTest(fixtures.TestBase):
),
((), {"z": 10}, [], {"z": 10}, testing.requires.legacy_engine),
(({"z": 10},), {}, [], {"z": 10}),
+ argnames="multiparams, params, expected_multiparams, expected_params",
)
def test_modify_parameters_from_event_one(
- self, multiparams, params, expected_multiparams, expected_params
+ self,
+ multiparams,
+ params,
+ expected_multiparams,
+ expected_params,
+ testing_engine,
):
# this is testing both the normalization added to parameters
# as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as
@@ -1704,7 +1709,9 @@ class EngineEventsTest(fixtures.TestBase):
[(15,), (19,)],
)
- def test_modify_parameters_from_event_three(self, connection):
+ def test_modify_parameters_from_event_three(
+ self, connection, testing_engine
+ ):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
@@ -1721,7 +1728,7 @@ class EngineEventsTest(fixtures.TestBase):
with e1.connect() as conn:
conn.execute(select(literal("1")))
- def test_argument_format_execute(self):
+ def test_argument_format_execute(self, testing_engine):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
@@ -1956,9 +1963,9 @@ class EngineEventsTest(fixtures.TestBase):
)
@testing.requires.ad_hoc_engines
- def test_dispose_event(self):
+ def test_dispose_event(self, testing_engine):
canary = Mock()
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
event.listen(eng, "engine_disposed", canary)
conn = eng.connect()
@@ -2102,13 +2109,13 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(engine, "commit", tracker("commit"))
event.listen(engine, "rollback", tracker("rollback"))
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
@@ -2145,13 +2152,13 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(engine, "commit", tracker("commit"), named=True)
event.listen(engine, "rollback", tracker("rollback"), named=True)
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
@@ -2310,7 +2317,7 @@ class HandleErrorTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -2742,7 +2749,7 @@ class HandleErrorTest(fixtures.TestBase):
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -3014,6 +3021,9 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
],
)
+ c.close()
+ c2.close()
+
class DialectEventTest(fixtures.TestBase):
@contextmanager
@@ -3370,7 +3380,7 @@ class SetInputSizesTest(fixtures.TablesTest):
)
@testing.fixture
- def input_sizes_fixture(self):
+ def input_sizes_fixture(self, testing_engine):
canary = mock.Mock()
def do_set_input_sizes(cursor, list_of_tuples, context):
diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py
index 29b8132aa..c56589248 100644
--- a/test/engine/test_logging.py
+++ b/test/engine/test_logging.py
@@ -30,7 +30,7 @@ class LogParamsTest(fixtures.TestBase):
__only_on__ = "sqlite"
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.eng = engines.testing_engine(options={"echo": True})
self.no_param_engine = engines.testing_engine(
options={"echo": True, "hide_parameters": True}
@@ -44,7 +44,7 @@ class LogParamsTest(fixtures.TestBase):
for log in [logging.getLogger("sqlalchemy.engine")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
exec_sql(self.eng, "drop table if exists foo")
for log in [logging.getLogger("sqlalchemy.engine")]:
log.removeHandler(self.buf)
@@ -413,14 +413,14 @@ class LogParamsTest(fixtures.TestBase):
class PoolLoggingTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.existing_level = logging.getLogger("sqlalchemy.pool").level
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.pool")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.pool")]:
log.removeHandler(self.buf)
logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level)
@@ -528,7 +528,7 @@ class LoggingNameTest(fixtures.TestBase):
kw.update({"echo": True})
return engines.testing_engine(options=kw)
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [
logging.getLogger("sqlalchemy.engine"),
@@ -536,7 +536,7 @@ class LoggingNameTest(fixtures.TestBase):
]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [
logging.getLogger("sqlalchemy.engine"),
logging.getLogger("sqlalchemy.pool"),
@@ -588,13 +588,13 @@ class LoggingNameTest(fixtures.TestBase):
class EchoTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.level = logging.getLogger("sqlalchemy.engine").level
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN)
self.buf = logging.handlers.BufferingHandler(100)
logging.getLogger("sqlalchemy.engine").addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
logging.getLogger("sqlalchemy.engine").removeHandler(self.buf)
logging.getLogger("sqlalchemy.engine").setLevel(self.level)
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index 550fedb8e..decdce3f9 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -17,7 +17,9 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_none
from sqlalchemy.testing import is_not
+from sqlalchemy.testing import is_not_none
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.engines import testing_engine
@@ -63,18 +65,18 @@ def MockDBAPI(): # noqa
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _with_teardown(self, connection):
@@ -364,10 +366,17 @@ class PoolEventsTest(PoolTestBase):
p = self._queuepool_fixture()
canary = []
+ @event.listens_for(p, "checkin")
def checkin(*arg, **kw):
canary.append("checkin")
- event.listen(p, "checkin", checkin)
+ @event.listens_for(p, "close_detached")
+ def close_detached(*arg, **kw):
+ canary.append("close_detached")
+
+ @event.listens_for(p, "detach")
+ def detach(*arg, **kw):
+ canary.append("detach")
return p, canary
@@ -629,15 +638,35 @@ class PoolEventsTest(PoolTestBase):
assert canary.call_args_list[0][0][0] is dbapi_con
assert canary.call_args_list[0][0][2] is exc
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_checkin_event_gc(self):
+ def test_checkin_event_gc(self, detach_gced):
p, canary = self._checkin_event_fixture()
+ if detach_gced:
+ p._is_asyncio = True
+
c1 = p.connect()
+
+ dbapi_connection = weakref.ref(c1.connection)
+
eq_(canary, [])
del c1
lazy_gc()
- eq_(canary, ["checkin"])
+
+ if detach_gced:
+ # "close_detached" is not called because for asyncio the
+ # connection is just lost.
+ eq_(canary, ["detach"])
+
+ else:
+ eq_(canary, ["checkin"])
+
+ gc_collect()
+ if detach_gced:
+ is_none(dbapi_connection())
+ else:
+ is_not_none(dbapi_connection())
def test_checkin_event_on_subsequently_recreated(self):
p, canary = self._checkin_event_fixture()
@@ -744,7 +773,7 @@ class PoolEventsTest(PoolTestBase):
eq_(conn.info["important_flag"], True)
conn.close()
- def teardown(self):
+ def teardown_test(self):
# TODO: need to get remove() functionality
# going
pool.Pool.dispatch._clear()
@@ -1490,12 +1519,16 @@ class QueuePoolTest(PoolTestBase):
self._assert_cleanup_on_pooled_reconnect(dbapi, p)
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_userspace_disconnectionerror_weakref_finalizer(self):
+ def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced):
dbapi, pool = self._queuepool_dbapi_fixture(
pool_size=1, max_overflow=2
)
+ if detach_gced:
+ pool._is_asyncio = True
+
@event.listens_for(pool, "checkout")
def handle_checkout_event(dbapi_con, con_record, con_proxy):
if getattr(dbapi_con, "boom") == "yes":
@@ -1514,8 +1547,12 @@ class QueuePoolTest(PoolTestBase):
del conn
gc_collect()
- # new connection was reset on return appropriately
- eq_(dbapi_conn.mock_calls, [call.rollback()])
+ if detach_gced:
+ # new connection was detached + abandoned on return
+ eq_(dbapi_conn.mock_calls, [])
+ else:
+ # new connection reset and returned to pool
+ eq_(dbapi_conn.mock_calls, [call.rollback()])
# old connection was just closed - did not get an
# erroneous reset on return
diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py
index 3810de06a..5a4220c82 100644
--- a/test/engine/test_processors.py
+++ b/test/engine/test_processors.py
@@ -25,7 +25,7 @@ class CBooleanProcessorTest(_BooleanProcessorTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
@@ -83,7 +83,7 @@ class _DateProcessorTest(fixtures.TestBase):
class PyDateProcessorTest(_DateProcessorTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import processors
cls.module = type(
@@ -100,7 +100,7 @@ class CDateProcessorTest(_DateProcessorTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
@@ -185,7 +185,7 @@ class _DistillArgsTest(fixtures.TestBase):
class PyDistillArgsTest(_DistillArgsTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.engine import util
cls.module = type(
@@ -202,7 +202,7 @@ class CDistillArgsTest(_DistillArgsTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cutils as util
cls.module = util
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index 5fe7f6cc2..7a64b2550 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -162,7 +162,7 @@ def MockDBAPI():
class PrePingMockTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
def _pool_fixture(self, pre_ping, pool_kw=None):
@@ -182,7 +182,7 @@ class PrePingMockTest(fixtures.TestBase):
)
return _pool
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_ping_not_on_first_connect(self):
@@ -357,7 +357,7 @@ class PrePingMockTest(fixtures.TestBase):
class MockReconnectTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
self.db = testing_engine(
@@ -373,7 +373,7 @@ class MockReconnectTest(fixtures.TestBase):
e, MockDisconnect
)
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_reconnect(self):
@@ -1004,10 +1004,10 @@ class RealReconnectTest(fixtures.TestBase):
__backend__ = True
__requires__ = "graceful_disconnects", "ad_hoc_engines"
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_reconnect(self):
@@ -1336,7 +1336,7 @@ class PrePingRealTest(fixtures.TestBase):
class InvalidateDuringResultTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
self.meta = MetaData()
table = Table(
@@ -1353,7 +1353,7 @@ class InvalidateDuringResultTest(fixtures.TestBase):
[{"id": i, "name": "row %d" % i} for i in range(1, 100)],
)
- def teardown(self):
+ def teardown_test(self):
with self.engine.begin() as conn:
self.meta.drop_all(conn)
self.engine.dispose()
@@ -1470,7 +1470,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine(
options=dict(future=self.future)
)
@@ -1483,7 +1483,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
)
self.meta.create_all(self.engine)
- def teardown(self):
+ def teardown_test(self):
self.meta.drop_all(self.engine)
self.engine.dispose()
diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py
index 658cdd79f..0a46ddeec 100644
--- a/test/engine/test_reflection.py
+++ b/test/engine/test_reflection.py
@@ -796,7 +796,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
assert f1 in b1.constraints
assert len(b1.constraints) == 2
- def test_override_keys(self, connection, metadata):
+ def test_override_keys(self, metadata, connection):
"""test that columns can be overridden with a 'key',
and that ForeignKey targeting during reflection still works."""
@@ -1375,7 +1375,7 @@ class CreateDropTest(fixtures.TablesTest):
run_create_tables = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
# TablesTest is used here without
# run_create_tables, so add an explicit drop of whatever is in
# metadata
@@ -1658,7 +1658,6 @@ class SchemaTest(fixtures.TestBase):
@testing.requires.schemas
@testing.requires.cross_schema_fk_reflection
@testing.requires.implicit_default_schema
- @testing.provide_metadata
def test_blank_schema_arg(self, connection, metadata):
Table(
@@ -1913,7 +1912,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
__backend__ = True
@testing.requires.denormalized_names
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
@@ -1926,7 +1925,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
)
@testing.requires.denormalized_names
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table weird_casing")
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py
index 79126fc5b..47504b60a 100644
--- a/test/engine/test_transaction.py
+++ b/test/engine/test_transaction.py
@@ -1,6 +1,5 @@
import sys
-from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
@@ -640,12 +639,12 @@ class AutoRollbackTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata
metadata = MetaData()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_rollback_deadlock(self):
@@ -871,11 +870,13 @@ class IsolationLevelTest(fixtures.TestBase):
def test_per_engine(self):
# new in 0.9
- eng = create_engine(
+ eng = testing_engine(
testing.db.url,
- execution_options={
- "isolation_level": self._non_default_isolation_level()
- },
+ options=dict(
+ execution_options={
+ "isolation_level": self._non_default_isolation_level()
+ }
+ ),
)
conn = eng.connect()
eq_(
@@ -884,7 +885,7 @@ class IsolationLevelTest(fixtures.TestBase):
)
def test_per_option_engine(self):
- eng = create_engine(testing.db.url).execution_options(
+ eng = testing_engine(testing.db.url).execution_options(
isolation_level=self._non_default_isolation_level()
)
@@ -895,14 +896,14 @@ class IsolationLevelTest(fixtures.TestBase):
)
def test_isolation_level_accessors_connection_default(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
eq_(conn.default_isolation_level, self._default_isolation_level())
with eng.connect() as conn:
eq_(conn.get_isolation_level(), self._default_isolation_level())
def test_isolation_level_accessors_connection_option_modified(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
c2 = conn.execution_options(
isolation_level=self._non_default_isolation_level()
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
index 7dae1411e..59a44f8e2 100644
--- a/test/ext/asyncio/test_engine_py3k.py
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -269,7 +269,7 @@ class AsyncEngineTest(EngineFixture):
await trans.rollback(),
@async_test
- async def test_pool_exhausted(self, async_engine):
+ async def test_pool_exhausted_some_timeout(self, async_engine):
engine = create_async_engine(
testing.db.url,
pool_size=1,
@@ -277,7 +277,19 @@ class AsyncEngineTest(EngineFixture):
pool_timeout=0.1,
)
async with engine.connect():
- with expect_raises(asyncio.TimeoutError):
+ with expect_raises(exc.TimeoutError):
+ await engine.connect()
+
+ @async_test
+ async def test_pool_exhausted_no_timeout(self, async_engine):
+ engine = create_async_engine(
+ testing.db.url,
+ pool_size=1,
+ max_overflow=0,
+ pool_timeout=0,
+ )
+ async with engine.connect():
+ with expect_raises(exc.TimeoutError):
await engine.connect()
@async_test
diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py
index 2b80b753e..e25e7cfc2 100644
--- a/test/ext/declarative/test_inheritance.py
+++ b/test/ext/declarative/test_inheritance.py
@@ -27,11 +27,11 @@ Base = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py
index d7fcbf9e8..c327de7d4 100644
--- a/test/ext/declarative/test_reflection.py
+++ b/test/ext/declarative/test_reflection.py
@@ -4,7 +4,6 @@ from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import decl_api as decl
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import exc as orm_exc
@@ -14,6 +13,7 @@ from sqlalchemy.orm.decl_base import _DeferredMapperConfig
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import gc_collect
@@ -22,20 +22,19 @@ from sqlalchemy.testing.util import gc_collect
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry()
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
class DeferredReflectBase(DeclarativeReflectionBase):
- def teardown(self):
- super(DeferredReflectBase, self).teardown()
+ def teardown_test(self):
+ super(DeferredReflectBase, self).teardown_test()
_DeferredMapperConfig._configs.clear()
@@ -101,22 +100,23 @@ class DeferredReflectionTest(DeferredReflectBase):
u1 = User(
name="u1", addresses=[Address(email="one"), Address(email="two")]
)
- sess = create_session(testing.db)
- sess.add(u1)
- sess.flush()
- sess.expunge_all()
- eq_(
- sess.query(User).all(),
- [
- User(
- name="u1",
- addresses=[Address(email="one"), Address(email="two")],
- )
- ],
- )
- a1 = sess.query(Address).filter(Address.email == "two").one()
- eq_(a1, Address(email="two"))
- eq_(a1.user, User(name="u1"))
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
+
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User).all(),
+ [
+ User(
+ name="u1",
+ addresses=[Address(email="one"), Address(email="two")],
+ )
+ ],
+ )
+ a1 = sess.query(Address).filter(Address.email == "two").one()
+ eq_(a1, Address(email="two"))
+ eq_(a1.user, User(name="u1"))
def test_exception_prepare_not_called(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -191,15 +191,25 @@ class DeferredReflectionTest(DeferredReflectBase):
return {"primary_key": cls.__table__.c.id}
DeferredReflection.prepare(testing.db)
- sess = Session(testing.db)
- sess.add_all(
- [User(name="G"), User(name="Q"), User(name="A"), User(name="C")]
- )
- sess.commit()
- eq_(
- sess.query(User).order_by(User.name).all(),
- [User(name="A"), User(name="C"), User(name="G"), User(name="Q")],
- )
+ with fixture_session() as sess:
+ sess.add_all(
+ [
+ User(name="G"),
+ User(name="Q"),
+ User(name="A"),
+ User(name="C"),
+ ]
+ )
+ sess.commit()
+ eq_(
+ sess.query(User).order_by(User.name).all(),
+ [
+ User(name="A"),
+ User(name="C"),
+ User(name="G"),
+ User(name="Q"),
+ ],
+ )
@testing.requires.predictable_gc
def test_cls_not_strong_ref(self):
@@ -255,14 +265,14 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase):
u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")])
- sess = Session(testing.db)
- sess.add(u1)
- sess.commit()
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
- eq_(
- sess.query(User).all(),
- [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
- )
+ eq_(
+ sess.query(User).all(),
+ [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
+ )
def test_string_resolution(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -296,27 +306,26 @@ class DeferredInhReflectBase(DeferredReflectBase):
Foo = Base.registry._class_registry["Foo"]
Bar = Base.registry._class_registry["Bar"]
- s = Session(testing.db)
-
- s.add_all(
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ]
- )
- s.commit()
-
- eq_(
- s.query(Foo).order_by(Foo.id).all(),
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ],
- )
+ with fixture_session() as s:
+ s.add_all(
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ]
+ )
+ s.commit()
+
+ eq_(
+ s.query(Foo).order_by(Foo.id).all(),
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ],
+ )
class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py
index b1f5cc956..31ae050c1 100644
--- a/test/ext/test_associationproxy.py
+++ b/test/ext/test_associationproxy.py
@@ -101,7 +101,7 @@ class AutoFlushTest(fixtures.TablesTest):
Column("name", String(50)),
)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture(self, collection_class, is_dict=False):
@@ -198,7 +198,7 @@ class AutoFlushTest(fixtures.TablesTest):
class _CollectionOperations(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
collection_class = self.collection_class
metadata = MetaData()
@@ -260,7 +260,7 @@ class _CollectionOperations(fixtures.TestBase):
self.session = fixture_session()
self.Parent, self.Child = Parent, Child
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
@@ -885,7 +885,7 @@ class CustomObjectTest(_CollectionOperations):
class ProxyFactoryTest(ListTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
@@ -1157,7 +1157,7 @@ class ScalarTest(fixtures.TestBase):
class LazyLoadTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
@@ -1197,7 +1197,7 @@ class LazyLoadTest(fixtures.TestBase):
self.Parent, self.Child = Parent, Child
self.table = parents_table
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
@@ -2294,7 +2294,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
class DictOfTupleUpdateTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
class B(object):
def __init__(self, key, elem):
self.key = key
@@ -2434,7 +2434,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest):
class AttributeAccessTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_resolve_aliased_class(self):
diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py
index 71fabc629..2d4e9848e 100644
--- a/test/ext/test_baked.py
+++ b/test/ext/test_baked.py
@@ -27,7 +27,7 @@ class BakedTest(_fixtures.FixtureTest):
run_inserts = "once"
run_deletes = None
- def setup(self):
+ def setup_test(self):
self.bakery = baked.bakery()
diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py
index 058c1dfd7..d011417d7 100644
--- a/test/ext/test_compiler.py
+++ b/test/ext/test_compiler.py
@@ -426,7 +426,7 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def teardown(self):
+ def teardown_test(self):
for cls in (Select, BindParameter):
deregister(cls)
diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py
index ad9bf0bc0..f3eceb0dc 100644
--- a/test/ext/test_extendedattr.py
+++ b/test/ext/test_extendedattr.py
@@ -32,7 +32,7 @@ def modifies_instrumentation_finders(fn, *args, **kw):
class _ExtBase(object):
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation._reinstall_default_lookups()
@@ -89,7 +89,7 @@ MyBaseClass, MyClass = None, None
class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global MyBaseClass, MyClass
class MyBaseClass(object):
@@ -143,7 +143,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
else:
del self._goofy_dict[key]
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_instance_dict(self):
diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py
index 038bdd83e..bb06d9648 100644
--- a/test/ext/test_horizontal_shard.py
+++ b/test/ext/test_horizontal_shard.py
@@ -19,7 +19,6 @@ from sqlalchemy import update
from sqlalchemy import util
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import deferred
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
@@ -42,7 +41,7 @@ class ShardTest(object):
schema = None
- def setUp(self):
+ def setup_test(self):
global db1, db2, db3, db4, weather_locations, weather_reports
db1, db2, db3, db4 = self._dbs = self._init_dbs()
@@ -88,7 +87,7 @@ class ShardTest(object):
@classmethod
def setup_session(cls):
- global create_session
+ global sharded_session
shard_lookup = {
"North America": "north_america",
"Asia": "asia",
@@ -128,10 +127,10 @@ class ShardTest(object):
else:
return ids
- create_session = sessionmaker(
+ sharded_session = sessionmaker(
class_=ShardedSession, autoflush=True, autocommit=False
)
- create_session.configure(
+ sharded_session.configure(
shards={
"north_america": db1,
"asia": db2,
@@ -180,7 +179,7 @@ class ShardTest(object):
tokyo.reports.append(Report(80.0, id_=1))
newyork.reports.append(Report(75, id_=1))
quito.reports.append(Report(85))
- sess = create_session(future=True)
+ sess = sharded_session(future=True)
for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
sess.add(c)
sess.flush()
@@ -671,11 +670,10 @@ class DistinctEngineShardTest(ShardTest, fixtures.TestBase):
self.dbs = [db1, db2, db3, db4]
return self.dbs
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- for db in self.dbs:
- db.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -702,10 +700,10 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase):
self.engine = e
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.engine.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -778,10 +776,13 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
self.postgresql_engine = e2
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.sqlite_engine.connect().invalidate()
+ # the tests in this suite don't cleanly close out the Session
+ # at the moment so use the reaper to close all connections
+ testing_reaper.checkin_all()
+
for i in [1, 3]:
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -789,6 +790,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
self.tables_test_metadata.drop_all(conn)
for i in [2, 4]:
conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,))
+ self.postgresql_engine.dispose()
class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
@@ -904,11 +906,11 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
return self.dbs
- def teardown(self):
+ def teardown_test(self):
for db in self.dbs:
db.connect().invalidate()
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
for i in range(1, 3):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py
index 048a8b52d..3bab7db93 100644
--- a/test/ext/test_hybrid.py
+++ b/test/ext/test_hybrid.py
@@ -697,7 +697,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import literal
symbols = ("usd", "gbp", "cad", "eur", "aud")
diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py
index eba2ac0cb..21244de73 100644
--- a/test/ext/test_mutable.py
+++ b/test/ext/test_mutable.py
@@ -90,11 +90,10 @@ class _MutableDictTestFixture(object):
def _type_fixture(cls):
return MutableDict
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableDictTestFixture, self).teardown()
class _MutableDictTestBase(_MutableDictTestFixture):
@@ -312,11 +311,10 @@ class _MutableListTestFixture(object):
def _type_fixture(cls):
return MutableList
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableListTestFixture, self).teardown()
class _MutableListTestBase(_MutableListTestFixture):
@@ -619,11 +617,10 @@ class _MutableSetTestFixture(object):
def _type_fixture(cls):
return MutableSet
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableSetTestFixture, self).teardown()
class _MutableSetTestBase(_MutableSetTestFixture):
@@ -1234,17 +1231,15 @@ class _CompositeTestBase(object):
Column("unrelated_data", String(50)),
)
- def setup(self):
+ def setup_test(self):
from sqlalchemy.ext import mutable
mutable._setup_composite_listener()
- super(_CompositeTestBase, self).setup()
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_CompositeTestBase, self).teardown()
@classmethod
def _type_fixture(cls):
diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py
index f23d6cb57..280fad6cf 100644
--- a/test/ext/test_orderinglist.py
+++ b/test/ext/test_orderinglist.py
@@ -8,7 +8,7 @@ from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.fixtures import create_session
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import picklers
@@ -60,7 +60,7 @@ def alpha_ordering(index, collection):
class OrderingListTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
global metadata, slides_table, bullets_table, Slide, Bullet
slides_table, bullets_table = None, None
Slide, Bullet = None, None
@@ -122,7 +122,7 @@ class OrderingListTest(fixtures.TestBase):
metadata.create_all(testing.db)
- def teardown(self):
+ def teardown_test(self):
metadata.drop_all(testing.db)
def test_append_no_reorder(self):
@@ -167,7 +167,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(s1.bullets[2].position == 3)
self.assert_(s1.bullets[3].position == 4)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -232,7 +232,7 @@ class OrderingListTest(fixtures.TestBase):
s1.bullets._reorder()
self.assert_(s1.bullets[4].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -289,7 +289,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(len(s1.bullets) == 6)
self.assert_(s1.bullets[5].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -338,7 +338,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(s1.bullets[li].position == li)
self.assert_(s1.bullets[li] == b[bi])
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -365,7 +365,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(len(s1.bullets) == 3)
self.assert_(s1.bullets[2].position == 2)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py
index 4c005d336..4d9162105 100644
--- a/test/orm/declarative/test_basic.py
+++ b/test/orm/declarative/test_basic.py
@@ -61,11 +61,11 @@ class DeclarativeTestBase(
):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
global Base
Base = declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/orm/declarative/test_concurrency.py b/test/orm/declarative/test_concurrency.py
index 5f12d8272..ecddc2e5f 100644
--- a/test/orm/declarative/test_concurrency.py
+++ b/test/orm/declarative/test_concurrency.py
@@ -17,7 +17,7 @@ from sqlalchemy.testing.fixtures import fixture_session
class ConcurrentUseDeclMappingTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@classmethod
diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py
index cc29cab7d..e09b1570e 100644
--- a/test/orm/declarative/test_inheritance.py
+++ b/test/orm/declarative/test_inheritance.py
@@ -27,11 +27,11 @@ Base = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py
index 631527daf..ad4832c35 100644
--- a/test/orm/declarative/test_mixin.py
+++ b/test/orm/declarative/test_mixin.py
@@ -38,13 +38,13 @@ mapper_registry = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base, mapper_registry
mapper_registry = registry(metadata=MetaData())
Base = mapper_registry.generate_base()
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
with testing.db.begin() as conn:
diff --git a/test/orm/declarative/test_reflection.py b/test/orm/declarative/test_reflection.py
index 241528c44..e7b2a7058 100644
--- a/test/orm/declarative/test_reflection.py
+++ b/test/orm/declarative/test_reflection.py
@@ -17,14 +17,13 @@ from sqlalchemy.testing.schema import Table
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry(metadata=MetaData())
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
index bdcdedc44..da07b4941 100644
--- a/test/orm/inheritance/test_basic.py
+++ b/test/orm/inheritance/test_basic.py
@@ -31,7 +31,6 @@ from sqlalchemy.orm import synonym
from sqlalchemy.orm.util import instance_str
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
@@ -1889,7 +1888,6 @@ class VersioningTest(fixtures.MappedTest):
@testing.emits_warning(r".*updated rowcount")
@testing.requires.sane_rowcount_w_returning
- @engines.close_open_connections
def test_save_update(self):
subtable, base, stuff = (
self.tables.subtable,
@@ -2927,7 +2925,7 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase):
)
return parent, child
- def tearDown(self):
+ def teardown_test(self):
clear_mappers()
def test_warning_on_sub(self):
@@ -3417,27 +3415,26 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest):
@classmethod
def insert_data(cls, connection):
Parent, A, B = cls.classes("Parent", "A", "B")
- s = fixture_session()
-
- p1 = Parent(id=1)
- p2 = Parent(id=2)
- s.add_all([p1, p2])
- s.flush()
+ with Session(connection) as s:
+ p1 = Parent(id=1)
+ p2 = Parent(id=2)
+ s.add_all([p1, p2])
+ s.flush()
- s.add_all(
- [
- A(id=1, parent_id=1),
- B(id=2, parent_id=1),
- A(id=3, parent_id=1),
- B(id=4, parent_id=1),
- ]
- )
- s.flush()
+ s.add_all(
+ [
+ A(id=1, parent_id=1),
+ B(id=2, parent_id=1),
+ A(id=3, parent_id=1),
+ B(id=4, parent_id=1),
+ ]
+ )
+ s.flush()
- s.query(A).filter(A.id.in_([3, 4])).update(
- {A.type: None}, synchronize_session=False
- )
- s.commit()
+ s.query(A).filter(A.id.in_([3, 4])).update(
+ {A.type: None}, synchronize_session=False
+ )
+ s.commit()
def test_pk_is_null(self):
Parent, A = self.classes("Parent", "A")
@@ -3527,10 +3524,12 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest):
ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes(
"ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB"
)
- s = fixture_session()
+ with Session(connection) as s:
- s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()])
- s.commit()
+ s.add_all(
+ [ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]
+ )
+ s.commit()
def test_single_invalid_ident(self):
ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA")
diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py
index 8820aa6a4..0a0a5d12b 100644
--- a/test/orm/test_attributes.py
+++ b/test/orm/test_attributes.py
@@ -209,7 +209,7 @@ class AttributeImplAPITest(fixtures.MappedTest):
class AttributesTest(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
global MyTest, MyTest2
class MyTest(object):
@@ -218,7 +218,7 @@ class AttributesTest(fixtures.ORMTest):
class MyTest2(object):
pass
- def teardown(self):
+ def teardown_test(self):
global MyTest, MyTest2
MyTest, MyTest2 = None, None
@@ -3690,7 +3690,7 @@ class EventPropagateTest(fixtures.TestBase):
class CollectionInitTest(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
@@ -3749,7 +3749,7 @@ class CollectionInitTest(fixtures.TestBase):
class TestUnlink(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py
index 2f54f7fff..014fa152e 100644
--- a/test/orm/test_bind.py
+++ b/test/orm/test_bind.py
@@ -428,39 +428,46 @@ class BindIntegrationTest(_fixtures.FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- c = testing.db.connect()
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.close()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u2")
- sess.add(u)
- sess.flush()
- sess.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
-
- with c.begin():
- c.exec_driver_sql("delete from users")
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- c = testing.db.connect()
-
- trans = c.begin()
- sess = Session(bind=c, autocommit=True)
- u = User(name="u3")
- sess.add(u)
- sess.flush()
- assert c.in_transaction()
- trans.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
+ with testing.db.connect() as c:
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.close()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u2")
+ sess.add(u)
+ sess.flush()
+ sess.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
+
+ with c.begin():
+ c.exec_driver_sql("delete from users")
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ with testing.db.connect() as c:
+ trans = c.begin()
+ sess = Session(bind=c, autocommit=True)
+ u = User(name="u3")
+ sess.add(u)
+ sess.flush()
+ assert c.in_transaction()
+ trans.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
class SessionBindTest(fixtures.MappedTest):
@@ -506,6 +513,7 @@ class SessionBindTest(fixtures.MappedTest):
finally:
if hasattr(bind, "close"):
bind.close()
+ sess.close()
def test_session_unbound(self):
Foo = self.classes.Foo
diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py
index 3d09bd446..2a0aafbbc 100644
--- a/test/orm/test_collection.py
+++ b/test/orm/test_collection.py
@@ -92,13 +92,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
return str((id(self), self.a, self.b, self.c))
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
instrumentation.register_class(cls.Entity)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation.unregister_class(cls.Entity)
- super(CollectionsTest, cls).teardown_class()
_entity_id = 1
diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py
index df652daf4..20d8ecc2d 100644
--- a/test/orm/test_compile.py
+++ b/test/orm/test_compile.py
@@ -19,7 +19,7 @@ from sqlalchemy.testing import fixtures
class CompileTest(fixtures.ORMTest):
"""test various mapper compilation scenarios"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_with_polymorphic(self):
diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py
index e1ef67fed..ed11b89c9 100644
--- a/test/orm/test_cycles.py
+++ b/test/orm/test_cycles.py
@@ -1743,8 +1743,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest):
id = Column(Integer, primary_key=True)
a_id = Column(ForeignKey("a.id", name="a_fk"))
- def setup(self):
- super(PostUpdateOnUpdateTest, self).setup()
+ def setup_test(self):
PostUpdateOnUpdateTest.counter = count()
PostUpdateOnUpdateTest.db_counter = count()
diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py
index 6d946cfe6..15063ebe9 100644
--- a/test/orm/test_deprecations.py
+++ b/test/orm/test_deprecations.py
@@ -2199,6 +2199,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture):
class AutocommitClosesOnFailTest(fixtures.MappedTest):
__requires__ = ("deferrable_fks",)
+ __only_on__ = ("postgresql+psycopg2",) # needs #5824 for asyncpg
@classmethod
def define_tables(cls, metadata):
@@ -4498,44 +4499,49 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
warnings += (join_aliased_dep,)
# load a user who has an order that contains item id 3 and address
# id 1 (order 3, owned by jack)
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_, isouter=True)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_, isouter=True)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
-
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .outerjoin("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .outerjoin("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join(
+ "orders", "items", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join(
+ "orders", "address", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .outerjoin("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .outerjoin("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
class AliasFromCorrectLeftTest(
diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py
index 4498fc1ff..7eedb37c9 100644
--- a/test/orm/test_eager_relations.py
+++ b/test/orm/test_eager_relations.py
@@ -559,15 +559,15 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
5,
),
]:
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opt).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opt).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
def test_disable_dynamic(self):
"""test no joined option on a dynamic."""
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index 1c918a88c..e85c23d6f 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -45,13 +45,14 @@ from test.orm import _fixtures
class _RemoveListeners(object):
- def teardown(self):
+ @testing.fixture(autouse=True)
+ def _remove_listeners(self):
+ yield
events.MapperEvents._clear()
events.InstanceEvents._clear()
events.SessionEvents._clear()
events.InstrumentationEvents._clear()
events.QueryEvents._clear()
- super(_RemoveListeners, self).teardown()
class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
@@ -1174,7 +1175,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
argnames="target, event_name, fn",
)(fn)
- def teardown(self):
+ def teardown_test(self):
A = self.classes.A
A._sa_class_manager.dispatch._clear()
diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py
index cc9596466..f622bff02 100644
--- a/test/orm/test_froms.py
+++ b/test/orm/test_froms.py
@@ -2395,16 +2395,19 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
]
adalias = addresses.alias()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
- assert q.all() == expected
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+
+ eq_(q.all(), expected)
# test with a straight statement
s = (
@@ -2417,52 +2420,57 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
.group_by(*[c for c in users.c])
.order_by(users.c.id)
)
- q = fixture_session().query(User)
- result = (
- q.add_columns(s.selected_columns.count, s.selected_columns.concat)
- .from_statement(s)
- .all()
- )
- assert result == expected
-
- sess.expunge_all()
- # test with select_entity_from()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .select_entity_from(users.outerjoin(addresses))
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = sess.query(User)
+ result = (
+ q.add_columns(
+ s.selected_columns.count, s.selected_columns.concat
+ )
+ .from_statement(s)
+ .all()
+ )
+ eq_(result, expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ # test with select_entity_from()
+ q = (
+ fixture_session()
+ .query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .select_entity_from(users.outerjoin(addresses))
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .outerjoin("addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ eq_(q.all(), expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin("addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+ eq_(q.all(), expected)
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- assert q.all() == expected
- sess.expunge_all()
+ eq_(q.all(), expected)
def test_expression_selectable_matches_mzero(self):
User, Address = self.classes.User, self.classes.Address
diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py
index 3061de309..43cf81e6d 100644
--- a/test/orm/test_lazy_relations.py
+++ b/test/orm/test_lazy_relations.py
@@ -717,24 +717,24 @@ class LazyTest(_fixtures.FixtureTest):
),
)
- sess = fixture_session()
+ with fixture_session() as sess:
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_uses_get_compatible_types(self):
"""test the use_get optimization with compatible
@@ -789,24 +789,23 @@ class LazyTest(_fixtures.FixtureTest):
properties=dict(user=relationship(mapper(User, users))),
)
- sess = fixture_session()
-
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ with fixture_session() as sess:
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_many_to_one(self):
users, Address, addresses, User = (
diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py
index 0e8ac97e3..42b5b3e45 100644
--- a/test/orm/test_load_on_fks.py
+++ b/test/orm/test_load_on_fks.py
@@ -9,14 +9,12 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import instance_state
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
-engine = testing.db
-
-
class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
@@ -36,27 +34,27 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
)
parent_id = Column(Integer, ForeignKey("parent.id"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
- def tearDown(self):
- Base.metadata.drop_all(engine)
+ def teardown_test(self):
+ Base.metadata.drop_all(testing.db)
def test_annoying_autoflush_one(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
p1.children = []
def test_annoying_autoflush_two(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
assert p1.children == []
def test_dont_load_if_no_keys(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
@@ -68,7 +66,9 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
@@ -91,10 +91,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
parent = relationship(Parent, backref=backref("children"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
global sess, p1, p2, c1, c2
- sess = Session(bind=engine)
+ sess = Session(bind=testing.db)
p1 = Parent()
p2 = Parent()
@@ -105,9 +105,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
sess.commit()
- def tearDown(self):
+ def teardown_test(self):
sess.rollback()
- Base.metadata.drop_all(engine)
+ Base.metadata.drop_all(testing.db)
def test_load_on_pending_allows_backref_event(self):
Child.parent.property.load_on_pending = True
diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py
index 013eb21e1..d182fd2c1 100644
--- a/test/orm/test_mapper.py
+++ b/test/orm/test_mapper.py
@@ -2560,7 +2560,7 @@ class MagicNamesTest(fixtures.MappedTest):
class DocumentTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.mapper = registry().map_imperatively
@@ -2624,14 +2624,14 @@ class DocumentTest(fixtures.TestBase):
class ORMLoggingTest(_fixtures.FixtureTest):
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.orm")]:
log.addHandler(self.buf)
self.mapper = registry().map_imperatively
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.orm")]:
log.removeHandler(self.buf)
diff --git a/test/orm/test_options.py b/test/orm/test_options.py
index b22b318e9..6f47c1238 100644
--- a/test/orm/test_options.py
+++ b/test/orm/test_options.py
@@ -1523,9 +1523,7 @@ class PickleTest(PathTest, QueryTest):
class LocalOptsTest(PathTest, QueryTest):
@classmethod
- def setup_class(cls):
- super(LocalOptsTest, cls).setup_class()
-
+ def setup_test_class(cls):
@strategy_options.loader_option()
def some_col_opt_only(loadopt, key, opts):
return loadopt.set_column_strategy(
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index fd8e849fb..7546ba162 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -6000,12 +6000,19 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
[User.orders_syn, Order.items_syn],
[User.orders_syn_2, Order.items_syn],
):
- q = fixture_session().query(User)
- for path in j:
- q = q.join(path)
- q = q.filter_by(id=3)
- result = q.all()
- assert [User(id=7, name="jack"), User(id=9, name="fred")] == result
+ with fixture_session() as sess:
+ q = sess.query(User)
+ for path in j:
+ q = q.join(path)
+ q = q.filter_by(id=3)
+ result = q.all()
+ eq_(
+ result,
+ [
+ User(id=7, name="jack"),
+ User(id=9, name="fred"),
+ ],
+ )
def test_with_parent(self):
Order, User = self.classes.Order, self.classes.User
@@ -6018,17 +6025,17 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
("name_syn", "orders_syn"),
("name_syn", "orders_syn_2"),
):
- sess = fixture_session()
- q = sess.query(User)
+ with fixture_session() as sess:
+ q = sess.query(User)
- u1 = q.filter_by(**{nameprop: "jack"}).one()
+ u1 = q.filter_by(**{nameprop: "jack"}).one()
- o = sess.query(Order).with_parent(u1, property=orderprop).all()
- assert [
- Order(description="order 1"),
- Order(description="order 3"),
- Order(description="order 5"),
- ] == o
+ o = sess.query(Order).with_parent(u1, property=orderprop).all()
+ assert [
+ Order(description="order 1"),
+ Order(description="order 3"),
+ Order(description="order 5"),
+ ] == o
def test_froms_aliased_col(self):
Address, User = self.classes.Address, self.classes.User
diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py
index 12c084b2d..ef1bf2e60 100644
--- a/test/orm/test_rel_fn.py
+++ b/test/orm/test_rel_fn.py
@@ -26,7 +26,7 @@ from sqlalchemy.testing import mock
class _JoinFixtures(object):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
m = MetaData()
cls.left = Table(
"lft",
diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py
index 5979f08ae..8d73cd40e 100644
--- a/test/orm/test_relationships.py
+++ b/test/orm/test_relationships.py
@@ -637,7 +637,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase):
"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture_one(
@@ -2474,7 +2474,7 @@ class JoinConditionErrorTest(fixtures.TestBase):
assert_raises(sa.exc.ArgumentError, configure_mappers)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@@ -4354,7 +4354,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest):
class SecondaryArgTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@testing.combinations((True,), (False,))
diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py
index 5535fe5d6..4895c7d3a 100644
--- a/test/orm/test_selectin_relations.py
+++ b/test/orm/test_selectin_relations.py
@@ -713,35 +713,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
diff --git a/test/orm/test_session.py b/test/orm/test_session.py
index 20c4752b8..3d4566af3 100644
--- a/test/orm/test_session.py
+++ b/test/orm/test_session.py
@@ -1273,7 +1273,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
run_inserts = None
- def setup(self):
+ def setup_test(self):
mapper(self.classes.User, self.tables.users)
def _assert_modified(self, u1):
@@ -1288,11 +1288,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
def _assert_no_cycle(self, u1):
assert sa.orm.attributes.instance_state(u1)._strong_obj is None
- def _persistent_fixture(self):
+ def _persistent_fixture(self, gc_collect=False):
User = self.classes.User
u1 = User()
u1.name = "ed"
- sess = fixture_session()
+ if gc_collect:
+ sess = Session(testing.db)
+ else:
+ sess = fixture_session()
sess.add(u1)
sess.flush()
return sess, u1
@@ -1389,14 +1392,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
@testing.requires.predictable_gc
def test_move_gc_session_persistent_dirty(self):
- sess, u1 = self._persistent_fixture()
+ sess, u1 = self._persistent_fixture(gc_collect=True)
u1.name = "edchanged"
self._assert_cycle(u1)
self._assert_modified(u1)
del sess
gc_collect()
self._assert_cycle(u1)
- s2 = fixture_session()
+ s2 = Session(testing.db)
s2.add(u1)
self._assert_cycle(u1)
self._assert_modified(u1)
@@ -1565,7 +1568,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
mapper(User, users)
- sess = fixture_session()
+ sess = Session(testing.db)
u1 = User(name="u1")
sess.add(u1)
@@ -1573,7 +1576,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
# can't add u1 to Session,
# already belongs to u2
- s2 = fixture_session()
+ s2 = Session(testing.db)
assert_raises_message(
sa.exc.InvalidRequestError,
r".*is already attached to session",
@@ -1725,11 +1728,10 @@ class DisposedStates(fixtures.MappedTest):
mapper(T, cls.tables.t1)
- def teardown(self):
+ def teardown_test(self):
from sqlalchemy.orm.session import _sessions
_sessions.clear()
- super(DisposedStates, self).teardown()
def _set_imap_in_disposal(self, sess, *objs):
"""remove selected objects from the given session, as though
diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py
index fe20442a3..150cee222 100644
--- a/test/orm/test_subquery_relations.py
+++ b/test/orm/test_subquery_relations.py
@@ -734,35 +734,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py
index 550cf6535..7f77b01c7 100644
--- a/test/orm/test_transaction.py
+++ b/test/orm/test_transaction.py
@@ -66,17 +66,18 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
users, User = self.tables.users, self.classes.User
mapper(User, users)
- conn = testing.db.connect()
- trans = conn.begin()
- sess = Session(bind=conn, autocommit=False, autoflush=True)
- sess.begin(subtransactions=True)
- u = User(name="ed")
- sess.add(u)
- sess.flush()
- sess.commit() # commit does nothing
- trans.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+
+ with testing.db.connect() as conn:
+ trans = conn.begin()
+ sess = Session(bind=conn, autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
+ u = User(name="ed")
+ sess.add(u)
+ sess.flush()
+ sess.commit() # commit does nothing
+ trans.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@engines.close_open_connections
def test_subtransaction_on_external_no_begin(self):
@@ -260,34 +261,33 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
users = self.tables.users
engine = Engine._future_facade(testing.db)
- session = Session(engine, autocommit=False)
-
- session.begin()
- session.connection().execute(users.insert().values(name="user1"))
- session.begin(subtransactions=True)
- session.begin_nested()
- session.connection().execute(users.insert().values(name="user2"))
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
- session.rollback()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(users.insert().values(name="user3"))
- session.commit()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
+ with Session(engine, autocommit=False) as session:
+ session.begin()
+ session.connection().execute(users.insert().values(name="user1"))
+ session.begin(subtransactions=True)
+ session.begin_nested()
+ session.connection().execute(users.insert().values(name="user2"))
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ session.rollback()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(users.insert().values(name="user3"))
+ session.commit()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting(self):
@@ -295,27 +295,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- s = Session(testing.db)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session() as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- s.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ s.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- s.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ s.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting_future(self):
@@ -323,27 +323,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- s = Session(testing.db, future=True)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session(future=True) as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- nt2.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ nt2.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- nt1.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ nt1.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.independent_connections
def test_transactions_isolated(self):
@@ -1049,23 +1049,25 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- session = Session(testing.db)
+ with fixture_session() as session:
- with expect_warnings(".*during handling of a previous exception.*"):
- session.begin_nested()
- savepoint = session.connection()._nested_transaction._savepoint
+ with expect_warnings(
+ ".*during handling of a previous exception.*"
+ ):
+ session.begin_nested()
+ savepoint = session.connection()._nested_transaction._savepoint
- # force the savepoint to disappear
- session.connection().dialect.do_release_savepoint(
- session.connection(), savepoint
- )
+ # force the savepoint to disappear
+ session.connection().dialect.do_release_savepoint(
+ session.connection(), savepoint
+ )
- # now do a broken flush
- session.add_all([User(id=1), User(id=1)])
+ # now do a broken flush
+ session.add_all([User(id=1), User(id=1)])
- assert_raises_message(
- sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
- )
+ assert_raises_message(
+ sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
+ )
class _LocalFixture(FixtureTest):
@@ -1170,39 +1172,40 @@ class SubtransactionRecipeTest(FixtureTest):
def test_recipe_heavy_nesting(self, subtransaction_recipe):
users = self.tables.users
- session = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(session):
- session.connection().execute(users.insert().values(name="user1"))
+ with fixture_session(future=self.future) as session:
with subtransaction_recipe(session):
- savepoint = session.begin_nested()
session.connection().execute(
- users.insert().values(name="user2")
+ users.insert().values(name="user1")
)
+ with subtransaction_recipe(session):
+ savepoint = session.begin_nested()
+ session.connection().execute(
+ users.insert().values(name="user2")
+ )
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ savepoint.rollback()
+
+ with subtransaction_recipe(session):
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(
+ users.insert().values(name="user3")
+ )
assert (
session.connection()
.exec_driver_sql("select count(1) from users")
.scalar()
== 2
)
- savepoint.rollback()
-
- with subtransaction_recipe(session):
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(
- users.insert().values(name="user3")
- )
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
@engines.close_open_connections
def test_recipe_subtransaction_on_external_subtrans(
@@ -1228,13 +1231,12 @@ class SubtransactionRecipeTest(FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.close()
- assert len(sess.query(User).all()) == 1
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.close()
+ assert len(sess.query(User).all()) == 1
def test_recipe_subtransaction_on_noautocommit(
self, subtransaction_recipe
@@ -1242,16 +1244,15 @@ class SubtransactionRecipeTest(FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- sess.begin()
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@testing.requires.savepoints
def test_recipe_mixed_transaction_control(self, subtransaction_recipe):
@@ -1259,30 +1260,28 @@ class SubtransactionRecipeTest(FixtureTest):
mapper(User, users)
- sess = Session(testing.db, future=self.future)
+ with fixture_session(future=self.future) as sess:
- sess.begin()
- sess.begin_nested()
+ sess.begin()
+ sess.begin_nested()
- with subtransaction_recipe(sess):
+ with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
+ sess.add(User(name="u1"))
- sess.commit()
- sess.commit()
+ sess.commit()
+ sess.commit()
- eq_(len(sess.query(User).all()), 1)
- sess.close()
+ eq_(len(sess.query(User).all()), 1)
+ sess.close()
- t1 = sess.begin()
- t2 = sess.begin_nested()
-
- sess.add(User(name="u2"))
+ t1 = sess.begin()
+ t2 = sess.begin_nested()
- t2.commit()
- assert sess._legacy_transaction() is t1
+ sess.add(User(name="u2"))
- sess.close()
+ t2.commit()
+ assert sess._legacy_transaction() is t1
def test_recipe_error_on_using_inactive_session_commands(
self, subtransaction_recipe
@@ -1290,56 +1289,55 @@ class SubtransactionRecipeTest(FixtureTest):
users, User = self.tables.users, self.classes.User
mapper(User, users)
- sess = Session(testing.db, future=self.future)
- sess.begin()
-
- try:
- with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
- sess.flush()
- raise Exception("force rollback")
- except:
- pass
-
- if self.recipe_rollsback_early:
- # that was a real rollback, so no transaction
- assert not sess.in_transaction()
- is_(sess.get_transaction(), None)
- else:
- assert sess.in_transaction()
-
- sess.close()
- assert not sess.in_transaction()
-
- def test_recipe_multi_nesting(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- assert sess.in_transaction()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
try:
with subtransaction_recipe(sess):
- assert sess._legacy_transaction()
+ sess.add(User(name="u1"))
+ sess.flush()
raise Exception("force rollback")
except:
pass
if self.recipe_rollsback_early:
+ # that was a real rollback, so no transaction
assert not sess.in_transaction()
+ is_(sess.get_transaction(), None)
else:
assert sess.in_transaction()
- assert not sess.in_transaction()
+ sess.close()
+ assert not sess.in_transaction()
+
+ def test_recipe_multi_nesting(self, subtransaction_recipe):
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ assert sess.in_transaction()
+
+ try:
+ with subtransaction_recipe(sess):
+ assert sess._legacy_transaction()
+ raise Exception("force rollback")
+ except:
+ pass
+
+ if self.recipe_rollsback_early:
+ assert not sess.in_transaction()
+ else:
+ assert sess.in_transaction()
+
+ assert not sess.in_transaction()
def test_recipe_deactive_status_check(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
- sess.begin()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
- with subtransaction_recipe(sess):
- sess.rollback()
+ with subtransaction_recipe(sess):
+ sess.rollback()
- assert not sess.in_transaction()
- sess.commit() # no error
+ assert not sess.in_transaction()
+ sess.commit() # no error
class FixtureDataTest(_LocalFixture):
@@ -1394,28 +1392,28 @@ class CleanSavepointTest(FixtureTest):
mapper(User, users)
- s = Session(bind=testing.db, future=future)
- u1 = User(name="u1")
- u2 = User(name="u2")
- s.add_all([u1, u2])
- s.commit()
- u1.name
- u2.name
- trans = s._transaction
- assert trans is not None
- s.begin_nested()
- update_fn(s, u2)
- eq_(u2.name, "u2modified")
- s.rollback()
+ with fixture_session(future=future) as s:
+ u1 = User(name="u1")
+ u2 = User(name="u2")
+ s.add_all([u1, u2])
+ s.commit()
+ u1.name
+ u2.name
+ trans = s._transaction
+ assert trans is not None
+ s.begin_nested()
+ update_fn(s, u2)
+ eq_(u2.name, "u2modified")
+ s.rollback()
- if future:
- assert s._transaction is None
- assert "name" not in u1.__dict__
- else:
- assert s._transaction is trans
- eq_(u1.__dict__["name"], "u1")
- assert "name" not in u2.__dict__
- eq_(u2.name, "u2")
+ if future:
+ assert s._transaction is None
+ assert "name" not in u1.__dict__
+ else:
+ assert s._transaction is trans
+ eq_(u1.__dict__["name"], "u1")
+ assert "name" not in u2.__dict__
+ eq_(u2.name, "u2")
@testing.requires.savepoints
def test_rollback_ignores_clean_on_savepoint(self):
@@ -2116,82 +2114,108 @@ class ContextManagerPlusFutureTest(FixtureTest):
eq_(sess.query(User).count(), 1)
def test_explicit_begin(self):
- s1 = Session(testing.db)
- with s1.begin() as trans:
- is_(trans, s1._legacy_transaction())
- s1.connection()
+ with fixture_session() as s1:
+ with s1.begin() as trans:
+ is_(trans, s1._legacy_transaction())
+ s1.connection()
- is_(s1._transaction, None)
+ is_(s1._transaction, None)
def test_no_double_begin_explicit(self):
- s1 = Session(testing.db)
- s1.begin()
- assert_raises_message(
- sa_exc.InvalidRequestError,
- "A transaction is already begun on this Session.",
- s1.begin,
- )
+ with fixture_session() as s1:
+ s1.begin()
+ assert_raises_message(
+ sa_exc.InvalidRequestError,
+ "A transaction is already begun on this Session.",
+ s1.begin,
+ )
@testing.requires.savepoints
def test_future_rollback_is_global(self):
users = self.tables.users
- s1 = Session(testing.db, future=True)
+ with fixture_session(future=True) as s1:
+ s1.begin()
- s1.begin()
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.begin_nested()
- s1.begin_nested()
-
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back the whole transaction
- s1.rollback()
- is_(s1._legacy_transaction(), None)
+ # rolls back the whole transaction
+ s1.rollback()
+ is_(s1._legacy_transaction(), None)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 0)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 0,
+ )
- s1.commit()
- is_(s1._legacy_transaction(), None)
+ s1.commit()
+ is_(s1._legacy_transaction(), None)
@testing.requires.savepoints
def test_old_rollback_is_local(self):
users = self.tables.users
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- t1 = s1.begin()
+ t1 = s1.begin()
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.begin_nested()
+ s1.begin_nested()
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back only the savepoint
- s1.rollback()
+ # rolls back only the savepoint
+ s1.rollback()
- is_(s1._legacy_transaction(), t1)
+ is_(s1._legacy_transaction(), t1)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
- s1.commit()
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
- is_not(s1._legacy_transaction(), None)
+ s1.commit()
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
+ is_not(s1._legacy_transaction(), None)
def test_session_as_ctx_manager_one(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2212,7 +2236,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
def test_session_as_ctx_manager_future_one(self):
users = self.tables.users
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2234,7 +2258,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2250,7 +2274,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
users = self.tables.users
try:
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2265,7 +2289,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
def test_begin_context_manager(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2296,12 +2320,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_begin_context_manager_rollback_trans(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2318,12 +2343,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_begin_context_manager_rollback_outer(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2340,6 +2366,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_trans(self):
users = self.tables.users
@@ -2363,6 +2390,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_outer(self):
users = self.tables.users
@@ -2386,36 +2414,37 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
class TransactionFlagsTest(fixtures.TestBase):
def test_in_transaction(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- n1 = s1.begin_nested()
+ n1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), n1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), n1)
- n1.rollback()
+ n1.rollback()
- is_(s1.get_nested_transaction(), None)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ eq_(s1.in_transaction(), True)
- s1.commit()
+ s1.commit()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
def test_in_transaction_subtransactions(self):
"""we'd like to do away with subtransactions for future sessions
@@ -2425,72 +2454,71 @@ class TransactionFlagsTest(fixtures.TestBase):
the external API works.
"""
- s1 = Session(testing.db)
-
- eq_(s1.in_transaction(), False)
+ with fixture_session() as s1:
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- subtrans = s1.begin(_subtrans=True)
- is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ subtrans = s1.begin(_subtrans=True)
+ is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
- is_(s1._transaction, subtrans)
+ is_(s1._transaction, subtrans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), True)
- is_(s1._transaction, trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1._transaction, trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1._transaction, None)
+ eq_(s1.in_transaction(), False)
+ is_(s1._transaction, None)
def test_in_transaction_nesting(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- sp1 = s1.begin_nested()
+ sp1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp2 = s1.begin_nested()
+ sp2 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp2)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp2)
- sp2.rollback()
+ sp2.rollback()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp1.rollback()
+ sp1.rollback()
- is_(s1.get_nested_transaction(), None)
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), False)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), False)
+ is_(s1.get_transaction(), trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
class NaturalPKRollbackTest(fixtures.MappedTest):
@@ -2674,8 +2702,11 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
class JoinIntoAnExternalTransactionFixture(object):
"""Test the "join into an external transaction" examples"""
- def setup(self):
- self.connection = testing.db.connect()
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
+ self.engine = testing.db
+ self.connection = self.engine.connect()
self.metadata = MetaData()
self.table = Table(
@@ -2686,6 +2717,17 @@ class JoinIntoAnExternalTransactionFixture(object):
self.setup_session()
+ def teardown_test(self):
+ self.teardown_session()
+
+ with self.connection.begin():
+ self._assert_count(0)
+
+ with self.connection.begin():
+ self.table.drop(self.connection)
+
+ self.connection.close()
+
def test_something(self):
A = self.A
@@ -2727,18 +2769,6 @@ class JoinIntoAnExternalTransactionFixture(object):
)
eq_(result, count)
- def teardown(self):
- self.teardown_session()
-
- with self.connection.begin():
- self._assert_count(0)
-
- with self.connection.begin():
- self.table.drop(self.connection)
-
- # return connection to the Engine
- self.connection.close()
-
class NewStyleJoinIntoAnExternalTransactionTest(
JoinIntoAnExternalTransactionFixture
@@ -2775,7 +2805,8 @@ class NewStyleJoinIntoAnExternalTransactionTest(
# rollback - everything that happened with the
# Session above (including calls to commit())
# is rolled back.
- self.trans.rollback()
+ if self.trans.is_active:
+ self.trans.rollback()
class FutureJoinIntoAnExternalTransactionTest(
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index 84373b2dc..2c35bec45 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -203,14 +203,6 @@ class UnicodeSchemaTest(fixtures.MappedTest):
cls.tables["t1"] = t1
cls.tables["t2"] = t2
- @classmethod
- def setup_class(cls):
- super(UnicodeSchemaTest, cls).setup_class()
-
- @classmethod
- def teardown_class(cls):
- super(UnicodeSchemaTest, cls).teardown_class()
-
def test_mapping(self):
t2, t1 = self.tables.t2, self.tables.t1
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 4e713627c..65089f773 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -771,14 +771,13 @@ class RudimentaryFlushTest(UOWTest):
class SingleCycleTest(UOWTest):
- def teardown(self):
+ def teardown_test(self):
engines.testing_reaper.rollback_all()
# mysql can't handle delete from nodes
# since it doesn't deal with the FKs correctly,
# so wipe out the parent_id first
with testing.db.begin() as conn:
conn.execute(self.tables.nodes.update().values(parent_id=None))
- super(SingleCycleTest, self).teardown()
def test_one_to_many_save(self):
Node, nodes = self.classes.Node, self.tables.nodes
diff --git a/test/requirements.py b/test/requirements.py
index d5a718372..3c9b39ac7 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -1624,15 +1624,15 @@ class DefaultRequirements(SuiteRequirements):
@property
def postgresql_utf8_server_encoding(self):
+ def go(config):
+ if not against(config, "postgresql"):
+ return False
- return only_if(
- lambda config: against(config, "postgresql")
- and config.db.connect(close_with_result=True)
- .exec_driver_sql("show server_encoding")
- .scalar()
- .lower()
- == "utf8"
- )
+ with config.db.connect() as conn:
+ enc = conn.exec_driver_sql("show server_encoding").scalar()
+ return enc.lower() == "utf8"
+
+ return only_if(go)
@property
def cxoracle6_or_greater(self):
diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py
index 4bef1df7f..b44971cec 100644
--- a/test/sql/test_case_statement.py
+++ b/test/sql/test_case_statement.py
@@ -26,7 +26,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
metadata = MetaData()
global info_table
info_table = Table(
@@ -52,7 +52,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
info_table.drop(conn)
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 70281d4e8..1ac3613f7 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -1203,7 +1203,7 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# TODO: we need to get dialects here somehow, perhaps in test_suite?
[
importlib.import_module("sqlalchemy.dialects.%s" % d)
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index fdffe04bf..4429753ec 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -4306,7 +4306,7 @@ class StringifySpecialTest(fixtures.TestBase):
class KwargPropagationTest(fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.sql.expression import ColumnClause, TableClause
class CatchCol(ColumnClause):
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index 2a2e70bc3..8be7eed1f 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -503,9 +503,8 @@ class DefaultRoundTripTest(fixtures.TablesTest):
Column("col11", MyType(), default="foo"),
)
- def teardown(self):
+ def teardown_test(self):
self.default_generator["x"] = 50
- super(DefaultRoundTripTest, self).teardown()
def test_standalone(self, connection):
t = self.tables.default_test
@@ -1226,7 +1225,7 @@ class SpecialTypePKTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
class MyInteger(TypeDecorator):
impl = Integer
diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py
index acc12a5fe..777565220 100644
--- a/test/sql/test_deprecations.py
+++ b/test/sql/test_deprecations.py
@@ -2488,12 +2488,12 @@ class LegacySequenceExecTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
@@ -2574,7 +2574,7 @@ class LegacySequenceExecTest(fixtures.TestBase):
class DDLDeprecatedBindTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
if inspect(conn).has_table("foo"):
conn.execute(schema.DropTable(table("foo")))
diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py
index 4edc9d025..a6001ba9d 100644
--- a/test/sql/test_external_traversal.py
+++ b/test/sql/test_external_traversal.py
@@ -47,7 +47,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
ability to copy and modify a ClauseElement in place."""
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global A, B
# establish two fictitious ClauseElements.
@@ -321,7 +321,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, t3
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1012,7 +1012,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table(
"table1",
@@ -1196,7 +1196,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1943,7 +1943,7 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global table1, table2, table3, table4
def _table(name):
@@ -2031,7 +2031,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -2128,7 +2128,7 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
# fixme: consolidate converage from elsewhere here and expand
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
index 6afe41aac..b0bcee18e 100644
--- a/test/sql/test_from_linter.py
+++ b/test/sql/test_from_linter.py
@@ -25,7 +25,7 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
- def setup(self):
+ def setup_test(self):
self.a = self.tables.table_a
self.b = self.tables.table_b
self.c = self.tables.table_c
@@ -267,8 +267,10 @@ class TestLinter(fixtures.TablesTest):
with self.bind.connect() as conn:
conn.execute(query)
- def test_no_linting(self):
- eng = engines.testing_engine(options={"enable_from_linting": False})
+ def test_no_linting(self, metadata, connection):
+ eng = engines.testing_engine(
+ options={"enable_from_linting": False, "use_reaper": False}
+ )
eng.pool = self.bind.pool # needed for SQLite
a, b = self.tables("table_a", "table_b")
query = select(a.c.col_a).where(b.c.col_b == 5)
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index 91076f9c3..32ea642d7 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -54,10 +54,10 @@ table1 = table(
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_compile(self):
@@ -938,7 +938,7 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
class ExecuteTest(fixtures.TestBase):
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
pass
def test_conn_execute(self, connection):
@@ -1113,10 +1113,10 @@ class ExecuteTest(fixtures.TestBase):
class RegisterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_GenericFunction_is_registered(self):
diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py
index 502f70ce7..9bf351f5c 100644
--- a/test/sql/test_metadata.py
+++ b/test/sql/test_metadata.py
@@ -4257,7 +4257,7 @@ class DialectKWArgTest(fixtures.TestBase):
with mock.patch("sqlalchemy.dialects.registry.load", load):
yield
- def teardown(self):
+ def teardown_test(self):
Index._kw_registry.clear()
def test_participating(self):
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index aaeed68dd..270e79ba1 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -608,7 +608,7 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
@@ -766,7 +766,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
@@ -2370,7 +2370,7 @@ class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
@@ -2403,7 +2403,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default_enhanced"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 136f10cf4..7ad12c620 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -661,8 +661,8 @@ class CursorResultTest(fixtures.TablesTest):
assert_raises(KeyError, lambda: row._mapping["Case_insensitive"])
assert_raises(KeyError, lambda: row._mapping["casesensitive"])
- def test_row_case_sensitive_unoptimized(self):
- with engines.testing_engine().connect() as ins_conn:
+ def test_row_case_sensitive_unoptimized(self, testing_engine):
+ with testing_engine().connect() as ins_conn:
row = ins_conn.execute(
select(
literal_column("1").label("case_insensitive"),
@@ -1234,8 +1234,7 @@ class CursorResultTest(fixtures.TablesTest):
eq_(proxy[0], "value")
eq_(proxy._mapping["key"], "value")
- @testing.provide_metadata
- def test_no_rowcount_on_selects_inserts(self):
+ def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine):
"""assert that rowcount is only called on deletes and updates.
This because cursor.rowcount may can be expensive on some dialects
@@ -1244,9 +1243,7 @@ class CursorResultTest(fixtures.TablesTest):
"""
- metadata = self.metadata
-
- engine = engines.testing_engine()
+ engine = testing_engine()
t = Table("t1", metadata, Column("data", String(10)))
metadata.create_all(engine)
@@ -2132,7 +2129,9 @@ class AlternateCursorResultTest(fixtures.TablesTest):
@classmethod
def setup_bind(cls):
- cls.engine = engine = engines.testing_engine("sqlite://")
+ cls.engine = engine = engines.testing_engine(
+ "sqlite://", options={"scope": "class"}
+ )
return engine
@classmethod
diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py
index 65325aa6f..5cfc2663f 100644
--- a/test/sql/test_sequences.py
+++ b/test/sql/test_sequences.py
@@ -100,12 +100,12 @@ class SequenceExecTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 0e1147800..64ace87df 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1375,7 +1375,7 @@ class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL):
class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
- def setup(self):
+ def setup_test(self):
class UTypeOne(types.UserDefinedType):
def get_col_spec(self):
return "UTYPEONE"
@@ -2504,7 +2504,7 @@ class BinaryTest(fixtures.TablesTest, AssertsExecutionResults):
class JSONTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -3445,7 +3445,12 @@ class BooleanTest(
@testing.requires.non_native_boolean_unconstrained
def test_constraint(self, connection):
assert_raises(
- (exc.IntegrityError, exc.ProgrammingError, exc.OperationalError),
+ (
+ exc.IntegrityError,
+ exc.ProgrammingError,
+ exc.OperationalError,
+ exc.InternalError, # older pymysql's do this
+ ),
connection.exec_driver_sql,
"insert into boolean_table (id, value) values(1, 5)",
)