summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py447
1 files changed, 316 insertions, 131 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index a84843c4b..da384bdab 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -499,26 +499,209 @@ class SchemaVisitor(ClauseVisitor):
__traverse_options__ = {"schema_visitor": True}
-class ColumnCollection(util.OrderedProperties):
- """An ordered dictionary that stores a list of ColumnElement
- instances.
+class ColumnCollection(object):
+ """Collection of :class:`.ColumnElement` instances, typically for
+ selectables.
+
+ The :class:`.ColumnCollection` has both mapping- and sequence- like
+ behaviors. A :class:`.ColumnCollection` usually stores :class:`.Column`
+ objects, which are then accessible both via mapping style access as well
+ as attribute access style. The name for which a :class:`.Column` would
+ be present is normally that of the :paramref:`.Column.key` parameter,
+ however depending on the context, it may be stored under a special label
+ name::
+
+ >>> from sqlalchemy import Column, Integer
+ >>> from sqlalchemy.sql import ColumnCollection
+ >>> x, y = Column('x', Integer), Column('y', Integer)
+ >>> cc = ColumnCollection(columns=[x, y])
+ >>> cc.x
+ Column('x', Integer(), table=None)
+ >>> cc.y
+ Column('y', Integer(), table=None)
+ >>> cc['x']
+ Column('x', Integer(), table=None)
+ >>> cc['y']
+
+ :class`.ColumnCollection` also indexes the columns in order and allows
+ them to be accessible by their integer position::
+
+ >>> cc[0]
+ Column('x', Integer(), table=None)
+ >>> cc[1]
+ Column('y', Integer(), table=None)
+
+ .. versionadded:: 1.4 :class:`.ColumnCollection` allows integer-based
+ index access to the collection.
+
+ Iterating the collection yields the column expressions in order::
+
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('y', Integer(), table=None)]
+
+ The base :class:`.ColumnCollection` object can store duplicates, which can
+ mean either two columns with the same key, in which case the column
+ returned by key access is **arbitrary**::
+
+ >>> x1, x2 = Column('x', Integer), Column('x', Integer)
+ >>> cc = ColumnCollection(columns=[x1, x2])
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('x', Integer(), table=None)]
+ >>> cc['x'] is x1
+ False
+ >>> cc['x'] is x2
+ True
+
+ Or it can also mean the same column multiple times. These cases are
+ supported as :class:`.ColumnCollection` is used to represent the columns in
+ a SELECT statement which may include duplicates.
+
+ A special subclass :class:`.DedupeColumnCollection` exists which instead
+ maintains SQLAlchemy's older behavior of not allowing duplicates; this
+ collection is used for schema level objects like :class:`.Table` and
+ :class:`.PrimaryKeyConstraint` where this deduping is helpful. The
+ :class:`.DedupeColumnCollection` class also has additional mutation methods
+ as the schema constructs have more use cases that require removal and
+ replacement of columns.
+
+ .. versionchanged:: 1.4 :class:`.ColumnCollection` now stores duplicate
+ column keys as well as the same column in multiple positions. The
+ :class:`.DedupeColumnCollection` class is added to maintain the
+ former behavior in those cases where deduplication as well as
+ additional replace/remove operations are needed.
- Overrides the ``__eq__()`` method to produce SQL clauses between
- sets of correlated columns.
"""
- __slots__ = "_all_columns"
+ __slots__ = "_collection", "_index", "_colset"
- def __init__(self, *columns):
- super(ColumnCollection, self).__init__()
- object.__setattr__(self, "_all_columns", [])
- for c in columns:
- self.add(c)
+ def __init__(self, columns=None):
+ object.__setattr__(self, "_colset", set())
+ object.__setattr__(self, "_index", {})
+ object.__setattr__(self, "_collection", [])
+ if columns:
+ self._initial_populate(columns)
+
+ def _initial_populate(self, iter_):
+ self._populate_separate_keys(iter_)
+
+ @property
+ def _all_columns(self):
+ return [col for (k, col) in self._collection]
+
+ def keys(self):
+ return [k for (k, col) in self._collection]
+
+ def __len__(self):
+ return len(self._collection)
+
+ def __iter__(self):
+ # turn to a list first to maintain over a course of changes
+ return iter([col for k, col in self._collection])
+
+ def __getitem__(self, key):
+ try:
+ return self._index[key]
+ except KeyError:
+ if isinstance(key, util.int_types):
+ raise IndexError(key)
+ else:
+ raise
+
+ def __getattr__(self, key):
+ try:
+ return self._index[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __contains__(self, key):
+ if key not in self._index:
+ if not isinstance(key, util.string_types):
+ raise exc.ArgumentError(
+ "__contains__ requires a string argument"
+ )
+ return False
+ else:
+ return True
+
+ def compare(self, other):
+ for l, r in util.zip_longest(self, other):
+ if l is not r:
+ return False
+ else:
+ return True
+
+ def __eq__(self, other):
+ return self.compare(other)
+
+ def get(self, key, default=None):
+ if key in self._index:
+ return self._index[key]
+ else:
+ return default
def __str__(self):
return repr([str(c) for c in self])
+ def __setitem__(self, key, value):
+ raise NotImplementedError()
+
+ def __delitem__(self, key):
+ raise NotImplementedError()
+
+ def __setattr__(self, key, obj):
+ raise NotImplementedError()
+
+ def clear(self):
+ raise NotImplementedError()
+
+ def remove(self, column):
+ raise NotImplementedError()
+
+ def update(self, iter_):
+ raise NotImplementedError()
+
+ __hash__ = None
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
+ self._collection[:] = cols
+ self._colset.update(c for k, c in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ self._index.update({k: col for k, col in reversed(self._collection)})
+
+ def add(self, column, key=None):
+ if key is None:
+ key = column.key
+
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ if key not in self._index:
+ self._index[key] = column
+
+ def __getstate__(self):
+ return {"_collection": self._collection, "_index": self._index}
+
+ def __setstate__(self, state):
+ object.__setattr__(self, "_index", state["_index"])
+ object.__setattr__(self, "_collection", state["_collection"])
+ object.__setattr__(
+ self, "_colset", {col for k, col in self._collection}
+ )
+
+ def contains_column(self, col):
+ return col in self._colset
+
+ def as_immutable(self):
+ return ImmutableColumnCollection(self)
+
def corresponding_column(self, column, require_embedded=False):
"""Given a :class:`.ColumnElement`, return the exported
:class:`.ColumnElement` object from this :class:`.ColumnCollection`
@@ -554,11 +737,11 @@ class ColumnCollection(util.OrderedProperties):
return True
# don't dig around if the column is locally present
- if self.contains_column(column):
+ if column in self._colset:
return column
col, intersect = None, None
target_set = column.proxy_set
- cols = self._all_columns
+ cols = [c for (k, c) in self._collection]
for c in cols:
expanded_proxy_set = set(_expand_cloned(c.proxy_set))
i = target_set.intersection(expanded_proxy_set)
@@ -610,165 +793,167 @@ class ColumnCollection(util.OrderedProperties):
col, intersect = c, i
return col
- def replace(self, column):
- """add the given column to this collection, removing unaliased
- versions of this column as well as existing columns with the
- same key.
- e.g.::
+class DedupeColumnCollection(ColumnCollection):
+ """A :class:`.ColumnCollection that maintains deduplicating behavior.
- t = Table('sometable', metadata, Column('col1', Integer))
- t.columns.replace(Column('col1', Integer, key='columnone'))
+ This is useful by schema level objects such as :class:`.Table` and
+ :class:`.PrimaryKeyConstraint`. The collection includes more
+ sophisticated mutator methods as well to suit schema objects which
+ require mutable column collections.
- will remove the original 'col1' from the collection, and add
- the new column under the name 'columnname'.
-
- Used by schema.Column to override columns during table reflection.
-
- """
- remove_col = None
- if column.name in self and column.key != column.name:
- other = self[column.name]
- if other.name == other.key:
- remove_col = other
- del self._data[other.key]
-
- if column.key in self._data:
- remove_col = self._data[column.key]
-
- self._data[column.key] = column
- if remove_col is not None:
- self._all_columns[:] = [
- column if c is remove_col else c for c in self._all_columns
- ]
- else:
- self._all_columns.append(column)
+ .. versionadded: 1.4
- def add(self, column):
- """Add a column to this collection.
+ """
- The key attribute of the column will be used as the hash key
- for this dictionary.
+ def add(self, column, key=None):
+ if key is not None and column.key != key:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ key = column.key
- """
- if not column.key:
+ if key is None:
raise exc.ArgumentError(
"Can't add unnamed column to column collection"
)
- self[column.key] = column
-
- def __delitem__(self, key):
- raise NotImplementedError()
-
- def __setattr__(self, key, obj):
- raise NotImplementedError()
- def __setitem__(self, key, value):
- if key in self:
+ if key in self._index:
- # this warning is primarily to catch select() statements
- # which have conflicting column names in their exported
- # columns collection
+ existing = self._index[key]
- existing = self[key]
-
- if existing is value:
+ if existing is column:
return
- if not existing.shares_lineage(value):
- util.warn(
- "Column %r on table %r being replaced by "
- "%r, which has the same key. Consider "
- "use_labels for select() statements."
- % (key, getattr(existing, "table", None), value)
- )
+ self.replace(column)
# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
- util.memoized_property.reset(value, "proxy_set")
+ util.memoized_property.reset(column, "proxy_set")
+ else:
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ self._index[key] = column
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
- self._all_columns.append(value)
- self._data[key] = value
+ replace_col = []
+ for k, col in cols:
+ if col.key != k:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ if col.name in self._index and col.key != col.name:
+ replace_col.append(col)
+ elif col.key in self._index:
+ replace_col.append(col)
+ else:
+ self._index[k] = col
+ self._collection.append((k, col))
+ self._colset.update(c for (k, c) in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ for col in replace_col:
+ self.replace(col)
- def clear(self):
- raise NotImplementedError()
+ def extend(self, iter_):
+ self._populate_separate_keys((col.key, col) for col in iter_)
def remove(self, column):
- del self._data[column.key]
- self._all_columns[:] = [
- c for c in self._all_columns if c is not column
+ if column not in self._colset:
+ raise ValueError(
+ "Can't remove column %r; column is not in this collection"
+ % column
+ )
+ del self._index[column.key]
+ self._colset.remove(column)
+ self._collection[:] = [
+ (k, c) for (k, c) in self._collection if c is not column
]
-
- def update(self, iter_):
- cols = list(iter_)
- all_col_set = set(self._all_columns)
- self._all_columns.extend(
- c for label, c in cols if c not in all_col_set
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
)
- self._data.update((label, c) for label, c in cols)
+ # delete higher index
+ del self._index[len(self._collection)]
- def extend(self, iter_):
- cols = list(iter_)
- all_col_set = set(self._all_columns)
- self._all_columns.extend(c for c in cols if c not in all_col_set)
- self._data.update((c.key, c) for c in cols)
+ def replace(self, column):
+ """add the given column to this collection, removing unaliased
+ versions of this column as well as existing columns with the
+ same key.
- __hash__ = None
+ e.g.::
- def __eq__(self, other):
- l = []
- for c in getattr(other, "_all_columns", other):
- for local in self._all_columns:
- if c.shares_lineage(local):
- l.append(c == local)
- return elements.and_(*l)
+ t = Table('sometable', metadata, Column('col1', Integer))
+ t.columns.replace(Column('col1', Integer, key='columnone'))
- def __contains__(self, other):
- if not isinstance(other, util.string_types):
- raise exc.ArgumentError("__contains__ requires a string argument")
- return util.OrderedProperties.__contains__(self, other)
+ will remove the original 'col1' from the collection, and add
+ the new column under the name 'columnname'.
- def __getstate__(self):
- return {"_data": self._data, "_all_columns": self._all_columns}
+ Used by schema.Column to override columns during table reflection.
- def __setstate__(self, state):
- object.__setattr__(self, "_data", state["_data"])
- object.__setattr__(self, "_all_columns", state["_all_columns"])
+ """
- def contains_column(self, col):
- return col in set(self._all_columns)
+ remove_col = set()
+ # remove up to two columns based on matches of name as well as key
+ if column.name in self._index and column.key != column.name:
+ other = self._index[column.name]
+ if other.name == other.key:
+ remove_col.add(other)
+
+ if column.key in self._index:
+ remove_col.add(self._index[column.key])
+
+ new_cols = []
+ replaced = False
+ for k, col in self._collection:
+ if col in remove_col:
+ if not replaced:
+ replaced = True
+ new_cols.append((column.key, column))
+ else:
+ new_cols.append((k, col))
- def as_immutable(self):
- return ImmutableColumnCollection(self._data, self._all_columns)
+ if remove_col:
+ self._colset.difference_update(remove_col)
+ if not replaced:
+ new_cols.append((column.key, column))
-class SeparateKeyColumnCollection(ColumnCollection):
- """Column collection that maintains a string name separate from the
- column itself"""
+ self._colset.add(column)
+ self._collection[:] = new_cols
- def __init__(self, cols_plus_names=None):
- super(ColumnCollection, self).__init__()
- object.__setattr__(self, "_all_columns", [])
- if cols_plus_names:
- self.update(cols_plus_names)
+ self._index.clear()
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
+ )
+ self._index.update(self._collection)
- def replace(self, column):
- raise NotImplementedError()
- def add(self, column):
- raise NotImplementedError()
+class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+ __slots__ = ("_parent",)
- def remove(self, column):
- raise NotImplementedError()
+ def __init__(self, collection):
+ object.__setattr__(self, "_parent", collection)
+ object.__setattr__(self, "_colset", collection._colset)
+ object.__setattr__(self, "_index", collection._index)
+ object.__setattr__(self, "_collection", collection._collection)
+ def __getstate__(self):
+ return {"_parent": self._parent}
-class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
- def __init__(self, data, all_columns):
- util.ImmutableProperties.__init__(self, data)
- object.__setattr__(self, "_all_columns", all_columns)
+ def __setstate__(self, state):
+ parent = state["_parent"]
+ self.__init__(parent)
- extend = remove = util.ImmutableProperties._immutable
+ add = extend = remove = util.ImmutableContainer._immutable
class ColumnSet(util.ordered_column_set):