diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 72 |
1 files changed, 61 insertions, 11 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 554a84112..619ff0848 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): else: return None, None - def warn(self): + def warn(self, stmt_type="SELECT"): the_rest, start_with = self.lint() # FROMS left over? boom @@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): froms = the_rest if froms: template = ( - "SELECT statement has a cartesian product between " + "{stmt_type} statement has a cartesian product between " "FROM element(s) {froms} and " 'FROM element "{start}". Apply join condition(s) ' "between each element to resolve." @@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): f'"{self.froms[from_]}"' for from_ in froms ) message = template.format( - froms=froms_str, start=self.froms[start_with] + stmt_type=stmt_type, + froms=froms_str, + start=self.froms[start_with], ) util.warn(message) @@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_factory( update_stmt, self, **kw ) @@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms is_multitable = bool(extra_froms) @@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled): ) table_text = self.update_tables_clause( - update_stmt, update_stmt.table, render_extra_froms, **kw + update_stmt, + update_stmt.table, + render_extra_froms, + from_linter=from_linter, + **kw, ) crud_params_struct = crud._get_crud_params( self, update_stmt, compile_state, toplevel, **kw @@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled): update_stmt.table, render_extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled): if update_stmt._where_criteria: t = self._generate_delimited_and_list( - update_stmt._where_criteria, **kw + update_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="UPDATE") + self.stack.pop(-1) return text @@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled): "criteria within DELETE" ) - def delete_table_clause(self, delete_stmt, from_table, extra_froms): - return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, **kw + ) def visit_delete(self, delete_stmt, **kw): compile_state = delete_stmt._compile_state_factory( @@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) @@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled): ) text += "FROM " - table_text = self.delete_table_clause( - delete_stmt, delete_stmt.table, extra_froms - ) + + try: + table_text = self.delete_table_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + from_linter=from_linter, + ) + except TypeError: + # anticipate 3rd party dialects that don't include **kw + # TODO: remove in 2.1 + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) + if from_linter: + _ = self.process(delete_stmt.table, from_linter=from_linter) crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) @@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled): delete_stmt.table, extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled): if delete_stmt._where_criteria: t = self._generate_delimited_and_list( - delete_stmt._where_criteria, **kw + delete_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="DELETE") + self.stack.pop(-1) return text |