summaryrefslogtreecommitdiff
path: root/test/sql/test_insert_exec.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-10-15 15:20:21 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-10-16 08:47:47 -0400
commit2b966de4196c8271934769337780f7d504d431cf (patch)
tree608cf4c6400faf6dccefbaefbcdd2e0db1e9bdae /test/sql/test_insert_exec.py
parente8da50ce0f0474bc89cee15603931760cb6c55ce (diff)
downloadsqlalchemy-2b966de4196c8271934769337780f7d504d431cf.tar.gz
accommodate arbitrary embedded params in insertmanyvalues
Fixed bug in new "insertmanyvalues" feature where INSERT that included a subquery with :func:`_sql.bindparam` inside of it would fail to render correctly in "insertmanyvalues" format. This affected psycopg2 most directly as "insertmanyvalues" is used unconditionally with this driver. Fixes: #8639 Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3
Diffstat (limited to 'test/sql/test_insert_exec.py')
-rw-r--r--test/sql/test_insert_exec.py92
1 files changed, 92 insertions, 0 deletions
diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py
index 4ce093156..429ebf163 100644
--- a/test/sql/test_insert_exec.py
+++ b/test/sql/test_insert_exec.py
@@ -1,6 +1,7 @@
import itertools
from sqlalchemy import and_
+from sqlalchemy import bindparam
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
@@ -8,6 +9,7 @@ from sqlalchemy import func
from sqlalchemy import INT
from sqlalchemy import Integer
from sqlalchemy import literal
+from sqlalchemy import select
from sqlalchemy import Sequence
from sqlalchemy import sql
from sqlalchemy import String
@@ -741,6 +743,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
Column("\u6e2c\u8a66", Integer),
)
+ Table(
+ "extra_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x_value", String(50)),
+ Column("y_value", String(50)),
+ )
+
def test_insert_unicode_keys(self, connection):
table = self.tables["Unitéble2"]
@@ -807,6 +817,88 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)])
+ @testing.combinations(True, False, argnames="use_returning")
+ @testing.combinations(1, 2, argnames="num_embedded_params")
+ @testing.combinations(True, False, argnames="use_whereclause")
+ @testing.crashes(
+ "+mariadbconnector",
+ "returning crashes, regular executemany malfunctions",
+ )
+ def test_insert_w_bindparam_in_subq(
+ self, connection, use_returning, num_embedded_params, use_whereclause
+ ):
+ """test #8639"""
+
+ t = self.tables.data
+ extra = self.tables.extra_table
+
+ conn = connection
+ connection.execute(
+ extra.insert(),
+ [
+ {"x_value": "p1", "y_value": "yv1"},
+ {"x_value": "p2", "y_value": "yv2"},
+ {"x_value": "p1_p1", "y_value": "yv3"},
+ {"x_value": "p2_p2", "y_value": "yv4"},
+ ],
+ )
+
+ if num_embedded_params == 1:
+ if use_whereclause:
+ scalar_subq = select(bindparam("paramname")).scalar_subquery()
+ params = [
+ {"paramname": "p1_p1", "y": "y1"},
+ {"paramname": "p2_p2", "y": "y2"},
+ ]
+ else:
+ scalar_subq = (
+ select(extra.c.x_value)
+ .where(extra.c.y_value == bindparam("y_value"))
+ .scalar_subquery()
+ )
+ params = [
+ {"y_value": "yv3", "y": "y1"},
+ {"y_value": "yv4", "y": "y2"},
+ ]
+
+ elif num_embedded_params == 2:
+ if use_whereclause:
+ scalar_subq = (
+ select(
+ bindparam("paramname1", type_=String) + extra.c.x_value
+ )
+ .where(extra.c.y_value == bindparam("y_value"))
+ .scalar_subquery()
+ )
+ params = [
+ {"paramname1": "p1_", "y_value": "yv1", "y": "y1"},
+ {"paramname1": "p2_", "y_value": "yv2", "y": "y2"},
+ ]
+ else:
+ scalar_subq = select(
+ bindparam("paramname1", type_=String)
+ + bindparam("paramname2", type_=String)
+ ).scalar_subquery()
+ params = [
+ {"paramname1": "p1_", "paramname2": "p1", "y": "y1"},
+ {"paramname1": "p2_", "paramname2": "p2", "y": "y2"},
+ ]
+ else:
+ assert False
+
+ stmt = t.insert().values(x=scalar_subq)
+ if use_returning:
+ stmt = stmt.returning(t.c["x", "y"])
+
+ result = conn.execute(stmt, params)
+
+ if use_returning:
+ eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")])
+
+ result = conn.execute(select(t.c["x", "y"]))
+
+ eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")])
+
def test_insert_returning_defaults(self, connection):
t = self.tables.data