diff options
-rw-r--r-- | numpy/_array_api/_creation_functions.py | 26 | ||||
-rw-r--r-- | numpy/_array_api/_elementwise_functions.py | 114 | ||||
-rw-r--r-- | numpy/_array_api/_linear_algebra_functions.py | 20 | ||||
-rw-r--r-- | numpy/_array_api/_manipulation_functions.py | 18 | ||||
-rw-r--r-- | numpy/_array_api/_searching_functions.py | 12 | ||||
-rw-r--r-- | numpy/_array_api/_set_functions.py | 6 | ||||
-rw-r--r-- | numpy/_array_api/_sorting_functions.py | 8 | ||||
-rw-r--r-- | numpy/_array_api/_statistical_functions.py | 18 | ||||
-rw-r--r-- | numpy/_array_api/_types.py | 18 | ||||
-rw-r--r-- | numpy/_array_api/_utility_functions.py | 8 |
10 files changed, 151 insertions, 97 deletions
diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py index b6c0c22cc..1aeaffb71 100644 --- a/numpy/_array_api/_creation_functions.py +++ b/numpy/_array_api/_creation_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import Optional, Tuple, Union, array, device, dtype + import numpy as np -def arange(start, /, *, stop=None, step=1, dtype=None, device=None): +def arange(start: Union[int, float], /, *, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`. @@ -11,7 +15,7 @@ def arange(start, /, *, stop=None, step=1, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.arange(start, stop=stop, step=step, dtype=dtype) -def empty(shape, /, *, dtype=None, device=None): +def empty(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`. @@ -22,7 +26,7 @@ def empty(shape, /, *, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.empty(shape, dtype=dtype) -def empty_like(x, /, *, dtype=None, device=None): +def empty_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`. @@ -33,7 +37,7 @@ def empty_like(x, /, *, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.empty_like(x, dtype=dtype) -def eye(N, /, *, M=None, k=0, dtype=None, device=None): +def eye(N: int, /, *, M: Optional[int] = None, k: Optional[int] = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`. @@ -44,7 +48,7 @@ def eye(N, /, *, M=None, k=0, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.eye(N, M=M, k=k, dtype=dtype) -def full(shape, fill_value, /, *, dtype=None, device=None): +def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.full <numpy.full>`. @@ -55,7 +59,7 @@ def full(shape, fill_value, /, *, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.full(shape, fill_value, dtype=dtype) -def full_like(x, fill_value, /, *, dtype=None, device=None): +def full_like(x: array, fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`. @@ -66,7 +70,7 @@ def full_like(x, fill_value, /, *, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.full_like(x, fill_value, dtype=dtype) -def linspace(start, stop, num, /, *, dtype=None, device=None, endpoint=True): +def linspace(start: Union[int, float], stop: Union[int, float], num: int, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: Optional[bool] = True) -> array: """ Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`. @@ -77,7 +81,7 @@ def linspace(start, stop, num, /, *, dtype=None, device=None, endpoint=True): raise NotImplementedError("Device support is not yet implemented") return np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint) -def ones(shape, /, *, dtype=None, device=None): +def ones(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`. @@ -88,7 +92,7 @@ def ones(shape, /, *, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.ones(shape, dtype=dtype) -def ones_like(x, /, *, dtype=None, device=None): +def ones_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`. @@ -99,7 +103,7 @@ def ones_like(x, /, *, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.ones_like(x, dtype=dtype) -def zeros(shape, /, *, dtype=None, device=None): +def zeros(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`. @@ -110,7 +114,7 @@ def zeros(shape, /, *, dtype=None, device=None): raise NotImplementedError("Device support is not yet implemented") return np.zeros(shape, dtype=dtype) -def zeros_like(x, /, *, dtype=None, device=None): +def zeros_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`. diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index 7ec01b2e1..9c013d8b4 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import array + import numpy as np -def abs(x, /): +def abs(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`. @@ -8,7 +12,7 @@ def abs(x, /): """ return np.abs(x) -def acos(x, /): +def acos(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.arccos <numpy.arccos>`. @@ -17,7 +21,7 @@ def acos(x, /): # Note: the function name is different here return np.arccos(x) -def acosh(x, /): +def acosh(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.arccosh <numpy.arccosh>`. @@ -26,7 +30,7 @@ def acosh(x, /): # Note: the function name is different here return np.arccosh(x) -def add(x1, x2, /): +def add(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.add <numpy.add>`. @@ -34,7 +38,7 @@ def add(x1, x2, /): """ return np.add(x1, x2) -def asin(x, /): +def asin(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.arcsin <numpy.arcsin>`. @@ -43,7 +47,7 @@ def asin(x, /): # Note: the function name is different here return np.arcsin(x) -def asinh(x, /): +def asinh(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.arcsinh <numpy.arcsinh>`. @@ -52,7 +56,7 @@ def asinh(x, /): # Note: the function name is different here return np.arcsinh(x) -def atan(x, /): +def atan(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.arctan <numpy.arctan>`. @@ -61,7 +65,7 @@ def atan(x, /): # Note: the function name is different here return np.arctan(x) -def atan2(x1, x2, /): +def atan2(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`. @@ -70,7 +74,7 @@ def atan2(x1, x2, /): # Note: the function name is different here return np.arctan2(x1, x2) -def atanh(x, /): +def atanh(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.arctanh <numpy.arctanh>`. @@ -79,7 +83,7 @@ def atanh(x, /): # Note: the function name is different here return np.arctanh(x) -def bitwise_and(x1, x2, /): +def bitwise_and(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`. @@ -87,7 +91,7 @@ def bitwise_and(x1, x2, /): """ return np.bitwise_and(x1, x2) -def bitwise_left_shift(x1, x2, /): +def bitwise_left_shift(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`. @@ -96,7 +100,7 @@ def bitwise_left_shift(x1, x2, /): # Note: the function name is different here return np.left_shift(x1, x2) -def bitwise_invert(x, /): +def bitwise_invert(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.invert <numpy.invert>`. @@ -105,7 +109,7 @@ def bitwise_invert(x, /): # Note: the function name is different here return np.invert(x) -def bitwise_or(x1, x2, /): +def bitwise_or(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`. @@ -113,7 +117,7 @@ def bitwise_or(x1, x2, /): """ return np.bitwise_or(x1, x2) -def bitwise_right_shift(x1, x2, /): +def bitwise_right_shift(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`. @@ -122,7 +126,7 @@ def bitwise_right_shift(x1, x2, /): # Note: the function name is different here return np.right_shift(x1, x2) -def bitwise_xor(x1, x2, /): +def bitwise_xor(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`. @@ -130,7 +134,7 @@ def bitwise_xor(x1, x2, /): """ return np.bitwise_xor(x1, x2) -def ceil(x, /): +def ceil(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`. @@ -138,7 +142,7 @@ def ceil(x, /): """ return np.ceil(x) -def cos(x, /): +def cos(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`. @@ -146,7 +150,7 @@ def cos(x, /): """ return np.cos(x) -def cosh(x, /): +def cosh(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`. @@ -154,7 +158,7 @@ def cosh(x, /): """ return np.cosh(x) -def divide(x1, x2, /): +def divide(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`. @@ -162,7 +166,7 @@ def divide(x1, x2, /): """ return np.divide(x1, x2) -def equal(x1, x2, /): +def equal(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`. @@ -170,7 +174,7 @@ def equal(x1, x2, /): """ return np.equal(x1, x2) -def exp(x, /): +def exp(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`. @@ -178,7 +182,7 @@ def exp(x, /): """ return np.exp(x) -def expm1(x, /): +def expm1(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`. @@ -186,7 +190,7 @@ def expm1(x, /): """ return np.expm1(x) -def floor(x, /): +def floor(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`. @@ -194,7 +198,7 @@ def floor(x, /): """ return np.floor(x) -def floor_divide(x1, x2, /): +def floor_divide(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`. @@ -202,7 +206,7 @@ def floor_divide(x1, x2, /): """ return np.floor_divide(x1, x2) -def greater(x1, x2, /): +def greater(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`. @@ -210,7 +214,7 @@ def greater(x1, x2, /): """ return np.greater(x1, x2) -def greater_equal(x1, x2, /): +def greater_equal(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`. @@ -218,7 +222,7 @@ def greater_equal(x1, x2, /): """ return np.greater_equal(x1, x2) -def isfinite(x, /): +def isfinite(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`. @@ -226,7 +230,7 @@ def isfinite(x, /): """ return np.isfinite(x) -def isinf(x, /): +def isinf(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`. @@ -234,7 +238,7 @@ def isinf(x, /): """ return np.isinf(x) -def isnan(x, /): +def isnan(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`. @@ -242,7 +246,7 @@ def isnan(x, /): """ return np.isnan(x) -def less(x1, x2, /): +def less(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.less <numpy.less>`. @@ -250,7 +254,7 @@ def less(x1, x2, /): """ return np.less(x1, x2) -def less_equal(x1, x2, /): +def less_equal(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`. @@ -258,7 +262,7 @@ def less_equal(x1, x2, /): """ return np.less_equal(x1, x2) -def log(x, /): +def log(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.log <numpy.log>`. @@ -266,7 +270,7 @@ def log(x, /): """ return np.log(x) -def log1p(x, /): +def log1p(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`. @@ -274,7 +278,7 @@ def log1p(x, /): """ return np.log1p(x) -def log2(x, /): +def log2(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`. @@ -282,7 +286,7 @@ def log2(x, /): """ return np.log2(x) -def log10(x, /): +def log10(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`. @@ -290,7 +294,7 @@ def log10(x, /): """ return np.log10(x) -def logical_and(x1, x2, /): +def logical_and(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`. @@ -298,7 +302,7 @@ def logical_and(x1, x2, /): """ return np.logical_and(x1, x2) -def logical_not(x, /): +def logical_not(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`. @@ -306,7 +310,7 @@ def logical_not(x, /): """ return np.logical_not(x) -def logical_or(x1, x2, /): +def logical_or(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`. @@ -314,7 +318,7 @@ def logical_or(x1, x2, /): """ return np.logical_or(x1, x2) -def logical_xor(x1, x2, /): +def logical_xor(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`. @@ -322,7 +326,7 @@ def logical_xor(x1, x2, /): """ return np.logical_xor(x1, x2) -def multiply(x1, x2, /): +def multiply(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`. @@ -330,7 +334,7 @@ def multiply(x1, x2, /): """ return np.multiply(x1, x2) -def negative(x, /): +def negative(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`. @@ -338,7 +342,7 @@ def negative(x, /): """ return np.negative(x) -def not_equal(x1, x2, /): +def not_equal(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`. @@ -346,7 +350,7 @@ def not_equal(x1, x2, /): """ return np.not_equal(x1, x2) -def positive(x, /): +def positive(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`. @@ -354,7 +358,7 @@ def positive(x, /): """ return np.positive(x) -def pow(x1, x2, /): +def pow(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.power <numpy.power>`. @@ -363,7 +367,7 @@ def pow(x1, x2, /): # Note: the function name is different here return np.power(x1, x2) -def remainder(x1, x2, /): +def remainder(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`. @@ -371,7 +375,7 @@ def remainder(x1, x2, /): """ return np.remainder(x1, x2) -def round(x, /): +def round(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.round <numpy.round>`. @@ -379,7 +383,7 @@ def round(x, /): """ return np.round(x) -def sign(x, /): +def sign(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`. @@ -387,7 +391,7 @@ def sign(x, /): """ return np.sign(x) -def sin(x, /): +def sin(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`. @@ -395,7 +399,7 @@ def sin(x, /): """ return np.sin(x) -def sinh(x, /): +def sinh(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`. @@ -403,7 +407,7 @@ def sinh(x, /): """ return np.sinh(x) -def square(x, /): +def square(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.square <numpy.square>`. @@ -411,7 +415,7 @@ def square(x, /): """ return np.square(x) -def sqrt(x, /): +def sqrt(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`. @@ -419,7 +423,7 @@ def sqrt(x, /): """ return np.sqrt(x) -def subtract(x1, x2, /): +def subtract(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`. @@ -427,7 +431,7 @@ def subtract(x1, x2, /): """ return np.subtract(x1, x2) -def tan(x, /): +def tan(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`. @@ -435,7 +439,7 @@ def tan(x, /): """ return np.tan(x) -def tanh(x, /): +def tanh(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`. @@ -443,7 +447,7 @@ def tanh(x, /): """ return np.tanh(x) -def trunc(x, /): +def trunc(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`. diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py index cfb184e8d..addbaeccb 100644 --- a/numpy/_array_api/_linear_algebra_functions.py +++ b/numpy/_array_api/_linear_algebra_functions.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from ._types import Literal, Optional, Tuple, Union, array + import numpy as np # def cholesky(): @@ -8,7 +12,7 @@ import numpy as np # """ # return np.cholesky() -def cross(x1, x2, /, *, axis=-1): +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: """ Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`. @@ -16,7 +20,7 @@ def cross(x1, x2, /, *, axis=-1): """ return np.cross(x1, x2, axis=axis) -def det(x, /): +def det(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`. @@ -25,7 +29,7 @@ def det(x, /): # Note: this function is being imported from a nondefault namespace return np.linalg.det(x) -def diagonal(x, /, *, axis1=0, axis2=1, offset=0): +def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array: """ Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`. @@ -65,7 +69,7 @@ def diagonal(x, /, *, axis1=0, axis2=1, offset=0): # """ # return np.einsum() -def inv(x): +def inv(x: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`. @@ -106,7 +110,7 @@ def inv(x): # """ # return np.matrix_rank() -def norm(x, /, *, axis=None, keepdims=False, ord=None): +def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[np.inf, -np.inf, 'fro', 'nuc']]] = None) -> array: """ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. @@ -118,7 +122,7 @@ def norm(x, /, *, axis=None, keepdims=False, ord=None): # Note: this function is being imported from a nondefault namespace return np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) -def outer(x1, x2, /): +def outer(x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`. @@ -166,7 +170,7 @@ def outer(x1, x2, /): # """ # return np.svd() -def trace(x, /, *, axis1=0, axis2=1, offset=0): +def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array: """ Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`. @@ -174,7 +178,7 @@ def trace(x, /, *, axis1=0, axis2=1, offset=0): """ return np.trace(x, axis1=axis1, axis2=axis2, offset=offset) -def transpose(x, /, *, axes=None): +def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array: """ Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`. diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py index 834aa2f8f..f79ef1f9c 100644 --- a/numpy/_array_api/_manipulation_functions.py +++ b/numpy/_array_api/_manipulation_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import Optional, Tuple, Union, array + import numpy as np -def concat(arrays, /, *, axis=0): +def concat(arrays: Tuple[array], /, *, axis: Optional[int] = 0) -> array: """ Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`. @@ -9,7 +13,7 @@ def concat(arrays, /, *, axis=0): # Note: the function name is different here return np.concatenate(arrays, axis=axis) -def expand_dims(x, axis, /): +def expand_dims(x: array, axis: int, /) -> array: """ Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`. @@ -17,7 +21,7 @@ def expand_dims(x, axis, /): """ return np.expand_dims(x, axis) -def flip(x, /, *, axis=None): +def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: """ Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`. @@ -25,7 +29,7 @@ def flip(x, /, *, axis=None): """ return np.flip(x, axis=axis) -def reshape(x, shape, /): +def reshape(x: array, shape: Tuple[int, ...], /) -> array: """ Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`. @@ -33,7 +37,7 @@ def reshape(x, shape, /): """ return np.reshape(x, shape) -def roll(x, shift, /, *, axis=None): +def roll(x: array, shift: Union[int, Tuple[int, ...]], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: """ Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`. @@ -41,7 +45,7 @@ def roll(x, shift, /, *, axis=None): """ return np.roll(x, shift, axis=axis) -def squeeze(x, /, *, axis=None): +def squeeze(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: """ Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`. @@ -49,7 +53,7 @@ def squeeze(x, /, *, axis=None): """ return np.squeeze(x, axis=axis) -def stack(arrays, /, *, axis=0): +def stack(arrays: Tuple[array], /, *, axis: Optional[int] = 0) -> array: """ Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`. diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py index 4eed66c48..3b37167af 100644 --- a/numpy/_array_api/_searching_functions.py +++ b/numpy/_array_api/_searching_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import Tuple, array + import numpy as np -def argmax(x, /, *, axis=None, keepdims=False): +def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`. @@ -8,7 +12,7 @@ def argmax(x, /, *, axis=None, keepdims=False): """ return np.argmax(x, axis=axis, keepdims=keepdims) -def argmin(x, /, *, axis=None, keepdims=False): +def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`. @@ -16,7 +20,7 @@ def argmin(x, /, *, axis=None, keepdims=False): """ return np.argmin(x, axis=axis, keepdims=keepdims) -def nonzero(x, /): +def nonzero(x: array, /) -> Tuple[array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`. @@ -24,7 +28,7 @@ def nonzero(x, /): """ return np.nonzero(x) -def where(condition, x1, x2, /): +def where(condition: array, x1: array, x2: array, /) -> array: """ Array API compatible wrapper for :py:func:`np.where <numpy.where>`. diff --git a/numpy/_array_api/_set_functions.py b/numpy/_array_api/_set_functions.py index fd1438be5..80288c57d 100644 --- a/numpy/_array_api/_set_functions.py +++ b/numpy/_array_api/_set_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import Tuple, Union, array + import numpy as np -def unique(x, /, *, return_counts=False, return_index=False, return_inverse=False, sorted=True): +def unique(x: array, /, *, return_counts: bool = False, return_index: bool = False, return_inverse: bool = False, sorted: bool = True) -> Union[array, Tuple[array, ...]]: """ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`. diff --git a/numpy/_array_api/_sorting_functions.py b/numpy/_array_api/_sorting_functions.py index 5ffe6c8f9..cddfd1598 100644 --- a/numpy/_array_api/_sorting_functions.py +++ b/numpy/_array_api/_sorting_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import array + import numpy as np -def argsort(x, /, *, axis=-1, descending=False, stable=True): +def argsort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array: """ Array API compatible wrapper for :py:func:`np.argsort <numpy.argsort>`. @@ -13,7 +17,7 @@ def argsort(x, /, *, axis=-1, descending=False, stable=True): res = np.flip(res, axis=axis) return res -def sort(x, /, *, axis=-1, descending=False, stable=True): +def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array: """ Array API compatible wrapper for :py:func:`np.sort <numpy.sort>`. diff --git a/numpy/_array_api/_statistical_functions.py b/numpy/_array_api/_statistical_functions.py index 833c47f66..020053896 100644 --- a/numpy/_array_api/_statistical_functions.py +++ b/numpy/_array_api/_statistical_functions.py @@ -1,24 +1,28 @@ +from __future__ import annotations + +from ._types import Optional, Tuple, Union, array + import numpy as np -def max(x, /, *, axis=None, keepdims=False): +def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: return np.max(x, axis=axis, keepdims=keepdims) -def mean(x, /, *, axis=None, keepdims=False): +def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: return np.mean(x, axis=axis, keepdims=keepdims) -def min(x, /, *, axis=None, keepdims=False): +def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: return np.min(x, axis=axis, keepdims=keepdims) -def prod(x, /, *, axis=None, keepdims=False): +def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: return np.prod(x, axis=axis, keepdims=keepdims) -def std(x, /, *, axis=None, correction=0.0, keepdims=False): +def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: # Note: the keyword argument correction is different here return np.std(x, axis=axis, ddof=correction, keepdims=keepdims) -def sum(x, /, *, axis=None, keepdims=False): +def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: return np.sum(x, axis=axis, keepdims=keepdims) -def var(x, /, *, axis=None, correction=0.0, keepdims=False): +def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: # Note: the keyword argument correction is different here return np.var(x, axis=axis, ddof=correction, keepdims=keepdims) diff --git a/numpy/_array_api/_types.py b/numpy/_array_api/_types.py new file mode 100644 index 000000000..e8867a29b --- /dev/null +++ b/numpy/_array_api/_types.py @@ -0,0 +1,18 @@ +""" +This file defines the types for type annotations. + +These names aren't part of the module namespace, but they are used in the +annotations in the function signatures. The functions in the module are only +valid for inputs that match the given type annotations. +""" + +__all__ = ['Literal', 'Optional', 'Tuple', 'Union', 'array', 'device', 'dtype'] + +from typing import Literal, Optional, Tuple, Union, TypeVar + +import numpy as np + +array = np.ndarray +device = TypeVar('device') +dtype = Literal[np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, + np.uint32, np.uint64, np.float32, np.float64] diff --git a/numpy/_array_api/_utility_functions.py b/numpy/_array_api/_utility_functions.py index 19743d15c..69e17e0e5 100644 --- a/numpy/_array_api/_utility_functions.py +++ b/numpy/_array_api/_utility_functions.py @@ -1,6 +1,10 @@ +from __future__ import annotations + +from ._types import Optional, Tuple, Union, array + import numpy as np -def all(x, /, *, axis=None, keepdims=False): +def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.all <numpy.all>`. @@ -8,7 +12,7 @@ def all(x, /, *, axis=None, keepdims=False): """ return np.all(x, axis=axis, keepdims=keepdims) -def any(x, /, *, axis=None, keepdims=False): +def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: """ Array API compatible wrapper for :py:func:`np.any <numpy.any>`. |