diff options
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r-- | numpy/_array_api/_elementwise_functions.py | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index abb7ef4dd..2357b337c 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -1,5 +1,7 @@ from __future__ import annotations +from ._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes) from ._types import array from ._array_object import ndarray @@ -11,6 +13,8 @@ def abs(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in abs') return ndarray._new(np.abs(x._array)) def acos(x: array, /) -> array: @@ -19,6 +23,8 @@ def acos(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in acos') # Note: the function name is different here return ndarray._new(np.arccos(x._array)) @@ -28,6 +34,8 @@ def acosh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in acosh') # Note: the function name is different here return ndarray._new(np.arccosh(x._array)) @@ -37,6 +45,8 @@ def add(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in add') return ndarray._new(np.add(x1._array, x2._array)) def asin(x: array, /) -> array: @@ -45,6 +55,8 @@ def asin(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in asin') # Note: the function name is different here return ndarray._new(np.arcsin(x._array)) @@ -54,6 +66,8 @@ def asinh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in asinh') # Note: the function name is different here return ndarray._new(np.arcsinh(x._array)) @@ -63,6 +77,8 @@ def atan(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atan') # Note: the function name is different here return ndarray._new(np.arctan(x._array)) @@ -72,6 +88,8 @@ def atan2(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atan2') # Note: the function name is different here return ndarray._new(np.arctan2(x1._array, x2._array)) @@ -81,6 +99,8 @@ def atanh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atanh') # Note: the function name is different here return ndarray._new(np.arctanh(x._array)) @@ -90,6 +110,8 @@ def bitwise_and(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer_or_boolean dtypes are allowed in bitwise_and') return ndarray._new(np.bitwise_and(x1._array, x2._array)) def bitwise_left_shift(x1: array, x2: array, /) -> array: @@ -98,6 +120,8 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: + raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') # Note: the function name is different here return ndarray._new(np.left_shift(x1._array, x2._array)) @@ -107,6 +131,8 @@ def bitwise_invert(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert') # Note: the function name is different here return ndarray._new(np.invert(x._array)) @@ -116,6 +142,8 @@ def bitwise_or(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') return ndarray._new(np.bitwise_or(x1._array, x2._array)) def bitwise_right_shift(x1: array, x2: array, /) -> array: @@ -124,6 +152,8 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: + raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') # Note: the function name is different here return ndarray._new(np.right_shift(x1._array, x2._array)) @@ -133,6 +163,8 @@ def bitwise_xor(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') return ndarray._new(np.bitwise_xor(x1._array, x2._array)) def ceil(x: array, /) -> array: @@ -141,6 +173,8 @@ def ceil(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in ceil') return ndarray._new(np.ceil(x._array)) def cos(x: array, /) -> array: @@ -149,6 +183,8 @@ def cos(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in cos') return ndarray._new(np.cos(x._array)) def cosh(x: array, /) -> array: @@ -157,6 +193,8 @@ def cosh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in cosh') return ndarray._new(np.cosh(x._array)) def divide(x1: array, x2: array, /) -> array: @@ -165,6 +203,8 @@ def divide(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in divide') return ndarray._new(np.divide(x1._array, x2._array)) def equal(x1: array, x2: array, /) -> array: @@ -173,6 +213,8 @@ def equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes: + raise TypeError('Only array API spec dtypes are allowed in equal') return ndarray._new(np.equal(x1._array, x2._array)) def exp(x: array, /) -> array: @@ -181,6 +223,8 @@ def exp(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in exp') return ndarray._new(np.exp(x._array)) def expm1(x: array, /) -> array: @@ -189,6 +233,8 @@ def expm1(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in expm1') return ndarray._new(np.expm1(x._array)) def floor(x: array, /) -> array: @@ -197,6 +243,8 @@ def floor(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in floor') return ndarray._new(np.floor(x._array)) def floor_divide(x1: array, x2: array, /) -> array: @@ -205,6 +253,8 @@ def floor_divide(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in floor_divide') return ndarray._new(np.floor_divide(x1._array, x2._array)) def greater(x1: array, x2: array, /) -> array: @@ -213,6 +263,8 @@ def greater(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in greater') return ndarray._new(np.greater(x1._array, x2._array)) def greater_equal(x1: array, x2: array, /) -> array: @@ -221,6 +273,8 @@ def greater_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in greater_equal') return ndarray._new(np.greater_equal(x1._array, x2._array)) def isfinite(x: array, /) -> array: @@ -229,6 +283,8 @@ def isfinite(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isfinite') return ndarray._new(np.isfinite(x._array)) def isinf(x: array, /) -> array: @@ -237,6 +293,8 @@ def isinf(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isinf') return ndarray._new(np.isinf(x._array)) def isnan(x: array, /) -> array: @@ -245,6 +303,8 @@ def isnan(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isnan') return ndarray._new(np.isnan(x._array)) def less(x1: array, x2: array, /) -> array: @@ -253,6 +313,8 @@ def less(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in less') return ndarray._new(np.less(x1._array, x2._array)) def less_equal(x1: array, x2: array, /) -> array: @@ -261,6 +323,8 @@ def less_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in less_equal') return ndarray._new(np.less_equal(x1._array, x2._array)) def log(x: array, /) -> array: @@ -269,6 +333,8 @@ def log(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log') return ndarray._new(np.log(x._array)) def log1p(x: array, /) -> array: @@ -277,6 +343,8 @@ def log1p(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log1p') return ndarray._new(np.log1p(x._array)) def log2(x: array, /) -> array: @@ -285,6 +353,8 @@ def log2(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log2') return ndarray._new(np.log2(x._array)) def log10(x: array, /) -> array: @@ -293,6 +363,8 @@ def log10(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log10') return ndarray._new(np.log10(x._array)) def logaddexp(x1: array, x2: array) -> array: @@ -301,6 +373,8 @@ def logaddexp(x1: array, x2: array) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in logaddexp') return ndarray._new(np.logaddexp(x1._array, x2._array)) def logical_and(x1: array, x2: array, /) -> array: @@ -309,6 +383,8 @@ def logical_and(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_and') return ndarray._new(np.logical_and(x1._array, x2._array)) def logical_not(x: array, /) -> array: @@ -317,6 +393,8 @@ def logical_not(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_not') return ndarray._new(np.logical_not(x._array)) def logical_or(x1: array, x2: array, /) -> array: @@ -325,6 +403,8 @@ def logical_or(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_or') return ndarray._new(np.logical_or(x1._array, x2._array)) def logical_xor(x1: array, x2: array, /) -> array: @@ -333,6 +413,8 @@ def logical_xor(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_xor') return ndarray._new(np.logical_xor(x1._array, x2._array)) def multiply(x1: array, x2: array, /) -> array: @@ -341,6 +423,8 @@ def multiply(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in multiply') return ndarray._new(np.multiply(x1._array, x2._array)) def negative(x: array, /) -> array: @@ -349,6 +433,8 @@ def negative(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in negative') return ndarray._new(np.negative(x._array)) def not_equal(x1: array, x2: array, /) -> array: @@ -357,6 +443,8 @@ def not_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes: + raise TypeError('Only array API spec dtypes are allowed in not_equal') return ndarray._new(np.not_equal(x1._array, x2._array)) def positive(x: array, /) -> array: @@ -365,6 +453,8 @@ def positive(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in positive') return ndarray._new(np.positive(x._array)) def pow(x1: array, x2: array, /) -> array: @@ -373,6 +463,8 @@ def pow(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in pow') # Note: the function name is different here return ndarray._new(np.power(x1._array, x2._array)) @@ -382,6 +474,8 @@ def remainder(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in remainder') return ndarray._new(np.remainder(x1._array, x2._array)) def round(x: array, /) -> array: @@ -390,6 +484,8 @@ def round(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in round') return ndarray._new(np.round._implementation(x._array)) def sign(x: array, /) -> array: @@ -398,6 +494,8 @@ def sign(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in sign') return ndarray._new(np.sign(x._array)) def sin(x: array, /) -> array: @@ -406,6 +504,8 @@ def sin(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sin') return ndarray._new(np.sin(x._array)) def sinh(x: array, /) -> array: @@ -414,6 +514,8 @@ def sinh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sinh') return ndarray._new(np.sinh(x._array)) def square(x: array, /) -> array: @@ -422,6 +524,8 @@ def square(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in square') return ndarray._new(np.square(x._array)) def sqrt(x: array, /) -> array: @@ -430,6 +534,8 @@ def sqrt(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sqrt') return ndarray._new(np.sqrt(x._array)) def subtract(x1: array, x2: array, /) -> array: @@ -438,6 +544,8 @@ def subtract(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in subtract') return ndarray._new(np.subtract(x1._array, x2._array)) def tan(x: array, /) -> array: @@ -446,6 +554,8 @@ def tan(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in tan') return ndarray._new(np.tan(x._array)) def tanh(x: array, /) -> array: @@ -454,6 +564,8 @@ def tanh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in tanh') return ndarray._new(np.tanh(x._array)) def trunc(x: array, /) -> array: @@ -462,4 +574,6 @@ def trunc(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trunc') return ndarray._new(np.trunc(x._array)) |