diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2023-04-27 00:13:00 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-04-27 00:13:00 +0000 |
| commit | 11535752b94acb41ff684cf8d9c745038addc447 (patch) | |
| tree | fec970fe35228d33c45280ad558ed9bc251b5208 /test | |
| parent | c89c2b3d9a18bd0eb4c8ace50ef875101c9f4b70 (diff) | |
| parent | 8ec396873c9bbfcc4416e55b5f9d8653554a1df0 (diff) | |
| download | sqlalchemy-11535752b94acb41ff684cf8d9c745038addc447.tar.gz | |
Merge "support parameters in all ORM insert modes" into main
Diffstat (limited to 'test')
| -rw-r--r-- | test/orm/dml/test_bulk_statements.py | 320 | ||||
| -rw-r--r-- | test/orm/dml/test_update_delete_where.py | 15 | ||||
| -rw-r--r-- | test/sql/test_utils.py | 26 |
3 files changed, 345 insertions, 16 deletions
diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 84ea7c82c..ab03b251d 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -7,6 +7,7 @@ from typing import Optional from typing import Set import uuid +from sqlalchemy import bindparam from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey @@ -14,6 +15,7 @@ from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import insert from sqlalchemy import inspect +from sqlalchemy import Integer from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import select @@ -226,6 +228,310 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): eq_(result.all(), [User(id=1, name="John", age=30)]) + @testing.variation( + "use_returning", [(True, testing.requires.insert_returning), False] + ) + @testing.variation("use_multiparams", [True, False]) + @testing.variation("bindparam_in_expression", [True, False]) + @testing.combinations( + "auto", "raw", "bulk", "orm", argnames="dml_strategy" + ) + def test_alt_bindparam_names( + self, + use_returning, + decl_base, + use_multiparams, + dml_strategy, + bindparam_in_expression, + ): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + s = fixture_session() + + if bindparam_in_expression: + stmt = insert(A).values(y=literal(3) * (bindparam("q") + 15)) + else: + stmt = insert(A).values(y=bindparam("q")) + + if dml_strategy != "auto": + # it really should work with any strategy + stmt = stmt.execution_options(dml_strategy=dml_strategy) + + if use_returning: + stmt = stmt.returning(A.x, A.y) + + if use_multiparams: + if bindparam_in_expression: + expected_qs = [60, 69, 81] + else: + expected_qs = [5, 8, 12] + + result = s.execute( + stmt, + [ + {"q": 5, "x": 10}, + {"q": 8, "x": 11}, + {"q": 12, "x": 12}, + ], + ) + else: + if bindparam_in_expression: + expected_qs = [60] + else: + expected_qs = [5] + + result = s.execute(stmt, {"q": 5, "x": 10}) + if use_returning: + if use_multiparams: + eq_( + result.all(), + [ + (10, expected_qs[0]), + (11, expected_qs[1]), + (12, expected_qs[2]), + ], + ) + else: + eq_(result.first(), (10, expected_qs[0])) + + +class UpdateStmtTest(fixtures.TestBase): + __backend__ = True + + @testing.variation( + "returning_executemany", + [ + ("returning", testing.requires.update_returning), + "executemany", + "plain", + ], + ) + @testing.variation("bindparam_in_expression", [True, False]) + # TODO: setting "bulk" here is all over the place as well, UPDATE is not + # too settled + @testing.combinations("auto", "orm", argnames="dml_strategy") + @testing.combinations( + "evaluate", "fetch", None, argnames="synchronize_strategy" + ) + def test_alt_bindparam_names( + self, + decl_base, + returning_executemany, + dml_strategy, + bindparam_in_expression, + synchronize_strategy, + ): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + s = fixture_session() + + s.add_all( + [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)], + ) + s.commit() + + if bindparam_in_expression: + stmt = ( + update(A) + .values(y=literal(3) * (bindparam("q") + 15)) + .where(A.id == bindparam("b_id")) + ) + else: + stmt = ( + update(A) + .values(y=bindparam("q")) + .where(A.id == bindparam("b_id")) + ) + + if dml_strategy != "auto": + # it really should work with any strategy + stmt = stmt.execution_options(dml_strategy=dml_strategy) + + if returning_executemany.returning: + stmt = stmt.returning(A.x, A.y) + + if synchronize_strategy in (None, "evaluate", "fetch"): + stmt = stmt.execution_options( + synchronize_session=synchronize_strategy + ) + + if returning_executemany.executemany: + if bindparam_in_expression: + expected_qs = [60, 69, 81] + else: + expected_qs = [5, 8, 12] + + if dml_strategy != "orm": + params = [ + {"id": 1, "b_id": 1, "q": 5, "x": 10}, + {"id": 2, "b_id": 2, "q": 8, "x": 11}, + {"id": 3, "b_id": 3, "q": 12, "x": 12}, + ] + else: + params = [ + {"b_id": 1, "q": 5, "x": 10}, + {"b_id": 2, "q": 8, "x": 11}, + {"b_id": 3, "q": 12, "x": 12}, + ] + + _expect_raises = None + + if synchronize_strategy == "fetch": + if dml_strategy != "orm": + _expect_raises = expect_raises_message( + exc.InvalidRequestError, + r"The 'fetch' synchronization strategy is not " + r"available for 'bulk' ORM updates " + r"\(i.e. multiple parameter sets\)", + ) + elif not testing.db.dialect.update_executemany_returning: + # no backend supports this except Oracle + _expect_raises = expect_raises_message( + exc.InvalidRequestError, + r"For synchronize_session='fetch', can't use multiple " + r"parameter sets in ORM mode, which this backend does " + r"not support with RETURNING", + ) + + elif synchronize_strategy == "evaluate" and dml_strategy != "orm": + _expect_raises = expect_raises_message( + exc.InvalidRequestError, + "bulk synchronize of persistent objects not supported", + ) + + if _expect_raises: + with _expect_raises: + result = s.execute(stmt, params) + return + + result = s.execute(stmt, params) + else: + if bindparam_in_expression: + expected_qs = [60] + else: + expected_qs = [5] + + result = s.execute(stmt, {"b_id": 1, "q": 5, "x": 10}) + + if returning_executemany.returning: + eq_(result.first(), (10, expected_qs[0])) + + elif returning_executemany.executemany: + eq_( + s.execute(select(A.x, A.y).order_by(A.id)).all(), + [ + (10, expected_qs[0]), + (11, expected_qs[1]), + (12, expected_qs[2]), + ], + ) + + def test_bulk_update_w_where_one(self, decl_base): + """test use case in #9595""" + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + s = fixture_session() + + s.add_all( + [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)], + ) + s.commit() + + stmt = ( + update(A) + .where(A.x > 1) + .execution_options(synchronize_session=None) + ) + + s.execute( + stmt, + [ + {"id": 1, "x": 3, "y": 8}, + {"id": 2, "x": 5, "y": 9}, + {"id": 3, "x": 12, "y": 15}, + ], + ) + + eq_( + s.execute(select(A.id, A.x, A.y).order_by(A.id)).all(), + [(1, 1, 1), (2, 5, 9), (3, 12, 15)], + ) + + def test_bulk_update_w_where_two(self, decl_base): + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + name: Mapped[str] + age: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + sess = fixture_session() + sess.execute( + insert(User), + [ + dict(id=1, name="john", age=25), + dict(id=2, name="jack", age=47), + dict(id=3, name="jill", age=29), + dict(id=4, name="jane", age=37), + ], + ) + + sess.execute( + update(User) + .where(User.age > bindparam("gtage")) + .values(age=bindparam("dest_age")) + .execution_options(synchronize_session=None), + [ + {"id": 1, "gtage": 28, "dest_age": 40}, + {"id": 2, "gtage": 20, "dest_age": 45}, + ], + ) + + eq_( + sess.execute( + select(User.id, User.name, User.age).order_by(User.id) + ).all(), + [ + (1, "john", 25), + (2, "jack", 45), + (3, "jill", 29), + (4, "jane", 37), + ], + ) + class BulkDMLReturningInhTest: use_sentinel = False @@ -965,7 +1271,10 @@ class BulkDMLReturningInhTest: eq_(coll(ids), coll(actual_ids)) - @testing.variation("insert_strategy", ["orm", "bulk", "bulk_ordered"]) + @testing.variation( + "insert_strategy", + ["orm", "bulk", "bulk_ordered", "bulk_w_embedded_bindparam"], + ) @testing.requires.provisioned_upsert def test_base_class_upsert(self, insert_strategy): """upsert is really tricky. if you dont have any data updated, @@ -1036,6 +1345,15 @@ class BulkDMLReturningInhTest: sort_by_parameter_order=insert_strategy.bulk_ordered ): result = s.scalars(stmt, upsert_data) + elif insert_strategy.bulk_w_embedded_bindparam: + # test related to #9583, specific user case in + # https://github.com/sqlalchemy/sqlalchemy/discussions/9581#discussioncomment-5504077 # noqa: E501 + stmt = stmt.values( + y=select(bindparam("qq1", type_=Integer)).scalar_subquery() + ) + for d in upsert_data: + d["qq1"] = d.pop("y") + result = s.scalars(stmt, upsert_data) else: insert_strategy.fail() diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py index 19e557fd9..e45d92659 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_update_delete_where.py @@ -1,4 +1,3 @@ -from sqlalchemy import bindparam from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column @@ -810,20 +809,6 @@ class UpdateDeleteTest(fixtures.MappedTest): eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane]) - def test_update_multirow_not_supported(self): - User = self.classes.User - - sess = fixture_session() - - with expect_raises_message( - exc.InvalidRequestError, - "WHERE clause with bulk ORM UPDATE not supported " "right now.", - ): - sess.execute( - update(User).where(User.id == bindparam("id")), - [{"id": 1, "age": 27}, {"id": 2, "age": 37}], - ) - def test_delete_bulk_not_supported(self): User = self.classes.User diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 61777def5..615995c73 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -1,5 +1,6 @@ from itertools import zip_longest +from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import Integer from sqlalchemy import MetaData @@ -7,6 +8,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy import TypeDecorator from sqlalchemy.sql import base as sql_base from sqlalchemy.sql import coercions from sqlalchemy.sql import column @@ -18,6 +20,8 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_not_none class MiscTest(fixtures.TestBase): @@ -41,6 +45,28 @@ class MiscTest(fixtures.TestBase): eq_(set(sql_util.find_tables(subset_select)), {common}) + @testing.variation("has_cache_key", [True, False]) + def test_get_embedded_bindparams(self, has_cache_key): + bp = bindparam("x") + + if not has_cache_key: + + class NotCacheable(TypeDecorator): + impl = String + cache_ok = False + + stmt = select(column("q", NotCacheable())).where(column("y") == bp) + + else: + stmt = select(column("q")).where(column("y") == bp) + + eq_(stmt._get_embedded_bindparams(), [bp]) + + if not has_cache_key: + is_(stmt._generate_cache_key(), None) + else: + is_not_none(stmt._generate_cache_key()) + def test_find_tables_aliases(self): metadata = MetaData() common = Table( |
