diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-04-28 12:07:09 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-05-09 10:08:52 -0400 |
commit | 4a62625d99470c8928422c4822df5234b93b6bb8 (patch) | |
tree | 280182818aea6846f1294705357b6a0754d51df4 /lib/sqlalchemy/sql/compiler.py | |
parent | 39c8e95b1f50190ff30a836b2bcf13ba2cacc052 (diff) | |
download | sqlalchemy-4a62625d99470c8928422c4822df5234b93b6bb8.tar.gz |
implement FromLinter for UPDATE, DELETE statements
Implemented the "cartesian product warning" for UPDATE and DELETE
statements, those which include multiple tables that are not correlated
together in some way.
Fixed issue where :func:`_dml.update` construct that included multiple
tables and no VALUES clause would raise with an internal error. Current
behavior for :class:`_dml.Update` with no values is to generate a SQL
UPDATE statement with an empty "set" clause, so this has been made
consistent for this specific sub-case.
Fixes: #9721
Change-Id: I556639811cc930d2e37532965d2ae751882af921
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 72 |
1 files changed, 61 insertions, 11 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 554a84112..619ff0848 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): else: return None, None - def warn(self): + def warn(self, stmt_type="SELECT"): the_rest, start_with = self.lint() # FROMS left over? boom @@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): froms = the_rest if froms: template = ( - "SELECT statement has a cartesian product between " + "{stmt_type} statement has a cartesian product between " "FROM element(s) {froms} and " 'FROM element "{start}". Apply join condition(s) ' "between each element to resolve." @@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): f'"{self.froms[from_]}"' for from_ in froms ) message = template.format( - froms=froms_str, start=self.froms[start_with] + stmt_type=stmt_type, + froms=froms_str, + start=self.froms[start_with], ) util.warn(message) @@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_factory( update_stmt, self, **kw ) @@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms is_multitable = bool(extra_froms) @@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled): ) table_text = self.update_tables_clause( - update_stmt, update_stmt.table, render_extra_froms, **kw + update_stmt, + update_stmt.table, + render_extra_froms, + from_linter=from_linter, + **kw, ) crud_params_struct = crud._get_crud_params( self, update_stmt, compile_state, toplevel, **kw @@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled): update_stmt.table, render_extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled): if update_stmt._where_criteria: t = self._generate_delimited_and_list( - update_stmt._where_criteria, **kw + update_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="UPDATE") + self.stack.pop(-1) return text @@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled): "criteria within DELETE" ) - def delete_table_clause(self, delete_stmt, from_table, extra_froms): - return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, **kw + ) def visit_delete(self, delete_stmt, **kw): compile_state = delete_stmt._compile_state_factory( @@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) @@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled): ) text += "FROM " - table_text = self.delete_table_clause( - delete_stmt, delete_stmt.table, extra_froms - ) + + try: + table_text = self.delete_table_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + from_linter=from_linter, + ) + except TypeError: + # anticipate 3rd party dialects that don't include **kw + # TODO: remove in 2.1 + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) + if from_linter: + _ = self.process(delete_stmt.table, from_linter=from_linter) crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) @@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled): delete_stmt.table, extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled): if delete_stmt._where_criteria: t = self._generate_delimited_and_list( - delete_stmt._where_criteria, **kw + delete_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="DELETE") + self.stack.pop(-1) return text |