diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 86 |
1 files changed, 60 insertions, 26 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 70c01d8d3..fbbf9f7f7 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -32,6 +32,7 @@ from typing import Mapping from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Set from typing import Tuple @@ -64,6 +65,7 @@ if TYPE_CHECKING: from . import elements from . import type_api from .elements import BindParameter + from .elements import ClauseList from .elements import ColumnClause # noqa from .elements import ColumnElement from .elements import KeyedColumnElement @@ -1396,7 +1398,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): __slots__ = "_collection", "_index", "_colset" _collection: List[Tuple[_COLKEY, _COL_co]] - _index: Dict[Union[None, str, int], _COL_co] + _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]] _colset: Set[_COL_co] def __init__( @@ -1408,6 +1410,16 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): if columns: self._initial_populate(columns) + @util.preload_module("sqlalchemy.sql.elements") + def __clause_element__(self) -> ClauseList: + elements = util.preloaded.sql_elements + + return elements.ClauseList( + _literal_as_text_role=roles.ColumnsClauseRole, + group=False, + *self._all_columns, + ) + def _initial_populate( self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] ) -> None: @@ -1415,18 +1427,18 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): @property def _all_columns(self) -> List[_COL_co]: - return [col for (k, col) in self._collection] + return [col for (_, col) in self._collection] def keys(self) -> List[_COLKEY]: """Return a sequence of string key names for all columns in this collection.""" - return [k for (k, col) in self._collection] + return [k for (k, _) in self._collection] def values(self) -> List[_COL_co]: """Return a sequence of :class:`_sql.ColumnClause` or :class:`_schema.Column` objects for all columns in this collection.""" - return [col for (k, col) in self._collection] + return [col for (_, col) in self._collection] def items(self) -> List[Tuple[_COLKEY, _COL_co]]: """Return a sequence of (key, column) tuples for all columns in this @@ -1445,20 +1457,37 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): def __iter__(self) -> Iterator[_COL_co]: # turn to a list first to maintain over a course of changes - return iter([col for k, col in self._collection]) + return iter([col for _, col in self._collection]) + @overload def __getitem__(self, key: Union[str, int]) -> _COL_co: + ... + + @overload + def __getitem__( + self, key: Tuple[Union[str, int], ...] + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: + ... + + def __getitem__( + self, key: Union[str, int, Tuple[Union[str, int], ...]] + ) -> Union[ReadOnlyColumnCollection[_COLKEY, _COL_co], _COL_co]: try: - return self._index[key] + if isinstance(key, tuple): + return ColumnCollection( # type: ignore + [self._index[sub_key] for sub_key in key] + ).as_readonly() + else: + return self._index[key][1] except KeyError as err: - if isinstance(key, int): - raise IndexError(key) from err + if isinstance(err.args[0], int): + raise IndexError(err.args[0]) from err else: raise def __getattr__(self, key: str) -> _COL_co: try: - return self._index[key] + return self._index[key][1] except KeyError as err: raise AttributeError(key) from err @@ -1493,7 +1522,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): :class:`_expression.ColumnCollection`.""" if key in self._index: - return self._index[key] + return self._index[key][1] else: return default @@ -1537,9 +1566,11 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): self._collection[:] = cols self._colset.update(c for k, c in self._collection) self._index.update( - (idx, c) for idx, (k, c) in enumerate(self._collection) + (idx, (k, c)) for idx, (k, c) in enumerate(self._collection) + ) + self._index.update( + {k: (k, col) for k, col in reversed(self._collection)} ) - self._index.update({k: col for k, col in reversed(self._collection)}) def add( self, column: ColumnElement[Any], key: Optional[_COLKEY] = None @@ -1571,12 +1602,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): self._collection.append((colkey, _column)) self._colset.add(_column) - self._index[l] = _column + self._index[l] = (colkey, _column) if colkey not in self._index: - self._index[colkey] = _column + self._index[colkey] = (colkey, _column) def __getstate__(self) -> Dict[str, Any]: - return {"_collection": self._collection, "_index": self._index} + return { + "_collection": self._collection, + "_index": self._index, + } def __setstate__(self, state: Dict[str, Any]) -> None: object.__setattr__(self, "_index", state["_index"]) @@ -1652,7 +1686,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): col, intersect = None, None target_set = column.proxy_set - cols = [c for (k, c) in self._collection] + cols = [c for (_, c) in self._collection] for c in cols: expanded_proxy_set = set(_expand_cloned(c.proxy_set)) i = target_set.intersection(expanded_proxy_set) @@ -1739,7 +1773,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): if key in self._index: - existing = self._index[key] + existing = self._index[key][1] if existing is named_column: return @@ -1754,8 +1788,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): l = len(self._collection) self._collection.append((key, named_column)) self._colset.add(named_column) - self._index[l] = named_column - self._index[key] = named_column + self._index[l] = (key, named_column) + self._index[key] = (key, named_column) def _populate_separate_keys( self, iter_: Iterable[Tuple[str, _NAMEDCOL]] @@ -1775,12 +1809,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): elif col.key in self._index: replace_col.append(col) else: - self._index[k] = col + self._index[k] = (k, col) self._collection.append((k, col)) self._colset.update(c for (k, c) in self._collection) self._index.update( - (idx, c) for idx, (k, c) in enumerate(self._collection) + (idx, (k, c)) for idx, (k, c) in enumerate(self._collection) ) for col in replace_col: self.replace(col) @@ -1801,7 +1835,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): ] self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} + {idx: (k, col) for idx, (k, col) in enumerate(self._collection)} ) # delete higher index del self._index[len(self._collection)] @@ -1826,12 +1860,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): remove_col = set() # remove up to two columns based on matches of name as well as key if column.name in self._index and column.key != column.name: - other = self._index[column.name] + other = self._index[column.name][1] if other.name == other.key: remove_col.add(other) if column.key in self._index: - remove_col.add(self._index[column.key]) + remove_col.add(self._index[column.key][1]) new_cols: List[Tuple[str, _NAMEDCOL]] = [] replaced = False @@ -1855,9 +1889,9 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): self._index.clear() self._index.update( - {idx: col for idx, (k, col) in enumerate(self._collection)} + {idx: (k, col) for idx, (k, col) in enumerate(self._collection)} ) - self._index.update(self._collection) + self._index.update({k: (k, col) for (k, col) in self._collection}) class ReadOnlyColumnCollection( |