diff options
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 65 |
1 files changed, 62 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 1ef0d7159..d3e36a494 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -9,6 +9,7 @@ from __future__ import annotations import enum +import functools import re import types import typing @@ -46,6 +47,7 @@ from .base import attribute_str as attribute_str # noqa: F401 from .base import class_mapper as class_mapper from .base import InspectionAttr as InspectionAttr from .base import instance_str as instance_str # noqa: F401 +from .base import Mapped from .base import object_mapper as object_mapper from .base import object_state as object_state # noqa: F401 from .base import opt_manager_of_class @@ -79,10 +81,14 @@ from ..sql.elements import ColumnElement from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots -from ..util.typing import de_stringify_annotation -from ..util.typing import eval_name_only +from ..util.typing import de_stringify_annotation as _de_stringify_annotation +from ..util.typing import ( + de_stringify_union_elements as _de_stringify_union_elements, +) +from ..util.typing import eval_name_only as _eval_name_only from ..util.typing import is_origin_of_cls from ..util.typing import Literal +from ..util.typing import Protocol from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: @@ -113,6 +119,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Subquery from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType + from ..util.typing import ArgsTypeProcotol _T = TypeVar("_T", bound=Any) @@ -130,6 +137,58 @@ all_cascades = frozenset( ) +_de_stringify_partial = functools.partial( + functools.partial, locals_=util.immutabledict({"Mapped": Mapped}) +) + +# partial is practically useless as we have to write out the whole +# function and maintain the signature anyway + + +class _DeStringifyAnnotation(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: _AnnotationScanType, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, + ) -> Type[Any]: + ... + + +de_stringify_annotation = cast( + _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation) +) + + +class _DeStringifyUnionElements(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: ArgsTypeProcotol, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + ) -> Type[Any]: + ... + + +de_stringify_union_elements = cast( + _DeStringifyUnionElements, + _de_stringify_partial(_de_stringify_union_elements), +) + + +class _EvalNameOnly(Protocol): + def __call__(self, name: str, module_name: str) -> Any: + ... + + +eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) + + class CascadeOptions(FrozenSet[str]): """Keeps track of the options sent to :paramref:`.relationship.cascade`""" @@ -2271,7 +2330,7 @@ def _extract_mapped_subtype( cls, raw_annotation, originating_module, - _cleanup_mapped_str_annotation, + str_cleanup_fn=_cleanup_mapped_str_annotation, ) except _CleanupError as ce: raise sa_exc.ArgumentError( |