summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-05-16 02:32:44 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-05-16 02:32:44 +0000
commit5d080d17464712d33c0215d12513e529d848ee8c (patch)
treeeec56f3138a48f55f2585a64f01b4fd9c14451b7 /lib/sqlalchemy/orm/util.py
parentc4dad3695f4ab9fef3a4cb05893492afbec811f7 (diff)
parent18a73fb1d1c267842ead5dacd05a49f4344d8b22 (diff)
downloadsqlalchemy-5d080d17464712d33c0215d12513e529d848ee8c.tar.gz
Merge "revenge of pep 484" into main
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r--lib/sqlalchemy/orm/util.py70
1 files changed, 50 insertions, 20 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 4da0b7773..c50cc5bac 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -12,6 +12,7 @@ import re
import types
import typing
from typing import Any
+from typing import Callable
from typing import cast
from typing import Dict
from typing import FrozenSet
@@ -82,24 +83,29 @@ if typing.TYPE_CHECKING:
from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InternalEntityType
- from ._typing import _ORMColumnExprArgument
+ from ._typing import _ORMCOLEXPR
from .context import _MapperEntity
from .context import ORMCompileState
from .mapper import Mapper
+ from .query import Query
from .relationships import Relationship
from ..engine import Row
from ..engine import RowMapping
+ from ..sql._typing import _CE
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _EquivalentColumnMap
from ..sql._typing import _FromClauseArgument
from ..sql._typing import _OnClauseArgument
from ..sql._typing import _PropagateAttrsType
+ from ..sql.annotation import _SA
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import BindParameter
from ..sql.selectable import _ColumnsClauseElement
from ..sql.selectable import Alias
+ from ..sql.selectable import Select
from ..sql.selectable import Subquery
from ..sql.visitors import anon_map
+ from ..util.typing import _AnnotationScanType
_T = TypeVar("_T", bound=Any)
@@ -144,9 +150,11 @@ class CascadeOptions(FrozenSet[str]):
expunge: bool
delete_orphan: bool
- def __new__(cls, value_list):
+ def __new__(
+ cls, value_list: Optional[Union[Iterable[str], str]]
+ ) -> CascadeOptions:
if isinstance(value_list, str) or value_list is None:
- return cls.from_string(value_list)
+ return cls.from_string(value_list) # type: ignore
values = set(value_list)
if values.difference(cls._allowed_cascades):
raise sa_exc.ArgumentError(
@@ -864,7 +872,7 @@ class AliasedInsp(
def _with_polymorphic_factory(
cls,
base: Union[_O, Mapper[_O]],
- classes: Iterable[Type[Any]],
+ classes: Iterable[_EntityType[Any]],
selectable: Union[Literal[False, None], FromClause] = False,
flat: bool = False,
polymorphic_on: Optional[ColumnElement[Any]] = None,
@@ -1011,23 +1019,40 @@ class AliasedInsp(
)._aliased_insp
def _adapt_element(
- self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None
- ) -> _ORMColumnExprArgument[_T]:
- assert isinstance(elem, ColumnElement)
+ self, expr: _ORMCOLEXPR, key: Optional[str] = None
+ ) -> _ORMCOLEXPR:
+ assert isinstance(expr, ColumnElement)
d: Dict[str, Any] = {
"parententity": self,
"parentmapper": self.mapper,
}
if key:
d["proxy_key"] = key
+
+ # IMO mypy should see this one also as returning the same type
+ # we put into it, but it's not
return (
- self._adapter.traverse(elem)
+ self._adapter.traverse(expr) # type: ignore
._annotate(d)
._set_propagate_attrs(
{"compile_state_plugin": "orm", "plugin_subject": self}
)
)
+ if TYPE_CHECKING:
+ # establish compatibility with the _ORMAdapterProto protocol,
+ # which in turn is compatible with _CoreAdapterProto.
+
+ def _orm_adapt_element(
+ self,
+ obj: _CE,
+ key: Optional[str] = None,
+ ) -> _CE:
+ ...
+
+ else:
+ _orm_adapt_element = _adapt_element
+
def _entity_for_mapper(self, mapper):
self_poly = self.with_polymorphic_mappers
if mapper in self_poly:
@@ -1469,7 +1494,12 @@ class Bundle(
cloned.name = name
return cloned
- def create_row_processor(self, query, procs, labels):
+ def create_row_processor(
+ self,
+ query: Select[Any],
+ procs: Sequence[Callable[[Row[Any]], Any]],
+ labels: Sequence[str],
+ ) -> Callable[[Row[Any]], Any]:
"""Produce the "row processing" function for this :class:`.Bundle`.
May be overridden by subclasses.
@@ -1481,13 +1511,13 @@ class Bundle(
"""
keyed_tuple = result_tuple(labels, [() for l in labels])
- def proc(row):
+ def proc(row: Row[Any]) -> Any:
return keyed_tuple([proc(row) for proc in procs])
return proc
-def _orm_annotate(element, exclude=None):
+def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA:
"""Deep copy the given ClauseElement, annotating each element with the
"_orm_adapt" flag.
@@ -1497,7 +1527,7 @@ def _orm_annotate(element, exclude=None):
return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
-def _orm_deannotate(element):
+def _orm_deannotate(element: _SA) -> _SA:
"""Remove annotations that link a column to a particular mapping.
Note this doesn't affect "remote" and "foreign" annotations
@@ -1511,7 +1541,7 @@ def _orm_deannotate(element):
)
-def _orm_full_deannotate(element):
+def _orm_full_deannotate(element: _SA) -> _SA:
return sql_util._deep_deannotate(element)
@@ -1560,13 +1590,15 @@ class _ORMJoin(expression.Join):
on_selectable = prop.parent.selectable
else:
prop = None
+ on_selectable = None
if prop:
left_selectable = left_info.selectable
-
+ adapt_from: Optional[FromClause]
if sql_util.clause_is_present(on_selectable, left_selectable):
adapt_from = on_selectable
else:
+ assert isinstance(left_selectable, FromClause)
adapt_from = left_selectable
(
@@ -1855,7 +1887,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool:
return given.isa(mapper)
-def _getitem(iterable_query, item):
+def _getitem(iterable_query: Query[Any], item: Any) -> Any:
"""calculate __getitem__ in terms of an iterable query object
that also has a slice() method.
@@ -1881,17 +1913,15 @@ def _getitem(iterable_query, item):
isinstance(stop, int) and stop < 0
):
_no_negative_indexes()
- return list(iterable_query)[item]
res = iterable_query.slice(start, stop)
if step is not None:
- return list(res)[None : None : item.step]
+ return list(res)[None : None : item.step] # type: ignore
else:
- return list(res)
+ return list(res) # type: ignore
else:
if item == -1:
_no_negative_indexes()
- return list(iterable_query)[-1]
else:
return list(iterable_query[item : item + 1])[0]
@@ -1933,7 +1963,7 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
def _extract_mapped_subtype(
- raw_annotation: Union[type, str],
+ raw_annotation: Optional[_AnnotationScanType],
cls: type,
key: str,
attr_cls: Type[Any],