diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 9df0c932f..a84843c4b 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -11,6 +11,7 @@ import itertools +import operator import re from .visitors import ClauseVisitor @@ -51,6 +52,38 @@ def _generative(fn, *args, **kw): return self +def _clone(element, **kw): + return element._clone() + + +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ) + + +def _cloned_difference(a, b): + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + ) + + class _DialectArgView(util.collections_abc.MutableMapping): """A dictionary view of dialect-level arguments in the form <dialectname>_<argument_name>. @@ -486,6 +519,97 @@ class ColumnCollection(util.OrderedProperties): def __str__(self): return repr([str(c) for c in self]) + def corresponding_column(self, column, require_embedded=False): + """Given a :class:`.ColumnElement`, return the exported + :class:`.ColumnElement` object from this :class:`.ColumnCollection` + which corresponds to that original :class:`.ColumnElement` via a common + ancestor column. + + :param column: the target :class:`.ColumnElement` to be matched + + :param require_embedded: only return corresponding columns for + the given :class:`.ColumnElement`, if the given + :class:`.ColumnElement` is actually present within a sub-element + of this :class:`.Selectable`. Normally the column will match if + it merely shares a common ancestor with one of the exported + columns of this :class:`.Selectable`. + + .. seealso:: + + :meth:`.Selectable.corresponding_column` - invokes this method + against the collection returned by + :attr:`.Selectable.exported_columns`. + + .. versionchanged:: 1.4 the implementation for ``corresponding_column`` + was moved onto the :class:`.ColumnCollection` itself. + + """ + + def embedded(expanded_proxy_set, target_set): + for t in target_set.difference(expanded_proxy_set): + if not set(_expand_cloned([t])).intersection( + expanded_proxy_set + ): + return False + return True + + # don't dig around if the column is locally present + if self.contains_column(column): + return column + col, intersect = None, None + target_set = column.proxy_set + cols = self._all_columns + for c in cols: + expanded_proxy_set = set(_expand_cloned(c.proxy_set)) + i = target_set.intersection(expanded_proxy_set) + if i and ( + not require_embedded + or embedded(expanded_proxy_set, target_set) + ): + if col is None: + + # no corresponding column yet, pick this one. + + col, intersect = c, i + elif len(i) > len(intersect): + + # 'c' has a larger field of correspondence than + # 'col'. i.e. selectable.c.a1_x->a1.c.x->table.c.x + # matches a1.c.x->table.c.x better than + # selectable.c.x->table.c.x does. + + col, intersect = c, i + elif i == intersect: + # they have the same field of correspondence. see + # which proxy_set has fewer columns in it, which + # indicates a closer relationship with the root + # column. Also take into account the "weight" + # attribute which CompoundSelect() uses to give + # higher precedence to columns based on vertical + # position in the compound statement, and discard + # columns that have no reference to the target + # column (also occurs with CompoundSelect) + + col_distance = util.reduce( + operator.add, + [ + sc._annotations.get("weight", 1) + for sc in col._uncached_proxy_set() + if sc.shares_lineage(column) + ], + ) + c_distance = util.reduce( + operator.add, + [ + sc._annotations.get("weight", 1) + for sc in c._uncached_proxy_set() + if sc.shares_lineage(column) + ], + ) + if c_distance < col_distance: + col, intersect = c, i + return col + def replace(self, column): """add the given column to this collection, removing unaliased versions of this column as well as existing columns with the @@ -619,6 +743,26 @@ class ColumnCollection(util.OrderedProperties): return ImmutableColumnCollection(self._data, self._all_columns) +class SeparateKeyColumnCollection(ColumnCollection): + """Column collection that maintains a string name separate from the + column itself""" + + def __init__(self, cols_plus_names=None): + super(ColumnCollection, self).__init__() + object.__setattr__(self, "_all_columns", []) + if cols_plus_names: + self.update(cols_plus_names) + + def replace(self, column): + raise NotImplementedError() + + def add(self, column): + raise NotImplementedError() + + def remove(self, column): + raise NotImplementedError() + + class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): def __init__(self, data, all_columns): util.ImmutableProperties.__init__(self, data) |