diff options
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 17 | ||||
| -rw-r--r-- | test/orm/declarative/test_tm_future_annotations.py | 82 |
3 files changed, 102 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 7d7175678..3d9fe578d 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -51,9 +51,11 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation +from ..util.typing import de_stringify_union_elements from ..util.typing import is_fwd_ref from ..util.typing import is_optional_union from ..util.typing import is_pep593 +from ..util.typing import is_union from ..util.typing import Self from ..util.typing import typing_get_args @@ -655,6 +657,9 @@ class MappedColumn( if is_fwd_ref(argument): argument = de_stringify_annotation(cls, argument) + if is_union(argument): + argument = de_stringify_union_elements(cls, argument) + nullable = is_optional_union(argument) if not self._has_nullable: @@ -690,6 +695,7 @@ class MappedColumn( checks = (our_type,) for check_type in checks: + if registry.type_annotation_map: new_sqltype = registry.type_annotation_map.get(check_type) if new_sqltype is None: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 85c1bae72..a0d59a630 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -120,6 +120,19 @@ def de_stringify_annotation( return annotation # type: ignore +def de_stringify_union_elements( + cls: Type[Any], + annotation: _AnnotationScanType, + str_cleanup_fn: Optional[Callable[[str], str]] = None, +) -> Type[Any]: + return make_union_type( + *[ + de_stringify_annotation(cls, anno, str_cleanup_fn) + for anno in annotation.__args__ # type: ignore + ] + ) + + def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated @@ -186,7 +199,7 @@ def expand_unions( return (type_,) -def is_optional(type_): +def is_optional(type_: Any) -> bool: return is_origin_of( type_, "Optional", @@ -199,7 +212,7 @@ def is_optional_union(type_: Any) -> bool: return is_optional(type_) and NoneType in typing_get_args(type_) -def is_union(type_): +def is_union(type_: Any) -> bool: return is_origin_of(type_, "Union") diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index 74cbebb4d..76ee464fa 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,13 +1,19 @@ from __future__ import annotations +from decimal import Decimal from typing import List +from typing import Optional from typing import Set from typing import TypeVar +from typing import Union from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy import Numeric +from sqlalchemy import Table from sqlalchemy.orm import attribute_mapped_collection +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedCollection @@ -16,7 +22,8 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true -from .test_typed_mapping import MappedColumnTest # noqa +from sqlalchemy.util import compat +from .test_typed_mapping import MappedColumnTest as _MappedColumnTest from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest """runs the annotation-sensitive tests from test_typed_mappings while @@ -28,6 +35,79 @@ having ``from __future__ import annotations`` in effect. _R = TypeVar("_R") +class MappedColumnTest(_MappedColumnTest): + def test_unions(self): + our_type = Numeric(10, 2) + + class Base(DeclarativeBase): + type_annotation_map = {Union[float, Decimal]: our_type} + + class User(Base): + __tablename__ = "users" + __table__: Table + + id: Mapped[int] = mapped_column(primary_key=True) + + data: Mapped[Union[float, Decimal]] = mapped_column() + reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + + optional_data: Mapped[ + Optional[Union[float, Decimal]] + ] = mapped_column() + + # use Optional directly + reverse_optional_data: Mapped[ + Optional[Union[Decimal, float]] + ] = mapped_column() + + # use Union with None, same as Optional but presents differently + # (Optional object with __origin__ Union vs. Union) + reverse_u_optional_data: Mapped[ + Union[Decimal, float, None] + ] = mapped_column() + + float_data: Mapped[float] = mapped_column() + decimal_data: Mapped[Decimal] = mapped_column() + + if compat.py310: + pep604_data: Mapped[float | Decimal] = mapped_column() + pep604_reverse: Mapped[Decimal | float] = mapped_column() + pep604_optional: Mapped[ + Decimal | float | None + ] = mapped_column() + pep604_data_fwd: Mapped["float | Decimal"] = mapped_column() + pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column() + pep604_optional_fwd: Mapped[ + "Decimal | float | None" + ] = mapped_column() + + is_(User.__table__.c.data.type, our_type) + is_false(User.__table__.c.data.nullable) + is_(User.__table__.c.reverse_data.type, our_type) + is_(User.__table__.c.optional_data.type, our_type) + is_true(User.__table__.c.optional_data.nullable) + + is_(User.__table__.c.reverse_optional_data.type, our_type) + is_(User.__table__.c.reverse_u_optional_data.type, our_type) + is_true(User.__table__.c.reverse_optional_data.nullable) + is_true(User.__table__.c.reverse_u_optional_data.nullable) + + is_(User.__table__.c.float_data.type, our_type) + is_(User.__table__.c.decimal_data.type, our_type) + + if compat.py310: + for suffix in ("", "_fwd"): + data_col = User.__table__.c[f"pep604_data{suffix}"] + reverse_col = User.__table__.c[f"pep604_reverse{suffix}"] + optional_col = User.__table__.c[f"pep604_optional{suffix}"] + is_(data_col.type, our_type) + is_false(data_col.nullable) + is_(reverse_col.type, our_type) + is_false(reverse_col.nullable) + is_(optional_col.type, our_type) + is_true(optional_col.nullable) + + class MappedOneArg(MappedCollection[str, _R]): pass |
