summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-07 11:17:47 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-07 11:18:25 -0500
commita8102ba496c4c11eae6b904a962cf352902f0de7 (patch)
treef181217c8f5c6a9ef7e354c8cb86561ae880d949
parentde95dc4ce5ed44cc63d9fd8b2e00a78858a73d2a (diff)
downloadsqlalchemy-a8102ba496c4c11eae6b904a962cf352902f0de7.tar.gz
test sqlite w/ savepoint workaround in session fixture test
Fixes: #7795 Change-Id: Ib790581555656c088f86c00080c70d19ca295a03 (cherry picked from commit fbacb1991585202a5bf22acb0d36b5c979bcfad8)
-rw-r--r--lib/sqlalchemy/testing/engines.py14
-rw-r--r--test/orm/test_transaction.py12
-rw-r--r--test/requirements.py10
3 files changed, 30 insertions, 6 deletions
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index a92d476ac..b8be6b9bd 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -276,10 +276,12 @@ def testing_engine(
future=None,
asyncio=False,
transfer_staticpool=False,
+ _sqlite_savepoint=False,
):
"""Produce an engine configured by --options with optional overrides."""
if asyncio:
+ assert not _sqlite_savepoint
from sqlalchemy.ext.asyncio import (
create_async_engine as create_engine,
)
@@ -294,9 +296,11 @@ def testing_engine(
if not options:
use_reaper = True
scope = "function"
+ sqlite_savepoint = False
else:
use_reaper = options.pop("use_reaper", True)
scope = options.pop("scope", "function")
+ sqlite_savepoint = options.pop("sqlite_savepoint", False)
url = url or config.db.url
@@ -312,6 +316,16 @@ def testing_engine(
engine = create_engine(url, **options)
+ if sqlite_savepoint and engine.name == "sqlite":
+ # apply SQLite savepoint workaround
+ @event.listens_for(engine, "connect")
+ def do_connect(dbapi_connection, connection_record):
+ dbapi_connection.isolation_level = None
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ conn.exec_driver_sql("BEGIN")
+
if transfer_staticpool:
from sqlalchemy.pool import StaticPool
diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py
index 603ec079a..e077220e1 100644
--- a/test/orm/test_transaction.py
+++ b/test/orm/test_transaction.py
@@ -2526,10 +2526,10 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
class JoinIntoAnExternalTransactionFixture(object):
"""Test the "join into an external transaction" examples"""
- __leave_connections_for_teardown__ = True
-
def setup_test(self):
- self.engine = testing.db
+ self.engine = engines.testing_engine(
+ options={"use_reaper": False, "sqlite_savepoint": True}
+ )
self.connection = self.engine.connect()
self.metadata = MetaData()
@@ -2590,7 +2590,7 @@ class NewStyleJoinIntoAnExternalTransactionTest(
# bind an individual Session to the connection
self.session = Session(bind=self.connection, future=True)
- if testing.requires.savepoints.enabled:
+ if testing.requires.compat_savepoints.enabled:
self.nested = self.connection.begin_nested()
@event.listens_for(self.session, "after_transaction_end")
@@ -2607,7 +2607,7 @@ class NewStyleJoinIntoAnExternalTransactionTest(
if self.trans.is_active:
self.trans.rollback()
- @testing.requires.savepoints
+ @testing.requires.compat_savepoints
def test_something_with_context_managers(self):
A = self.A
@@ -2673,7 +2673,7 @@ class LegacyJoinIntoAnExternalTransactionTest(
# bind an individual Session to the connection
self.session = Session(bind=self.connection)
- if testing.requires.savepoints.enabled:
+ if testing.requires.compat_savepoints.enabled:
# start the session in a SAVEPOINT...
self.session.begin_nested()
diff --git a/test/requirements.py b/test/requirements.py
index 1780e3b21..4c9ac40c5 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -559,6 +559,16 @@ class DefaultRequirements(SuiteRequirements):
)
@property
+ def compat_savepoints(self):
+ """Target database must support savepoints, or a compat
+ recipe e.g. for sqlite will be used"""
+
+ return skip_if(
+ ["sybase", ("mysql", "<", (5, 0, 3))],
+ "savepoints not supported",
+ )
+
+ @property
def savepoints_w_release(self):
return self.savepoints + skip_if(
["oracle", "mssql"],