diff options
Diffstat (limited to 'lib/sqlalchemy/orm/state_changes.py')
-rw-r--r-- | lib/sqlalchemy/orm/state_changes.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py index b7bf96558..764b5dfa6 100644 --- a/lib/sqlalchemy/orm/state_changes.py +++ b/lib/sqlalchemy/orm/state_changes.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php + """State tracking utilities used by :class:`_orm.Session`. """ @@ -14,6 +15,9 @@ import contextlib from enum import Enum from typing import Any from typing import Callable +from typing import cast +from typing import Iterator +from typing import NoReturn from typing import Optional from typing import Tuple from typing import TypeVar @@ -48,9 +52,11 @@ class _StateChange: _next_state: _StateChangeState = _StateChangeStates.ANY _state: _StateChangeState = _StateChangeStates.NO_CHANGE - _current_fn: Optional[Callable] = None + _current_fn: Optional[Callable[..., Any]] = None - def _raise_for_prerequisite_state(self, operation_name, state): + def _raise_for_prerequisite_state( + self, operation_name: str, state: _StateChangeState + ) -> NoReturn: raise sa_exc.IllegalStateChangeError( f"Can't run operation '{operation_name}()' when Session " f"is in state {state!r}" @@ -80,16 +86,19 @@ class _StateChange: prerequisite_states is not _StateChangeStates.ANY ) + prerequisite_state_collection = cast( + "Tuple[_StateChangeState, ...]", prerequisite_states + ) expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE @util.decorator - def _go(fn, self, *arg, **kw): + def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any: current_state = self._state if ( has_prerequisite_states - and current_state not in prerequisite_states + and current_state not in prerequisite_state_collection ): self._raise_for_prerequisite_state(fn.__name__, current_state) @@ -159,7 +168,7 @@ class _StateChange: return _go @contextlib.contextmanager - def _expect_state(self, expected: _StateChangeState): + def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]: """called within a method that changes states. method must also use the ``@declare_states()`` decorator. |