summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py86
1 files changed, 60 insertions, 26 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 70c01d8d3..fbbf9f7f7 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -32,6 +32,7 @@ from typing import Mapping
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
@@ -64,6 +65,7 @@ if TYPE_CHECKING:
from . import elements
from . import type_api
from .elements import BindParameter
+ from .elements import ClauseList
from .elements import ColumnClause # noqa
from .elements import ColumnElement
from .elements import KeyedColumnElement
@@ -1396,7 +1398,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
__slots__ = "_collection", "_index", "_colset"
_collection: List[Tuple[_COLKEY, _COL_co]]
- _index: Dict[Union[None, str, int], _COL_co]
+ _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]]
_colset: Set[_COL_co]
def __init__(
@@ -1408,6 +1410,16 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
if columns:
self._initial_populate(columns)
+ @util.preload_module("sqlalchemy.sql.elements")
+ def __clause_element__(self) -> ClauseList:
+ elements = util.preloaded.sql_elements
+
+ return elements.ClauseList(
+ _literal_as_text_role=roles.ColumnsClauseRole,
+ group=False,
+ *self._all_columns,
+ )
+
def _initial_populate(
self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
) -> None:
@@ -1415,18 +1427,18 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
@property
def _all_columns(self) -> List[_COL_co]:
- return [col for (k, col) in self._collection]
+ return [col for (_, col) in self._collection]
def keys(self) -> List[_COLKEY]:
"""Return a sequence of string key names for all columns in this
collection."""
- return [k for (k, col) in self._collection]
+ return [k for (k, _) in self._collection]
def values(self) -> List[_COL_co]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
- return [col for (k, col) in self._collection]
+ return [col for (_, col) in self._collection]
def items(self) -> List[Tuple[_COLKEY, _COL_co]]:
"""Return a sequence of (key, column) tuples for all columns in this
@@ -1445,20 +1457,37 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
def __iter__(self) -> Iterator[_COL_co]:
# turn to a list first to maintain over a course of changes
- return iter([col for k, col in self._collection])
+ return iter([col for _, col in self._collection])
+ @overload
def __getitem__(self, key: Union[str, int]) -> _COL_co:
+ ...
+
+ @overload
+ def __getitem__(
+ self, key: Tuple[Union[str, int], ...]
+ ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]:
+ ...
+
+ def __getitem__(
+ self, key: Union[str, int, Tuple[Union[str, int], ...]]
+ ) -> Union[ReadOnlyColumnCollection[_COLKEY, _COL_co], _COL_co]:
try:
- return self._index[key]
+ if isinstance(key, tuple):
+ return ColumnCollection( # type: ignore
+ [self._index[sub_key] for sub_key in key]
+ ).as_readonly()
+ else:
+ return self._index[key][1]
except KeyError as err:
- if isinstance(key, int):
- raise IndexError(key) from err
+ if isinstance(err.args[0], int):
+ raise IndexError(err.args[0]) from err
else:
raise
def __getattr__(self, key: str) -> _COL_co:
try:
- return self._index[key]
+ return self._index[key][1]
except KeyError as err:
raise AttributeError(key) from err
@@ -1493,7 +1522,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
:class:`_expression.ColumnCollection`."""
if key in self._index:
- return self._index[key]
+ return self._index[key][1]
else:
return default
@@ -1537,9 +1566,11 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
self._collection[:] = cols
self._colset.update(c for k, c in self._collection)
self._index.update(
- (idx, c) for idx, (k, c) in enumerate(self._collection)
+ (idx, (k, c)) for idx, (k, c) in enumerate(self._collection)
+ )
+ self._index.update(
+ {k: (k, col) for k, col in reversed(self._collection)}
)
- self._index.update({k: col for k, col in reversed(self._collection)})
def add(
self, column: ColumnElement[Any], key: Optional[_COLKEY] = None
@@ -1571,12 +1602,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
self._collection.append((colkey, _column))
self._colset.add(_column)
- self._index[l] = _column
+ self._index[l] = (colkey, _column)
if colkey not in self._index:
- self._index[colkey] = _column
+ self._index[colkey] = (colkey, _column)
def __getstate__(self) -> Dict[str, Any]:
- return {"_collection": self._collection, "_index": self._index}
+ return {
+ "_collection": self._collection,
+ "_index": self._index,
+ }
def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_index", state["_index"])
@@ -1652,7 +1686,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
col, intersect = None, None
target_set = column.proxy_set
- cols = [c for (k, c) in self._collection]
+ cols = [c for (_, c) in self._collection]
for c in cols:
expanded_proxy_set = set(_expand_cloned(c.proxy_set))
i = target_set.intersection(expanded_proxy_set)
@@ -1739,7 +1773,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
if key in self._index:
- existing = self._index[key]
+ existing = self._index[key][1]
if existing is named_column:
return
@@ -1754,8 +1788,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
l = len(self._collection)
self._collection.append((key, named_column))
self._colset.add(named_column)
- self._index[l] = named_column
- self._index[key] = named_column
+ self._index[l] = (key, named_column)
+ self._index[key] = (key, named_column)
def _populate_separate_keys(
self, iter_: Iterable[Tuple[str, _NAMEDCOL]]
@@ -1775,12 +1809,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
elif col.key in self._index:
replace_col.append(col)
else:
- self._index[k] = col
+ self._index[k] = (k, col)
self._collection.append((k, col))
self._colset.update(c for (k, c) in self._collection)
self._index.update(
- (idx, c) for idx, (k, c) in enumerate(self._collection)
+ (idx, (k, c)) for idx, (k, c) in enumerate(self._collection)
)
for col in replace_col:
self.replace(col)
@@ -1801,7 +1835,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
]
self._index.update(
- {idx: col for idx, (k, col) in enumerate(self._collection)}
+ {idx: (k, col) for idx, (k, col) in enumerate(self._collection)}
)
# delete higher index
del self._index[len(self._collection)]
@@ -1826,12 +1860,12 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
remove_col = set()
# remove up to two columns based on matches of name as well as key
if column.name in self._index and column.key != column.name:
- other = self._index[column.name]
+ other = self._index[column.name][1]
if other.name == other.key:
remove_col.add(other)
if column.key in self._index:
- remove_col.add(self._index[column.key])
+ remove_col.add(self._index[column.key][1])
new_cols: List[Tuple[str, _NAMEDCOL]] = []
replaced = False
@@ -1855,9 +1889,9 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
self._index.clear()
self._index.update(
- {idx: col for idx, (k, col) in enumerate(self._collection)}
+ {idx: (k, col) for idx, (k, col) in enumerate(self._collection)}
)
- self._index.update(self._collection)
+ self._index.update({k: (k, col) for (k, col) in self._collection})
class ReadOnlyColumnCollection(