diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 176 |
1 files changed, 163 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8499484f3..ed463ebe3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -41,7 +41,6 @@ from .base import NO_ARG from .. import exc from .. import util - RESERVED_WORDS = set( [ "all", @@ -270,6 +269,89 @@ ExpandedState = collections.namedtuple( ) +NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + +COLLECT_CARTESIAN_PRODUCTS = util.symbol( + "COLLECT_CARTESIAN_PRODUCTS", + "Collect data on FROMs and cartesian products and gather " + "into 'self.from_linter'", + canonical=1, +) + +WARN_LINTING = util.symbol( + "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 +) + +FROM_LINTING = util.symbol( + "FROM_LINTING", + "Warn for cartesian products; " + "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", + canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, +) + + +class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + def lint(self, start=None): + froms = self.froms + if not froms: + return None, None + + edges = set(self.edges) + the_rest = set(froms) + + if start is not None: + start_with = start + the_rest.remove(start_with) + else: + start_with = the_rest.pop() + + stack = collections.deque([start_with]) + + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + + # comparison of nodes in edges here is based on hash equality, as + # there are "annotated" elements that match the non-annotated ones. + # to remove the need for in-python hash() calls, use native + # containment routines (e.g. "node in edge", "edge.index(node)") + to_remove = {edge for edge in edges if node in edge} + + # appendleft the node in each edge that is not + # the one that matched. + stack.extendleft(edge[not edge.index(node)] for edge in to_remove) + edges.difference_update(to_remove) + + # FROMS left over? boom + if the_rest: + return the_rest, start_with + else: + return None, None + + def warn(self): + the_rest, start_with = self.lint() + + # FROMS left over? boom + if the_rest: + + froms = the_rest + if froms: + template = ( + "SELECT statement has a cartesian product between " + "FROM element(s) {froms} and " + 'FROM element "{start}". Apply join condition(s) ' + "between each element to resolve." + ) + froms_str = ", ".join( + '"{elem}"'.format(elem=self.froms[from_]) + for from_ in froms + ) + message = template.format( + froms=froms_str, start=self.froms[start_with] + ) + util.warn(message) + + class Compiled(object): """Represent a compiled SQL or DDL expression. @@ -568,7 +650,13 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () def __init__( - self, dialect, statement, column_keys=None, inline=False, **kwargs + self, + dialect, + statement, + column_keys=None, + inline=False, + linting=NO_LINTING, + **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -592,6 +680,8 @@ class SQLCompiler(Compiled): # execute) self.inline = inline or getattr(statement, "inline", False) + self.linting = linting + # a dictionary of bind parameter keys to BindParameter # instances. self.binds = {} @@ -1547,9 +1637,21 @@ class SQLCompiler(Compiled): return to_update, replacement_expression def visit_binary( - self, binary, override_operator=None, eager_grouping=False, **kw + self, + binary, + override_operator=None, + eager_grouping=False, + from_linter=None, + **kw ): + if from_linter and operators.is_comparison(binary.operator): + from_linter.edges.update( + itertools.product( + binary.left._from_objects, binary.right._from_objects + ) + ) + # don't allow "? = ?" to render if ( self.ansi_bind_rules @@ -1568,7 +1670,9 @@ class SQLCompiler(Compiled): except KeyError: raise exc.UnsupportedCompilationError(self, operator_) else: - return self._generate_generic_binary(binary, opstring, **kw) + return self._generate_generic_binary( + binary, opstring, from_linter=from_linter, **kw + ) def visit_function_as_comparison_op_binary(self, element, operator, **kw): return self.process(element.sql_function, **kw) @@ -1916,6 +2020,7 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, visiting_cte=None, + from_linter=None, **kwargs ): self._init_cte_state() @@ -2021,6 +2126,9 @@ class SQLCompiler(Compiled): self.ctes[cte] = text if asfrom: + if from_linter: + from_linter.froms[cte] = cte_name + if not is_new_cte and embedded_in_current_named_cte: return self.preparer.format_alias(cte, cte_name) @@ -2043,6 +2151,7 @@ class SQLCompiler(Compiled): subquery=False, lateral=False, enclosing_alias=None, + from_linter=None, **kwargs ): if enclosing_alias is not None and enclosing_alias.element is alias: @@ -2071,6 +2180,9 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: + if from_linter: + from_linter.froms[alias] = alias_name + inner = alias.element._compiler_dispatch( self, asfrom=True, lateral=lateral, **kwargs ) @@ -2284,6 +2396,7 @@ class SQLCompiler(Compiled): compound_index=0, select_wraps_for=None, lateral=False, + from_linter=None, **kwargs ): @@ -2373,7 +2486,7 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs + text, select, inner_columns, froms, byfrom, toplevel, kwargs ) if select._statement_hints: @@ -2465,10 +2578,17 @@ class SQLCompiler(Compiled): return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs + self, text, select, inner_columns, froms, byfrom, toplevel, kwargs ): text += ", ".join(inner_columns) + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + if froms: text += " \nFROM " @@ -2476,7 +2596,11 @@ class SQLCompiler(Compiled): text += ", ".join( [ f._compiler_dispatch( - self, asfrom=True, fromhints=byfrom, **kwargs + self, + asfrom=True, + fromhints=byfrom, + from_linter=from_linter, + **kwargs ) for f in froms ] @@ -2484,7 +2608,12 @@ class SQLCompiler(Compiled): else: text += ", ".join( [ - f._compiler_dispatch(self, asfrom=True, **kwargs) + f._compiler_dispatch( + self, + asfrom=True, + from_linter=from_linter, + **kwargs + ) for f in froms ] ) @@ -2492,10 +2621,18 @@ class SQLCompiler(Compiled): text += self.default_from() if select._whereclause is not None: - t = select._whereclause._compiler_dispatch(self, **kwargs) + t = select._whereclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) if t: text += " \nWHERE " + t + if ( + self.linting & COLLECT_CARTESIAN_PRODUCTS + and self.linting & WARN_LINTING + ): + from_linter.warn() + if select._group_by_clause.clauses: text += self.group_by_clause(select, **kwargs) @@ -2597,8 +2734,12 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, use_schema=True, + from_linter=None, **kwargs ): + if from_linter: + from_linter.froms[table] = table.fullname + if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) @@ -2618,7 +2759,10 @@ class SQLCompiler(Compiled): else: return "" - def visit_join(self, join, asfrom=False, **kwargs): + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + if join.full: join_type = " FULL OUTER JOIN " elif join.isouter: @@ -2626,12 +2770,18 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + join_type - + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + " ON " # TODO: likely need asfrom=True here? - + join.onclause._compiler_dispatch(self, **kwargs) + + join.onclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) ) def _setup_crud_hints(self, stmt, table_text): |