diff options
Diffstat (limited to 'test/sql/test_insert_exec.py')
-rw-r--r-- | test/sql/test_insert_exec.py | 1625 |
1 files changed, 1624 insertions, 1 deletions
diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 3b5a1856c..f545671e7 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -1,11 +1,19 @@ +import contextlib +import functools import itertools +import uuid from sqlalchemy import and_ +from sqlalchemy import ARRAY from sqlalchemy import bindparam +from sqlalchemy import DateTime from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import Identity +from sqlalchemy import insert +from sqlalchemy import insert_sentinel from sqlalchemy import INT from sqlalchemy import Integer from sqlalchemy import literal @@ -14,16 +22,22 @@ from sqlalchemy import Sequence from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import TypeDecorator +from sqlalchemy import Uuid from sqlalchemy import VARCHAR from sqlalchemy.engine import cursor as _cursor +from sqlalchemy.sql.compiler import InsertmanyvaluesSentinelOpts from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing import provision +from sqlalchemy.testing.fixtures import insertmanyvalues_fixture from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -924,7 +938,7 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): config, data, (data,), - lambda inserted: {"x": inserted.x + " upserted"}, + set_lambda=lambda inserted: {"x": inserted.x + " upserted"}, ) result = connection.execute(stmt, upsert_data) @@ -1169,3 +1183,1612 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): "INSERT..RETURNING when executemany", ): conn.execute(stmt.returning(t.c.id), data) + + +class IMVSentinelTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("insert_returning",) + + def _expect_downgrade_warnings( + self, + *, + warn_for_downgrades, + sort_by_parameter_order, + separate_sentinel=False, + server_autoincrement=False, + client_side_pk=False, + autoincrement_is_sequence=False, + connection=None, + ): + + if connection: + dialect = connection.dialect + else: + dialect = testing.db.dialect + + if ( + sort_by_parameter_order + and warn_for_downgrades + and dialect.use_insertmanyvalues + ): + + if ( + not separate_sentinel + and ( + server_autoincrement + and ( + not ( + dialect.insertmanyvalues_implicit_sentinel # noqa: E501 + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + or ( + autoincrement_is_sequence + and not ( + dialect.insertmanyvalues_implicit_sentinel # noqa: E501 + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ) + ) + ) + or ( + not separate_sentinel + and not server_autoincrement + and not client_side_pk + ) + ): + return expect_warnings( + "Batches were downgraded", + raise_on_any_unexpected=True, + ) + + return contextlib.nullcontext() + + @testing.variation + def sort_by_parameter_order(self): + return [True, False] + + @testing.variation + def warn_for_downgrades(self): + return [True, False] + + @testing.variation + def randomize_returning(self): + return [True, False] + + @testing.requires.insertmanyvalues + def test_fixture_randomizing(self, connection, metadata): + t = Table( + "t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", String(50)), + ) + metadata.create_all(connection) + + insertmanyvalues_fixture(connection, randomize_rows=True) + + results = set() + + for i in range(15): + result = connection.execute( + insert(t).returning(t.c.data, sort_by_parameter_order=False), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + + hashed_result = tuple(result.all()) + results.add(hashed_result) + if len(results) > 1: + return + else: + assert False, "got same order every time for 15 tries" + + @testing.only_on("postgresql>=13") + @testing.variation("downgrade", [True, False]) + def test_fixture_downgraded(self, connection, metadata, downgrade): + t = Table( + "t", + metadata, + Column( + "id", + Uuid(), + server_default=func.gen_random_uuid(), + primary_key=True, + ), + Column("data", String(50)), + ) + metadata.create_all(connection) + + r1 = connection.execute( + insert(t).returning(t.c.data, sort_by_parameter_order=True), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + eq_(r1.all(), [("d1",), ("d2",), ("d3",)]) + + if downgrade: + insertmanyvalues_fixture(connection, warn_on_downgraded=True) + + with self._expect_downgrade_warnings( + warn_for_downgrades=True, + sort_by_parameter_order=True, + ): + connection.execute( + insert(t).returning( + t.c.data, sort_by_parameter_order=True + ), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + else: + # run a plain test to help ensure the fixture doesn't leak to + # other tests + r1 = connection.execute( + insert(t).returning(t.c.data, sort_by_parameter_order=True), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + eq_(r1.all(), [("d1",), ("d2",), ("d3",)]) + + @testing.variation( + "sequence_type", + [ + ("sequence", testing.requires.sequences), + ("identity", testing.requires.identity_columns), + ], + ) + @testing.variation("increment", ["positive", "negative", "implicit"]) + @testing.variation("explicit_sentinel", [True, False]) + def test_invalid_identities( + self, + metadata, + connection, + warn_for_downgrades, + randomize_returning, + sort_by_parameter_order, + sequence_type: testing.Variation, + increment: testing.Variation, + explicit_sentinel, + ): + if sequence_type.sequence: + seq_cls = functools.partial(Sequence, name="t1_id_seq") + elif sequence_type.identity: + seq_cls = Identity + else: + sequence_type.fail() + + if increment.implicit: + sequence = seq_cls(start=1) + elif increment.positive: + sequence = seq_cls(start=1, increment=1) + elif increment.negative: + sequence = seq_cls(start=-1, increment=-1) + else: + increment.fail() + + t1 = Table( + "t1", + metadata, + Column( + "id", + Integer, + sequence, + primary_key=True, + insert_sentinel=bool(explicit_sentinel), + ), + Column("data", String(50)), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.id, + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}"} for i in range(10)] + + use_imv = testing.db.dialect.use_insertmanyvalues + if ( + use_imv + and increment.negative + and explicit_sentinel + and sort_by_parameter_order + ): + with expect_raises_message( + exc.InvalidRequestError, + rf"Can't use " + rf"{'SEQUENCE' if sequence_type.sequence else 'IDENTITY'} " + rf"default with negative increment", + ): + connection.execute(stmt, data) + return + elif ( + use_imv + and explicit_sentinel + and sort_by_parameter_order + and sequence_type.sequence + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + r"Column t1.id can't be explicitly marked as a sentinel " + r"column .* as the particular type of default generation", + ): + connection.execute(stmt, data) + return + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + server_autoincrement=not increment.negative, + autoincrement_is_sequence=sequence_type.sequence, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + if increment.negative: + expected_data = [(-1 - i, f"d{i}") for i in range(10)] + else: + expected_data = [(i + 1, f"d{i}") for i in range(10)] + + eq_( + coll(result), + coll(expected_data), + ) + + @testing.combinations( + Integer(), + String(50), + (ARRAY(Integer()), testing.requires.array_type), + DateTime(), + Uuid(), + argnames="datatype", + ) + def test_inserts_w_all_nulls( + self, connection, metadata, sort_by_parameter_order, datatype + ): + """this test is geared towards the INSERT..SELECT VALUES case, + where if the VALUES have all NULL for some column, PostgreSQL assumes + the datatype must be TEXT and throws for other table datatypes. So an + additional layer of casts is applied to the SELECT p0,p1, p2... part of + the statement for all datatypes unconditionally. Even though the VALUES + clause also has bind casts for selected datatypes, this NULL handling + is needed even for simple datatypes. We'd prefer not to render bind + casts for all possible datatypes as that affects other kinds of + statements as well and also is very verbose for insertmanyvalues. + + + """ + t = Table( + "t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", datatype), + ) + metadata.create_all(connection) + result = connection.execute( + insert(t).returning( + t.c.id, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + [{"data": None}, {"data": None}, {"data": None}], + ) + eq_(set(result), {(1,), (2,), (3,)}) + + @testing.variation("pk_type", ["autoinc", "clientside"]) + @testing.variation("add_sentinel", ["none", "clientside", "sentinel"]) + def test_imv_w_additional_values( + self, + metadata, + connection, + sort_by_parameter_order, + pk_type: testing.Variation, + randomize_returning, + warn_for_downgrades, + add_sentinel, + ): + if pk_type.autoinc: + pk_col = Column("id", Integer(), Identity(), primary_key=True) + elif pk_type.clientside: + pk_col = Column("id", Uuid(), default=uuid.uuid4, primary_key=True) + else: + pk_type.fail() + + if add_sentinel.clientside: + extra_col = insert_sentinel( + "sentinel", type_=Uuid(), default=uuid.uuid4 + ) + elif add_sentinel.sentinel: + extra_col = insert_sentinel("sentinel") + else: + extra_col = Column("sentinel", Integer()) + + t1 = Table( + "t1", + metadata, + pk_col, + Column("data", String(30)), + Column("moredata", String(30)), + extra_col, + Column( + "has_server_default", + String(50), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = ( + insert(t1) + .values(moredata="more data") + .returning( + t1.c.data, + t1.c.moredata, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + ) + data = [{"data": f"d{i}"} for i in range(10)] + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + separate_sentinel=not add_sentinel.none, + server_autoincrement=pk_type.autoinc, + client_side_pk=pk_type.clientside, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll( + [ + (f"d{i}", "more data", "some_server_default") + for i in range(10) + ] + ), + ) + + def test_sentinel_incorrect_rowcount( + self, metadata, connection, sort_by_parameter_order + ): + """test assertions to ensure sentinel values don't have duplicates""" + + uuids = [uuid.uuid4() for i in range(10)] + + # make some dupes + uuids[3] = uuids[5] + uuids[9] = uuids[5] + + t1 = Table( + "data", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", String(50)), + insert_sentinel( + "uuids", + Uuid(), + default=functools.partial(next, iter(uuids)), + ), + ) + + metadata.create_all(connection) + + stmt = insert(t1).returning( + t1.c.data, + t1.c.uuids, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}"} for i in range(10)] + + if testing.db.dialect.use_insertmanyvalues and sort_by_parameter_order: + with expect_raises_message( + exc.InvalidRequestError, + "Sentinel-keyed result set did not produce correct " + "number of rows 10; produced 8.", + ): + connection.execute(stmt, data) + else: + result = connection.execute(stmt, data) + eq_( + set(result.all()), + {(f"d{i}", uuids[i]) for i in range(10)}, + ) + + @testing.variation("resolve_sentinel_values", [True, False]) + def test_sentinel_cant_match_keys( + self, + metadata, + connection, + sort_by_parameter_order, + resolve_sentinel_values, + ): + """test assertions to ensure sentinel values passed in parameter + structures can be identified when they come back in cursor.fetchall(). + + Values that are further modified by the database driver or by + SQL expressions (as in the case below) before being INSERTed + won't match coming back out, so datatypes need to implement + _sentinel_value_resolver() if this is the case. + + """ + + class UnsymmetricDataType(TypeDecorator): + cache_ok = True + impl = String + + def bind_expression(self, bindparam): + return func.lower(bindparam) + + if resolve_sentinel_values: + + def _sentinel_value_resolver(self, dialect): + def fix_sentinels(value): + return value.lower() + + return fix_sentinels + + t1 = Table( + "data", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", String(50)), + insert_sentinel("unsym", UnsymmetricDataType(10)), + ) + + metadata.create_all(connection) + + stmt = insert(t1).returning( + t1.c.data, + t1.c.unsym, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}", "unsym": f"UPPER_d{i}"} for i in range(10)] + + if ( + testing.db.dialect.use_insertmanyvalues + and sort_by_parameter_order + and not resolve_sentinel_values + ): + with expect_raises_message( + exc.InvalidRequestError, + r"Can't match sentinel values in result set to parameter " + r"sets; key 'UPPER_d.' was not found.", + ): + connection.execute(stmt, data) + else: + result = connection.execute(stmt, data) + eq_( + set(result.all()), + {(f"d{i}", f"upper_d{i}") for i in range(10)}, + ) + + @testing.variation("add_insert_sentinel", [True, False]) + def test_sentinel_insert_default_pk_only( + self, + metadata, + connection, + sort_by_parameter_order, + add_insert_sentinel, + ): + t1 = Table( + "data", + metadata, + Column( + "id", + Integer, + Identity(), + insert_sentinel=bool(add_insert_sentinel), + primary_key=True, + ), + Column("data", String(50)), + ) + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, randomize_rows=True, warn_on_downgraded=False + ) + + stmt = insert(t1).returning( + t1.c.id, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{} for i in range(3)] + + if ( + testing.db.dialect.use_insertmanyvalues + and add_insert_sentinel + and sort_by_parameter_order + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column data.id can't be explicitly marked as a " + f"sentinel column when using the {testing.db.dialect.name} " + "dialect", + ): + connection.execute(stmt, data) + return + else: + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + # if we used a client side default function, or we had no sentinel + # at all, we're sorted + coll = list + else: + # otherwise we are not, we randomized the order in any case + coll = set + + eq_( + coll(result), + coll( + [ + (1,), + (2,), + (3,), + ] + ), + ) + + @testing.only_on("postgresql>=13") + @testing.variation("default_type", ["server_side", "client_side"]) + @testing.variation("add_insert_sentinel", [True, False]) + def test_no_sentinel_on_non_int_ss_function( + self, + metadata, + connection, + add_insert_sentinel, + default_type, + sort_by_parameter_order, + ): + + t1 = Table( + "data", + metadata, + Column( + "id", + Uuid(), + server_default=func.gen_random_uuid() + if default_type.server_side + else None, + default=uuid.uuid4 if default_type.client_side else None, + primary_key=True, + insert_sentinel=bool(add_insert_sentinel), + ), + Column("data", String(50)), + ) + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, randomize_rows=True, warn_on_downgraded=False + ) + + stmt = insert(t1).returning( + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + ] + + if ( + default_type.server_side + and add_insert_sentinel + and sort_by_parameter_order + ): + with expect_raises_message( + exc.InvalidRequestError, + r"Column data.id can't be a sentinel column because it uses " + r"an explicit server side default that's not the Identity\(\)", + ): + connection.execute(stmt, data) + return + else: + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + # if we used a client side default function, or we had no sentinel + # at all, we're sorted + coll = list + else: + # otherwise we are not, we randomized the order in any case + coll = set + + eq_( + coll(result), + coll( + [ + ("d1",), + ("d2",), + ("d3",), + ] + ), + ) + + @testing.variation( + "pk_type", + [ + ("plain_autoinc", testing.requires.autoincrement_without_sequence), + ("sequence", testing.requires.sequences), + ("identity", testing.requires.identity_columns), + ], + ) + @testing.variation( + "sentinel", + [ + "none", # passes because we automatically downgrade + # for no sentinel col + "implicit_not_omitted", + "implicit_omitted", + "explicit", + "explicit_but_nullable", + "default_uuid", + "default_string_uuid", + ("identity", testing.requires.multiple_identity_columns), + ("sequence", testing.requires.sequences), + ], + ) + def test_sentinel_col_configurations( + self, + pk_type: testing.Variation, + sentinel: testing.Variation, + sort_by_parameter_order, + randomize_returning, + metadata, + connection, + ): + + if pk_type.plain_autoinc: + pk_col = Column("id", Integer, primary_key=True) + elif pk_type.sequence: + pk_col = Column( + "id", + Integer, + Sequence("result_id_seq", start=1), + primary_key=True, + ) + elif pk_type.identity: + pk_col = Column("id", Integer, Identity(), primary_key=True) + else: + pk_type.fail() + + if sentinel.implicit_not_omitted or sentinel.implicit_omitted: + _sentinel = insert_sentinel( + "sentinel", + omit_from_statements=bool(sentinel.implicit_omitted), + ) + elif sentinel.explicit: + _sentinel = Column( + "some_uuid", Uuid(), nullable=False, insert_sentinel=True + ) + elif sentinel.explicit_but_nullable: + _sentinel = Column("some_uuid", Uuid(), insert_sentinel=True) + elif sentinel.default_uuid or sentinel.default_string_uuid: + _sentinel = Column( + "some_uuid", + Uuid(native_uuid=bool(sentinel.default_uuid)), + insert_sentinel=True, + default=uuid.uuid4, + ) + elif sentinel.identity: + _sentinel = Column( + "some_identity", + Integer, + Identity(), + insert_sentinel=True, + ) + elif sentinel.sequence: + _sentinel = Column( + "some_identity", + Integer, + Sequence("some_id_seq", start=1), + insert_sentinel=True, + ) + else: + _sentinel = Column("some_uuid", Uuid()) + + t = Table("t", metadata, pk_col, Column("data", String(50)), _sentinel) + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=True, + ) + + stmt = insert(t).returning( + pk_col, + t.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + if sentinel.explicit: + data = [ + {"data": f"d{i}", "some_uuid": uuid.uuid4()} + for i in range(150) + ] + else: + data = [{"data": f"d{i}"} for i in range(150)] + + expect_sentinel_use = ( + sort_by_parameter_order + and testing.db.dialect.insert_returning + and testing.db.dialect.use_insertmanyvalues + ) + + if sentinel.explicit_but_nullable and expect_sentinel_use: + with expect_raises_message( + exc.InvalidRequestError, + "Column t.some_uuid has been marked as a sentinel column " + "with no default generation function; it at least needs to " + "be marked nullable=False", + ): + connection.execute(stmt, data) + return + + elif ( + expect_sentinel_use + and sentinel.sequence + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column t.some_identity can't be explicitly marked as a " + f"sentinel column when using the {testing.db.dialect.name} " + "dialect", + ): + connection.execute(stmt, data) + return + + elif ( + sentinel.none + and expect_sentinel_use + and stmt.compile( + dialect=testing.db.dialect + )._get_sentinel_column_for_table(t) + is None + ): + with expect_warnings( + "Batches were downgraded for sorted INSERT", + raise_on_any_unexpected=True, + ): + result = connection.execute(stmt, data) + else: + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + eq_(list(result), [(i + 1, f"d{i}") for i in range(150)]) + else: + eq_(set(result), {(i + 1, f"d{i}") for i in range(150)}) + + @testing.variation( + "return_type", ["include_sentinel", "default_only", "return_defaults"] + ) + @testing.variation("add_sentinel_flag_to_col", [True, False]) + def test_sentinel_on_non_autoinc_primary_key( + self, + metadata, + connection, + return_type: testing.Variation, + sort_by_parameter_order, + randomize_returning, + add_sentinel_flag_to_col, + ): + uuids = [uuid.uuid4() for i in range(10)] + _some_uuids = iter(uuids) + + t1 = Table( + "data", + metadata, + Column( + "id", + Uuid(), + default=functools.partial(next, _some_uuids), + primary_key=True, + insert_sentinel=bool(add_sentinel_flag_to_col), + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=True, + ) + + if sort_by_parameter_order: + collection_cls = list + else: + collection_cls = set + + metadata.create_all(connection) + + if sort_by_parameter_order: + kw = {"sort_by_parameter_order": True} + else: + kw = {} + + if return_type.include_sentinel: + stmt = t1.insert().returning( + t1.c.id, t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.default_only: + stmt = t1.insert().returning( + t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.return_defaults: + stmt = t1.insert().return_defaults(**kw) + + else: + return_type.fail() + + r = connection.execute( + stmt, + [{"data": f"d{i}"} for i in range(1, 6)], + ) + + if return_type.include_sentinel: + eq_(r.keys(), ["id", "data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [ + (uuids[i], f"d{i+1}", "some_server_default") + for i in range(5) + ] + ), + ) + elif return_type.default_only: + eq_(r.keys(), ["data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [ + ( + f"d{i+1}", + "some_server_default", + ) + for i in range(5) + ] + ), + ) + elif return_type.return_defaults: + eq_(r.keys(), ["has_server_default"]) + eq_(r.inserted_primary_key_rows, [(uuids[i],) for i in range(5)]) + eq_( + r.returned_defaults_rows, + [ + ("some_server_default",), + ("some_server_default",), + ("some_server_default",), + ("some_server_default",), + ("some_server_default",), + ], + ) + eq_(r.all(), []) + else: + return_type.fail() + + def test_client_composite_pk( + self, + metadata, + connection, + randomize_returning, + sort_by_parameter_order, + warn_for_downgrades, + ): + uuids = [uuid.uuid4() for i in range(10)] + + t1 = Table( + "data", + metadata, + Column( + "id1", + Uuid(), + default=functools.partial(next, iter(uuids)), + primary_key=True, + ), + Column( + "id2", + # note this is testing that plain populated PK cols + # also qualify as sentinels since they have to be there + String(30), + primary_key=True, + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + result = connection.execute( + insert(t1).returning( + t1.c.id1, + t1.c.id2, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + [{"id2": f"id{i}", "data": f"d{i}"} for i in range(10)], + ) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll( + [ + (uuids[i], f"id{i}", f"d{i}", "some_server_default") + for i in range(10) + ] + ), + ) + + @testing.variation("add_sentinel", [True, False]) + @testing.variation( + "set_identity", [(True, testing.requires.identity_columns), False] + ) + def test_no_pk( + self, + metadata, + connection, + randomize_returning, + sort_by_parameter_order, + warn_for_downgrades, + add_sentinel, + set_identity, + ): + if set_identity: + id_col = Column("id", Integer(), Identity()) + else: + id_col = Column("id", Integer()) + + uuids = [uuid.uuid4() for i in range(10)] + + sentinel_col = Column( + "unique_id", + Uuid, + default=functools.partial(next, iter(uuids)), + insert_sentinel=bool(add_sentinel), + ) + t1 = Table( + "nopk", + metadata, + id_col, + Column("data", String(50)), + sentinel_col, + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.id, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + if not set_identity: + data = [{"id": i + 1, "data": f"d{i}"} for i in range(10)] + else: + data = [{"data": f"d{i}"} for i in range(10)] + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + separate_sentinel=add_sentinel, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll([(i + 1, f"d{i}", "some_server_default") for i in range(10)]), + ) + + @testing.variation("add_sentinel_to_col", [True, False]) + @testing.variation( + "set_autoincrement", [True, (False, testing.skip_if("mariadb"))] + ) + def test_hybrid_client_composite_pk( + self, + metadata, + connection, + randomize_returning, + sort_by_parameter_order, + warn_for_downgrades, + add_sentinel_to_col, + set_autoincrement, + ): + """test a pk that is part server generated part client generated. + + The server generated col by itself can be the sentinel. if it's + part of the PK and is autoincrement=True then it is automatically + used as such. if not, there's a graceful downgrade. + + """ + + t1 = Table( + "data", + metadata, + Column( + "idint", + Integer, + Identity(), + autoincrement=True if set_autoincrement else "auto", + primary_key=True, + insert_sentinel=bool(add_sentinel_to_col), + ), + Column( + "idstr", + String(30), + primary_key=True, + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + + no_autoincrement = ( + not testing.requires.supports_autoincrement_w_composite_pk.enabled # noqa: E501 + ) + if set_autoincrement and no_autoincrement: + with expect_raises_message( + exc.CompileError, + r".*SQLite does not support autoincrement for " + "composite primary keys", + ): + metadata.create_all(connection) + return + else: + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.idint, + t1.c.idstr, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + + if no_autoincrement: + data = [ + {"idint": i + 1, "idstr": f"id{i}", "data": f"d{i}"} + for i in range(10) + ] + else: + data = [{"idstr": f"id{i}", "data": f"d{i}"} for i in range(10)] + + if ( + testing.db.dialect.use_insertmanyvalues + and add_sentinel_to_col + and sort_by_parameter_order + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column data.idint can't be explicitly marked as a sentinel " + "column when using the sqlite dialect", + ): + result = connection.execute(stmt, data) + return + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + separate_sentinel=not set_autoincrement and add_sentinel_to_col, + server_autoincrement=set_autoincrement, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll( + [ + (i + 1, f"id{i}", f"d{i}", "some_server_default") + for i in range(10) + ] + ), + ) + + @testing.variation("composite_pk", [True, False]) + @testing.only_on( + [ + "+psycopg", + "+psycopg2", + "+pysqlite", + "+mysqlclient", + "+cx_oracle", + "+oracledb", + ] + ) + def test_failure_mode_if_i_dont_send_value( + self, metadata, connection, sort_by_parameter_order, composite_pk + ): + """test that we get a regular integrity error if a required + PK value was not sent, that is, imv does not get in the way + + """ + t1 = Table( + "data", + metadata, + Column("id", String(30), primary_key=True), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + if composite_pk: + t1.append_column(Column("uid", Uuid(), default=uuid.uuid4)) + + metadata.create_all(connection) + + with expect_warnings( + r".*but has no Python-side or server-side default ", + raise_on_any_unexpected=True, + ): + with expect_raises(exc.IntegrityError): + connection.execute( + insert(t1).returning( + t1.c.id, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + [{"data": f"d{i}"} for i in range(10)], + ) + + @testing.variation("add_sentinel_flag_to_col", [True, False]) + @testing.variation( + "return_type", ["include_sentinel", "default_only", "return_defaults"] + ) + @testing.variation( + "sentinel_type", + [ + ("autoincrement", testing.requires.autoincrement_without_sequence), + "identity", + "sequence", + ], + ) + def test_implicit_autoincrement_sentinel( + self, + metadata, + connection, + return_type: testing.Variation, + sort_by_parameter_order, + randomize_returning, + sentinel_type, + add_sentinel_flag_to_col, + ): + + if sentinel_type.identity: + sentinel_args = [Identity()] + elif sentinel_type.sequence: + sentinel_args = [Sequence("id_seq", start=1)] + else: + sentinel_args = [] + t1 = Table( + "data", + metadata, + Column( + "id", + Integer, + *sentinel_args, + primary_key=True, + insert_sentinel=bool(add_sentinel_flag_to_col), + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=False, + ) + + if sort_by_parameter_order: + collection_cls = list + else: + collection_cls = set + + metadata.create_all(connection) + + if sort_by_parameter_order: + kw = {"sort_by_parameter_order": True} + else: + kw = {} + + if return_type.include_sentinel: + stmt = t1.insert().returning( + t1.c.id, t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.default_only: + stmt = t1.insert().returning( + t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.return_defaults: + stmt = t1.insert().return_defaults(**kw) + + else: + return_type.fail() + + if ( + testing.db.dialect.use_insertmanyvalues + and add_sentinel_flag_to_col + and sort_by_parameter_order + and ( + not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + or ( + # currently a SQL Server case, we dont yet render a + # syntax for SQL Server sequence w/ deterministic + # ordering. The INSERT..SELECT could be restructured + # further to support this at a later time however + # sequences with SQL Server are very unusual. + sentinel_type.sequence + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ) + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column data.id can't be explicitly marked as a " + f"sentinel column when using the {testing.db.dialect.name} " + "dialect", + ): + connection.execute( + stmt, + [{"data": f"d{i}"} for i in range(1, 6)], + ) + return + else: + r = connection.execute( + stmt, + [{"data": f"d{i}"} for i in range(1, 6)], + ) + + if return_type.include_sentinel: + eq_(r.keys(), ["id", "data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [(i, f"d{i}", "some_server_default") for i in range(1, 6)] + ), + ) + elif return_type.default_only: + eq_(r.keys(), ["data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [(f"d{i}", "some_server_default") for i in range(1, 6)] + ), + ) + elif return_type.return_defaults: + eq_(r.keys(), ["id", "has_server_default"]) + eq_( + collection_cls(r.inserted_primary_key_rows), + collection_cls([(i + 1,) for i in range(5)]), + ) + eq_( + collection_cls(r.returned_defaults_rows), + collection_cls( + [ + ( + 1, + "some_server_default", + ), + ( + 2, + "some_server_default", + ), + ( + 3, + "some_server_default", + ), + ( + 4, + "some_server_default", + ), + ( + 5, + "some_server_default", + ), + ] + ), + ) + eq_(r.all(), []) + else: + return_type.fail() + + @testing.variation("pk_type", ["serverside", "clientside"]) + @testing.variation( + "sentinel_type", + [ + "use_pk", + ("use_pk_explicit", testing.skip_if("sqlite")), + "separate_uuid", + "separate_sentinel", + ], + ) + @testing.requires.provisioned_upsert + def test_upsert_downgrades( + self, + metadata, + connection, + pk_type: testing.Variation, + sort_by_parameter_order, + randomize_returning, + sentinel_type, + warn_for_downgrades, + ): + if pk_type.serverside: + pk_col = Column( + "id", + Integer(), + primary_key=True, + insert_sentinel=bool(sentinel_type.use_pk_explicit), + ) + elif pk_type.clientside: + pk_col = Column( + "id", + Uuid(), + default=uuid.uuid4, + primary_key=True, + insert_sentinel=bool(sentinel_type.use_pk_explicit), + ) + else: + pk_type.fail() + + if sentinel_type.separate_uuid: + extra_col = Column( + "sent_col", + Uuid(), + default=uuid.uuid4, + insert_sentinel=True, + nullable=False, + ) + elif sentinel_type.separate_sentinel: + extra_col = insert_sentinel("sent_col") + else: + extra_col = Column("sent_col", Integer) + + t1 = Table( + "upsert_table", + metadata, + pk_col, + Column("data", String(50)), + extra_col, + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + result = connection.execute( + insert(t1).returning( + t1.c.id, t1.c.data, sort_by_parameter_order=True + ), + [{"data": "d1"}, {"data": "d2"}], + ) + d1d2 = list(result) + + if pk_type.serverside: + new_ids = [10, 15, 3] + elif pk_type.clientside: + new_ids = [uuid.uuid4() for i in range(3)] + else: + pk_type.fail() + + upsert_data = [ + {"id": d1d2[0][0], "data": "d1 new"}, + {"id": new_ids[0], "data": "d10"}, + {"id": new_ids[1], "data": "d15"}, + {"id": d1d2[1][0], "data": "d2 new"}, + {"id": new_ids[2], "data": "d3"}, + ] + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = provision.upsert( + config, + t1, + (t1.c.data, t1.c.has_server_default), + set_lambda=lambda inserted: { + "data": inserted.data + " upserted", + }, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + ): + result = connection.execute(stmt, upsert_data) + + expected_data = [ + ("d1 new upserted", "some_server_default"), + ("d10", "some_server_default"), + ("d15", "some_server_default"), + ("d2 new upserted", "some_server_default"), + ("d3", "some_server_default"), + ] + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_(coll(result), coll(expected_data)) + + def test_auto_downgraded_non_mvi_dialect( + self, + metadata, + testing_engine, + randomize_returning, + warn_for_downgrades, + sort_by_parameter_order, + ): + """Accommodate the case of the dialect that supports RETURNING, but + does not support "multi values INSERT" syntax. + + These dialects should still provide insertmanyvalues/returning + support, using downgraded batching. + + For now, we are still keeping this entire thing "opt in" by requiring + that use_insertmanyvalues=True, which means we can't simplify the + ORM by not worrying about dialects where ordering is available or + not. + + However, dialects that use RETURNING, but don't support INSERT VALUES + (..., ..., ...) can set themselves up like this:: + + class MyDialect(DefaultDialect): + use_insertmanyvalues = True + supports_multivalues_insert = False + + This test runs for everyone **including** Oracle, where we + exercise Oracle using "insertmanyvalues" without "multivalues_insert". + + """ + engine = testing_engine() + engine.connect().close() + + engine.dialect.supports_multivalues_insert = False + engine.dialect.use_insertmanyvalues = True + + uuids = [uuid.uuid4() for i in range(10)] + + t1 = Table( + "t1", + metadata, + Column("id", Uuid(), default=functools.partial(next, iter(uuids))), + Column("data", String(50)), + ) + metadata.create_all(engine) + + with engine.connect() as conn: + + fixtures.insertmanyvalues_fixture( + conn, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.id, + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}"} for i in range(10)] + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=True, # will warn even if not sorted + connection=conn, + ): + result = conn.execute(stmt, data) + + expected_data = [(uuids[i], f"d{i}") for i in range(10)] + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_(coll(result), coll(expected_data)) |