summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/_array_api/_creation_functions.py26
-rw-r--r--numpy/_array_api/_elementwise_functions.py114
-rw-r--r--numpy/_array_api/_linear_algebra_functions.py20
-rw-r--r--numpy/_array_api/_manipulation_functions.py18
-rw-r--r--numpy/_array_api/_searching_functions.py12
-rw-r--r--numpy/_array_api/_set_functions.py6
-rw-r--r--numpy/_array_api/_sorting_functions.py8
-rw-r--r--numpy/_array_api/_statistical_functions.py18
-rw-r--r--numpy/_array_api/_types.py18
-rw-r--r--numpy/_array_api/_utility_functions.py8
10 files changed, 151 insertions, 97 deletions
diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py
index b6c0c22cc..1aeaffb71 100644
--- a/numpy/_array_api/_creation_functions.py
+++ b/numpy/_array_api/_creation_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import Optional, Tuple, Union, array, device, dtype
+
import numpy as np
-def arange(start, /, *, stop=None, step=1, dtype=None, device=None):
+def arange(start: Union[int, float], /, *, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
@@ -11,7 +15,7 @@ def arange(start, /, *, stop=None, step=1, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.arange(start, stop=stop, step=step, dtype=dtype)
-def empty(shape, /, *, dtype=None, device=None):
+def empty(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
@@ -22,7 +26,7 @@ def empty(shape, /, *, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.empty(shape, dtype=dtype)
-def empty_like(x, /, *, dtype=None, device=None):
+def empty_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
@@ -33,7 +37,7 @@ def empty_like(x, /, *, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.empty_like(x, dtype=dtype)
-def eye(N, /, *, M=None, k=0, dtype=None, device=None):
+def eye(N: int, /, *, M: Optional[int] = None, k: Optional[int] = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
@@ -44,7 +48,7 @@ def eye(N, /, *, M=None, k=0, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.eye(N, M=M, k=k, dtype=dtype)
-def full(shape, fill_value, /, *, dtype=None, device=None):
+def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
@@ -55,7 +59,7 @@ def full(shape, fill_value, /, *, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.full(shape, fill_value, dtype=dtype)
-def full_like(x, fill_value, /, *, dtype=None, device=None):
+def full_like(x: array, fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
@@ -66,7 +70,7 @@ def full_like(x, fill_value, /, *, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.full_like(x, fill_value, dtype=dtype)
-def linspace(start, stop, num, /, *, dtype=None, device=None, endpoint=True):
+def linspace(start: Union[int, float], stop: Union[int, float], num: int, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: Optional[bool] = True) -> array:
"""
Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
@@ -77,7 +81,7 @@ def linspace(start, stop, num, /, *, dtype=None, device=None, endpoint=True):
raise NotImplementedError("Device support is not yet implemented")
return np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
-def ones(shape, /, *, dtype=None, device=None):
+def ones(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
@@ -88,7 +92,7 @@ def ones(shape, /, *, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.ones(shape, dtype=dtype)
-def ones_like(x, /, *, dtype=None, device=None):
+def ones_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
@@ -99,7 +103,7 @@ def ones_like(x, /, *, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.ones_like(x, dtype=dtype)
-def zeros(shape, /, *, dtype=None, device=None):
+def zeros(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
@@ -110,7 +114,7 @@ def zeros(shape, /, *, dtype=None, device=None):
raise NotImplementedError("Device support is not yet implemented")
return np.zeros(shape, dtype=dtype)
-def zeros_like(x, /, *, dtype=None, device=None):
+def zeros_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py
index 7ec01b2e1..9c013d8b4 100644
--- a/numpy/_array_api/_elementwise_functions.py
+++ b/numpy/_array_api/_elementwise_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import array
+
import numpy as np
-def abs(x, /):
+def abs(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`.
@@ -8,7 +12,7 @@ def abs(x, /):
"""
return np.abs(x)
-def acos(x, /):
+def acos(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arccos <numpy.arccos>`.
@@ -17,7 +21,7 @@ def acos(x, /):
# Note: the function name is different here
return np.arccos(x)
-def acosh(x, /):
+def acosh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arccosh <numpy.arccosh>`.
@@ -26,7 +30,7 @@ def acosh(x, /):
# Note: the function name is different here
return np.arccosh(x)
-def add(x1, x2, /):
+def add(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.add <numpy.add>`.
@@ -34,7 +38,7 @@ def add(x1, x2, /):
"""
return np.add(x1, x2)
-def asin(x, /):
+def asin(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arcsin <numpy.arcsin>`.
@@ -43,7 +47,7 @@ def asin(x, /):
# Note: the function name is different here
return np.arcsin(x)
-def asinh(x, /):
+def asinh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arcsinh <numpy.arcsinh>`.
@@ -52,7 +56,7 @@ def asinh(x, /):
# Note: the function name is different here
return np.arcsinh(x)
-def atan(x, /):
+def atan(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arctan <numpy.arctan>`.
@@ -61,7 +65,7 @@ def atan(x, /):
# Note: the function name is different here
return np.arctan(x)
-def atan2(x1, x2, /):
+def atan2(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`.
@@ -70,7 +74,7 @@ def atan2(x1, x2, /):
# Note: the function name is different here
return np.arctan2(x1, x2)
-def atanh(x, /):
+def atanh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.arctanh <numpy.arctanh>`.
@@ -79,7 +83,7 @@ def atanh(x, /):
# Note: the function name is different here
return np.arctanh(x)
-def bitwise_and(x1, x2, /):
+def bitwise_and(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`.
@@ -87,7 +91,7 @@ def bitwise_and(x1, x2, /):
"""
return np.bitwise_and(x1, x2)
-def bitwise_left_shift(x1, x2, /):
+def bitwise_left_shift(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`.
@@ -96,7 +100,7 @@ def bitwise_left_shift(x1, x2, /):
# Note: the function name is different here
return np.left_shift(x1, x2)
-def bitwise_invert(x, /):
+def bitwise_invert(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.invert <numpy.invert>`.
@@ -105,7 +109,7 @@ def bitwise_invert(x, /):
# Note: the function name is different here
return np.invert(x)
-def bitwise_or(x1, x2, /):
+def bitwise_or(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`.
@@ -113,7 +117,7 @@ def bitwise_or(x1, x2, /):
"""
return np.bitwise_or(x1, x2)
-def bitwise_right_shift(x1, x2, /):
+def bitwise_right_shift(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`.
@@ -122,7 +126,7 @@ def bitwise_right_shift(x1, x2, /):
# Note: the function name is different here
return np.right_shift(x1, x2)
-def bitwise_xor(x1, x2, /):
+def bitwise_xor(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`.
@@ -130,7 +134,7 @@ def bitwise_xor(x1, x2, /):
"""
return np.bitwise_xor(x1, x2)
-def ceil(x, /):
+def ceil(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`.
@@ -138,7 +142,7 @@ def ceil(x, /):
"""
return np.ceil(x)
-def cos(x, /):
+def cos(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`.
@@ -146,7 +150,7 @@ def cos(x, /):
"""
return np.cos(x)
-def cosh(x, /):
+def cosh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`.
@@ -154,7 +158,7 @@ def cosh(x, /):
"""
return np.cosh(x)
-def divide(x1, x2, /):
+def divide(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`.
@@ -162,7 +166,7 @@ def divide(x1, x2, /):
"""
return np.divide(x1, x2)
-def equal(x1, x2, /):
+def equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`.
@@ -170,7 +174,7 @@ def equal(x1, x2, /):
"""
return np.equal(x1, x2)
-def exp(x, /):
+def exp(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`.
@@ -178,7 +182,7 @@ def exp(x, /):
"""
return np.exp(x)
-def expm1(x, /):
+def expm1(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`.
@@ -186,7 +190,7 @@ def expm1(x, /):
"""
return np.expm1(x)
-def floor(x, /):
+def floor(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`.
@@ -194,7 +198,7 @@ def floor(x, /):
"""
return np.floor(x)
-def floor_divide(x1, x2, /):
+def floor_divide(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`.
@@ -202,7 +206,7 @@ def floor_divide(x1, x2, /):
"""
return np.floor_divide(x1, x2)
-def greater(x1, x2, /):
+def greater(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`.
@@ -210,7 +214,7 @@ def greater(x1, x2, /):
"""
return np.greater(x1, x2)
-def greater_equal(x1, x2, /):
+def greater_equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`.
@@ -218,7 +222,7 @@ def greater_equal(x1, x2, /):
"""
return np.greater_equal(x1, x2)
-def isfinite(x, /):
+def isfinite(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`.
@@ -226,7 +230,7 @@ def isfinite(x, /):
"""
return np.isfinite(x)
-def isinf(x, /):
+def isinf(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`.
@@ -234,7 +238,7 @@ def isinf(x, /):
"""
return np.isinf(x)
-def isnan(x, /):
+def isnan(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`.
@@ -242,7 +246,7 @@ def isnan(x, /):
"""
return np.isnan(x)
-def less(x1, x2, /):
+def less(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.less <numpy.less>`.
@@ -250,7 +254,7 @@ def less(x1, x2, /):
"""
return np.less(x1, x2)
-def less_equal(x1, x2, /):
+def less_equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
@@ -258,7 +262,7 @@ def less_equal(x1, x2, /):
"""
return np.less_equal(x1, x2)
-def log(x, /):
+def log(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log <numpy.log>`.
@@ -266,7 +270,7 @@ def log(x, /):
"""
return np.log(x)
-def log1p(x, /):
+def log1p(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`.
@@ -274,7 +278,7 @@ def log1p(x, /):
"""
return np.log1p(x)
-def log2(x, /):
+def log2(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`.
@@ -282,7 +286,7 @@ def log2(x, /):
"""
return np.log2(x)
-def log10(x, /):
+def log10(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`.
@@ -290,7 +294,7 @@ def log10(x, /):
"""
return np.log10(x)
-def logical_and(x1, x2, /):
+def logical_and(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`.
@@ -298,7 +302,7 @@ def logical_and(x1, x2, /):
"""
return np.logical_and(x1, x2)
-def logical_not(x, /):
+def logical_not(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`.
@@ -306,7 +310,7 @@ def logical_not(x, /):
"""
return np.logical_not(x)
-def logical_or(x1, x2, /):
+def logical_or(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`.
@@ -314,7 +318,7 @@ def logical_or(x1, x2, /):
"""
return np.logical_or(x1, x2)
-def logical_xor(x1, x2, /):
+def logical_xor(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`.
@@ -322,7 +326,7 @@ def logical_xor(x1, x2, /):
"""
return np.logical_xor(x1, x2)
-def multiply(x1, x2, /):
+def multiply(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`.
@@ -330,7 +334,7 @@ def multiply(x1, x2, /):
"""
return np.multiply(x1, x2)
-def negative(x, /):
+def negative(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`.
@@ -338,7 +342,7 @@ def negative(x, /):
"""
return np.negative(x)
-def not_equal(x1, x2, /):
+def not_equal(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
@@ -346,7 +350,7 @@ def not_equal(x1, x2, /):
"""
return np.not_equal(x1, x2)
-def positive(x, /):
+def positive(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`.
@@ -354,7 +358,7 @@ def positive(x, /):
"""
return np.positive(x)
-def pow(x1, x2, /):
+def pow(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.power <numpy.power>`.
@@ -363,7 +367,7 @@ def pow(x1, x2, /):
# Note: the function name is different here
return np.power(x1, x2)
-def remainder(x1, x2, /):
+def remainder(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
@@ -371,7 +375,7 @@ def remainder(x1, x2, /):
"""
return np.remainder(x1, x2)
-def round(x, /):
+def round(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.round <numpy.round>`.
@@ -379,7 +383,7 @@ def round(x, /):
"""
return np.round(x)
-def sign(x, /):
+def sign(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`.
@@ -387,7 +391,7 @@ def sign(x, /):
"""
return np.sign(x)
-def sin(x, /):
+def sin(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`.
@@ -395,7 +399,7 @@ def sin(x, /):
"""
return np.sin(x)
-def sinh(x, /):
+def sinh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`.
@@ -403,7 +407,7 @@ def sinh(x, /):
"""
return np.sinh(x)
-def square(x, /):
+def square(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.square <numpy.square>`.
@@ -411,7 +415,7 @@ def square(x, /):
"""
return np.square(x)
-def sqrt(x, /):
+def sqrt(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`.
@@ -419,7 +423,7 @@ def sqrt(x, /):
"""
return np.sqrt(x)
-def subtract(x1, x2, /):
+def subtract(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`.
@@ -427,7 +431,7 @@ def subtract(x1, x2, /):
"""
return np.subtract(x1, x2)
-def tan(x, /):
+def tan(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`.
@@ -435,7 +439,7 @@ def tan(x, /):
"""
return np.tan(x)
-def tanh(x, /):
+def tanh(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`.
@@ -443,7 +447,7 @@ def tanh(x, /):
"""
return np.tanh(x)
-def trunc(x, /):
+def trunc(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`.
diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py
index cfb184e8d..addbaeccb 100644
--- a/numpy/_array_api/_linear_algebra_functions.py
+++ b/numpy/_array_api/_linear_algebra_functions.py
@@ -1,3 +1,7 @@
+from __future__ import annotations
+
+from ._types import Literal, Optional, Tuple, Union, array
+
import numpy as np
# def cholesky():
@@ -8,7 +12,7 @@ import numpy as np
# """
# return np.cholesky()
-def cross(x1, x2, /, *, axis=-1):
+def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
"""
Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
@@ -16,7 +20,7 @@ def cross(x1, x2, /, *, axis=-1):
"""
return np.cross(x1, x2, axis=axis)
-def det(x, /):
+def det(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
@@ -25,7 +29,7 @@ def det(x, /):
# Note: this function is being imported from a nondefault namespace
return np.linalg.det(x)
-def diagonal(x, /, *, axis1=0, axis2=1, offset=0):
+def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
"""
Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
@@ -65,7 +69,7 @@ def diagonal(x, /, *, axis1=0, axis2=1, offset=0):
# """
# return np.einsum()
-def inv(x):
+def inv(x: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
@@ -106,7 +110,7 @@ def inv(x):
# """
# return np.matrix_rank()
-def norm(x, /, *, axis=None, keepdims=False, ord=None):
+def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[np.inf, -np.inf, 'fro', 'nuc']]] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
@@ -118,7 +122,7 @@ def norm(x, /, *, axis=None, keepdims=False, ord=None):
# Note: this function is being imported from a nondefault namespace
return np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord)
-def outer(x1, x2, /):
+def outer(x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
@@ -166,7 +170,7 @@ def outer(x1, x2, /):
# """
# return np.svd()
-def trace(x, /, *, axis1=0, axis2=1, offset=0):
+def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
"""
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
@@ -174,7 +178,7 @@ def trace(x, /, *, axis1=0, axis2=1, offset=0):
"""
return np.trace(x, axis1=axis1, axis2=axis2, offset=offset)
-def transpose(x, /, *, axes=None):
+def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py
index 834aa2f8f..f79ef1f9c 100644
--- a/numpy/_array_api/_manipulation_functions.py
+++ b/numpy/_array_api/_manipulation_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import Optional, Tuple, Union, array
+
import numpy as np
-def concat(arrays, /, *, axis=0):
+def concat(arrays: Tuple[array], /, *, axis: Optional[int] = 0) -> array:
"""
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
@@ -9,7 +13,7 @@ def concat(arrays, /, *, axis=0):
# Note: the function name is different here
return np.concatenate(arrays, axis=axis)
-def expand_dims(x, axis, /):
+def expand_dims(x: array, axis: int, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
@@ -17,7 +21,7 @@ def expand_dims(x, axis, /):
"""
return np.expand_dims(x, axis)
-def flip(x, /, *, axis=None):
+def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
@@ -25,7 +29,7 @@ def flip(x, /, *, axis=None):
"""
return np.flip(x, axis=axis)
-def reshape(x, shape, /):
+def reshape(x: array, shape: Tuple[int, ...], /) -> array:
"""
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
@@ -33,7 +37,7 @@ def reshape(x, shape, /):
"""
return np.reshape(x, shape)
-def roll(x, shift, /, *, axis=None):
+def roll(x: array, shift: Union[int, Tuple[int, ...]], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
@@ -41,7 +45,7 @@ def roll(x, shift, /, *, axis=None):
"""
return np.roll(x, shift, axis=axis)
-def squeeze(x, /, *, axis=None):
+def squeeze(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
@@ -49,7 +53,7 @@ def squeeze(x, /, *, axis=None):
"""
return np.squeeze(x, axis=axis)
-def stack(arrays, /, *, axis=0):
+def stack(arrays: Tuple[array], /, *, axis: Optional[int] = 0) -> array:
"""
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py
index 4eed66c48..3b37167af 100644
--- a/numpy/_array_api/_searching_functions.py
+++ b/numpy/_array_api/_searching_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import Tuple, array
+
import numpy as np
-def argmax(x, /, *, axis=None, keepdims=False):
+def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array:
"""
Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`.
@@ -8,7 +12,7 @@ def argmax(x, /, *, axis=None, keepdims=False):
"""
return np.argmax(x, axis=axis, keepdims=keepdims)
-def argmin(x, /, *, axis=None, keepdims=False):
+def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array:
"""
Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`.
@@ -16,7 +20,7 @@ def argmin(x, /, *, axis=None, keepdims=False):
"""
return np.argmin(x, axis=axis, keepdims=keepdims)
-def nonzero(x, /):
+def nonzero(x: array, /) -> Tuple[array, ...]:
"""
Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`.
@@ -24,7 +28,7 @@ def nonzero(x, /):
"""
return np.nonzero(x)
-def where(condition, x1, x2, /):
+def where(condition: array, x1: array, x2: array, /) -> array:
"""
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
diff --git a/numpy/_array_api/_set_functions.py b/numpy/_array_api/_set_functions.py
index fd1438be5..80288c57d 100644
--- a/numpy/_array_api/_set_functions.py
+++ b/numpy/_array_api/_set_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import Tuple, Union, array
+
import numpy as np
-def unique(x, /, *, return_counts=False, return_index=False, return_inverse=False, sorted=True):
+def unique(x: array, /, *, return_counts: bool = False, return_index: bool = False, return_inverse: bool = False, sorted: bool = True) -> Union[array, Tuple[array, ...]]:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
diff --git a/numpy/_array_api/_sorting_functions.py b/numpy/_array_api/_sorting_functions.py
index 5ffe6c8f9..cddfd1598 100644
--- a/numpy/_array_api/_sorting_functions.py
+++ b/numpy/_array_api/_sorting_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import array
+
import numpy as np
-def argsort(x, /, *, axis=-1, descending=False, stable=True):
+def argsort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array:
"""
Array API compatible wrapper for :py:func:`np.argsort <numpy.argsort>`.
@@ -13,7 +17,7 @@ def argsort(x, /, *, axis=-1, descending=False, stable=True):
res = np.flip(res, axis=axis)
return res
-def sort(x, /, *, axis=-1, descending=False, stable=True):
+def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array:
"""
Array API compatible wrapper for :py:func:`np.sort <numpy.sort>`.
diff --git a/numpy/_array_api/_statistical_functions.py b/numpy/_array_api/_statistical_functions.py
index 833c47f66..020053896 100644
--- a/numpy/_array_api/_statistical_functions.py
+++ b/numpy/_array_api/_statistical_functions.py
@@ -1,24 +1,28 @@
+from __future__ import annotations
+
+from ._types import Optional, Tuple, Union, array
+
import numpy as np
-def max(x, /, *, axis=None, keepdims=False):
+def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
return np.max(x, axis=axis, keepdims=keepdims)
-def mean(x, /, *, axis=None, keepdims=False):
+def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
return np.mean(x, axis=axis, keepdims=keepdims)
-def min(x, /, *, axis=None, keepdims=False):
+def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
return np.min(x, axis=axis, keepdims=keepdims)
-def prod(x, /, *, axis=None, keepdims=False):
+def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
return np.prod(x, axis=axis, keepdims=keepdims)
-def std(x, /, *, axis=None, correction=0.0, keepdims=False):
+def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:
# Note: the keyword argument correction is different here
return np.std(x, axis=axis, ddof=correction, keepdims=keepdims)
-def sum(x, /, *, axis=None, keepdims=False):
+def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
return np.sum(x, axis=axis, keepdims=keepdims)
-def var(x, /, *, axis=None, correction=0.0, keepdims=False):
+def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:
# Note: the keyword argument correction is different here
return np.var(x, axis=axis, ddof=correction, keepdims=keepdims)
diff --git a/numpy/_array_api/_types.py b/numpy/_array_api/_types.py
new file mode 100644
index 000000000..e8867a29b
--- /dev/null
+++ b/numpy/_array_api/_types.py
@@ -0,0 +1,18 @@
+"""
+This file defines the types for type annotations.
+
+These names aren't part of the module namespace, but they are used in the
+annotations in the function signatures. The functions in the module are only
+valid for inputs that match the given type annotations.
+"""
+
+__all__ = ['Literal', 'Optional', 'Tuple', 'Union', 'array', 'device', 'dtype']
+
+from typing import Literal, Optional, Tuple, Union, TypeVar
+
+import numpy as np
+
+array = np.ndarray
+device = TypeVar('device')
+dtype = Literal[np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16,
+ np.uint32, np.uint64, np.float32, np.float64]
diff --git a/numpy/_array_api/_utility_functions.py b/numpy/_array_api/_utility_functions.py
index 19743d15c..69e17e0e5 100644
--- a/numpy/_array_api/_utility_functions.py
+++ b/numpy/_array_api/_utility_functions.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
+from ._types import Optional, Tuple, Union, array
+
import numpy as np
-def all(x, /, *, axis=None, keepdims=False):
+def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
"""
Array API compatible wrapper for :py:func:`np.all <numpy.all>`.
@@ -8,7 +12,7 @@ def all(x, /, *, axis=None, keepdims=False):
"""
return np.all(x, axis=axis, keepdims=keepdims)
-def any(x, /, *, axis=None, keepdims=False):
+def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
"""
Array API compatible wrapper for :py:func:`np.any <numpy.any>`.