diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-05-16 02:32:44 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-05-16 02:32:44 +0000 |
commit | 5d080d17464712d33c0215d12513e529d848ee8c (patch) | |
tree | eec56f3138a48f55f2585a64f01b4fd9c14451b7 /lib/sqlalchemy/orm/identity.py | |
parent | c4dad3695f4ab9fef3a4cb05893492afbec811f7 (diff) | |
parent | 18a73fb1d1c267842ead5dacd05a49f4344d8b22 (diff) | |
download | sqlalchemy-5d080d17464712d33c0215d12513e529d848ee8c.tar.gz |
Merge "revenge of pep 484" into main
Diffstat (limited to 'lib/sqlalchemy/orm/identity.py')
-rw-r--r-- | lib/sqlalchemy/orm/identity.py | 40 |
1 files changed, 26 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index d13265c56..63b131a78 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import Any +from typing import cast from typing import Dict from typing import Iterable from typing import Iterator @@ -15,6 +16,7 @@ from typing import List from typing import NoReturn from typing import Optional from typing import Set +from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar import weakref @@ -66,7 +68,7 @@ class IdentityMap: ) -> Optional[_O]: raise NotImplementedError() - def keys(self): + def keys(self) -> Iterable[_IdentityKeyType[Any]]: return self._dict.keys() def values(self) -> Iterable[object]: @@ -117,10 +119,10 @@ class IdentityMap: class WeakInstanceDict(IdentityMap): - _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]] + _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]] def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: - state = self._dict[key] + state = cast("InstanceState[_O]", self._dict[key]) o = state.obj() if o is None: raise KeyError(key) @@ -140,6 +142,8 @@ class WeakInstanceDict(IdentityMap): def contains_state(self, state: InstanceState[Any]) -> bool: if state.key in self._dict: + if TYPE_CHECKING: + assert state.key is not None try: return self._dict[state.key] is state except KeyError: @@ -150,15 +154,16 @@ class WeakInstanceDict(IdentityMap): def replace( self, state: InstanceState[Any] ) -> Optional[InstanceState[Any]]: + assert state.key is not None if state.key in self._dict: try: - existing = self._dict[state.key] + existing = existing_non_none = self._dict[state.key] except KeyError: # catch gc removed the key after we just checked for it existing = None else: - if existing is not state: - self._manage_removed_state(existing) + if existing_non_none is not state: + self._manage_removed_state(existing_non_none) else: return None else: @@ -170,6 +175,7 @@ class WeakInstanceDict(IdentityMap): def add(self, state: InstanceState[Any]) -> bool: key = state.key + assert key is not None # inline of self.__contains__ if key in self._dict: try: @@ -206,7 +212,7 @@ class WeakInstanceDict(IdentityMap): if key not in self._dict: return default try: - state = self._dict[key] + state = cast("InstanceState[_O]", self._dict[key]) except KeyError: # catch gc removed the key after we just checked for it return default @@ -216,13 +222,15 @@ class WeakInstanceDict(IdentityMap): return default return o - def items(self) -> List[InstanceState[Any]]: + def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]: values = self.all_states() result = [] for state in values: value = state.obj() + key = state.key + assert key is not None if value is not None: - result.append((state.key, value)) + result.append((key, value)) return result def values(self) -> List[object]: @@ -244,28 +252,32 @@ class WeakInstanceDict(IdentityMap): def _fast_discard(self, state: InstanceState[Any]) -> None: # used by InstanceState for state being # GC'ed, inlines _managed_removed_state + key = state.key + assert key is not None try: - st = self._dict[state.key] + st = self._dict[key] except KeyError: # catch gc removed the key after we just checked for it pass else: if st is state: - self._dict.pop(state.key, None) + self._dict.pop(key, None) def discard(self, state: InstanceState[Any]) -> None: self.safe_discard(state) def safe_discard(self, state: InstanceState[Any]) -> None: - if state.key in self._dict: + key = state.key + if key in self._dict: + assert key is not None try: - st = self._dict[state.key] + st = self._dict[key] except KeyError: # catch gc removed the key after we just checked for it pass else: if st is state: - self._dict.pop(state.key, None) + self._dict.pop(key, None) self._manage_removed_state(state) |