diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-07-07 11:12:31 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-07-11 14:20:10 -0400 |
commit | aceefb508ccd0911f52ff0e50324b3fefeaa3f16 (patch) | |
tree | e57124d3ea8b0e2cd7fe1d3ad22170fa956bcafb /lib/sqlalchemy | |
parent | 5c16367ee78fa1a41d6b715152dcc58f45323d2e (diff) | |
download | sqlalchemy-aceefb508ccd0911f52ff0e50324b3fefeaa3f16.tar.gz |
Allow duplicate columns in from clauses and selectables
The :func:`.select` construct and related constructs now allow for
duplication of column labels and columns themselves in the columns clause,
mirroring exactly how column expressions were passed in. This allows
the tuples returned by an executed result to match what was SELECTed
for in the first place, which is how the ORM :class:`.Query` works, so
this establishes better cross-compatibility between the two constructs.
Additionally, it allows column-positioning-sensitive structures such as
UNIONs (i.e. :class:`.CompoundSelect`) to be more intuitively constructed
in those cases where a particular column might appear in more than one
place. To support this change, the :class:`.ColumnCollection` has been
revised to support duplicate columns as well as to allow integer index
access.
Fixes: #4753
Change-Id: Ie09a8116f05c367995c1e43623c51e07971d3bf0
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 447 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 44 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 28 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 44 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 64 |
10 files changed, 449 insertions, 219 deletions
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 842730c5b..868c64ed3 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -878,7 +878,8 @@ class OracleCompiler(compiler.SQLCompiler): for_update._copy_internals() for elem in for_update.of: - select = select.column(elem) + if not select.selected_columns.contains_column(elem): + select = select.column(elem) # Wrap the middle select and add the hint inner_subquery = select.alias() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 33a474576..5e8d25647 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -45,6 +45,7 @@ from .. import log from .. import schema from .. import sql from .. import util +from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import operators @@ -1455,7 +1456,11 @@ class Mapper(InspectionAttr): def _configure_properties(self): # Column and other ClauseElement objects which are mapped - self.columns = self.c = util.OrderedProperties() + + # TODO: technically this should be a DedupeColumnCollection + # however DCC needs changes and more tests to fully cover + # storing columns under a separate key name + self.columns = self.c = sql_base.ColumnCollection() # object attribute names mapped to MapperProperty objects self._props = util.OrderedDict() @@ -1781,7 +1786,7 @@ class Mapper(InspectionAttr): or prop.columns[0] is self.polymorphic_on ) - self.columns[key] = col + self.columns.add(col, key) for col in prop.columns + prop._orig_columns: for col in col.proxy_set: self._columntoproperty[col] = prop diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 13402e7f4..b5c49ee05 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -528,7 +528,7 @@ class Query(object): """ - stmt = self._compile_context(labels=self._with_labels).statement + stmt = self._compile_context(for_statement=True).statement if self._params: stmt = stmt.params(self._params) @@ -3843,7 +3843,7 @@ class Query(object): update_op.exec_() return update_op.rowcount - def _compile_context(self, labels=True): + def _compile_context(self, for_statement=False): if self.dispatch.before_compile: for fn in self.dispatch.before_compile: new_query = fn(self) @@ -3855,7 +3855,8 @@ class Query(object): if context.statement is not None: return context - context.labels = labels + context.labels = not for_statement or self._with_labels + context.dedupe_cols = True context._for_update_arg = self._for_update_arg @@ -3909,7 +3910,9 @@ class Query(object): order_by_col_expr = [] inner = sql.select( - context.primary_columns + order_by_col_expr, + util.unique_list(context.primary_columns + order_by_col_expr) + if context.dedupe_cols + else (context.primary_columns + order_by_col_expr), context.whereclause, from_obj=context.froms, use_labels=context.labels, @@ -3979,7 +3982,11 @@ class Query(object): context.froms += tuple(context.eager_joins.values()) statement = sql.select( - context.primary_columns + context.secondary_columns, + util.unique_list( + context.primary_columns + context.secondary_columns + ) + if context.dedupe_cols + else (context.primary_columns + context.secondary_columns), context.whereclause, from_obj=context.froms, use_labels=context.labels, @@ -4290,8 +4297,7 @@ class Bundle(InspectionAttr): """ self.name = self._label = name self.exprs = exprs - self.c = self.columns = ColumnCollection() - self.columns.update( + self.c = self.columns = ColumnCollection( (getattr(col, "key", col._label), col) for col in exprs ) self.single_entity = kw.pop("single_entity", self.single_entity) @@ -4658,6 +4664,7 @@ class QueryContext(object): "whereclause", "order_by", "labels", + "dedupe_cols", "_for_update_arg", "runid", "partials", diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a84843c4b..da384bdab 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -499,26 +499,209 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -class ColumnCollection(util.OrderedProperties): - """An ordered dictionary that stores a list of ColumnElement - instances. +class ColumnCollection(object): + """Collection of :class:`.ColumnElement` instances, typically for + selectables. + + The :class:`.ColumnCollection` has both mapping- and sequence- like + behaviors. A :class:`.ColumnCollection` usually stores :class:`.Column` + objects, which are then accessible both via mapping style access as well + as attribute access style. The name for which a :class:`.Column` would + be present is normally that of the :paramref:`.Column.key` parameter, + however depending on the context, it may be stored under a special label + name:: + + >>> from sqlalchemy import Column, Integer + >>> from sqlalchemy.sql import ColumnCollection + >>> x, y = Column('x', Integer), Column('y', Integer) + >>> cc = ColumnCollection(columns=[x, y]) + >>> cc.x + Column('x', Integer(), table=None) + >>> cc.y + Column('y', Integer(), table=None) + >>> cc['x'] + Column('x', Integer(), table=None) + >>> cc['y'] + + :class`.ColumnCollection` also indexes the columns in order and allows + them to be accessible by their integer position:: + + >>> cc[0] + Column('x', Integer(), table=None) + >>> cc[1] + Column('y', Integer(), table=None) + + .. versionadded:: 1.4 :class:`.ColumnCollection` allows integer-based + index access to the collection. + + Iterating the collection yields the column expressions in order:: + + >>> list(cc) + [Column('x', Integer(), table=None), + Column('y', Integer(), table=None)] + + The base :class:`.ColumnCollection` object can store duplicates, which can + mean either two columns with the same key, in which case the column + returned by key access is **arbitrary**:: + + >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> cc = ColumnCollection(columns=[x1, x2]) + >>> list(cc) + [Column('x', Integer(), table=None), + Column('x', Integer(), table=None)] + >>> cc['x'] is x1 + False + >>> cc['x'] is x2 + True + + Or it can also mean the same column multiple times. These cases are + supported as :class:`.ColumnCollection` is used to represent the columns in + a SELECT statement which may include duplicates. + + A special subclass :class:`.DedupeColumnCollection` exists which instead + maintains SQLAlchemy's older behavior of not allowing duplicates; this + collection is used for schema level objects like :class:`.Table` and + :class:`.PrimaryKeyConstraint` where this deduping is helpful. The + :class:`.DedupeColumnCollection` class also has additional mutation methods + as the schema constructs have more use cases that require removal and + replacement of columns. + + .. versionchanged:: 1.4 :class:`.ColumnCollection` now stores duplicate + column keys as well as the same column in multiple positions. The + :class:`.DedupeColumnCollection` class is added to maintain the + former behavior in those cases where deduplication as well as + additional replace/remove operations are needed. - Overrides the ``__eq__()`` method to produce SQL clauses between - sets of correlated columns. """ - __slots__ = "_all_columns" + __slots__ = "_collection", "_index", "_colset" - def __init__(self, *columns): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - for c in columns: - self.add(c) + def __init__(self, columns=None): + object.__setattr__(self, "_colset", set()) + object.__setattr__(self, "_index", {}) + object.__setattr__(self, "_collection", []) + if columns: + self._initial_populate(columns) + + def _initial_populate(self, iter_): + self._populate_separate_keys(iter_) + + @property + def _all_columns(self): + return [col for (k, col) in self._collection] + + def keys(self): + return [k for (k, col) in self._collection] + + def __len__(self): + return len(self._collection) + + def __iter__(self): + # turn to a list first to maintain over a course of changes + return iter([col for k, col in self._collection]) + + def __getitem__(self, key): + try: + return self._index[key] + except KeyError: + if isinstance(key, util.int_types): + raise IndexError(key) + else: + raise + + def __getattr__(self, key): + try: + return self._index[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + if key not in self._index: + if not isinstance(key, util.string_types): + raise exc.ArgumentError( + "__contains__ requires a string argument" + ) + return False + else: + return True + + def compare(self, other): + for l, r in util.zip_longest(self, other): + if l is not r: + return False + else: + return True + + def __eq__(self, other): + return self.compare(other) + + def get(self, key, default=None): + if key in self._index: + return self._index[key] + else: + return default def __str__(self): return repr([str(c) for c in self]) + def __setitem__(self, key, value): + raise NotImplementedError() + + def __delitem__(self, key): + raise NotImplementedError() + + def __setattr__(self, key, obj): + raise NotImplementedError() + + def clear(self): + raise NotImplementedError() + + def remove(self, column): + raise NotImplementedError() + + def update(self, iter_): + raise NotImplementedError() + + __hash__ = None + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) + self._collection[:] = cols + self._colset.update(c for k, c in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + self._index.update({k: col for k, col in reversed(self._collection)}) + + def add(self, column, key=None): + if key is None: + key = column.key + + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + if key not in self._index: + self._index[key] = column + + def __getstate__(self): + return {"_collection": self._collection, "_index": self._index} + + def __setstate__(self, state): + object.__setattr__(self, "_index", state["_index"]) + object.__setattr__(self, "_collection", state["_collection"]) + object.__setattr__( + self, "_colset", {col for k, col in self._collection} + ) + + def contains_column(self, col): + return col in self._colset + + def as_immutable(self): + return ImmutableColumnCollection(self) + def corresponding_column(self, column, require_embedded=False): """Given a :class:`.ColumnElement`, return the exported :class:`.ColumnElement` object from this :class:`.ColumnCollection` @@ -554,11 +737,11 @@ class ColumnCollection(util.OrderedProperties): return True # don't dig around if the column is locally present - if self.contains_column(column): + if column in self._colset: return column col, intersect = None, None target_set = column.proxy_set - cols = self._all_columns + cols = [c for (k, c) in self._collection] for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) @@ -610,165 +793,167 @@ class ColumnCollection(util.OrderedProperties): col, intersect = c, i return col - def replace(self, column): - """add the given column to this collection, removing unaliased - versions of this column as well as existing columns with the - same key. - e.g.:: +class DedupeColumnCollection(ColumnCollection): + """A :class:`.ColumnCollection that maintains deduplicating behavior. - t = Table('sometable', metadata, Column('col1', Integer)) - t.columns.replace(Column('col1', Integer, key='columnone')) + This is useful by schema level objects such as :class:`.Table` and + :class:`.PrimaryKeyConstraint`. The collection includes more + sophisticated mutator methods as well to suit schema objects which + require mutable column collections. - will remove the original 'col1' from the collection, and add - the new column under the name 'columnname'. - - Used by schema.Column to override columns during table reflection. - - """ - remove_col = None - if column.name in self and column.key != column.name: - other = self[column.name] - if other.name == other.key: - remove_col = other - del self._data[other.key] - - if column.key in self._data: - remove_col = self._data[column.key] - - self._data[column.key] = column - if remove_col is not None: - self._all_columns[:] = [ - column if c is remove_col else c for c in self._all_columns - ] - else: - self._all_columns.append(column) + .. versionadded: 1.4 - def add(self, column): - """Add a column to this collection. + """ - The key attribute of the column will be used as the hash key - for this dictionary. + def add(self, column, key=None): + if key is not None and column.key != key: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + key = column.key - """ - if not column.key: + if key is None: raise exc.ArgumentError( "Can't add unnamed column to column collection" ) - self[column.key] = column - - def __delitem__(self, key): - raise NotImplementedError() - - def __setattr__(self, key, obj): - raise NotImplementedError() - def __setitem__(self, key, value): - if key in self: + if key in self._index: - # this warning is primarily to catch select() statements - # which have conflicting column names in their exported - # columns collection + existing = self._index[key] - existing = self[key] - - if existing is value: + if existing is column: return - if not existing.shares_lineage(value): - util.warn( - "Column %r on table %r being replaced by " - "%r, which has the same key. Consider " - "use_labels for select() statements." - % (key, getattr(existing, "table", None), value) - ) + self.replace(column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(value, "proxy_set") + util.memoized_property.reset(column, "proxy_set") + else: + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + self._index[key] = column + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) - self._all_columns.append(value) - self._data[key] = value + replace_col = [] + for k, col in cols: + if col.key != k: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + if col.name in self._index and col.key != col.name: + replace_col.append(col) + elif col.key in self._index: + replace_col.append(col) + else: + self._index[k] = col + self._collection.append((k, col)) + self._colset.update(c for (k, c) in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + for col in replace_col: + self.replace(col) - def clear(self): - raise NotImplementedError() + def extend(self, iter_): + self._populate_separate_keys((col.key, col) for col in iter_) def remove(self, column): - del self._data[column.key] - self._all_columns[:] = [ - c for c in self._all_columns if c is not column + if column not in self._colset: + raise ValueError( + "Can't remove column %r; column is not in this collection" + % column + ) + del self._index[column.key] + self._colset.remove(column) + self._collection[:] = [ + (k, c) for (k, c) in self._collection if c is not column ] - - def update(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend( - c for label, c in cols if c not in all_col_set + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} ) - self._data.update((label, c) for label, c in cols) + # delete higher index + del self._index[len(self._collection)] - def extend(self, iter_): - cols = list(iter_) - all_col_set = set(self._all_columns) - self._all_columns.extend(c for c in cols if c not in all_col_set) - self._data.update((c.key, c) for c in cols) + def replace(self, column): + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. - __hash__ = None + e.g.:: - def __eq__(self, other): - l = [] - for c in getattr(other, "_all_columns", other): - for local in self._all_columns: - if c.shares_lineage(local): - l.append(c == local) - return elements.and_(*l) + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) - def __contains__(self, other): - if not isinstance(other, util.string_types): - raise exc.ArgumentError("__contains__ requires a string argument") - return util.OrderedProperties.__contains__(self, other) + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. - def __getstate__(self): - return {"_data": self._data, "_all_columns": self._all_columns} + Used by schema.Column to override columns during table reflection. - def __setstate__(self, state): - object.__setattr__(self, "_data", state["_data"]) - object.__setattr__(self, "_all_columns", state["_all_columns"]) + """ - def contains_column(self, col): - return col in set(self._all_columns) + remove_col = set() + # remove up to two columns based on matches of name as well as key + if column.name in self._index and column.key != column.name: + other = self._index[column.name] + if other.name == other.key: + remove_col.add(other) + + if column.key in self._index: + remove_col.add(self._index[column.key]) + + new_cols = [] + replaced = False + for k, col in self._collection: + if col in remove_col: + if not replaced: + replaced = True + new_cols.append((column.key, column)) + else: + new_cols.append((k, col)) - def as_immutable(self): - return ImmutableColumnCollection(self._data, self._all_columns) + if remove_col: + self._colset.difference_update(remove_col) + if not replaced: + new_cols.append((column.key, column)) -class SeparateKeyColumnCollection(ColumnCollection): - """Column collection that maintains a string name separate from the - column itself""" + self._colset.add(column) + self._collection[:] = new_cols - def __init__(self, cols_plus_names=None): - super(ColumnCollection, self).__init__() - object.__setattr__(self, "_all_columns", []) - if cols_plus_names: - self.update(cols_plus_names) + self._index.clear() + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} + ) + self._index.update(self._collection) - def replace(self, column): - raise NotImplementedError() - def add(self, column): - raise NotImplementedError() +class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): + __slots__ = ("_parent",) - def remove(self, column): - raise NotImplementedError() + def __init__(self, collection): + object.__setattr__(self, "_parent", collection) + object.__setattr__(self, "_colset", collection._colset) + object.__setattr__(self, "_index", collection._index) + object.__setattr__(self, "_collection", collection._collection) + def __getstate__(self): + return {"_parent": self._parent} -class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): - def __init__(self, data, all_columns): - util.ImmutableProperties.__init__(self, data) - object.__setattr__(self, "_all_columns", all_columns) + def __setstate__(self, state): + parent = state["_parent"] + self.__init__(parent) - extend = remove = util.ImmutableProperties._immutable + add = extend = remove = util.ImmutableContainer._immutable class ColumnSet(util.ordered_column_set): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 13219ee68..ea7e890e7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1828,7 +1828,6 @@ class SQLCompiler(Compiled): result_expr = _CompileLabel( col_expr, name, alt_names=(column._key_label,) ) - elif ( asfrom and isinstance(column, elements.ColumnClause) @@ -1897,6 +1896,7 @@ class SQLCompiler(Compiled): """ cloned = {} column_translate = [{}] + created = set() def visit(element, **kw): if element in column_translate[-1]: @@ -1906,7 +1906,6 @@ class SQLCompiler(Compiled): return cloned[element] newelem = cloned[element] = element._clone() - if ( newelem._is_from_clause and newelem._is_join @@ -1921,6 +1920,8 @@ class SQLCompiler(Compiled): selectable_ = selectable.Select( [right.element], use_labels=True ).alias() + created.add(selectable_) + created.update(selectable_.c) for c in selectable_.c: c._key_label = c.key @@ -1971,6 +1972,11 @@ class SQLCompiler(Compiled): if barrier_select: column_translate.append({}) kw["transform_clue"] = "inside_select" + if not newelem._is_select_container: + froms = newelem.froms + newelem._raw_columns = list(newelem.selected_columns) + newelem._from_obj.update(froms) + newelem._reset_memoizations() newelem._copy_internals(clone=visit, **kw) if barrier_select: del column_translate[-1] @@ -1984,17 +1990,33 @@ class SQLCompiler(Compiled): def _transform_result_map_for_nested_joins( self, select, transformed_select ): - inner_col = dict( - (c._key_label, c) for c in transformed_select.inner_columns - ) - - d = dict((inner_col[c._key_label], c) for c in select.inner_columns) - - self._result_columns = [ - (key, name, tuple([d.get(col, col) for col in objs]), typ) - for key, name, objs, typ in self._result_columns + self._result_columns[:] = [ + result_rec + if col is tcol + else ( + result_rec[0], + name, + tuple([col if obj is tcol else obj for obj in result_rec[2]]), + result_rec[3], + ) + for result_rec, (name, col), (tname, tcol) in zip( + self._result_columns, + select._columns_plus_names, + transformed_select._columns_plus_names, + ) ] + # TODO: it's not anticipated that we need to correct anon_map + # however if we do, this is what it looks like: + # for (name, col), (tname, tcol) in zip( + # select._columns_plus_names, + # transformed_select._columns_plus_names, + # ): + # if isinstance(name, elements._anonymous_label) and name != tname: + # m1 = re.match(r"^%\((\d+ .+?)\)s$", name) + # m2 = re.match(r"^%\((\d+ .+?)\)s$", tname) + # self.anon_map[m1.group(1)] = self.anon_map[m2.group(1)] + _default_stack_entry = util.immutabledict( [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 5a9be7c62..66e92a63c 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -48,8 +48,9 @@ class UpdateBase( named_with_column = False def _generate_fromclause_column_proxies(self, fromclause): - for col in self._returning: - col._make_proxy(fromclause) + fromclause._columns._populate_separate_keys( + col._make_proxy(fromclause) for col in self._returning + ) def _process_colparams(self, parameters): def process_single(p): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 5b4442222..735a125d7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -184,6 +184,7 @@ class ClauseElement(roles.SQLRole, Visitable): _is_returns_rows = False _is_text_clause = False _is_from_container = False + _is_select_container = False _is_select_statement = False _order_by_label_element = None @@ -856,8 +857,7 @@ class ColumnElement( co._proxies = [self] if selectable._is_clone_of is not None: co._is_clone_of = selectable._is_clone_of.columns.get(key) - selectable._columns[key] = co - return co + return key, co def cast(self, type_): """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``. @@ -887,6 +887,12 @@ class ColumnElement( """ return Label(name, self, self.type) + def _anon_label(self, seed): + while self._is_clone_of is not None: + self = self._is_clone_of + + return _anonymous_label("%%(%d %s)s" % (id(self), seed or "anon")) + @util.memoized_property def anon_label(self): """provides a constant 'anonymous label' for this ColumnElement. @@ -901,12 +907,11 @@ class ColumnElement( expressions and function calls. """ - while self._is_clone_of is not None: - self = self._is_clone_of + return self._anon_label(getattr(self, "name", None)) - return _anonymous_label( - "%%(%d %s)s" % (id(self), getattr(self, "name", "anon")) - ) + @util.memoized_property + def _label_anon_label(self): + return self._anon_label(getattr(self, "_label", None)) class BindParameter(roles.InElementRole, ColumnElement): @@ -3951,7 +3956,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): - e = self.element._make_proxy( + key, e = self.element._make_proxy( selectable, name=name if name else self.name, disallow_is_literal=True, @@ -3959,7 +3964,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): e._proxies.append(self) if self._type is not None: e.type = self._type - return e + return key, e class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): @@ -4214,7 +4219,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): self, selectable, name=None, - attach=True, name_is_truncatable=False, disallow_is_literal=False, **kw @@ -4249,9 +4253,7 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns.get(c.key) - if attach: - selectable._columns[c.key] = c - return c + return c.key, c class CollationClause(ColumnElement): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 04fb16a80..2feb6fd5f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -118,7 +118,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ - return ColumnCollection(self.label(None)) + col = self.label(None) + return ColumnCollection(columns=[(col.key, col)]) @util.memoized_property def clauses(self): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index d39bc9832..23c58dc4e 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -40,7 +40,7 @@ from . import roles from . import type_api from . import visitors from .base import _bind_or_error -from .base import ColumnCollection +from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import SchemaEventTarget from .coercions import _document_text_coercion @@ -538,7 +538,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): self.indexes = set() self.constraints = set() - self._columns = ColumnCollection() + self._columns = DedupeColumnCollection() PrimaryKeyConstraint( _implicit_generated=True )._set_parent_with_dispatch(self) @@ -1607,13 +1607,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) c.table = selectable - selectable._columns.add(c) if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns[c.key] if self.primary_key: selectable.primary_key.add(c) - c.dispatch.after_parent_attach(c, selectable) - return c + if fk: + selectable.foreign_keys.update(fk) + return c.key, c def get_children(self, schema_visitor=False, **kwargs): if schema_visitor: @@ -1983,19 +1983,20 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._set_target_column(_column) def _set_target_column(self, column): + assert isinstance(self.parent.table, Table) + # propagate TypeEngine to parent if it didn't have one if self.parent.type._isnull: self.parent.type = column.type # super-edgy case, if other FKs point to our column, # they'd get the type propagated out also. - if isinstance(self.parent.table, Table): - def set_type(fk): - if fk.parent.type._isnull: - fk.parent.type = column.type + def set_type(fk): + if fk.parent.type._isnull: + fk.parent.type = column.type - self.parent._setup_on_memoized_fks(set_type) + self.parent._setup_on_memoized_fks(set_type) self.column = column @@ -2072,7 +2073,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _set_table(self, column, table): # standalone ForeignKey - create ForeignKeyConstraint # on the hosting Table when attached to the Table. - if self.constraint is None and isinstance(table, Table): + assert isinstance(table, Table) + if self.constraint is None: self.constraint = ForeignKeyConstraint( [], [], @@ -2088,7 +2090,6 @@ class ForeignKey(DialectKWArgs, SchemaItem): self.constraint._append_element(column, self) self.constraint._set_parent_with_dispatch(table) table.foreign_keys.add(self) - # set up remote ".column" attribute, or a note to pick it # up when the other Table/Column shows up if isinstance(self._colspec, util.string_types): @@ -2760,7 +2761,7 @@ class ColumnCollectionMixin(object): def __init__(self, *columns, **kw): _autoattach = kw.pop("_autoattach", True) self._column_flag = kw.pop("_column_flag", False) - self.columns = ColumnCollection() + self.columns = DedupeColumnCollection() self._pending_colargs = [ _to_schema_column_or_string(c) for c in columns ] @@ -2885,14 +2886,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): return self.columns.contains_column(col) def __iter__(self): - # inlining of - # return iter(self.columns) - # ColumnCollection->OrderedProperties->OrderedDict - ordered_dict = self.columns._data - return (ordered_dict[key] for key in ordered_dict._list) + return iter(self.columns) def __len__(self): - return len(self.columns._data) + return len(self.columns) class CheckConstraint(ColumnCollectionConstraint): @@ -3368,11 +3365,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): table.constraints.add(self) table_pks = [c for c in table.c if c.primary_key] - if ( - self.columns - and table_pks - and set(table_pks) != set(self.columns.values()) - ): + if self.columns and table_pks and set(table_pks) != set(self.columns): util.warn( "Table '%s' specifies columns %s as primary_key=True, " "not matching locally specified columns %s; setting the " @@ -3390,7 +3383,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): for c in self.columns: c.primary_key = True c.nullable = False - self.columns.extend(table_pks) + if table_pks: + self.columns.extend(table_pks) def _reload(self, columns): """repopulate this :class:`.PrimaryKeyConstraint` given diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c93de8d73..10643c9e4 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -27,10 +27,10 @@ from .base import _from_objects from .base import _generative from .base import ColumnCollection from .base import ColumnSet +from .base import DedupeColumnCollection from .base import Executable from .base import Generative from .base import Immutable -from .base import SeparateKeyColumnCollection from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import _select_iterables @@ -534,8 +534,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._memoized_property.expire_instance(self) def _generate_fromclause_column_proxies(self, fromclause): - for col in self.c: - col._make_proxy(fromclause) + fromclause._columns._populate_separate_keys( + col._make_proxy(fromclause) for col in self.c + ) @property def exported_columns(self): @@ -791,7 +792,9 @@ class Join(FromClause): (c for c in columns if c.primary_key), self.onclause ) ) - self._columns.update((col._key_label, col) for col in columns) + self._columns._populate_separate_keys( + (col._key_label, col) for col in columns + ) self.foreign_keys.update( itertools.chain(*[col.foreign_keys for col in columns]) ) @@ -1861,7 +1864,7 @@ class TableClause(Immutable, FromClause): super(TableClause, self).__init__() self.name = self.fullname = name - self._columns = ColumnCollection() + self._columns = DedupeColumnCollection() self.primary_key = ColumnSet() self.foreign_keys = set() for c in columns: @@ -1881,7 +1884,7 @@ class TableClause(Immutable, FromClause): return self.name.encode("ascii", "backslashreplace") def append_column(self, c): - self._columns[c.key] = c + self._columns.add(c) c.table = self def get_children(self, column_collections=True, **kwargs): @@ -2328,6 +2331,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase): __visit_name__ = "grouping" + _is_select_container = True + def __init__(self, element): # type: (SelectBase) self.element = coercions.expect(roles.SelectStatementRole, element) @@ -2526,6 +2531,7 @@ class GenerativeSelect(SelectBase): FROM clauses to produce a unique set of column names regardless of name conflicts among the individual FROM clauses. + """ self.use_labels = True @@ -4081,6 +4087,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ names = set() + cols = _select_iterables(self._raw_columns) + def name_for_col(c): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL @@ -4090,18 +4098,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): else: name = c._proxy_key if name in names: - name = c.anon_label + if self.use_labels: + name = c._label_anon_label + else: + name = c.anon_label else: names.add(name) return name - return SeparateKeyColumnCollection( - (name_for_col(c), c) - for c in util.unique_list(_select_iterables(self._raw_columns)) + return ColumnCollection( + (name_for_col(c), c) for c in cols ).as_immutable() @_memoized_property def _columns_plus_names(self): + cols = _select_iterables(self._raw_columns) + if self.use_labels: names = set() @@ -4111,23 +4123,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): name = c._label if name in names: - name = c.anon_label + name = c._label_anon_label else: names.add(name) return name, c - return [ - name_for_col(c) - for c in util.unique_list(_select_iterables(self._raw_columns)) - ] + return [name_for_col(c) for c in cols] else: - return [ - (None, c) - for c in util.unique_list(_select_iterables(self._raw_columns)) - ] + return [(None, c) for c in cols] def _generate_fromclause_column_proxies(self, subquery): keys_seen = set() + prox = [] for name, c in self._columns_plus_names: if not hasattr(c, "_make_proxy"): @@ -4137,14 +4144,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): elif self.use_labels: key = c._key_label if key is not None and key in keys_seen: - key = c.anon_label + key = c._label_anon_label keys_seen.add(key) else: key = None - - c._make_proxy( - subquery, key=key, name=name, name_is_truncatable=True + prox.append( + c._make_proxy( + subquery, key=key, name=name, name_is_truncatable=True + ) ) + subquery._columns._populate_separate_keys(prox) def _needs_parens_for_grouping(self): return ( @@ -4397,7 +4406,9 @@ class TextualSelect(SelectBase): .. versionadded:: 1.4 """ - return ColumnCollection(*self.column_args).as_immutable() + return ColumnCollection( + (c.key, c) for c in self.column_args + ).as_immutable() @property def _bind(self): @@ -4408,8 +4419,9 @@ class TextualSelect(SelectBase): self.element = self.element.bindparams(*binds, **bind_as_values) def _generate_fromclause_column_proxies(self, fromclause): - for c in self.column_args: - c._make_proxy(fromclause) + fromclause._columns._populate_separate_keys( + c._make_proxy(fromclause) for c in self.column_args + ) def _copy_internals(self, clone=_clone, **kw): self._reset_memoizations() |