summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/decl_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/decl_base.py')
-rw-r--r--lib/sqlalchemy/orm/decl_base.py315
1 files changed, 245 insertions, 70 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index a66421e22..54a272f86 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -10,6 +10,8 @@
from __future__ import annotations
import collections
+import dataclasses
+import re
from typing import Any
from typing import Callable
from typing import cast
@@ -40,6 +42,7 @@ from .base import _is_mapped_class
from .base import InspectionAttr
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
@@ -48,15 +51,18 @@ from .mapper import Mapper as mapper
from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
+from .util import _extract_mapped_subtype
from .util import _is_mapped_annotation
from .util import class_mapper
from .. import event
from .. import exc
from .. import util
from ..sql import expression
+from ..sql.base import _NoArg
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import _AnnotationScanType
from ..util.typing import Protocol
if TYPE_CHECKING:
@@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig):
"mapper_args",
"mapper_args_fn",
"inherits",
+ "allow_dataclass_fields",
+ "dataclass_setup_arguments",
)
registry: _RegistryType
clsdict_view: _ClassDict
- collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_annotations: Dict[str, Tuple[Any, Any, bool]]
collected_attributes: Dict[str, Any]
local_table: Optional[FromClause]
persist_selectable: Optional[FromClause]
@@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig):
mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
inherits: Optional[Type[Any]]
+ dataclass_setup_arguments: Optional[Dict[str, Any]]
+ """if the class has SQLAlchemy native dataclass parameters, where
+ we will create a SQLAlchemy dataclass (not a real dataclass).
+
+ """
+
+ allow_dataclass_fields: bool
+ """if true, look for dataclass-processed Field objects on the target
+ class as well as superclasses and extract ORM mapping directives from
+ the "metadata" attribute of each Field"""
+
def __init__(
self,
registry: _RegistryType,
@@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig):
self.declared_columns = util.OrderedSet()
self.column_copies = {}
+ self.dataclass_setup_arguments = dca = getattr(
+ self.cls, "_sa_apply_dc_transforms", None
+ )
+
+ cld = dataclasses.is_dataclass(cls_)
+
+ sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__")
+
+ # we don't want to consume Field objects from a not-already-dataclass.
+ # the Field objects won't have their "name" or "type" populated,
+ # and while it seems like we could just set these on Field as we
+ # read them, Field is documented as "user read only" and we need to
+ # stay far away from any off-label use of dataclasses APIs.
+ if (not cld or dca) and sdk:
+ raise exc.InvalidRequestError(
+ "SQLAlchemy mapped dataclasses can't consume mapping "
+ "information from dataclass.Field() objects if the immediate "
+ "class is not already a dataclass."
+ )
+
+ # if already a dataclass, and __sa_dataclass_metadata_key__ present,
+ # then also look inside of dataclass.Field() objects yielded by
+ # dataclasses.get_fields(cls) when scanning for attributes
+ self.allow_dataclass_fields = bool(sdk and cld)
+
self._setup_declared_events()
self._scan_attributes()
+ self._setup_dataclasses_transforms()
+
with mapperlib._CONFIGURE_MUTEX:
clsregistry.add_class(
self.classname, self.cls, registry._class_registry
@@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig):
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
- if sa_dataclass_metadata_key is None:
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
@@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig):
"__dict__",
"__weakref__",
"_sa_class_manager",
+ "_sa_apply_dc_transforms",
"__dict__",
"__weakref__",
]
@@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig):
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
-
cls_annotations = util.get_annotations(cls)
cls_vars = vars(cls)
@@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig):
names = util.merge_lists_w_ordering(
[n for n in cls_vars if n not in skip], list(cls_annotations)
)
- if sa_dataclass_metadata_key is None:
+
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def local_attributes_for_class() -> Iterable[
Tuple[str, Any, Any, bool]
@@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig):
name,
obj,
annotation,
- is_dataclass,
+ is_dataclass_field,
) in local_attributes_for_class():
- if name == "__mapper_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not mapper_args_fn and (not class_mapped or check_decl):
- # don't even invoke __mapper_args__ until
- # after we've determined everything about the
- # mapped table.
- # make a copy of it so a class-level dictionary
- # is not overwritten when we update column-based
- # arguments.
- def _mapper_args_fn() -> Dict[str, Any]:
- return dict(cls_as_Decl.__mapper_args__)
-
- mapper_args_fn = _mapper_args_fn
-
- elif name == "__tablename__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not tablename and (not class_mapped or check_decl):
- tablename = cls_as_Decl.__tablename__
- elif name == "__table_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not table_args and (not class_mapped or check_decl):
- table_args = cls_as_Decl.__table_args__
- if not isinstance(
- table_args, (tuple, dict, type(None))
+ if re.match(r"^__.+__$", name):
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (
+ not class_mapped or check_decl
):
- raise exc.ArgumentError(
- "__table_args__ value must be a tuple, "
- "dict, or None"
- )
- if base is not cls:
- inherited_table_args = True
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
+
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
+ tablename = cls_as_Decl.__tablename__
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
+ table_args = cls_as_Decl.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))
+ ):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None"
+ )
+ if base is not cls:
+ inherited_table_args = True
+ else:
+ # skip all other dunder names
+ continue
elif class_mapped:
if _is_declarative_props(obj):
util.warn(
@@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig):
# acting like that for now.
if isinstance(obj, (Column, MappedColumn)):
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
# already copied columns to the mapped class.
continue
@@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig):
] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
- if is_dataclass:
+ if is_dataclass_field:
# access attribute using normal class access
# first, to see if it's been mapped on a
# superclass. note if the dataclasses.field()
@@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig):
):
ret.doc = obj.__doc__
- self.collected_annotations[name] = (
+ self._collect_annotation(
+ name,
obj._collect_return_annotation(),
False,
+ True,
+ obj,
)
elif _is_mapped_annotation(annotation, cls):
- self.collected_annotations[name] = (
- annotation,
- is_dataclass,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
if obj is None:
if not fixed_table:
@@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# declarative mapping. however, check for some
# more common mistakes
self._warn_for_decl_attributes(base, name, obj)
- elif is_dataclass and (
+ elif is_dataclass_field and (
name not in clsdict_view or clsdict_view[name] is not obj
):
# here, we are definitely looking at the target class
@@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig):
obj = obj.fget()
collected_attributes[name] = obj
- self.collected_annotations[name] = (
- annotation,
- True,
+ self._collect_annotation(
+ name, annotation, True, False, obj
)
else:
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, False, None, obj
)
if (
obj is None
@@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig):
collected_attributes[name] = MappedColumn()
elif name in clsdict_view:
collected_attributes[name] = obj
+ # else if the name is not in the cls.__dict__,
+ # don't collect it as an attribute.
+ # we will see the annotation only, which is meaningful
+ # both for mapping and dataclasses setup
if inherited_table_args and not tablename:
table_args = None
@@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig):
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
+ def _setup_dataclasses_transforms(self) -> None:
+
+ dataclass_setup_arguments = self.dataclass_setup_arguments
+ if not dataclass_setup_arguments:
+ return
+
+ manager = instrumentation.manager_of_class(self.cls)
+ assert manager is not None
+
+ field_list = [
+ _AttributeOptions._get_arguments_for_make_dataclass(
+ key,
+ anno,
+ self.collected_attributes.get(key, _NoArg.NO_ARG),
+ )
+ for key, anno in (
+ (key, mapped_anno if mapped_anno else raw_anno)
+ for key, (
+ raw_anno,
+ mapped_anno,
+ is_dc,
+ ) in self.collected_annotations.items()
+ )
+ ]
+
+ annotations = {}
+ defaults = {}
+ for item in field_list:
+ if len(item) == 2:
+ name, tp = item # type: ignore
+ elif len(item) == 3:
+ name, tp, spec = item # type: ignore
+ defaults[name] = spec
+ else:
+ assert False
+ annotations[name] = tp
+
+ for k, v in defaults.items():
+ setattr(self.cls, k, v)
+ self.cls.__annotations__ = annotations
+
+ dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+
+ def _collect_annotation(
+ self,
+ name: str,
+ raw_annotation: _AnnotationScanType,
+ is_dataclass: bool,
+ expect_mapped: Optional[bool],
+ attr_value: Any,
+ ) -> None:
+
+ if expect_mapped is None:
+ expect_mapped = isinstance(attr_value, _MappedAttribute)
+
+ extracted_mapped_annotation = _extract_mapped_subtype(
+ raw_annotation,
+ self.cls,
+ name,
+ type(attr_value),
+ required=False,
+ is_dataclass_field=False,
+ expect_mapped=expect_mapped and not self.allow_dataclass_fields,
+ )
+
+ self.collected_annotations[name] = (
+ raw_annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
def _warn_for_decl_attributes(
self, cls: Type[Any], key: str, c: Any
) -> None:
@@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig):
_undefer_column_name(
k, self.column_copies.get(value, value) # type: ignore
)
- elif isinstance(value, _IntrospectsAnnotations):
- annotation, is_dataclass = self.collected_annotations.get(
- k, (None, False)
- )
- value.declarative_scan(
- self.registry, cls, k, annotation, is_dataclass
- )
+ else:
+ if isinstance(value, _IntrospectsAnnotations):
+ (
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ ) = self.collected_annotations.get(k, (None, None, False))
+ value.declarative_scan(
+ self.registry,
+ cls,
+ k,
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
+ if (
+ isinstance(value, (MapperProperty, _MapsColumns))
+ and value._has_dataclass_arguments
+ and not self.dataclass_setup_arguments
+ ):
+ if isinstance(value, MapperProperty):
+ argnames = [
+ "init",
+ "default_factory",
+ "repr",
+ "default",
+ ]
+ else:
+ argnames = ["init", "default_factory", "repr"]
+
+ args = {
+ a
+ for a in argnames
+ if getattr(
+ value._attribute_options, f"dataclasses_{a}"
+ )
+ is not _NoArg.NO_ARG
+ }
+ raise exc.ArgumentError(
+ f"Attribute '{k}' on class {cls} includes dataclasses "
+ f"argument(s): "
+ f"{', '.join(sorted(repr(a) for a in args))} but "
+ f"class does not specify "
+ "SQLAlchemy native dataclass configuration."
+ )
+
our_stuff[k] = value
def _extract_declared_columns(self) -> None:
@@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
+
for key, c in list(our_stuff.items()):
if isinstance(c, _MapsColumns):
@@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig):
# otherwise, Mapper will map it under the column key.
if mp_to_assign is None and key != col.key:
our_stuff[key] = col
-
elif isinstance(c, Column):
# undefer previously occurred here, and now occurs earlier.
# ensure every column we get here has been named