diff options
Diffstat (limited to 'test')
86 files changed, 1780 insertions, 1748 deletions
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 0202768ae..968a74700 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -18,7 +18,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2, metadata metadata = MetaData() diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 75a4f51cf..a41a8b9f1 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -241,7 +241,7 @@ def assert_no_mappers(): class EnsureZeroed(fixtures.ORMTest): - def setup(self): + def setup_test(self): _sessions.clear() _mapper_registry.clear() @@ -1032,7 +1032,7 @@ class MemUsageWBackendTest(EnsureZeroed): t2_mapper = mapper(T2, t2) t1_mapper.add_property("bar", relationship(t2_mapper)) - s1 = fixture_session() + s1 = Session(testing.db) # this causes the path_registry to be invoked s1.query(t1_mapper)._compile_context() diff --git a/test/aaa_profiling/test_misc.py b/test/aaa_profiling/test_misc.py index db6fd4b71..5b30a3968 100644 --- a/test/aaa_profiling/test_misc.py +++ b/test/aaa_profiling/test_misc.py @@ -19,7 +19,7 @@ from sqlalchemy.util import classproperty class EnumTest(fixtures.TestBase): __requires__ = ("cpython", "python_profiling_backend") - def setup(self): + def setup_test(self): class SomeEnum(object): # Implements PEP 435 in the minimal fashion needed by SQLAlchemy diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index f163078d8..8116e5f21 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -29,15 +29,13 @@ class NoCache(object): run_setup_bind = "each" @classmethod - def setup_class(cls): - super(NoCache, cls).setup_class() + def setup_test_class(cls): cls._cache = config.db._compiled_cache config.db._compiled_cache = None @classmethod - def teardown_class(cls): + def teardown_test_class(cls): config.db._compiled_cache = cls._cache - super(NoCache, cls).teardown_class() class MergeTest(NoCache, fixtures.MappedTest): diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py index fd02f9139..da3c1c525 100644 --- a/test/aaa_profiling/test_pool.py +++ b/test/aaa_profiling/test_pool.py @@ -17,7 +17,7 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults): def close(self): pass - def setup(self): + def setup_test(self): # create a throwaway pool which # has the effect of initializing # class-level event listeners on Pool, diff --git a/test/base/test_events.py b/test/base/test_events.py index 19f68e9a3..68db5207c 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -16,7 +16,7 @@ from sqlalchemy.testing.util import gc_collect class TearDownLocalEventsFixture(object): - def tearDown(self): + def teardown_test(self): classes = set() for entry in event.base._registrars.values(): for evt_cls in entry: @@ -30,7 +30,7 @@ class TearDownLocalEventsFixture(object): class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test class- and instance-level event registration.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, x, y): pass @@ -438,7 +438,7 @@ class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase): class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase): """test adaption of legacy args""" - def setUp(self): + def setup_test(self): class TargetEventsOne(event.Events): @event._legacy_signature("0.9", ["x", "y"]) def event_three(self, x, y, z, q): @@ -608,7 +608,7 @@ class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase): class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEventsOne(event.Events): def event_one(self, x, y): pass @@ -677,7 +677,7 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase): class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test default target acceptance.""" - def setUp(self): + def setup_test(self): class TargetEventsOne(event.Events): def event_one(self, x, y): pass @@ -734,7 +734,7 @@ class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test custom target acceptance.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): @classmethod def _accept_with(cls, target): @@ -771,7 +771,7 @@ class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase): """test that ad-hoc subclasses are garbage collected.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def some_event(self, x, y): pass @@ -797,7 +797,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test custom listen functions which change the listener function signature.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): @classmethod def _listen(cls, event_key, add=False): @@ -855,7 +855,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase): class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, arg): pass @@ -889,7 +889,7 @@ class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, target, arg): pass @@ -1109,7 +1109,7 @@ class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase): class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, target, arg): pass diff --git a/test/base/test_inspect.py b/test/base/test_inspect.py index 15b98c848..252d0d977 100644 --- a/test/base/test_inspect.py +++ b/test/base/test_inspect.py @@ -13,7 +13,7 @@ class TestFixture(object): class TestInspection(fixtures.TestBase): - def tearDown(self): + def teardown_test(self): for type_ in list(inspection._registrars): if issubclass(type_, TestFixture): del inspection._registrars[type_] diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py index 14e87ef69..6320ef052 100644 --- a/test/base/test_tutorials.py +++ b/test/base/test_tutorials.py @@ -48,11 +48,11 @@ class DocTest(fixtures.TestBase): ddl.sort_tables_and_constraints = self.orig_sort - def setup(self): + def setup_test(self): self._setup_logger() self._setup_create_table_patcher() - def teardown(self): + def teardown_test(self): self._teardown_create_table_patcher() self._teardown_logger() diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 8119612e1..f0bb66aa9 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -1814,7 +1814,7 @@ class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL): class SchemaTest(fixtures.TestBase): - def setup(self): + def setup_test(self): t = Table( "sometable", MetaData(), diff --git a/test/dialect/mssql/test_deprecations.py b/test/dialect/mssql/test_deprecations.py index c869182c5..27709beb0 100644 --- a/test/dialect/mssql/test_deprecations.py +++ b/test/dialect/mssql/test_deprecations.py @@ -31,7 +31,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL): """ - def setup(self): + def setup_test(self): metadata = MetaData() self.t1 = table( "t1", diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index cdb37cc61..b806b9247 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -455,7 +455,7 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): return testing.db.execution_options(isolation_level="AUTOCOMMIT") @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.connect().execution_options( isolation_level="AUTOCOMMIT" ) as conn: @@ -463,7 +463,6 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog") except: pass - super(MatchTest, cls).setup_class() @classmethod def insert_data(cls, connection): diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 62292b9da..7fd24e8b5 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -991,7 +991,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = mysql.dialect() - def setup(self): + def setup_test(self): self.table = Table( "foos", MetaData(), @@ -1062,7 +1062,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): class RegexpCommon(testing.AssertsCompiledSQL): - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index 40617e59c..795b2cbd3 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -1115,7 +1115,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): class RawReflectionTest(fixtures.TestBase): __backend__ = True - def setup(self): + def setup_test(self): dialect = mysql.dialect() self.parser = _reflection.MySQLTableDefinitionParser( dialect, dialect.identifier_preparer diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 1b8b3fb89..f09346eb3 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -1355,7 +1355,7 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "oracle" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index df87fe89f..32234bf65 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -439,7 +439,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.begin() as c: c.exec_driver_sql( """ @@ -471,7 +471,7 @@ end; assert isinstance(result.out_parameters["x_out"], int) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: conn.execute(text("DROP PROCEDURE foo")) diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index 81e4e4ab5..0df4236e2 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -39,7 +39,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): # currently assuming full DBA privs for the user. # don't really know how else to go here unless # we connect as the other user. @@ -85,7 +85,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): conn.exec_driver_sql(stmt) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: for stmt in ( """ @@ -379,7 +379,7 @@ class SystemTableTablenamesTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - def setup(self): + def setup_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("create table my_table (id integer)") conn.exec_driver_sql( @@ -389,7 +389,7 @@ class SystemTableTablenamesTest(fixtures.TestBase): "create table foo_table (id integer) tablespace SYSTEM" ) - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("drop table my_temp_table") conn.exec_driver_sql("drop table my_table") @@ -421,7 +421,7 @@ class DontReflectIOTTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - def setup(self): + def setup_test(self): with testing.db.begin() as conn: conn.exec_driver_sql( """ @@ -438,7 +438,7 @@ class DontReflectIOTTest(fixtures.TestBase): """, ) - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("drop table admin_docindex") @@ -715,7 +715,7 @@ class DBLinkReflectionTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy.testing import config cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link") @@ -734,7 +734,7 @@ class DBLinkReflectionTest(fixtures.TestBase): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: conn.exec_driver_sql("drop synonym test_table_syn") conn.exec_driver_sql("drop table test_table") diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index f008ea019..8ea7c0e04 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -1011,7 +1011,7 @@ class EuroNumericTest(fixtures.TestBase): __only_on__ = "oracle+cx_oracle" __backend__ = True - def setup(self): + def setup_test(self): connect = testing.db.pool._creator def _creator(): @@ -1023,7 +1023,7 @@ class EuroNumericTest(fixtures.TestBase): self.engine = testing_engine(options={"creator": _creator}) - def teardown(self): + def teardown_test(self): self.engine.dispose() def test_were_getting_a_comma(self): diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py index fadf939b8..f6d48f3c6 100644 --- a/test/dialect/postgresql/test_async_pg_py3k.py +++ b/test/dialect/postgresql/test_async_pg_py3k.py @@ -27,7 +27,7 @@ class AsyncPgTest(fixtures.TestBase): # TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3 # merges - from sqlalchemy.testing import engines + from sqlalchemy.testing import util as testing_util from sqlalchemy.sql import schema metadata = schema.MetaData() @@ -35,7 +35,7 @@ class AsyncPgTest(fixtures.TestBase): try: yield metadata finally: - engines.drop_all_tables(metadata, testing.db) + testing_util.drop_all_tables_from_metadata(metadata, testing.db) @async_test async def test_detect_stale_ddl_cache_raise_recover( diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 1763b210b..b3a0b9bbd 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1810,7 +1810,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() - def setup(self): + def setup_test(self): self.table1 = table1 = table( "mytable", column("myid", Integer), @@ -2222,7 +2222,7 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() - def setup(self): + def setup_test(self): self.table = Table( "t", MetaData(), @@ -2373,7 +2373,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() - def setup(self): + def setup_test(self): self.table = Table( "t", MetaData(), @@ -2464,7 +2464,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "postgresql" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index f760a309b..9c9d817bb 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -198,17 +198,17 @@ class ExecuteManyMode(object): @config.fixture() def connection(self): - eng = engines.testing_engine(options=self.options) + opts = dict(self.options) + opts["use_reaper"] = False + eng = engines.testing_engine(options=opts) conn = eng.connect() trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() - eng.dispose() + yield conn + if trans.is_active: + trans.rollback() + conn.close() + eng.dispose() @classmethod def define_tables(cls, metadata): @@ -510,8 +510,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): # assert result.closed assert result.cursor is None - @testing.provide_metadata - def test_insert_returning_preexecute_pk(self, connection): + def test_insert_returning_preexecute_pk(self, metadata, connection): counter = itertools.count(1) t = Table( @@ -525,7 +524,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): ), Column("data", Integer), ) - self.metadata.create_all(connection) + metadata.create_all(connection) result = connection.execute( t.insert().return_defaults(), diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 94af168ee..c51fd1943 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -40,10 +40,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" __backend__ = True - def setup(self): + def setup_test(self): self.metadata = MetaData() - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: self.metadata.drop_all(conn) @@ -890,7 +890,7 @@ class ExtractTest(fixtures.TablesTest): def setup_bind(cls): from sqlalchemy import event - eng = engines.testing_engine() + eng = engines.testing_engine(options={"scope": "class"}) @event.listens_for(eng, "connect") def connect(dbapi_conn, rec): diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 754eff25a..6586a8308 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -80,26 +80,24 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): ]: sa.event.listen(metadata, "before_drop", sa.DDL(ddl)) - def test_foreign_table_is_reflected(self): + def test_foreign_table_is_reflected(self, connection): metadata = MetaData() - table = Table("test_foreigntable", metadata, autoload_with=testing.db) + table = Table("test_foreigntable", metadata, autoload_with=connection) eq_( set(table.columns.keys()), set(["id", "data"]), "Columns of reflected foreign table didn't equal expected columns", ) - def test_get_foreign_table_names(self): - inspector = inspect(testing.db) - with testing.db.connect(): - ft_names = inspector.get_foreign_table_names() - eq_(ft_names, ["test_foreigntable"]) + def test_get_foreign_table_names(self, connection): + inspector = inspect(connection) + ft_names = inspector.get_foreign_table_names() + eq_(ft_names, ["test_foreigntable"]) - def test_get_table_names_no_foreign(self): - inspector = inspect(testing.db) - with testing.db.connect(): - names = inspector.get_table_names() - eq_(names, ["testtable"]) + def test_get_table_names_no_foreign(self, connection): + inspector = inspect(connection) + names = inspector.get_table_names() + eq_(names, ["testtable"]) class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): @@ -133,22 +131,22 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): if testing.against("postgresql >= 11"): Index("my_index", dv.c.q) - def test_get_tablenames(self): + def test_get_tablenames(self, connection): assert {"data_values", "data_values_4_10"}.issubset( - inspect(testing.db).get_table_names() + inspect(connection).get_table_names() ) - def test_reflect_cols(self): - cols = inspect(testing.db).get_columns("data_values") + def test_reflect_cols(self, connection): + cols = inspect(connection).get_columns("data_values") eq_([c["name"] for c in cols], ["modulus", "data", "q"]) - def test_reflect_cols_from_partition(self): - cols = inspect(testing.db).get_columns("data_values_4_10") + def test_reflect_cols_from_partition(self, connection): + cols = inspect(connection).get_columns("data_values_4_10") eq_([c["name"] for c in cols], ["modulus", "data", "q"]) @testing.only_on("postgresql >= 11") - def test_reflect_index(self): - idx = inspect(testing.db).get_indexes("data_values") + def test_reflect_index(self, connection): + idx = inspect(connection).get_indexes("data_values") eq_( idx, [ @@ -162,8 +160,8 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): ) @testing.only_on("postgresql >= 11") - def test_reflect_index_from_partition(self): - idx = inspect(testing.db).get_indexes("data_values_4_10") + def test_reflect_index_from_partition(self, connection): + idx = inspect(connection).get_indexes("data_values_4_10") # note the name appears to be generated by PG, currently # 'data_values_4_10_q_idx' eq_( @@ -220,44 +218,43 @@ class MaterializedViewReflectionTest( testtable, "before_drop", sa.DDL("DROP VIEW test_regview") ) - def test_mview_is_reflected(self): + def test_mview_is_reflected(self, connection): metadata = MetaData() - table = Table("test_mview", metadata, autoload_with=testing.db) + table = Table("test_mview", metadata, autoload_with=connection) eq_( set(table.columns.keys()), set(["id", "data"]), "Columns of reflected mview didn't equal expected columns", ) - def test_mview_select(self): + def test_mview_select(self, connection): metadata = MetaData() - table = Table("test_mview", metadata, autoload_with=testing.db) - with testing.db.connect() as conn: - eq_(conn.execute(table.select()).fetchall(), [(89, "d1")]) + table = Table("test_mview", metadata, autoload_with=connection) + eq_(connection.execute(table.select()).fetchall(), [(89, "d1")]) - def test_get_view_names(self): - insp = inspect(testing.db) + def test_get_view_names(self, connection): + insp = inspect(connection) eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) - def test_get_view_names_plain(self): - insp = inspect(testing.db) + def test_get_view_names_plain(self, connection): + insp = inspect(connection) eq_( set(insp.get_view_names(include=("plain",))), set(["test_regview"]) ) - def test_get_view_names_plain_string(self): - insp = inspect(testing.db) + def test_get_view_names_plain_string(self, connection): + insp = inspect(connection) eq_(set(insp.get_view_names(include="plain")), set(["test_regview"])) - def test_get_view_names_materialized(self): - insp = inspect(testing.db) + def test_get_view_names_materialized(self, connection): + insp = inspect(connection) eq_( set(insp.get_view_names(include=("materialized",))), set(["test_mview"]), ) - def test_get_view_names_reflection_cache_ok(self): - insp = inspect(testing.db) + def test_get_view_names_reflection_cache_ok(self, connection): + insp = inspect(connection) eq_( set(insp.get_view_names(include=("plain",))), set(["test_regview"]) ) @@ -267,12 +264,12 @@ class MaterializedViewReflectionTest( ) eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) - def test_get_view_names_empty(self): - insp = inspect(testing.db) + def test_get_view_names_empty(self, connection): + insp = inspect(connection) assert_raises(ValueError, insp.get_view_names, include=()) - def test_get_view_definition(self): - insp = inspect(testing.db) + def test_get_view_definition(self, connection): + insp = inspect(connection) eq_( re.sub( r"[\n\t ]+", @@ -290,7 +287,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.begin() as con: for ddl in [ 'CREATE SCHEMA "SomeSchema"', @@ -334,7 +331,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as con: con.exec_driver_sql("DROP TABLE testtable") con.exec_driver_sql("DROP TABLE test_schema.testtable") @@ -350,9 +347,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') con.exec_driver_sql('DROP SCHEMA "SomeSchema"') - def test_table_is_reflected(self): + def test_table_is_reflected(self, connection): metadata = MetaData() - table = Table("testtable", metadata, autoload_with=testing.db) + table = Table("testtable", metadata, autoload_with=connection) eq_( set(table.columns.keys()), set(["question", "answer"]), @@ -360,9 +357,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) assert isinstance(table.c.answer.type, Integer) - def test_domain_is_reflected(self): + def test_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("testtable", metadata, autoload_with=testing.db) + table = Table("testtable", metadata, autoload_with=connection) eq_( str(table.columns.answer.server_default.arg), "42", @@ -372,28 +369,28 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): not table.columns.answer.nullable ), "Expected reflected column to not be nullable." - def test_enum_domain_is_reflected(self): + def test_enum_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("enum_test", metadata, autoload_with=testing.db) + table = Table("enum_test", metadata, autoload_with=connection) eq_(table.c.data.type.enums, ["test"]) - def test_array_domain_is_reflected(self): + def test_array_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("array_test", metadata, autoload_with=testing.db) + table = Table("array_test", metadata, autoload_with=connection) eq_(table.c.data.type.__class__, ARRAY) eq_(table.c.data.type.item_type.__class__, INTEGER) - def test_quoted_remote_schema_domain_is_reflected(self): + def test_quoted_remote_schema_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("quote_test", metadata, autoload_with=testing.db) + table = Table("quote_test", metadata, autoload_with=connection) eq_(table.c.data.type.__class__, INTEGER) - def test_table_is_reflected_test_schema(self): + def test_table_is_reflected_test_schema(self, connection): metadata = MetaData() table = Table( "testtable", metadata, - autoload_with=testing.db, + autoload_with=connection, schema="test_schema", ) eq_( @@ -403,12 +400,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) assert isinstance(table.c.anything.type, Integer) - def test_schema_domain_is_reflected(self): + def test_schema_domain_is_reflected(self, connection): metadata = MetaData() table = Table( "testtable", metadata, - autoload_with=testing.db, + autoload_with=connection, schema="test_schema", ) eq_( @@ -420,9 +417,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): table.columns.answer.nullable ), "Expected reflected column to be nullable." - def test_crosschema_domain_is_reflected(self): + def test_crosschema_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("crosschema", metadata, autoload_with=testing.db) + table = Table("crosschema", metadata, autoload_with=connection) eq_( str(table.columns.answer.server_default.arg), "0", @@ -432,7 +429,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): table.columns.answer.nullable ), "Expected reflected column to be nullable." - def test_unknown_types(self): + def test_unknown_types(self, connection): from sqlalchemy.dialects.postgresql import base ischema_names = base.PGDialect.ischema_names @@ -440,13 +437,13 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): try: m2 = MetaData() assert_raises( - exc.SAWarning, Table, "testtable", m2, autoload_with=testing.db + exc.SAWarning, Table, "testtable", m2, autoload_with=connection ) @testing.emits_warning("Did not recognize type") def warns(): m3 = MetaData() - t3 = Table("testtable", m3, autoload_with=testing.db) + t3 = Table("testtable", m3, autoload_with=connection) assert t3.c.answer.type.__class__ == sa.types.NullType finally: @@ -471,9 +468,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): subject = Table("subject", meta2, autoload_with=connection) eq_(subject.primary_key.columns.keys(), ["p2", "p1"]) - @testing.provide_metadata - def test_pg_weirdchar_reflection(self): - meta1 = self.metadata + def test_pg_weirdchar_reflection(self, metadata, connection): + meta1 = metadata subject = Table( "subject", meta1, Column("id$", Integer, primary_key=True) ) @@ -483,101 +479,91 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("ref", Integer, ForeignKey("subject.id$")), ) - meta1.create_all(testing.db) + meta1.create_all(connection) meta2 = MetaData() - subject = Table("subject", meta2, autoload_with=testing.db) - referer = Table("referer", meta2, autoload_with=testing.db) + subject = Table("subject", meta2, autoload_with=connection) + referer = Table("referer", meta2, autoload_with=connection) self.assert_( (subject.c["id$"] == referer.c.ref).compare( subject.join(referer).onclause ) ) - @testing.provide_metadata - def test_reflect_default_over_128_chars(self): + def test_reflect_default_over_128_chars(self, metadata, connection): Table( "t", - self.metadata, + metadata, Column("x", String(200), server_default="abcd" * 40), - ).create(testing.db) + ).create(connection) m = MetaData() - t = Table("t", m, autoload_with=testing.db) + t = Table("t", m, autoload_with=connection) eq_( t.c.x.server_default.arg.text, "'%s'::character varying" % ("abcd" * 40), ) - @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure") - @testing.provide_metadata - def test_renamed_sequence_reflection(self): - metadata = self.metadata + def test_renamed_sequence_reflection(self, metadata, connection): Table("t", metadata, Column("id", Integer, primary_key=True)) - metadata.create_all(testing.db) + metadata.create_all(connection) m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db, implicit_returning=False) + t2 = Table("t", m2, autoload_with=connection, implicit_returning=False) eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)") - with testing.db.begin() as conn: - r = conn.execute(t2.insert()) - eq_(r.inserted_primary_key, (1,)) + r = connection.execute(t2.insert()) + eq_(r.inserted_primary_key, (1,)) - with testing.db.begin() as conn: - conn.exec_driver_sql( - "alter table t_id_seq rename to foobar_id_seq" - ) + connection.exec_driver_sql( + "alter table t_id_seq rename to foobar_id_seq" + ) m3 = MetaData() - t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False) + t3 = Table("t", m3, autoload_with=connection, implicit_returning=False) eq_( t3.c.id.server_default.arg.text, "nextval('foobar_id_seq'::regclass)", ) - with testing.db.begin() as conn: - r = conn.execute(t3.insert()) - eq_(r.inserted_primary_key, (2,)) + r = connection.execute(t3.insert()) + eq_(r.inserted_primary_key, (2,)) - @testing.provide_metadata - def test_altered_type_autoincrement_pk_reflection(self): - metadata = self.metadata + def test_altered_type_autoincrement_pk_reflection( + self, metadata, connection + ): + metadata = metadata Table( "t", metadata, Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all(testing.db) + metadata.create_all(connection) - with testing.db.begin() as conn: - conn.exec_driver_sql( - "alter table t alter column id type varchar(50)" - ) + connection.exec_driver_sql( + "alter table t alter column id type varchar(50)" + ) m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db) + t2 = Table("t", m2, autoload_with=connection) eq_(t2.c.id.autoincrement, False) eq_(t2.c.x.autoincrement, False) - @testing.provide_metadata - def test_renamed_pk_reflection(self): - metadata = self.metadata + def test_renamed_pk_reflection(self, metadata, connection): + metadata = metadata Table("t", metadata, Column("id", Integer, primary_key=True)) - metadata.create_all(testing.db) - with testing.db.begin() as conn: - conn.exec_driver_sql("alter table t rename id to t_id") + metadata.create_all(connection) + connection.exec_driver_sql("alter table t rename id to t_id") m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db) + t2 = Table("t", m2, autoload_with=connection) eq_([c.name for c in t2.primary_key], ["t_id"]) - @testing.provide_metadata - def test_has_temporary_table(self): - assert not inspect(testing.db).has_table("some_temp_table") + def test_has_temporary_table(self, metadata, connection): + assert not inspect(connection).has_table("some_temp_table") user_tmp = Table( "some_temp_table", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), prefixes=["TEMPORARY"], ) - user_tmp.create(testing.db) - assert inspect(testing.db).has_table("some_temp_table") + user_tmp.create(connection) + assert inspect(connection).has_table("some_temp_table") def test_cross_schema_reflection_one(self, metadata, connection): @@ -898,19 +884,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): A_table.create(connection, checkfirst=True) assert inspect(connection).has_table("A") - def test_uppercase_lowercase_sequence(self): + def test_uppercase_lowercase_sequence(self, connection): a_seq = Sequence("a") A_seq = Sequence("A") - a_seq.create(testing.db) - assert testing.db.dialect.has_sequence(testing.db, "a") - assert not testing.db.dialect.has_sequence(testing.db, "A") - A_seq.create(testing.db, checkfirst=True) - assert testing.db.dialect.has_sequence(testing.db, "A") + a_seq.create(connection) + assert connection.dialect.has_sequence(connection, "a") + assert not connection.dialect.has_sequence(connection, "A") + A_seq.create(connection, checkfirst=True) + assert connection.dialect.has_sequence(connection, "A") - a_seq.drop(testing.db) - A_seq.drop(testing.db) + a_seq.drop(connection) + A_seq.drop(connection) def test_index_reflection(self, metadata, connection): """Reflecting expression-based indexes should warn""" @@ -960,11 +946,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_index_reflection_partial(self, connection): + def test_index_reflection_partial(self, metadata, connection): """Reflect the filter defintion on partial indexes""" - metadata = self.metadata + metadata = metadata t1 = Table( "table1", @@ -978,7 +963,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): metadata.create_all(connection) - ind = testing.db.dialect.get_indexes(connection, t1, None) + ind = connection.dialect.get_indexes(connection, t1, None) partial_definitions = [] for ix in ind: @@ -1073,15 +1058,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): compile_exprs(r3.expressions), ) - @testing.provide_metadata - def test_index_reflection_modified(self): + def test_index_reflection_modified(self, metadata, connection): """reflect indexes when a column name has changed - PG 9 does not update the name of the column in the index def. [ticket:2141] """ - metadata = self.metadata + metadata = metadata Table( "t", @@ -1089,26 +1073,21 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all(testing.db) - with testing.db.begin() as conn: - conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)") - conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") + metadata.create_all(connection) + connection.exec_driver_sql("CREATE INDEX idx1 ON t (x)") + connection.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") - ind = testing.db.dialect.get_indexes(conn, "t", None) - expected = [ - {"name": "idx1", "unique": False, "column_names": ["y"]} - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] + ind = connection.dialect.get_indexes(connection, "t", None) + expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] - eq_(ind, expected) + eq_(ind, expected) - @testing.fails_if("postgresql < 8.2", "reloptions not supported") - @testing.provide_metadata - def test_index_reflection_with_storage_options(self): + def test_index_reflection_with_storage_options(self, metadata, connection): """reflect indexes with storage options set""" - metadata = self.metadata + metadata = metadata Table( "t", @@ -1116,70 +1095,63 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all(testing.db) + metadata.create_all(connection) - with testing.db.begin() as conn: - conn.exec_driver_sql( - "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)" - ) + connection.exec_driver_sql( + "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)" + ) - ind = testing.db.dialect.get_indexes(conn, "t", None) + ind = testing.db.dialect.get_indexes(connection, "t", None) - expected = [ - { - "unique": False, - "column_names": ["x"], - "name": "idx1", - "dialect_options": { - "postgresql_with": {"fillfactor": "50"} - }, - } - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - eq_(ind, expected) + expected = [ + { + "unique": False, + "column_names": ["x"], + "name": "idx1", + "dialect_options": {"postgresql_with": {"fillfactor": "50"}}, + } + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_(ind, expected) - m = MetaData() - t1 = Table("t", m, autoload_with=conn) - eq_( - list(t1.indexes)[0].dialect_options["postgresql"]["with"], - {"fillfactor": "50"}, - ) + m = MetaData() + t1 = Table("t", m, autoload_with=connection) + eq_( + list(t1.indexes)[0].dialect_options["postgresql"]["with"], + {"fillfactor": "50"}, + ) - @testing.provide_metadata - def test_index_reflection_with_access_method(self): + def test_index_reflection_with_access_method(self, metadata, connection): """reflect indexes with storage options set""" - metadata = self.metadata - Table( "t", metadata, Column("id", Integer, primary_key=True), Column("x", ARRAY(Integer)), ) - metadata.create_all(testing.db) - with testing.db.begin() as conn: - conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)") + metadata.create_all(connection) + connection.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)") - ind = testing.db.dialect.get_indexes(conn, "t", None) - expected = [ - { - "unique": False, - "column_names": ["x"], - "name": "idx1", - "dialect_options": {"postgresql_using": "gin"}, - } - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - eq_(ind, expected) - m = MetaData() - t1 = Table("t", m, autoload_with=conn) - eq_( - list(t1.indexes)[0].dialect_options["postgresql"]["using"], - "gin", - ) + ind = testing.db.dialect.get_indexes(connection, "t", None) + expected = [ + { + "unique": False, + "column_names": ["x"], + "name": "idx1", + "dialect_options": {"postgresql_using": "gin"}, + } + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_(ind, expected) + m = MetaData() + t1 = Table("t", m, autoload_with=connection) + eq_( + list(t1.indexes)[0].dialect_options["postgresql"]["using"], + "gin", + ) @testing.skip_if("postgresql < 11.0", "indnkeyatts not supported") def test_index_reflection_with_include(self, metadata, connection): @@ -1199,7 +1171,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): # [{'column_names': ['x', 'name'], # 'name': 'idx1', 'unique': False}] - ind = testing.db.dialect.get_indexes(connection, "t", None) + ind = connection.dialect.get_indexes(connection, "t", None) eq_( ind, [ @@ -1286,15 +1258,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): for fk in fks: eq_(fk, fk_ref[fk["name"]]) - @testing.provide_metadata - def test_inspect_enums_schema(self, connection): + def test_inspect_enums_schema(self, metadata, connection): enum_type = postgresql.ENUM( "sad", "ok", "happy", name="mood", schema="test_schema", - metadata=self.metadata, + metadata=metadata, ) enum_type.create(connection) inspector = inspect(connection) @@ -1310,13 +1281,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums(self): + def test_inspect_enums(self, metadata, connection): enum_type = postgresql.ENUM( - "cat", "dog", "rat", name="pet", metadata=self.metadata + "cat", "dog", "rat", name="pet", metadata=metadata ) - enum_type.create(testing.db) - inspector = inspect(testing.db) + enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums(), [ @@ -1329,17 +1299,16 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums_case_sensitive(self): + def test_inspect_enums_case_sensitive(self, metadata, connection): sa.event.listen( - self.metadata, + metadata, "before_create", sa.DDL('create schema "TestSchema"'), ) sa.event.listen( - self.metadata, + metadata, "after_drop", - sa.DDL('drop schema "TestSchema" cascade'), + sa.DDL('drop schema if exists "TestSchema" cascade'), ) for enum in "lower_case", "UpperCase", "Name.With.Dot": @@ -1350,11 +1319,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "CapsTwo", name=enum, schema=schema, - metadata=self.metadata, + metadata=metadata, ) - self.metadata.create_all(testing.db) - inspector = inspect(testing.db) + metadata.create_all(connection) + inspector = inspect(connection) for schema in None, "test_schema", "TestSchema": eq_( sorted( @@ -1382,17 +1351,18 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums_case_sensitive_from_table(self): + def test_inspect_enums_case_sensitive_from_table( + self, metadata, connection + ): sa.event.listen( - self.metadata, + metadata, "before_create", sa.DDL('create schema "TestSchema"'), ) sa.event.listen( - self.metadata, + metadata, "after_drop", - sa.DDL('drop schema "TestSchema" cascade'), + sa.DDL('drop schema if exists "TestSchema" cascade'), ) counter = itertools.count() @@ -1403,19 +1373,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "CapsOne", "CapsTwo", name=enum, - metadata=self.metadata, + metadata=metadata, schema=schema, ) Table( "t%d" % next(counter), - self.metadata, + metadata, Column("q", enum_type), ) - self.metadata.create_all(testing.db) + metadata.create_all(connection) - inspector = inspect(testing.db) + inspector = inspect(connection) counter = itertools.count() for enum in "lower_case", "UpperCase", "Name.With.Dot": for schema in None, "test_schema", "TestSchema": @@ -1439,10 +1409,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums_star(self): + def test_inspect_enums_star(self, metadata, connection): enum_type = postgresql.ENUM( - "cat", "dog", "rat", name="pet", metadata=self.metadata + "cat", "dog", "rat", name="pet", metadata=metadata ) schema_enum_type = postgresql.ENUM( "sad", @@ -1450,11 +1419,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "happy", name="mood", schema="test_schema", - metadata=self.metadata, + metadata=metadata, ) - enum_type.create(testing.db) - schema_enum_type.create(testing.db) - inspector = inspect(testing.db) + enum_type.create(connection) + schema_enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums(), @@ -1486,11 +1455,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enum_empty(self): - enum_type = postgresql.ENUM(name="empty", metadata=self.metadata) - enum_type.create(testing.db) - inspector = inspect(testing.db) + def test_inspect_enum_empty(self, metadata, connection): + enum_type = postgresql.ENUM(name="empty", metadata=metadata) + enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums(), @@ -1504,13 +1472,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enum_empty_from_table(self): + def test_inspect_enum_empty_from_table(self, metadata, connection): Table( - "t", self.metadata, Column("x", postgresql.ENUM(name="empty")) - ).create(testing.db) + "t", metadata, Column("x", postgresql.ENUM(name="empty")) + ).create(connection) - t = Table("t", MetaData(), autoload_with=testing.db) + t = Table("t", MetaData(), autoload_with=connection) eq_(t.c.x.type.enums, []) def test_reflection_with_unique_constraint(self, metadata, connection): @@ -1749,12 +1716,12 @@ class CustomTypeReflectionTest(fixtures.TestBase): ischema_names = None - def setup(self): + def setup_test(self): ischema_names = postgresql.PGDialect.ischema_names postgresql.PGDialect.ischema_names = ischema_names.copy() self.ischema_names = ischema_names - def teardown(self): + def teardown_test(self): postgresql.PGDialect.ischema_names = self.ischema_names self.ischema_names = None @@ -1788,55 +1755,51 @@ class IntervalReflectionTest(fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True - def test_interval_types(self): - for sym in [ - "YEAR", - "MONTH", - "DAY", - "HOUR", - "MINUTE", - "SECOND", - "YEAR TO MONTH", - "DAY TO HOUR", - "DAY TO MINUTE", - "DAY TO SECOND", - "HOUR TO MINUTE", - "HOUR TO SECOND", - "MINUTE TO SECOND", - ]: - self._test_interval_symbol(sym) - - @testing.provide_metadata - def _test_interval_symbol(self, sym): + @testing.combinations( + ("YEAR",), + ("MONTH",), + ("DAY",), + ("HOUR",), + ("MINUTE",), + ("SECOND",), + ("YEAR TO MONTH",), + ("DAY TO HOUR",), + ("DAY TO MINUTE",), + ("DAY TO SECOND",), + ("HOUR TO MINUTE",), + ("HOUR TO SECOND",), + ("MINUTE TO SECOND",), + argnames="sym", + ) + def test_interval_types(self, sym, metadata, connection): t = Table( "i_test", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("data1", INTERVAL(fields=sym)), ) - t.create(testing.db) + t.create(connection) columns = { rec["name"]: rec - for rec in inspect(testing.db).get_columns("i_test") + for rec in inspect(connection).get_columns("i_test") } assert isinstance(columns["data1"]["type"], INTERVAL) eq_(columns["data1"]["type"].fields, sym.lower()) eq_(columns["data1"]["type"].precision, None) - @testing.provide_metadata - def test_interval_precision(self): + def test_interval_precision(self, metadata, connection): t = Table( "i_test", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("data1", INTERVAL(precision=6)), ) - t.create(testing.db) + t.create(connection) columns = { rec["name"]: rec - for rec in inspect(testing.db).get_columns("i_test") + for rec in inspect(connection).get_columns("i_test") } assert isinstance(columns["data1"]["type"], INTERVAL) eq_(columns["data1"]["type"].fields, None) @@ -1871,8 +1834,8 @@ class IdentityReflectionTest(fixtures.TablesTest): Column("id4", SmallInteger, Identity()), ) - def test_reflect_identity(self): - insp = inspect(testing.db) + def test_reflect_identity(self, connection): + insp = inspect(connection) default = dict( always=False, start=1, diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index e8a1876c7..6202f8f86 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -49,7 +49,6 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes -from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -156,8 +155,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql > 8.3" - @testing.provide_metadata - def test_create_table(self, connection): + def test_create_table(self, metadata, connection): metadata = self.metadata t1 = Table( "table", @@ -177,8 +175,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): [(1, "two"), (2, "three"), (3, "three")], ) - @testing.combinations(None, "foo") - def test_create_table_schema_translate_map(self, symbol_name): + @testing.combinations(None, "foo", argnames="symbol_name") + def test_create_table_schema_translate_map(self, connection, symbol_name): # note we can't use the fixture here because it will not drop # from the correct schema metadata = MetaData() @@ -199,35 +197,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), schema=symbol_name, ) - with testing.db.begin() as conn: - conn = conn.execution_options( - schema_translate_map={symbol_name: testing.config.test_schema} - ) - t1.create(conn) - assert "schema_enum" in [ - e["name"] - for e in inspect(conn).get_enums( - schema=testing.config.test_schema - ) - ] - t1.create(conn, checkfirst=True) + conn = connection.execution_options( + schema_translate_map={symbol_name: testing.config.test_schema} + ) + t1.create(conn) + assert "schema_enum" in [ + e["name"] + for e in inspect(conn).get_enums(schema=testing.config.test_schema) + ] + t1.create(conn, checkfirst=True) - conn.execute(t1.insert(), value="two") - conn.execute(t1.insert(), value="three") - conn.execute(t1.insert(), value="three") - eq_( - conn.execute(t1.select().order_by(t1.c.id)).fetchall(), - [(1, "two"), (2, "three"), (3, "three")], - ) + conn.execute(t1.insert(), value="two") + conn.execute(t1.insert(), value="three") + conn.execute(t1.insert(), value="three") + eq_( + conn.execute(t1.select().order_by(t1.c.id)).fetchall(), + [(1, "two"), (2, "three"), (3, "three")], + ) - t1.drop(conn) - assert "schema_enum" not in [ - e["name"] - for e in inspect(conn).get_enums( - schema=testing.config.test_schema - ) - ] - t1.drop(conn, checkfirst=True) + t1.drop(conn) + assert "schema_enum" not in [ + e["name"] + for e in inspect(conn).get_enums(schema=testing.config.test_schema) + ] + t1.drop(conn, checkfirst=True) def test_name_required(self, metadata, connection): etype = Enum("four", "five", "six", metadata=metadata) @@ -270,8 +263,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): [util.u("réveillé"), util.u("drôle"), util.u("S’il")], ) - @testing.provide_metadata - def test_non_native_enum(self, connection): + def test_non_native_enum(self, metadata, connection): metadata = self.metadata t1 = Table( "foo", @@ -290,10 +282,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) def go(): - t1.create(testing.db) + t1.create(connection) self.assert_sql( - testing.db, + connection, go, [ ( @@ -307,8 +299,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection.execute(t1.insert(), {"bar": "two"}) eq_(connection.scalar(select(t1.c.bar)), "two") - @testing.provide_metadata - def test_non_native_enum_w_unicode(self, connection): + def test_non_native_enum_w_unicode(self, metadata, connection): metadata = self.metadata t1 = Table( "foo", @@ -326,10 +317,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) def go(): - t1.create(testing.db) + t1.create(connection) self.assert_sql( - testing.db, + connection, go, [ ( @@ -346,8 +337,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection.execute(t1.insert(), {"bar": util.u("Ü")}) eq_(connection.scalar(select(t1.c.bar)), util.u("Ü")) - @testing.provide_metadata - def test_disable_create(self): + def test_disable_create(self, metadata, connection): metadata = self.metadata e1 = postgresql.ENUM( @@ -357,13 +347,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): t1 = Table("e1", metadata, Column("c1", e1)) # table can be created separately # without conflict - e1.create(bind=testing.db) - t1.create(testing.db) - t1.drop(testing.db) - e1.drop(bind=testing.db) + e1.create(bind=connection) + t1.create(connection) + t1.drop(connection) + e1.drop(bind=connection) - @testing.provide_metadata - def test_dont_keep_checking(self, connection): + def test_dont_keep_checking(self, metadata, connection): metadata = self.metadata e1 = postgresql.ENUM("one", "two", "three", name="myenum") @@ -560,11 +549,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): e["name"] for e in inspect(connection).get_enums() ] - def test_non_native_dialect(self): - engine = engines.testing_engine() + def test_non_native_dialect(self, metadata, testing_engine): + engine = testing_engine() engine.connect() engine.dialect.supports_native_enum = False - metadata = MetaData() t1 = Table( "foo", metadata, @@ -583,21 +571,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): def go(): t1.create(engine) - try: - self.assert_sql( - engine, - go, - [ - ( - "CREATE TABLE foo (bar " - "VARCHAR(5), CONSTRAINT myenum CHECK " - "(bar IN ('one', 'two', 'three')))", - {}, - ) - ], - ) - finally: - metadata.drop_all(engine) + self.assert_sql( + engine, + go, + [ + ( + "CREATE TABLE foo (bar " + "VARCHAR(5), CONSTRAINT myenum CHECK " + "(bar IN ('one', 'two', 'three')))", + {}, + ) + ], + ) def test_standalone_enum(self, connection, metadata): etype = Enum( @@ -605,26 +590,26 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) etype.create(connection) try: - assert testing.db.dialect.has_type(connection, "fourfivesixtype") + assert connection.dialect.has_type(connection, "fourfivesixtype") finally: etype.drop(connection) - assert not testing.db.dialect.has_type( + assert not connection.dialect.has_type( connection, "fourfivesixtype" ) metadata.create_all(connection) try: - assert testing.db.dialect.has_type(connection, "fourfivesixtype") + assert connection.dialect.has_type(connection, "fourfivesixtype") finally: metadata.drop_all(connection) - assert not testing.db.dialect.has_type( + assert not connection.dialect.has_type( connection, "fourfivesixtype" ) - def test_no_support(self): + def test_no_support(self, testing_engine): def server_version_info(self): return (8, 2) - e = engines.testing_engine() + e = testing_engine() dialect = e.dialect dialect._get_server_version_info = server_version_info @@ -692,8 +677,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): eq_(t2.c.value2.type.name, "fourfivesixtype") eq_(t2.c.value2.type.schema, "test_schema") - @testing.provide_metadata - def test_custom_subclass(self, connection): + def test_custom_subclass(self, metadata, connection): class MyEnum(TypeDecorator): impl = Enum("oneHI", "twoHI", "threeHI", name="myenum") @@ -708,13 +692,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): return value t1 = Table("table1", self.metadata, Column("data", MyEnum())) - self.metadata.create_all(testing.db) + self.metadata.create_all(connection) connection.execute(t1.insert(), {"data": "two"}) eq_(connection.scalar(select(t1.c.data)), "twoHITHERE") - @testing.provide_metadata - def test_generic_w_pg_variant(self, connection): + def test_generic_w_pg_variant(self, metadata, connection): some_table = Table( "some_table", self.metadata, @@ -752,8 +735,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): e["name"] for e in inspect(connection).get_enums() ] - @testing.provide_metadata - def test_generic_w_some_other_variant(self, connection): + def test_generic_w_some_other_variant(self, metadata, connection): some_table = Table( "some_table", self.metadata, @@ -809,26 +791,28 @@ class RegClassTest(fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True - @staticmethod - def _scalar(expression): - with testing.db.connect() as conn: - return conn.scalar(select(expression)) + @testing.fixture() + def scalar(self, connection): + def go(expression): + return connection.scalar(select(expression)) - def test_cast_name(self): - eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class") + return go - def test_cast_path(self): + def test_cast_name(self, scalar): + eq_(scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class") + + def test_cast_path(self, scalar): eq_( - self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)), + scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)), "pg_class", ) - def test_cast_oid(self): + def test_cast_oid(self, scalar): regclass = cast("pg_class", postgresql.REGCLASS) - oid = self._scalar(cast(regclass, postgresql.OID)) + oid = scalar(cast(regclass, postgresql.OID)) assert isinstance(oid, int) eq_( - self._scalar( + scalar( cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS) ), "pg_class", @@ -1339,13 +1323,12 @@ class ArrayRoundTripTest(object): Column("dimarr", ProcValue), ) - def _fixture_456(self, table): - with testing.db.begin() as conn: - conn.execute(table.insert(), intarr=[4, 5, 6]) + def _fixture_456(self, table, connection): + connection.execute(table.insert(), intarr=[4, 5, 6]) - def test_reflect_array_column(self): + def test_reflect_array_column(self, connection): metadata2 = MetaData() - tbl = Table("arrtable", metadata2, autoload_with=testing.db) + tbl = Table("arrtable", metadata2, autoload_with=connection) assert isinstance(tbl.c.intarr.type, self.ARRAY) assert isinstance(tbl.c.strarr.type, self.ARRAY) assert isinstance(tbl.c.intarr.type.item_type, Integer) @@ -1564,7 +1547,7 @@ class ArrayRoundTripTest(object): def test_array_getitem_single_exec(self, connection): arrtable = self.tables.arrtable - self._fixture_456(arrtable) + self._fixture_456(arrtable, connection) eq_(connection.scalar(select(arrtable.c.intarr[2])), 5) connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7})) eq_(connection.scalar(select(arrtable.c.intarr[2])), 7) @@ -1654,11 +1637,10 @@ class ArrayRoundTripTest(object): set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]), ) - def test_array_plus_native_enum_create(self): - m = MetaData() + def test_array_plus_native_enum_create(self, metadata, connection): t = Table( "t", - m, + metadata, Column( "data_1", self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")), @@ -1669,13 +1651,13 @@ class ArrayRoundTripTest(object): ), ) - t.create(testing.db) + t.create(connection) eq_( - set(e["name"] for e in inspect(testing.db).get_enums()), + set(e["name"] for e in inspect(connection).get_enums()), set(["my_enum_1", "my_enum_2"]), ) - t.drop(testing.db) - eq_(inspect(testing.db).get_enums(), []) + t.drop(connection) + eq_(inspect(connection).get_enums(), []) class CoreArrayRoundTripTest( @@ -1690,33 +1672,35 @@ class PGArrayRoundTripTest( ): ARRAY = postgresql.ARRAY - @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) - def test_undim_array_contains_typed_exec(self, struct): + @testing.combinations( + (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct" + ) + def test_undim_array_contains_typed_exec(self, struct, connection): arrtable = self.tables.arrtable - self._fixture_456(arrtable) - with testing.db.begin() as conn: - eq_( - conn.scalar( - select(arrtable.c.intarr).where( - arrtable.c.intarr.contains(struct([4, 5])) - ) - ), - [4, 5, 6], - ) + self._fixture_456(arrtable, connection) + eq_( + connection.scalar( + select(arrtable.c.intarr).where( + arrtable.c.intarr.contains(struct([4, 5])) + ) + ), + [4, 5, 6], + ) - @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) - def test_dim_array_contains_typed_exec(self, struct): + @testing.combinations( + (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct" + ) + def test_dim_array_contains_typed_exec(self, struct, connection): dim_arrtable = self.tables.dim_arrtable - self._fixture_456(dim_arrtable) - with testing.db.begin() as conn: - eq_( - conn.scalar( - select(dim_arrtable.c.intarr).where( - dim_arrtable.c.intarr.contains(struct([4, 5])) - ) - ), - [4, 5, 6], - ) + self._fixture_456(dim_arrtable, connection) + eq_( + connection.scalar( + select(dim_arrtable.c.intarr).where( + dim_arrtable.c.intarr.contains(struct([4, 5])) + ) + ), + [4, 5, 6], + ) def test_array_contained_by_exec(self, connection): arrtable = self.tables.arrtable @@ -1730,7 +1714,7 @@ class PGArrayRoundTripTest( def test_undim_array_empty(self, connection): arrtable = self.tables.arrtable - self._fixture_456(arrtable) + self._fixture_456(arrtable, connection) eq_( connection.scalar( select(arrtable.c.intarr).where(arrtable.c.intarr.contains([])) @@ -1782,8 +1766,9 @@ class ArrayEnum(fixtures.TestBase): sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls" ) @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") - @testing.provide_metadata - def test_raises_non_native_enums(self, array_cls, enum_cls): + def test_raises_non_native_enums( + self, metadata, connection, array_cls, enum_cls + ): Table( "my_table", self.metadata, @@ -1808,7 +1793,7 @@ class ArrayEnum(fixtures.TestBase): "for ARRAY of non-native ENUM; please specify " "create_constraint=False on this Enum datatype.", self.metadata.create_all, - testing.db, + connection, ) @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") @@ -1818,8 +1803,7 @@ class ArrayEnum(fixtures.TestBase): (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")), argnames="array_cls", ) - @testing.provide_metadata - def test_array_of_enums(self, array_cls, enum_cls, connection): + def test_array_of_enums(self, array_cls, enum_cls, metadata, connection): tbl = Table( "enum_table", self.metadata, @@ -1875,8 +1859,7 @@ class ArrayJSON(fixtures.TestBase): @testing.combinations( sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls" ) - @testing.provide_metadata - def test_array_of_json(self, array_cls, json_cls, connection): + def test_array_of_json(self, array_cls, json_cls, metadata, connection): tbl = Table( "json_table", self.metadata, @@ -1982,19 +1965,38 @@ class HashableFlagORMTest(fixtures.TestBase): }, ], ), + ( + "HSTORE", + postgresql.HSTORE(), + [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}], + testing.requires.hstore, + ), + ( + "JSONB", + postgresql.JSONB(), + [ + {"a": "1", "b": "2", "c": "3"}, + { + "d": "4", + "e": {"e1": "5", "e2": "6"}, + "f": {"f1": [9, 10, 11]}, + }, + ], + testing.requires.postgresql_jsonb, + ), + argnames="type_,data", id_="iaa", ) - @testing.provide_metadata - def test_hashable_flag(self, type_, data): - Base = declarative_base(metadata=self.metadata) + def test_hashable_flag(self, metadata, connection, type_, data): + Base = declarative_base(metadata=metadata) class A(Base): __tablename__ = "a1" id = Column(Integer, primary_key=True) data = Column(type_) - Base.metadata.create_all(testing.db) - s = Session(testing.db) + Base.metadata.create_all(connection) + s = Session(connection) s.add_all([A(data=elem) for elem in data]) s.commit() @@ -2006,27 +2008,6 @@ class HashableFlagORMTest(fixtures.TestBase): list(enumerate(data, 1)), ) - @testing.requires.hstore - def test_hstore(self): - self.test_hashable_flag( - postgresql.HSTORE(), - [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}], - ) - - @testing.requires.postgresql_jsonb - def test_jsonb(self): - self.test_hashable_flag( - postgresql.JSONB(), - [ - {"a": "1", "b": "2", "c": "3"}, - { - "d": "4", - "e": {"e1": "5", "e2": "6"}, - "f": {"f1": [9, 10, 11]}, - }, - ], - ) - class TimestampTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" @@ -2108,14 +2089,14 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): return table - def test_reflection(self, special_types_table): + def test_reflection(self, special_types_table, connection): # cheat so that the "strict type check" # works special_types_table.c.year_interval.type = postgresql.INTERVAL() special_types_table.c.month_interval.type = postgresql.INTERVAL() m = MetaData() - t = Table("sometable", m, autoload_with=testing.db) + t = Table("sometable", m, autoload_with=connection) self.assert_tables_equal(special_types_table, t, strict_types=True) assert t.c.plain_interval.type.precision is None @@ -2210,7 +2191,7 @@ class UUIDTest(fixtures.TestBase): class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", @@ -2494,17 +2475,16 @@ class HStoreRoundTripTest(fixtures.TablesTest): Column("data", HSTORE), ) - def _fixture_data(self, engine): + def _fixture_data(self, connection): data_table = self.tables.data_table - with engine.begin() as conn: - conn.execute( - data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, - {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, - {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, - {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}}, - ) + connection.execute( + data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, + {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, + {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, + {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}}, + ) def _assert_data(self, compare, conn): data = conn.execute( @@ -2514,26 +2494,32 @@ class HStoreRoundTripTest(fixtures.TablesTest): ).fetchall() eq_([d for d, in data], compare) - def _test_insert(self, engine): - with engine.begin() as conn: - conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - ) - self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn) + def _test_insert(self, connection): + connection.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + ) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection) - def _non_native_engine(self): - if testing.requires.psycopg2_native_hstore.enabled: - engine = engines.testing_engine( - options=dict(use_native_hstore=False) - ) + @testing.fixture + def non_native_hstore_connection(self, testing_engine): + local_engine = testing.requires.psycopg2_native_hstore.enabled + + if local_engine: + engine = testing_engine(options=dict(use_native_hstore=False)) else: engine = testing.db - engine.connect().close() - return engine - def test_reflect(self): - insp = inspect(testing.db) + conn = engine.connect() + trans = conn.begin() + yield conn + try: + trans.rollback() + finally: + conn.close() + + def test_reflect(self, connection): + insp = inspect(connection) cols = insp.get_columns("data_table") assert isinstance(cols[2]["type"], HSTORE) @@ -2548,106 +2534,88 @@ class HStoreRoundTripTest(fixtures.TablesTest): eq_(connection.scalar(select(expr)), "3") @testing.requires.psycopg2_native_hstore - def test_insert_native(self): - engine = testing.db - self._test_insert(engine) + def test_insert_native(self, connection): + self._test_insert(connection) - def test_insert_python(self): - engine = self._non_native_engine() - self._test_insert(engine) + def test_insert_python(self, non_native_hstore_connection): + self._test_insert(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_criterion_native(self): - engine = testing.db - self._fixture_data(engine) - self._test_criterion(engine) + def test_criterion_native(self, connection): + self._fixture_data(connection) + self._test_criterion(connection) - def test_criterion_python(self): - engine = self._non_native_engine() - self._fixture_data(engine) - self._test_criterion(engine) + def test_criterion_python(self, non_native_hstore_connection): + self._fixture_data(non_native_hstore_connection) + self._test_criterion(non_native_hstore_connection) - def _test_criterion(self, engine): + def _test_criterion(self, connection): data_table = self.tables.data_table - with engine.begin() as conn: - result = conn.execute( - select(data_table.c.data).where( - data_table.c.data["k1"] == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) + result = connection.execute( + select(data_table.c.data).where(data_table.c.data["k1"] == "r3v1") + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) - def _test_fixed_round_trip(self, engine): - with engine.begin() as conn: - s = select( - hstore( - array(["key1", "key2", "key3"]), - array(["value1", "value2", "value3"]), - ) - ) - eq_( - conn.scalar(s), - {"key1": "value1", "key2": "value2", "key3": "value3"}, + def _test_fixed_round_trip(self, connection): + s = select( + hstore( + array(["key1", "key2", "key3"]), + array(["value1", "value2", "value3"]), ) + ) + eq_( + connection.scalar(s), + {"key1": "value1", "key2": "value2", "key3": "value3"}, + ) - def test_fixed_round_trip_python(self): - engine = self._non_native_engine() - self._test_fixed_round_trip(engine) + def test_fixed_round_trip_python(self, non_native_hstore_connection): + self._test_fixed_round_trip(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_fixed_round_trip_native(self): - engine = testing.db - self._test_fixed_round_trip(engine) + def test_fixed_round_trip_native(self, connection): + self._test_fixed_round_trip(connection) - def _test_unicode_round_trip(self, engine): - with engine.begin() as conn: - s = select( - hstore( - array( - [util.u("réveillé"), util.u("drôle"), util.u("S’il")] - ), - array( - [util.u("réveillé"), util.u("drôle"), util.u("S’il")] - ), - ) - ) - eq_( - conn.scalar(s), - { - util.u("réveillé"): util.u("réveillé"), - util.u("drôle"): util.u("drôle"), - util.u("S’il"): util.u("S’il"), - }, + def _test_unicode_round_trip(self, connection): + s = select( + hstore( + array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]), + array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]), ) + ) + eq_( + connection.scalar(s), + { + util.u("réveillé"): util.u("réveillé"), + util.u("drôle"): util.u("drôle"), + util.u("S’il"): util.u("S’il"), + }, + ) @testing.requires.psycopg2_native_hstore - def test_unicode_round_trip_python(self): - engine = self._non_native_engine() - self._test_unicode_round_trip(engine) + def test_unicode_round_trip_python(self, non_native_hstore_connection): + self._test_unicode_round_trip(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_unicode_round_trip_native(self): - engine = testing.db - self._test_unicode_round_trip(engine) + def test_unicode_round_trip_native(self, connection): + self._test_unicode_round_trip(connection) - def test_escaped_quotes_round_trip_python(self): - engine = self._non_native_engine() - self._test_escaped_quotes_round_trip(engine) + def test_escaped_quotes_round_trip_python( + self, non_native_hstore_connection + ): + self._test_escaped_quotes_round_trip(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_escaped_quotes_round_trip_native(self): - engine = testing.db - self._test_escaped_quotes_round_trip(engine) + def test_escaped_quotes_round_trip_native(self, connection): + self._test_escaped_quotes_round_trip(connection) - def _test_escaped_quotes_round_trip(self, engine): - with engine.begin() as conn: - conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}}, - ) - self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn) + def _test_escaped_quotes_round_trip(self, connection): + connection.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}}, + ) + self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], connection) - def test_orm_round_trip(self): + def test_orm_round_trip(self, connection): from sqlalchemy import orm class Data(object): @@ -2656,13 +2624,14 @@ class HStoreRoundTripTest(fixtures.TablesTest): self.data = data orm.mapper(Data, self.tables.data_table) - s = orm.Session(testing.db) - d = Data( - name="r1", - data={"key1": "value1", "key2": "value2", "key3": "value3"}, - ) - s.add(d) - eq_(s.query(Data.data, Data).all(), [(d.data, d)]) + + with orm.Session(connection) as s: + d = Data( + name="r1", + data={"key1": "value1", "key2": "value2", "key3": "value3"}, + ) + s.add(d) + eq_(s.query(Data.data, Data).all(), [(d.data, d)]) class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): @@ -2671,7 +2640,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): # operator tests @classmethod - def setup_class(cls): + def setup_test_class(cls): table = Table( "data_table", MetaData(), @@ -2852,10 +2821,10 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_actual_type(self): eq_(str(self._col_type()), self._col_str) - def test_reflect(self): + def test_reflect(self, connection): from sqlalchemy import inspect - insp = inspect(testing.db) + insp = inspect(connection) cols = insp.get_columns("data_table") assert isinstance(cols[0]["type"], self._col_type) @@ -2986,8 +2955,8 @@ class _DateTimeTZRangeTests(object): def tstzs(self): if self._tstzs is None: - with testing.db.begin() as conn: - lower = conn.scalar(func.current_timestamp().select()) + with testing.db.connect() as connection: + lower = connection.scalar(func.current_timestamp().select()) upper = lower + datetime.timedelta(1) self._tstzs = (lower, upper) return self._tstzs @@ -3052,7 +3021,7 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip): class JSONTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", @@ -3151,7 +3120,7 @@ class JSONRoundTripTest(fixtures.TablesTest): Column("nulldata", cls.data_type(none_as_null=True)), ) - def _fixture_data(self, engine): + def _fixture_data(self, connection): data_table = self.tables.data_table data = [ @@ -3162,8 +3131,7 @@ class JSONRoundTripTest(fixtures.TablesTest): {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, ] - with engine.begin() as conn: - conn.execute(data_table.insert(), data) + connection.execute(data_table.insert(), data) return data def _assert_data(self, compare, conn, column="data"): @@ -3185,51 +3153,39 @@ class JSONRoundTripTest(fixtures.TablesTest): ).fetchall() eq_([d for d, in data], [None]) - def _test_insert(self, conn): - conn.execute( + def test_reflect(self, connection): + insp = inspect(connection) + cols = insp.get_columns("data_table") + assert isinstance(cols[2]["type"], self.data_type) + + def test_insert(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, ) - self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection) - def _test_insert_nulls(self, conn): - conn.execute( + def test_insert_nulls(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "data": null()} ) - self._assert_data([None], conn) + self._assert_data([None], connection) - def _test_insert_none_as_null(self, conn): - conn.execute( + def test_insert_none_as_null(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "nulldata": None}, ) - self._assert_column_is_NULL(conn, column="nulldata") + self._assert_column_is_NULL(connection, column="nulldata") - def _test_insert_nulljson_into_none_as_null(self, conn): - conn.execute( + def test_insert_nulljson_into_none_as_null(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "nulldata": JSON.NULL}, ) - self._assert_column_is_JSON_NULL(conn, column="nulldata") - - def test_reflect(self): - insp = inspect(testing.db) - cols = insp.get_columns("data_table") - assert isinstance(cols[2]["type"], self.data_type) - - def test_insert(self, connection): - self._test_insert(connection) - - def test_insert_nulls(self, connection): - self._test_insert_nulls(connection) + self._assert_column_is_JSON_NULL(connection, column="nulldata") - def test_insert_none_as_null(self, connection): - self._test_insert_none_as_null(connection) - - def test_insert_nulljson_into_none_as_null(self, connection): - self._test_insert_nulljson_into_none_as_null(connection) - - def test_custom_serialize_deserialize(self): + def test_custom_serialize_deserialize(self, testing_engine): import json def loads(value): @@ -3242,7 +3198,7 @@ class JSONRoundTripTest(fixtures.TablesTest): value["x"] = "dumps_y" return json.dumps(value) - engine = engines.testing_engine( + engine = testing_engine( options=dict(json_serializer=dumps, json_deserializer=loads) ) @@ -3250,14 +3206,26 @@ class JSONRoundTripTest(fixtures.TablesTest): with engine.begin() as conn: eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) - def test_criterion(self): - engine = testing.db - self._fixture_data(engine) - self._test_criterion(engine) + def test_criterion(self, connection): + self._fixture_data(connection) + data_table = self.tables.data_table + + result = connection.execute( + select(data_table.c.data).where( + data_table.c.data["k1"].astext == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) + + result = connection.execute( + select(data_table.c.data).where( + data_table.c.data["k1"].astext.cast(String) == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) def test_path_query(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( @@ -3271,8 +3239,7 @@ class JSONRoundTripTest(fixtures.TablesTest): "postgresql < 9.4", "Improvement in PostgreSQL behavior?" ) def test_multi_index_query(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( @@ -3283,20 +3250,18 @@ class JSONRoundTripTest(fixtures.TablesTest): eq_(result.scalar(), "r6") def test_query_returned_as_text(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( select(data_table.c.data["k1"].astext) ).first() - if engine.dialect.returns_unicode_strings: + if connection.dialect.returns_unicode_strings: assert isinstance(result[0], util.text_type) else: assert isinstance(result[0], util.string_types) def test_query_returned_as_int(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( select(data_table.c.data["k3"].astext.cast(Integer)).where( @@ -3305,23 +3270,6 @@ class JSONRoundTripTest(fixtures.TablesTest): ).first() assert isinstance(result[0], int) - def _test_criterion(self, engine): - data_table = self.tables.data_table - with engine.begin() as conn: - result = conn.execute( - select(data_table.c.data).where( - data_table.c.data["k1"].astext == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) - - result = conn.execute( - select(data_table.c.data).where( - data_table.c.data["k1"].astext.cast(String) == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) - def test_fixed_round_trip(self, connection): s = select( cast( @@ -3352,42 +3300,41 @@ class JSONRoundTripTest(fixtures.TablesTest): }, ) - def test_eval_none_flag_orm(self): + def test_eval_none_flag_orm(self, connection): Base = declarative_base() class Data(Base): __table__ = self.tables.data_table - s = Session(testing.db) + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() - d1 = Data(name="d1", data=None, nulldata=None) - s.add(d1) - s.commit() - - s.bulk_insert_mappings( - Data, [{"name": "d2", "data": None, "nulldata": None}] - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String), - cast(self.tables.data_table.c.nulldata, String), + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] ) - .filter(self.tables.data_table.c.name == "d1") - .first(), - ("null", None), - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String), - cast(self.tables.data_table.c.nulldata, String), + eq_( + s.query( + cast(self.tables.data_table.c.data, String), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), ) - .filter(self.tables.data_table.c.name == "d2") - .first(), - ("null", None), - ) def test_literal(self, connection): - exp = self._fixture_data(testing.db) + exp = self._fixture_data(connection) result = connection.exec_driver_sql( "select data from data_table order by name" ) @@ -3395,11 +3342,10 @@ class JSONRoundTripTest(fixtures.TablesTest): eq_(len(res), len(exp)) for row, expected in zip(res, exp): eq_(row[0], expected["data"]) - result.close() class JSONBTest(JSONTest): - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 4658b40a8..1926c6065 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -37,6 +37,7 @@ from sqlalchemy import UniqueConstraint from sqlalchemy import util from sqlalchemy.dialects.sqlite import base as sqlite from sqlalchemy.dialects.sqlite import insert +from sqlalchemy.dialects.sqlite import provision from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect from sqlalchemy.engine.url import make_url from sqlalchemy.schema import CreateTable @@ -46,6 +47,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import combinations +from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings @@ -774,7 +776,7 @@ class AttachedDBTest(fixtures.TestBase): def _fixture(self): meta = self.metadata - self.conn = testing.db.connect() + self.conn = self.engine.connect() Table("created", meta, Column("foo", Integer), Column("bar", String)) Table("local_only", meta, Column("q", Integer), Column("p", Integer)) @@ -798,14 +800,20 @@ class AttachedDBTest(fixtures.TestBase): meta.create_all(self.conn) return ct - def setup(self): - self.conn = testing.db.connect() + def setup_test(self): + self.engine = engines.testing_engine(options={"use_reaper": False}) + + provision._sqlite_post_configure_engine( + self.engine.url, self.engine, config.ident + ) + self.conn = self.engine.connect() self.metadata = MetaData() - def teardown(self): + def teardown_test(self): with self.conn.begin(): self.metadata.drop_all(self.conn) self.conn.close() + self.engine.dispose() def test_no_tables(self): insp = inspect(self.conn) @@ -1495,7 +1503,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): __skip_if__ = (full_text_search_missing,) @classmethod - def setup_class(cls): + def setup_test_class(cls): global metadata, cattable, matchtable metadata = MetaData() exec_sql( @@ -1559,7 +1567,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): metadata.drop_all(testing.db) def test_expression(self): @@ -1681,7 +1689,7 @@ class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL): class ReflectHeadlessFKsTest(fixtures.TestBase): __only_on__ = "sqlite" - def setup(self): + def setup_test(self): exec_sql(testing.db, "CREATE TABLE a (id INTEGER PRIMARY KEY)") # this syntax actually works on other DBs perhaps we'd want to add # tests to test_reflection @@ -1689,7 +1697,7 @@ class ReflectHeadlessFKsTest(fixtures.TestBase): testing.db, "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)" ) - def teardown(self): + def teardown_test(self): exec_sql(testing.db, "drop table b") exec_sql(testing.db, "drop table a") @@ -1728,7 +1736,7 @@ class ConstraintReflectionTest(fixtures.TestBase): __only_on__ = "sqlite" @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.begin() as conn: conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)") @@ -1876,7 +1884,7 @@ class ConstraintReflectionTest(fixtures.TestBase): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: for name in [ "implicit_referrer_comp_fake", @@ -2370,7 +2378,7 @@ class SavepointTest(fixtures.TablesTest): @classmethod def setup_bind(cls): - engine = engines.testing_engine(options={"use_reaper": False}) + engine = engines.testing_engine(options={"scope": "class"}) @event.listens_for(engine, "connect") def do_connect(dbapi_connection, connection_record): @@ -2579,7 +2587,7 @@ class TypeReflectionTest(fixtures.TestBase): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "sqlite" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index 396b48aa4..baa766d48 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -21,7 +21,7 @@ from sqlalchemy.testing.schema import Table class DDLEventTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.bind = engines.mock_engine() self.metadata = MetaData() self.table = Table("t", self.metadata, Column("id", Integer)) @@ -374,7 +374,7 @@ class DDLEventTest(fixtures.TestBase): class DDLExecutionTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.engine = engines.mock_engine() self.metadata = MetaData() self.users = Table( diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index a18cf756b..0a2c9abe5 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -965,7 +965,7 @@ class TransactionTest(fixtures.TablesTest): class HandleInvalidatedOnConnectTest(fixtures.TestBase): __requires__ = ("sqlite",) - def setUp(self): + def setup_test(self): e = create_engine("sqlite://") connection = Mock(get_server_version_info=Mock(return_value="5.0")) @@ -1021,18 +1021,18 @@ def MockDBAPI(): # noqa class PoolTestBase(fixtures.TestBase): - def setup(self): + def setup_test(self): pool.clear_managers() self._teardown_conns = [] - def teardown(self): + def teardown_test(self): for ref in self._teardown_conns: conn = ref() if conn: conn.close() @classmethod - def teardown_class(cls): + def teardown_test_class(cls): pool.clear_managers() def _queuepool_fixture(self, **kw): @@ -1597,7 +1597,7 @@ class EngineEventsTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) __backend__ = True - def tearDown(self): + def teardown_test(self): Engine.dispatch._clear() Engine._has_events = False @@ -1650,6 +1650,7 @@ class EngineEventsTest(fixtures.TestBase): event.listen( engine, "before_cursor_execute", cursor_execute, retval=True ) + with testing.expect_deprecated( r"The argument signature for the " r"\"ConnectionEvents.before_execute\" event listener", @@ -1676,11 +1677,12 @@ class EngineEventsTest(fixtures.TestBase): r"The argument signature for the " r"\"ConnectionEvents.after_execute\" event listener", ): - e1.execute(select(1)) + result = e1.execute(select(1)) + result.close() class DDLExecutionTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.engine = engines.mock_engine() self.metadata = MetaData() self.users = Table( diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 21d4e06e0..a1e4ea218 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -43,7 +43,6 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL -from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.mock import call from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.mock import patch @@ -94,13 +93,13 @@ class ExecuteTest(fixtures.TablesTest): ).default_from() ) - conn = testing.db.connect() - result = ( - conn.execution_options(no_parameters=True) - .exec_driver_sql(stmt) - .scalar() - ) - eq_(result, "%") + with testing.db.connect() as conn: + result = ( + conn.execution_options(no_parameters=True) + .exec_driver_sql(stmt) + .scalar() + ) + eq_(result, "%") def test_raw_positional_invalid(self, connection): assert_raises_message( @@ -261,16 +260,15 @@ class ExecuteTest(fixtures.TablesTest): (4, "sally"), ] - @testing.engines.close_open_connections def test_exception_wrapping_dbapi(self): - conn = testing.db.connect() - # engine does not have exec_driver_sql - assert_raises_message( - tsa.exc.DBAPIError, - r"not_a_valid_statement", - conn.exec_driver_sql, - "not_a_valid_statement", - ) + with testing.db.connect() as conn: + # engine does not have exec_driver_sql + assert_raises_message( + tsa.exc.DBAPIError, + r"not_a_valid_statement", + conn.exec_driver_sql, + "not_a_valid_statement", + ) @testing.requires.sqlite def test_exception_wrapping_non_dbapi_error(self): @@ -864,12 +862,10 @@ class CompiledCacheTest(fixtures.TestBase): ["sqlite", "mysql", "postgresql"], "uses blob value that is problematic for some DBAPIs", ) - @testing.provide_metadata - def test_cache_noleak_on_statement_values(self, connection): + def test_cache_noleak_on_statement_values(self, metadata, connection): # This is a non regression test for an object reference leak caused # by the compiled_cache. - metadata = self.metadata photo = Table( "photo", metadata, @@ -1040,7 +1036,19 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): __requires__ = ("schemas",) __backend__ = True - def test_create_table(self): + @testing.fixture + def plain_tables(self, metadata): + t1 = Table( + "t1", metadata, Column("x", Integer), schema=config.test_schema + ) + t2 = Table( + "t2", metadata, Column("x", Integer), schema=config.test_schema + ) + t3 = Table("t3", metadata, Column("x", Integer), schema=None) + + return t1, t2, t3 + + def test_create_table(self, plain_tables, connection): map_ = { None: config.test_schema, "foo": config.test_schema, @@ -1052,18 +1060,16 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t2 = Table("t2", metadata, Column("x", Integer), schema="foo") t3 = Table("t3", metadata, Column("x", Integer), schema="bar") - with self.sql_execution_asserter(config.db) as asserter: - with config.db.begin() as conn, conn.execution_options( - schema_translate_map=map_ - ) as conn: + with self.sql_execution_asserter(connection) as asserter: + conn = connection.execution_options(schema_translate_map=map_) - t1.create(conn) - t2.create(conn) - t3.create(conn) + t1.create(conn) + t2.create(conn) + t3.create(conn) - t3.drop(conn) - t2.drop(conn) - t1.drop(conn) + t3.drop(conn) + t2.drop(conn) + t1.drop(conn) asserter.assert_( CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"), @@ -1074,14 +1080,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): CompiledSQL("DROP TABLE [SCHEMA__none].t1"), ) - def _fixture(self): - metadata = self.metadata - Table("t1", metadata, Column("x", Integer), schema=config.test_schema) - Table("t2", metadata, Column("x", Integer), schema=config.test_schema) - Table("t3", metadata, Column("x", Integer), schema=None) - metadata.create_all(testing.db) - - def test_ddl_hastable(self): + def test_ddl_hastable(self, plain_tables, connection): map_ = { None: config.test_schema, @@ -1094,27 +1093,28 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): Table("t2", metadata, Column("x", Integer), schema="foo") Table("t3", metadata, Column("x", Integer), schema="bar") - with config.db.begin() as conn: - conn = conn.execution_options(schema_translate_map=map_) - metadata.create_all(conn) + conn = connection.execution_options(schema_translate_map=map_) + metadata.create_all(conn) - insp = inspect(config.db) + insp = inspect(connection) is_true(insp.has_table("t1", schema=config.test_schema)) is_true(insp.has_table("t2", schema=config.test_schema)) is_true(insp.has_table("t3", schema=None)) - with config.db.begin() as conn: - conn = conn.execution_options(schema_translate_map=map_) - metadata.drop_all(conn) + conn = connection.execution_options(schema_translate_map=map_) + + # if this test fails, the tables won't get dropped. so need a + # more robust fixture for this + metadata.drop_all(conn) - insp = inspect(config.db) + insp = inspect(connection) is_false(insp.has_table("t1", schema=config.test_schema)) is_false(insp.has_table("t2", schema=config.test_schema)) is_false(insp.has_table("t3", schema=None)) - @testing.provide_metadata - def test_option_on_execute(self): - self._fixture() + def test_option_on_execute(self, plain_tables, connection): + # provided by metadata fixture provided by plain_tables fixture + self.metadata.create_all(connection) map_ = { None: config.test_schema, @@ -1127,61 +1127,54 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t2 = Table("t2", metadata, Column("x", Integer), schema="foo") t3 = Table("t3", metadata, Column("x", Integer), schema="bar") - with self.sql_execution_asserter(config.db) as asserter: - with config.db.begin() as conn: + with self.sql_execution_asserter(connection) as asserter: + conn = connection + execution_options = {"schema_translate_map": map_} + conn._execute_20( + t1.insert(), {"x": 1}, execution_options=execution_options + ) + conn._execute_20( + t2.insert(), {"x": 1}, execution_options=execution_options + ) + conn._execute_20( + t3.insert(), {"x": 1}, execution_options=execution_options + ) - execution_options = {"schema_translate_map": map_} - conn._execute_20( - t1.insert(), {"x": 1}, execution_options=execution_options - ) - conn._execute_20( - t2.insert(), {"x": 1}, execution_options=execution_options - ) - conn._execute_20( - t3.insert(), {"x": 1}, execution_options=execution_options - ) + conn._execute_20( + t1.update().values(x=1).where(t1.c.x == 1), + execution_options=execution_options, + ) + conn._execute_20( + t2.update().values(x=2).where(t2.c.x == 1), + execution_options=execution_options, + ) + conn._execute_20( + t3.update().values(x=3).where(t3.c.x == 1), + execution_options=execution_options, + ) + eq_( conn._execute_20( - t1.update().values(x=1).where(t1.c.x == 1), - execution_options=execution_options, - ) + select(t1.c.x), execution_options=execution_options + ).scalar(), + 1, + ) + eq_( conn._execute_20( - t2.update().values(x=2).where(t2.c.x == 1), - execution_options=execution_options, - ) + select(t2.c.x), execution_options=execution_options + ).scalar(), + 2, + ) + eq_( conn._execute_20( - t3.update().values(x=3).where(t3.c.x == 1), - execution_options=execution_options, - ) - - eq_( - conn._execute_20( - select(t1.c.x), execution_options=execution_options - ).scalar(), - 1, - ) - eq_( - conn._execute_20( - select(t2.c.x), execution_options=execution_options - ).scalar(), - 2, - ) - eq_( - conn._execute_20( - select(t3.c.x), execution_options=execution_options - ).scalar(), - 3, - ) + select(t3.c.x), execution_options=execution_options + ).scalar(), + 3, + ) - conn._execute_20( - t1.delete(), execution_options=execution_options - ) - conn._execute_20( - t2.delete(), execution_options=execution_options - ) - conn._execute_20( - t3.delete(), execution_options=execution_options - ) + conn._execute_20(t1.delete(), execution_options=execution_options) + conn._execute_20(t2.delete(), execution_options=execution_options) + conn._execute_20(t3.delete(), execution_options=execution_options) asserter.assert_( CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"), @@ -1207,9 +1200,9 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): CompiledSQL("DELETE FROM [SCHEMA_bar].t3"), ) - @testing.provide_metadata - def test_crud(self): - self._fixture() + def test_crud(self, plain_tables, connection): + # provided by metadata fixture provided by plain_tables fixture + self.metadata.create_all(connection) map_ = { None: config.test_schema, @@ -1222,26 +1215,24 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t2 = Table("t2", metadata, Column("x", Integer), schema="foo") t3 = Table("t3", metadata, Column("x", Integer), schema="bar") - with self.sql_execution_asserter(config.db) as asserter: - with config.db.begin() as conn, conn.execution_options( - schema_translate_map=map_ - ) as conn: + with self.sql_execution_asserter(connection) as asserter: + conn = connection.execution_options(schema_translate_map=map_) - conn.execute(t1.insert(), {"x": 1}) - conn.execute(t2.insert(), {"x": 1}) - conn.execute(t3.insert(), {"x": 1}) + conn.execute(t1.insert(), {"x": 1}) + conn.execute(t2.insert(), {"x": 1}) + conn.execute(t3.insert(), {"x": 1}) - conn.execute(t1.update().values(x=1).where(t1.c.x == 1)) - conn.execute(t2.update().values(x=2).where(t2.c.x == 1)) - conn.execute(t3.update().values(x=3).where(t3.c.x == 1)) + conn.execute(t1.update().values(x=1).where(t1.c.x == 1)) + conn.execute(t2.update().values(x=2).where(t2.c.x == 1)) + conn.execute(t3.update().values(x=3).where(t3.c.x == 1)) - eq_(conn.scalar(select(t1.c.x)), 1) - eq_(conn.scalar(select(t2.c.x)), 2) - eq_(conn.scalar(select(t3.c.x)), 3) + eq_(conn.scalar(select(t1.c.x)), 1) + eq_(conn.scalar(select(t2.c.x)), 2) + eq_(conn.scalar(select(t3.c.x)), 3) - conn.execute(t1.delete()) - conn.execute(t2.delete()) - conn.execute(t3.delete()) + conn.execute(t1.delete()) + conn.execute(t2.delete()) + conn.execute(t3.delete()) asserter.assert_( CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"), @@ -1267,9 +1258,10 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): CompiledSQL("DELETE FROM [SCHEMA_bar].t3"), ) - @testing.provide_metadata - def test_via_engine(self): - self._fixture() + def test_via_engine(self, plain_tables, metadata): + + with config.db.begin() as connection: + metadata.create_all(connection) map_ = { None: config.test_schema, @@ -1282,25 +1274,25 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): with self.sql_execution_asserter(config.db) as asserter: eng = config.db.execution_options(schema_translate_map=map_) - conn = eng.connect() - conn.execute(select(t2.c.x)) + with eng.connect() as conn: + conn.execute(select(t2.c.x)) asserter.assert_( CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2") ) class ExecutionOptionsTest(fixtures.TestBase): - def test_dialect_conn_options(self): + def test_dialect_conn_options(self, testing_engine): engine = testing_engine("sqlite://", options=dict(_initialize=False)) engine.dialect = Mock() - conn = engine.connect() - c2 = conn.execution_options(foo="bar") - eq_( - engine.dialect.set_connection_execution_options.mock_calls, - [call(c2, {"foo": "bar"})], - ) + with engine.connect() as conn: + c2 = conn.execution_options(foo="bar") + eq_( + engine.dialect.set_connection_execution_options.mock_calls, + [call(c2, {"foo": "bar"})], + ) - def test_dialect_engine_options(self): + def test_dialect_engine_options(self, testing_engine): engine = testing_engine("sqlite://") engine.dialect = Mock() e2 = engine.execution_options(foo="bar") @@ -1319,14 +1311,14 @@ class ExecutionOptionsTest(fixtures.TestBase): [call(engine, {"foo": "bar"})], ) - def test_propagate_engine_to_connection(self): + def test_propagate_engine_to_connection(self, testing_engine): engine = testing_engine( "sqlite://", options=dict(execution_options={"foo": "bar"}) ) - conn = engine.connect() - eq_(conn._execution_options, {"foo": "bar"}) + with engine.connect() as conn: + eq_(conn._execution_options, {"foo": "bar"}) - def test_propagate_option_engine_to_connection(self): + def test_propagate_option_engine_to_connection(self, testing_engine): e1 = testing_engine( "sqlite://", options=dict(execution_options={"foo": "bar"}) ) @@ -1336,27 +1328,30 @@ class ExecutionOptionsTest(fixtures.TestBase): eq_(c1._execution_options, {"foo": "bar"}) eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"}) - def test_get_engine_execution_options(self): + c1.close() + c2.close() + + def test_get_engine_execution_options(self, testing_engine): engine = testing_engine("sqlite://") engine.dialect = Mock() e2 = engine.execution_options(foo="bar") eq_(e2.get_execution_options(), {"foo": "bar"}) - def test_get_connection_execution_options(self): + def test_get_connection_execution_options(self, testing_engine): engine = testing_engine("sqlite://", options=dict(_initialize=False)) engine.dialect = Mock() - conn = engine.connect() - c = conn.execution_options(foo="bar") + with engine.connect() as conn: + c = conn.execution_options(foo="bar") - eq_(c.get_execution_options(), {"foo": "bar"}) + eq_(c.get_execution_options(), {"foo": "bar"}) class EngineEventsTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) __backend__ = True - def tearDown(self): + def teardown_test(self): Engine.dispatch._clear() Engine._has_events = False @@ -1376,7 +1371,7 @@ class EngineEventsTest(fixtures.TestBase): ): break - def test_per_engine_independence(self): + def test_per_engine_independence(self, testing_engine): e1 = testing_engine(config.db_url) e2 = testing_engine(config.db_url) @@ -1400,7 +1395,7 @@ class EngineEventsTest(fixtures.TestBase): conn.execute(s2) eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2]) - def test_per_engine_plus_global(self): + def test_per_engine_plus_global(self, testing_engine): canary = Mock() event.listen(Engine, "before_execute", canary.be1) e1 = testing_engine(config.db_url) @@ -1409,8 +1404,6 @@ class EngineEventsTest(fixtures.TestBase): event.listen(e1, "before_execute", canary.be2) event.listen(Engine, "before_execute", canary.be3) - e1.connect() - e2.connect() with e1.connect() as conn: conn.execute(select(1)) @@ -1424,7 +1417,7 @@ class EngineEventsTest(fixtures.TestBase): eq_(canary.be2.call_count, 1) eq_(canary.be3.call_count, 2) - def test_per_connection_plus_engine(self): + def test_per_connection_plus_engine(self, testing_engine): canary = Mock() e1 = testing_engine(config.db_url) @@ -1442,9 +1435,14 @@ class EngineEventsTest(fixtures.TestBase): eq_(canary.be1.call_count, 2) eq_(canary.be2.call_count, 2) - @testing.combinations((True, False), (True, True), (False, False)) + @testing.combinations( + (True, False), + (True, True), + (False, False), + argnames="mock_out_on_connect, add_our_own_onconnect", + ) def test_insert_connect_is_definitely_first( - self, mock_out_on_connect, add_our_own_onconnect + self, mock_out_on_connect, add_our_own_onconnect, testing_engine ): """test issue #5708. @@ -1478,7 +1476,7 @@ class EngineEventsTest(fixtures.TestBase): patcher = util.nullcontext() with patcher: - e1 = create_engine(config.db_url) + e1 = testing_engine(config.db_url) initialize = e1.dialect.initialize @@ -1559,10 +1557,11 @@ class EngineEventsTest(fixtures.TestBase): conn.exec_driver_sql(select1(testing.db)) eq_(m1.mock_calls, []) - def test_add_event_after_connect(self): + def test_add_event_after_connect(self, testing_engine): # new feature as of #2978 + canary = Mock() - e1 = create_engine(config.db_url) + e1 = testing_engine(config.db_url, future=False) assert not e1._has_events conn = e1.connect() @@ -1575,9 +1574,9 @@ class EngineEventsTest(fixtures.TestBase): conn._branch().execute(select(1)) eq_(canary.be1.call_count, 2) - def test_force_conn_events_false(self): + def test_force_conn_events_false(self, testing_engine): canary = Mock() - e1 = create_engine(config.db_url) + e1 = testing_engine(config.db_url, future=False) assert not e1._has_events event.listen(e1, "before_execute", canary.be1) @@ -1593,7 +1592,7 @@ class EngineEventsTest(fixtures.TestBase): conn._branch().execute(select(1)) eq_(canary.be1.call_count, 0) - def test_cursor_events_ctx_execute_scalar(self): + def test_cursor_events_ctx_execute_scalar(self, testing_engine): canary = Mock() e1 = testing_engine(config.db_url) @@ -1620,7 +1619,7 @@ class EngineEventsTest(fixtures.TestBase): [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], ) - def test_cursor_events_execute(self): + def test_cursor_events_execute(self, testing_engine): canary = Mock() e1 = testing_engine(config.db_url) @@ -1653,9 +1652,15 @@ class EngineEventsTest(fixtures.TestBase): ), ((), {"z": 10}, [], {"z": 10}, testing.requires.legacy_engine), (({"z": 10},), {}, [], {"z": 10}), + argnames="multiparams, params, expected_multiparams, expected_params", ) def test_modify_parameters_from_event_one( - self, multiparams, params, expected_multiparams, expected_params + self, + multiparams, + params, + expected_multiparams, + expected_params, + testing_engine, ): # this is testing both the normalization added to parameters # as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as @@ -1704,7 +1709,9 @@ class EngineEventsTest(fixtures.TestBase): [(15,), (19,)], ) - def test_modify_parameters_from_event_three(self, connection): + def test_modify_parameters_from_event_three( + self, connection, testing_engine + ): def before_execute( conn, clauseelement, multiparams, params, execution_options ): @@ -1721,7 +1728,7 @@ class EngineEventsTest(fixtures.TestBase): with e1.connect() as conn: conn.execute(select(literal("1"))) - def test_argument_format_execute(self): + def test_argument_format_execute(self, testing_engine): def before_execute( conn, clauseelement, multiparams, params, execution_options ): @@ -1956,9 +1963,9 @@ class EngineEventsTest(fixtures.TestBase): ) @testing.requires.ad_hoc_engines - def test_dispose_event(self): + def test_dispose_event(self, testing_engine): canary = Mock() - eng = create_engine(testing.db.url) + eng = testing_engine(testing.db.url) event.listen(eng, "engine_disposed", canary) conn = eng.connect() @@ -2102,13 +2109,13 @@ class EngineEventsTest(fixtures.TestBase): event.listen(engine, "commit", tracker("commit")) event.listen(engine, "rollback", tracker("rollback")) - conn = engine.connect() - trans = conn.begin() - conn.execute(select(1)) - trans.rollback() - trans = conn.begin() - conn.execute(select(1)) - trans.commit() + with engine.connect() as conn: + trans = conn.begin() + conn.execute(select(1)) + trans.rollback() + trans = conn.begin() + conn.execute(select(1)) + trans.commit() eq_( canary, @@ -2145,13 +2152,13 @@ class EngineEventsTest(fixtures.TestBase): event.listen(engine, "commit", tracker("commit"), named=True) event.listen(engine, "rollback", tracker("rollback"), named=True) - conn = engine.connect() - trans = conn.begin() - conn.execute(select(1)) - trans.rollback() - trans = conn.begin() - conn.execute(select(1)) - trans.commit() + with engine.connect() as conn: + trans = conn.begin() + conn.execute(select(1)) + trans.rollback() + trans = conn.begin() + conn.execute(select(1)) + trans.commit() eq_( canary, @@ -2310,7 +2317,7 @@ class HandleErrorTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) __backend__ = True - def tearDown(self): + def teardown_test(self): Engine.dispatch._clear() Engine._has_events = False @@ -2742,7 +2749,7 @@ class HandleErrorTest(fixtures.TestBase): class HandleInvalidatedOnConnectTest(fixtures.TestBase): __requires__ = ("sqlite",) - def setUp(self): + def setup_test(self): e = create_engine("sqlite://") connection = Mock(get_server_version_info=Mock(return_value="5.0")) @@ -3014,6 +3021,9 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): ], ) + c.close() + c2.close() + class DialectEventTest(fixtures.TestBase): @contextmanager @@ -3370,7 +3380,7 @@ class SetInputSizesTest(fixtures.TablesTest): ) @testing.fixture - def input_sizes_fixture(self): + def input_sizes_fixture(self, testing_engine): canary = mock.Mock() def do_set_input_sizes(cursor, list_of_tuples, context): diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index 29b8132aa..c56589248 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -30,7 +30,7 @@ class LogParamsTest(fixtures.TestBase): __only_on__ = "sqlite" __requires__ = ("ad_hoc_engines",) - def setup(self): + def setup_test(self): self.eng = engines.testing_engine(options={"echo": True}) self.no_param_engine = engines.testing_engine( options={"echo": True, "hide_parameters": True} @@ -44,7 +44,7 @@ class LogParamsTest(fixtures.TestBase): for log in [logging.getLogger("sqlalchemy.engine")]: log.addHandler(self.buf) - def teardown(self): + def teardown_test(self): exec_sql(self.eng, "drop table if exists foo") for log in [logging.getLogger("sqlalchemy.engine")]: log.removeHandler(self.buf) @@ -413,14 +413,14 @@ class LogParamsTest(fixtures.TestBase): class PoolLoggingTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.existing_level = logging.getLogger("sqlalchemy.pool").level self.buf = logging.handlers.BufferingHandler(100) for log in [logging.getLogger("sqlalchemy.pool")]: log.addHandler(self.buf) - def teardown(self): + def teardown_test(self): for log in [logging.getLogger("sqlalchemy.pool")]: log.removeHandler(self.buf) logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level) @@ -528,7 +528,7 @@ class LoggingNameTest(fixtures.TestBase): kw.update({"echo": True}) return engines.testing_engine(options=kw) - def setup(self): + def setup_test(self): self.buf = logging.handlers.BufferingHandler(100) for log in [ logging.getLogger("sqlalchemy.engine"), @@ -536,7 +536,7 @@ class LoggingNameTest(fixtures.TestBase): ]: log.addHandler(self.buf) - def teardown(self): + def teardown_test(self): for log in [ logging.getLogger("sqlalchemy.engine"), logging.getLogger("sqlalchemy.pool"), @@ -588,13 +588,13 @@ class LoggingNameTest(fixtures.TestBase): class EchoTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) - def setup(self): + def setup_test(self): self.level = logging.getLogger("sqlalchemy.engine").level logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN) self.buf = logging.handlers.BufferingHandler(100) logging.getLogger("sqlalchemy.engine").addHandler(self.buf) - def teardown(self): + def teardown_test(self): logging.getLogger("sqlalchemy.engine").removeHandler(self.buf) logging.getLogger("sqlalchemy.engine").setLevel(self.level) diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 550fedb8e..decdce3f9 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -17,7 +17,9 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_none from sqlalchemy.testing import is_not +from sqlalchemy.testing import is_not_none from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.engines import testing_engine @@ -63,18 +65,18 @@ def MockDBAPI(): # noqa class PoolTestBase(fixtures.TestBase): - def setup(self): + def setup_test(self): pool.clear_managers() self._teardown_conns = [] - def teardown(self): + def teardown_test(self): for ref in self._teardown_conns: conn = ref() if conn: conn.close() @classmethod - def teardown_class(cls): + def teardown_test_class(cls): pool.clear_managers() def _with_teardown(self, connection): @@ -364,10 +366,17 @@ class PoolEventsTest(PoolTestBase): p = self._queuepool_fixture() canary = [] + @event.listens_for(p, "checkin") def checkin(*arg, **kw): canary.append("checkin") - event.listen(p, "checkin", checkin) + @event.listens_for(p, "close_detached") + def close_detached(*arg, **kw): + canary.append("close_detached") + + @event.listens_for(p, "detach") + def detach(*arg, **kw): + canary.append("detach") return p, canary @@ -629,15 +638,35 @@ class PoolEventsTest(PoolTestBase): assert canary.call_args_list[0][0][0] is dbapi_con assert canary.call_args_list[0][0][2] is exc + @testing.combinations((True, testing.requires.python3), (False,)) @testing.requires.predictable_gc - def test_checkin_event_gc(self): + def test_checkin_event_gc(self, detach_gced): p, canary = self._checkin_event_fixture() + if detach_gced: + p._is_asyncio = True + c1 = p.connect() + + dbapi_connection = weakref.ref(c1.connection) + eq_(canary, []) del c1 lazy_gc() - eq_(canary, ["checkin"]) + + if detach_gced: + # "close_detached" is not called because for asyncio the + # connection is just lost. + eq_(canary, ["detach"]) + + else: + eq_(canary, ["checkin"]) + + gc_collect() + if detach_gced: + is_none(dbapi_connection()) + else: + is_not_none(dbapi_connection()) def test_checkin_event_on_subsequently_recreated(self): p, canary = self._checkin_event_fixture() @@ -744,7 +773,7 @@ class PoolEventsTest(PoolTestBase): eq_(conn.info["important_flag"], True) conn.close() - def teardown(self): + def teardown_test(self): # TODO: need to get remove() functionality # going pool.Pool.dispatch._clear() @@ -1490,12 +1519,16 @@ class QueuePoolTest(PoolTestBase): self._assert_cleanup_on_pooled_reconnect(dbapi, p) + @testing.combinations((True, testing.requires.python3), (False,)) @testing.requires.predictable_gc - def test_userspace_disconnectionerror_weakref_finalizer(self): + def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced): dbapi, pool = self._queuepool_dbapi_fixture( pool_size=1, max_overflow=2 ) + if detach_gced: + pool._is_asyncio = True + @event.listens_for(pool, "checkout") def handle_checkout_event(dbapi_con, con_record, con_proxy): if getattr(dbapi_con, "boom") == "yes": @@ -1514,8 +1547,12 @@ class QueuePoolTest(PoolTestBase): del conn gc_collect() - # new connection was reset on return appropriately - eq_(dbapi_conn.mock_calls, [call.rollback()]) + if detach_gced: + # new connection was detached + abandoned on return + eq_(dbapi_conn.mock_calls, []) + else: + # new connection reset and returned to pool + eq_(dbapi_conn.mock_calls, [call.rollback()]) # old connection was just closed - did not get an # erroneous reset on return diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py index 3810de06a..5a4220c82 100644 --- a/test/engine/test_processors.py +++ b/test/engine/test_processors.py @@ -25,7 +25,7 @@ class CBooleanProcessorTest(_BooleanProcessorTest): __requires__ = ("cextensions",) @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import cprocessors cls.module = cprocessors @@ -83,7 +83,7 @@ class _DateProcessorTest(fixtures.TestBase): class PyDateProcessorTest(_DateProcessorTest): @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import processors cls.module = type( @@ -100,7 +100,7 @@ class CDateProcessorTest(_DateProcessorTest): __requires__ = ("cextensions",) @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import cprocessors cls.module = cprocessors @@ -185,7 +185,7 @@ class _DistillArgsTest(fixtures.TestBase): class PyDistillArgsTest(_DistillArgsTest): @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy.engine import util cls.module = type( @@ -202,7 +202,7 @@ class CDistillArgsTest(_DistillArgsTest): __requires__ = ("cextensions",) @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import cutils as util cls.module = util diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 5fe7f6cc2..7a64b2550 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -162,7 +162,7 @@ def MockDBAPI(): class PrePingMockTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.dbapi = MockDBAPI() def _pool_fixture(self, pre_ping, pool_kw=None): @@ -182,7 +182,7 @@ class PrePingMockTest(fixtures.TestBase): ) return _pool - def teardown(self): + def teardown_test(self): self.dbapi.dispose() def test_ping_not_on_first_connect(self): @@ -357,7 +357,7 @@ class PrePingMockTest(fixtures.TestBase): class MockReconnectTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.dbapi = MockDBAPI() self.db = testing_engine( @@ -373,7 +373,7 @@ class MockReconnectTest(fixtures.TestBase): e, MockDisconnect ) - def teardown(self): + def teardown_test(self): self.dbapi.dispose() def test_reconnect(self): @@ -1004,10 +1004,10 @@ class RealReconnectTest(fixtures.TestBase): __backend__ = True __requires__ = "graceful_disconnects", "ad_hoc_engines" - def setup(self): + def setup_test(self): self.engine = engines.reconnecting_engine() - def teardown(self): + def teardown_test(self): self.engine.dispose() def test_reconnect(self): @@ -1336,7 +1336,7 @@ class PrePingRealTest(fixtures.TestBase): class InvalidateDuringResultTest(fixtures.TestBase): __backend__ = True - def setup(self): + def setup_test(self): self.engine = engines.reconnecting_engine() self.meta = MetaData() table = Table( @@ -1353,7 +1353,7 @@ class InvalidateDuringResultTest(fixtures.TestBase): [{"id": i, "name": "row %d" % i} for i in range(1, 100)], ) - def teardown(self): + def teardown_test(self): with self.engine.begin() as conn: self.meta.drop_all(conn) self.engine.dispose() @@ -1470,7 +1470,7 @@ class ReconnectRecipeTest(fixtures.TestBase): __backend__ = True - def setup(self): + def setup_test(self): self.engine = engines.reconnecting_engine( options=dict(future=self.future) ) @@ -1483,7 +1483,7 @@ class ReconnectRecipeTest(fixtures.TestBase): ) self.meta.create_all(self.engine) - def teardown(self): + def teardown_test(self): self.meta.drop_all(self.engine) self.engine.dispose() diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 658cdd79f..0a46ddeec 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -796,7 +796,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert f1 in b1.constraints assert len(b1.constraints) == 2 - def test_override_keys(self, connection, metadata): + def test_override_keys(self, metadata, connection): """test that columns can be overridden with a 'key', and that ForeignKey targeting during reflection still works.""" @@ -1375,7 +1375,7 @@ class CreateDropTest(fixtures.TablesTest): run_create_tables = None @classmethod - def teardown_class(cls): + def teardown_test_class(cls): # TablesTest is used here without # run_create_tables, so add an explicit drop of whatever is in # metadata @@ -1658,7 +1658,6 @@ class SchemaTest(fixtures.TestBase): @testing.requires.schemas @testing.requires.cross_schema_fk_reflection @testing.requires.implicit_default_schema - @testing.provide_metadata def test_blank_schema_arg(self, connection, metadata): Table( @@ -1913,7 +1912,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): __backend__ = True @testing.requires.denormalized_names - def setup(self): + def setup_test(self): with testing.db.begin() as conn: conn.exec_driver_sql( """ @@ -1926,7 +1925,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): ) @testing.requires.denormalized_names - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("drop table weird_casing") diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 79126fc5b..47504b60a 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -1,6 +1,5 @@ import sys -from sqlalchemy import create_engine from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import func @@ -640,12 +639,12 @@ class AutoRollbackTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): global metadata metadata = MetaData() @classmethod - def teardown_class(cls): + def teardown_test_class(cls): metadata.drop_all(testing.db) def test_rollback_deadlock(self): @@ -871,11 +870,13 @@ class IsolationLevelTest(fixtures.TestBase): def test_per_engine(self): # new in 0.9 - eng = create_engine( + eng = testing_engine( testing.db.url, - execution_options={ - "isolation_level": self._non_default_isolation_level() - }, + options=dict( + execution_options={ + "isolation_level": self._non_default_isolation_level() + } + ), ) conn = eng.connect() eq_( @@ -884,7 +885,7 @@ class IsolationLevelTest(fixtures.TestBase): ) def test_per_option_engine(self): - eng = create_engine(testing.db.url).execution_options( + eng = testing_engine(testing.db.url).execution_options( isolation_level=self._non_default_isolation_level() ) @@ -895,14 +896,14 @@ class IsolationLevelTest(fixtures.TestBase): ) def test_isolation_level_accessors_connection_default(self): - eng = create_engine(testing.db.url) + eng = testing_engine(testing.db.url) with eng.connect() as conn: eq_(conn.default_isolation_level, self._default_isolation_level()) with eng.connect() as conn: eq_(conn.get_isolation_level(), self._default_isolation_level()) def test_isolation_level_accessors_connection_option_modified(self): - eng = create_engine(testing.db.url) + eng = testing_engine(testing.db.url) with eng.connect() as conn: c2 = conn.execution_options( isolation_level=self._non_default_isolation_level() diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 7dae1411e..59a44f8e2 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -269,7 +269,7 @@ class AsyncEngineTest(EngineFixture): await trans.rollback(), @async_test - async def test_pool_exhausted(self, async_engine): + async def test_pool_exhausted_some_timeout(self, async_engine): engine = create_async_engine( testing.db.url, pool_size=1, @@ -277,7 +277,19 @@ class AsyncEngineTest(EngineFixture): pool_timeout=0.1, ) async with engine.connect(): - with expect_raises(asyncio.TimeoutError): + with expect_raises(exc.TimeoutError): + await engine.connect() + + @async_test + async def test_pool_exhausted_no_timeout(self, async_engine): + engine = create_async_engine( + testing.db.url, + pool_size=1, + max_overflow=0, + pool_timeout=0, + ) + async with engine.connect(): + with expect_raises(exc.TimeoutError): await engine.connect() @async_test diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index 2b80b753e..e25e7cfc2 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -27,11 +27,11 @@ Base = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): + def setup_test(self): global Base Base = decl.declarative_base(testing.db) - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() Base.metadata.drop_all(testing.db) diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index d7fcbf9e8..c327de7d4 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -4,7 +4,6 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext.declarative import DeferredReflection from sqlalchemy.orm import clear_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import decl_api as decl from sqlalchemy.orm import declared_attr from sqlalchemy.orm import exc as orm_exc @@ -14,6 +13,7 @@ from sqlalchemy.orm.decl_base import _DeferredMapperConfig from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -22,20 +22,19 @@ from sqlalchemy.testing.util import gc_collect class DeclarativeReflectionBase(fixtures.TablesTest): __requires__ = ("reflectable_autoincrement",) - def setup(self): + def setup_test(self): global Base, registry registry = decl.registry() Base = registry.generate_base() - def teardown(self): - super(DeclarativeReflectionBase, self).teardown() + def teardown_test(self): clear_mappers() class DeferredReflectBase(DeclarativeReflectionBase): - def teardown(self): - super(DeferredReflectBase, self).teardown() + def teardown_test(self): + super(DeferredReflectBase, self).teardown_test() _DeferredMapperConfig._configs.clear() @@ -101,22 +100,23 @@ class DeferredReflectionTest(DeferredReflectBase): u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session(testing.db) - sess.add(u1) - sess.flush() - sess.expunge_all() - eq_( - sess.query(User).all(), - [ - User( - name="u1", - addresses=[Address(email="one"), Address(email="two")], - ) - ], - ) - a1 = sess.query(Address).filter(Address.email == "two").one() - eq_(a1, Address(email="two")) - eq_(a1.user, User(name="u1")) + with fixture_session() as sess: + sess.add(u1) + sess.commit() + + with fixture_session() as sess: + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) + a1 = sess.query(Address).filter(Address.email == "two").one() + eq_(a1, Address(email="two")) + eq_(a1.user, User(name="u1")) def test_exception_prepare_not_called(self): class User(DeferredReflection, fixtures.ComparableEntity, Base): @@ -191,15 +191,25 @@ class DeferredReflectionTest(DeferredReflectBase): return {"primary_key": cls.__table__.c.id} DeferredReflection.prepare(testing.db) - sess = Session(testing.db) - sess.add_all( - [User(name="G"), User(name="Q"), User(name="A"), User(name="C")] - ) - sess.commit() - eq_( - sess.query(User).order_by(User.name).all(), - [User(name="A"), User(name="C"), User(name="G"), User(name="Q")], - ) + with fixture_session() as sess: + sess.add_all( + [ + User(name="G"), + User(name="Q"), + User(name="A"), + User(name="C"), + ] + ) + sess.commit() + eq_( + sess.query(User).order_by(User.name).all(), + [ + User(name="A"), + User(name="C"), + User(name="G"), + User(name="Q"), + ], + ) @testing.requires.predictable_gc def test_cls_not_strong_ref(self): @@ -255,14 +265,14 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase): u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")]) - sess = Session(testing.db) - sess.add(u1) - sess.commit() + with fixture_session() as sess: + sess.add(u1) + sess.commit() - eq_( - sess.query(User).all(), - [User(name="u1", items=[Item(name="i1"), Item(name="i2")])], - ) + eq_( + sess.query(User).all(), + [User(name="u1", items=[Item(name="i1"), Item(name="i2")])], + ) def test_string_resolution(self): class User(DeferredReflection, fixtures.ComparableEntity, Base): @@ -296,27 +306,26 @@ class DeferredInhReflectBase(DeferredReflectBase): Foo = Base.registry._class_registry["Foo"] Bar = Base.registry._class_registry["Bar"] - s = Session(testing.db) - - s.add_all( - [ - Bar(data="d1", bar_data="b1"), - Bar(data="d2", bar_data="b2"), - Bar(data="d3", bar_data="b3"), - Foo(data="d4"), - ] - ) - s.commit() - - eq_( - s.query(Foo).order_by(Foo.id).all(), - [ - Bar(data="d1", bar_data="b1"), - Bar(data="d2", bar_data="b2"), - Bar(data="d3", bar_data="b3"), - Foo(data="d4"), - ], - ) + with fixture_session() as s: + s.add_all( + [ + Bar(data="d1", bar_data="b1"), + Bar(data="d2", bar_data="b2"), + Bar(data="d3", bar_data="b3"), + Foo(data="d4"), + ] + ) + s.commit() + + eq_( + s.query(Foo).order_by(Foo.id).all(), + [ + Bar(data="d1", bar_data="b1"), + Bar(data="d2", bar_data="b2"), + Bar(data="d3", bar_data="b3"), + Foo(data="d4"), + ], + ) class DeferredSingleInhReflectionTest(DeferredInhReflectBase): diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index b1f5cc956..31ae050c1 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -101,7 +101,7 @@ class AutoFlushTest(fixtures.TablesTest): Column("name", String(50)), ) - def teardown(self): + def teardown_test(self): clear_mappers() def _fixture(self, collection_class, is_dict=False): @@ -198,7 +198,7 @@ class AutoFlushTest(fixtures.TablesTest): class _CollectionOperations(fixtures.TestBase): - def setup(self): + def setup_test(self): collection_class = self.collection_class metadata = MetaData() @@ -260,7 +260,7 @@ class _CollectionOperations(fixtures.TestBase): self.session = fixture_session() self.Parent, self.Child = Parent, Child - def teardown(self): + def teardown_test(self): self.metadata.drop_all(testing.db) def roundtrip(self, obj): @@ -885,7 +885,7 @@ class CustomObjectTest(_CollectionOperations): class ProxyFactoryTest(ListTest): - def setup(self): + def setup_test(self): metadata = MetaData() parents_table = Table( @@ -1157,7 +1157,7 @@ class ScalarTest(fixtures.TestBase): class LazyLoadTest(fixtures.TestBase): - def setup(self): + def setup_test(self): metadata = MetaData() parents_table = Table( @@ -1197,7 +1197,7 @@ class LazyLoadTest(fixtures.TestBase): self.Parent, self.Child = Parent, Child self.table = parents_table - def teardown(self): + def teardown_test(self): self.metadata.drop_all(testing.db) def roundtrip(self, obj): @@ -2294,7 +2294,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): class DictOfTupleUpdateTest(fixtures.TestBase): - def setup(self): + def setup_test(self): class B(object): def __init__(self, key, elem): self.key = key @@ -2434,7 +2434,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest): class AttributeAccessTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): clear_mappers() def test_resolve_aliased_class(self): diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index 71fabc629..2d4e9848e 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -27,7 +27,7 @@ class BakedTest(_fixtures.FixtureTest): run_inserts = "once" run_deletes = None - def setup(self): + def setup_test(self): self.bakery = baked.bakery() diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index 058c1dfd7..d011417d7 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -426,7 +426,7 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def teardown(self): + def teardown_test(self): for cls in (Select, BindParameter): deregister(cls) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index ad9bf0bc0..f3eceb0dc 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -32,7 +32,7 @@ def modifies_instrumentation_finders(fn, *args, **kw): class _ExtBase(object): @classmethod - def teardown_class(cls): + def teardown_test_class(cls): instrumentation._reinstall_default_lookups() @@ -89,7 +89,7 @@ MyBaseClass, MyClass = None, None class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): @classmethod - def setup_class(cls): + def setup_test_class(cls): global MyBaseClass, MyClass class MyBaseClass(object): @@ -143,7 +143,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): else: del self._goofy_dict[key] - def teardown(self): + def teardown_test(self): clear_mappers() def test_instance_dict(self): diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 038bdd83e..bb06d9648 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -19,7 +19,6 @@ from sqlalchemy import update from sqlalchemy import util from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.orm import clear_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import deferred from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -42,7 +41,7 @@ class ShardTest(object): schema = None - def setUp(self): + def setup_test(self): global db1, db2, db3, db4, weather_locations, weather_reports db1, db2, db3, db4 = self._dbs = self._init_dbs() @@ -88,7 +87,7 @@ class ShardTest(object): @classmethod def setup_session(cls): - global create_session + global sharded_session shard_lookup = { "North America": "north_america", "Asia": "asia", @@ -128,10 +127,10 @@ class ShardTest(object): else: return ids - create_session = sessionmaker( + sharded_session = sessionmaker( class_=ShardedSession, autoflush=True, autocommit=False ) - create_session.configure( + sharded_session.configure( shards={ "north_america": db1, "asia": db2, @@ -180,7 +179,7 @@ class ShardTest(object): tokyo.reports.append(Report(80.0, id_=1)) newyork.reports.append(Report(75, id_=1)) quito.reports.append(Report(85)) - sess = create_session(future=True) + sess = sharded_session(future=True) for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.add(c) sess.flush() @@ -671,11 +670,10 @@ class DistinctEngineShardTest(ShardTest, fixtures.TestBase): self.dbs = [db1, db2, db3, db4] return self.dbs - def teardown(self): + def teardown_test(self): clear_mappers() - for db in self.dbs: - db.connect().invalidate() + testing_reaper.checkin_all() for i in range(1, 5): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) @@ -702,10 +700,10 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase): self.engine = e return db1, db2, db3, db4 - def teardown(self): + def teardown_test(self): clear_mappers() - self.engine.connect().invalidate() + testing_reaper.checkin_all() for i in range(1, 5): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) @@ -778,10 +776,13 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase): self.postgresql_engine = e2 return db1, db2, db3, db4 - def teardown(self): + def teardown_test(self): clear_mappers() - self.sqlite_engine.connect().invalidate() + # the tests in this suite don't cleanly close out the Session + # at the moment so use the reaper to close all connections + testing_reaper.checkin_all() + for i in [1, 3]: os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) @@ -789,6 +790,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase): self.tables_test_metadata.drop_all(conn) for i in [2, 4]: conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,)) + self.postgresql_engine.dispose() class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): @@ -904,11 +906,11 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): return self.dbs - def teardown(self): + def teardown_test(self): for db in self.dbs: db.connect().invalidate() - testing_reaper.close_all() + testing_reaper.checkin_all() for i in range(1, 3): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 048a8b52d..3bab7db93 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -697,7 +697,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import literal symbols = ("usd", "gbp", "cad", "eur", "aud") diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index eba2ac0cb..21244de73 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -90,11 +90,10 @@ class _MutableDictTestFixture(object): def _type_fixture(cls): return MutableDict - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_MutableDictTestFixture, self).teardown() class _MutableDictTestBase(_MutableDictTestFixture): @@ -312,11 +311,10 @@ class _MutableListTestFixture(object): def _type_fixture(cls): return MutableList - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_MutableListTestFixture, self).teardown() class _MutableListTestBase(_MutableListTestFixture): @@ -619,11 +617,10 @@ class _MutableSetTestFixture(object): def _type_fixture(cls): return MutableSet - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_MutableSetTestFixture, self).teardown() class _MutableSetTestBase(_MutableSetTestFixture): @@ -1234,17 +1231,15 @@ class _CompositeTestBase(object): Column("unrelated_data", String(50)), ) - def setup(self): + def setup_test(self): from sqlalchemy.ext import mutable mutable._setup_composite_listener() - super(_CompositeTestBase, self).setup() - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_CompositeTestBase, self).teardown() @classmethod def _type_fixture(cls): diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py index f23d6cb57..280fad6cf 100644 --- a/test/ext/test_orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import picklers @@ -60,7 +60,7 @@ def alpha_ordering(index, collection): class OrderingListTest(fixtures.TestBase): - def setup(self): + def setup_test(self): global metadata, slides_table, bullets_table, Slide, Bullet slides_table, bullets_table = None, None Slide, Bullet = None, None @@ -122,7 +122,7 @@ class OrderingListTest(fixtures.TestBase): metadata.create_all(testing.db) - def teardown(self): + def teardown_test(self): metadata.drop_all(testing.db) def test_append_no_reorder(self): @@ -167,7 +167,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(s1.bullets[2].position == 3) self.assert_(s1.bullets[3].position == 4) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -232,7 +232,7 @@ class OrderingListTest(fixtures.TestBase): s1.bullets._reorder() self.assert_(s1.bullets[4].position == 5) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -289,7 +289,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(len(s1.bullets) == 6) self.assert_(s1.bullets[5].position == 5) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -338,7 +338,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(s1.bullets[li].position == li) self.assert_(s1.bullets[li] == b[bi]) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -365,7 +365,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(len(s1.bullets) == 3) self.assert_(s1.bullets[2].position == 2) - session = create_session() + session = fixture_session() session.add(s1) session.flush() diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 4c005d336..4d9162105 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -61,11 +61,11 @@ class DeclarativeTestBase( ): __dialect__ = "default" - def setup(self): + def setup_test(self): global Base Base = declarative_base(testing.db) - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() Base.metadata.drop_all(testing.db) diff --git a/test/orm/declarative/test_concurrency.py b/test/orm/declarative/test_concurrency.py index 5f12d8272..ecddc2e5f 100644 --- a/test/orm/declarative/test_concurrency.py +++ b/test/orm/declarative/test_concurrency.py @@ -17,7 +17,7 @@ from sqlalchemy.testing.fixtures import fixture_session class ConcurrentUseDeclMappingTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): clear_mappers() @classmethod diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index cc29cab7d..e09b1570e 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -27,11 +27,11 @@ Base = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): + def setup_test(self): global Base Base = decl.declarative_base(testing.db) - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() Base.metadata.drop_all(testing.db) diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 631527daf..ad4832c35 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -38,13 +38,13 @@ mapper_registry = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): + def setup_test(self): global Base, mapper_registry mapper_registry = registry(metadata=MetaData()) Base = mapper_registry.generate_base() - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() with testing.db.begin() as conn: diff --git a/test/orm/declarative/test_reflection.py b/test/orm/declarative/test_reflection.py index 241528c44..e7b2a7058 100644 --- a/test/orm/declarative/test_reflection.py +++ b/test/orm/declarative/test_reflection.py @@ -17,14 +17,13 @@ from sqlalchemy.testing.schema import Table class DeclarativeReflectionBase(fixtures.TablesTest): __requires__ = ("reflectable_autoincrement",) - def setup(self): + def setup_test(self): global Base, registry registry = decl.registry(metadata=MetaData()) Base = registry.generate_base() - def teardown(self): - super(DeclarativeReflectionBase, self).teardown() + def teardown_test(self): clear_mappers() diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index bdcdedc44..da07b4941 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -31,7 +31,6 @@ from sqlalchemy.orm import synonym from sqlalchemy.orm.util import instance_str from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message -from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -1889,7 +1888,6 @@ class VersioningTest(fixtures.MappedTest): @testing.emits_warning(r".*updated rowcount") @testing.requires.sane_rowcount_w_returning - @engines.close_open_connections def test_save_update(self): subtable, base, stuff = ( self.tables.subtable, @@ -2927,7 +2925,7 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase): ) return parent, child - def tearDown(self): + def teardown_test(self): clear_mappers() def test_warning_on_sub(self): @@ -3417,27 +3415,26 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest): @classmethod def insert_data(cls, connection): Parent, A, B = cls.classes("Parent", "A", "B") - s = fixture_session() - - p1 = Parent(id=1) - p2 = Parent(id=2) - s.add_all([p1, p2]) - s.flush() + with Session(connection) as s: + p1 = Parent(id=1) + p2 = Parent(id=2) + s.add_all([p1, p2]) + s.flush() - s.add_all( - [ - A(id=1, parent_id=1), - B(id=2, parent_id=1), - A(id=3, parent_id=1), - B(id=4, parent_id=1), - ] - ) - s.flush() + s.add_all( + [ + A(id=1, parent_id=1), + B(id=2, parent_id=1), + A(id=3, parent_id=1), + B(id=4, parent_id=1), + ] + ) + s.flush() - s.query(A).filter(A.id.in_([3, 4])).update( - {A.type: None}, synchronize_session=False - ) - s.commit() + s.query(A).filter(A.id.in_([3, 4])).update( + {A.type: None}, synchronize_session=False + ) + s.commit() def test_pk_is_null(self): Parent, A = self.classes("Parent", "A") @@ -3527,10 +3524,12 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest): ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes( "ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB" ) - s = fixture_session() + with Session(connection) as s: - s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]) - s.commit() + s.add_all( + [ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()] + ) + s.commit() def test_single_invalid_ident(self): ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA") diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 8820aa6a4..0a0a5d12b 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -209,7 +209,7 @@ class AttributeImplAPITest(fixtures.MappedTest): class AttributesTest(fixtures.ORMTest): - def setup(self): + def setup_test(self): global MyTest, MyTest2 class MyTest(object): @@ -218,7 +218,7 @@ class AttributesTest(fixtures.ORMTest): class MyTest2(object): pass - def teardown(self): + def teardown_test(self): global MyTest, MyTest2 MyTest, MyTest2 = None, None @@ -3690,7 +3690,7 @@ class EventPropagateTest(fixtures.TestBase): class CollectionInitTest(fixtures.TestBase): - def setUp(self): + def setup_test(self): class A(object): pass @@ -3749,7 +3749,7 @@ class CollectionInitTest(fixtures.TestBase): class TestUnlink(fixtures.TestBase): - def setUp(self): + def setup_test(self): class A(object): pass diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 2f54f7fff..014fa152e 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -428,39 +428,46 @@ class BindIntegrationTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - c = testing.db.connect() - - sess = Session(bind=c, autocommit=False) - u = User(name="u1") - sess.add(u) - sess.flush() - sess.close() - assert not c.in_transaction() - assert c.exec_driver_sql("select count(1) from users").scalar() == 0 - - sess = Session(bind=c, autocommit=False) - u = User(name="u2") - sess.add(u) - sess.flush() - sess.commit() - assert not c.in_transaction() - assert c.exec_driver_sql("select count(1) from users").scalar() == 1 - - with c.begin(): - c.exec_driver_sql("delete from users") - assert c.exec_driver_sql("select count(1) from users").scalar() == 0 - - c = testing.db.connect() - - trans = c.begin() - sess = Session(bind=c, autocommit=True) - u = User(name="u3") - sess.add(u) - sess.flush() - assert c.in_transaction() - trans.commit() - assert not c.in_transaction() - assert c.exec_driver_sql("select count(1) from users").scalar() == 1 + with testing.db.connect() as c: + + sess = Session(bind=c, autocommit=False) + u = User(name="u1") + sess.add(u) + sess.flush() + sess.close() + assert not c.in_transaction() + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 0 + ) + + sess = Session(bind=c, autocommit=False) + u = User(name="u2") + sess.add(u) + sess.flush() + sess.commit() + assert not c.in_transaction() + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 1 + ) + + with c.begin(): + c.exec_driver_sql("delete from users") + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 0 + ) + + with testing.db.connect() as c: + trans = c.begin() + sess = Session(bind=c, autocommit=True) + u = User(name="u3") + sess.add(u) + sess.flush() + assert c.in_transaction() + trans.commit() + assert not c.in_transaction() + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 1 + ) class SessionBindTest(fixtures.MappedTest): @@ -506,6 +513,7 @@ class SessionBindTest(fixtures.MappedTest): finally: if hasattr(bind, "close"): bind.close() + sess.close() def test_session_unbound(self): Foo = self.classes.Foo diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 3d09bd446..2a0aafbbc 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -92,13 +92,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): return str((id(self), self.a, self.b, self.c)) @classmethod - def setup_class(cls): + def setup_test_class(cls): instrumentation.register_class(cls.Entity) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): instrumentation.unregister_class(cls.Entity) - super(CollectionsTest, cls).teardown_class() _entity_id = 1 diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py index df652daf4..20d8ecc2d 100644 --- a/test/orm/test_compile.py +++ b/test/orm/test_compile.py @@ -19,7 +19,7 @@ from sqlalchemy.testing import fixtures class CompileTest(fixtures.ORMTest): """test various mapper compilation scenarios""" - def teardown(self): + def teardown_test(self): clear_mappers() def test_with_polymorphic(self): diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index e1ef67fed..ed11b89c9 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -1743,8 +1743,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): id = Column(Integer, primary_key=True) a_id = Column(ForeignKey("a.id", name="a_fk")) - def setup(self): - super(PostUpdateOnUpdateTest, self).setup() + def setup_test(self): PostUpdateOnUpdateTest.counter = count() PostUpdateOnUpdateTest.db_counter = count() diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 6d946cfe6..15063ebe9 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -2199,6 +2199,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture): class AutocommitClosesOnFailTest(fixtures.MappedTest): __requires__ = ("deferrable_fks",) + __only_on__ = ("postgresql+psycopg2",) # needs #5824 for asyncpg @classmethod def define_tables(cls, metadata): @@ -4498,44 +4499,49 @@ class JoinTest(QueryTest, AssertsCompiledSQL): warnings += (join_aliased_dep,) # load a user who has an order that contains item id 3 and address # id 1 (order 3, owned by jack) - with testing.expect_deprecated_20(*warnings): - result = ( - fixture_session() - .query(User) - .join("orders", "items", aliased=aliased_) - .filter_by(id=3) - .reset_joinpoint() - .join("orders", "address", aliased=aliased_) - .filter_by(id=1) - .all() - ) - assert [User(id=7, name="jack")] == result - with testing.expect_deprecated_20(*warnings): - result = ( - fixture_session() - .query(User) - .join("orders", "items", aliased=aliased_, isouter=True) - .filter_by(id=3) - .reset_joinpoint() - .join("orders", "address", aliased=aliased_, isouter=True) - .filter_by(id=1) - .all() - ) - assert [User(id=7, name="jack")] == result - - with testing.expect_deprecated_20(*warnings): - result = ( - fixture_session() - .query(User) - .outerjoin("orders", "items", aliased=aliased_) - .filter_by(id=3) - .reset_joinpoint() - .outerjoin("orders", "address", aliased=aliased_) - .filter_by(id=1) - .all() - ) - assert [User(id=7, name="jack")] == result + with fixture_session() as sess: + with testing.expect_deprecated_20(*warnings): + result = ( + sess.query(User) + .join("orders", "items", aliased=aliased_) + .filter_by(id=3) + .reset_joinpoint() + .join("orders", "address", aliased=aliased_) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result + + with fixture_session() as sess: + with testing.expect_deprecated_20(*warnings): + result = ( + sess.query(User) + .join( + "orders", "items", aliased=aliased_, isouter=True + ) + .filter_by(id=3) + .reset_joinpoint() + .join( + "orders", "address", aliased=aliased_, isouter=True + ) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result + + with fixture_session() as sess: + with testing.expect_deprecated_20(*warnings): + result = ( + sess.query(User) + .outerjoin("orders", "items", aliased=aliased_) + .filter_by(id=3) + .reset_joinpoint() + .outerjoin("orders", "address", aliased=aliased_) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result class AliasFromCorrectLeftTest( diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 4498fc1ff..7eedb37c9 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -559,15 +559,15 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): 5, ), ]: - sess = fixture_session() + with fixture_session() as sess: - def go(): - eq_( - sess.query(User).options(*opt).order_by(User.id).all(), - self.static.user_item_keyword_result, - ) + def go(): + eq_( + sess.query(User).options(*opt).order_by(User.id).all(), + self.static.user_item_keyword_result, + ) - self.assert_sql_count(testing.db, go, count) + self.assert_sql_count(testing.db, go, count) def test_disable_dynamic(self): """test no joined option on a dynamic.""" diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 1c918a88c..e85c23d6f 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -45,13 +45,14 @@ from test.orm import _fixtures class _RemoveListeners(object): - def teardown(self): + @testing.fixture(autouse=True) + def _remove_listeners(self): + yield events.MapperEvents._clear() events.InstanceEvents._clear() events.SessionEvents._clear() events.InstrumentationEvents._clear() events.QueryEvents._clear() - super(_RemoveListeners, self).teardown() class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): @@ -1174,7 +1175,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): argnames="target, event_name, fn", )(fn) - def teardown(self): + def teardown_test(self): A = self.classes.A A._sa_class_manager.dispatch._clear() diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index cc9596466..f622bff02 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -2395,16 +2395,19 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): ] adalias = addresses.alias() - q = ( - fixture_session() - .query(User) - .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name)) - .outerjoin(adalias, "addresses") - .group_by(users) - .order_by(users.c.id) - ) - assert q.all() == expected + with fixture_session() as sess: + q = ( + sess.query(User) + .add_columns( + func.count(adalias.c.id), ("Name:" + users.c.name) + ) + .outerjoin(adalias, "addresses") + .group_by(users) + .order_by(users.c.id) + ) + + eq_(q.all(), expected) # test with a straight statement s = ( @@ -2417,52 +2420,57 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): .group_by(*[c for c in users.c]) .order_by(users.c.id) ) - q = fixture_session().query(User) - result = ( - q.add_columns(s.selected_columns.count, s.selected_columns.concat) - .from_statement(s) - .all() - ) - assert result == expected - - sess.expunge_all() - # test with select_entity_from() - q = ( - fixture_session() - .query(User) - .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name)) - .select_entity_from(users.outerjoin(addresses)) - .group_by(users) - .order_by(users.c.id) - ) + with fixture_session() as sess: + q = sess.query(User) + result = ( + q.add_columns( + s.selected_columns.count, s.selected_columns.concat + ) + .from_statement(s) + .all() + ) + eq_(result, expected) - assert q.all() == expected - sess.expunge_all() + with fixture_session() as sess: + # test with select_entity_from() + q = ( + fixture_session() + .query(User) + .add_columns( + func.count(addresses.c.id), ("Name:" + users.c.name) + ) + .select_entity_from(users.outerjoin(addresses)) + .group_by(users) + .order_by(users.c.id) + ) - q = ( - fixture_session() - .query(User) - .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name)) - .outerjoin("addresses") - .group_by(users) - .order_by(users.c.id) - ) + eq_(q.all(), expected) - assert q.all() == expected - sess.expunge_all() + with fixture_session() as sess: + q = ( + sess.query(User) + .add_columns( + func.count(addresses.c.id), ("Name:" + users.c.name) + ) + .outerjoin("addresses") + .group_by(users) + .order_by(users.c.id) + ) + eq_(q.all(), expected) - q = ( - fixture_session() - .query(User) - .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name)) - .outerjoin(adalias, "addresses") - .group_by(users) - .order_by(users.c.id) - ) + with fixture_session() as sess: + q = ( + sess.query(User) + .add_columns( + func.count(adalias.c.id), ("Name:" + users.c.name) + ) + .outerjoin(adalias, "addresses") + .group_by(users) + .order_by(users.c.id) + ) - assert q.all() == expected - sess.expunge_all() + eq_(q.all(), expected) def test_expression_selectable_matches_mzero(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index 3061de309..43cf81e6d 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -717,24 +717,24 @@ class LazyTest(_fixtures.FixtureTest): ), ) - sess = fixture_session() + with fixture_session() as sess: - # load address - a1 = ( - sess.query(Address) - .filter_by(email_address="ed@wood.com") - .one() - ) + # load address + a1 = ( + sess.query(Address) + .filter_by(email_address="ed@wood.com") + .one() + ) - # load user that is attached to the address - u1 = sess.query(User).get(8) + # load user that is attached to the address + u1 = sess.query(User).get(8) - def go(): - # lazy load of a1.user should get it from the session - assert a1.user is u1 + def go(): + # lazy load of a1.user should get it from the session + assert a1.user is u1 - self.assert_sql_count(testing.db, go, 0) - sa.orm.clear_mappers() + self.assert_sql_count(testing.db, go, 0) + sa.orm.clear_mappers() def test_uses_get_compatible_types(self): """test the use_get optimization with compatible @@ -789,24 +789,23 @@ class LazyTest(_fixtures.FixtureTest): properties=dict(user=relationship(mapper(User, users))), ) - sess = fixture_session() - - # load address - a1 = ( - sess.query(Address) - .filter_by(email_address="ed@wood.com") - .one() - ) + with fixture_session() as sess: + # load address + a1 = ( + sess.query(Address) + .filter_by(email_address="ed@wood.com") + .one() + ) - # load user that is attached to the address - u1 = sess.query(User).get(8) + # load user that is attached to the address + u1 = sess.query(User).get(8) - def go(): - # lazy load of a1.user should get it from the session - assert a1.user is u1 + def go(): + # lazy load of a1.user should get it from the session + assert a1.user is u1 - self.assert_sql_count(testing.db, go, 0) - sa.orm.clear_mappers() + self.assert_sql_count(testing.db, go, 0) + sa.orm.clear_mappers() def test_many_to_one(self): users, Address, addresses, User = ( diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py index 0e8ac97e3..42b5b3e45 100644 --- a/test/orm/test_load_on_fks.py +++ b/test/orm/test_load_on_fks.py @@ -9,14 +9,12 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import instance_state from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column -engine = testing.db - - class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): - def setUp(self): + def setup_test(self): global Parent, Child, Base Base = declarative_base() @@ -36,27 +34,27 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): ) parent_id = Column(Integer, ForeignKey("parent.id")) - Base.metadata.create_all(engine) + Base.metadata.create_all(testing.db) - def tearDown(self): - Base.metadata.drop_all(engine) + def teardown_test(self): + Base.metadata.drop_all(testing.db) def test_annoying_autoflush_one(self): - sess = Session(engine) + sess = fixture_session() p1 = Parent() sess.add(p1) p1.children = [] def test_annoying_autoflush_two(self): - sess = Session(engine) + sess = fixture_session() p1 = Parent() sess.add(p1) assert p1.children == [] def test_dont_load_if_no_keys(self): - sess = Session(engine) + sess = fixture_session() p1 = Parent() sess.add(p1) @@ -68,7 +66,9 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): - def setUp(self): + __leave_connections_for_teardown__ = True + + def setup_test(self): global Parent, Child, Base Base = declarative_base() @@ -91,10 +91,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): parent = relationship(Parent, backref=backref("children")) - Base.metadata.create_all(engine) + Base.metadata.create_all(testing.db) global sess, p1, p2, c1, c2 - sess = Session(bind=engine) + sess = Session(bind=testing.db) p1 = Parent() p2 = Parent() @@ -105,9 +105,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): sess.commit() - def tearDown(self): + def teardown_test(self): sess.rollback() - Base.metadata.drop_all(engine) + Base.metadata.drop_all(testing.db) def test_load_on_pending_allows_backref_event(self): Child.parent.property.load_on_pending = True diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 013eb21e1..d182fd2c1 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2560,7 +2560,7 @@ class MagicNamesTest(fixtures.MappedTest): class DocumentTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.mapper = registry().map_imperatively @@ -2624,14 +2624,14 @@ class DocumentTest(fixtures.TestBase): class ORMLoggingTest(_fixtures.FixtureTest): - def setup(self): + def setup_test(self): self.buf = logging.handlers.BufferingHandler(100) for log in [logging.getLogger("sqlalchemy.orm")]: log.addHandler(self.buf) self.mapper = registry().map_imperatively - def teardown(self): + def teardown_test(self): for log in [logging.getLogger("sqlalchemy.orm")]: log.removeHandler(self.buf) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index b22b318e9..6f47c1238 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -1523,9 +1523,7 @@ class PickleTest(PathTest, QueryTest): class LocalOptsTest(PathTest, QueryTest): @classmethod - def setup_class(cls): - super(LocalOptsTest, cls).setup_class() - + def setup_test_class(cls): @strategy_options.loader_option() def some_col_opt_only(loadopt, key, opts): return loadopt.set_column_strategy( diff --git a/test/orm/test_query.py b/test/orm/test_query.py index fd8e849fb..7546ba162 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -6000,12 +6000,19 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): [User.orders_syn, Order.items_syn], [User.orders_syn_2, Order.items_syn], ): - q = fixture_session().query(User) - for path in j: - q = q.join(path) - q = q.filter_by(id=3) - result = q.all() - assert [User(id=7, name="jack"), User(id=9, name="fred")] == result + with fixture_session() as sess: + q = sess.query(User) + for path in j: + q = q.join(path) + q = q.filter_by(id=3) + result = q.all() + eq_( + result, + [ + User(id=7, name="jack"), + User(id=9, name="fred"), + ], + ) def test_with_parent(self): Order, User = self.classes.Order, self.classes.User @@ -6018,17 +6025,17 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): ("name_syn", "orders_syn"), ("name_syn", "orders_syn_2"), ): - sess = fixture_session() - q = sess.query(User) + with fixture_session() as sess: + q = sess.query(User) - u1 = q.filter_by(**{nameprop: "jack"}).one() + u1 = q.filter_by(**{nameprop: "jack"}).one() - o = sess.query(Order).with_parent(u1, property=orderprop).all() - assert [ - Order(description="order 1"), - Order(description="order 3"), - Order(description="order 5"), - ] == o + o = sess.query(Order).with_parent(u1, property=orderprop).all() + assert [ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ] == o def test_froms_aliased_col(self): Address, User = self.classes.Address, self.classes.User diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index 12c084b2d..ef1bf2e60 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -26,7 +26,7 @@ from sqlalchemy.testing import mock class _JoinFixtures(object): @classmethod - def setup_class(cls): + def setup_test_class(cls): m = MetaData() cls.left = Table( "lft", diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 5979f08ae..8d73cd40e 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -637,7 +637,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase): """ - def teardown(self): + def teardown_test(self): clear_mappers() def _fixture_one( @@ -2474,7 +2474,7 @@ class JoinConditionErrorTest(fixtures.TestBase): assert_raises(sa.exc.ArgumentError, configure_mappers) - def teardown(self): + def teardown_test(self): clear_mappers() @@ -4354,7 +4354,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest): class SecondaryArgTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): clear_mappers() @testing.combinations((True,), (False,)) diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index 5535fe5d6..4895c7d3a 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -713,35 +713,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def _do_query_tests(self, opts, count): Order, User = self.classes.Order, self.classes.User - sess = fixture_session() + with fixture_session() as sess: - def go(): - eq_( - sess.query(User).options(*opts).order_by(User.id).all(), - self.static.user_item_keyword_result, - ) + def go(): + eq_( + sess.query(User).options(*opts).order_by(User.id).all(), + self.static.user_item_keyword_result, + ) - self.assert_sql_count(testing.db, go, count) + self.assert_sql_count(testing.db, go, count) - eq_( - sess.query(User) - .options(*opts) - .filter(User.name == "fred") - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[2:3], - ) + eq_( + sess.query(User) + .options(*opts) + .filter(User.name == "fred") + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[2:3], + ) - sess = fixture_session() - eq_( - sess.query(User) - .options(*opts) - .join(User.orders) - .filter(Order.id == 3) - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[0:1], - ) + with fixture_session() as sess: + eq_( + sess.query(User) + .options(*opts) + .join(User.orders) + .filter(Order.id == 3) + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[0:1], + ) def test_cyclical(self): """A circular eager relationship breaks the cycle with a lazy loader""" diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 20c4752b8..3d4566af3 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1273,7 +1273,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): run_inserts = None - def setup(self): + def setup_test(self): mapper(self.classes.User, self.tables.users) def _assert_modified(self, u1): @@ -1288,11 +1288,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def _assert_no_cycle(self, u1): assert sa.orm.attributes.instance_state(u1)._strong_obj is None - def _persistent_fixture(self): + def _persistent_fixture(self, gc_collect=False): User = self.classes.User u1 = User() u1.name = "ed" - sess = fixture_session() + if gc_collect: + sess = Session(testing.db) + else: + sess = fixture_session() sess.add(u1) sess.flush() return sess, u1 @@ -1389,14 +1392,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): @testing.requires.predictable_gc def test_move_gc_session_persistent_dirty(self): - sess, u1 = self._persistent_fixture() + sess, u1 = self._persistent_fixture(gc_collect=True) u1.name = "edchanged" self._assert_cycle(u1) self._assert_modified(u1) del sess gc_collect() self._assert_cycle(u1) - s2 = fixture_session() + s2 = Session(testing.db) s2.add(u1) self._assert_cycle(u1) self._assert_modified(u1) @@ -1565,7 +1568,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): mapper(User, users) - sess = fixture_session() + sess = Session(testing.db) u1 = User(name="u1") sess.add(u1) @@ -1573,7 +1576,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): # can't add u1 to Session, # already belongs to u2 - s2 = fixture_session() + s2 = Session(testing.db) assert_raises_message( sa.exc.InvalidRequestError, r".*is already attached to session", @@ -1725,11 +1728,10 @@ class DisposedStates(fixtures.MappedTest): mapper(T, cls.tables.t1) - def teardown(self): + def teardown_test(self): from sqlalchemy.orm.session import _sessions _sessions.clear() - super(DisposedStates, self).teardown() def _set_imap_in_disposal(self, sess, *objs): """remove selected objects from the given session, as though diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index fe20442a3..150cee222 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -734,35 +734,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def _do_query_tests(self, opts, count): Order, User = self.classes.Order, self.classes.User - sess = fixture_session() + with fixture_session() as sess: - def go(): - eq_( - sess.query(User).options(*opts).order_by(User.id).all(), - self.static.user_item_keyword_result, - ) + def go(): + eq_( + sess.query(User).options(*opts).order_by(User.id).all(), + self.static.user_item_keyword_result, + ) - self.assert_sql_count(testing.db, go, count) + self.assert_sql_count(testing.db, go, count) - eq_( - sess.query(User) - .options(*opts) - .filter(User.name == "fred") - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[2:3], - ) + eq_( + sess.query(User) + .options(*opts) + .filter(User.name == "fred") + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[2:3], + ) - sess = fixture_session() - eq_( - sess.query(User) - .options(*opts) - .join(User.orders) - .filter(Order.id == 3) - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[0:1], - ) + with fixture_session() as sess: + eq_( + sess.query(User) + .options(*opts) + .join(User.orders) + .filter(Order.id == 3) + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[0:1], + ) def test_cyclical(self): """A circular eager relationship breaks the cycle with a lazy loader""" diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 550cf6535..7f77b01c7 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -66,17 +66,18 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - conn = testing.db.connect() - trans = conn.begin() - sess = Session(bind=conn, autocommit=False, autoflush=True) - sess.begin(subtransactions=True) - u = User(name="ed") - sess.add(u) - sess.flush() - sess.commit() # commit does nothing - trans.rollback() # rolls back - assert len(sess.query(User).all()) == 0 - sess.close() + + with testing.db.connect() as conn: + trans = conn.begin() + sess = Session(bind=conn, autocommit=False, autoflush=True) + sess.begin(subtransactions=True) + u = User(name="ed") + sess.add(u) + sess.flush() + sess.commit() # commit does nothing + trans.rollback() # rolls back + assert len(sess.query(User).all()) == 0 + sess.close() @engines.close_open_connections def test_subtransaction_on_external_no_begin(self): @@ -260,34 +261,33 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users = self.tables.users engine = Engine._future_facade(testing.db) - session = Session(engine, autocommit=False) - - session.begin() - session.connection().execute(users.insert().values(name="user1")) - session.begin(subtransactions=True) - session.begin_nested() - session.connection().execute(users.insert().values(name="user2")) - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 2 - ) - session.rollback() - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 1 - ) - session.connection().execute(users.insert().values(name="user3")) - session.commit() - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 2 - ) + with Session(engine, autocommit=False) as session: + session.begin() + session.connection().execute(users.insert().values(name="user1")) + session.begin(subtransactions=True) + session.begin_nested() + session.connection().execute(users.insert().values(name="user2")) + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 2 + ) + session.rollback() + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 1 + ) + session.connection().execute(users.insert().values(name="user3")) + session.commit() + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 2 + ) @testing.requires.savepoints def test_dirty_state_transferred_deep_nesting(self): @@ -295,27 +295,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - s = Session(testing.db) - u1 = User(name="u1") - s.add(u1) - s.commit() - - nt1 = s.begin_nested() - nt2 = s.begin_nested() - u1.name = "u2" - assert attributes.instance_state(u1) not in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty - s.flush() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty + with fixture_session() as s: + u1 = User(name="u1") + s.add(u1) + s.commit() + + nt1 = s.begin_nested() + nt2 = s.begin_nested() + u1.name = "u2" + assert attributes.instance_state(u1) not in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty + s.flush() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty - s.commit() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) in nt1._dirty + s.commit() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) in nt1._dirty - s.rollback() - assert attributes.instance_state(u1).expired - eq_(u1.name, "u1") + s.rollback() + assert attributes.instance_state(u1).expired + eq_(u1.name, "u1") @testing.requires.savepoints def test_dirty_state_transferred_deep_nesting_future(self): @@ -323,27 +323,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - s = Session(testing.db, future=True) - u1 = User(name="u1") - s.add(u1) - s.commit() - - nt1 = s.begin_nested() - nt2 = s.begin_nested() - u1.name = "u2" - assert attributes.instance_state(u1) not in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty - s.flush() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty + with fixture_session(future=True) as s: + u1 = User(name="u1") + s.add(u1) + s.commit() + + nt1 = s.begin_nested() + nt2 = s.begin_nested() + u1.name = "u2" + assert attributes.instance_state(u1) not in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty + s.flush() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty - nt2.commit() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) in nt1._dirty + nt2.commit() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) in nt1._dirty - nt1.rollback() - assert attributes.instance_state(u1).expired - eq_(u1.name, "u1") + nt1.rollback() + assert attributes.instance_state(u1).expired + eq_(u1.name, "u1") @testing.requires.independent_connections def test_transactions_isolated(self): @@ -1049,23 +1049,25 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - session = Session(testing.db) + with fixture_session() as session: - with expect_warnings(".*during handling of a previous exception.*"): - session.begin_nested() - savepoint = session.connection()._nested_transaction._savepoint + with expect_warnings( + ".*during handling of a previous exception.*" + ): + session.begin_nested() + savepoint = session.connection()._nested_transaction._savepoint - # force the savepoint to disappear - session.connection().dialect.do_release_savepoint( - session.connection(), savepoint - ) + # force the savepoint to disappear + session.connection().dialect.do_release_savepoint( + session.connection(), savepoint + ) - # now do a broken flush - session.add_all([User(id=1), User(id=1)]) + # now do a broken flush + session.add_all([User(id=1), User(id=1)]) - assert_raises_message( - sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush - ) + assert_raises_message( + sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush + ) class _LocalFixture(FixtureTest): @@ -1170,39 +1172,40 @@ class SubtransactionRecipeTest(FixtureTest): def test_recipe_heavy_nesting(self, subtransaction_recipe): users = self.tables.users - session = Session(testing.db, future=self.future) - - with subtransaction_recipe(session): - session.connection().execute(users.insert().values(name="user1")) + with fixture_session(future=self.future) as session: with subtransaction_recipe(session): - savepoint = session.begin_nested() session.connection().execute( - users.insert().values(name="user2") + users.insert().values(name="user1") ) + with subtransaction_recipe(session): + savepoint = session.begin_nested() + session.connection().execute( + users.insert().values(name="user2") + ) + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 2 + ) + savepoint.rollback() + + with subtransaction_recipe(session): + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 1 + ) + session.connection().execute( + users.insert().values(name="user3") + ) assert ( session.connection() .exec_driver_sql("select count(1) from users") .scalar() == 2 ) - savepoint.rollback() - - with subtransaction_recipe(session): - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 1 - ) - session.connection().execute( - users.insert().values(name="user3") - ) - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 2 - ) @engines.close_open_connections def test_recipe_subtransaction_on_external_subtrans( @@ -1228,13 +1231,12 @@ class SubtransactionRecipeTest(FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = Session(testing.db, future=self.future) - - with subtransaction_recipe(sess): - u = User(name="u1") - sess.add(u) - sess.close() - assert len(sess.query(User).all()) == 1 + with fixture_session(future=self.future) as sess: + with subtransaction_recipe(sess): + u = User(name="u1") + sess.add(u) + sess.close() + assert len(sess.query(User).all()) == 1 def test_recipe_subtransaction_on_noautocommit( self, subtransaction_recipe @@ -1242,16 +1244,15 @@ class SubtransactionRecipeTest(FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = Session(testing.db, future=self.future) - - sess.begin() - with subtransaction_recipe(sess): - u = User(name="u1") - sess.add(u) - sess.flush() - sess.rollback() # rolls back - assert len(sess.query(User).all()) == 0 - sess.close() + with fixture_session(future=self.future) as sess: + sess.begin() + with subtransaction_recipe(sess): + u = User(name="u1") + sess.add(u) + sess.flush() + sess.rollback() # rolls back + assert len(sess.query(User).all()) == 0 + sess.close() @testing.requires.savepoints def test_recipe_mixed_transaction_control(self, subtransaction_recipe): @@ -1259,30 +1260,28 @@ class SubtransactionRecipeTest(FixtureTest): mapper(User, users) - sess = Session(testing.db, future=self.future) + with fixture_session(future=self.future) as sess: - sess.begin() - sess.begin_nested() + sess.begin() + sess.begin_nested() - with subtransaction_recipe(sess): + with subtransaction_recipe(sess): - sess.add(User(name="u1")) + sess.add(User(name="u1")) - sess.commit() - sess.commit() + sess.commit() + sess.commit() - eq_(len(sess.query(User).all()), 1) - sess.close() + eq_(len(sess.query(User).all()), 1) + sess.close() - t1 = sess.begin() - t2 = sess.begin_nested() - - sess.add(User(name="u2")) + t1 = sess.begin() + t2 = sess.begin_nested() - t2.commit() - assert sess._legacy_transaction() is t1 + sess.add(User(name="u2")) - sess.close() + t2.commit() + assert sess._legacy_transaction() is t1 def test_recipe_error_on_using_inactive_session_commands( self, subtransaction_recipe @@ -1290,56 +1289,55 @@ class SubtransactionRecipeTest(FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session(testing.db, future=self.future) - sess.begin() - - try: - with subtransaction_recipe(sess): - sess.add(User(name="u1")) - sess.flush() - raise Exception("force rollback") - except: - pass - - if self.recipe_rollsback_early: - # that was a real rollback, so no transaction - assert not sess.in_transaction() - is_(sess.get_transaction(), None) - else: - assert sess.in_transaction() - - sess.close() - assert not sess.in_transaction() - - def test_recipe_multi_nesting(self, subtransaction_recipe): - sess = Session(testing.db, future=self.future) - - with subtransaction_recipe(sess): - assert sess.in_transaction() + with fixture_session(future=self.future) as sess: + sess.begin() try: with subtransaction_recipe(sess): - assert sess._legacy_transaction() + sess.add(User(name="u1")) + sess.flush() raise Exception("force rollback") except: pass if self.recipe_rollsback_early: + # that was a real rollback, so no transaction assert not sess.in_transaction() + is_(sess.get_transaction(), None) else: assert sess.in_transaction() - assert not sess.in_transaction() + sess.close() + assert not sess.in_transaction() + + def test_recipe_multi_nesting(self, subtransaction_recipe): + with fixture_session(future=self.future) as sess: + with subtransaction_recipe(sess): + assert sess.in_transaction() + + try: + with subtransaction_recipe(sess): + assert sess._legacy_transaction() + raise Exception("force rollback") + except: + pass + + if self.recipe_rollsback_early: + assert not sess.in_transaction() + else: + assert sess.in_transaction() + + assert not sess.in_transaction() def test_recipe_deactive_status_check(self, subtransaction_recipe): - sess = Session(testing.db, future=self.future) - sess.begin() + with fixture_session(future=self.future) as sess: + sess.begin() - with subtransaction_recipe(sess): - sess.rollback() + with subtransaction_recipe(sess): + sess.rollback() - assert not sess.in_transaction() - sess.commit() # no error + assert not sess.in_transaction() + sess.commit() # no error class FixtureDataTest(_LocalFixture): @@ -1394,28 +1392,28 @@ class CleanSavepointTest(FixtureTest): mapper(User, users) - s = Session(bind=testing.db, future=future) - u1 = User(name="u1") - u2 = User(name="u2") - s.add_all([u1, u2]) - s.commit() - u1.name - u2.name - trans = s._transaction - assert trans is not None - s.begin_nested() - update_fn(s, u2) - eq_(u2.name, "u2modified") - s.rollback() + with fixture_session(future=future) as s: + u1 = User(name="u1") + u2 = User(name="u2") + s.add_all([u1, u2]) + s.commit() + u1.name + u2.name + trans = s._transaction + assert trans is not None + s.begin_nested() + update_fn(s, u2) + eq_(u2.name, "u2modified") + s.rollback() - if future: - assert s._transaction is None - assert "name" not in u1.__dict__ - else: - assert s._transaction is trans - eq_(u1.__dict__["name"], "u1") - assert "name" not in u2.__dict__ - eq_(u2.name, "u2") + if future: + assert s._transaction is None + assert "name" not in u1.__dict__ + else: + assert s._transaction is trans + eq_(u1.__dict__["name"], "u1") + assert "name" not in u2.__dict__ + eq_(u2.name, "u2") @testing.requires.savepoints def test_rollback_ignores_clean_on_savepoint(self): @@ -2116,82 +2114,108 @@ class ContextManagerPlusFutureTest(FixtureTest): eq_(sess.query(User).count(), 1) def test_explicit_begin(self): - s1 = Session(testing.db) - with s1.begin() as trans: - is_(trans, s1._legacy_transaction()) - s1.connection() + with fixture_session() as s1: + with s1.begin() as trans: + is_(trans, s1._legacy_transaction()) + s1.connection() - is_(s1._transaction, None) + is_(s1._transaction, None) def test_no_double_begin_explicit(self): - s1 = Session(testing.db) - s1.begin() - assert_raises_message( - sa_exc.InvalidRequestError, - "A transaction is already begun on this Session.", - s1.begin, - ) + with fixture_session() as s1: + s1.begin() + assert_raises_message( + sa_exc.InvalidRequestError, + "A transaction is already begun on this Session.", + s1.begin, + ) @testing.requires.savepoints def test_future_rollback_is_global(self): users = self.tables.users - s1 = Session(testing.db, future=True) + with fixture_session(future=True) as s1: + s1.begin() - s1.begin() + s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) - s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) + s1.begin_nested() - s1.begin_nested() - - s1.connection().execute( - users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}] - ) + s1.connection().execute( + users.insert(), + [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}], + ) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 3, + ) - # rolls back the whole transaction - s1.rollback() - is_(s1._legacy_transaction(), None) + # rolls back the whole transaction + s1.rollback() + is_(s1._legacy_transaction(), None) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 0) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 0, + ) - s1.commit() - is_(s1._legacy_transaction(), None) + s1.commit() + is_(s1._legacy_transaction(), None) @testing.requires.savepoints def test_old_rollback_is_local(self): users = self.tables.users - s1 = Session(testing.db) + with fixture_session() as s1: - t1 = s1.begin() + t1 = s1.begin() - s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) + s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) - s1.begin_nested() + s1.begin_nested() - s1.connection().execute( - users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}] - ) + s1.connection().execute( + users.insert(), + [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}], + ) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 3, + ) - # rolls back only the savepoint - s1.rollback() + # rolls back only the savepoint + s1.rollback() - is_(s1._legacy_transaction(), t1) + is_(s1._legacy_transaction(), t1) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 1, + ) - s1.commit() - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1) - is_not(s1._legacy_transaction(), None) + s1.commit() + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 1, + ) + is_not(s1._legacy_transaction(), None) def test_session_as_ctx_manager_one(self): users = self.tables.users - with Session(testing.db) as sess: + with fixture_session() as sess: is_not(sess._legacy_transaction(), None) sess.connection().execute( @@ -2212,7 +2236,7 @@ class ContextManagerPlusFutureTest(FixtureTest): def test_session_as_ctx_manager_future_one(self): users = self.tables.users - with Session(testing.db, future=True) as sess: + with fixture_session(future=True) as sess: is_(sess._legacy_transaction(), None) sess.connection().execute( @@ -2234,7 +2258,7 @@ class ContextManagerPlusFutureTest(FixtureTest): users = self.tables.users try: - with Session(testing.db) as sess: + with fixture_session() as sess: is_not(sess._legacy_transaction(), None) sess.connection().execute( @@ -2250,7 +2274,7 @@ class ContextManagerPlusFutureTest(FixtureTest): users = self.tables.users try: - with Session(testing.db, future=True) as sess: + with fixture_session(future=True) as sess: is_(sess._legacy_transaction(), None) sess.connection().execute( @@ -2265,7 +2289,7 @@ class ContextManagerPlusFutureTest(FixtureTest): def test_begin_context_manager(self): users = self.tables.users - with Session(testing.db) as sess: + with fixture_session() as sess: with sess.begin(): sess.connection().execute( users.insert().values(id=1, name="user1") @@ -2296,12 +2320,13 @@ class ContextManagerPlusFutureTest(FixtureTest): # committed eq_(sess.connection().execute(users.select()).all(), [(1, "user1")]) + sess.close() def test_begin_context_manager_rollback_trans(self): users = self.tables.users try: - with Session(testing.db) as sess: + with fixture_session() as sess: with sess.begin(): sess.connection().execute( users.insert().values(id=1, name="user1") @@ -2318,12 +2343,13 @@ class ContextManagerPlusFutureTest(FixtureTest): # rolled back eq_(sess.connection().execute(users.select()).all(), []) + sess.close() def test_begin_context_manager_rollback_outer(self): users = self.tables.users try: - with Session(testing.db) as sess: + with fixture_session() as sess: with sess.begin(): sess.connection().execute( users.insert().values(id=1, name="user1") @@ -2340,6 +2366,7 @@ class ContextManagerPlusFutureTest(FixtureTest): # committed eq_(sess.connection().execute(users.select()).all(), [(1, "user1")]) + sess.close() def test_sessionmaker_begin_context_manager_rollback_trans(self): users = self.tables.users @@ -2363,6 +2390,7 @@ class ContextManagerPlusFutureTest(FixtureTest): # rolled back eq_(sess.connection().execute(users.select()).all(), []) + sess.close() def test_sessionmaker_begin_context_manager_rollback_outer(self): users = self.tables.users @@ -2386,36 +2414,37 @@ class ContextManagerPlusFutureTest(FixtureTest): # committed eq_(sess.connection().execute(users.select()).all(), [(1, "user1")]) + sess.close() class TransactionFlagsTest(fixtures.TestBase): def test_in_transaction(self): - s1 = Session(testing.db) + with fixture_session() as s1: - eq_(s1.in_transaction(), False) + eq_(s1.in_transaction(), False) - trans = s1.begin() + trans = s1.begin() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) - n1 = s1.begin_nested() + n1 = s1.begin_nested() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), n1) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), n1) - n1.rollback() + n1.rollback() - is_(s1.get_nested_transaction(), None) - is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), None) + is_(s1.get_transaction(), trans) - eq_(s1.in_transaction(), True) + eq_(s1.in_transaction(), True) - s1.commit() + s1.commit() - eq_(s1.in_transaction(), False) - is_(s1.get_transaction(), None) + eq_(s1.in_transaction(), False) + is_(s1.get_transaction(), None) def test_in_transaction_subtransactions(self): """we'd like to do away with subtransactions for future sessions @@ -2425,72 +2454,71 @@ class TransactionFlagsTest(fixtures.TestBase): the external API works. """ - s1 = Session(testing.db) - - eq_(s1.in_transaction(), False) + with fixture_session() as s1: + eq_(s1.in_transaction(), False) - trans = s1.begin() + trans = s1.begin() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) - subtrans = s1.begin(_subtrans=True) - is_(s1.get_transaction(), trans) - eq_(s1.in_transaction(), True) + subtrans = s1.begin(_subtrans=True) + is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) - is_(s1._transaction, subtrans) + is_(s1._transaction, subtrans) - s1.rollback() + s1.rollback() - eq_(s1.in_transaction(), True) - is_(s1._transaction, trans) + eq_(s1.in_transaction(), True) + is_(s1._transaction, trans) - s1.rollback() + s1.rollback() - eq_(s1.in_transaction(), False) - is_(s1._transaction, None) + eq_(s1.in_transaction(), False) + is_(s1._transaction, None) def test_in_transaction_nesting(self): - s1 = Session(testing.db) + with fixture_session() as s1: - eq_(s1.in_transaction(), False) + eq_(s1.in_transaction(), False) - trans = s1.begin() + trans = s1.begin() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) - sp1 = s1.begin_nested() + sp1 = s1.begin_nested() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), sp1) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), sp1) - sp2 = s1.begin_nested() + sp2 = s1.begin_nested() - eq_(s1.in_transaction(), True) - eq_(s1.in_nested_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), sp2) + eq_(s1.in_transaction(), True) + eq_(s1.in_nested_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), sp2) - sp2.rollback() + sp2.rollback() - eq_(s1.in_transaction(), True) - eq_(s1.in_nested_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), sp1) + eq_(s1.in_transaction(), True) + eq_(s1.in_nested_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), sp1) - sp1.rollback() + sp1.rollback() - is_(s1.get_nested_transaction(), None) - eq_(s1.in_transaction(), True) - eq_(s1.in_nested_transaction(), False) - is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), None) + eq_(s1.in_transaction(), True) + eq_(s1.in_nested_transaction(), False) + is_(s1.get_transaction(), trans) - s1.rollback() + s1.rollback() - eq_(s1.in_transaction(), False) - is_(s1.get_transaction(), None) + eq_(s1.in_transaction(), False) + is_(s1.get_transaction(), None) class NaturalPKRollbackTest(fixtures.MappedTest): @@ -2674,8 +2702,11 @@ class NaturalPKRollbackTest(fixtures.MappedTest): class JoinIntoAnExternalTransactionFixture(object): """Test the "join into an external transaction" examples""" - def setup(self): - self.connection = testing.db.connect() + __leave_connections_for_teardown__ = True + + def setup_test(self): + self.engine = testing.db + self.connection = self.engine.connect() self.metadata = MetaData() self.table = Table( @@ -2686,6 +2717,17 @@ class JoinIntoAnExternalTransactionFixture(object): self.setup_session() + def teardown_test(self): + self.teardown_session() + + with self.connection.begin(): + self._assert_count(0) + + with self.connection.begin(): + self.table.drop(self.connection) + + self.connection.close() + def test_something(self): A = self.A @@ -2727,18 +2769,6 @@ class JoinIntoAnExternalTransactionFixture(object): ) eq_(result, count) - def teardown(self): - self.teardown_session() - - with self.connection.begin(): - self._assert_count(0) - - with self.connection.begin(): - self.table.drop(self.connection) - - # return connection to the Engine - self.connection.close() - class NewStyleJoinIntoAnExternalTransactionTest( JoinIntoAnExternalTransactionFixture @@ -2775,7 +2805,8 @@ class NewStyleJoinIntoAnExternalTransactionTest( # rollback - everything that happened with the # Session above (including calls to commit()) # is rolled back. - self.trans.rollback() + if self.trans.is_active: + self.trans.rollback() class FutureJoinIntoAnExternalTransactionTest( diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 84373b2dc..2c35bec45 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -203,14 +203,6 @@ class UnicodeSchemaTest(fixtures.MappedTest): cls.tables["t1"] = t1 cls.tables["t2"] = t2 - @classmethod - def setup_class(cls): - super(UnicodeSchemaTest, cls).setup_class() - - @classmethod - def teardown_class(cls): - super(UnicodeSchemaTest, cls).teardown_class() - def test_mapping(self): t2, t1 = self.tables.t2, self.tables.t1 diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 4e713627c..65089f773 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -771,14 +771,13 @@ class RudimentaryFlushTest(UOWTest): class SingleCycleTest(UOWTest): - def teardown(self): + def teardown_test(self): engines.testing_reaper.rollback_all() # mysql can't handle delete from nodes # since it doesn't deal with the FKs correctly, # so wipe out the parent_id first with testing.db.begin() as conn: conn.execute(self.tables.nodes.update().values(parent_id=None)) - super(SingleCycleTest, self).teardown() def test_one_to_many_save(self): Node, nodes = self.classes.Node, self.tables.nodes diff --git a/test/requirements.py b/test/requirements.py index d5a718372..3c9b39ac7 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1624,15 +1624,15 @@ class DefaultRequirements(SuiteRequirements): @property def postgresql_utf8_server_encoding(self): + def go(config): + if not against(config, "postgresql"): + return False - return only_if( - lambda config: against(config, "postgresql") - and config.db.connect(close_with_result=True) - .exec_driver_sql("show server_encoding") - .scalar() - .lower() - == "utf8" - ) + with config.db.connect() as conn: + enc = conn.exec_driver_sql("show server_encoding").scalar() + return enc.lower() == "utf8" + + return only_if(go) @property def cxoracle6_or_greater(self): diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index 4bef1df7f..b44971cec 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -26,7 +26,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): metadata = MetaData() global info_table info_table = Table( @@ -52,7 +52,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: info_table.drop(conn) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 70281d4e8..1ac3613f7 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1203,7 +1203,7 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): @classmethod - def setup_class(cls): + def setup_test_class(cls): # TODO: we need to get dialects here somehow, perhaps in test_suite? [ importlib.import_module("sqlalchemy.dialects.%s" % d) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index fdffe04bf..4429753ec 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -4306,7 +4306,7 @@ class StringifySpecialTest(fixtures.TestBase): class KwargPropagationTest(fixtures.TestBase): @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy.sql.expression import ColumnClause, TableClause class CatchCol(ColumnClause): diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 2a2e70bc3..8be7eed1f 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -503,9 +503,8 @@ class DefaultRoundTripTest(fixtures.TablesTest): Column("col11", MyType(), default="foo"), ) - def teardown(self): + def teardown_test(self): self.default_generator["x"] = 50 - super(DefaultRoundTripTest, self).teardown() def test_standalone(self, connection): t = self.tables.default_test @@ -1226,7 +1225,7 @@ class SpecialTypePKTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): class MyInteger(TypeDecorator): impl = Integer diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index acc12a5fe..777565220 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -2488,12 +2488,12 @@ class LegacySequenceExecTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): cls.seq = Sequence("my_sequence") cls.seq.create(testing.db) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): cls.seq.drop(testing.db) def _assert_seq_result(self, ret): @@ -2574,7 +2574,7 @@ class LegacySequenceExecTest(fixtures.TestBase): class DDLDeprecatedBindTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: if inspect(conn).has_table("foo"): conn.execute(schema.DropTable(table("foo"))) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 4edc9d025..a6001ba9d 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -47,7 +47,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): ability to copy and modify a ClauseElement in place.""" @classmethod - def setup_class(cls): + def setup_test_class(cls): global A, B # establish two fictitious ClauseElements. @@ -321,7 +321,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2, t3 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) @@ -1012,7 +1012,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table( "table1", @@ -1196,7 +1196,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) @@ -1943,7 +1943,7 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global table1, table2, table3, table4 def _table(name): @@ -2031,7 +2031,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) @@ -2128,7 +2128,7 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): # fixme: consolidate converage from elsewhere here and expand @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py index 6afe41aac..b0bcee18e 100644 --- a/test/sql/test_from_linter.py +++ b/test/sql/test_from_linter.py @@ -25,7 +25,7 @@ class TestFindUnmatchingFroms(fixtures.TablesTest): Table("table_c", metadata, Column("col_c", Integer, primary_key=True)) Table("table_d", metadata, Column("col_d", Integer, primary_key=True)) - def setup(self): + def setup_test(self): self.a = self.tables.table_a self.b = self.tables.table_b self.c = self.tables.table_c @@ -267,8 +267,10 @@ class TestLinter(fixtures.TablesTest): with self.bind.connect() as conn: conn.execute(query) - def test_no_linting(self): - eng = engines.testing_engine(options={"enable_from_linting": False}) + def test_no_linting(self, metadata, connection): + eng = engines.testing_engine( + options={"enable_from_linting": False, "use_reaper": False} + ) eng.pool = self.bind.pool # needed for SQLite a, b = self.tables("table_a", "table_b") query = select(a.c.col_a).where(b.c.col_b == 5) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 91076f9c3..32ea642d7 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -54,10 +54,10 @@ table1 = table( class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def setup(self): + def setup_test(self): self._registry = deepcopy(functions._registry) - def teardown(self): + def teardown_test(self): functions._registry = self._registry def test_compile(self): @@ -938,7 +938,7 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase): class ExecuteTest(fixtures.TestBase): __backend__ = True - def tearDown(self): + def teardown_test(self): pass def test_conn_execute(self, connection): @@ -1113,10 +1113,10 @@ class ExecuteTest(fixtures.TestBase): class RegisterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def setup(self): + def setup_test(self): self._registry = deepcopy(functions._registry) - def teardown(self): + def teardown_test(self): functions._registry = self._registry def test_GenericFunction_is_registered(self): diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 502f70ce7..9bf351f5c 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -4257,7 +4257,7 @@ class DialectKWArgTest(fixtures.TestBase): with mock.patch("sqlalchemy.dialects.registry.load", load): yield - def teardown(self): + def teardown_test(self): Index._kw_registry.clear() def test_participating(self): diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index aaeed68dd..270e79ba1 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -608,7 +608,7 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): - def setUp(self): + def setup_test(self): class MyTypeCompiler(compiler.GenericTypeCompiler): def visit_mytype(self, type_, **kw): return "MYTYPE" @@ -766,7 +766,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): - def setUp(self): + def setup_test(self): class MyTypeCompiler(compiler.GenericTypeCompiler): def visit_mytype(self, type_, **kw): return "MYTYPE" @@ -2370,7 +2370,7 @@ class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) @@ -2403,7 +2403,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default_enhanced" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 136f10cf4..7ad12c620 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -661,8 +661,8 @@ class CursorResultTest(fixtures.TablesTest): assert_raises(KeyError, lambda: row._mapping["Case_insensitive"]) assert_raises(KeyError, lambda: row._mapping["casesensitive"]) - def test_row_case_sensitive_unoptimized(self): - with engines.testing_engine().connect() as ins_conn: + def test_row_case_sensitive_unoptimized(self, testing_engine): + with testing_engine().connect() as ins_conn: row = ins_conn.execute( select( literal_column("1").label("case_insensitive"), @@ -1234,8 +1234,7 @@ class CursorResultTest(fixtures.TablesTest): eq_(proxy[0], "value") eq_(proxy._mapping["key"], "value") - @testing.provide_metadata - def test_no_rowcount_on_selects_inserts(self): + def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine): """assert that rowcount is only called on deletes and updates. This because cursor.rowcount may can be expensive on some dialects @@ -1244,9 +1243,7 @@ class CursorResultTest(fixtures.TablesTest): """ - metadata = self.metadata - - engine = engines.testing_engine() + engine = testing_engine() t = Table("t1", metadata, Column("data", String(10))) metadata.create_all(engine) @@ -2132,7 +2129,9 @@ class AlternateCursorResultTest(fixtures.TablesTest): @classmethod def setup_bind(cls): - cls.engine = engine = engines.testing_engine("sqlite://") + cls.engine = engine = engines.testing_engine( + "sqlite://", options={"scope": "class"} + ) return engine @classmethod diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index 65325aa6f..5cfc2663f 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -100,12 +100,12 @@ class SequenceExecTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): cls.seq = Sequence("my_sequence") cls.seq.create(testing.db) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): cls.seq.drop(testing.db) def _assert_seq_result(self, ret): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 0e1147800..64ace87df 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1375,7 +1375,7 @@ class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL): class VariantTest(fixtures.TestBase, AssertsCompiledSQL): - def setup(self): + def setup_test(self): class UTypeOne(types.UserDefinedType): def get_col_spec(self): return "UTYPEONE" @@ -2504,7 +2504,7 @@ class BinaryTest(fixtures.TablesTest, AssertsExecutionResults): class JSONTest(fixtures.TestBase): - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", @@ -3445,7 +3445,12 @@ class BooleanTest( @testing.requires.non_native_boolean_unconstrained def test_constraint(self, connection): assert_raises( - (exc.IntegrityError, exc.ProgrammingError, exc.OperationalError), + ( + exc.IntegrityError, + exc.ProgrammingError, + exc.OperationalError, + exc.InternalError, # older pymysql's do this + ), connection.exec_driver_sql, "insert into boolean_table (id, value) values(1, 5)", ) |