summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/__init__.py4
-rw-r--r--numpy/array_api/_array_object.py16
-rw-r--r--numpy/array_api/_indexing_functions.py18
-rw-r--r--numpy/array_api/_manipulation_functions.py1
-rw-r--r--numpy/array_api/_typing.py31
-rw-r--r--numpy/array_api/tests/test_indexing_functions.py24
6 files changed, 64 insertions, 30 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index 5e58ee0a8..e154b9952 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -333,6 +333,10 @@ __all__ += [
"trunc",
]
+from ._indexing_functions import take
+
+__all__ += ["take"]
+
# linalg is an extension in the array API spec, which is a sub-namespace. Only
# a subset of functions in it are imported into the top-level namespace.
from . import linalg
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index c4746fad9..a949b5977 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -56,7 +56,7 @@ class Array:
functions, such as asarray().
"""
- _array: np.ndarray
+ _array: np.ndarray[Any, Any]
# Use a custom constructor instead of __init__, as manually initializing
# this class is not supported API.
@@ -850,23 +850,13 @@ class Array:
"""
Performs the operation __imatmul__.
"""
- # Note: NumPy does not implement __imatmul__.
-
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
if other is NotImplemented:
return other
-
- # __imatmul__ can only be allowed when it would not change the shape
- # of self.
- other_shape = other.shape
- if self.shape == () or other_shape == ():
- raise ValueError("@= requires at least one dimension")
- if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
- raise ValueError("@= cannot change the shape of the input array")
- self._array[:] = self._array.__matmul__(other._array)
- return self
+ res = self._array.__imatmul__(other._array)
+ return self.__class__._new(res)
def __rmatmul__(self: Array, other: Array, /) -> Array:
"""
diff --git a/numpy/array_api/_indexing_functions.py b/numpy/array_api/_indexing_functions.py
new file mode 100644
index 000000000..ba56bcd6f
--- /dev/null
+++ b/numpy/array_api/_indexing_functions.py
@@ -0,0 +1,18 @@
+from __future__ import annotations
+
+from ._array_object import Array
+from ._dtypes import _integer_dtypes
+
+import numpy as np
+
+def take(x: Array, indices: Array, /, *, axis: int) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.take <numpy.take>`.
+
+ See its docstring for more information.
+ """
+ if indices.dtype not in _integer_dtypes:
+ raise TypeError("Only integer dtypes are allowed in indexing")
+ if indices.ndim != 1:
+ raise ValueError("Only 1-dim indices array is supported")
+ return Array._new(np.take(x._array, indices._array, axis=axis))
diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py
index 4f2114ff5..7991f46a2 100644
--- a/numpy/array_api/_manipulation_functions.py
+++ b/numpy/array_api/_manipulation_functions.py
@@ -52,6 +52,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
return Array._new(np.transpose(x._array, axes))
+# Note: the optional argument is called 'shape', not 'newshape'
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index dfa87b358..3f9b7186a 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -17,14 +17,12 @@ __all__ = [
"PyCapsule",
]
-import sys
from typing import (
Any,
Literal,
Sequence,
Type,
Union,
- TYPE_CHECKING,
TypeVar,
Protocol,
)
@@ -51,21 +49,20 @@ class NestedSequence(Protocol[_T_co]):
def __len__(self, /) -> int: ...
Device = Literal["cpu"]
-if TYPE_CHECKING or sys.version_info >= (3, 9):
- Dtype = dtype[Union[
- int8,
- int16,
- int32,
- int64,
- uint8,
- uint16,
- uint32,
- uint64,
- float32,
- float64,
- ]]
-else:
- Dtype = dtype
+
+Dtype = dtype[Union[
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+]]
+
SupportsBufferProtocol = Any
PyCapsule = Any
diff --git a/numpy/array_api/tests/test_indexing_functions.py b/numpy/array_api/tests/test_indexing_functions.py
new file mode 100644
index 000000000..9e05c6386
--- /dev/null
+++ b/numpy/array_api/tests/test_indexing_functions.py
@@ -0,0 +1,24 @@
+import pytest
+
+from numpy import array_api as xp
+
+
+@pytest.mark.parametrize(
+ "x, indices, axis, expected",
+ [
+ ([2, 3], [1, 1, 0], 0, [3, 3, 2]),
+ ([2, 3], [1, 1, 0], -1, [3, 3, 2]),
+ ([[2, 3]], [1], -1, [[3]]),
+ ([[2, 3]], [0, 0], 0, [[2, 3], [2, 3]]),
+ ],
+)
+def test_take_function(x, indices, axis, expected):
+ """
+ Indices respect relative order of a descending stable-sort
+
+ See https://github.com/numpy/numpy/issues/20778
+ """
+ x = xp.asarray(x)
+ indices = xp.asarray(indices)
+ out = xp.take(x, indices, axis=axis)
+ assert xp.all(out == xp.asarray(expected))