summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-22 16:39:54 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-22 16:39:54 -0600
commit626567645b180179159fa1807e72b26d58ce20dd (patch)
tree4897551ac751bf3c7c60a1ee32635b8199b040f3 /numpy/_array_api
parent776b1171aa76cc912abafb8434850bc9d37bd482 (diff)
downloadnumpy-626567645b180179159fa1807e72b26d58ce20dd.tar.gz
Prevent unwanted type promotions everywhere in the array API namespace
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_elementwise_functions.py48
-rw-r--r--numpy/_array_api/_linear_algebra_functions.py6
-rw-r--r--numpy/_array_api/_manipulation_functions.py5
-rw-r--r--numpy/_array_api/_searching_functions.py3
4 files changed, 60 insertions, 2 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index c07c32de7..67fb7034d 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from ._dtypes import (_boolean_dtypes, _floating_dtypes,
_integer_dtypes, _integer_or_boolean_dtypes,
- _numeric_dtypes)
+ _numeric_dtypes, _result_type)
from ._array_object import Array
import numpy as np
@@ -47,6 +47,8 @@ def add(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in add')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.add(x1._array, x2._array))
@@ -92,6 +94,8 @@ def atan2(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in atan2')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.arctan2(x1._array, x2._array))
@@ -114,6 +118,8 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError('Only integer or boolean dtypes are allowed in bitwise_and')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.bitwise_and(x1._array, x2._array))
@@ -126,6 +132,8 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
raise TypeError('Only integer dtypes are allowed in bitwise_left_shift')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
# Note: bitwise_left_shift is only defined for x2 nonnegative.
if np.any(x2._array < 0):
@@ -151,6 +159,8 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.bitwise_or(x1._array, x2._array))
@@ -163,6 +173,8 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
raise TypeError('Only integer dtypes are allowed in bitwise_right_shift')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
# Note: bitwise_right_shift is only defined for x2 nonnegative.
if np.any(x2._array < 0):
@@ -177,6 +189,8 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.bitwise_xor(x1._array, x2._array))
@@ -221,6 +235,8 @@ def divide(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in divide')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.divide(x1._array, x2._array))
@@ -230,6 +246,8 @@ def equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.equal(x1._array, x2._array))
@@ -274,6 +292,8 @@ def floor_divide(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in floor_divide')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.floor_divide(x1._array, x2._array))
@@ -285,6 +305,8 @@ def greater(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in greater')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater(x1._array, x2._array))
@@ -296,6 +318,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in greater_equal')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater_equal(x1._array, x2._array))
@@ -337,6 +361,8 @@ def less(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in less')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less(x1._array, x2._array))
@@ -348,6 +374,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in less_equal')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less_equal(x1._array, x2._array))
@@ -399,6 +427,8 @@ def logaddexp(x1: Array, x2: Array) -> Array:
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in logaddexp')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logaddexp(x1._array, x2._array))
@@ -410,6 +440,8 @@ def logical_and(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
raise TypeError('Only boolean dtypes are allowed in logical_and')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logical_and(x1._array, x2._array))
@@ -431,6 +463,8 @@ def logical_or(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
raise TypeError('Only boolean dtypes are allowed in logical_or')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logical_or(x1._array, x2._array))
@@ -442,6 +476,8 @@ def logical_xor(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
raise TypeError('Only boolean dtypes are allowed in logical_xor')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.logical_xor(x1._array, x2._array))
@@ -453,6 +489,8 @@ def multiply(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in multiply')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.multiply(x1._array, x2._array))
@@ -472,6 +510,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.not_equal(x1._array, x2._array))
@@ -494,6 +534,8 @@ def pow(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in pow')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.power(x1._array, x2._array))
@@ -505,6 +547,8 @@ def remainder(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in remainder')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.remainder(x1._array, x2._array))
@@ -576,6 +620,8 @@ def subtract(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in subtract')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.subtract(x1._array, x2._array))
diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py
index b4b2af134..f13f9c541 100644
--- a/numpy/_array_api/_linear_algebra_functions.py
+++ b/numpy/_array_api/_linear_algebra_functions.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from ._array_object import Array
-from ._dtypes import _numeric_dtypes
+from ._dtypes import _numeric_dtypes, _result_type
from typing import Optional, Sequence, Tuple, Union
@@ -27,6 +27,8 @@ def matmul(x1: Array, x2: Array, /) -> Array:
# np.matmul.
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in matmul')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
return Array._new(np.matmul(x1._array, x2._array))
@@ -36,6 +38,8 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
# np.tensordot.
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in tensordot')
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py
index 6308bfc26..fa6344beb 100644
--- a/numpy/_array_api/_manipulation_functions.py
+++ b/numpy/_array_api/_manipulation_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._array_object import Array
+from ._data_type_functions import result_type
from typing import List, Optional, Tuple, Union
@@ -14,6 +15,8 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i
See its docstring for more information.
"""
arrays = tuple(a._array for a in arrays)
+ # Call result type here just to raise on disallowed type combinations
+ result_type(*arrays)
return Array._new(np.concatenate(arrays, axis=axis))
def expand_dims(x: Array, /, *, axis: int) -> Array:
@@ -63,4 +66,6 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
See its docstring for more information.
"""
arrays = tuple(a._array for a in arrays)
+ # Call result type here just to raise on disallowed type combinations
+ result_type(*arrays)
return Array._new(np.stack(arrays, axis=axis))
diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py
index 4764992a1..d80720850 100644
--- a/numpy/_array_api/_searching_functions.py
+++ b/numpy/_array_api/_searching_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._array_object import Array
+from ._dtypes import _result_type
from typing import Optional, Tuple
@@ -38,4 +39,6 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
return Array._new(np.where(condition._array, x1._array, x2._array))