summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-01-10 13:44:14 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-01-13 22:10:13 -0500
commitf1e96cb0874927a475d0c111393b7861796dd758 (patch)
tree810f3c43c0d2c6336805ebcf13d86d5cf1226efa
parent7f92fdbd8ec479a61c53c11921ce0688ad4dd94b (diff)
downloadsqlalchemy-f1e96cb0874927a475d0c111393b7861796dd758.tar.gz
reinvent xdist hooks in terms of pytest fixtures
To allow the "connection" pytest fixture and others work correctly in conjunction with setup/teardown that expects to be external to the transaction, remove and prevent any usage of "xdist" style names that are hardcoded by pytest to run inside of fixtures, even function level ones. Instead use pytest autouse fixtures to implement our own r"setup|teardown_test(?:_class)?" methods so that we can ensure function-scoped fixtures are run within them. A new more explicit flow is set up within plugin_base and pytestplugin such that the order of setup/teardown steps, which there are now many, is fully documented and controllable. New granularity has been added to the test teardown phase to distinguish between "end of the test" when lock-holding structures on connections should be released to allow for table drops, vs. "end of the test plus its teardown steps" when we can perform final cleanup on connections and run assertions that everything is closed out. From there we can remove most of the defensive "tear down everything" logic inside of engines which for many years would frequently dispose of pools over and over again, creating for a broken and expensive connection flow. A quick test shows that running test/sql/ against a single Postgresql engine with the new approach uses 75% fewer new connections, creating 42 new connections total, vs. 164 new connections total with the previous system. As part of this, the new fixtures metadata/connection/future_connection have been integrated such that they can be combined together effectively. The fixture_session(), provide_metadata() fixtures have been improved, including that fixture_session() now strongly references sessions which are explicitly torn down before table drops occur afer a test. Major changes have been made to the ConnectionKiller such that it now features different "scopes" for testing engines and will limit its cleanup to those testing engines corresponding to end of test, end of test class, or end of test session. The system by which it tracks DBAPI connections has been reworked, is ultimately somewhat similar to how it worked before but is organized more clearly along with the proxy-tracking logic. A "testing_engine" fixture is also added that works as a pytest fixture rather than a standalone function. The connection cleanup logic should now be very robust, as we now can use the same global connection pools for the whole suite without ever disposing them, while also running a query for PostgreSQL locks remaining after every test and assert there are no open transactions leaking between tests at all. Additional steps are added that also accommodate for asyncio connections not explicitly closed, as is the case for legacy sync-style tests as well as the async tests themselves. As always, hundreds of tests are further refined to use the new fixtures where problems with loose connections were identified, largely as a result of the new PostgreSQL assertions, many more tests have moved from legacy patterns into the newest. An unfortunate discovery during the creation of this system is that autouse fixtures (as well as if they are set up by @pytest.mark.usefixtures) are not usable at our current scale with pytest 4.6.11 running under Python 2. It's unclear if this is due to the older version of pytest or how it implements itself for Python 2, as well as if the issue is CPU slowness or just large memory use, but collecting the full span of tests takes over a minute for a single process when any autouse fixtures are in place and on CI the jobs just time out after ten minutes. So at the moment this patch also reinvents a small version of "autouse" fixtures when py2k is running, which skips generating the real fixture and instead uses two global pytest fixtures (which don't seem to impact performance) to invoke the "autouse" fixtures ourselves outside of pytest. This will limit our ability to do more with fixtures until we can remove py2k support. py.test is still observed to be much slower in collection in the 4.6.11 version compared to modern 6.2 versions, so add support for new TOX_POSTGRESQL_PY2K and TOX_MYSQL_PY2K environment variables that will run the suite for fewer backends under Python 2. For Python 3 pin pytest to modern 6.2 versions where performance for collection has been improved greatly. Includes the following improvements: Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would be raised rather than :class:`.exc.TimeoutError`. Also repaired the :paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using the async engine, which previously would ignore the timeout and block rather than timing out immediately as is the behavior with regular :class:`.QueuePool`. For asyncio the connection pool will now also not interact at all with an asyncio connection whose ConnectionFairy is being garbage collected; a warning that the connection was not properly closed is emitted and the connection is discarded. Within the test suite the ConnectionKiller is now maintaining strong references to all DBAPI connections and ensuring they are released when tests end, including those whose ConnectionFairy proxies are GCed. Identified cx_Oracle.stmtcachesize as a major factor in Oracle test scalability issues, this can be reset on a per-test basis rather than setting it to zero across the board. the addition of this flag has resolved the long-standing oracle "two task" error problem. For SQL Server, changed the temp table style used by the "suite" tests to be the double-pound-sign, i.e. global, variety, which is much easier to test generically. There are already reflection tests that are more finely tuned to both styles of temp table within the mssql test suite. Additionally, added an extra step to the "dropfirst" mechanism for SQL Server that will remove all foreign key constraints first as some issues were observed when using this flag when multiple schemas had not been torn down. Identified and fixed two subtle failure modes in the engine, when commit/rollback fails in a begin() context manager, the connection is explicitly closed, and when "initialize()" fails on the first new connection of a dialect, the transactional state on that connection is still rolled back. Fixes: #5826 Fixes: #5827 Change-Id: Ib1d05cb8c7cf84f9a4bfd23df397dc23c9329bfe
-rw-r--r--doc/build/changelog/unreleased_14/5823.rst13
-rw-r--r--doc/build/changelog/unreleased_14/5827.rst10
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py13
-rw-r--r--lib/sqlalchemy/dialects/mssql/provision.py34
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py1
-rw-r--r--lib/sqlalchemy/dialects/oracle/provision.py42
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py1
-rw-r--r--lib/sqlalchemy/dialects/postgresql/provision.py21
-rw-r--r--lib/sqlalchemy/dialects/sqlite/provision.py6
-rw-r--r--lib/sqlalchemy/engine/base.py18
-rw-r--r--lib/sqlalchemy/engine/create.py9
-rw-r--r--lib/sqlalchemy/future/engine.py15
-rw-r--r--lib/sqlalchemy/pool/base.py35
-rw-r--r--lib/sqlalchemy/testing/__init__.py2
-rw-r--r--lib/sqlalchemy/testing/assertions.py8
-rw-r--r--lib/sqlalchemy/testing/config.py4
-rw-r--r--lib/sqlalchemy/testing/engines.py179
-rw-r--r--lib/sqlalchemy/testing/fixtures.py245
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py5
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py39
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py188
-rw-r--r--lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py112
-rw-r--r--lib/sqlalchemy/testing/provision.py11
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py2
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py32
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py74
-rw-r--r--lib/sqlalchemy/testing/util.py60
-rw-r--r--lib/sqlalchemy/util/queue.py15
-rw-r--r--test/aaa_profiling/test_compiler.py2
-rw-r--r--test/aaa_profiling/test_memusage.py4
-rw-r--r--test/aaa_profiling/test_misc.py2
-rw-r--r--test/aaa_profiling/test_orm.py6
-rw-r--r--test/aaa_profiling/test_pool.py2
-rw-r--r--test/base/test_events.py22
-rw-r--r--test/base/test_inspect.py2
-rw-r--r--test/base/test_tutorials.py4
-rw-r--r--test/dialect/mssql/test_compiler.py2
-rw-r--r--test/dialect/mssql/test_deprecations.py2
-rw-r--r--test/dialect/mssql/test_query.py3
-rw-r--r--test/dialect/mysql/test_compiler.py4
-rw-r--r--test/dialect/mysql/test_reflection.py2
-rw-r--r--test/dialect/oracle/test_compiler.py2
-rw-r--r--test/dialect/oracle/test_dialect.py4
-rw-r--r--test/dialect/oracle/test_reflection.py16
-rw-r--r--test/dialect/oracle/test_types.py4
-rw-r--r--test/dialect/postgresql/test_async_pg_py3k.py4
-rw-r--r--test/dialect/postgresql/test_compiler.py8
-rw-r--r--test/dialect/postgresql/test_dialect.py21
-rw-r--r--test/dialect/postgresql/test_query.py6
-rw-r--r--test/dialect/postgresql/test_reflection.py499
-rw-r--r--test/dialect/postgresql/test_types.py692
-rw-r--r--test/dialect/test_sqlite.py32
-rw-r--r--test/engine/test_ddlevents.py4
-rw-r--r--test/engine/test_deprecations.py16
-rw-r--r--test/engine/test_execute.py364
-rw-r--r--test/engine/test_logging.py16
-rw-r--r--test/engine/test_pool.py57
-rw-r--r--test/engine/test_processors.py10
-rw-r--r--test/engine/test_reconnect.py20
-rw-r--r--test/engine/test_reflection.py9
-rw-r--r--test/engine/test_transaction.py21
-rw-r--r--test/ext/asyncio/test_engine_py3k.py16
-rw-r--r--test/ext/declarative/test_inheritance.py4
-rw-r--r--test/ext/declarative/test_reflection.py127
-rw-r--r--test/ext/test_associationproxy.py16
-rw-r--r--test/ext/test_baked.py2
-rw-r--r--test/ext/test_compiler.py2
-rw-r--r--test/ext/test_extendedattr.py6
-rw-r--r--test/ext/test_horizontal_shard.py32
-rw-r--r--test/ext/test_hybrid.py2
-rw-r--r--test/ext/test_mutable.py15
-rw-r--r--test/ext/test_orderinglist.py16
-rw-r--r--test/orm/declarative/test_basic.py4
-rw-r--r--test/orm/declarative/test_concurrency.py2
-rw-r--r--test/orm/declarative/test_inheritance.py4
-rw-r--r--test/orm/declarative/test_mixin.py4
-rw-r--r--test/orm/declarative/test_reflection.py5
-rw-r--r--test/orm/inheritance/test_basic.py49
-rw-r--r--test/orm/test_attributes.py8
-rw-r--r--test/orm/test_bind.py74
-rw-r--r--test/orm/test_collection.py5
-rw-r--r--test/orm/test_compile.py2
-rw-r--r--test/orm/test_cycles.py3
-rw-r--r--test/orm/test_deprecations.py80
-rw-r--r--test/orm/test_eager_relations.py14
-rw-r--r--test/orm/test_events.py7
-rw-r--r--test/orm/test_froms.py106
-rw-r--r--test/orm/test_lazy_relations.py57
-rw-r--r--test/orm/test_load_on_fks.py30
-rw-r--r--test/orm/test_mapper.py6
-rw-r--r--test/orm/test_options.py4
-rw-r--r--test/orm/test_query.py37
-rw-r--r--test/orm/test_rel_fn.py2
-rw-r--r--test/orm/test_relationships.py6
-rw-r--r--test/orm/test_selectin_relations.py50
-rw-r--r--test/orm/test_session.py20
-rw-r--r--test/orm/test_subquery_relations.py50
-rw-r--r--test/orm/test_transaction.py681
-rw-r--r--test/orm/test_unitofwork.py8
-rw-r--r--test/orm/test_unitofworkv2.py3
-rw-r--r--test/requirements.py16
-rw-r--r--test/sql/test_case_statement.py4
-rw-r--r--test/sql/test_compare.py2
-rw-r--r--test/sql/test_compiler.py2
-rw-r--r--test/sql/test_defaults.py5
-rw-r--r--test/sql/test_deprecations.py6
-rw-r--r--test/sql/test_external_traversal.py14
-rw-r--r--test/sql/test_from_linter.py8
-rw-r--r--test/sql/test_functions.py10
-rw-r--r--test/sql/test_metadata.py2
-rw-r--r--test/sql/test_operators.py8
-rw-r--r--test/sql/test_resultset.py15
-rw-r--r--test/sql/test_sequences.py4
-rw-r--r--test/sql/test_types.py11
-rw-r--r--tox.ini8
115 files changed, 2638 insertions, 2092 deletions
diff --git a/doc/build/changelog/unreleased_14/5823.rst b/doc/build/changelog/unreleased_14/5823.rst
new file mode 100644
index 000000000..74debdaa9
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5823.rst
@@ -0,0 +1,13 @@
+.. change::
+ :tags: bug, pool, asyncio
+ :tickets: 5823
+
+ When using an asyncio engine, the connection pool will now detach and
+ discard a pooled connection that is was not explicitly closed/returned to
+ the pool when its tracking object is garbage collected, emitting a warning
+ that the connection was not properly closed. As this operation occurs
+ during Python gc finalizers, it's not safe to run any IO operations upon
+ the connection including transaction rollback or connection close as this
+ will often be outside of the event loop.
+
+
diff --git a/doc/build/changelog/unreleased_14/5827.rst b/doc/build/changelog/unreleased_14/5827.rst
new file mode 100644
index 000000000..d5c8acd8c
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5827.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, asyncio
+ :tickets: 5827
+
+ Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would
+ be raised rather than :class:`.exc.TimeoutError`. Also repaired the
+ :paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using
+ the async engine, which previously would ignore the timeout and block
+ rather than timing out immediately as is the behavior with regular
+ :class:`.QueuePool`.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 538679fcf..0227e515d 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2785,15 +2785,14 @@ class MSDialect(default.DefaultDialect):
def has_table(self, connection, tablename, dbname, owner, schema):
if tablename.startswith("#"): # temporary table
tables = ischema.mssql_temp_table_columns
- result = connection.execute(
- sql.select(tables.c.table_name)
- .where(
- tables.c.table_name.like(
- self._temp_table_name_like_pattern(tablename)
- )
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
)
- .limit(1)
)
+
+ result = connection.execute(s.limit(1))
return result.scalar() is not None
else:
tables = ischema.tables
diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py
index 269eb164f..56f3305a7 100644
--- a/lib/sqlalchemy/dialects/mssql/provision.py
+++ b/lib/sqlalchemy/dialects/mssql/provision.py
@@ -1,6 +1,14 @@
+from sqlalchemy import inspect
+from sqlalchemy import Integer
from ... import create_engine
from ... import exc
+from ...schema import Column
+from ...schema import DropConstraint
+from ...schema import ForeignKeyConstraint
+from ...schema import MetaData
+from ...schema import Table
from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import get_temp_table_name
from ...testing.provision import log
@@ -38,7 +46,6 @@ def _mssql_drop_ignore(conn, ident):
# "where database_id=db_id('%s')" % ident):
# log.info("killing SQL server session %s", row['session_id'])
# conn.exec_driver_sql("kill %s" % row['session_id'])
-
conn.exec_driver_sql("drop database %s" % ident)
log.info("Reaped db: %s", ident)
return True
@@ -83,4 +90,27 @@ def _mssql_temp_table_keyword_args(cfg, eng):
@get_temp_table_name.for_db("mssql")
def _mssql_get_temp_table_name(cfg, eng, base_name):
- return "#" + base_name
+ return "##" + base_name
+
+
+@drop_all_schema_objects_pre_tables.for_db("mssql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ inspector = inspect(conn)
+ for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
+ for tname in inspector.get_table_names(schema=schema):
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint(
+ [tb.c.x], [tb.c.y], name=fk["name"]
+ )
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 042443692..b8b4df760 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -93,6 +93,7 @@ The parameters accepted by the cx_oracle dialect are as follows:
* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail.
+
.. _cx_oracle_unicode:
Unicode
diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py
index d51131c0b..e0dadd58e 100644
--- a/lib/sqlalchemy/dialects/oracle/provision.py
+++ b/lib/sqlalchemy/dialects/oracle/provision.py
@@ -6,11 +6,11 @@ from ...testing.provision import create_db
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
+from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import set_default_schema_on_connection
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
-from ...testing.provision import update_db_opts
@create_db.for_db("oracle")
@@ -57,21 +57,39 @@ def _oracle_drop_db(cfg, eng, ident):
_ora_drop_ignore(conn, "%s_ts2" % ident)
-@update_db_opts.for_db("oracle")
-def _oracle_update_db_opts(db_url, db_opts):
- pass
+@stop_test_class_outside_fixtures.for_db("oracle")
+def stop_test_class_outside_fixtures(config, db, cls):
+ with db.begin() as conn:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa E501
+ conn.exec_driver_sql("purge recyclebin")
-@stop_test_class.for_db("oracle")
-def stop_test_class(config, db, cls):
- """run magic command to get rid of identity sequences
+ # clear statement cache on all connections that were used
+ # https://github.com/oracle/python-cx_Oracle/issues/519
- # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/
+ for cx_oracle_conn in _all_conns:
+ try:
+ sc = cx_oracle_conn.stmtcachesize
+ except db.dialect.dbapi.InterfaceError:
+ # connection closed
+ pass
+ else:
+ cx_oracle_conn.stmtcachesize = 0
+ cx_oracle_conn.stmtcachesize = sc
+ _all_conns.clear()
- """
- with db.begin() as conn:
- conn.exec_driver_sql("purge recyclebin")
+_all_conns = set()
+
+
+@post_configure_engine.for_db("oracle")
+def _oracle_post_configure_engine(url, engine, follower_ident):
+ from sqlalchemy import event
+
+ @event.listens_for(engine, "checkout")
+ def checkout(dbapi_con, con_record, con_proxy):
+ _all_conns.add(dbapi_con)
@run_reap_dbs.for_db("oracle")
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 7c6e8fb02..e542c77f4 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -670,7 +670,6 @@ class AsyncAdapt_asyncpg_connection:
def rollback(self):
if self._started:
self.await_(self._transaction.rollback())
-
self._transaction = None
self._started = False
diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py
index d345cdfdf..70c390800 100644
--- a/lib/sqlalchemy/dialects/postgresql/provision.py
+++ b/lib/sqlalchemy/dialects/postgresql/provision.py
@@ -8,6 +8,7 @@ from ...testing.provision import drop_all_schema_objects_post_tables
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import log
+from ...testing.provision import prepare_for_drop_tables
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import temp_table_keyword_args
@@ -102,3 +103,23 @@ def drop_all_schema_objects_post_tables(cfg, eng):
postgresql.ENUM(name=enum["name"], schema=enum["schema"])
)
)
+
+
+@prepare_for_drop_tables.for_db("postgresql")
+def prepare_for_drop_tables(config, connection):
+ """Ensure there are no locks on the current username/database."""
+
+ result = connection.exec_driver_sql(
+ "select pid, state, wait_event_type, query "
+ # "select pg_terminate_backend(pid), state, wait_event_type "
+ "from pg_stat_activity where "
+ "usename=current_user "
+ "and datname=current_database() and state='idle in transaction' "
+ "and pid != pg_backend_pid()"
+ )
+ rows = result.all() # noqa
+ assert not rows, (
+ "PostgreSQL may not be able to DROP tables due to "
+ "idle in transaction: %s"
+ % ("; ".join(row._mapping["query"] for row in rows))
+ )
diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py
index f26c21e22..a481be27e 100644
--- a/lib/sqlalchemy/dialects/sqlite/provision.py
+++ b/lib/sqlalchemy/dialects/sqlite/provision.py
@@ -7,7 +7,7 @@ from ...testing.provision import follower_url_from_main
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
@@ -57,8 +57,8 @@ def _sqlite_drop_db(cfg, eng, ident):
os.remove(path)
-@stop_test_class.for_db("sqlite")
-def stop_test_class(config, db, cls):
+@stop_test_class_outside_fixtures.for_db("sqlite")
+def stop_test_class_outside_fixtures(config, db, cls):
with db.connect() as conn:
files = [
row.file
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 50f00c025..72d66b7c8 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -2729,14 +2729,16 @@ class Engine(Connectable, log.Identified):
return self.conn
def __exit__(self, type_, value, traceback):
-
- if type_ is not None:
- self.transaction.rollback()
- else:
- if self.transaction.is_active:
- self.transaction.commit()
- if not self.close_with_result:
- self.conn.close()
+ try:
+ if type_ is not None:
+ if self.transaction.is_active:
+ self.transaction.rollback()
+ else:
+ if self.transaction.is_active:
+ self.transaction.commit()
+ finally:
+ if not self.close_with_result:
+ self.conn.close()
def begin(self, close_with_result=False):
"""Return a context manager delivering a :class:`_engine.Connection`
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
index f89be1809..72d232085 100644
--- a/lib/sqlalchemy/engine/create.py
+++ b/lib/sqlalchemy/engine/create.py
@@ -655,9 +655,12 @@ def create_engine(url, **kwargs):
c = base.Connection(
engine, connection=dbapi_connection, _has_events=False
)
- c._execution_options = util.immutabledict()
- dialect.initialize(c)
- dialect.do_rollback(c.connection)
+ c._execution_options = util.EMPTY_DICT
+
+ try:
+ dialect.initialize(c)
+ finally:
+ dialect.do_rollback(c.connection)
# previously, the "first_connect" event was used here, which was then
# scaled back if the "on_connect" handler were present. now,
diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py
index d2f609326..bfdcdfc7f 100644
--- a/lib/sqlalchemy/future/engine.py
+++ b/lib/sqlalchemy/future/engine.py
@@ -368,12 +368,15 @@ class Engine(_LegacyEngine):
return self.conn
def __exit__(self, type_, value, traceback):
- if type_ is not None:
- self.transaction.rollback()
- else:
- if self.transaction.is_active:
- self.transaction.commit()
- self.conn.close()
+ try:
+ if type_ is not None:
+ if self.transaction.is_active:
+ self.transaction.rollback()
+ else:
+ if self.transaction.is_active:
+ self.transaction.commit()
+ finally:
+ self.conn.close()
def begin(self):
"""Return a :class:`_future.Connection` object with a transaction
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
index 7c9509e45..6c3aad037 100644
--- a/lib/sqlalchemy/pool/base.py
+++ b/lib/sqlalchemy/pool/base.py
@@ -426,6 +426,7 @@ class _ConnectionRecord(object):
rec._checkin_failed(err)
echo = pool._should_log_debug()
fairy = _ConnectionFairy(dbapi_connection, rec, echo)
+
rec.fairy_ref = weakref.ref(
fairy,
lambda ref: _finalize_fairy
@@ -609,6 +610,15 @@ def _finalize_fairy(
assert connection is None
connection = connection_record.connection
+ dont_restore_gced = pool._is_asyncio
+
+ if dont_restore_gced:
+ detach = not connection_record or ref
+ can_manipulate_connection = not ref
+ else:
+ detach = not connection_record
+ can_manipulate_connection = True
+
if connection is not None:
if connection_record and echo:
pool.logger.debug(
@@ -620,13 +630,26 @@ def _finalize_fairy(
connection, connection_record, echo
)
assert fairy.connection is connection
- fairy._reset(pool)
+ if can_manipulate_connection:
+ fairy._reset(pool)
+
+ if detach:
+ if connection_record:
+ fairy._pool = pool
+ fairy.detach()
+
+ if can_manipulate_connection:
+ if pool.dispatch.close_detached:
+ pool.dispatch.close_detached(connection)
+
+ pool._close_connection(connection)
+ else:
+ util.warn(
+ "asyncio connection is being garbage "
+ "collected without being properly closed: %r"
+ % connection
+ )
- # Immediately close detached instances
- if not connection_record:
- if pool.dispatch.close_detached:
- pool.dispatch.close_detached(connection)
- pool._close_connection(connection)
except BaseException as e:
pool.logger.error(
"Exception during reset or similar", exc_info=True
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 191252bfb..9f2d0b857 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -29,8 +29,10 @@ from .assertions import in_ # noqa
from .assertions import is_ # noqa
from .assertions import is_false # noqa
from .assertions import is_instance_of # noqa
+from .assertions import is_none # noqa
from .assertions import is_not # noqa
from .assertions import is_not_ # noqa
+from .assertions import is_not_none # noqa
from .assertions import is_true # noqa
from .assertions import le_ # noqa
from .assertions import ne_ # noqa
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 0a2aed9d8..db530a961 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -232,6 +232,14 @@ def is_false(a, msg=None):
is_(bool(a), False, msg=msg)
+def is_none(a, msg=None):
+ is_(a, None, msg=msg)
+
+
+def is_not_none(a, msg=None):
+ is_not(a, None, msg=msg)
+
+
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index f64153f33..750671f9f 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -97,6 +97,10 @@ def get_current_test_name():
return _fixture_functions.get_current_test_name()
+def mark_base_test_class():
+ return _fixture_functions.mark_base_test_class()
+
+
class Config(object):
def __init__(self, db, db_opts, options, file_config):
self._set_name(db)
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index a4c1f3973..8b334fde2 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -7,6 +7,7 @@
from __future__ import absolute_import
+import collections
import re
import warnings
import weakref
@@ -20,26 +21,29 @@ from .. import pool
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
- self.testing_engines = weakref.WeakKeyDictionary()
- self.conns = set()
+ self.testing_engines = collections.defaultdict(set)
+ self.dbapi_connections = set()
def add_pool(self, pool):
- event.listen(pool, "connect", self.connect)
- event.listen(pool, "checkout", self.checkout)
- event.listen(pool, "invalidate", self.invalidate)
-
- def add_engine(self, engine):
- self.add_pool(engine.pool)
- self.testing_engines[engine] = True
+ event.listen(pool, "checkout", self._add_conn)
+ event.listen(pool, "checkin", self._remove_conn)
+ event.listen(pool, "close", self._remove_conn)
+ event.listen(pool, "close_detached", self._remove_conn)
+ # note we are keeping "invalidated" here, as those are still
+ # opened connections we would like to roll back
+
+ def _add_conn(self, dbapi_con, con_record, con_proxy):
+ self.dbapi_connections.add(dbapi_con)
+ self.proxy_refs[con_proxy] = True
- def connect(self, dbapi_conn, con_record):
- self.conns.add((dbapi_conn, con_record))
+ def _remove_conn(self, dbapi_conn, *arg):
+ self.dbapi_connections.discard(dbapi_conn)
- def checkout(self, dbapi_con, con_record, con_proxy):
- self.proxy_refs[con_proxy] = True
+ def add_engine(self, engine, scope):
+ self.add_pool(engine.pool)
- def invalidate(self, dbapi_con, con_record, exception):
- self.conns.discard((dbapi_con, con_record))
+ assert scope in ("class", "global", "function", "fixture")
+ self.testing_engines[scope].add(engine)
def _safe(self, fn):
try:
@@ -54,53 +58,76 @@ class ConnectionKiller(object):
if rec is not None and rec.is_valid:
self._safe(rec.rollback)
- def close_all(self):
+ def checkin_all(self):
+ # run pool.checkin() for all ConnectionFairy instances we have
+ # tracked.
+
for rec in list(self.proxy_refs):
if rec is not None and rec.is_valid:
- self._safe(rec._close)
-
- def _after_test_ctx(self):
- # this can cause a deadlock with pg8000 - pg8000 acquires
- # prepared statement lock inside of rollback() - if async gc
- # is collecting in finalize_fairy, deadlock.
- # not sure if this should be for non-cpython only.
- # note that firebird/fdb definitely needs this though
- for conn, rec in list(self.conns):
- if rec.connection is None:
- # this is a hint that the connection is closed, which
- # is causing segfaults on mysqlclient due to
- # https://github.com/PyMySQL/mysqlclient-python/issues/270;
- # try to work around here
- continue
- self._safe(conn.rollback)
-
- def _stop_test_ctx(self):
- if config.options.low_connections:
- self._stop_test_ctx_minimal()
- else:
- self._stop_test_ctx_aggressive()
+ self.dbapi_connections.discard(rec.connection)
+ self._safe(rec._checkin)
- def _stop_test_ctx_minimal(self):
- self.close_all()
+ # for fairy refs that were GCed and could not close the connection,
+ # such as asyncio, roll back those remaining connections
+ for con in self.dbapi_connections:
+ self._safe(con.rollback)
+ self.dbapi_connections.clear()
- self.conns = set()
+ def close_all(self):
+ self.checkin_all()
- for rec in list(self.testing_engines):
- if rec is not config.db:
- rec.dispose()
+ def prepare_for_drop_tables(self, connection):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
- def _stop_test_ctx_aggressive(self):
- self.close_all()
- for conn, rec in list(self.conns):
- self._safe(conn.close)
- rec.connection = None
+ from . import provision
+
+ provision.prepare_for_drop_tables(connection.engine.url, connection)
+
+ def _drop_testing_engines(self, scope):
+ eng = self.testing_engines[scope]
+ for rec in list(eng):
+ for proxy_ref in list(self.proxy_refs):
+ if proxy_ref is not None and proxy_ref.is_valid:
+ if (
+ proxy_ref._pool is not None
+ and proxy_ref._pool is rec.pool
+ ):
+ self._safe(proxy_ref._checkin)
+ rec.dispose()
+ eng.clear()
+
+ def after_test(self):
+ self._drop_testing_engines("function")
+
+ def after_test_outside_fixtures(self, test):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ if test.__class__.__leave_connections_for_teardown__:
+ return
- self.conns = set()
- for rec in list(self.testing_engines):
- if hasattr(rec, "sync_engine"):
- rec.sync_engine.dispose()
- else:
- rec.dispose()
+ self.checkin_all()
+
+ # on PostgreSQL, this will test for any "idle in transaction"
+ # connections. useful to identify tests with unusual patterns
+ # that can't be cleaned up correctly.
+ from . import provision
+
+ with config.db.connect() as conn:
+ provision.prepare_for_drop_tables(conn.engine.url, conn)
+
+ def stop_test_class_inside_fixtures(self):
+ self.checkin_all()
+ self._drop_testing_engines("function")
+ self._drop_testing_engines("class")
+
+ def final_cleanup(self):
+ self.checkin_all()
+ for scope in self.testing_engines:
+ self._drop_testing_engines(scope)
def assert_all_closed(self):
for rec in self.proxy_refs:
@@ -111,20 +138,6 @@ class ConnectionKiller(object):
testing_reaper = ConnectionKiller()
-def drop_all_tables(metadata, bind):
- testing_reaper.close_all()
- if hasattr(bind, "close"):
- bind.close()
-
- if not config.db.dialect.supports_alter:
- from . import assertions
-
- with assertions.expect_warnings("Can't sort tables", assert_=False):
- metadata.drop_all(bind)
- else:
- metadata.drop_all(bind)
-
-
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
@@ -147,7 +160,7 @@ def rollback_open_connections(fn, *args, **kw):
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
fn(*args, **kw)
@@ -157,7 +170,7 @@ def close_open_connections(fn, *args, **kw):
try:
fn(*args, **kw)
finally:
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
def all_dialects(exclude=None):
@@ -239,12 +252,14 @@ def reconnecting_engine(url=None, options=None):
return engine
-def testing_engine(url=None, options=None, future=False, asyncio=False):
+def testing_engine(url=None, options=None, future=None, asyncio=False):
"""Produce an engine configured by --options with optional overrides."""
if asyncio:
from sqlalchemy.ext.asyncio import create_async_engine as create_engine
- elif future or config.db and config.db._is_future:
+ elif future or (
+ config.db and config.db._is_future and future is not False
+ ):
from sqlalchemy.future import create_engine
else:
from sqlalchemy import create_engine
@@ -252,8 +267,10 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):
if not options:
use_reaper = True
+ scope = "function"
else:
use_reaper = options.pop("use_reaper", True)
+ scope = options.pop("scope", "function")
url = url or config.db.url
@@ -268,16 +285,20 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):
default_opt.update(options)
engine = create_engine(url, **options)
- if asyncio:
- engine.sync_engine._has_events = True
- else:
- engine._has_events = True # enable event blocks, helps with profiling
+
+ if scope == "global":
+ if asyncio:
+ engine.sync_engine._has_events = True
+ else:
+ engine._has_events = (
+ True # enable event blocks, helps with profiling
+ )
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
- engine.pool._max_overflow = 5
+ engine.pool._max_overflow = 0
if use_reaper:
- testing_reaper.add_engine(engine)
+ testing_reaper.add_engine(engine, scope)
return engine
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index ac4d3d8fa..f19b4652a 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -5,6 +5,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import contextlib
import re
import sys
@@ -12,12 +13,11 @@ import sqlalchemy as sa
from . import assertions
from . import config
from . import schema
-from .engines import drop_all_tables
-from .engines import testing_engine
from .entities import BasicEntity
from .entities import ComparableEntity
from .entities import ComparableMixin # noqa
from .util import adict
+from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
from ..orm import declarative_base
@@ -25,10 +25,8 @@ from ..orm import registry
from ..orm.decl_api import DeclarativeMeta
from ..schema import sort_tables_and_constraints
-# whether or not we use unittest changes things dramatically,
-# as far as how pytest collection works.
-
+@config.mark_base_test_class()
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
@@ -48,81 +46,114 @@ class TestBase(object):
# skipped.
__skip_if__ = None
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
def assert_(self, val, msg=None):
assert val, msg
- # apparently a handful of tests are doing this....OK
- def setup(self):
- if hasattr(self, "setUp"):
- self.setUp()
-
- def teardown(self):
- if hasattr(self, "tearDown"):
- self.tearDown()
-
@config.fixture()
def connection(self):
- eng = getattr(self, "bind", config.db)
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
conn = eng.connect()
trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
@config.fixture()
- def future_connection(self):
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
- eng = testing_engine(future=True)
- conn = eng.connect()
- trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
+ @config.fixture()
+ def future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from . import engines
+
+ def gen_testing_engine(
+ url=None, options=None, future=False, asyncio=False
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url, options=options, future=future, asyncio=asyncio
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
@config.fixture()
- def metadata(self):
+ def metadata(self, request):
"""Provide bound MetaData for a single test, dropping afterwards."""
- from . import engines
from ..sql import schema
metadata = schema.MetaData()
- try:
- yield metadata
- finally:
- engines.drop_all_tables(metadata, config.db)
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
-class FutureEngineMixin(object):
- @classmethod
- def setup_class(cls):
- from ..future.engine import Engine
- from sqlalchemy import testing
+_connection_fixture_connection = None
- facade = Engine._future_facade(config.db)
- config._current.push_engine(facade, testing)
- super_ = super(FutureEngineMixin, cls)
- if hasattr(super_, "setup_class"):
- super_.setup_class()
+@contextlib.contextmanager
+def _push_future_engine(engine):
- @classmethod
- def teardown_class(cls):
- super_ = super(FutureEngineMixin, cls)
- if hasattr(super_, "teardown_class"):
- super_.teardown_class()
+ from ..future.engine import Engine
+ from sqlalchemy import testing
+
+ facade = Engine._future_facade(engine)
+ config._current.push_engine(facade, testing)
+
+ yield facade
- from sqlalchemy import testing
+ config._current.pop(testing)
- config._current.pop(testing)
+
+class FutureEngineMixin(object):
+ @config.fixture(autouse=True, scope="class")
+ def _push_future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
class TablesTest(TestBase):
@@ -151,18 +182,32 @@ class TablesTest(TestBase):
other = None
sequences = None
- @property
- def tables_test_metadata(self):
- return self._tables_metadata
-
- @classmethod
- def setup_class(cls):
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
cls._init_class()
cls._setup_once_tables()
cls._setup_once_inserts()
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
@classmethod
def _init_class(cls):
if cls.run_define_tables == "each":
@@ -213,10 +258,10 @@ class TablesTest(TestBase):
if self.run_define_tables == "each":
self.tables.clear()
if self.run_create_tables == "each":
- drop_all_tables(self._tables_metadata, self.bind)
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
self._tables_metadata.clear()
elif self.run_create_tables == "each":
- drop_all_tables(self._tables_metadata, self.bind)
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
# no need to run deletes if tables are recreated on setup
if (
@@ -242,17 +287,10 @@ class TablesTest(TestBase):
file=sys.stderr,
)
- def setup(self):
- self._setup_each_tables()
- self._setup_each_inserts()
-
- def teardown(self):
- self._teardown_each_tables()
-
@classmethod
def _teardown_once_metadata_bind(cls):
if cls.run_create_tables:
- drop_all_tables(cls._tables_metadata, cls.bind)
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
if cls.run_dispose_bind == "once":
cls.dispose_bind(cls.bind)
@@ -263,10 +301,6 @@ class TablesTest(TestBase):
cls.bind = None
@classmethod
- def teardown_class(cls):
- cls._teardown_once_metadata_bind()
-
- @classmethod
def setup_bind(cls):
return config.db
@@ -332,38 +366,47 @@ class RemovesEvents(object):
self._event_fns.add((target, name, fn))
event.listen(target, name, fn, **kw)
- def teardown(self):
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
for key in self._event_fns:
event.remove(*key)
- super_ = super(RemovesEvents, self)
- if hasattr(super_, "teardown"):
- super_.teardown()
-
-
-class _ORMTest(object):
- @classmethod
- def teardown_class(cls):
- sa.orm.session.close_all_sessions()
- sa.orm.clear_mappers()
-def create_session(**kw):
- kw.setdefault("autoflush", False)
- kw.setdefault("expire_on_commit", False)
- return sa.orm.Session(config.db, **kw)
+_fixture_sessions = set()
def fixture_session(**kw):
kw.setdefault("autoflush", True)
kw.setdefault("expire_on_commit", True)
- return sa.orm.Session(config.db, **kw)
+ sess = sa.orm.Session(config.db, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def _close_all_sessions():
+ # will close all still-referenced sessions
+ sa.orm.session.close_all_sessions()
+ _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+ _close_all_sessions()
+ sa.orm.clear_mappers()
-class ORMTest(_ORMTest, TestBase):
+def after_test():
+
+ if _fixture_sessions:
+
+ _close_all_sessions()
+
+
+class ORMTest(TestBase):
pass
-class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = "once"
@@ -372,8 +415,9 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
classes = None
- @classmethod
- def setup_class(cls):
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
cls._init_class()
if cls.classes is None:
@@ -384,18 +428,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
cls._setup_once_mappers()
cls._setup_once_inserts()
- @classmethod
- def teardown_class(cls):
+ yield
+
cls._teardown_once_class()
cls._teardown_once_metadata_bind()
- def setup(self):
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
self._setup_each_tables()
self._setup_each_classes()
self._setup_each_mappers()
self._setup_each_inserts()
- def teardown(self):
+ yield
+
sa.orm.session.close_all_sessions()
self._teardown_each_mappers()
self._teardown_each_classes()
@@ -404,7 +450,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
@classmethod
def _teardown_once_class(cls):
cls.classes.clear()
- _ORMTest.teardown_class()
@classmethod
def _setup_once_classes(cls):
@@ -440,6 +485,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
"""
cls_registry = cls.classes
+ assert cls_registry is not None
+
class FindFixture(type):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
index a95c947e2..1f568dfc8 100644
--- a/lib/sqlalchemy/testing/plugin/bootstrap.py
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -40,6 +40,11 @@ def load_file_as_module(name):
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+ sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
+ if sys.version_info < (3, 0):
+ sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
+ "reinvent_fixtures_py2k"
+ )
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
else:
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 3594cd276..7851fbb3e 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -21,6 +21,9 @@ import logging
import re
import sys
+# flag which indicates we are in the SQLAlchemy testing suite,
+# and not that of Alembic or a third party dialect.
+bootstrapped_as_sqlalchemy = False
log = logging.getLogger("sqlalchemy.testing.plugin_base")
@@ -381,7 +384,7 @@ def _init_symbols(options, file_config):
@post
def _set_disable_asyncio(opt, file_config):
- if opt.disable_asyncio:
+ if opt.disable_asyncio or not py3k:
from sqlalchemy.testing import asyncio
asyncio.ENABLE_ASYNCIO = False
@@ -458,6 +461,8 @@ def _setup_requirements(argument):
config.requirements = testing.requires = req_cls()
+ config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
+
@post
def _prep_testing_database(options, file_config):
@@ -566,17 +571,22 @@ def generate_sub_tests(cls, module):
yield cls
-def start_test_class(cls):
+def start_test_class_outside_fixtures(cls):
_do_skips(cls)
_setup_engine(cls)
def stop_test_class(cls):
- # from sqlalchemy import inspect
- # assert not inspect(testing.db).get_table_names()
+ # close sessions, immediate connections, etc.
+ fixtures.stop_test_class_inside_fixtures(cls)
+
+ # close outstanding connection pool connections, dispose of
+ # additional engines
+ engines.testing_reaper.stop_test_class_inside_fixtures()
- provision.stop_test_class(config, config.db, cls)
- engines.testing_reaper._stop_test_ctx()
+
+def stop_test_class_outside_fixtures(cls):
+ provision.stop_test_class_outside_fixtures(config, config.db, cls)
try:
if not options.low_connections:
assertions.global_cleanup_assertions()
@@ -590,14 +600,16 @@ def _restore_engine():
def final_process_cleanup():
- engines.testing_reaper._stop_test_ctx_aggressive()
+ engines.testing_reaper.final_cleanup()
assertions.global_cleanup_assertions()
_restore_engine()
def _setup_engine(cls):
if getattr(cls, "__engine_options__", None):
- eng = engines.testing_engine(options=cls.__engine_options__)
+ opts = dict(cls.__engine_options__)
+ opts["scope"] = "class"
+ eng = engines.testing_engine(options=opts)
config._current.push_engine(eng, testing)
@@ -614,7 +626,12 @@ def before_test(test, test_module_name, test_class, test_name):
def after_test(test):
- engines.testing_reaper._after_test_ctx()
+ fixtures.after_test()
+ engines.testing_reaper.after_test()
+
+
+def after_test_fixtures(test):
+ engines.testing_reaper.after_test_outside_fixtures(test)
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
@@ -748,6 +765,10 @@ class FixtureFunctions(ABC):
def get_current_test_name(self):
raise NotImplementedError()
+ @abc.abstractmethod
+ def mark_base_test_class(self):
+ raise NotImplementedError()
+
_fixture_fn_class = None
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 46468a07d..4eaaecebb 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -17,6 +17,7 @@ import sys
import pytest
+
try:
import typing
except ImportError:
@@ -33,6 +34,14 @@ except ImportError:
has_xdist = False
+py2k = sys.version_info < (3, 0)
+if py2k:
+ try:
+ import sqla_reinvent_fixtures as reinvent_fixtures_py2k
+ except ImportError:
+ from . import reinvent_fixtures_py2k
+
+
def pytest_addoption(parser):
group = parser.getgroup("sqlalchemy")
@@ -238,6 +247,10 @@ def pytest_collection_modifyitems(session, config, items):
else:
newitems.append(item)
+ if py2k:
+ for item in newitems:
+ reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
+
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
items[:] = sorted(
@@ -251,7 +264,6 @@ def pytest_collection_modifyitems(session, config, items):
def pytest_pycollect_makeitem(collector, name, obj):
-
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
from sqlalchemy.testing import config
@@ -259,7 +271,6 @@ def pytest_pycollect_makeitem(collector, name, obj):
obj = _apply_maybe_async(obj)
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
-
return [
ctor(name=parametrize_cls.__name__, parent=collector)
for parametrize_cls in _parametrize_cls(collector.module, obj)
@@ -287,12 +298,11 @@ def _is_wrapped_coroutine_function(fn):
def _apply_maybe_async(obj, recurse=True):
from sqlalchemy.testing import asyncio
- setup_names = {"setup", "setup_class", "teardown", "teardown_class"}
for name, value in vars(obj).items():
if (
(callable(value) or isinstance(value, classmethod))
and not getattr(value, "_maybe_async_applied", False)
- and (name.startswith("test_") or name in setup_names)
+ and (name.startswith("test_"))
and not _is_wrapped_coroutine_function(value)
):
is_classmethod = False
@@ -317,9 +327,6 @@ def _apply_maybe_async(obj, recurse=True):
return obj
-_current_class = None
-
-
def _parametrize_cls(module, cls):
"""implement a class-based version of pytest parametrize."""
@@ -355,63 +362,153 @@ def _parametrize_cls(module, cls):
return classes
+_current_class = None
+
+
def pytest_runtest_setup(item):
from sqlalchemy.testing import asyncio
- # here we seem to get called only based on what we collected
- # in pytest_collection_modifyitems. So to do class-based stuff
- # we have to tear that out.
- global _current_class
-
if not isinstance(item, pytest.Function):
return
- # ... so we're doing a little dance here to figure it out...
+ # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
+ # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
+ # for the whole class and has to run things that are across all current
+ # databases, so we run this outside of the pytest fixture system altogether
+ # and ensure asyncio greenlet if any engines are async
+
+ global _current_class
+
if _current_class is None:
- asyncio._maybe_async(class_setup, item.parent.parent)
+ asyncio._maybe_async_provisioning(
+ plugin_base.start_test_class_outside_fixtures,
+ item.parent.parent.cls,
+ )
_current_class = item.parent.parent
- # this is needed for the class-level, to ensure that the
- # teardown runs after the class is completed with its own
- # class-level teardown...
def finalize():
global _current_class
- asyncio._maybe_async(class_teardown, item.parent.parent)
_current_class = None
+ asyncio._maybe_async_provisioning(
+ plugin_base.stop_test_class_outside_fixtures,
+ item.parent.parent.cls,
+ )
+
item.parent.parent.addfinalizer(finalize)
- asyncio._maybe_async(test_setup, item)
+def pytest_runtest_call(item):
+ # runs inside of pytest function fixture scope
+ # before test function runs
-def pytest_runtest_teardown(item):
from sqlalchemy.testing import asyncio
- # ...but this works better as the hook here rather than
- # using a finalizer, as the finalizer seems to get in the way
- # of the test reporting failures correctly (you get a bunch of
- # pytest assertion stuff instead)
- asyncio._maybe_async(test_teardown, item)
+ asyncio._maybe_async(
+ plugin_base.before_test,
+ item,
+ item.parent.module.__name__,
+ item.parent.cls,
+ item.name,
+ )
-def test_setup(item):
- plugin_base.before_test(
- item, item.parent.module.__name__, item.parent.cls, item.name
- )
+def pytest_runtest_teardown(item, nextitem):
+ # runs inside of pytest function fixture scope
+ # after test function runs
+ from sqlalchemy.testing import asyncio
-def test_teardown(item):
- plugin_base.after_test(item)
+ asyncio._maybe_async(plugin_base.after_test, item)
-def class_setup(item):
+@pytest.fixture(scope="class")
+def setup_class_methods(request):
from sqlalchemy.testing import asyncio
- asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls)
+ cls = request.cls
+
+ if hasattr(cls, "setup_test_class"):
+ asyncio._maybe_async(cls.setup_test_class)
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_setup(request)
+
+ yield
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_teardown(request)
+ if hasattr(cls, "teardown_test_class"):
+ asyncio._maybe_async(cls.teardown_test_class)
-def class_teardown(item):
- plugin_base.stop_test_class(item.cls)
+ asyncio._maybe_async(plugin_base.stop_test_class, cls)
+
+
+@pytest.fixture(scope="function")
+def setup_test_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ # called for each test
+
+ self = request.instance
+
+ # 1. run outer xdist-style setup
+ if hasattr(self, "setup_test"):
+ asyncio._maybe_async(self.setup_test)
+
+ # alembic test suite is using setUp and tearDown
+ # xdist methods; support these in the test suite
+ # for the near term
+ if hasattr(self, "setUp"):
+ asyncio._maybe_async(self.setUp)
+
+ # 2. run homegrown function level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_setup(request)
+
+ # inside the yield:
+
+ # 3. function level "autouse" fixtures under py3k (examples: TablesTest
+ # define tables / data, MappedTest define tables / mappers / data)
+
+ # 4. function level fixtures defined on test functions themselves,
+ # e.g. "connection", "metadata" run next
+
+ # 5. pytest hook pytest_runtest_call then runs
+
+ # 6. test itself runs
+
+ yield
+
+ # yield finishes:
+
+ # 7. pytest hook pytest_runtest_teardown hook runs, this is associated
+ # with fixtures close all sessions, provisioning.stop_test_class(),
+ # engines.testing_reaper -> ensure all connection pool connections
+ # are returned, engines created by testing_engine that aren't the
+ # config engine are disposed
+
+ # 8. function level fixtures defined on test functions
+ # themselves, e.g. "connection" rolls back the transaction, "metadata"
+ # emits drop all
+
+ # 9. function level "autouse" fixtures under py3k (examples: TablesTest /
+ # MappedTest delete table data, possibly drop tables and clear mappers
+ # depending on the flags defined by the test class)
+
+ # 10. run homegrown function-level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
+
+ asyncio._maybe_async(plugin_base.after_test_fixtures, self)
+
+ # 11. run outer xdist-style teardown
+ if hasattr(self, "tearDown"):
+ asyncio._maybe_async(self.tearDown)
+
+ if hasattr(self, "teardown_test"):
+ asyncio._maybe_async(self.teardown_test)
def getargspec(fn):
@@ -461,6 +558,8 @@ def %(name)s(%(args)s):
# for the wrapped function
decorated.__module__ = fn.__module__
decorated.__name__ = fn.__name__
+ if hasattr(fn, "pytestmark"):
+ decorated.pytestmark = fn.pytestmark
return decorated
return decorate
@@ -470,6 +569,11 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
+ def mark_base_test_class(self):
+ return pytest.mark.usefixtures(
+ "setup_class_methods", "setup_test_methods"
+ )
+
_combination_id_fns = {
"i": lambda obj: obj,
"r": repr,
@@ -647,8 +751,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
fn = asyncio._maybe_async_wrapper(fn)
# other wrappers may be added here
- # now apply FixtureFunctionMarker
- fn = fixture(fn)
+ if py2k and "autouse" in kw:
+ # py2k workaround for too-slow collection of autouse fixtures
+ # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
+ # rationale.
+
+ # comment this condition out in order to disable the
+ # py2k workaround entirely.
+ reinvent_fixtures_py2k.add_fixture(fn, fixture)
+ else:
+ # now apply FixtureFunctionMarker
+ fn = fixture(fn)
+
return fn
if fn:
diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
new file mode 100644
index 000000000..36b68417b
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
@@ -0,0 +1,112 @@
+"""
+invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
+collection/high memory use in pytest 4.6.11, which is the highest version that
+works in py2k.
+
+by "too-slow" we mean the test suite can't even manage to be collected for a
+single process in less than 70 seconds or so and memory use seems to be very
+high as well. for two or four workers the job just times out after ten
+minutes.
+
+so instead we have invented a very limited form of these fixtures, as our
+current use of "autouse" fixtures are limited to those in fixtures.py.
+
+assumptions for these fixtures:
+
+1. we are only using "function" or "class" scope
+
+2. the functions must be associated with a test class
+
+3. the fixture functions cannot themselves use pytest fixtures
+
+4. the fixture functions must use yield, not return
+
+When py2k support is removed and we can stay on a modern pytest version, this
+can all be removed.
+
+
+"""
+import collections
+
+
+_py2k_fixture_fn_names = collections.defaultdict(set)
+_py2k_class_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+_py2k_function_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+
+_py2k_cls_fixture_stack = []
+_py2k_fn_fixture_stack = []
+
+
+def add_fixture(fn, fixture):
+ assert fixture.scope in ("class", "function")
+ _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
+
+
+def scan_for_fixtures_to_use_for_class(item):
+ test_class = item.parent.parent.obj
+
+ for name in _py2k_fixture_fn_names:
+ for fixture_fn, scope in _py2k_fixture_fn_names[name]:
+ meth = getattr(test_class, name, None)
+ if meth and meth.im_func is fixture_fn:
+ for sup in test_class.__mro__:
+ if name in sup.__dict__:
+ if scope == "class":
+ _py2k_class_fixtures[test_class][sup].add(meth)
+ elif scope == "function":
+ _py2k_function_fixtures[test_class][sup].add(meth)
+ break
+ break
+
+
+def run_class_fixture_setup(request):
+
+ cls = request.cls
+ self = cls.__new__(cls)
+
+ fixtures_for_this_class = _py2k_class_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in cls.__mro__:
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_cls_fixture_stack.append(iter_)
+
+
+def run_class_fixture_teardown(request):
+ while _py2k_cls_fixture_stack:
+ iter_ = _py2k_cls_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
+
+
+def run_fn_fixture_setup(request):
+ cls = request.cls
+ self = request.instance
+
+ fixtures_for_this_class = _py2k_function_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in reversed(cls.__mro__):
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_fn_fixture_stack.append(iter_)
+
+
+def run_fn_fixture_teardown(request):
+ while _py2k_fn_fixture_stack:
+ iter_ = _py2k_fn_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index 4ee0567f2..2fade1c32 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -67,6 +67,7 @@ def setup_config(db_url, options, file_config, follower_ident):
db_url = follower_url_from_main(db_url, follower_ident)
db_opts = {}
update_db_opts(db_url, db_opts)
+ db_opts["scope"] = "global"
eng = engines.testing_engine(db_url, db_opts)
post_configure_engine(db_url, eng, follower_ident)
eng.connect().close()
@@ -264,6 +265,7 @@ def drop_all_schema_objects(cfg, eng):
if config.requirements.schemas.enabled_for_config(cfg):
util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
drop_all_schema_objects_post_tables(cfg, eng)
@@ -299,7 +301,7 @@ def update_db_opts(db_url, db_opts):
def post_configure_engine(url, engine, follower_ident):
"""Perform extra steps after configuring an engine for testing.
- (For the internal dialects, currently only used by sqlite.)
+ (For the internal dialects, currently only used by sqlite, oracle)
"""
pass
@@ -375,7 +377,12 @@ def temp_table_keyword_args(cfg, eng):
@register.init
-def stop_test_class(config, db, testcls):
+def prepare_for_drop_tables(config, connection):
+ pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
pass
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 6c3c1005a..de157d028 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
from sqlalchemy import pool
return engines.testing_engine(
- options=dict(poolclass=pool.StaticPool)
+ options=dict(poolclass=pool.StaticPool, scope="class"),
)
else:
return config.db
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index e0fdbe47a..e8dd6cf2c 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -261,10 +261,6 @@ class ServerSideCursorsTest(
)
return self.engine
- def tearDown(self):
- engines.testing_reaper.close_all()
- self.engine.dispose()
-
@testing.combinations(
("global_string", True, "select 1", True),
("global_text", True, text("select 1"), True),
@@ -309,24 +305,22 @@ class ServerSideCursorsTest(
def test_conn_option(self):
engine = self._fixture(False)
- # should be enabled for this one
- result = (
- engine.connect()
- .execution_options(stream_results=True)
- .exec_driver_sql("select 1")
- )
- assert self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
def test_stmt_enabled_conn_option_disabled(self):
engine = self._fixture(False)
s = select(1).execution_options(stream_results=True)
- # not this one
- result = (
- engine.connect().execution_options(stream_results=False).execute(s)
- )
- assert not self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
def test_aliases_and_ss(self):
engine = self._fixture(False)
@@ -344,8 +338,7 @@ class ServerSideCursorsTest(
assert not self._is_server_side(result.cursor)
result.close()
- @testing.provide_metadata
- def test_roundtrip_fetchall(self):
+ def test_roundtrip_fetchall(self, metadata):
md = self.metadata
engine = self._fixture(True)
@@ -385,8 +378,7 @@ class ServerSideCursorsTest(
0,
)
- @testing.provide_metadata
- def test_roundtrip_fetchmany(self):
+ def test_roundtrip_fetchmany(self, metadata):
md = self.metadata
engine = self._fixture(True)
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 3a5e02c32..ebcceaae7 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
__backend__ = True
@testing.fixture
- def do_numeric_test(self, metadata):
+ def do_numeric_test(self, metadata, connection):
@testing.emits_warning(
r".*does \*not\* support Decimal objects natively"
)
def run(type_, input_, output, filter_=None, check_scale=False):
t = Table("t", metadata, Column("x", type_))
- t.create(testing.db)
- with config.db.begin() as conn:
- conn.execute(t.insert(), [{"x": x} for x in input_])
-
- result = {row[0] for row in conn.execute(t.select())}
- output = set(output)
- if filter_:
- result = set(filter_(x) for x in result)
- output = set(filter_(x) for x in output)
- eq_(result, output)
- if check_scale:
- eq_([str(x) for x in result], [str(x) for x in output])
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
return run
@@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
class JSONLegacyStringCastIndexTest(
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index eb9fcd1cd..01185c284 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -14,6 +14,7 @@ import types
from . import config
from . import mock
from .. import inspect
+from ..engine import Connection
from ..schema import Column
from ..schema import DropConstraint
from ..schema import DropTable
@@ -207,11 +208,13 @@ def fail(msg):
@decorator
def provide_metadata(fn, *args, **kw):
- """Provide bound MetaData for a single test, dropping afterwards."""
+ """Provide bound MetaData for a single test, dropping afterwards.
- # import cycle that only occurs with py2k's import resolver
- # in py3k this can be moved top level.
- from . import engines
+ Legacy; use the "metadata" pytest fixture.
+
+ """
+
+ from . import fixtures
metadata = schema.MetaData()
self = args[0]
@@ -220,7 +223,31 @@ def provide_metadata(fn, *args, **kw):
try:
return fn(*args, **kw)
finally:
- engines.drop_all_tables(metadata, config.db)
+ # close out some things that get in the way of dropping tables.
+ # when using the "metadata" fixture, there is a set ordering
+ # of things that makes sure things are cleaned up in order, however
+ # the simple "decorator" nature of this legacy function means
+ # we have to hardcode some of that cleanup ahead of time.
+
+ # close ORM sessions
+ fixtures._close_all_sessions()
+
+ # integrate with the "connection" fixture as there are many
+ # tests where it is used along with provide_metadata
+ if fixtures._connection_fixture_connection:
+ # TODO: this warning can be used to find all the places
+ # this is used with connection fixture
+ # warn("mixing legacy provide metadata with connection fixture")
+ drop_all_tables_from_metadata(
+ metadata, fixtures._connection_fixture_connection
+ )
+ # as the provide_metadata fixture is often used with "testing.db",
+ # when we do the drop we have to commit the transaction so that
+ # the DB is actually updated as the CREATE would have been
+ # committed
+ fixtures._connection_fixture_connection.get_transaction().commit()
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
self.metadata = prev_meta
@@ -359,6 +386,29 @@ class adict(dict):
get_all = __call__
+def drop_all_tables_from_metadata(metadata, engine_or_connection):
+ from . import engines
+
+ def go(connection):
+ engines.testing_reaper.prepare_for_drop_tables(connection)
+
+ if not connection.dialect.supports_alter:
+ from . import assertions
+
+ with assertions.expect_warnings(
+ "Can't sort tables", assert_=False
+ ):
+ metadata.drop_all(connection)
+ else:
+ metadata.drop_all(connection)
+
+ if not isinstance(engine_or_connection, Connection):
+ with engine_or_connection.begin() as connection:
+ go(connection)
+ else:
+ go(engine_or_connection)
+
+
def drop_all_tables(engine, inspector, schema=None, include_names=None):
if include_names is not None:
diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py
index 99ecb4fb3..ca5a3abde 100644
--- a/lib/sqlalchemy/util/queue.py
+++ b/lib/sqlalchemy/util/queue.py
@@ -230,13 +230,16 @@ class AsyncAdaptedQueue:
return self.put_nowait(item)
try:
- if timeout:
+ if timeout is not None:
return self.await_(
asyncio.wait_for(self._queue.put(item), timeout)
)
else:
return self.await_(self._queue.put(item))
- except asyncio.queues.QueueFull as err:
+ except (
+ asyncio.queues.QueueFull,
+ asyncio.exceptions.TimeoutError,
+ ) as err:
compat.raise_(
Full(),
replace_context=err,
@@ -254,14 +257,18 @@ class AsyncAdaptedQueue:
def get(self, block=True, timeout=None):
if not block:
return self.get_nowait()
+
try:
- if timeout:
+ if timeout is not None:
return self.await_(
asyncio.wait_for(self._queue.get(), timeout)
)
else:
return self.await_(self._queue.get())
- except asyncio.queues.QueueEmpty as err:
+ except (
+ asyncio.queues.QueueEmpty,
+ asyncio.exceptions.TimeoutError,
+ ) as err:
compat.raise_(
Empty(),
replace_context=err,
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py
index 0202768ae..968a74700 100644
--- a/test/aaa_profiling/test_compiler.py
+++ b/test/aaa_profiling/test_compiler.py
@@ -18,7 +18,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, metadata
metadata = MetaData()
diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py
index 75a4f51cf..a41a8b9f1 100644
--- a/test/aaa_profiling/test_memusage.py
+++ b/test/aaa_profiling/test_memusage.py
@@ -241,7 +241,7 @@ def assert_no_mappers():
class EnsureZeroed(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
_sessions.clear()
_mapper_registry.clear()
@@ -1032,7 +1032,7 @@ class MemUsageWBackendTest(EnsureZeroed):
t2_mapper = mapper(T2, t2)
t1_mapper.add_property("bar", relationship(t2_mapper))
- s1 = fixture_session()
+ s1 = Session(testing.db)
# this causes the path_registry to be invoked
s1.query(t1_mapper)._compile_context()
diff --git a/test/aaa_profiling/test_misc.py b/test/aaa_profiling/test_misc.py
index db6fd4b71..5b30a3968 100644
--- a/test/aaa_profiling/test_misc.py
+++ b/test/aaa_profiling/test_misc.py
@@ -19,7 +19,7 @@ from sqlalchemy.util import classproperty
class EnumTest(fixtures.TestBase):
__requires__ = ("cpython", "python_profiling_backend")
- def setup(self):
+ def setup_test(self):
class SomeEnum(object):
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py
index f163078d8..8116e5f21 100644
--- a/test/aaa_profiling/test_orm.py
+++ b/test/aaa_profiling/test_orm.py
@@ -29,15 +29,13 @@ class NoCache(object):
run_setup_bind = "each"
@classmethod
- def setup_class(cls):
- super(NoCache, cls).setup_class()
+ def setup_test_class(cls):
cls._cache = config.db._compiled_cache
config.db._compiled_cache = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
config.db._compiled_cache = cls._cache
- super(NoCache, cls).teardown_class()
class MergeTest(NoCache, fixtures.MappedTest):
diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py
index fd02f9139..da3c1c525 100644
--- a/test/aaa_profiling/test_pool.py
+++ b/test/aaa_profiling/test_pool.py
@@ -17,7 +17,7 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults):
def close(self):
pass
- def setup(self):
+ def setup_test(self):
# create a throwaway pool which
# has the effect of initializing
# class-level event listeners on Pool,
diff --git a/test/base/test_events.py b/test/base/test_events.py
index 19f68e9a3..68db5207c 100644
--- a/test/base/test_events.py
+++ b/test/base/test_events.py
@@ -16,7 +16,7 @@ from sqlalchemy.testing.util import gc_collect
class TearDownLocalEventsFixture(object):
- def tearDown(self):
+ def teardown_test(self):
classes = set()
for entry in event.base._registrars.values():
for evt_cls in entry:
@@ -30,7 +30,7 @@ class TearDownLocalEventsFixture(object):
class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test class- and instance-level event registration."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, x, y):
pass
@@ -438,7 +438,7 @@ class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test adaption of legacy args"""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
@event._legacy_signature("0.9", ["x", "y"])
def event_three(self, x, y, z, q):
@@ -608,7 +608,7 @@ class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
@@ -677,7 +677,7 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test default target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
@@ -734,7 +734,7 @@ class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _accept_with(cls, target):
@@ -771,7 +771,7 @@ class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test that ad-hoc subclasses are garbage collected."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def some_event(self, x, y):
pass
@@ -797,7 +797,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom listen functions which change the listener function
signature."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _listen(cls, event_key, add=False):
@@ -855,7 +855,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, arg):
pass
@@ -889,7 +889,7 @@ class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
@@ -1109,7 +1109,7 @@ class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
diff --git a/test/base/test_inspect.py b/test/base/test_inspect.py
index 15b98c848..252d0d977 100644
--- a/test/base/test_inspect.py
+++ b/test/base/test_inspect.py
@@ -13,7 +13,7 @@ class TestFixture(object):
class TestInspection(fixtures.TestBase):
- def tearDown(self):
+ def teardown_test(self):
for type_ in list(inspection._registrars):
if issubclass(type_, TestFixture):
del inspection._registrars[type_]
diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py
index 14e87ef69..6320ef052 100644
--- a/test/base/test_tutorials.py
+++ b/test/base/test_tutorials.py
@@ -48,11 +48,11 @@ class DocTest(fixtures.TestBase):
ddl.sort_tables_and_constraints = self.orig_sort
- def setup(self):
+ def setup_test(self):
self._setup_logger()
self._setup_create_table_patcher()
- def teardown(self):
+ def teardown_test(self):
self._teardown_create_table_patcher()
self._teardown_logger()
diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py
index 8119612e1..f0bb66aa9 100644
--- a/test/dialect/mssql/test_compiler.py
+++ b/test/dialect/mssql/test_compiler.py
@@ -1814,7 +1814,7 @@ class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL):
class SchemaTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
t = Table(
"sometable",
MetaData(),
diff --git a/test/dialect/mssql/test_deprecations.py b/test/dialect/mssql/test_deprecations.py
index c869182c5..27709beb0 100644
--- a/test/dialect/mssql/test_deprecations.py
+++ b/test/dialect/mssql/test_deprecations.py
@@ -31,7 +31,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL):
"""
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.t1 = table(
"t1",
diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py
index cdb37cc61..b806b9247 100644
--- a/test/dialect/mssql/test_query.py
+++ b/test/dialect/mssql/test_query.py
@@ -455,7 +455,7 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
return testing.db.execution_options(isolation_level="AUTOCOMMIT")
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as conn:
@@ -463,7 +463,6 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
except:
pass
- super(MatchTest, cls).setup_class()
@classmethod
def insert_data(cls, connection):
diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py
index 62292b9da..7fd24e8b5 100644
--- a/test/dialect/mysql/test_compiler.py
+++ b/test/dialect/mysql/test_compiler.py
@@ -991,7 +991,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = mysql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"foos",
MetaData(),
@@ -1062,7 +1062,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpCommon(testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py
index 40617e59c..795b2cbd3 100644
--- a/test/dialect/mysql/test_reflection.py
+++ b/test/dialect/mysql/test_reflection.py
@@ -1115,7 +1115,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
class RawReflectionTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
dialect = mysql.dialect()
self.parser = _reflection.MySQLTableDefinitionParser(
dialect, dialect.identifier_preparer
diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py
index 1b8b3fb89..f09346eb3 100644
--- a/test/dialect/oracle/test_compiler.py
+++ b/test/dialect/oracle/test_compiler.py
@@ -1355,7 +1355,7 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "oracle"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py
index df87fe89f..32234bf65 100644
--- a/test/dialect/oracle/test_dialect.py
+++ b/test/dialect/oracle/test_dialect.py
@@ -439,7 +439,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as c:
c.exec_driver_sql(
"""
@@ -471,7 +471,7 @@ end;
assert isinstance(result.out_parameters["x_out"], int)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.execute(text("DROP PROCEDURE foo"))
diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py
index 81e4e4ab5..0df4236e2 100644
--- a/test/dialect/oracle/test_reflection.py
+++ b/test/dialect/oracle/test_reflection.py
@@ -39,7 +39,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# currently assuming full DBA privs for the user.
# don't really know how else to go here unless
# we connect as the other user.
@@ -85,7 +85,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
conn.exec_driver_sql(stmt)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for stmt in (
"""
@@ -379,7 +379,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("create table my_table (id integer)")
conn.exec_driver_sql(
@@ -389,7 +389,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
"create table foo_table (id integer) tablespace SYSTEM"
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table my_temp_table")
conn.exec_driver_sql("drop table my_table")
@@ -421,7 +421,7 @@ class DontReflectIOTTest(fixtures.TestBase):
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
@@ -438,7 +438,7 @@ class DontReflectIOTTest(fixtures.TestBase):
""",
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table admin_docindex")
@@ -715,7 +715,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.testing import config
cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link")
@@ -734,7 +734,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop synonym test_table_syn")
conn.exec_driver_sql("drop table test_table")
diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py
index f008ea019..8ea7c0e04 100644
--- a/test/dialect/oracle/test_types.py
+++ b/test/dialect/oracle/test_types.py
@@ -1011,7 +1011,7 @@ class EuroNumericTest(fixtures.TestBase):
__only_on__ = "oracle+cx_oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
connect = testing.db.pool._creator
def _creator():
@@ -1023,7 +1023,7 @@ class EuroNumericTest(fixtures.TestBase):
self.engine = testing_engine(options={"creator": _creator})
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_were_getting_a_comma(self):
diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py
index fadf939b8..f6d48f3c6 100644
--- a/test/dialect/postgresql/test_async_pg_py3k.py
+++ b/test/dialect/postgresql/test_async_pg_py3k.py
@@ -27,7 +27,7 @@ class AsyncPgTest(fixtures.TestBase):
# TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3
# merges
- from sqlalchemy.testing import engines
+ from sqlalchemy.testing import util as testing_util
from sqlalchemy.sql import schema
metadata = schema.MetaData()
@@ -35,7 +35,7 @@ class AsyncPgTest(fixtures.TestBase):
try:
yield metadata
finally:
- engines.drop_all_tables(metadata, testing.db)
+ testing_util.drop_all_tables_from_metadata(metadata, testing.db)
@async_test
async def test_detect_stale_ddl_cache_raise_recover(
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 1763b210b..b3a0b9bbd 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1810,7 +1810,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table1 = table1 = table(
"mytable",
column("myid", Integer),
@@ -2222,7 +2222,7 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
@@ -2373,7 +2373,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
@@ -2464,7 +2464,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "postgresql"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index f760a309b..9c9d817bb 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -198,17 +198,17 @@ class ExecuteManyMode(object):
@config.fixture()
def connection(self):
- eng = engines.testing_engine(options=self.options)
+ opts = dict(self.options)
+ opts["use_reaper"] = False
+ eng = engines.testing_engine(options=opts)
conn = eng.connect()
trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
- eng.dispose()
+ yield conn
+ if trans.is_active:
+ trans.rollback()
+ conn.close()
+ eng.dispose()
@classmethod
def define_tables(cls, metadata):
@@ -510,8 +510,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
# assert result.closed
assert result.cursor is None
- @testing.provide_metadata
- def test_insert_returning_preexecute_pk(self, connection):
+ def test_insert_returning_preexecute_pk(self, metadata, connection):
counter = itertools.count(1)
t = Table(
@@ -525,7 +524,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
),
Column("data", Integer),
)
- self.metadata.create_all(connection)
+ metadata.create_all(connection)
result = connection.execute(
t.insert().return_defaults(),
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index 94af168ee..c51fd1943 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -40,10 +40,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
self.metadata.drop_all(conn)
@@ -890,7 +890,7 @@ class ExtractTest(fixtures.TablesTest):
def setup_bind(cls):
from sqlalchemy import event
- eng = engines.testing_engine()
+ eng = engines.testing_engine(options={"scope": "class"})
@event.listens_for(eng, "connect")
def connect(dbapi_conn, rec):
diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py
index 754eff25a..6586a8308 100644
--- a/test/dialect/postgresql/test_reflection.py
+++ b/test/dialect/postgresql/test_reflection.py
@@ -80,26 +80,24 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
]:
sa.event.listen(metadata, "before_drop", sa.DDL(ddl))
- def test_foreign_table_is_reflected(self):
+ def test_foreign_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_foreigntable", metadata, autoload_with=testing.db)
+ table = Table("test_foreigntable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected foreign table didn't equal expected columns",
)
- def test_get_foreign_table_names(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- ft_names = inspector.get_foreign_table_names()
- eq_(ft_names, ["test_foreigntable"])
+ def test_get_foreign_table_names(self, connection):
+ inspector = inspect(connection)
+ ft_names = inspector.get_foreign_table_names()
+ eq_(ft_names, ["test_foreigntable"])
- def test_get_table_names_no_foreign(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- names = inspector.get_table_names()
- eq_(names, ["testtable"])
+ def test_get_table_names_no_foreign(self, connection):
+ inspector = inspect(connection)
+ names = inspector.get_table_names()
+ eq_(names, ["testtable"])
class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
@@ -133,22 +131,22 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
if testing.against("postgresql >= 11"):
Index("my_index", dv.c.q)
- def test_get_tablenames(self):
+ def test_get_tablenames(self, connection):
assert {"data_values", "data_values_4_10"}.issubset(
- inspect(testing.db).get_table_names()
+ inspect(connection).get_table_names()
)
- def test_reflect_cols(self):
- cols = inspect(testing.db).get_columns("data_values")
+ def test_reflect_cols(self, connection):
+ cols = inspect(connection).get_columns("data_values")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
- def test_reflect_cols_from_partition(self):
- cols = inspect(testing.db).get_columns("data_values_4_10")
+ def test_reflect_cols_from_partition(self, connection):
+ cols = inspect(connection).get_columns("data_values_4_10")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
@testing.only_on("postgresql >= 11")
- def test_reflect_index(self):
- idx = inspect(testing.db).get_indexes("data_values")
+ def test_reflect_index(self, connection):
+ idx = inspect(connection).get_indexes("data_values")
eq_(
idx,
[
@@ -162,8 +160,8 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
)
@testing.only_on("postgresql >= 11")
- def test_reflect_index_from_partition(self):
- idx = inspect(testing.db).get_indexes("data_values_4_10")
+ def test_reflect_index_from_partition(self, connection):
+ idx = inspect(connection).get_indexes("data_values_4_10")
# note the name appears to be generated by PG, currently
# 'data_values_4_10_q_idx'
eq_(
@@ -220,44 +218,43 @@ class MaterializedViewReflectionTest(
testtable, "before_drop", sa.DDL("DROP VIEW test_regview")
)
- def test_mview_is_reflected(self):
+ def test_mview_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
+ table = Table("test_mview", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected mview didn't equal expected columns",
)
- def test_mview_select(self):
+ def test_mview_select(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
- with testing.db.connect() as conn:
- eq_(conn.execute(table.select()).fetchall(), [(89, "d1")])
+ table = Table("test_mview", metadata, autoload_with=connection)
+ eq_(connection.execute(table.select()).fetchall(), [(89, "d1")])
- def test_get_view_names(self):
- insp = inspect(testing.db)
+ def test_get_view_names(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_plain(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
- def test_get_view_names_plain_string(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain_string(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names(include="plain")), set(["test_regview"]))
- def test_get_view_names_materialized(self):
- insp = inspect(testing.db)
+ def test_get_view_names_materialized(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("materialized",))),
set(["test_mview"]),
)
- def test_get_view_names_reflection_cache_ok(self):
- insp = inspect(testing.db)
+ def test_get_view_names_reflection_cache_ok(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
@@ -267,12 +264,12 @@ class MaterializedViewReflectionTest(
)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_empty(self):
- insp = inspect(testing.db)
+ def test_get_view_names_empty(self, connection):
+ insp = inspect(connection)
assert_raises(ValueError, insp.get_view_names, include=())
- def test_get_view_definition(self):
- insp = inspect(testing.db)
+ def test_get_view_definition(self, connection):
+ insp = inspect(connection)
eq_(
re.sub(
r"[\n\t ]+",
@@ -290,7 +287,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as con:
for ddl in [
'CREATE SCHEMA "SomeSchema"',
@@ -334,7 +331,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as con:
con.exec_driver_sql("DROP TABLE testtable")
con.exec_driver_sql("DROP TABLE test_schema.testtable")
@@ -350,9 +347,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
- def test_table_is_reflected(self):
+ def test_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["question", "answer"]),
@@ -360,9 +357,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
assert isinstance(table.c.answer.type, Integer)
- def test_domain_is_reflected(self):
+ def test_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"42",
@@ -372,28 +369,28 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
not table.columns.answer.nullable
), "Expected reflected column to not be nullable."
- def test_enum_domain_is_reflected(self):
+ def test_enum_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("enum_test", metadata, autoload_with=testing.db)
+ table = Table("enum_test", metadata, autoload_with=connection)
eq_(table.c.data.type.enums, ["test"])
- def test_array_domain_is_reflected(self):
+ def test_array_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("array_test", metadata, autoload_with=testing.db)
+ table = Table("array_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, ARRAY)
eq_(table.c.data.type.item_type.__class__, INTEGER)
- def test_quoted_remote_schema_domain_is_reflected(self):
+ def test_quoted_remote_schema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("quote_test", metadata, autoload_with=testing.db)
+ table = Table("quote_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, INTEGER)
- def test_table_is_reflected_test_schema(self):
+ def test_table_is_reflected_test_schema(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
@@ -403,12 +400,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
)
assert isinstance(table.c.anything.type, Integer)
- def test_schema_domain_is_reflected(self):
+ def test_schema_domain_is_reflected(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
@@ -420,9 +417,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_crosschema_domain_is_reflected(self):
+ def test_crosschema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("crosschema", metadata, autoload_with=testing.db)
+ table = Table("crosschema", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"0",
@@ -432,7 +429,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_unknown_types(self):
+ def test_unknown_types(self, connection):
from sqlalchemy.dialects.postgresql import base
ischema_names = base.PGDialect.ischema_names
@@ -440,13 +437,13 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
try:
m2 = MetaData()
assert_raises(
- exc.SAWarning, Table, "testtable", m2, autoload_with=testing.db
+ exc.SAWarning, Table, "testtable", m2, autoload_with=connection
)
@testing.emits_warning("Did not recognize type")
def warns():
m3 = MetaData()
- t3 = Table("testtable", m3, autoload_with=testing.db)
+ t3 = Table("testtable", m3, autoload_with=connection)
assert t3.c.answer.type.__class__ == sa.types.NullType
finally:
@@ -471,9 +468,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
subject = Table("subject", meta2, autoload_with=connection)
eq_(subject.primary_key.columns.keys(), ["p2", "p1"])
- @testing.provide_metadata
- def test_pg_weirdchar_reflection(self):
- meta1 = self.metadata
+ def test_pg_weirdchar_reflection(self, metadata, connection):
+ meta1 = metadata
subject = Table(
"subject", meta1, Column("id$", Integer, primary_key=True)
)
@@ -483,101 +479,91 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("ref", Integer, ForeignKey("subject.id$")),
)
- meta1.create_all(testing.db)
+ meta1.create_all(connection)
meta2 = MetaData()
- subject = Table("subject", meta2, autoload_with=testing.db)
- referer = Table("referer", meta2, autoload_with=testing.db)
+ subject = Table("subject", meta2, autoload_with=connection)
+ referer = Table("referer", meta2, autoload_with=connection)
self.assert_(
(subject.c["id$"] == referer.c.ref).compare(
subject.join(referer).onclause
)
)
- @testing.provide_metadata
- def test_reflect_default_over_128_chars(self):
+ def test_reflect_default_over_128_chars(self, metadata, connection):
Table(
"t",
- self.metadata,
+ metadata,
Column("x", String(200), server_default="abcd" * 40),
- ).create(testing.db)
+ ).create(connection)
m = MetaData()
- t = Table("t", m, autoload_with=testing.db)
+ t = Table("t", m, autoload_with=connection)
eq_(
t.c.x.server_default.arg.text,
"'%s'::character varying" % ("abcd" * 40),
)
- @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure")
- @testing.provide_metadata
- def test_renamed_sequence_reflection(self):
- metadata = self.metadata
+ def test_renamed_sequence_reflection(self, metadata, connection):
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db, implicit_returning=False)
+ t2 = Table("t", m2, autoload_with=connection, implicit_returning=False)
eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)")
- with testing.db.begin() as conn:
- r = conn.execute(t2.insert())
- eq_(r.inserted_primary_key, (1,))
+ r = connection.execute(t2.insert())
+ eq_(r.inserted_primary_key, (1,))
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t_id_seq rename to foobar_id_seq"
- )
+ connection.exec_driver_sql(
+ "alter table t_id_seq rename to foobar_id_seq"
+ )
m3 = MetaData()
- t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False)
+ t3 = Table("t", m3, autoload_with=connection, implicit_returning=False)
eq_(
t3.c.id.server_default.arg.text,
"nextval('foobar_id_seq'::regclass)",
)
- with testing.db.begin() as conn:
- r = conn.execute(t3.insert())
- eq_(r.inserted_primary_key, (2,))
+ r = connection.execute(t3.insert())
+ eq_(r.inserted_primary_key, (2,))
- @testing.provide_metadata
- def test_altered_type_autoincrement_pk_reflection(self):
- metadata = self.metadata
+ def test_altered_type_autoincrement_pk_reflection(
+ self, metadata, connection
+ ):
+ metadata = metadata
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t alter column id type varchar(50)"
- )
+ connection.exec_driver_sql(
+ "alter table t alter column id type varchar(50)"
+ )
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_(t2.c.id.autoincrement, False)
eq_(t2.c.x.autoincrement, False)
- @testing.provide_metadata
- def test_renamed_pk_reflection(self):
- metadata = self.metadata
+ def test_renamed_pk_reflection(self, metadata, connection):
+ metadata = metadata
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("alter table t rename id to t_id")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("alter table t rename id to t_id")
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_([c.name for c in t2.primary_key], ["t_id"])
- @testing.provide_metadata
- def test_has_temporary_table(self):
- assert not inspect(testing.db).has_table("some_temp_table")
+ def test_has_temporary_table(self, metadata, connection):
+ assert not inspect(connection).has_table("some_temp_table")
user_tmp = Table(
"some_temp_table",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
prefixes=["TEMPORARY"],
)
- user_tmp.create(testing.db)
- assert inspect(testing.db).has_table("some_temp_table")
+ user_tmp.create(connection)
+ assert inspect(connection).has_table("some_temp_table")
def test_cross_schema_reflection_one(self, metadata, connection):
@@ -898,19 +884,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
A_table.create(connection, checkfirst=True)
assert inspect(connection).has_table("A")
- def test_uppercase_lowercase_sequence(self):
+ def test_uppercase_lowercase_sequence(self, connection):
a_seq = Sequence("a")
A_seq = Sequence("A")
- a_seq.create(testing.db)
- assert testing.db.dialect.has_sequence(testing.db, "a")
- assert not testing.db.dialect.has_sequence(testing.db, "A")
- A_seq.create(testing.db, checkfirst=True)
- assert testing.db.dialect.has_sequence(testing.db, "A")
+ a_seq.create(connection)
+ assert connection.dialect.has_sequence(connection, "a")
+ assert not connection.dialect.has_sequence(connection, "A")
+ A_seq.create(connection, checkfirst=True)
+ assert connection.dialect.has_sequence(connection, "A")
- a_seq.drop(testing.db)
- A_seq.drop(testing.db)
+ a_seq.drop(connection)
+ A_seq.drop(connection)
def test_index_reflection(self, metadata, connection):
"""Reflecting expression-based indexes should warn"""
@@ -960,11 +946,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_index_reflection_partial(self, connection):
+ def test_index_reflection_partial(self, metadata, connection):
"""Reflect the filter defintion on partial indexes"""
- metadata = self.metadata
+ metadata = metadata
t1 = Table(
"table1",
@@ -978,7 +963,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
metadata.create_all(connection)
- ind = testing.db.dialect.get_indexes(connection, t1, None)
+ ind = connection.dialect.get_indexes(connection, t1, None)
partial_definitions = []
for ix in ind:
@@ -1073,15 +1058,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
compile_exprs(r3.expressions),
)
- @testing.provide_metadata
- def test_index_reflection_modified(self):
+ def test_index_reflection_modified(self, metadata, connection):
"""reflect indexes when a column name has changed - PG 9
does not update the name of the column in the index def.
[ticket:2141]
"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
@@ -1089,26 +1073,21 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
- conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
+ connection.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {"name": "idx1", "unique": False, "column_names": ["y"]}
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
+ ind = connection.dialect.get_indexes(connection, "t", None)
+ expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
- eq_(ind, expected)
+ eq_(ind, expected)
- @testing.fails_if("postgresql < 8.2", "reloptions not supported")
- @testing.provide_metadata
- def test_index_reflection_with_storage_options(self):
+ def test_index_reflection_with_storage_options(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
@@ -1116,70 +1095,63 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
- )
+ connection.exec_driver_sql(
+ "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
+ )
- ind = testing.db.dialect.get_indexes(conn, "t", None)
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {
- "postgresql_with": {"fillfactor": "50"}
- },
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_with": {"fillfactor": "50"}},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["with"],
- {"fillfactor": "50"},
- )
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["with"],
+ {"fillfactor": "50"},
+ )
- @testing.provide_metadata
- def test_index_reflection_with_access_method(self):
+ def test_index_reflection_with_access_method(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
-
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", ARRAY(Integer)),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {"postgresql_using": "gin"},
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["using"],
- "gin",
- )
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_using": "gin"},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["using"],
+ "gin",
+ )
@testing.skip_if("postgresql < 11.0", "indnkeyatts not supported")
def test_index_reflection_with_include(self, metadata, connection):
@@ -1199,7 +1171,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
# [{'column_names': ['x', 'name'],
# 'name': 'idx1', 'unique': False}]
- ind = testing.db.dialect.get_indexes(connection, "t", None)
+ ind = connection.dialect.get_indexes(connection, "t", None)
eq_(
ind,
[
@@ -1286,15 +1258,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
for fk in fks:
eq_(fk, fk_ref[fk["name"]])
- @testing.provide_metadata
- def test_inspect_enums_schema(self, connection):
+ def test_inspect_enums_schema(self, metadata, connection):
enum_type = postgresql.ENUM(
"sad",
"ok",
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
enum_type.create(connection)
inspector = inspect(connection)
@@ -1310,13 +1281,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums(self):
+ def test_inspect_enums(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
[
@@ -1329,17 +1299,16 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive(self):
+ def test_inspect_enums_case_sensitive(self, metadata, connection):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
for enum in "lower_case", "UpperCase", "Name.With.Dot":
@@ -1350,11 +1319,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"CapsTwo",
name=enum,
schema=schema,
- metadata=self.metadata,
+ metadata=metadata,
)
- self.metadata.create_all(testing.db)
- inspector = inspect(testing.db)
+ metadata.create_all(connection)
+ inspector = inspect(connection)
for schema in None, "test_schema", "TestSchema":
eq_(
sorted(
@@ -1382,17 +1351,18 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive_from_table(self):
+ def test_inspect_enums_case_sensitive_from_table(
+ self, metadata, connection
+ ):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
counter = itertools.count()
@@ -1403,19 +1373,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"CapsOne",
"CapsTwo",
name=enum,
- metadata=self.metadata,
+ metadata=metadata,
schema=schema,
)
Table(
"t%d" % next(counter),
- self.metadata,
+ metadata,
Column("q", enum_type),
)
- self.metadata.create_all(testing.db)
+ metadata.create_all(connection)
- inspector = inspect(testing.db)
+ inspector = inspect(connection)
counter = itertools.count()
for enum in "lower_case", "UpperCase", "Name.With.Dot":
for schema in None, "test_schema", "TestSchema":
@@ -1439,10 +1409,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enums_star(self):
+ def test_inspect_enums_star(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
schema_enum_type = postgresql.ENUM(
"sad",
@@ -1450,11 +1419,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
- enum_type.create(testing.db)
- schema_enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ schema_enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
@@ -1486,11 +1455,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty(self):
- enum_type = postgresql.ENUM(name="empty", metadata=self.metadata)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ def test_inspect_enum_empty(self, metadata, connection):
+ enum_type = postgresql.ENUM(name="empty", metadata=metadata)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
@@ -1504,13 +1472,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty_from_table(self):
+ def test_inspect_enum_empty_from_table(self, metadata, connection):
Table(
- "t", self.metadata, Column("x", postgresql.ENUM(name="empty"))
- ).create(testing.db)
+ "t", metadata, Column("x", postgresql.ENUM(name="empty"))
+ ).create(connection)
- t = Table("t", MetaData(), autoload_with=testing.db)
+ t = Table("t", MetaData(), autoload_with=connection)
eq_(t.c.x.type.enums, [])
def test_reflection_with_unique_constraint(self, metadata, connection):
@@ -1749,12 +1716,12 @@ class CustomTypeReflectionTest(fixtures.TestBase):
ischema_names = None
- def setup(self):
+ def setup_test(self):
ischema_names = postgresql.PGDialect.ischema_names
postgresql.PGDialect.ischema_names = ischema_names.copy()
self.ischema_names = ischema_names
- def teardown(self):
+ def teardown_test(self):
postgresql.PGDialect.ischema_names = self.ischema_names
self.ischema_names = None
@@ -1788,55 +1755,51 @@ class IntervalReflectionTest(fixtures.TestBase):
__only_on__ = "postgresql"
__backend__ = True
- def test_interval_types(self):
- for sym in [
- "YEAR",
- "MONTH",
- "DAY",
- "HOUR",
- "MINUTE",
- "SECOND",
- "YEAR TO MONTH",
- "DAY TO HOUR",
- "DAY TO MINUTE",
- "DAY TO SECOND",
- "HOUR TO MINUTE",
- "HOUR TO SECOND",
- "MINUTE TO SECOND",
- ]:
- self._test_interval_symbol(sym)
-
- @testing.provide_metadata
- def _test_interval_symbol(self, sym):
+ @testing.combinations(
+ ("YEAR",),
+ ("MONTH",),
+ ("DAY",),
+ ("HOUR",),
+ ("MINUTE",),
+ ("SECOND",),
+ ("YEAR TO MONTH",),
+ ("DAY TO HOUR",),
+ ("DAY TO MINUTE",),
+ ("DAY TO SECOND",),
+ ("HOUR TO MINUTE",),
+ ("HOUR TO SECOND",),
+ ("MINUTE TO SECOND",),
+ argnames="sym",
+ )
+ def test_interval_types(self, sym, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(fields=sym)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, sym.lower())
eq_(columns["data1"]["type"].precision, None)
- @testing.provide_metadata
- def test_interval_precision(self):
+ def test_interval_precision(self, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(precision=6)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, None)
@@ -1871,8 +1834,8 @@ class IdentityReflectionTest(fixtures.TablesTest):
Column("id4", SmallInteger, Identity()),
)
- def test_reflect_identity(self):
- insp = inspect(testing.db)
+ def test_reflect_identity(self, connection):
+ insp = inspect(connection)
default = dict(
always=False,
start=1,
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index e8a1876c7..6202f8f86 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -49,7 +49,6 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.sql import operators
from sqlalchemy.sql import sqltypes
-from sqlalchemy.testing import engines
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
@@ -156,8 +155,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql > 8.3"
- @testing.provide_metadata
- def test_create_table(self, connection):
+ def test_create_table(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"table",
@@ -177,8 +175,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[(1, "two"), (2, "three"), (3, "three")],
)
- @testing.combinations(None, "foo")
- def test_create_table_schema_translate_map(self, symbol_name):
+ @testing.combinations(None, "foo", argnames="symbol_name")
+ def test_create_table_schema_translate_map(self, connection, symbol_name):
# note we can't use the fixture here because it will not drop
# from the correct schema
metadata = MetaData()
@@ -199,35 +197,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
),
schema=symbol_name,
)
- with testing.db.begin() as conn:
- conn = conn.execution_options(
- schema_translate_map={symbol_name: testing.config.test_schema}
- )
- t1.create(conn)
- assert "schema_enum" in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.create(conn, checkfirst=True)
+ conn = connection.execution_options(
+ schema_translate_map={symbol_name: testing.config.test_schema}
+ )
+ t1.create(conn)
+ assert "schema_enum" in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.create(conn, checkfirst=True)
- conn.execute(t1.insert(), value="two")
- conn.execute(t1.insert(), value="three")
- conn.execute(t1.insert(), value="three")
- eq_(
- conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
- [(1, "two"), (2, "three"), (3, "three")],
- )
+ conn.execute(t1.insert(), value="two")
+ conn.execute(t1.insert(), value="three")
+ conn.execute(t1.insert(), value="three")
+ eq_(
+ conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
+ [(1, "two"), (2, "three"), (3, "three")],
+ )
- t1.drop(conn)
- assert "schema_enum" not in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.drop(conn, checkfirst=True)
+ t1.drop(conn)
+ assert "schema_enum" not in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.drop(conn, checkfirst=True)
def test_name_required(self, metadata, connection):
etype = Enum("four", "five", "six", metadata=metadata)
@@ -270,8 +263,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
[util.u("réveillé"), util.u("drôle"), util.u("S’il")],
)
- @testing.provide_metadata
- def test_non_native_enum(self, connection):
+ def test_non_native_enum(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
@@ -290,10 +282,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
@@ -307,8 +299,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": "two"})
eq_(connection.scalar(select(t1.c.bar)), "two")
- @testing.provide_metadata
- def test_non_native_enum_w_unicode(self, connection):
+ def test_non_native_enum_w_unicode(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
@@ -326,10 +317,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
@@ -346,8 +337,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
connection.execute(t1.insert(), {"bar": util.u("Ü")})
eq_(connection.scalar(select(t1.c.bar)), util.u("Ü"))
- @testing.provide_metadata
- def test_disable_create(self):
+ def test_disable_create(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM(
@@ -357,13 +347,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
t1 = Table("e1", metadata, Column("c1", e1))
# table can be created separately
# without conflict
- e1.create(bind=testing.db)
- t1.create(testing.db)
- t1.drop(testing.db)
- e1.drop(bind=testing.db)
+ e1.create(bind=connection)
+ t1.create(connection)
+ t1.drop(connection)
+ e1.drop(bind=connection)
- @testing.provide_metadata
- def test_dont_keep_checking(self, connection):
+ def test_dont_keep_checking(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM("one", "two", "three", name="myenum")
@@ -560,11 +549,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
e["name"] for e in inspect(connection).get_enums()
]
- def test_non_native_dialect(self):
- engine = engines.testing_engine()
+ def test_non_native_dialect(self, metadata, testing_engine):
+ engine = testing_engine()
engine.connect()
engine.dialect.supports_native_enum = False
- metadata = MetaData()
t1 = Table(
"foo",
metadata,
@@ -583,21 +571,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
def go():
t1.create(engine)
- try:
- self.assert_sql(
- engine,
- go,
- [
- (
- "CREATE TABLE foo (bar "
- "VARCHAR(5), CONSTRAINT myenum CHECK "
- "(bar IN ('one', 'two', 'three')))",
- {},
- )
- ],
- )
- finally:
- metadata.drop_all(engine)
+ self.assert_sql(
+ engine,
+ go,
+ [
+ (
+ "CREATE TABLE foo (bar "
+ "VARCHAR(5), CONSTRAINT myenum CHECK "
+ "(bar IN ('one', 'two', 'three')))",
+ {},
+ )
+ ],
+ )
def test_standalone_enum(self, connection, metadata):
etype = Enum(
@@ -605,26 +590,26 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
)
etype.create(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
etype.drop(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
metadata.create_all(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
metadata.drop_all(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
- def test_no_support(self):
+ def test_no_support(self, testing_engine):
def server_version_info(self):
return (8, 2)
- e = engines.testing_engine()
+ e = testing_engine()
dialect = e.dialect
dialect._get_server_version_info = server_version_info
@@ -692,8 +677,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
eq_(t2.c.value2.type.name, "fourfivesixtype")
eq_(t2.c.value2.type.schema, "test_schema")
- @testing.provide_metadata
- def test_custom_subclass(self, connection):
+ def test_custom_subclass(self, metadata, connection):
class MyEnum(TypeDecorator):
impl = Enum("oneHI", "twoHI", "threeHI", name="myenum")
@@ -708,13 +692,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
return value
t1 = Table("table1", self.metadata, Column("data", MyEnum()))
- self.metadata.create_all(testing.db)
+ self.metadata.create_all(connection)
connection.execute(t1.insert(), {"data": "two"})
eq_(connection.scalar(select(t1.c.data)), "twoHITHERE")
- @testing.provide_metadata
- def test_generic_w_pg_variant(self, connection):
+ def test_generic_w_pg_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
@@ -752,8 +735,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
e["name"] for e in inspect(connection).get_enums()
]
- @testing.provide_metadata
- def test_generic_w_some_other_variant(self, connection):
+ def test_generic_w_some_other_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
@@ -809,26 +791,28 @@ class RegClassTest(fixtures.TestBase):
__only_on__ = "postgresql"
__backend__ = True
- @staticmethod
- def _scalar(expression):
- with testing.db.connect() as conn:
- return conn.scalar(select(expression))
+ @testing.fixture()
+ def scalar(self, connection):
+ def go(expression):
+ return connection.scalar(select(expression))
- def test_cast_name(self):
- eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+ return go
- def test_cast_path(self):
+ def test_cast_name(self, scalar):
+ eq_(scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+
+ def test_cast_path(self, scalar):
eq_(
- self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
+ scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
"pg_class",
)
- def test_cast_oid(self):
+ def test_cast_oid(self, scalar):
regclass = cast("pg_class", postgresql.REGCLASS)
- oid = self._scalar(cast(regclass, postgresql.OID))
+ oid = scalar(cast(regclass, postgresql.OID))
assert isinstance(oid, int)
eq_(
- self._scalar(
+ scalar(
cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS)
),
"pg_class",
@@ -1339,13 +1323,12 @@ class ArrayRoundTripTest(object):
Column("dimarr", ProcValue),
)
- def _fixture_456(self, table):
- with testing.db.begin() as conn:
- conn.execute(table.insert(), intarr=[4, 5, 6])
+ def _fixture_456(self, table, connection):
+ connection.execute(table.insert(), intarr=[4, 5, 6])
- def test_reflect_array_column(self):
+ def test_reflect_array_column(self, connection):
metadata2 = MetaData()
- tbl = Table("arrtable", metadata2, autoload_with=testing.db)
+ tbl = Table("arrtable", metadata2, autoload_with=connection)
assert isinstance(tbl.c.intarr.type, self.ARRAY)
assert isinstance(tbl.c.strarr.type, self.ARRAY)
assert isinstance(tbl.c.intarr.type.item_type, Integer)
@@ -1564,7 +1547,7 @@ class ArrayRoundTripTest(object):
def test_array_getitem_single_exec(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(connection.scalar(select(arrtable.c.intarr[2])), 5)
connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7}))
eq_(connection.scalar(select(arrtable.c.intarr[2])), 7)
@@ -1654,11 +1637,10 @@ class ArrayRoundTripTest(object):
set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]),
)
- def test_array_plus_native_enum_create(self):
- m = MetaData()
+ def test_array_plus_native_enum_create(self, metadata, connection):
t = Table(
"t",
- m,
+ metadata,
Column(
"data_1",
self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")),
@@ -1669,13 +1651,13 @@ class ArrayRoundTripTest(object):
),
)
- t.create(testing.db)
+ t.create(connection)
eq_(
- set(e["name"] for e in inspect(testing.db).get_enums()),
+ set(e["name"] for e in inspect(connection).get_enums()),
set(["my_enum_1", "my_enum_2"]),
)
- t.drop(testing.db)
- eq_(inspect(testing.db).get_enums(), [])
+ t.drop(connection)
+ eq_(inspect(connection).get_enums(), [])
class CoreArrayRoundTripTest(
@@ -1690,33 +1672,35 @@ class PGArrayRoundTripTest(
):
ARRAY = postgresql.ARRAY
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_undim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_undim_array_contains_typed_exec(self, struct, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(arrtable.c.intarr).where(
- arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(arrtable.c.intarr).where(
+ arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_dim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_dim_array_contains_typed_exec(self, struct, connection):
dim_arrtable = self.tables.dim_arrtable
- self._fixture_456(dim_arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(dim_arrtable.c.intarr).where(
- dim_arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(dim_arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(dim_arrtable.c.intarr).where(
+ dim_arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
def test_array_contained_by_exec(self, connection):
arrtable = self.tables.arrtable
@@ -1730,7 +1714,7 @@ class PGArrayRoundTripTest(
def test_undim_array_empty(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(
connection.scalar(
select(arrtable.c.intarr).where(arrtable.c.intarr.contains([]))
@@ -1782,8 +1766,9 @@ class ArrayEnum(fixtures.TestBase):
sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls"
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
- @testing.provide_metadata
- def test_raises_non_native_enums(self, array_cls, enum_cls):
+ def test_raises_non_native_enums(
+ self, metadata, connection, array_cls, enum_cls
+ ):
Table(
"my_table",
self.metadata,
@@ -1808,7 +1793,7 @@ class ArrayEnum(fixtures.TestBase):
"for ARRAY of non-native ENUM; please specify "
"create_constraint=False on this Enum datatype.",
self.metadata.create_all,
- testing.db,
+ connection,
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
@@ -1818,8 +1803,7 @@ class ArrayEnum(fixtures.TestBase):
(_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
argnames="array_cls",
)
- @testing.provide_metadata
- def test_array_of_enums(self, array_cls, enum_cls, connection):
+ def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
tbl = Table(
"enum_table",
self.metadata,
@@ -1875,8 +1859,7 @@ class ArrayJSON(fixtures.TestBase):
@testing.combinations(
sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls"
)
- @testing.provide_metadata
- def test_array_of_json(self, array_cls, json_cls, connection):
+ def test_array_of_json(self, array_cls, json_cls, metadata, connection):
tbl = Table(
"json_table",
self.metadata,
@@ -1982,19 +1965,38 @@ class HashableFlagORMTest(fixtures.TestBase):
},
],
),
+ (
+ "HSTORE",
+ postgresql.HSTORE(),
+ [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
+ testing.requires.hstore,
+ ),
+ (
+ "JSONB",
+ postgresql.JSONB(),
+ [
+ {"a": "1", "b": "2", "c": "3"},
+ {
+ "d": "4",
+ "e": {"e1": "5", "e2": "6"},
+ "f": {"f1": [9, 10, 11]},
+ },
+ ],
+ testing.requires.postgresql_jsonb,
+ ),
+ argnames="type_,data",
id_="iaa",
)
- @testing.provide_metadata
- def test_hashable_flag(self, type_, data):
- Base = declarative_base(metadata=self.metadata)
+ def test_hashable_flag(self, metadata, connection, type_, data):
+ Base = declarative_base(metadata=metadata)
class A(Base):
__tablename__ = "a1"
id = Column(Integer, primary_key=True)
data = Column(type_)
- Base.metadata.create_all(testing.db)
- s = Session(testing.db)
+ Base.metadata.create_all(connection)
+ s = Session(connection)
s.add_all([A(data=elem) for elem in data])
s.commit()
@@ -2006,27 +2008,6 @@ class HashableFlagORMTest(fixtures.TestBase):
list(enumerate(data, 1)),
)
- @testing.requires.hstore
- def test_hstore(self):
- self.test_hashable_flag(
- postgresql.HSTORE(),
- [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
- )
-
- @testing.requires.postgresql_jsonb
- def test_jsonb(self):
- self.test_hashable_flag(
- postgresql.JSONB(),
- [
- {"a": "1", "b": "2", "c": "3"},
- {
- "d": "4",
- "e": {"e1": "5", "e2": "6"},
- "f": {"f1": [9, 10, 11]},
- },
- ],
- )
-
class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
@@ -2108,14 +2089,14 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
return table
- def test_reflection(self, special_types_table):
+ def test_reflection(self, special_types_table, connection):
# cheat so that the "strict type check"
# works
special_types_table.c.year_interval.type = postgresql.INTERVAL()
special_types_table.c.month_interval.type = postgresql.INTERVAL()
m = MetaData()
- t = Table("sometable", m, autoload_with=testing.db)
+ t = Table("sometable", m, autoload_with=connection)
self.assert_tables_equal(special_types_table, t, strict_types=True)
assert t.c.plain_interval.type.precision is None
@@ -2210,7 +2191,7 @@ class UUIDTest(fixtures.TestBase):
class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -2494,17 +2475,16 @@ class HStoreRoundTripTest(fixtures.TablesTest):
Column("data", HSTORE),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- conn.execute(
- data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
- {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
- {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
- {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
- )
+ connection.execute(
+ data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
+ {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
+ {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
+ {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
+ )
def _assert_data(self, compare, conn):
data = conn.execute(
@@ -2514,26 +2494,32 @@ class HStoreRoundTripTest(fixtures.TablesTest):
).fetchall()
eq_([d for d, in data], compare)
- def _test_insert(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- )
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ def _test_insert(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ )
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _non_native_engine(self):
- if testing.requires.psycopg2_native_hstore.enabled:
- engine = engines.testing_engine(
- options=dict(use_native_hstore=False)
- )
+ @testing.fixture
+ def non_native_hstore_connection(self, testing_engine):
+ local_engine = testing.requires.psycopg2_native_hstore.enabled
+
+ if local_engine:
+ engine = testing_engine(options=dict(use_native_hstore=False))
else:
engine = testing.db
- engine.connect().close()
- return engine
- def test_reflect(self):
- insp = inspect(testing.db)
+ conn = engine.connect()
+ trans = conn.begin()
+ yield conn
+ try:
+ trans.rollback()
+ finally:
+ conn.close()
+
+ def test_reflect(self, connection):
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[2]["type"], HSTORE)
@@ -2548,106 +2534,88 @@ class HStoreRoundTripTest(fixtures.TablesTest):
eq_(connection.scalar(select(expr)), "3")
@testing.requires.psycopg2_native_hstore
- def test_insert_native(self):
- engine = testing.db
- self._test_insert(engine)
+ def test_insert_native(self, connection):
+ self._test_insert(connection)
- def test_insert_python(self):
- engine = self._non_native_engine()
- self._test_insert(engine)
+ def test_insert_python(self, non_native_hstore_connection):
+ self._test_insert(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_criterion_native(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_native(self, connection):
+ self._fixture_data(connection)
+ self._test_criterion(connection)
- def test_criterion_python(self):
- engine = self._non_native_engine()
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_python(self, non_native_hstore_connection):
+ self._fixture_data(non_native_hstore_connection)
+ self._test_criterion(non_native_hstore_connection)
- def _test_criterion(self, engine):
+ def _test_criterion(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"] == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+ result = connection.execute(
+ select(data_table.c.data).where(data_table.c.data["k1"] == "r3v1")
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
- def _test_fixed_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(["key1", "key2", "key3"]),
- array(["value1", "value2", "value3"]),
- )
- )
- eq_(
- conn.scalar(s),
- {"key1": "value1", "key2": "value2", "key3": "value3"},
+ def _test_fixed_round_trip(self, connection):
+ s = select(
+ hstore(
+ array(["key1", "key2", "key3"]),
+ array(["value1", "value2", "value3"]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
- def test_fixed_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_python(self, non_native_hstore_connection):
+ self._test_fixed_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_fixed_round_trip_native(self):
- engine = testing.db
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_native(self, connection):
+ self._test_fixed_round_trip(connection)
- def _test_unicode_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- )
- )
- eq_(
- conn.scalar(s),
- {
- util.u("réveillé"): util.u("réveillé"),
- util.u("drôle"): util.u("drôle"),
- util.u("S’il"): util.u("S’il"),
- },
+ def _test_unicode_round_trip(self, connection):
+ s = select(
+ hstore(
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {
+ util.u("réveillé"): util.u("réveillé"),
+ util.u("drôle"): util.u("drôle"),
+ util.u("S’il"): util.u("S’il"),
+ },
+ )
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_python(self, non_native_hstore_connection):
+ self._test_unicode_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_native(self):
- engine = testing.db
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_native(self, connection):
+ self._test_unicode_round_trip(connection)
- def test_escaped_quotes_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_python(
+ self, non_native_hstore_connection
+ ):
+ self._test_escaped_quotes_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_escaped_quotes_round_trip_native(self):
- engine = testing.db
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_native(self, connection):
+ self._test_escaped_quotes_round_trip(connection)
- def _test_escaped_quotes_round_trip(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
- )
- self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn)
+ def _test_escaped_quotes_round_trip(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
+ )
+ self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], connection)
- def test_orm_round_trip(self):
+ def test_orm_round_trip(self, connection):
from sqlalchemy import orm
class Data(object):
@@ -2656,13 +2624,14 @@ class HStoreRoundTripTest(fixtures.TablesTest):
self.data = data
orm.mapper(Data, self.tables.data_table)
- s = orm.Session(testing.db)
- d = Data(
- name="r1",
- data={"key1": "value1", "key2": "value2", "key3": "value3"},
- )
- s.add(d)
- eq_(s.query(Data.data, Data).all(), [(d.data, d)])
+
+ with orm.Session(connection) as s:
+ d = Data(
+ name="r1",
+ data={"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
+ s.add(d)
+ eq_(s.query(Data.data, Data).all(), [(d.data, d)])
class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
@@ -2671,7 +2640,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
# operator tests
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
table = Table(
"data_table",
MetaData(),
@@ -2852,10 +2821,10 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
def test_actual_type(self):
eq_(str(self._col_type()), self._col_str)
- def test_reflect(self):
+ def test_reflect(self, connection):
from sqlalchemy import inspect
- insp = inspect(testing.db)
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[0]["type"], self._col_type)
@@ -2986,8 +2955,8 @@ class _DateTimeTZRangeTests(object):
def tstzs(self):
if self._tstzs is None:
- with testing.db.begin() as conn:
- lower = conn.scalar(func.current_timestamp().select())
+ with testing.db.connect() as connection:
+ lower = connection.scalar(func.current_timestamp().select())
upper = lower + datetime.timedelta(1)
self._tstzs = (lower, upper)
return self._tstzs
@@ -3052,7 +3021,7 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip):
class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -3151,7 +3120,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
Column("nulldata", cls.data_type(none_as_null=True)),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
data = [
@@ -3162,8 +3131,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
{"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
{"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
]
- with engine.begin() as conn:
- conn.execute(data_table.insert(), data)
+ connection.execute(data_table.insert(), data)
return data
def _assert_data(self, compare, conn, column="data"):
@@ -3185,51 +3153,39 @@ class JSONRoundTripTest(fixtures.TablesTest):
).fetchall()
eq_([d for d, in data], [None])
- def _test_insert(self, conn):
- conn.execute(
+ def test_reflect(self, connection):
+ insp = inspect(connection)
+ cols = insp.get_columns("data_table")
+ assert isinstance(cols[2]["type"], self.data_type)
+
+ def test_insert(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
)
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _test_insert_nulls(self, conn):
- conn.execute(
+ def test_insert_nulls(self, connection):
+ connection.execute(
self.tables.data_table.insert(), {"name": "r1", "data": null()}
)
- self._assert_data([None], conn)
+ self._assert_data([None], connection)
- def _test_insert_none_as_null(self, conn):
- conn.execute(
+ def test_insert_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": None},
)
- self._assert_column_is_NULL(conn, column="nulldata")
+ self._assert_column_is_NULL(connection, column="nulldata")
- def _test_insert_nulljson_into_none_as_null(self, conn):
- conn.execute(
+ def test_insert_nulljson_into_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": JSON.NULL},
)
- self._assert_column_is_JSON_NULL(conn, column="nulldata")
-
- def test_reflect(self):
- insp = inspect(testing.db)
- cols = insp.get_columns("data_table")
- assert isinstance(cols[2]["type"], self.data_type)
-
- def test_insert(self, connection):
- self._test_insert(connection)
-
- def test_insert_nulls(self, connection):
- self._test_insert_nulls(connection)
+ self._assert_column_is_JSON_NULL(connection, column="nulldata")
- def test_insert_none_as_null(self, connection):
- self._test_insert_none_as_null(connection)
-
- def test_insert_nulljson_into_none_as_null(self, connection):
- self._test_insert_nulljson_into_none_as_null(connection)
-
- def test_custom_serialize_deserialize(self):
+ def test_custom_serialize_deserialize(self, testing_engine):
import json
def loads(value):
@@ -3242,7 +3198,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
value["x"] = "dumps_y"
return json.dumps(value)
- engine = engines.testing_engine(
+ engine = testing_engine(
options=dict(json_serializer=dumps, json_deserializer=loads)
)
@@ -3250,14 +3206,26 @@ class JSONRoundTripTest(fixtures.TablesTest):
with engine.begin() as conn:
eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
- def test_criterion(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion(self, connection):
+ self._fixture_data(connection)
+ data_table = self.tables.data_table
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext.cast(String) == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
def test_path_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
@@ -3271,8 +3239,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
"postgresql < 9.4", "Improvement in PostgreSQL behavior?"
)
def test_multi_index_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
@@ -3283,20 +3250,18 @@ class JSONRoundTripTest(fixtures.TablesTest):
eq_(result.scalar(), "r6")
def test_query_returned_as_text(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k1"].astext)
).first()
- if engine.dialect.returns_unicode_strings:
+ if connection.dialect.returns_unicode_strings:
assert isinstance(result[0], util.text_type)
else:
assert isinstance(result[0], util.string_types)
def test_query_returned_as_int(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k3"].astext.cast(Integer)).where(
@@ -3305,23 +3270,6 @@ class JSONRoundTripTest(fixtures.TablesTest):
).first()
assert isinstance(result[0], int)
- def _test_criterion(self, engine):
- data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext.cast(String) == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
def test_fixed_round_trip(self, connection):
s = select(
cast(
@@ -3352,42 +3300,41 @@ class JSONRoundTripTest(fixtures.TablesTest):
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
def test_literal(self, connection):
- exp = self._fixture_data(testing.db)
+ exp = self._fixture_data(connection)
result = connection.exec_driver_sql(
"select data from data_table order by name"
)
@@ -3395,11 +3342,10 @@ class JSONRoundTripTest(fixtures.TablesTest):
eq_(len(res), len(exp))
for row, expected in zip(res, exp):
eq_(row[0], expected["data"])
- result.close()
class JSONBTest(JSONTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py
index 4658b40a8..1926c6065 100644
--- a/test/dialect/test_sqlite.py
+++ b/test/dialect/test_sqlite.py
@@ -37,6 +37,7 @@ from sqlalchemy import UniqueConstraint
from sqlalchemy import util
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.sqlite import insert
+from sqlalchemy.dialects.sqlite import provision
from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect
from sqlalchemy.engine.url import make_url
from sqlalchemy.schema import CreateTable
@@ -46,6 +47,7 @@ from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import combinations
+from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
@@ -774,7 +776,7 @@ class AttachedDBTest(fixtures.TestBase):
def _fixture(self):
meta = self.metadata
- self.conn = testing.db.connect()
+ self.conn = self.engine.connect()
Table("created", meta, Column("foo", Integer), Column("bar", String))
Table("local_only", meta, Column("q", Integer), Column("p", Integer))
@@ -798,14 +800,20 @@ class AttachedDBTest(fixtures.TestBase):
meta.create_all(self.conn)
return ct
- def setup(self):
- self.conn = testing.db.connect()
+ def setup_test(self):
+ self.engine = engines.testing_engine(options={"use_reaper": False})
+
+ provision._sqlite_post_configure_engine(
+ self.engine.url, self.engine, config.ident
+ )
+ self.conn = self.engine.connect()
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with self.conn.begin():
self.metadata.drop_all(self.conn)
self.conn.close()
+ self.engine.dispose()
def test_no_tables(self):
insp = inspect(self.conn)
@@ -1495,7 +1503,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
__skip_if__ = (full_text_search_missing,)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata, cattable, matchtable
metadata = MetaData()
exec_sql(
@@ -1559,7 +1567,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_expression(self):
@@ -1681,7 +1689,7 @@ class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL):
class ReflectHeadlessFKsTest(fixtures.TestBase):
__only_on__ = "sqlite"
- def setup(self):
+ def setup_test(self):
exec_sql(testing.db, "CREATE TABLE a (id INTEGER PRIMARY KEY)")
# this syntax actually works on other DBs perhaps we'd want to add
# tests to test_reflection
@@ -1689,7 +1697,7 @@ class ReflectHeadlessFKsTest(fixtures.TestBase):
testing.db, "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)"
)
- def teardown(self):
+ def teardown_test(self):
exec_sql(testing.db, "drop table b")
exec_sql(testing.db, "drop table a")
@@ -1728,7 +1736,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
__only_on__ = "sqlite"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)")
@@ -1876,7 +1884,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for name in [
"implicit_referrer_comp_fake",
@@ -2370,7 +2378,7 @@ class SavepointTest(fixtures.TablesTest):
@classmethod
def setup_bind(cls):
- engine = engines.testing_engine(options={"use_reaper": False})
+ engine = engines.testing_engine(options={"scope": "class"})
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
@@ -2579,7 +2587,7 @@ class TypeReflectionTest(fixtures.TestBase):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "sqlite"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py
index 396b48aa4..baa766d48 100644
--- a/test/engine/test_ddlevents.py
+++ b/test/engine/test_ddlevents.py
@@ -21,7 +21,7 @@ from sqlalchemy.testing.schema import Table
class DDLEventTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.bind = engines.mock_engine()
self.metadata = MetaData()
self.table = Table("t", self.metadata, Column("id", Integer))
@@ -374,7 +374,7 @@ class DDLEventTest(fixtures.TestBase):
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py
index a18cf756b..0a2c9abe5 100644
--- a/test/engine/test_deprecations.py
+++ b/test/engine/test_deprecations.py
@@ -965,7 +965,7 @@ class TransactionTest(fixtures.TablesTest):
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -1021,18 +1021,18 @@ def MockDBAPI(): # noqa
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _queuepool_fixture(self, **kw):
@@ -1597,7 +1597,7 @@ class EngineEventsTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -1650,6 +1650,7 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(
engine, "before_cursor_execute", cursor_execute, retval=True
)
+
with testing.expect_deprecated(
r"The argument signature for the "
r"\"ConnectionEvents.before_execute\" event listener",
@@ -1676,11 +1677,12 @@ class EngineEventsTest(fixtures.TestBase):
r"The argument signature for the "
r"\"ConnectionEvents.after_execute\" event listener",
):
- e1.execute(select(1))
+ result = e1.execute(select(1))
+ result.close()
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 21d4e06e0..a1e4ea218 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -43,7 +43,6 @@ from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing.mock import call
from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.mock import patch
@@ -94,13 +93,13 @@ class ExecuteTest(fixtures.TablesTest):
).default_from()
)
- conn = testing.db.connect()
- result = (
- conn.execution_options(no_parameters=True)
- .exec_driver_sql(stmt)
- .scalar()
- )
- eq_(result, "%")
+ with testing.db.connect() as conn:
+ result = (
+ conn.execution_options(no_parameters=True)
+ .exec_driver_sql(stmt)
+ .scalar()
+ )
+ eq_(result, "%")
def test_raw_positional_invalid(self, connection):
assert_raises_message(
@@ -261,16 +260,15 @@ class ExecuteTest(fixtures.TablesTest):
(4, "sally"),
]
- @testing.engines.close_open_connections
def test_exception_wrapping_dbapi(self):
- conn = testing.db.connect()
- # engine does not have exec_driver_sql
- assert_raises_message(
- tsa.exc.DBAPIError,
- r"not_a_valid_statement",
- conn.exec_driver_sql,
- "not_a_valid_statement",
- )
+ with testing.db.connect() as conn:
+ # engine does not have exec_driver_sql
+ assert_raises_message(
+ tsa.exc.DBAPIError,
+ r"not_a_valid_statement",
+ conn.exec_driver_sql,
+ "not_a_valid_statement",
+ )
@testing.requires.sqlite
def test_exception_wrapping_non_dbapi_error(self):
@@ -864,12 +862,10 @@ class CompiledCacheTest(fixtures.TestBase):
["sqlite", "mysql", "postgresql"],
"uses blob value that is problematic for some DBAPIs",
)
- @testing.provide_metadata
- def test_cache_noleak_on_statement_values(self, connection):
+ def test_cache_noleak_on_statement_values(self, metadata, connection):
# This is a non regression test for an object reference leak caused
# by the compiled_cache.
- metadata = self.metadata
photo = Table(
"photo",
metadata,
@@ -1040,7 +1036,19 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
__requires__ = ("schemas",)
__backend__ = True
- def test_create_table(self):
+ @testing.fixture
+ def plain_tables(self, metadata):
+ t1 = Table(
+ "t1", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t2 = Table(
+ "t2", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t3 = Table("t3", metadata, Column("x", Integer), schema=None)
+
+ return t1, t2, t3
+
+ def test_create_table(self, plain_tables, connection):
map_ = {
None: config.test_schema,
"foo": config.test_schema,
@@ -1052,18 +1060,16 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- t1.create(conn)
- t2.create(conn)
- t3.create(conn)
+ t1.create(conn)
+ t2.create(conn)
+ t3.create(conn)
- t3.drop(conn)
- t2.drop(conn)
- t1.drop(conn)
+ t3.drop(conn)
+ t2.drop(conn)
+ t1.drop(conn)
asserter.assert_(
CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"),
@@ -1074,14 +1080,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DROP TABLE [SCHEMA__none].t1"),
)
- def _fixture(self):
- metadata = self.metadata
- Table("t1", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t2", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t3", metadata, Column("x", Integer), schema=None)
- metadata.create_all(testing.db)
-
- def test_ddl_hastable(self):
+ def test_ddl_hastable(self, plain_tables, connection):
map_ = {
None: config.test_schema,
@@ -1094,27 +1093,28 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
Table("t2", metadata, Column("x", Integer), schema="foo")
Table("t3", metadata, Column("x", Integer), schema="bar")
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.create_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+ metadata.create_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_true(insp.has_table("t1", schema=config.test_schema))
is_true(insp.has_table("t2", schema=config.test_schema))
is_true(insp.has_table("t3", schema=None))
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.drop_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+
+ # if this test fails, the tables won't get dropped. so need a
+ # more robust fixture for this
+ metadata.drop_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_false(insp.has_table("t1", schema=config.test_schema))
is_false(insp.has_table("t2", schema=config.test_schema))
is_false(insp.has_table("t3", schema=None))
- @testing.provide_metadata
- def test_option_on_execute(self):
- self._fixture()
+ def test_option_on_execute(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1127,61 +1127,54 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection
+ execution_options = {"schema_translate_map": map_}
+ conn._execute_20(
+ t1.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t2.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t3.insert(), {"x": 1}, execution_options=execution_options
+ )
- execution_options = {"schema_translate_map": map_}
- conn._execute_20(
- t1.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t2.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t3.insert(), {"x": 1}, execution_options=execution_options
- )
+ conn._execute_20(
+ t1.update().values(x=1).where(t1.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t2.update().values(x=2).where(t2.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t3.update().values(x=3).where(t3.c.x == 1),
+ execution_options=execution_options,
+ )
+ eq_(
conn._execute_20(
- t1.update().values(x=1).where(t1.c.x == 1),
- execution_options=execution_options,
- )
+ select(t1.c.x), execution_options=execution_options
+ ).scalar(),
+ 1,
+ )
+ eq_(
conn._execute_20(
- t2.update().values(x=2).where(t2.c.x == 1),
- execution_options=execution_options,
- )
+ select(t2.c.x), execution_options=execution_options
+ ).scalar(),
+ 2,
+ )
+ eq_(
conn._execute_20(
- t3.update().values(x=3).where(t3.c.x == 1),
- execution_options=execution_options,
- )
-
- eq_(
- conn._execute_20(
- select(t1.c.x), execution_options=execution_options
- ).scalar(),
- 1,
- )
- eq_(
- conn._execute_20(
- select(t2.c.x), execution_options=execution_options
- ).scalar(),
- 2,
- )
- eq_(
- conn._execute_20(
- select(t3.c.x), execution_options=execution_options
- ).scalar(),
- 3,
- )
+ select(t3.c.x), execution_options=execution_options
+ ).scalar(),
+ 3,
+ )
- conn._execute_20(
- t1.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t2.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t3.delete(), execution_options=execution_options
- )
+ conn._execute_20(t1.delete(), execution_options=execution_options)
+ conn._execute_20(t2.delete(), execution_options=execution_options)
+ conn._execute_20(t3.delete(), execution_options=execution_options)
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1207,9 +1200,9 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_crud(self):
- self._fixture()
+ def test_crud(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1222,26 +1215,24 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- conn.execute(t1.insert(), {"x": 1})
- conn.execute(t2.insert(), {"x": 1})
- conn.execute(t3.insert(), {"x": 1})
+ conn.execute(t1.insert(), {"x": 1})
+ conn.execute(t2.insert(), {"x": 1})
+ conn.execute(t3.insert(), {"x": 1})
- conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
- conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
- conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
+ conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
+ conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
+ conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
- eq_(conn.scalar(select(t1.c.x)), 1)
- eq_(conn.scalar(select(t2.c.x)), 2)
- eq_(conn.scalar(select(t3.c.x)), 3)
+ eq_(conn.scalar(select(t1.c.x)), 1)
+ eq_(conn.scalar(select(t2.c.x)), 2)
+ eq_(conn.scalar(select(t3.c.x)), 3)
- conn.execute(t1.delete())
- conn.execute(t2.delete())
- conn.execute(t3.delete())
+ conn.execute(t1.delete())
+ conn.execute(t2.delete())
+ conn.execute(t3.delete())
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1267,9 +1258,10 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_via_engine(self):
- self._fixture()
+ def test_via_engine(self, plain_tables, metadata):
+
+ with config.db.begin() as connection:
+ metadata.create_all(connection)
map_ = {
None: config.test_schema,
@@ -1282,25 +1274,25 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
with self.sql_execution_asserter(config.db) as asserter:
eng = config.db.execution_options(schema_translate_map=map_)
- conn = eng.connect()
- conn.execute(select(t2.c.x))
+ with eng.connect() as conn:
+ conn.execute(select(t2.c.x))
asserter.assert_(
CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2")
)
class ExecutionOptionsTest(fixtures.TestBase):
- def test_dialect_conn_options(self):
+ def test_dialect_conn_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c2 = conn.execution_options(foo="bar")
- eq_(
- engine.dialect.set_connection_execution_options.mock_calls,
- [call(c2, {"foo": "bar"})],
- )
+ with engine.connect() as conn:
+ c2 = conn.execution_options(foo="bar")
+ eq_(
+ engine.dialect.set_connection_execution_options.mock_calls,
+ [call(c2, {"foo": "bar"})],
+ )
- def test_dialect_engine_options(self):
+ def test_dialect_engine_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
@@ -1319,14 +1311,14 @@ class ExecutionOptionsTest(fixtures.TestBase):
[call(engine, {"foo": "bar"})],
)
- def test_propagate_engine_to_connection(self):
+ def test_propagate_engine_to_connection(self, testing_engine):
engine = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
- conn = engine.connect()
- eq_(conn._execution_options, {"foo": "bar"})
+ with engine.connect() as conn:
+ eq_(conn._execution_options, {"foo": "bar"})
- def test_propagate_option_engine_to_connection(self):
+ def test_propagate_option_engine_to_connection(self, testing_engine):
e1 = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
@@ -1336,27 +1328,30 @@ class ExecutionOptionsTest(fixtures.TestBase):
eq_(c1._execution_options, {"foo": "bar"})
eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"})
- def test_get_engine_execution_options(self):
+ c1.close()
+ c2.close()
+
+ def test_get_engine_execution_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
eq_(e2.get_execution_options(), {"foo": "bar"})
- def test_get_connection_execution_options(self):
+ def test_get_connection_execution_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c = conn.execution_options(foo="bar")
+ with engine.connect() as conn:
+ c = conn.execution_options(foo="bar")
- eq_(c.get_execution_options(), {"foo": "bar"})
+ eq_(c.get_execution_options(), {"foo": "bar"})
class EngineEventsTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -1376,7 +1371,7 @@ class EngineEventsTest(fixtures.TestBase):
):
break
- def test_per_engine_independence(self):
+ def test_per_engine_independence(self, testing_engine):
e1 = testing_engine(config.db_url)
e2 = testing_engine(config.db_url)
@@ -1400,7 +1395,7 @@ class EngineEventsTest(fixtures.TestBase):
conn.execute(s2)
eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2])
- def test_per_engine_plus_global(self):
+ def test_per_engine_plus_global(self, testing_engine):
canary = Mock()
event.listen(Engine, "before_execute", canary.be1)
e1 = testing_engine(config.db_url)
@@ -1409,8 +1404,6 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(e1, "before_execute", canary.be2)
event.listen(Engine, "before_execute", canary.be3)
- e1.connect()
- e2.connect()
with e1.connect() as conn:
conn.execute(select(1))
@@ -1424,7 +1417,7 @@ class EngineEventsTest(fixtures.TestBase):
eq_(canary.be2.call_count, 1)
eq_(canary.be3.call_count, 2)
- def test_per_connection_plus_engine(self):
+ def test_per_connection_plus_engine(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1442,9 +1435,14 @@ class EngineEventsTest(fixtures.TestBase):
eq_(canary.be1.call_count, 2)
eq_(canary.be2.call_count, 2)
- @testing.combinations((True, False), (True, True), (False, False))
+ @testing.combinations(
+ (True, False),
+ (True, True),
+ (False, False),
+ argnames="mock_out_on_connect, add_our_own_onconnect",
+ )
def test_insert_connect_is_definitely_first(
- self, mock_out_on_connect, add_our_own_onconnect
+ self, mock_out_on_connect, add_our_own_onconnect, testing_engine
):
"""test issue #5708.
@@ -1478,7 +1476,7 @@ class EngineEventsTest(fixtures.TestBase):
patcher = util.nullcontext()
with patcher:
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url)
initialize = e1.dialect.initialize
@@ -1559,10 +1557,11 @@ class EngineEventsTest(fixtures.TestBase):
conn.exec_driver_sql(select1(testing.db))
eq_(m1.mock_calls, [])
- def test_add_event_after_connect(self):
+ def test_add_event_after_connect(self, testing_engine):
# new feature as of #2978
+
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
conn = e1.connect()
@@ -1575,9 +1574,9 @@ class EngineEventsTest(fixtures.TestBase):
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 2)
- def test_force_conn_events_false(self):
+ def test_force_conn_events_false(self, testing_engine):
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
event.listen(e1, "before_execute", canary.be1)
@@ -1593,7 +1592,7 @@ class EngineEventsTest(fixtures.TestBase):
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 0)
- def test_cursor_events_ctx_execute_scalar(self):
+ def test_cursor_events_ctx_execute_scalar(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1620,7 +1619,7 @@ class EngineEventsTest(fixtures.TestBase):
[call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
)
- def test_cursor_events_execute(self):
+ def test_cursor_events_execute(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
@@ -1653,9 +1652,15 @@ class EngineEventsTest(fixtures.TestBase):
),
((), {"z": 10}, [], {"z": 10}, testing.requires.legacy_engine),
(({"z": 10},), {}, [], {"z": 10}),
+ argnames="multiparams, params, expected_multiparams, expected_params",
)
def test_modify_parameters_from_event_one(
- self, multiparams, params, expected_multiparams, expected_params
+ self,
+ multiparams,
+ params,
+ expected_multiparams,
+ expected_params,
+ testing_engine,
):
# this is testing both the normalization added to parameters
# as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as
@@ -1704,7 +1709,9 @@ class EngineEventsTest(fixtures.TestBase):
[(15,), (19,)],
)
- def test_modify_parameters_from_event_three(self, connection):
+ def test_modify_parameters_from_event_three(
+ self, connection, testing_engine
+ ):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
@@ -1721,7 +1728,7 @@ class EngineEventsTest(fixtures.TestBase):
with e1.connect() as conn:
conn.execute(select(literal("1")))
- def test_argument_format_execute(self):
+ def test_argument_format_execute(self, testing_engine):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
@@ -1956,9 +1963,9 @@ class EngineEventsTest(fixtures.TestBase):
)
@testing.requires.ad_hoc_engines
- def test_dispose_event(self):
+ def test_dispose_event(self, testing_engine):
canary = Mock()
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
event.listen(eng, "engine_disposed", canary)
conn = eng.connect()
@@ -2102,13 +2109,13 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(engine, "commit", tracker("commit"))
event.listen(engine, "rollback", tracker("rollback"))
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
@@ -2145,13 +2152,13 @@ class EngineEventsTest(fixtures.TestBase):
event.listen(engine, "commit", tracker("commit"), named=True)
event.listen(engine, "rollback", tracker("rollback"), named=True)
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
@@ -2310,7 +2317,7 @@ class HandleErrorTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
@@ -2742,7 +2749,7 @@ class HandleErrorTest(fixtures.TestBase):
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -3014,6 +3021,9 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
],
)
+ c.close()
+ c2.close()
+
class DialectEventTest(fixtures.TestBase):
@contextmanager
@@ -3370,7 +3380,7 @@ class SetInputSizesTest(fixtures.TablesTest):
)
@testing.fixture
- def input_sizes_fixture(self):
+ def input_sizes_fixture(self, testing_engine):
canary = mock.Mock()
def do_set_input_sizes(cursor, list_of_tuples, context):
diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py
index 29b8132aa..c56589248 100644
--- a/test/engine/test_logging.py
+++ b/test/engine/test_logging.py
@@ -30,7 +30,7 @@ class LogParamsTest(fixtures.TestBase):
__only_on__ = "sqlite"
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.eng = engines.testing_engine(options={"echo": True})
self.no_param_engine = engines.testing_engine(
options={"echo": True, "hide_parameters": True}
@@ -44,7 +44,7 @@ class LogParamsTest(fixtures.TestBase):
for log in [logging.getLogger("sqlalchemy.engine")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
exec_sql(self.eng, "drop table if exists foo")
for log in [logging.getLogger("sqlalchemy.engine")]:
log.removeHandler(self.buf)
@@ -413,14 +413,14 @@ class LogParamsTest(fixtures.TestBase):
class PoolLoggingTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.existing_level = logging.getLogger("sqlalchemy.pool").level
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.pool")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.pool")]:
log.removeHandler(self.buf)
logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level)
@@ -528,7 +528,7 @@ class LoggingNameTest(fixtures.TestBase):
kw.update({"echo": True})
return engines.testing_engine(options=kw)
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [
logging.getLogger("sqlalchemy.engine"),
@@ -536,7 +536,7 @@ class LoggingNameTest(fixtures.TestBase):
]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [
logging.getLogger("sqlalchemy.engine"),
logging.getLogger("sqlalchemy.pool"),
@@ -588,13 +588,13 @@ class LoggingNameTest(fixtures.TestBase):
class EchoTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.level = logging.getLogger("sqlalchemy.engine").level
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN)
self.buf = logging.handlers.BufferingHandler(100)
logging.getLogger("sqlalchemy.engine").addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
logging.getLogger("sqlalchemy.engine").removeHandler(self.buf)
logging.getLogger("sqlalchemy.engine").setLevel(self.level)
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index 550fedb8e..decdce3f9 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -17,7 +17,9 @@ from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_none
from sqlalchemy.testing import is_not
+from sqlalchemy.testing import is_not_none
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.engines import testing_engine
@@ -63,18 +65,18 @@ def MockDBAPI(): # noqa
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _with_teardown(self, connection):
@@ -364,10 +366,17 @@ class PoolEventsTest(PoolTestBase):
p = self._queuepool_fixture()
canary = []
+ @event.listens_for(p, "checkin")
def checkin(*arg, **kw):
canary.append("checkin")
- event.listen(p, "checkin", checkin)
+ @event.listens_for(p, "close_detached")
+ def close_detached(*arg, **kw):
+ canary.append("close_detached")
+
+ @event.listens_for(p, "detach")
+ def detach(*arg, **kw):
+ canary.append("detach")
return p, canary
@@ -629,15 +638,35 @@ class PoolEventsTest(PoolTestBase):
assert canary.call_args_list[0][0][0] is dbapi_con
assert canary.call_args_list[0][0][2] is exc
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_checkin_event_gc(self):
+ def test_checkin_event_gc(self, detach_gced):
p, canary = self._checkin_event_fixture()
+ if detach_gced:
+ p._is_asyncio = True
+
c1 = p.connect()
+
+ dbapi_connection = weakref.ref(c1.connection)
+
eq_(canary, [])
del c1
lazy_gc()
- eq_(canary, ["checkin"])
+
+ if detach_gced:
+ # "close_detached" is not called because for asyncio the
+ # connection is just lost.
+ eq_(canary, ["detach"])
+
+ else:
+ eq_(canary, ["checkin"])
+
+ gc_collect()
+ if detach_gced:
+ is_none(dbapi_connection())
+ else:
+ is_not_none(dbapi_connection())
def test_checkin_event_on_subsequently_recreated(self):
p, canary = self._checkin_event_fixture()
@@ -744,7 +773,7 @@ class PoolEventsTest(PoolTestBase):
eq_(conn.info["important_flag"], True)
conn.close()
- def teardown(self):
+ def teardown_test(self):
# TODO: need to get remove() functionality
# going
pool.Pool.dispatch._clear()
@@ -1490,12 +1519,16 @@ class QueuePoolTest(PoolTestBase):
self._assert_cleanup_on_pooled_reconnect(dbapi, p)
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_userspace_disconnectionerror_weakref_finalizer(self):
+ def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced):
dbapi, pool = self._queuepool_dbapi_fixture(
pool_size=1, max_overflow=2
)
+ if detach_gced:
+ pool._is_asyncio = True
+
@event.listens_for(pool, "checkout")
def handle_checkout_event(dbapi_con, con_record, con_proxy):
if getattr(dbapi_con, "boom") == "yes":
@@ -1514,8 +1547,12 @@ class QueuePoolTest(PoolTestBase):
del conn
gc_collect()
- # new connection was reset on return appropriately
- eq_(dbapi_conn.mock_calls, [call.rollback()])
+ if detach_gced:
+ # new connection was detached + abandoned on return
+ eq_(dbapi_conn.mock_calls, [])
+ else:
+ # new connection reset and returned to pool
+ eq_(dbapi_conn.mock_calls, [call.rollback()])
# old connection was just closed - did not get an
# erroneous reset on return
diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py
index 3810de06a..5a4220c82 100644
--- a/test/engine/test_processors.py
+++ b/test/engine/test_processors.py
@@ -25,7 +25,7 @@ class CBooleanProcessorTest(_BooleanProcessorTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
@@ -83,7 +83,7 @@ class _DateProcessorTest(fixtures.TestBase):
class PyDateProcessorTest(_DateProcessorTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import processors
cls.module = type(
@@ -100,7 +100,7 @@ class CDateProcessorTest(_DateProcessorTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
@@ -185,7 +185,7 @@ class _DistillArgsTest(fixtures.TestBase):
class PyDistillArgsTest(_DistillArgsTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.engine import util
cls.module = type(
@@ -202,7 +202,7 @@ class CDistillArgsTest(_DistillArgsTest):
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cutils as util
cls.module = util
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index 5fe7f6cc2..7a64b2550 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -162,7 +162,7 @@ def MockDBAPI():
class PrePingMockTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
def _pool_fixture(self, pre_ping, pool_kw=None):
@@ -182,7 +182,7 @@ class PrePingMockTest(fixtures.TestBase):
)
return _pool
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_ping_not_on_first_connect(self):
@@ -357,7 +357,7 @@ class PrePingMockTest(fixtures.TestBase):
class MockReconnectTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
self.db = testing_engine(
@@ -373,7 +373,7 @@ class MockReconnectTest(fixtures.TestBase):
e, MockDisconnect
)
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_reconnect(self):
@@ -1004,10 +1004,10 @@ class RealReconnectTest(fixtures.TestBase):
__backend__ = True
__requires__ = "graceful_disconnects", "ad_hoc_engines"
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_reconnect(self):
@@ -1336,7 +1336,7 @@ class PrePingRealTest(fixtures.TestBase):
class InvalidateDuringResultTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
self.meta = MetaData()
table = Table(
@@ -1353,7 +1353,7 @@ class InvalidateDuringResultTest(fixtures.TestBase):
[{"id": i, "name": "row %d" % i} for i in range(1, 100)],
)
- def teardown(self):
+ def teardown_test(self):
with self.engine.begin() as conn:
self.meta.drop_all(conn)
self.engine.dispose()
@@ -1470,7 +1470,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine(
options=dict(future=self.future)
)
@@ -1483,7 +1483,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
)
self.meta.create_all(self.engine)
- def teardown(self):
+ def teardown_test(self):
self.meta.drop_all(self.engine)
self.engine.dispose()
diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py
index 658cdd79f..0a46ddeec 100644
--- a/test/engine/test_reflection.py
+++ b/test/engine/test_reflection.py
@@ -796,7 +796,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
assert f1 in b1.constraints
assert len(b1.constraints) == 2
- def test_override_keys(self, connection, metadata):
+ def test_override_keys(self, metadata, connection):
"""test that columns can be overridden with a 'key',
and that ForeignKey targeting during reflection still works."""
@@ -1375,7 +1375,7 @@ class CreateDropTest(fixtures.TablesTest):
run_create_tables = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
# TablesTest is used here without
# run_create_tables, so add an explicit drop of whatever is in
# metadata
@@ -1658,7 +1658,6 @@ class SchemaTest(fixtures.TestBase):
@testing.requires.schemas
@testing.requires.cross_schema_fk_reflection
@testing.requires.implicit_default_schema
- @testing.provide_metadata
def test_blank_schema_arg(self, connection, metadata):
Table(
@@ -1913,7 +1912,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
__backend__ = True
@testing.requires.denormalized_names
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
@@ -1926,7 +1925,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
)
@testing.requires.denormalized_names
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table weird_casing")
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py
index 79126fc5b..47504b60a 100644
--- a/test/engine/test_transaction.py
+++ b/test/engine/test_transaction.py
@@ -1,6 +1,5 @@
import sys
-from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
@@ -640,12 +639,12 @@ class AutoRollbackTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata
metadata = MetaData()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_rollback_deadlock(self):
@@ -871,11 +870,13 @@ class IsolationLevelTest(fixtures.TestBase):
def test_per_engine(self):
# new in 0.9
- eng = create_engine(
+ eng = testing_engine(
testing.db.url,
- execution_options={
- "isolation_level": self._non_default_isolation_level()
- },
+ options=dict(
+ execution_options={
+ "isolation_level": self._non_default_isolation_level()
+ }
+ ),
)
conn = eng.connect()
eq_(
@@ -884,7 +885,7 @@ class IsolationLevelTest(fixtures.TestBase):
)
def test_per_option_engine(self):
- eng = create_engine(testing.db.url).execution_options(
+ eng = testing_engine(testing.db.url).execution_options(
isolation_level=self._non_default_isolation_level()
)
@@ -895,14 +896,14 @@ class IsolationLevelTest(fixtures.TestBase):
)
def test_isolation_level_accessors_connection_default(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
eq_(conn.default_isolation_level, self._default_isolation_level())
with eng.connect() as conn:
eq_(conn.get_isolation_level(), self._default_isolation_level())
def test_isolation_level_accessors_connection_option_modified(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
c2 = conn.execution_options(
isolation_level=self._non_default_isolation_level()
diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py
index 7dae1411e..59a44f8e2 100644
--- a/test/ext/asyncio/test_engine_py3k.py
+++ b/test/ext/asyncio/test_engine_py3k.py
@@ -269,7 +269,7 @@ class AsyncEngineTest(EngineFixture):
await trans.rollback(),
@async_test
- async def test_pool_exhausted(self, async_engine):
+ async def test_pool_exhausted_some_timeout(self, async_engine):
engine = create_async_engine(
testing.db.url,
pool_size=1,
@@ -277,7 +277,19 @@ class AsyncEngineTest(EngineFixture):
pool_timeout=0.1,
)
async with engine.connect():
- with expect_raises(asyncio.TimeoutError):
+ with expect_raises(exc.TimeoutError):
+ await engine.connect()
+
+ @async_test
+ async def test_pool_exhausted_no_timeout(self, async_engine):
+ engine = create_async_engine(
+ testing.db.url,
+ pool_size=1,
+ max_overflow=0,
+ pool_timeout=0,
+ )
+ async with engine.connect():
+ with expect_raises(exc.TimeoutError):
await engine.connect()
@async_test
diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py
index 2b80b753e..e25e7cfc2 100644
--- a/test/ext/declarative/test_inheritance.py
+++ b/test/ext/declarative/test_inheritance.py
@@ -27,11 +27,11 @@ Base = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py
index d7fcbf9e8..c327de7d4 100644
--- a/test/ext/declarative/test_reflection.py
+++ b/test/ext/declarative/test_reflection.py
@@ -4,7 +4,6 @@ from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import decl_api as decl
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import exc as orm_exc
@@ -14,6 +13,7 @@ from sqlalchemy.orm.decl_base import _DeferredMapperConfig
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import gc_collect
@@ -22,20 +22,19 @@ from sqlalchemy.testing.util import gc_collect
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry()
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
class DeferredReflectBase(DeclarativeReflectionBase):
- def teardown(self):
- super(DeferredReflectBase, self).teardown()
+ def teardown_test(self):
+ super(DeferredReflectBase, self).teardown_test()
_DeferredMapperConfig._configs.clear()
@@ -101,22 +100,23 @@ class DeferredReflectionTest(DeferredReflectBase):
u1 = User(
name="u1", addresses=[Address(email="one"), Address(email="two")]
)
- sess = create_session(testing.db)
- sess.add(u1)
- sess.flush()
- sess.expunge_all()
- eq_(
- sess.query(User).all(),
- [
- User(
- name="u1",
- addresses=[Address(email="one"), Address(email="two")],
- )
- ],
- )
- a1 = sess.query(Address).filter(Address.email == "two").one()
- eq_(a1, Address(email="two"))
- eq_(a1.user, User(name="u1"))
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
+
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User).all(),
+ [
+ User(
+ name="u1",
+ addresses=[Address(email="one"), Address(email="two")],
+ )
+ ],
+ )
+ a1 = sess.query(Address).filter(Address.email == "two").one()
+ eq_(a1, Address(email="two"))
+ eq_(a1.user, User(name="u1"))
def test_exception_prepare_not_called(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -191,15 +191,25 @@ class DeferredReflectionTest(DeferredReflectBase):
return {"primary_key": cls.__table__.c.id}
DeferredReflection.prepare(testing.db)
- sess = Session(testing.db)
- sess.add_all(
- [User(name="G"), User(name="Q"), User(name="A"), User(name="C")]
- )
- sess.commit()
- eq_(
- sess.query(User).order_by(User.name).all(),
- [User(name="A"), User(name="C"), User(name="G"), User(name="Q")],
- )
+ with fixture_session() as sess:
+ sess.add_all(
+ [
+ User(name="G"),
+ User(name="Q"),
+ User(name="A"),
+ User(name="C"),
+ ]
+ )
+ sess.commit()
+ eq_(
+ sess.query(User).order_by(User.name).all(),
+ [
+ User(name="A"),
+ User(name="C"),
+ User(name="G"),
+ User(name="Q"),
+ ],
+ )
@testing.requires.predictable_gc
def test_cls_not_strong_ref(self):
@@ -255,14 +265,14 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase):
u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")])
- sess = Session(testing.db)
- sess.add(u1)
- sess.commit()
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
- eq_(
- sess.query(User).all(),
- [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
- )
+ eq_(
+ sess.query(User).all(),
+ [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
+ )
def test_string_resolution(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -296,27 +306,26 @@ class DeferredInhReflectBase(DeferredReflectBase):
Foo = Base.registry._class_registry["Foo"]
Bar = Base.registry._class_registry["Bar"]
- s = Session(testing.db)
-
- s.add_all(
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ]
- )
- s.commit()
-
- eq_(
- s.query(Foo).order_by(Foo.id).all(),
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ],
- )
+ with fixture_session() as s:
+ s.add_all(
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ]
+ )
+ s.commit()
+
+ eq_(
+ s.query(Foo).order_by(Foo.id).all(),
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ],
+ )
class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py
index b1f5cc956..31ae050c1 100644
--- a/test/ext/test_associationproxy.py
+++ b/test/ext/test_associationproxy.py
@@ -101,7 +101,7 @@ class AutoFlushTest(fixtures.TablesTest):
Column("name", String(50)),
)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture(self, collection_class, is_dict=False):
@@ -198,7 +198,7 @@ class AutoFlushTest(fixtures.TablesTest):
class _CollectionOperations(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
collection_class = self.collection_class
metadata = MetaData()
@@ -260,7 +260,7 @@ class _CollectionOperations(fixtures.TestBase):
self.session = fixture_session()
self.Parent, self.Child = Parent, Child
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
@@ -885,7 +885,7 @@ class CustomObjectTest(_CollectionOperations):
class ProxyFactoryTest(ListTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
@@ -1157,7 +1157,7 @@ class ScalarTest(fixtures.TestBase):
class LazyLoadTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
@@ -1197,7 +1197,7 @@ class LazyLoadTest(fixtures.TestBase):
self.Parent, self.Child = Parent, Child
self.table = parents_table
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
@@ -2294,7 +2294,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
class DictOfTupleUpdateTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
class B(object):
def __init__(self, key, elem):
self.key = key
@@ -2434,7 +2434,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest):
class AttributeAccessTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_resolve_aliased_class(self):
diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py
index 71fabc629..2d4e9848e 100644
--- a/test/ext/test_baked.py
+++ b/test/ext/test_baked.py
@@ -27,7 +27,7 @@ class BakedTest(_fixtures.FixtureTest):
run_inserts = "once"
run_deletes = None
- def setup(self):
+ def setup_test(self):
self.bakery = baked.bakery()
diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py
index 058c1dfd7..d011417d7 100644
--- a/test/ext/test_compiler.py
+++ b/test/ext/test_compiler.py
@@ -426,7 +426,7 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def teardown(self):
+ def teardown_test(self):
for cls in (Select, BindParameter):
deregister(cls)
diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py
index ad9bf0bc0..f3eceb0dc 100644
--- a/test/ext/test_extendedattr.py
+++ b/test/ext/test_extendedattr.py
@@ -32,7 +32,7 @@ def modifies_instrumentation_finders(fn, *args, **kw):
class _ExtBase(object):
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation._reinstall_default_lookups()
@@ -89,7 +89,7 @@ MyBaseClass, MyClass = None, None
class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global MyBaseClass, MyClass
class MyBaseClass(object):
@@ -143,7 +143,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
else:
del self._goofy_dict[key]
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_instance_dict(self):
diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py
index 038bdd83e..bb06d9648 100644
--- a/test/ext/test_horizontal_shard.py
+++ b/test/ext/test_horizontal_shard.py
@@ -19,7 +19,6 @@ from sqlalchemy import update
from sqlalchemy import util
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import deferred
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
@@ -42,7 +41,7 @@ class ShardTest(object):
schema = None
- def setUp(self):
+ def setup_test(self):
global db1, db2, db3, db4, weather_locations, weather_reports
db1, db2, db3, db4 = self._dbs = self._init_dbs()
@@ -88,7 +87,7 @@ class ShardTest(object):
@classmethod
def setup_session(cls):
- global create_session
+ global sharded_session
shard_lookup = {
"North America": "north_america",
"Asia": "asia",
@@ -128,10 +127,10 @@ class ShardTest(object):
else:
return ids
- create_session = sessionmaker(
+ sharded_session = sessionmaker(
class_=ShardedSession, autoflush=True, autocommit=False
)
- create_session.configure(
+ sharded_session.configure(
shards={
"north_america": db1,
"asia": db2,
@@ -180,7 +179,7 @@ class ShardTest(object):
tokyo.reports.append(Report(80.0, id_=1))
newyork.reports.append(Report(75, id_=1))
quito.reports.append(Report(85))
- sess = create_session(future=True)
+ sess = sharded_session(future=True)
for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
sess.add(c)
sess.flush()
@@ -671,11 +670,10 @@ class DistinctEngineShardTest(ShardTest, fixtures.TestBase):
self.dbs = [db1, db2, db3, db4]
return self.dbs
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- for db in self.dbs:
- db.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -702,10 +700,10 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase):
self.engine = e
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.engine.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -778,10 +776,13 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
self.postgresql_engine = e2
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.sqlite_engine.connect().invalidate()
+ # the tests in this suite don't cleanly close out the Session
+ # at the moment so use the reaper to close all connections
+ testing_reaper.checkin_all()
+
for i in [1, 3]:
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
@@ -789,6 +790,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
self.tables_test_metadata.drop_all(conn)
for i in [2, 4]:
conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,))
+ self.postgresql_engine.dispose()
class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
@@ -904,11 +906,11 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
return self.dbs
- def teardown(self):
+ def teardown_test(self):
for db in self.dbs:
db.connect().invalidate()
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
for i in range(1, 3):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py
index 048a8b52d..3bab7db93 100644
--- a/test/ext/test_hybrid.py
+++ b/test/ext/test_hybrid.py
@@ -697,7 +697,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import literal
symbols = ("usd", "gbp", "cad", "eur", "aud")
diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py
index eba2ac0cb..21244de73 100644
--- a/test/ext/test_mutable.py
+++ b/test/ext/test_mutable.py
@@ -90,11 +90,10 @@ class _MutableDictTestFixture(object):
def _type_fixture(cls):
return MutableDict
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableDictTestFixture, self).teardown()
class _MutableDictTestBase(_MutableDictTestFixture):
@@ -312,11 +311,10 @@ class _MutableListTestFixture(object):
def _type_fixture(cls):
return MutableList
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableListTestFixture, self).teardown()
class _MutableListTestBase(_MutableListTestFixture):
@@ -619,11 +617,10 @@ class _MutableSetTestFixture(object):
def _type_fixture(cls):
return MutableSet
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableSetTestFixture, self).teardown()
class _MutableSetTestBase(_MutableSetTestFixture):
@@ -1234,17 +1231,15 @@ class _CompositeTestBase(object):
Column("unrelated_data", String(50)),
)
- def setup(self):
+ def setup_test(self):
from sqlalchemy.ext import mutable
mutable._setup_composite_listener()
- super(_CompositeTestBase, self).setup()
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_CompositeTestBase, self).teardown()
@classmethod
def _type_fixture(cls):
diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py
index f23d6cb57..280fad6cf 100644
--- a/test/ext/test_orderinglist.py
+++ b/test/ext/test_orderinglist.py
@@ -8,7 +8,7 @@ from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.fixtures import create_session
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import picklers
@@ -60,7 +60,7 @@ def alpha_ordering(index, collection):
class OrderingListTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
global metadata, slides_table, bullets_table, Slide, Bullet
slides_table, bullets_table = None, None
Slide, Bullet = None, None
@@ -122,7 +122,7 @@ class OrderingListTest(fixtures.TestBase):
metadata.create_all(testing.db)
- def teardown(self):
+ def teardown_test(self):
metadata.drop_all(testing.db)
def test_append_no_reorder(self):
@@ -167,7 +167,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(s1.bullets[2].position == 3)
self.assert_(s1.bullets[3].position == 4)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -232,7 +232,7 @@ class OrderingListTest(fixtures.TestBase):
s1.bullets._reorder()
self.assert_(s1.bullets[4].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -289,7 +289,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(len(s1.bullets) == 6)
self.assert_(s1.bullets[5].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -338,7 +338,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(s1.bullets[li].position == li)
self.assert_(s1.bullets[li] == b[bi])
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
@@ -365,7 +365,7 @@ class OrderingListTest(fixtures.TestBase):
self.assert_(len(s1.bullets) == 3)
self.assert_(s1.bullets[2].position == 2)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py
index 4c005d336..4d9162105 100644
--- a/test/orm/declarative/test_basic.py
+++ b/test/orm/declarative/test_basic.py
@@ -61,11 +61,11 @@ class DeclarativeTestBase(
):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
global Base
Base = declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/orm/declarative/test_concurrency.py b/test/orm/declarative/test_concurrency.py
index 5f12d8272..ecddc2e5f 100644
--- a/test/orm/declarative/test_concurrency.py
+++ b/test/orm/declarative/test_concurrency.py
@@ -17,7 +17,7 @@ from sqlalchemy.testing.fixtures import fixture_session
class ConcurrentUseDeclMappingTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@classmethod
diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py
index cc29cab7d..e09b1570e 100644
--- a/test/orm/declarative/test_inheritance.py
+++ b/test/orm/declarative/test_inheritance.py
@@ -27,11 +27,11 @@ Base = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py
index 631527daf..ad4832c35 100644
--- a/test/orm/declarative/test_mixin.py
+++ b/test/orm/declarative/test_mixin.py
@@ -38,13 +38,13 @@ mapper_registry = None
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base, mapper_registry
mapper_registry = registry(metadata=MetaData())
Base = mapper_registry.generate_base()
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
with testing.db.begin() as conn:
diff --git a/test/orm/declarative/test_reflection.py b/test/orm/declarative/test_reflection.py
index 241528c44..e7b2a7058 100644
--- a/test/orm/declarative/test_reflection.py
+++ b/test/orm/declarative/test_reflection.py
@@ -17,14 +17,13 @@ from sqlalchemy.testing.schema import Table
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry(metadata=MetaData())
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
index bdcdedc44..da07b4941 100644
--- a/test/orm/inheritance/test_basic.py
+++ b/test/orm/inheritance/test_basic.py
@@ -31,7 +31,6 @@ from sqlalchemy.orm import synonym
from sqlalchemy.orm.util import instance_str
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
@@ -1889,7 +1888,6 @@ class VersioningTest(fixtures.MappedTest):
@testing.emits_warning(r".*updated rowcount")
@testing.requires.sane_rowcount_w_returning
- @engines.close_open_connections
def test_save_update(self):
subtable, base, stuff = (
self.tables.subtable,
@@ -2927,7 +2925,7 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase):
)
return parent, child
- def tearDown(self):
+ def teardown_test(self):
clear_mappers()
def test_warning_on_sub(self):
@@ -3417,27 +3415,26 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest):
@classmethod
def insert_data(cls, connection):
Parent, A, B = cls.classes("Parent", "A", "B")
- s = fixture_session()
-
- p1 = Parent(id=1)
- p2 = Parent(id=2)
- s.add_all([p1, p2])
- s.flush()
+ with Session(connection) as s:
+ p1 = Parent(id=1)
+ p2 = Parent(id=2)
+ s.add_all([p1, p2])
+ s.flush()
- s.add_all(
- [
- A(id=1, parent_id=1),
- B(id=2, parent_id=1),
- A(id=3, parent_id=1),
- B(id=4, parent_id=1),
- ]
- )
- s.flush()
+ s.add_all(
+ [
+ A(id=1, parent_id=1),
+ B(id=2, parent_id=1),
+ A(id=3, parent_id=1),
+ B(id=4, parent_id=1),
+ ]
+ )
+ s.flush()
- s.query(A).filter(A.id.in_([3, 4])).update(
- {A.type: None}, synchronize_session=False
- )
- s.commit()
+ s.query(A).filter(A.id.in_([3, 4])).update(
+ {A.type: None}, synchronize_session=False
+ )
+ s.commit()
def test_pk_is_null(self):
Parent, A = self.classes("Parent", "A")
@@ -3527,10 +3524,12 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest):
ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes(
"ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB"
)
- s = fixture_session()
+ with Session(connection) as s:
- s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()])
- s.commit()
+ s.add_all(
+ [ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]
+ )
+ s.commit()
def test_single_invalid_ident(self):
ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA")
diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py
index 8820aa6a4..0a0a5d12b 100644
--- a/test/orm/test_attributes.py
+++ b/test/orm/test_attributes.py
@@ -209,7 +209,7 @@ class AttributeImplAPITest(fixtures.MappedTest):
class AttributesTest(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
global MyTest, MyTest2
class MyTest(object):
@@ -218,7 +218,7 @@ class AttributesTest(fixtures.ORMTest):
class MyTest2(object):
pass
- def teardown(self):
+ def teardown_test(self):
global MyTest, MyTest2
MyTest, MyTest2 = None, None
@@ -3690,7 +3690,7 @@ class EventPropagateTest(fixtures.TestBase):
class CollectionInitTest(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
@@ -3749,7 +3749,7 @@ class CollectionInitTest(fixtures.TestBase):
class TestUnlink(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py
index 2f54f7fff..014fa152e 100644
--- a/test/orm/test_bind.py
+++ b/test/orm/test_bind.py
@@ -428,39 +428,46 @@ class BindIntegrationTest(_fixtures.FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- c = testing.db.connect()
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.close()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u2")
- sess.add(u)
- sess.flush()
- sess.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
-
- with c.begin():
- c.exec_driver_sql("delete from users")
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- c = testing.db.connect()
-
- trans = c.begin()
- sess = Session(bind=c, autocommit=True)
- u = User(name="u3")
- sess.add(u)
- sess.flush()
- assert c.in_transaction()
- trans.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
+ with testing.db.connect() as c:
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.close()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u2")
+ sess.add(u)
+ sess.flush()
+ sess.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
+
+ with c.begin():
+ c.exec_driver_sql("delete from users")
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ with testing.db.connect() as c:
+ trans = c.begin()
+ sess = Session(bind=c, autocommit=True)
+ u = User(name="u3")
+ sess.add(u)
+ sess.flush()
+ assert c.in_transaction()
+ trans.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
class SessionBindTest(fixtures.MappedTest):
@@ -506,6 +513,7 @@ class SessionBindTest(fixtures.MappedTest):
finally:
if hasattr(bind, "close"):
bind.close()
+ sess.close()
def test_session_unbound(self):
Foo = self.classes.Foo
diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py
index 3d09bd446..2a0aafbbc 100644
--- a/test/orm/test_collection.py
+++ b/test/orm/test_collection.py
@@ -92,13 +92,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
return str((id(self), self.a, self.b, self.c))
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
instrumentation.register_class(cls.Entity)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation.unregister_class(cls.Entity)
- super(CollectionsTest, cls).teardown_class()
_entity_id = 1
diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py
index df652daf4..20d8ecc2d 100644
--- a/test/orm/test_compile.py
+++ b/test/orm/test_compile.py
@@ -19,7 +19,7 @@ from sqlalchemy.testing import fixtures
class CompileTest(fixtures.ORMTest):
"""test various mapper compilation scenarios"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_with_polymorphic(self):
diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py
index e1ef67fed..ed11b89c9 100644
--- a/test/orm/test_cycles.py
+++ b/test/orm/test_cycles.py
@@ -1743,8 +1743,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest):
id = Column(Integer, primary_key=True)
a_id = Column(ForeignKey("a.id", name="a_fk"))
- def setup(self):
- super(PostUpdateOnUpdateTest, self).setup()
+ def setup_test(self):
PostUpdateOnUpdateTest.counter = count()
PostUpdateOnUpdateTest.db_counter = count()
diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py
index 6d946cfe6..15063ebe9 100644
--- a/test/orm/test_deprecations.py
+++ b/test/orm/test_deprecations.py
@@ -2199,6 +2199,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture):
class AutocommitClosesOnFailTest(fixtures.MappedTest):
__requires__ = ("deferrable_fks",)
+ __only_on__ = ("postgresql+psycopg2",) # needs #5824 for asyncpg
@classmethod
def define_tables(cls, metadata):
@@ -4498,44 +4499,49 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
warnings += (join_aliased_dep,)
# load a user who has an order that contains item id 3 and address
# id 1 (order 3, owned by jack)
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_, isouter=True)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_, isouter=True)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
-
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .outerjoin("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .outerjoin("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join(
+ "orders", "items", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join(
+ "orders", "address", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .outerjoin("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .outerjoin("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
class AliasFromCorrectLeftTest(
diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py
index 4498fc1ff..7eedb37c9 100644
--- a/test/orm/test_eager_relations.py
+++ b/test/orm/test_eager_relations.py
@@ -559,15 +559,15 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
5,
),
]:
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opt).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opt).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
def test_disable_dynamic(self):
"""test no joined option on a dynamic."""
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index 1c918a88c..e85c23d6f 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -45,13 +45,14 @@ from test.orm import _fixtures
class _RemoveListeners(object):
- def teardown(self):
+ @testing.fixture(autouse=True)
+ def _remove_listeners(self):
+ yield
events.MapperEvents._clear()
events.InstanceEvents._clear()
events.SessionEvents._clear()
events.InstrumentationEvents._clear()
events.QueryEvents._clear()
- super(_RemoveListeners, self).teardown()
class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
@@ -1174,7 +1175,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
argnames="target, event_name, fn",
)(fn)
- def teardown(self):
+ def teardown_test(self):
A = self.classes.A
A._sa_class_manager.dispatch._clear()
diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py
index cc9596466..f622bff02 100644
--- a/test/orm/test_froms.py
+++ b/test/orm/test_froms.py
@@ -2395,16 +2395,19 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
]
adalias = addresses.alias()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
- assert q.all() == expected
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+
+ eq_(q.all(), expected)
# test with a straight statement
s = (
@@ -2417,52 +2420,57 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
.group_by(*[c for c in users.c])
.order_by(users.c.id)
)
- q = fixture_session().query(User)
- result = (
- q.add_columns(s.selected_columns.count, s.selected_columns.concat)
- .from_statement(s)
- .all()
- )
- assert result == expected
-
- sess.expunge_all()
- # test with select_entity_from()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .select_entity_from(users.outerjoin(addresses))
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = sess.query(User)
+ result = (
+ q.add_columns(
+ s.selected_columns.count, s.selected_columns.concat
+ )
+ .from_statement(s)
+ .all()
+ )
+ eq_(result, expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ # test with select_entity_from()
+ q = (
+ fixture_session()
+ .query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .select_entity_from(users.outerjoin(addresses))
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .outerjoin("addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ eq_(q.all(), expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin("addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+ eq_(q.all(), expected)
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- assert q.all() == expected
- sess.expunge_all()
+ eq_(q.all(), expected)
def test_expression_selectable_matches_mzero(self):
User, Address = self.classes.User, self.classes.Address
diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py
index 3061de309..43cf81e6d 100644
--- a/test/orm/test_lazy_relations.py
+++ b/test/orm/test_lazy_relations.py
@@ -717,24 +717,24 @@ class LazyTest(_fixtures.FixtureTest):
),
)
- sess = fixture_session()
+ with fixture_session() as sess:
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_uses_get_compatible_types(self):
"""test the use_get optimization with compatible
@@ -789,24 +789,23 @@ class LazyTest(_fixtures.FixtureTest):
properties=dict(user=relationship(mapper(User, users))),
)
- sess = fixture_session()
-
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ with fixture_session() as sess:
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_many_to_one(self):
users, Address, addresses, User = (
diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py
index 0e8ac97e3..42b5b3e45 100644
--- a/test/orm/test_load_on_fks.py
+++ b/test/orm/test_load_on_fks.py
@@ -9,14 +9,12 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import instance_state
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
-engine = testing.db
-
-
class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
@@ -36,27 +34,27 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
)
parent_id = Column(Integer, ForeignKey("parent.id"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
- def tearDown(self):
- Base.metadata.drop_all(engine)
+ def teardown_test(self):
+ Base.metadata.drop_all(testing.db)
def test_annoying_autoflush_one(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
p1.children = []
def test_annoying_autoflush_two(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
assert p1.children == []
def test_dont_load_if_no_keys(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
@@ -68,7 +66,9 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
@@ -91,10 +91,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
parent = relationship(Parent, backref=backref("children"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
global sess, p1, p2, c1, c2
- sess = Session(bind=engine)
+ sess = Session(bind=testing.db)
p1 = Parent()
p2 = Parent()
@@ -105,9 +105,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
sess.commit()
- def tearDown(self):
+ def teardown_test(self):
sess.rollback()
- Base.metadata.drop_all(engine)
+ Base.metadata.drop_all(testing.db)
def test_load_on_pending_allows_backref_event(self):
Child.parent.property.load_on_pending = True
diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py
index 013eb21e1..d182fd2c1 100644
--- a/test/orm/test_mapper.py
+++ b/test/orm/test_mapper.py
@@ -2560,7 +2560,7 @@ class MagicNamesTest(fixtures.MappedTest):
class DocumentTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.mapper = registry().map_imperatively
@@ -2624,14 +2624,14 @@ class DocumentTest(fixtures.TestBase):
class ORMLoggingTest(_fixtures.FixtureTest):
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.orm")]:
log.addHandler(self.buf)
self.mapper = registry().map_imperatively
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.orm")]:
log.removeHandler(self.buf)
diff --git a/test/orm/test_options.py b/test/orm/test_options.py
index b22b318e9..6f47c1238 100644
--- a/test/orm/test_options.py
+++ b/test/orm/test_options.py
@@ -1523,9 +1523,7 @@ class PickleTest(PathTest, QueryTest):
class LocalOptsTest(PathTest, QueryTest):
@classmethod
- def setup_class(cls):
- super(LocalOptsTest, cls).setup_class()
-
+ def setup_test_class(cls):
@strategy_options.loader_option()
def some_col_opt_only(loadopt, key, opts):
return loadopt.set_column_strategy(
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index fd8e849fb..7546ba162 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -6000,12 +6000,19 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
[User.orders_syn, Order.items_syn],
[User.orders_syn_2, Order.items_syn],
):
- q = fixture_session().query(User)
- for path in j:
- q = q.join(path)
- q = q.filter_by(id=3)
- result = q.all()
- assert [User(id=7, name="jack"), User(id=9, name="fred")] == result
+ with fixture_session() as sess:
+ q = sess.query(User)
+ for path in j:
+ q = q.join(path)
+ q = q.filter_by(id=3)
+ result = q.all()
+ eq_(
+ result,
+ [
+ User(id=7, name="jack"),
+ User(id=9, name="fred"),
+ ],
+ )
def test_with_parent(self):
Order, User = self.classes.Order, self.classes.User
@@ -6018,17 +6025,17 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
("name_syn", "orders_syn"),
("name_syn", "orders_syn_2"),
):
- sess = fixture_session()
- q = sess.query(User)
+ with fixture_session() as sess:
+ q = sess.query(User)
- u1 = q.filter_by(**{nameprop: "jack"}).one()
+ u1 = q.filter_by(**{nameprop: "jack"}).one()
- o = sess.query(Order).with_parent(u1, property=orderprop).all()
- assert [
- Order(description="order 1"),
- Order(description="order 3"),
- Order(description="order 5"),
- ] == o
+ o = sess.query(Order).with_parent(u1, property=orderprop).all()
+ assert [
+ Order(description="order 1"),
+ Order(description="order 3"),
+ Order(description="order 5"),
+ ] == o
def test_froms_aliased_col(self):
Address, User = self.classes.Address, self.classes.User
diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py
index 12c084b2d..ef1bf2e60 100644
--- a/test/orm/test_rel_fn.py
+++ b/test/orm/test_rel_fn.py
@@ -26,7 +26,7 @@ from sqlalchemy.testing import mock
class _JoinFixtures(object):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
m = MetaData()
cls.left = Table(
"lft",
diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py
index 5979f08ae..8d73cd40e 100644
--- a/test/orm/test_relationships.py
+++ b/test/orm/test_relationships.py
@@ -637,7 +637,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase):
"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture_one(
@@ -2474,7 +2474,7 @@ class JoinConditionErrorTest(fixtures.TestBase):
assert_raises(sa.exc.ArgumentError, configure_mappers)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@@ -4354,7 +4354,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest):
class SecondaryArgTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@testing.combinations((True,), (False,))
diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py
index 5535fe5d6..4895c7d3a 100644
--- a/test/orm/test_selectin_relations.py
+++ b/test/orm/test_selectin_relations.py
@@ -713,35 +713,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
diff --git a/test/orm/test_session.py b/test/orm/test_session.py
index 20c4752b8..3d4566af3 100644
--- a/test/orm/test_session.py
+++ b/test/orm/test_session.py
@@ -1273,7 +1273,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
run_inserts = None
- def setup(self):
+ def setup_test(self):
mapper(self.classes.User, self.tables.users)
def _assert_modified(self, u1):
@@ -1288,11 +1288,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
def _assert_no_cycle(self, u1):
assert sa.orm.attributes.instance_state(u1)._strong_obj is None
- def _persistent_fixture(self):
+ def _persistent_fixture(self, gc_collect=False):
User = self.classes.User
u1 = User()
u1.name = "ed"
- sess = fixture_session()
+ if gc_collect:
+ sess = Session(testing.db)
+ else:
+ sess = fixture_session()
sess.add(u1)
sess.flush()
return sess, u1
@@ -1389,14 +1392,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
@testing.requires.predictable_gc
def test_move_gc_session_persistent_dirty(self):
- sess, u1 = self._persistent_fixture()
+ sess, u1 = self._persistent_fixture(gc_collect=True)
u1.name = "edchanged"
self._assert_cycle(u1)
self._assert_modified(u1)
del sess
gc_collect()
self._assert_cycle(u1)
- s2 = fixture_session()
+ s2 = Session(testing.db)
s2.add(u1)
self._assert_cycle(u1)
self._assert_modified(u1)
@@ -1565,7 +1568,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
mapper(User, users)
- sess = fixture_session()
+ sess = Session(testing.db)
u1 = User(name="u1")
sess.add(u1)
@@ -1573,7 +1576,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
# can't add u1 to Session,
# already belongs to u2
- s2 = fixture_session()
+ s2 = Session(testing.db)
assert_raises_message(
sa.exc.InvalidRequestError,
r".*is already attached to session",
@@ -1725,11 +1728,10 @@ class DisposedStates(fixtures.MappedTest):
mapper(T, cls.tables.t1)
- def teardown(self):
+ def teardown_test(self):
from sqlalchemy.orm.session import _sessions
_sessions.clear()
- super(DisposedStates, self).teardown()
def _set_imap_in_disposal(self, sess, *objs):
"""remove selected objects from the given session, as though
diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py
index fe20442a3..150cee222 100644
--- a/test/orm/test_subquery_relations.py
+++ b/test/orm/test_subquery_relations.py
@@ -734,35 +734,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py
index 550cf6535..7f77b01c7 100644
--- a/test/orm/test_transaction.py
+++ b/test/orm/test_transaction.py
@@ -66,17 +66,18 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
users, User = self.tables.users, self.classes.User
mapper(User, users)
- conn = testing.db.connect()
- trans = conn.begin()
- sess = Session(bind=conn, autocommit=False, autoflush=True)
- sess.begin(subtransactions=True)
- u = User(name="ed")
- sess.add(u)
- sess.flush()
- sess.commit() # commit does nothing
- trans.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+
+ with testing.db.connect() as conn:
+ trans = conn.begin()
+ sess = Session(bind=conn, autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
+ u = User(name="ed")
+ sess.add(u)
+ sess.flush()
+ sess.commit() # commit does nothing
+ trans.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@engines.close_open_connections
def test_subtransaction_on_external_no_begin(self):
@@ -260,34 +261,33 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
users = self.tables.users
engine = Engine._future_facade(testing.db)
- session = Session(engine, autocommit=False)
-
- session.begin()
- session.connection().execute(users.insert().values(name="user1"))
- session.begin(subtransactions=True)
- session.begin_nested()
- session.connection().execute(users.insert().values(name="user2"))
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
- session.rollback()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(users.insert().values(name="user3"))
- session.commit()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
+ with Session(engine, autocommit=False) as session:
+ session.begin()
+ session.connection().execute(users.insert().values(name="user1"))
+ session.begin(subtransactions=True)
+ session.begin_nested()
+ session.connection().execute(users.insert().values(name="user2"))
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ session.rollback()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(users.insert().values(name="user3"))
+ session.commit()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting(self):
@@ -295,27 +295,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- s = Session(testing.db)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session() as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- s.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ s.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- s.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ s.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting_future(self):
@@ -323,27 +323,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- s = Session(testing.db, future=True)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session(future=True) as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- nt2.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ nt2.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- nt1.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ nt1.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.independent_connections
def test_transactions_isolated(self):
@@ -1049,23 +1049,25 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
mapper(User, users)
- session = Session(testing.db)
+ with fixture_session() as session:
- with expect_warnings(".*during handling of a previous exception.*"):
- session.begin_nested()
- savepoint = session.connection()._nested_transaction._savepoint
+ with expect_warnings(
+ ".*during handling of a previous exception.*"
+ ):
+ session.begin_nested()
+ savepoint = session.connection()._nested_transaction._savepoint
- # force the savepoint to disappear
- session.connection().dialect.do_release_savepoint(
- session.connection(), savepoint
- )
+ # force the savepoint to disappear
+ session.connection().dialect.do_release_savepoint(
+ session.connection(), savepoint
+ )
- # now do a broken flush
- session.add_all([User(id=1), User(id=1)])
+ # now do a broken flush
+ session.add_all([User(id=1), User(id=1)])
- assert_raises_message(
- sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
- )
+ assert_raises_message(
+ sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
+ )
class _LocalFixture(FixtureTest):
@@ -1170,39 +1172,40 @@ class SubtransactionRecipeTest(FixtureTest):
def test_recipe_heavy_nesting(self, subtransaction_recipe):
users = self.tables.users
- session = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(session):
- session.connection().execute(users.insert().values(name="user1"))
+ with fixture_session(future=self.future) as session:
with subtransaction_recipe(session):
- savepoint = session.begin_nested()
session.connection().execute(
- users.insert().values(name="user2")
+ users.insert().values(name="user1")
)
+ with subtransaction_recipe(session):
+ savepoint = session.begin_nested()
+ session.connection().execute(
+ users.insert().values(name="user2")
+ )
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ savepoint.rollback()
+
+ with subtransaction_recipe(session):
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(
+ users.insert().values(name="user3")
+ )
assert (
session.connection()
.exec_driver_sql("select count(1) from users")
.scalar()
== 2
)
- savepoint.rollback()
-
- with subtransaction_recipe(session):
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(
- users.insert().values(name="user3")
- )
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
@engines.close_open_connections
def test_recipe_subtransaction_on_external_subtrans(
@@ -1228,13 +1231,12 @@ class SubtransactionRecipeTest(FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.close()
- assert len(sess.query(User).all()) == 1
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.close()
+ assert len(sess.query(User).all()) == 1
def test_recipe_subtransaction_on_noautocommit(
self, subtransaction_recipe
@@ -1242,16 +1244,15 @@ class SubtransactionRecipeTest(FixtureTest):
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- sess.begin()
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@testing.requires.savepoints
def test_recipe_mixed_transaction_control(self, subtransaction_recipe):
@@ -1259,30 +1260,28 @@ class SubtransactionRecipeTest(FixtureTest):
mapper(User, users)
- sess = Session(testing.db, future=self.future)
+ with fixture_session(future=self.future) as sess:
- sess.begin()
- sess.begin_nested()
+ sess.begin()
+ sess.begin_nested()
- with subtransaction_recipe(sess):
+ with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
+ sess.add(User(name="u1"))
- sess.commit()
- sess.commit()
+ sess.commit()
+ sess.commit()
- eq_(len(sess.query(User).all()), 1)
- sess.close()
+ eq_(len(sess.query(User).all()), 1)
+ sess.close()
- t1 = sess.begin()
- t2 = sess.begin_nested()
-
- sess.add(User(name="u2"))
+ t1 = sess.begin()
+ t2 = sess.begin_nested()
- t2.commit()
- assert sess._legacy_transaction() is t1
+ sess.add(User(name="u2"))
- sess.close()
+ t2.commit()
+ assert sess._legacy_transaction() is t1
def test_recipe_error_on_using_inactive_session_commands(
self, subtransaction_recipe
@@ -1290,56 +1289,55 @@ class SubtransactionRecipeTest(FixtureTest):
users, User = self.tables.users, self.classes.User
mapper(User, users)
- sess = Session(testing.db, future=self.future)
- sess.begin()
-
- try:
- with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
- sess.flush()
- raise Exception("force rollback")
- except:
- pass
-
- if self.recipe_rollsback_early:
- # that was a real rollback, so no transaction
- assert not sess.in_transaction()
- is_(sess.get_transaction(), None)
- else:
- assert sess.in_transaction()
-
- sess.close()
- assert not sess.in_transaction()
-
- def test_recipe_multi_nesting(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- assert sess.in_transaction()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
try:
with subtransaction_recipe(sess):
- assert sess._legacy_transaction()
+ sess.add(User(name="u1"))
+ sess.flush()
raise Exception("force rollback")
except:
pass
if self.recipe_rollsback_early:
+ # that was a real rollback, so no transaction
assert not sess.in_transaction()
+ is_(sess.get_transaction(), None)
else:
assert sess.in_transaction()
- assert not sess.in_transaction()
+ sess.close()
+ assert not sess.in_transaction()
+
+ def test_recipe_multi_nesting(self, subtransaction_recipe):
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ assert sess.in_transaction()
+
+ try:
+ with subtransaction_recipe(sess):
+ assert sess._legacy_transaction()
+ raise Exception("force rollback")
+ except:
+ pass
+
+ if self.recipe_rollsback_early:
+ assert not sess.in_transaction()
+ else:
+ assert sess.in_transaction()
+
+ assert not sess.in_transaction()
def test_recipe_deactive_status_check(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
- sess.begin()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
- with subtransaction_recipe(sess):
- sess.rollback()
+ with subtransaction_recipe(sess):
+ sess.rollback()
- assert not sess.in_transaction()
- sess.commit() # no error
+ assert not sess.in_transaction()
+ sess.commit() # no error
class FixtureDataTest(_LocalFixture):
@@ -1394,28 +1392,28 @@ class CleanSavepointTest(FixtureTest):
mapper(User, users)
- s = Session(bind=testing.db, future=future)
- u1 = User(name="u1")
- u2 = User(name="u2")
- s.add_all([u1, u2])
- s.commit()
- u1.name
- u2.name
- trans = s._transaction
- assert trans is not None
- s.begin_nested()
- update_fn(s, u2)
- eq_(u2.name, "u2modified")
- s.rollback()
+ with fixture_session(future=future) as s:
+ u1 = User(name="u1")
+ u2 = User(name="u2")
+ s.add_all([u1, u2])
+ s.commit()
+ u1.name
+ u2.name
+ trans = s._transaction
+ assert trans is not None
+ s.begin_nested()
+ update_fn(s, u2)
+ eq_(u2.name, "u2modified")
+ s.rollback()
- if future:
- assert s._transaction is None
- assert "name" not in u1.__dict__
- else:
- assert s._transaction is trans
- eq_(u1.__dict__["name"], "u1")
- assert "name" not in u2.__dict__
- eq_(u2.name, "u2")
+ if future:
+ assert s._transaction is None
+ assert "name" not in u1.__dict__
+ else:
+ assert s._transaction is trans
+ eq_(u1.__dict__["name"], "u1")
+ assert "name" not in u2.__dict__
+ eq_(u2.name, "u2")
@testing.requires.savepoints
def test_rollback_ignores_clean_on_savepoint(self):
@@ -2116,82 +2114,108 @@ class ContextManagerPlusFutureTest(FixtureTest):
eq_(sess.query(User).count(), 1)
def test_explicit_begin(self):
- s1 = Session(testing.db)
- with s1.begin() as trans:
- is_(trans, s1._legacy_transaction())
- s1.connection()
+ with fixture_session() as s1:
+ with s1.begin() as trans:
+ is_(trans, s1._legacy_transaction())
+ s1.connection()
- is_(s1._transaction, None)
+ is_(s1._transaction, None)
def test_no_double_begin_explicit(self):
- s1 = Session(testing.db)
- s1.begin()
- assert_raises_message(
- sa_exc.InvalidRequestError,
- "A transaction is already begun on this Session.",
- s1.begin,
- )
+ with fixture_session() as s1:
+ s1.begin()
+ assert_raises_message(
+ sa_exc.InvalidRequestError,
+ "A transaction is already begun on this Session.",
+ s1.begin,
+ )
@testing.requires.savepoints
def test_future_rollback_is_global(self):
users = self.tables.users
- s1 = Session(testing.db, future=True)
+ with fixture_session(future=True) as s1:
+ s1.begin()
- s1.begin()
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.begin_nested()
- s1.begin_nested()
-
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back the whole transaction
- s1.rollback()
- is_(s1._legacy_transaction(), None)
+ # rolls back the whole transaction
+ s1.rollback()
+ is_(s1._legacy_transaction(), None)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 0)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 0,
+ )
- s1.commit()
- is_(s1._legacy_transaction(), None)
+ s1.commit()
+ is_(s1._legacy_transaction(), None)
@testing.requires.savepoints
def test_old_rollback_is_local(self):
users = self.tables.users
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- t1 = s1.begin()
+ t1 = s1.begin()
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.begin_nested()
+ s1.begin_nested()
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back only the savepoint
- s1.rollback()
+ # rolls back only the savepoint
+ s1.rollback()
- is_(s1._legacy_transaction(), t1)
+ is_(s1._legacy_transaction(), t1)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
- s1.commit()
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
- is_not(s1._legacy_transaction(), None)
+ s1.commit()
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
+ is_not(s1._legacy_transaction(), None)
def test_session_as_ctx_manager_one(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2212,7 +2236,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
def test_session_as_ctx_manager_future_one(self):
users = self.tables.users
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2234,7 +2258,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2250,7 +2274,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
users = self.tables.users
try:
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
@@ -2265,7 +2289,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
def test_begin_context_manager(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2296,12 +2320,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_begin_context_manager_rollback_trans(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2318,12 +2343,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_begin_context_manager_rollback_outer(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
@@ -2340,6 +2366,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_trans(self):
users = self.tables.users
@@ -2363,6 +2390,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_outer(self):
users = self.tables.users
@@ -2386,36 +2414,37 @@ class ContextManagerPlusFutureTest(FixtureTest):
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
class TransactionFlagsTest(fixtures.TestBase):
def test_in_transaction(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- n1 = s1.begin_nested()
+ n1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), n1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), n1)
- n1.rollback()
+ n1.rollback()
- is_(s1.get_nested_transaction(), None)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ eq_(s1.in_transaction(), True)
- s1.commit()
+ s1.commit()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
def test_in_transaction_subtransactions(self):
"""we'd like to do away with subtransactions for future sessions
@@ -2425,72 +2454,71 @@ class TransactionFlagsTest(fixtures.TestBase):
the external API works.
"""
- s1 = Session(testing.db)
-
- eq_(s1.in_transaction(), False)
+ with fixture_session() as s1:
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- subtrans = s1.begin(_subtrans=True)
- is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ subtrans = s1.begin(_subtrans=True)
+ is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
- is_(s1._transaction, subtrans)
+ is_(s1._transaction, subtrans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), True)
- is_(s1._transaction, trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1._transaction, trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1._transaction, None)
+ eq_(s1.in_transaction(), False)
+ is_(s1._transaction, None)
def test_in_transaction_nesting(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- sp1 = s1.begin_nested()
+ sp1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp2 = s1.begin_nested()
+ sp2 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp2)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp2)
- sp2.rollback()
+ sp2.rollback()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp1.rollback()
+ sp1.rollback()
- is_(s1.get_nested_transaction(), None)
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), False)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), False)
+ is_(s1.get_transaction(), trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
class NaturalPKRollbackTest(fixtures.MappedTest):
@@ -2674,8 +2702,11 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
class JoinIntoAnExternalTransactionFixture(object):
"""Test the "join into an external transaction" examples"""
- def setup(self):
- self.connection = testing.db.connect()
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
+ self.engine = testing.db
+ self.connection = self.engine.connect()
self.metadata = MetaData()
self.table = Table(
@@ -2686,6 +2717,17 @@ class JoinIntoAnExternalTransactionFixture(object):
self.setup_session()
+ def teardown_test(self):
+ self.teardown_session()
+
+ with self.connection.begin():
+ self._assert_count(0)
+
+ with self.connection.begin():
+ self.table.drop(self.connection)
+
+ self.connection.close()
+
def test_something(self):
A = self.A
@@ -2727,18 +2769,6 @@ class JoinIntoAnExternalTransactionFixture(object):
)
eq_(result, count)
- def teardown(self):
- self.teardown_session()
-
- with self.connection.begin():
- self._assert_count(0)
-
- with self.connection.begin():
- self.table.drop(self.connection)
-
- # return connection to the Engine
- self.connection.close()
-
class NewStyleJoinIntoAnExternalTransactionTest(
JoinIntoAnExternalTransactionFixture
@@ -2775,7 +2805,8 @@ class NewStyleJoinIntoAnExternalTransactionTest(
# rollback - everything that happened with the
# Session above (including calls to commit())
# is rolled back.
- self.trans.rollback()
+ if self.trans.is_active:
+ self.trans.rollback()
class FutureJoinIntoAnExternalTransactionTest(
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index 84373b2dc..2c35bec45 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -203,14 +203,6 @@ class UnicodeSchemaTest(fixtures.MappedTest):
cls.tables["t1"] = t1
cls.tables["t2"] = t2
- @classmethod
- def setup_class(cls):
- super(UnicodeSchemaTest, cls).setup_class()
-
- @classmethod
- def teardown_class(cls):
- super(UnicodeSchemaTest, cls).teardown_class()
-
def test_mapping(self):
t2, t1 = self.tables.t2, self.tables.t1
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 4e713627c..65089f773 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -771,14 +771,13 @@ class RudimentaryFlushTest(UOWTest):
class SingleCycleTest(UOWTest):
- def teardown(self):
+ def teardown_test(self):
engines.testing_reaper.rollback_all()
# mysql can't handle delete from nodes
# since it doesn't deal with the FKs correctly,
# so wipe out the parent_id first
with testing.db.begin() as conn:
conn.execute(self.tables.nodes.update().values(parent_id=None))
- super(SingleCycleTest, self).teardown()
def test_one_to_many_save(self):
Node, nodes = self.classes.Node, self.tables.nodes
diff --git a/test/requirements.py b/test/requirements.py
index d5a718372..3c9b39ac7 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -1624,15 +1624,15 @@ class DefaultRequirements(SuiteRequirements):
@property
def postgresql_utf8_server_encoding(self):
+ def go(config):
+ if not against(config, "postgresql"):
+ return False
- return only_if(
- lambda config: against(config, "postgresql")
- and config.db.connect(close_with_result=True)
- .exec_driver_sql("show server_encoding")
- .scalar()
- .lower()
- == "utf8"
- )
+ with config.db.connect() as conn:
+ enc = conn.exec_driver_sql("show server_encoding").scalar()
+ return enc.lower() == "utf8"
+
+ return only_if(go)
@property
def cxoracle6_or_greater(self):
diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py
index 4bef1df7f..b44971cec 100644
--- a/test/sql/test_case_statement.py
+++ b/test/sql/test_case_statement.py
@@ -26,7 +26,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
metadata = MetaData()
global info_table
info_table = Table(
@@ -52,7 +52,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
info_table.drop(conn)
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 70281d4e8..1ac3613f7 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -1203,7 +1203,7 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# TODO: we need to get dialects here somehow, perhaps in test_suite?
[
importlib.import_module("sqlalchemy.dialects.%s" % d)
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index fdffe04bf..4429753ec 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -4306,7 +4306,7 @@ class StringifySpecialTest(fixtures.TestBase):
class KwargPropagationTest(fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.sql.expression import ColumnClause, TableClause
class CatchCol(ColumnClause):
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index 2a2e70bc3..8be7eed1f 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -503,9 +503,8 @@ class DefaultRoundTripTest(fixtures.TablesTest):
Column("col11", MyType(), default="foo"),
)
- def teardown(self):
+ def teardown_test(self):
self.default_generator["x"] = 50
- super(DefaultRoundTripTest, self).teardown()
def test_standalone(self, connection):
t = self.tables.default_test
@@ -1226,7 +1225,7 @@ class SpecialTypePKTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
class MyInteger(TypeDecorator):
impl = Integer
diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py
index acc12a5fe..777565220 100644
--- a/test/sql/test_deprecations.py
+++ b/test/sql/test_deprecations.py
@@ -2488,12 +2488,12 @@ class LegacySequenceExecTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
@@ -2574,7 +2574,7 @@ class LegacySequenceExecTest(fixtures.TestBase):
class DDLDeprecatedBindTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
if inspect(conn).has_table("foo"):
conn.execute(schema.DropTable(table("foo")))
diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py
index 4edc9d025..a6001ba9d 100644
--- a/test/sql/test_external_traversal.py
+++ b/test/sql/test_external_traversal.py
@@ -47,7 +47,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
ability to copy and modify a ClauseElement in place."""
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global A, B
# establish two fictitious ClauseElements.
@@ -321,7 +321,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, t3
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1012,7 +1012,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table(
"table1",
@@ -1196,7 +1196,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1943,7 +1943,7 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global table1, table2, table3, table4
def _table(name):
@@ -2031,7 +2031,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -2128,7 +2128,7 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
# fixme: consolidate converage from elsewhere here and expand
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
index 6afe41aac..b0bcee18e 100644
--- a/test/sql/test_from_linter.py
+++ b/test/sql/test_from_linter.py
@@ -25,7 +25,7 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
- def setup(self):
+ def setup_test(self):
self.a = self.tables.table_a
self.b = self.tables.table_b
self.c = self.tables.table_c
@@ -267,8 +267,10 @@ class TestLinter(fixtures.TablesTest):
with self.bind.connect() as conn:
conn.execute(query)
- def test_no_linting(self):
- eng = engines.testing_engine(options={"enable_from_linting": False})
+ def test_no_linting(self, metadata, connection):
+ eng = engines.testing_engine(
+ options={"enable_from_linting": False, "use_reaper": False}
+ )
eng.pool = self.bind.pool # needed for SQLite
a, b = self.tables("table_a", "table_b")
query = select(a.c.col_a).where(b.c.col_b == 5)
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index 91076f9c3..32ea642d7 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -54,10 +54,10 @@ table1 = table(
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_compile(self):
@@ -938,7 +938,7 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
class ExecuteTest(fixtures.TestBase):
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
pass
def test_conn_execute(self, connection):
@@ -1113,10 +1113,10 @@ class ExecuteTest(fixtures.TestBase):
class RegisterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_GenericFunction_is_registered(self):
diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py
index 502f70ce7..9bf351f5c 100644
--- a/test/sql/test_metadata.py
+++ b/test/sql/test_metadata.py
@@ -4257,7 +4257,7 @@ class DialectKWArgTest(fixtures.TestBase):
with mock.patch("sqlalchemy.dialects.registry.load", load):
yield
- def teardown(self):
+ def teardown_test(self):
Index._kw_registry.clear()
def test_participating(self):
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index aaeed68dd..270e79ba1 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -608,7 +608,7 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
@@ -766,7 +766,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
@@ -2370,7 +2370,7 @@ class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
@@ -2403,7 +2403,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default_enhanced"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index 136f10cf4..7ad12c620 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -661,8 +661,8 @@ class CursorResultTest(fixtures.TablesTest):
assert_raises(KeyError, lambda: row._mapping["Case_insensitive"])
assert_raises(KeyError, lambda: row._mapping["casesensitive"])
- def test_row_case_sensitive_unoptimized(self):
- with engines.testing_engine().connect() as ins_conn:
+ def test_row_case_sensitive_unoptimized(self, testing_engine):
+ with testing_engine().connect() as ins_conn:
row = ins_conn.execute(
select(
literal_column("1").label("case_insensitive"),
@@ -1234,8 +1234,7 @@ class CursorResultTest(fixtures.TablesTest):
eq_(proxy[0], "value")
eq_(proxy._mapping["key"], "value")
- @testing.provide_metadata
- def test_no_rowcount_on_selects_inserts(self):
+ def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine):
"""assert that rowcount is only called on deletes and updates.
This because cursor.rowcount may can be expensive on some dialects
@@ -1244,9 +1243,7 @@ class CursorResultTest(fixtures.TablesTest):
"""
- metadata = self.metadata
-
- engine = engines.testing_engine()
+ engine = testing_engine()
t = Table("t1", metadata, Column("data", String(10)))
metadata.create_all(engine)
@@ -2132,7 +2129,9 @@ class AlternateCursorResultTest(fixtures.TablesTest):
@classmethod
def setup_bind(cls):
- cls.engine = engine = engines.testing_engine("sqlite://")
+ cls.engine = engine = engines.testing_engine(
+ "sqlite://", options={"scope": "class"}
+ )
return engine
@classmethod
diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py
index 65325aa6f..5cfc2663f 100644
--- a/test/sql/test_sequences.py
+++ b/test/sql/test_sequences.py
@@ -100,12 +100,12 @@ class SequenceExecTest(fixtures.TestBase):
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 0e1147800..64ace87df 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1375,7 +1375,7 @@ class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL):
class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
- def setup(self):
+ def setup_test(self):
class UTypeOne(types.UserDefinedType):
def get_col_spec(self):
return "UTYPEONE"
@@ -2504,7 +2504,7 @@ class BinaryTest(fixtures.TablesTest, AssertsExecutionResults):
class JSONTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@@ -3445,7 +3445,12 @@ class BooleanTest(
@testing.requires.non_native_boolean_unconstrained
def test_constraint(self, connection):
assert_raises(
- (exc.IntegrityError, exc.ProgrammingError, exc.OperationalError),
+ (
+ exc.IntegrityError,
+ exc.ProgrammingError,
+ exc.OperationalError,
+ exc.InternalError, # older pymysql's do this
+ ),
connection.exec_driver_sql,
"insert into boolean_table (id, value) values(1, 5)",
)
diff --git a/tox.ini b/tox.ini
index 8f2fd9de0..ea2b76e16 100644
--- a/tox.ini
+++ b/tox.ini
@@ -15,7 +15,9 @@ install_command=python -m pip install {env:TOX_PIP_OPTS:} {opts} {packages}
usedevelop=
cov: True
-deps=pytest>=4.6.11 # this can be 6.x once we are on python 3 only
+deps=
+ pytest>=4.6.11,<5.0; python_version < '3'
+ pytest>=6.2; python_version >= '3'
pytest-xdist
greenlet != 0.4.17
mock; python_version < '3.3'
@@ -74,9 +76,11 @@ setenv=
sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
+ py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}}
py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000}
mysql: MYSQL={env:TOX_MYSQL:--db mysql}
+ py2{,7}-mysql: MYSQL={env:TOX_MYSQL_PY2K:{env:TOX_MYSQL:--db mysql}}
mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql}
@@ -89,7 +93,7 @@ setenv=
# tox as of 2.0 blocks all environment variables from the
# outside, unless they are here (or in TOX_TESTENV_PASSENV,
# wildcards OK). Need at least these
-passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
+passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_POSTGRESQL_PY2K TOX_MYSQL TOX_MYSQL_PY2K TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
# for nocext, we rm *.so in lib in case we are doing usedevelop=True
commands=