summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py72
1 files changed, 61 insertions, 11 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 554a84112..619ff0848 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
else:
return None, None
- def warn(self):
+ def warn(self, stmt_type="SELECT"):
the_rest, start_with = self.lint()
# FROMS left over? boom
@@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
froms = the_rest
if froms:
template = (
- "SELECT statement has a cartesian product between "
+ "{stmt_type} statement has a cartesian product between "
"FROM element(s) {froms} and "
'FROM element "{start}". Apply join condition(s) '
"between each element to resolve."
@@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
f'"{self.froms[from_]}"' for from_ in froms
)
message = template.format(
- froms=froms_str, start=self.froms[start_with]
+ stmt_type=stmt_type,
+ froms=froms_str,
+ start=self.froms[start_with],
)
util.warn(message)
@@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
+
compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
@@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
@@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled):
)
table_text = self.update_tables_clause(
- update_stmt, update_stmt.table, render_extra_froms, **kw
+ update_stmt,
+ update_stmt.table,
+ render_extra_froms,
+ from_linter=from_linter,
+ **kw,
)
crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, toplevel, **kw
@@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled):
update_stmt.table,
render_extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled):
if update_stmt._where_criteria:
t = self._generate_delimited_and_list(
- update_stmt._where_criteria, **kw
+ update_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="UPDATE")
+
self.stack.pop(-1)
return text
@@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled):
"criteria within DELETE"
)
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
- return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, **kw
+ )
def visit_delete(self, delete_stmt, **kw):
compile_state = delete_stmt._compile_state_factory(
@@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
@@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled):
)
text += "FROM "
- table_text = self.delete_table_clause(
- delete_stmt, delete_stmt.table, extra_froms
- )
+
+ try:
+ table_text = self.delete_table_clause(
+ delete_stmt,
+ delete_stmt.table,
+ extra_froms,
+ from_linter=from_linter,
+ )
+ except TypeError:
+ # anticipate 3rd party dialects that don't include **kw
+ # TODO: remove in 2.1
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
+ if from_linter:
+ _ = self.process(delete_stmt.table, from_linter=from_linter)
crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
@@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled):
delete_stmt.table,
extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled):
if delete_stmt._where_criteria:
t = self._generate_delimited_and_list(
- delete_stmt._where_criteria, **kw
+ delete_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="DELETE")
+
self.stack.pop(-1)
return text