summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/properties.py6
-rw-r--r--lib/sqlalchemy/util/typing.py17
2 files changed, 21 insertions, 2 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 7d7175678..3d9fe578d 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -51,9 +51,11 @@ from ..sql.schema import Column
from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
+from ..util.typing import de_stringify_union_elements
from ..util.typing import is_fwd_ref
from ..util.typing import is_optional_union
from ..util.typing import is_pep593
+from ..util.typing import is_union
from ..util.typing import Self
from ..util.typing import typing_get_args
@@ -655,6 +657,9 @@ class MappedColumn(
if is_fwd_ref(argument):
argument = de_stringify_annotation(cls, argument)
+ if is_union(argument):
+ argument = de_stringify_union_elements(cls, argument)
+
nullable = is_optional_union(argument)
if not self._has_nullable:
@@ -690,6 +695,7 @@ class MappedColumn(
checks = (our_type,)
for check_type in checks:
+
if registry.type_annotation_map:
new_sqltype = registry.type_annotation_map.get(check_type)
if new_sqltype is None:
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 85c1bae72..a0d59a630 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -120,6 +120,19 @@ def de_stringify_annotation(
return annotation # type: ignore
+def de_stringify_union_elements(
+ cls: Type[Any],
+ annotation: _AnnotationScanType,
+ str_cleanup_fn: Optional[Callable[[str], str]] = None,
+) -> Type[Any]:
+ return make_union_type(
+ *[
+ de_stringify_annotation(cls, anno, str_cleanup_fn)
+ for anno in annotation.__args__ # type: ignore
+ ]
+ )
+
+
def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
@@ -186,7 +199,7 @@ def expand_unions(
return (type_,)
-def is_optional(type_):
+def is_optional(type_: Any) -> bool:
return is_origin_of(
type_,
"Optional",
@@ -199,7 +212,7 @@ def is_optional_union(type_: Any) -> bool:
return is_optional(type_) and NoneType in typing_get_args(type_)
-def is_union(type_):
+def is_union(type_: Any) -> bool:
return is_origin_of(type_, "Union")