diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 447 |
1 files changed, 316 insertions, 131 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a84843c4b..da384bdab 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -499,26 +499,209 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -class ColumnCollection(util.OrderedProperties): - """An ordered dictionary that stores a list of ColumnElement - instances. +class ColumnCollection(object): + """Collection of :class:`.ColumnElement` instances, typically for + selectables. + + The :class:`.ColumnCollection` has both mapping- and sequence- like + behaviors. A :class:`.ColumnCollection` usually stores :class:`.Column` + objects, which are then accessible both via mapping style access as well + as attribute access style. The name for which a :class:`.Column` would + be present is normally that of the :paramref:`.Column.key` parameter, + however depending on the context, it may be stored under a special label + name:: + + >>> from sqlalchemy import Column, Integer + >>> from sqlalchemy.sql import ColumnCollection + >>> x, y = Column('x', Integer), Column('y', Integer) + >>> cc = ColumnCollection(columns=[x, y]) + >>> cc.x + Column('x', Integer(), table=None) + >>> cc.y + Column('y', Integer(), table=None) + >>> cc['x'] + Column('x', Integer(), table=None) + >>> cc['y'] + + :class`.ColumnCollection` also indexes the columns in order and allows + them to be accessible by their integer position:: + + >>> cc[0] + Column('x', Integer(), table=None) + >>> cc[1] + Column('y', Integer(), table=None) + + .. versionadded:: 1.4 :class:`.ColumnCollection` allows integer-based + index access to the collection. + + Iterating the collection yields the column expressions in order:: + + >>> list(cc) + [Column('x', Integer(), table=None), + Column('y', Integer(), table=None)] + + The base :class:`.ColumnCollection` object can store duplicates, which can + mean either two columns with the same key, in which case the column + returned by key access is **arbitrary**:: + + >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> cc = ColumnCollection(columns=[x1, x2]) + >>> list(cc) + [Column('x', Integer(), table=None), + Column('x', Integer(), table=None)] + >>> cc['x'] is x1 + False + >>> cc['x'] is x2 + True + + Or it can also mean the same column multiple times. These cases are + supported as :class:`.ColumnCollection` is used to represent the columns in + a SELECT statement which may include duplicates. + + A special subclass :class:`.DedupeColumnCollection` exists which instead + maintains SQLAlchemy's older behavior of not allowing duplicates; this + collection is used for schema level objects like :class:`.Table` and + :class:`.PrimaryKeyConstraint` where this deduping is helpful. The + :class:`.DedupeColumnCollection` class also has additional mutation methods + as the schema constructs have more use cases that require removal and + replacement of columns. + + .. versionchanged:: 1.4 :class:`.ColumnCollection` now stores duplicate + column keys as well as the same column in multiple positions. The + :class:`.DedupeColumnCollection` class is added to maintain the + former behavior in those cases where deduplication as well as + additional replace/remove operations are needed. - Overrides the ``__eq__()`` method to produce SQL clauses between - sets of correlated columns. """ - __slots__ = "_all_columns" + __slots__ = "_collection", "_index", "_colset" - def __init__(self, *columns): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - for c in columns: - self.add(c) + def __init__(self, columns=None): + object.__setattr__(self, "_colset", set()) + object.__setattr__(self, "_index", {}) + object.__setattr__(self, "_collection", []) + if columns: + self._initial_populate(columns) + + def _initial_populate(self, iter_): + self._populate_separate_keys(iter_) + + @property + def _all_columns(self): + return [col for (k, col) in self._collection] + + def keys(self): + return [k for (k, col) in self._collection] + + def __len__(self): + return len(self._collection) + + def __iter__(self): + # turn to a list first to maintain over a course of changes + return iter([col for k, col in self._collection]) + + def __getitem__(self, key): + try: + return self._index[key] + except KeyError: + if isinstance(key, util.int_types): + raise IndexError(key) + else: + raise + + def __getattr__(self, key): + try: + return self._index[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + if key not in self._index: + if not isinstance(key, util.string_types): + raise exc.ArgumentError( + "__contains__ requires a string argument" + ) + return False + else: + return True + + def compare(self, other): + for l, r in util.zip_longest(self, other): + if l is not r: + return False + else: + return True + + def __eq__(self, other): + return self.compare(other) + + def get(self, key, default=None): + if key in self._index: + return self._index[key] + else: + return default def __str__(self): return repr([str(c) for c in self]) + def __setitem__(self, key, value): + raise NotImplementedError() + + def __delitem__(self, key): + raise NotImplementedError() + + def __setattr__(self, key, obj): + raise NotImplementedError() + + def clear(self): + raise NotImplementedError() + + def remove(self, column): + raise NotImplementedError() + + def update(self, iter_): + raise NotImplementedError() + + __hash__ = None + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) + self._collection[:] = cols + self._colset.update(c for k, c in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + self._index.update({k: col for k, col in reversed(self._collection)}) + + def add(self, column, key=None): + if key is None: + key = column.key + + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + if key not in self._index: + self._index[key] = column + + def __getstate__(self): + return {"_collection": self._collection, "_index": self._index} + + def __setstate__(self, state): + object.__setattr__(self, "_index", state["_index"]) + object.__setattr__(self, "_collection", state["_collection"]) + object.__setattr__( + self, "_colset", {col for k, col in self._collection} + ) + + def contains_column(self, col): + return col in self._colset + + def as_immutable(self): + return ImmutableColumnCollection(self) + def corresponding_column(self, column, require_embedded=False): """Given a :class:`.ColumnElement`, return the exported :class:`.ColumnElement` object from this :class:`.ColumnCollection` @@ -554,11 +737,11 @@ class ColumnCollection(util.OrderedProperties): return True # don't dig around if the column is locally present - if self.contains_column(column): + if column in self._colset: return column col, intersect = None, None target_set = column.proxy_set - cols = self._all_columns + cols = [c for (k, c) in self._collection] for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) @@ -610,165 +793,167 @@ class ColumnCollection(util.OrderedProperties): col, intersect = c, i return col - def replace(self, column): - """add the given column to this collection, removing unaliased - versions of this column as well as existing columns with the - same key. - e.g.:: +class DedupeColumnCollection(ColumnCollection): + """A :class:`.ColumnCollection that maintains deduplicating behavior. - t = Table('sometable', metadata, Column('col1', Integer)) - t.columns.replace(Column('col1', Integer, key='columnone')) + This is useful by schema level objects such as :class:`.Table` and + :class:`.PrimaryKeyConstraint`. The collection includes more + sophisticated mutator methods as well to suit schema objects which + require mutable column collections. - will remove the original 'col1' from the collection, and add - the new column under the name 'columnname'. - - Used by schema.Column to override columns during table reflection. - - """ - remove_col = None - if column.name in self and column.key != column.name: - other = self[column.name] - if other.name == other.key: - remove_col = other - del self._data[other.key] - - if column.key in self._data: - remove_col = self._data[column.key] - - self._data[column.key] = column - if remove_col is not None: - self._all_columns[:] = [ - column if c is remove_col else c for c in self._all_columns - ] - else: - self._all_columns.append(column) + .. versionadded: 1.4 - def add(self, column): - """Add a column to this collection. + """ - The key attribute of the column will be used as the hash key - for this dictionary. + def add(self, column, key=None): + if key is not None and column.key != key: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + key = column.key - """ - if not column.key: + if key is None: raise exc.ArgumentError( "Can't add unnamed column to column collection" ) - self[column.key] = column - - def __delitem__(self, key): - raise NotImplementedError() - - def __setattr__(self, key, obj): - raise NotImplementedError() - def __setitem__(self, key, value): - if key in self: + if key in self._index: - # this warning is primarily to catch select() statements - # which have conflicting column names in their exported - # columns collection + existing = self._index[key] - existing = self[key] - - if existing is value: + if existing is column: return - if not existing.shares_lineage(value): - util.warn( - "Column %r on table %r being replaced by " - "%r, which has the same key. Consider " - "use_labels for select() statements." - % (key, getattr(existing, "table", None), value) - ) + self.replace(column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(value, "proxy_set") + util.memoized_property.reset(column, "proxy_set") + else: + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + self._index[key] = column + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) - self._all_columns.append(value) - self._data[key] = value + replace_col = [] + for k, col in cols: + if col.key != k: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + if col.name in self._index and col.key != col.name: + replace_col.append(col) + elif col.key in self._index: + replace_col.append(col) + else: + self._index[k] = col + self._collection.append((k, col)) + self._colset.update(c for (k, c) in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + for col in replace_col: + self.replace(col) - def clear(self): - raise NotImplementedError() + def extend(self, iter_): + self._populate_separate_keys((col.key, col) for col in iter_) def remove(self, column): - del self._data[column.key] - self._all_columns[:] = [ - c for c in self._all_columns if c is not column + if column not in self._colset: + raise ValueError( + "Can't remove column %r; column is not in this collection" + % column + ) + del self._index[column.key] + self._colset.remove(column) + self._collection[:] = [ + (k, c) for (k, c) in self._collection if c is not column ] - - def update(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend( - c for label, c in cols if c not in all_col_set + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} ) - self._data.update((label, c) for label, c in cols) + # delete higher index + del self._index[len(self._collection)] - def extend(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend(c for c in cols if c not in all_col_set) - self._data.update((c.key, c) for c in cols) + def replace(self, column): + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. - __hash__ = None + e.g.:: - def __eq__(self, other): - l = [] - for c in getattr(other, "_all_columns", other): - for local in self._all_columns: - if c.shares_lineage(local): - l.append(c == local) - return elements.and_(*l) + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) - def __contains__(self, other): - if not isinstance(other, util.string_types): - raise exc.ArgumentError("__contains__ requires a string argument") - return util.OrderedProperties.__contains__(self, other) + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. - def __getstate__(self): - return {"_data": self._data, "_all_columns": self._all_columns} + Used by schema.Column to override columns during table reflection. - def __setstate__(self, state): - object.__setattr__(self, "_data", state["_data"]) - object.__setattr__(self, "_all_columns", state["_all_columns"]) + """ - def contains_column(self, col): - return col in set(self._all_columns) + remove_col = set() + # remove up to two columns based on matches of name as well as key + if column.name in self._index and column.key != column.name: + other = self._index[column.name] + if other.name == other.key: + remove_col.add(other) + + if column.key in self._index: + remove_col.add(self._index[column.key]) + + new_cols = [] + replaced = False + for k, col in self._collection: + if col in remove_col: + if not replaced: + replaced = True + new_cols.append((column.key, column)) + else: + new_cols.append((k, col)) - def as_immutable(self): - return ImmutableColumnCollection(self._data, self._all_columns) + if remove_col: + self._colset.difference_update(remove_col) + if not replaced: + new_cols.append((column.key, column)) -class SeparateKeyColumnCollection(ColumnCollection): - """Column collection that maintains a string name separate from the - column itself""" + self._colset.add(column) + self._collection[:] = new_cols - def __init__(self, cols_plus_names=None): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - if cols_plus_names: - self.update(cols_plus_names) + self._index.clear() + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} + ) + self._index.update(self._collection) - def replace(self, column): - raise NotImplementedError() - def add(self, column): - raise NotImplementedError() +class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): + __slots__ = ("_parent",) - def remove(self, column): - raise NotImplementedError() + def __init__(self, collection): + object.__setattr__(self, "_parent", collection) + object.__setattr__(self, "_colset", collection._colset) + object.__setattr__(self, "_index", collection._index) + object.__setattr__(self, "_collection", collection._collection) + def __getstate__(self): + return {"_parent": self._parent} -class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): - def __init__(self, data, all_columns): - util.ImmutableProperties.__init__(self, data) - object.__setattr__(self, "_all_columns", all_columns) + def __setstate__(self, state): + parent = state["_parent"] + self.__init__(parent) - extend = remove = util.ImmutableProperties._immutable + add = extend = remove = util.ImmutableContainer._immutable class ColumnSet(util.ordered_column_set): |