summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py94
1 files changed, 59 insertions, 35 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index bb51693cf..629e88a32 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -948,6 +948,7 @@ class Executable(roles.StatementRole, Generative):
supports_execution: bool = True
_execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
+ _is_default_generator = False
_with_options: Tuple[ExecutableOption, ...] = ()
_with_context_options: Tuple[
Tuple[Callable[[CompileState], None], Any], ...
@@ -993,10 +994,17 @@ class Executable(roles.StatementRole, Generative):
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- _force: bool = False,
) -> CursorResult:
...
+ def _execute_on_scalar(
+ self,
+ connection: Connection,
+ distilled_params: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ ) -> Any:
+ ...
+
@util.ro_non_memoized_property
def _all_selected_columns(self):
raise NotImplementedError()
@@ -1243,10 +1251,12 @@ class SchemaVisitor(ClauseVisitor):
_COLKEY = TypeVar("_COLKEY", Union[None, str], str)
+
+_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
_COL = TypeVar("_COL", bound="ColumnElement[Any]")
-class ColumnCollection(Generic[_COLKEY, _COL]):
+class ColumnCollection(Generic[_COLKEY, _COL_co]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
:class:`_sql.FromClause` objects.
@@ -1357,12 +1367,12 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
__slots__ = "_collection", "_index", "_colset"
- _collection: List[Tuple[_COLKEY, _COL]]
- _index: Dict[Union[None, str, int], _COL]
- _colset: Set[_COL]
+ _collection: List[Tuple[_COLKEY, _COL_co]]
+ _index: Dict[Union[None, str, int], _COL_co]
+ _colset: Set[_COL_co]
def __init__(
- self, columns: Optional[Iterable[Tuple[_COLKEY, _COL]]] = None
+ self, columns: Optional[Iterable[Tuple[_COLKEY, _COL_co]]] = None
):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
@@ -1370,11 +1380,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
if columns:
self._initial_populate(columns)
- def _initial_populate(self, iter_: Iterable[Tuple[_COLKEY, _COL]]) -> None:
+ def _initial_populate(
+ self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
+ ) -> None:
self._populate_separate_keys(iter_)
@property
- def _all_columns(self) -> List[_COL]:
+ def _all_columns(self) -> List[_COL_co]:
return [col for (k, col) in self._collection]
def keys(self) -> List[_COLKEY]:
@@ -1382,13 +1394,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
collection."""
return [k for (k, col) in self._collection]
- def values(self) -> List[_COL]:
+ def values(self) -> List[_COL_co]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
return [col for (k, col) in self._collection]
- def items(self) -> List[Tuple[_COLKEY, _COL]]:
+ def items(self) -> List[Tuple[_COLKEY, _COL_co]]:
"""Return a sequence of (key, column) tuples for all columns in this
collection each consisting of a string key name and a
:class:`_sql.ColumnClause` or
@@ -1403,11 +1415,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
def __len__(self) -> int:
return len(self._collection)
- def __iter__(self) -> Iterator[_COL]:
+ def __iter__(self) -> Iterator[_COL_co]:
# turn to a list first to maintain over a course of changes
return iter([col for k, col in self._collection])
- def __getitem__(self, key: Union[str, int]) -> _COL:
+ def __getitem__(self, key: Union[str, int]) -> _COL_co:
try:
return self._index[key]
except KeyError as err:
@@ -1416,7 +1428,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
else:
raise
- def __getattr__(self, key: str) -> _COL:
+ def __getattr__(self, key: str) -> _COL_co:
try:
return self._index[key]
except KeyError as err:
@@ -1445,7 +1457,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
def __eq__(self, other: Any) -> bool:
return self.compare(other)
- def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]:
+ def get(
+ self, key: str, default: Optional[_COL_co] = None
+ ) -> Optional[_COL_co]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
@@ -1487,7 +1501,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
__hash__ = None # type: ignore
def _populate_separate_keys(
- self, iter_: Iterable[Tuple[_COLKEY, _COL]]
+ self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
@@ -1498,7 +1512,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
)
self._index.update({k: col for k, col in reversed(self._collection)})
- def add(self, column: _COL, key: Optional[_COLKEY] = None) -> None:
+ def add(
+ self, column: ColumnElement[Any], key: Optional[_COLKEY] = None
+ ) -> None:
"""Add a column to this :class:`_sql.ColumnCollection`.
.. note::
@@ -1518,11 +1534,17 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
colkey = key
l = len(self._collection)
- self._collection.append((colkey, column))
- self._colset.add(column)
- self._index[l] = column
+
+ # don't really know how this part is supposed to work w/ the
+ # covariant thing
+
+ _column = cast(_COL_co, column)
+
+ self._collection.append((colkey, _column))
+ self._colset.add(_column)
+ self._index[l] = _column
if colkey not in self._index:
- self._index[colkey] = column
+ self._index[colkey] = _column
def __getstate__(self) -> Dict[str, Any]:
return {"_collection": self._collection, "_index": self._index}
@@ -1534,7 +1556,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
self, "_colset", {col for k, col in self._collection}
)
- def contains_column(self, col: _COL) -> bool:
+ def contains_column(self, col: ColumnElement[Any]) -> bool:
"""Checks if a column object exists in this collection"""
if col not in self._colset:
if isinstance(col, str):
@@ -1546,7 +1568,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
else:
return True
- def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL]:
+ def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]:
"""Return a "read only" form of this
:class:`_sql.ColumnCollection`."""
@@ -1554,7 +1576,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
def corresponding_column(
self, column: _COL, require_embedded: bool = False
- ) -> Optional[_COL]:
+ ) -> Optional[Union[_COL, _COL_co]]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from this
:class:`_expression.ColumnCollection`
@@ -1670,14 +1692,16 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
"""
- def add(self, column: _NAMEDCOL, key: Optional[str] = None) -> None:
-
- if key is not None and column.key != key:
+ def add(
+ self, column: ColumnElement[Any], key: Optional[str] = None
+ ) -> None:
+ named_column = cast(_NAMEDCOL, column)
+ if key is not None and named_column.key != key:
raise exc.ArgumentError(
"DedupeColumnCollection requires columns be under "
"the same key as their .key"
)
- key = column.key
+ key = named_column.key
if key is None:
raise exc.ArgumentError(
@@ -1688,21 +1712,21 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
existing = self._index[key]
- if existing is column:
+ if existing is named_column:
return
- self.replace(column)
+ self.replace(named_column)
# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
- util.memoized_property.reset(column, "proxy_set")
+ util.memoized_property.reset(named_column, "proxy_set")
else:
l = len(self._collection)
- self._collection.append((key, column))
- self._colset.add(column)
- self._index[l] = column
- self._index[key] = column
+ self._collection.append((key, named_column))
+ self._colset.add(named_column)
+ self._index[l] = named_column
+ self._index[key] = named_column
def _populate_separate_keys(
self, iter_: Iterable[Tuple[str, _NAMEDCOL]]
@@ -1805,7 +1829,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
class ReadOnlyColumnCollection(
- util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL]
+ util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL_co]
):
__slots__ = ("_parent",)