diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/util/typing.py | 17 |
2 files changed, 21 insertions, 2 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 7d7175678..3d9fe578d 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -51,9 +51,11 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation +from ..util.typing import de_stringify_union_elements from ..util.typing import is_fwd_ref from ..util.typing import is_optional_union from ..util.typing import is_pep593 +from ..util.typing import is_union from ..util.typing import Self from ..util.typing import typing_get_args @@ -655,6 +657,9 @@ class MappedColumn( if is_fwd_ref(argument): argument = de_stringify_annotation(cls, argument) + if is_union(argument): + argument = de_stringify_union_elements(cls, argument) + nullable = is_optional_union(argument) if not self._has_nullable: @@ -690,6 +695,7 @@ class MappedColumn( checks = (our_type,) for check_type in checks: + if registry.type_annotation_map: new_sqltype = registry.type_annotation_map.get(check_type) if new_sqltype is None: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 85c1bae72..a0d59a630 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -120,6 +120,19 @@ def de_stringify_annotation( return annotation # type: ignore +def de_stringify_union_elements( + cls: Type[Any], + annotation: _AnnotationScanType, + str_cleanup_fn: Optional[Callable[[str], str]] = None, +) -> Type[Any]: + return make_union_type( + *[ + de_stringify_annotation(cls, anno, str_cleanup_fn) + for anno in annotation.__args__ # type: ignore + ] + ) + + def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated @@ -186,7 +199,7 @@ def expand_unions( return (type_,) -def is_optional(type_): +def is_optional(type_: Any) -> bool: return is_origin_of( type_, "Optional", @@ -199,7 +212,7 @@ def is_optional_union(type_: Any) -> bool: return is_optional(type_) and NoneType in typing_get_args(type_) -def is_union(type_): +def is_union(type_: Any) -> bool: return is_origin_of(type_, "Union") |