diff options
-rw-r--r-- | doc/build/changelog/unreleased_20/8688.rst | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 18 | ||||
-rw-r--r-- | test/orm/declarative/test_dc_transforms.py | 26 |
3 files changed, 49 insertions, 3 deletions
diff --git a/doc/build/changelog/unreleased_20/8688.rst b/doc/build/changelog/unreleased_20/8688.rst new file mode 100644 index 000000000..7ae4d2b0d --- /dev/null +++ b/doc/build/changelog/unreleased_20/8688.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 8688 + + Fixed issue with new dataclass mapping feature where arguments passed to + the dataclasses API could sometimes be mis-ordered when dealing with mixins + that override :func:`_orm.mapped_column` declarations, leading to + initializer problems. diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index eed04025d..ef2c2f3c9 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -420,7 +420,7 @@ class _ClassScanMapperConfig(_MapperConfig): registry: _RegistryType clsdict_view: _ClassDict - collected_annotations: Dict[str, Tuple[Any, Any, Any, bool]] + collected_annotations: Dict[str, Tuple[Any, Any, Any, bool, Any]] collected_attributes: Dict[str, Any] local_table: Optional[FromClause] persist_selectable: Optional[FromClause] @@ -831,7 +831,6 @@ class _ClassScanMapperConfig(_MapperConfig): # acting like that for now. if isinstance(obj, (Column, MappedColumn)): - self._collect_annotation(name, annotation, True, obj) # already copied columns to the mapped class. continue elif isinstance(obj, MapperProperty): @@ -1000,6 +999,7 @@ class _ClassScanMapperConfig(_MapperConfig): mapped_container, mapped_anno, is_dc, + attr_value, ) in self.collected_annotations.items() ) ] @@ -1018,6 +1018,7 @@ class _ClassScanMapperConfig(_MapperConfig): for k, v in defaults.items(): setattr(self.cls, k, v) + self.cls.__annotations__ = annotations self._assert_dc_arguments(dataclass_setup_arguments) @@ -1056,6 +1057,10 @@ class _ClassScanMapperConfig(_MapperConfig): expect_mapped: Optional[bool], attr_value: Any, ) -> Any: + + if name in self.collected_annotations: + return self.collected_annotations[name][4] + if raw_annotation is None: return attr_value @@ -1105,6 +1110,7 @@ class _ClassScanMapperConfig(_MapperConfig): mapped_container, extracted_mapped_annotation, is_dataclass, + attr_value, ) return attr_value @@ -1133,6 +1139,7 @@ class _ClassScanMapperConfig(_MapperConfig): # copy mixin columns to the mapped class for name, obj, annotation, is_dataclass in attributes_for_class(): + if ( not fixed_table and obj is None @@ -1146,6 +1153,9 @@ class _ClassScanMapperConfig(_MapperConfig): setattr(cls, name, obj) elif isinstance(obj, (Column, MappedColumn)): + + obj = self._collect_annotation(name, annotation, True, obj) + if attribute_is_overridden(name, obj): # if column has been overridden # (like by the InstrumentedAttribute of the @@ -1176,6 +1186,7 @@ class _ClassScanMapperConfig(_MapperConfig): locally_collected_attributes[name] = copy_ setattr(cls, name, copy_) + return locally_collected_attributes def _extract_mappable_attributes(self) -> None: @@ -1260,8 +1271,9 @@ class _ClassScanMapperConfig(_MapperConfig): mapped_container, extracted_mapped_annotation, is_dataclass, + attr_value, ) = self.collected_annotations.get( - k, (None, None, None, False) + k, (None, None, None, False, None) ) value.declarative_scan( self.registry, diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index bff9482ec..b467644bf 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -14,6 +14,7 @@ from unittest import mock from typing_extensions import Annotated +from sqlalchemy import BigInteger from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import ForeignKey @@ -553,6 +554,31 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): eq_(fas.args, ["self", "id"]) eq_(fas.kwonlyargs, ["data"]) + def test_mapped_column_overrides(self, dc_decl_base): + """test #8688""" + + class TriggeringMixin: + mixin_value: Mapped[int] = mapped_column(BigInteger) + + class NonTriggeringMixin: + mixin_value: Mapped[int] + + class Foo(dc_decl_base, TriggeringMixin): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + foo_value: Mapped[float] = mapped_column(default=78) + + class Bar(dc_decl_base, NonTriggeringMixin): + __tablename__ = "bar" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + bar_value: Mapped[float] = mapped_column(default=78) + + f1 = Foo(mixin_value=5) + eq_(f1.foo_value, 78) + + b1 = Bar(mixin_value=5) + eq_(b1.bar_value, 78) + class RelationshipDefaultFactoryTest(fixtures.TestBase): def test_list(self, dc_decl_base: Type[MappedAsDataclass]): |