diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-05-16 02:32:44 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-05-16 02:32:44 +0000 |
commit | 5d080d17464712d33c0215d12513e529d848ee8c (patch) | |
tree | eec56f3138a48f55f2585a64f01b4fd9c14451b7 /lib/sqlalchemy/orm/util.py | |
parent | c4dad3695f4ab9fef3a4cb05893492afbec811f7 (diff) | |
parent | 18a73fb1d1c267842ead5dacd05a49f4344d8b22 (diff) | |
download | sqlalchemy-5d080d17464712d33c0215d12513e529d848ee8c.tar.gz |
Merge "revenge of pep 484" into main
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 70 |
1 files changed, 50 insertions, 20 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4da0b7773..c50cc5bac 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -12,6 +12,7 @@ import re import types import typing from typing import Any +from typing import Callable from typing import cast from typing import Dict from typing import FrozenSet @@ -82,24 +83,29 @@ if typing.TYPE_CHECKING: from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InternalEntityType - from ._typing import _ORMColumnExprArgument + from ._typing import _ORMCOLEXPR from .context import _MapperEntity from .context import ORMCompileState from .mapper import Mapper + from .query import Query from .relationships import Relationship from ..engine import Row from ..engine import RowMapping + from ..sql._typing import _CE from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _EquivalentColumnMap from ..sql._typing import _FromClauseArgument from ..sql._typing import _OnClauseArgument from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _SA from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import BindParameter from ..sql.selectable import _ColumnsClauseElement from ..sql.selectable import Alias + from ..sql.selectable import Select from ..sql.selectable import Subquery from ..sql.visitors import anon_map + from ..util.typing import _AnnotationScanType _T = TypeVar("_T", bound=Any) @@ -144,9 +150,11 @@ class CascadeOptions(FrozenSet[str]): expunge: bool delete_orphan: bool - def __new__(cls, value_list): + def __new__( + cls, value_list: Optional[Union[Iterable[str], str]] + ) -> CascadeOptions: if isinstance(value_list, str) or value_list is None: - return cls.from_string(value_list) + return cls.from_string(value_list) # type: ignore values = set(value_list) if values.difference(cls._allowed_cascades): raise sa_exc.ArgumentError( @@ -864,7 +872,7 @@ class AliasedInsp( def _with_polymorphic_factory( cls, base: Union[_O, Mapper[_O]], - classes: Iterable[Type[Any]], + classes: Iterable[_EntityType[Any]], selectable: Union[Literal[False, None], FromClause] = False, flat: bool = False, polymorphic_on: Optional[ColumnElement[Any]] = None, @@ -1011,23 +1019,40 @@ class AliasedInsp( )._aliased_insp def _adapt_element( - self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None - ) -> _ORMColumnExprArgument[_T]: - assert isinstance(elem, ColumnElement) + self, expr: _ORMCOLEXPR, key: Optional[str] = None + ) -> _ORMCOLEXPR: + assert isinstance(expr, ColumnElement) d: Dict[str, Any] = { "parententity": self, "parentmapper": self.mapper, } if key: d["proxy_key"] = key + + # IMO mypy should see this one also as returning the same type + # we put into it, but it's not return ( - self._adapter.traverse(elem) + self._adapter.traverse(expr) # type: ignore ._annotate(d) ._set_propagate_attrs( {"compile_state_plugin": "orm", "plugin_subject": self} ) ) + if TYPE_CHECKING: + # establish compatibility with the _ORMAdapterProto protocol, + # which in turn is compatible with _CoreAdapterProto. + + def _orm_adapt_element( + self, + obj: _CE, + key: Optional[str] = None, + ) -> _CE: + ... + + else: + _orm_adapt_element = _adapt_element + def _entity_for_mapper(self, mapper): self_poly = self.with_polymorphic_mappers if mapper in self_poly: @@ -1469,7 +1494,12 @@ class Bundle( cloned.name = name return cloned - def create_row_processor(self, query, procs, labels): + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: """Produce the "row processing" function for this :class:`.Bundle`. May be overridden by subclasses. @@ -1481,13 +1511,13 @@ class Bundle( """ keyed_tuple = result_tuple(labels, [() for l in labels]) - def proc(row): + def proc(row: Row[Any]) -> Any: return keyed_tuple([proc(row) for proc in procs]) return proc -def _orm_annotate(element, exclude=None): +def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA: """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. @@ -1497,7 +1527,7 @@ def _orm_annotate(element, exclude=None): return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude) -def _orm_deannotate(element): +def _orm_deannotate(element: _SA) -> _SA: """Remove annotations that link a column to a particular mapping. Note this doesn't affect "remote" and "foreign" annotations @@ -1511,7 +1541,7 @@ def _orm_deannotate(element): ) -def _orm_full_deannotate(element): +def _orm_full_deannotate(element: _SA) -> _SA: return sql_util._deep_deannotate(element) @@ -1560,13 +1590,15 @@ class _ORMJoin(expression.Join): on_selectable = prop.parent.selectable else: prop = None + on_selectable = None if prop: left_selectable = left_info.selectable - + adapt_from: Optional[FromClause] if sql_util.clause_is_present(on_selectable, left_selectable): adapt_from = on_selectable else: + assert isinstance(left_selectable, FromClause) adapt_from = left_selectable ( @@ -1855,7 +1887,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: return given.isa(mapper) -def _getitem(iterable_query, item): +def _getitem(iterable_query: Query[Any], item: Any) -> Any: """calculate __getitem__ in terms of an iterable query object that also has a slice() method. @@ -1881,17 +1913,15 @@ def _getitem(iterable_query, item): isinstance(stop, int) and stop < 0 ): _no_negative_indexes() - return list(iterable_query)[item] res = iterable_query.slice(start, stop) if step is not None: - return list(res)[None : None : item.step] + return list(res)[None : None : item.step] # type: ignore else: - return list(res) + return list(res) # type: ignore else: if item == -1: _no_negative_indexes() - return list(iterable_query)[-1] else: return list(iterable_query[item : item + 1])[0] @@ -1933,7 +1963,7 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str: def _extract_mapped_subtype( - raw_annotation: Union[type, str], + raw_annotation: Optional[_AnnotationScanType], cls: type, key: str, attr_cls: Type[Any], |