from __future__ import annotations from ._types import Tuple, array import numpy as np def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.argmax `. See its docstring for more information. """ # Note: this currently fails as np.argmax does not implement keepdims return np.asarray(np.argmax._implementation(x, axis=axis, keepdims=keepdims)) def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.argmin `. See its docstring for more information. """ # Note: this currently fails as np.argmin does not implement keepdims return np.asarray(np.argmin._implementation(x, axis=axis, keepdims=keepdims)) def nonzero(x: array, /) -> Tuple[array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero `. See its docstring for more information. """ return np.nonzero._implementation(x) def where(condition: array, x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ return np.where._implementation(condition, x1, x2)