diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-07-21 15:48:06 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-07-21 15:48:06 -0600 |
commit | 63a9a87360ef492c46c37416b8270563e73a6349 (patch) | |
tree | 7d04d62253f2402d589803539d98875c60222fbc | |
parent | 9d5d0ec2264c86a19714cf185a5a183df14cbb94 (diff) | |
download | numpy-63a9a87360ef492c46c37416b8270563e73a6349.tar.gz |
Restrict the array API namespace array operator type promotions
Only those type promotions that are required by the spec are allowed. In
particular, promotions across kinds, like integer + floating-point, are not
allowed, except for the case of Python scalars.
Tests are added for this.
This commit additionally makes the operators return NotImplemented on
unexpected input types rather than directly giving a TypeError. This is not
strictly required by the array API spec, but it is generally considered a best
practice for operator methods in Python.
This same thing will be implemented for the various functions in the array API
namespace in a later commit.
-rw-r--r-- | numpy/_array_api/_array_object.py | 305 | ||||
-rw-r--r-- | numpy/_array_api/tests/test_array_object.py | 178 |
2 files changed, 376 insertions, 107 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index f8bad0b59..f6371fbf4 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -18,7 +18,8 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import _all_dtypes, _boolean_dtypes, _integer_dtypes, _floating_dtypes +from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, + _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes) from typing import TYPE_CHECKING, Any, Optional, Tuple, Union if TYPE_CHECKING: @@ -83,6 +84,52 @@ class Array: """ return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})" + # These are various helper functions to make the array behavior match the + # spec in places where it either deviates from or is more strict than + # NumPy behavior + + def _check_allowed_dtypes(self, other, dtype_category, op): + """ + Helper function for operators to only allow specific input dtypes + + Use like + + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other + """ + from ._dtypes import _result_type + + _dtypes = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer or boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating-point': _floating_dtypes, + } + + if self.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) + elif isinstance(other, Array): + if other.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + else: + return NotImplemented + + res_dtype = _result_type(self.dtype, other.dtype) + if op.startswith('__i'): + # Note: NumPy will allow in-place operators in some cases where the type promoted operator does not match the left-hand side operand. For example, + + # >>> a = np.array(1, dtype=np.int8) + # >>> a += np.array(1, dtype=np.int16) + if res_dtype != self.dtype: + raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + + return other + # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): """ @@ -270,8 +317,9 @@ class Array: """ Performs the operation __add__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__add__(other._array) return self.__class__._new(res) @@ -280,8 +328,9 @@ class Array: """ Performs the operation __and__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) return self.__class__._new(res) @@ -321,6 +370,11 @@ class Array: """ Performs the operation __eq__. """ + # Even though "all" dtypes are allowed, we still require them to be + # promotable with each other. + other = self._check_allowed_dtypes(other, 'all', '__eq__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -341,8 +395,9 @@ class Array: """ Performs the operation __floordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__floordiv__(other._array) return self.__class__._new(res) @@ -351,8 +406,9 @@ class Array: """ Performs the operation __ge__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) return self.__class__._new(res) @@ -371,8 +427,9 @@ class Array: """ Performs the operation __gt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__gt__(other._array) return self.__class__._new(res) @@ -391,6 +448,8 @@ class Array: """ Performs the operation __invert__. """ + if self.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in __invert__') res = self._array.__invert__() return self.__class__._new(res) @@ -398,8 +457,9 @@ class Array: """ Performs the operation __le__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__le__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__le__(other._array) return self.__class__._new(res) @@ -416,8 +476,9 @@ class Array: """ Performs the operation __lshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lshift__(other._array) return self.__class__._new(res) @@ -426,8 +487,9 @@ class Array: """ Performs the operation __lt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lt__(other._array) return self.__class__._new(res) @@ -436,10 +498,11 @@ class Array: """ Performs the operation __matmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + if other is NotImplemented: + return other res = self._array.__matmul__(other._array) return self.__class__._new(res) @@ -447,8 +510,9 @@ class Array: """ Performs the operation __mod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mod__(other._array) return self.__class__._new(res) @@ -457,8 +521,9 @@ class Array: """ Performs the operation __mul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mul__(other._array) return self.__class__._new(res) @@ -467,6 +532,9 @@ class Array: """ Performs the operation __ne__. """ + other = self._check_allowed_dtypes(other, 'all', '__ne__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -477,6 +545,8 @@ class Array: """ Performs the operation __neg__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __neg__') res = self._array.__neg__() return self.__class__._new(res) @@ -484,8 +554,9 @@ class Array: """ Performs the operation __or__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__or__(other._array) return self.__class__._new(res) @@ -494,6 +565,8 @@ class Array: """ Performs the operation __pos__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __pos__') res = self._array.__pos__() return self.__class__._new(res) @@ -505,10 +578,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d # arrays, so we use pow() here instead. return pow(self, other) @@ -517,8 +589,9 @@ class Array: """ Performs the operation __rshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) return self.__class__._new(res) @@ -537,8 +610,9 @@ class Array: """ Performs the operation __sub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__sub__(other._array) return self.__class__._new(res) @@ -549,10 +623,9 @@ class Array: """ Performs the operation __truediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) return self.__class__._new(res) @@ -561,8 +634,9 @@ class Array: """ Performs the operation __xor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__xor__(other._array) return self.__class__._new(res) @@ -571,8 +645,9 @@ class Array: """ Performs the operation __iadd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + if other is NotImplemented: + return other self._array.__iadd__(other._array) return self @@ -580,8 +655,9 @@ class Array: """ Performs the operation __radd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__radd__(other._array) return self.__class__._new(res) @@ -590,8 +666,9 @@ class Array: """ Performs the operation __iand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + if other is NotImplemented: + return other self._array.__iand__(other._array) return self @@ -599,8 +676,9 @@ class Array: """ Performs the operation __rand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rand__(other._array) return self.__class__._new(res) @@ -609,8 +687,9 @@ class Array: """ Performs the operation __ifloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + if other is NotImplemented: + return other self._array.__ifloordiv__(other._array) return self @@ -618,8 +697,9 @@ class Array: """ Performs the operation __rfloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rfloordiv__(other._array) return self.__class__._new(res) @@ -628,8 +708,9 @@ class Array: """ Performs the operation __ilshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + if other is NotImplemented: + return other self._array.__ilshift__(other._array) return self @@ -637,8 +718,9 @@ class Array: """ Performs the operation __rlshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rlshift__(other._array) return self.__class__._new(res) @@ -649,15 +731,17 @@ class Array: """ # Note: NumPy does not implement __imatmul__. - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + if other is NotImplemented: + return other + # __imatmul__ can only be allowed when it would not change the shape # of self. other_shape = other.shape if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") + raise TypeError("@= requires at least one dimension") if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: raise ValueError("@= cannot change the shape of the input array") self._array[:] = self._array.__matmul__(other._array) @@ -667,10 +751,11 @@ class Array: """ Performs the operation __rmatmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + if other is NotImplemented: + return other res = self._array.__rmatmul__(other._array) return self.__class__._new(res) @@ -678,8 +763,9 @@ class Array: """ Performs the operation __imod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + if other is NotImplemented: + return other self._array.__imod__(other._array) return self @@ -687,8 +773,9 @@ class Array: """ Performs the operation __rmod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) return self.__class__._new(res) @@ -697,8 +784,9 @@ class Array: """ Performs the operation __imul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + if other is NotImplemented: + return other self._array.__imul__(other._array) return self @@ -706,8 +794,9 @@ class Array: """ Performs the operation __rmul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) return self.__class__._new(res) @@ -716,8 +805,9 @@ class Array: """ Performs the operation __ior__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + if other is NotImplemented: + return other self._array.__ior__(other._array) return self @@ -725,8 +815,9 @@ class Array: """ Performs the operation __ror__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ror__(other._array) return self.__class__._new(res) @@ -735,10 +826,9 @@ class Array: """ Performs the operation __ipow__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + if other is NotImplemented: + return other self._array.__ipow__(other._array) return self @@ -748,10 +838,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow the spec type promotion rules # for 0-d arrays, so we use pow() here instead. return pow(other, self) @@ -760,8 +849,9 @@ class Array: """ Performs the operation __irshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + if other is NotImplemented: + return other self._array.__irshift__(other._array) return self @@ -769,8 +859,9 @@ class Array: """ Performs the operation __rrshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) return self.__class__._new(res) @@ -779,8 +870,9 @@ class Array: """ Performs the operation __isub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + if other is NotImplemented: + return other self._array.__isub__(other._array) return self @@ -788,8 +880,9 @@ class Array: """ Performs the operation __rsub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) return self.__class__._new(res) @@ -798,10 +891,9 @@ class Array: """ Performs the operation __itruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + if other is NotImplemented: + return other self._array.__itruediv__(other._array) return self @@ -809,10 +901,9 @@ class Array: """ Performs the operation __rtruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res) @@ -821,8 +912,9 @@ class Array: """ Performs the operation __ixor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + if other is NotImplemented: + return other self._array.__ixor__(other._array) return self @@ -830,8 +922,9 @@ class Array: """ Performs the operation __rxor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) return self.__class__._new(res) diff --git a/numpy/_array_api/tests/test_array_object.py b/numpy/_array_api/tests/test_array_object.py index 49ec3b37b..5aba2b23c 100644 --- a/numpy/_array_api/tests/test_array_object.py +++ b/numpy/_array_api/tests/test_array_object.py @@ -1,7 +1,10 @@ from numpy.testing import assert_raises import numpy as np -from .. import ones, asarray +from .. import ones, asarray, result_type +from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, + _numeric_dtypes, int8, int16, int32, int64, uint64) def test_validate_index(): # The indexing tests in the official array API test suite test that the @@ -57,3 +60,176 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[None]) assert_raises(IndexError, lambda: a[None, ...]) assert_raises(IndexError, lambda: a[..., None]) + +def test_operators(): + # For every operator, we test that it works for the required type + # combinations and raises TypeError otherwise + binary_op_dtypes ={ + '__add__': 'numeric', + '__and__': 'integer_or_boolean', + '__eq__': 'all', + '__floordiv__': 'numeric', + '__ge__': 'numeric', + '__gt__': 'numeric', + '__le__': 'numeric', + '__lshift__': 'integer', + '__lt__': 'numeric', + '__mod__': 'numeric', + '__mul__': 'numeric', + '__ne__': 'all', + '__or__': 'integer_or_boolean', + '__pow__': 'floating', + '__rshift__': 'integer', + '__sub__': 'numeric', + '__truediv__': 'floating', + '__xor__': 'integer_or_boolean', + } + + # Recompute each time because of in-place ops + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1., dtype=d) + + for op, dtypes in binary_op_dtypes.items(): + ops = [op] + if op not in ['__eq__', '__ne__', '__le__', '__ge__', '__lt__', '__gt__']: + rop = '__r' + op[2:] + iop = '__i' + op[2:] + ops += [rop, iop] + for s in [1, 1., False]: + for _op in ops: + for a in _array_vals(): + # Test array op scalar. From the spec, the following combinations + # are supported: + + # - Python bool for a bool array dtype, + # - a Python int within the bounds of the given dtype for integer array dtypes, + # - a Python int or float for floating-point array dtypes + + # We do not do bounds checking for int scalars, but rather use the default + # NumPy behavior for casting in that case. + + if ((dtypes == "all" + or dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "integer" and a.dtype in _integer_dtypes + or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes + or dtypes == "boolean" and a.dtype in _boolean_dtypes + or dtypes == "floating" and a.dtype in _floating_dtypes + ) + # bool is a subtype of int, which is why we avoid + # isinstance here. + and (a.dtype in _boolean_dtypes and type(s) == bool + or a.dtype in _integer_dtypes and type(s) == int + or a.dtype in _floating_dtypes and type(s) in [float, int] + )): + # Only test for no error + getattr(a, _op)(s) + else: + assert_raises(TypeError, lambda: getattr(a, _op)(s)) + + # Test array op array. + for _op in ops: + for x in _array_vals(): + for y in _array_vals(): + # See the promotion table in NEP 47 or the array + # API spec page on type promotion. Mixed kind + # promotion is not defined. + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + ): + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure in-place operators only promote to the same dtype as the left operand. + elif _op.startswith('__i') and result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure only those dtypes that are required for every operator are allowed. + elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes) + or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes + ): + getattr(x, _op)(y) + else: + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + + unary_op_dtypes ={ + '__invert__': 'integer_or_boolean', + '__neg__': 'numeric', + '__pos__': 'numeric', + } + for op, dtypes in unary_op_dtypes.items(): + for a in _array_vals(): + if (dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes + ): + # Only test for no error + getattr(a, op)() + else: + assert_raises(TypeError, lambda: getattr(a, op)()) + + # Finally, matmul() must be tested separately, because it works a bit + # different from the other operations. + def _matmul_array_vals(): + for a in _array_vals(): + yield a + for d in _all_dtypes: + yield ones((3, 4), dtype=d) + yield ones((4, 2), dtype=d) + yield ones((4, 4), dtype=d) + + # Scalars always error + for _op in ['__matmul__', '__rmatmul__', '__imatmul__']: + for s in [1, 1., False]: + for a in _matmul_array_vals(): + if (type(s) in [float, int] and a.dtype in _floating_dtypes + or type(s) == int and a.dtype in _integer_dtypes): + # Type promotion is valid, but @ is not allowed on 0-D + # inputs, so the error is a ValueError + assert_raises(ValueError, lambda: getattr(a, _op)(s)) + else: + assert_raises(TypeError, lambda: getattr(a, _op)(s)) + + for x in _matmul_array_vals(): + for y in _matmul_array_vals(): + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + or x.dtype in _boolean_dtypes + or y.dtype in _boolean_dtypes + ): + assert_raises(TypeError, lambda: x.__matmul__(y)) + assert_raises(TypeError, lambda: y.__rmatmul__(x)) + assert_raises(TypeError, lambda: x.__imatmul__(y)) + elif x.shape == () or y.shape == () or x.shape[1] != y.shape[0]: + assert_raises(ValueError, lambda: x.__matmul__(y)) + assert_raises(ValueError, lambda: y.__rmatmul__(x)) + if result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: x.__imatmul__(y)) + else: + assert_raises(ValueError, lambda: x.__imatmul__(y)) + else: + x.__matmul__(y) + y.__rmatmul__(x) + if result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: x.__imatmul__(y)) + elif y.shape[0] != y.shape[1]: + # This one fails because x @ y has a different shape from x + assert_raises(ValueError, lambda: x.__imatmul__(y)) + else: + x.__imatmul__(y) |