diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2022-01-24 12:31:20 +0100 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2022-01-24 14:00:32 +0100 |
commit | 11bc3d314038a92671c97ca33eee650f046a3029 (patch) | |
tree | 9241a63a04257f2536d00b650cda4c5bdee6fbc1 /numpy/testing/_private | |
parent | f40b105d627096ef04ca7a5ccbcd6ff53d810be6 (diff) | |
download | numpy-11bc3d314038a92671c97ca33eee650f046a3029.tar.gz |
ENH: Improve typing with the help of `ParamSpec`
Diffstat (limited to 'numpy/testing/_private')
-rw-r--r-- | numpy/testing/_private/utils.pyi | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi index 8117f18ae..f4b22834d 100644 --- a/numpy/testing/_private/utils.pyi +++ b/numpy/testing/_private/utils.pyi @@ -20,6 +20,7 @@ from typing import ( Final, SupportsIndex, ) +from typing_extensions import ParamSpec from numpy import generic, dtype, number, object_, bool_, _FloatValue from numpy.typing import ( @@ -36,6 +37,7 @@ from unittest.case import ( SkipTest as SkipTest, ) +_P = ParamSpec("_P") _T = TypeVar("_T") _ET = TypeVar("_ET", bound=BaseException) _FT = TypeVar("_FT", bound=Callable[..., Any]) @@ -254,10 +256,10 @@ def raises(*args: type[BaseException]) -> Callable[[_FT], _FT]: ... @overload def assert_raises( # type: ignore expected_exception: type[BaseException] | tuple[type[BaseException], ...], - callable: Callable[..., Any], + callable: Callable[_P, Any], /, - *args: Any, - **kwargs: Any, + *args: _P.args, + **kwargs: _P.kwargs, ) -> None: ... @overload def assert_raises( @@ -270,10 +272,10 @@ def assert_raises( def assert_raises_regex( expected_exception: type[BaseException] | tuple[type[BaseException], ...], expected_regex: str | bytes | Pattern[Any], - callable: Callable[..., Any], + callable: Callable[_P, Any], /, - *args: Any, - **kwargs: Any, + *args: _P.args, + **kwargs: _P.kwargs, ) -> None: ... @overload def assert_raises_regex( @@ -336,20 +338,20 @@ def assert_warns( @overload def assert_warns( warning_class: type[Warning], - func: Callable[..., _T], + func: Callable[_P, _T], /, - *args: Any, - **kwargs: Any, + *args: _P.args, + **kwargs: _P.kwargs, ) -> _T: ... @overload def assert_no_warnings() -> contextlib._GeneratorContextManager[None]: ... @overload def assert_no_warnings( - func: Callable[..., _T], + func: Callable[_P, _T], /, - *args: Any, - **kwargs: Any, + *args: _P.args, + **kwargs: _P.kwargs, ) -> _T: ... @overload @@ -384,10 +386,10 @@ def temppath( def assert_no_gc_cycles() -> contextlib._GeneratorContextManager[None]: ... @overload def assert_no_gc_cycles( - func: Callable[..., Any], + func: Callable[_P, Any], /, - *args: Any, - **kwargs: Any, + *args: _P.args, + **kwargs: _P.kwargs, ) -> None: ... def break_cycles() -> None: ... |