diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-02-26 17:41:18 -0700 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-02-26 17:46:18 -0700 |
commit | 892b536a36b89f362a845fd50959d6474ec2c5f4 (patch) | |
tree | 9fb32216466242fd04b56d914477f6121be6dc6f /numpy/_array_api/_elementwise_functions.py | |
parent | 587613f056299766be2da00a64b5fa0ac31c84aa (diff) | |
download | numpy-892b536a36b89f362a845fd50959d6474ec2c5f4.tar.gz |
Only allow the spec guaranteed dtypes in the array API elementwise functions
The array API namespace is designed to be only those parts of specification
that are required. So many things that work in NumPy but are not required by
the array API specification will not work in the array_api namespace
functions. For example, transcendental functions will only work with
floating-point dtypes, because those are the only dtypes required to work by
the array API specification.
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r-- | numpy/_array_api/_elementwise_functions.py | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index abb7ef4dd..2357b337c 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -1,5 +1,7 @@ from __future__ import annotations +from ._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes) from ._types import array from ._array_object import ndarray @@ -11,6 +13,8 @@ def abs(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in abs') return ndarray._new(np.abs(x._array)) def acos(x: array, /) -> array: @@ -19,6 +23,8 @@ def acos(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in acos') # Note: the function name is different here return ndarray._new(np.arccos(x._array)) @@ -28,6 +34,8 @@ def acosh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in acosh') # Note: the function name is different here return ndarray._new(np.arccosh(x._array)) @@ -37,6 +45,8 @@ def add(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in add') return ndarray._new(np.add(x1._array, x2._array)) def asin(x: array, /) -> array: @@ -45,6 +55,8 @@ def asin(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in asin') # Note: the function name is different here return ndarray._new(np.arcsin(x._array)) @@ -54,6 +66,8 @@ def asinh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in asinh') # Note: the function name is different here return ndarray._new(np.arcsinh(x._array)) @@ -63,6 +77,8 @@ def atan(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atan') # Note: the function name is different here return ndarray._new(np.arctan(x._array)) @@ -72,6 +88,8 @@ def atan2(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atan2') # Note: the function name is different here return ndarray._new(np.arctan2(x1._array, x2._array)) @@ -81,6 +99,8 @@ def atanh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in atanh') # Note: the function name is different here return ndarray._new(np.arctanh(x._array)) @@ -90,6 +110,8 @@ def bitwise_and(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer_or_boolean dtypes are allowed in bitwise_and') return ndarray._new(np.bitwise_and(x1._array, x2._array)) def bitwise_left_shift(x1: array, x2: array, /) -> array: @@ -98,6 +120,8 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: + raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') # Note: the function name is different here return ndarray._new(np.left_shift(x1._array, x2._array)) @@ -107,6 +131,8 @@ def bitwise_invert(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert') # Note: the function name is different here return ndarray._new(np.invert(x._array)) @@ -116,6 +142,8 @@ def bitwise_or(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') return ndarray._new(np.bitwise_or(x1._array, x2._array)) def bitwise_right_shift(x1: array, x2: array, /) -> array: @@ -124,6 +152,8 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: + raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') # Note: the function name is different here return ndarray._new(np.right_shift(x1._array, x2._array)) @@ -133,6 +163,8 @@ def bitwise_xor(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: + raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') return ndarray._new(np.bitwise_xor(x1._array, x2._array)) def ceil(x: array, /) -> array: @@ -141,6 +173,8 @@ def ceil(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in ceil') return ndarray._new(np.ceil(x._array)) def cos(x: array, /) -> array: @@ -149,6 +183,8 @@ def cos(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in cos') return ndarray._new(np.cos(x._array)) def cosh(x: array, /) -> array: @@ -157,6 +193,8 @@ def cosh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in cosh') return ndarray._new(np.cosh(x._array)) def divide(x1: array, x2: array, /) -> array: @@ -165,6 +203,8 @@ def divide(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in divide') return ndarray._new(np.divide(x1._array, x2._array)) def equal(x1: array, x2: array, /) -> array: @@ -173,6 +213,8 @@ def equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes: + raise TypeError('Only array API spec dtypes are allowed in equal') return ndarray._new(np.equal(x1._array, x2._array)) def exp(x: array, /) -> array: @@ -181,6 +223,8 @@ def exp(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in exp') return ndarray._new(np.exp(x._array)) def expm1(x: array, /) -> array: @@ -189,6 +233,8 @@ def expm1(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in expm1') return ndarray._new(np.expm1(x._array)) def floor(x: array, /) -> array: @@ -197,6 +243,8 @@ def floor(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in floor') return ndarray._new(np.floor(x._array)) def floor_divide(x1: array, x2: array, /) -> array: @@ -205,6 +253,8 @@ def floor_divide(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in floor_divide') return ndarray._new(np.floor_divide(x1._array, x2._array)) def greater(x1: array, x2: array, /) -> array: @@ -213,6 +263,8 @@ def greater(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in greater') return ndarray._new(np.greater(x1._array, x2._array)) def greater_equal(x1: array, x2: array, /) -> array: @@ -221,6 +273,8 @@ def greater_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in greater_equal') return ndarray._new(np.greater_equal(x1._array, x2._array)) def isfinite(x: array, /) -> array: @@ -229,6 +283,8 @@ def isfinite(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isfinite') return ndarray._new(np.isfinite(x._array)) def isinf(x: array, /) -> array: @@ -237,6 +293,8 @@ def isinf(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isinf') return ndarray._new(np.isinf(x._array)) def isnan(x: array, /) -> array: @@ -245,6 +303,8 @@ def isnan(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in isnan') return ndarray._new(np.isnan(x._array)) def less(x1: array, x2: array, /) -> array: @@ -253,6 +313,8 @@ def less(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in less') return ndarray._new(np.less(x1._array, x2._array)) def less_equal(x1: array, x2: array, /) -> array: @@ -261,6 +323,8 @@ def less_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in less_equal') return ndarray._new(np.less_equal(x1._array, x2._array)) def log(x: array, /) -> array: @@ -269,6 +333,8 @@ def log(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log') return ndarray._new(np.log(x._array)) def log1p(x: array, /) -> array: @@ -277,6 +343,8 @@ def log1p(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log1p') return ndarray._new(np.log1p(x._array)) def log2(x: array, /) -> array: @@ -285,6 +353,8 @@ def log2(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log2') return ndarray._new(np.log2(x._array)) def log10(x: array, /) -> array: @@ -293,6 +363,8 @@ def log10(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in log10') return ndarray._new(np.log10(x._array)) def logaddexp(x1: array, x2: array) -> array: @@ -301,6 +373,8 @@ def logaddexp(x1: array, x2: array) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in logaddexp') return ndarray._new(np.logaddexp(x1._array, x2._array)) def logical_and(x1: array, x2: array, /) -> array: @@ -309,6 +383,8 @@ def logical_and(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_and') return ndarray._new(np.logical_and(x1._array, x2._array)) def logical_not(x: array, /) -> array: @@ -317,6 +393,8 @@ def logical_not(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_not') return ndarray._new(np.logical_not(x._array)) def logical_or(x1: array, x2: array, /) -> array: @@ -325,6 +403,8 @@ def logical_or(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_or') return ndarray._new(np.logical_or(x1._array, x2._array)) def logical_xor(x1: array, x2: array, /) -> array: @@ -333,6 +413,8 @@ def logical_xor(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: + raise TypeError('Only boolean dtypes are allowed in logical_xor') return ndarray._new(np.logical_xor(x1._array, x2._array)) def multiply(x1: array, x2: array, /) -> array: @@ -341,6 +423,8 @@ def multiply(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in multiply') return ndarray._new(np.multiply(x1._array, x2._array)) def negative(x: array, /) -> array: @@ -349,6 +433,8 @@ def negative(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in negative') return ndarray._new(np.negative(x._array)) def not_equal(x1: array, x2: array, /) -> array: @@ -357,6 +443,8 @@ def not_equal(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _all_dtypes or x2.dtype not in _all_dtypes: + raise TypeError('Only array API spec dtypes are allowed in not_equal') return ndarray._new(np.not_equal(x1._array, x2._array)) def positive(x: array, /) -> array: @@ -365,6 +453,8 @@ def positive(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in positive') return ndarray._new(np.positive(x._array)) def pow(x1: array, x2: array, /) -> array: @@ -373,6 +463,8 @@ def pow(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in pow') # Note: the function name is different here return ndarray._new(np.power(x1._array, x2._array)) @@ -382,6 +474,8 @@ def remainder(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in remainder') return ndarray._new(np.remainder(x1._array, x2._array)) def round(x: array, /) -> array: @@ -390,6 +484,8 @@ def round(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in round') return ndarray._new(np.round._implementation(x._array)) def sign(x: array, /) -> array: @@ -398,6 +494,8 @@ def sign(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in sign') return ndarray._new(np.sign(x._array)) def sin(x: array, /) -> array: @@ -406,6 +504,8 @@ def sin(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sin') return ndarray._new(np.sin(x._array)) def sinh(x: array, /) -> array: @@ -414,6 +514,8 @@ def sinh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sinh') return ndarray._new(np.sinh(x._array)) def square(x: array, /) -> array: @@ -422,6 +524,8 @@ def square(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in square') return ndarray._new(np.square(x._array)) def sqrt(x: array, /) -> array: @@ -430,6 +534,8 @@ def sqrt(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in sqrt') return ndarray._new(np.sqrt(x._array)) def subtract(x1: array, x2: array, /) -> array: @@ -438,6 +544,8 @@ def subtract(x1: array, x2: array, /) -> array: See its docstring for more information. """ + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in subtract') return ndarray._new(np.subtract(x1._array, x2._array)) def tan(x: array, /) -> array: @@ -446,6 +554,8 @@ def tan(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in tan') return ndarray._new(np.tan(x._array)) def tanh(x: array, /) -> array: @@ -454,6 +564,8 @@ def tanh(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in tanh') return ndarray._new(np.tanh(x._array)) def trunc(x: array, /) -> array: @@ -462,4 +574,6 @@ def trunc(x: array, /) -> array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trunc') return ndarray._new(np.trunc(x._array)) |