diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-03-31 16:34:23 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-03-31 16:34:23 -0600 |
commit | 9fe4fc7ff7f477fc4aaad850f8d1841beb2924bc (patch) | |
tree | a5c49d08d8f3be89cd506762f5ef752fdf432e25 /numpy/_array_api/_array_object.py | |
parent | 7ce435c610fcd7fee01da9d9e7ff5c1ab4ae6ef6 (diff) | |
download | numpy-9fe4fc7ff7f477fc4aaad850f8d1841beb2924bc.tar.gz |
Make the array API follow the spec Python scalar promotion rules
Diffstat (limited to 'numpy/_array_api/_array_object.py')
-rw-r--r-- | numpy/_array_api/_array_object.py | 129 |
1 files changed, 128 insertions, 1 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index ad0cbc71e..c5acb5d1d 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -18,7 +18,7 @@ from __future__ import annotations import operator from enum import IntEnum from ._creation_functions import asarray -from ._dtypes import _boolean_dtypes, _integer_dtypes +from ._dtypes import _boolean_dtypes, _integer_dtypes, _floating_dtypes from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -83,6 +83,37 @@ class ndarray: """ return self._array.__repr__().replace('array', 'ndarray') + # Helper function to match the type promotion rules in the spec + def _promote_scalar(self, scalar): + """ + Returns a promoted version of a Python scalar appropiate for use with + operations on self. + + This may raise an OverflowError in cases where the scalar is an + integer that is too large to fit in a NumPy integer dtype, or + TypeError when the scalar type is incompatible with the dtype of self. + + Note: this helper function returns a NumPy array (NOT a NumPy array + API ndarray). + """ + if isinstance(scalar, bool): + if self.dtype not in _boolean_dtypes: + raise TypeError("Python bool scalars can only be promoted with bool arrays") + elif isinstance(scalar, int): + if self.dtype in _boolean_dtypes: + raise TypeError("Python int scalars cannot be promoted with bool arrays") + elif isinstance(scalar, float): + if self.dtype not in _floating_dtypes: + raise TypeError("Python float scalars can only be promoted with floating-point arrays.") + else: + raise TypeError("'scalar' must be a Python scalar") + + # Note: the spec only specifies integer-dtype/int promotion + # behavior for integers within the bounds of the integer dtype. + # Outside of those bounds we use the default NumPy behavior (either + # cast or raise OverflowError). + return np.array(scalar, self.dtype) + # Everything below this is required by the spec. def __abs__(self: array, /) -> array: @@ -96,6 +127,8 @@ class ndarray: """ Performs the operation __add__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__add__(asarray(other)._array) return self.__class__._new(res) @@ -103,6 +136,8 @@ class ndarray: """ Performs the operation __and__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__and__(asarray(other)._array) return self.__class__._new(res) @@ -140,6 +175,8 @@ class ndarray: """ Performs the operation __eq__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__eq__(asarray(other)._array) return self.__class__._new(res) @@ -157,6 +194,8 @@ class ndarray: """ Performs the operation __floordiv__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__floordiv__(asarray(other)._array) return self.__class__._new(res) @@ -164,6 +203,8 @@ class ndarray: """ Performs the operation __ge__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ge__(asarray(other)._array) return self.__class__._new(res) @@ -288,6 +329,8 @@ class ndarray: """ Performs the operation __gt__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__gt__(asarray(other)._array) return self.__class__._new(res) @@ -312,6 +355,8 @@ class ndarray: """ Performs the operation __le__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__le__(asarray(other)._array) return self.__class__._new(res) @@ -326,6 +371,8 @@ class ndarray: """ Performs the operation __lshift__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__lshift__(asarray(other)._array) return self.__class__._new(res) @@ -333,6 +380,8 @@ class ndarray: """ Performs the operation __lt__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__lt__(asarray(other)._array) return self.__class__._new(res) @@ -340,6 +389,10 @@ class ndarray: """ Performs the operation __matmul__. """ + if isinstance(other, (int, float, bool)): + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._promote_scalar(other) res = self._array.__matmul__(asarray(other)._array) return self.__class__._new(res) @@ -347,6 +400,8 @@ class ndarray: """ Performs the operation __mod__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__mod__(asarray(other)._array) return self.__class__._new(res) @@ -354,6 +409,8 @@ class ndarray: """ Performs the operation __mul__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__mul__(asarray(other)._array) return self.__class__._new(res) @@ -361,6 +418,8 @@ class ndarray: """ Performs the operation __ne__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ne__(asarray(other)._array) return self.__class__._new(res) @@ -375,6 +434,8 @@ class ndarray: """ Performs the operation __or__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__or__(asarray(other)._array) return self.__class__._new(res) @@ -389,6 +450,8 @@ class ndarray: """ Performs the operation __pow__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__pow__(asarray(other)._array) return self.__class__._new(res) @@ -396,6 +459,8 @@ class ndarray: """ Performs the operation __rshift__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rshift__(asarray(other)._array) return self.__class__._new(res) @@ -413,6 +478,8 @@ class ndarray: """ Performs the operation __sub__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__sub__(asarray(other)._array) return self.__class__._new(res) @@ -420,6 +487,8 @@ class ndarray: """ Performs the operation __truediv__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__truediv__(asarray(other)._array) return self.__class__._new(res) @@ -427,6 +496,8 @@ class ndarray: """ Performs the operation __xor__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__xor__(asarray(other)._array) return self.__class__._new(res) @@ -434,6 +505,8 @@ class ndarray: """ Performs the operation __iadd__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__iadd__(asarray(other)._array) return self.__class__._new(res) @@ -441,6 +514,8 @@ class ndarray: """ Performs the operation __radd__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__radd__(asarray(other)._array) return self.__class__._new(res) @@ -448,6 +523,8 @@ class ndarray: """ Performs the operation __iand__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__iand__(asarray(other)._array) return self.__class__._new(res) @@ -455,6 +532,8 @@ class ndarray: """ Performs the operation __rand__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rand__(asarray(other)._array) return self.__class__._new(res) @@ -462,6 +541,8 @@ class ndarray: """ Performs the operation __ifloordiv__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ifloordiv__(asarray(other)._array) return self.__class__._new(res) @@ -469,6 +550,8 @@ class ndarray: """ Performs the operation __rfloordiv__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rfloordiv__(asarray(other)._array) return self.__class__._new(res) @@ -476,6 +559,8 @@ class ndarray: """ Performs the operation __ilshift__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ilshift__(asarray(other)._array) return self.__class__._new(res) @@ -483,6 +568,8 @@ class ndarray: """ Performs the operation __rlshift__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rlshift__(asarray(other)._array) return self.__class__._new(res) @@ -490,6 +577,10 @@ class ndarray: """ Performs the operation __imatmul__. """ + if isinstance(other, (int, float, bool)): + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._promote_scalar(other) res = self._array.__imatmul__(asarray(other)._array) return self.__class__._new(res) @@ -497,6 +588,10 @@ class ndarray: """ Performs the operation __rmatmul__. """ + if isinstance(other, (int, float, bool)): + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._promote_scalar(other) res = self._array.__rmatmul__(asarray(other)._array) return self.__class__._new(res) @@ -504,6 +599,8 @@ class ndarray: """ Performs the operation __imod__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__imod__(asarray(other)._array) return self.__class__._new(res) @@ -511,6 +608,8 @@ class ndarray: """ Performs the operation __rmod__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rmod__(asarray(other)._array) return self.__class__._new(res) @@ -518,6 +617,8 @@ class ndarray: """ Performs the operation __imul__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__imul__(asarray(other)._array) return self.__class__._new(res) @@ -525,6 +626,8 @@ class ndarray: """ Performs the operation __rmul__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rmul__(asarray(other)._array) return self.__class__._new(res) @@ -532,6 +635,8 @@ class ndarray: """ Performs the operation __ior__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ior__(asarray(other)._array) return self.__class__._new(res) @@ -539,6 +644,8 @@ class ndarray: """ Performs the operation __ror__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ror__(asarray(other)._array) return self.__class__._new(res) @@ -546,6 +653,8 @@ class ndarray: """ Performs the operation __ipow__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ipow__(asarray(other)._array) return self.__class__._new(res) @@ -553,6 +662,8 @@ class ndarray: """ Performs the operation __rpow__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rpow__(asarray(other)._array) return self.__class__._new(res) @@ -560,6 +671,8 @@ class ndarray: """ Performs the operation __irshift__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__irshift__(asarray(other)._array) return self.__class__._new(res) @@ -567,6 +680,8 @@ class ndarray: """ Performs the operation __rrshift__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rrshift__(asarray(other)._array) return self.__class__._new(res) @@ -574,6 +689,8 @@ class ndarray: """ Performs the operation __isub__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__isub__(asarray(other)._array) return self.__class__._new(res) @@ -581,6 +698,8 @@ class ndarray: """ Performs the operation __rsub__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rsub__(asarray(other)._array) return self.__class__._new(res) @@ -588,6 +707,8 @@ class ndarray: """ Performs the operation __itruediv__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__itruediv__(asarray(other)._array) return self.__class__._new(res) @@ -595,6 +716,8 @@ class ndarray: """ Performs the operation __rtruediv__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rtruediv__(asarray(other)._array) return self.__class__._new(res) @@ -602,6 +725,8 @@ class ndarray: """ Performs the operation __ixor__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__ixor__(asarray(other)._array) return self.__class__._new(res) @@ -609,6 +734,8 @@ class ndarray: """ Performs the operation __rxor__. """ + if isinstance(other, (int, float, bool)): + other = self._promote_scalar(other) res = self._array.__rxor__(asarray(other)._array) return self.__class__._new(res) |