diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2030 |
1 files changed, 1201 insertions, 829 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 80ed707ed..f641d0a84 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -25,133 +25,218 @@ To generate user-defined SQL strings, see import contextlib import re -from . import schema, sqltypes, operators, functions, visitors, \ - elements, selectable, crud +from . import ( + schema, + sqltypes, + operators, + functions, + visitors, + elements, + selectable, + crud, +) from .. import util, exc import itertools -RESERVED_WORDS = set([ - 'all', 'analyse', 'analyze', 'and', 'any', 'array', - 'as', 'asc', 'asymmetric', 'authorization', 'between', - 'binary', 'both', 'case', 'cast', 'check', 'collate', - 'column', 'constraint', 'create', 'cross', 'current_date', - 'current_role', 'current_time', 'current_timestamp', - 'current_user', 'default', 'deferrable', 'desc', - 'distinct', 'do', 'else', 'end', 'except', 'false', - 'for', 'foreign', 'freeze', 'from', 'full', 'grant', - 'group', 'having', 'ilike', 'in', 'initially', 'inner', - 'intersect', 'into', 'is', 'isnull', 'join', 'leading', - 'left', 'like', 'limit', 'localtime', 'localtimestamp', - 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', - 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', - 'placing', 'primary', 'references', 'right', 'select', - 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', - 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', - 'using', 'verbose', 'when', 'where']) - -LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) -ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$']) - -BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) -BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE) +RESERVED_WORDS = set( + [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", + ] +) + +LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) +ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"]) + +BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE) +BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE) BIND_TEMPLATES = { - 'pyformat': "%%(%(name)s)s", - 'qmark': "?", - 'format': "%%s", - 'numeric': ":[_POSITION]", - 'named': ":%(name)s" + "pyformat": "%%(%(name)s)s", + "qmark": "?", + "format": "%%s", + "numeric": ":[_POSITION]", + "named": ":%(name)s", } OPERATORS = { # binary - operators.and_: ' AND ', - operators.or_: ' OR ', - operators.add: ' + ', - operators.mul: ' * ', - operators.sub: ' - ', - operators.div: ' / ', - operators.mod: ' % ', - operators.truediv: ' / ', - operators.neg: '-', - operators.lt: ' < ', - operators.le: ' <= ', - operators.ne: ' != ', - operators.gt: ' > ', - operators.ge: ' >= ', - operators.eq: ' = ', - operators.is_distinct_from: ' IS DISTINCT FROM ', - operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ', - operators.concat_op: ' || ', - operators.match_op: ' MATCH ', - operators.notmatch_op: ' NOT MATCH ', - operators.in_op: ' IN ', - operators.notin_op: ' NOT IN ', - operators.comma_op: ', ', - operators.from_: ' FROM ', - operators.as_: ' AS ', - operators.is_: ' IS ', - operators.isnot: ' IS NOT ', - operators.collate: ' COLLATE ', - + operators.and_: " AND ", + operators.or_: " OR ", + operators.add: " + ", + operators.mul: " * ", + operators.sub: " - ", + operators.div: " / ", + operators.mod: " % ", + operators.truediv: " / ", + operators.neg: "-", + operators.lt: " < ", + operators.le: " <= ", + operators.ne: " != ", + operators.gt: " > ", + operators.ge: " >= ", + operators.eq: " = ", + operators.is_distinct_from: " IS DISTINCT FROM ", + operators.isnot_distinct_from: " IS NOT DISTINCT FROM ", + operators.concat_op: " || ", + operators.match_op: " MATCH ", + operators.notmatch_op: " NOT MATCH ", + operators.in_op: " IN ", + operators.notin_op: " NOT IN ", + operators.comma_op: ", ", + operators.from_: " FROM ", + operators.as_: " AS ", + operators.is_: " IS ", + operators.isnot: " IS NOT ", + operators.collate: " COLLATE ", # unary - operators.exists: 'EXISTS ', - operators.distinct_op: 'DISTINCT ', - operators.inv: 'NOT ', - operators.any_op: 'ANY ', - operators.all_op: 'ALL ', - + operators.exists: "EXISTS ", + operators.distinct_op: "DISTINCT ", + operators.inv: "NOT ", + operators.any_op: "ANY ", + operators.all_op: "ALL ", # modifiers - operators.desc_op: ' DESC', - operators.asc_op: ' ASC', - operators.nullsfirst_op: ' NULLS FIRST', - operators.nullslast_op: ' NULLS LAST', - + operators.desc_op: " DESC", + operators.asc_op: " ASC", + operators.nullsfirst_op: " NULLS FIRST", + operators.nullslast_op: " NULLS LAST", } FUNCTIONS = { - functions.coalesce: 'coalesce', - functions.current_date: 'CURRENT_DATE', - functions.current_time: 'CURRENT_TIME', - functions.current_timestamp: 'CURRENT_TIMESTAMP', - functions.current_user: 'CURRENT_USER', - functions.localtime: 'LOCALTIME', - functions.localtimestamp: 'LOCALTIMESTAMP', - functions.random: 'random', - functions.sysdate: 'sysdate', - functions.session_user: 'SESSION_USER', - functions.user: 'USER', - functions.cube: 'CUBE', - functions.rollup: 'ROLLUP', - functions.grouping_sets: 'GROUPING SETS', + functions.coalesce: "coalesce", + functions.current_date: "CURRENT_DATE", + functions.current_time: "CURRENT_TIME", + functions.current_timestamp: "CURRENT_TIMESTAMP", + functions.current_user: "CURRENT_USER", + functions.localtime: "LOCALTIME", + functions.localtimestamp: "LOCALTIMESTAMP", + functions.random: "random", + functions.sysdate: "sysdate", + functions.session_user: "SESSION_USER", + functions.user: "USER", + functions.cube: "CUBE", + functions.rollup: "ROLLUP", + functions.grouping_sets: "GROUPING SETS", } EXTRACT_MAP = { - 'month': 'month', - 'day': 'day', - 'year': 'year', - 'second': 'second', - 'hour': 'hour', - 'doy': 'doy', - 'minute': 'minute', - 'quarter': 'quarter', - 'dow': 'dow', - 'week': 'week', - 'epoch': 'epoch', - 'milliseconds': 'milliseconds', - 'microseconds': 'microseconds', - 'timezone_hour': 'timezone_hour', - 'timezone_minute': 'timezone_minute' + "month": "month", + "day": "day", + "year": "year", + "second": "second", + "hour": "hour", + "doy": "doy", + "minute": "minute", + "quarter": "quarter", + "dow": "dow", + "week": "week", + "epoch": "epoch", + "milliseconds": "milliseconds", + "microseconds": "microseconds", + "timezone_hour": "timezone_hour", + "timezone_minute": "timezone_minute", } COMPOUND_KEYWORDS = { - selectable.CompoundSelect.UNION: 'UNION', - selectable.CompoundSelect.UNION_ALL: 'UNION ALL', - selectable.CompoundSelect.EXCEPT: 'EXCEPT', - selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', - selectable.CompoundSelect.INTERSECT: 'INTERSECT', - selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' + selectable.CompoundSelect.UNION: "UNION", + selectable.CompoundSelect.UNION_ALL: "UNION ALL", + selectable.CompoundSelect.EXCEPT: "EXCEPT", + selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", + selectable.CompoundSelect.INTERSECT: "INTERSECT", + selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", } @@ -177,9 +262,14 @@ class Compiled(object): sub-elements of the statement can modify these. """ - def __init__(self, dialect, statement, bind=None, - schema_translate_map=None, - compile_kwargs=util.immutabledict()): + def __init__( + self, + dialect, + statement, + bind=None, + schema_translate_map=None, + compile_kwargs=util.immutabledict(), + ): """Construct a new :class:`.Compiled` object. :param dialect: :class:`.Dialect` to compile against. @@ -209,7 +299,8 @@ class Compiled(object): self.preparer = self.dialect.identifier_preparer if schema_translate_map: self.preparer = self.preparer._with_schema_translate( - schema_translate_map) + schema_translate_map + ) if statement is not None: self.statement = statement @@ -218,8 +309,10 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) - @util.deprecated("0.7", ":class:`.Compiled` objects now compile " - "within the constructor.") + @util.deprecated( + "0.7", + ":class:`.Compiled` objects now compile " "within the constructor.", + ) def compile(self): """Produce the internal string representation of this element. """ @@ -247,7 +340,7 @@ class Compiled(object): def __str__(self): """Return the string text of the generated SQL or DDL.""" - return self.string or '' + return self.string or "" def construct_params(self, params=None): """Return the bind params for this compiled object. @@ -271,7 +364,9 @@ class Compiled(object): if e is None: raise exc.UnboundExecutionError( "This Compiled object is not bound to any Engine " - "or Connection.", code="2afi") + "or Connection.", + code="2afi", + ) return e._execute_compiled(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -284,7 +379,7 @@ class Compiled(object): class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" - ensure_kwarg = r'visit_\w+' + ensure_kwarg = r"visit_\w+" def __init__(self, dialect): self.dialect = dialect @@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression.Label.""" - __visit_name__ = 'label' - __slots__ = 'element', 'name' + __visit_name__ = "label" + __slots__ = "element", "name" def __init__(self, col, name, alt_names=()): self.element = col @@ -390,8 +485,9 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () - def __init__(self, dialect, statement, column_keys=None, - inline=False, **kwargs): + def __init__( + self, dialect, statement, column_keys=None, inline=False, **kwargs + ): """Construct a new :class:`.SQLCompiler` object. :param dialect: :class:`.Dialect` to be used @@ -412,7 +508,7 @@ class SQLCompiler(Compiled): # compile INSERT/UPDATE defaults/sequences inlined (no pre- # execute) - self.inline = inline or getattr(statement, 'inline', False) + self.inline = inline or getattr(statement, "inline", False) # a dictionary of bind parameter keys to BindParameter # instances. @@ -440,8 +536,9 @@ class SQLCompiler(Compiled): self.ctes = None - self.label_length = dialect.label_length \ - or dialect.max_identifier_length + self.label_length = ( + dialect.label_length or dialect.max_identifier_length + ) # a map which tracks "anonymous" identifiers that are created on # the fly here @@ -453,7 +550,7 @@ class SQLCompiler(Compiled): Compiled.__init__(self, dialect, statement, **kwargs) if ( - self.isinsert or self.isupdate or self.isdelete + self.isinsert or self.isupdate or self.isdelete ) and statement._returning: self.returning = statement._returning @@ -482,37 +579,43 @@ class SQLCompiler(Compiled): def _nested_result(self): """special API to support the use case of 'nested result sets'""" result_columns, ordered_columns = ( - self._result_columns, self._ordered_columns) + self._result_columns, + self._ordered_columns, + ) self._result_columns, self._ordered_columns = [], False try: if self.stack: entry = self.stack[-1] - entry['need_result_map_for_nested'] = True + entry["need_result_map_for_nested"] = True else: entry = None yield self._result_columns, self._ordered_columns finally: if entry: - entry.pop('need_result_map_for_nested') + entry.pop("need_result_map_for_nested") self._result_columns, self._ordered_columns = ( - result_columns, ordered_columns) + result_columns, + ordered_columns, + ) def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r'\[_POSITION\]', - lambda m: str(util.next(poscount)), - self.string) + r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + ) @util.memoized_property def _bind_processors(self): return dict( - (key, value) for key, value in - ((self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect) - ) - for bindparam in self.bind_names) + (key, value) + for key, value in ( + ( + self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect), + ) + for bindparam in self.bind_names + ) if value is not None ) @@ -539,12 +642,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) elif bindparam.callable: pd[name] = bindparam.effective_value @@ -558,12 +665,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) if bindparam.callable: pd[self.bind_names[bindparam]] = bindparam.effective_value @@ -595,9 +706,10 @@ class SQLCompiler(Compiled): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" def visit_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if self.stack and self.dialect.supports_simple_order_by_label: - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict if within_columns_clause: @@ -611,25 +723,30 @@ class SQLCompiler(Compiled): # to something else like a ColumnClause expression. order_by_elem = element.element._order_by_label_element - if order_by_elem is not None and order_by_elem.name in \ - resolve_dict and \ - order_by_elem.shares_lineage( - resolve_dict[order_by_elem.name]): - kwargs['render_label_as_label'] = \ - element.element._order_by_label_element + if ( + order_by_elem is not None + and order_by_elem.name in resolve_dict + and order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name] + ) + ): + kwargs[ + "render_label_as_label" + ] = element.element._order_by_label_element return self.process( - element.element, within_columns_clause=within_columns_clause, - **kwargs) + element.element, + within_columns_clause=within_columns_clause, + **kwargs + ) def visit_textual_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if not self.stack: # compiling the element outside of the context of a SELECT - return self.process( - element._text_clause - ) + return self.process(element._text_clause) - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict try: if within_columns_clause: @@ -640,26 +757,30 @@ class SQLCompiler(Compiled): # treat it like text() util.warn_limited( "Can't resolve label reference %r; converting to text()", - util.ellipses_string(element.element)) - return self.process( - element._text_clause + util.ellipses_string(element.element), ) + return self.process(element._text_clause) else: - kwargs['render_label_as_label'] = col + kwargs["render_label_as_label"] = col return self.process( - col, within_columns_clause=within_columns_clause, **kwargs) - - def visit_label(self, label, - add_to_result_map=None, - within_label_clause=False, - within_columns_clause=False, - render_label_as_label=None, - **kw): + col, within_columns_clause=within_columns_clause, **kwargs + ) + + def visit_label( + self, + label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + **kw + ): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. - render_label_with_as = (within_columns_clause and not - within_label_clause) + render_label_with_as = ( + within_columns_clause and not within_label_clause + ) render_label_only = render_label_as_label is label if render_label_only or render_label_with_as: @@ -673,27 +794,35 @@ class SQLCompiler(Compiled): add_to_result_map( labelname, label.name, - (label, labelname, ) + label._alt_names, - label.type + (label, labelname) + label._alt_names, + label.type, ) - return label.element._compiler_dispatch( - self, within_columns_clause=True, - within_label_clause=True, **kw) + \ - OPERATORS[operators.as_] + \ - self.preparer.format_label(label, labelname) + return ( + label.element._compiler_dispatch( + self, + within_columns_clause=True, + within_label_clause=True, + **kw + ) + + OPERATORS[operators.as_] + + self.preparer.format_label(label, labelname) + ) elif render_label_only: return self.preparer.format_label(label, labelname) else: return label.element._compiler_dispatch( - self, within_columns_clause=False, **kw) + self, within_columns_clause=False, **kw + ) def _fallback_column_name(self, column): - raise exc.CompileError("Cannot compile Column object until " - "its 'name' is assigned.") + raise exc.CompileError( + "Cannot compile Column object until " "its 'name' is assigned." + ) - def visit_column(self, column, add_to_result_map=None, - include_table=True, **kwargs): + def visit_column( + self, column, add_to_result_map=None, include_table=True, **kwargs + ): name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -704,10 +833,7 @@ class SQLCompiler(Compiled): if add_to_result_map is not None: add_to_result_map( - name, - orig_name, - (column, name, column.key), - column.type + name, orig_name, (column, name, column.key), column.type ) if is_literal: @@ -721,17 +847,16 @@ class SQLCompiler(Compiled): effective_schema = self.preparer.schema_for_object(table) if effective_schema: - schema_prefix = self.preparer.quote_schema( - effective_schema) + '.' + schema_prefix = ( + self.preparer.quote_schema(effective_schema) + "." + ) else: - schema_prefix = '' + schema_prefix = "" tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) - return schema_prefix + \ - self.preparer.quote(tablename) + \ - "." + name + return schema_prefix + self.preparer.quote(tablename) + "." + name def visit_collation(self, element, **kw): return self.preparer.format_collation(element.collation) @@ -743,17 +868,17 @@ class SQLCompiler(Compiled): return index.name def visit_typeclause(self, typeclause, **kw): - kw['type_expression'] = typeclause + kw["type_expression"] = typeclause return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def escape_literal_column(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def visit_textclause(self, textclause, **kw): @@ -771,30 +896,36 @@ class SQLCompiler(Compiled): return BIND_PARAMS_ESC.sub( lambda m: m.group(1), BIND_PARAMS.sub( - do_bindparam, - self.post_process_text(textclause.text)) + do_bindparam, self.post_process_text(textclause.text) + ), ) - def visit_text_as_from(self, taf, - compound_index=None, - asfrom=False, - parens=True, **kw): + def visit_text_as_from( + self, taf, compound_index=None, asfrom=False, parens=True, **kw + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) if populate_result_map: - self._ordered_columns = \ - self._textual_ordered_columns = taf.positional + self._ordered_columns = ( + self._textual_ordered_columns + ) = taf.positional for c in taf.column_args: - self.process(c, within_columns_clause=True, - add_to_result_map=self._add_to_result_map) + self.process( + c, + within_columns_clause=True, + add_to_result_map=self._add_to_result_map, + ) text = self.process(taf.element, **kw) if asfrom and parens: @@ -802,17 +933,17 @@ class SQLCompiler(Compiled): return text def visit_null(self, expr, **kw): - return 'NULL' + return "NULL" def visit_true(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'true' + return "true" else: return "1" def visit_false(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'false' + return "false" else: return "0" @@ -823,25 +954,29 @@ class SQLCompiler(Compiled): else: sep = OPERATORS[clauselist.operator] return sep.join( - s for s in - ( - c._compiler_dispatch(self, **kw) - for c in clauselist.clauses) - if s) + s + for s in ( + c._compiler_dispatch(self, **kw) for c in clauselist.clauses + ) + if s + ) def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: x += clause.value._compiler_dispatch(self, **kwargs) + " " for cond, result in clause.whens: - x += "WHEN " + cond._compiler_dispatch( - self, **kwargs - ) + " THEN " + result._compiler_dispatch( - self, **kwargs) + " " + x += ( + "WHEN " + + cond._compiler_dispatch(self, **kwargs) + + " THEN " + + result._compiler_dispatch(self, **kwargs) + + " " + ) if clause.else_ is not None: - x += "ELSE " + clause.else_._compiler_dispatch( - self, **kwargs - ) + " " + x += ( + "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " " + ) x += "END" return x @@ -849,79 +984,84 @@ class SQLCompiler(Compiled): return type_coerce.typed_expression._compiler_dispatch(self, **kw) def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % \ - (cast.clause._compiler_dispatch(self, **kwargs), - cast.typeclause._compiler_dispatch(self, **kwargs)) + return "CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs), + ) def _format_frame_clause(self, range_, **kw): - return '%s AND %s' % ( + return "%s AND %s" % ( "UNBOUNDED PRECEDING" if range_[0] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[0])), **kw), ) + else "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[0])), **kw),) if range_[0] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[0]), **kw), ), - + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),), "UNBOUNDED FOLLOWING" if range_[1] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[1])), **kw), ) + else "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[1])), **kw),) if range_[1] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[1]), **kw), ), + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),), ) def visit_over(self, over, **kwargs): if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( - over.range_, **kwargs) + over.range_, **kwargs + ) elif over.rows: range_ = "ROWS BETWEEN %s" % self._format_frame_clause( - over.rows, **kwargs) + over.rows, **kwargs + ) else: range_ = None return "%s OVER (%s)" % ( over.element._compiler_dispatch(self, **kwargs), - ' '.join([ - '%s BY %s' % ( - word, clause._compiler_dispatch(self, **kwargs) - ) - for word, clause in ( - ('PARTITION', over.partition_by), - ('ORDER', over.order_by) - ) - if clause is not None and len(clause) - ] + ([range_] if range_ else []) - ) + " ".join( + [ + "%s BY %s" + % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + ([range_] if range_ else []) + ), ) def visit_withingroup(self, withingroup, **kwargs): return "%s WITHIN GROUP (ORDER BY %s)" % ( withingroup.element._compiler_dispatch(self, **kwargs), - withingroup.order_by._compiler_dispatch(self, **kwargs) + withingroup.order_by._compiler_dispatch(self, **kwargs), ) def visit_funcfilter(self, funcfilter, **kwargs): return "%s FILTER (WHERE %s)" % ( funcfilter.func._compiler_dispatch(self, **kwargs), - funcfilter.criterion._compiler_dispatch(self, **kwargs) + funcfilter.criterion._compiler_dispatch(self, **kwargs), ) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) return "EXTRACT(%s FROM %s)" % ( - field, extract.expr._compiler_dispatch(self, **kwargs)) + field, + extract.expr._compiler_dispatch(self, **kwargs), + ) def visit_function(self, func, add_to_result_map=None, **kwargs): if add_to_result_map is not None: - add_to_result_map( - func.name, func.name, (), func.type - ) + add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -933,51 +1073,63 @@ class SQLCompiler(Compiled): name += "%(expr)s" else: name = func.name + "%(expr)s" - return ".".join(list(func.packagenames) + [name]) % \ - {'expr': self.function_argspec(func, **kwargs)} + return ".".join(list(func.packagenames) + [name]) % { + "expr": self.function_argspec(func, **kwargs) + } def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) def visit_sequence(self, sequence, **kw): raise NotImplementedError( - "Dialect '%s' does not support sequence increments." % - self.dialect.name + "Dialect '%s' does not support sequence increments." + % self.dialect.name ) def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, - parens=True, compound_index=0, **kwargs): + def visit_compound_select( + self, cs, asfrom=False, parens=True, compound_index=0, **kwargs + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - need_result_map = toplevel or \ - (compound_index == 0 - and entry.get('need_result_map_for_compound', False)) + need_result_map = toplevel or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) self.stack.append( { - 'correlate_froms': entry['correlate_froms'], - 'asfrom_froms': entry['asfrom_froms'], - 'selectable': cs, - 'need_result_map_for_compound': need_result_map - }) + "correlate_froms": entry["correlate_froms"], + "asfrom_froms": entry["asfrom_froms"], + "selectable": cs, + "need_result_map_for_compound": need_result_map, + } + ) keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (c._compiler_dispatch(self, - asfrom=asfrom, parens=False, - compound_index=i, **kwargs) - for i, c in enumerate(cs.selects)) + ( + c._compiler_dispatch( + self, + asfrom=asfrom, + parens=False, + compound_index=i, + **kwargs + ) + for i, c in enumerate(cs.selects) + ) ) text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) text += self.order_by_clause(cs, **kwargs) - text += (cs._limit_clause is not None - or cs._offset_clause is not None) and \ - self.limit_clause(cs, **kwargs) or "" + text += ( + (cs._limit_clause is not None or cs._offset_clause is not None) + and self.limit_clause(cs, **kwargs) + or "" + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -990,8 +1142,10 @@ class SQLCompiler(Compiled): def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): attrname = "visit_%s_%s%s" % ( - operator_.__name__, qualifier1, - "_" + qualifier2 if qualifier2 else "") + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) return getattr(self, attrname, None) def visit_unary(self, unary, **kw): @@ -999,51 +1153,63 @@ class SQLCompiler(Compiled): if unary.modifier: raise exc.CompileError( "Unary expression does not support operator " - "and modifier simultaneously") + "and modifier simultaneously" + ) disp = self._get_operator_dispatch( - unary.operator, "unary", "operator") + unary.operator, "unary", "operator" + ) if disp: return disp(unary, unary.operator, **kw) else: return self._generate_generic_unary_operator( - unary, OPERATORS[unary.operator], **kw) + unary, OPERATORS[unary.operator], **kw + ) elif unary.modifier: disp = self._get_operator_dispatch( - unary.modifier, "unary", "modifier") + unary.modifier, "unary", "modifier" + ) if disp: return disp(unary, unary.modifier, **kw) else: return self._generate_generic_unary_modifier( - unary, OPERATORS[unary.modifier], **kw) + unary, OPERATORS[unary.modifier], **kw + ) else: raise exc.CompileError( - "Unary expression has no operator or modifier") + "Unary expression has no operator or modifier" + ) def visit_istrue_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return self.process(element.element, **kw) else: return "%s = 1" % self.process(element.element, **kw) def visit_isfalse_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return "NOT %s" % self.process(element.element, **kw) else: return "%s = 0" % self.process(element.element, **kw) def visit_notmatch_op_binary(self, binary, operator, **kw): return "NOT %s" % self.visit_binary( - binary, override_operator=operators.match_op) + binary, override_operator=operators.match_op + ) def _emit_empty_in_warning(self): util.warn( - 'The IN-predicate was invoked with an ' - 'empty sequence. This results in a ' - 'contradiction, which nonetheless can be ' - 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.') + "The IN-predicate was invoked with an " + "empty sequence. This results in a " + "contradiction, which nonetheless can be " + "expensive to evaluate. Consider alternative " + "strategies for improved performance." + ) def visit_empty_in_op_binary(self, binary, operator, **kw): if self.dialect._use_static_in: @@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled): def visit_empty_set_expr(self, element_types): raise NotImplementedError( - "Dialect '%s' does not support empty set expression." % - self.dialect.name + "Dialect '%s' does not support empty set expression." + % self.dialect.name ) - def visit_binary(self, binary, override_operator=None, - eager_grouping=False, **kw): + def visit_binary( + self, binary, override_operator=None, eager_grouping=False, **kw + ): # don't allow "? = ?" to render - if self.ansi_bind_rules and \ - isinstance(binary.left, elements.BindParameter) and \ - isinstance(binary.right, elements.BindParameter): - kw['literal_binds'] = True + if ( + self.ansi_bind_rules + and isinstance(binary.left, elements.BindParameter) + and isinstance(binary.right, elements.BindParameter) + ): + kw["literal_binds"] = True operator_ = override_operator or binary.operator disp = self._get_operator_dispatch(operator_, "binary", None) @@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled): def visit_mod_binary(self, binary, operator, **kw): if self.preparer._double_percents: - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) else: - return self.process(binary.left, **kw) + " % " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) def visit_custom_op_binary(self, element, operator, **kw): - kw['eager_grouping'] = operator.eager_grouping + kw["eager_grouping"] = operator.eager_grouping return self._generate_generic_binary( - element, " " + operator.opstring + " ", **kw) + element, " " + operator.opstring + " ", **kw + ) def visit_custom_op_unary_operator(self, element, operator, **kw): return self._generate_generic_unary_operator( - element, operator.opstring + " ", **kw) + element, operator.opstring + " ", **kw + ) def visit_custom_op_unary_modifier(self, element, operator, **kw): return self._generate_generic_unary_modifier( - element, " " + operator.opstring, **kw) + element, " " + operator.opstring, **kw + ) def _generate_generic_binary( - self, binary, opstring, eager_grouping=False, **kw): + self, binary, opstring, eager_grouping=False, **kw + ): - _in_binary = kw.get('_in_binary', False) + _in_binary = kw.get("_in_binary", False) - kw['_in_binary'] = True - text = binary.left._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + \ - opstring + \ - binary.right._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + kw["_in_binary"] = True + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) if _in_binary and eager_grouping: text = "(%s)" % text @@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled): def visit_startswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_like_op_binary(binary, operator, **kw) def visit_notstartswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_notlike_op_binary(binary, operator, **kw) def visit_endswith_op_binary(self, binary, operator, **kw): @@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled): escape = binary.modifiers.get("escape", None) # TODO: use ternary here, not "and"/ "or" - return '%s LIKE %s' % ( + return "%s LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT LIKE %s' % ( + return "%s NOT LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) LIKE lower(%s)' % ( + return "lower(%s) LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) NOT LIKE lower(%s)' % ( + return "lower(%s) NOT LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " BETWEEN SYMMETRIC " - if symmetric else " BETWEEN ", **kw) + binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw + ) def visit_notbetween_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " NOT BETWEEN SYMMETRIC " - if symmetric else " NOT BETWEEN ", **kw) + binary, + " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ", + **kw + ) - def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, - skip_bind_expression=False, - **kwargs): + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs + ): if not skip_bind_expression: impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: bind_expression = impl.bind_expression(bindparam) return self.process( - bind_expression, skip_bind_expression=True, + bind_expression, + skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, **kwargs ) - if literal_binds or \ - (within_columns_clause and - self.ansi_bind_rules): + if literal_binds or (within_columns_clause and self.ansi_bind_rules): if bindparam.value is None and bindparam.callable is None: - raise exc.CompileError("Bind parameter '%s' without a " - "renderable value not allowed here." - % bindparam.key) + raise exc.CompileError( + "Bind parameter '%s' without a " + "renderable value not allowed here." % bindparam.key + ) return self.render_literal_bindparam( - bindparam, within_columns_clause=True, **kwargs) + bindparam, within_columns_clause=True, **kwargs + ) name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam: - if (existing.unique or bindparam.unique) and \ - not existing.proxy_set.intersection( - bindparam.proxy_set): + if ( + existing.unique or bindparam.unique + ) and not existing.proxy_set.intersection(bindparam.proxy_set): raise exc.CompileError( "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" % - bindparam.key + "unique bind parameter of the same name" + % bindparam.key ) elif existing._is_crud or bindparam._is_crud: raise exc.CompileError( @@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled): "clause of this " "insert/update statement. Please use a " "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." % - (bindparam.key, bindparam.key) + "with insert() or update() (for example, 'b_%s')." + % (bindparam.key, bindparam.key) ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string( - name, expanding=bindparam.expanding, **kwargs) + name, expanding=bindparam.expanding, **kwargs + ) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.effective_value @@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled): return processor(value) else: raise NotImplementedError( - "Don't know how to literal-quote value %r" % value) + "Don't know how to literal-quote value %r" % value + ) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: @@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled): if len(anonname) > self.label_length - 6: counter = self.truncated_names.get(ident_class, 1) - truncname = anonname[0:max(self.label_length - 6, 0)] + \ - "_" + hex(counter)[2:] + truncname = ( + anonname[0 : max(self.label_length - 6, 0)] + + "_" + + hex(counter)[2:] + ) self.truncated_names[ident_class] = counter + 1 else: truncname = anonname @@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled): return name % self.anon_map def _process_anon(self, key): - (ident, derived) = key.split(' ', 1) + (ident, derived) = key.split(" ", 1) anonymous_counter = self.anon_map.get(derived, 1) self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) def bindparam_string( - self, name, positional_names=None, expanding=False, **kw): + self, name, positional_names=None, expanding=False, **kw + ): if self.positional: if positional_names is not None: positional_names.append(name) @@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled): self.contains_expanding_parameters = True return "([EXPANDING_%s])" % name else: - return self.bindtemplate % {'name': name} - - def visit_cte(self, cte, asfrom=False, ashint=False, - fromhints=None, visiting_cte=None, - **kwargs): + return self.bindtemplate % {"name": name} + + def visit_cte( + self, + cte, + asfrom=False, + ashint=False, + fromhints=None, + visiting_cte=None, + **kwargs + ): self._init_cte_state() - kwargs['visiting_cte'] = cte + kwargs["visiting_cte"] = cte if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: @@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled): else: raise exc.CompileError( "Multiple, unrelated CTEs found with " - "the same name: %r" % - cte_name) + "the same name: %r" % cte_name + ) if asfrom or is_new_cte: if cte._cte_alias is not None: @@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled): cte_pre_alias_name = cte._cte_alias.name if isinstance(cte_pre_alias_name, elements._truncated_label): cte_pre_alias_name = self._truncated_identifier( - "alias", cte_pre_alias_name) + "alias", cte_pre_alias_name + ) else: pre_alias_cte = cte cte_pre_alias_name = None @@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled): self.ctes_by_name[cte_name] = cte # look for embedded DML ctes and propagate autocommit - if 'autocommit' in cte.element._execution_options and \ - 'autocommit' not in self.execution_options: + if ( + "autocommit" in cte.element._execution_options + and "autocommit" not in self.execution_options + ): self.execution_options = self.execution_options.union( - {"autocommit": - cte.element._execution_options['autocommit']}) + { + "autocommit": cte.element._execution_options[ + "autocommit" + ] + } + ) if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled): col_source = cte.original.selects[0] else: assert False - recur_cols = [c for c in - util.unique_list(col_source.inner_columns) - if c is not None] - - text += "(%s)" % (", ".join( - self.preparer.format_column(ident) - for ident in recur_cols)) + recur_cols = [ + c + for c in util.unique_list(col_source.inner_columns) + if c is not None + ] + + text += "(%s)" % ( + ", ".join( + self.preparer.format_column(ident) + for ident in recur_cols + ) + ) if self.positional: - kwargs['positional_names'] = self.cte_positional[cte] = [] + kwargs["positional_names"] = self.cte_positional[cte] = [] - text += " AS \n" + \ - cte.original._compiler_dispatch( - self, asfrom=True, **kwargs - ) + text += " AS \n" + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) if cte._suffixes: text += " " + self._generate_prefixes( - cte, cte._suffixes, **kwargs) + cte, cte._suffixes, **kwargs + ) self.ctes[cte] = text @@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) - def visit_alias(self, alias, asfrom=False, ashint=False, - iscrud=False, - fromhints=None, **kwargs): + def visit_alias( + self, + alias, + asfrom=False, + ashint=False, + iscrud=False, + fromhints=None, + **kwargs + ): if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) @@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = alias.original._compiler_dispatch(self, - asfrom=True, **kwargs) + \ - self.get_render_as_alias_suffix( - self.preparer.format_alias(alias, alias_name)) + ret = alias.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) if fromhints and alias in fromhints: - ret = self.format_from_hint_text(ret, alias, - fromhints[alias], iscrud) + ret = self.format_from_hint_text( + ret, alias, fromhints[alias], iscrud + ) return ret else: return alias.original._compiler_dispatch(self, **kwargs) def visit_lateral(self, lateral, **kw): - kw['lateral'] = True + kw["lateral"] = True return "LATERAL %s" % self.visit_alias(lateral, **kw) def visit_tablesample(self, tablesample, asfrom=False, **kw): text = "%s TABLESAMPLE %s" % ( self.visit_alias(tablesample, asfrom=True, **kw), - tablesample._get_method()._compiler_dispatch(self, **kw)) + tablesample._get_method()._compiler_dispatch(self, **kw), + ) if tablesample.seed is not None: text += " REPEATABLE (%s)" % ( - tablesample.seed._compiler_dispatch(self, **kw)) + tablesample.seed._compiler_dispatch(self, **kw) + ) return text @@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled): def _add_to_result_map(self, keyname, name, objects, type_): self._result_columns.append((keyname, name, objects, type_)) - def _label_select_column(self, select, column, - populate_result_map, - asfrom, column_clause_args, - name=None, - within_columns_clause=True): + def _label_select_column( + self, + select, + column, + populate_result_map, + asfrom, + column_clause_args, + name=None, + within_columns_clause=True, + ): """produce labeled columns present in a select().""" impl = column.type.dialect_impl(self.dialect) - if impl._has_column_expression and \ - populate_result_map: + if impl._has_column_expression and populate_result_map: col_expr = impl.column_expression(column) def add_to_result_map(keyname, name, objects, type_): self._add_to_result_map( - keyname, name, - (column,) + objects, type_) + keyname, name, (column,) + objects, type_ + ) + else: col_expr = column if populate_result_map: @@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled): elif isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( - col_expr, - column.name, - alt_names=(column.element,) + col_expr, column.name, alt_names=(column.element,) ) else: result_expr = col_expr elif select is not None and name: result_expr = _CompileLabel( + col_expr, name, alt_names=(column._key_label,) + ) + + elif ( + asfrom + and isinstance(column, elements.ColumnClause) + and not column.is_literal + and column.table is not None + and not isinstance(column.table, selectable.Select) + ): + result_expr = _CompileLabel( col_expr, - name, - alt_names=(column._key_label,) - ) - - elif \ - asfrom and \ - isinstance(column, elements.ColumnClause) and \ - not column.is_literal and \ - column.table is not None and \ - not isinstance(column.table, selectable.Select): - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + elements._as_truncated(column.name), + alt_names=(column.key,), + ) elif ( - not isinstance(column, elements.TextClause) and - ( - not isinstance(column, elements.UnaryExpression) or - column.wraps_column_expression - ) and - ( - not hasattr(column, 'name') or - isinstance(column, functions.Function) + not isinstance(column, elements.TextClause) + and ( + not isinstance(column, elements.UnaryExpression) + or column.wraps_column_expression + ) + and ( + not hasattr(column, "name") + or isinstance(column, functions.Function) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + result_expr = _CompileLabel( + col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,), + ) else: result_expr = col_expr column_clause_args.update( within_columns_clause=within_columns_clause, - add_to_result_map=add_to_result_map - ) - return result_expr._compiler_dispatch( - self, - **column_clause_args + add_to_result_map=add_to_result_map, ) + return result_expr._compiler_dispatch(self, **column_clause_args) def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) @@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() - if newelem.is_selectable and newelem._is_join and \ - isinstance(newelem.right, selectable.FromGrouping): + if ( + newelem.is_selectable + and newelem._is_join + and isinstance(newelem.right, selectable.FromGrouping) + ): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) @@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled): right = visit(newelem.right, **kw) selectable_ = selectable.Select( - [right.element], - use_labels=True).alias() + [right.element], use_labels=True + ).alias() for c in selectable_.c: c._key_label = c.key @@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled): elif newelem._is_from_container: # if we hit an Alias, CompoundSelect or ScalarSelect, put a # marker in the stack. - kw['transform_clue'] = 'select_container' + kw["transform_clue"] = "select_container" newelem._copy_internals(clone=visit, **kw) elif newelem.is_selectable and newelem._is_select: - barrier_select = kw.get('transform_clue', None) == \ - 'select_container' + barrier_select = ( + kw.get("transform_clue", None) == "select_container" + ) # if we're still descended from an # Alias/CompoundSelect/ScalarSelect, we're # in a FROM clause, so start with a new translate collection if barrier_select: column_translate.append({}) - kw['transform_clue'] = 'inside_select' + kw["transform_clue"] = "inside_select" newelem._copy_internals(clone=visit, **kw) if barrier_select: del column_translate[-1] @@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled): return visit(select) def _transform_result_map_for_nested_joins( - self, select, transformed_select): - inner_col = dict((c._key_label, c) for - c in transformed_select.inner_columns) - - d = dict( - (inner_col[c._key_label], c) - for c in select.inner_columns + self, select, transformed_select + ): + inner_col = dict( + (c._key_label, c) for c in transformed_select.inner_columns ) + d = dict((inner_col[c._key_label], c) for c in select.inner_columns) + self._result_columns = [ (key, name, tuple([d.get(col, col) for col in objs]), typ) for key, name, objs, typ in self._result_columns ] - _default_stack_entry = util.immutabledict([ - ('correlate_froms', frozenset()), - ('asfrom_froms', frozenset()) - ]) + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select(self, select, asfrom, lateral=False): # utility method to help external dialects @@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) return froms - def visit_select(self, select, asfrom=False, parens=True, - fromhints=None, - compound_index=0, - nested_join_translation=False, - select_wraps_for=None, - lateral=False, - **kwargs): - - needs_nested_translation = \ - select.use_labels and \ - not nested_join_translation and \ - not self.stack and \ - not self.dialect.supports_right_nested_joins + def visit_select( + self, + select, + asfrom=False, + parens=True, + fromhints=None, + compound_index=0, + nested_join_translation=False, + select_wraps_for=None, + lateral=False, + **kwargs + ): + + needs_nested_translation = ( + select.use_labels + and not nested_join_translation + and not self.stack + and not self.dialect.supports_right_nested_joins + ) if needs_nested_translation: transformed_select = self._transform_select_for_nested_joins( - select) + select + ) text = self.visit_select( - transformed_select, asfrom=asfrom, parens=parens, + transformed_select, + asfrom=asfrom, + parens=parens, fromhints=fromhints, compound_index=compound_index, - nested_join_translation=True, **kwargs + nested_join_translation=True, + **kwargs ) toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) # this was first proposed as part of #3372; however, it is not # reached in current tests and could possibly be an assertion # instead. - if not populate_result_map and 'add_to_result_map' in kwargs: - del kwargs['add_to_result_map'] + if not populate_result_map and "add_to_result_map" in kwargs: + del kwargs["add_to_result_map"] if needs_nested_translation: if populate_result_map: self._transform_result_map_for_nested_joins( - select, transformed_select) + select, transformed_select + ) return text froms = self._setup_select_stack(select, entry, asfrom, lateral) column_clause_args = kwargs.copy() - column_clause_args.update({ - 'within_label_clause': False, - 'within_columns_clause': False - }) + column_clause_args.update( + {"within_label_clause": False, "within_columns_clause": False} + ) text = "SELECT " # we're off to a good start ! @@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled): byfrom = None if select._prefixes: - text += self._generate_prefixes( - select, select._prefixes, **kwargs) + text += self._generate_prefixes(select, select._prefixes, **kwargs) text += self.get_select_precolumns(select, **kwargs) # the actual list of columns to print in the SELECT column list. inner_columns = [ - c for c in [ + c + for c in [ self._label_select_column( select, column, - populate_result_map, asfrom, + populate_result_map, + asfrom, column_clause_args, - name=name) + name=name, + ) for name, column in select._columns_plus_names ] if c is not None @@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled): translate = dict( zip( [name for (key, name) in select._columns_plus_names], - [name for (key, name) in - select_wraps_for._columns_plus_names]) + [ + name + for (key, name) in select_wraps_for._columns_plus_names + ], + ) ) self._result_columns = [ @@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs) + text, select, inner_columns, froms, byfrom, kwargs + ) if select._statement_hints: per_dialect = [ - ht for (dialect_name, ht) - in select._statement_hints - if dialect_name in ('*', self.dialect.name) + ht + for (dialect_name, ht) in select._statement_hints + if dialect_name in ("*", self.dialect.name) ] if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) @@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled): if select._suffixes: text += " " + self._generate_prefixes( - select, select._suffixes, **kwargs) + select, select._suffixes, **kwargs + ) self.stack.pop(-1) @@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled): return text def _setup_select_hints(self, select): - byfrom = dict([ - (from_, hinttext % { - 'name': from_._compiler_dispatch( - self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.items() - if dialect in ('*', self.dialect.name) - ]) + byfrom = dict( + [ + ( + from_, + hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)}, + ) + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom def _setup_select_stack(self, select, entry, asfrom, lateral): - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) new_correlate_froms = set(selectable._from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) new_entry = { - 'asfrom_froms': new_correlate_froms, - 'correlate_froms': all_correlate_froms, - 'selectable': select, + "asfrom_froms": new_correlate_froms, + "correlate_froms": all_correlate_froms, + "selectable": select, } self.stack.append(new_entry) return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs): - text += ', '.join(inner_columns) + self, text, select, inner_columns, froms, byfrom, kwargs + ): + text += ", ".join(inner_columns) if froms: text += " \nFROM " if select._hints: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, - fromhints=byfrom, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch( + self, asfrom=True, fromhints=byfrom, **kwargs + ) + for f in froms + ] + ) else: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch(self, asfrom=True, **kwargs) + for f in froms + ] + ) else: text += self.default_from() @@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled): if select._order_by_clause.clauses: text += self.order_by_clause(select, **kwargs) - if (select._limit_clause is not None or - select._offset_clause is not None): + if ( + select._limit_clause is not None + or select._offset_clause is not None + ): text += self.limit_clause(select, **kwargs) if select._for_update_arg is not None: @@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled): clause = " ".join( prefix._compiler_dispatch(self, **kw) for prefix, dialect_name in prefixes - if dialect_name is None or - dialect_name == self.dialect.name + if dialect_name is None or dialect_name == self.dialect.name ) if clause: clause += " " @@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled): def _render_cte_clause(self): if self.positional: - self.positiontup = sum([ - self.cte_positional[cte] - for cte in self.ctes], []) + \ - self.positiontup + self.positiontup = ( + sum([self.cte_positional[cte] for cte in self.ctes], []) + + self.positiontup + ) cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join( - [txt for txt in self.ctes.values()] - ) + cte_text += ", \n".join([txt for txt in self.ctes.values()]) cte_text += "\n " return cte_text @@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled): def returning_clause(self, stmt, returning_cols): raise exc.CompileError( "RETURNING is not supported by this " - "dialect's statement compiler.") + "dialect's statement compiler." + ) def limit_clause(self, select, **kw): text = "" @@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled): text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + **kwargs + ): if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) if use_schema and effective_schema: - ret = self.preparer.quote_schema(effective_schema) + \ - "." + self.preparer.quote(table.name) + ret = ( + self.preparer.quote_schema(effective_schema) + + "." + + self.preparer.quote(table.name) + ) else: ret = self.preparer.quote(table.name) if fromhints and table in fromhints: - ret = self.format_from_hint_text(ret, table, - fromhints[table], iscrud) + ret = self.format_from_hint_text( + ret, table, fromhints[table], iscrud + ) return ret else: return "" @@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + - join_type + - join.right._compiler_dispatch(self, asfrom=True, **kwargs) + - " ON " + - join.onclause._compiler_dispatch(self, **kwargs) + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + join_type + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + + join.onclause._compiler_dispatch(self, **kwargs) ) def _setup_crud_hints(self, stmt, table_text): - dialect_hints = dict([ - (table, hint_text) - for (table, dialect), hint_text in - stmt._hints.items() - if dialect in ('*', self.dialect.name) - ]) + dialect_hints = dict( + [ + (table, hint_text) + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) if stmt.table in dialect_hints: table_text = self.format_from_hint_text( - table_text, - stmt.table, - dialect_hints[stmt.table], - True + table_text, stmt.table, dialect_hints[stmt.table], True ) return dialect_hints, table_text @@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled): toplevel = not self.stack self.stack.append( - {'correlate_froms': set(), - "asfrom_froms": set(), - "selectable": insert_stmt}) + { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": insert_stmt, + } + ) crud_params = crud._setup_crud_params( - self, insert_stmt, crud.ISINSERT, **kw) + self, insert_stmt, crud.ISINSERT, **kw + ) - if not crud_params and \ - not self.dialect.supports_default_values and \ - not self.dialect.supports_empty_insert: - raise exc.CompileError("The '%s' dialect with current database " - "version settings does not support empty " - "inserts." % - self.dialect.name) + if ( + not crud_params + and not self.dialect.supports_default_values + and not self.dialect.supports_empty_insert + ): + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name + ) if insert_stmt._has_multi_parameters: if not self.dialect.supports_multivalues_insert: raise exc.CompileError( "The '%s' dialect with current database " "version settings does not support " - "in-place multirow inserts." % - self.dialect.name) + "in-place multirow inserts." % self.dialect.name + ) crud_params_single = crud_params[0] else: crud_params_single = crud_params @@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled): text = "INSERT " if insert_stmt._prefixes: - text += self._generate_prefixes(insert_stmt, - insert_stmt._prefixes, **kw) + text += self._generate_prefixes( + insert_stmt, insert_stmt._prefixes, **kw + ) text += "INTO " table_text = preparer.format_table(insert_stmt.table) if insert_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - insert_stmt, table_text) + insert_stmt, table_text + ) else: dialect_hints = None text += table_text if crud_params_single or not supports_default_values: - text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in crud_params_single]) + text += " (%s)" % ", ".join( + [preparer.format_column(c[0]) for c in crud_params_single] + ) if self.returning or insert_stmt._returning: returning_clause = self.returning_clause( - insert_stmt, self.returning or insert_stmt._returning) + insert_stmt, self.returning or insert_stmt._returning + ) if self.returning_precedes_values: text += " " + returning_clause @@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled): elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( ", ".join( - "(%s)" % ( - ', '.join(c[1] for c in crud_param_set) - ) + "(%s)" % (", ".join(c[1] for c in crud_param_set)) for crud_param_set in crud_params ) ) else: - text += " VALUES (%s)" % \ - ', '.join([c[1] for c in crud_params]) + text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params]) if insert_stmt._post_values_clause is not None: post_values_clause = self.process( - insert_stmt._post_values_clause, **kw) + insert_stmt._post_values_clause, **kw + ) if post_values_clause: text += " " + post_values_clause @@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled): """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. MySQL overrides this. """ - kw['asfrom'] = True + kw["asfrom"] = True return from_table._compiler_dispatch(self, iscrud=True, **kw) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an UPDATE..FROM clause. @@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within UPDATE") + "criteria within UPDATE" + ) def visit_update(self, update_stmt, asfrom=False, **kw): toplevel = not self.stack @@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled): correlate_froms = {update_stmt.table} self.stack.append( - {'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": update_stmt}) + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) text = "UPDATE " if update_stmt._prefixes: - text += self._generate_prefixes(update_stmt, - update_stmt._prefixes, **kw) + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) - table_text = self.update_tables_clause(update_stmt, update_stmt.table, - render_extra_froms, **kw) + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) crud_params = crud._setup_crud_params( - self, update_stmt, crud.ISUPDATE, **kw) + self, update_stmt, crud.ISUPDATE, **kw + ) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - update_stmt, table_text) + update_stmt, table_text + ) else: dialect_hints = None text += table_text - text += ' SET ' - include_table = is_multitable and \ - self.render_table_with_column_in_update_from - text += ', '.join( - c[0]._compiler_dispatch(self, - include_table=include_table) + - '=' + c[1] for c in crud_params + text += " SET " + include_table = ( + is_multitable and self.render_table_with_column_in_update_from + ) + text += ", ".join( + c[0]._compiler_dispatch(self, include_table=include_table) + + "=" + + c[1] + for c in crud_params ) if self.returning or update_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if extra_froms: extra_from_text = self.update_from_clause( update_stmt, update_stmt.table, render_extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if (self.returning or update_stmt._returning) and \ - not self.returning_precedes_values: + if ( + self.returning or update_stmt._returning + ) and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled): def _key_getters_for_crud_column(self): return crud._key_getters_for_crud_column(self, self.statement) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, **kw): + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an DELETE..FROM clause. @@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within DELETE") + "criteria within DELETE" + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, asfrom=False, **kw): @@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled): extra_froms = delete_stmt._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) - self.stack.append({'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": delete_stmt}) + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": delete_stmt, + } + ) text = "DELETE " if delete_stmt._prefixes: - text += self._generate_prefixes(delete_stmt, - delete_stmt._prefixes, **kw) + text += self._generate_prefixes( + delete_stmt, delete_stmt._prefixes, **kw + ) text += "FROM " - table_text = self.delete_table_clause(delete_stmt, delete_stmt.table, - extra_froms) + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) if delete_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - delete_stmt, table_text) + delete_stmt, table_text + ) else: dialect_hints = None @@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled): if delete_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if extra_froms: extra_from_text = self.delete_extra_from_clause( delete_stmt, delete_stmt.table, extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled): if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) class StrSQLCompiler(SQLCompiler): @@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler): def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_json_getitem_op_binary(self, binary, operator, **kw): @@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler): for c in elements._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return ', ' + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return ", " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) class DDLCompiler(Compiled): - @util.memoized_property def sql_compiler(self): return self.dialect.statement_compiler(self.dialect, None) @@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled): preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: - table, sch = path[0], '' + table, sch = path[0], "" else: table, sch = path[-1], path[0] - context.setdefault('table', table) - context.setdefault('schema', sch) - context.setdefault('fullname', preparer.format_table(ddl.target)) + context.setdefault("table", table) + context.setdefault("schema", sch) + context.setdefault("fullname", preparer.format_table(ddl.target)) return self.sql_compiler.post_process_text(ddl.statement % context) @@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled): for create_column in create.columns: column = create_column.element try: - processed = self.process(create_column, - first_pk=column.primary_key - and not first_pk) + processed = self.process( + create_column, first_pk=column.primary_key and not first_pk + ) if processed is not None: text += separator separator = ", \n" @@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled): except exc.CompileError as ce: util.raise_from_cause( exc.CompileError( - util.u("(in table '%s', column '%s'): %s") % - (table.description, column.name, ce.args[0]) - )) + util.u("(in table '%s', column '%s'): %s") + % (table.description, column.name, ce.args[0]) + ) + ) const = self.create_table_constraints( - table, _include_foreign_key_constraints= # noqa - create.include_foreign_key_constraints) + table, + _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa + ) if const: text += separator + "\t" + const @@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled): if column.system: return None - text = self.get_column_specification( - column, - first_pk=first_pk + text = self.get_column_specification(column, first_pk=first_pk) + const = " ".join( + self.process(constraint) for constraint in column.constraints ) - const = " ".join(self.process(constraint) - for constraint in column.constraints) if const: text += " " + const return text def create_table_constraints( - self, table, - _include_foreign_key_constraints=None): + self, table, _include_foreign_key_constraints=None + ): # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) @@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled): else: omit_fkcs = set() - constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key and - c not in omit_fkcs]) + constraints.extend( + [ + c + for c in table._sorted_constraints + if c is not table.primary_key and c not in omit_fkcs + ] + ) return ", \n\t".join( - p for p in - (self.process(constraint) + p + for p in ( + self.process(constraint) for constraint in constraints if ( - constraint._create_rule is None or - constraint._create_rule(self)) + constraint._create_rule is None + or constraint._create_rule(self) + ) and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None + not self.dialect.supports_alter + or not getattr(constraint, "use_alter", False) + ) + ) + if p is not None ) def visit_drop_table(self, drop): @@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled): def _verify_index_table(self, index): if index.table is None: - raise exc.CompileError("Index '%s' is not associated " - "with any table." % index.name) + raise exc.CompileError( + "Index '%s' is not associated " "with any table." % index.name + ) - def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): index = create.element self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table, - use_schema=include_table_schema), - ', '.join( - self.sql_compiler.process( - expr, include_table=False, literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table( + index.table, use_schema=include_table_schema + ), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) return text def visit_drop_index(self, drop): index = drop.element return "\nDROP INDEX " + self._prepared_index_name( - index, include_schema=True) + index, include_schema=True + ) def _prepared_index_name(self, index, include_schema=False): if index.table is not None: @@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled): def visit_add_constraint(self, create): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), - self.process(create.element) + self.process(create.element), ) def visit_set_table_comment(self, create): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_table_comment(self, drop): - return "COMMENT ON TABLE %s IS NULL" % \ - self.preparer.format_table(drop.element) + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) def visit_set_column_comment(self, create): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( - create.element, use_table=True, use_schema=True), + create.element, use_table=True, use_schema=True + ), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_column_comment(self, drop): - return "COMMENT ON COLUMN %s IS NULL" % \ - self.preparer.format_column(drop.element, use_table=True) + return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( + drop.element, use_table=True + ) def visit_create_sequence(self, create): - text = "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) if create.element.increment is not None: text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: @@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled): return text def visit_drop_sequence(self, drop): - return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): constraint = drop.element @@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled): if formatted_name is None: raise exc.CompileError( "Can't emit DROP CONSTRAINT for constraint %r; " - "it has no name" % drop.element) + "it has no name" % drop.element + ) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), formatted_name, - drop.cascade and " CASCADE" or "" + drop.cascade and " CASCADE" or "", ) def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process( - column.type, type_expression=column) + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled): return colspec def create_table_suffix(self, table): - return '' + return "" def post_create_table(self, table): - return '' + return "" def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, util.string_types): return self.sql_compiler.render_literal_value( - column.server_default.arg, sqltypes.STRINGTYPE) + column.server_default.arg, sqltypes.STRINGTYPE + ) else: return self.sql_compiler.process( - column.server_default.arg, literal_binds=True) + column.server_default.arg, literal_binds=True + ) else: return None @@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text @@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name text += "PRIMARY KEY " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in (constraint.columns_autoinc_first - if constraint._implicit_generated - else constraint.columns)) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled): text += "CONSTRAINT %s " % formatted_name remote_table = list(constraint.elements)[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join(preparer.quote(f.parent.name) - for f in constraint.elements), + ", ".join( + preparer.quote(f.parent.name) for f in constraint.elements + ), self.define_constraint_remote_table( - constraint, remote_table, preparer), - ', '.join(preparer.quote(f.column.name) - for f in constraint.elements) + constraint, remote_table, preparer + ), + ", ".join( + preparer.quote(f.column.name) for f in constraint.elements + ), ) text += self.define_constraint_match(constraint) text += self.define_constraint_cascades(constraint) @@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled): def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) text += "CONSTRAINT %s " % formatted_name text += "UNIQUE (%s)" % ( - ', '.join(self.preparer.quote(c.name) - for c in constraint)) + ", ".join(self.preparer.quote(c.name) for c in constraint) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_, **kw): return "FLOAT" @@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler): if type_.precision is None: return "NUMERIC" elif type_.scale is None: - return "NUMERIC(%(precision)s)" % \ - {'precision': type_.precision} + return "NUMERIC(%(precision)s)" % {"precision": type_.precision} else: - return "NUMERIC(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "NUMERIC(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: - return "DECIMAL(%(precision)s)" % \ - {'precision': type_.precision} + return "DECIMAL(%(precision)s)" % {"precision": type_.precision} else: - return "DECIMAL(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "DECIMAL(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_INTEGER(self, type_, **kw): return "INTEGER" @@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler): return "BIGINT" def visit_TIMESTAMP(self, type_, **kw): - return 'TIMESTAMP' + return "TIMESTAMP" def visit_DATETIME(self, type_, **kw): return "DATETIME" @@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler): return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): - raise exc.CompileError("Can't generate DDL for %r; " - "did you forget to specify a " - "type on this Column?" % type_) + raise exc.CompileError( + "Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_ + ) def visit_type_decorator(self, type_, **kw): return self.process(type_.type_engine(self.dialect), **kw) @@ -3018,9 +3353,15 @@ class IdentifierPreparer(object): schema_for_object = schema._schema_getter(None) - def __init__(self, dialect, initial_quote='"', - final_quote=None, escape_quote='"', - quote_case_sensitive_collations=True, omit_schema=False): + def __init__( + self, + dialect, + initial_quote='"', + final_quote=None, + escape_quote='"', + quote_case_sensitive_collations=True, + omit_schema=False, + ): """Construct a new ``IdentifierPreparer`` object. initial_quote @@ -3043,7 +3384,10 @@ class IdentifierPreparer(object): self.omit_schema = omit_schema self.quote_case_sensitive_collations = quote_case_sensitive_collations self._strings = {} - self._double_percents = self.dialect.paramstyle in ('format', 'pyformat') + self._double_percents = self.dialect.paramstyle in ( + "format", + "pyformat", + ) def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) @@ -3060,7 +3404,7 @@ class IdentifierPreparer(object): value = value.replace(self.escape_quote, self.escape_to_quote) if self._double_percents: - value = value.replace('%', '%%') + value = value.replace("%", "%%") return value def _unescape_identifier(self, value): @@ -3079,17 +3423,21 @@ class IdentifierPreparer(object): quoting behavior. """ - return self.initial_quote + \ - self._escape_identifier(value) + \ - self.final_quote + return ( + self.initial_quote + + self._escape_identifier(value) + + self.final_quote + ) def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() - return (lc_value in self.reserved_words - or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) - or (lc_value != value)) + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + or (lc_value != value) + ) def quote_schema(self, schema, force=None): """Conditionally quote a schema. @@ -3135,8 +3483,11 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(sequence) - if (not self.omit_schema and use_schema and - effective_schema is not None): + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): name = self.quote_schema(effective_schema) + "." + name return name @@ -3159,7 +3510,8 @@ class IdentifierPreparer(object): def format_constraint(self, naming, constraint): if isinstance(constraint.name, elements._defer_name): name = naming._constraint_name_for_table( - constraint, constraint.table) + constraint, constraint.table + ) if name is None: if isinstance(constraint.name, elements._defer_none_name): @@ -3170,14 +3522,15 @@ class IdentifierPreparer(object): name = constraint.name if isinstance(name, elements._truncated_label): - if constraint.__visit_name__ == 'index': - max_ = self.dialect.max_index_name_length or \ - self.dialect.max_identifier_length + if constraint.__visit_name__ == "index": + max_ = ( + self.dialect.max_index_name_length + or self.dialect.max_identifier_length + ) else: max_ = self.dialect.max_identifier_length if len(name) > max_: - name = name[0:max_ - 8] + \ - "_" + util.md5_hex(name)[-4:] + name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] else: self.dialect.validate_identifier(name) @@ -3195,8 +3548,7 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema \ - and effective_schema: + if not self.omit_schema and use_schema and effective_schema: result = self.quote_schema(effective_schema) + "." + result return result @@ -3205,17 +3557,27 @@ class IdentifierPreparer(object): return self.quote(name, quote) - def format_column(self, column, use_table=False, - name=None, table_name=None, use_schema=False): + def format_column( + self, + column, + use_table=False, + name=None, + table_name=None, + use_schema=False, + ): """Prepare a quoted column name.""" if name is None: name = column.name - if not getattr(column, 'is_literal', False): + if not getattr(column, "is_literal", False): if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + "." + self.quote(name) + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + self.quote(name) + ) else: return self.quote(name) else: @@ -3223,9 +3585,13 @@ class IdentifierPreparer(object): # which shouldn't get quoted if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + '.' + name + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + name + ) else: return name @@ -3238,31 +3604,37 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema and \ - effective_schema: - return (self.quote_schema(effective_schema), - self.format_table(table, use_schema=False)) + if not self.omit_schema and use_schema and effective_schema: + return ( + self.quote_schema(effective_schema), + self.format_table(table, use_schema=False), + ) else: - return (self.format_table(table, use_schema=False), ) + return (self.format_table(table, use_schema=False),) @util.memoized_property def _r_identifiers(self): - initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] + initial, final, escaped_final = [ + re.escape(s) + for s in ( + self.initial_quote, + self.final_quote, + self._escape_identifier(self.final_quote), + ) + ] r = re.compile( - r'(?:' - r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' - r'|([^\.]+))(?=\.|$))+' % - {'initial': initial, - 'final': final, - 'escaped': escaped_final}) + r"(?:" + r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" + r"|([^\.]+))(?=\.|$))+" + % {"initial": initial, "final": final, "escaped": escaped_final} + ) return r def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" r = self._r_identifiers - return [self._unescape_identifier(i) - for i in [a or b for a, b in r.findall(identifiers)]] + return [ + self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)] + ] |