summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/decl_api.py4
-rw-r--r--lib/sqlalchemy/orm/decl_base.py129
-rw-r--r--lib/sqlalchemy/orm/util.py6
3 files changed, 89 insertions, 50 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index 01766ad85..09397eb65 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -1553,6 +1553,10 @@ class registry:
RegistryType = registry
+if not TYPE_CHECKING:
+ # allow for runtime type resolution of ``ClassVar[_RegistryType]``
+ _RegistryType = registry # noqa
+
def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]:
"""
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index c23ea0311..21e3c3344 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -635,36 +635,37 @@ class _ClassScanMapperConfig(_MapperConfig):
return attribute_is_overridden
- _skip_attrs = frozenset(
- [
- "__module__",
- "__annotations__",
- "__doc__",
- "__dict__",
- "__weakref__",
- "_sa_class_manager",
- "_sa_apply_dc_transforms",
- "__dict__",
- "__weakref__",
- ]
- )
+ _include_dunders = {
+ "__table__",
+ "__mapper_args__",
+ "__tablename__",
+ "__table_args__",
+ }
+
+ _match_exclude_dunders = re.compile(r"^(?:_sa_|__)")
def _cls_attr_resolver(
self, cls: Type[Any]
) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]:
- """produce a function to iterate the "attributes" of a class,
- adjusting for SQLAlchemy fields embedded in dataclass fields.
+ """produce a function to iterate the "attributes" of a class
+ which we want to consider for mapping, adjusting for SQLAlchemy fields
+ embedded in dataclass fields.
"""
cls_annotations = util.get_annotations(cls)
cls_vars = vars(cls)
- skip = self._skip_attrs
+ _include_dunders = self._include_dunders
+ _match_exclude_dunders = self._match_exclude_dunders
- names = util.merge_lists_w_ordering(
- [n for n in cls_vars if n not in skip], list(cls_annotations)
- )
+ names = [
+ n
+ for n in util.merge_lists_w_ordering(
+ list(cls_vars), list(cls_annotations)
+ )
+ if not _match_exclude_dunders.match(n) or n in _include_dunders
+ ]
if self.allow_dataclass_fields:
sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
@@ -719,6 +720,7 @@ class _ClassScanMapperConfig(_MapperConfig):
clsdict_view = self.clsdict_view
collected_attributes = self.collected_attributes
column_copies = self.column_copies
+ _include_dunders = self._include_dunders
mapper_args_fn = None
table_args = inherited_table_args = None
@@ -784,7 +786,7 @@ class _ClassScanMapperConfig(_MapperConfig):
annotation,
is_dataclass_field,
) in local_attributes_for_class():
- if re.match(r"^__.+__$", name):
+ if name in _include_dunders:
if name == "__mapper_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
@@ -825,7 +827,8 @@ class _ClassScanMapperConfig(_MapperConfig):
if base is not cls:
inherited_table_args = True
else:
- # skip all other dunder names
+ # skip all other dunder names, which at the moment
+ # should only be __table__
continue
elif class_mapped:
if _is_declarative_props(obj):
@@ -965,14 +968,19 @@ class _ClassScanMapperConfig(_MapperConfig):
name, annotation, base, False, obj
)
else:
- generated_obj = self._collect_annotation(
+ collected_annotation = self._collect_annotation(
name, annotation, base, None, obj
)
- if (
- obj is None
- and not fixed_table
- and _is_mapped_annotation(annotation, cls, base)
- ):
+ is_mapped = (
+ collected_annotation is not None
+ and collected_annotation.mapped_container is not None
+ )
+ generated_obj = (
+ collected_annotation.attr_value
+ if collected_annotation is not None
+ else obj
+ )
+ if obj is None and not fixed_table and is_mapped:
collected_attributes[name] = (
generated_obj
if generated_obj is not None
@@ -1077,13 +1085,13 @@ class _ClassScanMapperConfig(_MapperConfig):
originating_class: Type[Any],
expect_mapped: Optional[bool],
attr_value: Any,
- ) -> Any:
+ ) -> Optional[_CollectedAnnotation]:
if name in self.collected_annotations:
- return self.collected_annotations[name][4]
+ return self.collected_annotations[name]
if raw_annotation is None:
- return attr_value
+ return None
is_dataclass = self.is_dataclass_prior_to_mapping
allow_unmapped = self.allow_unmapped_annotations
@@ -1116,7 +1124,7 @@ class _ClassScanMapperConfig(_MapperConfig):
if extracted is None:
# ClassVar can come out here
- return attr_value
+ return None
extracted_mapped_annotation, mapped_container = extracted
@@ -1136,7 +1144,7 @@ class _ClassScanMapperConfig(_MapperConfig):
if isinstance(elem, _IntrospectsAnnotations):
attr_value = elem.found_in_pep593_annotated()
- self.collected_annotations[name] = _CollectedAnnotation(
+ self.collected_annotations[name] = ca = _CollectedAnnotation(
raw_annotation,
mapped_container,
extracted_mapped_annotation,
@@ -1144,7 +1152,7 @@ class _ClassScanMapperConfig(_MapperConfig):
attr_value,
originating_class.__module__,
)
- return attr_value
+ return ca
def _warn_for_decl_attributes(
self, cls: Type[Any], key: str, c: Any
@@ -1177,9 +1185,14 @@ class _ClassScanMapperConfig(_MapperConfig):
and obj is None
and _is_mapped_annotation(annotation, cls, originating_class)
):
- obj = self._collect_annotation(
+ collected_annotation = self._collect_annotation(
name, annotation, originating_class, True, obj
)
+ obj = (
+ collected_annotation.attr_value
+ if collected_annotation is not None
+ else obj
+ )
if obj is None:
obj = MappedColumn()
@@ -1195,9 +1208,14 @@ class _ClassScanMapperConfig(_MapperConfig):
# either (issue #8718)
continue
- obj = self._collect_annotation(
+ collected_annotation = self._collect_annotation(
name, annotation, originating_class, True, obj
)
+ obj = (
+ collected_annotation.attr_value
+ if collected_annotation is not None
+ else obj
+ )
if name not in dict_ and not (
"__table__" in dict_
@@ -1233,6 +1251,8 @@ class _ClassScanMapperConfig(_MapperConfig):
our_stuff = self.properties
+ _include_dunders = self._include_dunders
+
late_mapped = _get_immediate_cls_attr(
cls, "_sa_decl_prepare_nocascade", strict=True
)
@@ -1244,7 +1264,7 @@ class _ClassScanMapperConfig(_MapperConfig):
for k in list(collected_attributes):
- if k in ("__table__", "__tablename__", "__mapper_args__"):
+ if k in _include_dunders:
continue
value = collected_attributes[k]
@@ -1297,11 +1317,12 @@ class _ClassScanMapperConfig(_MapperConfig):
# we expect to see the name 'metadata' in some valid cases;
# however at this point we see it's assigned to something trying
# to be mapped, so raise for that.
- elif k == "metadata":
+ # TODO: should "registry" here be also? might be too late
+ # to change that now (2.0 betas)
+ elif k in ("metadata",):
raise exc.InvalidRequestError(
- "Attribute name 'metadata' is reserved "
- "for the MetaData instance when using a "
- "declarative base class."
+ f"Attribute name '{k}' is reserved when using the "
+ "Declarative API."
)
elif isinstance(value, Column):
_undefer_column_name(
@@ -1326,16 +1347,24 @@ class _ClassScanMapperConfig(_MapperConfig):
# do declarative_scan so that the property can raise
# for required
if mapped_container is not None or annotation is None:
- value.declarative_scan(
- self.registry,
- cls,
- originating_module,
- k,
- mapped_container,
- annotation,
- extracted_mapped_annotation,
- is_dataclass,
- )
+ try:
+ value.declarative_scan(
+ self.registry,
+ cls,
+ originating_module,
+ k,
+ mapped_container,
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+ except NameError as ne:
+ raise exc.ArgumentError(
+ f"Could not resolve all types within mapped "
+ f'annotation: "{annotation}". Ensure all '
+ f"types are written correctly and are "
+ f"imported within the module in use."
+ ) from ne
else:
# assert that we were expecting annotations
# without Mapped[] were going to be passed.
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 6250cd104..58407a74d 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -2033,6 +2033,12 @@ def _is_mapped_annotation(
cls, raw_annotation, originating_cls.__module__
)
except NameError:
+ # in most cases, at least within our own tests, we can raise
+ # here, which is more accurate as it prevents us from returning
+ # false negatives. However, in the real world, try to avoid getting
+ # involved with end-user annotations that have nothing to do with us.
+ # see issue #8888 where we bypass using this function in the case
+ # that we want to detect an unresolvable Mapped[] type.
return False
else:
return is_origin_of_cls(annotated, _MappedAnnotationBase)