summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/clsregistry.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/clsregistry.py')
-rw-r--r--lib/sqlalchemy/orm/clsregistry.py177
1 files changed, 129 insertions, 48 deletions
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py
index 473468c6c..b3fcd29ea 100644
--- a/lib/sqlalchemy/orm/clsregistry.py
+++ b/lib/sqlalchemy/orm/clsregistry.py
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
"""Routines to handle the string class registry used by declarative.
@@ -16,7 +15,22 @@ This system allows specification of classes and expressions used in
from __future__ import annotations
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -29,6 +43,14 @@ from .. import exc
from .. import inspection
from .. import util
from ..sql.schema import _get_table_key
+from ..util.typing import CallableReference
+
+if TYPE_CHECKING:
+ from .relationships import Relationship
+ from ..sql.schema import MetaData
+ from ..sql.schema import Table
+
+_T = TypeVar("_T", bound=Any)
_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
@@ -36,10 +58,12 @@ _ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
# the _decl_class_registry, which is usually weak referencing.
# the internal registries here link to classes with weakrefs and remove
# themselves when all references to contained classes are removed.
-_registries = set()
+_registries: Set[ClsRegistryToken] = set()
-def add_class(classname, cls, decl_class_registry):
+def add_class(
+ classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
+) -> None:
"""Add a class to the _decl_class_registry associated with the
given declarative class.
@@ -49,13 +73,15 @@ def add_class(classname, cls, decl_class_registry):
existing = decl_class_registry[classname]
if not isinstance(existing, _MultipleClassMarker):
existing = decl_class_registry[classname] = _MultipleClassMarker(
- [cls, existing]
+ [cls, cast("Type[Any]", existing)]
)
else:
decl_class_registry[classname] = cls
try:
- root_module = decl_class_registry["_sa_module_registry"]
+ root_module = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
except KeyError:
decl_class_registry[
"_sa_module_registry"
@@ -79,7 +105,9 @@ def add_class(classname, cls, decl_class_registry):
module.add_class(classname, cls)
-def remove_class(classname, cls, decl_class_registry):
+def remove_class(
+ classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
+) -> None:
if classname in decl_class_registry:
existing = decl_class_registry[classname]
if isinstance(existing, _MultipleClassMarker):
@@ -88,7 +116,9 @@ def remove_class(classname, cls, decl_class_registry):
del decl_class_registry[classname]
try:
- root_module = decl_class_registry["_sa_module_registry"]
+ root_module = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
except KeyError:
return
@@ -102,7 +132,11 @@ def remove_class(classname, cls, decl_class_registry):
module.remove_class(classname, cls)
-def _key_is_empty(key, decl_class_registry, test):
+def _key_is_empty(
+ key: str,
+ decl_class_registry: _ClsRegistryType,
+ test: Callable[[Any], bool],
+) -> bool:
"""test if a key is empty of a certain object.
used for unit tests against the registry to see if garbage collection
@@ -124,6 +158,8 @@ def _key_is_empty(key, decl_class_registry, test):
for sub_thing in thing.contents:
if test(sub_thing):
return False
+ else:
+ raise NotImplementedError("unknown codepath")
else:
return not test(thing)
@@ -142,20 +178,27 @@ class _MultipleClassMarker(ClsRegistryToken):
__slots__ = "on_remove", "contents", "__weakref__"
- def __init__(self, classes, on_remove=None):
+ contents: Set[weakref.ref[Type[Any]]]
+ on_remove: CallableReference[Optional[Callable[[], None]]]
+
+ def __init__(
+ self,
+ classes: Iterable[Type[Any]],
+ on_remove: Optional[Callable[[], None]] = None,
+ ):
self.on_remove = on_remove
self.contents = set(
[weakref.ref(item, self._remove_item) for item in classes]
)
_registries.add(self)
- def remove_item(self, cls):
+ def remove_item(self, cls: Type[Any]) -> None:
self._remove_item(weakref.ref(cls))
- def __iter__(self):
+ def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
return (ref() for ref in self.contents)
- def attempt_get(self, path, key):
+ def attempt_get(self, path: List[str], key: str) -> Type[Any]:
if len(self.contents) > 1:
raise exc.InvalidRequestError(
'Multiple classes found for path "%s" '
@@ -170,14 +213,14 @@ class _MultipleClassMarker(ClsRegistryToken):
raise NameError(key)
return cls
- def _remove_item(self, ref):
+ def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
self.contents.discard(ref)
if not self.contents:
_registries.discard(self)
if self.on_remove:
self.on_remove()
- def add_item(self, item):
+ def add_item(self, item: Type[Any]) -> None:
# protect against class registration race condition against
# asynchronous garbage collection calling _remove_item,
# [ticket:3208]
@@ -206,7 +249,12 @@ class _ModuleMarker(ClsRegistryToken):
__slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
- def __init__(self, name, parent):
+ parent: Optional[_ModuleMarker]
+ contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
+ mod_ns: _ModNS
+ path: List[str]
+
+ def __init__(self, name: str, parent: Optional[_ModuleMarker]):
self.parent = parent
self.name = name
self.contents = {}
@@ -217,51 +265,53 @@ class _ModuleMarker(ClsRegistryToken):
self.path = []
_registries.add(self)
- def __contains__(self, name):
+ def __contains__(self, name: str) -> bool:
return name in self.contents
- def __getitem__(self, name):
+ def __getitem__(self, name: str) -> ClsRegistryToken:
return self.contents[name]
- def _remove_item(self, name):
+ def _remove_item(self, name: str) -> None:
self.contents.pop(name, None)
if not self.contents and self.parent is not None:
self.parent._remove_item(self.name)
_registries.discard(self)
- def resolve_attr(self, key):
- return getattr(self.mod_ns, key)
+ def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
+ return self.mod_ns.__getattr__(key)
- def get_module(self, name):
+ def get_module(self, name: str) -> _ModuleMarker:
if name not in self.contents:
marker = _ModuleMarker(name, self)
self.contents[name] = marker
else:
- marker = self.contents[name]
+ marker = cast(_ModuleMarker, self.contents[name])
return marker
- def add_class(self, name, cls):
+ def add_class(self, name: str, cls: Type[Any]) -> None:
if name in self.contents:
- existing = self.contents[name]
+ existing = cast(_MultipleClassMarker, self.contents[name])
existing.add_item(cls)
else:
existing = self.contents[name] = _MultipleClassMarker(
[cls], on_remove=lambda: self._remove_item(name)
)
- def remove_class(self, name, cls):
+ def remove_class(self, name: str, cls: Type[Any]) -> None:
if name in self.contents:
- existing = self.contents[name]
+ existing = cast(_MultipleClassMarker, self.contents[name])
existing.remove_item(cls)
class _ModNS:
__slots__ = ("__parent",)
- def __init__(self, parent):
+ __parent: _ModuleMarker
+
+ def __init__(self, parent: _ModuleMarker):
self.__parent = parent
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
try:
value = self.__parent.contents[key]
except KeyError:
@@ -282,10 +332,12 @@ class _ModNS:
class _GetColumns:
__slots__ = ("cls",)
- def __init__(self, cls):
+ cls: Type[Any]
+
+ def __init__(self, cls: Type[Any]):
self.cls = cls
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
mp = class_mapper(self.cls, configure=False)
if mp:
if key not in mp.all_orm_descriptors:
@@ -296,6 +348,7 @@ class _GetColumns:
desc = mp.all_orm_descriptors[key]
if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
+ assert isinstance(desc, attributes.QueryableAttribute)
prop = desc.property
if isinstance(prop, Synonym):
key = prop.name
@@ -316,15 +369,18 @@ inspection._inspects(_GetColumns)(
class _GetTable:
__slots__ = "key", "metadata"
- def __init__(self, key, metadata):
+ key: str
+ metadata: MetaData
+
+ def __init__(self, key: str, metadata: MetaData):
self.key = key
self.metadata = metadata
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Table:
return self.metadata.tables[_get_table_key(key, self.key)]
-def _determine_container(key, value):
+def _determine_container(key: str, value: Any) -> _GetColumns:
if isinstance(value, _MultipleClassMarker):
value = value.attempt_get([], key)
return _GetColumns(value)
@@ -341,7 +397,21 @@ class _class_resolver:
"favor_tables",
)
- def __init__(self, cls, prop, fallback, arg, favor_tables=False):
+ cls: Type[Any]
+ prop: Relationship[Any]
+ fallback: Mapping[str, Any]
+ arg: str
+ favor_tables: bool
+ _resolvers: Tuple[Callable[[str], Any], ...]
+
+ def __init__(
+ self,
+ cls: Type[Any],
+ prop: Relationship[Any],
+ fallback: Mapping[str, Any],
+ arg: str,
+ favor_tables: bool = False,
+ ):
self.cls = cls
self.prop = prop
self.arg = arg
@@ -350,11 +420,12 @@ class _class_resolver:
self._resolvers = ()
self.favor_tables = favor_tables
- def _access_cls(self, key):
+ def _access_cls(self, key: str) -> Any:
cls = self.cls
manager = attributes.manager_of_class(cls)
decl_base = manager.registry
+ assert decl_base is not None
decl_class_registry = decl_base._class_registry
metadata = decl_base.metadata
@@ -362,7 +433,7 @@ class _class_resolver:
if key in metadata.tables:
return metadata.tables[key]
elif key in metadata._schemas:
- return _GetTable(key, cls.metadata)
+ return _GetTable(key, getattr(cls, "metadata", metadata))
if key in decl_class_registry:
return _determine_container(key, decl_class_registry[key])
@@ -371,13 +442,14 @@ class _class_resolver:
if key in metadata.tables:
return metadata.tables[key]
elif key in metadata._schemas:
- return _GetTable(key, cls.metadata)
+ return _GetTable(key, getattr(cls, "metadata", metadata))
- if (
- "_sa_module_registry" in decl_class_registry
- and key in decl_class_registry["_sa_module_registry"]
+ if "_sa_module_registry" in decl_class_registry and key in cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
):
- registry = decl_class_registry["_sa_module_registry"]
+ registry = cast(
+ _ModuleMarker, decl_class_registry["_sa_module_registry"]
+ )
return registry.resolve_attr(key)
elif self._resolvers:
for resolv in self._resolvers:
@@ -387,7 +459,7 @@ class _class_resolver:
return self.fallback[key]
- def _raise_for_name(self, name, err):
+ def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
generic_match = re.match(r"(.+)\[(.+)\]", name)
if generic_match:
@@ -409,7 +481,7 @@ class _class_resolver:
% (self.prop.parent, self.arg, name, self.cls)
) from err
- def _resolve_name(self):
+ def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
name = self.arg
d = self._dict
rval = None
@@ -427,9 +499,11 @@ class _class_resolver:
if isinstance(rval, _GetColumns):
return rval.cls
else:
+ if TYPE_CHECKING:
+ assert isinstance(rval, (type, Table, _ModNS))
return rval
- def __call__(self):
+ def __call__(self) -> Any:
try:
x = eval(self.arg, globals(), self._dict)
@@ -441,10 +515,15 @@ class _class_resolver:
self._raise_for_name(n.args[0], n)
-_fallback_dict = None
+_fallback_dict: Mapping[str, Any] = None # type: ignore
-def _resolver(cls, prop):
+def _resolver(
+ cls: Type[Any], prop: Relationship[Any]
+) -> Tuple[
+ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
+ Callable[[str, bool], _class_resolver],
+]:
global _fallback_dict
@@ -456,12 +535,14 @@ def _resolver(cls, prop):
{"foreign": foreign, "remote": remote}
)
- def resolve_arg(arg, favor_tables=False):
+ def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver:
return _class_resolver(
cls, prop, _fallback_dict, arg, favor_tables=favor_tables
)
- def resolve_name(arg):
+ def resolve_name(
+ arg: str,
+ ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
return resolve_name, resolve_arg