diff options
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r-- | numpy/_array_api/_array_object.py | 305 |
1 files changed, 199 insertions, 106 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index f8bad0b59..f6371fbf4 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -18,7 +18,8 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import _all_dtypes, _boolean_dtypes, _integer_dtypes, _floating_dtypes +from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, + _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes) from typing import TYPE_CHECKING, Any, Optional, Tuple, Union if TYPE_CHECKING: @@ -83,6 +84,52 @@ class Array: """ return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})" + # These are various helper functions to make the array behavior match the + # spec in places where it either deviates from or is more strict than + # NumPy behavior + + def _check_allowed_dtypes(self, other, dtype_category, op): + """ + Helper function for operators to only allow specific input dtypes + + Use like + + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other + """ + from ._dtypes import _result_type + + _dtypes = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer or boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating-point': _floating_dtypes, + } + + if self.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) + elif isinstance(other, Array): + if other.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + else: + return NotImplemented + + res_dtype = _result_type(self.dtype, other.dtype) + if op.startswith('__i'): + # Note: NumPy will allow in-place operators in some cases where the type promoted operator does not match the left-hand side operand. For example, + + # >>> a = np.array(1, dtype=np.int8) + # >>> a += np.array(1, dtype=np.int16) + if res_dtype != self.dtype: + raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + + return other + # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): """ @@ -270,8 +317,9 @@ class Array: """ Performs the operation __add__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__add__(other._array) return self.__class__._new(res) @@ -280,8 +328,9 @@ class Array: """ Performs the operation __and__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) return self.__class__._new(res) @@ -321,6 +370,11 @@ class Array: """ Performs the operation __eq__. """ + # Even though "all" dtypes are allowed, we still require them to be + # promotable with each other. + other = self._check_allowed_dtypes(other, 'all', '__eq__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -341,8 +395,9 @@ class Array: """ Performs the operation __floordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__floordiv__(other._array) return self.__class__._new(res) @@ -351,8 +406,9 @@ class Array: """ Performs the operation __ge__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) return self.__class__._new(res) @@ -371,8 +427,9 @@ class Array: """ Performs the operation __gt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__gt__(other._array) return self.__class__._new(res) @@ -391,6 +448,8 @@ class Array: """ Performs the operation __invert__. """ + if self.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in __invert__') res = self._array.__invert__() return self.__class__._new(res) @@ -398,8 +457,9 @@ class Array: """ Performs the operation __le__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__le__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__le__(other._array) return self.__class__._new(res) @@ -416,8 +476,9 @@ class Array: """ Performs the operation __lshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lshift__(other._array) return self.__class__._new(res) @@ -426,8 +487,9 @@ class Array: """ Performs the operation __lt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lt__(other._array) return self.__class__._new(res) @@ -436,10 +498,11 @@ class Array: """ Performs the operation __matmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + if other is NotImplemented: + return other res = self._array.__matmul__(other._array) return self.__class__._new(res) @@ -447,8 +510,9 @@ class Array: """ Performs the operation __mod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mod__(other._array) return self.__class__._new(res) @@ -457,8 +521,9 @@ class Array: """ Performs the operation __mul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mul__(other._array) return self.__class__._new(res) @@ -467,6 +532,9 @@ class Array: """ Performs the operation __ne__. """ + other = self._check_allowed_dtypes(other, 'all', '__ne__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -477,6 +545,8 @@ class Array: """ Performs the operation __neg__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __neg__') res = self._array.__neg__() return self.__class__._new(res) @@ -484,8 +554,9 @@ class Array: """ Performs the operation __or__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__or__(other._array) return self.__class__._new(res) @@ -494,6 +565,8 @@ class Array: """ Performs the operation __pos__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __pos__') res = self._array.__pos__() return self.__class__._new(res) @@ -505,10 +578,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d # arrays, so we use pow() here instead. return pow(self, other) @@ -517,8 +589,9 @@ class Array: """ Performs the operation __rshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) return self.__class__._new(res) @@ -537,8 +610,9 @@ class Array: """ Performs the operation __sub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__sub__(other._array) return self.__class__._new(res) @@ -549,10 +623,9 @@ class Array: """ Performs the operation __truediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) return self.__class__._new(res) @@ -561,8 +634,9 @@ class Array: """ Performs the operation __xor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__xor__(other._array) return self.__class__._new(res) @@ -571,8 +645,9 @@ class Array: """ Performs the operation __iadd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + if other is NotImplemented: + return other self._array.__iadd__(other._array) return self @@ -580,8 +655,9 @@ class Array: """ Performs the operation __radd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__radd__(other._array) return self.__class__._new(res) @@ -590,8 +666,9 @@ class Array: """ Performs the operation __iand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + if other is NotImplemented: + return other self._array.__iand__(other._array) return self @@ -599,8 +676,9 @@ class Array: """ Performs the operation __rand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rand__(other._array) return self.__class__._new(res) @@ -609,8 +687,9 @@ class Array: """ Performs the operation __ifloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + if other is NotImplemented: + return other self._array.__ifloordiv__(other._array) return self @@ -618,8 +697,9 @@ class Array: """ Performs the operation __rfloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rfloordiv__(other._array) return self.__class__._new(res) @@ -628,8 +708,9 @@ class Array: """ Performs the operation __ilshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + if other is NotImplemented: + return other self._array.__ilshift__(other._array) return self @@ -637,8 +718,9 @@ class Array: """ Performs the operation __rlshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rlshift__(other._array) return self.__class__._new(res) @@ -649,15 +731,17 @@ class Array: """ # Note: NumPy does not implement __imatmul__. - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + if other is NotImplemented: + return other + # __imatmul__ can only be allowed when it would not change the shape # of self. other_shape = other.shape if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") + raise TypeError("@= requires at least one dimension") if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: raise ValueError("@= cannot change the shape of the input array") self._array[:] = self._array.__matmul__(other._array) @@ -667,10 +751,11 @@ class Array: """ Performs the operation __rmatmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + if other is NotImplemented: + return other res = self._array.__rmatmul__(other._array) return self.__class__._new(res) @@ -678,8 +763,9 @@ class Array: """ Performs the operation __imod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + if other is NotImplemented: + return other self._array.__imod__(other._array) return self @@ -687,8 +773,9 @@ class Array: """ Performs the operation __rmod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) return self.__class__._new(res) @@ -697,8 +784,9 @@ class Array: """ Performs the operation __imul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + if other is NotImplemented: + return other self._array.__imul__(other._array) return self @@ -706,8 +794,9 @@ class Array: """ Performs the operation __rmul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) return self.__class__._new(res) @@ -716,8 +805,9 @@ class Array: """ Performs the operation __ior__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + if other is NotImplemented: + return other self._array.__ior__(other._array) return self @@ -725,8 +815,9 @@ class Array: """ Performs the operation __ror__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ror__(other._array) return self.__class__._new(res) @@ -735,10 +826,9 @@ class Array: """ Performs the operation __ipow__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + if other is NotImplemented: + return other self._array.__ipow__(other._array) return self @@ -748,10 +838,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow the spec type promotion rules # for 0-d arrays, so we use pow() here instead. return pow(other, self) @@ -760,8 +849,9 @@ class Array: """ Performs the operation __irshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + if other is NotImplemented: + return other self._array.__irshift__(other._array) return self @@ -769,8 +859,9 @@ class Array: """ Performs the operation __rrshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) return self.__class__._new(res) @@ -779,8 +870,9 @@ class Array: """ Performs the operation __isub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + if other is NotImplemented: + return other self._array.__isub__(other._array) return self @@ -788,8 +880,9 @@ class Array: """ Performs the operation __rsub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) return self.__class__._new(res) @@ -798,10 +891,9 @@ class Array: """ Performs the operation __itruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + if other is NotImplemented: + return other self._array.__itruediv__(other._array) return self @@ -809,10 +901,9 @@ class Array: """ Performs the operation __rtruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res) @@ -821,8 +912,9 @@ class Array: """ Performs the operation __ixor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + if other is NotImplemented: + return other self._array.__ixor__(other._array) return self @@ -830,8 +922,9 @@ class Array: """ Performs the operation __rxor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) return self.__class__._new(res) |