diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/sql/compiler.py | |
parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2030 |
1 files changed, 1201 insertions, 829 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 80ed707ed..f641d0a84 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -25,133 +25,218 @@ To generate user-defined SQL strings, see import contextlib import re -from . import schema, sqltypes, operators, functions, visitors, \ - elements, selectable, crud +from . import ( + schema, + sqltypes, + operators, + functions, + visitors, + elements, + selectable, + crud, +) from .. import util, exc import itertools -RESERVED_WORDS = set([ - 'all', 'analyse', 'analyze', 'and', 'any', 'array', - 'as', 'asc', 'asymmetric', 'authorization', 'between', - 'binary', 'both', 'case', 'cast', 'check', 'collate', - 'column', 'constraint', 'create', 'cross', 'current_date', - 'current_role', 'current_time', 'current_timestamp', - 'current_user', 'default', 'deferrable', 'desc', - 'distinct', 'do', 'else', 'end', 'except', 'false', - 'for', 'foreign', 'freeze', 'from', 'full', 'grant', - 'group', 'having', 'ilike', 'in', 'initially', 'inner', - 'intersect', 'into', 'is', 'isnull', 'join', 'leading', - 'left', 'like', 'limit', 'localtime', 'localtimestamp', - 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', - 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', - 'placing', 'primary', 'references', 'right', 'select', - 'session_user', 'set', 'similar', 'some', 'symmetric', 'table', - 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', - 'using', 'verbose', 'when', 'where']) - -LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) -ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$']) - -BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) -BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE) +RESERVED_WORDS = set( + [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", + ] +) + +LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) +ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"]) + +BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE) +BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE) BIND_TEMPLATES = { - 'pyformat': "%%(%(name)s)s", - 'qmark': "?", - 'format': "%%s", - 'numeric': ":[_POSITION]", - 'named': ":%(name)s" + "pyformat": "%%(%(name)s)s", + "qmark": "?", + "format": "%%s", + "numeric": ":[_POSITION]", + "named": ":%(name)s", } OPERATORS = { # binary - operators.and_: ' AND ', - operators.or_: ' OR ', - operators.add: ' + ', - operators.mul: ' * ', - operators.sub: ' - ', - operators.div: ' / ', - operators.mod: ' % ', - operators.truediv: ' / ', - operators.neg: '-', - operators.lt: ' < ', - operators.le: ' <= ', - operators.ne: ' != ', - operators.gt: ' > ', - operators.ge: ' >= ', - operators.eq: ' = ', - operators.is_distinct_from: ' IS DISTINCT FROM ', - operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ', - operators.concat_op: ' || ', - operators.match_op: ' MATCH ', - operators.notmatch_op: ' NOT MATCH ', - operators.in_op: ' IN ', - operators.notin_op: ' NOT IN ', - operators.comma_op: ', ', - operators.from_: ' FROM ', - operators.as_: ' AS ', - operators.is_: ' IS ', - operators.isnot: ' IS NOT ', - operators.collate: ' COLLATE ', - + operators.and_: " AND ", + operators.or_: " OR ", + operators.add: " + ", + operators.mul: " * ", + operators.sub: " - ", + operators.div: " / ", + operators.mod: " % ", + operators.truediv: " / ", + operators.neg: "-", + operators.lt: " < ", + operators.le: " <= ", + operators.ne: " != ", + operators.gt: " > ", + operators.ge: " >= ", + operators.eq: " = ", + operators.is_distinct_from: " IS DISTINCT FROM ", + operators.isnot_distinct_from: " IS NOT DISTINCT FROM ", + operators.concat_op: " || ", + operators.match_op: " MATCH ", + operators.notmatch_op: " NOT MATCH ", + operators.in_op: " IN ", + operators.notin_op: " NOT IN ", + operators.comma_op: ", ", + operators.from_: " FROM ", + operators.as_: " AS ", + operators.is_: " IS ", + operators.isnot: " IS NOT ", + operators.collate: " COLLATE ", # unary - operators.exists: 'EXISTS ', - operators.distinct_op: 'DISTINCT ', - operators.inv: 'NOT ', - operators.any_op: 'ANY ', - operators.all_op: 'ALL ', - + operators.exists: "EXISTS ", + operators.distinct_op: "DISTINCT ", + operators.inv: "NOT ", + operators.any_op: "ANY ", + operators.all_op: "ALL ", # modifiers - operators.desc_op: ' DESC', - operators.asc_op: ' ASC', - operators.nullsfirst_op: ' NULLS FIRST', - operators.nullslast_op: ' NULLS LAST', - + operators.desc_op: " DESC", + operators.asc_op: " ASC", + operators.nullsfirst_op: " NULLS FIRST", + operators.nullslast_op: " NULLS LAST", } FUNCTIONS = { - functions.coalesce: 'coalesce', - functions.current_date: 'CURRENT_DATE', - functions.current_time: 'CURRENT_TIME', - functions.current_timestamp: 'CURRENT_TIMESTAMP', - functions.current_user: 'CURRENT_USER', - functions.localtime: 'LOCALTIME', - functions.localtimestamp: 'LOCALTIMESTAMP', - functions.random: 'random', - functions.sysdate: 'sysdate', - functions.session_user: 'SESSION_USER', - functions.user: 'USER', - functions.cube: 'CUBE', - functions.rollup: 'ROLLUP', - functions.grouping_sets: 'GROUPING SETS', + functions.coalesce: "coalesce", + functions.current_date: "CURRENT_DATE", + functions.current_time: "CURRENT_TIME", + functions.current_timestamp: "CURRENT_TIMESTAMP", + functions.current_user: "CURRENT_USER", + functions.localtime: "LOCALTIME", + functions.localtimestamp: "LOCALTIMESTAMP", + functions.random: "random", + functions.sysdate: "sysdate", + functions.session_user: "SESSION_USER", + functions.user: "USER", + functions.cube: "CUBE", + functions.rollup: "ROLLUP", + functions.grouping_sets: "GROUPING SETS", } EXTRACT_MAP = { - 'month': 'month', - 'day': 'day', - 'year': 'year', - 'second': 'second', - 'hour': 'hour', - 'doy': 'doy', - 'minute': 'minute', - 'quarter': 'quarter', - 'dow': 'dow', - 'week': 'week', - 'epoch': 'epoch', - 'milliseconds': 'milliseconds', - 'microseconds': 'microseconds', - 'timezone_hour': 'timezone_hour', - 'timezone_minute': 'timezone_minute' + "month": "month", + "day": "day", + "year": "year", + "second": "second", + "hour": "hour", + "doy": "doy", + "minute": "minute", + "quarter": "quarter", + "dow": "dow", + "week": "week", + "epoch": "epoch", + "milliseconds": "milliseconds", + "microseconds": "microseconds", + "timezone_hour": "timezone_hour", + "timezone_minute": "timezone_minute", } COMPOUND_KEYWORDS = { - selectable.CompoundSelect.UNION: 'UNION', - selectable.CompoundSelect.UNION_ALL: 'UNION ALL', - selectable.CompoundSelect.EXCEPT: 'EXCEPT', - selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', - selectable.CompoundSelect.INTERSECT: 'INTERSECT', - selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' + selectable.CompoundSelect.UNION: "UNION", + selectable.CompoundSelect.UNION_ALL: "UNION ALL", + selectable.CompoundSelect.EXCEPT: "EXCEPT", + selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", + selectable.CompoundSelect.INTERSECT: "INTERSECT", + selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", } @@ -177,9 +262,14 @@ class Compiled(object): sub-elements of the statement can modify these. """ - def __init__(self, dialect, statement, bind=None, - schema_translate_map=None, - compile_kwargs=util.immutabledict()): + def __init__( + self, + dialect, + statement, + bind=None, + schema_translate_map=None, + compile_kwargs=util.immutabledict(), + ): """Construct a new :class:`.Compiled` object. :param dialect: :class:`.Dialect` to compile against. @@ -209,7 +299,8 @@ class Compiled(object): self.preparer = self.dialect.identifier_preparer if schema_translate_map: self.preparer = self.preparer._with_schema_translate( - schema_translate_map) + schema_translate_map + ) if statement is not None: self.statement = statement @@ -218,8 +309,10 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) - @util.deprecated("0.7", ":class:`.Compiled` objects now compile " - "within the constructor.") + @util.deprecated( + "0.7", + ":class:`.Compiled` objects now compile " "within the constructor.", + ) def compile(self): """Produce the internal string representation of this element. """ @@ -247,7 +340,7 @@ class Compiled(object): def __str__(self): """Return the string text of the generated SQL or DDL.""" - return self.string or '' + return self.string or "" def construct_params(self, params=None): """Return the bind params for this compiled object. @@ -271,7 +364,9 @@ class Compiled(object): if e is None: raise exc.UnboundExecutionError( "This Compiled object is not bound to any Engine " - "or Connection.", code="2afi") + "or Connection.", + code="2afi", + ) return e._execute_compiled(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -284,7 +379,7 @@ class Compiled(object): class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): """Produces DDL specification for TypeEngine objects.""" - ensure_kwarg = r'visit_\w+' + ensure_kwarg = r"visit_\w+" def __init__(self, dialect): self.dialect = dialect @@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression.Label.""" - __visit_name__ = 'label' - __slots__ = 'element', 'name' + __visit_name__ = "label" + __slots__ = "element", "name" def __init__(self, col, name, alt_names=()): self.element = col @@ -390,8 +485,9 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () - def __init__(self, dialect, statement, column_keys=None, - inline=False, **kwargs): + def __init__( + self, dialect, statement, column_keys=None, inline=False, **kwargs + ): """Construct a new :class:`.SQLCompiler` object. :param dialect: :class:`.Dialect` to be used @@ -412,7 +508,7 @@ class SQLCompiler(Compiled): # compile INSERT/UPDATE defaults/sequences inlined (no pre- # execute) - self.inline = inline or getattr(statement, 'inline', False) + self.inline = inline or getattr(statement, "inline", False) # a dictionary of bind parameter keys to BindParameter # instances. @@ -440,8 +536,9 @@ class SQLCompiler(Compiled): self.ctes = None - self.label_length = dialect.label_length \ - or dialect.max_identifier_length + self.label_length = ( + dialect.label_length or dialect.max_identifier_length + ) # a map which tracks "anonymous" identifiers that are created on # the fly here @@ -453,7 +550,7 @@ class SQLCompiler(Compiled): Compiled.__init__(self, dialect, statement, **kwargs) if ( - self.isinsert or self.isupdate or self.isdelete + self.isinsert or self.isupdate or self.isdelete ) and statement._returning: self.returning = statement._returning @@ -482,37 +579,43 @@ class SQLCompiler(Compiled): def _nested_result(self): """special API to support the use case of 'nested result sets'""" result_columns, ordered_columns = ( - self._result_columns, self._ordered_columns) + self._result_columns, + self._ordered_columns, + ) self._result_columns, self._ordered_columns = [], False try: if self.stack: entry = self.stack[-1] - entry['need_result_map_for_nested'] = True + entry["need_result_map_for_nested"] = True else: entry = None yield self._result_columns, self._ordered_columns finally: if entry: - entry.pop('need_result_map_for_nested') + entry.pop("need_result_map_for_nested") self._result_columns, self._ordered_columns = ( - result_columns, ordered_columns) + result_columns, + ordered_columns, + ) def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r'\[_POSITION\]', - lambda m: str(util.next(poscount)), - self.string) + r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + ) @util.memoized_property def _bind_processors(self): return dict( - (key, value) for key, value in - ((self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect) - ) - for bindparam in self.bind_names) + (key, value) + for key, value in ( + ( + self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect), + ) + for bindparam in self.bind_names + ) if value is not None ) @@ -539,12 +642,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) elif bindparam.callable: pd[name] = bindparam.effective_value @@ -558,12 +665,16 @@ class SQLCompiler(Compiled): if _group_number: raise exc.InvalidRequestError( "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number), code="cd3x") + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) else: raise exc.InvalidRequestError( "A value is required for bind parameter %r" - % bindparam.key, code="cd3x") + % bindparam.key, + code="cd3x", + ) if bindparam.callable: pd[self.bind_names[bindparam]] = bindparam.effective_value @@ -595,9 +706,10 @@ class SQLCompiler(Compiled): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" def visit_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if self.stack and self.dialect.supports_simple_order_by_label: - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict if within_columns_clause: @@ -611,25 +723,30 @@ class SQLCompiler(Compiled): # to something else like a ColumnClause expression. order_by_elem = element.element._order_by_label_element - if order_by_elem is not None and order_by_elem.name in \ - resolve_dict and \ - order_by_elem.shares_lineage( - resolve_dict[order_by_elem.name]): - kwargs['render_label_as_label'] = \ - element.element._order_by_label_element + if ( + order_by_elem is not None + and order_by_elem.name in resolve_dict + and order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name] + ) + ): + kwargs[ + "render_label_as_label" + ] = element.element._order_by_label_element return self.process( - element.element, within_columns_clause=within_columns_clause, - **kwargs) + element.element, + within_columns_clause=within_columns_clause, + **kwargs + ) def visit_textual_label_reference( - self, element, within_columns_clause=False, **kwargs): + self, element, within_columns_clause=False, **kwargs + ): if not self.stack: # compiling the element outside of the context of a SELECT - return self.process( - element._text_clause - ) + return self.process(element._text_clause) - selectable = self.stack[-1]['selectable'] + selectable = self.stack[-1]["selectable"] with_cols, only_froms, only_cols = selectable._label_resolve_dict try: if within_columns_clause: @@ -640,26 +757,30 @@ class SQLCompiler(Compiled): # treat it like text() util.warn_limited( "Can't resolve label reference %r; converting to text()", - util.ellipses_string(element.element)) - return self.process( - element._text_clause + util.ellipses_string(element.element), ) + return self.process(element._text_clause) else: - kwargs['render_label_as_label'] = col + kwargs["render_label_as_label"] = col return self.process( - col, within_columns_clause=within_columns_clause, **kwargs) - - def visit_label(self, label, - add_to_result_map=None, - within_label_clause=False, - within_columns_clause=False, - render_label_as_label=None, - **kw): + col, within_columns_clause=within_columns_clause, **kwargs + ) + + def visit_label( + self, + label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + **kw + ): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. - render_label_with_as = (within_columns_clause and not - within_label_clause) + render_label_with_as = ( + within_columns_clause and not within_label_clause + ) render_label_only = render_label_as_label is label if render_label_only or render_label_with_as: @@ -673,27 +794,35 @@ class SQLCompiler(Compiled): add_to_result_map( labelname, label.name, - (label, labelname, ) + label._alt_names, - label.type + (label, labelname) + label._alt_names, + label.type, ) - return label.element._compiler_dispatch( - self, within_columns_clause=True, - within_label_clause=True, **kw) + \ - OPERATORS[operators.as_] + \ - self.preparer.format_label(label, labelname) + return ( + label.element._compiler_dispatch( + self, + within_columns_clause=True, + within_label_clause=True, + **kw + ) + + OPERATORS[operators.as_] + + self.preparer.format_label(label, labelname) + ) elif render_label_only: return self.preparer.format_label(label, labelname) else: return label.element._compiler_dispatch( - self, within_columns_clause=False, **kw) + self, within_columns_clause=False, **kw + ) def _fallback_column_name(self, column): - raise exc.CompileError("Cannot compile Column object until " - "its 'name' is assigned.") + raise exc.CompileError( + "Cannot compile Column object until " "its 'name' is assigned." + ) - def visit_column(self, column, add_to_result_map=None, - include_table=True, **kwargs): + def visit_column( + self, column, add_to_result_map=None, include_table=True, **kwargs + ): name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -704,10 +833,7 @@ class SQLCompiler(Compiled): if add_to_result_map is not None: add_to_result_map( - name, - orig_name, - (column, name, column.key), - column.type + name, orig_name, (column, name, column.key), column.type ) if is_literal: @@ -721,17 +847,16 @@ class SQLCompiler(Compiled): effective_schema = self.preparer.schema_for_object(table) if effective_schema: - schema_prefix = self.preparer.quote_schema( - effective_schema) + '.' + schema_prefix = ( + self.preparer.quote_schema(effective_schema) + "." + ) else: - schema_prefix = '' + schema_prefix = "" tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) - return schema_prefix + \ - self.preparer.quote(tablename) + \ - "." + name + return schema_prefix + self.preparer.quote(tablename) + "." + name def visit_collation(self, element, **kw): return self.preparer.format_collation(element.collation) @@ -743,17 +868,17 @@ class SQLCompiler(Compiled): return index.name def visit_typeclause(self, typeclause, **kw): - kw['type_expression'] = typeclause + kw["type_expression"] = typeclause return self.dialect.type_compiler.process(typeclause.type, **kw) def post_process_text(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def escape_literal_column(self, text): if self.preparer._double_percents: - text = text.replace('%', '%%') + text = text.replace("%", "%%") return text def visit_textclause(self, textclause, **kw): @@ -771,30 +896,36 @@ class SQLCompiler(Compiled): return BIND_PARAMS_ESC.sub( lambda m: m.group(1), BIND_PARAMS.sub( - do_bindparam, - self.post_process_text(textclause.text)) + do_bindparam, self.post_process_text(textclause.text) + ), ) - def visit_text_as_from(self, taf, - compound_index=None, - asfrom=False, - parens=True, **kw): + def visit_text_as_from( + self, taf, compound_index=None, asfrom=False, parens=True, **kw + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) if populate_result_map: - self._ordered_columns = \ - self._textual_ordered_columns = taf.positional + self._ordered_columns = ( + self._textual_ordered_columns + ) = taf.positional for c in taf.column_args: - self.process(c, within_columns_clause=True, - add_to_result_map=self._add_to_result_map) + self.process( + c, + within_columns_clause=True, + add_to_result_map=self._add_to_result_map, + ) text = self.process(taf.element, **kw) if asfrom and parens: @@ -802,17 +933,17 @@ class SQLCompiler(Compiled): return text def visit_null(self, expr, **kw): - return 'NULL' + return "NULL" def visit_true(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'true' + return "true" else: return "1" def visit_false(self, expr, **kw): if self.dialect.supports_native_boolean: - return 'false' + return "false" else: return "0" @@ -823,25 +954,29 @@ class SQLCompiler(Compiled): else: sep = OPERATORS[clauselist.operator] return sep.join( - s for s in - ( - c._compiler_dispatch(self, **kw) - for c in clauselist.clauses) - if s) + s + for s in ( + c._compiler_dispatch(self, **kw) for c in clauselist.clauses + ) + if s + ) def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: x += clause.value._compiler_dispatch(self, **kwargs) + " " for cond, result in clause.whens: - x += "WHEN " + cond._compiler_dispatch( - self, **kwargs - ) + " THEN " + result._compiler_dispatch( - self, **kwargs) + " " + x += ( + "WHEN " + + cond._compiler_dispatch(self, **kwargs) + + " THEN " + + result._compiler_dispatch(self, **kwargs) + + " " + ) if clause.else_ is not None: - x += "ELSE " + clause.else_._compiler_dispatch( - self, **kwargs - ) + " " + x += ( + "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " " + ) x += "END" return x @@ -849,79 +984,84 @@ class SQLCompiler(Compiled): return type_coerce.typed_expression._compiler_dispatch(self, **kw) def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % \ - (cast.clause._compiler_dispatch(self, **kwargs), - cast.typeclause._compiler_dispatch(self, **kwargs)) + return "CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs), + ) def _format_frame_clause(self, range_, **kw): - return '%s AND %s' % ( + return "%s AND %s" % ( "UNBOUNDED PRECEDING" if range_[0] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[0])), **kw), ) + else "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[0])), **kw),) if range_[0] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[0]), **kw), ), - + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),), "UNBOUNDED FOLLOWING" if range_[1] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT - else "%s PRECEDING" % ( - self.process(elements.literal(abs(range_[1])), **kw), ) + else "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[1])), **kw),) if range_[1] < 0 - else "%s FOLLOWING" % ( - self.process(elements.literal(range_[1]), **kw), ), + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),), ) def visit_over(self, over, **kwargs): if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( - over.range_, **kwargs) + over.range_, **kwargs + ) elif over.rows: range_ = "ROWS BETWEEN %s" % self._format_frame_clause( - over.rows, **kwargs) + over.rows, **kwargs + ) else: range_ = None return "%s OVER (%s)" % ( over.element._compiler_dispatch(self, **kwargs), - ' '.join([ - '%s BY %s' % ( - word, clause._compiler_dispatch(self, **kwargs) - ) - for word, clause in ( - ('PARTITION', over.partition_by), - ('ORDER', over.order_by) - ) - if clause is not None and len(clause) - ] + ([range_] if range_ else []) - ) + " ".join( + [ + "%s BY %s" + % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + ([range_] if range_ else []) + ), ) def visit_withingroup(self, withingroup, **kwargs): return "%s WITHIN GROUP (ORDER BY %s)" % ( withingroup.element._compiler_dispatch(self, **kwargs), - withingroup.order_by._compiler_dispatch(self, **kwargs) + withingroup.order_by._compiler_dispatch(self, **kwargs), ) def visit_funcfilter(self, funcfilter, **kwargs): return "%s FILTER (WHERE %s)" % ( funcfilter.func._compiler_dispatch(self, **kwargs), - funcfilter.criterion._compiler_dispatch(self, **kwargs) + funcfilter.criterion._compiler_dispatch(self, **kwargs), ) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) return "EXTRACT(%s FROM %s)" % ( - field, extract.expr._compiler_dispatch(self, **kwargs)) + field, + extract.expr._compiler_dispatch(self, **kwargs), + ) def visit_function(self, func, add_to_result_map=None, **kwargs): if add_to_result_map is not None: - add_to_result_map( - func.name, func.name, (), func.type - ) + add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -933,51 +1073,63 @@ class SQLCompiler(Compiled): name += "%(expr)s" else: name = func.name + "%(expr)s" - return ".".join(list(func.packagenames) + [name]) % \ - {'expr': self.function_argspec(func, **kwargs)} + return ".".join(list(func.packagenames) + [name]) % { + "expr": self.function_argspec(func, **kwargs) + } def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) def visit_sequence(self, sequence, **kw): raise NotImplementedError( - "Dialect '%s' does not support sequence increments." % - self.dialect.name + "Dialect '%s' does not support sequence increments." + % self.dialect.name ) def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, - parens=True, compound_index=0, **kwargs): + def visit_compound_select( + self, cs, asfrom=False, parens=True, compound_index=0, **kwargs + ): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - need_result_map = toplevel or \ - (compound_index == 0 - and entry.get('need_result_map_for_compound', False)) + need_result_map = toplevel or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) self.stack.append( { - 'correlate_froms': entry['correlate_froms'], - 'asfrom_froms': entry['asfrom_froms'], - 'selectable': cs, - 'need_result_map_for_compound': need_result_map - }) + "correlate_froms": entry["correlate_froms"], + "asfrom_froms": entry["asfrom_froms"], + "selectable": cs, + "need_result_map_for_compound": need_result_map, + } + ) keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (c._compiler_dispatch(self, - asfrom=asfrom, parens=False, - compound_index=i, **kwargs) - for i, c in enumerate(cs.selects)) + ( + c._compiler_dispatch( + self, + asfrom=asfrom, + parens=False, + compound_index=i, + **kwargs + ) + for i, c in enumerate(cs.selects) + ) ) text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) text += self.order_by_clause(cs, **kwargs) - text += (cs._limit_clause is not None - or cs._offset_clause is not None) and \ - self.limit_clause(cs, **kwargs) or "" + text += ( + (cs._limit_clause is not None or cs._offset_clause is not None) + and self.limit_clause(cs, **kwargs) + or "" + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -990,8 +1142,10 @@ class SQLCompiler(Compiled): def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): attrname = "visit_%s_%s%s" % ( - operator_.__name__, qualifier1, - "_" + qualifier2 if qualifier2 else "") + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) return getattr(self, attrname, None) def visit_unary(self, unary, **kw): @@ -999,51 +1153,63 @@ class SQLCompiler(Compiled): if unary.modifier: raise exc.CompileError( "Unary expression does not support operator " - "and modifier simultaneously") + "and modifier simultaneously" + ) disp = self._get_operator_dispatch( - unary.operator, "unary", "operator") + unary.operator, "unary", "operator" + ) if disp: return disp(unary, unary.operator, **kw) else: return self._generate_generic_unary_operator( - unary, OPERATORS[unary.operator], **kw) + unary, OPERATORS[unary.operator], **kw + ) elif unary.modifier: disp = self._get_operator_dispatch( - unary.modifier, "unary", "modifier") + unary.modifier, "unary", "modifier" + ) if disp: return disp(unary, unary.modifier, **kw) else: return self._generate_generic_unary_modifier( - unary, OPERATORS[unary.modifier], **kw) + unary, OPERATORS[unary.modifier], **kw + ) else: raise exc.CompileError( - "Unary expression has no operator or modifier") + "Unary expression has no operator or modifier" + ) def visit_istrue_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return self.process(element.element, **kw) else: return "%s = 1" % self.process(element.element, **kw) def visit_isfalse_unary_operator(self, element, operator, **kw): - if element._is_implicitly_boolean or \ - self.dialect.supports_native_boolean: + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): return "NOT %s" % self.process(element.element, **kw) else: return "%s = 0" % self.process(element.element, **kw) def visit_notmatch_op_binary(self, binary, operator, **kw): return "NOT %s" % self.visit_binary( - binary, override_operator=operators.match_op) + binary, override_operator=operators.match_op + ) def _emit_empty_in_warning(self): util.warn( - 'The IN-predicate was invoked with an ' - 'empty sequence. This results in a ' - 'contradiction, which nonetheless can be ' - 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.') + "The IN-predicate was invoked with an " + "empty sequence. This results in a " + "contradiction, which nonetheless can be " + "expensive to evaluate. Consider alternative " + "strategies for improved performance." + ) def visit_empty_in_op_binary(self, binary, operator, **kw): if self.dialect._use_static_in: @@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled): def visit_empty_set_expr(self, element_types): raise NotImplementedError( - "Dialect '%s' does not support empty set expression." % - self.dialect.name + "Dialect '%s' does not support empty set expression." + % self.dialect.name ) - def visit_binary(self, binary, override_operator=None, - eager_grouping=False, **kw): + def visit_binary( + self, binary, override_operator=None, eager_grouping=False, **kw + ): # don't allow "? = ?" to render - if self.ansi_bind_rules and \ - isinstance(binary.left, elements.BindParameter) and \ - isinstance(binary.right, elements.BindParameter): - kw['literal_binds'] = True + if ( + self.ansi_bind_rules + and isinstance(binary.left, elements.BindParameter) + and isinstance(binary.right, elements.BindParameter) + ): + kw["literal_binds"] = True operator_ = override_operator or binary.operator disp = self._get_operator_dispatch(operator_, "binary", None) @@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled): def visit_mod_binary(self, binary, operator, **kw): if self.preparer._double_percents: - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) else: - return self.process(binary.left, **kw) + " % " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) def visit_custom_op_binary(self, element, operator, **kw): - kw['eager_grouping'] = operator.eager_grouping + kw["eager_grouping"] = operator.eager_grouping return self._generate_generic_binary( - element, " " + operator.opstring + " ", **kw) + element, " " + operator.opstring + " ", **kw + ) def visit_custom_op_unary_operator(self, element, operator, **kw): return self._generate_generic_unary_operator( - element, operator.opstring + " ", **kw) + element, operator.opstring + " ", **kw + ) def visit_custom_op_unary_modifier(self, element, operator, **kw): return self._generate_generic_unary_modifier( - element, " " + operator.opstring, **kw) + element, " " + operator.opstring, **kw + ) def _generate_generic_binary( - self, binary, opstring, eager_grouping=False, **kw): + self, binary, opstring, eager_grouping=False, **kw + ): - _in_binary = kw.get('_in_binary', False) + _in_binary = kw.get("_in_binary", False) - kw['_in_binary'] = True - text = binary.left._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + \ - opstring + \ - binary.right._compiler_dispatch( - self, eager_grouping=eager_grouping, **kw) + kw["_in_binary"] = True + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) if _in_binary and eager_grouping: text = "(%s)" % text @@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled): def visit_startswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_like_op_binary(binary, operator, **kw) def visit_notstartswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal - binary.right = percent.__radd__( - binary.right - ) + binary.right = percent.__radd__(binary.right) return self.visit_notlike_op_binary(binary, operator, **kw) def visit_endswith_op_binary(self, binary, operator, **kw): @@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled): escape = binary.modifiers.get("escape", None) # TODO: use ternary here, not "and"/ "or" - return '%s LIKE %s' % ( + return "%s LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT LIKE %s' % ( + return "%s NOT LIKE %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) LIKE lower(%s)' % ( + return "lower(%s) LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return 'lower(%s) NOT LIKE lower(%s)' % ( + return "lower(%s) NOT LIKE lower(%s)" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " BETWEEN SYMMETRIC " - if symmetric else " BETWEEN ", **kw) + binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw + ) def visit_notbetween_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " NOT BETWEEN SYMMETRIC " - if symmetric else " NOT BETWEEN ", **kw) + binary, + " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ", + **kw + ) - def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, - skip_bind_expression=False, - **kwargs): + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs + ): if not skip_bind_expression: impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: bind_expression = impl.bind_expression(bindparam) return self.process( - bind_expression, skip_bind_expression=True, + bind_expression, + skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, **kwargs ) - if literal_binds or \ - (within_columns_clause and - self.ansi_bind_rules): + if literal_binds or (within_columns_clause and self.ansi_bind_rules): if bindparam.value is None and bindparam.callable is None: - raise exc.CompileError("Bind parameter '%s' without a " - "renderable value not allowed here." - % bindparam.key) + raise exc.CompileError( + "Bind parameter '%s' without a " + "renderable value not allowed here." % bindparam.key + ) return self.render_literal_bindparam( - bindparam, within_columns_clause=True, **kwargs) + bindparam, within_columns_clause=True, **kwargs + ) name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam: - if (existing.unique or bindparam.unique) and \ - not existing.proxy_set.intersection( - bindparam.proxy_set): + if ( + existing.unique or bindparam.unique + ) and not existing.proxy_set.intersection(bindparam.proxy_set): raise exc.CompileError( "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" % - bindparam.key + "unique bind parameter of the same name" + % bindparam.key ) elif existing._is_crud or bindparam._is_crud: raise exc.CompileError( @@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled): "clause of this " "insert/update statement. Please use a " "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." % - (bindparam.key, bindparam.key) + "with insert() or update() (for example, 'b_%s')." + % (bindparam.key, bindparam.key) ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string( - name, expanding=bindparam.expanding, **kwargs) + name, expanding=bindparam.expanding, **kwargs + ) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.effective_value @@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled): return processor(value) else: raise NotImplementedError( - "Don't know how to literal-quote value %r" % value) + "Don't know how to literal-quote value %r" % value + ) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: @@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled): if len(anonname) > self.label_length - 6: counter = self.truncated_names.get(ident_class, 1) - truncname = anonname[0:max(self.label_length - 6, 0)] + \ - "_" + hex(counter)[2:] + truncname = ( + anonname[0 : max(self.label_length - 6, 0)] + + "_" + + hex(counter)[2:] + ) self.truncated_names[ident_class] = counter + 1 else: truncname = anonname @@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled): return name % self.anon_map def _process_anon(self, key): - (ident, derived) = key.split(' ', 1) + (ident, derived) = key.split(" ", 1) anonymous_counter = self.anon_map.get(derived, 1) self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) def bindparam_string( - self, name, positional_names=None, expanding=False, **kw): + self, name, positional_names=None, expanding=False, **kw + ): if self.positional: if positional_names is not None: positional_names.append(name) @@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled): self.contains_expanding_parameters = True return "([EXPANDING_%s])" % name else: - return self.bindtemplate % {'name': name} - - def visit_cte(self, cte, asfrom=False, ashint=False, - fromhints=None, visiting_cte=None, - **kwargs): + return self.bindtemplate % {"name": name} + + def visit_cte( + self, + cte, + asfrom=False, + ashint=False, + fromhints=None, + visiting_cte=None, + **kwargs + ): self._init_cte_state() - kwargs['visiting_cte'] = cte + kwargs["visiting_cte"] = cte if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: @@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled): else: raise exc.CompileError( "Multiple, unrelated CTEs found with " - "the same name: %r" % - cte_name) + "the same name: %r" % cte_name + ) if asfrom or is_new_cte: if cte._cte_alias is not None: @@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled): cte_pre_alias_name = cte._cte_alias.name if isinstance(cte_pre_alias_name, elements._truncated_label): cte_pre_alias_name = self._truncated_identifier( - "alias", cte_pre_alias_name) + "alias", cte_pre_alias_name + ) else: pre_alias_cte = cte cte_pre_alias_name = None @@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled): self.ctes_by_name[cte_name] = cte # look for embedded DML ctes and propagate autocommit - if 'autocommit' in cte.element._execution_options and \ - 'autocommit' not in self.execution_options: + if ( + "autocommit" in cte.element._execution_options + and "autocommit" not in self.execution_options + ): self.execution_options = self.execution_options.union( - {"autocommit": - cte.element._execution_options['autocommit']}) + { + "autocommit": cte.element._execution_options[ + "autocommit" + ] + } + ) if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled): col_source = cte.original.selects[0] else: assert False - recur_cols = [c for c in - util.unique_list(col_source.inner_columns) - if c is not None] - - text += "(%s)" % (", ".join( - self.preparer.format_column(ident) - for ident in recur_cols)) + recur_cols = [ + c + for c in util.unique_list(col_source.inner_columns) + if c is not None + ] + + text += "(%s)" % ( + ", ".join( + self.preparer.format_column(ident) + for ident in recur_cols + ) + ) if self.positional: - kwargs['positional_names'] = self.cte_positional[cte] = [] + kwargs["positional_names"] = self.cte_positional[cte] = [] - text += " AS \n" + \ - cte.original._compiler_dispatch( - self, asfrom=True, **kwargs - ) + text += " AS \n" + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) if cte._suffixes: text += " " + self._generate_prefixes( - cte, cte._suffixes, **kwargs) + cte, cte._suffixes, **kwargs + ) self.ctes[cte] = text @@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) - def visit_alias(self, alias, asfrom=False, ashint=False, - iscrud=False, - fromhints=None, **kwargs): + def visit_alias( + self, + alias, + asfrom=False, + ashint=False, + iscrud=False, + fromhints=None, + **kwargs + ): if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) @@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = alias.original._compiler_dispatch(self, - asfrom=True, **kwargs) + \ - self.get_render_as_alias_suffix( - self.preparer.format_alias(alias, alias_name)) + ret = alias.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) if fromhints and alias in fromhints: - ret = self.format_from_hint_text(ret, alias, - fromhints[alias], iscrud) + ret = self.format_from_hint_text( + ret, alias, fromhints[alias], iscrud + ) return ret else: return alias.original._compiler_dispatch(self, **kwargs) def visit_lateral(self, lateral, **kw): - kw['lateral'] = True + kw["lateral"] = True return "LATERAL %s" % self.visit_alias(lateral, **kw) def visit_tablesample(self, tablesample, asfrom=False, **kw): text = "%s TABLESAMPLE %s" % ( self.visit_alias(tablesample, asfrom=True, **kw), - tablesample._get_method()._compiler_dispatch(self, **kw)) + tablesample._get_method()._compiler_dispatch(self, **kw), + ) if tablesample.seed is not None: text += " REPEATABLE (%s)" % ( - tablesample.seed._compiler_dispatch(self, **kw)) + tablesample.seed._compiler_dispatch(self, **kw) + ) return text @@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled): def _add_to_result_map(self, keyname, name, objects, type_): self._result_columns.append((keyname, name, objects, type_)) - def _label_select_column(self, select, column, - populate_result_map, - asfrom, column_clause_args, - name=None, - within_columns_clause=True): + def _label_select_column( + self, + select, + column, + populate_result_map, + asfrom, + column_clause_args, + name=None, + within_columns_clause=True, + ): """produce labeled columns present in a select().""" impl = column.type.dialect_impl(self.dialect) - if impl._has_column_expression and \ - populate_result_map: + if impl._has_column_expression and populate_result_map: col_expr = impl.column_expression(column) def add_to_result_map(keyname, name, objects, type_): self._add_to_result_map( - keyname, name, - (column,) + objects, type_) + keyname, name, (column,) + objects, type_ + ) + else: col_expr = column if populate_result_map: @@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled): elif isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( - col_expr, - column.name, - alt_names=(column.element,) + col_expr, column.name, alt_names=(column.element,) ) else: result_expr = col_expr elif select is not None and name: result_expr = _CompileLabel( + col_expr, name, alt_names=(column._key_label,) + ) + + elif ( + asfrom + and isinstance(column, elements.ColumnClause) + and not column.is_literal + and column.table is not None + and not isinstance(column.table, selectable.Select) + ): + result_expr = _CompileLabel( col_expr, - name, - alt_names=(column._key_label,) - ) - - elif \ - asfrom and \ - isinstance(column, elements.ColumnClause) and \ - not column.is_literal and \ - column.table is not None and \ - not isinstance(column.table, selectable.Select): - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + elements._as_truncated(column.name), + alt_names=(column.key,), + ) elif ( - not isinstance(column, elements.TextClause) and - ( - not isinstance(column, elements.UnaryExpression) or - column.wraps_column_expression - ) and - ( - not hasattr(column, 'name') or - isinstance(column, functions.Function) + not isinstance(column, elements.TextClause) + and ( + not isinstance(column, elements.UnaryExpression) + or column.wraps_column_expression + ) + and ( + not hasattr(column, "name") + or isinstance(column, functions.Function) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) - result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + result_expr = _CompileLabel( + col_expr, + elements._as_truncated(column.name), + alt_names=(column.key,), + ) else: result_expr = col_expr column_clause_args.update( within_columns_clause=within_columns_clause, - add_to_result_map=add_to_result_map - ) - return result_expr._compiler_dispatch( - self, - **column_clause_args + add_to_result_map=add_to_result_map, ) + return result_expr._compiler_dispatch(self, **column_clause_args) def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) @@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() - if newelem.is_selectable and newelem._is_join and \ - isinstance(newelem.right, selectable.FromGrouping): + if ( + newelem.is_selectable + and newelem._is_join + and isinstance(newelem.right, selectable.FromGrouping) + ): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) @@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled): right = visit(newelem.right, **kw) selectable_ = selectable.Select( - [right.element], - use_labels=True).alias() + [right.element], use_labels=True + ).alias() for c in selectable_.c: c._key_label = c.key @@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled): elif newelem._is_from_container: # if we hit an Alias, CompoundSelect or ScalarSelect, put a # marker in the stack. - kw['transform_clue'] = 'select_container' + kw["transform_clue"] = "select_container" newelem._copy_internals(clone=visit, **kw) elif newelem.is_selectable and newelem._is_select: - barrier_select = kw.get('transform_clue', None) == \ - 'select_container' + barrier_select = ( + kw.get("transform_clue", None) == "select_container" + ) # if we're still descended from an # Alias/CompoundSelect/ScalarSelect, we're # in a FROM clause, so start with a new translate collection if barrier_select: column_translate.append({}) - kw['transform_clue'] = 'inside_select' + kw["transform_clue"] = "inside_select" newelem._copy_internals(clone=visit, **kw) if barrier_select: del column_translate[-1] @@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled): return visit(select) def _transform_result_map_for_nested_joins( - self, select, transformed_select): - inner_col = dict((c._key_label, c) for - c in transformed_select.inner_columns) - - d = dict( - (inner_col[c._key_label], c) - for c in select.inner_columns + self, select, transformed_select + ): + inner_col = dict( + (c._key_label, c) for c in transformed_select.inner_columns ) + d = dict((inner_col[c._key_label], c) for c in select.inner_columns) + self._result_columns = [ (key, name, tuple([d.get(col, col) for col in objs]), typ) for key, name, objs, typ in self._result_columns ] - _default_stack_entry = util.immutabledict([ - ('correlate_froms', frozenset()), - ('asfrom_froms', frozenset()) - ]) + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select(self, select, asfrom, lateral=False): # utility method to help external dialects @@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) return froms - def visit_select(self, select, asfrom=False, parens=True, - fromhints=None, - compound_index=0, - nested_join_translation=False, - select_wraps_for=None, - lateral=False, - **kwargs): - - needs_nested_translation = \ - select.use_labels and \ - not nested_join_translation and \ - not self.stack and \ - not self.dialect.supports_right_nested_joins + def visit_select( + self, + select, + asfrom=False, + parens=True, + fromhints=None, + compound_index=0, + nested_join_translation=False, + select_wraps_for=None, + lateral=False, + **kwargs + ): + + needs_nested_translation = ( + select.use_labels + and not nested_join_translation + and not self.stack + and not self.dialect.supports_right_nested_joins + ) if needs_nested_translation: transformed_select = self._transform_select_for_nested_joins( - select) + select + ) text = self.visit_select( - transformed_select, asfrom=asfrom, parens=parens, + transformed_select, + asfrom=asfrom, + parens=parens, fromhints=fromhints, compound_index=compound_index, - nested_join_translation=True, **kwargs + nested_join_translation=True, + **kwargs ) toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = toplevel or \ - ( - compound_index == 0 and entry.get( - 'need_result_map_for_compound', False) - ) or entry.get('need_result_map_for_nested', False) + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) # this was first proposed as part of #3372; however, it is not # reached in current tests and could possibly be an assertion # instead. - if not populate_result_map and 'add_to_result_map' in kwargs: - del kwargs['add_to_result_map'] + if not populate_result_map and "add_to_result_map" in kwargs: + del kwargs["add_to_result_map"] if needs_nested_translation: if populate_result_map: self._transform_result_map_for_nested_joins( - select, transformed_select) + select, transformed_select + ) return text froms = self._setup_select_stack(select, entry, asfrom, lateral) column_clause_args = kwargs.copy() - column_clause_args.update({ - 'within_label_clause': False, - 'within_columns_clause': False - }) + column_clause_args.update( + {"within_label_clause": False, "within_columns_clause": False} + ) text = "SELECT " # we're off to a good start ! @@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled): byfrom = None if select._prefixes: - text += self._generate_prefixes( - select, select._prefixes, **kwargs) + text += self._generate_prefixes(select, select._prefixes, **kwargs) text += self.get_select_precolumns(select, **kwargs) # the actual list of columns to print in the SELECT column list. inner_columns = [ - c for c in [ + c + for c in [ self._label_select_column( select, column, - populate_result_map, asfrom, + populate_result_map, + asfrom, column_clause_args, - name=name) + name=name, + ) for name, column in select._columns_plus_names ] if c is not None @@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled): translate = dict( zip( [name for (key, name) in select._columns_plus_names], - [name for (key, name) in - select_wraps_for._columns_plus_names]) + [ + name + for (key, name) in select_wraps_for._columns_plus_names + ], + ) ) self._result_columns = [ @@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs) + text, select, inner_columns, froms, byfrom, kwargs + ) if select._statement_hints: per_dialect = [ - ht for (dialect_name, ht) - in select._statement_hints - if dialect_name in ('*', self.dialect.name) + ht + for (dialect_name, ht) in select._statement_hints + if dialect_name in ("*", self.dialect.name) ] if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) @@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled): if select._suffixes: text += " " + self._generate_prefixes( - select, select._suffixes, **kwargs) + select, select._suffixes, **kwargs + ) self.stack.pop(-1) @@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled): return text def _setup_select_hints(self, select): - byfrom = dict([ - (from_, hinttext % { - 'name': from_._compiler_dispatch( - self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.items() - if dialect in ('*', self.dialect.name) - ]) + byfrom = dict( + [ + ( + from_, + hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)}, + ) + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom def _setup_select_stack(self, select, entry, asfrom, lateral): - correlate_froms = entry['correlate_froms'] - asfrom_froms = entry['asfrom_froms'] + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms.difference( - asfrom_froms), - implicit_correlate_froms=()) + asfrom_froms + ), + implicit_correlate_froms=(), + ) else: froms = select._get_display_froms( explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + implicit_correlate_froms=asfrom_froms, + ) new_correlate_froms = set(selectable._from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) new_entry = { - 'asfrom_froms': new_correlate_froms, - 'correlate_froms': all_correlate_froms, - 'selectable': select, + "asfrom_froms": new_correlate_froms, + "correlate_froms": all_correlate_froms, + "selectable": select, } self.stack.append(new_entry) return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs): - text += ', '.join(inner_columns) + self, text, select, inner_columns, froms, byfrom, kwargs + ): + text += ", ".join(inner_columns) if froms: text += " \nFROM " if select._hints: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, - fromhints=byfrom, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch( + self, asfrom=True, fromhints=byfrom, **kwargs + ) + for f in froms + ] + ) else: - text += ', '.join( - [f._compiler_dispatch(self, asfrom=True, **kwargs) - for f in froms]) + text += ", ".join( + [ + f._compiler_dispatch(self, asfrom=True, **kwargs) + for f in froms + ] + ) else: text += self.default_from() @@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled): if select._order_by_clause.clauses: text += self.order_by_clause(select, **kwargs) - if (select._limit_clause is not None or - select._offset_clause is not None): + if ( + select._limit_clause is not None + or select._offset_clause is not None + ): text += self.limit_clause(select, **kwargs) if select._for_update_arg is not None: @@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled): clause = " ".join( prefix._compiler_dispatch(self, **kw) for prefix, dialect_name in prefixes - if dialect_name is None or - dialect_name == self.dialect.name + if dialect_name is None or dialect_name == self.dialect.name ) if clause: clause += " " @@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled): def _render_cte_clause(self): if self.positional: - self.positiontup = sum([ - self.cte_positional[cte] - for cte in self.ctes], []) + \ - self.positiontup + self.positiontup = ( + sum([self.cte_positional[cte] for cte in self.ctes], []) + + self.positiontup + ) cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join( - [txt for txt in self.ctes.values()] - ) + cte_text += ", \n".join([txt for txt in self.ctes.values()]) cte_text += "\n " return cte_text @@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled): def returning_clause(self, stmt, returning_cols): raise exc.CompileError( "RETURNING is not supported by this " - "dialect's statement compiler.") + "dialect's statement compiler." + ) def limit_clause(self, select, **kw): text = "" @@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled): text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + **kwargs + ): if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) if use_schema and effective_schema: - ret = self.preparer.quote_schema(effective_schema) + \ - "." + self.preparer.quote(table.name) + ret = ( + self.preparer.quote_schema(effective_schema) + + "." + + self.preparer.quote(table.name) + ) else: ret = self.preparer.quote(table.name) if fromhints and table in fromhints: - ret = self.format_from_hint_text(ret, table, - fromhints[table], iscrud) + ret = self.format_from_hint_text( + ret, table, fromhints[table], iscrud + ) return ret else: return "" @@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + - join_type + - join.right._compiler_dispatch(self, asfrom=True, **kwargs) + - " ON " + - join.onclause._compiler_dispatch(self, **kwargs) + join.left._compiler_dispatch(self, asfrom=True, **kwargs) + + join_type + + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + " ON " + + join.onclause._compiler_dispatch(self, **kwargs) ) def _setup_crud_hints(self, stmt, table_text): - dialect_hints = dict([ - (table, hint_text) - for (table, dialect), hint_text in - stmt._hints.items() - if dialect in ('*', self.dialect.name) - ]) + dialect_hints = dict( + [ + (table, hint_text) + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) if stmt.table in dialect_hints: table_text = self.format_from_hint_text( - table_text, - stmt.table, - dialect_hints[stmt.table], - True + table_text, stmt.table, dialect_hints[stmt.table], True ) return dialect_hints, table_text @@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled): toplevel = not self.stack self.stack.append( - {'correlate_froms': set(), - "asfrom_froms": set(), - "selectable": insert_stmt}) + { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": insert_stmt, + } + ) crud_params = crud._setup_crud_params( - self, insert_stmt, crud.ISINSERT, **kw) + self, insert_stmt, crud.ISINSERT, **kw + ) - if not crud_params and \ - not self.dialect.supports_default_values and \ - not self.dialect.supports_empty_insert: - raise exc.CompileError("The '%s' dialect with current database " - "version settings does not support empty " - "inserts." % - self.dialect.name) + if ( + not crud_params + and not self.dialect.supports_default_values + and not self.dialect.supports_empty_insert + ): + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name + ) if insert_stmt._has_multi_parameters: if not self.dialect.supports_multivalues_insert: raise exc.CompileError( "The '%s' dialect with current database " "version settings does not support " - "in-place multirow inserts." % - self.dialect.name) + "in-place multirow inserts." % self.dialect.name + ) crud_params_single = crud_params[0] else: crud_params_single = crud_params @@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled): text = "INSERT " if insert_stmt._prefixes: - text += self._generate_prefixes(insert_stmt, - insert_stmt._prefixes, **kw) + text += self._generate_prefixes( + insert_stmt, insert_stmt._prefixes, **kw + ) text += "INTO " table_text = preparer.format_table(insert_stmt.table) if insert_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - insert_stmt, table_text) + insert_stmt, table_text + ) else: dialect_hints = None text += table_text if crud_params_single or not supports_default_values: - text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in crud_params_single]) + text += " (%s)" % ", ".join( + [preparer.format_column(c[0]) for c in crud_params_single] + ) if self.returning or insert_stmt._returning: returning_clause = self.returning_clause( - insert_stmt, self.returning or insert_stmt._returning) + insert_stmt, self.returning or insert_stmt._returning + ) if self.returning_precedes_values: text += " " + returning_clause @@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled): elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( ", ".join( - "(%s)" % ( - ', '.join(c[1] for c in crud_param_set) - ) + "(%s)" % (", ".join(c[1] for c in crud_param_set)) for crud_param_set in crud_params ) ) else: - text += " VALUES (%s)" % \ - ', '.join([c[1] for c in crud_params]) + text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params]) if insert_stmt._post_values_clause is not None: post_values_clause = self.process( - insert_stmt._post_values_clause, **kw) + insert_stmt._post_values_clause, **kw + ) if post_values_clause: text += " " + post_values_clause @@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled): """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. MySQL overrides this. """ - kw['asfrom'] = True + kw["asfrom"] = True return from_table._compiler_dispatch(self, iscrud=True, **kw) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an UPDATE..FROM clause. @@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within UPDATE") + "criteria within UPDATE" + ) def visit_update(self, update_stmt, asfrom=False, **kw): toplevel = not self.stack @@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled): correlate_froms = {update_stmt.table} self.stack.append( - {'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": update_stmt}) + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) text = "UPDATE " if update_stmt._prefixes: - text += self._generate_prefixes(update_stmt, - update_stmt._prefixes, **kw) + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) - table_text = self.update_tables_clause(update_stmt, update_stmt.table, - render_extra_froms, **kw) + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) crud_params = crud._setup_crud_params( - self, update_stmt, crud.ISUPDATE, **kw) + self, update_stmt, crud.ISUPDATE, **kw + ) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - update_stmt, table_text) + update_stmt, table_text + ) else: dialect_hints = None text += table_text - text += ' SET ' - include_table = is_multitable and \ - self.render_table_with_column_in_update_from - text += ', '.join( - c[0]._compiler_dispatch(self, - include_table=include_table) + - '=' + c[1] for c in crud_params + text += " SET " + include_table = ( + is_multitable and self.render_table_with_column_in_update_from + ) + text += ", ".join( + c[0]._compiler_dispatch(self, include_table=include_table) + + "=" + + c[1] + for c in crud_params ) if self.returning or update_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if extra_froms: extra_from_text = self.update_from_clause( update_stmt, update_stmt.table, render_extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if (self.returning or update_stmt._returning) and \ - not self.returning_precedes_values: + if ( + self.returning or update_stmt._returning + ) and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning) + update_stmt, self.returning or update_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled): def _key_getters_for_crud_column(self): return crud._key_getters_for_crud_column(self, self.statement) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, **kw): + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Provide a hook to override the generation of an DELETE..FROM clause. @@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled): """ raise NotImplementedError( "This backend does not support multiple-table " - "criteria within DELETE") + "criteria within DELETE" + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, asfrom=False, **kw): @@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled): extra_froms = delete_stmt._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) - self.stack.append({'correlate_froms': correlate_froms, - "asfrom_froms": correlate_froms, - "selectable": delete_stmt}) + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": delete_stmt, + } + ) text = "DELETE " if delete_stmt._prefixes: - text += self._generate_prefixes(delete_stmt, - delete_stmt._prefixes, **kw) + text += self._generate_prefixes( + delete_stmt, delete_stmt._prefixes, **kw + ) text += "FROM " - table_text = self.delete_table_clause(delete_stmt, delete_stmt.table, - extra_froms) + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) if delete_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( - delete_stmt, table_text) + delete_stmt, table_text + ) else: dialect_hints = None @@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled): if delete_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if extra_froms: extra_from_text = self.delete_extra_from_clause( delete_stmt, delete_stmt.table, extra_froms, - dialect_hints, **kw) + dialect_hints, + **kw + ) if extra_from_text: text += " " + extra_from_text @@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled): if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning + ) if self.ctes and toplevel: text = self._render_cte_clause() + text @@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) class StrSQLCompiler(SQLCompiler): @@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler): def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_json_getitem_op_binary(self, binary, operator, **kw): @@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler): for c in elements._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) - def delete_extra_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return ', ' + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return ", " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) class DDLCompiler(Compiled): - @util.memoized_property def sql_compiler(self): return self.dialect.statement_compiler(self.dialect, None) @@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled): preparer = self.preparer path = preparer.format_table_seq(ddl.target) if len(path) == 1: - table, sch = path[0], '' + table, sch = path[0], "" else: table, sch = path[-1], path[0] - context.setdefault('table', table) - context.setdefault('schema', sch) - context.setdefault('fullname', preparer.format_table(ddl.target)) + context.setdefault("table", table) + context.setdefault("schema", sch) + context.setdefault("fullname", preparer.format_table(ddl.target)) return self.sql_compiler.post_process_text(ddl.statement % context) @@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled): for create_column in create.columns: column = create_column.element try: - processed = self.process(create_column, - first_pk=column.primary_key - and not first_pk) + processed = self.process( + create_column, first_pk=column.primary_key and not first_pk + ) if processed is not None: text += separator separator = ", \n" @@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled): except exc.CompileError as ce: util.raise_from_cause( exc.CompileError( - util.u("(in table '%s', column '%s'): %s") % - (table.description, column.name, ce.args[0]) - )) + util.u("(in table '%s', column '%s'): %s") + % (table.description, column.name, ce.args[0]) + ) + ) const = self.create_table_constraints( - table, _include_foreign_key_constraints= # noqa - create.include_foreign_key_constraints) + table, + _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa + ) if const: text += separator + "\t" + const @@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled): if column.system: return None - text = self.get_column_specification( - column, - first_pk=first_pk + text = self.get_column_specification(column, first_pk=first_pk) + const = " ".join( + self.process(constraint) for constraint in column.constraints ) - const = " ".join(self.process(constraint) - for constraint in column.constraints) if const: text += " " + const return text def create_table_constraints( - self, table, - _include_foreign_key_constraints=None): + self, table, _include_foreign_key_constraints=None + ): # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) @@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled): else: omit_fkcs = set() - constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key and - c not in omit_fkcs]) + constraints.extend( + [ + c + for c in table._sorted_constraints + if c is not table.primary_key and c not in omit_fkcs + ] + ) return ", \n\t".join( - p for p in - (self.process(constraint) + p + for p in ( + self.process(constraint) for constraint in constraints if ( - constraint._create_rule is None or - constraint._create_rule(self)) + constraint._create_rule is None + or constraint._create_rule(self) + ) and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None + not self.dialect.supports_alter + or not getattr(constraint, "use_alter", False) + ) + ) + if p is not None ) def visit_drop_table(self, drop): @@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled): def _verify_index_table(self, index): if index.table is None: - raise exc.CompileError("Index '%s' is not associated " - "with any table." % index.name) + raise exc.CompileError( + "Index '%s' is not associated " "with any table." % index.name + ) - def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): index = create.element self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table, - use_schema=include_table_schema), - ', '.join( - self.sql_compiler.process( - expr, include_table=False, literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table( + index.table, use_schema=include_table_schema + ), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) return text def visit_drop_index(self, drop): index = drop.element return "\nDROP INDEX " + self._prepared_index_name( - index, include_schema=True) + index, include_schema=True + ) def _prepared_index_name(self, index, include_schema=False): if index.table is not None: @@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled): def visit_add_constraint(self, create): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), - self.process(create.element) + self.process(create.element), ) def visit_set_table_comment(self, create): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_table_comment(self, drop): - return "COMMENT ON TABLE %s IS NULL" % \ - self.preparer.format_table(drop.element) + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) def visit_set_column_comment(self, create): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( - create.element, use_table=True, use_schema=True), + create.element, use_table=True, use_schema=True + ), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_drop_column_comment(self, drop): - return "COMMENT ON COLUMN %s IS NULL" % \ - self.preparer.format_column(drop.element, use_table=True) + return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( + drop.element, use_table=True + ) def visit_create_sequence(self, create): - text = "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) if create.element.increment is not None: text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: @@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled): return text def visit_drop_sequence(self, drop): - return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): constraint = drop.element @@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled): if formatted_name is None: raise exc.CompileError( "Can't emit DROP CONSTRAINT for constraint %r; " - "it has no name" % drop.element) + "it has no name" % drop.element + ) return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), formatted_name, - drop.cascade and " CASCADE" or "" + drop.cascade and " CASCADE" or "", ) def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process( - column.type, type_expression=column) + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled): return colspec def create_table_suffix(self, table): - return '' + return "" def post_create_table(self, table): - return '' + return "" def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, util.string_types): return self.sql_compiler.render_literal_value( - column.server_default.arg, sqltypes.STRINGTYPE) + column.server_default.arg, sqltypes.STRINGTYPE + ) else: return self.sql_compiler.process( - column.server_default.arg, literal_binds=True) + column.server_default.arg, literal_binds=True + ) else: return None @@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text @@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name text += "PRIMARY KEY " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in (constraint.columns_autoinc_first - if constraint._implicit_generated - else constraint.columns)) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled): text += "CONSTRAINT %s " % formatted_name remote_table = list(constraint.elements)[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join(preparer.quote(f.parent.name) - for f in constraint.elements), + ", ".join( + preparer.quote(f.parent.name) for f in constraint.elements + ), self.define_constraint_remote_table( - constraint, remote_table, preparer), - ', '.join(preparer.quote(f.column.name) - for f in constraint.elements) + constraint, remote_table, preparer + ), + ", ".join( + preparer.quote(f.column.name) for f in constraint.elements + ), ) text += self.define_constraint_match(constraint) text += self.define_constraint_cascades(constraint) @@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled): def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) text += "CONSTRAINT %s " % formatted_name text += "UNIQUE (%s)" % ( - ', '.join(self.preparer.quote(c.name) - for c in constraint)) + ", ".join(self.preparer.quote(c.name) for c in constraint) + ) text += self.define_constraint_deferrability(constraint) return text @@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_, **kw): return "FLOAT" @@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler): if type_.precision is None: return "NUMERIC" elif type_.scale is None: - return "NUMERIC(%(precision)s)" % \ - {'precision': type_.precision} + return "NUMERIC(%(precision)s)" % {"precision": type_.precision} else: - return "NUMERIC(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "NUMERIC(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return "DECIMAL" elif type_.scale is None: - return "DECIMAL(%(precision)s)" % \ - {'precision': type_.precision} + return "DECIMAL(%(precision)s)" % {"precision": type_.precision} else: - return "DECIMAL(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + return "DECIMAL(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } def visit_INTEGER(self, type_, **kw): return "INTEGER" @@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler): return "BIGINT" def visit_TIMESTAMP(self, type_, **kw): - return 'TIMESTAMP' + return "TIMESTAMP" def visit_DATETIME(self, type_, **kw): return "DATETIME" @@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler): return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): - raise exc.CompileError("Can't generate DDL for %r; " - "did you forget to specify a " - "type on this Column?" % type_) + raise exc.CompileError( + "Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_ + ) def visit_type_decorator(self, type_, **kw): return self.process(type_.type_engine(self.dialect), **kw) @@ -3018,9 +3353,15 @@ class IdentifierPreparer(object): schema_for_object = schema._schema_getter(None) - def __init__(self, dialect, initial_quote='"', - final_quote=None, escape_quote='"', - quote_case_sensitive_collations=True, omit_schema=False): + def __init__( + self, + dialect, + initial_quote='"', + final_quote=None, + escape_quote='"', + quote_case_sensitive_collations=True, + omit_schema=False, + ): """Construct a new ``IdentifierPreparer`` object. initial_quote @@ -3043,7 +3384,10 @@ class IdentifierPreparer(object): self.omit_schema = omit_schema self.quote_case_sensitive_collations = quote_case_sensitive_collations self._strings = {} - self._double_percents = self.dialect.paramstyle in ('format', 'pyformat') + self._double_percents = self.dialect.paramstyle in ( + "format", + "pyformat", + ) def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) @@ -3060,7 +3404,7 @@ class IdentifierPreparer(object): value = value.replace(self.escape_quote, self.escape_to_quote) if self._double_percents: - value = value.replace('%', '%%') + value = value.replace("%", "%%") return value def _unescape_identifier(self, value): @@ -3079,17 +3423,21 @@ class IdentifierPreparer(object): quoting behavior. """ - return self.initial_quote + \ - self._escape_identifier(value) + \ - self.final_quote + return ( + self.initial_quote + + self._escape_identifier(value) + + self.final_quote + ) def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() - return (lc_value in self.reserved_words - or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) - or (lc_value != value)) + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + or (lc_value != value) + ) def quote_schema(self, schema, force=None): """Conditionally quote a schema. @@ -3135,8 +3483,11 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(sequence) - if (not self.omit_schema and use_schema and - effective_schema is not None): + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): name = self.quote_schema(effective_schema) + "." + name return name @@ -3159,7 +3510,8 @@ class IdentifierPreparer(object): def format_constraint(self, naming, constraint): if isinstance(constraint.name, elements._defer_name): name = naming._constraint_name_for_table( - constraint, constraint.table) + constraint, constraint.table + ) if name is None: if isinstance(constraint.name, elements._defer_none_name): @@ -3170,14 +3522,15 @@ class IdentifierPreparer(object): name = constraint.name if isinstance(name, elements._truncated_label): - if constraint.__visit_name__ == 'index': - max_ = self.dialect.max_index_name_length or \ - self.dialect.max_identifier_length + if constraint.__visit_name__ == "index": + max_ = ( + self.dialect.max_index_name_length + or self.dialect.max_identifier_length + ) else: max_ = self.dialect.max_identifier_length if len(name) > max_: - name = name[0:max_ - 8] + \ - "_" + util.md5_hex(name)[-4:] + name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] else: self.dialect.validate_identifier(name) @@ -3195,8 +3548,7 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema \ - and effective_schema: + if not self.omit_schema and use_schema and effective_schema: result = self.quote_schema(effective_schema) + "." + result return result @@ -3205,17 +3557,27 @@ class IdentifierPreparer(object): return self.quote(name, quote) - def format_column(self, column, use_table=False, - name=None, table_name=None, use_schema=False): + def format_column( + self, + column, + use_table=False, + name=None, + table_name=None, + use_schema=False, + ): """Prepare a quoted column name.""" if name is None: name = column.name - if not getattr(column, 'is_literal', False): + if not getattr(column, "is_literal", False): if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + "." + self.quote(name) + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + self.quote(name) + ) else: return self.quote(name) else: @@ -3223,9 +3585,13 @@ class IdentifierPreparer(object): # which shouldn't get quoted if use_table: - return self.format_table( - column.table, use_schema=use_schema, - name=table_name) + '.' + name + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + name + ) else: return name @@ -3238,31 +3604,37 @@ class IdentifierPreparer(object): effective_schema = self.schema_for_object(table) - if not self.omit_schema and use_schema and \ - effective_schema: - return (self.quote_schema(effective_schema), - self.format_table(table, use_schema=False)) + if not self.omit_schema and use_schema and effective_schema: + return ( + self.quote_schema(effective_schema), + self.format_table(table, use_schema=False), + ) else: - return (self.format_table(table, use_schema=False), ) + return (self.format_table(table, use_schema=False),) @util.memoized_property def _r_identifiers(self): - initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] + initial, final, escaped_final = [ + re.escape(s) + for s in ( + self.initial_quote, + self.final_quote, + self._escape_identifier(self.final_quote), + ) + ] r = re.compile( - r'(?:' - r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' - r'|([^\.]+))(?=\.|$))+' % - {'initial': initial, - 'final': final, - 'escaped': escaped_final}) + r"(?:" + r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" + r"|([^\.]+))(?=\.|$))+" + % {"initial": initial, "final": final, "escaped": escaped_final} + ) return r def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" r = self._r_identifiers - return [self._unescape_identifier(i) - for i in [a or b for a, b in r.findall(identifiers)]] + return [ + self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)] + ] |