diff options
Diffstat (limited to 'numpy/_array_api/_searching_functions.py')
-rw-r--r-- | numpy/_array_api/_searching_functions.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py index 4eed66c48..3b37167af 100644 --- a/numpy/_array_api/_searching_functions.py +++ b/numpy/_array_api/_searching_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import Tuple, array + import numpy as np -def argmax(x, /, *, axis=None, keepdims=False): +def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`. @@ -8,7 +12,7 @@ def argmax(x, /, *, axis=None, keepdims=False): """ return np.argmax(x, axis=axis, keepdims=keepdims) -def argmin(x, /, *, axis=None, keepdims=False): +def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`. @@ -16,7 +20,7 @@ def argmin(x, /, *, axis=None, keepdims=False): """ return np.argmin(x, axis=axis, keepdims=keepdims) -def nonzero(x, /): +def nonzero(x: array, /) -> Tuple[array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`. @@ -24,7 +28,7 @@ def nonzero(x, /): """ return np.nonzero(x) -def where(condition, x1, x2, /): +def where(condition: array, x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.where <numpy.where>`. |