summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py176
1 files changed, 163 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8499484f3..ed463ebe3 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -41,7 +41,6 @@ from .base import NO_ARG
from .. import exc
from .. import util
-
RESERVED_WORDS = set(
[
"all",
@@ -270,6 +269,89 @@ ExpandedState = collections.namedtuple(
)
+NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
+
+COLLECT_CARTESIAN_PRODUCTS = util.symbol(
+ "COLLECT_CARTESIAN_PRODUCTS",
+ "Collect data on FROMs and cartesian products and gather "
+ "into 'self.from_linter'",
+ canonical=1,
+)
+
+WARN_LINTING = util.symbol(
+ "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
+)
+
+FROM_LINTING = util.symbol(
+ "FROM_LINTING",
+ "Warn for cartesian products; "
+ "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
+ canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
+)
+
+
+class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
+ def lint(self, start=None):
+ froms = self.froms
+ if not froms:
+ return None, None
+
+ edges = set(self.edges)
+ the_rest = set(froms)
+
+ if start is not None:
+ start_with = start
+ the_rest.remove(start_with)
+ else:
+ start_with = the_rest.pop()
+
+ stack = collections.deque([start_with])
+
+ while stack and the_rest:
+ node = stack.popleft()
+ the_rest.discard(node)
+
+ # comparison of nodes in edges here is based on hash equality, as
+ # there are "annotated" elements that match the non-annotated ones.
+ # to remove the need for in-python hash() calls, use native
+ # containment routines (e.g. "node in edge", "edge.index(node)")
+ to_remove = {edge for edge in edges if node in edge}
+
+ # appendleft the node in each edge that is not
+ # the one that matched.
+ stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
+ edges.difference_update(to_remove)
+
+ # FROMS left over? boom
+ if the_rest:
+ return the_rest, start_with
+ else:
+ return None, None
+
+ def warn(self):
+ the_rest, start_with = self.lint()
+
+ # FROMS left over? boom
+ if the_rest:
+
+ froms = the_rest
+ if froms:
+ template = (
+ "SELECT statement has a cartesian product between "
+ "FROM element(s) {froms} and "
+ 'FROM element "{start}". Apply join condition(s) '
+ "between each element to resolve."
+ )
+ froms_str = ", ".join(
+ '"{elem}"'.format(elem=self.froms[from_])
+ for from_ in froms
+ )
+ message = template.format(
+ froms=froms_str, start=self.froms[start_with]
+ )
+ util.warn(message)
+
+
class Compiled(object):
"""Represent a compiled SQL or DDL expression.
@@ -568,7 +650,13 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
def __init__(
- self, dialect, statement, column_keys=None, inline=False, **kwargs
+ self,
+ dialect,
+ statement,
+ column_keys=None,
+ inline=False,
+ linting=NO_LINTING,
+ **kwargs
):
"""Construct a new :class:`.SQLCompiler` object.
@@ -592,6 +680,8 @@ class SQLCompiler(Compiled):
# execute)
self.inline = inline or getattr(statement, "inline", False)
+ self.linting = linting
+
# a dictionary of bind parameter keys to BindParameter
# instances.
self.binds = {}
@@ -1547,9 +1637,21 @@ class SQLCompiler(Compiled):
return to_update, replacement_expression
def visit_binary(
- self, binary, override_operator=None, eager_grouping=False, **kw
+ self,
+ binary,
+ override_operator=None,
+ eager_grouping=False,
+ from_linter=None,
+ **kw
):
+ if from_linter and operators.is_comparison(binary.operator):
+ from_linter.edges.update(
+ itertools.product(
+ binary.left._from_objects, binary.right._from_objects
+ )
+ )
+
# don't allow "? = ?" to render
if (
self.ansi_bind_rules
@@ -1568,7 +1670,9 @@ class SQLCompiler(Compiled):
except KeyError:
raise exc.UnsupportedCompilationError(self, operator_)
else:
- return self._generate_generic_binary(binary, opstring, **kw)
+ return self._generate_generic_binary(
+ binary, opstring, from_linter=from_linter, **kw
+ )
def visit_function_as_comparison_op_binary(self, element, operator, **kw):
return self.process(element.sql_function, **kw)
@@ -1916,6 +2020,7 @@ class SQLCompiler(Compiled):
ashint=False,
fromhints=None,
visiting_cte=None,
+ from_linter=None,
**kwargs
):
self._init_cte_state()
@@ -2021,6 +2126,9 @@ class SQLCompiler(Compiled):
self.ctes[cte] = text
if asfrom:
+ if from_linter:
+ from_linter.froms[cte] = cte_name
+
if not is_new_cte and embedded_in_current_named_cte:
return self.preparer.format_alias(cte, cte_name)
@@ -2043,6 +2151,7 @@ class SQLCompiler(Compiled):
subquery=False,
lateral=False,
enclosing_alias=None,
+ from_linter=None,
**kwargs
):
if enclosing_alias is not None and enclosing_alias.element is alias:
@@ -2071,6 +2180,9 @@ class SQLCompiler(Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
+ if from_linter:
+ from_linter.froms[alias] = alias_name
+
inner = alias.element._compiler_dispatch(
self, asfrom=True, lateral=lateral, **kwargs
)
@@ -2284,6 +2396,7 @@ class SQLCompiler(Compiled):
compound_index=0,
select_wraps_for=None,
lateral=False,
+ from_linter=None,
**kwargs
):
@@ -2373,7 +2486,7 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, kwargs
+ text, select, inner_columns, froms, byfrom, toplevel, kwargs
)
if select._statement_hints:
@@ -2465,10 +2578,17 @@ class SQLCompiler(Compiled):
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, kwargs
+ self, text, select, inner_columns, froms, byfrom, toplevel, kwargs
):
text += ", ".join(inner_columns)
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+
if froms:
text += " \nFROM "
@@ -2476,7 +2596,11 @@ class SQLCompiler(Compiled):
text += ", ".join(
[
f._compiler_dispatch(
- self, asfrom=True, fromhints=byfrom, **kwargs
+ self,
+ asfrom=True,
+ fromhints=byfrom,
+ from_linter=from_linter,
+ **kwargs
)
for f in froms
]
@@ -2484,7 +2608,12 @@ class SQLCompiler(Compiled):
else:
text += ", ".join(
[
- f._compiler_dispatch(self, asfrom=True, **kwargs)
+ f._compiler_dispatch(
+ self,
+ asfrom=True,
+ from_linter=from_linter,
+ **kwargs
+ )
for f in froms
]
)
@@ -2492,10 +2621,18 @@ class SQLCompiler(Compiled):
text += self.default_from()
if select._whereclause is not None:
- t = select._whereclause._compiler_dispatch(self, **kwargs)
+ t = select._whereclause._compiler_dispatch(
+ self, from_linter=from_linter, **kwargs
+ )
if t:
text += " \nWHERE " + t
+ if (
+ self.linting & COLLECT_CARTESIAN_PRODUCTS
+ and self.linting & WARN_LINTING
+ ):
+ from_linter.warn()
+
if select._group_by_clause.clauses:
text += self.group_by_clause(select, **kwargs)
@@ -2597,8 +2734,12 @@ class SQLCompiler(Compiled):
ashint=False,
fromhints=None,
use_schema=True,
+ from_linter=None,
**kwargs
):
+ if from_linter:
+ from_linter.froms[table] = table.fullname
+
if asfrom or ashint:
effective_schema = self.preparer.schema_for_object(table)
@@ -2618,7 +2759,10 @@ class SQLCompiler(Compiled):
else:
return ""
- def visit_join(self, join, asfrom=False, **kwargs):
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
if join.full:
join_type = " FULL OUTER JOIN "
elif join.isouter:
@@ -2626,12 +2770,18 @@ class SQLCompiler(Compiled):
else:
join_type = " JOIN "
return (
- join.left._compiler_dispatch(self, asfrom=True, **kwargs)
+ join.left._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ join_type
- + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
+ + join.right._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ " ON "
# TODO: likely need asfrom=True here?
- + join.onclause._compiler_dispatch(self, **kwargs)
+ + join.onclause._compiler_dispatch(
+ self, from_linter=from_linter, **kwargs
+ )
)
def _setup_crud_hints(self, stmt, table_text):