diff options
Diffstat (limited to 'numpy/_array_api')
-rw-r--r-- | numpy/_array_api/_array_object.py | 305 | ||||
-rw-r--r-- | numpy/_array_api/tests/test_array_object.py | 178 |
2 files changed, 376 insertions, 107 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index f8bad0b59..f6371fbf4 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -18,7 +18,8 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import _all_dtypes, _boolean_dtypes, _integer_dtypes, _floating_dtypes +from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, + _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes) from typing import TYPE_CHECKING, Any, Optional, Tuple, Union if TYPE_CHECKING: @@ -83,6 +84,52 @@ class Array: """ return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})" + # These are various helper functions to make the array behavior match the + # spec in places where it either deviates from or is more strict than + # NumPy behavior + + def _check_allowed_dtypes(self, other, dtype_category, op): + """ + Helper function for operators to only allow specific input dtypes + + Use like + + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other + """ + from ._dtypes import _result_type + + _dtypes = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer or boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating-point': _floating_dtypes, + } + + if self.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) + elif isinstance(other, Array): + if other.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + else: + return NotImplemented + + res_dtype = _result_type(self.dtype, other.dtype) + if op.startswith('__i'): + # Note: NumPy will allow in-place operators in some cases where the type promoted operator does not match the left-hand side operand. For example, + + # >>> a = np.array(1, dtype=np.int8) + # >>> a += np.array(1, dtype=np.int16) + if res_dtype != self.dtype: + raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + + return other + # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): """ @@ -270,8 +317,9 @@ class Array: """ Performs the operation __add__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__add__(other._array) return self.__class__._new(res) @@ -280,8 +328,9 @@ class Array: """ Performs the operation __and__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) return self.__class__._new(res) @@ -321,6 +370,11 @@ class Array: """ Performs the operation __eq__. """ + # Even though "all" dtypes are allowed, we still require them to be + # promotable with each other. + other = self._check_allowed_dtypes(other, 'all', '__eq__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -341,8 +395,9 @@ class Array: """ Performs the operation __floordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__floordiv__(other._array) return self.__class__._new(res) @@ -351,8 +406,9 @@ class Array: """ Performs the operation __ge__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) return self.__class__._new(res) @@ -371,8 +427,9 @@ class Array: """ Performs the operation __gt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__gt__(other._array) return self.__class__._new(res) @@ -391,6 +448,8 @@ class Array: """ Performs the operation __invert__. """ + if self.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in __invert__') res = self._array.__invert__() return self.__class__._new(res) @@ -398,8 +457,9 @@ class Array: """ Performs the operation __le__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__le__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__le__(other._array) return self.__class__._new(res) @@ -416,8 +476,9 @@ class Array: """ Performs the operation __lshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lshift__(other._array) return self.__class__._new(res) @@ -426,8 +487,9 @@ class Array: """ Performs the operation __lt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lt__(other._array) return self.__class__._new(res) @@ -436,10 +498,11 @@ class Array: """ Performs the operation __matmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + if other is NotImplemented: + return other res = self._array.__matmul__(other._array) return self.__class__._new(res) @@ -447,8 +510,9 @@ class Array: """ Performs the operation __mod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mod__(other._array) return self.__class__._new(res) @@ -457,8 +521,9 @@ class Array: """ Performs the operation __mul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mul__(other._array) return self.__class__._new(res) @@ -467,6 +532,9 @@ class Array: """ Performs the operation __ne__. """ + other = self._check_allowed_dtypes(other, 'all', '__ne__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -477,6 +545,8 @@ class Array: """ Performs the operation __neg__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __neg__') res = self._array.__neg__() return self.__class__._new(res) @@ -484,8 +554,9 @@ class Array: """ Performs the operation __or__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__or__(other._array) return self.__class__._new(res) @@ -494,6 +565,8 @@ class Array: """ Performs the operation __pos__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __pos__') res = self._array.__pos__() return self.__class__._new(res) @@ -505,10 +578,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d # arrays, so we use pow() here instead. return pow(self, other) @@ -517,8 +589,9 @@ class Array: """ Performs the operation __rshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) return self.__class__._new(res) @@ -537,8 +610,9 @@ class Array: """ Performs the operation __sub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__sub__(other._array) return self.__class__._new(res) @@ -549,10 +623,9 @@ class Array: """ Performs the operation __truediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) return self.__class__._new(res) @@ -561,8 +634,9 @@ class Array: """ Performs the operation __xor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__xor__(other._array) return self.__class__._new(res) @@ -571,8 +645,9 @@ class Array: """ Performs the operation __iadd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + if other is NotImplemented: + return other self._array.__iadd__(other._array) return self @@ -580,8 +655,9 @@ class Array: """ Performs the operation __radd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__radd__(other._array) return self.__class__._new(res) @@ -590,8 +666,9 @@ class Array: """ Performs the operation __iand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + if other is NotImplemented: + return other self._array.__iand__(other._array) return self @@ -599,8 +676,9 @@ class Array: """ Performs the operation __rand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rand__(other._array) return self.__class__._new(res) @@ -609,8 +687,9 @@ class Array: """ Performs the operation __ifloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + if other is NotImplemented: + return other self._array.__ifloordiv__(other._array) return self @@ -618,8 +697,9 @@ class Array: """ Performs the operation __rfloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rfloordiv__(other._array) return self.__class__._new(res) @@ -628,8 +708,9 @@ class Array: """ Performs the operation __ilshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + if other is NotImplemented: + return other self._array.__ilshift__(other._array) return self @@ -637,8 +718,9 @@ class Array: """ Performs the operation __rlshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rlshift__(other._array) return self.__class__._new(res) @@ -649,15 +731,17 @@ class Array: """ # Note: NumPy does not implement __imatmul__. - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + if other is NotImplemented: + return other + # __imatmul__ can only be allowed when it would not change the shape # of self. other_shape = other.shape if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") + raise TypeError("@= requires at least one dimension") if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: raise ValueError("@= cannot change the shape of the input array") self._array[:] = self._array.__matmul__(other._array) @@ -667,10 +751,11 @@ class Array: """ Performs the operation __rmatmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + if other is NotImplemented: + return other res = self._array.__rmatmul__(other._array) return self.__class__._new(res) @@ -678,8 +763,9 @@ class Array: """ Performs the operation __imod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + if other is NotImplemented: + return other self._array.__imod__(other._array) return self @@ -687,8 +773,9 @@ class Array: """ Performs the operation __rmod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) return self.__class__._new(res) @@ -697,8 +784,9 @@ class Array: """ Performs the operation __imul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + if other is NotImplemented: + return other self._array.__imul__(other._array) return self @@ -706,8 +794,9 @@ class Array: """ Performs the operation __rmul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) return self.__class__._new(res) @@ -716,8 +805,9 @@ class Array: """ Performs the operation __ior__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + if other is NotImplemented: + return other self._array.__ior__(other._array) return self @@ -725,8 +815,9 @@ class Array: """ Performs the operation __ror__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ror__(other._array) return self.__class__._new(res) @@ -735,10 +826,9 @@ class Array: """ Performs the operation __ipow__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + if other is NotImplemented: + return other self._array.__ipow__(other._array) return self @@ -748,10 +838,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow the spec type promotion rules # for 0-d arrays, so we use pow() here instead. return pow(other, self) @@ -760,8 +849,9 @@ class Array: """ Performs the operation __irshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + if other is NotImplemented: + return other self._array.__irshift__(other._array) return self @@ -769,8 +859,9 @@ class Array: """ Performs the operation __rrshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) return self.__class__._new(res) @@ -779,8 +870,9 @@ class Array: """ Performs the operation __isub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + if other is NotImplemented: + return other self._array.__isub__(other._array) return self @@ -788,8 +880,9 @@ class Array: """ Performs the operation __rsub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) return self.__class__._new(res) @@ -798,10 +891,9 @@ class Array: """ Performs the operation __itruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + if other is NotImplemented: + return other self._array.__itruediv__(other._array) return self @@ -809,10 +901,9 @@ class Array: """ Performs the operation __rtruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res) @@ -821,8 +912,9 @@ class Array: """ Performs the operation __ixor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + if other is NotImplemented: + return other self._array.__ixor__(other._array) return self @@ -830,8 +922,9 @@ class Array: """ Performs the operation __rxor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) return self.__class__._new(res) diff --git a/numpy/_array_api/tests/test_array_object.py b/numpy/_array_api/tests/test_array_object.py index 49ec3b37b..5aba2b23c 100644 --- a/numpy/_array_api/tests/test_array_object.py +++ b/numpy/_array_api/tests/test_array_object.py @@ -1,7 +1,10 @@ from numpy.testing import assert_raises import numpy as np -from .. import ones, asarray +from .. import ones, asarray, result_type +from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, + _numeric_dtypes, int8, int16, int32, int64, uint64) def test_validate_index(): # The indexing tests in the official array API test suite test that the @@ -57,3 +60,176 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[None]) assert_raises(IndexError, lambda: a[None, ...]) assert_raises(IndexError, lambda: a[..., None]) + +def test_operators(): + # For every operator, we test that it works for the required type + # combinations and raises TypeError otherwise + binary_op_dtypes ={ + '__add__': 'numeric', + '__and__': 'integer_or_boolean', + '__eq__': 'all', + '__floordiv__': 'numeric', + '__ge__': 'numeric', + '__gt__': 'numeric', + '__le__': 'numeric', + '__lshift__': 'integer', + '__lt__': 'numeric', + '__mod__': 'numeric', + '__mul__': 'numeric', + '__ne__': 'all', + '__or__': 'integer_or_boolean', + '__pow__': 'floating', + '__rshift__': 'integer', + '__sub__': 'numeric', + '__truediv__': 'floating', + '__xor__': 'integer_or_boolean', + } + + # Recompute each time because of in-place ops + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1., dtype=d) + + for op, dtypes in binary_op_dtypes.items(): + ops = [op] + if op not in ['__eq__', '__ne__', '__le__', '__ge__', '__lt__', '__gt__']: + rop = '__r' + op[2:] + iop = '__i' + op[2:] + ops += [rop, iop] + for s in [1, 1., False]: + for _op in ops: + for a in _array_vals(): + # Test array op scalar. From the spec, the following combinations + # are supported: + + # - Python bool for a bool array dtype, + # - a Python int within the bounds of the given dtype for integer array dtypes, + # - a Python int or float for floating-point array dtypes + + # We do not do bounds checking for int scalars, but rather use the default + # NumPy behavior for casting in that case. + + if ((dtypes == "all" + or dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "integer" and a.dtype in _integer_dtypes + or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes + or dtypes == "boolean" and a.dtype in _boolean_dtypes + or dtypes == "floating" and a.dtype in _floating_dtypes + ) + # bool is a subtype of int, which is why we avoid + # isinstance here. + and (a.dtype in _boolean_dtypes and type(s) == bool + or a.dtype in _integer_dtypes and type(s) == int + or a.dtype in _floating_dtypes and type(s) in [float, int] + )): + # Only test for no error + getattr(a, _op)(s) + else: + assert_raises(TypeError, lambda: getattr(a, _op)(s)) + + # Test array op array. + for _op in ops: + for x in _array_vals(): + for y in _array_vals(): + # See the promotion table in NEP 47 or the array + # API spec page on type promotion. Mixed kind + # promotion is not defined. + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + ): + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure in-place operators only promote to the same dtype as the left operand. + elif _op.startswith('__i') and result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure only those dtypes that are required for every operator are allowed. + elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes) + or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes + ): + getattr(x, _op)(y) + else: + assert_raises(TypeError, lambda: getattr(x, _op)(y)) + + unary_op_dtypes ={ + '__invert__': 'integer_or_boolean', + '__neg__': 'numeric', + '__pos__': 'numeric', + } + for op, dtypes in unary_op_dtypes.items(): + for a in _array_vals(): + if (dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes + ): + # Only test for no error + getattr(a, op)() + else: + assert_raises(TypeError, lambda: getattr(a, op)()) + + # Finally, matmul() must be tested separately, because it works a bit + # different from the other operations. + def _matmul_array_vals(): + for a in _array_vals(): + yield a + for d in _all_dtypes: + yield ones((3, 4), dtype=d) + yield ones((4, 2), dtype=d) + yield ones((4, 4), dtype=d) + + # Scalars always error + for _op in ['__matmul__', '__rmatmul__', '__imatmul__']: + for s in [1, 1., False]: + for a in _matmul_array_vals(): + if (type(s) in [float, int] and a.dtype in _floating_dtypes + or type(s) == int and a.dtype in _integer_dtypes): + # Type promotion is valid, but @ is not allowed on 0-D + # inputs, so the error is a ValueError + assert_raises(ValueError, lambda: getattr(a, _op)(s)) + else: + assert_raises(TypeError, lambda: getattr(a, _op)(s)) + + for x in _matmul_array_vals(): + for y in _matmul_array_vals(): + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + or x.dtype in _boolean_dtypes + or y.dtype in _boolean_dtypes + ): + assert_raises(TypeError, lambda: x.__matmul__(y)) + assert_raises(TypeError, lambda: y.__rmatmul__(x)) + assert_raises(TypeError, lambda: x.__imatmul__(y)) + elif x.shape == () or y.shape == () or x.shape[1] != y.shape[0]: + assert_raises(ValueError, lambda: x.__matmul__(y)) + assert_raises(ValueError, lambda: y.__rmatmul__(x)) + if result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: x.__imatmul__(y)) + else: + assert_raises(ValueError, lambda: x.__imatmul__(y)) + else: + x.__matmul__(y) + y.__rmatmul__(x) + if result_type(x.dtype, y.dtype) != x.dtype: + assert_raises(TypeError, lambda: x.__imatmul__(y)) + elif y.shape[0] != y.shape[1]: + # This one fails because x @ y has a different shape from x + assert_raises(ValueError, lambda: x.__imatmul__(y)) + else: + x.__imatmul__(y) |