diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-03-09 15:47:07 -0700 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-03-09 15:49:01 -0700 |
commit | 1ccbe680e24f2d2254ee4cd5053bb10859b22d1d (patch) | |
tree | 8ecb85d1cc8e57031fcc86d28a0282455ebbb3ec /numpy/_array_api/_array_object.py | |
parent | cdd6bbcdf260a4d6947901604dc8dd64c864c8d4 (diff) | |
download | numpy-1ccbe680e24f2d2254ee4cd5053bb10859b22d1d.tar.gz |
Only allow indices that are required by the spec in the array API namespace
The private function _validate_indices describes the cases that are
disallowed. This functionality should be tested (it isn't yet), as the array
API test suite will only test the cases that are allowed, not that
non-required cases are rejected.
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r-- | numpy/_array_api/_array_object.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index b78405860..64ce740f0 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -15,9 +15,11 @@ of ndarray. from __future__ import annotations +import operator from enum import IntEnum from ._types import Optional, PyCapsule, Tuple, Union, array from ._creation_functions import asarray +from ._dtypes import _boolean_dtypes, _integer_dtypes import numpy as np @@ -140,10 +142,120 @@ class ndarray: res = x1._array.__ge__(asarray(x2)._array) return x1.__class__._new(res) + # Note: A large fraction of allowed indices are disallowed here (see the + # docstring below) + @staticmethod + def _validate_index(key, shape): + """ + Validate an index according to the array API. + + The array API specification only requires a subset of indices that are + supported by NumPy. This function will reject any index that is + allowed by NumPy but not required by the array API specification. We + always raise ``IndexError`` on such indices (the spec does not require + any specific behavior on them, but this makes the NumPy array API + namespace a minimal implementation of the spec). + + This function either raises IndexError if the index ``key`` is + invalid, or a new key to be used in place of ``key`` in indexing. It + only raises ``IndexError`` on indices that are not already rejected by + NumPy, as NumPy will already raise the appropriate error on such + indices. ``shape`` may be None, in which case, only cases that are + independent of the array shape are checked. + + The following cases are allowed by NumPy, but not specified by the array + API specification: + + - The start and stop of a slice may not be out of bounds. In + particular, for a slice ``i:j:k`` on an axis of size ``n``, only the + following are allowed: + + - ``i`` or ``j`` omitted (``None``). + - ``-n <= i <= max(0, n - 1)``. + - For ``k > 0`` or ``k`` omitted (``None``), ``-n <= j <= n``. + - For ``k < 0``, ``-n - 1 <= j <= max(0, n - 1)``. + + - Boolean array indices are not allowed as part of a larger tuple + index. + + - Integer array indices are not allowed (with the exception of shape + () arrays, which are treated the same as scalars). + + Additionally, it should be noted that indices that would return a + scalar in NumPy will return a shape () array. Array scalars are not allowed + in the specification, only shape () arrays. This is done in the + ``ndarray._new`` constructor, not this function. + + """ + if isinstance(key, slice): + if shape is None: + return key + if shape == (): + return key + size = shape[0] + # Ensure invalid slice entries are passed through. + if key.start is not None: + try: + operator.index(key.start) + except TypeError: + return key + if not (-size <= key.start <= max(0, size - 1)): + raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace") + if key.stop is not None: + try: + operator.index(key.stop) + except TypeError: + return key + step = 1 if key.step is None else key.step + if (step > 0 and not (-size <= key.stop <= size) + or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))): + raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace") + return key + + elif isinstance(key, tuple): + key = tuple(ndarray._validate_index(idx, None) for idx in key) + + for idx in key: + if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): + if len(key) == 1: + return key + raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace") + + if shape is None: + return key + n_ellipsis = key.count(...) + if n_ellipsis > 1: + return key + ellipsis_i = key.index(...) if n_ellipsis else len(key) + + for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): + ndarray._validate_index(idx, (size,)) + return key + elif isinstance(key, bool): + return key + elif isinstance(key, ndarray): + if key.dtype in _integer_dtypes: + if key.shape != (): + raise IndexError("Integer array indices with shape != () are not allowed in the array API namespace") + return key._array + elif key is Ellipsis: + return key + elif key is None: + raise IndexError("newaxis indices are not allowed in the array API namespace") + try: + return operator.index(key) + except TypeError: + # Note: This also omits boolean arrays that are not already in + # ndarray() form, like a list of booleans. + raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") + def __getitem__(x: array, key: Union[int, slice, Tuple[Union[int, slice], ...], array], /) -> array: """ Performs the operation __getitem__. """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + key = x._validate_index(key, x.shape) res = x._array.__getitem__(key) return x.__class__._new(res) @@ -266,6 +378,9 @@ class ndarray: """ Performs the operation __setitem__. """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + key = x._validate_index(key, x.shape) res = x._array.__setitem__(key, asarray(value)._array) return x.__class__._new(res) |