diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-15 15:20:21 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-16 08:47:47 -0400 |
commit | 2b966de4196c8271934769337780f7d504d431cf (patch) | |
tree | 608cf4c6400faf6dccefbaefbcdd2e0db1e9bdae /test/sql/test_insert_exec.py | |
parent | e8da50ce0f0474bc89cee15603931760cb6c55ce (diff) | |
download | sqlalchemy-2b966de4196c8271934769337780f7d504d431cf.tar.gz |
accommodate arbitrary embedded params in insertmanyvalues
Fixed bug in new "insertmanyvalues" feature where INSERT that included a
subquery with :func:`_sql.bindparam` inside of it would fail to render
correctly in "insertmanyvalues" format. This affected psycopg2 most
directly as "insertmanyvalues" is used unconditionally with this driver.
Fixes: #8639
Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3
Diffstat (limited to 'test/sql/test_insert_exec.py')
-rw-r--r-- | test/sql/test_insert_exec.py | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 4ce093156..429ebf163 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -1,6 +1,7 @@ import itertools from sqlalchemy import and_ +from sqlalchemy import bindparam from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey @@ -8,6 +9,7 @@ from sqlalchemy import func from sqlalchemy import INT from sqlalchemy import Integer from sqlalchemy import literal +from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import sql from sqlalchemy import String @@ -741,6 +743,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): Column("\u6e2c\u8a66", Integer), ) + Table( + "extra_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_value", String(50)), + Column("y_value", String(50)), + ) + def test_insert_unicode_keys(self, connection): table = self.tables["Unitéble2"] @@ -807,6 +817,88 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)]) + @testing.combinations(True, False, argnames="use_returning") + @testing.combinations(1, 2, argnames="num_embedded_params") + @testing.combinations(True, False, argnames="use_whereclause") + @testing.crashes( + "+mariadbconnector", + "returning crashes, regular executemany malfunctions", + ) + def test_insert_w_bindparam_in_subq( + self, connection, use_returning, num_embedded_params, use_whereclause + ): + """test #8639""" + + t = self.tables.data + extra = self.tables.extra_table + + conn = connection + connection.execute( + extra.insert(), + [ + {"x_value": "p1", "y_value": "yv1"}, + {"x_value": "p2", "y_value": "yv2"}, + {"x_value": "p1_p1", "y_value": "yv3"}, + {"x_value": "p2_p2", "y_value": "yv4"}, + ], + ) + + if num_embedded_params == 1: + if use_whereclause: + scalar_subq = select(bindparam("paramname")).scalar_subquery() + params = [ + {"paramname": "p1_p1", "y": "y1"}, + {"paramname": "p2_p2", "y": "y2"}, + ] + else: + scalar_subq = ( + select(extra.c.x_value) + .where(extra.c.y_value == bindparam("y_value")) + .scalar_subquery() + ) + params = [ + {"y_value": "yv3", "y": "y1"}, + {"y_value": "yv4", "y": "y2"}, + ] + + elif num_embedded_params == 2: + if use_whereclause: + scalar_subq = ( + select( + bindparam("paramname1", type_=String) + extra.c.x_value + ) + .where(extra.c.y_value == bindparam("y_value")) + .scalar_subquery() + ) + params = [ + {"paramname1": "p1_", "y_value": "yv1", "y": "y1"}, + {"paramname1": "p2_", "y_value": "yv2", "y": "y2"}, + ] + else: + scalar_subq = select( + bindparam("paramname1", type_=String) + + bindparam("paramname2", type_=String) + ).scalar_subquery() + params = [ + {"paramname1": "p1_", "paramname2": "p1", "y": "y1"}, + {"paramname1": "p2_", "paramname2": "p2", "y": "y2"}, + ] + else: + assert False + + stmt = t.insert().values(x=scalar_subq) + if use_returning: + stmt = stmt.returning(t.c["x", "y"]) + + result = conn.execute(stmt, params) + + if use_returning: + eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")]) + + result = conn.execute(select(t.c["x", "y"])) + + eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")]) + def test_insert_returning_defaults(self, connection): t = self.tables.data |