summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r--lib/sqlalchemy/orm/util.py65
1 files changed, 62 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 1ef0d7159..d3e36a494 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -9,6 +9,7 @@
from __future__ import annotations
import enum
+import functools
import re
import types
import typing
@@ -46,6 +47,7 @@ from .base import attribute_str as attribute_str # noqa: F401
from .base import class_mapper as class_mapper
from .base import InspectionAttr as InspectionAttr
from .base import instance_str as instance_str # noqa: F401
+from .base import Mapped
from .base import object_mapper as object_mapper
from .base import object_state as object_state # noqa: F401
from .base import opt_manager_of_class
@@ -79,10 +81,14 @@ from ..sql.elements import ColumnElement
from ..sql.elements import KeyedColumnElement
from ..sql.selectable import FromClause
from ..util.langhelpers import MemoizedSlots
-from ..util.typing import de_stringify_annotation
-from ..util.typing import eval_name_only
+from ..util.typing import de_stringify_annotation as _de_stringify_annotation
+from ..util.typing import (
+ de_stringify_union_elements as _de_stringify_union_elements,
+)
+from ..util.typing import eval_name_only as _eval_name_only
from ..util.typing import is_origin_of_cls
from ..util.typing import Literal
+from ..util.typing import Protocol
from ..util.typing import typing_get_origin
if typing.TYPE_CHECKING:
@@ -113,6 +119,7 @@ if typing.TYPE_CHECKING:
from ..sql.selectable import Subquery
from ..sql.visitors import anon_map
from ..util.typing import _AnnotationScanType
+ from ..util.typing import ArgsTypeProcotol
_T = TypeVar("_T", bound=Any)
@@ -130,6 +137,58 @@ all_cascades = frozenset(
)
+_de_stringify_partial = functools.partial(
+ functools.partial, locals_=util.immutabledict({"Mapped": Mapped})
+)
+
+# partial is practically useless as we have to write out the whole
+# function and maintain the signature anyway
+
+
+class _DeStringifyAnnotation(Protocol):
+ def __call__(
+ self,
+ cls: Type[Any],
+ annotation: _AnnotationScanType,
+ originating_module: str,
+ *,
+ str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+ include_generic: bool = False,
+ ) -> Type[Any]:
+ ...
+
+
+de_stringify_annotation = cast(
+ _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation)
+)
+
+
+class _DeStringifyUnionElements(Protocol):
+ def __call__(
+ self,
+ cls: Type[Any],
+ annotation: ArgsTypeProcotol,
+ originating_module: str,
+ *,
+ str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+ ) -> Type[Any]:
+ ...
+
+
+de_stringify_union_elements = cast(
+ _DeStringifyUnionElements,
+ _de_stringify_partial(_de_stringify_union_elements),
+)
+
+
+class _EvalNameOnly(Protocol):
+ def __call__(self, name: str, module_name: str) -> Any:
+ ...
+
+
+eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only))
+
+
class CascadeOptions(FrozenSet[str]):
"""Keeps track of the options sent to
:paramref:`.relationship.cascade`"""
@@ -2271,7 +2330,7 @@ def _extract_mapped_subtype(
cls,
raw_annotation,
originating_module,
- _cleanup_mapped_str_annotation,
+ str_cleanup_fn=_cleanup_mapped_str_annotation,
)
except _CleanupError as ce:
raise sa_exc.ArgumentError(