diff options
-rw-r--r-- | numpy/_array_api/_array_object.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index b78405860..64ce740f0 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -15,9 +15,11 @@ of ndarray. from __future__ import annotations +import operator from enum import IntEnum from ._types import Optional, PyCapsule, Tuple, Union, array from ._creation_functions import asarray +from ._dtypes import _boolean_dtypes, _integer_dtypes import numpy as np @@ -140,10 +142,120 @@ class ndarray: res = x1._array.__ge__(asarray(x2)._array) return x1.__class__._new(res) + # Note: A large fraction of allowed indices are disallowed here (see the + # docstring below) + @staticmethod + def _validate_index(key, shape): + """ + Validate an index according to the array API. + + The array API specification only requires a subset of indices that are + supported by NumPy. This function will reject any index that is + allowed by NumPy but not required by the array API specification. We + always raise ``IndexError`` on such indices (the spec does not require + any specific behavior on them, but this makes the NumPy array API + namespace a minimal implementation of the spec). + + This function either raises IndexError if the index ``key`` is + invalid, or a new key to be used in place of ``key`` in indexing. It + only raises ``IndexError`` on indices that are not already rejected by + NumPy, as NumPy will already raise the appropriate error on such + indices. ``shape`` may be None, in which case, only cases that are + independent of the array shape are checked. + + The following cases are allowed by NumPy, but not specified by the array + API specification: + + - The start and stop of a slice may not be out of bounds. In + particular, for a slice ``i:j:k`` on an axis of size ``n``, only the + following are allowed: + + - ``i`` or ``j`` omitted (``None``). + - ``-n <= i <= max(0, n - 1)``. + - For ``k > 0`` or ``k`` omitted (``None``), ``-n <= j <= n``. + - For ``k < 0``, ``-n - 1 <= j <= max(0, n - 1)``. + + - Boolean array indices are not allowed as part of a larger tuple + index. + + - Integer array indices are not allowed (with the exception of shape + () arrays, which are treated the same as scalars). + + Additionally, it should be noted that indices that would return a + scalar in NumPy will return a shape () array. Array scalars are not allowed + in the specification, only shape () arrays. This is done in the + ``ndarray._new`` constructor, not this function. + + """ + if isinstance(key, slice): + if shape is None: + return key + if shape == (): + return key + size = shape[0] + # Ensure invalid slice entries are passed through. + if key.start is not None: + try: + operator.index(key.start) + except TypeError: + return key + if not (-size <= key.start <= max(0, size - 1)): + raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace") + if key.stop is not None: + try: + operator.index(key.stop) + except TypeError: + return key + step = 1 if key.step is None else key.step + if (step > 0 and not (-size <= key.stop <= size) + or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))): + raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace") + return key + + elif isinstance(key, tuple): + key = tuple(ndarray._validate_index(idx, None) for idx in key) + + for idx in key: + if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): + if len(key) == 1: + return key + raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace") + + if shape is None: + return key + n_ellipsis = key.count(...) + if n_ellipsis > 1: + return key + ellipsis_i = key.index(...) if n_ellipsis else len(key) + + for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): + ndarray._validate_index(idx, (size,)) + return key + elif isinstance(key, bool): + return key + elif isinstance(key, ndarray): + if key.dtype in _integer_dtypes: + if key.shape != (): + raise IndexError("Integer array indices with shape != () are not allowed in the array API namespace") + return key._array + elif key is Ellipsis: + return key + elif key is None: + raise IndexError("newaxis indices are not allowed in the array API namespace") + try: + return operator.index(key) + except TypeError: + # Note: This also omits boolean arrays that are not already in + # ndarray() form, like a list of booleans. + raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") + def __getitem__(x: array, key: Union[int, slice, Tuple[Union[int, slice], ...], array], /) -> array: """ Performs the operation __getitem__. """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + key = x._validate_index(key, x.shape) res = x._array.__getitem__(key) return x.__class__._new(res) @@ -266,6 +378,9 @@ class ndarray: """ Performs the operation __setitem__. """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + key = x._validate_index(key, x.shape) res = x._array.__setitem__(key, asarray(value)._array) return x.__class__._new(res) |