diff options
Diffstat (limited to 'lib/sqlalchemy/orm/clsregistry.py')
-rw-r--r-- | lib/sqlalchemy/orm/clsregistry.py | 177 |
1 files changed, 129 insertions, 48 deletions
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 473468c6c..b3fcd29ea 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """Routines to handle the string class registry used by declarative. @@ -16,7 +15,22 @@ This system allows specification of classes and expressions used in from __future__ import annotations import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Mapping from typing import MutableMapping +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -29,6 +43,14 @@ from .. import exc from .. import inspection from .. import util from ..sql.schema import _get_table_key +from ..util.typing import CallableReference + +if TYPE_CHECKING: + from .relationships import Relationship + from ..sql.schema import MetaData + from ..sql.schema import Table + +_T = TypeVar("_T", bound=Any) _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] @@ -36,10 +58,12 @@ _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] # the _decl_class_registry, which is usually weak referencing. # the internal registries here link to classes with weakrefs and remove # themselves when all references to contained classes are removed. -_registries = set() +_registries: Set[ClsRegistryToken] = set() -def add_class(classname, cls, decl_class_registry): +def add_class( + classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType +) -> None: """Add a class to the _decl_class_registry associated with the given declarative class. @@ -49,13 +73,15 @@ def add_class(classname, cls, decl_class_registry): existing = decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): existing = decl_class_registry[classname] = _MultipleClassMarker( - [cls, existing] + [cls, cast("Type[Any]", existing)] ) else: decl_class_registry[classname] = cls try: - root_module = decl_class_registry["_sa_module_registry"] + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) except KeyError: decl_class_registry[ "_sa_module_registry" @@ -79,7 +105,9 @@ def add_class(classname, cls, decl_class_registry): module.add_class(classname, cls) -def remove_class(classname, cls, decl_class_registry): +def remove_class( + classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType +) -> None: if classname in decl_class_registry: existing = decl_class_registry[classname] if isinstance(existing, _MultipleClassMarker): @@ -88,7 +116,9 @@ def remove_class(classname, cls, decl_class_registry): del decl_class_registry[classname] try: - root_module = decl_class_registry["_sa_module_registry"] + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) except KeyError: return @@ -102,7 +132,11 @@ def remove_class(classname, cls, decl_class_registry): module.remove_class(classname, cls) -def _key_is_empty(key, decl_class_registry, test): +def _key_is_empty( + key: str, + decl_class_registry: _ClsRegistryType, + test: Callable[[Any], bool], +) -> bool: """test if a key is empty of a certain object. used for unit tests against the registry to see if garbage collection @@ -124,6 +158,8 @@ def _key_is_empty(key, decl_class_registry, test): for sub_thing in thing.contents: if test(sub_thing): return False + else: + raise NotImplementedError("unknown codepath") else: return not test(thing) @@ -142,20 +178,27 @@ class _MultipleClassMarker(ClsRegistryToken): __slots__ = "on_remove", "contents", "__weakref__" - def __init__(self, classes, on_remove=None): + contents: Set[weakref.ref[Type[Any]]] + on_remove: CallableReference[Optional[Callable[[], None]]] + + def __init__( + self, + classes: Iterable[Type[Any]], + on_remove: Optional[Callable[[], None]] = None, + ): self.on_remove = on_remove self.contents = set( [weakref.ref(item, self._remove_item) for item in classes] ) _registries.add(self) - def remove_item(self, cls): + def remove_item(self, cls: Type[Any]) -> None: self._remove_item(weakref.ref(cls)) - def __iter__(self): + def __iter__(self) -> Generator[Optional[Type[Any]], None, None]: return (ref() for ref in self.contents) - def attempt_get(self, path, key): + def attempt_get(self, path: List[str], key: str) -> Type[Any]: if len(self.contents) > 1: raise exc.InvalidRequestError( 'Multiple classes found for path "%s" ' @@ -170,14 +213,14 @@ class _MultipleClassMarker(ClsRegistryToken): raise NameError(key) return cls - def _remove_item(self, ref): + def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: self.contents.discard(ref) if not self.contents: _registries.discard(self) if self.on_remove: self.on_remove() - def add_item(self, item): + def add_item(self, item: Type[Any]) -> None: # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] @@ -206,7 +249,12 @@ class _ModuleMarker(ClsRegistryToken): __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" - def __init__(self, name, parent): + parent: Optional[_ModuleMarker] + contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]] + mod_ns: _ModNS + path: List[str] + + def __init__(self, name: str, parent: Optional[_ModuleMarker]): self.parent = parent self.name = name self.contents = {} @@ -217,51 +265,53 @@ class _ModuleMarker(ClsRegistryToken): self.path = [] _registries.add(self) - def __contains__(self, name): + def __contains__(self, name: str) -> bool: return name in self.contents - def __getitem__(self, name): + def __getitem__(self, name: str) -> ClsRegistryToken: return self.contents[name] - def _remove_item(self, name): + def _remove_item(self, name: str) -> None: self.contents.pop(name, None) if not self.contents and self.parent is not None: self.parent._remove_item(self.name) _registries.discard(self) - def resolve_attr(self, key): - return getattr(self.mod_ns, key) + def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: + return self.mod_ns.__getattr__(key) - def get_module(self, name): + def get_module(self, name: str) -> _ModuleMarker: if name not in self.contents: marker = _ModuleMarker(name, self) self.contents[name] = marker else: - marker = self.contents[name] + marker = cast(_ModuleMarker, self.contents[name]) return marker - def add_class(self, name, cls): + def add_class(self, name: str, cls: Type[Any]) -> None: if name in self.contents: - existing = self.contents[name] + existing = cast(_MultipleClassMarker, self.contents[name]) existing.add_item(cls) else: existing = self.contents[name] = _MultipleClassMarker( [cls], on_remove=lambda: self._remove_item(name) ) - def remove_class(self, name, cls): + def remove_class(self, name: str, cls: Type[Any]) -> None: if name in self.contents: - existing = self.contents[name] + existing = cast(_MultipleClassMarker, self.contents[name]) existing.remove_item(cls) class _ModNS: __slots__ = ("__parent",) - def __init__(self, parent): + __parent: _ModuleMarker + + def __init__(self, parent: _ModuleMarker): self.__parent = parent - def __getattr__(self, key): + def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]: try: value = self.__parent.contents[key] except KeyError: @@ -282,10 +332,12 @@ class _ModNS: class _GetColumns: __slots__ = ("cls",) - def __init__(self, cls): + cls: Type[Any] + + def __init__(self, cls: Type[Any]): self.cls = cls - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: mp = class_mapper(self.cls, configure=False) if mp: if key not in mp.all_orm_descriptors: @@ -296,6 +348,7 @@ class _GetColumns: desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: + assert isinstance(desc, attributes.QueryableAttribute) prop = desc.property if isinstance(prop, Synonym): key = prop.name @@ -316,15 +369,18 @@ inspection._inspects(_GetColumns)( class _GetTable: __slots__ = "key", "metadata" - def __init__(self, key, metadata): + key: str + metadata: MetaData + + def __init__(self, key: str, metadata: MetaData): self.key = key self.metadata = metadata - def __getattr__(self, key): + def __getattr__(self, key: str) -> Table: return self.metadata.tables[_get_table_key(key, self.key)] -def _determine_container(key, value): +def _determine_container(key: str, value: Any) -> _GetColumns: if isinstance(value, _MultipleClassMarker): value = value.attempt_get([], key) return _GetColumns(value) @@ -341,7 +397,21 @@ class _class_resolver: "favor_tables", ) - def __init__(self, cls, prop, fallback, arg, favor_tables=False): + cls: Type[Any] + prop: Relationship[Any] + fallback: Mapping[str, Any] + arg: str + favor_tables: bool + _resolvers: Tuple[Callable[[str], Any], ...] + + def __init__( + self, + cls: Type[Any], + prop: Relationship[Any], + fallback: Mapping[str, Any], + arg: str, + favor_tables: bool = False, + ): self.cls = cls self.prop = prop self.arg = arg @@ -350,11 +420,12 @@ class _class_resolver: self._resolvers = () self.favor_tables = favor_tables - def _access_cls(self, key): + def _access_cls(self, key: str) -> Any: cls = self.cls manager = attributes.manager_of_class(cls) decl_base = manager.registry + assert decl_base is not None decl_class_registry = decl_base._class_registry metadata = decl_base.metadata @@ -362,7 +433,7 @@ class _class_resolver: if key in metadata.tables: return metadata.tables[key] elif key in metadata._schemas: - return _GetTable(key, cls.metadata) + return _GetTable(key, getattr(cls, "metadata", metadata)) if key in decl_class_registry: return _determine_container(key, decl_class_registry[key]) @@ -371,13 +442,14 @@ class _class_resolver: if key in metadata.tables: return metadata.tables[key] elif key in metadata._schemas: - return _GetTable(key, cls.metadata) + return _GetTable(key, getattr(cls, "metadata", metadata)) - if ( - "_sa_module_registry" in decl_class_registry - and key in decl_class_registry["_sa_module_registry"] + if "_sa_module_registry" in decl_class_registry and key in cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] ): - registry = decl_class_registry["_sa_module_registry"] + registry = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) return registry.resolve_attr(key) elif self._resolvers: for resolv in self._resolvers: @@ -387,7 +459,7 @@ class _class_resolver: return self.fallback[key] - def _raise_for_name(self, name, err): + def _raise_for_name(self, name: str, err: Exception) -> NoReturn: generic_match = re.match(r"(.+)\[(.+)\]", name) if generic_match: @@ -409,7 +481,7 @@ class _class_resolver: % (self.prop.parent, self.arg, name, self.cls) ) from err - def _resolve_name(self): + def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]: name = self.arg d = self._dict rval = None @@ -427,9 +499,11 @@ class _class_resolver: if isinstance(rval, _GetColumns): return rval.cls else: + if TYPE_CHECKING: + assert isinstance(rval, (type, Table, _ModNS)) return rval - def __call__(self): + def __call__(self) -> Any: try: x = eval(self.arg, globals(), self._dict) @@ -441,10 +515,15 @@ class _class_resolver: self._raise_for_name(n.args[0], n) -_fallback_dict = None +_fallback_dict: Mapping[str, Any] = None # type: ignore -def _resolver(cls, prop): +def _resolver( + cls: Type[Any], prop: Relationship[Any] +) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], +]: global _fallback_dict @@ -456,12 +535,14 @@ def _resolver(cls, prop): {"foreign": foreign, "remote": remote} ) - def resolve_arg(arg, favor_tables=False): + def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver: return _class_resolver( cls, prop, _fallback_dict, arg, favor_tables=favor_tables ) - def resolve_name(arg): + def resolve_name( + arg: str, + ) -> Callable[[], Union[Type[Any], Table, _ModNS]]: return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name return resolve_name, resolve_arg |