summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r--lib/sqlalchemy/orm/util.py45
1 files changed, 27 insertions, 18 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index c50cc5bac..520c95672 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
def _is_mapped_annotation(
- raw_annotation: Union[type, str], cls: Type[Any]
+ raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
annotated = de_stringify_annotation(cls, raw_annotation)
return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
@@ -1969,9 +1969,14 @@ def _extract_mapped_subtype(
attr_cls: Type[Any],
required: bool,
is_dataclass_field: bool,
- superclasses: Optional[Tuple[Type[Any], ...]] = None,
+ expect_mapped: bool = True,
) -> Optional[Union[type, str]]:
+ """given an annotation, figure out if it's ``Mapped[something]`` and if
+ so, return the ``something`` part.
+ Includes error raise scenarios and other options.
+
+ """
if raw_annotation is None:
if required:
@@ -1989,25 +1994,29 @@ def _extract_mapped_subtype(
if is_dataclass_field:
return annotated
else:
- # TODO: there don't seem to be tests for the failure
- # conditions here
- if not hasattr(annotated, "__origin__") or (
- not issubclass(
- annotated.__origin__, # type: ignore
- superclasses if superclasses else attr_cls,
- )
- and not issubclass(attr_cls, annotated.__origin__) # type: ignore
+ if not hasattr(annotated, "__origin__") or not is_origin_of(
+ annotated, "Mapped", module="sqlalchemy.orm"
):
- our_annotated_str = (
- annotated.__name__
+ anno_name = (
+ getattr(annotated, "__name__", None)
if not isinstance(annotated, str)
- else repr(annotated)
- )
- raise sa_exc.ArgumentError(
- f'Type annotation for "{cls.__name__}.{key}" should use the '
- f'syntax "Mapped[{our_annotated_str}]" or '
- f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ else None
)
+ if anno_name is None:
+ our_annotated_str = repr(annotated)
+ else:
+ our_annotated_str = anno_name
+
+ if expect_mapped:
+ raise sa_exc.ArgumentError(
+ f'Type annotation for "{cls.__name__}.{key}" '
+ "should use the "
+ f'syntax "Mapped[{our_annotated_str}]" or '
+ f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ )
+
+ else:
+ return annotated
if len(annotated.__args__) != 1: # type: ignore
raise sa_exc.ArgumentError(