summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/base.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-13 09:45:29 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-15 10:29:23 -0400
commitc932123bacad9bf047d160b85e3f95d396c513ae (patch)
tree3f84221c467ff8fba468d7ca78dc4b0c158d8970 /lib/sqlalchemy/sql/base.py
parent0bfb620009f668e97ad3c2c25a564ca36428b9ae (diff)
downloadsqlalchemy-c932123bacad9bf047d160b85e3f95d396c513ae.tar.gz
pep484: schema API
implement strict typing for schema.py this module has lots of public API, lots of old decisions and very hard to follow construction sequences in many cases, and is also where we get a lot of new feature requests, so strict typing should help keep things clean. among improvements here, fixed the pool .info getters and also figured out how to get ColumnCollection and related to be covariant so that we may set them up as returning Column or ColumnClause without any conflicts. DDL was affected, noting that superclasses of DDLElement (_DDLCompiles, added recently) can now be passed into "ddl_if" callables; reorganized ddl into ExecutableDDLElement as a new name for DDLElement and _DDLCompiles renamed to BaseDDLElement. setting up strict also located an API use case that is completely broken, which is connection.execute(some_default) returns a scalar value. This case has been deprecated and new paths have been set up so that connection.scalar() may be used. This likely wasn't possible in previous versions because scalar() would assume a CursorResult. The scalar() change also impacts Session as we have explicit support (since someone had reported it as a regression) for session.execute(Sequence()) to work. They will get the same deprecation message (which omits the word "Connection", just uses ".execute()" and ".scalar()") and they can then use Session.scalar() as well. Getting this to type correctly while still supporting ORM use cases required some refactoring, and I also set up a keyword only delimeter for Session.execute() and related as execution_options / bind_arguments should always be keyword only, applied these changes to AsyncSession as well. Additionally simpify Table __init__ now that we are Python 3 only, we can have positional plus explicit kwargs finally. Simplify Column.__init__ as well again taking advantage of kw only arguments. Fill in most/all __init__ methods in sqltypes.py as the constructor for types is most of the API. should likely do this for dialect-specific types as well. Apply _InfoType for all info attributes as should have been done originally and update descriptor decorators. Change-Id: I3f9f8ff3f1c8858471ff4545ac83d68c88107527
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
-rw-r--r--lib/sqlalchemy/sql/base.py94
1 files changed, 59 insertions, 35 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index bb51693cf..629e88a32 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -948,6 +948,7 @@ class Executable(roles.StatementRole, Generative):
supports_execution: bool = True
_execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
+ _is_default_generator = False
_with_options: Tuple[ExecutableOption, ...] = ()
_with_context_options: Tuple[
Tuple[Callable[[CompileState], None], Any], ...
@@ -993,10 +994,17 @@ class Executable(roles.StatementRole, Generative):
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- _force: bool = False,
) -> CursorResult:
...
+ def _execute_on_scalar(
+ self,
+ connection: Connection,
+ distilled_params: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ ) -> Any:
+ ...
+
@util.ro_non_memoized_property
def _all_selected_columns(self):
raise NotImplementedError()
@@ -1243,10 +1251,12 @@ class SchemaVisitor(ClauseVisitor):
_COLKEY = TypeVar("_COLKEY", Union[None, str], str)
+
+_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
_COL = TypeVar("_COL", bound="ColumnElement[Any]")
-class ColumnCollection(Generic[_COLKEY, _COL]):
+class ColumnCollection(Generic[_COLKEY, _COL_co]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
:class:`_sql.FromClause` objects.
@@ -1357,12 +1367,12 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
__slots__ = "_collection", "_index", "_colset"
- _collection: List[Tuple[_COLKEY, _COL]]
- _index: Dict[Union[None, str, int], _COL]
- _colset: Set[_COL]
+ _collection: List[Tuple[_COLKEY, _COL_co]]
+ _index: Dict[Union[None, str, int], _COL_co]
+ _colset: Set[_COL_co]
def __init__(
- self, columns: Optional[Iterable[Tuple[_COLKEY, _COL]]] = None
+ self, columns: Optional[Iterable[Tuple[_COLKEY, _COL_co]]] = None
):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
@@ -1370,11 +1380,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
if columns:
self._initial_populate(columns)
- def _initial_populate(self, iter_: Iterable[Tuple[_COLKEY, _COL]]) -> None:
+ def _initial_populate(
+ self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
+ ) -> None:
self._populate_separate_keys(iter_)
@property
- def _all_columns(self) -> List[_COL]:
+ def _all_columns(self) -> List[_COL_co]:
return [col for (k, col) in self._collection]
def keys(self) -> List[_COLKEY]:
@@ -1382,13 +1394,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
collection."""
return [k for (k, col) in self._collection]
- def values(self) -> List[_COL]:
+ def values(self) -> List[_COL_co]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
return [col for (k, col) in self._collection]
- def items(self) -> List[Tuple[_COLKEY, _COL]]:
+ def items(self) -> List[Tuple[_COLKEY, _COL_co]]:
"""Return a sequence of (key, column) tuples for all columns in this
collection each consisting of a string key name and a
:class:`_sql.ColumnClause` or
@@ -1403,11 +1415,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
def __len__(self) -> int:
return len(self._collection)
- def __iter__(self) -> Iterator[_COL]:
+ def __iter__(self) -> Iterator[_COL_co]:
# turn to a list first to maintain over a course of changes
return iter([col for k, col in self._collection])
- def __getitem__(self, key: Union[str, int]) -> _COL:
+ def __getitem__(self, key: Union[str, int]) -> _COL_co:
try:
return self._index[key]
except KeyError as err:
@@ -1416,7 +1428,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
else:
raise
- def __getattr__(self, key: str) -> _COL:
+ def __getattr__(self, key: str) -> _COL_co:
try:
return self._index[key]
except KeyError as err:
@@ -1445,7 +1457,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
def __eq__(self, other: Any) -> bool:
return self.compare(other)
- def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]:
+ def get(
+ self, key: str, default: Optional[_COL_co] = None
+ ) -> Optional[_COL_co]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
@@ -1487,7 +1501,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
__hash__ = None # type: ignore
def _populate_separate_keys(
- self, iter_: Iterable[Tuple[_COLKEY, _COL]]
+ self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
@@ -1498,7 +1512,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
)
self._index.update({k: col for k, col in reversed(self._collection)})
- def add(self, column: _COL, key: Optional[_COLKEY] = None) -> None:
+ def add(
+ self, column: ColumnElement[Any], key: Optional[_COLKEY] = None
+ ) -> None:
"""Add a column to this :class:`_sql.ColumnCollection`.
.. note::
@@ -1518,11 +1534,17 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
colkey = key
l = len(self._collection)
- self._collection.append((colkey, column))
- self._colset.add(column)
- self._index[l] = column
+
+ # don't really know how this part is supposed to work w/ the
+ # covariant thing
+
+ _column = cast(_COL_co, column)
+
+ self._collection.append((colkey, _column))
+ self._colset.add(_column)
+ self._index[l] = _column
if colkey not in self._index:
- self._index[colkey] = column
+ self._index[colkey] = _column
def __getstate__(self) -> Dict[str, Any]:
return {"_collection": self._collection, "_index": self._index}
@@ -1534,7 +1556,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
self, "_colset", {col for k, col in self._collection}
)
- def contains_column(self, col: _COL) -> bool:
+ def contains_column(self, col: ColumnElement[Any]) -> bool:
"""Checks if a column object exists in this collection"""
if col not in self._colset:
if isinstance(col, str):
@@ -1546,7 +1568,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
else:
return True
- def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL]:
+ def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]:
"""Return a "read only" form of this
:class:`_sql.ColumnCollection`."""
@@ -1554,7 +1576,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
def corresponding_column(
self, column: _COL, require_embedded: bool = False
- ) -> Optional[_COL]:
+ ) -> Optional[Union[_COL, _COL_co]]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from this
:class:`_expression.ColumnCollection`
@@ -1670,14 +1692,16 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
"""
- def add(self, column: _NAMEDCOL, key: Optional[str] = None) -> None:
-
- if key is not None and column.key != key:
+ def add(
+ self, column: ColumnElement[Any], key: Optional[str] = None
+ ) -> None:
+ named_column = cast(_NAMEDCOL, column)
+ if key is not None and named_column.key != key:
raise exc.ArgumentError(
"DedupeColumnCollection requires columns be under "
"the same key as their .key"
)
- key = column.key
+ key = named_column.key
if key is None:
raise exc.ArgumentError(
@@ -1688,21 +1712,21 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
existing = self._index[key]
- if existing is column:
+ if existing is named_column:
return
- self.replace(column)
+ self.replace(named_column)
# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
- util.memoized_property.reset(column, "proxy_set")
+ util.memoized_property.reset(named_column, "proxy_set")
else:
l = len(self._collection)
- self._collection.append((key, column))
- self._colset.add(column)
- self._index[l] = column
- self._index[key] = column
+ self._collection.append((key, named_column))
+ self._colset.add(named_column)
+ self._index[l] = named_column
+ self._index[key] = named_column
def _populate_separate_keys(
self, iter_: Iterable[Tuple[str, _NAMEDCOL]]
@@ -1805,7 +1829,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
class ReadOnlyColumnCollection(
- util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL]
+ util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL_co]
):
__slots__ = ("_parent",)