diff options
-rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/_orm_constructors.py | 167 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 142 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 315 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 48 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/instrumentation.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 97 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/properties.py | 63 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 45 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 25 | ||||
-rw-r--r-- | lib/sqlalchemy/util/compat.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/util/typing.py | 8 | ||||
-rw-r--r-- | pyproject.toml | 1 | ||||
-rw-r--r-- | setup.cfg | 2 | ||||
-rw-r--r-- | test/orm/declarative/test_dc_transforms.py | 816 | ||||
-rw-r--r-- | test/orm/declarative/test_typed_mapping.py | 46 |
17 files changed, 1661 insertions, 163 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index b7d1df532..4f19ba946 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -60,6 +60,7 @@ from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta from .decl_api import DeclarativeMeta as DeclarativeMeta from .decl_api import declared_attr as declared_attr from .decl_api import has_inherited_table as has_inherited_table +from .decl_api import MappedAsDataclass as MappedAsDataclass from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for from .descriptor_props import Composite as Composite diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 0692cac09..ece6a52be 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -21,9 +21,9 @@ from typing import Union from . import mapperlib as mapperlib from ._typing import _O -from .base import Mapped from .descriptor_props import Composite from .descriptor_props import Synonym +from .interfaces import _AttributeOptions from .properties import ColumnProperty from .properties import MappedColumn from .query import AliasOption @@ -37,6 +37,8 @@ from .util import LoaderCriteriaOption from .. import sql from .. import util from ..exc import InvalidRequestError +from ..sql._typing import _no_kw +from ..sql.base import _NoArg from ..sql.base import SchemaEventTarget from ..sql.schema import SchemaConst from ..sql.selectable import FromClause @@ -105,6 +107,10 @@ def mapped_column( Union[_TypeEngineArgument[Any], SchemaEventTarget] ] = None, *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] ] = SchemaConst.NULL_UNSPECIFIED, @@ -113,7 +119,6 @@ def mapped_column( name: Optional[str] = None, type_: Optional[_TypeEngineArgument[Any]] = None, autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, doc: Optional[str] = None, key: Optional[str] = None, index: Optional[bool] = None, @@ -300,6 +305,12 @@ def mapped_column( type_=type_, autoincrement=autoincrement, default=default, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), doc=doc, key=key, index=index, @@ -325,6 +336,10 @@ def column_property( deferred: bool = False, raiseload: bool = False, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, @@ -416,6 +431,12 @@ def column_property( return ColumnProperty( column, *additional_columns, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), group=group, deferred=deferred, raiseload=raiseload, @@ -429,25 +450,61 @@ def column_property( @overload def composite( - class_: Type[_CC], + _class_or_attr: Type[_CC], *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[_CC]: ... @overload def composite( + _class_or_attr: _CompositeAttrType[Any], *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[Any]: ... def composite( - class_: Any = None, + _class_or_attr: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. @@ -497,7 +554,26 @@ def composite( :attr:`.MapperProperty.info` attribute of this object. """ - return Composite(class_, *attrs, **kwargs) + if __kw: + raise _no_kw() + + return Composite( + _class_or_attr, + *attrs, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + info=info, + doc=doc, + ) def with_loader_criteria( @@ -700,6 +776,10 @@ def relationship( post_update: bool = False, cascade: str = "save-update, merge", viewonly: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -1532,6 +1612,12 @@ def relationship( post_update=post_update, cascade=cascade, viewonly=viewonly, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), lazy=lazy, passive_deletes=passive_deletes, passive_updates=passive_updates, @@ -1559,6 +1645,10 @@ def synonym( map_column: Optional[bool] = None, descriptor: Optional[Any] = None, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, info: Optional[_InfoType] = None, doc: Optional[str] = None, ) -> Synonym[Any]: @@ -1670,6 +1760,12 @@ def synonym( map_column=map_column, descriptor=descriptor, comparator_factory=comparator_factory, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), doc=doc, info=info, ) @@ -1784,7 +1880,17 @@ def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument: def deferred( column: _ORMColumnExprArgument[_T], *additional_columns: _ORMColumnExprArgument[Any], - **kw: Any, + group: Optional[str] = None, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, ) -> ColumnProperty[_T]: r"""Indicate a column-based mapped attribute that by default will not load unless accessed. @@ -1803,21 +1909,41 @@ def deferred( :ref:`deferred_raiseload` - :param \**kw: additional keyword arguments passed to - :class:`.ColumnProperty`. + Additional arguments are the same as that of :func:`_orm.column_property`. .. seealso:: :ref:`deferred` """ - kw["deferred"] = True - return ColumnProperty(column, *additional_columns, **kw) + return ColumnProperty( + column, + *additional_columns, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), + group=group, + deferred=True, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) def query_expression( default_expr: _ORMColumnExprArgument[_T] = sql.null(), -) -> Mapped[_T]: + *, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -1840,7 +1966,18 @@ def query_expression( :ref:`mapper_querytime_expression` """ - prop = ColumnProperty(default_expr) + prop = ColumnProperty( + default_expr, + attribute_options=_AttributeOptions( + _NoArg.NO_ARG, + repr, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + ), + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) prop.strategy_key = (("query_expression", True),) return prop diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 1c343b04c..553a50107 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -33,6 +33,13 @@ from . import clsregistry from . import instrumentation from . import interfaces from . import mapperlib +from ._orm_constructors import column_property +from ._orm_constructors import composite +from ._orm_constructors import deferred +from ._orm_constructors import mapped_column +from ._orm_constructors import query_expression +from ._orm_constructors import relationship +from ._orm_constructors import synonym from .attributes import InstrumentedAttribute from .base import _inspect_mapped_class from .base import Mapped @@ -42,8 +49,13 @@ from .decl_base import _declarative_constructor from .decl_base import _DeferredMapperConfig from .decl_base import _del_attribute from .decl_base import _mapper +from .descriptor_props import Composite +from .descriptor_props import Synonym from .descriptor_props import Synonym as _orm_synonym from .mapper import Mapper +from .properties import ColumnProperty +from .properties import MappedColumn +from .relationships import Relationship from .state import InstanceState from .. import exc from .. import inspection @@ -60,9 +72,9 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _O from ._typing import _RegistryType - from .descriptor_props import Synonym from .instrumentation import ClassManager from .interfaces import MapperProperty + from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) @@ -120,6 +132,26 @@ class DeclarativeAttributeIntercept( """ +@compat_typing.dataclass_transform( + field_descriptors=( + MappedColumn[Any], + Relationship[Any], + Composite[Any], + ColumnProperty[Any], + Synonym[Any], + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), +) +class DCTransformDeclarative(DeclarativeAttributeIntercept): + """metaclass that includes @dataclass_transforms""" + + class DeclarativeMeta( _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] ): @@ -543,12 +575,42 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]): cls._sa_registry.map_declaratively(cls) +class MappedAsDataclass(metaclass=DCTransformDeclarative): + """Mixin class to indicate when mapping this class, also convert it to be + a dataclass. + + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + + .. versionadded:: 2.0 + """ + + def __init_subclass__( + cls, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> None: + cls._sa_apply_dc_transforms = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + } + super().__init_subclass__() + + class DeclarativeBase( inspection.Inspectable[InstanceState[Any]], metaclass=DeclarativeAttributeIntercept, ): """Base class used for declarative class definitions. + The :class:`_orm.DeclarativeBase` allows for the creation of new declarative bases in such a way that is compatible with type checkers:: @@ -1121,7 +1183,7 @@ class registry: bases = not isinstance(cls, tuple) and (cls,) or cls - class_dict = dict(registry=self, metadata=metadata) + class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata) if isinstance(cls, type): class_dict["__doc__"] = cls.__doc__ @@ -1142,6 +1204,78 @@ class registry: return metaclass(name, bases, class_dict) + @compat_typing.dataclass_transform( + field_descriptors=( + MappedColumn[Any], + Relationship[Any], + Composite[Any], + ColumnProperty[Any], + Synonym[Any], + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), + ) + @overload + def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: + ... + + @overload + def mapped_as_dataclass( + self, + __cls: Literal[None] = ..., + *, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> Callable[[Type[_O]], Type[_O]]: + ... + + def mapped_as_dataclass( + self, + __cls: Optional[Type[_O]] = None, + *, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: + """Class decorator that will apply the Declarative mapping process + to a given class, and additionally convert the class to be a + Python dataclass. + + .. seealso:: + + :meth:`_orm.registry.mapped` + + .. versionadded:: 2.0 + + + """ + + def decorate(cls: Type[_O]) -> Type[_O]: + cls._sa_apply_dc_transforms = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + } + _as_declarative(self, cls, cls.__dict__) + return cls + + if __cls: + return decorate(__cls) + else: + return decorate + def mapped(self, cls: Type[_O]) -> Type[_O]: """Class decorator that will apply the Declarative mapping process to a given class. @@ -1174,6 +1308,10 @@ class registry: that will apply Declarative mapping to subclasses automatically using a Python metaclass. + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + """ _as_declarative(self, cls, cls.__dict__) return cls diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a66421e22..54a272f86 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -10,6 +10,8 @@ from __future__ import annotations import collections +import dataclasses +import re from typing import Any from typing import Callable from typing import cast @@ -40,6 +42,7 @@ from .base import _is_mapped_class from .base import InspectionAttr from .descriptor_props import Composite from .descriptor_props import Synonym +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import _MappedAttribute from .interfaces import _MapsColumns @@ -48,15 +51,18 @@ from .mapper import Mapper as mapper from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn +from .util import _extract_mapped_subtype from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc from .. import util from ..sql import expression +from ..sql.base import _NoArg from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +from ..util.typing import _AnnotationScanType from ..util.typing import Protocol if TYPE_CHECKING: @@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig): "mapper_args", "mapper_args_fn", "inherits", + "allow_dataclass_fields", + "dataclass_setup_arguments", ) registry: _RegistryType clsdict_view: _ClassDict - collected_annotations: Dict[str, Tuple[Any, bool]] + collected_annotations: Dict[str, Tuple[Any, Any, bool]] collected_attributes: Dict[str, Any] local_table: Optional[FromClause] persist_selectable: Optional[FromClause] @@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] inherits: Optional[Type[Any]] + dataclass_setup_arguments: Optional[Dict[str, Any]] + """if the class has SQLAlchemy native dataclass parameters, where + we will create a SQLAlchemy dataclass (not a real dataclass). + + """ + + allow_dataclass_fields: bool + """if true, look for dataclass-processed Field objects on the target + class as well as superclasses and extract ORM mapping directives from + the "metadata" attribute of each Field""" + def __init__( self, registry: _RegistryType, @@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig): self.declared_columns = util.OrderedSet() self.column_copies = {} + self.dataclass_setup_arguments = dca = getattr( + self.cls, "_sa_apply_dc_transforms", None + ) + + cld = dataclasses.is_dataclass(cls_) + + sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__") + + # we don't want to consume Field objects from a not-already-dataclass. + # the Field objects won't have their "name" or "type" populated, + # and while it seems like we could just set these on Field as we + # read them, Field is documented as "user read only" and we need to + # stay far away from any off-label use of dataclasses APIs. + if (not cld or dca) and sdk: + raise exc.InvalidRequestError( + "SQLAlchemy mapped dataclasses can't consume mapping " + "information from dataclass.Field() objects if the immediate " + "class is not already a dataclass." + ) + + # if already a dataclass, and __sa_dataclass_metadata_key__ present, + # then also look inside of dataclass.Field() objects yielded by + # dataclasses.get_fields(cls) when scanning for attributes + self.allow_dataclass_fields = bool(sdk and cld) + self._setup_declared_events() self._scan_attributes() + self._setup_dataclasses_transforms() + with mapperlib._CONFIGURE_MUTEX: clsregistry.add_class( self.classname, self.cls, registry._class_registry @@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig): attribute, taking SQLAlchemy-enabled dataclass fields into account. """ - sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - if sa_dataclass_metadata_key is None: + if self.allow_dataclass_fields: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def attribute_is_overridden(key: str, obj: Any) -> bool: return getattr(cls, key) is not obj @@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig): "__dict__", "__weakref__", "_sa_class_manager", + "_sa_apply_dc_transforms", "__dict__", "__weakref__", ] @@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig): adjusting for SQLAlchemy fields embedded in dataclass fields. """ - sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - cls_annotations = util.get_annotations(cls) cls_vars = vars(cls) @@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig): names = util.merge_lists_w_ordering( [n for n in cls_vars if n not in skip], list(cls_annotations) ) - if sa_dataclass_metadata_key is None: + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def local_attributes_for_class() -> Iterable[ Tuple[str, Any, Any, bool] @@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig): name, obj, annotation, - is_dataclass, + is_dataclass_field, ) in local_attributes_for_class(): - if name == "__mapper_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not mapper_args_fn and (not class_mapped or check_decl): - # don't even invoke __mapper_args__ until - # after we've determined everything about the - # mapped table. - # make a copy of it so a class-level dictionary - # is not overwritten when we update column-based - # arguments. - def _mapper_args_fn() -> Dict[str, Any]: - return dict(cls_as_Decl.__mapper_args__) - - mapper_args_fn = _mapper_args_fn - - elif name == "__tablename__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not tablename and (not class_mapped or check_decl): - tablename = cls_as_Decl.__tablename__ - elif name == "__table_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not table_args and (not class_mapped or check_decl): - table_args = cls_as_Decl.__table_args__ - if not isinstance( - table_args, (tuple, dict, type(None)) + if re.match(r"^__.+__$", name): + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and ( + not class_mapped or check_decl ): - raise exc.ArgumentError( - "__table_args__ value must be a tuple, " - "dict, or None" - ) - if base is not cls: - inherited_table_args = True + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn + + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): + tablename = cls_as_Decl.__tablename__ + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): + table_args = cls_as_Decl.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None)) + ): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None" + ) + if base is not cls: + inherited_table_args = True + else: + # skip all other dunder names + continue elif class_mapped: if _is_declarative_props(obj): util.warn( @@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig): # acting like that for now. if isinstance(obj, (Column, MappedColumn)): - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) # already copied columns to the mapped class. continue @@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig): ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: - if is_dataclass: + if is_dataclass_field: # access attribute using normal class access # first, to see if it's been mapped on a # superclass. note if the dataclasses.field() @@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret.doc = obj.__doc__ - self.collected_annotations[name] = ( + self._collect_annotation( + name, obj._collect_return_annotation(), False, + True, + obj, ) elif _is_mapped_annotation(annotation, cls): - self.collected_annotations[name] = ( - annotation, - is_dataclass, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) if obj is None: if not fixed_table: @@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig): # declarative mapping. however, check for some # more common mistakes self._warn_for_decl_attributes(base, name, obj) - elif is_dataclass and ( + elif is_dataclass_field and ( name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class @@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig): obj = obj.fget() collected_attributes[name] = obj - self.collected_annotations[name] = ( - annotation, - True, + self._collect_annotation( + name, annotation, True, False, obj ) else: - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, False, None, obj ) if ( obj is None @@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig): collected_attributes[name] = MappedColumn() elif name in clsdict_view: collected_attributes[name] = obj + # else if the name is not in the cls.__dict__, + # don't collect it as an attribute. + # we will see the annotation only, which is meaningful + # both for mapping and dataclasses setup if inherited_table_args and not tablename: table_args = None @@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig): self.tablename = tablename self.mapper_args_fn = mapper_args_fn + def _setup_dataclasses_transforms(self) -> None: + + dataclass_setup_arguments = self.dataclass_setup_arguments + if not dataclass_setup_arguments: + return + + manager = instrumentation.manager_of_class(self.cls) + assert manager is not None + + field_list = [ + _AttributeOptions._get_arguments_for_make_dataclass( + key, + anno, + self.collected_attributes.get(key, _NoArg.NO_ARG), + ) + for key, anno in ( + (key, mapped_anno if mapped_anno else raw_anno) + for key, ( + raw_anno, + mapped_anno, + is_dc, + ) in self.collected_annotations.items() + ) + ] + + annotations = {} + defaults = {} + for item in field_list: + if len(item) == 2: + name, tp = item # type: ignore + elif len(item) == 3: + name, tp, spec = item # type: ignore + defaults[name] = spec + else: + assert False + annotations[name] = tp + + for k, v in defaults.items(): + setattr(self.cls, k, v) + self.cls.__annotations__ = annotations + + dataclasses.dataclass(self.cls, **dataclass_setup_arguments) + + def _collect_annotation( + self, + name: str, + raw_annotation: _AnnotationScanType, + is_dataclass: bool, + expect_mapped: Optional[bool], + attr_value: Any, + ) -> None: + + if expect_mapped is None: + expect_mapped = isinstance(attr_value, _MappedAttribute) + + extracted_mapped_annotation = _extract_mapped_subtype( + raw_annotation, + self.cls, + name, + type(attr_value), + required=False, + is_dataclass_field=False, + expect_mapped=expect_mapped and not self.allow_dataclass_fields, + ) + + self.collected_annotations[name] = ( + raw_annotation, + extracted_mapped_annotation, + is_dataclass, + ) + def _warn_for_decl_attributes( self, cls: Type[Any], key: str, c: Any ) -> None: @@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig): _undefer_column_name( k, self.column_copies.get(value, value) # type: ignore ) - elif isinstance(value, _IntrospectsAnnotations): - annotation, is_dataclass = self.collected_annotations.get( - k, (None, False) - ) - value.declarative_scan( - self.registry, cls, k, annotation, is_dataclass - ) + else: + if isinstance(value, _IntrospectsAnnotations): + ( + annotation, + extracted_mapped_annotation, + is_dataclass, + ) = self.collected_annotations.get(k, (None, None, False)) + value.declarative_scan( + self.registry, + cls, + k, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + + if ( + isinstance(value, (MapperProperty, _MapsColumns)) + and value._has_dataclass_arguments + and not self.dataclass_setup_arguments + ): + if isinstance(value, MapperProperty): + argnames = [ + "init", + "default_factory", + "repr", + "default", + ] + else: + argnames = ["init", "default_factory", "repr"] + + args = { + a + for a in argnames + if getattr( + value._attribute_options, f"dataclasses_{a}" + ) + is not _NoArg.NO_ARG + } + raise exc.ArgumentError( + f"Attribute '{k}' on class {cls} includes dataclasses " + f"argument(s): " + f"{', '.join(sorted(repr(a) for a in args))} but " + f"class does not specify " + "SQLAlchemy native dataclass configuration." + ) + our_stuff[k] = value def _extract_declared_columns(self) -> None: @@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig): # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) + for key, c in list(our_stuff.items()): if isinstance(c, _MapsColumns): @@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig): # otherwise, Mapper will map it under the column key. if mp_to_assign is None and key != col.key: our_stuff[key] = col - elif isinstance(c, Column): # undefer previously occurred here, and now occurs earlier. # ensure every column we get here has been named diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 8c89f96aa..a366a9534 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -35,11 +35,11 @@ from .base import LoaderCallableStatus from .base import Mapped from .base import PassiveFlag from .base import SQLORMOperations +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator -from .util import _extract_mapped_subtype from .util import _none_set from .. import event from .. import exc as sa_exc @@ -200,24 +200,26 @@ class Composite( def __init__( self, - class_: Union[ + _class_or_attr: Union[ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] ] = None, *attrs: _CompositeAttrType[Any], + attribute_options: Optional[_AttributeOptions] = None, active_history: bool = False, deferred: bool = False, group: Optional[str] = None, comparator_factory: Optional[Type[Comparator[_CC]]] = None, info: Optional[_InfoType] = None, + **kwargs: Any, ): - super().__init__() + super().__init__(attribute_options=attribute_options) - if isinstance(class_, (Mapped, str, sql.ColumnElement)): - self.attrs = (class_,) + attrs + if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)): + self.attrs = (_class_or_attr,) + attrs # will initialize within declarative_scan self.composite_class = None # type: ignore else: - self.composite_class = class_ # type: ignore + self.composite_class = _class_or_attr # type: ignore self.attrs = attrs self.active_history = active_history @@ -332,19 +334,15 @@ class Composite( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: - MappedColumn = util.preloaded.orm_properties.MappedColumn - - argument = _extract_mapped_subtype( - annotation, - cls, - key, - MappedColumn, - self.composite_class is None, - is_dataclass_field, - ) - + if ( + self.composite_class is None + and extracted_mapped_annotation is None + ): + self._raise_for_required(key, cls) + argument = extracted_mapped_annotation if argument and self.composite_class is None: if isinstance(argument, str) or hasattr( argument, "__forward_arg__" @@ -371,11 +369,18 @@ class Composite( for param, attr in itertools.zip_longest( insp.parameters.values(), self.attrs ): - if param is None or attr is None: + if param is None: raise sa_exc.ArgumentError( - f"number of arguments to {self.composite_class.__name__} " - f"class and number of attributes don't match" + f"number of composite attributes " + f"{len(self.attrs)} exceeds " + f"that of the number of attributes in class " + f"{self.composite_class.__name__} {len(insp.parameters)}" ) + if attr is None: + # fill in missing attr spots with empty MappedColumn + attr = MappedColumn() + self.attrs += (attr,) + if isinstance(attr, MappedColumn): attr.declarative_scan_for_composite( registry, cls, key, param.name, param.annotation @@ -800,10 +805,11 @@ class Synonym(DescriptorProperty[_T]): map_column: Optional[bool] = None, descriptor: Optional[Any] = None, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + attribute_options: Optional[_AttributeOptions] = None, info: Optional[_InfoType] = None, doc: Optional[str] = None, ): - super().__init__() + super().__init__(attribute_options=attribute_options) self.name = name self.map_column = map_column diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 4fa61b7ce..33de2aee9 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -113,6 +113,7 @@ class ClassManager( "previously known as deferred_scalar_loader" init_method: Optional[Callable[..., None]] + original_init: Optional[Callable[..., None]] = None factory: Optional[_ManagerFactory] @@ -229,7 +230,7 @@ class ClassManager( if finalize and not self._finalized: self._finalize() - def _finalize(self): + def _finalize(self) -> None: if self._finalized: return self._finalized = True @@ -238,14 +239,14 @@ class ClassManager( _instrumentation_factory.dispatch.class_instrument(self.class_) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return id(self) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return other is self @property - def is_mapped(self): + def is_mapped(self) -> bool: return "mapper" in self.__dict__ @HasMemoized.memoized_attribute diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b5569ce06..e0034061d 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -19,6 +19,7 @@ are exposed when inspecting mappings. from __future__ import annotations import collections +import dataclasses import typing from typing import Any from typing import Callable @@ -27,6 +28,8 @@ from typing import ClassVar from typing import Dict from typing import Iterator from typing import List +from typing import NamedTuple +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Set @@ -51,11 +54,13 @@ from .base import ONETOMANY as ONETOMANY # noqa: F401 from .base import RelationshipDirection as RelationshipDirection # noqa: F401 from .base import SQLORMOperations from .. import ColumnElement +from .. import exc as sa_exc from .. import inspection from .. import util from ..sql import operators from ..sql import roles from ..sql import visitors +from ..sql.base import _NoArg from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey from ..sql.schema import Column @@ -141,6 +146,7 @@ class _IntrospectsAnnotations: cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: """Perform class-specific initializaton at early declarative scanning @@ -150,6 +156,70 @@ class _IntrospectsAnnotations: """ + def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{self.__class__.__name__}" construct are None or not present' + ) + + +class _AttributeOptions(NamedTuple): + """define Python-local attribute behavior options common to all + :class:`.MapperProperty` objects. + + Currently this includes dataclass-generation arguments. + + .. versionadded:: 2.0 + + """ + + dataclasses_init: Union[_NoArg, bool] + dataclasses_repr: Union[_NoArg, bool] + dataclasses_default: Union[_NoArg, Any] + dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] + + def _as_dataclass_field(self) -> Any: + """Return a ``dataclasses.Field`` object given these arguments.""" + + kw: Dict[str, Any] = {} + if self.dataclasses_default_factory is not _NoArg.NO_ARG: + kw["default_factory"] = self.dataclasses_default_factory + if self.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = self.dataclasses_default + if self.dataclasses_init is not _NoArg.NO_ARG: + kw["init"] = self.dataclasses_init + if self.dataclasses_repr is not _NoArg.NO_ARG: + kw["repr"] = self.dataclasses_repr + + return dataclasses.field(**kw) + + @classmethod + def _get_arguments_for_make_dataclass( + cls, key: str, annotation: Type[Any], elem: _T + ) -> Union[ + Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]] + ]: + """given attribute key, annotation, and value from a class, return + the argument tuple we would pass to dataclasses.make_dataclass() + for this attribute. + + """ + if isinstance(elem, (MapperProperty, _MapsColumns)): + dc_field = elem._attribute_options._as_dataclass_field() + + return (key, annotation, dc_field) + elif elem is not _NoArg.NO_ARG: + # why is typing not erroring on this? + return (key, annotation, elem) + else: + return (key, annotation) + + +_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions( + _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG +) + class _MapsColumns(_MappedAttribute[_T]): """interface for declarative-capable construct that delivers one or more @@ -158,6 +228,9 @@ class _MapsColumns(_MappedAttribute[_T]): __slots__ = () + _attribute_options: _AttributeOptions + _has_dataclass_arguments: bool + @property def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: """return a MapperProperty to be assigned to the declarative mapping""" @@ -199,6 +272,8 @@ class MapperProperty( __slots__ = ( "_configure_started", "_configure_finished", + "_attribute_options", + "_has_dataclass_arguments", "parent", "key", "info", @@ -241,6 +316,15 @@ class MapperProperty( doc: Optional[str] """optional documentation string""" + _attribute_options: _AttributeOptions + """behavioral options for ORM-enabled Python attributes + + .. versionadded:: 2.0 + + """ + + _has_dataclass_arguments: bool + def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.InspectionAttr`. @@ -349,9 +433,20 @@ class MapperProperty( """ - def __init__(self) -> None: + def __init__( + self, attribute_options: Optional[_AttributeOptions] = None + ) -> None: self._configure_started = False self._configure_finished = False + if ( + attribute_options + and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS + ): + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS def init(self) -> None: """Called after all mappers are created to assemble diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ad3e9f248..7655f3ae2 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -30,13 +30,14 @@ from . import strategy_options from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty from .descriptor_props import Synonym +from .interfaces import _AttributeOptions +from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import Relationship -from .util import _extract_mapped_subtype from .util import _orm_full_deannotate from .. import exc as sa_exc from .. import ForeignKey @@ -45,6 +46,7 @@ from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes +from ..sql.base import _NoArg from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.schema import SchemaConst @@ -131,6 +133,7 @@ class ColumnProperty( self, column: _ORMColumnExprArgument[_T], *additional_columns: _ORMColumnExprArgument[Any], + attribute_options: Optional[_AttributeOptions] = None, group: Optional[str] = None, deferred: bool = False, raiseload: bool = False, @@ -141,7 +144,9 @@ class ColumnProperty( doc: Optional[str] = None, _instrument: bool = True, ): - super(ColumnProperty, self).__init__() + super(ColumnProperty, self).__init__( + attribute_options=attribute_options + ) columns = (column,) + additional_columns self._orig_columns = [ coercions.expect(roles.LabeledColumnExprRole, c) for c in columns @@ -193,6 +198,7 @@ class ColumnProperty( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: column = self.columns[0] @@ -487,13 +493,38 @@ class MappedColumn( "foreign_keys", "_has_nullable", "deferred", + "_attribute_options", + "_has_dataclass_arguments", ) deferred: bool column: Column[_T] foreign_keys: Optional[Set[ForeignKey]] + _attribute_options: _AttributeOptions def __init__(self, *arg: Any, **kw: Any): + self._attribute_options = attr_opts = kw.pop( + "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS + ) + + self._has_dataclass_arguments = False + + if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS: + if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG: + self._has_dataclass_arguments = True + kw["default"] = attr_opts.dataclasses_default_factory + elif attr_opts.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = attr_opts.dataclasses_default + + if ( + attr_opts.dataclasses_init is not _NoArg.NO_ARG + or attr_opts.dataclasses_repr is not _NoArg.NO_ARG + ): + self._has_dataclass_arguments = True + + if "default" in kw and kw["default"] is _NoArg.NO_ARG: + kw.pop("default") + self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys @@ -509,13 +540,19 @@ class MappedColumn( new.deferred = self.deferred new.foreign_keys = new.column.foreign_keys new._has_nullable = self._has_nullable + new._attribute_options = self._attribute_options + new._has_dataclass_arguments = self._has_dataclass_arguments util.set_creation_order(new) return new @property def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: if self.deferred: - return ColumnProperty(self.column, deferred=True) + return ColumnProperty( + self.column, + deferred=True, + attribute_options=self._attribute_options, + ) else: return None @@ -543,6 +580,7 @@ class MappedColumn( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: column = self.column @@ -553,18 +591,15 @@ class MappedColumn( sqltype = column.type - argument = _extract_mapped_subtype( - annotation, - cls, - key, - MappedColumn, - sqltype._isnull and not self.column.foreign_keys, - is_dataclass_field, - ) - if argument is None: - return + if extracted_mapped_annotation is None: + if sqltype._isnull and not self.column.foreign_keys: + self._raise_for_required(key, cls) + else: + return - self._init_column_for_annotation(cls, registry, argument) + self._init_column_for_annotation( + cls, registry, extracted_mapped_annotation + ) @util.preload_module("sqlalchemy.orm.decl_base") def declarative_scan_for_composite( diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 1186f0f54..deaf52147 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -49,6 +49,7 @@ from .base import class_mapper from .base import LoaderCallableStatus from .base import PassiveFlag from .base import state_str +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE @@ -56,7 +57,6 @@ from .interfaces import ONETOMANY from .interfaces import PropComparator from .interfaces import RelationshipDirection from .interfaces import StrategizedProperty -from .util import _extract_mapped_subtype from .util import _orm_annotate from .util import _orm_deannotate from .util import CascadeOptions @@ -355,6 +355,7 @@ class Relationship( post_update: bool = False, cascade: str = "save-update, merge", viewonly: bool = False, + attribute_options: Optional[_AttributeOptions] = None, lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -380,7 +381,7 @@ class Relationship( _local_remote_pairs: Optional[_ColumnPairs] = None, _legacy_inactive_history_style: bool = False, ): - super(Relationship, self).__init__() + super(Relationship, self).__init__(attribute_options=attribute_options) self.uselist = uselist self.argument = argument @@ -1701,18 +1702,19 @@ class Relationship( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: - argument = _extract_mapped_subtype( - annotation, - cls, - key, - Relationship, - self.argument is None, - is_dataclass_field, - ) - if argument is None: - return + argument = extracted_mapped_annotation + + if extracted_mapped_annotation is None: + + if self.argument is None: + self._raise_for_required(key, cls) + else: + return + + argument = extracted_mapped_annotation if hasattr(argument, "__origin__"): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index c50cc5bac..520c95672 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any: def _is_mapped_annotation( - raw_annotation: Union[type, str], cls: Type[Any] + raw_annotation: _AnnotationScanType, cls: Type[Any] ) -> bool: annotated = de_stringify_annotation(cls, raw_annotation) return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") @@ -1969,9 +1969,14 @@ def _extract_mapped_subtype( attr_cls: Type[Any], required: bool, is_dataclass_field: bool, - superclasses: Optional[Tuple[Type[Any], ...]] = None, + expect_mapped: bool = True, ) -> Optional[Union[type, str]]: + """given an annotation, figure out if it's ``Mapped[something]`` and if + so, return the ``something`` part. + Includes error raise scenarios and other options. + + """ if raw_annotation is None: if required: @@ -1989,25 +1994,29 @@ def _extract_mapped_subtype( if is_dataclass_field: return annotated else: - # TODO: there don't seem to be tests for the failure - # conditions here - if not hasattr(annotated, "__origin__") or ( - not issubclass( - annotated.__origin__, # type: ignore - superclasses if superclasses else attr_cls, - ) - and not issubclass(attr_cls, annotated.__origin__) # type: ignore + if not hasattr(annotated, "__origin__") or not is_origin_of( + annotated, "Mapped", module="sqlalchemy.orm" ): - our_annotated_str = ( - annotated.__name__ + anno_name = ( + getattr(annotated, "__name__", None) if not isinstance(annotated, str) - else repr(annotated) - ) - raise sa_exc.ArgumentError( - f'Type annotation for "{cls.__name__}.{key}" should use the ' - f'syntax "Mapped[{our_annotated_str}]" or ' - f'"{attr_cls.__name__}[{our_annotated_str}]".' + else None ) + if anno_name is None: + our_annotated_str = repr(annotated) + else: + our_annotated_str = anno_name + + if expect_mapped: + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "should use the " + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]".' + ) + + else: + return annotated if len(annotated.__args__) != 1: # type: ignore raise sa_exc.ArgumentError( diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 53f76f3ce..d4e4d2dca 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -25,6 +25,7 @@ from .. import event from .. import util from ..orm import declarative_base from ..orm import DeclarativeBase +from ..orm import MappedAsDataclass from ..orm import registry from ..schema import sort_tables_and_constraints @@ -90,7 +91,14 @@ class TestBase: @config.fixture() def registry(self, metadata): - reg = registry(metadata=metadata) + reg = registry( + metadata=metadata, + type_annotation_map={ + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + }, + ) yield reg reg.dispose() @@ -109,6 +117,21 @@ class TestBase: yield Base Base.registry.dispose() + @config.fixture + def dc_decl_base(self, metadata): + _md = metadata + + class Base(MappedAsDataclass, DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + @config.fixture() def future_connection(self, future_engine, connection): # integrate the future_engine and connection fixtures so diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index adbbf143f..4ce1e7ff3 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -230,7 +230,11 @@ def inspect_formatargspec( def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated - with a class.""" + with a class as an already processed dataclass. + + The class must **already be a dataclass** for Field objects to be returned. + + """ if dataclasses.is_dataclass(cls): return dataclasses.fields(cls) @@ -240,7 +244,12 @@ def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated with - a class, excluding those that originate from a superclass.""" + an already processed dataclass, excluding those that originate from a + superclass. + + The class must **already be a dataclass** for Field objects to be returned. + + """ if dataclasses.is_dataclass(cls): super_fields: Set[dataclasses.Field[Any]] = set() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 44e26f609..454de100b 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -23,6 +23,14 @@ from typing_extensions import NotRequired as NotRequired # noqa: F401 from . import compat + +# more zimports issues +if True: + from typing_extensions import ( # noqa: F401 + dataclass_transform as dataclass_transform, + ) + + _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT") _KT_co = TypeVar("_KT_co", covariant=True) diff --git a/pyproject.toml b/pyproject.toml index 29d59ea69..812d60e91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ markers = [ [tool.pyright] - reportPrivateUsage = "none" reportUnusedClass = "none" reportUnusedFunction = "none" @@ -37,7 +37,7 @@ package_dir = install_requires = importlib-metadata;python_version<"3.8" greenlet != 0.4.17;(platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32')))))) - typing-extensions >= 4 + typing-extensions >= 4.1.0 [options.extras_require] asyncio = diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py new file mode 100644 index 000000000..aac873723 --- /dev/null +++ b/test/orm/declarative/test_dc_transforms.py @@ -0,0 +1,816 @@ +import dataclasses +import inspect as pyinspect +from typing import Any +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from unittest import mock + +from sqlalchemy import Column +from sqlalchemy import exc +from sqlalchemy import ForeignKey +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy.orm import column_property +from sqlalchemy.orm import composite +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import deferred +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass +from sqlalchemy.orm import MappedColumn +from sqlalchemy.orm import registry as _RegistryType +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy.orm import synonym +from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_regex +from sqlalchemy.testing import expect_raises +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true +from sqlalchemy.testing import ne_ + + +class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): + def test_basic_constructor_repr_base_cls( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(None, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(None,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " + "some_module.B(id=None, data='data2', a_id=None, x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + def test_basic_constructor_repr_cls_decorator( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass() + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(None, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(None,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " + "some_module.B(id=None, data='data2', a_id=None, x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(default="d1") + data2: Mapped[str] = mapped_column(default_factory=lambda: "d2") + + a1 = A() + eq_(a1.data, "d1") + eq_(a1.data2, "d2") + + def test_default_factory_vs_collection_class( + self, dc_decl_base: Type[MappedAsDataclass] + ): + # this is currently the error raised by dataclasses. We can instead + # do this validation ourselves, but overall I don't know that we + # can hit every validation and rule that's in dataclasses + with expect_raises_message( + ValueError, "cannot specify both default and default_factory" + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column( + default="d1", default_factory=lambda: "d2" + ) + + def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]): + class Person(dc_decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + name: Mapped[str] + type: Mapped[str] = mapped_column(init=False) + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + + e1 = Engineer("nm", "st", "en", "pl") + eq_(e1.name, "nm") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]): + """We will be telling users "this is a dataclass that is also + mapped". Therefore, they will want *any* kind of attribute to do what + it would normally do in a dataclass, including normal types without any + field and explicit use of dataclasses.field(). additionally, we'd like + ``Mapped`` to mean "persist this attribute". So the absence of + ``Mapped`` should also mean something too. + + """ + + class A(dc_decl_base): + __tablename__ = "a" + + ctrl_one: str + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + some_field: int = dataclasses.field(default=5) + + some_none_field: Optional[str] = None + + a1 = A("ctrlone", "datafield") + eq_(a1.some_field, 5) + eq_(a1.some_none_field, None) + + # only Mapped[] is mapped + self.assert_compile(select(A), "SELECT a.id, a.data FROM a") + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=[ + "self", + "ctrl_one", + "data", + "some_field", + "some_none_field", + ], + varargs=None, + varkw=None, + defaults=(5, None), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + def test_dc_on_top_of_non_dc(self, decl_base: Type[DeclarativeBase]): + class Person(decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + type: Mapped[str] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(MappedAsDataclass, Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + + e1 = Engineer("st", "en", "pl") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + eq_( + pyinspect.getfullargspec(Person.__init__), + # the boring **kw __init__ + pyinspect.FullArgSpec( + args=["self"], + varargs=None, + varkw="kwargs", + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + eq_( + pyinspect.getfullargspec(Engineer.__init__), + # the exciting dataclasses __init__ + pyinspect.FullArgSpec( + args=["self", "status", "engineer_name", "primary_language"], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + +class RelationshipDefaultFactoryTest(fixtures.TestBase): + def test_list(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs[0].data, "hi") + + def test_set(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: {B(data="hi")} + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs.pop().data, "hi") + + def test_oh_no_mismatch(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + # old school collection mismatch error FTW + with expect_raises_message( + TypeError, "Incompatible collection type: list is not set-like" + ): + A() + + def test_replace_operation_works_w_history_etc( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + registry.metadata.create_all(testing.db) + + with Session(testing.db) as sess: + a1 = A("data", 10, [B("b1"), B("b2", x=5), B("b3")]) + sess.add(a1) + sess.commit() + + a2 = dataclasses.replace(a1, x=12, bs=[B("b4")]) + + assert a1 in sess + assert not sess.is_modified(a1, include_collections=True) + assert a2 not in sess + eq_(inspect(a2).attrs.x.history, ([12], (), ())) + sess.add(a2) + sess.commit() + + eq_(sess.scalars(select(A.x).order_by(A.id)).all(), [10, 12]) + eq_( + sess.scalars(select(B.data).order_by(B.id)).all(), + ["b1", "b2", "b3", "b4"], + ) + + def test_post_init(self, registry: _RegistryType): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(init=False) + + def __post_init__(self): + self.data = "some data" + + a1 = A() + eq_(a1.data, "some data") + + def test_no_field_args_w_new_style(self, registry: _RegistryType): + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + def test_no_field_args_w_new_style_two(self, registry: _RegistryType): + @dataclasses.dataclass + class Base: + pass + + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A(Base): + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + +class DataclassArgsTest(fixtures.TestBase): + dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash") + + @testing.fixture(params=dc_arg_names) + def dc_argument_fixture(self, request: Any, registry: _RegistryType): + name = request.param + + args = {n: n == name for n in self.dc_arg_names} + if args["order"]: + args["eq"] = True + yield args + + @testing.fixture( + params=["mapped_column", "synonym", "deferred", "column_property"] + ) + def mapped_expr_constructor(self, request): + name = request.param + + if name == "mapped_column": + yield mapped_column(default=7, init=True) + elif name == "synonym": + yield synonym("some_int", default=7, init=True) + elif name == "deferred": + yield deferred(Column(Integer), default=7, init=True) + elif name == "column_property": + yield column_property(Column(Integer), default=7, init=True) + + def test_attrs_rejected_if_not_a_dc( + self, mapped_expr_constructor, decl_base: Type[DeclarativeBase] + ): + if isinstance(mapped_expr_constructor, MappedColumn): + unwanted_args = "'init'" + else: + unwanted_args = "'default', 'init'" + with expect_raises_message( + exc.ArgumentError, + r"Attribute 'x' on class .*A.* includes dataclasses " + r"argument\(s\): " + rf"{unwanted_args} but class does not specify SQLAlchemy native " + "dataclass configuration", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] = mapped_expr_constructor + + def _assert_cls(self, cls, dc_arguments): + + if dc_arguments["init"]: + + def create(data, x): + return cls(data, x) + + else: + + def create(data, x): + a1 = cls() + a1.data = data + a1.x = x + return a1 + + for n in self.dc_arg_names: + if dc_arguments[n]: + getattr(self, f"_assert_{n}")(cls, create, dc_arguments) + else: + getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments) + + if dc_arguments["init"]: + a1 = cls("some data") + eq_(a1.x, 7) + + a1 = create("some data", 15) + some_int = a1.some_int + eq_( + dataclasses.asdict(a1), + {"data": "some data", "id": None, "some_int": some_int, "x": 15}, + ) + eq_(dataclasses.astuple(a1), (None, "some data", some_int, 15)) + + def _assert_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + hash(a1) + + def _assert_not_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + + if dc_arguments["eq"]: + with expect_raises(TypeError): + hash(a1) + else: + hash(a1) + + def _assert_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a3) + ne_(a1, a2) + + def _assert_not_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a1) + ne_(a1, a3) + ne_(a1, a2) + + def _assert_order(self, cls, create, dc_arguments): + is_false(create("g", 10) < create("b", 7)) + + is_true(create("g", 10) > create("b", 7)) + + is_false(create("g", 10) <= create("b", 7)) + + is_true(create("g", 10) >= create("b", 7)) + + eq_( + list(sorted([create("g", 10), create("g", 5), create("b", 7)])), + [ + create("b", 7), + create("g", 5), + create("g", 10), + ], + ) + + def _assert_not_order(self, cls, create, dc_arguments): + with expect_raises(TypeError): + create("g", 10) < create("b", 7) + + with expect_raises(TypeError): + create("g", 10) > create("b", 7) + + with expect_raises(TypeError): + create("g", 10) <= create("b", 7) + + with expect_raises(TypeError): + create("g", 10) >= create("b", 7) + + def _assert_repr(self, cls, create, dc_arguments): + a1 = create("some data", 12) + eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)") + + def _assert_not_repr(self, cls, create, dc_arguments): + a1 = create("some data", 12) + eq_regex(repr(a1), r"<.*A object at 0x.*>") + + def _assert_init(self, cls, create, dc_arguments): + a1 = cls("some data", 5) + + eq_(a1.data, "some data") + eq_(a1.x, 5) + + a2 = cls(data="some data", x=5) + eq_(a2.data, "some data") + eq_(a2.x, 5) + + a3 = cls(data="some data") + eq_(a3.data, "some data") + eq_(a3.x, 7) + + def _assert_not_init(self, cls, create, dc_arguments): + + with expect_raises(TypeError): + cls("Some data", 5) + + # we run real "dataclasses" on the class. so with init=False, it + # doesn't touch what was there, and the SQLA default constructor + # gets put on. + a1 = cls(data="some data") + eq_(a1.data, "some data") + eq_(a1.x, None) + + a1 = cls() + eq_(a1.data, None) + + # no constructor, it sets None for x...ok + eq_(a1.x, None) + + def test_dc_arguments_decorator( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + @registry.mapped_as_dataclass(**dc_argument_fixture) + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture) + + def test_dc_arguments_base( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + reg = registry + + class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture): + registry = reg + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self.A = A + + def test_dc_arguments_perclass( + self, + dc_argument_fixture, + mapped_expr_constructor, + decl_base: Type[DeclarativeBase], + ): + class A(MappedAsDataclass, decl_base, **dc_argument_fixture): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self.A = A + + +class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_composite_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Point: + x: int + y: int + + class Edge(dc_decl_base): + __tablename__ = "edge" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + graph_id: Mapped[int] = mapped_column( + ForeignKey("graph.id"), init=False + ) + + start: Mapped[Point] = composite( + Point, mapped_column("x1"), mapped_column("y1"), default=None + ) + + end: Mapped[Point] = composite( + Point, mapped_column("x2"), mapped_column("y2"), default=None + ) + + class Graph(dc_decl_base): + __tablename__ = "graph" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + edges: Mapped[List[Edge]] = relationship() + + Point.__qualname__ = "mymodel.Point" + Edge.__qualname__ = "mymodel.Edge" + Graph.__qualname__ = "mymodel.Graph" + g = Graph( + edges=[ + Edge(start=Point(1, 2), end=Point(3, 4)), + Edge(start=Point(7, 8), end=Point(5, 6)), + ] + ) + eq_( + repr(g), + "mymodel.Graph(id=None, edges=[mymodel.Edge(id=None, " + "graph_id=None, start=mymodel.Point(x=1, y=2), " + "end=mymodel.Point(x=3, y=4)), " + "mymodel.Edge(id=None, graph_id=None, " + "start=mymodel.Point(x=7, y=8), end=mymodel.Point(x=5, y=6))])", + ) + + def test_named_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(dc_decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column( + primary_key=True, init=False, repr=False + ) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + Address, + mapped_column(), + mapped_column(), + mapped_column("zip"), + default=None, + ) + + Address.__qualname__ = "mymodule.Address" + User.__qualname__ = "mymodule.User" + u = User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + u2 = User("u2") + eq_( + repr(u), + "mymodule.User(name='user 1', " + "address=mymodule.Address(street='123 anywhere street', " + "state='NY', zip_='12345'))", + ) + eq_(repr(u2), "mymodule.User(name='u2', address=None)") diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index d7d19821c..865735439 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -190,6 +190,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(User.__table__.c.data.nullable) assert isinstance(User.__table__.c.created_at.type, DateTime) + def test_column_default(self, decl_base): + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + mc = MyClass() + assert "data" not in mc.__dict__ + + eq_(MyClass.__table__.c.data.default.arg, "some default") + def test_anno_w_fixed_table(self, decl_base): users = Table( "users", @@ -959,7 +971,7 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): with expect_raises_message( ArgumentError, r"Type annotation for \"User.address\" should use the syntax " - r"\"Mapped\['Address'\]\" or \"MappedColumn\['Address'\]\"", + r"\"Mapped\['Address'\]\"", ): class User(decl_base): @@ -1068,6 +1080,38 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): # round trip! eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + def test_cls_annotated_no_mapped_cols_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite() + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + def test_one_col_setup(self, decl_base): @dataclasses.dataclass class Address: |