summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/identity.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/identity.py')
-rw-r--r--lib/sqlalchemy/orm/identity.py40
1 files changed, 26 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
index d13265c56..63b131a78 100644
--- a/lib/sqlalchemy/orm/identity.py
+++ b/lib/sqlalchemy/orm/identity.py
@@ -8,6 +8,7 @@
from __future__ import annotations
from typing import Any
+from typing import cast
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -15,6 +16,7 @@ from typing import List
from typing import NoReturn
from typing import Optional
from typing import Set
+from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
import weakref
@@ -66,7 +68,7 @@ class IdentityMap:
) -> Optional[_O]:
raise NotImplementedError()
- def keys(self):
+ def keys(self) -> Iterable[_IdentityKeyType[Any]]:
return self._dict.keys()
def values(self) -> Iterable[object]:
@@ -117,10 +119,10 @@ class IdentityMap:
class WeakInstanceDict(IdentityMap):
- _dict: Dict[Optional[_IdentityKeyType[Any]], InstanceState[Any]]
+ _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]]
def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
- state = self._dict[key]
+ state = cast("InstanceState[_O]", self._dict[key])
o = state.obj()
if o is None:
raise KeyError(key)
@@ -140,6 +142,8 @@ class WeakInstanceDict(IdentityMap):
def contains_state(self, state: InstanceState[Any]) -> bool:
if state.key in self._dict:
+ if TYPE_CHECKING:
+ assert state.key is not None
try:
return self._dict[state.key] is state
except KeyError:
@@ -150,15 +154,16 @@ class WeakInstanceDict(IdentityMap):
def replace(
self, state: InstanceState[Any]
) -> Optional[InstanceState[Any]]:
+ assert state.key is not None
if state.key in self._dict:
try:
- existing = self._dict[state.key]
+ existing = existing_non_none = self._dict[state.key]
except KeyError:
# catch gc removed the key after we just checked for it
existing = None
else:
- if existing is not state:
- self._manage_removed_state(existing)
+ if existing_non_none is not state:
+ self._manage_removed_state(existing_non_none)
else:
return None
else:
@@ -170,6 +175,7 @@ class WeakInstanceDict(IdentityMap):
def add(self, state: InstanceState[Any]) -> bool:
key = state.key
+ assert key is not None
# inline of self.__contains__
if key in self._dict:
try:
@@ -206,7 +212,7 @@ class WeakInstanceDict(IdentityMap):
if key not in self._dict:
return default
try:
- state = self._dict[key]
+ state = cast("InstanceState[_O]", self._dict[key])
except KeyError:
# catch gc removed the key after we just checked for it
return default
@@ -216,13 +222,15 @@ class WeakInstanceDict(IdentityMap):
return default
return o
- def items(self) -> List[InstanceState[Any]]:
+ def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]:
values = self.all_states()
result = []
for state in values:
value = state.obj()
+ key = state.key
+ assert key is not None
if value is not None:
- result.append((state.key, value))
+ result.append((key, value))
return result
def values(self) -> List[object]:
@@ -244,28 +252,32 @@ class WeakInstanceDict(IdentityMap):
def _fast_discard(self, state: InstanceState[Any]) -> None:
# used by InstanceState for state being
# GC'ed, inlines _managed_removed_state
+ key = state.key
+ assert key is not None
try:
- st = self._dict[state.key]
+ st = self._dict[key]
except KeyError:
# catch gc removed the key after we just checked for it
pass
else:
if st is state:
- self._dict.pop(state.key, None)
+ self._dict.pop(key, None)
def discard(self, state: InstanceState[Any]) -> None:
self.safe_discard(state)
def safe_discard(self, state: InstanceState[Any]) -> None:
- if state.key in self._dict:
+ key = state.key
+ if key in self._dict:
+ assert key is not None
try:
- st = self._dict[state.key]
+ st = self._dict[key]
except KeyError:
# catch gc removed the key after we just checked for it
pass
else:
if st is state:
- self._dict.pop(state.key, None)
+ self._dict.pop(key, None)
self._manage_removed_state(state)