diff options
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 94 |
1 files changed, 59 insertions, 35 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index bb51693cf..629e88a32 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -948,6 +948,7 @@ class Executable(roles.StatementRole, Generative): supports_execution: bool = True _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT + _is_default_generator = False _with_options: Tuple[ExecutableOption, ...] = () _with_context_options: Tuple[ Tuple[Callable[[CompileState], None], Any], ... @@ -993,10 +994,17 @@ class Executable(roles.StatementRole, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - _force: bool = False, ) -> CursorResult: ... + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + ) -> Any: + ... + @util.ro_non_memoized_property def _all_selected_columns(self): raise NotImplementedError() @@ -1243,10 +1251,12 @@ class SchemaVisitor(ClauseVisitor): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) + +_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) _COL = TypeVar("_COL", bound="ColumnElement[Any]") -class ColumnCollection(Generic[_COLKEY, _COL]): +class ColumnCollection(Generic[_COLKEY, _COL_co]): """Collection of :class:`_expression.ColumnElement` instances, typically for :class:`_sql.FromClause` objects. @@ -1357,12 +1367,12 @@ class ColumnCollection(Generic[_COLKEY, _COL]): __slots__ = "_collection", "_index", "_colset" - _collection: List[Tuple[_COLKEY, _COL]] - _index: Dict[Union[None, str, int], _COL] - _colset: Set[_COL] + _collection: List[Tuple[_COLKEY, _COL_co]] + _index: Dict[Union[None, str, int], _COL_co] + _colset: Set[_COL_co] def __init__( - self, columns: Optional[Iterable[Tuple[_COLKEY, _COL]]] = None + self, columns: Optional[Iterable[Tuple[_COLKEY, _COL_co]]] = None ): object.__setattr__(self, "_colset", set()) object.__setattr__(self, "_index", {}) @@ -1370,11 +1380,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]): if columns: self._initial_populate(columns) - def _initial_populate(self, iter_: Iterable[Tuple[_COLKEY, _COL]]) -> None: + def _initial_populate( + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] + ) -> None: self._populate_separate_keys(iter_) @property - def _all_columns(self) -> List[_COL]: + def _all_columns(self) -> List[_COL_co]: return [col for (k, col) in self._collection] def keys(self) -> List[_COLKEY]: @@ -1382,13 +1394,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]): collection.""" return [k for (k, col) in self._collection] - def values(self) -> List[_COL]: + def values(self) -> List[_COL_co]: """Return a sequence of :class:`_sql.ColumnClause` or :class:`_schema.Column` objects for all columns in this collection.""" return [col for (k, col) in self._collection] - def items(self) -> List[Tuple[_COLKEY, _COL]]: + def items(self) -> List[Tuple[_COLKEY, _COL_co]]: """Return a sequence of (key, column) tuples for all columns in this collection each consisting of a string key name and a :class:`_sql.ColumnClause` or @@ -1403,11 +1415,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]): def __len__(self) -> int: return len(self._collection) - def __iter__(self) -> Iterator[_COL]: + def __iter__(self) -> Iterator[_COL_co]: # turn to a list first to maintain over a course of changes return iter([col for k, col in self._collection]) - def __getitem__(self, key: Union[str, int]) -> _COL: + def __getitem__(self, key: Union[str, int]) -> _COL_co: try: return self._index[key] except KeyError as err: @@ -1416,7 +1428,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): else: raise - def __getattr__(self, key: str) -> _COL: + def __getattr__(self, key: str) -> _COL_co: try: return self._index[key] except KeyError as err: @@ -1445,7 +1457,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]): def __eq__(self, other: Any) -> bool: return self.compare(other) - def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]: + def get( + self, key: str, default: Optional[_COL_co] = None + ) -> Optional[_COL_co]: """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object based on a string key name from this :class:`_expression.ColumnCollection`.""" @@ -1487,7 +1501,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): __hash__ = None # type: ignore def _populate_separate_keys( - self, iter_: Iterable[Tuple[_COLKEY, _COL]] + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1498,7 +1512,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]): ) self._index.update({k: col for k, col in reversed(self._collection)}) - def add(self, column: _COL, key: Optional[_COLKEY] = None) -> None: + def add( + self, column: ColumnElement[Any], key: Optional[_COLKEY] = None + ) -> None: """Add a column to this :class:`_sql.ColumnCollection`. .. note:: @@ -1518,11 +1534,17 @@ class ColumnCollection(Generic[_COLKEY, _COL]): colkey = key l = len(self._collection) - self._collection.append((colkey, column)) - self._colset.add(column) - self._index[l] = column + + # don't really know how this part is supposed to work w/ the + # covariant thing + + _column = cast(_COL_co, column) + + self._collection.append((colkey, _column)) + self._colset.add(_column) + self._index[l] = _column if colkey not in self._index: - self._index[colkey] = column + self._index[colkey] = _column def __getstate__(self) -> Dict[str, Any]: return {"_collection": self._collection, "_index": self._index} @@ -1534,7 +1556,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): self, "_colset", {col for k, col in self._collection} ) - def contains_column(self, col: _COL) -> bool: + def contains_column(self, col: ColumnElement[Any]) -> bool: """Checks if a column object exists in this collection""" if col not in self._colset: if isinstance(col, str): @@ -1546,7 +1568,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): else: return True - def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL]: + def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: """Return a "read only" form of this :class:`_sql.ColumnCollection`.""" @@ -1554,7 +1576,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): def corresponding_column( self, column: _COL, require_embedded: bool = False - ) -> Optional[_COL]: + ) -> Optional[Union[_COL, _COL_co]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from this :class:`_expression.ColumnCollection` @@ -1670,14 +1692,16 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """ - def add(self, column: _NAMEDCOL, key: Optional[str] = None) -> None: - - if key is not None and column.key != key: + def add( + self, column: ColumnElement[Any], key: Optional[str] = None + ) -> None: + named_column = cast(_NAMEDCOL, column) + if key is not None and named_column.key != key: raise exc.ArgumentError( "DedupeColumnCollection requires columns be under " "the same key as their .key" ) - key = column.key + key = named_column.key if key is None: raise exc.ArgumentError( @@ -1688,21 +1712,21 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): existing = self._index[key] - if existing is column: + if existing is named_column: return - self.replace(column) + self.replace(named_column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(column, "proxy_set") + util.memoized_property.reset(named_column, "proxy_set") else: l = len(self._collection) - self._collection.append((key, column)) - self._colset.add(column) - self._index[l] = column - self._index[key] = column + self._collection.append((key, named_column)) + self._colset.add(named_column) + self._index[l] = named_column + self._index[key] = named_column def _populate_separate_keys( self, iter_: Iterable[Tuple[str, _NAMEDCOL]] @@ -1805,7 +1829,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): class ReadOnlyColumnCollection( - util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL] + util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL_co] ): __slots__ = ("_parent",) |