diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r-- | lib/sqlalchemy/testing/provision.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_dialect.py | 154 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_insert.py | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_results.py | 20 |
5 files changed, 220 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 7ba89b505..a8650f222 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -459,3 +459,19 @@ def set_default_schema_on_connection(cfg, dbapi_connection, schema_name): "backend does not implement a schema name set function: %s" % (cfg.db.url,) ) + + +@register.init +def upsert(cfg, table, returning, set_lambda=None): + """return the backends insert..on conflict / on dupe etc. construct. + + while we should add a backend-neutral upsert construct as well, such as + insert().upsert(), it's important that we continue to test the + backend-specific insert() constructs since if we do implement + insert().upsert(), that would be using a different codepath for the things + we need to test like insertmanyvalues, etc. + + """ + raise NotImplementedError( + f"backend does not include an upsert implementation: {cfg.db.url}" + ) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 874383394..3a0fc818d 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -424,6 +424,15 @@ class SuiteRequirements(Requirements): ) @property + def insertmanyvalues(self): + return exclusions.only_if( + lambda config: config.db.dialect.supports_multivalues_insert + and config.db.dialect.insert_returning + and config.db.dialect.use_insertmanyvalues, + "%(database)s %(does_support)s 'insertmanyvalues functionality", + ) + + @property def tuple_in(self): """Target platform supports the syntax "(x, y) IN ((x1, y1), (x2, y2), ...)" diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index bb2dd6574..efad81930 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -11,6 +11,7 @@ from .. import fixtures from .. import is_true from .. import ne_ from .. import provide_metadata +from ..assertions import expect_raises from ..assertions import expect_raises_message from ..config import requirements from ..provision import set_default_schema_on_connection @@ -412,3 +413,156 @@ class DifficultParametersTest(fixtures.TestBase): # name works as the key from cursor.description eq_(row._mapping[name], "some name") + + +class ReturningGuardsTest(fixtures.TablesTest): + """test that the various 'returning' flags are set appropriately""" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + + Table( + "t", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + @testing.fixture + def run_stmt(self, connection): + t = self.tables.t + + def go(stmt, executemany, id_param_name, expect_success): + stmt = stmt.returning(t.c.id) + + if executemany: + if not expect_success: + # for RETURNING executemany(), we raise our own + # error as this is independent of general RETURNING + # support + with expect_raises_message( + exc.StatementError, + rf"Dialect {connection.dialect.name}\+" + f"{connection.dialect.driver} with " + f"current server capabilities does not support " + f".*RETURNING when executemany is used", + ): + result = connection.execute( + stmt, + [ + {id_param_name: 1, "data": "d1"}, + {id_param_name: 2, "data": "d2"}, + {id_param_name: 3, "data": "d3"}, + ], + ) + else: + result = connection.execute( + stmt, + [ + {id_param_name: 1, "data": "d1"}, + {id_param_name: 2, "data": "d2"}, + {id_param_name: 3, "data": "d3"}, + ], + ) + eq_(result.all(), [(1,), (2,), (3,)]) + else: + if not expect_success: + # for RETURNING execute(), we pass all the way to the DB + # and let it fail + with expect_raises(exc.DBAPIError): + connection.execute( + stmt, {id_param_name: 1, "data": "d1"} + ) + else: + result = connection.execute( + stmt, {id_param_name: 1, "data": "d1"} + ) + eq_(result.all(), [(1,)]) + + return go + + def test_insert_single(self, connection, run_stmt): + t = self.tables.t + + stmt = t.insert() + + run_stmt(stmt, False, "id", connection.dialect.insert_returning) + + def test_insert_many(self, connection, run_stmt): + t = self.tables.t + + stmt = t.insert() + + run_stmt( + stmt, True, "id", connection.dialect.insert_executemany_returning + ) + + def test_update_single(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.update().where(t.c.id == bindparam("b_id")) + + run_stmt(stmt, False, "b_id", connection.dialect.update_returning) + + def test_update_many(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.update().where(t.c.id == bindparam("b_id")) + + run_stmt( + stmt, True, "b_id", connection.dialect.update_executemany_returning + ) + + def test_delete_single(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.delete().where(t.c.id == bindparam("b_id")) + + run_stmt(stmt, False, "b_id", connection.dialect.delete_returning) + + def test_delete_many(self, connection, run_stmt): + t = self.tables.t + + connection.execute( + t.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + stmt = t.delete().where(t.c.id == bindparam("b_id")) + + run_stmt( + stmt, True, "b_id", connection.dialect.delete_executemany_returning + ) diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index 2307d3b3f..ae54f6bcd 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -338,6 +338,7 @@ class ReturningTest(fixtures.TablesTest): r = connection.execute( table.insert().returning(table.c.id), dict(data="some data") ) + pk = r.first()[0] fetched_pk = connection.scalar(select(table.c.id)) eq_(fetched_pk, pk) @@ -357,5 +358,25 @@ class ReturningTest(fixtures.TablesTest): pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) eq_(r.inserted_primary_key, (pk,)) + @requirements.insert_executemany_returning + def test_insertmanyvalues_returning(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert().returning( + self.tables.autoinc_pk.c.id + ), + [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + {"data": "d4"}, + {"data": "d5"}, + ], + ) + rall = r.all() + + pks = connection.execute(select(self.tables.autoinc_pk.c.id)) + + eq_(rall, pks.all()) + __all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 59e9cc7f4..7d79c67ae 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -164,6 +164,26 @@ class PercentSchemaNamesTest(fixtures.TablesTest): ) self._assert_table(connection) + @requirements.insert_executemany_returning + def test_executemany_returning_roundtrip(self, connection): + percent_table = self.tables.percent_table + connection.execute( + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} + ) + result = connection.execute( + percent_table.insert().returning( + percent_table.c["percent%"], + percent_table.c["spaces % more spaces"], + ), + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], + ) + eq_(result.all(), [(7, 11), (9, 10), (11, 9)]) + self._assert_table(connection) + def _assert_table(self, conn): percent_table = self.tables.percent_table lightweight_percent_table = self.tables.lightweight_percent_table |