From aceefb508ccd0911f52ff0e50324b3fefeaa3f16 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 7 Jul 2019 11:12:31 -0400 Subject: Allow duplicate columns in from clauses and selectables The :func:`.select` construct and related constructs now allow for duplication of column labels and columns themselves in the columns clause, mirroring exactly how column expressions were passed in. This allows the tuples returned by an executed result to match what was SELECTed for in the first place, which is how the ORM :class:`.Query` works, so this establishes better cross-compatibility between the two constructs. Additionally, it allows column-positioning-sensitive structures such as UNIONs (i.e. :class:`.CompoundSelect`) to be more intuitively constructed in those cases where a particular column might appear in more than one place. To support this change, the :class:`.ColumnCollection` has been revised to support duplicate columns as well as to allow integer index access. Fixes: #4753 Change-Id: Ie09a8116f05c367995c1e43623c51e07971d3bf0 --- lib/sqlalchemy/sql/base.py | 447 ++++++++++++++++++++++++++++++++------------- 1 file changed, 316 insertions(+), 131 deletions(-) (limited to 'lib/sqlalchemy/sql/base.py') diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a84843c4b..da384bdab 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -499,26 +499,209 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -class ColumnCollection(util.OrderedProperties): - """An ordered dictionary that stores a list of ColumnElement - instances. +class ColumnCollection(object): + """Collection of :class:`.ColumnElement` instances, typically for + selectables. + + The :class:`.ColumnCollection` has both mapping- and sequence- like + behaviors. A :class:`.ColumnCollection` usually stores :class:`.Column` + objects, which are then accessible both via mapping style access as well + as attribute access style. The name for which a :class:`.Column` would + be present is normally that of the :paramref:`.Column.key` parameter, + however depending on the context, it may be stored under a special label + name:: + + >>> from sqlalchemy import Column, Integer + >>> from sqlalchemy.sql import ColumnCollection + >>> x, y = Column('x', Integer), Column('y', Integer) + >>> cc = ColumnCollection(columns=[x, y]) + >>> cc.x + Column('x', Integer(), table=None) + >>> cc.y + Column('y', Integer(), table=None) + >>> cc['x'] + Column('x', Integer(), table=None) + >>> cc['y'] + + :class`.ColumnCollection` also indexes the columns in order and allows + them to be accessible by their integer position:: + + >>> cc[0] + Column('x', Integer(), table=None) + >>> cc[1] + Column('y', Integer(), table=None) + + .. versionadded:: 1.4 :class:`.ColumnCollection` allows integer-based + index access to the collection. + + Iterating the collection yields the column expressions in order:: + + >>> list(cc) + [Column('x', Integer(), table=None), + Column('y', Integer(), table=None)] + + The base :class:`.ColumnCollection` object can store duplicates, which can + mean either two columns with the same key, in which case the column + returned by key access is **arbitrary**:: + + >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> cc = ColumnCollection(columns=[x1, x2]) + >>> list(cc) + [Column('x', Integer(), table=None), + Column('x', Integer(), table=None)] + >>> cc['x'] is x1 + False + >>> cc['x'] is x2 + True + + Or it can also mean the same column multiple times. These cases are + supported as :class:`.ColumnCollection` is used to represent the columns in + a SELECT statement which may include duplicates. + + A special subclass :class:`.DedupeColumnCollection` exists which instead + maintains SQLAlchemy's older behavior of not allowing duplicates; this + collection is used for schema level objects like :class:`.Table` and + :class:`.PrimaryKeyConstraint` where this deduping is helpful. The + :class:`.DedupeColumnCollection` class also has additional mutation methods + as the schema constructs have more use cases that require removal and + replacement of columns. + + .. versionchanged:: 1.4 :class:`.ColumnCollection` now stores duplicate + column keys as well as the same column in multiple positions. The + :class:`.DedupeColumnCollection` class is added to maintain the + former behavior in those cases where deduplication as well as + additional replace/remove operations are needed. - Overrides the ``__eq__()`` method to produce SQL clauses between - sets of correlated columns. """ - __slots__ = "_all_columns" + __slots__ = "_collection", "_index", "_colset" - def __init__(self, *columns): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - for c in columns: - self.add(c) + def __init__(self, columns=None): + object.__setattr__(self, "_colset", set()) + object.__setattr__(self, "_index", {}) + object.__setattr__(self, "_collection", []) + if columns: + self._initial_populate(columns) + + def _initial_populate(self, iter_): + self._populate_separate_keys(iter_) + + @property + def _all_columns(self): + return [col for (k, col) in self._collection] + + def keys(self): + return [k for (k, col) in self._collection] + + def __len__(self): + return len(self._collection) + + def __iter__(self): + # turn to a list first to maintain over a course of changes + return iter([col for k, col in self._collection]) + + def __getitem__(self, key): + try: + return self._index[key] + except KeyError: + if isinstance(key, util.int_types): + raise IndexError(key) + else: + raise + + def __getattr__(self, key): + try: + return self._index[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + if key not in self._index: + if not isinstance(key, util.string_types): + raise exc.ArgumentError( + "__contains__ requires a string argument" + ) + return False + else: + return True + + def compare(self, other): + for l, r in util.zip_longest(self, other): + if l is not r: + return False + else: + return True + + def __eq__(self, other): + return self.compare(other) + + def get(self, key, default=None): + if key in self._index: + return self._index[key] + else: + return default def __str__(self): return repr([str(c) for c in self]) + def __setitem__(self, key, value): + raise NotImplementedError() + + def __delitem__(self, key): + raise NotImplementedError() + + def __setattr__(self, key, obj): + raise NotImplementedError() + + def clear(self): + raise NotImplementedError() + + def remove(self, column): + raise NotImplementedError() + + def update(self, iter_): + raise NotImplementedError() + + __hash__ = None + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) + self._collection[:] = cols + self._colset.update(c for k, c in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + self._index.update({k: col for k, col in reversed(self._collection)}) + + def add(self, column, key=None): + if key is None: + key = column.key + + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + if key not in self._index: + self._index[key] = column + + def __getstate__(self): + return {"_collection": self._collection, "_index": self._index} + + def __setstate__(self, state): + object.__setattr__(self, "_index", state["_index"]) + object.__setattr__(self, "_collection", state["_collection"]) + object.__setattr__( + self, "_colset", {col for k, col in self._collection} + ) + + def contains_column(self, col): + return col in self._colset + + def as_immutable(self): + return ImmutableColumnCollection(self) + def corresponding_column(self, column, require_embedded=False): """Given a :class:`.ColumnElement`, return the exported :class:`.ColumnElement` object from this :class:`.ColumnCollection` @@ -554,11 +737,11 @@ class ColumnCollection(util.OrderedProperties): return True # don't dig around if the column is locally present - if self.contains_column(column): + if column in self._colset: return column col, intersect = None, None target_set = column.proxy_set - cols = self._all_columns + cols = [c for (k, c) in self._collection] for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) @@ -610,165 +793,167 @@ class ColumnCollection(util.OrderedProperties): col, intersect = c, i return col - def replace(self, column): - """add the given column to this collection, removing unaliased - versions of this column as well as existing columns with the - same key. - e.g.:: +class DedupeColumnCollection(ColumnCollection): + """A :class:`.ColumnCollection that maintains deduplicating behavior. - t = Table('sometable', metadata, Column('col1', Integer)) - t.columns.replace(Column('col1', Integer, key='columnone')) + This is useful by schema level objects such as :class:`.Table` and + :class:`.PrimaryKeyConstraint`. The collection includes more + sophisticated mutator methods as well to suit schema objects which + require mutable column collections. - will remove the original 'col1' from the collection, and add - the new column under the name 'columnname'. - - Used by schema.Column to override columns during table reflection. - - """ - remove_col = None - if column.name in self and column.key != column.name: - other = self[column.name] - if other.name == other.key: - remove_col = other - del self._data[other.key] - - if column.key in self._data: - remove_col = self._data[column.key] - - self._data[column.key] = column - if remove_col is not None: - self._all_columns[:] = [ - column if c is remove_col else c for c in self._all_columns - ] - else: - self._all_columns.append(column) + .. versionadded: 1.4 - def add(self, column): - """Add a column to this collection. + """ - The key attribute of the column will be used as the hash key - for this dictionary. + def add(self, column, key=None): + if key is not None and column.key != key: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + key = column.key - """ - if not column.key: + if key is None: raise exc.ArgumentError( "Can't add unnamed column to column collection" ) - self[column.key] = column - - def __delitem__(self, key): - raise NotImplementedError() - - def __setattr__(self, key, obj): - raise NotImplementedError() - def __setitem__(self, key, value): - if key in self: + if key in self._index: - # this warning is primarily to catch select() statements - # which have conflicting column names in their exported - # columns collection + existing = self._index[key] - existing = self[key] - - if existing is value: + if existing is column: return - if not existing.shares_lineage(value): - util.warn( - "Column %r on table %r being replaced by " - "%r, which has the same key. Consider " - "use_labels for select() statements." - % (key, getattr(existing, "table", None), value) - ) + self.replace(column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(value, "proxy_set") + util.memoized_property.reset(column, "proxy_set") + else: + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + self._index[key] = column + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) - self._all_columns.append(value) - self._data[key] = value + replace_col = [] + for k, col in cols: + if col.key != k: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + if col.name in self._index and col.key != col.name: + replace_col.append(col) + elif col.key in self._index: + replace_col.append(col) + else: + self._index[k] = col + self._collection.append((k, col)) + self._colset.update(c for (k, c) in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + for col in replace_col: + self.replace(col) - def clear(self): - raise NotImplementedError() + def extend(self, iter_): + self._populate_separate_keys((col.key, col) for col in iter_) def remove(self, column): - del self._data[column.key] - self._all_columns[:] = [ - c for c in self._all_columns if c is not column + if column not in self._colset: + raise ValueError( + "Can't remove column %r; column is not in this collection" + % column + ) + del self._index[column.key] + self._colset.remove(column) + self._collection[:] = [ + (k, c) for (k, c) in self._collection if c is not column ] - - def update(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend( - c for label, c in cols if c not in all_col_set + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} ) - self._data.update((label, c) for label, c in cols) + # delete higher index + del self._index[len(self._collection)] - def extend(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend(c for c in cols if c not in all_col_set) - self._data.update((c.key, c) for c in cols) + def replace(self, column): + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. - __hash__ = None + e.g.:: - def __eq__(self, other): - l = [] - for c in getattr(other, "_all_columns", other): - for local in self._all_columns: - if c.shares_lineage(local): - l.append(c == local) - return elements.and_(*l) + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) - def __contains__(self, other): - if not isinstance(other, util.string_types): - raise exc.ArgumentError("__contains__ requires a string argument") - return util.OrderedProperties.__contains__(self, other) + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. - def __getstate__(self): - return {"_data": self._data, "_all_columns": self._all_columns} + Used by schema.Column to override columns during table reflection. - def __setstate__(self, state): - object.__setattr__(self, "_data", state["_data"]) - object.__setattr__(self, "_all_columns", state["_all_columns"]) + """ - def contains_column(self, col): - return col in set(self._all_columns) + remove_col = set() + # remove up to two columns based on matches of name as well as key + if column.name in self._index and column.key != column.name: + other = self._index[column.name] + if other.name == other.key: + remove_col.add(other) + + if column.key in self._index: + remove_col.add(self._index[column.key]) + + new_cols = [] + replaced = False + for k, col in self._collection: + if col in remove_col: + if not replaced: + replaced = True + new_cols.append((column.key, column)) + else: + new_cols.append((k, col)) - def as_immutable(self): - return ImmutableColumnCollection(self._data, self._all_columns) + if remove_col: + self._colset.difference_update(remove_col) + if not replaced: + new_cols.append((column.key, column)) -class SeparateKeyColumnCollection(ColumnCollection): - """Column collection that maintains a string name separate from the - column itself""" + self._colset.add(column) + self._collection[:] = new_cols - def __init__(self, cols_plus_names=None): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - if cols_plus_names: - self.update(cols_plus_names) + self._index.clear() + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} + ) + self._index.update(self._collection) - def replace(self, column): - raise NotImplementedError() - def add(self, column): - raise NotImplementedError() +class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): + __slots__ = ("_parent",) - def remove(self, column): - raise NotImplementedError() + def __init__(self, collection): + object.__setattr__(self, "_parent", collection) + object.__setattr__(self, "_colset", collection._colset) + object.__setattr__(self, "_index", collection._index) + object.__setattr__(self, "_collection", collection._collection) + def __getstate__(self): + return {"_parent": self._parent} -class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): - def __init__(self, data, all_columns): - util.ImmutableProperties.__init__(self, data) - object.__setattr__(self, "_all_columns", all_columns) + def __setstate__(self, state): + parent = state["_parent"] + self.__init__(parent) - extend = remove = util.ImmutableProperties._immutable + add = extend = remove = util.ImmutableContainer._immutable class ColumnSet(util.ordered_column_set): -- cgit v1.2.1