diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-07-22 16:39:54 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-07-22 16:39:54 -0600 |
commit | 626567645b180179159fa1807e72b26d58ce20dd (patch) | |
tree | 4897551ac751bf3c7c60a1ee32635b8199b040f3 /numpy/_array_api | |
parent | 776b1171aa76cc912abafb8434850bc9d37bd482 (diff) | |
download | numpy-626567645b180179159fa1807e72b26d58ce20dd.tar.gz |
Prevent unwanted type promotions everywhere in the array API namespace
Diffstat (limited to 'numpy/_array_api')
-rw-r--r-- | numpy/_array_api/_elementwise_functions.py | 48 | ||||
-rw-r--r-- | numpy/_array_api/_linear_algebra_functions.py | 6 | ||||
-rw-r--r-- | numpy/_array_api/_manipulation_functions.py | 5 | ||||
-rw-r--r-- | numpy/_array_api/_searching_functions.py | 3 |
4 files changed, 60 insertions, 2 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index c07c32de7..67fb7034d 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -2,7 +2,7 @@ from __future__ import annotations from ._dtypes import (_boolean_dtypes, _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes) + _numeric_dtypes, _result_type) from ._array_object import Array import numpy as np @@ -47,6 +47,8 @@ def add(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in add') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.add(x1._array, x2._array)) @@ -92,6 +94,8 @@ def atan2(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in atan2') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.arctan2(x1._array, x2._array)) @@ -114,6 +118,8 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_and') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_and(x1._array, x2._array)) @@ -126,6 +132,8 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_left_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): @@ -151,6 +159,8 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_or(x1._array, x2._array)) @@ -163,6 +173,8 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_right_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): @@ -177,6 +189,8 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.bitwise_xor(x1._array, x2._array)) @@ -221,6 +235,8 @@ def divide(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in divide') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.divide(x1._array, x2._array)) @@ -230,6 +246,8 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.equal(x1._array, x2._array)) @@ -274,6 +292,8 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in floor_divide') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.floor_divide(x1._array, x2._array)) @@ -285,6 +305,8 @@ def greater(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in greater') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater(x1._array, x2._array)) @@ -296,6 +318,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in greater_equal') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) @@ -337,6 +361,8 @@ def less(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in less') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less(x1._array, x2._array)) @@ -348,6 +374,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in less_equal') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less_equal(x1._array, x2._array)) @@ -399,6 +427,8 @@ def logaddexp(x1: Array, x2: Array) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in logaddexp') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logaddexp(x1._array, x2._array)) @@ -410,6 +440,8 @@ def logical_and(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_and') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_and(x1._array, x2._array)) @@ -431,6 +463,8 @@ def logical_or(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_or') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_or(x1._array, x2._array)) @@ -442,6 +476,8 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_xor') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_xor(x1._array, x2._array)) @@ -453,6 +489,8 @@ def multiply(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in multiply') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.multiply(x1._array, x2._array)) @@ -472,6 +510,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.not_equal(x1._array, x2._array)) @@ -494,6 +534,8 @@ def pow(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in pow') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.power(x1._array, x2._array)) @@ -505,6 +547,8 @@ def remainder(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in remainder') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.remainder(x1._array, x2._array)) @@ -576,6 +620,8 @@ def subtract(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in subtract') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.subtract(x1._array, x2._array)) diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py index b4b2af134..f13f9c541 100644 --- a/numpy/_array_api/_linear_algebra_functions.py +++ b/numpy/_array_api/_linear_algebra_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations from ._array_object import Array -from ._dtypes import _numeric_dtypes +from ._dtypes import _numeric_dtypes, _result_type from typing import Optional, Sequence, Tuple, Union @@ -27,6 +27,8 @@ def matmul(x1: Array, x2: Array, /) -> Array: # np.matmul. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in matmul') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) return Array._new(np.matmul(x1._array, x2._array)) @@ -36,6 +38,8 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], # np.tensordot. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in tensordot') + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py index 6308bfc26..fa6344beb 100644 --- a/numpy/_array_api/_manipulation_functions.py +++ b/numpy/_array_api/_manipulation_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._array_object import Array +from ._data_type_functions import result_type from typing import List, Optional, Tuple, Union @@ -14,6 +15,8 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i See its docstring for more information. """ arrays = tuple(a._array for a in arrays) + # Call result type here just to raise on disallowed type combinations + result_type(*arrays) return Array._new(np.concatenate(arrays, axis=axis)) def expand_dims(x: Array, /, *, axis: int) -> Array: @@ -63,4 +66,6 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> See its docstring for more information. """ arrays = tuple(a._array for a in arrays) + # Call result type here just to raise on disallowed type combinations + result_type(*arrays) return Array._new(np.stack(arrays, axis=axis)) diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py index 4764992a1..d80720850 100644 --- a/numpy/_array_api/_searching_functions.py +++ b/numpy/_array_api/_searching_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._array_object import Array +from ._dtypes import _result_type from typing import Optional, Tuple @@ -38,4 +39,6 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) return Array._new(np.where(condition._array, x1._array, x2._array)) |