diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-07-09 16:24:45 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-07-09 16:24:45 -0600 |
commit | 5febef530e055572fd5eac18807675ee451c81b0 (patch) | |
tree | 395c332253db4a091d6fdeba64d93f06d12fef3d /numpy/_array_api/_array_object.py | |
parent | 6379138a6da6ebf73bfc4bc4e019a21d8a99be0a (diff) | |
download | numpy-5febef530e055572fd5eac18807675ee451c81b0.tar.gz |
Only allow floating-point dtypes in the array API __pow__ and __truediv__
See https://github.com/data-apis/array-api/pull/221.
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r-- | numpy/_array_api/_array_object.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 2377bffe3..797f9ea4f 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -503,6 +503,8 @@ class Array: if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') # Note: NumPy's __pow__ does not follow type promotion rules for 0-d # arrays, so we use pow() here instead. return pow(self, other) @@ -548,6 +550,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) return self.__class__._new(res) @@ -744,6 +748,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') self._array.__ipow__(other._array) return self @@ -756,6 +762,8 @@ class Array: if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __pow__') # Note: NumPy's __pow__ does not follow the spec type promotion rules # for 0-d arrays, so we use pow() here instead. return pow(other, self) @@ -810,6 +818,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self._array.__itruediv__(other._array) return self @@ -820,6 +830,8 @@ class Array: """ if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) + if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in __truediv__') self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res) |