diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 129 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 6 |
3 files changed, 89 insertions, 50 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 01766ad85..09397eb65 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -1553,6 +1553,10 @@ class registry: RegistryType = registry +if not TYPE_CHECKING: + # allow for runtime type resolution of ``ClassVar[_RegistryType]`` + _RegistryType = registry # noqa + def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: """ diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index c23ea0311..21e3c3344 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -635,36 +635,37 @@ class _ClassScanMapperConfig(_MapperConfig): return attribute_is_overridden - _skip_attrs = frozenset( - [ - "__module__", - "__annotations__", - "__doc__", - "__dict__", - "__weakref__", - "_sa_class_manager", - "_sa_apply_dc_transforms", - "__dict__", - "__weakref__", - ] - ) + _include_dunders = { + "__table__", + "__mapper_args__", + "__tablename__", + "__table_args__", + } + + _match_exclude_dunders = re.compile(r"^(?:_sa_|__)") def _cls_attr_resolver( self, cls: Type[Any] ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]: - """produce a function to iterate the "attributes" of a class, - adjusting for SQLAlchemy fields embedded in dataclass fields. + """produce a function to iterate the "attributes" of a class + which we want to consider for mapping, adjusting for SQLAlchemy fields + embedded in dataclass fields. """ cls_annotations = util.get_annotations(cls) cls_vars = vars(cls) - skip = self._skip_attrs + _include_dunders = self._include_dunders + _match_exclude_dunders = self._match_exclude_dunders - names = util.merge_lists_w_ordering( - [n for n in cls_vars if n not in skip], list(cls_annotations) - ) + names = [ + n + for n in util.merge_lists_w_ordering( + list(cls_vars), list(cls_annotations) + ) + if not _match_exclude_dunders.match(n) or n in _include_dunders + ] if self.allow_dataclass_fields: sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( @@ -719,6 +720,7 @@ class _ClassScanMapperConfig(_MapperConfig): clsdict_view = self.clsdict_view collected_attributes = self.collected_attributes column_copies = self.column_copies + _include_dunders = self._include_dunders mapper_args_fn = None table_args = inherited_table_args = None @@ -784,7 +786,7 @@ class _ClassScanMapperConfig(_MapperConfig): annotation, is_dataclass_field, ) in local_attributes_for_class(): - if re.match(r"^__.+__$", name): + if name in _include_dunders: if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -825,7 +827,8 @@ class _ClassScanMapperConfig(_MapperConfig): if base is not cls: inherited_table_args = True else: - # skip all other dunder names + # skip all other dunder names, which at the moment + # should only be __table__ continue elif class_mapped: if _is_declarative_props(obj): @@ -965,14 +968,19 @@ class _ClassScanMapperConfig(_MapperConfig): name, annotation, base, False, obj ) else: - generated_obj = self._collect_annotation( + collected_annotation = self._collect_annotation( name, annotation, base, None, obj ) - if ( - obj is None - and not fixed_table - and _is_mapped_annotation(annotation, cls, base) - ): + is_mapped = ( + collected_annotation is not None + and collected_annotation.mapped_container is not None + ) + generated_obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + if obj is None and not fixed_table and is_mapped: collected_attributes[name] = ( generated_obj if generated_obj is not None @@ -1077,13 +1085,13 @@ class _ClassScanMapperConfig(_MapperConfig): originating_class: Type[Any], expect_mapped: Optional[bool], attr_value: Any, - ) -> Any: + ) -> Optional[_CollectedAnnotation]: if name in self.collected_annotations: - return self.collected_annotations[name][4] + return self.collected_annotations[name] if raw_annotation is None: - return attr_value + return None is_dataclass = self.is_dataclass_prior_to_mapping allow_unmapped = self.allow_unmapped_annotations @@ -1116,7 +1124,7 @@ class _ClassScanMapperConfig(_MapperConfig): if extracted is None: # ClassVar can come out here - return attr_value + return None extracted_mapped_annotation, mapped_container = extracted @@ -1136,7 +1144,7 @@ class _ClassScanMapperConfig(_MapperConfig): if isinstance(elem, _IntrospectsAnnotations): attr_value = elem.found_in_pep593_annotated() - self.collected_annotations[name] = _CollectedAnnotation( + self.collected_annotations[name] = ca = _CollectedAnnotation( raw_annotation, mapped_container, extracted_mapped_annotation, @@ -1144,7 +1152,7 @@ class _ClassScanMapperConfig(_MapperConfig): attr_value, originating_class.__module__, ) - return attr_value + return ca def _warn_for_decl_attributes( self, cls: Type[Any], key: str, c: Any @@ -1177,9 +1185,14 @@ class _ClassScanMapperConfig(_MapperConfig): and obj is None and _is_mapped_annotation(annotation, cls, originating_class) ): - obj = self._collect_annotation( + collected_annotation = self._collect_annotation( name, annotation, originating_class, True, obj ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) if obj is None: obj = MappedColumn() @@ -1195,9 +1208,14 @@ class _ClassScanMapperConfig(_MapperConfig): # either (issue #8718) continue - obj = self._collect_annotation( + collected_annotation = self._collect_annotation( name, annotation, originating_class, True, obj ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) if name not in dict_ and not ( "__table__" in dict_ @@ -1233,6 +1251,8 @@ class _ClassScanMapperConfig(_MapperConfig): our_stuff = self.properties + _include_dunders = self._include_dunders + late_mapped = _get_immediate_cls_attr( cls, "_sa_decl_prepare_nocascade", strict=True ) @@ -1244,7 +1264,7 @@ class _ClassScanMapperConfig(_MapperConfig): for k in list(collected_attributes): - if k in ("__table__", "__tablename__", "__mapper_args__"): + if k in _include_dunders: continue value = collected_attributes[k] @@ -1297,11 +1317,12 @@ class _ClassScanMapperConfig(_MapperConfig): # we expect to see the name 'metadata' in some valid cases; # however at this point we see it's assigned to something trying # to be mapped, so raise for that. - elif k == "metadata": + # TODO: should "registry" here be also? might be too late + # to change that now (2.0 betas) + elif k in ("metadata",): raise exc.InvalidRequestError( - "Attribute name 'metadata' is reserved " - "for the MetaData instance when using a " - "declarative base class." + f"Attribute name '{k}' is reserved when using the " + "Declarative API." ) elif isinstance(value, Column): _undefer_column_name( @@ -1326,16 +1347,24 @@ class _ClassScanMapperConfig(_MapperConfig): # do declarative_scan so that the property can raise # for required if mapped_container is not None or annotation is None: - value.declarative_scan( - self.registry, - cls, - originating_module, - k, - mapped_container, - annotation, - extracted_mapped_annotation, - is_dataclass, - ) + try: + value.declarative_scan( + self.registry, + cls, + originating_module, + k, + mapped_container, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + except NameError as ne: + raise exc.ArgumentError( + f"Could not resolve all types within mapped " + f'annotation: "{annotation}". Ensure all ' + f"types are written correctly and are " + f"imported within the module in use." + ) from ne else: # assert that we were expecting annotations # without Mapped[] were going to be passed. diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 6250cd104..58407a74d 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -2033,6 +2033,12 @@ def _is_mapped_annotation( cls, raw_annotation, originating_cls.__module__ ) except NameError: + # in most cases, at least within our own tests, we can raise + # here, which is more accurate as it prevents us from returning + # false negatives. However, in the real world, try to avoid getting + # involved with end-user annotations that have nothing to do with us. + # see issue #8888 where we bypass using this function in the case + # that we want to detect an unresolvable Mapped[] type. return False else: return is_origin_of_cls(annotated, _MappedAnnotationBase) |