summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_elementwise_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r--numpy/_array_api/_elementwise_functions.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index abb7ef4dd..2357b337c 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+from ._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
+ _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes)
from ._types import array
from ._array_object import ndarray
@@ -11,6 +13,8 @@ def abs(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in abs')
return ndarray._new(np.abs(x._array))
def acos(x: array, /) -> array:
@@ -19,6 +23,8 @@ def acos(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in acos')
# Note: the function name is different here
return ndarray._new(np.arccos(x._array))
@@ -28,6 +34,8 @@ def acosh(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in acosh')
# Note: the function name is different here
return ndarray._new(np.arccosh(x._array))
@@ -37,6 +45,8 @@ def add(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in add')
return ndarray._new(np.add(x1._array, x2._array))
def asin(x: array, /) -> array:
@@ -45,6 +55,8 @@ def asin(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in asin')
# Note: the function name is different here
return ndarray._new(np.arcsin(x._array))
@@ -54,6 +66,8 @@ def asinh(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in asinh')
# Note: the function name is different here
return ndarray._new(np.arcsinh(x._array))
@@ -63,6 +77,8 @@ def atan(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in atan')
# Note: the function name is different here
return ndarray._new(np.arctan(x._array))
@@ -72,6 +88,8 @@ def atan2(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in atan2')
# Note: the function name is different here
return ndarray._new(np.arctan2(x1._array, x2._array))
@@ -81,6 +99,8 @@ def atanh(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in atanh')
# Note: the function name is different here
return ndarray._new(np.arctanh(x._array))
@@ -90,6 +110,8 @@ def bitwise_and(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError('Only integer_or_boolean dtypes are allowed in bitwise_and')
return ndarray._new(np.bitwise_and(x1._array, x2._array))
def bitwise_left_shift(x1: array, x2: array, /) -> array:
@@ -98,6 +120,8 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
+ raise TypeError('Only integer dtypes are allowed in bitwise_left_shift')
# Note: the function name is different here
return ndarray._new(np.left_shift(x1._array, x2._array))
@@ -107,6 +131,8 @@ def bitwise_invert(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert')
# Note: the function name is different here
return ndarray._new(np.invert(x._array))
@@ -116,6 +142,8 @@ def bitwise_or(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or')
return ndarray._new(np.bitwise_or(x1._array, x2._array))
def bitwise_right_shift(x1: array, x2: array, /) -> array:
@@ -124,6 +152,8 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
+ raise TypeError('Only integer dtypes are allowed in bitwise_right_shift')
# Note: the function name is different here
return ndarray._new(np.right_shift(x1._array, x2._array))
@@ -133,6 +163,8 @@ def bitwise_xor(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor')
return ndarray._new(np.bitwise_xor(x1._array, x2._array))
def ceil(x: array, /) -> array:
@@ -141,6 +173,8 @@ def ceil(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in ceil')
return ndarray._new(np.ceil(x._array))
def cos(x: array, /) -> array:
@@ -149,6 +183,8 @@ def cos(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in cos')
return ndarray._new(np.cos(x._array))
def cosh(x: array, /) -> array:
@@ -157,6 +193,8 @@ def cosh(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in cosh')
return ndarray._new(np.cosh(x._array))
def divide(x1: array, x2: array, /) -> array:
@@ -165,6 +203,8 @@ def divide(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in divide')
return ndarray._new(np.divide(x1._array, x2._array))
def equal(x1: array, x2: array, /) -> array:
@@ -173,6 +213,8 @@ def equal(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes:
+ raise TypeError('Only array API spec dtypes are allowed in equal')
return ndarray._new(np.equal(x1._array, x2._array))
def exp(x: array, /) -> array:
@@ -181,6 +223,8 @@ def exp(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in exp')
return ndarray._new(np.exp(x._array))
def expm1(x: array, /) -> array:
@@ -189,6 +233,8 @@ def expm1(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in expm1')
return ndarray._new(np.expm1(x._array))
def floor(x: array, /) -> array:
@@ -197,6 +243,8 @@ def floor(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in floor')
return ndarray._new(np.floor(x._array))
def floor_divide(x1: array, x2: array, /) -> array:
@@ -205,6 +253,8 @@ def floor_divide(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in floor_divide')
return ndarray._new(np.floor_divide(x1._array, x2._array))
def greater(x1: array, x2: array, /) -> array:
@@ -213,6 +263,8 @@ def greater(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in greater')
return ndarray._new(np.greater(x1._array, x2._array))
def greater_equal(x1: array, x2: array, /) -> array:
@@ -221,6 +273,8 @@ def greater_equal(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in greater_equal')
return ndarray._new(np.greater_equal(x1._array, x2._array))
def isfinite(x: array, /) -> array:
@@ -229,6 +283,8 @@ def isfinite(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in isfinite')
return ndarray._new(np.isfinite(x._array))
def isinf(x: array, /) -> array:
@@ -237,6 +293,8 @@ def isinf(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in isinf')
return ndarray._new(np.isinf(x._array))
def isnan(x: array, /) -> array:
@@ -245,6 +303,8 @@ def isnan(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in isnan')
return ndarray._new(np.isnan(x._array))
def less(x1: array, x2: array, /) -> array:
@@ -253,6 +313,8 @@ def less(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in less')
return ndarray._new(np.less(x1._array, x2._array))
def less_equal(x1: array, x2: array, /) -> array:
@@ -261,6 +323,8 @@ def less_equal(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in less_equal')
return ndarray._new(np.less_equal(x1._array, x2._array))
def log(x: array, /) -> array:
@@ -269,6 +333,8 @@ def log(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in log')
return ndarray._new(np.log(x._array))
def log1p(x: array, /) -> array:
@@ -277,6 +343,8 @@ def log1p(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in log1p')
return ndarray._new(np.log1p(x._array))
def log2(x: array, /) -> array:
@@ -285,6 +353,8 @@ def log2(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in log2')
return ndarray._new(np.log2(x._array))
def log10(x: array, /) -> array:
@@ -293,6 +363,8 @@ def log10(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in log10')
return ndarray._new(np.log10(x._array))
def logaddexp(x1: array, x2: array) -> array:
@@ -301,6 +373,8 @@ def logaddexp(x1: array, x2: array) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in logaddexp')
return ndarray._new(np.logaddexp(x1._array, x2._array))
def logical_and(x1: array, x2: array, /) -> array:
@@ -309,6 +383,8 @@ def logical_and(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
+ raise TypeError('Only boolean dtypes are allowed in logical_and')
return ndarray._new(np.logical_and(x1._array, x2._array))
def logical_not(x: array, /) -> array:
@@ -317,6 +393,8 @@ def logical_not(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _boolean_dtypes:
+ raise TypeError('Only boolean dtypes are allowed in logical_not')
return ndarray._new(np.logical_not(x._array))
def logical_or(x1: array, x2: array, /) -> array:
@@ -325,6 +403,8 @@ def logical_or(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
+ raise TypeError('Only boolean dtypes are allowed in logical_or')
return ndarray._new(np.logical_or(x1._array, x2._array))
def logical_xor(x1: array, x2: array, /) -> array:
@@ -333,6 +413,8 @@ def logical_xor(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
+ raise TypeError('Only boolean dtypes are allowed in logical_xor')
return ndarray._new(np.logical_xor(x1._array, x2._array))
def multiply(x1: array, x2: array, /) -> array:
@@ -341,6 +423,8 @@ def multiply(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in multiply')
return ndarray._new(np.multiply(x1._array, x2._array))
def negative(x: array, /) -> array:
@@ -349,6 +433,8 @@ def negative(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in negative')
return ndarray._new(np.negative(x._array))
def not_equal(x1: array, x2: array, /) -> array:
@@ -357,6 +443,8 @@ def not_equal(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes:
+ raise TypeError('Only array API spec dtypes are allowed in not_equal')
return ndarray._new(np.not_equal(x1._array, x2._array))
def positive(x: array, /) -> array:
@@ -365,6 +453,8 @@ def positive(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in positive')
return ndarray._new(np.positive(x._array))
def pow(x1: array, x2: array, /) -> array:
@@ -373,6 +463,8 @@ def pow(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in pow')
# Note: the function name is different here
return ndarray._new(np.power(x1._array, x2._array))
@@ -382,6 +474,8 @@ def remainder(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in remainder')
return ndarray._new(np.remainder(x1._array, x2._array))
def round(x: array, /) -> array:
@@ -390,6 +484,8 @@ def round(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in round')
return ndarray._new(np.round._implementation(x._array))
def sign(x: array, /) -> array:
@@ -398,6 +494,8 @@ def sign(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in sign')
return ndarray._new(np.sign(x._array))
def sin(x: array, /) -> array:
@@ -406,6 +504,8 @@ def sin(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in sin')
return ndarray._new(np.sin(x._array))
def sinh(x: array, /) -> array:
@@ -414,6 +514,8 @@ def sinh(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in sinh')
return ndarray._new(np.sinh(x._array))
def square(x: array, /) -> array:
@@ -422,6 +524,8 @@ def square(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in square')
return ndarray._new(np.square(x._array))
def sqrt(x: array, /) -> array:
@@ -430,6 +534,8 @@ def sqrt(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in sqrt')
return ndarray._new(np.sqrt(x._array))
def subtract(x1: array, x2: array, /) -> array:
@@ -438,6 +544,8 @@ def subtract(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in subtract')
return ndarray._new(np.subtract(x1._array, x2._array))
def tan(x: array, /) -> array:
@@ -446,6 +554,8 @@ def tan(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in tan')
return ndarray._new(np.tan(x._array))
def tanh(x: array, /) -> array:
@@ -454,6 +564,8 @@ def tanh(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in tanh')
return ndarray._new(np.tanh(x._array))
def trunc(x: array, /) -> array:
@@ -462,4 +574,6 @@ def trunc(x: array, /) -> array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in trunc')
return ndarray._new(np.trunc(x._array))