summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/properties.py6
-rw-r--r--lib/sqlalchemy/util/typing.py17
-rw-r--r--test/orm/declarative/test_tm_future_annotations.py82
3 files changed, 102 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 7d7175678..3d9fe578d 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -51,9 +51,11 @@ from ..sql.schema import Column
from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
+from ..util.typing import de_stringify_union_elements
from ..util.typing import is_fwd_ref
from ..util.typing import is_optional_union
from ..util.typing import is_pep593
+from ..util.typing import is_union
from ..util.typing import Self
from ..util.typing import typing_get_args
@@ -655,6 +657,9 @@ class MappedColumn(
if is_fwd_ref(argument):
argument = de_stringify_annotation(cls, argument)
+ if is_union(argument):
+ argument = de_stringify_union_elements(cls, argument)
+
nullable = is_optional_union(argument)
if not self._has_nullable:
@@ -690,6 +695,7 @@ class MappedColumn(
checks = (our_type,)
for check_type in checks:
+
if registry.type_annotation_map:
new_sqltype = registry.type_annotation_map.get(check_type)
if new_sqltype is None:
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 85c1bae72..a0d59a630 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -120,6 +120,19 @@ def de_stringify_annotation(
return annotation # type: ignore
+def de_stringify_union_elements(
+ cls: Type[Any],
+ annotation: _AnnotationScanType,
+ str_cleanup_fn: Optional[Callable[[str], str]] = None,
+) -> Type[Any]:
+ return make_union_type(
+ *[
+ de_stringify_annotation(cls, anno, str_cleanup_fn)
+ for anno in annotation.__args__ # type: ignore
+ ]
+ )
+
+
def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
@@ -186,7 +199,7 @@ def expand_unions(
return (type_,)
-def is_optional(type_):
+def is_optional(type_: Any) -> bool:
return is_origin_of(
type_,
"Optional",
@@ -199,7 +212,7 @@ def is_optional_union(type_: Any) -> bool:
return is_optional(type_) and NoneType in typing_get_args(type_)
-def is_union(type_):
+def is_union(type_: Any) -> bool:
return is_origin_of(type_, "Union")
diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py
index 74cbebb4d..76ee464fa 100644
--- a/test/orm/declarative/test_tm_future_annotations.py
+++ b/test/orm/declarative/test_tm_future_annotations.py
@@ -1,13 +1,19 @@
from __future__ import annotations
+from decimal import Decimal
from typing import List
+from typing import Optional
from typing import Set
from typing import TypeVar
+from typing import Union
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
+from sqlalchemy import Numeric
+from sqlalchemy import Table
from sqlalchemy.orm import attribute_mapped_collection
+from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import MappedCollection
@@ -16,7 +22,8 @@ from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
-from .test_typed_mapping import MappedColumnTest # noqa
+from sqlalchemy.util import compat
+from .test_typed_mapping import MappedColumnTest as _MappedColumnTest
from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
"""runs the annotation-sensitive tests from test_typed_mappings while
@@ -28,6 +35,79 @@ having ``from __future__ import annotations`` in effect.
_R = TypeVar("_R")
+class MappedColumnTest(_MappedColumnTest):
+ def test_unions(self):
+ our_type = Numeric(10, 2)
+
+ class Base(DeclarativeBase):
+ type_annotation_map = {Union[float, Decimal]: our_type}
+
+ class User(Base):
+ __tablename__ = "users"
+ __table__: Table
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ data: Mapped[Union[float, Decimal]] = mapped_column()
+ reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+
+ optional_data: Mapped[
+ Optional[Union[float, Decimal]]
+ ] = mapped_column()
+
+ # use Optional directly
+ reverse_optional_data: Mapped[
+ Optional[Union[Decimal, float]]
+ ] = mapped_column()
+
+ # use Union with None, same as Optional but presents differently
+ # (Optional object with __origin__ Union vs. Union)
+ reverse_u_optional_data: Mapped[
+ Union[Decimal, float, None]
+ ] = mapped_column()
+
+ float_data: Mapped[float] = mapped_column()
+ decimal_data: Mapped[Decimal] = mapped_column()
+
+ if compat.py310:
+ pep604_data: Mapped[float | Decimal] = mapped_column()
+ pep604_reverse: Mapped[Decimal | float] = mapped_column()
+ pep604_optional: Mapped[
+ Decimal | float | None
+ ] = mapped_column()
+ pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
+ pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
+ pep604_optional_fwd: Mapped[
+ "Decimal | float | None"
+ ] = mapped_column()
+
+ is_(User.__table__.c.data.type, our_type)
+ is_false(User.__table__.c.data.nullable)
+ is_(User.__table__.c.reverse_data.type, our_type)
+ is_(User.__table__.c.optional_data.type, our_type)
+ is_true(User.__table__.c.optional_data.nullable)
+
+ is_(User.__table__.c.reverse_optional_data.type, our_type)
+ is_(User.__table__.c.reverse_u_optional_data.type, our_type)
+ is_true(User.__table__.c.reverse_optional_data.nullable)
+ is_true(User.__table__.c.reverse_u_optional_data.nullable)
+
+ is_(User.__table__.c.float_data.type, our_type)
+ is_(User.__table__.c.decimal_data.type, our_type)
+
+ if compat.py310:
+ for suffix in ("", "_fwd"):
+ data_col = User.__table__.c[f"pep604_data{suffix}"]
+ reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
+ optional_col = User.__table__.c[f"pep604_optional{suffix}"]
+ is_(data_col.type, our_type)
+ is_false(data_col.nullable)
+ is_(reverse_col.type, our_type)
+ is_false(reverse_col.nullable)
+ is_(optional_col.type, our_type)
+ is_true(optional_col.nullable)
+
+
class MappedOneArg(MappedCollection[str, _R]):
pass