summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-07-21 11:18:01 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-07-21 13:57:22 -0400
commita34a4af8a80f4edd12b022753b69065025818e20 (patch)
tree379a4060304439f46a6515b9b8c6cd74553c477e
parente7119aea7870f0322e78d3a2cb28337b1640f0c2 (diff)
downloadsqlalchemy-a34a4af8a80f4edd12b022753b69065025818e20.tar.gz
implement cache key for return_defaults token
Fixed critical caching issue where the ORM's persistence feature using INSERT..RETURNING would cache an incorrect query when mixing the "bulk save" and standard "flush" forms of INSERT. Fixes: #6793 Change-Id: Ifeb61c1226d3fa6d5e1c2e29b6f5ff77a27d6a2d
-rw-r--r--doc/build/changelog/unreleased_14/6793.rst7
-rw-r--r--lib/sqlalchemy/sql/crud.py6
-rw-r--r--lib/sqlalchemy/sql/dml.py21
-rw-r--r--test/orm/test_bulk.py54
-rw-r--r--test/sql/test_compare.py6
5 files changed, 88 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_14/6793.rst b/doc/build/changelog/unreleased_14/6793.rst
new file mode 100644
index 000000000..059bdac65
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/6793.rst
@@ -0,0 +1,7 @@
+.. change::
+ :tags: bug, orm, regression
+ :tickets: 6793
+
+ Fixed critical caching issue where the ORM's persistence feature using
+ INSERT..RETURNING would cache an incorrect query when mixing the "bulk
+ save" and standard "flush" forms of INSERT.
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 74f5a1d05..b8f8cb4ce 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -760,7 +760,7 @@ def _append_param_update(
compiler.postfetch.append(c)
elif (
implicit_return_defaults
- and stmt._return_defaults is not True
+ and (stmt._return_defaults_columns or not stmt._return_defaults)
and c in implicit_return_defaults
):
compiler.returning.append(c)
@@ -1024,10 +1024,10 @@ def _get_returning_modifiers(compiler, stmt, compile_state):
implicit_return_defaults = False # pragma: no cover
if implicit_return_defaults:
- if stmt._return_defaults is True:
+ if not stmt._return_defaults_columns:
implicit_return_defaults = set(stmt.table.c)
else:
- implicit_return_defaults = set(stmt._return_defaults)
+ implicit_return_defaults = set(stmt._return_defaults_columns)
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 048475040..158cb40f2 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -214,7 +214,8 @@ class UpdateBase(
_hints = util.immutabledict()
named_with_column = False
- _return_defaults = None
+ _return_defaults = False
+ _return_defaults_columns = None
_returning = ()
is_dml = True
@@ -794,7 +795,8 @@ class ValuesBase(UpdateBase):
:attr:`_engine.CursorResult.inserted_primary_key_rows`
"""
- self._return_defaults = cols or True
+ self._return_defaults = True
+ self._return_defaults_columns = cols
class Insert(ValuesBase):
@@ -825,6 +827,11 @@ class Insert(ValuesBase):
("_post_values_clause", InternalTraversal.dp_clauseelement),
("_returning", InternalTraversal.dp_clauseelement_list),
("_hints", InternalTraversal.dp_table_hint_list),
+ ("_return_defaults", InternalTraversal.dp_boolean),
+ (
+ "_return_defaults_columns",
+ InternalTraversal.dp_clauseelement_list,
+ ),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
@@ -929,7 +936,10 @@ class Insert(ValuesBase):
if dialect_kw:
self._validate_dialect_kwargs_deprecated(dialect_kw)
- self._return_defaults = return_defaults
+ if return_defaults:
+ self._return_defaults = True
+ if not isinstance(return_defaults, bool):
+ self._return_defaults_columns = return_defaults
@_generative
def inline(self):
@@ -1116,6 +1126,11 @@ class Update(DMLWhereBase, ValuesBase):
("_values", InternalTraversal.dp_dml_values),
("_returning", InternalTraversal.dp_clauseelement_list),
("_hints", InternalTraversal.dp_table_hint_list),
+ ("_return_defaults", InternalTraversal.dp_boolean),
+ (
+ "_return_defaults_columns",
+ InternalTraversal.dp_clauseelement_list,
+ ),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py
index 32ee80708..7e47507c0 100644
--- a/test/orm/test_bulk.py
+++ b/test/orm/test_bulk.py
@@ -866,3 +866,57 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
],
),
)
+
+
+class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class User(Base):
+ __tablename__ = "users"
+ id = Column(Integer, primary_key=True)
+ name = Column(String(255), nullable=False)
+
+ def test_issue_6793(self):
+ User = self.classes.User
+
+ session = fixture_session()
+
+ with self.sql_execution_asserter() as asserter:
+
+ session.bulk_save_objects([User(name="A"), User(name="B")])
+
+ session.add(User(name="C"))
+ session.add(User(name="D"))
+ session.flush()
+
+ asserter.assert_(
+ Conditional(
+ testing.db.dialect.insert_executemany_returning,
+ [
+ CompiledSQL(
+ "INSERT INTO users (name) VALUES (:name)",
+ [{"name": "A"}, {"name": "B"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO users (name) VALUES (:name)",
+ [{"name": "C"}, {"name": "D"}],
+ ),
+ ],
+ [
+ CompiledSQL(
+ "INSERT INTO users (name) VALUES (:name)",
+ [{"name": "A"}, {"name": "B"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO users (name) VALUES (:name)",
+ [{"name": "C"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO users (name) VALUES (:name)",
+ [{"name": "D"}],
+ ),
+ ],
+ )
+ )
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 188d9337e..371e68a8a 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -534,6 +534,9 @@ class CoreFixtures(object):
),
lambda: (
table_a.insert(),
+ table_a.insert().return_defaults(),
+ table_a.insert().return_defaults(table_a.c.a),
+ table_a.insert().return_defaults(table_a.c.b),
table_a.insert().values({})._annotate({"nocache": True}),
table_b.insert(),
table_b.insert().with_dialect_options(sqlite_foo="some value"),
@@ -570,6 +573,9 @@ class CoreFixtures(object):
),
lambda: (
table_b.update(),
+ table_b.update().return_defaults(),
+ table_b.update().return_defaults(table_b.c.a),
+ table_b.update().return_defaults(table_b.c.b),
table_b.update().where(table_b.c.a == 5),
table_b.update().where(table_b.c.b == 5),
table_b.update()