diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-04-15 15:39:14 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-04-15 15:39:14 -0600 |
commit | 6c196f540429aa0869e8fb66917a5e76447d2c02 (patch) | |
tree | 593b0e4602865f0c07519c061eee6f5681316b8c /numpy/_array_api/_elementwise_functions.py | |
parent | 9af1cc60edd4fdbb7e9c18d124e639a44ce420c7 (diff) | |
download | numpy-6c196f540429aa0869e8fb66917a5e76447d2c02.tar.gz |
Fix type promotion consistency for the array API elementwise functions and operators
NumPy's type promotion behavior deviates from the spec, which says that type
promotion should work independently of shapes or values, in cases where one
array is 0-d and the other is not. A helper function is added that works
around this issue by adding a dimension to the 0-d array before passing it to
the NumPy function. This function is used in elementwise functions and
operators. It may still need to be applied to other functions in the
namespace.
Additionally, this fixes:
- The shift operators (<< and >>) should always return the same dtype as the
first argument.
- NumPy's __pow__ does not type promote the two arguments, so we use the array
API pow() in ndarray.__pow__, which does.
- The internal _promote_scalar helper function was changed to return an array
API ndarray object, as this is simpler with the inclusion of the new
_normalize_two_args helper in the operators.
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r-- | numpy/_array_api/_elementwise_functions.py | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index 5ee33f60f..cb855da12 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -1,7 +1,8 @@ from __future__ import annotations from ._dtypes import (_boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes) + _integer_dtypes, _integer_or_boolean_dtypes, + _numeric_dtypes) from ._array_object import ndarray from typing import TYPE_CHECKING @@ -50,6 +51,7 @@ def add(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in add') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.add(x1._array, x2._array)) # Note: the function name is different here @@ -94,6 +96,7 @@ def atan2(x1: array, x2: array, /) -> array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in atan2') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.arctan2(x1._array, x2._array)) # Note: the function name is different here @@ -115,6 +118,7 @@ def bitwise_and(x1: array, x2: array, /) -> array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer_or_boolean dtypes are allowed in bitwise_and') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.bitwise_and(x1._array, x2._array)) # Note: the function name is different here @@ -126,6 +130,7 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array: """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') + x1, x2 = ndarray._normalize_two_args(x1, x2) # Note: bitwise_left_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError('bitwise_left_shift(x1, x2) is only defined for x2 >= 0') @@ -153,6 +158,7 @@ def bitwise_or(x1: array, x2: array, /) -> array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.bitwise_or(x1._array, x2._array)) # Note: the function name is different here @@ -164,6 +170,7 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array: """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') + x1, x2 = ndarray._normalize_two_args(x1, x2) # Note: bitwise_right_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError('bitwise_right_shift(x1, x2) is only defined for x2 >= 0') @@ -180,6 +187,7 @@ def bitwise_xor(x1: array, x2: array, /) -> array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.bitwise_xor(x1._array, x2._array)) def ceil(x: array, /) -> array: @@ -223,6 +231,7 @@ def divide(x1: array, x2: array, /) -> array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in divide') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.divide(x1._array, x2._array)) def equal(x1: array, x2: array, /) -> array: @@ -231,6 +240,7 @@ def equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.equal(x1._array, x2._array)) def exp(x: array, /) -> array: @@ -274,6 +284,7 @@ def floor_divide(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in floor_divide') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.floor_divide(x1._array, x2._array)) def greater(x1: array, x2: array, /) -> array: @@ -284,6 +295,7 @@ def greater(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in greater') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.greater(x1._array, x2._array)) def greater_equal(x1: array, x2: array, /) -> array: @@ -294,6 +306,7 @@ def greater_equal(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in greater_equal') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.greater_equal(x1._array, x2._array)) def isfinite(x: array, /) -> array: @@ -334,6 +347,7 @@ def less(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in less') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.less(x1._array, x2._array)) def less_equal(x1: array, x2: array, /) -> array: @@ -344,6 +358,7 @@ def less_equal(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in less_equal') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.less_equal(x1._array, x2._array)) def log(x: array, /) -> array: @@ -394,6 +409,7 @@ def logaddexp(x1: array, x2: array) -> array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in logaddexp') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.logaddexp(x1._array, x2._array)) def logical_and(x1: array, x2: array, /) -> array: @@ -404,6 +420,7 @@ def logical_and(x1: array, x2: array, /) -> array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_and') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.logical_and(x1._array, x2._array)) def logical_not(x: array, /) -> array: @@ -424,6 +441,7 @@ def logical_or(x1: array, x2: array, /) -> array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_or') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.logical_or(x1._array, x2._array)) def logical_xor(x1: array, x2: array, /) -> array: @@ -434,6 +452,7 @@ def logical_xor(x1: array, x2: array, /) -> array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_xor') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.logical_xor(x1._array, x2._array)) def multiply(x1: array, x2: array, /) -> array: @@ -444,6 +463,7 @@ def multiply(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in multiply') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.multiply(x1._array, x2._array)) def negative(x: array, /) -> array: @@ -462,6 +482,7 @@ def not_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.not_equal(x1._array, x2._array)) def positive(x: array, /) -> array: @@ -483,6 +504,7 @@ def pow(x1: array, x2: array, /) -> array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in pow') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.power(x1._array, x2._array)) def remainder(x1: array, x2: array, /) -> array: @@ -493,6 +515,7 @@ def remainder(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in remainder') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.remainder(x1._array, x2._array)) def round(x: array, /) -> array: @@ -563,6 +586,7 @@ def subtract(x1: array, x2: array, /) -> array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in subtract') + x1, x2 = ndarray._normalize_two_args(x1, x2) return ndarray._new(np.subtract(x1._array, x2._array)) def tan(x: array, /) -> array: |