summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/state_changes.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/state_changes.py')
-rw-r--r--lib/sqlalchemy/orm/state_changes.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py
index b7bf96558..764b5dfa6 100644
--- a/lib/sqlalchemy/orm/state_changes.py
+++ b/lib/sqlalchemy/orm/state_changes.py
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
"""State tracking utilities used by :class:`_orm.Session`.
"""
@@ -14,6 +15,9 @@ import contextlib
from enum import Enum
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Iterator
+from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TypeVar
@@ -48,9 +52,11 @@ class _StateChange:
_next_state: _StateChangeState = _StateChangeStates.ANY
_state: _StateChangeState = _StateChangeStates.NO_CHANGE
- _current_fn: Optional[Callable] = None
+ _current_fn: Optional[Callable[..., Any]] = None
- def _raise_for_prerequisite_state(self, operation_name, state):
+ def _raise_for_prerequisite_state(
+ self, operation_name: str, state: _StateChangeState
+ ) -> NoReturn:
raise sa_exc.IllegalStateChangeError(
f"Can't run operation '{operation_name}()' when Session "
f"is in state {state!r}"
@@ -80,16 +86,19 @@ class _StateChange:
prerequisite_states is not _StateChangeStates.ANY
)
+ prerequisite_state_collection = cast(
+ "Tuple[_StateChangeState, ...]", prerequisite_states
+ )
expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
@util.decorator
- def _go(fn, self, *arg, **kw):
+ def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
current_state = self._state
if (
has_prerequisite_states
- and current_state not in prerequisite_states
+ and current_state not in prerequisite_state_collection
):
self._raise_for_prerequisite_state(fn.__name__, current_state)
@@ -159,7 +168,7 @@ class _StateChange:
return _go
@contextlib.contextmanager
- def _expect_state(self, expected: _StateChangeState):
+ def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
"""called within a method that changes states.
method must also use the ``@declare_states()`` decorator.