summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-04-27 00:13:00 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-04-27 00:13:00 +0000
commit11535752b94acb41ff684cf8d9c745038addc447 (patch)
treefec970fe35228d33c45280ad558ed9bc251b5208 /test
parentc89c2b3d9a18bd0eb4c8ace50ef875101c9f4b70 (diff)
parent8ec396873c9bbfcc4416e55b5f9d8653554a1df0 (diff)
downloadsqlalchemy-11535752b94acb41ff684cf8d9c745038addc447.tar.gz
Merge "support parameters in all ORM insert modes" into main
Diffstat (limited to 'test')
-rw-r--r--test/orm/dml/test_bulk_statements.py320
-rw-r--r--test/orm/dml/test_update_delete_where.py15
-rw-r--r--test/sql/test_utils.py26
3 files changed, 345 insertions, 16 deletions
diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py
index 84ea7c82c..ab03b251d 100644
--- a/test/orm/dml/test_bulk_statements.py
+++ b/test/orm/dml/test_bulk_statements.py
@@ -7,6 +7,7 @@ from typing import Optional
from typing import Set
import uuid
+from sqlalchemy import bindparam
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
@@ -14,6 +15,7 @@ from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy import insert
from sqlalchemy import inspect
+from sqlalchemy import Integer
from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import select
@@ -226,6 +228,310 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
eq_(result.all(), [User(id=1, name="John", age=30)])
+ @testing.variation(
+ "use_returning", [(True, testing.requires.insert_returning), False]
+ )
+ @testing.variation("use_multiparams", [True, False])
+ @testing.variation("bindparam_in_expression", [True, False])
+ @testing.combinations(
+ "auto", "raw", "bulk", "orm", argnames="dml_strategy"
+ )
+ def test_alt_bindparam_names(
+ self,
+ use_returning,
+ decl_base,
+ use_multiparams,
+ dml_strategy,
+ bindparam_in_expression,
+ ):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+
+ x: Mapped[int]
+ y: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ s = fixture_session()
+
+ if bindparam_in_expression:
+ stmt = insert(A).values(y=literal(3) * (bindparam("q") + 15))
+ else:
+ stmt = insert(A).values(y=bindparam("q"))
+
+ if dml_strategy != "auto":
+ # it really should work with any strategy
+ stmt = stmt.execution_options(dml_strategy=dml_strategy)
+
+ if use_returning:
+ stmt = stmt.returning(A.x, A.y)
+
+ if use_multiparams:
+ if bindparam_in_expression:
+ expected_qs = [60, 69, 81]
+ else:
+ expected_qs = [5, 8, 12]
+
+ result = s.execute(
+ stmt,
+ [
+ {"q": 5, "x": 10},
+ {"q": 8, "x": 11},
+ {"q": 12, "x": 12},
+ ],
+ )
+ else:
+ if bindparam_in_expression:
+ expected_qs = [60]
+ else:
+ expected_qs = [5]
+
+ result = s.execute(stmt, {"q": 5, "x": 10})
+ if use_returning:
+ if use_multiparams:
+ eq_(
+ result.all(),
+ [
+ (10, expected_qs[0]),
+ (11, expected_qs[1]),
+ (12, expected_qs[2]),
+ ],
+ )
+ else:
+ eq_(result.first(), (10, expected_qs[0]))
+
+
+class UpdateStmtTest(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.variation(
+ "returning_executemany",
+ [
+ ("returning", testing.requires.update_returning),
+ "executemany",
+ "plain",
+ ],
+ )
+ @testing.variation("bindparam_in_expression", [True, False])
+ # TODO: setting "bulk" here is all over the place as well, UPDATE is not
+ # too settled
+ @testing.combinations("auto", "orm", argnames="dml_strategy")
+ @testing.combinations(
+ "evaluate", "fetch", None, argnames="synchronize_strategy"
+ )
+ def test_alt_bindparam_names(
+ self,
+ decl_base,
+ returning_executemany,
+ dml_strategy,
+ bindparam_in_expression,
+ synchronize_strategy,
+ ):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, autoincrement=False
+ )
+
+ x: Mapped[int]
+ y: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ s = fixture_session()
+
+ s.add_all(
+ [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+ )
+ s.commit()
+
+ if bindparam_in_expression:
+ stmt = (
+ update(A)
+ .values(y=literal(3) * (bindparam("q") + 15))
+ .where(A.id == bindparam("b_id"))
+ )
+ else:
+ stmt = (
+ update(A)
+ .values(y=bindparam("q"))
+ .where(A.id == bindparam("b_id"))
+ )
+
+ if dml_strategy != "auto":
+ # it really should work with any strategy
+ stmt = stmt.execution_options(dml_strategy=dml_strategy)
+
+ if returning_executemany.returning:
+ stmt = stmt.returning(A.x, A.y)
+
+ if synchronize_strategy in (None, "evaluate", "fetch"):
+ stmt = stmt.execution_options(
+ synchronize_session=synchronize_strategy
+ )
+
+ if returning_executemany.executemany:
+ if bindparam_in_expression:
+ expected_qs = [60, 69, 81]
+ else:
+ expected_qs = [5, 8, 12]
+
+ if dml_strategy != "orm":
+ params = [
+ {"id": 1, "b_id": 1, "q": 5, "x": 10},
+ {"id": 2, "b_id": 2, "q": 8, "x": 11},
+ {"id": 3, "b_id": 3, "q": 12, "x": 12},
+ ]
+ else:
+ params = [
+ {"b_id": 1, "q": 5, "x": 10},
+ {"b_id": 2, "q": 8, "x": 11},
+ {"b_id": 3, "q": 12, "x": 12},
+ ]
+
+ _expect_raises = None
+
+ if synchronize_strategy == "fetch":
+ if dml_strategy != "orm":
+ _expect_raises = expect_raises_message(
+ exc.InvalidRequestError,
+ r"The 'fetch' synchronization strategy is not "
+ r"available for 'bulk' ORM updates "
+ r"\(i.e. multiple parameter sets\)",
+ )
+ elif not testing.db.dialect.update_executemany_returning:
+ # no backend supports this except Oracle
+ _expect_raises = expect_raises_message(
+ exc.InvalidRequestError,
+ r"For synchronize_session='fetch', can't use multiple "
+ r"parameter sets in ORM mode, which this backend does "
+ r"not support with RETURNING",
+ )
+
+ elif synchronize_strategy == "evaluate" and dml_strategy != "orm":
+ _expect_raises = expect_raises_message(
+ exc.InvalidRequestError,
+ "bulk synchronize of persistent objects not supported",
+ )
+
+ if _expect_raises:
+ with _expect_raises:
+ result = s.execute(stmt, params)
+ return
+
+ result = s.execute(stmt, params)
+ else:
+ if bindparam_in_expression:
+ expected_qs = [60]
+ else:
+ expected_qs = [5]
+
+ result = s.execute(stmt, {"b_id": 1, "q": 5, "x": 10})
+
+ if returning_executemany.returning:
+ eq_(result.first(), (10, expected_qs[0]))
+
+ elif returning_executemany.executemany:
+ eq_(
+ s.execute(select(A.x, A.y).order_by(A.id)).all(),
+ [
+ (10, expected_qs[0]),
+ (11, expected_qs[1]),
+ (12, expected_qs[2]),
+ ],
+ )
+
+ def test_bulk_update_w_where_one(self, decl_base):
+ """test use case in #9595"""
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, autoincrement=False
+ )
+
+ x: Mapped[int]
+ y: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ s = fixture_session()
+
+ s.add_all(
+ [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+ )
+ s.commit()
+
+ stmt = (
+ update(A)
+ .where(A.x > 1)
+ .execution_options(synchronize_session=None)
+ )
+
+ s.execute(
+ stmt,
+ [
+ {"id": 1, "x": 3, "y": 8},
+ {"id": 2, "x": 5, "y": 9},
+ {"id": 3, "x": 12, "y": 15},
+ ],
+ )
+
+ eq_(
+ s.execute(select(A.id, A.x, A.y).order_by(A.id)).all(),
+ [(1, 1, 1), (2, 5, 9), (3, 12, 15)],
+ )
+
+ def test_bulk_update_w_where_two(self, decl_base):
+ class User(decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, autoincrement=False
+ )
+ name: Mapped[str]
+ age: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ sess = fixture_session()
+ sess.execute(
+ insert(User),
+ [
+ dict(id=1, name="john", age=25),
+ dict(id=2, name="jack", age=47),
+ dict(id=3, name="jill", age=29),
+ dict(id=4, name="jane", age=37),
+ ],
+ )
+
+ sess.execute(
+ update(User)
+ .where(User.age > bindparam("gtage"))
+ .values(age=bindparam("dest_age"))
+ .execution_options(synchronize_session=None),
+ [
+ {"id": 1, "gtage": 28, "dest_age": 40},
+ {"id": 2, "gtage": 20, "dest_age": 45},
+ ],
+ )
+
+ eq_(
+ sess.execute(
+ select(User.id, User.name, User.age).order_by(User.id)
+ ).all(),
+ [
+ (1, "john", 25),
+ (2, "jack", 45),
+ (3, "jill", 29),
+ (4, "jane", 37),
+ ],
+ )
+
class BulkDMLReturningInhTest:
use_sentinel = False
@@ -965,7 +1271,10 @@ class BulkDMLReturningInhTest:
eq_(coll(ids), coll(actual_ids))
- @testing.variation("insert_strategy", ["orm", "bulk", "bulk_ordered"])
+ @testing.variation(
+ "insert_strategy",
+ ["orm", "bulk", "bulk_ordered", "bulk_w_embedded_bindparam"],
+ )
@testing.requires.provisioned_upsert
def test_base_class_upsert(self, insert_strategy):
"""upsert is really tricky. if you dont have any data updated,
@@ -1036,6 +1345,15 @@ class BulkDMLReturningInhTest:
sort_by_parameter_order=insert_strategy.bulk_ordered
):
result = s.scalars(stmt, upsert_data)
+ elif insert_strategy.bulk_w_embedded_bindparam:
+ # test related to #9583, specific user case in
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/9581#discussioncomment-5504077 # noqa: E501
+ stmt = stmt.values(
+ y=select(bindparam("qq1", type_=Integer)).scalar_subquery()
+ )
+ for d in upsert_data:
+ d["qq1"] = d.pop("y")
+ result = s.scalars(stmt, upsert_data)
else:
insert_strategy.fail()
diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py
index 19e557fd9..e45d92659 100644
--- a/test/orm/dml/test_update_delete_where.py
+++ b/test/orm/dml/test_update_delete_where.py
@@ -1,4 +1,3 @@
-from sqlalchemy import bindparam
from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
@@ -810,20 +809,6 @@ class UpdateDeleteTest(fixtures.MappedTest):
eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
- def test_update_multirow_not_supported(self):
- User = self.classes.User
-
- sess = fixture_session()
-
- with expect_raises_message(
- exc.InvalidRequestError,
- "WHERE clause with bulk ORM UPDATE not supported " "right now.",
- ):
- sess.execute(
- update(User).where(User.id == bindparam("id")),
- [{"id": 1, "age": 27}, {"id": 2, "age": 37}],
- )
-
def test_delete_bulk_not_supported(self):
User = self.classes.User
diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py
index 61777def5..615995c73 100644
--- a/test/sql/test_utils.py
+++ b/test/sql/test_utils.py
@@ -1,5 +1,6 @@
from itertools import zip_longest
+from sqlalchemy import bindparam
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
@@ -7,6 +8,7 @@ from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
from sqlalchemy.sql import base as sql_base
from sqlalchemy.sql import coercions
from sqlalchemy.sql import column
@@ -18,6 +20,8 @@ from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not_none
class MiscTest(fixtures.TestBase):
@@ -41,6 +45,28 @@ class MiscTest(fixtures.TestBase):
eq_(set(sql_util.find_tables(subset_select)), {common})
+ @testing.variation("has_cache_key", [True, False])
+ def test_get_embedded_bindparams(self, has_cache_key):
+ bp = bindparam("x")
+
+ if not has_cache_key:
+
+ class NotCacheable(TypeDecorator):
+ impl = String
+ cache_ok = False
+
+ stmt = select(column("q", NotCacheable())).where(column("y") == bp)
+
+ else:
+ stmt = select(column("q")).where(column("y") == bp)
+
+ eq_(stmt._get_embedded_bindparams(), [bp])
+
+ if not has_cache_key:
+ is_(stmt._generate_cache_key(), None)
+ else:
+ is_not_none(stmt._generate_cache_key())
+
def test_find_tables_aliases(self):
metadata = MetaData()
common = Table(