diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/sql/util.py | |
parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 327 |
1 files changed, 196 insertions, 131 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 12cfe09d1..4feaf9938 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -15,15 +15,29 @@ from . import operators, visitors from itertools import chain from collections import deque -from .elements import BindParameter, ColumnClause, ColumnElement, \ - Null, UnaryExpression, literal_column, Label, _label_reference, \ - _textual_label_reference -from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping +from .elements import ( + BindParameter, + ColumnClause, + ColumnElement, + Null, + UnaryExpression, + literal_column, + Label, + _label_reference, + _textual_label_reference, +) +from .selectable import ( + SelectBase, + ScalarSelect, + Join, + FromClause, + FromGrouping, +) from .schema import Column join_condition = util.langhelpers.public_factory( - Join._join_condition, - ".sql.util.join_condition") + Join._join_condition, ".sql.util.join_condition" +) # names that are still being imported from the outside from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate @@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from): for idx in liberal_idx: f = clauses[idx] for s in selectables: - if set(surface_selectables(f)).\ - intersection(surface_selectables(s)): + if set(surface_selectables(f)).intersection( + surface_selectables(s) + ): conservative_idx.append(idx) break if conservative_idx: @@ -184,8 +199,9 @@ def visit_binary_product(fn, expr): # we don't want to dig into correlated subqueries, # those are just column elements by themselves yield element - elif element.__visit_name__ == 'binary' and \ - operators.is_comparison(element.operator): + elif element.__visit_name__ == "binary" and operators.is_comparison( + element.operator + ): stack.insert(0, element) for l in visit(element.left): for r in visit(element.right): @@ -199,38 +215,47 @@ def visit_binary_product(fn, expr): for elem in element.get_children(): for e in visit(elem): yield e + list(visit(expr)) -def find_tables(clause, check_columns=False, - include_aliases=False, include_joins=False, - include_selects=False, include_crud=False): +def find_tables( + clause, + check_columns=False, + include_aliases=False, + include_joins=False, + include_selects=False, + include_crud=False, +): """locate Table objects within the given expression.""" tables = [] _visitors = {} if include_selects: - _visitors['select'] = _visitors['compound_select'] = tables.append + _visitors["select"] = _visitors["compound_select"] = tables.append if include_joins: - _visitors['join'] = tables.append + _visitors["join"] = tables.append if include_aliases: - _visitors['alias'] = tables.append + _visitors["alias"] = tables.append if include_crud: - _visitors['insert'] = _visitors['update'] = \ - _visitors['delete'] = lambda ent: tables.append(ent.table) + _visitors["insert"] = _visitors["update"] = _visitors[ + "delete" + ] = lambda ent: tables.append(ent.table) if check_columns: + def visit_column(column): tables.append(column.table) - _visitors['column'] = visit_column - _visitors['table'] = tables.append + _visitors["column"] = visit_column - visitors.traverse(clause, {'column_collections': False}, _visitors) + _visitors["table"] = tables.append + + visitors.traverse(clause, {"column_collections": False}, _visitors) return tables @@ -243,10 +268,9 @@ def unwrap_order_by(clause): stack = deque([clause]) while stack: t = stack.popleft() - if isinstance(t, ColumnElement) and \ - ( - not isinstance(t, UnaryExpression) or - not operators.is_ordering_modifier(t.modifier) + if isinstance(t, ColumnElement) and ( + not isinstance(t, UnaryExpression) + or not operators.is_ordering_modifier(t.modifier) ): if isinstance(t, _label_reference): t = t.element @@ -266,9 +290,7 @@ def unwrap_label_reference(element): if isinstance(elem, (_label_reference, _textual_label_reference)): return elem.element - return visitors.replacement_traverse( - element, {}, replace - ) + return visitors.replacement_traverse(element, {}, replace) def expand_column_list_from_order_by(collist, order_by): @@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by): in the collist. """ - cols_already_present = set([ - col.element if col._order_by_label_element is not None - else col for col in collist - ]) + cols_already_present = set( + [ + col.element if col._order_by_label_element is not None else col + for col in collist + ] + ) return [ - col for col in - chain(*[ - unwrap_order_by(o) - for o in order_by - ]) + col + for col in chain(*[unwrap_order_by(o) for o in order_by]) if col not in cols_already_present ] @@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True): be addressable in the WHERE clause of a SELECT if this element were in the columns clause.""" - filter_ = (FromGrouping, ) + filter_ = (FromGrouping,) if not include_scalar_selects: - filter_ += (SelectBase, ) + filter_ += (SelectBase,) stack = deque([clause]) while stack: @@ -343,9 +364,7 @@ def selectables_overlap(left, right): """Return True if left/right have some overlapping selectable""" return bool( - set(surface_selectables(left)).intersection( - surface_selectables(right) - ) + set(surface_selectables(left)).intersection(surface_selectables(right)) ) @@ -366,7 +385,7 @@ def bind_values(clause): def visit_bindparam(bind): v.append(bind.effective_value) - visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) + visitors.traverse(clause, {}, {"bindparam": visit_bindparam}) return v @@ -383,7 +402,7 @@ class _repr_base(object): _TUPLE = 1 _DICT = 2 - __slots__ = 'max_chars', + __slots__ = ("max_chars",) def trunc(self, value): rep = repr(value) @@ -391,10 +410,12 @@ class _repr_base(object): if lenrep > self.max_chars: segment_length = self.max_chars // 2 rep = ( - rep[0:segment_length] + - (" ... (%d characters truncated) ... " - % (lenrep - self.max_chars)) + - rep[-segment_length:] + rep[0:segment_length] + + ( + " ... (%d characters truncated) ... " + % (lenrep - self.max_chars) + ) + + rep[-segment_length:] ) return rep @@ -402,7 +423,7 @@ class _repr_base(object): class _repr_row(_repr_base): """Provide a string view of a row.""" - __slots__ = 'row', + __slots__ = ("row",) def __init__(self, row, max_chars=300): self.row = row @@ -412,7 +433,7 @@ class _repr_row(_repr_base): trunc = self.trunc return "(%s%s)" % ( ", ".join(trunc(value) for value in self.row), - "," if len(self.row) == 1 else "" + "," if len(self.row) == 1 else "", ) @@ -424,7 +445,7 @@ class _repr_params(_repr_base): """ - __slots__ = 'params', 'batches', + __slots__ = "params", "batches" def __init__(self, params, batches, max_chars=300): self.params = params @@ -435,11 +456,13 @@ class _repr_params(_repr_base): if isinstance(self.params, list): typ = self._LIST ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, tuple): typ = self._TUPLE ismulti = self.params and isinstance( - self.params[0], (list, dict, tuple)) + self.params[0], (list, dict, tuple) + ) elif isinstance(self.params, dict): typ = self._DICT ismulti = False @@ -448,11 +471,15 @@ class _repr_params(_repr_base): if ismulti and len(self.params) > self.batches: msg = " ... displaying %i of %i total bound parameter sets ... " - return ' '.join(( - self._repr_multi(self.params[:self.batches - 2], typ)[0:-1], - msg % (self.batches, len(self.params)), - self._repr_multi(self.params[-2:], typ)[1:] - )) + return " ".join( + ( + self._repr_multi(self.params[: self.batches - 2], typ)[ + 0:-1 + ], + msg % (self.batches, len(self.params)), + self._repr_multi(self.params[-2:], typ)[1:], + ) + ) elif ismulti: return self._repr_multi(self.params, typ) else: @@ -467,12 +494,13 @@ class _repr_params(_repr_base): elif isinstance(multi_params[0], dict): elem_type = self._DICT else: - assert False, \ - "Unknown parameter type %s" % (type(multi_params[0])) + assert False, "Unknown parameter type %s" % ( + type(multi_params[0]) + ) elements = ", ".join( - self._repr_params(params, elem_type) - for params in multi_params) + self._repr_params(params, elem_type) for params in multi_params + ) else: elements = "" @@ -493,13 +521,10 @@ class _repr_params(_repr_base): elif typ is self._TUPLE: return "(%s%s)" % ( ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "" - + "," if len(params) == 1 else "", ) else: - return "[%s]" % ( - ", ".join(trunc(value) for value in params) - ) + return "[%s]" % (", ".join(trunc(value) for value in params)) def adapt_criterion_to_null(crit, nulls): @@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls): """ def visit_binary(binary): - if isinstance(binary.left, BindParameter) \ - and binary.left._identifying_key in nulls: + if ( + isinstance(binary.left, BindParameter) + and binary.left._identifying_key in nulls + ): # reverse order if the NULL is on the left side binary.left = binary.right binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - elif isinstance(binary.right, BindParameter) \ - and binary.right._identifying_key in nulls: + elif ( + isinstance(binary.right, BindParameter) + and binary.right._identifying_key in nulls + ): binary.right = Null() binary.operator = operators.is_ binary.negate = operators.isnot - return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) + return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) def splice_joins(left, right, stop_on=None): @@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw): in the selectable to just those that are not repeated. """ - ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) - only_synonyms = kw.pop('only_synonyms', False) + ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) + only_synonyms = kw.pop("only_synonyms", False) columns = util.ordered_column_set(columns) @@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw): continue else: raise - if fk_col.shares_lineage(c) and \ - (not only_synonyms or - c.name == col.name): + if fk_col.shares_lineage(c) and ( + not only_synonyms or c.name == col.name + ): omit.add(col) break if clauses: + def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)])) + chain(*[c.proxy_set for c in columns.difference(omit)]) + ) if binary.left in cols and binary.right in cols: for c in reversed(columns): - if c.shares_lineage(binary.right) and \ - (not only_synonyms or - c.name == binary.left.name): + if c.shares_lineage(binary.right) and ( + not only_synonyms or c.name == binary.left.name + ): omit.add(c) break + for clause in clauses: if clause is not None: - visitors.traverse(clause, {}, {'binary': visit_binary}) + visitors.traverse(clause, {}, {"binary": visit_binary}) return ColumnSet(columns.difference(omit)) -def criterion_as_pairs(expression, consider_as_foreign_keys=None, - consider_as_referenced_keys=None, any_operator=False): +def criterion_as_pairs( + expression, + consider_as_foreign_keys=None, + consider_as_referenced_keys=None, + any_operator=False, +): """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: - raise exc.ArgumentError("Can only specify one of " - "'consider_as_foreign_keys' or " - "'consider_as_referenced_keys'") + raise exc.ArgumentError( + "Can only specify one of " + "'consider_as_foreign_keys' or " + "'consider_as_referenced_keys'" + ) def col_is(a, b): # return a is b @@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return - if not isinstance(binary.left, ColumnElement) or \ - not isinstance(binary.right, ColumnElement): + if not isinstance(binary.left, ColumnElement) or not isinstance( + binary.right, ColumnElement + ): return if consider_as_foreign_keys: - if binary.left in consider_as_foreign_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_foreign_keys): + if binary.left in consider_as_foreign_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_foreign_keys + ): pairs.append((binary.right, binary.left)) - elif binary.right in consider_as_foreign_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_foreign_keys): + elif binary.right in consider_as_foreign_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_foreign_keys + ): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: - if binary.left in consider_as_referenced_keys and \ - (col_is(binary.right, binary.left) or - binary.right not in consider_as_referenced_keys): + if binary.left in consider_as_referenced_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_referenced_keys + ): pairs.append((binary.left, binary.right)) - elif binary.right in consider_as_referenced_keys and \ - (col_is(binary.left, binary.right) or - binary.left not in consider_as_referenced_keys): + elif binary.right in consider_as_referenced_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_referenced_keys + ): pairs.append((binary.right, binary.left)) else: - if isinstance(binary.left, Column) and \ - isinstance(binary.right, Column): + if isinstance(binary.left, Column) and isinstance( + binary.right, Column + ): if binary.left.references(binary.right): pairs.append((binary.right, binary.left)) elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) + pairs = [] - visitors.traverse(expression, {}, {'binary': visit_binary}) + visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs @@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): """ - def __init__(self, selectable, equivalents=None, - include_fn=None, exclude_fn=None, - adapt_on_names=False, anonymize_labels=False): + def __init__( + self, + selectable, + equivalents=None, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + anonymize_labels=False, + ): self.__traverse_options__ = { - 'stop_on': [selectable], - 'anonymize_labels': anonymize_labels} + "stop_on": [selectable], + "anonymize_labels": anonymize_labels, + } self.selectable = selectable self.include_fn = include_fn self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names - def _corresponding_column(self, col, require_embedded, - _seen=util.EMPTY_SET): + def _corresponding_column( + self, col, require_embedded, _seen=util.EMPTY_SET + ): newcol = self.selectable.corresponding_column( - col, - require_embedded=require_embedded) + col, require_embedded=require_embedded + ) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: newcol = self._corresponding_column( - equiv, require_embedded=require_embedded, - _seen=_seen.union([col])) + equiv, + require_embedded=require_embedded, + _seen=_seen.union([col]), + ) if newcol is not None: return newcol if self.adapt_on_names and newcol is None: @@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, FromClause) and \ - self.selectable.is_derived_from(col): + if isinstance(col, FromClause) and self.selectable.is_derived_from( + col + ): return self.selectable elif not isinstance(col, ColumnElement): return None @@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter): """ - def __init__(self, selectable, equivalents=None, - chain_to=None, adapt_required=False, - include_fn=None, exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False): - ClauseAdapter.__init__(self, selectable, equivalents, - include_fn=include_fn, exclude_fn=exclude_fn, - adapt_on_names=adapt_on_names, - anonymize_labels=anonymize_labels) + def __init__( + self, + selectable, + equivalents=None, + chain_to=None, + adapt_required=False, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + allow_label_resolve=True, + anonymize_labels=False, + ): + ClauseAdapter.__init__( + self, + selectable, + equivalents, + include_fn=include_fn, + exclude_fn=exclude_fn, + adapt_on_names=adapt_on_names, + anonymize_labels=anonymize_labels, + ) if chain_to: self.chain(chain_to) @@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter): def __getitem__(self, key): if ( self.parent.include_fn and not self.parent.include_fn(key) - ) or ( - self.parent.exclude_fn and self.parent.exclude_fn(key) - ): + ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)): if self.parent._wrap: return self.parent._wrap.columns[key] else: @@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter): def __getstate__(self): d = self.__dict__.copy() - del d['columns'] + del d["columns"] return d def __setstate__(self, state): |