diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/util/typing.py | 2 |
4 files changed, 32 insertions, 12 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 268a1d57a..c23ea0311 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -64,6 +64,8 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType +from ..util.typing import de_stringify_annotation +from ..util.typing import is_fwd_ref from ..util.typing import Protocol from ..util.typing import TypedDict from ..util.typing import typing_get_args @@ -1120,6 +1122,15 @@ class _ClassScanMapperConfig(_MapperConfig): if attr_value is None: for elem in typing_get_args(extracted_mapped_annotation): + if isinstance(elem, str) or is_fwd_ref( + elem, check_generic=True + ): + elem = de_stringify_annotation( + self.cls, + elem, + originating_class.__module__, + include_generic=True, + ) # look in Annotated[...] for an ORM construct, # such as Annotated[int, mapped_column(primary_key=True)] if isinstance(elem, _IntrospectsAnnotations): diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 84d15360d..55c7e9290 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -52,6 +52,8 @@ from .. import util from ..sql import expression from ..sql import operators from ..sql.elements import BindParameter +from ..util.typing import de_stringify_annotation +from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 from ..util.typing import typing_get_args @@ -351,18 +353,23 @@ class CompositeProperty( argument = typing_get_args(argument)[0] if argument and self.composite_class is None: - if isinstance(argument, str) or hasattr( - argument, "__forward_arg__" + if isinstance(argument, str) or is_fwd_ref( + argument, check_generic=True ): - str_arg = ( - argument.__forward_arg__ - if hasattr(argument, "__forward_arg__") - else str(argument) - ) - raise sa_exc.ArgumentError( - f"Can't use forward ref {argument} for composite " - f"class argument; set up the type as Mapped[{str_arg}]" + if originating_module is None: + str_arg = ( + argument.__forward_arg__ + if hasattr(argument, "__forward_arg__") + else str(argument) + ) + raise sa_exc.ArgumentError( + f"Can't use forward ref {argument} for composite " + f"class argument; set up the type as Mapped[{str_arg}]" + ) + argument = de_stringify_annotation( + cls, argument, originating_module, include_generic=True ) + self.composite_class = argument if is_dataclass(self.composite_class): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b8e1521a2..e89e3c356 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -684,7 +684,9 @@ class MappedColumn( ) -> None: sqltype = self.column.type - if is_fwd_ref(argument, check_generic=True): + if isinstance(argument, str) or is_fwd_ref( + argument, check_generic=True + ): assert originating_module is not None argument = de_stringify_annotation( cls, argument, originating_module, include_generic=True diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index dcbc15825..0c8e5a633 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -174,7 +174,7 @@ def eval_expression(expression: str, module_name: str) -> Any: annotation = eval(expression, base_globals, None) except Exception as err: raise NameError( - f"Could not de-stringify annotation {expression}" + f"Could not de-stringify annotation {expression!r}" ) from err else: return annotation |