diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-07-21 15:48:06 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-07-21 15:48:06 -0600 |
commit | 63a9a87360ef492c46c37416b8270563e73a6349 (patch) | |
tree | 7d04d62253f2402d589803539d98875c60222fbc /numpy/_array_api/_array_object.py | |
parent | 9d5d0ec2264c86a19714cf185a5a183df14cbb94 (diff) | |
download | numpy-63a9a87360ef492c46c37416b8270563e73a6349.tar.gz |
Restrict the array API namespace array operator type promotions
Only those type promotions that are required by the spec are allowed. In
particular, promotions across kinds, like integer + floating-point, are not
allowed, except for the case of Python scalars.
Tests are added for this.
This commit additionally makes the operators return NotImplemented on
unexpected input types rather than directly giving a TypeError. This is not
strictly required by the array API spec, but it is generally considered a best
practice for operator methods in Python.
This same thing will be implemented for the various functions in the array API
namespace in a later commit.
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r-- | numpy/_array_api/_array_object.py | 305 |
1 files changed, 199 insertions, 106 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index f8bad0b59..f6371fbf4 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -18,7 +18,8 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import _all_dtypes, _boolean_dtypes, _integer_dtypes, _floating_dtypes +from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes, + _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes) from typing import TYPE_CHECKING, Any, Optional, Tuple, Union if TYPE_CHECKING: @@ -83,6 +84,52 @@ class Array: """ return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})" + # These are various helper functions to make the array behavior match the + # spec in places where it either deviates from or is more strict than + # NumPy behavior + + def _check_allowed_dtypes(self, other, dtype_category, op): + """ + Helper function for operators to only allow specific input dtypes + + Use like + + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other + """ + from ._dtypes import _result_type + + _dtypes = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer or boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating-point': _floating_dtypes, + } + + if self.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) + elif isinstance(other, Array): + if other.dtype not in _dtypes[dtype_category]: + raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') + else: + return NotImplemented + + res_dtype = _result_type(self.dtype, other.dtype) + if op.startswith('__i'): + # Note: NumPy will allow in-place operators in some cases where the type promoted operator does not match the left-hand side operand. For example, + + # >>> a = np.array(1, dtype=np.int8) + # >>> a += np.array(1, dtype=np.int16) + if res_dtype != self.dtype: + raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + + return other + # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): """ @@ -270,8 +317,9 @@ class Array: """ Performs the operation __add__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__add__(other._array) return self.__class__._new(res) @@ -280,8 +328,9 @@ class Array: """ Performs the operation __and__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__and__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) return self.__class__._new(res) @@ -321,6 +370,11 @@ class Array: """ Performs the operation __eq__. """ + # Even though "all" dtypes are allowed, we still require them to be + # promotable with each other. + other = self._check_allowed_dtypes(other, 'all', '__eq__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -341,8 +395,9 @@ class Array: """ Performs the operation __floordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__floordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__floordiv__(other._array) return self.__class__._new(res) @@ -351,8 +406,9 @@ class Array: """ Performs the operation __ge__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ge__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) return self.__class__._new(res) @@ -371,8 +427,9 @@ class Array: """ Performs the operation __gt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__gt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__gt__(other._array) return self.__class__._new(res) @@ -391,6 +448,8 @@ class Array: """ Performs the operation __invert__. """ + if self.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in __invert__') res = self._array.__invert__() return self.__class__._new(res) @@ -398,8 +457,9 @@ class Array: """ Performs the operation __le__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__le__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__le__(other._array) return self.__class__._new(res) @@ -416,8 +476,9 @@ class Array: """ Performs the operation __lshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__lshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lshift__(other._array) return self.__class__._new(res) @@ -426,8 +487,9 @@ class Array: """ Performs the operation __lt__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__lt__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__lt__(other._array) return self.__class__._new(res) @@ -436,10 +498,11 @@ class Array: """ Performs the operation __matmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__matmul__') + if other is NotImplemented: + return other res = self._array.__matmul__(other._array) return self.__class__._new(res) @@ -447,8 +510,9 @@ class Array: """ Performs the operation __mod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mod__(other._array) return self.__class__._new(res) @@ -457,8 +521,9 @@ class Array: """ Performs the operation __mul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__mul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__mul__(other._array) return self.__class__._new(res) @@ -467,6 +532,9 @@ class Array: """ Performs the operation __ne__. """ + other = self._check_allowed_dtypes(other, 'all', '__ne__') + if other is NotImplemented: + return other if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) self, other = self._normalize_two_args(self, other) @@ -477,6 +545,8 @@ class Array: """ Performs the operation __neg__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __neg__') res = self._array.__neg__() return self.__class__._new(res) @@ -484,8 +554,9 @@ class Array: """ Performs the operation __or__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__or__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__or__(other._array) return self.__class__._new(res) @@ -494,6 +565,8 @@ class Array: """ Performs the operation __pos__. """ + if self.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in __pos__') res = self._array.__pos__() return self.__class__._new(res) @@ -505,10 +578,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__pow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d # arrays, so we use pow() here instead. return pow(self, other) @@ -517,8 +589,9 @@ class Array: """ Performs the operation __rshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) return self.__class__._new(res) @@ -537,8 +610,9 @@ class Array: """ Performs the operation __sub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__sub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__sub__(other._array) return self.__class__._new(res) @@ -549,10 +623,9 @@ class Array: """ Performs the operation __truediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__truediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) return self.__class__._new(res) @@ -561,8 +634,9 @@ class Array: """ Performs the operation __xor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__xor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__xor__(other._array) return self.__class__._new(res) @@ -571,8 +645,9 @@ class Array: """ Performs the operation __iadd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__iadd__') + if other is NotImplemented: + return other self._array.__iadd__(other._array) return self @@ -580,8 +655,9 @@ class Array: """ Performs the operation __radd__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__radd__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__radd__(other._array) return self.__class__._new(res) @@ -590,8 +666,9 @@ class Array: """ Performs the operation __iand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__iand__') + if other is NotImplemented: + return other self._array.__iand__(other._array) return self @@ -599,8 +676,9 @@ class Array: """ Performs the operation __rand__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rand__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rand__(other._array) return self.__class__._new(res) @@ -609,8 +687,9 @@ class Array: """ Performs the operation __ifloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__ifloordiv__') + if other is NotImplemented: + return other self._array.__ifloordiv__(other._array) return self @@ -618,8 +697,9 @@ class Array: """ Performs the operation __rfloordiv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rfloordiv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rfloordiv__(other._array) return self.__class__._new(res) @@ -628,8 +708,9 @@ class Array: """ Performs the operation __ilshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__ilshift__') + if other is NotImplemented: + return other self._array.__ilshift__(other._array) return self @@ -637,8 +718,9 @@ class Array: """ Performs the operation __rlshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rlshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rlshift__(other._array) return self.__class__._new(res) @@ -649,15 +731,17 @@ class Array: """ # Note: NumPy does not implement __imatmul__. - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__imatmul__') + if other is NotImplemented: + return other + # __imatmul__ can only be allowed when it would not change the shape # of self. other_shape = other.shape if self.shape == () or other_shape == (): - raise ValueError("@= requires at least one dimension") + raise TypeError("@= requires at least one dimension") if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]: raise ValueError("@= cannot change the shape of the input array") self._array[:] = self._array.__matmul__(other._array) @@ -667,10 +751,11 @@ class Array: """ Performs the operation __rmatmul__. """ - if isinstance(other, (int, float, bool)): - # matmul is not defined for scalars, but without this, we may get - # the wrong error message from asarray. - other = self._promote_scalar(other) + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, 'numeric', '__rmatmul__') + if other is NotImplemented: + return other res = self._array.__rmatmul__(other._array) return self.__class__._new(res) @@ -678,8 +763,9 @@ class Array: """ Performs the operation __imod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imod__') + if other is NotImplemented: + return other self._array.__imod__(other._array) return self @@ -687,8 +773,9 @@ class Array: """ Performs the operation __rmod__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmod__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) return self.__class__._new(res) @@ -697,8 +784,9 @@ class Array: """ Performs the operation __imul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__imul__') + if other is NotImplemented: + return other self._array.__imul__(other._array) return self @@ -706,8 +794,9 @@ class Array: """ Performs the operation __rmul__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rmul__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) return self.__class__._new(res) @@ -716,8 +805,9 @@ class Array: """ Performs the operation __ior__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ior__') + if other is NotImplemented: + return other self._array.__ior__(other._array) return self @@ -725,8 +815,9 @@ class Array: """ Performs the operation __ror__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ror__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__ror__(other._array) return self.__class__._new(res) @@ -735,10 +826,9 @@ class Array: """ Performs the operation __ipow__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__ipow__') + if other is NotImplemented: + return other self._array.__ipow__(other._array) return self @@ -748,10 +838,9 @@ class Array: """ from ._elementwise_functions import pow - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __pow__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rpow__') + if other is NotImplemented: + return other # Note: NumPy's __pow__ does not follow the spec type promotion rules # for 0-d arrays, so we use pow() here instead. return pow(other, self) @@ -760,8 +849,9 @@ class Array: """ Performs the operation __irshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__irshift__') + if other is NotImplemented: + return other self._array.__irshift__(other._array) return self @@ -769,8 +859,9 @@ class Array: """ Performs the operation __rrshift__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer', '__rrshift__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) return self.__class__._new(res) @@ -779,8 +870,9 @@ class Array: """ Performs the operation __isub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__isub__') + if other is NotImplemented: + return other self._array.__isub__(other._array) return self @@ -788,8 +880,9 @@ class Array: """ Performs the operation __rsub__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'numeric', '__rsub__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) return self.__class__._new(res) @@ -798,10 +891,9 @@ class Array: """ Performs the operation __itruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__itruediv__') + if other is NotImplemented: + return other self._array.__itruediv__(other._array) return self @@ -809,10 +901,9 @@ class Array: """ Performs the operation __rtruediv__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) - if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: - raise TypeError('Only floating-point dtypes are allowed in __truediv__') + other = self._check_allowed_dtypes(other, 'floating-point', '__rtruediv__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res) @@ -821,8 +912,9 @@ class Array: """ Performs the operation __ixor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__ixor__') + if other is NotImplemented: + return other self._array.__ixor__(other._array) return self @@ -830,8 +922,9 @@ class Array: """ Performs the operation __rxor__. """ - if isinstance(other, (int, float, bool)): - other = self._promote_scalar(other) + other = self._check_allowed_dtypes(other, 'integer or boolean', '__rxor__') + if other is NotImplemented: + return other self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) return self.__class__._new(res) |