diff options
Diffstat (limited to 'lib/sqlalchemy/orm/decl_base.py')
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 315 |
1 files changed, 245 insertions, 70 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a66421e22..54a272f86 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -10,6 +10,8 @@ from __future__ import annotations import collections +import dataclasses +import re from typing import Any from typing import Callable from typing import cast @@ -40,6 +42,7 @@ from .base import _is_mapped_class from .base import InspectionAttr from .descriptor_props import Composite from .descriptor_props import Synonym +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import _MappedAttribute from .interfaces import _MapsColumns @@ -48,15 +51,18 @@ from .mapper import Mapper as mapper from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn +from .util import _extract_mapped_subtype from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc from .. import util from ..sql import expression +from ..sql.base import _NoArg from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +from ..util.typing import _AnnotationScanType from ..util.typing import Protocol if TYPE_CHECKING: @@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig): "mapper_args", "mapper_args_fn", "inherits", + "allow_dataclass_fields", + "dataclass_setup_arguments", ) registry: _RegistryType clsdict_view: _ClassDict - collected_annotations: Dict[str, Tuple[Any, bool]] + collected_annotations: Dict[str, Tuple[Any, Any, bool]] collected_attributes: Dict[str, Any] local_table: Optional[FromClause] persist_selectable: Optional[FromClause] @@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] inherits: Optional[Type[Any]] + dataclass_setup_arguments: Optional[Dict[str, Any]] + """if the class has SQLAlchemy native dataclass parameters, where + we will create a SQLAlchemy dataclass (not a real dataclass). + + """ + + allow_dataclass_fields: bool + """if true, look for dataclass-processed Field objects on the target + class as well as superclasses and extract ORM mapping directives from + the "metadata" attribute of each Field""" + def __init__( self, registry: _RegistryType, @@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig): self.declared_columns = util.OrderedSet() self.column_copies = {} + self.dataclass_setup_arguments = dca = getattr( + self.cls, "_sa_apply_dc_transforms", None + ) + + cld = dataclasses.is_dataclass(cls_) + + sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__") + + # we don't want to consume Field objects from a not-already-dataclass. + # the Field objects won't have their "name" or "type" populated, + # and while it seems like we could just set these on Field as we + # read them, Field is documented as "user read only" and we need to + # stay far away from any off-label use of dataclasses APIs. + if (not cld or dca) and sdk: + raise exc.InvalidRequestError( + "SQLAlchemy mapped dataclasses can't consume mapping " + "information from dataclass.Field() objects if the immediate " + "class is not already a dataclass." + ) + + # if already a dataclass, and __sa_dataclass_metadata_key__ present, + # then also look inside of dataclass.Field() objects yielded by + # dataclasses.get_fields(cls) when scanning for attributes + self.allow_dataclass_fields = bool(sdk and cld) + self._setup_declared_events() self._scan_attributes() + self._setup_dataclasses_transforms() + with mapperlib._CONFIGURE_MUTEX: clsregistry.add_class( self.classname, self.cls, registry._class_registry @@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig): attribute, taking SQLAlchemy-enabled dataclass fields into account. """ - sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - if sa_dataclass_metadata_key is None: + if self.allow_dataclass_fields: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def attribute_is_overridden(key: str, obj: Any) -> bool: return getattr(cls, key) is not obj @@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig): "__dict__", "__weakref__", "_sa_class_manager", + "_sa_apply_dc_transforms", "__dict__", "__weakref__", ] @@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig): adjusting for SQLAlchemy fields embedded in dataclass fields. """ - sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - cls_annotations = util.get_annotations(cls) cls_vars = vars(cls) @@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig): names = util.merge_lists_w_ordering( [n for n in cls_vars if n not in skip], list(cls_annotations) ) - if sa_dataclass_metadata_key is None: + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def local_attributes_for_class() -> Iterable[ Tuple[str, Any, Any, bool] @@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig): name, obj, annotation, - is_dataclass, + is_dataclass_field, ) in local_attributes_for_class(): - if name == "__mapper_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not mapper_args_fn and (not class_mapped or check_decl): - # don't even invoke __mapper_args__ until - # after we've determined everything about the - # mapped table. - # make a copy of it so a class-level dictionary - # is not overwritten when we update column-based - # arguments. - def _mapper_args_fn() -> Dict[str, Any]: - return dict(cls_as_Decl.__mapper_args__) - - mapper_args_fn = _mapper_args_fn - - elif name == "__tablename__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not tablename and (not class_mapped or check_decl): - tablename = cls_as_Decl.__tablename__ - elif name == "__table_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not table_args and (not class_mapped or check_decl): - table_args = cls_as_Decl.__table_args__ - if not isinstance( - table_args, (tuple, dict, type(None)) + if re.match(r"^__.+__$", name): + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and ( + not class_mapped or check_decl ): - raise exc.ArgumentError( - "__table_args__ value must be a tuple, " - "dict, or None" - ) - if base is not cls: - inherited_table_args = True + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn + + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): + tablename = cls_as_Decl.__tablename__ + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): + table_args = cls_as_Decl.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None)) + ): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None" + ) + if base is not cls: + inherited_table_args = True + else: + # skip all other dunder names + continue elif class_mapped: if _is_declarative_props(obj): util.warn( @@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig): # acting like that for now. if isinstance(obj, (Column, MappedColumn)): - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) # already copied columns to the mapped class. continue @@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig): ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: - if is_dataclass: + if is_dataclass_field: # access attribute using normal class access # first, to see if it's been mapped on a # superclass. note if the dataclasses.field() @@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret.doc = obj.__doc__ - self.collected_annotations[name] = ( + self._collect_annotation( + name, obj._collect_return_annotation(), False, + True, + obj, ) elif _is_mapped_annotation(annotation, cls): - self.collected_annotations[name] = ( - annotation, - is_dataclass, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) if obj is None: if not fixed_table: @@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig): # declarative mapping. however, check for some # more common mistakes self._warn_for_decl_attributes(base, name, obj) - elif is_dataclass and ( + elif is_dataclass_field and ( name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class @@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig): obj = obj.fget() collected_attributes[name] = obj - self.collected_annotations[name] = ( - annotation, - True, + self._collect_annotation( + name, annotation, True, False, obj ) else: - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, False, None, obj ) if ( obj is None @@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig): collected_attributes[name] = MappedColumn() elif name in clsdict_view: collected_attributes[name] = obj + # else if the name is not in the cls.__dict__, + # don't collect it as an attribute. + # we will see the annotation only, which is meaningful + # both for mapping and dataclasses setup if inherited_table_args and not tablename: table_args = None @@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig): self.tablename = tablename self.mapper_args_fn = mapper_args_fn + def _setup_dataclasses_transforms(self) -> None: + + dataclass_setup_arguments = self.dataclass_setup_arguments + if not dataclass_setup_arguments: + return + + manager = instrumentation.manager_of_class(self.cls) + assert manager is not None + + field_list = [ + _AttributeOptions._get_arguments_for_make_dataclass( + key, + anno, + self.collected_attributes.get(key, _NoArg.NO_ARG), + ) + for key, anno in ( + (key, mapped_anno if mapped_anno else raw_anno) + for key, ( + raw_anno, + mapped_anno, + is_dc, + ) in self.collected_annotations.items() + ) + ] + + annotations = {} + defaults = {} + for item in field_list: + if len(item) == 2: + name, tp = item # type: ignore + elif len(item) == 3: + name, tp, spec = item # type: ignore + defaults[name] = spec + else: + assert False + annotations[name] = tp + + for k, v in defaults.items(): + setattr(self.cls, k, v) + self.cls.__annotations__ = annotations + + dataclasses.dataclass(self.cls, **dataclass_setup_arguments) + + def _collect_annotation( + self, + name: str, + raw_annotation: _AnnotationScanType, + is_dataclass: bool, + expect_mapped: Optional[bool], + attr_value: Any, + ) -> None: + + if expect_mapped is None: + expect_mapped = isinstance(attr_value, _MappedAttribute) + + extracted_mapped_annotation = _extract_mapped_subtype( + raw_annotation, + self.cls, + name, + type(attr_value), + required=False, + is_dataclass_field=False, + expect_mapped=expect_mapped and not self.allow_dataclass_fields, + ) + + self.collected_annotations[name] = ( + raw_annotation, + extracted_mapped_annotation, + is_dataclass, + ) + def _warn_for_decl_attributes( self, cls: Type[Any], key: str, c: Any ) -> None: @@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig): _undefer_column_name( k, self.column_copies.get(value, value) # type: ignore ) - elif isinstance(value, _IntrospectsAnnotations): - annotation, is_dataclass = self.collected_annotations.get( - k, (None, False) - ) - value.declarative_scan( - self.registry, cls, k, annotation, is_dataclass - ) + else: + if isinstance(value, _IntrospectsAnnotations): + ( + annotation, + extracted_mapped_annotation, + is_dataclass, + ) = self.collected_annotations.get(k, (None, None, False)) + value.declarative_scan( + self.registry, + cls, + k, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + + if ( + isinstance(value, (MapperProperty, _MapsColumns)) + and value._has_dataclass_arguments + and not self.dataclass_setup_arguments + ): + if isinstance(value, MapperProperty): + argnames = [ + "init", + "default_factory", + "repr", + "default", + ] + else: + argnames = ["init", "default_factory", "repr"] + + args = { + a + for a in argnames + if getattr( + value._attribute_options, f"dataclasses_{a}" + ) + is not _NoArg.NO_ARG + } + raise exc.ArgumentError( + f"Attribute '{k}' on class {cls} includes dataclasses " + f"argument(s): " + f"{', '.join(sorted(repr(a) for a in args))} but " + f"class does not specify " + "SQLAlchemy native dataclass configuration." + ) + our_stuff[k] = value def _extract_declared_columns(self) -> None: @@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig): # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) + for key, c in list(our_stuff.items()): if isinstance(c, _MapsColumns): @@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig): # otherwise, Mapper will map it under the column key. if mp_to_assign is None and key != col.key: our_stuff[key] = col - elif isinstance(c, Column): # undefer previously occurred here, and now occurs earlier. # ensure every column we get here has been named |