diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 447 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 44 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 28 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 44 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 64 |
10 files changed, 449 insertions, 219 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 842730c5b..868c64ed3 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -878,7 +878,8 @@ class OracleCompiler(compiler.SQLCompiler): for_update._copy_internals() for elem in for_update.of: - select = select.column(elem) + if not select.selected_columns.contains_column(elem): + select = select.column(elem) # Wrap the middle select and add the hint inner_subquery = select.alias() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 33a474576..5e8d25647 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -45,6 +45,7 @@ from .. import log from .. import schema from .. import sql from .. import util +from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import operators @@ -1455,7 +1456,11 @@ class Mapper(InspectionAttr): def _configure_properties(self): # Column and other ClauseElement objects which are mapped - self.columns = self.c = util.OrderedProperties() + + # TODO: technically this should be a DedupeColumnCollection + # however DCC needs changes and more tests to fully cover + # storing columns under a separate key name + self.columns = self.c = sql_base.ColumnCollection() # object attribute names mapped to MapperProperty objects self._props = util.OrderedDict() @@ -1781,7 +1786,7 @@ class Mapper(InspectionAttr): or prop.columns[0] is self.polymorphic_on ) - self.columns[key] = col + self.columns.add(col, key) for col in prop.columns + prop._orig_columns: for col in col.proxy_set: self._columntoproperty[col] = prop diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 13402e7f4..b5c49ee05 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -528,7 +528,7 @@ class Query(object): """ - stmt = self._compile_context(labels=self._with_labels).statement + stmt = self._compile_context(for_statement=True).statement if self._params: stmt = stmt.params(self._params) @@ -3843,7 +3843,7 @@ class Query(object): update_op.exec_() return update_op.rowcount - def _compile_context(self, labels=True): + def _compile_context(self, for_statement=False): if self.dispatch.before_compile: for fn in self.dispatch.before_compile: new_query = fn(self) @@ -3855,7 +3855,8 @@ class Query(object): if context.statement is not None: return context - context.labels = labels + context.labels = not for_statement or self._with_labels + context.dedupe_cols = True context._for_update_arg = self._for_update_arg @@ -3909,7 +3910,9 @@ class Query(object): order_by_col_expr = [] inner = sql.select( - context.primary_columns + order_by_col_expr, + util.unique_list(context.primary_columns + order_by_col_expr) + if context.dedupe_cols + else (context.primary_columns + order_by_col_expr), context.whereclause, from_obj=context.froms, use_labels=context.labels, @@ -3979,7 +3982,11 @@ class Query(object): context.froms += tuple(context.eager_joins.values()) statement = sql.select( - context.primary_columns + context.secondary_columns, + util.unique_list( + context.primary_columns + context.secondary_columns + ) + if context.dedupe_cols + else (context.primary_columns + context.secondary_columns), context.whereclause, from_obj=context.froms, use_labels=context.labels, @@ -4290,8 +4297,7 @@ class Bundle(InspectionAttr): """ self.name = self._label = name self.exprs = exprs - self.c = self.columns = ColumnCollection() - self.columns.update( + self.c = self.columns = ColumnCollection( (getattr(col, "key", col._label), col) for col in exprs ) self.single_entity = kw.pop("single_entity", self.single_entity) @@ -4658,6 +4664,7 @@ class QueryContext(object): "whereclause", "order_by", "labels", + "dedupe_cols", "_for_update_arg", "runid", "partials", diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a84843c4b..da384bdab 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -499,26 +499,209 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -class ColumnCollection(util.OrderedProperties): - """An ordered dictionary that stores a list of ColumnElement - instances. +class ColumnCollection(object): + """Collection of :class:`.ColumnElement` instances, typically for + selectables. + + The :class:`.ColumnCollection` has both mapping- and sequence- like + behaviors. A :class:`.ColumnCollection` usually stores :class:`.Column` + objects, which are then accessible both via mapping style access as well + as attribute access style. The name for which a :class:`.Column` would + be present is normally that of the :paramref:`.Column.key` parameter, + however depending on the context, it may be stored under a special label + name:: + + >>> from sqlalchemy import Column, Integer + >>> from sqlalchemy.sql import ColumnCollection + >>> x, y = Column('x', Integer), Column('y', Integer) + >>> cc = ColumnCollection(columns=[x, y]) + >>> cc.x + Column('x', Integer(), table=None) + >>> cc.y + Column('y', Integer(), table=None) + >>> cc['x'] + Column('x', Integer(), table=None) + >>> cc['y'] + + :class`.ColumnCollection` also indexes the columns in order and allows + them to be accessible by their integer position:: + + >>> cc[0] + Column('x', Integer(), table=None) + >>> cc[1] + Column('y', Integer(), table=None) + + .. versionadded:: 1.4 :class:`.ColumnCollection` allows integer-based + index access to the collection. + + Iterating the collection yields the column expressions in order:: + + >>> list(cc) + [Column('x', Integer(), table=None), + Column('y', Integer(), table=None)] + + The base :class:`.ColumnCollection` object can store duplicates, which can + mean either two columns with the same key, in which case the column + returned by key access is **arbitrary**:: + + >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> cc = ColumnCollection(columns=[x1, x2]) + >>> list(cc) + [Column('x', Integer(), table=None), + Column('x', Integer(), table=None)] + >>> cc['x'] is x1 + False + >>> cc['x'] is x2 + True + + Or it can also mean the same column multiple times. These cases are + supported as :class:`.ColumnCollection` is used to represent the columns in + a SELECT statement which may include duplicates. + + A special subclass :class:`.DedupeColumnCollection` exists which instead + maintains SQLAlchemy's older behavior of not allowing duplicates; this + collection is used for schema level objects like :class:`.Table` and + :class:`.PrimaryKeyConstraint` where this deduping is helpful. The + :class:`.DedupeColumnCollection` class also has additional mutation methods + as the schema constructs have more use cases that require removal and + replacement of columns. + + .. versionchanged:: 1.4 :class:`.ColumnCollection` now stores duplicate + column keys as well as the same column in multiple positions. The + :class:`.DedupeColumnCollection` class is added to maintain the + former behavior in those cases where deduplication as well as + additional replace/remove operations are needed. - Overrides the ``__eq__()`` method to produce SQL clauses between - sets of correlated columns. """ - __slots__ = "_all_columns" + __slots__ = "_collection", "_index", "_colset" - def __init__(self, *columns): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - for c in columns: - self.add(c) + def __init__(self, columns=None): + object.__setattr__(self, "_colset", set()) + object.__setattr__(self, "_index", {}) + object.__setattr__(self, "_collection", []) + if columns: + self._initial_populate(columns) + + def _initial_populate(self, iter_): + self._populate_separate_keys(iter_) + + @property + def _all_columns(self): + return [col for (k, col) in self._collection] + + def keys(self): + return [k for (k, col) in self._collection] + + def __len__(self): + return len(self._collection) + + def __iter__(self): + # turn to a list first to maintain over a course of changes + return iter([col for k, col in self._collection]) + + def __getitem__(self, key): + try: + return self._index[key] + except KeyError: + if isinstance(key, util.int_types): + raise IndexError(key) + else: + raise + + def __getattr__(self, key): + try: + return self._index[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + if key not in self._index: + if not isinstance(key, util.string_types): + raise exc.ArgumentError( + "__contains__ requires a string argument" + ) + return False + else: + return True + + def compare(self, other): + for l, r in util.zip_longest(self, other): + if l is not r: + return False + else: + return True + + def __eq__(self, other): + return self.compare(other) + + def get(self, key, default=None): + if key in self._index: + return self._index[key] + else: + return default def __str__(self): return repr([str(c) for c in self]) + def __setitem__(self, key, value): + raise NotImplementedError() + + def __delitem__(self, key): + raise NotImplementedError() + + def __setattr__(self, key, obj): + raise NotImplementedError() + + def clear(self): + raise NotImplementedError() + + def remove(self, column): + raise NotImplementedError() + + def update(self, iter_): + raise NotImplementedError() + + __hash__ = None + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) + self._collection[:] = cols + self._colset.update(c for k, c in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + self._index.update({k: col for k, col in reversed(self._collection)}) + + def add(self, column, key=None): + if key is None: + key = column.key + + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + if key not in self._index: + self._index[key] = column + + def __getstate__(self): + return {"_collection": self._collection, "_index": self._index} + + def __setstate__(self, state): + object.__setattr__(self, "_index", state["_index"]) + object.__setattr__(self, "_collection", state["_collection"]) + object.__setattr__( + self, "_colset", {col for k, col in self._collection} + ) + + def contains_column(self, col): + return col in self._colset + + def as_immutable(self): + return ImmutableColumnCollection(self) + def corresponding_column(self, column, require_embedded=False): """Given a :class:`.ColumnElement`, return the exported :class:`.ColumnElement` object from this :class:`.ColumnCollection` @@ -554,11 +737,11 @@ class ColumnCollection(util.OrderedProperties): return True # don't dig around if the column is locally present - if self.contains_column(column): + if column in self._colset: return column col, intersect = None, None target_set = column.proxy_set - cols = self._all_columns + cols = [c for (k, c) in self._collection] for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) @@ -610,165 +793,167 @@ class ColumnCollection(util.OrderedProperties): col, intersect = c, i return col - def replace(self, column): - """add the given column to this collection, removing unaliased - versions of this column as well as existing columns with the - same key. - e.g.:: +class DedupeColumnCollection(ColumnCollection): + """A :class:`.ColumnCollection that maintains deduplicating behavior. - t = Table('sometable', metadata, Column('col1', Integer)) - t.columns.replace(Column('col1', Integer, key='columnone')) + This is useful by schema level objects such as :class:`.Table` and + :class:`.PrimaryKeyConstraint`. The collection includes more + sophisticated mutator methods as well to suit schema objects which + require mutable column collections. - will remove the original 'col1' from the collection, and add - the new column under the name 'columnname'. - - Used by schema.Column to override columns during table reflection. - - """ - remove_col = None - if column.name in self and column.key != column.name: - other = self[column.name] - if other.name == other.key: - remove_col = other - del self._data[other.key] - - if column.key in self._data: - remove_col = self._data[column.key] - - self._data[column.key] = column - if remove_col is not None: - self._all_columns[:] = [ - column if c is remove_col else c for c in self._all_columns - ] - else: - self._all_columns.append(column) + .. versionadded: 1.4 - def add(self, column): - """Add a column to this collection. + """ - The key attribute of the column will be used as the hash key - for this dictionary. + def add(self, column, key=None): + if key is not None and column.key != key: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + key = column.key - """ - if not column.key: + if key is None: raise exc.ArgumentError( "Can't add unnamed column to column collection" ) - self[column.key] = column - - def __delitem__(self, key): - raise NotImplementedError() - - def __setattr__(self, key, obj): - raise NotImplementedError() - def __setitem__(self, key, value): - if key in self: + if key in self._index: - # this warning is primarily to catch select() statements - # which have conflicting column names in their exported - # columns collection + existing = self._index[key] - existing = self[key] - - if existing is value: + if existing is column: return - if not existing.shares_lineage(value): - util.warn( - "Column %r on table %r being replaced by " - "%r, which has the same key. Consider " - "use_labels for select() statements." - % (key, getattr(existing, "table", None), value) - ) + self.replace(column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(value, "proxy_set") + util.memoized_property.reset(column, "proxy_set") + else: + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + self._index[key] = column + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) - self._all_columns.append(value) - self._data[key] = value + replace_col = [] + for k, col in cols: + if col.key != k: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + if col.name in self._index and col.key != col.name: + replace_col.append(col) + elif col.key in self._index: + replace_col.append(col) + else: + self._index[k] = col + self._collection.append((k, col)) + self._colset.update(c for (k, c) in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + for col in replace_col: + self.replace(col) - def clear(self): - raise NotImplementedError() + def extend(self, iter_): + self._populate_separate_keys((col.key, col) for col in iter_) def remove(self, column): - del self._data[column.key] - self._all_columns[:] = [ - c for c in self._all_columns if c is not column + if column not in self._colset: + raise ValueError( + "Can't remove column %r; column is not in this collection" + % column + ) + del self._index[column.key] + self._colset.remove(column) + self._collection[:] = [ + (k, c) for (k, c) in self._collection if c is not column ] - - def update(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend( - c for label, c in cols if c not in all_col_set + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} ) - self._data.update((label, c) for label, c in cols) + # delete higher index + del self._index[len(self._collection)] - def extend(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend(c for c in cols if c not in all_col_set) - self._data.update((c.key, c) for c in cols) + def replace(self, column): + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. - __hash__ = None + e.g.:: - def __eq__(self, other): - l = [] - for c in getattr(other, "_all_columns", other): - for local in self._all_columns: - if c.shares_lineage(local): - l.append(c == local) - return elements.and_(*l) + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) - def __contains__(self, other): - if not isinstance(other, util.string_types): - raise exc.ArgumentError("__contains__ requires a string argument") - return util.OrderedProperties.__contains__(self, other) + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. - def __getstate__(self): - return {"_data": self._data, "_all_columns": self._all_columns} + Used by schema.Column to override columns during table reflection. - def __setstate__(self, state): - object.__setattr__(self, "_data", state["_data"]) - object.__setattr__(self, "_all_columns", state["_all_columns"]) + """ - def contains_column(self, col): - return col in set(self._all_columns) + remove_col = set() + # remove up to two columns based on matches of name as well as key + if column.name in self._index and column.key != column.name: + other = self._index[column.name] + if other.name == other.key: + remove_col.add(other) + + if column.key in self._index: + remove_col.add(self._index[column.key]) + + new_cols = [] + replaced = False + for k, col in self._collection: + if col in remove_col: + if not replaced: + replaced = True + new_cols.append((column.key, column)) + else: + new_cols.append((k, col)) - def as_immutable(self): - return ImmutableColumnCollection(self._data, self._all_columns) + if remove_col: + self._colset.difference_update(remove_col) + if not replaced: + new_cols.append((column.key, column)) -class SeparateKeyColumnCollection(ColumnCollection): - """Column collection that maintains a string name separate from the - column itself""" + self._colset.add(column) + self._collection[:] = new_cols - def __init__(self, cols_plus_names=None): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - if cols_plus_names: - self.update(cols_plus_names) + self._index.clear() + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} + ) + self._index.update(self._collection) - def replace(self, column): - raise NotImplementedError() - def add(self, column): - raise NotImplementedError() +class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): + __slots__ = ("_parent",) - def remove(self, column): - raise NotImplementedError() + def __init__(self, collection): + object.__setattr__(self, "_parent", collection) + object.__setattr__(self, "_colset", collection._colset) + object.__setattr__(self, "_index", collection._index) + object.__setattr__(self, "_collection", collection._collection) + def __getstate__(self): + return {"_parent": self._parent} -class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): - def __init__(self, data, all_columns): - util.ImmutableProperties.__init__(self, data) - object.__setattr__(self, "_all_columns", all_columns) + def __setstate__(self, state): + parent = state["_parent"] + self.__init__(parent) - extend = remove = util.ImmutableProperties._immutable + add = extend = remove = util.ImmutableContainer._immutable class ColumnSet(util.ordered_column_set): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 13219ee68..ea7e890e7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1828,7 +1828,6 @@ class SQLCompiler(Compiled): result_expr = _CompileLabel( col_expr, name, alt_names=(column._key_label,) ) - elif ( asfrom and isinstance(column, elements.ColumnClause) @@ -1897,6 +1896,7 @@ class SQLCompiler(Compiled): """ cloned = {} column_translate = [{}] + created = set() def visit(element, **kw): if element in column_translate[-1]: @@ -1906,7 +1906,6 @@ class SQLCompiler(Compiled): return cloned[element] newelem = cloned[element] = element._clone() - if ( newelem._is_from_clause and newelem._is_join @@ -1921,6 +1920,8 @@ class SQLCompiler(Compiled): selectable_ = selectable.Select( [right.element], use_labels=True ).alias() + created.add(selectable_) + created.update(selectable_.c) for c in selectable_.c: c._key_label = c.key @@ -1971,6 +1972,11 @@ class SQLCompiler(Compiled): if barrier_select: column_translate.append({}) kw["transform_clue"] = "inside_select" + if not newelem._is_select_container: + froms = newelem.froms + newelem._raw_columns = list(newelem.selected_columns) + newelem._from_obj.update(froms) + newelem._reset_memoizations() newelem._copy_internals(clone=visit, **kw) if barrier_select: del column_translate[-1] @@ -1984,17 +1990,33 @@ class SQLCompiler(Compiled): def _transform_result_map_for_nested_joins( self, select, transformed_select ): - inner_col = dict( - (c._key_label, c) for c in transformed_select.inner_columns - ) - - d = dict((inner_col[c._key_label], c) for c in select.inner_columns) - - self._result_columns = [ - (key, name, tuple([d.get(col, col) for col in objs]), typ) - for key, name, objs, typ in self._result_columns + self._result_columns[:] = [ + result_rec + if col is tcol + else ( + result_rec[0], + name, + tuple([col if obj is tcol else obj for obj in result_rec[2]]), + result_rec[3], + ) + for result_rec, (name, col), (tname, tcol) in zip( + self._result_columns, + select._columns_plus_names, + transformed_select._columns_plus_names, + ) ] + # TODO: it's not anticipated that we need to correct anon_map + # however if we do, this is what it looks like: + # for (name, col), (tname, tcol) in zip( + # select._columns_plus_names, + # transformed_select._columns_plus_names, + # ): + # if isinstance(name, elements._anonymous_label) and name != tname: + # m1 = re.match(r"^%\((\d+ .+?)\)s$", name) + # m2 = re.match(r"^%\((\d+ .+?)\)s$", tname) + # self.anon_map[m1.group(1)] = self.anon_map[m2.group(1)] + _default_stack_entry = util.immutabledict( [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 5a9be7c62..66e92a63c 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -48,8 +48,9 @@ class UpdateBase( named_with_column = False def _generate_fromclause_column_proxies(self, fromclause): - for col in self._returning: - col._make_proxy(fromclause) + fromclause._columns._populate_separate_keys( + col._make_proxy(fromclause) for col in self._returning + ) def _process_colparams(self, parameters): def process_single(p): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 5b4442222..735a125d7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -184,6 +184,7 @@ class ClauseElement(roles.SQLRole, Visitable): _is_returns_rows = False _is_text_clause = False _is_from_container = False + _is_select_container = False _is_select_statement = False _order_by_label_element = None @@ -856,8 +857,7 @@ class ColumnElement( co._proxies = [self] if selectable._is_clone_of is not None: co._is_clone_of = selectable._is_clone_of.columns.get(key) - selectable._columns[key] = co - return co + return key, co def cast(self, type_): """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``. @@ -887,6 +887,12 @@ class ColumnElement( """ return Label(name, self, self.type) + def _anon_label(self, seed): + while self._is_clone_of is not None: + self = self._is_clone_of + + return _anonymous_label("%%(%d %s)s" % (id(self), seed or "anon")) + @util.memoized_property def anon_label(self): """provides a constant 'anonymous label' for this ColumnElement. @@ -901,12 +907,11 @@ class ColumnElement( expressions and function calls. """ - while self._is_clone_of is not None: - self = self._is_clone_of + return self._anon_label(getattr(self, "name", None)) - return _anonymous_label( - "%%(%d %s)s" % (id(self), getattr(self, "name", "anon")) - ) + @util.memoized_property + def _label_anon_label(self): + return self._anon_label(getattr(self, "_label", None)) class BindParameter(roles.InElementRole, ColumnElement): @@ -3951,7 +3956,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): - e = self.element._make_proxy( + key, e = self.element._make_proxy( selectable, name=name if name else self.name, disallow_is_literal=True, @@ -3959,7 +3964,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): e._proxies.append(self) if self._type is not None: e.type = self._type - return e + return key, e class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): @@ -4214,7 +4219,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): self, selectable, name=None, - attach=True, name_is_truncatable=False, disallow_is_literal=False, **kw @@ -4249,9 +4253,7 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns.get(c.key) - if attach: - selectable._columns[c.key] = c - return c + return c.key, c class CollationClause(ColumnElement): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 04fb16a80..2feb6fd5f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -118,7 +118,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ - return ColumnCollection(self.label(None)) + col = self.label(None) + return ColumnCollection(columns=[(col.key, col)]) @util.memoized_property def clauses(self): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index d39bc9832..23c58dc4e 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -40,7 +40,7 @@ from . import roles from . import type_api from . import visitors from .base import _bind_or_error -from .base import ColumnCollection +from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import SchemaEventTarget from .coercions import _document_text_coercion @@ -538,7 +538,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): self.indexes = set() self.constraints = set() - self._columns = ColumnCollection() + self._columns = DedupeColumnCollection() PrimaryKeyConstraint( _implicit_generated=True )._set_parent_with_dispatch(self) @@ -1607,13 +1607,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) c.table = selectable - selectable._columns.add(c) if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns[c.key] if self.primary_key: selectable.primary_key.add(c) - c.dispatch.after_parent_attach(c, selectable) - return c + if fk: + selectable.foreign_keys.update(fk) + return c.key, c def get_children(self, schema_visitor=False, **kwargs): if schema_visitor: @@ -1983,19 +1983,20 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._set_target_column(_column) def _set_target_column(self, column): + assert isinstance(self.parent.table, Table) + # propagate TypeEngine to parent if it didn't have one if self.parent.type._isnull: self.parent.type = column.type # super-edgy case, if other FKs point to our column, # they'd get the type propagated out also. - if isinstance(self.parent.table, Table): - def set_type(fk): - if fk.parent.type._isnull: - fk.parent.type = column.type + def set_type(fk): + if fk.parent.type._isnull: + fk.parent.type = column.type - self.parent._setup_on_memoized_fks(set_type) + self.parent._setup_on_memoized_fks(set_type) self.column = column @@ -2072,7 +2073,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _set_table(self, column, table): # standalone ForeignKey - create ForeignKeyConstraint # on the hosting Table when attached to the Table. - if self.constraint is None and isinstance(table, Table): + assert isinstance(table, Table) + if self.constraint is None: self.constraint = ForeignKeyConstraint( [], [], @@ -2088,7 +2090,6 @@ class ForeignKey(DialectKWArgs, SchemaItem): self.constraint._append_element(column, self) self.constraint._set_parent_with_dispatch(table) table.foreign_keys.add(self) - # set up remote ".column" attribute, or a note to pick it # up when the other Table/Column shows up if isinstance(self._colspec, util.string_types): @@ -2760,7 +2761,7 @@ class ColumnCollectionMixin(object): def __init__(self, *columns, **kw): _autoattach = kw.pop("_autoattach", True) self._column_flag = kw.pop("_column_flag", False) - self.columns = ColumnCollection() + self.columns = DedupeColumnCollection() self._pending_colargs = [ _to_schema_column_or_string(c) for c in columns ] @@ -2885,14 +2886,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): return self.columns.contains_column(col) def __iter__(self): - # inlining of - # return iter(self.columns) - # ColumnCollection->OrderedProperties->OrderedDict - ordered_dict = self.columns._data - return (ordered_dict[key] for key in ordered_dict._list) + return iter(self.columns) def __len__(self): - return len(self.columns._data) + return len(self.columns) class CheckConstraint(ColumnCollectionConstraint): @@ -3368,11 +3365,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): table.constraints.add(self) table_pks = [c for c in table.c if c.primary_key] - if ( - self.columns - and table_pks - and set(table_pks) != set(self.columns.values()) - ): + if self.columns and table_pks and set(table_pks) != set(self.columns): util.warn( "Table '%s' specifies columns %s as primary_key=True, " "not matching locally specified columns %s; setting the " @@ -3390,7 +3383,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): for c in self.columns: c.primary_key = True c.nullable = False - self.columns.extend(table_pks) + if table_pks: + self.columns.extend(table_pks) def _reload(self, columns): """repopulate this :class:`.PrimaryKeyConstraint` given diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c93de8d73..10643c9e4 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -27,10 +27,10 @@ from .base import _from_objects from .base import _generative from .base import ColumnCollection from .base import ColumnSet +from .base import DedupeColumnCollection from .base import Executable from .base import Generative from .base import Immutable -from .base import SeparateKeyColumnCollection from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import _select_iterables @@ -534,8 +534,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._memoized_property.expire_instance(self) def _generate_fromclause_column_proxies(self, fromclause): - for col in self.c: - col._make_proxy(fromclause) + fromclause._columns._populate_separate_keys( + col._make_proxy(fromclause) for col in self.c + ) @property def exported_columns(self): @@ -791,7 +792,9 @@ class Join(FromClause): (c for c in columns if c.primary_key), self.onclause ) ) - self._columns.update((col._key_label, col) for col in columns) + self._columns._populate_separate_keys( + (col._key_label, col) for col in columns + ) self.foreign_keys.update( itertools.chain(*[col.foreign_keys for col in columns]) ) @@ -1861,7 +1864,7 @@ class TableClause(Immutable, FromClause): super(TableClause, self).__init__() self.name = self.fullname = name - self._columns = ColumnCollection() + self._columns = DedupeColumnCollection() self.primary_key = ColumnSet() self.foreign_keys = set() for c in columns: @@ -1881,7 +1884,7 @@ class TableClause(Immutable, FromClause): return self.name.encode("ascii", "backslashreplace") def append_column(self, c): - self._columns[c.key] = c + self._columns.add(c) c.table = self def get_children(self, column_collections=True, **kwargs): @@ -2328,6 +2331,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase): __visit_name__ = "grouping" + _is_select_container = True + def __init__(self, element): # type: (SelectBase) self.element = coercions.expect(roles.SelectStatementRole, element) @@ -2526,6 +2531,7 @@ class GenerativeSelect(SelectBase): FROM clauses to produce a unique set of column names regardless of name conflicts among the individual FROM clauses. + """ self.use_labels = True @@ -4081,6 +4087,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ names = set() + cols = _select_iterables(self._raw_columns) + def name_for_col(c): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL @@ -4090,18 +4098,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): else: name = c._proxy_key if name in names: - name = c.anon_label + if self.use_labels: + name = c._label_anon_label + else: + name = c.anon_label else: names.add(name) return name - return SeparateKeyColumnCollection( - (name_for_col(c), c) - for c in util.unique_list(_select_iterables(self._raw_columns)) + return ColumnCollection( + (name_for_col(c), c) for c in cols ).as_immutable() @_memoized_property def _columns_plus_names(self): + cols = _select_iterables(self._raw_columns) + if self.use_labels: names = set() @@ -4111,23 +4123,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): name = c._label if name in names: - name = c.anon_label + name = c._label_anon_label else: names.add(name) return name, c - return [ - name_for_col(c) - for c in util.unique_list(_select_iterables(self._raw_columns)) - ] + return [name_for_col(c) for c in cols] else: - return [ - (None, c) - for c in util.unique_list(_select_iterables(self._raw_columns)) - ] + return [(None, c) for c in cols] def _generate_fromclause_column_proxies(self, subquery): keys_seen = set() + prox = [] for name, c in self._columns_plus_names: if not hasattr(c, "_make_proxy"): @@ -4137,14 +4144,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): elif self.use_labels: key = c._key_label if key is not None and key in keys_seen: - key = c.anon_label + key = c._label_anon_label keys_seen.add(key) else: key = None - - c._make_proxy( - subquery, key=key, name=name, name_is_truncatable=True + prox.append( + c._make_proxy( + subquery, key=key, name=name, name_is_truncatable=True + ) ) + subquery._columns._populate_separate_keys(prox) def _needs_parens_for_grouping(self): return ( @@ -4397,7 +4406,9 @@ class TextualSelect(SelectBase): .. versionadded:: 1.4 """ - return ColumnCollection(*self.column_args).as_immutable() + return ColumnCollection( + (c.key, c) for c in self.column_args + ).as_immutable() @property def _bind(self): @@ -4408,8 +4419,9 @@ class TextualSelect(SelectBase): self.element = self.element.bindparams(*binds, **bind_as_values) def _generate_fromclause_column_proxies(self, fromclause): - for c in self.column_args: - c._make_proxy(fromclause) + fromclause._columns._populate_separate_keys( + c._make_proxy(fromclause) for c in self.column_args + ) def _copy_internals(self, clone=_clone, **kw): self._reset_memoizations() |