summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py2030
1 files changed, 1201 insertions, 829 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 80ed707ed..f641d0a84 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -25,133 +25,218 @@ To generate user-defined SQL strings, see
import contextlib
import re
-from . import schema, sqltypes, operators, functions, visitors, \
- elements, selectable, crud
+from . import (
+ schema,
+ sqltypes,
+ operators,
+ functions,
+ visitors,
+ elements,
+ selectable,
+ crud,
+)
from .. import util, exc
import itertools
-RESERVED_WORDS = set([
- 'all', 'analyse', 'analyze', 'and', 'any', 'array',
- 'as', 'asc', 'asymmetric', 'authorization', 'between',
- 'binary', 'both', 'case', 'cast', 'check', 'collate',
- 'column', 'constraint', 'create', 'cross', 'current_date',
- 'current_role', 'current_time', 'current_timestamp',
- 'current_user', 'default', 'deferrable', 'desc',
- 'distinct', 'do', 'else', 'end', 'except', 'false',
- 'for', 'foreign', 'freeze', 'from', 'full', 'grant',
- 'group', 'having', 'ilike', 'in', 'initially', 'inner',
- 'intersect', 'into', 'is', 'isnull', 'join', 'leading',
- 'left', 'like', 'limit', 'localtime', 'localtimestamp',
- 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
- 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
- 'placing', 'primary', 'references', 'right', 'select',
- 'session_user', 'set', 'similar', 'some', 'symmetric', 'table',
- 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
- 'using', 'verbose', 'when', 'where'])
-
-LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
-ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$'])
-
-BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
-BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE)
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "authorization",
+ "between",
+ "binary",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "for",
+ "foreign",
+ "freeze",
+ "from",
+ "full",
+ "grant",
+ "group",
+ "having",
+ "ilike",
+ "in",
+ "initially",
+ "inner",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "leading",
+ "left",
+ "like",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "natural",
+ "new",
+ "not",
+ "notnull",
+ "null",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "outer",
+ "overlaps",
+ "placing",
+ "primary",
+ "references",
+ "right",
+ "select",
+ "session_user",
+ "set",
+ "similar",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "verbose",
+ "when",
+ "where",
+ ]
+)
+
+LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+
+BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
BIND_TEMPLATES = {
- 'pyformat': "%%(%(name)s)s",
- 'qmark': "?",
- 'format': "%%s",
- 'numeric': ":[_POSITION]",
- 'named': ":%(name)s"
+ "pyformat": "%%(%(name)s)s",
+ "qmark": "?",
+ "format": "%%s",
+ "numeric": ":[_POSITION]",
+ "named": ":%(name)s",
}
OPERATORS = {
# binary
- operators.and_: ' AND ',
- operators.or_: ' OR ',
- operators.add: ' + ',
- operators.mul: ' * ',
- operators.sub: ' - ',
- operators.div: ' / ',
- operators.mod: ' % ',
- operators.truediv: ' / ',
- operators.neg: '-',
- operators.lt: ' < ',
- operators.le: ' <= ',
- operators.ne: ' != ',
- operators.gt: ' > ',
- operators.ge: ' >= ',
- operators.eq: ' = ',
- operators.is_distinct_from: ' IS DISTINCT FROM ',
- operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ',
- operators.concat_op: ' || ',
- operators.match_op: ' MATCH ',
- operators.notmatch_op: ' NOT MATCH ',
- operators.in_op: ' IN ',
- operators.notin_op: ' NOT IN ',
- operators.comma_op: ', ',
- operators.from_: ' FROM ',
- operators.as_: ' AS ',
- operators.is_: ' IS ',
- operators.isnot: ' IS NOT ',
- operators.collate: ' COLLATE ',
-
+ operators.and_: " AND ",
+ operators.or_: " OR ",
+ operators.add: " + ",
+ operators.mul: " * ",
+ operators.sub: " - ",
+ operators.div: " / ",
+ operators.mod: " % ",
+ operators.truediv: " / ",
+ operators.neg: "-",
+ operators.lt: " < ",
+ operators.le: " <= ",
+ operators.ne: " != ",
+ operators.gt: " > ",
+ operators.ge: " >= ",
+ operators.eq: " = ",
+ operators.is_distinct_from: " IS DISTINCT FROM ",
+ operators.isnot_distinct_from: " IS NOT DISTINCT FROM ",
+ operators.concat_op: " || ",
+ operators.match_op: " MATCH ",
+ operators.notmatch_op: " NOT MATCH ",
+ operators.in_op: " IN ",
+ operators.notin_op: " NOT IN ",
+ operators.comma_op: ", ",
+ operators.from_: " FROM ",
+ operators.as_: " AS ",
+ operators.is_: " IS ",
+ operators.isnot: " IS NOT ",
+ operators.collate: " COLLATE ",
# unary
- operators.exists: 'EXISTS ',
- operators.distinct_op: 'DISTINCT ',
- operators.inv: 'NOT ',
- operators.any_op: 'ANY ',
- operators.all_op: 'ALL ',
-
+ operators.exists: "EXISTS ",
+ operators.distinct_op: "DISTINCT ",
+ operators.inv: "NOT ",
+ operators.any_op: "ANY ",
+ operators.all_op: "ALL ",
# modifiers
- operators.desc_op: ' DESC',
- operators.asc_op: ' ASC',
- operators.nullsfirst_op: ' NULLS FIRST',
- operators.nullslast_op: ' NULLS LAST',
-
+ operators.desc_op: " DESC",
+ operators.asc_op: " ASC",
+ operators.nullsfirst_op: " NULLS FIRST",
+ operators.nullslast_op: " NULLS LAST",
}
FUNCTIONS = {
- functions.coalesce: 'coalesce',
- functions.current_date: 'CURRENT_DATE',
- functions.current_time: 'CURRENT_TIME',
- functions.current_timestamp: 'CURRENT_TIMESTAMP',
- functions.current_user: 'CURRENT_USER',
- functions.localtime: 'LOCALTIME',
- functions.localtimestamp: 'LOCALTIMESTAMP',
- functions.random: 'random',
- functions.sysdate: 'sysdate',
- functions.session_user: 'SESSION_USER',
- functions.user: 'USER',
- functions.cube: 'CUBE',
- functions.rollup: 'ROLLUP',
- functions.grouping_sets: 'GROUPING SETS',
+ functions.coalesce: "coalesce",
+ functions.current_date: "CURRENT_DATE",
+ functions.current_time: "CURRENT_TIME",
+ functions.current_timestamp: "CURRENT_TIMESTAMP",
+ functions.current_user: "CURRENT_USER",
+ functions.localtime: "LOCALTIME",
+ functions.localtimestamp: "LOCALTIMESTAMP",
+ functions.random: "random",
+ functions.sysdate: "sysdate",
+ functions.session_user: "SESSION_USER",
+ functions.user: "USER",
+ functions.cube: "CUBE",
+ functions.rollup: "ROLLUP",
+ functions.grouping_sets: "GROUPING SETS",
}
EXTRACT_MAP = {
- 'month': 'month',
- 'day': 'day',
- 'year': 'year',
- 'second': 'second',
- 'hour': 'hour',
- 'doy': 'doy',
- 'minute': 'minute',
- 'quarter': 'quarter',
- 'dow': 'dow',
- 'week': 'week',
- 'epoch': 'epoch',
- 'milliseconds': 'milliseconds',
- 'microseconds': 'microseconds',
- 'timezone_hour': 'timezone_hour',
- 'timezone_minute': 'timezone_minute'
+ "month": "month",
+ "day": "day",
+ "year": "year",
+ "second": "second",
+ "hour": "hour",
+ "doy": "doy",
+ "minute": "minute",
+ "quarter": "quarter",
+ "dow": "dow",
+ "week": "week",
+ "epoch": "epoch",
+ "milliseconds": "milliseconds",
+ "microseconds": "microseconds",
+ "timezone_hour": "timezone_hour",
+ "timezone_minute": "timezone_minute",
}
COMPOUND_KEYWORDS = {
- selectable.CompoundSelect.UNION: 'UNION',
- selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
- selectable.CompoundSelect.EXCEPT: 'EXCEPT',
- selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
- selectable.CompoundSelect.INTERSECT: 'INTERSECT',
- selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
+ selectable.CompoundSelect.UNION: "UNION",
+ selectable.CompoundSelect.UNION_ALL: "UNION ALL",
+ selectable.CompoundSelect.EXCEPT: "EXCEPT",
+ selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
+ selectable.CompoundSelect.INTERSECT: "INTERSECT",
+ selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
}
@@ -177,9 +262,14 @@ class Compiled(object):
sub-elements of the statement can modify these.
"""
- def __init__(self, dialect, statement, bind=None,
- schema_translate_map=None,
- compile_kwargs=util.immutabledict()):
+ def __init__(
+ self,
+ dialect,
+ statement,
+ bind=None,
+ schema_translate_map=None,
+ compile_kwargs=util.immutabledict(),
+ ):
"""Construct a new :class:`.Compiled` object.
:param dialect: :class:`.Dialect` to compile against.
@@ -209,7 +299,8 @@ class Compiled(object):
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.preparer = self.preparer._with_schema_translate(
- schema_translate_map)
+ schema_translate_map
+ )
if statement is not None:
self.statement = statement
@@ -218,8 +309,10 @@ class Compiled(object):
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
- @util.deprecated("0.7", ":class:`.Compiled` objects now compile "
- "within the constructor.")
+ @util.deprecated(
+ "0.7",
+ ":class:`.Compiled` objects now compile " "within the constructor.",
+ )
def compile(self):
"""Produce the internal string representation of this element.
"""
@@ -247,7 +340,7 @@ class Compiled(object):
def __str__(self):
"""Return the string text of the generated SQL or DDL."""
- return self.string or ''
+ return self.string or ""
def construct_params(self, params=None):
"""Return the bind params for this compiled object.
@@ -271,7 +364,9 @@ class Compiled(object):
if e is None:
raise exc.UnboundExecutionError(
"This Compiled object is not bound to any Engine "
- "or Connection.", code="2afi")
+ "or Connection.",
+ code="2afi",
+ )
return e._execute_compiled(self, multiparams, params)
def scalar(self, *multiparams, **params):
@@ -284,7 +379,7 @@ class Compiled(object):
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
- ensure_kwarg = r'visit_\w+'
+ ensure_kwarg = r"visit_\w+"
def __init__(self, dialect):
self.dialect = dialect
@@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression.Label."""
- __visit_name__ = 'label'
- __slots__ = 'element', 'name'
+ __visit_name__ = "label"
+ __slots__ = "element", "name"
def __init__(self, col, name, alt_names=()):
self.element = col
@@ -390,8 +485,9 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
- def __init__(self, dialect, statement, column_keys=None,
- inline=False, **kwargs):
+ def __init__(
+ self, dialect, statement, column_keys=None, inline=False, **kwargs
+ ):
"""Construct a new :class:`.SQLCompiler` object.
:param dialect: :class:`.Dialect` to be used
@@ -412,7 +508,7 @@ class SQLCompiler(Compiled):
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
- self.inline = inline or getattr(statement, 'inline', False)
+ self.inline = inline or getattr(statement, "inline", False)
# a dictionary of bind parameter keys to BindParameter
# instances.
@@ -440,8 +536,9 @@ class SQLCompiler(Compiled):
self.ctes = None
- self.label_length = dialect.label_length \
- or dialect.max_identifier_length
+ self.label_length = (
+ dialect.label_length or dialect.max_identifier_length
+ )
# a map which tracks "anonymous" identifiers that are created on
# the fly here
@@ -453,7 +550,7 @@ class SQLCompiler(Compiled):
Compiled.__init__(self, dialect, statement, **kwargs)
if (
- self.isinsert or self.isupdate or self.isdelete
+ self.isinsert or self.isupdate or self.isdelete
) and statement._returning:
self.returning = statement._returning
@@ -482,37 +579,43 @@ class SQLCompiler(Compiled):
def _nested_result(self):
"""special API to support the use case of 'nested result sets'"""
result_columns, ordered_columns = (
- self._result_columns, self._ordered_columns)
+ self._result_columns,
+ self._ordered_columns,
+ )
self._result_columns, self._ordered_columns = [], False
try:
if self.stack:
entry = self.stack[-1]
- entry['need_result_map_for_nested'] = True
+ entry["need_result_map_for_nested"] = True
else:
entry = None
yield self._result_columns, self._ordered_columns
finally:
if entry:
- entry.pop('need_result_map_for_nested')
+ entry.pop("need_result_map_for_nested")
self._result_columns, self._ordered_columns = (
- result_columns, ordered_columns)
+ result_columns,
+ ordered_columns,
+ )
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
- r'\[_POSITION\]',
- lambda m: str(util.next(poscount)),
- self.string)
+ r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string
+ )
@util.memoized_property
def _bind_processors(self):
return dict(
- (key, value) for key, value in
- ((self.bind_names[bindparam],
- bindparam.type._cached_bind_processor(self.dialect)
- )
- for bindparam in self.bind_names)
+ (key, value)
+ for key, value in (
+ (
+ self.bind_names[bindparam],
+ bindparam.type._cached_bind_processor(self.dialect),
+ )
+ for bindparam in self.bind_names
+ )
if value is not None
)
@@ -539,12 +642,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
elif bindparam.callable:
pd[name] = bindparam.effective_value
@@ -558,12 +665,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
if bindparam.callable:
pd[self.bind_names[bindparam]] = bindparam.effective_value
@@ -595,9 +706,10 @@ class SQLCompiler(Compiled):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if self.stack and self.dialect.supports_simple_order_by_label:
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
if within_columns_clause:
@@ -611,25 +723,30 @@ class SQLCompiler(Compiled):
# to something else like a ColumnClause expression.
order_by_elem = element.element._order_by_label_element
- if order_by_elem is not None and order_by_elem.name in \
- resolve_dict and \
- order_by_elem.shares_lineage(
- resolve_dict[order_by_elem.name]):
- kwargs['render_label_as_label'] = \
- element.element._order_by_label_element
+ if (
+ order_by_elem is not None
+ and order_by_elem.name in resolve_dict
+ and order_by_elem.shares_lineage(
+ resolve_dict[order_by_elem.name]
+ )
+ ):
+ kwargs[
+ "render_label_as_label"
+ ] = element.element._order_by_label_element
return self.process(
- element.element, within_columns_clause=within_columns_clause,
- **kwargs)
+ element.element,
+ within_columns_clause=within_columns_clause,
+ **kwargs
+ )
def visit_textual_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if not self.stack:
# compiling the element outside of the context of a SELECT
- return self.process(
- element._text_clause
- )
+ return self.process(element._text_clause)
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
try:
if within_columns_clause:
@@ -640,26 +757,30 @@ class SQLCompiler(Compiled):
# treat it like text()
util.warn_limited(
"Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.element))
- return self.process(
- element._text_clause
+ util.ellipses_string(element.element),
)
+ return self.process(element._text_clause)
else:
- kwargs['render_label_as_label'] = col
+ kwargs["render_label_as_label"] = col
return self.process(
- col, within_columns_clause=within_columns_clause, **kwargs)
-
- def visit_label(self, label,
- add_to_result_map=None,
- within_label_clause=False,
- within_columns_clause=False,
- render_label_as_label=None,
- **kw):
+ col, within_columns_clause=within_columns_clause, **kwargs
+ )
+
+ def visit_label(
+ self,
+ label,
+ add_to_result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False,
+ render_label_as_label=None,
+ **kw
+ ):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
- render_label_with_as = (within_columns_clause and not
- within_label_clause)
+ render_label_with_as = (
+ within_columns_clause and not within_label_clause
+ )
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
@@ -673,27 +794,35 @@ class SQLCompiler(Compiled):
add_to_result_map(
labelname,
label.name,
- (label, labelname, ) + label._alt_names,
- label.type
+ (label, labelname) + label._alt_names,
+ label.type,
)
- return label.element._compiler_dispatch(
- self, within_columns_clause=True,
- within_label_clause=True, **kw) + \
- OPERATORS[operators.as_] + \
- self.preparer.format_label(label, labelname)
+ return (
+ label.element._compiler_dispatch(
+ self,
+ within_columns_clause=True,
+ within_label_clause=True,
+ **kw
+ )
+ + OPERATORS[operators.as_]
+ + self.preparer.format_label(label, labelname)
+ )
elif render_label_only:
return self.preparer.format_label(label, labelname)
else:
return label.element._compiler_dispatch(
- self, within_columns_clause=False, **kw)
+ self, within_columns_clause=False, **kw
+ )
def _fallback_column_name(self, column):
- raise exc.CompileError("Cannot compile Column object until "
- "its 'name' is assigned.")
+ raise exc.CompileError(
+ "Cannot compile Column object until " "its 'name' is assigned."
+ )
- def visit_column(self, column, add_to_result_map=None,
- include_table=True, **kwargs):
+ def visit_column(
+ self, column, add_to_result_map=None, include_table=True, **kwargs
+ ):
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -704,10 +833,7 @@ class SQLCompiler(Compiled):
if add_to_result_map is not None:
add_to_result_map(
- name,
- orig_name,
- (column, name, column.key),
- column.type
+ name, orig_name, (column, name, column.key), column.type
)
if is_literal:
@@ -721,17 +847,16 @@ class SQLCompiler(Compiled):
effective_schema = self.preparer.schema_for_object(table)
if effective_schema:
- schema_prefix = self.preparer.quote_schema(
- effective_schema) + '.'
+ schema_prefix = (
+ self.preparer.quote_schema(effective_schema) + "."
+ )
else:
- schema_prefix = ''
+ schema_prefix = ""
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
- return schema_prefix + \
- self.preparer.quote(tablename) + \
- "." + name
+ return schema_prefix + self.preparer.quote(tablename) + "." + name
def visit_collation(self, element, **kw):
return self.preparer.format_collation(element.collation)
@@ -743,17 +868,17 @@ class SQLCompiler(Compiled):
return index.name
def visit_typeclause(self, typeclause, **kw):
- kw['type_expression'] = typeclause
+ kw["type_expression"] = typeclause
return self.dialect.type_compiler.process(typeclause.type, **kw)
def post_process_text(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def escape_literal_column(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def visit_textclause(self, textclause, **kw):
@@ -771,30 +896,36 @@ class SQLCompiler(Compiled):
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
BIND_PARAMS.sub(
- do_bindparam,
- self.post_process_text(textclause.text))
+ do_bindparam, self.post_process_text(textclause.text)
+ ),
)
- def visit_text_as_from(self, taf,
- compound_index=None,
- asfrom=False,
- parens=True, **kw):
+ def visit_text_as_from(
+ self, taf, compound_index=None, asfrom=False, parens=True, **kw
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
if populate_result_map:
- self._ordered_columns = \
- self._textual_ordered_columns = taf.positional
+ self._ordered_columns = (
+ self._textual_ordered_columns
+ ) = taf.positional
for c in taf.column_args:
- self.process(c, within_columns_clause=True,
- add_to_result_map=self._add_to_result_map)
+ self.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=self._add_to_result_map,
+ )
text = self.process(taf.element, **kw)
if asfrom and parens:
@@ -802,17 +933,17 @@ class SQLCompiler(Compiled):
return text
def visit_null(self, expr, **kw):
- return 'NULL'
+ return "NULL"
def visit_true(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'true'
+ return "true"
else:
return "1"
def visit_false(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'false'
+ return "false"
else:
return "0"
@@ -823,25 +954,29 @@ class SQLCompiler(Compiled):
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
- s for s in
- (
- c._compiler_dispatch(self, **kw)
- for c in clauselist.clauses)
- if s)
+ s
+ for s in (
+ c._compiler_dispatch(self, **kw) for c in clauselist.clauses
+ )
+ if s
+ )
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
x += clause.value._compiler_dispatch(self, **kwargs) + " "
for cond, result in clause.whens:
- x += "WHEN " + cond._compiler_dispatch(
- self, **kwargs
- ) + " THEN " + result._compiler_dispatch(
- self, **kwargs) + " "
+ x += (
+ "WHEN "
+ + cond._compiler_dispatch(self, **kwargs)
+ + " THEN "
+ + result._compiler_dispatch(self, **kwargs)
+ + " "
+ )
if clause.else_ is not None:
- x += "ELSE " + clause.else_._compiler_dispatch(
- self, **kwargs
- ) + " "
+ x += (
+ "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
+ )
x += "END"
return x
@@ -849,79 +984,84 @@ class SQLCompiler(Compiled):
return type_coerce.typed_expression._compiler_dispatch(self, **kw)
def visit_cast(self, cast, **kwargs):
- return "CAST(%s AS %s)" % \
- (cast.clause._compiler_dispatch(self, **kwargs),
- cast.typeclause._compiler_dispatch(self, **kwargs))
+ return "CAST(%s AS %s)" % (
+ cast.clause._compiler_dispatch(self, **kwargs),
+ cast.typeclause._compiler_dispatch(self, **kwargs),
+ )
def _format_frame_clause(self, range_, **kw):
- return '%s AND %s' % (
+ return "%s AND %s" % (
"UNBOUNDED PRECEDING"
if range_[0] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[0])), **kw), )
+ else "CURRENT ROW"
+ if range_[0] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[0])), **kw),)
if range_[0] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[0]), **kw), ),
-
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[0]), **kw),),
"UNBOUNDED FOLLOWING"
if range_[1] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[1])), **kw), )
+ else "CURRENT ROW"
+ if range_[1] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[1])), **kw),)
if range_[1] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[1]), **kw), ),
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[1]), **kw),),
)
def visit_over(self, over, **kwargs):
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
- over.range_, **kwargs)
+ over.range_, **kwargs
+ )
elif over.rows:
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
- over.rows, **kwargs)
+ over.rows, **kwargs
+ )
else:
range_ = None
return "%s OVER (%s)" % (
over.element._compiler_dispatch(self, **kwargs),
- ' '.join([
- '%s BY %s' % (
- word, clause._compiler_dispatch(self, **kwargs)
- )
- for word, clause in (
- ('PARTITION', over.partition_by),
- ('ORDER', over.order_by)
- )
- if clause is not None and len(clause)
- ] + ([range_] if range_ else [])
- )
+ " ".join(
+ [
+ "%s BY %s"
+ % (word, clause._compiler_dispatch(self, **kwargs))
+ for word, clause in (
+ ("PARTITION", over.partition_by),
+ ("ORDER", over.order_by),
+ )
+ if clause is not None and len(clause)
+ ]
+ + ([range_] if range_ else [])
+ ),
)
def visit_withingroup(self, withingroup, **kwargs):
return "%s WITHIN GROUP (ORDER BY %s)" % (
withingroup.element._compiler_dispatch(self, **kwargs),
- withingroup.order_by._compiler_dispatch(self, **kwargs)
+ withingroup.order_by._compiler_dispatch(self, **kwargs),
)
def visit_funcfilter(self, funcfilter, **kwargs):
return "%s FILTER (WHERE %s)" % (
funcfilter.func._compiler_dispatch(self, **kwargs),
- funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ funcfilter.criterion._compiler_dispatch(self, **kwargs),
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
- field, extract.expr._compiler_dispatch(self, **kwargs))
+ field,
+ extract.expr._compiler_dispatch(self, **kwargs),
+ )
def visit_function(self, func, add_to_result_map=None, **kwargs):
if add_to_result_map is not None:
- add_to_result_map(
- func.name, func.name, (), func.type
- )
+ add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
if disp:
@@ -933,51 +1073,63 @@ class SQLCompiler(Compiled):
name += "%(expr)s"
else:
name = func.name + "%(expr)s"
- return ".".join(list(func.packagenames) + [name]) % \
- {'expr': self.function_argspec(func, **kwargs)}
+ return ".".join(list(func.packagenames) + [name]) % {
+ "expr": self.function_argspec(func, **kwargs)
+ }
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
def visit_sequence(self, sequence, **kw):
raise NotImplementedError(
- "Dialect '%s' does not support sequence increments." %
- self.dialect.name
+ "Dialect '%s' does not support sequence increments."
+ % self.dialect.name
)
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
- def visit_compound_select(self, cs, asfrom=False,
- parens=True, compound_index=0, **kwargs):
+ def visit_compound_select(
+ self, cs, asfrom=False, parens=True, compound_index=0, **kwargs
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- need_result_map = toplevel or \
- (compound_index == 0
- and entry.get('need_result_map_for_compound', False))
+ need_result_map = toplevel or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
self.stack.append(
{
- 'correlate_froms': entry['correlate_froms'],
- 'asfrom_froms': entry['asfrom_froms'],
- 'selectable': cs,
- 'need_result_map_for_compound': need_result_map
- })
+ "correlate_froms": entry["correlate_froms"],
+ "asfrom_froms": entry["asfrom_froms"],
+ "selectable": cs,
+ "need_result_map_for_compound": need_result_map,
+ }
+ )
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (c._compiler_dispatch(self,
- asfrom=asfrom, parens=False,
- compound_index=i, **kwargs)
- for i, c in enumerate(cs.selects))
+ (
+ c._compiler_dispatch(
+ self,
+ asfrom=asfrom,
+ parens=False,
+ compound_index=i,
+ **kwargs
+ )
+ for i, c in enumerate(cs.selects)
+ )
)
text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
text += self.order_by_clause(cs, **kwargs)
- text += (cs._limit_clause is not None
- or cs._offset_clause is not None) and \
- self.limit_clause(cs, **kwargs) or ""
+ text += (
+ (cs._limit_clause is not None or cs._offset_clause is not None)
+ and self.limit_clause(cs, **kwargs)
+ or ""
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -990,8 +1142,10 @@ class SQLCompiler(Compiled):
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
attrname = "visit_%s_%s%s" % (
- operator_.__name__, qualifier1,
- "_" + qualifier2 if qualifier2 else "")
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
return getattr(self, attrname, None)
def visit_unary(self, unary, **kw):
@@ -999,51 +1153,63 @@ class SQLCompiler(Compiled):
if unary.modifier:
raise exc.CompileError(
"Unary expression does not support operator "
- "and modifier simultaneously")
+ "and modifier simultaneously"
+ )
disp = self._get_operator_dispatch(
- unary.operator, "unary", "operator")
+ unary.operator, "unary", "operator"
+ )
if disp:
return disp(unary, unary.operator, **kw)
else:
return self._generate_generic_unary_operator(
- unary, OPERATORS[unary.operator], **kw)
+ unary, OPERATORS[unary.operator], **kw
+ )
elif unary.modifier:
disp = self._get_operator_dispatch(
- unary.modifier, "unary", "modifier")
+ unary.modifier, "unary", "modifier"
+ )
if disp:
return disp(unary, unary.modifier, **kw)
else:
return self._generate_generic_unary_modifier(
- unary, OPERATORS[unary.modifier], **kw)
+ unary, OPERATORS[unary.modifier], **kw
+ )
else:
raise exc.CompileError(
- "Unary expression has no operator or modifier")
+ "Unary expression has no operator or modifier"
+ )
def visit_istrue_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return self.process(element.element, **kw)
else:
return "%s = 1" % self.process(element.element, **kw)
def visit_isfalse_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return "NOT %s" % self.process(element.element, **kw)
else:
return "%s = 0" % self.process(element.element, **kw)
def visit_notmatch_op_binary(self, binary, operator, **kw):
return "NOT %s" % self.visit_binary(
- binary, override_operator=operators.match_op)
+ binary, override_operator=operators.match_op
+ )
def _emit_empty_in_warning(self):
util.warn(
- 'The IN-predicate was invoked with an '
- 'empty sequence. This results in a '
- 'contradiction, which nonetheless can be '
- 'expensive to evaluate. Consider alternative '
- 'strategies for improved performance.')
+ "The IN-predicate was invoked with an "
+ "empty sequence. This results in a "
+ "contradiction, which nonetheless can be "
+ "expensive to evaluate. Consider alternative "
+ "strategies for improved performance."
+ )
def visit_empty_in_op_binary(self, binary, operator, **kw):
if self.dialect._use_static_in:
@@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled):
def visit_empty_set_expr(self, element_types):
raise NotImplementedError(
- "Dialect '%s' does not support empty set expression." %
- self.dialect.name
+ "Dialect '%s' does not support empty set expression."
+ % self.dialect.name
)
- def visit_binary(self, binary, override_operator=None,
- eager_grouping=False, **kw):
+ def visit_binary(
+ self, binary, override_operator=None, eager_grouping=False, **kw
+ ):
# don't allow "? = ?" to render
- if self.ansi_bind_rules and \
- isinstance(binary.left, elements.BindParameter) and \
- isinstance(binary.right, elements.BindParameter):
- kw['literal_binds'] = True
+ if (
+ self.ansi_bind_rules
+ and isinstance(binary.left, elements.BindParameter)
+ and isinstance(binary.right, elements.BindParameter)
+ ):
+ kw["literal_binds"] = True
operator_ = override_operator or binary.operator
disp = self._get_operator_dispatch(operator_, "binary", None)
@@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled):
def visit_mod_binary(self, binary, operator, **kw):
if self.preparer._double_percents:
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
else:
- return self.process(binary.left, **kw) + " % " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
def visit_custom_op_binary(self, element, operator, **kw):
- kw['eager_grouping'] = operator.eager_grouping
+ kw["eager_grouping"] = operator.eager_grouping
return self._generate_generic_binary(
- element, " " + operator.opstring + " ", **kw)
+ element, " " + operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_operator(self, element, operator, **kw):
return self._generate_generic_unary_operator(
- element, operator.opstring + " ", **kw)
+ element, operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_modifier(self, element, operator, **kw):
return self._generate_generic_unary_modifier(
- element, " " + operator.opstring, **kw)
+ element, " " + operator.opstring, **kw
+ )
def _generate_generic_binary(
- self, binary, opstring, eager_grouping=False, **kw):
+ self, binary, opstring, eager_grouping=False, **kw
+ ):
- _in_binary = kw.get('_in_binary', False)
+ _in_binary = kw.get("_in_binary", False)
- kw['_in_binary'] = True
- text = binary.left._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw) + \
- opstring + \
- binary.right._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw)
+ kw["_in_binary"] = True
+ text = (
+ binary.left._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ + opstring
+ + binary.right._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ )
if _in_binary and eager_grouping:
text = "(%s)" % text
@@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled):
def visit_startswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notstartswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_endswith_op_binary(self, binary, operator, **kw):
@@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled):
escape = binary.modifiers.get("escape", None)
# TODO: use ternary here, not "and"/ "or"
- return '%s LIKE %s' % (
+ return "%s LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s NOT LIKE %s' % (
+ return "%s NOT LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) LIKE lower(%s)' % (
+ return "lower(%s) LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) NOT LIKE lower(%s)' % (
+ return "lower(%s) NOT LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " BETWEEN SYMMETRIC "
- if symmetric else " BETWEEN ", **kw)
+ binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
+ )
def visit_notbetween_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " NOT BETWEEN SYMMETRIC "
- if symmetric else " NOT BETWEEN ", **kw)
+ binary,
+ " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
+ **kw
+ )
- def visit_bindparam(self, bindparam, within_columns_clause=False,
- literal_binds=False,
- skip_bind_expression=False,
- **kwargs):
+ def visit_bindparam(
+ self,
+ bindparam,
+ within_columns_clause=False,
+ literal_binds=False,
+ skip_bind_expression=False,
+ **kwargs
+ ):
if not skip_bind_expression:
impl = bindparam.type.dialect_impl(self.dialect)
if impl._has_bind_expression:
bind_expression = impl.bind_expression(bindparam)
return self.process(
- bind_expression, skip_bind_expression=True,
+ bind_expression,
+ skip_bind_expression=True,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
**kwargs
)
- if literal_binds or \
- (within_columns_clause and
- self.ansi_bind_rules):
+ if literal_binds or (within_columns_clause and self.ansi_bind_rules):
if bindparam.value is None and bindparam.callable is None:
- raise exc.CompileError("Bind parameter '%s' without a "
- "renderable value not allowed here."
- % bindparam.key)
+ raise exc.CompileError(
+ "Bind parameter '%s' without a "
+ "renderable value not allowed here." % bindparam.key
+ )
return self.render_literal_bindparam(
- bindparam, within_columns_clause=True, **kwargs)
+ bindparam, within_columns_clause=True, **kwargs
+ )
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam:
- if (existing.unique or bindparam.unique) and \
- not existing.proxy_set.intersection(
- bindparam.proxy_set):
+ if (
+ existing.unique or bindparam.unique
+ ) and not existing.proxy_set.intersection(bindparam.proxy_set):
raise exc.CompileError(
"Bind parameter '%s' conflicts with "
- "unique bind parameter of the same name" %
- bindparam.key
+ "unique bind parameter of the same name"
+ % bindparam.key
)
elif existing._is_crud or bindparam._is_crud:
raise exc.CompileError(
@@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled):
"clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')." %
- (bindparam.key, bindparam.key)
+ "with insert() or update() (for example, 'b_%s')."
+ % (bindparam.key, bindparam.key)
)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(
- name, expanding=bindparam.expanding, **kwargs)
+ name, expanding=bindparam.expanding, **kwargs
+ )
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.effective_value
@@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled):
return processor(value)
else:
raise NotImplementedError(
- "Don't know how to literal-quote value %r" % value)
+ "Don't know how to literal-quote value %r" % value
+ )
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
@@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled):
if len(anonname) > self.label_length - 6:
counter = self.truncated_names.get(ident_class, 1)
- truncname = anonname[0:max(self.label_length - 6, 0)] + \
- "_" + hex(counter)[2:]
+ truncname = (
+ anonname[0 : max(self.label_length - 6, 0)]
+ + "_"
+ + hex(counter)[2:]
+ )
self.truncated_names[ident_class] = counter + 1
else:
truncname = anonname
@@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled):
return name % self.anon_map
def _process_anon(self, key):
- (ident, derived) = key.split(' ', 1)
+ (ident, derived) = key.split(" ", 1)
anonymous_counter = self.anon_map.get(derived, 1)
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(
- self, name, positional_names=None, expanding=False, **kw):
+ self, name, positional_names=None, expanding=False, **kw
+ ):
if self.positional:
if positional_names is not None:
positional_names.append(name)
@@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled):
self.contains_expanding_parameters = True
return "([EXPANDING_%s])" % name
else:
- return self.bindtemplate % {'name': name}
-
- def visit_cte(self, cte, asfrom=False, ashint=False,
- fromhints=None, visiting_cte=None,
- **kwargs):
+ return self.bindtemplate % {"name": name}
+
+ def visit_cte(
+ self,
+ cte,
+ asfrom=False,
+ ashint=False,
+ fromhints=None,
+ visiting_cte=None,
+ **kwargs
+ ):
self._init_cte_state()
- kwargs['visiting_cte'] = cte
+ kwargs["visiting_cte"] = cte
if isinstance(cte.name, elements._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
@@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled):
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
- "the same name: %r" %
- cte_name)
+ "the same name: %r" % cte_name
+ )
if asfrom or is_new_cte:
if cte._cte_alias is not None:
@@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled):
cte_pre_alias_name = cte._cte_alias.name
if isinstance(cte_pre_alias_name, elements._truncated_label):
cte_pre_alias_name = self._truncated_identifier(
- "alias", cte_pre_alias_name)
+ "alias", cte_pre_alias_name
+ )
else:
pre_alias_cte = cte
cte_pre_alias_name = None
@@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled):
self.ctes_by_name[cte_name] = cte
# look for embedded DML ctes and propagate autocommit
- if 'autocommit' in cte.element._execution_options and \
- 'autocommit' not in self.execution_options:
+ if (
+ "autocommit" in cte.element._execution_options
+ and "autocommit" not in self.execution_options
+ ):
self.execution_options = self.execution_options.union(
- {"autocommit":
- cte.element._execution_options['autocommit']})
+ {
+ "autocommit": cte.element._execution_options[
+ "autocommit"
+ ]
+ }
+ )
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled):
col_source = cte.original.selects[0]
else:
assert False
- recur_cols = [c for c in
- util.unique_list(col_source.inner_columns)
- if c is not None]
-
- text += "(%s)" % (", ".join(
- self.preparer.format_column(ident)
- for ident in recur_cols))
+ recur_cols = [
+ c
+ for c in util.unique_list(col_source.inner_columns)
+ if c is not None
+ ]
+
+ text += "(%s)" % (
+ ", ".join(
+ self.preparer.format_column(ident)
+ for ident in recur_cols
+ )
+ )
if self.positional:
- kwargs['positional_names'] = self.cte_positional[cte] = []
+ kwargs["positional_names"] = self.cte_positional[cte] = []
- text += " AS \n" + \
- cte.original._compiler_dispatch(
- self, asfrom=True, **kwargs
- )
+ text += " AS \n" + cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
if cte._suffixes:
text += " " + self._generate_prefixes(
- cte, cte._suffixes, **kwargs)
+ cte, cte._suffixes, **kwargs
+ )
self.ctes[cte] = text
@@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
- def visit_alias(self, alias, asfrom=False, ashint=False,
- iscrud=False,
- fromhints=None, **kwargs):
+ def visit_alias(
+ self,
+ alias,
+ asfrom=False,
+ ashint=False,
+ iscrud=False,
+ fromhints=None,
+ **kwargs
+ ):
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
@@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
- ret = alias.original._compiler_dispatch(self,
- asfrom=True, **kwargs) + \
- self.get_render_as_alias_suffix(
- self.preparer.format_alias(alias, alias_name))
+ ret = alias.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ ) + self.get_render_as_alias_suffix(
+ self.preparer.format_alias(alias, alias_name)
+ )
if fromhints and alias in fromhints:
- ret = self.format_from_hint_text(ret, alias,
- fromhints[alias], iscrud)
+ ret = self.format_from_hint_text(
+ ret, alias, fromhints[alias], iscrud
+ )
return ret
else:
return alias.original._compiler_dispatch(self, **kwargs)
def visit_lateral(self, lateral, **kw):
- kw['lateral'] = True
+ kw["lateral"] = True
return "LATERAL %s" % self.visit_alias(lateral, **kw)
def visit_tablesample(self, tablesample, asfrom=False, **kw):
text = "%s TABLESAMPLE %s" % (
self.visit_alias(tablesample, asfrom=True, **kw),
- tablesample._get_method()._compiler_dispatch(self, **kw))
+ tablesample._get_method()._compiler_dispatch(self, **kw),
+ )
if tablesample.seed is not None:
text += " REPEATABLE (%s)" % (
- tablesample.seed._compiler_dispatch(self, **kw))
+ tablesample.seed._compiler_dispatch(self, **kw)
+ )
return text
@@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled):
def _add_to_result_map(self, keyname, name, objects, type_):
self._result_columns.append((keyname, name, objects, type_))
- def _label_select_column(self, select, column,
- populate_result_map,
- asfrom, column_clause_args,
- name=None,
- within_columns_clause=True):
+ def _label_select_column(
+ self,
+ select,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=None,
+ within_columns_clause=True,
+ ):
"""produce labeled columns present in a select()."""
impl = column.type.dialect_impl(self.dialect)
- if impl._has_column_expression and \
- populate_result_map:
+ if impl._has_column_expression and populate_result_map:
col_expr = impl.column_expression(column)
def add_to_result_map(keyname, name, objects, type_):
self._add_to_result_map(
- keyname, name,
- (column,) + objects, type_)
+ keyname, name, (column,) + objects, type_
+ )
+
else:
col_expr = column
if populate_result_map:
@@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled):
elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
- col_expr,
- column.name,
- alt_names=(column.element,)
+ col_expr, column.name, alt_names=(column.element,)
)
else:
result_expr = col_expr
elif select is not None and name:
result_expr = _CompileLabel(
+ col_expr, name, alt_names=(column._key_label,)
+ )
+
+ elif (
+ asfrom
+ and isinstance(column, elements.ColumnClause)
+ and not column.is_literal
+ and column.table is not None
+ and not isinstance(column.table, selectable.Select)
+ ):
+ result_expr = _CompileLabel(
col_expr,
- name,
- alt_names=(column._key_label,)
- )
-
- elif \
- asfrom and \
- isinstance(column, elements.ColumnClause) and \
- not column.is_literal and \
- column.table is not None and \
- not isinstance(column.table, selectable.Select):
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
elif (
- not isinstance(column, elements.TextClause) and
- (
- not isinstance(column, elements.UnaryExpression) or
- column.wraps_column_expression
- ) and
- (
- not hasattr(column, 'name') or
- isinstance(column, functions.Function)
+ not isinstance(column, elements.TextClause)
+ and (
+ not isinstance(column, elements.UnaryExpression)
+ or column.wraps_column_expression
+ )
+ and (
+ not hasattr(column, "name")
+ or isinstance(column, functions.Function)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ result_expr = _CompileLabel(
+ col_expr,
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
else:
result_expr = col_expr
column_clause_args.update(
within_columns_clause=within_columns_clause,
- add_to_result_map=add_to_result_map
- )
- return result_expr._compiler_dispatch(
- self,
- **column_clause_args
+ add_to_result_map=add_to_result_map,
)
+ return result_expr._compiler_dispatch(self, **column_clause_args)
def format_from_hint_text(self, sqltext, table, hint, iscrud):
hinttext = self.get_from_hint_text(table, hint)
@@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled):
newelem = cloned[element] = element._clone()
- if newelem.is_selectable and newelem._is_join and \
- isinstance(newelem.right, selectable.FromGrouping):
+ if (
+ newelem.is_selectable
+ and newelem._is_join
+ and isinstance(newelem.right, selectable.FromGrouping)
+ ):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
@@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled):
right = visit(newelem.right, **kw)
selectable_ = selectable.Select(
- [right.element],
- use_labels=True).alias()
+ [right.element], use_labels=True
+ ).alias()
for c in selectable_.c:
c._key_label = c.key
@@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled):
elif newelem._is_from_container:
# if we hit an Alias, CompoundSelect or ScalarSelect, put a
# marker in the stack.
- kw['transform_clue'] = 'select_container'
+ kw["transform_clue"] = "select_container"
newelem._copy_internals(clone=visit, **kw)
elif newelem.is_selectable and newelem._is_select:
- barrier_select = kw.get('transform_clue', None) == \
- 'select_container'
+ barrier_select = (
+ kw.get("transform_clue", None) == "select_container"
+ )
# if we're still descended from an
# Alias/CompoundSelect/ScalarSelect, we're
# in a FROM clause, so start with a new translate collection
if barrier_select:
column_translate.append({})
- kw['transform_clue'] = 'inside_select'
+ kw["transform_clue"] = "inside_select"
newelem._copy_internals(clone=visit, **kw)
if barrier_select:
del column_translate[-1]
@@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled):
return visit(select)
def _transform_result_map_for_nested_joins(
- self, select, transformed_select):
- inner_col = dict((c._key_label, c) for
- c in transformed_select.inner_columns)
-
- d = dict(
- (inner_col[c._key_label], c)
- for c in select.inner_columns
+ self, select, transformed_select
+ ):
+ inner_col = dict(
+ (c._key_label, c) for c in transformed_select.inner_columns
)
+ d = dict((inner_col[c._key_label], c) for c in select.inner_columns)
+
self._result_columns = [
(key, name, tuple([d.get(col, col) for col in objs]), typ)
for key, name, objs, typ in self._result_columns
]
- _default_stack_entry = util.immutabledict([
- ('correlate_froms', frozenset()),
- ('asfrom_froms', frozenset())
- ])
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(self, select, asfrom, lateral=False):
# utility method to help external dialects
@@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
return froms
- def visit_select(self, select, asfrom=False, parens=True,
- fromhints=None,
- compound_index=0,
- nested_join_translation=False,
- select_wraps_for=None,
- lateral=False,
- **kwargs):
-
- needs_nested_translation = \
- select.use_labels and \
- not nested_join_translation and \
- not self.stack and \
- not self.dialect.supports_right_nested_joins
+ def visit_select(
+ self,
+ select,
+ asfrom=False,
+ parens=True,
+ fromhints=None,
+ compound_index=0,
+ nested_join_translation=False,
+ select_wraps_for=None,
+ lateral=False,
+ **kwargs
+ ):
+
+ needs_nested_translation = (
+ select.use_labels
+ and not nested_join_translation
+ and not self.stack
+ and not self.dialect.supports_right_nested_joins
+ )
if needs_nested_translation:
transformed_select = self._transform_select_for_nested_joins(
- select)
+ select
+ )
text = self.visit_select(
- transformed_select, asfrom=asfrom, parens=parens,
+ transformed_select,
+ asfrom=asfrom,
+ parens=parens,
fromhints=fromhints,
compound_index=compound_index,
- nested_join_translation=True, **kwargs
+ nested_join_translation=True,
+ **kwargs
)
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
# this was first proposed as part of #3372; however, it is not
# reached in current tests and could possibly be an assertion
# instead.
- if not populate_result_map and 'add_to_result_map' in kwargs:
- del kwargs['add_to_result_map']
+ if not populate_result_map and "add_to_result_map" in kwargs:
+ del kwargs["add_to_result_map"]
if needs_nested_translation:
if populate_result_map:
self._transform_result_map_for_nested_joins(
- select, transformed_select)
+ select, transformed_select
+ )
return text
froms = self._setup_select_stack(select, entry, asfrom, lateral)
column_clause_args = kwargs.copy()
- column_clause_args.update({
- 'within_label_clause': False,
- 'within_columns_clause': False
- })
+ column_clause_args.update(
+ {"within_label_clause": False, "within_columns_clause": False}
+ )
text = "SELECT " # we're off to a good start !
@@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled):
byfrom = None
if select._prefixes:
- text += self._generate_prefixes(
- select, select._prefixes, **kwargs)
+ text += self._generate_prefixes(select, select._prefixes, **kwargs)
text += self.get_select_precolumns(select, **kwargs)
# the actual list of columns to print in the SELECT column list.
inner_columns = [
- c for c in [
+ c
+ for c in [
self._label_select_column(
select,
column,
- populate_result_map, asfrom,
+ populate_result_map,
+ asfrom,
column_clause_args,
- name=name)
+ name=name,
+ )
for name, column in select._columns_plus_names
]
if c is not None
@@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled):
translate = dict(
zip(
[name for (key, name) in select._columns_plus_names],
- [name for (key, name) in
- select_wraps_for._columns_plus_names])
+ [
+ name
+ for (key, name) in select_wraps_for._columns_plus_names
+ ],
+ )
)
self._result_columns = [
@@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, kwargs)
+ text, select, inner_columns, froms, byfrom, kwargs
+ )
if select._statement_hints:
per_dialect = [
- ht for (dialect_name, ht)
- in select._statement_hints
- if dialect_name in ('*', self.dialect.name)
+ ht
+ for (dialect_name, ht) in select._statement_hints
+ if dialect_name in ("*", self.dialect.name)
]
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
@@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled):
if select._suffixes:
text += " " + self._generate_prefixes(
- select, select._suffixes, **kwargs)
+ select, select._suffixes, **kwargs
+ )
self.stack.pop(-1)
@@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled):
return text
def _setup_select_hints(self, select):
- byfrom = dict([
- (from_, hinttext % {
- 'name': from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ byfrom = dict(
+ [
+ (
+ from_,
+ hinttext
+ % {"name": from_._compiler_dispatch(self, ashint=True)},
+ )
+ for (from_, dialect), hinttext in select._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
hint_text = self.get_select_hint_text(byfrom)
return hint_text, byfrom
def _setup_select_stack(self, select, entry, asfrom, lateral):
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'correlate_froms': all_correlate_froms,
- 'selectable': select,
+ "asfrom_froms": new_correlate_froms,
+ "correlate_froms": all_correlate_froms,
+ "selectable": select,
}
self.stack.append(new_entry)
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, kwargs):
- text += ', '.join(inner_columns)
+ self, text, select, inner_columns, froms, byfrom, kwargs
+ ):
+ text += ", ".join(inner_columns)
if froms:
text += " \nFROM "
if select._hints:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True,
- fromhints=byfrom, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self, asfrom=True, fromhints=byfrom, **kwargs
+ )
+ for f in froms
+ ]
+ )
else:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(self, asfrom=True, **kwargs)
+ for f in froms
+ ]
+ )
else:
text += self.default_from()
@@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled):
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)
- if (select._limit_clause is not None or
- select._offset_clause is not None):
+ if (
+ select._limit_clause is not None
+ or select._offset_clause is not None
+ ):
text += self.limit_clause(select, **kwargs)
if select._for_update_arg is not None:
@@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled):
clause = " ".join(
prefix._compiler_dispatch(self, **kw)
for prefix, dialect_name in prefixes
- if dialect_name is None or
- dialect_name == self.dialect.name
+ if dialect_name is None or dialect_name == self.dialect.name
)
if clause:
clause += " "
@@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled):
def _render_cte_clause(self):
if self.positional:
- self.positiontup = sum([
- self.cte_positional[cte]
- for cte in self.ctes], []) + \
- self.positiontup
+ self.positiontup = (
+ sum([self.cte_positional[cte] for cte in self.ctes], [])
+ + self.positiontup
+ )
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
- cte_text += ", \n".join(
- [txt for txt in self.ctes.values()]
- )
+ cte_text += ", \n".join([txt for txt in self.ctes.values()])
cte_text += "\n "
return cte_text
@@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled):
def returning_clause(self, stmt, returning_cols):
raise exc.CompileError(
"RETURNING is not supported by this "
- "dialect's statement compiler.")
+ "dialect's statement compiler."
+ )
def limit_clause(self, select, **kw):
text = ""
@@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled):
text += " OFFSET " + self.process(select._offset_clause, **kw)
return text
- def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
- fromhints=None, use_schema=True, **kwargs):
+ def visit_table(
+ self,
+ table,
+ asfrom=False,
+ iscrud=False,
+ ashint=False,
+ fromhints=None,
+ use_schema=True,
+ **kwargs
+ ):
if asfrom or ashint:
effective_schema = self.preparer.schema_for_object(table)
if use_schema and effective_schema:
- ret = self.preparer.quote_schema(effective_schema) + \
- "." + self.preparer.quote(table.name)
+ ret = (
+ self.preparer.quote_schema(effective_schema)
+ + "."
+ + self.preparer.quote(table.name)
+ )
else:
ret = self.preparer.quote(table.name)
if fromhints and table in fromhints:
- ret = self.format_from_hint_text(ret, table,
- fromhints[table], iscrud)
+ ret = self.format_from_hint_text(
+ ret, table, fromhints[table], iscrud
+ )
return ret
else:
return ""
@@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled):
else:
join_type = " JOIN "
return (
- join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
- join_type +
- join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
- " ON " +
- join.onclause._compiler_dispatch(self, **kwargs)
+ join.left._compiler_dispatch(self, asfrom=True, **kwargs)
+ + join_type
+ + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
+ + " ON "
+ + join.onclause._compiler_dispatch(self, **kwargs)
)
def _setup_crud_hints(self, stmt, table_text):
- dialect_hints = dict([
- (table, hint_text)
- for (table, dialect), hint_text in
- stmt._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ dialect_hints = dict(
+ [
+ (table, hint_text)
+ for (table, dialect), hint_text in stmt._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
if stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
- table_text,
- stmt.table,
- dialect_hints[stmt.table],
- True
+ table_text, stmt.table, dialect_hints[stmt.table], True
)
return dialect_hints, table_text
@@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
self.stack.append(
- {'correlate_froms': set(),
- "asfrom_froms": set(),
- "selectable": insert_stmt})
+ {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": insert_stmt,
+ }
+ )
crud_params = crud._setup_crud_params(
- self, insert_stmt, crud.ISINSERT, **kw)
+ self, insert_stmt, crud.ISINSERT, **kw
+ )
- if not crud_params and \
- not self.dialect.supports_default_values and \
- not self.dialect.supports_empty_insert:
- raise exc.CompileError("The '%s' dialect with current database "
- "version settings does not support empty "
- "inserts." %
- self.dialect.name)
+ if (
+ not crud_params
+ and not self.dialect.supports_default_values
+ and not self.dialect.supports_empty_insert
+ ):
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." % self.dialect.name
+ )
if insert_stmt._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
raise exc.CompileError(
"The '%s' dialect with current database "
"version settings does not support "
- "in-place multirow inserts." %
- self.dialect.name)
+ "in-place multirow inserts." % self.dialect.name
+ )
crud_params_single = crud_params[0]
else:
crud_params_single = crud_params
@@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled):
text = "INSERT "
if insert_stmt._prefixes:
- text += self._generate_prefixes(insert_stmt,
- insert_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ insert_stmt, insert_stmt._prefixes, **kw
+ )
text += "INTO "
table_text = preparer.format_table(insert_stmt.table)
if insert_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- insert_stmt, table_text)
+ insert_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
if crud_params_single or not supports_default_values:
- text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in crud_params_single])
+ text += " (%s)" % ", ".join(
+ [preparer.format_column(c[0]) for c in crud_params_single]
+ )
if self.returning or insert_stmt._returning:
returning_clause = self.returning_clause(
- insert_stmt, self.returning or insert_stmt._returning)
+ insert_stmt, self.returning or insert_stmt._returning
+ )
if self.returning_precedes_values:
text += " " + returning_clause
@@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled):
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
- "(%s)" % (
- ', '.join(c[1] for c in crud_param_set)
- )
+ "(%s)" % (", ".join(c[1] for c in crud_param_set))
for crud_param_set in crud_params
)
)
else:
- text += " VALUES (%s)" % \
- ', '.join([c[1] for c in crud_params])
+ text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params])
if insert_stmt._post_values_clause is not None:
post_values_clause = self.process(
- insert_stmt._post_values_clause, **kw)
+ insert_stmt._post_values_clause, **kw
+ )
if post_values_clause:
text += " " + post_values_clause
@@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
- def update_tables_clause(self, update_stmt, from_table,
- extra_froms, **kw):
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
MySQL overrides this.
"""
- kw['asfrom'] = True
+ kw["asfrom"] = True
return from_table._compiler_dispatch(self, iscrud=True, **kw)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
@@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within UPDATE")
+ "criteria within UPDATE"
+ )
def visit_update(self, update_stmt, asfrom=False, **kw):
toplevel = not self.stack
@@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled):
correlate_froms = {update_stmt.table}
self.stack.append(
- {'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": update_stmt})
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": update_stmt,
+ }
+ )
text = "UPDATE "
if update_stmt._prefixes:
- text += self._generate_prefixes(update_stmt,
- update_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ update_stmt, update_stmt._prefixes, **kw
+ )
- table_text = self.update_tables_clause(update_stmt, update_stmt.table,
- render_extra_froms, **kw)
+ table_text = self.update_tables_clause(
+ update_stmt, update_stmt.table, render_extra_froms, **kw
+ )
crud_params = crud._setup_crud_params(
- self, update_stmt, crud.ISUPDATE, **kw)
+ self, update_stmt, crud.ISUPDATE, **kw
+ )
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- update_stmt, table_text)
+ update_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
- text += ' SET '
- include_table = is_multitable and \
- self.render_table_with_column_in_update_from
- text += ', '.join(
- c[0]._compiler_dispatch(self,
- include_table=include_table) +
- '=' + c[1] for c in crud_params
+ text += " SET "
+ include_table = (
+ is_multitable and self.render_table_with_column_in_update_from
+ )
+ text += ", ".join(
+ c[0]._compiler_dispatch(self, include_table=include_table)
+ + "="
+ + c[1]
+ for c in crud_params
)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
render_extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled):
if limit_clause:
text += " " + limit_clause
- if (self.returning or update_stmt._returning) and \
- not self.returning_precedes_values:
+ if (
+ self.returning or update_stmt._returning
+ ) and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled):
def _key_getters_for_crud_column(self):
return crud._key_getters_for_crud_column(self, self.statement)
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints, **kw):
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
DELETE..FROM clause.
@@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within DELETE")
+ "criteria within DELETE"
+ )
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, asfrom=False, **kw):
@@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled):
extra_froms = delete_stmt._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
- self.stack.append({'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": delete_stmt})
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": delete_stmt,
+ }
+ )
text = "DELETE "
if delete_stmt._prefixes:
- text += self._generate_prefixes(delete_stmt,
- delete_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ delete_stmt, delete_stmt._prefixes, **kw
+ )
text += "FROM "
- table_text = self.delete_table_clause(delete_stmt, delete_stmt.table,
- extra_froms)
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- delete_stmt, table_text)
+ delete_stmt, table_text
+ )
else:
dialect_hints = None
@@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled):
if delete_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if extra_froms:
extra_from_text = self.delete_extra_from_clause(
delete_stmt,
delete_stmt.table,
extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled):
if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
- return "ROLLBACK TO SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_release_savepoint(self, savepoint_stmt):
- return "RELEASE SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
class StrSQLCompiler(SQLCompiler):
@@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler):
def visit_getitem_binary(self, binary, operator, **kw):
return "%s[%s]" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_json_getitem_op_binary(self, binary, operator, **kw):
@@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler):
for c in elements._select_iterables(returning_cols)
]
- return 'RETURNING ' + ', '.join(columns)
+ return "RETURNING " + ", ".join(columns)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return ', ' + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return ", " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
class DDLCompiler(Compiled):
-
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(self.dialect, None)
@@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled):
preparer = self.preparer
path = preparer.format_table_seq(ddl.target)
if len(path) == 1:
- table, sch = path[0], ''
+ table, sch = path[0], ""
else:
table, sch = path[-1], path[0]
- context.setdefault('table', table)
- context.setdefault('schema', sch)
- context.setdefault('fullname', preparer.format_table(ddl.target))
+ context.setdefault("table", table)
+ context.setdefault("schema", sch)
+ context.setdefault("fullname", preparer.format_table(ddl.target))
return self.sql_compiler.post_process_text(ddl.statement % context)
@@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled):
for create_column in create.columns:
column = create_column.element
try:
- processed = self.process(create_column,
- first_pk=column.primary_key
- and not first_pk)
+ processed = self.process(
+ create_column, first_pk=column.primary_key and not first_pk
+ )
if processed is not None:
text += separator
separator = ", \n"
@@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled):
except exc.CompileError as ce:
util.raise_from_cause(
exc.CompileError(
- util.u("(in table '%s', column '%s'): %s") %
- (table.description, column.name, ce.args[0])
- ))
+ util.u("(in table '%s', column '%s'): %s")
+ % (table.description, column.name, ce.args[0])
+ )
+ )
const = self.create_table_constraints(
- table, _include_foreign_key_constraints= # noqa
- create.include_foreign_key_constraints)
+ table,
+ _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
+ )
if const:
text += separator + "\t" + const
@@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled):
if column.system:
return None
- text = self.get_column_specification(
- column,
- first_pk=first_pk
+ text = self.get_column_specification(column, first_pk=first_pk)
+ const = " ".join(
+ self.process(constraint) for constraint in column.constraints
)
- const = " ".join(self.process(constraint)
- for constraint in column.constraints)
if const:
text += " " + const
return text
def create_table_constraints(
- self, table,
- _include_foreign_key_constraints=None):
+ self, table, _include_foreign_key_constraints=None
+ ):
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
@@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled):
else:
omit_fkcs = set()
- constraints.extend([c for c in table._sorted_constraints
- if c is not table.primary_key and
- c not in omit_fkcs])
+ constraints.extend(
+ [
+ c
+ for c in table._sorted_constraints
+ if c is not table.primary_key and c not in omit_fkcs
+ ]
+ )
return ", \n\t".join(
- p for p in
- (self.process(constraint)
+ p
+ for p in (
+ self.process(constraint)
for constraint in constraints
if (
- constraint._create_rule is None or
- constraint._create_rule(self))
+ constraint._create_rule is None
+ or constraint._create_rule(self)
+ )
and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
+ not self.dialect.supports_alter
+ or not getattr(constraint, "use_alter", False)
+ )
+ )
+ if p is not None
)
def visit_drop_table(self, drop):
@@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled):
def _verify_index_table(self, index):
if index.table is None:
- raise exc.CompileError("Index '%s' is not associated "
- "with any table." % index.name)
+ raise exc.CompileError(
+ "Index '%s' is not associated " "with any table." % index.name
+ )
- def visit_create_index(self, create, include_schema=False,
- include_table_schema=True):
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True
+ ):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
text = "CREATE "
if index.unique:
text += "UNIQUE "
- text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table,
- use_schema=include_table_schema),
- ', '.join(
- self.sql_compiler.process(
- expr, include_table=False, literal_binds=True) for
- expr in index.expressions)
- )
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(
+ index.table, use_schema=include_table_schema
+ ),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
return text
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + self._prepared_index_name(
- index, include_schema=True)
+ index, include_schema=True
+ )
def _prepared_index_name(self, index, include_schema=False):
if index.table is not None:
@@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled):
def visit_add_constraint(self, create):
return "ALTER TABLE %s ADD %s" % (
self.preparer.format_table(create.element.table),
- self.process(create.element)
+ self.process(create.element),
)
def visit_set_table_comment(self, create):
return "COMMENT ON TABLE %s IS %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_table_comment(self, drop):
- return "COMMENT ON TABLE %s IS NULL" % \
- self.preparer.format_table(drop.element)
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
def visit_set_column_comment(self, create):
return "COMMENT ON COLUMN %s IS %s" % (
self.preparer.format_column(
- create.element, use_table=True, use_schema=True),
+ create.element, use_table=True, use_schema=True
+ ),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_column_comment(self, drop):
- return "COMMENT ON COLUMN %s IS NULL" % \
- self.preparer.format_column(drop.element, use_table=True)
+ return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
+ drop.element, use_table=True
+ )
def visit_create_sequence(self, create):
- text = "CREATE SEQUENCE %s" % \
- self.preparer.format_sequence(create.element)
+ text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
if create.element.increment is not None:
text += " INCREMENT BY %d" % create.element.increment
if create.element.start is not None:
@@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled):
return text
def visit_drop_sequence(self, drop):
- return "DROP SEQUENCE %s" % \
- self.preparer.format_sequence(drop.element)
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
def visit_drop_constraint(self, drop):
constraint = drop.element
@@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled):
if formatted_name is None:
raise exc.CompileError(
"Can't emit DROP CONSTRAINT for constraint %r; "
- "it has no name" % drop.element)
+ "it has no name" % drop.element
+ )
return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
self.preparer.format_table(drop.element.table),
formatted_name,
- drop.cascade and " CASCADE" or ""
+ drop.cascade and " CASCADE" or "",
)
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled):
return colspec
def create_table_suffix(self, table):
- return ''
+ return ""
def post_create_table(self, table):
- return ''
+ return ""
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, util.string_types):
return self.sql_compiler.render_literal_value(
- column.server_default.arg, sqltypes.STRINGTYPE)
+ column.server_default.arg, sqltypes.STRINGTYPE
+ )
else:
return self.sql_compiler.process(
- column.server_default.arg, literal_binds=True)
+ column.server_default.arg, literal_binds=True
+ )
else:
return None
@@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "PRIMARY KEY "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in (constraint.columns_autoinc_first
- if constraint._implicit_generated
- else constraint.columns))
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name)
+ for c in (
+ constraint.columns_autoinc_first
+ if constraint._implicit_generated
+ else constraint.columns
+ )
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled):
text += "CONSTRAINT %s " % formatted_name
remote_table = list(constraint.elements)[0].column.table
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join(preparer.quote(f.parent.name)
- for f in constraint.elements),
+ ", ".join(
+ preparer.quote(f.parent.name) for f in constraint.elements
+ ),
self.define_constraint_remote_table(
- constraint, remote_table, preparer),
- ', '.join(preparer.quote(f.column.name)
- for f in constraint.elements)
+ constraint, remote_table, preparer
+ ),
+ ", ".join(
+ preparer.quote(f.column.name) for f in constraint.elements
+ ),
)
text += self.define_constraint_match(constraint)
text += self.define_constraint_cascades(constraint)
@@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled):
def visit_unique_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
text += "CONSTRAINT %s " % formatted_name
text += "UNIQUE (%s)" % (
- ', '.join(self.preparer.quote(c.name)
- for c in constraint))
+ ", ".join(self.preparer.quote(c.name) for c in constraint)
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled):
class GenericTypeCompiler(TypeCompiler):
-
def visit_FLOAT(self, type_, **kw):
return "FLOAT"
@@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
- return "NUMERIC(%(precision)s)" % \
- {'precision': type_.precision}
+ return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
else:
- return "NUMERIC(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "NUMERIC(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
- return "DECIMAL(%(precision)s)" % \
- {'precision': type_.precision}
+ return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
else:
- return "DECIMAL(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "DECIMAL(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_INTEGER(self, type_, **kw):
return "INTEGER"
@@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler):
return "BIGINT"
def visit_TIMESTAMP(self, type_, **kw):
- return 'TIMESTAMP'
+ return "TIMESTAMP"
def visit_DATETIME(self, type_, **kw):
return "DATETIME"
@@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler):
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
- raise exc.CompileError("Can't generate DDL for %r; "
- "did you forget to specify a "
- "type on this Column?" % type_)
+ raise exc.CompileError(
+ "Can't generate DDL for %r; "
+ "did you forget to specify a "
+ "type on this Column?" % type_
+ )
def visit_type_decorator(self, type_, **kw):
return self.process(type_.type_engine(self.dialect), **kw)
@@ -3018,9 +3353,15 @@ class IdentifierPreparer(object):
schema_for_object = schema._schema_getter(None)
- def __init__(self, dialect, initial_quote='"',
- final_quote=None, escape_quote='"',
- quote_case_sensitive_collations=True, omit_schema=False):
+ def __init__(
+ self,
+ dialect,
+ initial_quote='"',
+ final_quote=None,
+ escape_quote='"',
+ quote_case_sensitive_collations=True,
+ omit_schema=False,
+ ):
"""Construct a new ``IdentifierPreparer`` object.
initial_quote
@@ -3043,7 +3384,10 @@ class IdentifierPreparer(object):
self.omit_schema = omit_schema
self.quote_case_sensitive_collations = quote_case_sensitive_collations
self._strings = {}
- self._double_percents = self.dialect.paramstyle in ('format', 'pyformat')
+ self._double_percents = self.dialect.paramstyle in (
+ "format",
+ "pyformat",
+ )
def _with_schema_translate(self, schema_translate_map):
prep = self.__class__.__new__(self.__class__)
@@ -3060,7 +3404,7 @@ class IdentifierPreparer(object):
value = value.replace(self.escape_quote, self.escape_to_quote)
if self._double_percents:
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
return value
def _unescape_identifier(self, value):
@@ -3079,17 +3423,21 @@ class IdentifierPreparer(object):
quoting behavior.
"""
- return self.initial_quote + \
- self._escape_identifier(value) + \
- self.final_quote
+ return (
+ self.initial_quote
+ + self._escape_identifier(value)
+ + self.final_quote
+ )
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
- return (lc_value in self.reserved_words
- or value[0] in self.illegal_initial_characters
- or not self.legal_characters.match(util.text_type(value))
- or (lc_value != value))
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ or (lc_value != value)
+ )
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema.
@@ -3135,8 +3483,11 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(sequence)
- if (not self.omit_schema and use_schema and
- effective_schema is not None):
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
name = self.quote_schema(effective_schema) + "." + name
return name
@@ -3159,7 +3510,8 @@ class IdentifierPreparer(object):
def format_constraint(self, naming, constraint):
if isinstance(constraint.name, elements._defer_name):
name = naming._constraint_name_for_table(
- constraint, constraint.table)
+ constraint, constraint.table
+ )
if name is None:
if isinstance(constraint.name, elements._defer_none_name):
@@ -3170,14 +3522,15 @@ class IdentifierPreparer(object):
name = constraint.name
if isinstance(name, elements._truncated_label):
- if constraint.__visit_name__ == 'index':
- max_ = self.dialect.max_index_name_length or \
- self.dialect.max_identifier_length
+ if constraint.__visit_name__ == "index":
+ max_ = (
+ self.dialect.max_index_name_length
+ or self.dialect.max_identifier_length
+ )
else:
max_ = self.dialect.max_identifier_length
if len(name) > max_:
- name = name[0:max_ - 8] + \
- "_" + util.md5_hex(name)[-4:]
+ name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
else:
self.dialect.validate_identifier(name)
@@ -3195,8 +3548,7 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema \
- and effective_schema:
+ if not self.omit_schema and use_schema and effective_schema:
result = self.quote_schema(effective_schema) + "." + result
return result
@@ -3205,17 +3557,27 @@ class IdentifierPreparer(object):
return self.quote(name, quote)
- def format_column(self, column, use_table=False,
- name=None, table_name=None, use_schema=False):
+ def format_column(
+ self,
+ column,
+ use_table=False,
+ name=None,
+ table_name=None,
+ use_schema=False,
+ ):
"""Prepare a quoted column name."""
if name is None:
name = column.name
- if not getattr(column, 'is_literal', False):
+ if not getattr(column, "is_literal", False):
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + "." + self.quote(name)
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + self.quote(name)
+ )
else:
return self.quote(name)
else:
@@ -3223,9 +3585,13 @@ class IdentifierPreparer(object):
# which shouldn't get quoted
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + '.' + name
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + name
+ )
else:
return name
@@ -3238,31 +3604,37 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema and \
- effective_schema:
- return (self.quote_schema(effective_schema),
- self.format_table(table, use_schema=False))
+ if not self.omit_schema and use_schema and effective_schema:
+ return (
+ self.quote_schema(effective_schema),
+ self.format_table(table, use_schema=False),
+ )
else:
- return (self.format_table(table, use_schema=False), )
+ return (self.format_table(table, use_schema=False),)
@util.memoized_property
def _r_identifiers(self):
- initial, final, escaped_final = \
- [re.escape(s) for s in
- (self.initial_quote, self.final_quote,
- self._escape_identifier(self.final_quote))]
+ initial, final, escaped_final = [
+ re.escape(s)
+ for s in (
+ self.initial_quote,
+ self.final_quote,
+ self._escape_identifier(self.final_quote),
+ )
+ ]
r = re.compile(
- r'(?:'
- r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
- r'|([^\.]+))(?=\.|$))+' %
- {'initial': initial,
- 'final': final,
- 'escaped': escaped_final})
+ r"(?:"
+ r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
+ r"|([^\.]+))(?=\.|$))+"
+ % {"initial": initial, "final": final, "escaped": escaped_final}
+ )
return r
def unformat_identifiers(self, identifiers):
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers
- return [self._unescape_identifier(i)
- for i in [a or b for a, b in r.findall(identifiers)]]
+ return [
+ self._unescape_identifier(i)
+ for i in [a or b for a, b in r.findall(identifiers)]
+ ]