summaryrefslogtreecommitdiff
path: root/numpy/testing/_private
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2022-01-24 12:31:20 +0100
committerBas van Beek <b.f.van.beek@vu.nl>2022-01-24 14:00:32 +0100
commit11bc3d314038a92671c97ca33eee650f046a3029 (patch)
tree9241a63a04257f2536d00b650cda4c5bdee6fbc1 /numpy/testing/_private
parentf40b105d627096ef04ca7a5ccbcd6ff53d810be6 (diff)
downloadnumpy-11bc3d314038a92671c97ca33eee650f046a3029.tar.gz
ENH: Improve typing with the help of `ParamSpec`
Diffstat (limited to 'numpy/testing/_private')
-rw-r--r--numpy/testing/_private/utils.pyi32
1 files changed, 17 insertions, 15 deletions
diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi
index 8117f18ae..f4b22834d 100644
--- a/numpy/testing/_private/utils.pyi
+++ b/numpy/testing/_private/utils.pyi
@@ -20,6 +20,7 @@ from typing import (
Final,
SupportsIndex,
)
+from typing_extensions import ParamSpec
from numpy import generic, dtype, number, object_, bool_, _FloatValue
from numpy.typing import (
@@ -36,6 +37,7 @@ from unittest.case import (
SkipTest as SkipTest,
)
+_P = ParamSpec("_P")
_T = TypeVar("_T")
_ET = TypeVar("_ET", bound=BaseException)
_FT = TypeVar("_FT", bound=Callable[..., Any])
@@ -254,10 +256,10 @@ def raises(*args: type[BaseException]) -> Callable[[_FT], _FT]: ...
@overload
def assert_raises( # type: ignore
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
- callable: Callable[..., Any],
+ callable: Callable[_P, Any],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> None: ...
@overload
def assert_raises(
@@ -270,10 +272,10 @@ def assert_raises(
def assert_raises_regex(
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
expected_regex: str | bytes | Pattern[Any],
- callable: Callable[..., Any],
+ callable: Callable[_P, Any],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> None: ...
@overload
def assert_raises_regex(
@@ -336,20 +338,20 @@ def assert_warns(
@overload
def assert_warns(
warning_class: type[Warning],
- func: Callable[..., _T],
+ func: Callable[_P, _T],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> _T: ...
@overload
def assert_no_warnings() -> contextlib._GeneratorContextManager[None]: ...
@overload
def assert_no_warnings(
- func: Callable[..., _T],
+ func: Callable[_P, _T],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> _T: ...
@overload
@@ -384,10 +386,10 @@ def temppath(
def assert_no_gc_cycles() -> contextlib._GeneratorContextManager[None]: ...
@overload
def assert_no_gc_cycles(
- func: Callable[..., Any],
+ func: Callable[_P, Any],
/,
- *args: Any,
- **kwargs: Any,
+ *args: _P.args,
+ **kwargs: _P.kwargs,
) -> None: ...
def break_cycles() -> None: ...