summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/decl_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/decl_base.py')
-rw-r--r--lib/sqlalchemy/orm/decl_base.py414
1 files changed, 290 insertions, 124 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index b1f81cb6b..c3faac36c 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -4,16 +4,26 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""Internal implementation for declarative."""
from __future__ import annotations
import collections
from typing import Any
+from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import NoReturn
+from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import weakref
from . import attributes
@@ -21,6 +31,8 @@ from . import clsregistry
from . import exc as orm_exc
from . import instrumentation
from . import mapperlib
+from ._typing import _O
+from ._typing import attr_is_internal_proxy
from .attributes import InstrumentedAttribute
from .attributes import QueryableAttribute
from .base import _is_mapped_class
@@ -32,6 +44,7 @@ from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .mapper import Mapper as mapper
+from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
from .util import _is_mapped_annotation
@@ -43,12 +56,41 @@ from ..sql import expression
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import Protocol
if TYPE_CHECKING:
+ from ._typing import _ClassDict
from ._typing import _RegistryType
+ from .decl_api import declared_attr
+ from .instrumentation import ClassManager
+ from ..sql.schema import MetaData
+ from ..sql.selectable import FromClause
+
+_T = TypeVar("_T", bound=Any)
+
+_MapperKwArgs = Mapping[str, Any]
+
+_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]]
-def _declared_mapping_info(cls):
+class _DeclMappedClassProtocol(Protocol[_O]):
+ metadata: MetaData
+ __mapper__: Mapper[_O]
+ __table__: Table
+ __tablename__: str
+ __mapper_args__: Mapping[str, Any]
+ __table_args__: Optional[_TableArgsType]
+
+ def __declare_first__(self) -> None:
+ pass
+
+ def __declare_last__(self) -> None:
+ pass
+
+
+def _declared_mapping_info(
+ cls: Type[Any],
+) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]:
# deferred mapping
if _DeferredMapperConfig.has_cls(cls):
return _DeferredMapperConfig.config_for_cls(cls)
@@ -59,13 +101,15 @@ def _declared_mapping_info(cls):
return None
-def _resolve_for_abstract_or_classical(cls):
+def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]:
if cls is object:
return None
+ sup: Optional[Type[Any]]
+
if cls.__dict__.get("__abstract__", False):
- for sup in cls.__bases__:
- sup = _resolve_for_abstract_or_classical(sup)
+ for base_ in cls.__bases__:
+ sup = _resolve_for_abstract_or_classical(base_)
if sup is not None:
return sup
else:
@@ -79,7 +123,9 @@ def _resolve_for_abstract_or_classical(cls):
return cls
-def _get_immediate_cls_attr(cls, attrname, strict=False):
+def _get_immediate_cls_attr(
+ cls: Type[Any], attrname: str, strict: bool = False
+) -> Optional[Any]:
"""return an attribute of the class that is either present directly
on the class, e.g. not on a superclass, or is from a superclass but
this superclass is a non-mapped mixin, that is, not a descendant of
@@ -102,7 +148,7 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
return getattr(cls, attrname)
for base in cls.__mro__[1:]:
- _is_classicial_inherits = _dive_for_cls_manager(base)
+ _is_classicial_inherits = _dive_for_cls_manager(base) is not None
if attrname in base.__dict__ and (
base is cls
@@ -116,33 +162,37 @@ def _get_immediate_cls_attr(cls, attrname, strict=False):
return None
-def _dive_for_cls_manager(cls):
+def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]:
# because the class manager registration is pluggable,
# we need to do the search for every class in the hierarchy,
# rather than just a simple "cls._sa_class_manager"
- # python 2 old style class
- if not hasattr(cls, "__mro__"):
- return None
-
for base in cls.__mro__:
- manager = attributes.opt_manager_of_class(base)
+ manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class(
+ base
+ )
if manager:
return manager
return None
-def _as_declarative(registry, cls, dict_):
+def _as_declarative(
+ registry: _RegistryType, cls: Type[Any], dict_: _ClassDict
+) -> Optional[_MapperConfig]:
# declarative scans the class for attributes. no table or mapper
# args passed separately.
-
return _MapperConfig.setup_mapping(registry, cls, dict_, None, {})
-def _mapper(registry, cls, table, mapper_kw):
+def _mapper(
+ registry: _RegistryType,
+ cls: Type[_O],
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
+) -> Mapper[_O]:
_ImperativeMapperConfig(registry, cls, table, mapper_kw)
- return cls.__mapper__
+ return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__
@util.preload_module("sqlalchemy.orm.decl_api")
@@ -152,7 +202,9 @@ def _is_declarative_props(obj: Any) -> bool:
return isinstance(obj, (declared_attr, util.classproperty))
-def _check_declared_props_nocascade(obj, name, cls):
+def _check_declared_props_nocascade(
+ obj: Any, name: str, cls: Type[_O]
+) -> bool:
if _is_declarative_props(obj):
if getattr(obj, "_cascading", False):
util.warn(
@@ -174,8 +226,20 @@ class _MapperConfig:
"__weakref__",
)
+ cls: Type[Any]
+ classname: str
+ properties: util.OrderedDict[str, MapperProperty[Any]]
+ declared_attr_reg: Dict[declared_attr[Any], Any]
+
@classmethod
- def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
+ def setup_mapping(
+ cls,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ dict_: _ClassDict,
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
+ ) -> Optional[_MapperConfig]:
manager = attributes.opt_manager_of_class(cls)
if manager and manager.class_ is cls_:
raise exc.InvalidRequestError(
@@ -183,24 +247,26 @@ class _MapperConfig:
)
if cls_.__dict__.get("__abstract__", False):
- return
+ return None
defer_map = _get_immediate_cls_attr(
cls_, "_sa_decl_prepare_nocascade", strict=True
) or hasattr(cls_, "_sa_decl_prepare")
if defer_map:
- cfg_cls = _DeferredMapperConfig
+ return _DeferredMapperConfig(
+ registry, cls_, dict_, table, mapper_kw
+ )
else:
- cfg_cls = _ClassScanMapperConfig
-
- return cfg_cls(registry, cls_, dict_, table, mapper_kw)
+ return _ClassScanMapperConfig(
+ registry, cls_, dict_, table, mapper_kw
+ )
def __init__(
self,
registry: _RegistryType,
cls_: Type[Any],
- mapper_kw: Dict[str, Any],
+ mapper_kw: _MapperKwArgs,
):
self.cls = util.assert_arg_type(cls_, type, "cls_")
self.classname = cls_.__name__
@@ -224,13 +290,16 @@ class _MapperConfig:
"Mapper." % self.cls
)
- def set_cls_attribute(self, attrname, value):
+ def set_cls_attribute(self, attrname: str, value: _T) -> _T:
manager = instrumentation.manager_of_class(self.cls)
manager.install_member(attrname, value)
return value
- def _early_mapping(self, mapper_kw):
+ def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]:
+ raise NotImplementedError()
+
+ def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
self.map(mapper_kw)
@@ -239,10 +308,10 @@ class _ImperativeMapperConfig(_MapperConfig):
def __init__(
self,
- registry,
- cls_,
- table,
- mapper_kw,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
):
super(_ImperativeMapperConfig, self).__init__(
registry, cls_, mapper_kw
@@ -260,7 +329,7 @@ class _ImperativeMapperConfig(_MapperConfig):
self._early_mapping(mapper_kw)
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
mapper_cls = mapper
return self.set_cls_attribute(
@@ -268,7 +337,7 @@ class _ImperativeMapperConfig(_MapperConfig):
mapper_cls(self.cls, self.local_table, **mapper_kw),
)
- def _setup_inheritance(self, mapper_kw):
+ def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
cls = self.cls
inherits = mapper_kw.get("inherits", None)
@@ -277,8 +346,8 @@ class _ImperativeMapperConfig(_MapperConfig):
# since we search for classical mappings now, search for
# multiple mapped bases as well and raise an error.
inherits_search = []
- for c in cls.__bases__:
- c = _resolve_for_abstract_or_classical(c)
+ for base_ in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(base_)
if c is None:
continue
if _declared_mapping_info(
@@ -318,13 +387,30 @@ class _ClassScanMapperConfig(_MapperConfig):
"inherits",
)
+ registry: _RegistryType
+ clsdict_view: _ClassDict
+ collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_attributes: Dict[str, Any]
+ local_table: Optional[FromClause]
+ persist_selectable: Optional[FromClause]
+ declared_columns: util.OrderedSet[Column[Any]]
+ column_copies: Dict[
+ Union[MappedColumn[Any], Column[Any]],
+ Union[MappedColumn[Any], Column[Any]],
+ ]
+ tablename: Optional[str]
+ mapper_args: Mapping[str, Any]
+ table_args: Optional[_TableArgsType]
+ mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
+ inherits: Optional[Type[Any]]
+
def __init__(
self,
- registry,
- cls_,
- dict_,
- table,
- mapper_kw,
+ registry: _RegistryType,
+ cls_: Type[_O],
+ dict_: _ClassDict,
+ table: Optional[FromClause],
+ mapper_kw: _MapperKwArgs,
):
# grab class dict before the instrumentation manager has been added.
@@ -337,7 +423,7 @@ class _ClassScanMapperConfig(_MapperConfig):
self.persist_selectable = None
self.collected_attributes = {}
- self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
+ self.collected_annotations = {}
self.declared_columns = util.OrderedSet()
self.column_copies = {}
@@ -360,31 +446,37 @@ class _ClassScanMapperConfig(_MapperConfig):
self._early_mapping(mapper_kw)
- def _setup_declared_events(self):
+ def _setup_declared_events(self) -> None:
if _get_immediate_cls_attr(self.cls, "__declare_last__"):
@event.listens_for(mapper, "after_configured")
- def after_configured():
- self.cls.__declare_last__()
+ def after_configured() -> None:
+ cast(
+ "_DeclMappedClassProtocol[Any]", self.cls
+ ).__declare_last__()
if _get_immediate_cls_attr(self.cls, "__declare_first__"):
@event.listens_for(mapper, "before_configured")
- def before_configured():
- self.cls.__declare_first__()
-
- def _cls_attr_override_checker(self, cls):
+ def before_configured() -> None:
+ cast(
+ "_DeclMappedClassProtocol[Any]", self.cls
+ ).__declare_first__()
+
+ def _cls_attr_override_checker(
+ self, cls: Type[_O]
+ ) -> Callable[[str, Any], bool]:
"""Produce a function that checks if a class has overridden an
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__", None
+ cls, "__sa_dataclass_metadata_key__"
)
if sa_dataclass_metadata_key is None:
- def attribute_is_overridden(key, obj):
+ def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
else:
@@ -402,7 +494,7 @@ class _ClassScanMapperConfig(_MapperConfig):
absent = object()
- def attribute_is_overridden(key, obj):
+ def attribute_is_overridden(key: str, obj: Any) -> bool:
if _is_declarative_props(obj):
obj = obj.fget
@@ -457,13 +549,15 @@ class _ClassScanMapperConfig(_MapperConfig):
]
)
- def _cls_attr_resolver(self, cls):
+ def _cls_attr_resolver(
+ self, cls: Type[Any]
+ ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]:
"""produce a function to iterate the "attributes" of a class,
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__", None
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
)
cls_annotations = util.get_annotations(cls)
@@ -477,7 +571,9 @@ class _ClassScanMapperConfig(_MapperConfig):
)
if sa_dataclass_metadata_key is None:
- def local_attributes_for_class():
+ def local_attributes_for_class() -> Iterable[
+ Tuple[str, Any, Any, bool]
+ ]:
return (
(
name,
@@ -493,12 +589,16 @@ class _ClassScanMapperConfig(_MapperConfig):
field.name: field for field in util.local_dataclass_fields(cls)
}
- def local_attributes_for_class():
+ fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key
+
+ def local_attributes_for_class() -> Iterable[
+ Tuple[str, Any, Any, bool]
+ ]:
for name in names:
field = dataclass_fields.get(name, None)
if field and sa_dataclass_metadata_key in field.metadata:
yield field.name, _as_dc_declaredattr(
- field.metadata, sa_dataclass_metadata_key
+ field.metadata, fixed_sa_dataclass_metadata_key
), cls_annotations.get(field.name), True
else:
yield name, cls_vars.get(name), cls_annotations.get(
@@ -507,14 +607,17 @@ class _ClassScanMapperConfig(_MapperConfig):
return local_attributes_for_class
- def _scan_attributes(self):
+ def _scan_attributes(self) -> None:
cls = self.cls
+ cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
clsdict_view = self.clsdict_view
collected_attributes = self.collected_attributes
column_copies = self.column_copies
mapper_args_fn = None
table_args = inherited_table_args = None
+
tablename = None
fixed_table = "__table__" in clsdict_view
@@ -555,21 +658,23 @@ class _ClassScanMapperConfig(_MapperConfig):
# make a copy of it so a class-level dictionary
# is not overwritten when we update column-based
# arguments.
- def mapper_args_fn():
- return dict(cls.__mapper_args__)
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
elif name == "__tablename__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
)
if not tablename and (not class_mapped or check_decl):
- tablename = cls.__tablename__
+ tablename = cls_as_Decl.__tablename__
elif name == "__table_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
)
if not table_args and (not class_mapped or check_decl):
- table_args = cls.__table_args__
+ table_args = cls_as_Decl.__table_args__
if not isinstance(
table_args, (tuple, dict, type(None))
):
@@ -657,9 +762,10 @@ class _ClassScanMapperConfig(_MapperConfig):
# or similar. note there is no known case that
# produces nested proxies, so we are only
# looking one level deep right now.
+
if (
isinstance(ret, InspectionAttr)
- and ret._is_internal_proxy
+ and attr_is_internal_proxy(ret)
and not isinstance(
ret.original_property, MapperProperty
)
@@ -669,6 +775,7 @@ class _ClassScanMapperConfig(_MapperConfig):
collected_attributes[name] = column_copies[
obj
] = ret
+
if (
isinstance(ret, (Column, MapperProperty))
and ret.doc is None
@@ -737,7 +844,9 @@ class _ClassScanMapperConfig(_MapperConfig):
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
- def _warn_for_decl_attributes(self, cls, key, c):
+ def _warn_for_decl_attributes(
+ self, cls: Type[Any], key: str, c: Any
+ ) -> None:
if isinstance(c, expression.ColumnClause):
util.warn(
f"Attribute '{key}' on class {cls} appears to "
@@ -746,8 +855,12 @@ class _ClassScanMapperConfig(_MapperConfig):
)
def _produce_column_copies(
- self, attributes_for_class, attribute_is_overridden
- ):
+ self,
+ attributes_for_class: Callable[
+ [], Iterable[Tuple[str, Any, Any, bool]]
+ ],
+ attribute_is_overridden: Callable[[str, Any], bool],
+ ) -> None:
cls = self.cls
dict_ = self.clsdict_view
collected_attributes = self.collected_attributes
@@ -763,7 +876,8 @@ class _ClassScanMapperConfig(_MapperConfig):
continue
elif name not in dict_ and not (
"__table__" in dict_
- and (obj.name or name) in dict_["__table__"].c
+ and (getattr(obj, "name", None) or name)
+ in dict_["__table__"].c
):
if obj.foreign_keys:
for fk in obj.foreign_keys:
@@ -786,7 +900,7 @@ class _ClassScanMapperConfig(_MapperConfig):
setattr(cls, name, copy_)
- def _extract_mappable_attributes(self):
+ def _extract_mappable_attributes(self) -> None:
cls = self.cls
collected_attributes = self.collected_attributes
@@ -858,17 +972,19 @@ class _ClassScanMapperConfig(_MapperConfig):
"declarative base class."
)
elif isinstance(value, Column):
- _undefer_column_name(k, self.column_copies.get(value, value))
+ _undefer_column_name(
+ k, self.column_copies.get(value, value) # type: ignore
+ )
elif isinstance(value, _IntrospectsAnnotations):
annotation, is_dataclass = self.collected_annotations.get(
- k, (None, None)
+ k, (None, False)
)
value.declarative_scan(
self.registry, cls, k, annotation, is_dataclass
)
our_stuff[k] = value
- def _extract_declared_columns(self):
+ def _extract_declared_columns(self) -> None:
our_stuff = self.properties
# extract columns from the class dict
@@ -914,8 +1030,10 @@ class _ClassScanMapperConfig(_MapperConfig):
% (self.classname, name, (", ".join(sorted(keys))))
)
- def _setup_table(self, table=None):
+ def _setup_table(self, table: Optional[FromClause] = None) -> None:
cls = self.cls
+ cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+
tablename = self.tablename
table_args = self.table_args
clsdict_view = self.clsdict_view
@@ -925,13 +1043,18 @@ class _ClassScanMapperConfig(_MapperConfig):
if "__table__" not in clsdict_view and table is None:
if hasattr(cls, "__table_cls__"):
- table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+ table_cls = cast(
+ Type[Table],
+ util.unbound_method_to_callable(cls.__table_cls__), # type: ignore # noqa: E501
+ )
else:
table_cls = Table
if tablename is not None:
- args, table_kw = (), {}
+ args: Tuple[Any, ...] = ()
+ table_kw: Dict[str, Any] = {}
+
if table_args:
if isinstance(table_args, dict):
table_kw = table_args
@@ -960,7 +1083,7 @@ class _ClassScanMapperConfig(_MapperConfig):
)
else:
if table is None:
- table = cls.__table__
+ table = cls_as_Decl.__table__
if declared_columns:
for c in declared_columns:
if not table.c.contains_column(c):
@@ -968,15 +1091,16 @@ class _ClassScanMapperConfig(_MapperConfig):
"Can't add additional column %r when "
"specifying __table__" % c.key
)
+
self.local_table = table
- def _metadata_for_cls(self, manager):
+ def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData:
if hasattr(self.cls, "metadata"):
- return self.cls.metadata
+ return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata
else:
return manager.registry.metadata
- def _setup_inheritance(self, mapper_kw):
+ def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
table = self.local_table
cls = self.cls
table_args = self.table_args
@@ -988,8 +1112,8 @@ class _ClassScanMapperConfig(_MapperConfig):
# since we search for classical mappings now, search for
# multiple mapped bases as well and raise an error.
inherits_search = []
- for c in cls.__bases__:
- c = _resolve_for_abstract_or_classical(c)
+ for base_ in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(base_)
if c is None:
continue
if _declared_mapping_info(
@@ -1024,9 +1148,12 @@ class _ClassScanMapperConfig(_MapperConfig):
"table-mapped class." % cls
)
elif self.inherits:
- inherited_mapper = _declared_mapping_info(self.inherits)
- inherited_table = inherited_mapper.local_table
- inherited_persist_selectable = inherited_mapper.persist_selectable
+ inherited_mapper_or_config = _declared_mapping_info(self.inherits)
+ assert inherited_mapper_or_config is not None
+ inherited_table = inherited_mapper_or_config.local_table
+ inherited_persist_selectable = (
+ inherited_mapper_or_config.persist_selectable
+ )
if table is None:
# single table inheritance.
@@ -1036,29 +1163,44 @@ class _ClassScanMapperConfig(_MapperConfig):
"Can't place __table_args__ on an inherited class "
"with no table."
)
+
# add any columns declared here to the inherited table.
- for c in declared_columns:
- if c.name in inherited_table.c:
- if inherited_table.c[c.name] is c:
+ if declared_columns and not isinstance(inherited_table, Table):
+ raise exc.ArgumentError(
+ f"Can't declare columns on single-table-inherited "
+ f"subclass {self.cls}; superclass {self.inherits} "
+ "is not mapped to a Table"
+ )
+
+ for col in declared_columns:
+ assert inherited_table is not None
+ if col.name in inherited_table.c:
+ if inherited_table.c[col.name] is col:
continue
raise exc.ArgumentError(
"Column '%s' on class %s conflicts with "
"existing column '%s'"
- % (c, cls, inherited_table.c[c.name])
+ % (col, cls, inherited_table.c[col.name])
)
- if c.primary_key:
+ if col.primary_key:
raise exc.ArgumentError(
"Can't place primary key columns on an inherited "
"class with no table."
)
- inherited_table.append_column(c)
+
+ if TYPE_CHECKING:
+ assert isinstance(inherited_table, Table)
+
+ inherited_table.append_column(col)
if (
inherited_persist_selectable is not None
and inherited_persist_selectable is not inherited_table
):
- inherited_persist_selectable._refresh_for_new_column(c)
+ inherited_persist_selectable._refresh_for_new_column(
+ col
+ )
- def _prepare_mapper_arguments(self, mapper_kw):
+ def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None:
properties = self.properties
if self.mapper_args_fn:
@@ -1100,6 +1242,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# not mapped on the parent class, to avoid
# mapping columns specific to sibling/nephew classes
inherited_mapper = _declared_mapping_info(self.inherits)
+ assert isinstance(inherited_mapper, Mapper)
inherited_table = inherited_mapper.local_table
if "exclude_properties" not in mapper_args:
@@ -1133,11 +1276,14 @@ class _ClassScanMapperConfig(_MapperConfig):
result_mapper_args["properties"] = properties
self.mapper_args = result_mapper_args
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
self._prepare_mapper_arguments(mapper_kw)
if hasattr(self.cls, "__mapper_cls__"):
- mapper_cls = util.unbound_method_to_callable(
- self.cls.__mapper_cls__
+ mapper_cls = cast(
+ "Type[Mapper[Any]]",
+ util.unbound_method_to_callable(
+ self.cls.__mapper_cls__ # type: ignore
+ ),
)
else:
mapper_cls = mapper
@@ -1149,7 +1295,9 @@ class _ClassScanMapperConfig(_MapperConfig):
@util.preload_module("sqlalchemy.orm.decl_api")
-def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
+def _as_dc_declaredattr(
+ field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str
+) -> Any:
# wrap lambdas inside dataclass fields inside an ad-hoc declared_attr.
# we can't write it because field.metadata is immutable :( so we have
# to go through extra trouble to compare these
@@ -1162,46 +1310,55 @@ def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
class _DeferredMapperConfig(_ClassScanMapperConfig):
- _configs = util.OrderedDict()
+ _cls: weakref.ref[Type[Any]]
+
+ _configs: util.OrderedDict[
+ weakref.ref[Type[Any]], _DeferredMapperConfig
+ ] = util.OrderedDict()
- def _early_mapping(self, mapper_kw):
+ def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None:
pass
- @property
- def cls(self):
- return self._cls()
+ # mypy disallows plain property override of variable
+ @property # type: ignore
+ def cls(self) -> Type[Any]: # type: ignore
+ return self._cls() # type: ignore
@cls.setter
- def cls(self, class_):
+ def cls(self, class_: Type[Any]) -> None:
self._cls = weakref.ref(class_, self._remove_config_cls)
self._configs[self._cls] = self
@classmethod
- def _remove_config_cls(cls, ref):
+ def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None:
cls._configs.pop(ref, None)
@classmethod
- def has_cls(cls, class_):
+ def has_cls(cls, class_: Type[Any]) -> bool:
# 2.6 fails on weakref if class_ is an old style class
return isinstance(class_, type) and weakref.ref(class_) in cls._configs
@classmethod
- def raise_unmapped_for_cls(cls, class_):
+ def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn:
if hasattr(class_, "_sa_raise_deferred_config"):
- class_._sa_raise_deferred_config()
+ class_._sa_raise_deferred_config() # type: ignore
raise orm_exc.UnmappedClassError(
class_,
- msg="Class %s has a deferred mapping on it. It is not yet "
- "usable as a mapped class." % orm_exc._safe_cls_name(class_),
+ msg=(
+ f"Class {orm_exc._safe_cls_name(class_)} has a deferred "
+ "mapping on it. It is not yet usable as a mapped class."
+ ),
)
@classmethod
- def config_for_cls(cls, class_):
+ def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig:
return cls._configs[weakref.ref(class_)]
@classmethod
- def classes_for_base(cls, base_cls, sort=True):
+ def classes_for_base(
+ cls, base_cls: Type[Any], sort: bool = True
+ ) -> List[_DeferredMapperConfig]:
classes_for_base = [
m
for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
@@ -1213,7 +1370,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig):
all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
- tuples = []
+ tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = []
for m_cls in all_m_by_cls:
tuples.extend(
(all_m_by_cls[base_cls], all_m_by_cls[m_cls])
@@ -1222,12 +1379,14 @@ class _DeferredMapperConfig(_ClassScanMapperConfig):
)
return list(topological.sort(tuples, classes_for_base))
- def map(self, mapper_kw=util.EMPTY_DICT):
+ def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
self._configs.pop(self._cls, None)
return super(_DeferredMapperConfig, self).map(mapper_kw)
-def _add_attribute(cls, key, value):
+def _add_attribute(
+ cls: Type[Any], key: str, value: MapperProperty[Any]
+) -> None:
"""add an attribute to an existing declarative class.
This runs through the logic to determine MapperProperty,
@@ -1236,39 +1395,44 @@ def _add_attribute(cls, key, value):
"""
if "__mapper__" in cls.__dict__:
+ mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls)
if isinstance(value, Column):
_undefer_column_name(key, value)
- cls.__table__.append_column(value, replace_existing=True)
- cls.__mapper__.add_property(key, value)
+ # TODO: raise for this is not a Table
+ mapped_cls.__table__.append_column(value, replace_existing=True)
+ mapped_cls.__mapper__.add_property(key, value)
elif isinstance(value, _MapsColumns):
mp = value.mapper_property_to_assign
for col in value.columns_to_assign:
_undefer_column_name(key, col)
- cls.__table__.append_column(col, replace_existing=True)
+ # TODO: raise for this is not a Table
+ mapped_cls.__table__.append_column(col, replace_existing=True)
if not mp:
- cls.__mapper__.add_property(key, col)
+ mapped_cls.__mapper__.add_property(key, col)
if mp:
- cls.__mapper__.add_property(key, mp)
+ mapped_cls.__mapper__.add_property(key, mp)
elif isinstance(value, MapperProperty):
- cls.__mapper__.add_property(key, value)
+ mapped_cls.__mapper__.add_property(key, value)
elif isinstance(value, QueryableAttribute) and value.key != key:
# detect a QueryableAttribute that's already mapped being
# assigned elsewhere in userland, turn into a synonym()
value = Synonym(value.key)
- cls.__mapper__.add_property(key, value)
+ mapped_cls.__mapper__.add_property(key, value)
else:
type.__setattr__(cls, key, value)
- cls.__mapper__._expire_memoizations()
+ mapped_cls.__mapper__._expire_memoizations()
else:
type.__setattr__(cls, key, value)
-def _del_attribute(cls, key):
+def _del_attribute(cls: Type[Any], key: str) -> None:
if (
"__mapper__" in cls.__dict__
and key in cls.__dict__
- and not cls.__mapper__._dispose_called
+ and not cast(
+ "_DeclMappedClassProtocol[Any]", cls
+ ).__mapper__._dispose_called
):
value = cls.__dict__[key]
if isinstance(
@@ -1279,12 +1443,14 @@ def _del_attribute(cls, key):
)
else:
type.__delattr__(cls, key)
- cls.__mapper__._expire_memoizations()
+ cast(
+ "_DeclMappedClassProtocol[Any]", cls
+ ).__mapper__._expire_memoizations()
else:
type.__delattr__(cls, key)
-def _declarative_constructor(self, **kwargs):
+def _declarative_constructor(self: Any, **kwargs: Any) -> None:
"""A simple constructor that allows initialization from kwargs.
Sets attributes on the constructed instance using the names and
@@ -1306,7 +1472,7 @@ def _declarative_constructor(self, **kwargs):
_declarative_constructor.__name__ = "__init__"
-def _undefer_column_name(key, column):
+def _undefer_column_name(key: str, column: Column[Any]) -> None:
if column.key is None:
column.key = key
if column.name is None: