summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py327
1 files changed, 196 insertions, 131 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 12cfe09d1..4feaf9938 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -15,15 +15,29 @@ from . import operators, visitors
from itertools import chain
from collections import deque
-from .elements import BindParameter, ColumnClause, ColumnElement, \
- Null, UnaryExpression, literal_column, Label, _label_reference, \
- _textual_label_reference
-from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping
+from .elements import (
+ BindParameter,
+ ColumnClause,
+ ColumnElement,
+ Null,
+ UnaryExpression,
+ literal_column,
+ Label,
+ _label_reference,
+ _textual_label_reference,
+)
+from .selectable import (
+ SelectBase,
+ ScalarSelect,
+ Join,
+ FromClause,
+ FromGrouping,
+)
from .schema import Column
join_condition = util.langhelpers.public_factory(
- Join._join_condition,
- ".sql.util.join_condition")
+ Join._join_condition, ".sql.util.join_condition"
+)
# names that are still being imported from the outside
from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
@@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from):
for idx in liberal_idx:
f = clauses[idx]
for s in selectables:
- if set(surface_selectables(f)).\
- intersection(surface_selectables(s)):
+ if set(surface_selectables(f)).intersection(
+ surface_selectables(s)
+ ):
conservative_idx.append(idx)
break
if conservative_idx:
@@ -184,8 +199,9 @@ def visit_binary_product(fn, expr):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
- elif element.__visit_name__ == 'binary' and \
- operators.is_comparison(element.operator):
+ elif element.__visit_name__ == "binary" and operators.is_comparison(
+ element.operator
+ ):
stack.insert(0, element)
for l in visit(element.left):
for r in visit(element.right):
@@ -199,38 +215,47 @@ def visit_binary_product(fn, expr):
for elem in element.get_children():
for e in visit(elem):
yield e
+
list(visit(expr))
-def find_tables(clause, check_columns=False,
- include_aliases=False, include_joins=False,
- include_selects=False, include_crud=False):
+def find_tables(
+ clause,
+ check_columns=False,
+ include_aliases=False,
+ include_joins=False,
+ include_selects=False,
+ include_crud=False,
+):
"""locate Table objects within the given expression."""
tables = []
_visitors = {}
if include_selects:
- _visitors['select'] = _visitors['compound_select'] = tables.append
+ _visitors["select"] = _visitors["compound_select"] = tables.append
if include_joins:
- _visitors['join'] = tables.append
+ _visitors["join"] = tables.append
if include_aliases:
- _visitors['alias'] = tables.append
+ _visitors["alias"] = tables.append
if include_crud:
- _visitors['insert'] = _visitors['update'] = \
- _visitors['delete'] = lambda ent: tables.append(ent.table)
+ _visitors["insert"] = _visitors["update"] = _visitors[
+ "delete"
+ ] = lambda ent: tables.append(ent.table)
if check_columns:
+
def visit_column(column):
tables.append(column.table)
- _visitors['column'] = visit_column
- _visitors['table'] = tables.append
+ _visitors["column"] = visit_column
- visitors.traverse(clause, {'column_collections': False}, _visitors)
+ _visitors["table"] = tables.append
+
+ visitors.traverse(clause, {"column_collections": False}, _visitors)
return tables
@@ -243,10 +268,9 @@ def unwrap_order_by(clause):
stack = deque([clause])
while stack:
t = stack.popleft()
- if isinstance(t, ColumnElement) and \
- (
- not isinstance(t, UnaryExpression) or
- not operators.is_ordering_modifier(t.modifier)
+ if isinstance(t, ColumnElement) and (
+ not isinstance(t, UnaryExpression)
+ or not operators.is_ordering_modifier(t.modifier)
):
if isinstance(t, _label_reference):
t = t.element
@@ -266,9 +290,7 @@ def unwrap_label_reference(element):
if isinstance(elem, (_label_reference, _textual_label_reference)):
return elem.element
- return visitors.replacement_traverse(
- element, {}, replace
- )
+ return visitors.replacement_traverse(element, {}, replace)
def expand_column_list_from_order_by(collist, order_by):
@@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by):
in the collist.
"""
- cols_already_present = set([
- col.element if col._order_by_label_element is not None
- else col for col in collist
- ])
+ cols_already_present = set(
+ [
+ col.element if col._order_by_label_element is not None else col
+ for col in collist
+ ]
+ )
return [
- col for col in
- chain(*[
- unwrap_order_by(o)
- for o in order_by
- ])
+ col
+ for col in chain(*[unwrap_order_by(o) for o in order_by])
if col not in cols_already_present
]
@@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True):
be addressable in the WHERE clause of a SELECT if this element were
in the columns clause."""
- filter_ = (FromGrouping, )
+ filter_ = (FromGrouping,)
if not include_scalar_selects:
- filter_ += (SelectBase, )
+ filter_ += (SelectBase,)
stack = deque([clause])
while stack:
@@ -343,9 +364,7 @@ def selectables_overlap(left, right):
"""Return True if left/right have some overlapping selectable"""
return bool(
- set(surface_selectables(left)).intersection(
- surface_selectables(right)
- )
+ set(surface_selectables(left)).intersection(surface_selectables(right))
)
@@ -366,7 +385,7 @@ def bind_values(clause):
def visit_bindparam(bind):
v.append(bind.effective_value)
- visitors.traverse(clause, {}, {'bindparam': visit_bindparam})
+ visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
return v
@@ -383,7 +402,7 @@ class _repr_base(object):
_TUPLE = 1
_DICT = 2
- __slots__ = 'max_chars',
+ __slots__ = ("max_chars",)
def trunc(self, value):
rep = repr(value)
@@ -391,10 +410,12 @@ class _repr_base(object):
if lenrep > self.max_chars:
segment_length = self.max_chars // 2
rep = (
- rep[0:segment_length] +
- (" ... (%d characters truncated) ... "
- % (lenrep - self.max_chars)) +
- rep[-segment_length:]
+ rep[0:segment_length]
+ + (
+ " ... (%d characters truncated) ... "
+ % (lenrep - self.max_chars)
+ )
+ + rep[-segment_length:]
)
return rep
@@ -402,7 +423,7 @@ class _repr_base(object):
class _repr_row(_repr_base):
"""Provide a string view of a row."""
- __slots__ = 'row',
+ __slots__ = ("row",)
def __init__(self, row, max_chars=300):
self.row = row
@@ -412,7 +433,7 @@ class _repr_row(_repr_base):
trunc = self.trunc
return "(%s%s)" % (
", ".join(trunc(value) for value in self.row),
- "," if len(self.row) == 1 else ""
+ "," if len(self.row) == 1 else "",
)
@@ -424,7 +445,7 @@ class _repr_params(_repr_base):
"""
- __slots__ = 'params', 'batches',
+ __slots__ = "params", "batches"
def __init__(self, params, batches, max_chars=300):
self.params = params
@@ -435,11 +456,13 @@ class _repr_params(_repr_base):
if isinstance(self.params, list):
typ = self._LIST
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, tuple):
typ = self._TUPLE
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, dict):
typ = self._DICT
ismulti = False
@@ -448,11 +471,15 @@ class _repr_params(_repr_base):
if ismulti and len(self.params) > self.batches:
msg = " ... displaying %i of %i total bound parameter sets ... "
- return ' '.join((
- self._repr_multi(self.params[:self.batches - 2], typ)[0:-1],
- msg % (self.batches, len(self.params)),
- self._repr_multi(self.params[-2:], typ)[1:]
- ))
+ return " ".join(
+ (
+ self._repr_multi(self.params[: self.batches - 2], typ)[
+ 0:-1
+ ],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(self.params[-2:], typ)[1:],
+ )
+ )
elif ismulti:
return self._repr_multi(self.params, typ)
else:
@@ -467,12 +494,13 @@ class _repr_params(_repr_base):
elif isinstance(multi_params[0], dict):
elem_type = self._DICT
else:
- assert False, \
- "Unknown parameter type %s" % (type(multi_params[0]))
+ assert False, "Unknown parameter type %s" % (
+ type(multi_params[0])
+ )
elements = ", ".join(
- self._repr_params(params, elem_type)
- for params in multi_params)
+ self._repr_params(params, elem_type) for params in multi_params
+ )
else:
elements = ""
@@ -493,13 +521,10 @@ class _repr_params(_repr_base):
elif typ is self._TUPLE:
return "(%s%s)" % (
", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else ""
-
+ "," if len(params) == 1 else "",
)
else:
- return "[%s]" % (
- ", ".join(trunc(value) for value in params)
- )
+ return "[%s]" % (", ".join(trunc(value) for value in params))
def adapt_criterion_to_null(crit, nulls):
@@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls):
"""
def visit_binary(binary):
- if isinstance(binary.left, BindParameter) \
- and binary.left._identifying_key in nulls:
+ if (
+ isinstance(binary.left, BindParameter)
+ and binary.left._identifying_key in nulls
+ ):
# reverse order if the NULL is on the left side
binary.left = binary.right
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- elif isinstance(binary.right, BindParameter) \
- and binary.right._identifying_key in nulls:
+ elif (
+ isinstance(binary.right, BindParameter)
+ and binary.right._identifying_key in nulls
+ ):
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- return visitors.cloned_traverse(crit, {}, {'binary': visit_binary})
+ return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
def splice_joins(left, right, stop_on=None):
@@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw):
in the selectable to just those that are not repeated.
"""
- ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
- only_synonyms = kw.pop('only_synonyms', False)
+ ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
+ only_synonyms = kw.pop("only_synonyms", False)
columns = util.ordered_column_set(columns)
@@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw):
continue
else:
raise
- if fk_col.shares_lineage(c) and \
- (not only_synonyms or
- c.name == col.name):
+ if fk_col.shares_lineage(c) and (
+ not only_synonyms or c.name == col.name
+ ):
omit.add(col)
break
if clauses:
+
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)]))
+ chain(*[c.proxy_set for c in columns.difference(omit)])
+ )
if binary.left in cols and binary.right in cols:
for c in reversed(columns):
- if c.shares_lineage(binary.right) and \
- (not only_synonyms or
- c.name == binary.left.name):
+ if c.shares_lineage(binary.right) and (
+ not only_synonyms or c.name == binary.left.name
+ ):
omit.add(c)
break
+
for clause in clauses:
if clause is not None:
- visitors.traverse(clause, {}, {'binary': visit_binary})
+ visitors.traverse(clause, {}, {"binary": visit_binary})
return ColumnSet(columns.difference(omit))
-def criterion_as_pairs(expression, consider_as_foreign_keys=None,
- consider_as_referenced_keys=None, any_operator=False):
+def criterion_as_pairs(
+ expression,
+ consider_as_foreign_keys=None,
+ consider_as_referenced_keys=None,
+ any_operator=False,
+):
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exc.ArgumentError("Can only specify one of "
- "'consider_as_foreign_keys' or "
- "'consider_as_referenced_keys'")
+ raise exc.ArgumentError(
+ "Can only specify one of "
+ "'consider_as_foreign_keys' or "
+ "'consider_as_referenced_keys'"
+ )
def col_is(a, b):
# return a is b
@@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
- if not isinstance(binary.left, ColumnElement) or \
- not isinstance(binary.right, ColumnElement):
+ if not isinstance(binary.left, ColumnElement) or not isinstance(
+ binary.right, ColumnElement
+ ):
return
if consider_as_foreign_keys:
- if binary.left in consider_as_foreign_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_foreign_keys):
+ if binary.left in consider_as_foreign_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_foreign_keys
+ ):
pairs.append((binary.right, binary.left))
- elif binary.right in consider_as_foreign_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_foreign_keys):
+ elif binary.right in consider_as_foreign_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_foreign_keys
+ ):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
- if binary.left in consider_as_referenced_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_referenced_keys):
+ if binary.left in consider_as_referenced_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_referenced_keys
+ ):
pairs.append((binary.left, binary.right))
- elif binary.right in consider_as_referenced_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_referenced_keys):
+ elif binary.right in consider_as_referenced_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_referenced_keys
+ ):
pairs.append((binary.right, binary.left))
else:
- if isinstance(binary.left, Column) and \
- isinstance(binary.right, Column):
+ if isinstance(binary.left, Column) and isinstance(
+ binary.right, Column
+ ):
if binary.left.references(binary.right):
pairs.append((binary.right, binary.left))
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
+
pairs = []
- visitors.traverse(expression, {}, {'binary': visit_binary})
+ visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
@@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""
- def __init__(self, selectable, equivalents=None,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False, anonymize_labels=False):
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ anonymize_labels=False,
+ ):
self.__traverse_options__ = {
- 'stop_on': [selectable],
- 'anonymize_labels': anonymize_labels}
+ "stop_on": [selectable],
+ "anonymize_labels": anonymize_labels,
+ }
self.selectable = selectable
self.include_fn = include_fn
self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
- def _corresponding_column(self, col, require_embedded,
- _seen=util.EMPTY_SET):
+ def _corresponding_column(
+ self, col, require_embedded, _seen=util.EMPTY_SET
+ ):
newcol = self.selectable.corresponding_column(
- col,
- require_embedded=require_embedded)
+ col, require_embedded=require_embedded
+ )
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
newcol = self._corresponding_column(
- equiv, require_embedded=require_embedded,
- _seen=_seen.union([col]))
+ equiv,
+ require_embedded=require_embedded,
+ _seen=_seen.union([col]),
+ )
if newcol is not None:
return newcol
if self.adapt_on_names and newcol is None:
@@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return newcol
def replace(self, col):
- if isinstance(col, FromClause) and \
- self.selectable.is_derived_from(col):
+ if isinstance(col, FromClause) and self.selectable.is_derived_from(
+ col
+ ):
return self.selectable
elif not isinstance(col, ColumnElement):
return None
@@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter):
"""
- def __init__(self, selectable, equivalents=None,
- chain_to=None, adapt_required=False,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False):
- ClauseAdapter.__init__(self, selectable, equivalents,
- include_fn=include_fn, exclude_fn=exclude_fn,
- adapt_on_names=adapt_on_names,
- anonymize_labels=anonymize_labels)
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ chain_to=None,
+ adapt_required=False,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
+ ClauseAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ include_fn=include_fn,
+ exclude_fn=exclude_fn,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=anonymize_labels,
+ )
if chain_to:
self.chain(chain_to)
@@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter):
def __getitem__(self, key):
if (
self.parent.include_fn and not self.parent.include_fn(key)
- ) or (
- self.parent.exclude_fn and self.parent.exclude_fn(key)
- ):
+ ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
if self.parent._wrap:
return self.parent._wrap.columns[key]
else:
@@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter):
def __getstate__(self):
d = self.__dict__.copy()
- del d['columns']
+ del d["columns"]
return d
def __setstate__(self, state):