summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_elementwise_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-01-20 16:11:17 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-01-20 16:11:17 -0700
commitbe1b1932f73fb5946b4867337ba2fd2d31964d11 (patch)
tree8a05934e452a13c9b6f93ffa238809f90104c777 /numpy/_array_api/_elementwise_functions.py
parentdf698f80732508af50b24ecc1b4bd34c470aaba8 (diff)
downloadnumpy-be1b1932f73fb5946b4867337ba2fd2d31964d11.tar.gz
Add type annotations to the array api submodule function definitions
Some stubs still need to be modified to properly pass mypy type checking. Also, 'device' is just left as a TypeVar() for now.
Diffstat (limited to 'numpy/_array_api/_elementwise_functions.py')
-rw-r--r--numpy/_array_api/_elementwise_functions.py114
1 files changed, 59 insertions, 55 deletions
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index 7ec01b2e1..9c013d8b4 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import array
+
import numpy as np
-def abs(x, /):
+def abs(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`.
@@ -8,7 +12,7 @@ def abs(x, /):
"""
return np.abs(x)
-def acos(x, /):
+def acos(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arccos <numpy.arccos>`.
@@ -17,7 +21,7 @@ def acos(x, /):
# Note: the function name is different here
return np.arccos(x)
-def acosh(x, /):
+def acosh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arccosh <numpy.arccosh>`.
@@ -26,7 +30,7 @@ def acosh(x, /):
# Note: the function name is different here
return np.arccosh(x)
-def add(x1, x2, /):
+def add(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.add <numpy.add>`.
@@ -34,7 +38,7 @@ def add(x1, x2, /):
"""
return np.add(x1, x2)
-def asin(x, /):
+def asin(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arcsin <numpy.arcsin>`.
@@ -43,7 +47,7 @@ def asin(x, /):
# Note: the function name is different here
return np.arcsin(x)
-def asinh(x, /):
+def asinh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arcsinh <numpy.arcsinh>`.
@@ -52,7 +56,7 @@ def asinh(x, /):
# Note: the function name is different here
return np.arcsinh(x)
-def atan(x, /):
+def atan(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arctan <numpy.arctan>`.
@@ -61,7 +65,7 @@ def atan(x, /):
# Note: the function name is different here
return np.arctan(x)
-def atan2(x1, x2, /):
+def atan2(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`.
@@ -70,7 +74,7 @@ def atan2(x1, x2, /):
# Note: the function name is different here
return np.arctan2(x1, x2)
-def atanh(x, /):
+def atanh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arctanh <numpy.arctanh>`.
@@ -79,7 +83,7 @@ def atanh(x, /):
# Note: the function name is different here
return np.arctanh(x)
-def bitwise_and(x1, x2, /):
+def bitwise_and(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`.
@@ -87,7 +91,7 @@ def bitwise_and(x1, x2, /):
"""
return np.bitwise_and(x1, x2)
-def bitwise_left_shift(x1, x2, /):
+def bitwise_left_shift(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`.
@@ -96,7 +100,7 @@ def bitwise_left_shift(x1, x2, /):
# Note: the function name is different here
return np.left_shift(x1, x2)
-def bitwise_invert(x, /):
+def bitwise_invert(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.invert <numpy.invert>`.
@@ -105,7 +109,7 @@ def bitwise_invert(x, /):
# Note: the function name is different here
return np.invert(x)
-def bitwise_or(x1, x2, /):
+def bitwise_or(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`.
@@ -113,7 +117,7 @@ def bitwise_or(x1, x2, /):
"""
return np.bitwise_or(x1, x2)
-def bitwise_right_shift(x1, x2, /):
+def bitwise_right_shift(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`.
@@ -122,7 +126,7 @@ def bitwise_right_shift(x1, x2, /):
# Note: the function name is different here
return np.right_shift(x1, x2)
-def bitwise_xor(x1, x2, /):
+def bitwise_xor(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`.
@@ -130,7 +134,7 @@ def bitwise_xor(x1, x2, /):
"""
return np.bitwise_xor(x1, x2)
-def ceil(x, /):
+def ceil(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`.
@@ -138,7 +142,7 @@ def ceil(x, /):
"""
return np.ceil(x)
-def cos(x, /):
+def cos(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`.
@@ -146,7 +150,7 @@ def cos(x, /):
"""
return np.cos(x)
-def cosh(x, /):
+def cosh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`.
@@ -154,7 +158,7 @@ def cosh(x, /):
"""
return np.cosh(x)
-def divide(x1, x2, /):
+def divide(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`.
@@ -162,7 +166,7 @@ def divide(x1, x2, /):
"""
return np.divide(x1, x2)
-def equal(x1, x2, /):
+def equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`.
@@ -170,7 +174,7 @@ def equal(x1, x2, /):
"""
return np.equal(x1, x2)
-def exp(x, /):
+def exp(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`.
@@ -178,7 +182,7 @@ def exp(x, /):
"""
return np.exp(x)
-def expm1(x, /):
+def expm1(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`.
@@ -186,7 +190,7 @@ def expm1(x, /):
"""
return np.expm1(x)
-def floor(x, /):
+def floor(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`.
@@ -194,7 +198,7 @@ def floor(x, /):
"""
return np.floor(x)
-def floor_divide(x1, x2, /):
+def floor_divide(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`.
@@ -202,7 +206,7 @@ def floor_divide(x1, x2, /):
"""
return np.floor_divide(x1, x2)
-def greater(x1, x2, /):
+def greater(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`.
@@ -210,7 +214,7 @@ def greater(x1, x2, /):
"""
return np.greater(x1, x2)
-def greater_equal(x1, x2, /):
+def greater_equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`.
@@ -218,7 +222,7 @@ def greater_equal(x1, x2, /):
"""
return np.greater_equal(x1, x2)
-def isfinite(x, /):
+def isfinite(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`.
@@ -226,7 +230,7 @@ def isfinite(x, /):
"""
return np.isfinite(x)
-def isinf(x, /):
+def isinf(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`.
@@ -234,7 +238,7 @@ def isinf(x, /):
"""
return np.isinf(x)
-def isnan(x, /):
+def isnan(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`.
@@ -242,7 +246,7 @@ def isnan(x, /):
"""
return np.isnan(x)
-def less(x1, x2, /):
+def less(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.less <numpy.less>`.
@@ -250,7 +254,7 @@ def less(x1, x2, /):
"""
return np.less(x1, x2)
-def less_equal(x1, x2, /):
+def less_equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
@@ -258,7 +262,7 @@ def less_equal(x1, x2, /):
"""
return np.less_equal(x1, x2)
-def log(x, /):
+def log(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log <numpy.log>`.
@@ -266,7 +270,7 @@ def log(x, /):
"""
return np.log(x)
-def log1p(x, /):
+def log1p(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`.
@@ -274,7 +278,7 @@ def log1p(x, /):
"""
return np.log1p(x)
-def log2(x, /):
+def log2(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`.
@@ -282,7 +286,7 @@ def log2(x, /):
"""
return np.log2(x)
-def log10(x, /):
+def log10(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`.
@@ -290,7 +294,7 @@ def log10(x, /):
"""
return np.log10(x)
-def logical_and(x1, x2, /):
+def logical_and(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`.
@@ -298,7 +302,7 @@ def logical_and(x1, x2, /):
"""
return np.logical_and(x1, x2)
-def logical_not(x, /):
+def logical_not(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`.
@@ -306,7 +310,7 @@ def logical_not(x, /):
"""
return np.logical_not(x)
-def logical_or(x1, x2, /):
+def logical_or(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`.
@@ -314,7 +318,7 @@ def logical_or(x1, x2, /):
"""
return np.logical_or(x1, x2)
-def logical_xor(x1, x2, /):
+def logical_xor(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`.
@@ -322,7 +326,7 @@ def logical_xor(x1, x2, /):
"""
return np.logical_xor(x1, x2)
-def multiply(x1, x2, /):
+def multiply(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`.
@@ -330,7 +334,7 @@ def multiply(x1, x2, /):
"""
return np.multiply(x1, x2)
-def negative(x, /):
+def negative(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`.
@@ -338,7 +342,7 @@ def negative(x, /):
"""
return np.negative(x)
-def not_equal(x1, x2, /):
+def not_equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
@@ -346,7 +350,7 @@ def not_equal(x1, x2, /):
"""
return np.not_equal(x1, x2)
-def positive(x, /):
+def positive(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`.
@@ -354,7 +358,7 @@ def positive(x, /):
"""
return np.positive(x)
-def pow(x1, x2, /):
+def pow(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.power <numpy.power>`.
@@ -363,7 +367,7 @@ def pow(x1, x2, /):
# Note: the function name is different here
return np.power(x1, x2)
-def remainder(x1, x2, /):
+def remainder(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
@@ -371,7 +375,7 @@ def remainder(x1, x2, /):
"""
return np.remainder(x1, x2)
-def round(x, /):
+def round(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.round <numpy.round>`.
@@ -379,7 +383,7 @@ def round(x, /):
"""
return np.round(x)
-def sign(x, /):
+def sign(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`.
@@ -387,7 +391,7 @@ def sign(x, /):
"""
return np.sign(x)
-def sin(x, /):
+def sin(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`.
@@ -395,7 +399,7 @@ def sin(x, /):
"""
return np.sin(x)
-def sinh(x, /):
+def sinh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`.
@@ -403,7 +407,7 @@ def sinh(x, /):
"""
return np.sinh(x)
-def square(x, /):
+def square(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.square <numpy.square>`.
@@ -411,7 +415,7 @@ def square(x, /):
"""
return np.square(x)
-def sqrt(x, /):
+def sqrt(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`.
@@ -419,7 +423,7 @@ def sqrt(x, /):
"""
return np.sqrt(x)
-def subtract(x1, x2, /):
+def subtract(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`.
@@ -427,7 +431,7 @@ def subtract(x1, x2, /):
"""
return np.subtract(x1, x2)
-def tan(x, /):
+def tan(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`.
@@ -435,7 +439,7 @@ def tan(x, /):
"""
return np.tan(x)
-def tanh(x, /):
+def tanh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`.
@@ -443,7 +447,7 @@ def tanh(x, /):
"""
return np.tanh(x)
-def trunc(x, /):
+def trunc(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`.