summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/base.py244
-rw-r--r--lib/sqlalchemy/sql/cache_key.py2
-rw-r--r--lib/sqlalchemy/sql/elements.py52
-rw-r--r--lib/sqlalchemy/sql/schema.py7
-rw-r--r--lib/sqlalchemy/sql/selectable.py146
5 files changed, 327 insertions, 124 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index a8901c144..34b295113 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -13,8 +13,8 @@
from __future__ import annotations
+import collections
from enum import Enum
-from functools import reduce
import itertools
from itertools import zip_longest
import operator
@@ -280,10 +280,6 @@ def _expand_cloned(elements):
"""
# TODO: cython candidate
- # and/or change approach: in
- # https://gerrit.sqlalchemy.org/c/sqlalchemy/sqlalchemy/+/3712 we propose
- # getting rid of _cloned_set.
- # turning this into chain.from_iterable adds all kinds of callcount
return itertools.chain(*[x._cloned_set for x in elements])
@@ -1316,6 +1312,50 @@ _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]")
+class _ColumnMetrics(Generic[_COL_co]):
+ __slots__ = ("column",)
+
+ column: _COL_co
+
+ def __init__(
+ self, collection: ColumnCollection[Any, _COL_co], col: _COL_co
+ ):
+ self.column = col
+
+ # proxy_index being non-empty means it was initialized.
+ # so we need to update it
+ pi = collection._proxy_index
+ if pi:
+ for eps_col in col._expanded_proxy_set:
+ pi[eps_col].add(self)
+
+ def get_expanded_proxy_set(self):
+ return self.column._expanded_proxy_set
+
+ def dispose(self, collection):
+ pi = collection._proxy_index
+ if not pi:
+ return
+ for col in self.column._expanded_proxy_set:
+ colset = pi.get(col, None)
+ if colset:
+ colset.discard(self)
+ if colset is not None and not colset:
+ del pi[col]
+
+ def embedded(
+ self,
+ target_set: Union[
+ Set[ColumnElement[Any]], FrozenSet[ColumnElement[Any]]
+ ],
+ ) -> bool:
+ expanded_proxy_set = self.column._expanded_proxy_set
+ for t in target_set.difference(expanded_proxy_set):
+ if not expanded_proxy_set.intersection(_expand_cloned([t])):
+ return False
+ return True
+
+
class ColumnCollection(Generic[_COLKEY, _COL_co]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
@@ -1425,10 +1465,11 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
"""
- __slots__ = "_collection", "_index", "_colset"
+ __slots__ = "_collection", "_index", "_colset", "_proxy_index"
- _collection: List[Tuple[_COLKEY, _COL_co]]
+ _collection: List[Tuple[_COLKEY, _COL_co, _ColumnMetrics[_COL_co]]]
_index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]]
+ _proxy_index: Dict[ColumnElement[Any], Set[_ColumnMetrics[_COL_co]]]
_colset: Set[_COL_co]
def __init__(
@@ -1436,6 +1477,9 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
+ object.__setattr__(
+ self, "_proxy_index", collections.defaultdict(util.OrderedSet)
+ )
object.__setattr__(self, "_collection", [])
if columns:
self._initial_populate(columns)
@@ -1457,18 +1501,18 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
@property
def _all_columns(self) -> List[_COL_co]:
- return [col for (_, col) in self._collection]
+ return [col for (_, col, _) in self._collection]
def keys(self) -> List[_COLKEY]:
"""Return a sequence of string key names for all columns in this
collection."""
- return [k for (k, _) in self._collection]
+ return [k for (k, _, _) in self._collection]
def values(self) -> List[_COL_co]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
- return [col for (_, col) in self._collection]
+ return [col for (_, col, _) in self._collection]
def items(self) -> List[Tuple[_COLKEY, _COL_co]]:
"""Return a sequence of (key, column) tuples for all columns in this
@@ -1477,7 +1521,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
:class:`_schema.Column` object.
"""
- return list(self._collection)
+ return [(k, col) for (k, col, _) in self._collection]
def __bool__(self) -> bool:
return bool(self._collection)
@@ -1487,7 +1531,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
def __iter__(self) -> Iterator[_COL_co]:
# turn to a list first to maintain over a course of changes
- return iter([col for _, col in self._collection])
+ return iter([col for _, col, _ in self._collection])
@overload
def __getitem__(self, key: Union[str, int]) -> _COL_co:
@@ -1591,16 +1635,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
) -> None:
"""populate from an iterator of (key, column)"""
- cols = list(iter_)
- self._collection[:] = cols
- self._colset.update(c for k, c in self._collection)
- self._index.update(
- (idx, (k, c)) for idx, (k, c) in enumerate(self._collection)
- )
+ self._collection[:] = collection = [
+ (k, c, _ColumnMetrics(self, c)) for k, c in iter_
+ ]
+ self._colset.update(c._deannotate() for _, c, _ in collection)
self._index.update(
- {k: (k, col) for k, col in reversed(self._collection)}
+ {idx: (k, c) for idx, (k, c, _) in enumerate(collection)}
)
+ self._index.update({k: (k, col) for k, col, _ in reversed(collection)})
def add(
self, column: ColumnElement[Any], key: Optional[_COLKEY] = None
@@ -1630,23 +1673,35 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
_column = cast(_COL_co, column)
- self._collection.append((colkey, _column))
- self._colset.add(_column)
+ self._collection.append(
+ (colkey, _column, _ColumnMetrics(self, _column))
+ )
+ self._colset.add(_column._deannotate())
self._index[l] = (colkey, _column)
if colkey not in self._index:
self._index[colkey] = (colkey, _column)
def __getstate__(self) -> Dict[str, Any]:
return {
- "_collection": self._collection,
+ "_collection": [(k, c) for k, c, _ in self._collection],
"_index": self._index,
}
def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_index", state["_index"])
- object.__setattr__(self, "_collection", state["_collection"])
object.__setattr__(
- self, "_colset", {col for k, col in self._collection}
+ self, "_proxy_index", collections.defaultdict(util.OrderedSet)
+ )
+ object.__setattr__(
+ self,
+ "_collection",
+ [
+ (k, c, _ColumnMetrics(self, c))
+ for (k, c) in state["_collection"]
+ ],
+ )
+ object.__setattr__(
+ self, "_colset", {col for k, col, _ in self._collection}
)
def contains_column(self, col: ColumnElement[Any]) -> bool:
@@ -1667,6 +1722,32 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
return ReadOnlyColumnCollection(self)
+ def _init_proxy_index(self):
+ """populate the "proxy index", if empty.
+
+ proxy index is added in 2.0 to provide more efficient operation
+ for the corresponding_column() method.
+
+ For reasons of both time to construct new .c collections as well as
+ memory conservation for large numbers of large .c collections, the
+ proxy_index is only filled if corresponding_column() is called. once
+ filled it stays that way, and new _ColumnMetrics objects created after
+ that point will populate it with new data. Note this case would be
+ unusual, if not nonexistent, as it means a .c collection is being
+ mutated after corresponding_column() were used, however it is tested in
+ test/base/test_utils.py.
+
+ """
+ pi = self._proxy_index
+ if pi:
+ return
+
+ for _, _, metrics in self._collection:
+ eps = metrics.column._expanded_proxy_set
+
+ for eps_col in eps:
+ pi[eps_col].add(metrics)
+
def corresponding_column(
self, column: _COL, require_embedded: bool = False
) -> Optional[Union[_COL, _COL_co]]:
@@ -1706,38 +1787,40 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
if column in self._colset:
return column
- def embedded(expanded_proxy_set, target_set):
- for t in target_set.difference(expanded_proxy_set):
- if not set(_expand_cloned([t])).intersection(
- expanded_proxy_set
- ):
- return False
- return True
-
- col, intersect = None, None
+ selected_intersection, selected_metrics = None, None
target_set = column.proxy_set
- cols = [c for (_, c) in self._collection]
- for c in cols:
- expanded_proxy_set = set(_expand_cloned(c.proxy_set))
- i = target_set.intersection(expanded_proxy_set)
- if i and (
- not require_embedded
- or embedded(expanded_proxy_set, target_set)
- ):
- if col is None or intersect is None:
+ pi = self._proxy_index
+ if not pi:
+ self._init_proxy_index()
+
+ for current_metrics in (
+ mm for ts in target_set if ts in pi for mm in pi[ts]
+ ):
+ if not require_embedded or current_metrics.embedded(target_set):
+ if selected_metrics is None:
# no corresponding column yet, pick this one.
+ selected_metrics = current_metrics
+ continue
- col, intersect = c, i
- elif len(i) > len(intersect):
+ current_intersection = target_set.intersection(
+ current_metrics.column._expanded_proxy_set
+ )
+ if selected_intersection is None:
+ selected_intersection = target_set.intersection(
+ selected_metrics.column._expanded_proxy_set
+ )
- # 'c' has a larger field of correspondence than
- # 'col'. i.e. selectable.c.a1_x->a1.c.x->table.c.x
+ if len(current_intersection) > len(selected_intersection):
+
+ # 'current' has a larger field of correspondence than
+ # 'selected'. i.e. selectable.c.a1_x->a1.c.x->table.c.x
# matches a1.c.x->table.c.x better than
# selectable.c.x->table.c.x does.
- col, intersect = c, i
- elif i == intersect:
+ selected_metrics = current_metrics
+ selected_intersection = current_intersection
+ elif current_intersection == selected_intersection:
# they have the same field of correspondence. see
# which proxy_set has fewer columns in it, which
# indicates a closer relationship with the root
@@ -1748,25 +1831,29 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
# columns that have no reference to the target
# column (also occurs with CompoundSelect)
- col_distance = reduce(
- operator.add,
+ selected_col_distance = sum(
[
sc._annotations.get("weight", 1)
- for sc in col._uncached_proxy_set()
+ for sc in (
+ selected_metrics.column._uncached_proxy_list()
+ )
if sc.shares_lineage(column)
],
)
- c_distance = reduce(
- operator.add,
+ current_col_distance = sum(
[
sc._annotations.get("weight", 1)
- for sc in c._uncached_proxy_set()
+ for sc in (
+ current_metrics.column._uncached_proxy_list()
+ )
if sc.shares_lineage(column)
],
)
- if c_distance < col_distance:
- col, intersect = c, i
- return col
+ if current_col_distance < selected_col_distance:
+ selected_metrics = current_metrics
+ selected_intersection = current_intersection
+
+ return selected_metrics.column if selected_metrics else None
_NAMEDCOL = TypeVar("_NAMEDCOL", bound="NamedColumn[Any]")
@@ -1816,8 +1903,10 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
util.memoized_property.reset(named_column, "proxy_set")
else:
l = len(self._collection)
- self._collection.append((key, named_column))
- self._colset.add(named_column)
+ self._collection.append(
+ (key, named_column, _ColumnMetrics(self, named_column))
+ )
+ self._colset.add(named_column._deannotate())
self._index[l] = (key, named_column)
self._index[key] = (key, named_column)
@@ -1840,11 +1929,11 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
replace_col.append(col)
else:
self._index[k] = (k, col)
- self._collection.append((k, col))
- self._colset.update(c for (k, c) in self._collection)
+ self._collection.append((k, col, _ColumnMetrics(self, col)))
+ self._colset.update(c._deannotate() for (k, c, _) in self._collection)
self._index.update(
- (idx, (k, c)) for idx, (k, c) in enumerate(self._collection)
+ (idx, (k, c)) for idx, (k, c, _) in enumerate(self._collection)
)
for col in replace_col:
self.replace(col)
@@ -1861,11 +1950,15 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
del self._index[column.key]
self._colset.remove(column)
self._collection[:] = [
- (k, c) for (k, c) in self._collection if c is not column
+ (k, c, metrics)
+ for (k, c, metrics) in self._collection
+ if c is not column
]
+ for metrics in self._proxy_index.get(column, ()):
+ metrics.dispose(self)
self._index.update(
- {idx: (k, col) for idx, (k, col) in enumerate(self._collection)}
+ {idx: (k, col) for idx, (k, col, _) in enumerate(self._collection)}
)
# delete higher index
del self._index[len(self._collection)]
@@ -1897,31 +1990,37 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
if column.key in self._index:
remove_col.add(self._index[column.key][1])
- new_cols: List[Tuple[str, _NAMEDCOL]] = []
+ new_cols: List[Tuple[str, _NAMEDCOL, _ColumnMetrics[_NAMEDCOL]]] = []
replaced = False
- for k, col in self._collection:
+ for k, col, metrics in self._collection:
if col in remove_col:
if not replaced:
replaced = True
- new_cols.append((column.key, column))
+ new_cols.append(
+ (column.key, column, _ColumnMetrics(self, column))
+ )
else:
- new_cols.append((k, col))
+ new_cols.append((k, col, metrics))
if remove_col:
self._colset.difference_update(remove_col)
+ for rc in remove_col:
+ for metrics in self._proxy_index.get(rc, ()):
+ metrics.dispose(self)
+
if not replaced:
- new_cols.append((column.key, column))
+ new_cols.append((column.key, column, _ColumnMetrics(self, column)))
- self._colset.add(column)
+ self._colset.add(column._deannotate())
self._collection[:] = new_cols
self._index.clear()
self._index.update(
- {idx: (k, col) for idx, (k, col) in enumerate(self._collection)}
+ {idx: (k, col) for idx, (k, col, _) in enumerate(self._collection)}
)
- self._index.update({k: (k, col) for (k, col) in self._collection})
+ self._index.update({k: (k, col) for (k, col, _) in self._collection})
class ReadOnlyColumnCollection(
@@ -1934,6 +2033,7 @@ class ReadOnlyColumnCollection(
object.__setattr__(self, "_colset", collection._colset)
object.__setattr__(self, "_index", collection._index)
object.__setattr__(self, "_collection", collection._collection)
+ object.__setattr__(self, "_proxy_index", collection._proxy_index)
def __getstate__(self):
return {"_parent": self._parent}
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index 39d09d3ab..94b19b7f2 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -949,7 +949,7 @@ class _CacheKeyTraversal(HasTraversalDispatch):
attrname,
tuple(
col._gen_cache_key(anon_map, bindparams)
- for k, col in obj._collection
+ for k, col, _ in obj._collection
),
)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 6a5aa7db9..044bdf585 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -52,6 +52,7 @@ from ._typing import is_tuple_type
from .annotation import Annotated
from .annotation import SupportsWrappingAnnotations
from .base import _clone
+from .base import _expand_cloned
from .base import _generative
from .base import _NoArg
from .base import Executable
@@ -1464,21 +1465,32 @@ class ColumnElement(
@util.memoized_property
def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
- return frozenset([self]).union(
- itertools.chain.from_iterable(c.proxy_set for c in self._proxies)
+ """set of all columns we are proxying
+
+ as of 2.0 this is explicitly deannotated columns. previously it was
+ effectively deannotated columns but wasn't enforced. annotated
+ columns should basically not go into sets if at all possible because
+ their hashing behavior is very non-performant.
+
+ """
+ return frozenset([self._deannotate()]).union(
+ itertools.chain(*[c.proxy_set for c in self._proxies])
)
- def _uncached_proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
+ @util.memoized_property
+ def _expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
+ return frozenset(_expand_cloned(self.proxy_set))
+
+ def _uncached_proxy_list(self) -> List[ColumnElement[Any]]:
"""An 'uncached' version of proxy set.
- This is so that we can read annotations from the list of columns
- without breaking the caching of the above proxy_set.
+ This list includes annotated columns which perform very poorly in
+ set operations.
"""
- return frozenset([self]).union(
- itertools.chain.from_iterable(
- c._uncached_proxy_set() for c in self._proxies
- )
+
+ return [self] + list(
+ itertools.chain(*[c._uncached_proxy_list() for c in self._proxies])
)
def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool:
@@ -1540,6 +1552,7 @@ class ColumnElement(
name: Optional[str] = None,
key: Optional[str] = None,
name_is_truncatable: bool = False,
+ compound_select_cols: Optional[Sequence[ColumnElement[Any]]] = None,
**kw: Any,
) -> typing_Tuple[str, ColumnClause[_T]]:
"""Create a new :class:`_expression.ColumnElement` representing this
@@ -1565,7 +1578,10 @@ class ColumnElement(
)
co._propagate_attrs = selectable._propagate_attrs
- co._proxies = [self]
+ if compound_select_cols:
+ co._proxies = list(compound_select_cols)
+ else:
+ co._proxies = [self]
if selectable._is_clone_of is not None:
co._is_clone_of = selectable._is_clone_of.columns.get(key)
return key, co
@@ -4303,6 +4319,7 @@ class NamedColumn(KeyedColumnElement[_T]):
name: Optional[str] = None,
key: Optional[str] = None,
name_is_truncatable: bool = False,
+ compound_select_cols: Optional[Sequence[ColumnElement[Any]]] = None,
disallow_is_literal: bool = False,
**kw: Any,
) -> typing_Tuple[str, ColumnClause[_T]]:
@@ -4318,7 +4335,11 @@ class NamedColumn(KeyedColumnElement[_T]):
c._propagate_attrs = selectable._propagate_attrs
if name is None:
c.key = self.key
- c._proxies = [self]
+ if compound_select_cols:
+ c._proxies = list(compound_select_cols)
+ else:
+ c._proxies = [self]
+
if selectable._is_clone_of is not None:
c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
return c.key, c
@@ -4466,6 +4487,7 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
selectable: FromClause,
*,
name: Optional[str] = None,
+ compound_select_cols: Optional[Sequence[ColumnElement[Any]]] = None,
**kw: Any,
) -> typing_Tuple[str, ColumnClause[_T]]:
name = self.name if not name else name
@@ -4475,6 +4497,7 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
name=name,
disallow_is_literal=True,
name_is_truncatable=isinstance(name, _truncated_label),
+ compound_select_cols=compound_select_cols,
)
# there was a note here to remove this assertion, which was here
@@ -4710,6 +4733,7 @@ class ColumnClause(
name: Optional[str] = None,
key: Optional[str] = None,
name_is_truncatable: bool = False,
+ compound_select_cols: Optional[Sequence[ColumnElement[Any]]] = None,
disallow_is_literal: bool = False,
**kw: Any,
) -> typing_Tuple[str, ColumnClause[_T]]:
@@ -4740,7 +4764,11 @@ class ColumnClause(
c._propagate_attrs = selectable._propagate_attrs
if name is None:
c.key = self.key
- c._proxies = [self]
+ if compound_select_cols:
+ c._proxies = list(compound_select_cols)
+ else:
+ c._proxies = [self]
+
if selectable._is_clone_of is not None:
c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
return c.key, c
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index dde5cd372..cd10d0c4a 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -2360,6 +2360,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
name: Optional[str] = None,
key: Optional[str] = None,
name_is_truncatable: bool = False,
+ compound_select_cols: Optional[
+ _typing_Sequence[ColumnElement[Any]]
+ ] = None,
**kw: Any,
) -> Tuple[str, ColumnClause[_T]]:
"""Create a *proxy* for this column.
@@ -2401,7 +2404,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
key=key if key else name if name else self.key,
primary_key=self.primary_key,
nullable=self.nullable,
- _proxies=[self],
+ _proxies=list(compound_select_cols)
+ if compound_select_cols
+ else [self],
*fk,
)
except TypeError as err:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 08d01c883..8c64dea9d 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -3316,6 +3316,15 @@ class SelectBase(
"""
raise NotImplementedError()
+ def _generate_fromclause_column_proxies(
+ self,
+ subquery: FromClause,
+ proxy_compound_columns: Optional[
+ Iterable[Sequence[ColumnElement[Any]]]
+ ] = None,
+ ) -> None:
+ raise NotImplementedError()
+
@util.ro_non_memoized_property
def _all_selected_columns(self) -> _SelectIterable:
"""A sequence of expressions that correspond to what is rendered
@@ -3621,9 +3630,15 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
# return self.element._generate_columns_plus_names(anon_for_dupe_key)
def _generate_fromclause_column_proxies(
- self, subquery: FromClause
+ self,
+ subquery: FromClause,
+ proxy_compound_columns: Optional[
+ Iterable[Sequence[ColumnElement[Any]]]
+ ] = None,
) -> None:
- self.element._generate_fromclause_column_proxies(subquery)
+ self.element._generate_fromclause_column_proxies(
+ subquery, proxy_compound_columns=proxy_compound_columns
+ )
@util.ro_non_memoized_property
def _all_selected_columns(self) -> _SelectIterable:
@@ -4308,38 +4323,50 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
return self
def _generate_fromclause_column_proxies(
- self, subquery: FromClause
+ self,
+ subquery: FromClause,
+ proxy_compound_columns: Optional[
+ Iterable[Sequence[ColumnElement[Any]]]
+ ] = None,
) -> None:
# this is a slightly hacky thing - the union exports a
# column that resembles just that of the *first* selectable.
# to get at a "composite" column, particularly foreign keys,
# you have to dig through the proxies collection which we
- # generate below. We may want to improve upon this, such as
- # perhaps _make_proxy can accept a list of other columns
- # that are "shared" - schema.column can then copy all the
- # ForeignKeys in. this would allow the union() to have all
- # those fks too.
+ # generate below.
select_0 = self.selects[0]
if self._label_style is not LABEL_STYLE_DEFAULT:
select_0 = select_0.set_label_style(self._label_style)
- select_0._generate_fromclause_column_proxies(subquery)
# hand-construct the "_proxies" collection to include all
# derived columns place a 'weight' annotation corresponding
# to how low in the list of select()s the column occurs, so
# that the corresponding_column() operation can resolve
# conflicts
-
- for subq_col, select_cols in zip(
- subquery.c._all_columns,
- zip(*[s.selected_columns for s in self.selects]),
- ):
- subq_col._proxies = [
- c._annotate({"weight": i + 1})
- for (i, c) in enumerate(select_cols)
+ extra_col_iterator = zip(
+ *[
+ [
+ c._annotate(dd)
+ for c in stmt._all_selected_columns
+ if is_column_element(c)
+ ]
+ for dd, stmt in [
+ ({"weight": i + 1}, stmt)
+ for i, stmt in enumerate(self.selects)
+ ]
]
+ )
+
+ # the incoming proxy_compound_columns can be present also if this is
+ # a compound embedded in a compound. it's probably more appropriate
+ # that we generate new weights local to this nested compound, though
+ # i haven't tried to think what it means for compound nested in
+ # compound
+ select_0._generate_fromclause_column_proxies(
+ subquery, proxy_compound_columns=extra_col_iterator
+ )
def _refresh_for_new_column(self, column):
super(CompoundSelect, self)._refresh_for_new_column(column)
@@ -6172,27 +6199,60 @@ class Select(
return self
def _generate_fromclause_column_proxies(
- self, subquery: FromClause
+ self,
+ subquery: FromClause,
+ proxy_compound_columns: Optional[
+ Iterable[Sequence[ColumnElement[Any]]]
+ ] = None,
) -> None:
"""Generate column proxies to place in the exported ``.c``
collection of a subquery."""
- prox = [
- c._make_proxy(
- subquery,
- key=proxy_key,
- name=required_label_name,
- name_is_truncatable=True,
- )
- for (
- required_label_name,
- proxy_key,
- fallback_label_name,
- c,
- repeated,
- ) in (self._generate_columns_plus_names(False))
- if is_column_element(c)
- ]
+ if proxy_compound_columns:
+ extra_col_iterator = proxy_compound_columns
+ prox = [
+ c._make_proxy(
+ subquery,
+ key=proxy_key,
+ name=required_label_name,
+ name_is_truncatable=True,
+ compound_select_cols=extra_cols,
+ )
+ for (
+ (
+ required_label_name,
+ proxy_key,
+ fallback_label_name,
+ c,
+ repeated,
+ ),
+ extra_cols,
+ ) in (
+ zip(
+ self._generate_columns_plus_names(False),
+ extra_col_iterator,
+ )
+ )
+ if is_column_element(c)
+ ]
+ else:
+
+ prox = [
+ c._make_proxy(
+ subquery,
+ key=proxy_key,
+ name=required_label_name,
+ name_is_truncatable=True,
+ )
+ for (
+ required_label_name,
+ proxy_key,
+ fallback_label_name,
+ c,
+ repeated,
+ ) in (self._generate_columns_plus_names(False))
+ if is_column_element(c)
+ ]
subquery._columns._populate_separate_keys(prox)
@@ -6739,10 +6799,20 @@ class TextualSelect(SelectBase, Executable, Generative):
self.element = self.element.bindparams(*binds, **bind_as_values)
return self
- def _generate_fromclause_column_proxies(self, fromclause):
- fromclause._columns._populate_separate_keys(
- c._make_proxy(fromclause) for c in self.column_args
- )
+ def _generate_fromclause_column_proxies(
+ self, fromclause, proxy_compound_columns=None
+ ):
+ if proxy_compound_columns:
+ fromclause._columns._populate_separate_keys(
+ c._make_proxy(fromclause, compound_select_cols=extra_cols)
+ for c, extra_cols in zip(
+ self.column_args, proxy_compound_columns
+ )
+ )
+ else:
+ fromclause._columns._populate_separate_keys(
+ c._make_proxy(fromclause) for c in self.column_args
+ )
def _scalar_type(self):
return self.column_args[0].type