summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_array_object.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-03-31 16:34:23 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-03-31 16:34:23 -0600
commit9fe4fc7ff7f477fc4aaad850f8d1841beb2924bc (patch)
treea5c49d08d8f3be89cd506762f5ef752fdf432e25 /numpy/_array_api/_array_object.py
parent7ce435c610fcd7fee01da9d9e7ff5c1ab4ae6ef6 (diff)
downloadnumpy-9fe4fc7ff7f477fc4aaad850f8d1841beb2924bc.tar.gz
Make the array API follow the spec Python scalar promotion rules
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r--numpy/_array_api/_array_object.py129
1 files changed, 128 insertions, 1 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index ad0cbc71e..c5acb5d1d 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import operator
from enum import IntEnum
from ._creation_functions import asarray
-from ._dtypes import _boolean_dtypes, _integer_dtypes
+from ._dtypes import _boolean_dtypes, _integer_dtypes, _floating_dtypes
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -83,6 +83,37 @@ class ndarray:
"""
return self._array.__repr__().replace('array', 'ndarray')
+ # Helper function to match the type promotion rules in the spec
+ def _promote_scalar(self, scalar):
+ """
+ Returns a promoted version of a Python scalar appropiate for use with
+ operations on self.
+
+ This may raise an OverflowError in cases where the scalar is an
+ integer that is too large to fit in a NumPy integer dtype, or
+ TypeError when the scalar type is incompatible with the dtype of self.
+
+ Note: this helper function returns a NumPy array (NOT a NumPy array
+ API ndarray).
+ """
+ if isinstance(scalar, bool):
+ if self.dtype not in _boolean_dtypes:
+ raise TypeError("Python bool scalars can only be promoted with bool arrays")
+ elif isinstance(scalar, int):
+ if self.dtype in _boolean_dtypes:
+ raise TypeError("Python int scalars cannot be promoted with bool arrays")
+ elif isinstance(scalar, float):
+ if self.dtype not in _floating_dtypes:
+ raise TypeError("Python float scalars can only be promoted with floating-point arrays.")
+ else:
+ raise TypeError("'scalar' must be a Python scalar")
+
+ # Note: the spec only specifies integer-dtype/int promotion
+ # behavior for integers within the bounds of the integer dtype.
+ # Outside of those bounds we use the default NumPy behavior (either
+ # cast or raise OverflowError).
+ return np.array(scalar, self.dtype)
+
# Everything below this is required by the spec.
def __abs__(self: array, /) -> array:
@@ -96,6 +127,8 @@ class ndarray:
"""
Performs the operation __add__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__add__(asarray(other)._array)
return self.__class__._new(res)
@@ -103,6 +136,8 @@ class ndarray:
"""
Performs the operation __and__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__and__(asarray(other)._array)
return self.__class__._new(res)
@@ -140,6 +175,8 @@ class ndarray:
"""
Performs the operation __eq__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__eq__(asarray(other)._array)
return self.__class__._new(res)
@@ -157,6 +194,8 @@ class ndarray:
"""
Performs the operation __floordiv__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__floordiv__(asarray(other)._array)
return self.__class__._new(res)
@@ -164,6 +203,8 @@ class ndarray:
"""
Performs the operation __ge__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ge__(asarray(other)._array)
return self.__class__._new(res)
@@ -288,6 +329,8 @@ class ndarray:
"""
Performs the operation __gt__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__gt__(asarray(other)._array)
return self.__class__._new(res)
@@ -312,6 +355,8 @@ class ndarray:
"""
Performs the operation __le__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__le__(asarray(other)._array)
return self.__class__._new(res)
@@ -326,6 +371,8 @@ class ndarray:
"""
Performs the operation __lshift__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__lshift__(asarray(other)._array)
return self.__class__._new(res)
@@ -333,6 +380,8 @@ class ndarray:
"""
Performs the operation __lt__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__lt__(asarray(other)._array)
return self.__class__._new(res)
@@ -340,6 +389,10 @@ class ndarray:
"""
Performs the operation __matmul__.
"""
+ if isinstance(other, (int, float, bool)):
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._promote_scalar(other)
res = self._array.__matmul__(asarray(other)._array)
return self.__class__._new(res)
@@ -347,6 +400,8 @@ class ndarray:
"""
Performs the operation __mod__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__mod__(asarray(other)._array)
return self.__class__._new(res)
@@ -354,6 +409,8 @@ class ndarray:
"""
Performs the operation __mul__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__mul__(asarray(other)._array)
return self.__class__._new(res)
@@ -361,6 +418,8 @@ class ndarray:
"""
Performs the operation __ne__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ne__(asarray(other)._array)
return self.__class__._new(res)
@@ -375,6 +434,8 @@ class ndarray:
"""
Performs the operation __or__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__or__(asarray(other)._array)
return self.__class__._new(res)
@@ -389,6 +450,8 @@ class ndarray:
"""
Performs the operation __pow__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__pow__(asarray(other)._array)
return self.__class__._new(res)
@@ -396,6 +459,8 @@ class ndarray:
"""
Performs the operation __rshift__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rshift__(asarray(other)._array)
return self.__class__._new(res)
@@ -413,6 +478,8 @@ class ndarray:
"""
Performs the operation __sub__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__sub__(asarray(other)._array)
return self.__class__._new(res)
@@ -420,6 +487,8 @@ class ndarray:
"""
Performs the operation __truediv__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__truediv__(asarray(other)._array)
return self.__class__._new(res)
@@ -427,6 +496,8 @@ class ndarray:
"""
Performs the operation __xor__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__xor__(asarray(other)._array)
return self.__class__._new(res)
@@ -434,6 +505,8 @@ class ndarray:
"""
Performs the operation __iadd__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__iadd__(asarray(other)._array)
return self.__class__._new(res)
@@ -441,6 +514,8 @@ class ndarray:
"""
Performs the operation __radd__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__radd__(asarray(other)._array)
return self.__class__._new(res)
@@ -448,6 +523,8 @@ class ndarray:
"""
Performs the operation __iand__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__iand__(asarray(other)._array)
return self.__class__._new(res)
@@ -455,6 +532,8 @@ class ndarray:
"""
Performs the operation __rand__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rand__(asarray(other)._array)
return self.__class__._new(res)
@@ -462,6 +541,8 @@ class ndarray:
"""
Performs the operation __ifloordiv__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ifloordiv__(asarray(other)._array)
return self.__class__._new(res)
@@ -469,6 +550,8 @@ class ndarray:
"""
Performs the operation __rfloordiv__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rfloordiv__(asarray(other)._array)
return self.__class__._new(res)
@@ -476,6 +559,8 @@ class ndarray:
"""
Performs the operation __ilshift__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ilshift__(asarray(other)._array)
return self.__class__._new(res)
@@ -483,6 +568,8 @@ class ndarray:
"""
Performs the operation __rlshift__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rlshift__(asarray(other)._array)
return self.__class__._new(res)
@@ -490,6 +577,10 @@ class ndarray:
"""
Performs the operation __imatmul__.
"""
+ if isinstance(other, (int, float, bool)):
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._promote_scalar(other)
res = self._array.__imatmul__(asarray(other)._array)
return self.__class__._new(res)
@@ -497,6 +588,10 @@ class ndarray:
"""
Performs the operation __rmatmul__.
"""
+ if isinstance(other, (int, float, bool)):
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._promote_scalar(other)
res = self._array.__rmatmul__(asarray(other)._array)
return self.__class__._new(res)
@@ -504,6 +599,8 @@ class ndarray:
"""
Performs the operation __imod__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__imod__(asarray(other)._array)
return self.__class__._new(res)
@@ -511,6 +608,8 @@ class ndarray:
"""
Performs the operation __rmod__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rmod__(asarray(other)._array)
return self.__class__._new(res)
@@ -518,6 +617,8 @@ class ndarray:
"""
Performs the operation __imul__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__imul__(asarray(other)._array)
return self.__class__._new(res)
@@ -525,6 +626,8 @@ class ndarray:
"""
Performs the operation __rmul__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rmul__(asarray(other)._array)
return self.__class__._new(res)
@@ -532,6 +635,8 @@ class ndarray:
"""
Performs the operation __ior__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ior__(asarray(other)._array)
return self.__class__._new(res)
@@ -539,6 +644,8 @@ class ndarray:
"""
Performs the operation __ror__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ror__(asarray(other)._array)
return self.__class__._new(res)
@@ -546,6 +653,8 @@ class ndarray:
"""
Performs the operation __ipow__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ipow__(asarray(other)._array)
return self.__class__._new(res)
@@ -553,6 +662,8 @@ class ndarray:
"""
Performs the operation __rpow__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rpow__(asarray(other)._array)
return self.__class__._new(res)
@@ -560,6 +671,8 @@ class ndarray:
"""
Performs the operation __irshift__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__irshift__(asarray(other)._array)
return self.__class__._new(res)
@@ -567,6 +680,8 @@ class ndarray:
"""
Performs the operation __rrshift__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rrshift__(asarray(other)._array)
return self.__class__._new(res)
@@ -574,6 +689,8 @@ class ndarray:
"""
Performs the operation __isub__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__isub__(asarray(other)._array)
return self.__class__._new(res)
@@ -581,6 +698,8 @@ class ndarray:
"""
Performs the operation __rsub__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rsub__(asarray(other)._array)
return self.__class__._new(res)
@@ -588,6 +707,8 @@ class ndarray:
"""
Performs the operation __itruediv__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__itruediv__(asarray(other)._array)
return self.__class__._new(res)
@@ -595,6 +716,8 @@ class ndarray:
"""
Performs the operation __rtruediv__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rtruediv__(asarray(other)._array)
return self.__class__._new(res)
@@ -602,6 +725,8 @@ class ndarray:
"""
Performs the operation __ixor__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__ixor__(asarray(other)._array)
return self.__class__._new(res)
@@ -609,6 +734,8 @@ class ndarray:
"""
Performs the operation __rxor__.
"""
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
res = self._array.__rxor__(asarray(other)._array)
return self.__class__._new(res)