summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorBrian Jarrett <celttechie@gmail.com>2014-07-20 12:44:40 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-07-20 12:44:40 -0400
commitcca03097f47f22783d42d1853faac6cf84607c5a (patch)
tree4fe1a63d03a2d88d1cf37e1167759dfaf84f4ce7 /lib/sqlalchemy/sql/compiler.py
parent827329a0cca5351094a1a86b6b2be2b9182f0ae2 (diff)
downloadsqlalchemy-cca03097f47f22783d42d1853faac6cf84607c5a.tar.gz
- apply pep8 formatting to sqlalchemy/sql, sqlalchemy/util, sqlalchemy/dialects,
sqlalchemy/orm, sqlalchemy/event, sqlalchemy/testing
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py1019
1 files changed, 515 insertions, 504 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 384cf27c2..ac45054ae 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -25,7 +25,7 @@ To generate user-defined SQL strings, see
import re
from . import schema, sqltypes, operators, functions, \
- util as sql_util, visitors, elements, selectable, base
+ util as sql_util, visitors, elements, selectable, base
from .. import util, exc
import decimal
import itertools
@@ -158,7 +158,9 @@ COMPOUND_KEYWORDS = {
selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
}
+
class Compiled(object):
+
"""Represent a compiled SQL or DDL expression.
The ``__str__`` method of the ``Compiled`` object should produce
@@ -174,7 +176,7 @@ class Compiled(object):
_cached_metadata = None
def __init__(self, dialect, statement, bind=None,
- compile_kwargs=util.immutabledict()):
+ compile_kwargs=util.immutabledict()):
"""Construct a new ``Compiled`` object.
:param dialect: ``Dialect`` to compile against.
@@ -199,7 +201,7 @@ class Compiled(object):
self.string = self.process(self.statement, **compile_kwargs)
@util.deprecated("0.7", ":class:`.Compiled` objects now compile "
- "within the constructor.")
+ "within the constructor.")
def compile(self):
"""Produce the internal string representation of this element.
"""
@@ -247,8 +249,8 @@ class Compiled(object):
e = self.bind
if e is None:
raise exc.UnboundExecutionError(
- "This Compiled object is not bound to any Engine "
- "or Connection.")
+ "This Compiled object is not bound to any Engine "
+ "or Connection.")
return e._execute_compiled(self, multiparams, params)
def scalar(self, *multiparams, **params):
@@ -259,6 +261,7 @@ class Compiled(object):
class TypeCompiler(object):
+
"""Produces DDL specification for TypeEngine objects."""
def __init__(self, dialect):
@@ -268,8 +271,8 @@ class TypeCompiler(object):
return type_._compiler_dispatch(self)
-
class _CompileLabel(visitors.Visitable):
+
"""lightweight label object which acts as an expression.Label."""
__visit_name__ = 'label'
@@ -290,6 +293,7 @@ class _CompileLabel(visitors.Visitable):
class SQLCompiler(Compiled):
+
"""Default implementation of Compiled.
Compiles ClauseElements into SQL strings. Uses a similar visit
@@ -333,7 +337,7 @@ class SQLCompiler(Compiled):
"""
def __init__(self, dialect, statement, column_keys=None,
- inline=False, **kwargs):
+ inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
dialect
@@ -412,19 +416,19 @@ class SQLCompiler(Compiled):
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
- r'\[_POSITION\]',
- lambda m: str(util.next(poscount)),
- self.string)
+ r'\[_POSITION\]',
+ lambda m: str(util.next(poscount)),
+ self.string)
@util.memoized_property
def _bind_processors(self):
return dict(
- (key, value) for key, value in
- ((self.bind_names[bindparam],
- bindparam.type._cached_bind_processor(self.dialect))
- for bindparam in self.bind_names)
- if value is not None
- )
+ (key, value) for key, value in
+ ((self.bind_names[bindparam],
+ bindparam.type._cached_bind_processor(self.dialect))
+ for bindparam in self.bind_names)
+ if value is not None
+ )
def is_subquery(self):
return len(self.stack) > 1
@@ -491,15 +495,16 @@ class SQLCompiler(Compiled):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label(self, label,
- add_to_result_map=None,
- within_label_clause=False,
- within_columns_clause=False,
- render_label_as_label=None,
- **kw):
+ add_to_result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False,
+ render_label_as_label=None,
+ **kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
- render_label_with_as = within_columns_clause and not within_label_clause
+ render_label_with_as = (within_columns_clause and not
+ within_label_clause)
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
@@ -511,27 +516,25 @@ class SQLCompiler(Compiled):
if render_label_with_as:
if add_to_result_map is not None:
add_to_result_map(
- labelname,
- label.name,
- (label, labelname, ) + label._alt_names,
- label.type
+ labelname,
+ label.name,
+ (label, labelname, ) + label._alt_names,
+ label.type
)
- return label.element._compiler_dispatch(self,
- within_columns_clause=True,
- within_label_clause=True,
- **kw) + \
- OPERATORS[operators.as_] + \
- self.preparer.format_label(label, labelname)
+ return label.element._compiler_dispatch(
+ self, within_columns_clause=True,
+ within_label_clause=True, **kw) + \
+ OPERATORS[operators.as_] + \
+ self.preparer.format_label(label, labelname)
elif render_label_only:
return self.preparer.format_label(label, labelname)
else:
- return label.element._compiler_dispatch(self,
- within_columns_clause=False,
- **kw)
+ return label.element._compiler_dispatch(
+ self, within_columns_clause=False, **kw)
def visit_column(self, column, add_to_result_map=None,
- include_table=True, **kwargs):
+ include_table=True, **kwargs):
name = orig_name = column.name
if name is None:
raise exc.CompileError("Cannot compile Column object until "
@@ -567,8 +570,8 @@ class SQLCompiler(Compiled):
tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + \
- self.preparer.quote(tablename) + \
- "." + name
+ self.preparer.quote(tablename) + \
+ "." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
@@ -597,37 +600,38 @@ class SQLCompiler(Compiled):
return self.bindparam_string(name, **kw)
# un-escape any \:params
- return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
- BIND_PARAMS.sub(do_bindparam,
- self.post_process_text(textclause.text))
+ return BIND_PARAMS_ESC.sub(
+ lambda m: m.group(1),
+ BIND_PARAMS.sub(
+ do_bindparam,
+ self.post_process_text(textclause.text))
)
def visit_text_as_from(self, taf, iswrapper=False,
- compound_index=0, force_result_map=False,
- asfrom=False,
- parens=True, **kw):
+ compound_index=0, force_result_map=False,
+ asfrom=False,
+ parens=True, **kw):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
populate_result_map = force_result_map or (
- compound_index == 0 and (
- toplevel or \
- entry['iswrapper']
- )
- )
+ compound_index == 0 and (
+ toplevel or
+ entry['iswrapper']
+ )
+ )
if populate_result_map:
for c in taf.column_args:
self.process(c, within_columns_clause=True,
- add_to_result_map=self._add_to_result_map)
+ add_to_result_map=self._add_to_result_map)
text = self.process(taf.element, **kw)
if asfrom and parens:
text = "(%s)" % text
return text
-
def visit_null(self, expr, **kw):
return 'NULL'
@@ -646,7 +650,7 @@ class SQLCompiler(Compiled):
def visit_clauselist(self, clauselist, order_by_select=None, **kw):
if order_by_select is not None:
return self._order_by_clauselist(
- clauselist, order_by_select, **kw)
+ clauselist, order_by_select, **kw)
sep = clauselist.operator
if sep is None:
@@ -654,11 +658,11 @@ class SQLCompiler(Compiled):
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
- s for s in
- (
- c._compiler_dispatch(self, **kw)
- for c in clauselist.clauses)
- if s)
+ s for s in
+ (
+ c._compiler_dispatch(self, **kw)
+ for c in clauselist.clauses)
+ if s)
def _order_by_clauselist(self, clauselist, order_by_select, **kw):
# look through raw columns collection for labels.
@@ -667,21 +671,21 @@ class SQLCompiler(Compiled):
# label expression in the columns clause.
raw_col = set(l._order_by_label_element.name
- for l in order_by_select._raw_columns
- if l._order_by_label_element is not None)
+ for l in order_by_select._raw_columns
+ if l._order_by_label_element is not None)
return ", ".join(
- s for s in
- (
- c._compiler_dispatch(self,
- render_label_as_label=
- c._order_by_label_element if
- c._order_by_label_element is not None and
- c._order_by_label_element.name in raw_col
- else None,
- **kw)
- for c in clauselist.clauses)
- if s)
+ s for s in
+ (
+ c._compiler_dispatch(
+ self,
+ render_label_as_label=c._order_by_label_element if
+ c._order_by_label_element is not None and
+ c._order_by_label_element.name in raw_col
+ else None,
+ **kw)
+ for c in clauselist.clauses)
+ if s)
def visit_case(self, clause, **kwargs):
x = "CASE "
@@ -689,38 +693,38 @@ class SQLCompiler(Compiled):
x += clause.value._compiler_dispatch(self, **kwargs) + " "
for cond, result in clause.whens:
x += "WHEN " + cond._compiler_dispatch(
- self, **kwargs
- ) + " THEN " + result._compiler_dispatch(
- self, **kwargs) + " "
+ self, **kwargs
+ ) + " THEN " + result._compiler_dispatch(
+ self, **kwargs) + " "
if clause.else_ is not None:
x += "ELSE " + clause.else_._compiler_dispatch(
- self, **kwargs
- ) + " "
+ self, **kwargs
+ ) + " "
x += "END"
return x
def visit_cast(self, cast, **kwargs):
return "CAST(%s AS %s)" % \
- (cast.clause._compiler_dispatch(self, **kwargs),
- cast.typeclause._compiler_dispatch(self, **kwargs))
+ (cast.clause._compiler_dispatch(self, **kwargs),
+ cast.typeclause._compiler_dispatch(self, **kwargs))
def visit_over(self, over, **kwargs):
return "%s OVER (%s)" % (
over.func._compiler_dispatch(self, **kwargs),
' '.join(
- '%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs))
- for word, clause in (
- ('PARTITION', over.partition_by),
- ('ORDER', over.order_by)
- )
- if clause is not None and len(clause)
+ '%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs))
+ for word, clause in (
+ ('PARTITION', over.partition_by),
+ ('ORDER', over.order_by)
+ )
+ if clause is not None and len(clause)
)
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
- return "EXTRACT(%s FROM %s)" % (field,
- extract.expr._compiler_dispatch(self, **kwargs))
+ return "EXTRACT(%s FROM %s)" % (
+ field, extract.expr._compiler_dispatch(self, **kwargs))
def visit_function(self, func, add_to_result_map=None, **kwargs):
if add_to_result_map is not None:
@@ -734,7 +738,7 @@ class SQLCompiler(Compiled):
else:
name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
return ".".join(list(func.packagenames) + [name]) % \
- {'expr': self.function_argspec(func, **kwargs)}
+ {'expr': self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
@@ -748,39 +752,38 @@ class SQLCompiler(Compiled):
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
-
def visit_compound_select(self, cs, asfrom=False,
- parens=True, compound_index=0, **kwargs):
+ parens=True, compound_index=0, **kwargs):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
self.stack.append(
- {
- 'correlate_froms': entry['correlate_froms'],
- 'iswrapper': toplevel,
- 'asfrom_froms': entry['asfrom_froms']
- })
+ {
+ 'correlate_froms': entry['correlate_froms'],
+ 'iswrapper': toplevel,
+ 'asfrom_froms': entry['asfrom_froms']
+ })
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (c._compiler_dispatch(self,
- asfrom=asfrom, parens=False,
- compound_index=i, **kwargs)
- for i, c in enumerate(cs.selects))
- )
+ (c._compiler_dispatch(self,
+ asfrom=asfrom, parens=False,
+ compound_index=i, **kwargs)
+ for i, c in enumerate(cs.selects))
+ )
group_by = cs._group_by_clause._compiler_dispatch(
- self, asfrom=asfrom, **kwargs)
+ self, asfrom=asfrom, **kwargs)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs, **kwargs)
text += (cs._limit_clause is not None or cs._offset_clause is not None) and \
- self.limit_clause(cs) or ""
+ self.limit_clause(cs) or ""
if self.ctes and \
- compound_index == 0 and toplevel:
+ compound_index == 0 and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -793,26 +796,26 @@ class SQLCompiler(Compiled):
if unary.operator:
if unary.modifier:
raise exc.CompileError(
- "Unary expression does not support operator "
- "and modifier simultaneously")
+ "Unary expression does not support operator "
+ "and modifier simultaneously")
disp = getattr(self, "visit_%s_unary_operator" %
- unary.operator.__name__, None)
+ unary.operator.__name__, None)
if disp:
return disp(unary, unary.operator, **kw)
else:
- return self._generate_generic_unary_operator(unary,
- OPERATORS[unary.operator], **kw)
+ return self._generate_generic_unary_operator(
+ unary, OPERATORS[unary.operator], **kw)
elif unary.modifier:
disp = getattr(self, "visit_%s_unary_modifier" %
- unary.modifier.__name__, None)
+ unary.modifier.__name__, None)
if disp:
return disp(unary, unary.modifier, **kw)
else:
- return self._generate_generic_unary_modifier(unary,
- OPERATORS[unary.modifier], **kw)
+ return self._generate_generic_unary_modifier(
+ unary, OPERATORS[unary.modifier], **kw)
else:
raise exc.CompileError(
- "Unary expression has no operator or modifier")
+ "Unary expression has no operator or modifier")
def visit_istrue_unary_operator(self, element, operator, **kw):
if self.dialect.supports_native_boolean:
@@ -829,8 +832,8 @@ class SQLCompiler(Compiled):
def visit_binary(self, binary, **kw):
# don't allow "? = ?" to render
if self.ansi_bind_rules and \
- isinstance(binary.left, elements.BindParameter) and \
- isinstance(binary.right, elements.BindParameter):
+ isinstance(binary.left, elements.BindParameter) and \
+ isinstance(binary.right, elements.BindParameter):
kw['literal_binds'] = True
operator = binary.operator
@@ -846,21 +849,21 @@ class SQLCompiler(Compiled):
return self._generate_generic_binary(binary, opstring, **kw)
def visit_custom_op_binary(self, element, operator, **kw):
- return self._generate_generic_binary(element,
- " " + operator.opstring + " ", **kw)
+ return self._generate_generic_binary(
+ element, " " + operator.opstring + " ", **kw)
def visit_custom_op_unary_operator(self, element, operator, **kw):
- return self._generate_generic_unary_operator(element,
- operator.opstring + " ", **kw)
+ return self._generate_generic_unary_operator(
+ element, operator.opstring + " ", **kw)
def visit_custom_op_unary_modifier(self, element, operator, **kw):
- return self._generate_generic_unary_modifier(element,
- " " + operator.opstring, **kw)
+ return self._generate_generic_unary_modifier(
+ element, " " + operator.opstring, **kw)
def _generate_generic_binary(self, binary, opstring, **kw):
return binary.left._compiler_dispatch(self, **kw) + \
- opstring + \
- binary.right._compiler_dispatch(self, **kw)
+ opstring + \
+ binary.right._compiler_dispatch(self, **kw)
def _generate_generic_unary_operator(self, unary, opstring, **kw):
return opstring + unary.element._compiler_dispatch(self, **kw)
@@ -888,16 +891,16 @@ class SQLCompiler(Compiled):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__radd__(
- binary.right
- )
+ binary.right
+ )
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notstartswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__radd__(
- binary.right
- )
+ binary.right
+ )
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_endswith_op_binary(self, binary, operator, **kw):
@@ -917,77 +920,77 @@ class SQLCompiler(Compiled):
# TODO: use ternary here, not "and"/ "or"
return '%s LIKE %s' % (
- binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return '%s NOT LIKE %s' % (
- binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
- binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
- binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
- )
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " BETWEEN SYMMETRIC "
- if symmetric else " BETWEEN ", **kw)
+ binary, " BETWEEN SYMMETRIC "
+ if symmetric else " BETWEEN ", **kw)
def visit_notbetween_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " NOT BETWEEN SYMMETRIC "
- if symmetric else " NOT BETWEEN ", **kw)
+ binary, " NOT BETWEEN SYMMETRIC "
+ if symmetric else " NOT BETWEEN ", **kw)
def visit_bindparam(self, bindparam, within_columns_clause=False,
- literal_binds=False,
- skip_bind_expression=False,
- **kwargs):
+ literal_binds=False,
+ skip_bind_expression=False,
+ **kwargs):
if not skip_bind_expression and bindparam.type._has_bind_expression:
bind_expression = bindparam.type.bind_expression(bindparam)
return self.process(bind_expression,
skip_bind_expression=True)
if literal_binds or \
- (within_columns_clause and \
+ (within_columns_clause and
self.ansi_bind_rules):
if bindparam.value is None and bindparam.callable is None:
raise exc.CompileError("Bind parameter '%s' without a "
- "renderable value not allowed here."
- % bindparam.key)
- return self.render_literal_bindparam(bindparam,
- within_columns_clause=True, **kwargs)
+ "renderable value not allowed here."
+ % bindparam.key)
+ return self.render_literal_bindparam(
+ bindparam, within_columns_clause=True, **kwargs)
name = self._truncate_bindparam(bindparam)
@@ -995,13 +998,13 @@ class SQLCompiler(Compiled):
existing = self.binds[name]
if existing is not bindparam:
if (existing.unique or bindparam.unique) and \
- not existing.proxy_set.intersection(
- bindparam.proxy_set):
+ not existing.proxy_set.intersection(
+ bindparam.proxy_set):
raise exc.CompileError(
- "Bind parameter '%s' conflicts with "
- "unique bind parameter of the same name" %
- bindparam.key
- )
+ "Bind parameter '%s' conflicts with "
+ "unique bind parameter of the same name" %
+ bindparam.key
+ )
elif existing._is_crud or bindparam._is_crud:
raise exc.CompileError(
"bindparam() name '%s' is reserved "
@@ -1009,8 +1012,8 @@ class SQLCompiler(Compiled):
"clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')."
- % (bindparam.key, bindparam.key)
+ "with insert() or update() (for example, 'b_%s')." %
+ (bindparam.key, bindparam.key)
)
self.binds[bindparam.key] = self.binds[name] = bindparam
@@ -1037,7 +1040,7 @@ class SQLCompiler(Compiled):
return processor(value)
else:
raise NotImplementedError(
- "Don't know how to literal-quote value %r" % value)
+ "Don't know how to literal-quote value %r" % value)
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
@@ -1061,7 +1064,7 @@ class SQLCompiler(Compiled):
if len(anonname) > self.label_length:
counter = self.truncated_names.get(ident_class, 1)
truncname = anonname[0:max(self.label_length - 6, 0)] + \
- "_" + hex(counter)[2:]
+ "_" + hex(counter)[2:]
self.truncated_names[ident_class] = counter + 1
else:
truncname = anonname
@@ -1086,8 +1089,8 @@ class SQLCompiler(Compiled):
return self.bindtemplate % {'name': name}
def visit_cte(self, cte, asfrom=False, ashint=False,
- fromhints=None,
- **kwargs):
+ fromhints=None,
+ **kwargs):
self._init_cte_state()
if isinstance(cte.name, elements._truncated_label):
@@ -1108,9 +1111,9 @@ class SQLCompiler(Compiled):
del self.ctes[existing_cte]
else:
raise exc.CompileError(
- "Multiple, unrelated CTEs found with "
- "the same name: %r" %
- cte_name)
+ "Multiple, unrelated CTEs found with "
+ "the same name: %r" %
+ cte_name)
self.ctes_by_name[cte_name] = cte
@@ -1120,7 +1123,8 @@ class SQLCompiler(Compiled):
self.visit_cte(orig_cte)
cte_alias_name = cte._cte_alias.name
if isinstance(cte_alias_name, elements._truncated_label):
- cte_alias_name = self._truncated_identifier("alias", cte_alias_name)
+ cte_alias_name = self._truncated_identifier(
+ "alias", cte_alias_name)
else:
orig_cte = cte
cte_alias_name = None
@@ -1136,20 +1140,20 @@ class SQLCompiler(Compiled):
else:
assert False
recur_cols = [c for c in
- util.unique_list(col_source.inner_columns)
- if c is not None]
+ util.unique_list(col_source.inner_columns)
+ if c is not None]
text += "(%s)" % (", ".join(
- self.preparer.format_column(ident)
- for ident in recur_cols))
+ self.preparer.format_column(ident)
+ for ident in recur_cols))
if self.positional:
kwargs['positional_names'] = self.cte_positional[cte] = []
text += " AS \n" + \
- cte.original._compiler_dispatch(
- self, asfrom=True, **kwargs
- )
+ cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
self.ctes[cte] = text
@@ -1162,8 +1166,8 @@ class SQLCompiler(Compiled):
return text
def visit_alias(self, alias, asfrom=False, ashint=False,
- iscrud=False,
- fromhints=None, **kwargs):
+ iscrud=False,
+ fromhints=None, **kwargs):
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
@@ -1174,13 +1178,13 @@ class SQLCompiler(Compiled):
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
ret = alias.original._compiler_dispatch(self,
- asfrom=True, **kwargs) + \
- " AS " + \
- self.preparer.format_alias(alias, alias_name)
+ asfrom=True, **kwargs) + \
+ " AS " + \
+ self.preparer.format_alias(alias, alias_name)
if fromhints and alias in fromhints:
ret = self.format_from_hint_text(ret, alias,
- fromhints[alias], iscrud)
+ fromhints[alias], iscrud)
return ret
else:
@@ -1201,19 +1205,19 @@ class SQLCompiler(Compiled):
self.result_map[keyname] = name, objects, type_
def _label_select_column(self, select, column,
- populate_result_map,
- asfrom, column_clause_args,
- name=None,
- within_columns_clause=True):
+ populate_result_map,
+ asfrom, column_clause_args,
+ name=None,
+ within_columns_clause=True):
"""produce labeled columns present in a select()."""
if column.type._has_column_expression and \
populate_result_map:
col_expr = column.type.column_expression(column)
add_to_result_map = lambda keyname, name, objects, type_: \
- self._add_to_result_map(
- keyname, name,
- objects + (column,), type_)
+ self._add_to_result_map(
+ keyname, name,
+ objects + (column,), type_)
else:
col_expr = column
if populate_result_map:
@@ -1226,19 +1230,19 @@ class SQLCompiler(Compiled):
elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
- col_expr,
- column.name,
- alt_names=(column.element,)
- )
+ col_expr,
+ column.name,
+ alt_names=(column.element,)
+ )
else:
result_expr = col_expr
elif select is not None and name:
result_expr = _CompileLabel(
- col_expr,
- name,
- alt_names=(column._key_label,)
- )
+ col_expr,
+ name,
+ alt_names=(column._key_label,)
+ )
elif \
asfrom and \
@@ -1247,30 +1251,30 @@ class SQLCompiler(Compiled):
column.table is not None and \
not isinstance(column.table, selectable.Select):
result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ elements._as_truncated(column.name),
+ alt_names=(column.key,))
elif not isinstance(column,
- (elements.UnaryExpression, elements.TextClause)) \
- and (not hasattr(column, 'name') or \
- isinstance(column, functions.Function)):
+ (elements.UnaryExpression, elements.TextClause)) \
+ and (not hasattr(column, 'name') or
+ isinstance(column, functions.Function)):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ elements._as_truncated(column.name),
+ alt_names=(column.key,))
else:
result_expr = col_expr
column_clause_args.update(
- within_columns_clause=within_columns_clause,
- add_to_result_map=add_to_result_map
- )
+ within_columns_clause=within_columns_clause,
+ add_to_result_map=add_to_result_map
+ )
return result_expr._compiler_dispatch(
- self,
- **column_clause_args
- )
+ self,
+ **column_clause_args
+ )
def format_from_hint_text(self, sqltext, table, hint, iscrud):
hinttext = self.get_from_hint_text(table, hint)
@@ -1307,7 +1311,7 @@ class SQLCompiler(Compiled):
newelem = cloned[element] = element._clone()
if newelem.is_selectable and newelem._is_join and \
- isinstance(newelem.right, selectable.FromGrouping):
+ isinstance(newelem.right, selectable.FromGrouping):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
@@ -1376,24 +1380,24 @@ class SQLCompiler(Compiled):
return visit(select)
- def _transform_result_map_for_nested_joins(self, select, transformed_select):
+ def _transform_result_map_for_nested_joins(
+ self, select, transformed_select):
inner_col = dict((c._key_label, c) for
- c in transformed_select.inner_columns)
+ c in transformed_select.inner_columns)
d = dict(
- (inner_col[c._key_label], c)
- for c in select.inner_columns
- )
+ (inner_col[c._key_label], c)
+ for c in select.inner_columns
+ )
for key, (name, objs, typ) in list(self.result_map.items()):
objs = tuple([d.get(col, col) for col in objs])
self.result_map[key] = (name, objs, typ)
-
_default_stack_entry = util.immutabledict([
- ('iswrapper', False),
- ('correlate_froms', frozenset()),
- ('asfrom_froms', frozenset())
- ])
+ ('iswrapper', False),
+ ('correlate_froms', frozenset()),
+ ('asfrom_froms', frozenset())
+ ])
def _display_froms_for_select(self, select, asfrom):
# utility method to help external dialects
@@ -1408,53 +1412,53 @@ class SQLCompiler(Compiled):
if asfrom:
froms = select._get_display_froms(
- explicit_correlate_froms=\
- correlate_froms.difference(asfrom_froms),
- implicit_correlate_froms=())
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms),
+ implicit_correlate_froms=())
else:
froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms)
return froms
def visit_select(self, select, asfrom=False, parens=True,
- iswrapper=False, fromhints=None,
- compound_index=0,
- force_result_map=False,
- nested_join_translation=False,
- **kwargs):
+ iswrapper=False, fromhints=None,
+ compound_index=0,
+ force_result_map=False,
+ nested_join_translation=False,
+ **kwargs):
needs_nested_translation = \
- select.use_labels and \
- not nested_join_translation and \
- not self.stack and \
- not self.dialect.supports_right_nested_joins
+ select.use_labels and \
+ not nested_join_translation and \
+ not self.stack and \
+ not self.dialect.supports_right_nested_joins
if needs_nested_translation:
- transformed_select = self._transform_select_for_nested_joins(select)
+ transformed_select = self._transform_select_for_nested_joins(
+ select)
text = self.visit_select(
- transformed_select, asfrom=asfrom, parens=parens,
- iswrapper=iswrapper, fromhints=fromhints,
- compound_index=compound_index,
- force_result_map=force_result_map,
- nested_join_translation=True, **kwargs
- )
+ transformed_select, asfrom=asfrom, parens=parens,
+ iswrapper=iswrapper, fromhints=fromhints,
+ compound_index=compound_index,
+ force_result_map=force_result_map,
+ nested_join_translation=True, **kwargs
+ )
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
-
populate_result_map = force_result_map or (
- compound_index == 0 and (
- toplevel or \
- entry['iswrapper']
- )
- )
+ compound_index == 0 and (
+ toplevel or
+ entry['iswrapper']
+ )
+ )
if needs_nested_translation:
if populate_result_map:
self._transform_result_map_for_nested_joins(
- select, transformed_select)
+ select, transformed_select)
return text
correlate_froms = entry['correlate_froms']
@@ -1462,48 +1466,49 @@ class SQLCompiler(Compiled):
if asfrom:
froms = select._get_display_froms(
- explicit_correlate_froms=
- correlate_froms.difference(asfrom_froms),
- implicit_correlate_froms=())
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms),
+ implicit_correlate_froms=())
else:
froms = select._get_display_froms(
- explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms)
new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'iswrapper': iswrapper,
- 'correlate_froms': all_correlate_froms
- }
+ 'asfrom_froms': new_correlate_froms,
+ 'iswrapper': iswrapper,
+ 'correlate_froms': all_correlate_froms
+ }
self.stack.append(new_entry)
column_clause_args = kwargs.copy()
column_clause_args.update({
- 'within_label_clause': False,
- 'within_columns_clause': False
- })
+ 'within_label_clause': False,
+ 'within_columns_clause': False
+ })
text = "SELECT " # we're off to a good start !
if select._hints:
byfrom = dict([
- (from_, hinttext % {
- 'name':from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ (from_, hinttext % {
+ 'name': from_._compiler_dispatch(
+ self, ashint=True)
+ })
+ for (from_, dialect), hinttext in
+ select._hints.items()
+ if dialect in ('*', self.dialect.name)
+ ])
hint_text = self.get_select_hint_text(byfrom)
if hint_text:
text += hint_text + " "
if select._prefixes:
- text += self._generate_prefixes(select, select._prefixes, **kwargs)
+ text += self._generate_prefixes(
+ select, select._prefixes, **kwargs)
text += self.get_select_precolumns(select)
@@ -1511,12 +1516,12 @@ class SQLCompiler(Compiled):
inner_columns = [
c for c in [
self._label_select_column(select,
- column,
- populate_result_map, asfrom,
- column_clause_args,
- name=name)
+ column,
+ populate_result_map, asfrom,
+ column_clause_args,
+ name=name)
for name, column in select._columns_plus_names
- ]
+ ]
if c is not None
]
@@ -1526,14 +1531,14 @@ class SQLCompiler(Compiled):
text += " \nFROM "
if select._hints:
- text += ', '.join([f._compiler_dispatch(self,
- asfrom=True, fromhints=byfrom,
- **kwargs)
- for f in froms])
+ text += ', '.join(
+ [f._compiler_dispatch(self, asfrom=True,
+ fromhints=byfrom, **kwargs)
+ for f in froms])
else:
- text += ', '.join([f._compiler_dispatch(self,
- asfrom=True, **kwargs)
- for f in froms])
+ text += ', '.join(
+ [f._compiler_dispatch(self, asfrom=True, **kwargs)
+ for f in froms])
else:
text += self.default_from()
@@ -1544,7 +1549,7 @@ class SQLCompiler(Compiled):
if select._group_by_clause.clauses:
group_by = select._group_by_clause._compiler_dispatch(
- self, **kwargs)
+ self, **kwargs)
if group_by:
text += " GROUP BY " + group_by
@@ -1559,17 +1564,18 @@ class SQLCompiler(Compiled):
else:
order_by_select = None
- text += self.order_by_clause(select,
- order_by_select=order_by_select, **kwargs)
+ text += self.order_by_clause(
+ select, order_by_select=order_by_select, **kwargs)
- if select._limit_clause is not None or select._offset_clause is not None:
+ if (select._limit_clause is not None or
+ select._offset_clause is not None):
text += self.limit_clause(select)
if select._for_update_arg is not None:
text += self.for_update_clause(select)
if self.ctes and \
- compound_index == 0 and toplevel:
+ compound_index == 0 and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -1581,11 +1587,11 @@ class SQLCompiler(Compiled):
def _generate_prefixes(self, stmt, prefixes, **kw):
clause = " ".join(
- prefix._compiler_dispatch(self, **kw)
- for prefix, dialect_name in prefixes
- if dialect_name is None or
- dialect_name == self.dialect.name
- )
+ prefix._compiler_dispatch(self, **kw)
+ for prefix, dialect_name in prefixes
+ if dialect_name is None or
+ dialect_name == self.dialect.name
+ )
if clause:
clause += " "
return clause
@@ -1593,9 +1599,9 @@ class SQLCompiler(Compiled):
def _render_cte_clause(self):
if self.positional:
self.positiontup = sum([
- self.cte_positional[cte]
- for cte in self.ctes], []) + \
- self.positiontup
+ self.cte_positional[cte]
+ for cte in self.ctes], []) + \
+ self.positiontup
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
@@ -1628,8 +1634,8 @@ class SQLCompiler(Compiled):
def returning_clause(self, stmt, returning_cols):
raise exc.CompileError(
- "RETURNING is not supported by this "
- "dialect's statement compiler.")
+ "RETURNING is not supported by this "
+ "dialect's statement compiler.")
def limit_clause(self, select):
text = ""
@@ -1642,16 +1648,16 @@ class SQLCompiler(Compiled):
return text
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
- fromhints=None, **kwargs):
+ fromhints=None, **kwargs):
if asfrom or ashint:
if getattr(table, "schema", None):
ret = self.preparer.quote_schema(table.schema) + \
- "." + self.preparer.quote(table.name)
+ "." + self.preparer.quote(table.name)
else:
ret = self.preparer.quote(table.name)
if fromhints and table in fromhints:
ret = self.format_from_hint_text(ret, table,
- fromhints[table], iscrud)
+ fromhints[table], iscrud)
return ret
else:
return ""
@@ -1673,21 +1679,21 @@ class SQLCompiler(Compiled):
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The '%s' dialect with current database "
- "version settings does not support empty "
- "inserts." %
- self.dialect.name)
+ "version settings does not support empty "
+ "inserts." %
+ self.dialect.name)
if insert_stmt._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
- raise exc.CompileError("The '%s' dialect with current database "
- "version settings does not support "
- "in-place multirow inserts." %
- self.dialect.name)
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support "
+ "in-place multirow inserts." %
+ self.dialect.name)
colparams_single = colparams[0]
else:
colparams_single = colparams
-
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -1695,7 +1701,7 @@ class SQLCompiler(Compiled):
if insert_stmt._prefixes:
text += self._generate_prefixes(insert_stmt,
- insert_stmt._prefixes, **kw)
+ insert_stmt._prefixes, **kw)
text += "INTO "
table_text = preparer.format_table(insert_stmt.table)
@@ -1709,22 +1715,22 @@ class SQLCompiler(Compiled):
])
if insert_stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
- table_text,
- insert_stmt.table,
- dialect_hints[insert_stmt.table],
- True
- )
+ table_text,
+ insert_stmt.table,
+ dialect_hints[insert_stmt.table],
+ True
+ )
text += table_text
if colparams_single or not supports_default_values:
text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in colparams_single])
+ for c in colparams_single])
if self.returning or insert_stmt._returning:
self.returning = self.returning or insert_stmt._returning
returning_clause = self.returning_clause(
- insert_stmt, self.returning)
+ insert_stmt, self.returning)
if self.returning_precedes_values:
text += " " + returning_clause
@@ -1735,16 +1741,16 @@ class SQLCompiler(Compiled):
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
- ", ".join(
- "(%s)" % (
- ', '.join(c[1] for c in colparam_set)
- )
- for colparam_set in colparams
- )
- )
+ ", ".join(
+ "(%s)" % (
+ ', '.join(c[1] for c in colparam_set)
+ )
+ for colparam_set in colparams
+ )
+ )
else:
text += " VALUES (%s)" % \
- ', '.join([c[1] for c in colparams])
+ ', '.join([c[1] for c in colparams])
if self.returning and not self.returning_precedes_values:
text += " " + returning_clause
@@ -1756,7 +1762,7 @@ class SQLCompiler(Compiled):
return None
def update_tables_clause(self, update_stmt, from_table,
- extra_froms, **kw):
+ extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
@@ -1764,12 +1770,12 @@ class SQLCompiler(Compiled):
"""
return from_table._compiler_dispatch(self, asfrom=True,
- iscrud=True, **kw)
+ iscrud=True, **kw)
def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ from_table, extra_froms,
+ from_hints,
+ **kw):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
@@ -1777,15 +1783,15 @@ class SQLCompiler(Compiled):
"""
return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ t._compiler_dispatch(self, asfrom=True,
+ fromhints=from_hints, **kw)
+ for t in extra_froms)
def visit_update(self, update_stmt, **kw):
self.stack.append(
- {'correlate_froms': set([update_stmt.table]),
- "iswrapper": False,
- "asfrom_froms": set([update_stmt.table])})
+ {'correlate_froms': set([update_stmt.table]),
+ "iswrapper": False,
+ "asfrom_froms": set([update_stmt.table])})
self.isupdate = True
@@ -1795,7 +1801,7 @@ class SQLCompiler(Compiled):
if update_stmt._prefixes:
text += self._generate_prefixes(update_stmt,
- update_stmt._prefixes, **kw)
+ update_stmt._prefixes, **kw)
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
@@ -1811,11 +1817,11 @@ class SQLCompiler(Compiled):
])
if update_stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
- table_text,
- update_stmt.table,
- dialect_hints[update_stmt.table],
- True
- )
+ table_text,
+ update_stmt.table,
+ dialect_hints[update_stmt.table],
+ True
+ )
else:
dialect_hints = None
@@ -1823,26 +1829,26 @@ class SQLCompiler(Compiled):
text += ' SET '
include_table = extra_froms and \
- self.render_table_with_column_in_update_from
+ self.render_table_with_column_in_update_from
text += ', '.join(
- c[0]._compiler_dispatch(self,
- include_table=include_table) +
- '=' + c[1] for c in colparams
- )
+ c[0]._compiler_dispatch(self,
+ include_table=include_table) +
+ '=' + c[1] for c in colparams
+ )
if self.returning or update_stmt._returning:
if not self.returning:
self.returning = update_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, self.returning)
if extra_froms:
extra_from_text = self.update_from_clause(
- update_stmt,
- update_stmt.table,
- extra_froms,
- dialect_hints, **kw)
+ update_stmt,
+ update_stmt.table,
+ extra_froms,
+ dialect_hints, **kw)
if extra_from_text:
text += " " + extra_from_text
@@ -1857,7 +1863,7 @@ class SQLCompiler(Compiled):
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, self.returning)
self.stack.pop(-1)
@@ -1867,7 +1873,7 @@ class SQLCompiler(Compiled):
if name is None:
name = col.key
bindparam = elements.BindParameter(name, value,
- type_=col.type, required=required)
+ type_=col.type, required=required)
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
@@ -1881,17 +1887,20 @@ class SQLCompiler(Compiled):
# allowing the most compatibility with a non-multi-table
# statement.
_et = set(self.statement._extra_froms)
+
def _column_as_key(key):
str_key = elements._column_as_key(key)
if hasattr(key, 'table') and key.table in _et:
return (key.table.name, str_key)
else:
return str_key
+
def _getattr_col_key(col):
if col.table in _et:
return (col.table.name, col.key)
else:
return col.key
+
def _col_bind_name(col):
if col.table in _et:
return "%s_%s" % (col.table.name, col.key)
@@ -1923,10 +1932,10 @@ class SQLCompiler(Compiled):
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
return [
- (c, self._create_crud_bind_param(c,
- None, required=True))
- for c in stmt.table.columns
- ]
+ (c, self._create_crud_bind_param(c,
+ None, required=True))
+ for c in stmt.table.columns
+ ]
if stmt._has_multi_parameters:
stmt_parameters = stmt.parameters[0]
@@ -1937,7 +1946,7 @@ class SQLCompiler(Compiled):
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
_column_as_key, _getattr_col_key, _col_bind_name = \
- self._key_getters_for_crud_column
+ self._key_getters_for_crud_column
# if we have statement parameters - set defaults in the
# compiled params
@@ -1963,26 +1972,27 @@ class SQLCompiler(Compiled):
# coercing right side to bound param
if elements._is_literal(v):
v = self.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
+ elements.BindParameter(None, v, type_=k.type),
+ **kw)
else:
v = self.process(v.self_group(), **kw)
values.append((k, v))
need_pks = self.isinsert and \
- not self.inline and \
- not stmt._returning
+ not self.inline and \
+ not stmt._returning
implicit_returning = need_pks and \
- self.dialect.implicit_returning and \
- stmt.table.implicit_returning
+ self.dialect.implicit_returning and \
+ stmt.table.implicit_returning
if self.isinsert:
- implicit_return_defaults = implicit_returning and stmt._return_defaults
+ implicit_return_defaults = (implicit_returning and
+ stmt._return_defaults)
elif self.isupdate:
- implicit_return_defaults = self.dialect.implicit_returning and \
- stmt.table.implicit_returning and \
- stmt._return_defaults
+ implicit_return_defaults = (self.dialect.implicit_returning and
+ stmt.table.implicit_returning and
+ stmt._return_defaults)
else:
implicit_return_defaults = False
@@ -2025,20 +2035,21 @@ class SQLCompiler(Compiled):
for c in t.c:
if c in normalized_params:
continue
- elif c.onupdate is not None and not c.onupdate.is_sequence:
+ elif (c.onupdate is not None and not
+ c.onupdate.is_sequence):
if c.onupdate.is_clause_element:
values.append(
(c, self.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
+ c.onupdate.arg.self_group(),
+ **kw)
+ )
)
self.postfetch.append(c)
else:
values.append(
(c, self._create_crud_bind_param(
- c, None, name=_col_bind_name(c)
- )
+ c, None, name=_col_bind_name(c)
+ )
)
)
self.prefetch.append(c)
@@ -2049,7 +2060,7 @@ class SQLCompiler(Compiled):
# for an insert from select, we can only use names that
# are given, so only select for those names.
cols = (stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names)
+ for name in stmt.select_names)
else:
# iterate through all table columns to maintain
# ordering, even for those cols that aren't included
@@ -2061,14 +2072,14 @@ class SQLCompiler(Compiled):
value = parameters.pop(col_key)
if elements._is_literal(value):
value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED,
- name=_col_bind_name(c)
- if not stmt._has_multi_parameters
- else "%s_0" % _col_bind_name(c)
- )
+ c, value, required=value is REQUIRED,
+ name=_col_bind_name(c)
+ if not stmt._has_multi_parameters
+ else "%s_0" % _col_bind_name(c)
+ )
else:
if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
+ value.type._isnull:
value = value._clone()
value.type = c.type
@@ -2076,7 +2087,7 @@ class SQLCompiler(Compiled):
self.returning.append(c)
value = self.process(value.self_group(), **kw)
elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ c in implicit_return_defaults:
self.returning.append(c)
value = self.process(value.self_group(), **kw)
else:
@@ -2086,26 +2097,26 @@ class SQLCompiler(Compiled):
elif self.isinsert:
if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
+ need_pks and \
+ (
+ implicit_returning or
+ not postfetch_lastrowid or
+ c is not stmt.table._autoincrement_column
+ ):
if implicit_returning:
if c.default is not None:
if c.default.is_sequence:
if self.dialect.supports_sequences and \
- (not c.default.optional or \
- not self.dialect.sequences_optional):
+ (not c.default.optional or
+ not self.dialect.sequences_optional):
proc = self.process(c.default, **kw)
values.append((c, proc))
self.returning.append(c)
elif c.default.is_clause_element:
values.append(
- (c,
- self.process(c.default.arg.self_group(), **kw))
+ (c, self.process(
+ c.default.arg.self_group(), **kw))
)
self.returning.append(c)
else:
@@ -2117,16 +2128,14 @@ class SQLCompiler(Compiled):
self.returning.append(c)
else:
if (
- c.default is not None and
- (
- not c.default.is_sequence or
- self.dialect.supports_sequences
- )
- ) or \
- c is stmt.table._autoincrement_column and (
- self.dialect.supports_sequences or
- self.dialect.preexecute_autoincrement_sequences
- ):
+ (c.default is not None and
+ (not c.default.is_sequence or
+ self.dialect.supports_sequences)) or
+ c is stmt.table._autoincrement_column and
+ (self.dialect.supports_sequences or
+ self.dialect.
+ preexecute_autoincrement_sequences)
+ ):
values.append(
(c, self._create_crud_bind_param(c, None))
@@ -2137,22 +2146,23 @@ class SQLCompiler(Compiled):
elif c.default is not None:
if c.default.is_sequence:
if self.dialect.supports_sequences and \
- (not c.default.optional or \
- not self.dialect.sequences_optional):
+ (not c.default.optional or
+ not self.dialect.sequences_optional):
proc = self.process(c.default, **kw)
values.append((c, proc))
if implicit_return_defaults and \
- c in implicit_return_defaults:
+ c in implicit_return_defaults:
self.returning.append(c)
elif not c.primary_key:
self.postfetch.append(c)
elif c.default.is_clause_element:
values.append(
- (c, self.process(c.default.arg.self_group(), **kw))
+ (c, self.process(
+ c.default.arg.self_group(), **kw))
)
if implicit_return_defaults and \
- c in implicit_return_defaults:
+ c in implicit_return_defaults:
self.returning.append(c)
elif not c.primary_key:
# don't add primary key column to postfetch
@@ -2164,22 +2174,23 @@ class SQLCompiler(Compiled):
self.prefetch.append(c)
elif c.server_default is not None:
if implicit_return_defaults and \
- c in implicit_return_defaults:
+ c in implicit_return_defaults:
self.returning.append(c)
elif not c.primary_key:
self.postfetch.append(c)
elif implicit_return_defaults and \
c in implicit_return_defaults:
- self.returning.append(c)
+ self.returning.append(c)
elif self.isupdate:
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, self.process(c.onupdate.arg.self_group(), **kw))
+ (c, self.process(
+ c.onupdate.arg.self_group(), **kw))
)
if implicit_return_defaults and \
- c in implicit_return_defaults:
+ c in implicit_return_defaults:
self.returning.append(c)
else:
self.postfetch.append(c)
@@ -2190,13 +2201,13 @@ class SQLCompiler(Compiled):
self.prefetch.append(c)
elif c.server_onupdate is not None:
if implicit_return_defaults and \
- c in implicit_return_defaults:
+ c in implicit_return_defaults:
self.returning.append(c)
else:
self.postfetch.append(c)
elif implicit_return_defaults and \
c in implicit_return_defaults:
- self.returning.append(c)
+ self.returning.append(c)
if parameters and stmt_parameters:
check = set(parameters).intersection(
@@ -2216,13 +2227,13 @@ class SQLCompiler(Compiled):
[
(
c,
- (self._create_crud_bind_param(
- c, row[c.key],
- name="%s_%d" % (c.key, i + 1)
- ) if elements._is_literal(row[c.key])
- else self.process(
- row[c.key].self_group(), **kw))
- if c.key in row else param
+ (self._create_crud_bind_param(
+ c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ ) if elements._is_literal(row[c.key])
+ else self.process(
+ row[c.key].self_group(), **kw))
+ if c.key in row else param
)
for (c, param) in values_0
]
@@ -2233,19 +2244,19 @@ class SQLCompiler(Compiled):
def visit_delete(self, delete_stmt, **kw):
self.stack.append({'correlate_froms': set([delete_stmt.table]),
- "iswrapper": False,
- "asfrom_froms": set([delete_stmt.table])})
+ "iswrapper": False,
+ "asfrom_froms": set([delete_stmt.table])})
self.isdelete = True
text = "DELETE "
if delete_stmt._prefixes:
text += self._generate_prefixes(delete_stmt,
- delete_stmt._prefixes, **kw)
+ delete_stmt._prefixes, **kw)
text += "FROM "
- table_text = delete_stmt.table._compiler_dispatch(self,
- asfrom=True, iscrud=True)
+ table_text = delete_stmt.table._compiler_dispatch(
+ self, asfrom=True, iscrud=True)
if delete_stmt._hints:
dialect_hints = dict([
@@ -2256,11 +2267,11 @@ class SQLCompiler(Compiled):
])
if delete_stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
- table_text,
- delete_stmt.table,
- dialect_hints[delete_stmt.table],
- True
- )
+ table_text,
+ delete_stmt.table,
+ dialect_hints[delete_stmt.table],
+ True
+ )
else:
dialect_hints = None
@@ -2271,7 +2282,7 @@ class SQLCompiler(Compiled):
self.returning = delete_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning)
if delete_stmt._whereclause is not None:
t = delete_stmt._whereclause._compiler_dispatch(self)
@@ -2280,7 +2291,7 @@ class SQLCompiler(Compiled):
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning)
self.stack.pop(-1)
@@ -2291,11 +2302,11 @@ class SQLCompiler(Compiled):
def visit_rollback_to_savepoint(self, savepoint_stmt):
return "ROLLBACK TO SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ self.preparer.format_savepoint(savepoint_stmt)
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ self.preparer.format_savepoint(savepoint_stmt)
class DDLCompiler(Compiled):
@@ -2349,11 +2360,11 @@ class DDLCompiler(Compiled):
table = create.element
preparer = self.dialect.identifier_preparer
- text = "\n" + " ".join(['CREATE'] + \
- table._prefixes + \
- ['TABLE',
- preparer.format_table(table),
- "("])
+ text = "\n" + " ".join(['CREATE'] +
+ table._prefixes +
+ ['TABLE',
+ preparer.format_table(table),
+ "("])
separator = "\n"
# if only one primary key, specify it along with the column
@@ -2362,8 +2373,8 @@ class DDLCompiler(Compiled):
column = create_column.element
try:
processed = self.process(create_column,
- first_pk=column.primary_key
- and not first_pk)
+ first_pk=column.primary_key
+ and not first_pk)
if processed is not None:
text += separator
separator = ", \n"
@@ -2372,11 +2383,10 @@ class DDLCompiler(Compiled):
first_pk = True
except exc.CompileError as ce:
util.raise_from_cause(
- exc.CompileError(util.u("(in table '%s', column '%s'): %s") % (
- table.description,
- column.name,
- ce.args[0]
- )))
+ exc.CompileError(
+ util.u("(in table '%s', column '%s'): %s") %
+ (table.description, column.name, ce.args[0])
+ ))
const = self.create_table_constraints(table)
if const:
@@ -2392,11 +2402,11 @@ class DDLCompiler(Compiled):
return None
text = self.get_column_specification(
- column,
- first_pk=first_pk
- )
- const = " ".join(self.process(constraint) \
- for constraint in column.constraints)
+ column,
+ first_pk=first_pk
+ )
+ const = " ".join(self.process(constraint)
+ for constraint in column.constraints)
if const:
text += " " + const
@@ -2411,19 +2421,19 @@ class DDLCompiler(Compiled):
constraints.append(table.primary_key)
constraints.extend([c for c in table._sorted_constraints
- if c is not table.primary_key])
+ if c is not table.primary_key])
return ", \n\t".join(p for p in
- (self.process(constraint)
- for constraint in constraints
- if (
- constraint._create_rule is None or
- constraint._create_rule(self))
- and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
- )
+ (self.process(constraint)
+ for constraint in constraints
+ if (
+ constraint._create_rule is None or
+ constraint._create_rule(self))
+ and (
+ not self.dialect.supports_alter or
+ not getattr(constraint, 'use_alter', False)
+ )) if p is not None
+ )
def visit_drop_table(self, drop):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
@@ -2431,15 +2441,13 @@ class DDLCompiler(Compiled):
def visit_drop_view(self, drop):
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
-
def _verify_index_table(self, index):
if index.table is None:
raise exc.CompileError("Index '%s' is not associated "
- "with any table." % index.name)
-
+ "with any table." % index.name)
def visit_create_index(self, create, include_schema=False,
- include_table_schema=True):
+ include_table_schema=True):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
@@ -2447,22 +2455,22 @@ class DDLCompiler(Compiled):
if index.unique:
text += "UNIQUE "
text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table,
- use_schema=include_table_schema),
- ', '.join(
- self.sql_compiler.process(expr,
- include_table=False, literal_binds=True) for
- expr in index.expressions)
- )
+ % (
+ self._prepared_index_name(index,
+ include_schema=include_schema),
+ preparer.format_table(index.table,
+ use_schema=include_table_schema),
+ ', '.join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True) for
+ expr in index.expressions)
+ )
return text
def visit_drop_index(self, drop):
index = drop.element
- return "\nDROP INDEX " + self._prepared_index_name(index,
- include_schema=True)
+ return "\nDROP INDEX " + self._prepared_index_name(
+ index, include_schema=True)
def _prepared_index_name(self, index, include_schema=False):
if include_schema and index.table is not None and index.table.schema:
@@ -2474,10 +2482,10 @@ class DDLCompiler(Compiled):
ident = index.name
if isinstance(ident, elements._truncated_label):
max_ = self.dialect.max_index_name_length or \
- self.dialect.max_identifier_length
+ self.dialect.max_identifier_length
if len(ident) > max_:
ident = ident[0:max_ - 8] + \
- "_" + util.md5_hex(ident)[-4:]
+ "_" + util.md5_hex(ident)[-4:]
else:
self.dialect.validate_identifier(ident)
@@ -2495,7 +2503,7 @@ class DDLCompiler(Compiled):
def visit_create_sequence(self, create):
text = "CREATE SEQUENCE %s" % \
- self.preparer.format_sequence(create.element)
+ self.preparer.format_sequence(create.element)
if create.element.increment is not None:
text += " INCREMENT BY %d" % create.element.increment
if create.element.start is not None:
@@ -2504,7 +2512,7 @@ class DDLCompiler(Compiled):
def visit_drop_sequence(self, drop):
return "DROP SEQUENCE %s" % \
- self.preparer.format_sequence(drop.element)
+ self.preparer.format_sequence(drop.element)
def visit_drop_constraint(self, drop):
return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
@@ -2515,7 +2523,7 @@ class DDLCompiler(Compiled):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(column.type)
+ self.dialect.type_compiler.process(column.type)
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -2543,8 +2551,8 @@ class DDLCompiler(Compiled):
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ include_table=False,
+ literal_binds=True)
text += self.define_constraint_deferrability(constraint)
return text
@@ -2568,7 +2576,7 @@ class DDLCompiler(Compiled):
text += "CONSTRAINT %s " % formatted_name
text += "PRIMARY KEY "
text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in constraint)
+ for c in constraint)
text += self.define_constraint_deferrability(constraint)
return text
@@ -2607,7 +2615,7 @@ class DDLCompiler(Compiled):
text += "CONSTRAINT %s " % formatted_name
text += "UNIQUE (%s)" % (
', '.join(self.preparer.quote(c.name)
- for c in constraint))
+ for c in constraint))
text += self.define_constraint_deferrability(constraint)
return text
@@ -2650,22 +2658,22 @@ class GenericTypeCompiler(TypeCompiler):
return "NUMERIC"
elif type_.scale is None:
return "NUMERIC(%(precision)s)" % \
- {'precision': type_.precision}
+ {'precision': type_.precision}
else:
return "NUMERIC(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ {'precision': type_.precision,
+ 'scale': type_.scale}
def visit_DECIMAL(self, type_):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
return "DECIMAL(%(precision)s)" % \
- {'precision': type_.precision}
+ {'precision': type_.precision}
else:
return "DECIMAL(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ {'precision': type_.precision,
+ 'scale': type_.scale}
def visit_INTEGER(self, type_):
return "INTEGER"
@@ -2780,8 +2788,8 @@ class GenericTypeCompiler(TypeCompiler):
def visit_null(self, type_):
raise exc.CompileError("Can't generate DDL for %r; "
- "did you forget to specify a "
- "type on this Column?" % type_)
+ "did you forget to specify a "
+ "type on this Column?" % type_)
def visit_type_decorator(self, type_):
return self.process(type_.type_engine(self.dialect))
@@ -2791,6 +2799,7 @@ class GenericTypeCompiler(TypeCompiler):
class IdentifierPreparer(object):
+
"""Handle quoting and case-folding of identifiers based on options."""
reserved_words = RESERVED_WORDS
@@ -2800,7 +2809,7 @@ class IdentifierPreparer(object):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
def __init__(self, dialect, initial_quote='"',
- final_quote=None, escape_quote='"', omit_schema=False):
+ final_quote=None, escape_quote='"', omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
initial_quote
@@ -2849,8 +2858,8 @@ class IdentifierPreparer(object):
"""
return self.initial_quote + \
- self._escape_identifier(value) + \
- self.final_quote
+ self._escape_identifier(value) + \
+ self.final_quote
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
@@ -2895,7 +2904,8 @@ class IdentifierPreparer(object):
def format_sequence(self, sequence, use_schema=True):
name = self.quote(sequence.name)
- if not self.omit_schema and use_schema and sequence.schema is not None:
+ if (not self.omit_schema and use_schema and
+ sequence.schema is not None):
name = self.quote_schema(sequence.schema) + "." + name
return name
@@ -2912,7 +2922,7 @@ class IdentifierPreparer(object):
def format_constraint(self, naming, constraint):
if isinstance(constraint.name, elements._defer_name):
name = naming._constraint_name_for_table(
- constraint, constraint.table)
+ constraint, constraint.table)
if name:
return self.quote(name)
elif isinstance(constraint.name, elements._defer_none_name):
@@ -2926,7 +2936,7 @@ class IdentifierPreparer(object):
name = table.name
result = self.quote(name)
if not self.omit_schema and use_schema \
- and getattr(table, "schema", None):
+ and getattr(table, "schema", None):
result = self.quote_schema(table.schema) + "." + result
return result
@@ -2936,7 +2946,7 @@ class IdentifierPreparer(object):
return self.quote(name, quote)
def format_column(self, column, use_table=False,
- name=None, table_name=None):
+ name=None, table_name=None):
"""Prepare a quoted column name."""
if name is None:
@@ -2944,8 +2954,8 @@ class IdentifierPreparer(object):
if not getattr(column, 'is_literal', False):
if use_table:
return self.format_table(
- column.table, use_schema=False,
- name=table_name) + "." + self.quote(name)
+ column.table, use_schema=False,
+ name=table_name) + "." + self.quote(name)
else:
return self.quote(name)
else:
@@ -2953,8 +2963,9 @@ class IdentifierPreparer(object):
# which shouldn't get quoted
if use_table:
- return self.format_table(column.table,
- use_schema=False, name=table_name) + '.' + name
+ return self.format_table(
+ column.table, use_schema=False,
+ name=table_name) + '.' + name
else:
return name
@@ -2975,9 +2986,9 @@ class IdentifierPreparer(object):
@util.memoized_property
def _r_identifiers(self):
initial, final, escaped_final = \
- [re.escape(s) for s in
- (self.initial_quote, self.final_quote,
- self._escape_identifier(self.final_quote))]
+ [re.escape(s) for s in
+ (self.initial_quote, self.final_quote,
+ self._escape_identifier(self.final_quote))]
r = re.compile(
r'(?:'
r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'