summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/_array_api/_array_object.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index b78405860..64ce740f0 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -15,9 +15,11 @@ of ndarray.
from __future__ import annotations
+import operator
from enum import IntEnum
from ._types import Optional, PyCapsule, Tuple, Union, array
from ._creation_functions import asarray
+from ._dtypes import _boolean_dtypes, _integer_dtypes
import numpy as np
@@ -140,10 +142,120 @@ class ndarray:
res = x1._array.__ge__(asarray(x2)._array)
return x1.__class__._new(res)
+ # Note: A large fraction of allowed indices are disallowed here (see the
+ # docstring below)
+ @staticmethod
+ def _validate_index(key, shape):
+ """
+ Validate an index according to the array API.
+
+ The array API specification only requires a subset of indices that are
+ supported by NumPy. This function will reject any index that is
+ allowed by NumPy but not required by the array API specification. We
+ always raise ``IndexError`` on such indices (the spec does not require
+ any specific behavior on them, but this makes the NumPy array API
+ namespace a minimal implementation of the spec).
+
+ This function either raises IndexError if the index ``key`` is
+ invalid, or a new key to be used in place of ``key`` in indexing. It
+ only raises ``IndexError`` on indices that are not already rejected by
+ NumPy, as NumPy will already raise the appropriate error on such
+ indices. ``shape`` may be None, in which case, only cases that are
+ independent of the array shape are checked.
+
+ The following cases are allowed by NumPy, but not specified by the array
+ API specification:
+
+ - The start and stop of a slice may not be out of bounds. In
+ particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
+ following are allowed:
+
+ - ``i`` or ``j`` omitted (``None``).
+ - ``-n <= i <= max(0, n - 1)``.
+ - For ``k > 0`` or ``k`` omitted (``None``), ``-n <= j <= n``.
+ - For ``k < 0``, ``-n - 1 <= j <= max(0, n - 1)``.
+
+ - Boolean array indices are not allowed as part of a larger tuple
+ index.
+
+ - Integer array indices are not allowed (with the exception of shape
+ () arrays, which are treated the same as scalars).
+
+ Additionally, it should be noted that indices that would return a
+ scalar in NumPy will return a shape () array. Array scalars are not allowed
+ in the specification, only shape () arrays. This is done in the
+ ``ndarray._new`` constructor, not this function.
+
+ """
+ if isinstance(key, slice):
+ if shape is None:
+ return key
+ if shape == ():
+ return key
+ size = shape[0]
+ # Ensure invalid slice entries are passed through.
+ if key.start is not None:
+ try:
+ operator.index(key.start)
+ except TypeError:
+ return key
+ if not (-size <= key.start <= max(0, size - 1)):
+ raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace")
+ if key.stop is not None:
+ try:
+ operator.index(key.stop)
+ except TypeError:
+ return key
+ step = 1 if key.step is None else key.step
+ if (step > 0 and not (-size <= key.stop <= size)
+ or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))):
+ raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace")
+ return key
+
+ elif isinstance(key, tuple):
+ key = tuple(ndarray._validate_index(idx, None) for idx in key)
+
+ for idx in key:
+ if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)):
+ if len(key) == 1:
+ return key
+ raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace")
+
+ if shape is None:
+ return key
+ n_ellipsis = key.count(...)
+ if n_ellipsis > 1:
+ return key
+ ellipsis_i = key.index(...) if n_ellipsis else len(key)
+
+ for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])):
+ ndarray._validate_index(idx, (size,))
+ return key
+ elif isinstance(key, bool):
+ return key
+ elif isinstance(key, ndarray):
+ if key.dtype in _integer_dtypes:
+ if key.shape != ():
+ raise IndexError("Integer array indices with shape != () are not allowed in the array API namespace")
+ return key._array
+ elif key is Ellipsis:
+ return key
+ elif key is None:
+ raise IndexError("newaxis indices are not allowed in the array API namespace")
+ try:
+ return operator.index(key)
+ except TypeError:
+ # Note: This also omits boolean arrays that are not already in
+ # ndarray() form, like a list of booleans.
+ raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace")
+
def __getitem__(x: array, key: Union[int, slice, Tuple[Union[int, slice], ...], array], /) -> array:
"""
Performs the operation __getitem__.
"""
+ # Note: Only indices required by the spec are allowed. See the
+ # docstring of _validate_index
+ key = x._validate_index(key, x.shape)
res = x._array.__getitem__(key)
return x.__class__._new(res)
@@ -266,6 +378,9 @@ class ndarray:
"""
Performs the operation __setitem__.
"""
+ # Note: Only indices required by the spec are allowed. See the
+ # docstring of _validate_index
+ key = x._validate_index(key, x.shape)
res = x._array.__setitem__(key, asarray(value)._array)
return x.__class__._new(res)