diff options
Diffstat (limited to 'lib/sqlalchemy/orm/base.py')
-rw-r--r-- | lib/sqlalchemy/orm/base.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index c79592625..31dceb065 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -10,11 +10,13 @@ """ import operator +import typing from . import exc from .. import exc as sa_exc from .. import inspection from .. import util +from ..util import typing as compat_typing PASSIVE_NO_RESULT = util.symbol( @@ -221,13 +223,25 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -def _assertions(*assertions): +_Fn = typing.TypeVar("_Fn", bound=typing.Callable) +_Args = compat_typing.ParamSpec("_Args") +_Self = typing.TypeVar("_Self") + + +def _assertions( + *assertions, +) -> typing.Callable[ + [typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self]], + typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self], +]: @util.decorator - def generate(fn, *args, **kw): - self = args[0] + def generate( + fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs + ) -> _Self: for assertion in assertions: assertion(self, fn.__name__) - fn(self, *args[1:], **kw) + fn(self, *args, **kw) + return self return generate |