diff options
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 45 |
1 files changed, 27 insertions, 18 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index c50cc5bac..520c95672 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any: def _is_mapped_annotation( - raw_annotation: Union[type, str], cls: Type[Any] + raw_annotation: _AnnotationScanType, cls: Type[Any] ) -> bool: annotated = de_stringify_annotation(cls, raw_annotation) return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") @@ -1969,9 +1969,14 @@ def _extract_mapped_subtype( attr_cls: Type[Any], required: bool, is_dataclass_field: bool, - superclasses: Optional[Tuple[Type[Any], ...]] = None, + expect_mapped: bool = True, ) -> Optional[Union[type, str]]: + """given an annotation, figure out if it's ``Mapped[something]`` and if + so, return the ``something`` part. + Includes error raise scenarios and other options. + + """ if raw_annotation is None: if required: @@ -1989,25 +1994,29 @@ def _extract_mapped_subtype( if is_dataclass_field: return annotated else: - # TODO: there don't seem to be tests for the failure - # conditions here - if not hasattr(annotated, "__origin__") or ( - not issubclass( - annotated.__origin__, # type: ignore - superclasses if superclasses else attr_cls, - ) - and not issubclass(attr_cls, annotated.__origin__) # type: ignore + if not hasattr(annotated, "__origin__") or not is_origin_of( + annotated, "Mapped", module="sqlalchemy.orm" ): - our_annotated_str = ( - annotated.__name__ + anno_name = ( + getattr(annotated, "__name__", None) if not isinstance(annotated, str) - else repr(annotated) - ) - raise sa_exc.ArgumentError( - f'Type annotation for "{cls.__name__}.{key}" should use the ' - f'syntax "Mapped[{our_annotated_str}]" or ' - f'"{attr_cls.__name__}[{our_annotated_str}]".' + else None ) + if anno_name is None: + our_annotated_str = repr(annotated) + else: + our_annotated_str = anno_name + + if expect_mapped: + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "should use the " + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]".' + ) + + else: + return annotated if len(annotated.__args__) != 1: # type: ignore raise sa_exc.ArgumentError( |