diff options
author | Brian Jarrett <celttechie@gmail.com> | 2014-07-20 12:44:40 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-07-20 12:44:40 -0400 |
commit | cca03097f47f22783d42d1853faac6cf84607c5a (patch) | |
tree | 4fe1a63d03a2d88d1cf37e1167759dfaf84f4ce7 /lib/sqlalchemy/sql/compiler.py | |
parent | 827329a0cca5351094a1a86b6b2be2b9182f0ae2 (diff) | |
download | sqlalchemy-cca03097f47f22783d42d1853faac6cf84607c5a.tar.gz |
- apply pep8 formatting to sqlalchemy/sql, sqlalchemy/util, sqlalchemy/dialects,
sqlalchemy/orm, sqlalchemy/event, sqlalchemy/testing
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 1019 |
1 files changed, 515 insertions, 504 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 384cf27c2..ac45054ae 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -25,7 +25,7 @@ To generate user-defined SQL strings, see import re from . import schema, sqltypes, operators, functions, \ - util as sql_util, visitors, elements, selectable, base + util as sql_util, visitors, elements, selectable, base from .. import util, exc import decimal import itertools @@ -158,7 +158,9 @@ COMPOUND_KEYWORDS = { selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' } + class Compiled(object): + """Represent a compiled SQL or DDL expression. The ``__str__`` method of the ``Compiled`` object should produce @@ -174,7 +176,7 @@ class Compiled(object): _cached_metadata = None def __init__(self, dialect, statement, bind=None, - compile_kwargs=util.immutabledict()): + compile_kwargs=util.immutabledict()): """Construct a new ``Compiled`` object. :param dialect: ``Dialect`` to compile against. @@ -199,7 +201,7 @@ class Compiled(object): self.string = self.process(self.statement, **compile_kwargs) @util.deprecated("0.7", ":class:`.Compiled` objects now compile " - "within the constructor.") + "within the constructor.") def compile(self): """Produce the internal string representation of this element. """ @@ -247,8 +249,8 @@ class Compiled(object): e = self.bind if e is None: raise exc.UnboundExecutionError( - "This Compiled object is not bound to any Engine " - "or Connection.") + "This Compiled object is not bound to any Engine " + "or Connection.") return e._execute_compiled(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -259,6 +261,7 @@ class Compiled(object): class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" def __init__(self, dialect): @@ -268,8 +271,8 @@ class TypeCompiler(object): return type_._compiler_dispatch(self) - class _CompileLabel(visitors.Visitable): + """lightweight label object which acts as an expression.Label.""" __visit_name__ = 'label' @@ -290,6 +293,7 @@ class _CompileLabel(visitors.Visitable): class SQLCompiler(Compiled): + """Default implementation of Compiled. Compiles ClauseElements into SQL strings. Uses a similar visit @@ -333,7 +337,7 @@ class SQLCompiler(Compiled): """ def __init__(self, dialect, statement, column_keys=None, - inline=False, **kwargs): + inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. dialect @@ -412,19 +416,19 @@ class SQLCompiler(Compiled): def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r'\[_POSITION\]', - lambda m: str(util.next(poscount)), - self.string) + r'\[_POSITION\]', + lambda m: str(util.next(poscount)), + self.string) @util.memoized_property def _bind_processors(self): return dict( - (key, value) for key, value in - ((self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect)) - for bindparam in self.bind_names) - if value is not None - ) + (key, value) for key, value in + ((self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect)) + for bindparam in self.bind_names) + if value is not None + ) def is_subquery(self): return len(self.stack) > 1 @@ -491,15 +495,16 @@ class SQLCompiler(Compiled): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" def visit_label(self, label, - add_to_result_map=None, - within_label_clause=False, - within_columns_clause=False, - render_label_as_label=None, - **kw): + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + **kw): # only render labels within the columns clause # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. - render_label_with_as = within_columns_clause and not within_label_clause + render_label_with_as = (within_columns_clause and not + within_label_clause) render_label_only = render_label_as_label is label if render_label_only or render_label_with_as: @@ -511,27 +516,25 @@ class SQLCompiler(Compiled): if render_label_with_as: if add_to_result_map is not None: add_to_result_map( - labelname, - label.name, - (label, labelname, ) + label._alt_names, - label.type + labelname, + label.name, + (label, labelname, ) + label._alt_names, + label.type ) - return label.element._compiler_dispatch(self, - within_columns_clause=True, - within_label_clause=True, - **kw) + \ - OPERATORS[operators.as_] + \ - self.preparer.format_label(label, labelname) + return label.element._compiler_dispatch( + self, within_columns_clause=True, + within_label_clause=True, **kw) + \ + OPERATORS[operators.as_] + \ + self.preparer.format_label(label, labelname) elif render_label_only: return self.preparer.format_label(label, labelname) else: - return label.element._compiler_dispatch(self, - within_columns_clause=False, - **kw) + return label.element._compiler_dispatch( + self, within_columns_clause=False, **kw) def visit_column(self, column, add_to_result_map=None, - include_table=True, **kwargs): + include_table=True, **kwargs): name = orig_name = column.name if name is None: raise exc.CompileError("Cannot compile Column object until " @@ -567,8 +570,8 @@ class SQLCompiler(Compiled): tablename = self._truncated_identifier("alias", tablename) return schema_prefix + \ - self.preparer.quote(tablename) + \ - "." + name + self.preparer.quote(tablename) + \ + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -597,37 +600,38 @@ class SQLCompiler(Compiled): return self.bindparam_string(name, **kw) # un-escape any \:params - return BIND_PARAMS_ESC.sub(lambda m: m.group(1), - BIND_PARAMS.sub(do_bindparam, - self.post_process_text(textclause.text)) + return BIND_PARAMS_ESC.sub( + lambda m: m.group(1), + BIND_PARAMS.sub( + do_bindparam, + self.post_process_text(textclause.text)) ) def visit_text_as_from(self, taf, iswrapper=False, - compound_index=0, force_result_map=False, - asfrom=False, - parens=True, **kw): + compound_index=0, force_result_map=False, + asfrom=False, + parens=True, **kw): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] populate_result_map = force_result_map or ( - compound_index == 0 and ( - toplevel or \ - entry['iswrapper'] - ) - ) + compound_index == 0 and ( + toplevel or + entry['iswrapper'] + ) + ) if populate_result_map: for c in taf.column_args: self.process(c, within_columns_clause=True, - add_to_result_map=self._add_to_result_map) + add_to_result_map=self._add_to_result_map) text = self.process(taf.element, **kw) if asfrom and parens: text = "(%s)" % text return text - def visit_null(self, expr, **kw): return 'NULL' @@ -646,7 +650,7 @@ class SQLCompiler(Compiled): def visit_clauselist(self, clauselist, order_by_select=None, **kw): if order_by_select is not None: return self._order_by_clauselist( - clauselist, order_by_select, **kw) + clauselist, order_by_select, **kw) sep = clauselist.operator if sep is None: @@ -654,11 +658,11 @@ class SQLCompiler(Compiled): else: sep = OPERATORS[clauselist.operator] return sep.join( - s for s in - ( - c._compiler_dispatch(self, **kw) - for c in clauselist.clauses) - if s) + s for s in + ( + c._compiler_dispatch(self, **kw) + for c in clauselist.clauses) + if s) def _order_by_clauselist(self, clauselist, order_by_select, **kw): # look through raw columns collection for labels. @@ -667,21 +671,21 @@ class SQLCompiler(Compiled): # label expression in the columns clause. raw_col = set(l._order_by_label_element.name - for l in order_by_select._raw_columns - if l._order_by_label_element is not None) + for l in order_by_select._raw_columns + if l._order_by_label_element is not None) return ", ".join( - s for s in - ( - c._compiler_dispatch(self, - render_label_as_label= - c._order_by_label_element if - c._order_by_label_element is not None and - c._order_by_label_element.name in raw_col - else None, - **kw) - for c in clauselist.clauses) - if s) + s for s in + ( + c._compiler_dispatch( + self, + render_label_as_label=c._order_by_label_element if + c._order_by_label_element is not None and + c._order_by_label_element.name in raw_col + else None, + **kw) + for c in clauselist.clauses) + if s) def visit_case(self, clause, **kwargs): x = "CASE " @@ -689,38 +693,38 @@ class SQLCompiler(Compiled): x += clause.value._compiler_dispatch(self, **kwargs) + " " for cond, result in clause.whens: x += "WHEN " + cond._compiler_dispatch( - self, **kwargs - ) + " THEN " + result._compiler_dispatch( - self, **kwargs) + " " + self, **kwargs + ) + " THEN " + result._compiler_dispatch( + self, **kwargs) + " " if clause.else_ is not None: x += "ELSE " + clause.else_._compiler_dispatch( - self, **kwargs - ) + " " + self, **kwargs + ) + " " x += "END" return x def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % \ - (cast.clause._compiler_dispatch(self, **kwargs), - cast.typeclause._compiler_dispatch(self, **kwargs)) + (cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs)) def visit_over(self, over, **kwargs): return "%s OVER (%s)" % ( over.func._compiler_dispatch(self, **kwargs), ' '.join( - '%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs)) - for word, clause in ( - ('PARTITION', over.partition_by), - ('ORDER', over.order_by) - ) - if clause is not None and len(clause) + '%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ('PARTITION', over.partition_by), + ('ORDER', over.order_by) + ) + if clause is not None and len(clause) ) ) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s)" % (field, - extract.expr._compiler_dispatch(self, **kwargs)) + return "EXTRACT(%s FROM %s)" % ( + field, extract.expr._compiler_dispatch(self, **kwargs)) def visit_function(self, func, add_to_result_map=None, **kwargs): if add_to_result_map is not None: @@ -734,7 +738,7 @@ class SQLCompiler(Compiled): else: name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") return ".".join(list(func.packagenames) + [name]) % \ - {'expr': self.function_argspec(func, **kwargs)} + {'expr': self.function_argspec(func, **kwargs)} def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) @@ -748,39 +752,38 @@ class SQLCompiler(Compiled): def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) - def visit_compound_select(self, cs, asfrom=False, - parens=True, compound_index=0, **kwargs): + parens=True, compound_index=0, **kwargs): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] self.stack.append( - { - 'correlate_froms': entry['correlate_froms'], - 'iswrapper': toplevel, - 'asfrom_froms': entry['asfrom_froms'] - }) + { + 'correlate_froms': entry['correlate_froms'], + 'iswrapper': toplevel, + 'asfrom_froms': entry['asfrom_froms'] + }) keyword = self.compound_keywords.get(cs.keyword) text = (" " + keyword + " ").join( - (c._compiler_dispatch(self, - asfrom=asfrom, parens=False, - compound_index=i, **kwargs) - for i, c in enumerate(cs.selects)) - ) + (c._compiler_dispatch(self, + asfrom=asfrom, parens=False, + compound_index=i, **kwargs) + for i, c in enumerate(cs.selects)) + ) group_by = cs._group_by_clause._compiler_dispatch( - self, asfrom=asfrom, **kwargs) + self, asfrom=asfrom, **kwargs) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs, **kwargs) text += (cs._limit_clause is not None or cs._offset_clause is not None) and \ - self.limit_clause(cs) or "" + self.limit_clause(cs) or "" if self.ctes and \ - compound_index == 0 and toplevel: + compound_index == 0 and toplevel: text = self._render_cte_clause() + text self.stack.pop(-1) @@ -793,26 +796,26 @@ class SQLCompiler(Compiled): if unary.operator: if unary.modifier: raise exc.CompileError( - "Unary expression does not support operator " - "and modifier simultaneously") + "Unary expression does not support operator " + "and modifier simultaneously") disp = getattr(self, "visit_%s_unary_operator" % - unary.operator.__name__, None) + unary.operator.__name__, None) if disp: return disp(unary, unary.operator, **kw) else: - return self._generate_generic_unary_operator(unary, - OPERATORS[unary.operator], **kw) + return self._generate_generic_unary_operator( + unary, OPERATORS[unary.operator], **kw) elif unary.modifier: disp = getattr(self, "visit_%s_unary_modifier" % - unary.modifier.__name__, None) + unary.modifier.__name__, None) if disp: return disp(unary, unary.modifier, **kw) else: - return self._generate_generic_unary_modifier(unary, - OPERATORS[unary.modifier], **kw) + return self._generate_generic_unary_modifier( + unary, OPERATORS[unary.modifier], **kw) else: raise exc.CompileError( - "Unary expression has no operator or modifier") + "Unary expression has no operator or modifier") def visit_istrue_unary_operator(self, element, operator, **kw): if self.dialect.supports_native_boolean: @@ -829,8 +832,8 @@ class SQLCompiler(Compiled): def visit_binary(self, binary, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ - isinstance(binary.left, elements.BindParameter) and \ - isinstance(binary.right, elements.BindParameter): + isinstance(binary.left, elements.BindParameter) and \ + isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True operator = binary.operator @@ -846,21 +849,21 @@ class SQLCompiler(Compiled): return self._generate_generic_binary(binary, opstring, **kw) def visit_custom_op_binary(self, element, operator, **kw): - return self._generate_generic_binary(element, - " " + operator.opstring + " ", **kw) + return self._generate_generic_binary( + element, " " + operator.opstring + " ", **kw) def visit_custom_op_unary_operator(self, element, operator, **kw): - return self._generate_generic_unary_operator(element, - operator.opstring + " ", **kw) + return self._generate_generic_unary_operator( + element, operator.opstring + " ", **kw) def visit_custom_op_unary_modifier(self, element, operator, **kw): - return self._generate_generic_unary_modifier(element, - " " + operator.opstring, **kw) + return self._generate_generic_unary_modifier( + element, " " + operator.opstring, **kw) def _generate_generic_binary(self, binary, opstring, **kw): return binary.left._compiler_dispatch(self, **kw) + \ - opstring + \ - binary.right._compiler_dispatch(self, **kw) + opstring + \ + binary.right._compiler_dispatch(self, **kw) def _generate_generic_unary_operator(self, unary, opstring, **kw): return opstring + unary.element._compiler_dispatch(self, **kw) @@ -888,16 +891,16 @@ class SQLCompiler(Compiled): binary = binary._clone() percent = self._like_percent_literal binary.right = percent.__radd__( - binary.right - ) + binary.right + ) return self.visit_like_op_binary(binary, operator, **kw) def visit_notstartswith_op_binary(self, binary, operator, **kw): binary = binary._clone() percent = self._like_percent_literal binary.right = percent.__radd__( - binary.right - ) + binary.right + ) return self.visit_notlike_op_binary(binary, operator, **kw) def visit_endswith_op_binary(self, binary, operator, **kw): @@ -917,77 +920,77 @@ class SQLCompiler(Compiled): # TODO: use ternary here, not "and"/ "or" return '%s LIKE %s' % ( - binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + ( ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT LIKE %s' % ( - binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + ( ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( - binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + ( ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( - binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) \ + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw)) \ + ( ' ESCAPE ' + self.render_literal_value(escape, sqltypes.STRINGTYPE) if escape else '' - ) + ) def visit_between_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " BETWEEN SYMMETRIC " - if symmetric else " BETWEEN ", **kw) + binary, " BETWEEN SYMMETRIC " + if symmetric else " BETWEEN ", **kw) def visit_notbetween_op_binary(self, binary, operator, **kw): symmetric = binary.modifiers.get("symmetric", False) return self._generate_generic_binary( - binary, " NOT BETWEEN SYMMETRIC " - if symmetric else " NOT BETWEEN ", **kw) + binary, " NOT BETWEEN SYMMETRIC " + if symmetric else " NOT BETWEEN ", **kw) def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, - skip_bind_expression=False, - **kwargs): + literal_binds=False, + skip_bind_expression=False, + **kwargs): if not skip_bind_expression and bindparam.type._has_bind_expression: bind_expression = bindparam.type.bind_expression(bindparam) return self.process(bind_expression, skip_bind_expression=True) if literal_binds or \ - (within_columns_clause and \ + (within_columns_clause and self.ansi_bind_rules): if bindparam.value is None and bindparam.callable is None: raise exc.CompileError("Bind parameter '%s' without a " - "renderable value not allowed here." - % bindparam.key) - return self.render_literal_bindparam(bindparam, - within_columns_clause=True, **kwargs) + "renderable value not allowed here." + % bindparam.key) + return self.render_literal_bindparam( + bindparam, within_columns_clause=True, **kwargs) name = self._truncate_bindparam(bindparam) @@ -995,13 +998,13 @@ class SQLCompiler(Compiled): existing = self.binds[name] if existing is not bindparam: if (existing.unique or bindparam.unique) and \ - not existing.proxy_set.intersection( - bindparam.proxy_set): + not existing.proxy_set.intersection( + bindparam.proxy_set): raise exc.CompileError( - "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" % - bindparam.key - ) + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % + bindparam.key + ) elif existing._is_crud or bindparam._is_crud: raise exc.CompileError( "bindparam() name '%s' is reserved " @@ -1009,8 +1012,8 @@ class SQLCompiler(Compiled): "clause of this " "insert/update statement. Please use a " "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (bindparam.key, bindparam.key) + "with insert() or update() (for example, 'b_%s')." % + (bindparam.key, bindparam.key) ) self.binds[bindparam.key] = self.binds[name] = bindparam @@ -1037,7 +1040,7 @@ class SQLCompiler(Compiled): return processor(value) else: raise NotImplementedError( - "Don't know how to literal-quote value %r" % value) + "Don't know how to literal-quote value %r" % value) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: @@ -1061,7 +1064,7 @@ class SQLCompiler(Compiled): if len(anonname) > self.label_length: counter = self.truncated_names.get(ident_class, 1) truncname = anonname[0:max(self.label_length - 6, 0)] + \ - "_" + hex(counter)[2:] + "_" + hex(counter)[2:] self.truncated_names[ident_class] = counter + 1 else: truncname = anonname @@ -1086,8 +1089,8 @@ class SQLCompiler(Compiled): return self.bindtemplate % {'name': name} def visit_cte(self, cte, asfrom=False, ashint=False, - fromhints=None, - **kwargs): + fromhints=None, + **kwargs): self._init_cte_state() if isinstance(cte.name, elements._truncated_label): @@ -1108,9 +1111,9 @@ class SQLCompiler(Compiled): del self.ctes[existing_cte] else: raise exc.CompileError( - "Multiple, unrelated CTEs found with " - "the same name: %r" % - cte_name) + "Multiple, unrelated CTEs found with " + "the same name: %r" % + cte_name) self.ctes_by_name[cte_name] = cte @@ -1120,7 +1123,8 @@ class SQLCompiler(Compiled): self.visit_cte(orig_cte) cte_alias_name = cte._cte_alias.name if isinstance(cte_alias_name, elements._truncated_label): - cte_alias_name = self._truncated_identifier("alias", cte_alias_name) + cte_alias_name = self._truncated_identifier( + "alias", cte_alias_name) else: orig_cte = cte cte_alias_name = None @@ -1136,20 +1140,20 @@ class SQLCompiler(Compiled): else: assert False recur_cols = [c for c in - util.unique_list(col_source.inner_columns) - if c is not None] + util.unique_list(col_source.inner_columns) + if c is not None] text += "(%s)" % (", ".join( - self.preparer.format_column(ident) - for ident in recur_cols)) + self.preparer.format_column(ident) + for ident in recur_cols)) if self.positional: kwargs['positional_names'] = self.cte_positional[cte] = [] text += " AS \n" + \ - cte.original._compiler_dispatch( - self, asfrom=True, **kwargs - ) + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) self.ctes[cte] = text @@ -1162,8 +1166,8 @@ class SQLCompiler(Compiled): return text def visit_alias(self, alias, asfrom=False, ashint=False, - iscrud=False, - fromhints=None, **kwargs): + iscrud=False, + fromhints=None, **kwargs): if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) @@ -1174,13 +1178,13 @@ class SQLCompiler(Compiled): return self.preparer.format_alias(alias, alias_name) elif asfrom: ret = alias.original._compiler_dispatch(self, - asfrom=True, **kwargs) + \ - " AS " + \ - self.preparer.format_alias(alias, alias_name) + asfrom=True, **kwargs) + \ + " AS " + \ + self.preparer.format_alias(alias, alias_name) if fromhints and alias in fromhints: ret = self.format_from_hint_text(ret, alias, - fromhints[alias], iscrud) + fromhints[alias], iscrud) return ret else: @@ -1201,19 +1205,19 @@ class SQLCompiler(Compiled): self.result_map[keyname] = name, objects, type_ def _label_select_column(self, select, column, - populate_result_map, - asfrom, column_clause_args, - name=None, - within_columns_clause=True): + populate_result_map, + asfrom, column_clause_args, + name=None, + within_columns_clause=True): """produce labeled columns present in a select().""" if column.type._has_column_expression and \ populate_result_map: col_expr = column.type.column_expression(column) add_to_result_map = lambda keyname, name, objects, type_: \ - self._add_to_result_map( - keyname, name, - objects + (column,), type_) + self._add_to_result_map( + keyname, name, + objects + (column,), type_) else: col_expr = column if populate_result_map: @@ -1226,19 +1230,19 @@ class SQLCompiler(Compiled): elif isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( - col_expr, - column.name, - alt_names=(column.element,) - ) + col_expr, + column.name, + alt_names=(column.element,) + ) else: result_expr = col_expr elif select is not None and name: result_expr = _CompileLabel( - col_expr, - name, - alt_names=(column._key_label,) - ) + col_expr, + name, + alt_names=(column._key_label,) + ) elif \ asfrom and \ @@ -1247,30 +1251,30 @@ class SQLCompiler(Compiled): column.table is not None and \ not isinstance(column.table, selectable.Select): result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + elements._as_truncated(column.name), + alt_names=(column.key,)) elif not isinstance(column, - (elements.UnaryExpression, elements.TextClause)) \ - and (not hasattr(column, 'name') or \ - isinstance(column, functions.Function)): + (elements.UnaryExpression, elements.TextClause)) \ + and (not hasattr(column, 'name') or + isinstance(column, functions.Function)): result_expr = _CompileLabel(col_expr, column.anon_label) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) result_expr = _CompileLabel(col_expr, - elements._as_truncated(column.name), - alt_names=(column.key,)) + elements._as_truncated(column.name), + alt_names=(column.key,)) else: result_expr = col_expr column_clause_args.update( - within_columns_clause=within_columns_clause, - add_to_result_map=add_to_result_map - ) + within_columns_clause=within_columns_clause, + add_to_result_map=add_to_result_map + ) return result_expr._compiler_dispatch( - self, - **column_clause_args - ) + self, + **column_clause_args + ) def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) @@ -1307,7 +1311,7 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() if newelem.is_selectable and newelem._is_join and \ - isinstance(newelem.right, selectable.FromGrouping): + isinstance(newelem.right, selectable.FromGrouping): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) @@ -1376,24 +1380,24 @@ class SQLCompiler(Compiled): return visit(select) - def _transform_result_map_for_nested_joins(self, select, transformed_select): + def _transform_result_map_for_nested_joins( + self, select, transformed_select): inner_col = dict((c._key_label, c) for - c in transformed_select.inner_columns) + c in transformed_select.inner_columns) d = dict( - (inner_col[c._key_label], c) - for c in select.inner_columns - ) + (inner_col[c._key_label], c) + for c in select.inner_columns + ) for key, (name, objs, typ) in list(self.result_map.items()): objs = tuple([d.get(col, col) for col in objs]) self.result_map[key] = (name, objs, typ) - _default_stack_entry = util.immutabledict([ - ('iswrapper', False), - ('correlate_froms', frozenset()), - ('asfrom_froms', frozenset()) - ]) + ('iswrapper', False), + ('correlate_froms', frozenset()), + ('asfrom_froms', frozenset()) + ]) def _display_froms_for_select(self, select, asfrom): # utility method to help external dialects @@ -1408,53 +1412,53 @@ class SQLCompiler(Compiled): if asfrom: froms = select._get_display_froms( - explicit_correlate_froms=\ - correlate_froms.difference(asfrom_froms), - implicit_correlate_froms=()) + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms), + implicit_correlate_froms=()) else: froms = select._get_display_froms( - explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) return froms def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, fromhints=None, - compound_index=0, - force_result_map=False, - nested_join_translation=False, - **kwargs): + iswrapper=False, fromhints=None, + compound_index=0, + force_result_map=False, + nested_join_translation=False, + **kwargs): needs_nested_translation = \ - select.use_labels and \ - not nested_join_translation and \ - not self.stack and \ - not self.dialect.supports_right_nested_joins + select.use_labels and \ + not nested_join_translation and \ + not self.stack and \ + not self.dialect.supports_right_nested_joins if needs_nested_translation: - transformed_select = self._transform_select_for_nested_joins(select) + transformed_select = self._transform_select_for_nested_joins( + select) text = self.visit_select( - transformed_select, asfrom=asfrom, parens=parens, - iswrapper=iswrapper, fromhints=fromhints, - compound_index=compound_index, - force_result_map=force_result_map, - nested_join_translation=True, **kwargs - ) + transformed_select, asfrom=asfrom, parens=parens, + iswrapper=iswrapper, fromhints=fromhints, + compound_index=compound_index, + force_result_map=force_result_map, + nested_join_translation=True, **kwargs + ) toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = force_result_map or ( - compound_index == 0 and ( - toplevel or \ - entry['iswrapper'] - ) - ) + compound_index == 0 and ( + toplevel or + entry['iswrapper'] + ) + ) if needs_nested_translation: if populate_result_map: self._transform_result_map_for_nested_joins( - select, transformed_select) + select, transformed_select) return text correlate_froms = entry['correlate_froms'] @@ -1462,48 +1466,49 @@ class SQLCompiler(Compiled): if asfrom: froms = select._get_display_froms( - explicit_correlate_froms= - correlate_froms.difference(asfrom_froms), - implicit_correlate_froms=()) + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms), + implicit_correlate_froms=()) else: froms = select._get_display_froms( - explicit_correlate_froms=correlate_froms, - implicit_correlate_froms=asfrom_froms) + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms) new_correlate_froms = set(selectable._from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) new_entry = { - 'asfrom_froms': new_correlate_froms, - 'iswrapper': iswrapper, - 'correlate_froms': all_correlate_froms - } + 'asfrom_froms': new_correlate_froms, + 'iswrapper': iswrapper, + 'correlate_froms': all_correlate_froms + } self.stack.append(new_entry) column_clause_args = kwargs.copy() column_clause_args.update({ - 'within_label_clause': False, - 'within_columns_clause': False - }) + 'within_label_clause': False, + 'within_columns_clause': False + }) text = "SELECT " # we're off to a good start ! if select._hints: byfrom = dict([ - (from_, hinttext % { - 'name':from_._compiler_dispatch( - self, ashint=True) - }) - for (from_, dialect), hinttext in - select._hints.items() - if dialect in ('*', self.dialect.name) - ]) + (from_, hinttext % { + 'name': from_._compiler_dispatch( + self, ashint=True) + }) + for (from_, dialect), hinttext in + select._hints.items() + if dialect in ('*', self.dialect.name) + ]) hint_text = self.get_select_hint_text(byfrom) if hint_text: text += hint_text + " " if select._prefixes: - text += self._generate_prefixes(select, select._prefixes, **kwargs) + text += self._generate_prefixes( + select, select._prefixes, **kwargs) text += self.get_select_precolumns(select) @@ -1511,12 +1516,12 @@ class SQLCompiler(Compiled): inner_columns = [ c for c in [ self._label_select_column(select, - column, - populate_result_map, asfrom, - column_clause_args, - name=name) + column, + populate_result_map, asfrom, + column_clause_args, + name=name) for name, column in select._columns_plus_names - ] + ] if c is not None ] @@ -1526,14 +1531,14 @@ class SQLCompiler(Compiled): text += " \nFROM " if select._hints: - text += ', '.join([f._compiler_dispatch(self, - asfrom=True, fromhints=byfrom, - **kwargs) - for f in froms]) + text += ', '.join( + [f._compiler_dispatch(self, asfrom=True, + fromhints=byfrom, **kwargs) + for f in froms]) else: - text += ', '.join([f._compiler_dispatch(self, - asfrom=True, **kwargs) - for f in froms]) + text += ', '.join( + [f._compiler_dispatch(self, asfrom=True, **kwargs) + for f in froms]) else: text += self.default_from() @@ -1544,7 +1549,7 @@ class SQLCompiler(Compiled): if select._group_by_clause.clauses: group_by = select._group_by_clause._compiler_dispatch( - self, **kwargs) + self, **kwargs) if group_by: text += " GROUP BY " + group_by @@ -1559,17 +1564,18 @@ class SQLCompiler(Compiled): else: order_by_select = None - text += self.order_by_clause(select, - order_by_select=order_by_select, **kwargs) + text += self.order_by_clause( + select, order_by_select=order_by_select, **kwargs) - if select._limit_clause is not None or select._offset_clause is not None: + if (select._limit_clause is not None or + select._offset_clause is not None): text += self.limit_clause(select) if select._for_update_arg is not None: text += self.for_update_clause(select) if self.ctes and \ - compound_index == 0 and toplevel: + compound_index == 0 and toplevel: text = self._render_cte_clause() + text self.stack.pop(-1) @@ -1581,11 +1587,11 @@ class SQLCompiler(Compiled): def _generate_prefixes(self, stmt, prefixes, **kw): clause = " ".join( - prefix._compiler_dispatch(self, **kw) - for prefix, dialect_name in prefixes - if dialect_name is None or - dialect_name == self.dialect.name - ) + prefix._compiler_dispatch(self, **kw) + for prefix, dialect_name in prefixes + if dialect_name is None or + dialect_name == self.dialect.name + ) if clause: clause += " " return clause @@ -1593,9 +1599,9 @@ class SQLCompiler(Compiled): def _render_cte_clause(self): if self.positional: self.positiontup = sum([ - self.cte_positional[cte] - for cte in self.ctes], []) + \ - self.positiontup + self.cte_positional[cte] + for cte in self.ctes], []) + \ + self.positiontup cte_text = self.get_cte_preamble(self.ctes_recursive) + " " cte_text += ", \n".join( [txt for txt in self.ctes.values()] @@ -1628,8 +1634,8 @@ class SQLCompiler(Compiled): def returning_clause(self, stmt, returning_cols): raise exc.CompileError( - "RETURNING is not supported by this " - "dialect's statement compiler.") + "RETURNING is not supported by this " + "dialect's statement compiler.") def limit_clause(self, select): text = "" @@ -1642,16 +1648,16 @@ class SQLCompiler(Compiled): return text def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, **kwargs): + fromhints=None, **kwargs): if asfrom or ashint: if getattr(table, "schema", None): ret = self.preparer.quote_schema(table.schema) + \ - "." + self.preparer.quote(table.name) + "." + self.preparer.quote(table.name) else: ret = self.preparer.quote(table.name) if fromhints and table in fromhints: ret = self.format_from_hint_text(ret, table, - fromhints[table], iscrud) + fromhints[table], iscrud) return ret else: return "" @@ -1673,21 +1679,21 @@ class SQLCompiler(Compiled): not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError("The '%s' dialect with current database " - "version settings does not support empty " - "inserts." % - self.dialect.name) + "version settings does not support empty " + "inserts." % + self.dialect.name) if insert_stmt._has_multi_parameters: if not self.dialect.supports_multivalues_insert: - raise exc.CompileError("The '%s' dialect with current database " - "version settings does not support " - "in-place multirow inserts." % - self.dialect.name) + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % + self.dialect.name) colparams_single = colparams[0] else: colparams_single = colparams - preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -1695,7 +1701,7 @@ class SQLCompiler(Compiled): if insert_stmt._prefixes: text += self._generate_prefixes(insert_stmt, - insert_stmt._prefixes, **kw) + insert_stmt._prefixes, **kw) text += "INTO " table_text = preparer.format_table(insert_stmt.table) @@ -1709,22 +1715,22 @@ class SQLCompiler(Compiled): ]) if insert_stmt.table in dialect_hints: table_text = self.format_from_hint_text( - table_text, - insert_stmt.table, - dialect_hints[insert_stmt.table], - True - ) + table_text, + insert_stmt.table, + dialect_hints[insert_stmt.table], + True + ) text += table_text if colparams_single or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in colparams_single]) + for c in colparams_single]) if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning returning_clause = self.returning_clause( - insert_stmt, self.returning) + insert_stmt, self.returning) if self.returning_precedes_values: text += " " + returning_clause @@ -1735,16 +1741,16 @@ class SQLCompiler(Compiled): text += " DEFAULT VALUES" elif insert_stmt._has_multi_parameters: text += " VALUES %s" % ( - ", ".join( - "(%s)" % ( - ', '.join(c[1] for c in colparam_set) - ) - for colparam_set in colparams - ) - ) + ", ".join( + "(%s)" % ( + ', '.join(c[1] for c in colparam_set) + ) + for colparam_set in colparams + ) + ) else: text += " VALUES (%s)" % \ - ', '.join([c[1] for c in colparams]) + ', '.join([c[1] for c in colparams]) if self.returning and not self.returning_precedes_values: text += " " + returning_clause @@ -1756,7 +1762,7 @@ class SQLCompiler(Compiled): return None def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): + extra_froms, **kw): """Provide a hook to override the initial table clause in an UPDATE statement. @@ -1764,12 +1770,12 @@ class SQLCompiler(Compiled): """ return from_table._compiler_dispatch(self, asfrom=True, - iscrud=True, **kw) + iscrud=True, **kw) def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + from_table, extra_froms, + from_hints, + **kw): """Provide a hook to override the generation of an UPDATE..FROM clause. @@ -1777,15 +1783,15 @@ class SQLCompiler(Compiled): """ return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) + for t in extra_froms) def visit_update(self, update_stmt, **kw): self.stack.append( - {'correlate_froms': set([update_stmt.table]), - "iswrapper": False, - "asfrom_froms": set([update_stmt.table])}) + {'correlate_froms': set([update_stmt.table]), + "iswrapper": False, + "asfrom_froms": set([update_stmt.table])}) self.isupdate = True @@ -1795,7 +1801,7 @@ class SQLCompiler(Compiled): if update_stmt._prefixes: text += self._generate_prefixes(update_stmt, - update_stmt._prefixes, **kw) + update_stmt._prefixes, **kw) table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) @@ -1811,11 +1817,11 @@ class SQLCompiler(Compiled): ]) if update_stmt.table in dialect_hints: table_text = self.format_from_hint_text( - table_text, - update_stmt.table, - dialect_hints[update_stmt.table], - True - ) + table_text, + update_stmt.table, + dialect_hints[update_stmt.table], + True + ) else: dialect_hints = None @@ -1823,26 +1829,26 @@ class SQLCompiler(Compiled): text += ' SET ' include_table = extra_froms and \ - self.render_table_with_column_in_update_from + self.render_table_with_column_in_update_from text += ', '.join( - c[0]._compiler_dispatch(self, - include_table=include_table) + - '=' + c[1] for c in colparams - ) + c[0]._compiler_dispatch(self, + include_table=include_table) + + '=' + c[1] for c in colparams + ) if self.returning or update_stmt._returning: if not self.returning: self.returning = update_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning) if extra_froms: extra_from_text = self.update_from_clause( - update_stmt, - update_stmt.table, - extra_froms, - dialect_hints, **kw) + update_stmt, + update_stmt.table, + extra_froms, + dialect_hints, **kw) if extra_from_text: text += " " + extra_from_text @@ -1857,7 +1863,7 @@ class SQLCompiler(Compiled): if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning) self.stack.pop(-1) @@ -1867,7 +1873,7 @@ class SQLCompiler(Compiled): if name is None: name = col.key bindparam = elements.BindParameter(name, value, - type_=col.type, required=required) + type_=col.type, required=required) bindparam._is_crud = True return bindparam._compiler_dispatch(self) @@ -1881,17 +1887,20 @@ class SQLCompiler(Compiled): # allowing the most compatibility with a non-multi-table # statement. _et = set(self.statement._extra_froms) + def _column_as_key(key): str_key = elements._column_as_key(key) if hasattr(key, 'table') and key.table in _et: return (key.table.name, str_key) else: return str_key + def _getattr_col_key(col): if col.table in _et: return (col.table.name, col.key) else: return col.key + def _col_bind_name(col): if col.table in _et: return "%s_%s" % (col.table.name, col.key) @@ -1923,10 +1932,10 @@ class SQLCompiler(Compiled): # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [ - (c, self._create_crud_bind_param(c, - None, required=True)) - for c in stmt.table.columns - ] + (c, self._create_crud_bind_param(c, + None, required=True)) + for c in stmt.table.columns + ] if stmt._has_multi_parameters: stmt_parameters = stmt.parameters[0] @@ -1937,7 +1946,7 @@ class SQLCompiler(Compiled): # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account _column_as_key, _getattr_col_key, _col_bind_name = \ - self._key_getters_for_crud_column + self._key_getters_for_crud_column # if we have statement parameters - set defaults in the # compiled params @@ -1963,26 +1972,27 @@ class SQLCompiler(Compiled): # coercing right side to bound param if elements._is_literal(v): v = self.process( - elements.BindParameter(None, v, type_=k.type), - **kw) + elements.BindParameter(None, v, type_=k.type), + **kw) else: v = self.process(v.self_group(), **kw) values.append((k, v)) need_pks = self.isinsert and \ - not self.inline and \ - not stmt._returning + not self.inline and \ + not stmt._returning implicit_returning = need_pks and \ - self.dialect.implicit_returning and \ - stmt.table.implicit_returning + self.dialect.implicit_returning and \ + stmt.table.implicit_returning if self.isinsert: - implicit_return_defaults = implicit_returning and stmt._return_defaults + implicit_return_defaults = (implicit_returning and + stmt._return_defaults) elif self.isupdate: - implicit_return_defaults = self.dialect.implicit_returning and \ - stmt.table.implicit_returning and \ - stmt._return_defaults + implicit_return_defaults = (self.dialect.implicit_returning and + stmt.table.implicit_returning and + stmt._return_defaults) else: implicit_return_defaults = False @@ -2025,20 +2035,21 @@ class SQLCompiler(Compiled): for c in t.c: if c in normalized_params: continue - elif c.onupdate is not None and not c.onupdate.is_sequence: + elif (c.onupdate is not None and not + c.onupdate.is_sequence): if c.onupdate.is_clause_element: values.append( (c, self.process( - c.onupdate.arg.self_group(), - **kw) - ) + c.onupdate.arg.self_group(), + **kw) + ) ) self.postfetch.append(c) else: values.append( (c, self._create_crud_bind_param( - c, None, name=_col_bind_name(c) - ) + c, None, name=_col_bind_name(c) + ) ) ) self.prefetch.append(c) @@ -2049,7 +2060,7 @@ class SQLCompiler(Compiled): # for an insert from select, we can only use names that # are given, so only select for those names. cols = (stmt.table.c[_column_as_key(name)] - for name in stmt.select_names) + for name in stmt.select_names) else: # iterate through all table columns to maintain # ordering, even for those cols that aren't included @@ -2061,14 +2072,14 @@ class SQLCompiler(Compiled): value = parameters.pop(col_key) if elements._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is REQUIRED, - name=_col_bind_name(c) - if not stmt._has_multi_parameters - else "%s_0" % _col_bind_name(c) - ) + c, value, required=value is REQUIRED, + name=_col_bind_name(c) + if not stmt._has_multi_parameters + else "%s_0" % _col_bind_name(c) + ) else: if isinstance(value, elements.BindParameter) and \ - value.type._isnull: + value.type._isnull: value = value._clone() value.type = c.type @@ -2076,7 +2087,7 @@ class SQLCompiler(Compiled): self.returning.append(c) value = self.process(value.self_group(), **kw) elif implicit_return_defaults and \ - c in implicit_return_defaults: + c in implicit_return_defaults: self.returning.append(c) value = self.process(value.self_group(), **kw) else: @@ -2086,26 +2097,26 @@ class SQLCompiler(Compiled): elif self.isinsert: if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): + need_pks and \ + ( + implicit_returning or + not postfetch_lastrowid or + c is not stmt.table._autoincrement_column + ): if implicit_returning: if c.default is not None: if c.default.is_sequence: if self.dialect.supports_sequences and \ - (not c.default.optional or \ - not self.dialect.sequences_optional): + (not c.default.optional or + not self.dialect.sequences_optional): proc = self.process(c.default, **kw) values.append((c, proc)) self.returning.append(c) elif c.default.is_clause_element: values.append( - (c, - self.process(c.default.arg.self_group(), **kw)) + (c, self.process( + c.default.arg.self_group(), **kw)) ) self.returning.append(c) else: @@ -2117,16 +2128,14 @@ class SQLCompiler(Compiled): self.returning.append(c) else: if ( - c.default is not None and - ( - not c.default.is_sequence or - self.dialect.supports_sequences - ) - ) or \ - c is stmt.table._autoincrement_column and ( - self.dialect.supports_sequences or - self.dialect.preexecute_autoincrement_sequences - ): + (c.default is not None and + (not c.default.is_sequence or + self.dialect.supports_sequences)) or + c is stmt.table._autoincrement_column and + (self.dialect.supports_sequences or + self.dialect. + preexecute_autoincrement_sequences) + ): values.append( (c, self._create_crud_bind_param(c, None)) @@ -2137,22 +2146,23 @@ class SQLCompiler(Compiled): elif c.default is not None: if c.default.is_sequence: if self.dialect.supports_sequences and \ - (not c.default.optional or \ - not self.dialect.sequences_optional): + (not c.default.optional or + not self.dialect.sequences_optional): proc = self.process(c.default, **kw) values.append((c, proc)) if implicit_return_defaults and \ - c in implicit_return_defaults: + c in implicit_return_defaults: self.returning.append(c) elif not c.primary_key: self.postfetch.append(c) elif c.default.is_clause_element: values.append( - (c, self.process(c.default.arg.self_group(), **kw)) + (c, self.process( + c.default.arg.self_group(), **kw)) ) if implicit_return_defaults and \ - c in implicit_return_defaults: + c in implicit_return_defaults: self.returning.append(c) elif not c.primary_key: # don't add primary key column to postfetch @@ -2164,22 +2174,23 @@ class SQLCompiler(Compiled): self.prefetch.append(c) elif c.server_default is not None: if implicit_return_defaults and \ - c in implicit_return_defaults: + c in implicit_return_defaults: self.returning.append(c) elif not c.primary_key: self.postfetch.append(c) elif implicit_return_defaults and \ c in implicit_return_defaults: - self.returning.append(c) + self.returning.append(c) elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group(), **kw)) + (c, self.process( + c.onupdate.arg.self_group(), **kw)) ) if implicit_return_defaults and \ - c in implicit_return_defaults: + c in implicit_return_defaults: self.returning.append(c) else: self.postfetch.append(c) @@ -2190,13 +2201,13 @@ class SQLCompiler(Compiled): self.prefetch.append(c) elif c.server_onupdate is not None: if implicit_return_defaults and \ - c in implicit_return_defaults: + c in implicit_return_defaults: self.returning.append(c) else: self.postfetch.append(c) elif implicit_return_defaults and \ c in implicit_return_defaults: - self.returning.append(c) + self.returning.append(c) if parameters and stmt_parameters: check = set(parameters).intersection( @@ -2216,13 +2227,13 @@ class SQLCompiler(Compiled): [ ( c, - (self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) if elements._is_literal(row[c.key]) - else self.process( - row[c.key].self_group(), **kw)) - if c.key in row else param + (self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) if elements._is_literal(row[c.key]) + else self.process( + row[c.key].self_group(), **kw)) + if c.key in row else param ) for (c, param) in values_0 ] @@ -2233,19 +2244,19 @@ class SQLCompiler(Compiled): def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), - "iswrapper": False, - "asfrom_froms": set([delete_stmt.table])}) + "iswrapper": False, + "asfrom_froms": set([delete_stmt.table])}) self.isdelete = True text = "DELETE " if delete_stmt._prefixes: text += self._generate_prefixes(delete_stmt, - delete_stmt._prefixes, **kw) + delete_stmt._prefixes, **kw) text += "FROM " - table_text = delete_stmt.table._compiler_dispatch(self, - asfrom=True, iscrud=True) + table_text = delete_stmt.table._compiler_dispatch( + self, asfrom=True, iscrud=True) if delete_stmt._hints: dialect_hints = dict([ @@ -2256,11 +2267,11 @@ class SQLCompiler(Compiled): ]) if delete_stmt.table in dialect_hints: table_text = self.format_from_hint_text( - table_text, - delete_stmt.table, - dialect_hints[delete_stmt.table], - True - ) + table_text, + delete_stmt.table, + dialect_hints[delete_stmt.table], + True + ) else: dialect_hints = None @@ -2271,7 +2282,7 @@ class SQLCompiler(Compiled): self.returning = delete_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning) if delete_stmt._whereclause is not None: t = delete_stmt._whereclause._compiler_dispatch(self) @@ -2280,7 +2291,7 @@ class SQLCompiler(Compiled): if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning) + delete_stmt, delete_stmt._returning) self.stack.pop(-1) @@ -2291,11 +2302,11 @@ class SQLCompiler(Compiled): def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TO SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + self.preparer.format_savepoint(savepoint_stmt) def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + self.preparer.format_savepoint(savepoint_stmt) class DDLCompiler(Compiled): @@ -2349,11 +2360,11 @@ class DDLCompiler(Compiled): table = create.element preparer = self.dialect.identifier_preparer - text = "\n" + " ".join(['CREATE'] + \ - table._prefixes + \ - ['TABLE', - preparer.format_table(table), - "("]) + text = "\n" + " ".join(['CREATE'] + + table._prefixes + + ['TABLE', + preparer.format_table(table), + "("]) separator = "\n" # if only one primary key, specify it along with the column @@ -2362,8 +2373,8 @@ class DDLCompiler(Compiled): column = create_column.element try: processed = self.process(create_column, - first_pk=column.primary_key - and not first_pk) + first_pk=column.primary_key + and not first_pk) if processed is not None: text += separator separator = ", \n" @@ -2372,11 +2383,10 @@ class DDLCompiler(Compiled): first_pk = True except exc.CompileError as ce: util.raise_from_cause( - exc.CompileError(util.u("(in table '%s', column '%s'): %s") % ( - table.description, - column.name, - ce.args[0] - ))) + exc.CompileError( + util.u("(in table '%s', column '%s'): %s") % + (table.description, column.name, ce.args[0]) + )) const = self.create_table_constraints(table) if const: @@ -2392,11 +2402,11 @@ class DDLCompiler(Compiled): return None text = self.get_column_specification( - column, - first_pk=first_pk - ) - const = " ".join(self.process(constraint) \ - for constraint in column.constraints) + column, + first_pk=first_pk + ) + const = " ".join(self.process(constraint) + for constraint in column.constraints) if const: text += " " + const @@ -2411,19 +2421,19 @@ class DDLCompiler(Compiled): constraints.append(table.primary_key) constraints.extend([c for c in table._sorted_constraints - if c is not table.primary_key]) + if c is not table.primary_key]) return ", \n\t".join(p for p in - (self.process(constraint) - for constraint in constraints - if ( - constraint._create_rule is None or - constraint._create_rule(self)) - and ( - not self.dialect.supports_alter or - not getattr(constraint, 'use_alter', False) - )) if p is not None - ) + (self.process(constraint) + for constraint in constraints + if ( + constraint._create_rule is None or + constraint._create_rule(self)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None + ) def visit_drop_table(self, drop): return "\nDROP TABLE " + self.preparer.format_table(drop.element) @@ -2431,15 +2441,13 @@ class DDLCompiler(Compiled): def visit_drop_view(self, drop): return "\nDROP VIEW " + self.preparer.format_table(drop.element) - def _verify_index_table(self, index): if index.table is None: raise exc.CompileError("Index '%s' is not associated " - "with any table." % index.name) - + "with any table." % index.name) def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + include_table_schema=True): index = create.element self._verify_index_table(index) preparer = self.preparer @@ -2447,22 +2455,22 @@ class DDLCompiler(Compiled): if index.unique: text += "UNIQUE " text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table, - use_schema=include_table_schema), - ', '.join( - self.sql_compiler.process(expr, - include_table=False, literal_binds=True) for - expr in index.expressions) - ) + % ( + self._prepared_index_name(index, + include_schema=include_schema), + preparer.format_table(index.table, + use_schema=include_table_schema), + ', '.join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True) for + expr in index.expressions) + ) return text def visit_drop_index(self, drop): index = drop.element - return "\nDROP INDEX " + self._prepared_index_name(index, - include_schema=True) + return "\nDROP INDEX " + self._prepared_index_name( + index, include_schema=True) def _prepared_index_name(self, index, include_schema=False): if include_schema and index.table is not None and index.table.schema: @@ -2474,10 +2482,10 @@ class DDLCompiler(Compiled): ident = index.name if isinstance(ident, elements._truncated_label): max_ = self.dialect.max_index_name_length or \ - self.dialect.max_identifier_length + self.dialect.max_identifier_length if len(ident) > max_: ident = ident[0:max_ - 8] + \ - "_" + util.md5_hex(ident)[-4:] + "_" + util.md5_hex(ident)[-4:] else: self.dialect.validate_identifier(ident) @@ -2495,7 +2503,7 @@ class DDLCompiler(Compiled): def visit_create_sequence(self, create): text = "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + self.preparer.format_sequence(create.element) if create.element.increment is not None: text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: @@ -2504,7 +2512,7 @@ class DDLCompiler(Compiled): def visit_drop_sequence(self, drop): return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( @@ -2515,7 +2523,7 @@ class DDLCompiler(Compiled): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process(column.type) + self.dialect.type_compiler.process(column.type) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -2543,8 +2551,8 @@ class DDLCompiler(Compiled): if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext, - include_table=False, - literal_binds=True) + include_table=False, + literal_binds=True) text += self.define_constraint_deferrability(constraint) return text @@ -2568,7 +2576,7 @@ class DDLCompiler(Compiled): text += "CONSTRAINT %s " % formatted_name text += "PRIMARY KEY " text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + for c in constraint) text += self.define_constraint_deferrability(constraint) return text @@ -2607,7 +2615,7 @@ class DDLCompiler(Compiled): text += "CONSTRAINT %s " % formatted_name text += "UNIQUE (%s)" % ( ', '.join(self.preparer.quote(c.name) - for c in constraint)) + for c in constraint)) text += self.define_constraint_deferrability(constraint) return text @@ -2650,22 +2658,22 @@ class GenericTypeCompiler(TypeCompiler): return "NUMERIC" elif type_.scale is None: return "NUMERIC(%(precision)s)" % \ - {'precision': type_.precision} + {'precision': type_.precision} else: return "NUMERIC(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + {'precision': type_.precision, + 'scale': type_.scale} def visit_DECIMAL(self, type_): if type_.precision is None: return "DECIMAL" elif type_.scale is None: return "DECIMAL(%(precision)s)" % \ - {'precision': type_.precision} + {'precision': type_.precision} else: return "DECIMAL(%(precision)s, %(scale)s)" % \ - {'precision': type_.precision, - 'scale': type_.scale} + {'precision': type_.precision, + 'scale': type_.scale} def visit_INTEGER(self, type_): return "INTEGER" @@ -2780,8 +2788,8 @@ class GenericTypeCompiler(TypeCompiler): def visit_null(self, type_): raise exc.CompileError("Can't generate DDL for %r; " - "did you forget to specify a " - "type on this Column?" % type_) + "did you forget to specify a " + "type on this Column?" % type_) def visit_type_decorator(self, type_): return self.process(type_.type_engine(self.dialect)) @@ -2791,6 +2799,7 @@ class GenericTypeCompiler(TypeCompiler): class IdentifierPreparer(object): + """Handle quoting and case-folding of identifiers based on options.""" reserved_words = RESERVED_WORDS @@ -2800,7 +2809,7 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS def __init__(self, dialect, initial_quote='"', - final_quote=None, escape_quote='"', omit_schema=False): + final_quote=None, escape_quote='"', omit_schema=False): """Construct a new ``IdentifierPreparer`` object. initial_quote @@ -2849,8 +2858,8 @@ class IdentifierPreparer(object): """ return self.initial_quote + \ - self._escape_identifier(value) + \ - self.final_quote + self._escape_identifier(value) + \ + self.final_quote def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" @@ -2895,7 +2904,8 @@ class IdentifierPreparer(object): def format_sequence(self, sequence, use_schema=True): name = self.quote(sequence.name) - if not self.omit_schema and use_schema and sequence.schema is not None: + if (not self.omit_schema and use_schema and + sequence.schema is not None): name = self.quote_schema(sequence.schema) + "." + name return name @@ -2912,7 +2922,7 @@ class IdentifierPreparer(object): def format_constraint(self, naming, constraint): if isinstance(constraint.name, elements._defer_name): name = naming._constraint_name_for_table( - constraint, constraint.table) + constraint, constraint.table) if name: return self.quote(name) elif isinstance(constraint.name, elements._defer_none_name): @@ -2926,7 +2936,7 @@ class IdentifierPreparer(object): name = table.name result = self.quote(name) if not self.omit_schema and use_schema \ - and getattr(table, "schema", None): + and getattr(table, "schema", None): result = self.quote_schema(table.schema) + "." + result return result @@ -2936,7 +2946,7 @@ class IdentifierPreparer(object): return self.quote(name, quote) def format_column(self, column, use_table=False, - name=None, table_name=None): + name=None, table_name=None): """Prepare a quoted column name.""" if name is None: @@ -2944,8 +2954,8 @@ class IdentifierPreparer(object): if not getattr(column, 'is_literal', False): if use_table: return self.format_table( - column.table, use_schema=False, - name=table_name) + "." + self.quote(name) + column.table, use_schema=False, + name=table_name) + "." + self.quote(name) else: return self.quote(name) else: @@ -2953,8 +2963,9 @@ class IdentifierPreparer(object): # which shouldn't get quoted if use_table: - return self.format_table(column.table, - use_schema=False, name=table_name) + '.' + name + return self.format_table( + column.table, use_schema=False, + name=table_name) + '.' + name else: return name @@ -2975,9 +2986,9 @@ class IdentifierPreparer(object): @util.memoized_property def _r_identifiers(self): initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] r = re.compile( r'(?:' r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' |