diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 327 |
1 files changed, 196 insertions, 131 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 12cfe09d1..4feaf9938 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -15,15 +15,29 @@ from . import operators, visitors from itertools import chain from collections import deque -from .elements import BindParameter, ColumnClause, ColumnElement, \ - Null, UnaryExpression, literal_column, Label, _label_reference, \ - _textual_label_reference -from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping +from .elements import ( + BindParameter, + ColumnClause, + ColumnElement, + Null, + UnaryExpression, + literal_column, + Label, + _label_reference, + _textual_label_reference, +) +from .selectable import ( + SelectBase, + ScalarSelect, + Join, + FromClause, + FromGrouping, +) from .schema import Column join_condition = util.langhelpers.public_factory( - Join._join_condition, - ".sql.util.join_condition") + Join._join_condition, ".sql.util.join_condition" +) # names that are still being imported from the outside from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate @@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from): for idx in liberal_idx: f = clauses[idx] for s in selectables: - if set(surface_selectables(f)).\ - intersection(surface_selectables(s)): + if set(surface_selectables(f)).intersection( + surface_selectables(s) + ): conservative_idx.append(idx) break if conservative_idx: @@ -184,8 +199,9 @@ def visit_binary_product(fn, expr): # we don't want to dig into correlated subqueries, # those are just column elements by themselves yield element - elif element.__visit_name__ == 'binary' and \ - operators.is_comparison(element.operator): + elif element.__visit_name__ == "binary" and operators.is_comparison( + element.operator + ): stack.insert(0, element) for l in visit(element.left): for r in visit(element.right): @@ -199,38 +215,47 @@ def visit_binary_product(fn, expr): for elem in element.get_children(): for e in visit(elem): yield e + list(visit(expr)) -def find_tables(clause, check_columns=False, - include_aliases=False, include_joins=False, - include_selects=False, include_crud=False): +def find_tables( + clause, + check_columns=False, + include_aliases=False, + include_joins=False, + include_selects=False, + include_crud=False, +): """locate Table objects within the given expression.""" tables = [] _visitors = {} if include_selects: - _visitors['select'] = _visitors['compound_select'] = tables.append + _visitors["select"] = _visitors["compound_select"] = tables.append if include_joins: - _visitors['join'] = tables.append + _visitors["join"] = tables.append if include_aliases: - _visitors['alias'] = tables.append + _visitors["alias"] = tables.append if include_crud: - _visitors['insert'] = _visitors['update'] = \ - _visitors['delete'] = lambda ent: tables.append(ent.table) + _visitors["insert"] = _visitors["update"] = _visitors[ + "delete" + ] = lambda ent: tables.append(ent.table) if check_columns: + def visit_column(column): tables.append(column.table) - _visitors['column'] = visit_column - _visitors['table'] = tables.append + _visitors["column"] = visit_column - visitors.traverse(clause, {'column_collections': False}, _visitors) + _visitors["table"] = tables.append + + visitors.traverse(clause, {"column_collections": False}, _visitors) return tables @@ -243,10 +268,9 @@ def unwrap_order_by(clause): stack = deque([clause]) while stack: t = stack.popleft() - if isinstance(t, ColumnElement) and \ - ( - not isinstance(t, UnaryExpression) or - not operators.is_ordering_modifier(t.modifier) + if isinstance(t, ColumnElement) and ( + not isinstance(t, UnaryExpression) + or not operators.is_ordering_modifier(t.modifier) ): if isinstance(t, _label_reference): t = t.element @@ -266,9 +290,7 @@ def unwrap_label_reference(element): if isinstance(elem, (_label_reference, _textual_label_reference)): return elem.element - return visitors.replacement_traverse( - element, {}, replace - ) + return visitors.replacement_traverse(element, {}, replace) def expand_column_list_from_order_by(collist, order_by): @@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by): in the collist. """ - cols_already_present = set([ - col.element if col._order_by_label_element is not None - else col for col in collist - ]) + cols_already_present = set( + [ + col.element if col._order_by_label_element is not None else col + for col in collist + ] + ) return [ - col for col in - chain(*[ - unwrap_order_by(o) - for o in order_by - ]) + col + for col in chain(*[unwrap_order_by(o) for o in order_by]) if col not in cols_already_present ] @@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True): be addressable in the WHERE clause of a SELECT if this element were in the columns clause.""" - filter_ = (FromGrouping, ) + filter_ = (FromGrouping,) if not include_scalar_selects: - filter_ += (SelectBase, ) + filter_ += (SelectBase,) stack = deque([clause]) while stack: @@ -343,9 +364,7 @@ def selectables_overlap(left, right): """Return True if left/right have some overlapping selectable""" return bool( - set(surface_selectables(left)).intersection( - surface_selectables(right) - ) + set(surface_selectables(left)).intersection(surface_selectables(right)) ) @@ -366,7 +385,7 @@ def bind_values(clause): def visit_bindparam(bind): v.append(bind.effective_value) - visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) + visitors.traverse(clause, {}, {"bindparam": visit_bindparam}) return v @@ -383,7 +402,7 @@ class _repr_base(object): _TUPLE = 1 _DICT = 2 - __slots__ = 'max_chars', + __slots__ = ("max_chars",) def trunc(self, value): rep = repr(value) @@ -391,10 +410,12 @@ class _repr_base(object): if lenrep > self.max_chars: segment_length = self.max_chars // 2 rep = ( - rep[0:segment_length] + - (" ... (%d characters truncated) ... " - % (lenrep - self.max_chars)) + - rep[-segment_length:] + rep[0:segment_length] + + ( + " ... (%d characters truncated) ... " + % (lenrep - self.max_chars) + ) + + rep[-segment_length:] ) return rep @@ -402,7 +423,7 @@ class _repr_base(object): class _repr_row(_repr_base): """Provide a string view of a row.""" - __slots__ = 'row', + __slots__ = ("row",) def __init__(self, row, max_chars=300): self.row = row @@ -412,7 +433,7 @@ class _repr_row(_repr_base): trunc = self.trunc return "(%s%s)" % ( ", ".join(trunc(value) for value in self.row), - "," if len(self.row) == 1 else "" + "," if len(self.row) == 1 else "", ) @@ -424,7 +445,7 @@ class _repr_params(_repr_base): """ - __slots__ = 'params', 'batches', + __slots__ = "params", "batches" def __init__(self, params, batches, max_chars=300): self.params = params @@ -435,11 +456,13 @@ class _repr_params(_repr_base): if isinstance(self.params, list): typ = self._LIST ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, tuple): typ = self._TUPLE ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, dict): typ = self._DICT ismulti = False @@ -448,11 +471,15 @@ class _repr_params(_repr_base): if ismulti and len(self.params) > self.batches: msg = " ... displaying %i of %i total bound parameter sets ... " - return ' '.join(( - self._repr_multi(self.params[:self.batches - 2], typ)[0:-1], - msg % (self.batches, len(self.params)), - self._repr_multi(self.params[-2:], typ)[1:] - )) + return " ".join( + ( + self._repr_multi(self.params[: self.batches - 2], typ)[ + 0:-1 + ], + msg % (self.batches, len(self.params)), + self._repr_multi(self.params[-2:], typ)[1:], + ) + ) elif ismulti: return self._repr_multi(self.params, typ) else: @@ -467,12 +494,13 @@ class _repr_params(_repr_base): elif isinstance(multi_params[0], dict): elem_type = self._DICT else: - assert False, \ - "Unknown parameter type %s" % (type(multi_params[0])) + assert False, "Unknown parameter type %s" % ( + type(multi_params[0]) + ) elements = ", ".join( - self._repr_params(params, elem_type) - for params in multi_params) + self._repr_params(params, elem_type) for params in multi_params + ) else: elements = "" @@ -493,13 +521,10 @@ class _repr_params(_repr_base): elif typ is self._TUPLE: return "(%s%s)" % ( ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "" - + "," if len(params) == 1 else "", ) else: - return "[%s]" % ( - ", ".join(trunc(value) for value in params) - ) + return "[%s]" % (", ".join(trunc(value) for value in params)) def adapt_criterion_to_null(crit, nulls): @@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls): """ def visit_binary(binary): - if isinstance(binary.left, BindParameter) \ - and binary.left._identifying_key in nulls: + if ( + isinstance(binary.left, BindParameter) + and binary.left._identifying_key in nulls + ): # reverse order if the NULL is on the left side binary.left = binary.right binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - elif isinstance(binary.right, BindParameter) \ - and binary.right._identifying_key in nulls: + elif ( + isinstance(binary.right, BindParameter) + and binary.right._identifying_key in nulls + ): binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) + return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) def splice_joins(left, right, stop_on=None): @@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw): in the selectable to just those that are not repeated. """ - ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) - only_synonyms = kw.pop('only_synonyms', False) + ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) + only_synonyms = kw.pop("only_synonyms", False) columns = util.ordered_column_set(columns) @@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw): continue else: raise - if fk_col.shares_lineage(c) and \ - (not only_synonyms or - c.name == col.name): + if fk_col.shares_lineage(c) and ( + not only_synonyms or c.name == col.name + ): omit.add(col) break if clauses: + def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)])) + chain(*[c.proxy_set for c in columns.difference(omit)]) + ) if binary.left in cols and binary.right in cols: for c in reversed(columns): - if c.shares_lineage(binary.right) and \ - (not only_synonyms or - c.name == binary.left.name): + if c.shares_lineage(binary.right) and ( + not only_synonyms or c.name == binary.left.name + ): omit.add(c) break + for clause in clauses: if clause is not None: - visitors.traverse(clause, {}, {'binary': visit_binary}) + visitors.traverse(clause, {}, {"binary": visit_binary}) return ColumnSet(columns.difference(omit)) -def criterion_as_pairs(expression, consider_as_foreign_keys=None, - consider_as_referenced_keys=None, any_operator=False): +def criterion_as_pairs( + expression, + consider_as_foreign_keys=None, + consider_as_referenced_keys=None, + any_operator=False, +): """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: - raise exc.ArgumentError("Can only specify one of " - "'consider_as_foreign_keys' or " - "'consider_as_referenced_keys'") + raise exc.ArgumentError( + "Can only specify one of " + "'consider_as_foreign_keys' or " + "'consider_as_referenced_keys'" + ) def col_is(a, b): # return a is b @@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return - if not isinstance(binary.left, ColumnElement) or \ - not isinstance(binary.right, ColumnElement): + if not isinstance(binary.left, ColumnElement) or not isinstance( + binary.right, ColumnElement + ): return if consider_as_foreign_keys: - if binary.left in consider_as_foreign_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_foreign_keys): + if binary.left in consider_as_foreign_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_foreign_keys + ): pairs.append((binary.right, binary.left)) - elif binary.right in consider_as_foreign_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_foreign_keys): + elif binary.right in consider_as_foreign_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_foreign_keys + ): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: - if binary.left in consider_as_referenced_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_referenced_keys): + if binary.left in consider_as_referenced_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_referenced_keys + ): pairs.append((binary.left, binary.right)) - elif binary.right in consider_as_referenced_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_referenced_keys): + elif binary.right in consider_as_referenced_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_referenced_keys + ): pairs.append((binary.right, binary.left)) else: - if isinstance(binary.left, Column) and \ - isinstance(binary.right, Column): + if isinstance(binary.left, Column) and isinstance( + binary.right, Column + ): if binary.left.references(binary.right): pairs.append((binary.right, binary.left)) elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) + pairs = [] - visitors.traverse(expression, {}, {'binary': visit_binary}) + visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs @@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): """ - def __init__(self, selectable, equivalents=None, - include_fn=None, exclude_fn=None, - adapt_on_names=False, anonymize_labels=False): + def __init__( + self, + selectable, + equivalents=None, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + anonymize_labels=False, + ): self.__traverse_options__ = { - 'stop_on': [selectable], - 'anonymize_labels': anonymize_labels} + "stop_on": [selectable], + "anonymize_labels": anonymize_labels, + } self.selectable = selectable self.include_fn = include_fn self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names - def _corresponding_column(self, col, require_embedded, - _seen=util.EMPTY_SET): + def _corresponding_column( + self, col, require_embedded, _seen=util.EMPTY_SET + ): newcol = self.selectable.corresponding_column( - col, - require_embedded=require_embedded) + col, require_embedded=require_embedded + ) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: newcol = self._corresponding_column( - equiv, require_embedded=require_embedded, - _seen=_seen.union([col])) + equiv, + require_embedded=require_embedded, + _seen=_seen.union([col]), + ) if newcol is not None: return newcol if self.adapt_on_names and newcol is None: @@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, FromClause) and \ - self.selectable.is_derived_from(col): + if isinstance(col, FromClause) and self.selectable.is_derived_from( + col + ): return self.selectable elif not isinstance(col, ColumnElement): return None @@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter): """ - def __init__(self, selectable, equivalents=None, - chain_to=None, adapt_required=False, - include_fn=None, exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False): - ClauseAdapter.__init__(self, selectable, equivalents, - include_fn=include_fn, exclude_fn=exclude_fn, - adapt_on_names=adapt_on_names, - anonymize_labels=anonymize_labels) + def __init__( + self, + selectable, + equivalents=None, + chain_to=None, + adapt_required=False, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + allow_label_resolve=True, + anonymize_labels=False, + ): + ClauseAdapter.__init__( + self, + selectable, + equivalents, + include_fn=include_fn, + exclude_fn=exclude_fn, + adapt_on_names=adapt_on_names, + anonymize_labels=anonymize_labels, + ) if chain_to: self.chain(chain_to) @@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter): def __getitem__(self, key): if ( self.parent.include_fn and not self.parent.include_fn(key) - ) or ( - self.parent.exclude_fn and self.parent.exclude_fn(key) - ): + ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)): if self.parent._wrap: return self.parent._wrap.columns[key] else: @@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter): def __getstate__(self): d = self.__dict__.copy() - del d['columns'] + del d["columns"] return d def __setstate__(self, state): |