summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-04-28 12:07:09 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2023-05-09 10:08:52 -0400
commit4a62625d99470c8928422c4822df5234b93b6bb8 (patch)
tree280182818aea6846f1294705357b6a0754d51df4 /lib/sqlalchemy/sql/compiler.py
parent39c8e95b1f50190ff30a836b2bcf13ba2cacc052 (diff)
downloadsqlalchemy-4a62625d99470c8928422c4822df5234b93b6bb8.tar.gz
implement FromLinter for UPDATE, DELETE statements
Implemented the "cartesian product warning" for UPDATE and DELETE statements, those which include multiple tables that are not correlated together in some way. Fixed issue where :func:`_dml.update` construct that included multiple tables and no VALUES clause would raise with an internal error. Current behavior for :class:`_dml.Update` with no values is to generate a SQL UPDATE statement with an empty "set" clause, so this has been made consistent for this specific sub-case. Fixes: #9721 Change-Id: I556639811cc930d2e37532965d2ae751882af921
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py72
1 files changed, 61 insertions, 11 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 554a84112..619ff0848 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
else:
return None, None
- def warn(self):
+ def warn(self, stmt_type="SELECT"):
the_rest, start_with = self.lint()
# FROMS left over? boom
@@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
froms = the_rest
if froms:
template = (
- "SELECT statement has a cartesian product between "
+ "{stmt_type} statement has a cartesian product between "
"FROM element(s) {froms} and "
'FROM element "{start}". Apply join condition(s) '
"between each element to resolve."
@@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
f'"{self.froms[from_]}"' for from_ in froms
)
message = template.format(
- froms=froms_str, start=self.froms[start_with]
+ stmt_type=stmt_type,
+ froms=froms_str,
+ start=self.froms[start_with],
)
util.warn(message)
@@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
+
compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
@@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
@@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled):
)
table_text = self.update_tables_clause(
- update_stmt, update_stmt.table, render_extra_froms, **kw
+ update_stmt,
+ update_stmt.table,
+ render_extra_froms,
+ from_linter=from_linter,
+ **kw,
)
crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, toplevel, **kw
@@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled):
update_stmt.table,
render_extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled):
if update_stmt._where_criteria:
t = self._generate_delimited_and_list(
- update_stmt._where_criteria, **kw
+ update_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="UPDATE")
+
self.stack.pop(-1)
return text
@@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled):
"criteria within DELETE"
)
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
- return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, **kw
+ )
def visit_delete(self, delete_stmt, **kw):
compile_state = delete_stmt._compile_state_factory(
@@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
@@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled):
)
text += "FROM "
- table_text = self.delete_table_clause(
- delete_stmt, delete_stmt.table, extra_froms
- )
+
+ try:
+ table_text = self.delete_table_clause(
+ delete_stmt,
+ delete_stmt.table,
+ extra_froms,
+ from_linter=from_linter,
+ )
+ except TypeError:
+ # anticipate 3rd party dialects that don't include **kw
+ # TODO: remove in 2.1
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
+ if from_linter:
+ _ = self.process(delete_stmt.table, from_linter=from_linter)
crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
@@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled):
delete_stmt.table,
extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled):
if delete_stmt._where_criteria:
t = self._generate_delimited_and_list(
- delete_stmt._where_criteria, **kw
+ delete_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="DELETE")
+
self.stack.pop(-1)
return text