summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/provision.py16
-rw-r--r--lib/sqlalchemy/testing/requirements.py9
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py154
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py21
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py20
5 files changed, 220 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index 7ba89b505..a8650f222 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -459,3 +459,19 @@ def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
"backend does not implement a schema name set function: %s"
% (cfg.db.url,)
)
+
+
+@register.init
+def upsert(cfg, table, returning, set_lambda=None):
+ """return the backends insert..on conflict / on dupe etc. construct.
+
+ while we should add a backend-neutral upsert construct as well, such as
+ insert().upsert(), it's important that we continue to test the
+ backend-specific insert() constructs since if we do implement
+ insert().upsert(), that would be using a different codepath for the things
+ we need to test like insertmanyvalues, etc.
+
+ """
+ raise NotImplementedError(
+ f"backend does not include an upsert implementation: {cfg.db.url}"
+ )
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 874383394..3a0fc818d 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -424,6 +424,15 @@ class SuiteRequirements(Requirements):
)
@property
+ def insertmanyvalues(self):
+ return exclusions.only_if(
+ lambda config: config.db.dialect.supports_multivalues_insert
+ and config.db.dialect.insert_returning
+ and config.db.dialect.use_insertmanyvalues,
+ "%(database)s %(does_support)s 'insertmanyvalues functionality",
+ )
+
+ @property
def tuple_in(self):
"""Target platform supports the syntax
"(x, y) IN ((x1, y1), (x2, y2), ...)"
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
index bb2dd6574..efad81930 100644
--- a/lib/sqlalchemy/testing/suite/test_dialect.py
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -11,6 +11,7 @@ from .. import fixtures
from .. import is_true
from .. import ne_
from .. import provide_metadata
+from ..assertions import expect_raises
from ..assertions import expect_raises_message
from ..config import requirements
from ..provision import set_default_schema_on_connection
@@ -412,3 +413,156 @@ class DifficultParametersTest(fixtures.TestBase):
# name works as the key from cursor.description
eq_(row._mapping[name], "some name")
+
+
+class ReturningGuardsTest(fixtures.TablesTest):
+ """test that the various 'returning' flags are set appropriately"""
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+
+ Table(
+ "t",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ @testing.fixture
+ def run_stmt(self, connection):
+ t = self.tables.t
+
+ def go(stmt, executemany, id_param_name, expect_success):
+ stmt = stmt.returning(t.c.id)
+
+ if executemany:
+ if not expect_success:
+ # for RETURNING executemany(), we raise our own
+ # error as this is independent of general RETURNING
+ # support
+ with expect_raises_message(
+ exc.StatementError,
+ rf"Dialect {connection.dialect.name}\+"
+ f"{connection.dialect.driver} with "
+ f"current server capabilities does not support "
+ f".*RETURNING when executemany is used",
+ ):
+ result = connection.execute(
+ stmt,
+ [
+ {id_param_name: 1, "data": "d1"},
+ {id_param_name: 2, "data": "d2"},
+ {id_param_name: 3, "data": "d3"},
+ ],
+ )
+ else:
+ result = connection.execute(
+ stmt,
+ [
+ {id_param_name: 1, "data": "d1"},
+ {id_param_name: 2, "data": "d2"},
+ {id_param_name: 3, "data": "d3"},
+ ],
+ )
+ eq_(result.all(), [(1,), (2,), (3,)])
+ else:
+ if not expect_success:
+ # for RETURNING execute(), we pass all the way to the DB
+ # and let it fail
+ with expect_raises(exc.DBAPIError):
+ connection.execute(
+ stmt, {id_param_name: 1, "data": "d1"}
+ )
+ else:
+ result = connection.execute(
+ stmt, {id_param_name: 1, "data": "d1"}
+ )
+ eq_(result.all(), [(1,)])
+
+ return go
+
+ def test_insert_single(self, connection, run_stmt):
+ t = self.tables.t
+
+ stmt = t.insert()
+
+ run_stmt(stmt, False, "id", connection.dialect.insert_returning)
+
+ def test_insert_many(self, connection, run_stmt):
+ t = self.tables.t
+
+ stmt = t.insert()
+
+ run_stmt(
+ stmt, True, "id", connection.dialect.insert_executemany_returning
+ )
+
+ def test_update_single(self, connection, run_stmt):
+ t = self.tables.t
+
+ connection.execute(
+ t.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ stmt = t.update().where(t.c.id == bindparam("b_id"))
+
+ run_stmt(stmt, False, "b_id", connection.dialect.update_returning)
+
+ def test_update_many(self, connection, run_stmt):
+ t = self.tables.t
+
+ connection.execute(
+ t.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ stmt = t.update().where(t.c.id == bindparam("b_id"))
+
+ run_stmt(
+ stmt, True, "b_id", connection.dialect.update_executemany_returning
+ )
+
+ def test_delete_single(self, connection, run_stmt):
+ t = self.tables.t
+
+ connection.execute(
+ t.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ stmt = t.delete().where(t.c.id == bindparam("b_id"))
+
+ run_stmt(stmt, False, "b_id", connection.dialect.delete_returning)
+
+ def test_delete_many(self, connection, run_stmt):
+ t = self.tables.t
+
+ connection.execute(
+ t.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ stmt = t.delete().where(t.c.id == bindparam("b_id"))
+
+ run_stmt(
+ stmt, True, "b_id", connection.dialect.delete_executemany_returning
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index 2307d3b3f..ae54f6bcd 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -338,6 +338,7 @@ class ReturningTest(fixtures.TablesTest):
r = connection.execute(
table.insert().returning(table.c.id), dict(data="some data")
)
+
pk = r.first()[0]
fetched_pk = connection.scalar(select(table.c.id))
eq_(fetched_pk, pk)
@@ -357,5 +358,25 @@ class ReturningTest(fixtures.TablesTest):
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
eq_(r.inserted_primary_key, (pk,))
+ @requirements.insert_executemany_returning
+ def test_insertmanyvalues_returning(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert().returning(
+ self.tables.autoinc_pk.c.id
+ ),
+ [
+ {"data": "d1"},
+ {"data": "d2"},
+ {"data": "d3"},
+ {"data": "d4"},
+ {"data": "d5"},
+ ],
+ )
+ rall = r.all()
+
+ pks = connection.execute(select(self.tables.autoinc_pk.c.id))
+
+ eq_(rall, pks.all())
+
__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index 59e9cc7f4..7d79c67ae 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -164,6 +164,26 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
)
self._assert_table(connection)
+ @requirements.insert_executemany_returning
+ def test_executemany_returning_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ connection.execute(
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
+ )
+ result = connection.execute(
+ percent_table.insert().returning(
+ percent_table.c["percent%"],
+ percent_table.c["spaces % more spaces"],
+ ),
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
+ )
+ eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
+ self._assert_table(connection)
+
def _assert_table(self, conn):
percent_table = self.tables.percent_table
lightweight_percent_table = self.tables.lightweight_percent_table