summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_array_object.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-09 16:24:45 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-09 16:24:45 -0600
commit5febef530e055572fd5eac18807675ee451c81b0 (patch)
tree395c332253db4a091d6fdeba64d93f06d12fef3d /numpy/_array_api/_array_object.py
parent6379138a6da6ebf73bfc4bc4e019a21d8a99be0a (diff)
downloadnumpy-5febef530e055572fd5eac18807675ee451c81b0.tar.gz
Only allow floating-point dtypes in the array API __pow__ and __truediv__
See https://github.com/data-apis/array-api/pull/221.
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r--numpy/_array_api/_array_object.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index 2377bffe3..797f9ea4f 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -503,6 +503,8 @@ class Array:
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
+ if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in __pow__')
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
# arrays, so we use pow() here instead.
return pow(self, other)
@@ -548,6 +550,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
+ if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in __truediv__')
self, other = self._normalize_two_args(self, other)
res = self._array.__truediv__(other._array)
return self.__class__._new(res)
@@ -744,6 +748,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
+ if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in __pow__')
self._array.__ipow__(other._array)
return self
@@ -756,6 +762,8 @@ class Array:
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
+ if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in __pow__')
# Note: NumPy's __pow__ does not follow the spec type promotion rules
# for 0-d arrays, so we use pow() here instead.
return pow(other, self)
@@ -810,6 +818,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
+ if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in __truediv__')
self._array.__itruediv__(other._array)
return self
@@ -820,6 +830,8 @@ class Array:
"""
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
+ if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in __truediv__')
self, other = self._normalize_two_args(self, other)
res = self._array.__rtruediv__(other._array)
return self.__class__._new(res)