summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/base.py447
-rw-r--r--lib/sqlalchemy/sql/compiler.py44
-rw-r--r--lib/sqlalchemy/sql/dml.py5
-rw-r--r--lib/sqlalchemy/sql/elements.py28
-rw-r--r--lib/sqlalchemy/sql/functions.py3
-rw-r--r--lib/sqlalchemy/sql/schema.py44
-rw-r--r--lib/sqlalchemy/sql/selectable.py64
7 files changed, 426 insertions, 209 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index a84843c4b..da384bdab 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -499,26 +499,209 @@ class SchemaVisitor(ClauseVisitor):
__traverse_options__ = {"schema_visitor": True}
-class ColumnCollection(util.OrderedProperties):
- """An ordered dictionary that stores a list of ColumnElement
- instances.
+class ColumnCollection(object):
+ """Collection of :class:`.ColumnElement` instances, typically for
+ selectables.
+
+ The :class:`.ColumnCollection` has both mapping- and sequence- like
+ behaviors. A :class:`.ColumnCollection` usually stores :class:`.Column`
+ objects, which are then accessible both via mapping style access as well
+ as attribute access style. The name for which a :class:`.Column` would
+ be present is normally that of the :paramref:`.Column.key` parameter,
+ however depending on the context, it may be stored under a special label
+ name::
+
+ >>> from sqlalchemy import Column, Integer
+ >>> from sqlalchemy.sql import ColumnCollection
+ >>> x, y = Column('x', Integer), Column('y', Integer)
+ >>> cc = ColumnCollection(columns=[x, y])
+ >>> cc.x
+ Column('x', Integer(), table=None)
+ >>> cc.y
+ Column('y', Integer(), table=None)
+ >>> cc['x']
+ Column('x', Integer(), table=None)
+ >>> cc['y']
+
+ :class`.ColumnCollection` also indexes the columns in order and allows
+ them to be accessible by their integer position::
+
+ >>> cc[0]
+ Column('x', Integer(), table=None)
+ >>> cc[1]
+ Column('y', Integer(), table=None)
+
+ .. versionadded:: 1.4 :class:`.ColumnCollection` allows integer-based
+ index access to the collection.
+
+ Iterating the collection yields the column expressions in order::
+
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('y', Integer(), table=None)]
+
+ The base :class:`.ColumnCollection` object can store duplicates, which can
+ mean either two columns with the same key, in which case the column
+ returned by key access is **arbitrary**::
+
+ >>> x1, x2 = Column('x', Integer), Column('x', Integer)
+ >>> cc = ColumnCollection(columns=[x1, x2])
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('x', Integer(), table=None)]
+ >>> cc['x'] is x1
+ False
+ >>> cc['x'] is x2
+ True
+
+ Or it can also mean the same column multiple times. These cases are
+ supported as :class:`.ColumnCollection` is used to represent the columns in
+ a SELECT statement which may include duplicates.
+
+ A special subclass :class:`.DedupeColumnCollection` exists which instead
+ maintains SQLAlchemy's older behavior of not allowing duplicates; this
+ collection is used for schema level objects like :class:`.Table` and
+ :class:`.PrimaryKeyConstraint` where this deduping is helpful. The
+ :class:`.DedupeColumnCollection` class also has additional mutation methods
+ as the schema constructs have more use cases that require removal and
+ replacement of columns.
+
+ .. versionchanged:: 1.4 :class:`.ColumnCollection` now stores duplicate
+ column keys as well as the same column in multiple positions. The
+ :class:`.DedupeColumnCollection` class is added to maintain the
+ former behavior in those cases where deduplication as well as
+ additional replace/remove operations are needed.
- Overrides the ``__eq__()`` method to produce SQL clauses between
- sets of correlated columns.
"""
- __slots__ = "_all_columns"
+ __slots__ = "_collection", "_index", "_colset"
- def __init__(self, *columns):
- super(ColumnCollection, self).__init__()
- object.__setattr__(self, "_all_columns", [])
- for c in columns:
- self.add(c)
+ def __init__(self, columns=None):
+ object.__setattr__(self, "_colset", set())
+ object.__setattr__(self, "_index", {})
+ object.__setattr__(self, "_collection", [])
+ if columns:
+ self._initial_populate(columns)
+
+ def _initial_populate(self, iter_):
+ self._populate_separate_keys(iter_)
+
+ @property
+ def _all_columns(self):
+ return [col for (k, col) in self._collection]
+
+ def keys(self):
+ return [k for (k, col) in self._collection]
+
+ def __len__(self):
+ return len(self._collection)
+
+ def __iter__(self):
+ # turn to a list first to maintain over a course of changes
+ return iter([col for k, col in self._collection])
+
+ def __getitem__(self, key):
+ try:
+ return self._index[key]
+ except KeyError:
+ if isinstance(key, util.int_types):
+ raise IndexError(key)
+ else:
+ raise
+
+ def __getattr__(self, key):
+ try:
+ return self._index[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __contains__(self, key):
+ if key not in self._index:
+ if not isinstance(key, util.string_types):
+ raise exc.ArgumentError(
+ "__contains__ requires a string argument"
+ )
+ return False
+ else:
+ return True
+
+ def compare(self, other):
+ for l, r in util.zip_longest(self, other):
+ if l is not r:
+ return False
+ else:
+ return True
+
+ def __eq__(self, other):
+ return self.compare(other)
+
+ def get(self, key, default=None):
+ if key in self._index:
+ return self._index[key]
+ else:
+ return default
def __str__(self):
return repr([str(c) for c in self])
+ def __setitem__(self, key, value):
+ raise NotImplementedError()
+
+ def __delitem__(self, key):
+ raise NotImplementedError()
+
+ def __setattr__(self, key, obj):
+ raise NotImplementedError()
+
+ def clear(self):
+ raise NotImplementedError()
+
+ def remove(self, column):
+ raise NotImplementedError()
+
+ def update(self, iter_):
+ raise NotImplementedError()
+
+ __hash__ = None
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
+ self._collection[:] = cols
+ self._colset.update(c for k, c in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ self._index.update({k: col for k, col in reversed(self._collection)})
+
+ def add(self, column, key=None):
+ if key is None:
+ key = column.key
+
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ if key not in self._index:
+ self._index[key] = column
+
+ def __getstate__(self):
+ return {"_collection": self._collection, "_index": self._index}
+
+ def __setstate__(self, state):
+ object.__setattr__(self, "_index", state["_index"])
+ object.__setattr__(self, "_collection", state["_collection"])
+ object.__setattr__(
+ self, "_colset", {col for k, col in self._collection}
+ )
+
+ def contains_column(self, col):
+ return col in self._colset
+
+ def as_immutable(self):
+ return ImmutableColumnCollection(self)
+
def corresponding_column(self, column, require_embedded=False):
"""Given a :class:`.ColumnElement`, return the exported
:class:`.ColumnElement` object from this :class:`.ColumnCollection`
@@ -554,11 +737,11 @@ class ColumnCollection(util.OrderedProperties):
return True
# don't dig around if the column is locally present
- if self.contains_column(column):
+ if column in self._colset:
return column
col, intersect = None, None
target_set = column.proxy_set
- cols = self._all_columns
+ cols = [c for (k, c) in self._collection]
for c in cols:
expanded_proxy_set = set(_expand_cloned(c.proxy_set))
i = target_set.intersection(expanded_proxy_set)
@@ -610,165 +793,167 @@ class ColumnCollection(util.OrderedProperties):
col, intersect = c, i
return col
- def replace(self, column):
- """add the given column to this collection, removing unaliased
- versions of this column as well as existing columns with the
- same key.
- e.g.::
+class DedupeColumnCollection(ColumnCollection):
+ """A :class:`.ColumnCollection that maintains deduplicating behavior.
- t = Table('sometable', metadata, Column('col1', Integer))
- t.columns.replace(Column('col1', Integer, key='columnone'))
+ This is useful by schema level objects such as :class:`.Table` and
+ :class:`.PrimaryKeyConstraint`. The collection includes more
+ sophisticated mutator methods as well to suit schema objects which
+ require mutable column collections.
- will remove the original 'col1' from the collection, and add
- the new column under the name 'columnname'.
-
- Used by schema.Column to override columns during table reflection.
-
- """
- remove_col = None
- if column.name in self and column.key != column.name:
- other = self[column.name]
- if other.name == other.key:
- remove_col = other
- del self._data[other.key]
-
- if column.key in self._data:
- remove_col = self._data[column.key]
-
- self._data[column.key] = column
- if remove_col is not None:
- self._all_columns[:] = [
- column if c is remove_col else c for c in self._all_columns
- ]
- else:
- self._all_columns.append(column)
+ .. versionadded: 1.4
- def add(self, column):
- """Add a column to this collection.
+ """
- The key attribute of the column will be used as the hash key
- for this dictionary.
+ def add(self, column, key=None):
+ if key is not None and column.key != key:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ key = column.key
- """
- if not column.key:
+ if key is None:
raise exc.ArgumentError(
"Can't add unnamed column to column collection"
)
- self[column.key] = column
-
- def __delitem__(self, key):
- raise NotImplementedError()
-
- def __setattr__(self, key, obj):
- raise NotImplementedError()
- def __setitem__(self, key, value):
- if key in self:
+ if key in self._index:
- # this warning is primarily to catch select() statements
- # which have conflicting column names in their exported
- # columns collection
+ existing = self._index[key]
- existing = self[key]
-
- if existing is value:
+ if existing is column:
return
- if not existing.shares_lineage(value):
- util.warn(
- "Column %r on table %r being replaced by "
- "%r, which has the same key. Consider "
- "use_labels for select() statements."
- % (key, getattr(existing, "table", None), value)
- )
+ self.replace(column)
# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
- util.memoized_property.reset(value, "proxy_set")
+ util.memoized_property.reset(column, "proxy_set")
+ else:
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ self._index[key] = column
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
- self._all_columns.append(value)
- self._data[key] = value
+ replace_col = []
+ for k, col in cols:
+ if col.key != k:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ if col.name in self._index and col.key != col.name:
+ replace_col.append(col)
+ elif col.key in self._index:
+ replace_col.append(col)
+ else:
+ self._index[k] = col
+ self._collection.append((k, col))
+ self._colset.update(c for (k, c) in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ for col in replace_col:
+ self.replace(col)
- def clear(self):
- raise NotImplementedError()
+ def extend(self, iter_):
+ self._populate_separate_keys((col.key, col) for col in iter_)
def remove(self, column):
- del self._data[column.key]
- self._all_columns[:] = [
- c for c in self._all_columns if c is not column
+ if column not in self._colset:
+ raise ValueError(
+ "Can't remove column %r; column is not in this collection"
+ % column
+ )
+ del self._index[column.key]
+ self._colset.remove(column)
+ self._collection[:] = [
+ (k, c) for (k, c) in self._collection if c is not column
]
-
- def update(self, iter_):
- cols = list(iter_)
- all_col_set = set(self._all_columns)
- self._all_columns.extend(
- c for label, c in cols if c not in all_col_set
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
)
- self._data.update((label, c) for label, c in cols)
+ # delete higher index
+ del self._index[len(self._collection)]
- def extend(self, iter_):
- cols = list(iter_)
- all_col_set = set(self._all_columns)
- self._all_columns.extend(c for c in cols if c not in all_col_set)
- self._data.update((c.key, c) for c in cols)
+ def replace(self, column):
+ """add the given column to this collection, removing unaliased
+ versions of this column as well as existing columns with the
+ same key.
- __hash__ = None
+ e.g.::
- def __eq__(self, other):
- l = []
- for c in getattr(other, "_all_columns", other):
- for local in self._all_columns:
- if c.shares_lineage(local):
- l.append(c == local)
- return elements.and_(*l)
+ t = Table('sometable', metadata, Column('col1', Integer))
+ t.columns.replace(Column('col1', Integer, key='columnone'))
- def __contains__(self, other):
- if not isinstance(other, util.string_types):
- raise exc.ArgumentError("__contains__ requires a string argument")
- return util.OrderedProperties.__contains__(self, other)
+ will remove the original 'col1' from the collection, and add
+ the new column under the name 'columnname'.
- def __getstate__(self):
- return {"_data": self._data, "_all_columns": self._all_columns}
+ Used by schema.Column to override columns during table reflection.
- def __setstate__(self, state):
- object.__setattr__(self, "_data", state["_data"])
- object.__setattr__(self, "_all_columns", state["_all_columns"])
+ """
- def contains_column(self, col):
- return col in set(self._all_columns)
+ remove_col = set()
+ # remove up to two columns based on matches of name as well as key
+ if column.name in self._index and column.key != column.name:
+ other = self._index[column.name]
+ if other.name == other.key:
+ remove_col.add(other)
+
+ if column.key in self._index:
+ remove_col.add(self._index[column.key])
+
+ new_cols = []
+ replaced = False
+ for k, col in self._collection:
+ if col in remove_col:
+ if not replaced:
+ replaced = True
+ new_cols.append((column.key, column))
+ else:
+ new_cols.append((k, col))
- def as_immutable(self):
- return ImmutableColumnCollection(self._data, self._all_columns)
+ if remove_col:
+ self._colset.difference_update(remove_col)
+ if not replaced:
+ new_cols.append((column.key, column))
-class SeparateKeyColumnCollection(ColumnCollection):
- """Column collection that maintains a string name separate from the
- column itself"""
+ self._colset.add(column)
+ self._collection[:] = new_cols
- def __init__(self, cols_plus_names=None):
- super(ColumnCollection, self).__init__()
- object.__setattr__(self, "_all_columns", [])
- if cols_plus_names:
- self.update(cols_plus_names)
+ self._index.clear()
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
+ )
+ self._index.update(self._collection)
- def replace(self, column):
- raise NotImplementedError()
- def add(self, column):
- raise NotImplementedError()
+class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+ __slots__ = ("_parent",)
- def remove(self, column):
- raise NotImplementedError()
+ def __init__(self, collection):
+ object.__setattr__(self, "_parent", collection)
+ object.__setattr__(self, "_colset", collection._colset)
+ object.__setattr__(self, "_index", collection._index)
+ object.__setattr__(self, "_collection", collection._collection)
+ def __getstate__(self):
+ return {"_parent": self._parent}
-class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
- def __init__(self, data, all_columns):
- util.ImmutableProperties.__init__(self, data)
- object.__setattr__(self, "_all_columns", all_columns)
+ def __setstate__(self, state):
+ parent = state["_parent"]
+ self.__init__(parent)
- extend = remove = util.ImmutableProperties._immutable
+ add = extend = remove = util.ImmutableContainer._immutable
class ColumnSet(util.ordered_column_set):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 13219ee68..ea7e890e7 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1828,7 +1828,6 @@ class SQLCompiler(Compiled):
result_expr = _CompileLabel(
col_expr, name, alt_names=(column._key_label,)
)
-
elif (
asfrom
and isinstance(column, elements.ColumnClause)
@@ -1897,6 +1896,7 @@ class SQLCompiler(Compiled):
"""
cloned = {}
column_translate = [{}]
+ created = set()
def visit(element, **kw):
if element in column_translate[-1]:
@@ -1906,7 +1906,6 @@ class SQLCompiler(Compiled):
return cloned[element]
newelem = cloned[element] = element._clone()
-
if (
newelem._is_from_clause
and newelem._is_join
@@ -1921,6 +1920,8 @@ class SQLCompiler(Compiled):
selectable_ = selectable.Select(
[right.element], use_labels=True
).alias()
+ created.add(selectable_)
+ created.update(selectable_.c)
for c in selectable_.c:
c._key_label = c.key
@@ -1971,6 +1972,11 @@ class SQLCompiler(Compiled):
if barrier_select:
column_translate.append({})
kw["transform_clue"] = "inside_select"
+ if not newelem._is_select_container:
+ froms = newelem.froms
+ newelem._raw_columns = list(newelem.selected_columns)
+ newelem._from_obj.update(froms)
+ newelem._reset_memoizations()
newelem._copy_internals(clone=visit, **kw)
if barrier_select:
del column_translate[-1]
@@ -1984,17 +1990,33 @@ class SQLCompiler(Compiled):
def _transform_result_map_for_nested_joins(
self, select, transformed_select
):
- inner_col = dict(
- (c._key_label, c) for c in transformed_select.inner_columns
- )
-
- d = dict((inner_col[c._key_label], c) for c in select.inner_columns)
-
- self._result_columns = [
- (key, name, tuple([d.get(col, col) for col in objs]), typ)
- for key, name, objs, typ in self._result_columns
+ self._result_columns[:] = [
+ result_rec
+ if col is tcol
+ else (
+ result_rec[0],
+ name,
+ tuple([col if obj is tcol else obj for obj in result_rec[2]]),
+ result_rec[3],
+ )
+ for result_rec, (name, col), (tname, tcol) in zip(
+ self._result_columns,
+ select._columns_plus_names,
+ transformed_select._columns_plus_names,
+ )
]
+ # TODO: it's not anticipated that we need to correct anon_map
+ # however if we do, this is what it looks like:
+ # for (name, col), (tname, tcol) in zip(
+ # select._columns_plus_names,
+ # transformed_select._columns_plus_names,
+ # ):
+ # if isinstance(name, elements._anonymous_label) and name != tname:
+ # m1 = re.match(r"^%\((\d+ .+?)\)s$", name)
+ # m2 = re.match(r"^%\((\d+ .+?)\)s$", tname)
+ # self.anon_map[m1.group(1)] = self.anon_map[m2.group(1)]
+
_default_stack_entry = util.immutabledict(
[("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 5a9be7c62..66e92a63c 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -48,8 +48,9 @@ class UpdateBase(
named_with_column = False
def _generate_fromclause_column_proxies(self, fromclause):
- for col in self._returning:
- col._make_proxy(fromclause)
+ fromclause._columns._populate_separate_keys(
+ col._make_proxy(fromclause) for col in self._returning
+ )
def _process_colparams(self, parameters):
def process_single(p):
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 5b4442222..735a125d7 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -184,6 +184,7 @@ class ClauseElement(roles.SQLRole, Visitable):
_is_returns_rows = False
_is_text_clause = False
_is_from_container = False
+ _is_select_container = False
_is_select_statement = False
_order_by_label_element = None
@@ -856,8 +857,7 @@ class ColumnElement(
co._proxies = [self]
if selectable._is_clone_of is not None:
co._is_clone_of = selectable._is_clone_of.columns.get(key)
- selectable._columns[key] = co
- return co
+ return key, co
def cast(self, type_):
"""Produce a type cast, i.e. ``CAST(<expression> AS <type>)``.
@@ -887,6 +887,12 @@ class ColumnElement(
"""
return Label(name, self, self.type)
+ def _anon_label(self, seed):
+ while self._is_clone_of is not None:
+ self = self._is_clone_of
+
+ return _anonymous_label("%%(%d %s)s" % (id(self), seed or "anon"))
+
@util.memoized_property
def anon_label(self):
"""provides a constant 'anonymous label' for this ColumnElement.
@@ -901,12 +907,11 @@ class ColumnElement(
expressions and function calls.
"""
- while self._is_clone_of is not None:
- self = self._is_clone_of
+ return self._anon_label(getattr(self, "name", None))
- return _anonymous_label(
- "%%(%d %s)s" % (id(self), getattr(self, "name", "anon"))
- )
+ @util.memoized_property
+ def _label_anon_label(self):
+ return self._anon_label(getattr(self, "_label", None))
class BindParameter(roles.InElementRole, ColumnElement):
@@ -3951,7 +3956,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
return self.element._from_objects
def _make_proxy(self, selectable, name=None, **kw):
- e = self.element._make_proxy(
+ key, e = self.element._make_proxy(
selectable,
name=name if name else self.name,
disallow_is_literal=True,
@@ -3959,7 +3964,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
e._proxies.append(self)
if self._type is not None:
e.type = self._type
- return e
+ return key, e
class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
@@ -4214,7 +4219,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
self,
selectable,
name=None,
- attach=True,
name_is_truncatable=False,
disallow_is_literal=False,
**kw
@@ -4249,9 +4253,7 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
if selectable._is_clone_of is not None:
c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
- if attach:
- selectable._columns[c.key] = c
- return c
+ return c.key, c
class CollationClause(ColumnElement):
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 04fb16a80..2feb6fd5f 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -118,7 +118,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
- return ColumnCollection(self.label(None))
+ col = self.label(None)
+ return ColumnCollection(columns=[(col.key, col)])
@util.memoized_property
def clauses(self):
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index d39bc9832..23c58dc4e 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -40,7 +40,7 @@ from . import roles
from . import type_api
from . import visitors
from .base import _bind_or_error
-from .base import ColumnCollection
+from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import SchemaEventTarget
from .coercions import _document_text_coercion
@@ -538,7 +538,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
self.indexes = set()
self.constraints = set()
- self._columns = ColumnCollection()
+ self._columns = DedupeColumnCollection()
PrimaryKeyConstraint(
_implicit_generated=True
)._set_parent_with_dispatch(self)
@@ -1607,13 +1607,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
)
c.table = selectable
- selectable._columns.add(c)
if selectable._is_clone_of is not None:
c._is_clone_of = selectable._is_clone_of.columns[c.key]
if self.primary_key:
selectable.primary_key.add(c)
- c.dispatch.after_parent_attach(c, selectable)
- return c
+ if fk:
+ selectable.foreign_keys.update(fk)
+ return c.key, c
def get_children(self, schema_visitor=False, **kwargs):
if schema_visitor:
@@ -1983,19 +1983,20 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self._set_target_column(_column)
def _set_target_column(self, column):
+ assert isinstance(self.parent.table, Table)
+
# propagate TypeEngine to parent if it didn't have one
if self.parent.type._isnull:
self.parent.type = column.type
# super-edgy case, if other FKs point to our column,
# they'd get the type propagated out also.
- if isinstance(self.parent.table, Table):
- def set_type(fk):
- if fk.parent.type._isnull:
- fk.parent.type = column.type
+ def set_type(fk):
+ if fk.parent.type._isnull:
+ fk.parent.type = column.type
- self.parent._setup_on_memoized_fks(set_type)
+ self.parent._setup_on_memoized_fks(set_type)
self.column = column
@@ -2072,7 +2073,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def _set_table(self, column, table):
# standalone ForeignKey - create ForeignKeyConstraint
# on the hosting Table when attached to the Table.
- if self.constraint is None and isinstance(table, Table):
+ assert isinstance(table, Table)
+ if self.constraint is None:
self.constraint = ForeignKeyConstraint(
[],
[],
@@ -2088,7 +2090,6 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self.constraint._append_element(column, self)
self.constraint._set_parent_with_dispatch(table)
table.foreign_keys.add(self)
-
# set up remote ".column" attribute, or a note to pick it
# up when the other Table/Column shows up
if isinstance(self._colspec, util.string_types):
@@ -2760,7 +2761,7 @@ class ColumnCollectionMixin(object):
def __init__(self, *columns, **kw):
_autoattach = kw.pop("_autoattach", True)
self._column_flag = kw.pop("_column_flag", False)
- self.columns = ColumnCollection()
+ self.columns = DedupeColumnCollection()
self._pending_colargs = [
_to_schema_column_or_string(c) for c in columns
]
@@ -2885,14 +2886,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
return self.columns.contains_column(col)
def __iter__(self):
- # inlining of
- # return iter(self.columns)
- # ColumnCollection->OrderedProperties->OrderedDict
- ordered_dict = self.columns._data
- return (ordered_dict[key] for key in ordered_dict._list)
+ return iter(self.columns)
def __len__(self):
- return len(self.columns._data)
+ return len(self.columns)
class CheckConstraint(ColumnCollectionConstraint):
@@ -3368,11 +3365,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
table.constraints.add(self)
table_pks = [c for c in table.c if c.primary_key]
- if (
- self.columns
- and table_pks
- and set(table_pks) != set(self.columns.values())
- ):
+ if self.columns and table_pks and set(table_pks) != set(self.columns):
util.warn(
"Table '%s' specifies columns %s as primary_key=True, "
"not matching locally specified columns %s; setting the "
@@ -3390,7 +3383,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
for c in self.columns:
c.primary_key = True
c.nullable = False
- self.columns.extend(table_pks)
+ if table_pks:
+ self.columns.extend(table_pks)
def _reload(self, columns):
"""repopulate this :class:`.PrimaryKeyConstraint` given
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index c93de8d73..10643c9e4 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -27,10 +27,10 @@ from .base import _from_objects
from .base import _generative
from .base import ColumnCollection
from .base import ColumnSet
+from .base import DedupeColumnCollection
from .base import Executable
from .base import Generative
from .base import Immutable
-from .base import SeparateKeyColumnCollection
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import _select_iterables
@@ -534,8 +534,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._memoized_property.expire_instance(self)
def _generate_fromclause_column_proxies(self, fromclause):
- for col in self.c:
- col._make_proxy(fromclause)
+ fromclause._columns._populate_separate_keys(
+ col._make_proxy(fromclause) for col in self.c
+ )
@property
def exported_columns(self):
@@ -791,7 +792,9 @@ class Join(FromClause):
(c for c in columns if c.primary_key), self.onclause
)
)
- self._columns.update((col._key_label, col) for col in columns)
+ self._columns._populate_separate_keys(
+ (col._key_label, col) for col in columns
+ )
self.foreign_keys.update(
itertools.chain(*[col.foreign_keys for col in columns])
)
@@ -1861,7 +1864,7 @@ class TableClause(Immutable, FromClause):
super(TableClause, self).__init__()
self.name = self.fullname = name
- self._columns = ColumnCollection()
+ self._columns = DedupeColumnCollection()
self.primary_key = ColumnSet()
self.foreign_keys = set()
for c in columns:
@@ -1881,7 +1884,7 @@ class TableClause(Immutable, FromClause):
return self.name.encode("ascii", "backslashreplace")
def append_column(self, c):
- self._columns[c.key] = c
+ self._columns.add(c)
c.table = self
def get_children(self, column_collections=True, **kwargs):
@@ -2328,6 +2331,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
__visit_name__ = "grouping"
+ _is_select_container = True
+
def __init__(self, element):
# type: (SelectBase)
self.element = coercions.expect(roles.SelectStatementRole, element)
@@ -2526,6 +2531,7 @@ class GenerativeSelect(SelectBase):
FROM clauses to produce a unique set of column names regardless of
name conflicts among the individual FROM clauses.
+
"""
self.use_labels = True
@@ -4081,6 +4087,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
names = set()
+ cols = _select_iterables(self._raw_columns)
+
def name_for_col(c):
# we use key_label since this name is intended for targeting
# within the ColumnCollection only, it's not related to SQL
@@ -4090,18 +4098,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
else:
name = c._proxy_key
if name in names:
- name = c.anon_label
+ if self.use_labels:
+ name = c._label_anon_label
+ else:
+ name = c.anon_label
else:
names.add(name)
return name
- return SeparateKeyColumnCollection(
- (name_for_col(c), c)
- for c in util.unique_list(_select_iterables(self._raw_columns))
+ return ColumnCollection(
+ (name_for_col(c), c) for c in cols
).as_immutable()
@_memoized_property
def _columns_plus_names(self):
+ cols = _select_iterables(self._raw_columns)
+
if self.use_labels:
names = set()
@@ -4111,23 +4123,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
name = c._label
if name in names:
- name = c.anon_label
+ name = c._label_anon_label
else:
names.add(name)
return name, c
- return [
- name_for_col(c)
- for c in util.unique_list(_select_iterables(self._raw_columns))
- ]
+ return [name_for_col(c) for c in cols]
else:
- return [
- (None, c)
- for c in util.unique_list(_select_iterables(self._raw_columns))
- ]
+ return [(None, c) for c in cols]
def _generate_fromclause_column_proxies(self, subquery):
keys_seen = set()
+ prox = []
for name, c in self._columns_plus_names:
if not hasattr(c, "_make_proxy"):
@@ -4137,14 +4144,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
elif self.use_labels:
key = c._key_label
if key is not None and key in keys_seen:
- key = c.anon_label
+ key = c._label_anon_label
keys_seen.add(key)
else:
key = None
-
- c._make_proxy(
- subquery, key=key, name=name, name_is_truncatable=True
+ prox.append(
+ c._make_proxy(
+ subquery, key=key, name=name, name_is_truncatable=True
+ )
)
+ subquery._columns._populate_separate_keys(prox)
def _needs_parens_for_grouping(self):
return (
@@ -4397,7 +4406,9 @@ class TextualSelect(SelectBase):
.. versionadded:: 1.4
"""
- return ColumnCollection(*self.column_args).as_immutable()
+ return ColumnCollection(
+ (c.key, c) for c in self.column_args
+ ).as_immutable()
@property
def _bind(self):
@@ -4408,8 +4419,9 @@ class TextualSelect(SelectBase):
self.element = self.element.bindparams(*binds, **bind_as_values)
def _generate_fromclause_column_proxies(self, fromclause):
- for c in self.column_args:
- c._make_proxy(fromclause)
+ fromclause._columns._populate_separate_keys(
+ c._make_proxy(fromclause) for c in self.column_args
+ )
def _copy_internals(self, clone=_clone, **kw):
self._reset_memoizations()