diff options
Diffstat (limited to 'lib/sqlalchemy/orm/decl_base.py')
-rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 414 |
1 files changed, 290 insertions, 124 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index b1f81cb6b..c3faac36c 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -4,16 +4,26 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php + """Internal implementation for declarative.""" from __future__ import annotations import collections from typing import Any +from typing import Callable +from typing import cast from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from . import attributes @@ -21,6 +31,8 @@ from . import clsregistry from . import exc as orm_exc from . import instrumentation from . import mapperlib +from ._typing import _O +from ._typing import attr_is_internal_proxy from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class @@ -32,6 +44,7 @@ from .interfaces import _MappedAttribute from .interfaces import _MapsColumns from .interfaces import MapperProperty from .mapper import Mapper as mapper +from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn from .util import _is_mapped_annotation @@ -43,12 +56,41 @@ from ..sql import expression from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _ClassDict from ._typing import _RegistryType + from .decl_api import declared_attr + from .instrumentation import ClassManager + from ..sql.schema import MetaData + from ..sql.selectable import FromClause + +_T = TypeVar("_T", bound=Any) + +_MapperKwArgs = Mapping[str, Any] + +_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]] -def _declared_mapping_info(cls): +class _DeclMappedClassProtocol(Protocol[_O]): + metadata: MetaData + __mapper__: Mapper[_O] + __table__: Table + __tablename__: str + __mapper_args__: Mapping[str, Any] + __table_args__: Optional[_TableArgsType] + + def __declare_first__(self) -> None: + pass + + def __declare_last__(self) -> None: + pass + + +def _declared_mapping_info( + cls: Type[Any], +) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]: # deferred mapping if _DeferredMapperConfig.has_cls(cls): return _DeferredMapperConfig.config_for_cls(cls) @@ -59,13 +101,15 @@ def _declared_mapping_info(cls): return None -def _resolve_for_abstract_or_classical(cls): +def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]: if cls is object: return None + sup: Optional[Type[Any]] + if cls.__dict__.get("__abstract__", False): - for sup in cls.__bases__: - sup = _resolve_for_abstract_or_classical(sup) + for base_ in cls.__bases__: + sup = _resolve_for_abstract_or_classical(base_) if sup is not None: return sup else: @@ -79,7 +123,9 @@ def _resolve_for_abstract_or_classical(cls): return cls -def _get_immediate_cls_attr(cls, attrname, strict=False): +def _get_immediate_cls_attr( + cls: Type[Any], attrname: str, strict: bool = False +) -> Optional[Any]: """return an attribute of the class that is either present directly on the class, e.g. not on a superclass, or is from a superclass but this superclass is a non-mapped mixin, that is, not a descendant of @@ -102,7 +148,7 @@ def _get_immediate_cls_attr(cls, attrname, strict=False): return getattr(cls, attrname) for base in cls.__mro__[1:]: - _is_classicial_inherits = _dive_for_cls_manager(base) + _is_classicial_inherits = _dive_for_cls_manager(base) is not None if attrname in base.__dict__ and ( base is cls @@ -116,33 +162,37 @@ def _get_immediate_cls_attr(cls, attrname, strict=False): return None -def _dive_for_cls_manager(cls): +def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]: # because the class manager registration is pluggable, # we need to do the search for every class in the hierarchy, # rather than just a simple "cls._sa_class_manager" - # python 2 old style class - if not hasattr(cls, "__mro__"): - return None - for base in cls.__mro__: - manager = attributes.opt_manager_of_class(base) + manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class( + base + ) if manager: return manager return None -def _as_declarative(registry, cls, dict_): +def _as_declarative( + registry: _RegistryType, cls: Type[Any], dict_: _ClassDict +) -> Optional[_MapperConfig]: # declarative scans the class for attributes. no table or mapper # args passed separately. - return _MapperConfig.setup_mapping(registry, cls, dict_, None, {}) -def _mapper(registry, cls, table, mapper_kw): +def _mapper( + registry: _RegistryType, + cls: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, +) -> Mapper[_O]: _ImperativeMapperConfig(registry, cls, table, mapper_kw) - return cls.__mapper__ + return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__ @util.preload_module("sqlalchemy.orm.decl_api") @@ -152,7 +202,9 @@ def _is_declarative_props(obj: Any) -> bool: return isinstance(obj, (declared_attr, util.classproperty)) -def _check_declared_props_nocascade(obj, name, cls): +def _check_declared_props_nocascade( + obj: Any, name: str, cls: Type[_O] +) -> bool: if _is_declarative_props(obj): if getattr(obj, "_cascading", False): util.warn( @@ -174,8 +226,20 @@ class _MapperConfig: "__weakref__", ) + cls: Type[Any] + classname: str + properties: util.OrderedDict[str, MapperProperty[Any]] + declared_attr_reg: Dict[declared_attr[Any], Any] + @classmethod - def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw): + def setup_mapping( + cls, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ) -> Optional[_MapperConfig]: manager = attributes.opt_manager_of_class(cls) if manager and manager.class_ is cls_: raise exc.InvalidRequestError( @@ -183,24 +247,26 @@ class _MapperConfig: ) if cls_.__dict__.get("__abstract__", False): - return + return None defer_map = _get_immediate_cls_attr( cls_, "_sa_decl_prepare_nocascade", strict=True ) or hasattr(cls_, "_sa_decl_prepare") if defer_map: - cfg_cls = _DeferredMapperConfig + return _DeferredMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) else: - cfg_cls = _ClassScanMapperConfig - - return cfg_cls(registry, cls_, dict_, table, mapper_kw) + return _ClassScanMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) def __init__( self, registry: _RegistryType, cls_: Type[Any], - mapper_kw: Dict[str, Any], + mapper_kw: _MapperKwArgs, ): self.cls = util.assert_arg_type(cls_, type, "cls_") self.classname = cls_.__name__ @@ -224,13 +290,16 @@ class _MapperConfig: "Mapper." % self.cls ) - def set_cls_attribute(self, attrname, value): + def set_cls_attribute(self, attrname: str, value: _T) -> _T: manager = instrumentation.manager_of_class(self.cls) manager.install_member(attrname, value) return value - def _early_mapping(self, mapper_kw): + def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]: + raise NotImplementedError() + + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: self.map(mapper_kw) @@ -239,10 +308,10 @@ class _ImperativeMapperConfig(_MapperConfig): def __init__( self, - registry, - cls_, - table, - mapper_kw, + registry: _RegistryType, + cls_: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, ): super(_ImperativeMapperConfig, self).__init__( registry, cls_, mapper_kw @@ -260,7 +329,7 @@ class _ImperativeMapperConfig(_MapperConfig): self._early_mapping(mapper_kw) - def map(self, mapper_kw=util.EMPTY_DICT): + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: mapper_cls = mapper return self.set_cls_attribute( @@ -268,7 +337,7 @@ class _ImperativeMapperConfig(_MapperConfig): mapper_cls(self.cls, self.local_table, **mapper_kw), ) - def _setup_inheritance(self, mapper_kw): + def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: cls = self.cls inherits = mapper_kw.get("inherits", None) @@ -277,8 +346,8 @@ class _ImperativeMapperConfig(_MapperConfig): # since we search for classical mappings now, search for # multiple mapped bases as well and raise an error. inherits_search = [] - for c in cls.__bases__: - c = _resolve_for_abstract_or_classical(c) + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) if c is None: continue if _declared_mapping_info( @@ -318,13 +387,30 @@ class _ClassScanMapperConfig(_MapperConfig): "inherits", ) + registry: _RegistryType + clsdict_view: _ClassDict + collected_annotations: Dict[str, Tuple[Any, bool]] + collected_attributes: Dict[str, Any] + local_table: Optional[FromClause] + persist_selectable: Optional[FromClause] + declared_columns: util.OrderedSet[Column[Any]] + column_copies: Dict[ + Union[MappedColumn[Any], Column[Any]], + Union[MappedColumn[Any], Column[Any]], + ] + tablename: Optional[str] + mapper_args: Mapping[str, Any] + table_args: Optional[_TableArgsType] + mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] + inherits: Optional[Type[Any]] + def __init__( self, - registry, - cls_, - dict_, - table, - mapper_kw, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, ): # grab class dict before the instrumentation manager has been added. @@ -337,7 +423,7 @@ class _ClassScanMapperConfig(_MapperConfig): self.persist_selectable = None self.collected_attributes = {} - self.collected_annotations: Dict[str, Tuple[Any, bool]] = {} + self.collected_annotations = {} self.declared_columns = util.OrderedSet() self.column_copies = {} @@ -360,31 +446,37 @@ class _ClassScanMapperConfig(_MapperConfig): self._early_mapping(mapper_kw) - def _setup_declared_events(self): + def _setup_declared_events(self) -> None: if _get_immediate_cls_attr(self.cls, "__declare_last__"): @event.listens_for(mapper, "after_configured") - def after_configured(): - self.cls.__declare_last__() + def after_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_last__() if _get_immediate_cls_attr(self.cls, "__declare_first__"): @event.listens_for(mapper, "before_configured") - def before_configured(): - self.cls.__declare_first__() - - def _cls_attr_override_checker(self, cls): + def before_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_first__() + + def _cls_attr_override_checker( + self, cls: Type[_O] + ) -> Callable[[str, Any], bool]: """Produce a function that checks if a class has overridden an attribute, taking SQLAlchemy-enabled dataclass fields into account. """ sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__", None + cls, "__sa_dataclass_metadata_key__" ) if sa_dataclass_metadata_key is None: - def attribute_is_overridden(key, obj): + def attribute_is_overridden(key: str, obj: Any) -> bool: return getattr(cls, key) is not obj else: @@ -402,7 +494,7 @@ class _ClassScanMapperConfig(_MapperConfig): absent = object() - def attribute_is_overridden(key, obj): + def attribute_is_overridden(key: str, obj: Any) -> bool: if _is_declarative_props(obj): obj = obj.fget @@ -457,13 +549,15 @@ class _ClassScanMapperConfig(_MapperConfig): ] ) - def _cls_attr_resolver(self, cls): + def _cls_attr_resolver( + self, cls: Type[Any] + ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]: """produce a function to iterate the "attributes" of a class, adjusting for SQLAlchemy fields embedded in dataclass fields. """ - sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__", None + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" ) cls_annotations = util.get_annotations(cls) @@ -477,7 +571,9 @@ class _ClassScanMapperConfig(_MapperConfig): ) if sa_dataclass_metadata_key is None: - def local_attributes_for_class(): + def local_attributes_for_class() -> Iterable[ + Tuple[str, Any, Any, bool] + ]: return ( ( name, @@ -493,12 +589,16 @@ class _ClassScanMapperConfig(_MapperConfig): field.name: field for field in util.local_dataclass_fields(cls) } - def local_attributes_for_class(): + fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key + + def local_attributes_for_class() -> Iterable[ + Tuple[str, Any, Any, bool] + ]: for name in names: field = dataclass_fields.get(name, None) if field and sa_dataclass_metadata_key in field.metadata: yield field.name, _as_dc_declaredattr( - field.metadata, sa_dataclass_metadata_key + field.metadata, fixed_sa_dataclass_metadata_key ), cls_annotations.get(field.name), True else: yield name, cls_vars.get(name), cls_annotations.get( @@ -507,14 +607,17 @@ class _ClassScanMapperConfig(_MapperConfig): return local_attributes_for_class - def _scan_attributes(self): + def _scan_attributes(self) -> None: cls = self.cls + cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls) + clsdict_view = self.clsdict_view collected_attributes = self.collected_attributes column_copies = self.column_copies mapper_args_fn = None table_args = inherited_table_args = None + tablename = None fixed_table = "__table__" in clsdict_view @@ -555,21 +658,23 @@ class _ClassScanMapperConfig(_MapperConfig): # make a copy of it so a class-level dictionary # is not overwritten when we update column-based # arguments. - def mapper_args_fn(): - return dict(cls.__mapper_args__) + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn elif name == "__tablename__": check_decl = _check_declared_props_nocascade( obj, name, cls ) if not tablename and (not class_mapped or check_decl): - tablename = cls.__tablename__ + tablename = cls_as_Decl.__tablename__ elif name == "__table_args__": check_decl = _check_declared_props_nocascade( obj, name, cls ) if not table_args and (not class_mapped or check_decl): - table_args = cls.__table_args__ + table_args = cls_as_Decl.__table_args__ if not isinstance( table_args, (tuple, dict, type(None)) ): @@ -657,9 +762,10 @@ class _ClassScanMapperConfig(_MapperConfig): # or similar. note there is no known case that # produces nested proxies, so we are only # looking one level deep right now. + if ( isinstance(ret, InspectionAttr) - and ret._is_internal_proxy + and attr_is_internal_proxy(ret) and not isinstance( ret.original_property, MapperProperty ) @@ -669,6 +775,7 @@ class _ClassScanMapperConfig(_MapperConfig): collected_attributes[name] = column_copies[ obj ] = ret + if ( isinstance(ret, (Column, MapperProperty)) and ret.doc is None @@ -737,7 +844,9 @@ class _ClassScanMapperConfig(_MapperConfig): self.tablename = tablename self.mapper_args_fn = mapper_args_fn - def _warn_for_decl_attributes(self, cls, key, c): + def _warn_for_decl_attributes( + self, cls: Type[Any], key: str, c: Any + ) -> None: if isinstance(c, expression.ColumnClause): util.warn( f"Attribute '{key}' on class {cls} appears to " @@ -746,8 +855,12 @@ class _ClassScanMapperConfig(_MapperConfig): ) def _produce_column_copies( - self, attributes_for_class, attribute_is_overridden - ): + self, + attributes_for_class: Callable[ + [], Iterable[Tuple[str, Any, Any, bool]] + ], + attribute_is_overridden: Callable[[str, Any], bool], + ) -> None: cls = self.cls dict_ = self.clsdict_view collected_attributes = self.collected_attributes @@ -763,7 +876,8 @@ class _ClassScanMapperConfig(_MapperConfig): continue elif name not in dict_ and not ( "__table__" in dict_ - and (obj.name or name) in dict_["__table__"].c + and (getattr(obj, "name", None) or name) + in dict_["__table__"].c ): if obj.foreign_keys: for fk in obj.foreign_keys: @@ -786,7 +900,7 @@ class _ClassScanMapperConfig(_MapperConfig): setattr(cls, name, copy_) - def _extract_mappable_attributes(self): + def _extract_mappable_attributes(self) -> None: cls = self.cls collected_attributes = self.collected_attributes @@ -858,17 +972,19 @@ class _ClassScanMapperConfig(_MapperConfig): "declarative base class." ) elif isinstance(value, Column): - _undefer_column_name(k, self.column_copies.get(value, value)) + _undefer_column_name( + k, self.column_copies.get(value, value) # type: ignore + ) elif isinstance(value, _IntrospectsAnnotations): annotation, is_dataclass = self.collected_annotations.get( - k, (None, None) + k, (None, False) ) value.declarative_scan( self.registry, cls, k, annotation, is_dataclass ) our_stuff[k] = value - def _extract_declared_columns(self): + def _extract_declared_columns(self) -> None: our_stuff = self.properties # extract columns from the class dict @@ -914,8 +1030,10 @@ class _ClassScanMapperConfig(_MapperConfig): % (self.classname, name, (", ".join(sorted(keys)))) ) - def _setup_table(self, table=None): + def _setup_table(self, table: Optional[FromClause] = None) -> None: cls = self.cls + cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls) + tablename = self.tablename table_args = self.table_args clsdict_view = self.clsdict_view @@ -925,13 +1043,18 @@ class _ClassScanMapperConfig(_MapperConfig): if "__table__" not in clsdict_view and table is None: if hasattr(cls, "__table_cls__"): - table_cls = util.unbound_method_to_callable(cls.__table_cls__) + table_cls = cast( + Type[Table], + util.unbound_method_to_callable(cls.__table_cls__), # type: ignore # noqa: E501 + ) else: table_cls = Table if tablename is not None: - args, table_kw = (), {} + args: Tuple[Any, ...] = () + table_kw: Dict[str, Any] = {} + if table_args: if isinstance(table_args, dict): table_kw = table_args @@ -960,7 +1083,7 @@ class _ClassScanMapperConfig(_MapperConfig): ) else: if table is None: - table = cls.__table__ + table = cls_as_Decl.__table__ if declared_columns: for c in declared_columns: if not table.c.contains_column(c): @@ -968,15 +1091,16 @@ class _ClassScanMapperConfig(_MapperConfig): "Can't add additional column %r when " "specifying __table__" % c.key ) + self.local_table = table - def _metadata_for_cls(self, manager): + def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData: if hasattr(self.cls, "metadata"): - return self.cls.metadata + return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata else: return manager.registry.metadata - def _setup_inheritance(self, mapper_kw): + def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: table = self.local_table cls = self.cls table_args = self.table_args @@ -988,8 +1112,8 @@ class _ClassScanMapperConfig(_MapperConfig): # since we search for classical mappings now, search for # multiple mapped bases as well and raise an error. inherits_search = [] - for c in cls.__bases__: - c = _resolve_for_abstract_or_classical(c) + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) if c is None: continue if _declared_mapping_info( @@ -1024,9 +1148,12 @@ class _ClassScanMapperConfig(_MapperConfig): "table-mapped class." % cls ) elif self.inherits: - inherited_mapper = _declared_mapping_info(self.inherits) - inherited_table = inherited_mapper.local_table - inherited_persist_selectable = inherited_mapper.persist_selectable + inherited_mapper_or_config = _declared_mapping_info(self.inherits) + assert inherited_mapper_or_config is not None + inherited_table = inherited_mapper_or_config.local_table + inherited_persist_selectable = ( + inherited_mapper_or_config.persist_selectable + ) if table is None: # single table inheritance. @@ -1036,29 +1163,44 @@ class _ClassScanMapperConfig(_MapperConfig): "Can't place __table_args__ on an inherited class " "with no table." ) + # add any columns declared here to the inherited table. - for c in declared_columns: - if c.name in inherited_table.c: - if inherited_table.c[c.name] is c: + if declared_columns and not isinstance(inherited_table, Table): + raise exc.ArgumentError( + f"Can't declare columns on single-table-inherited " + f"subclass {self.cls}; superclass {self.inherits} " + "is not mapped to a Table" + ) + + for col in declared_columns: + assert inherited_table is not None + if col.name in inherited_table.c: + if inherited_table.c[col.name] is col: continue raise exc.ArgumentError( "Column '%s' on class %s conflicts with " "existing column '%s'" - % (c, cls, inherited_table.c[c.name]) + % (col, cls, inherited_table.c[col.name]) ) - if c.primary_key: + if col.primary_key: raise exc.ArgumentError( "Can't place primary key columns on an inherited " "class with no table." ) - inherited_table.append_column(c) + + if TYPE_CHECKING: + assert isinstance(inherited_table, Table) + + inherited_table.append_column(col) if ( inherited_persist_selectable is not None and inherited_persist_selectable is not inherited_table ): - inherited_persist_selectable._refresh_for_new_column(c) + inherited_persist_selectable._refresh_for_new_column( + col + ) - def _prepare_mapper_arguments(self, mapper_kw): + def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None: properties = self.properties if self.mapper_args_fn: @@ -1100,6 +1242,7 @@ class _ClassScanMapperConfig(_MapperConfig): # not mapped on the parent class, to avoid # mapping columns specific to sibling/nephew classes inherited_mapper = _declared_mapping_info(self.inherits) + assert isinstance(inherited_mapper, Mapper) inherited_table = inherited_mapper.local_table if "exclude_properties" not in mapper_args: @@ -1133,11 +1276,14 @@ class _ClassScanMapperConfig(_MapperConfig): result_mapper_args["properties"] = properties self.mapper_args = result_mapper_args - def map(self, mapper_kw=util.EMPTY_DICT): + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: self._prepare_mapper_arguments(mapper_kw) if hasattr(self.cls, "__mapper_cls__"): - mapper_cls = util.unbound_method_to_callable( - self.cls.__mapper_cls__ + mapper_cls = cast( + "Type[Mapper[Any]]", + util.unbound_method_to_callable( + self.cls.__mapper_cls__ # type: ignore + ), ) else: mapper_cls = mapper @@ -1149,7 +1295,9 @@ class _ClassScanMapperConfig(_MapperConfig): @util.preload_module("sqlalchemy.orm.decl_api") -def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key): +def _as_dc_declaredattr( + field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str +) -> Any: # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr. # we can't write it because field.metadata is immutable :( so we have # to go through extra trouble to compare these @@ -1162,46 +1310,55 @@ def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key): class _DeferredMapperConfig(_ClassScanMapperConfig): - _configs = util.OrderedDict() + _cls: weakref.ref[Type[Any]] + + _configs: util.OrderedDict[ + weakref.ref[Type[Any]], _DeferredMapperConfig + ] = util.OrderedDict() - def _early_mapping(self, mapper_kw): + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: pass - @property - def cls(self): - return self._cls() + # mypy disallows plain property override of variable + @property # type: ignore + def cls(self) -> Type[Any]: # type: ignore + return self._cls() # type: ignore @cls.setter - def cls(self, class_): + def cls(self, class_: Type[Any]) -> None: self._cls = weakref.ref(class_, self._remove_config_cls) self._configs[self._cls] = self @classmethod - def _remove_config_cls(cls, ref): + def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None: cls._configs.pop(ref, None) @classmethod - def has_cls(cls, class_): + def has_cls(cls, class_: Type[Any]) -> bool: # 2.6 fails on weakref if class_ is an old style class return isinstance(class_, type) and weakref.ref(class_) in cls._configs @classmethod - def raise_unmapped_for_cls(cls, class_): + def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn: if hasattr(class_, "_sa_raise_deferred_config"): - class_._sa_raise_deferred_config() + class_._sa_raise_deferred_config() # type: ignore raise orm_exc.UnmappedClassError( class_, - msg="Class %s has a deferred mapping on it. It is not yet " - "usable as a mapped class." % orm_exc._safe_cls_name(class_), + msg=( + f"Class {orm_exc._safe_cls_name(class_)} has a deferred " + "mapping on it. It is not yet usable as a mapped class." + ), ) @classmethod - def config_for_cls(cls, class_): + def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig: return cls._configs[weakref.ref(class_)] @classmethod - def classes_for_base(cls, base_cls, sort=True): + def classes_for_base( + cls, base_cls: Type[Any], sort: bool = True + ) -> List[_DeferredMapperConfig]: classes_for_base = [ m for m, cls_ in [(m, m.cls) for m in cls._configs.values()] @@ -1213,7 +1370,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): all_m_by_cls = dict((m.cls, m) for m in classes_for_base) - tuples = [] + tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = [] for m_cls in all_m_by_cls: tuples.extend( (all_m_by_cls[base_cls], all_m_by_cls[m_cls]) @@ -1222,12 +1379,14 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): ) return list(topological.sort(tuples, classes_for_base)) - def map(self, mapper_kw=util.EMPTY_DICT): + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: self._configs.pop(self._cls, None) return super(_DeferredMapperConfig, self).map(mapper_kw) -def _add_attribute(cls, key, value): +def _add_attribute( + cls: Type[Any], key: str, value: MapperProperty[Any] +) -> None: """add an attribute to an existing declarative class. This runs through the logic to determine MapperProperty, @@ -1236,39 +1395,44 @@ def _add_attribute(cls, key, value): """ if "__mapper__" in cls.__dict__: + mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls) if isinstance(value, Column): _undefer_column_name(key, value) - cls.__table__.append_column(value, replace_existing=True) - cls.__mapper__.add_property(key, value) + # TODO: raise for this is not a Table + mapped_cls.__table__.append_column(value, replace_existing=True) + mapped_cls.__mapper__.add_property(key, value) elif isinstance(value, _MapsColumns): mp = value.mapper_property_to_assign for col in value.columns_to_assign: _undefer_column_name(key, col) - cls.__table__.append_column(col, replace_existing=True) + # TODO: raise for this is not a Table + mapped_cls.__table__.append_column(col, replace_existing=True) if not mp: - cls.__mapper__.add_property(key, col) + mapped_cls.__mapper__.add_property(key, col) if mp: - cls.__mapper__.add_property(key, mp) + mapped_cls.__mapper__.add_property(key, mp) elif isinstance(value, MapperProperty): - cls.__mapper__.add_property(key, value) + mapped_cls.__mapper__.add_property(key, value) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() value = Synonym(value.key) - cls.__mapper__.add_property(key, value) + mapped_cls.__mapper__.add_property(key, value) else: type.__setattr__(cls, key, value) - cls.__mapper__._expire_memoizations() + mapped_cls.__mapper__._expire_memoizations() else: type.__setattr__(cls, key, value) -def _del_attribute(cls, key): +def _del_attribute(cls: Type[Any], key: str) -> None: if ( "__mapper__" in cls.__dict__ and key in cls.__dict__ - and not cls.__mapper__._dispose_called + and not cast( + "_DeclMappedClassProtocol[Any]", cls + ).__mapper__._dispose_called ): value = cls.__dict__[key] if isinstance( @@ -1279,12 +1443,14 @@ def _del_attribute(cls, key): ) else: type.__delattr__(cls, key) - cls.__mapper__._expire_memoizations() + cast( + "_DeclMappedClassProtocol[Any]", cls + ).__mapper__._expire_memoizations() else: type.__delattr__(cls, key) -def _declarative_constructor(self, **kwargs): +def _declarative_constructor(self: Any, **kwargs: Any) -> None: """A simple constructor that allows initialization from kwargs. Sets attributes on the constructed instance using the names and @@ -1306,7 +1472,7 @@ def _declarative_constructor(self, **kwargs): _declarative_constructor.__name__ = "__init__" -def _undefer_column_name(key, column): +def _undefer_column_name(key: str, column: Column[Any]) -> None: if column.key is None: column.key = key if column.name is None: |