summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_array_object.py15
-rw-r--r--numpy/_array_api/_creation_functions.py2
-rw-r--r--numpy/_array_api/_data_type_functions.py37
-rw-r--r--numpy/_array_api/_manipulation_functions.py4
4 files changed, 46 insertions, 12 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index 9ea0eef18..43d8a8961 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -396,7 +396,8 @@ class Array:
res = self._array.__le__(other._array)
return self.__class__._new(res)
- def __len__(self, /):
+ # Note: __len__ may end up being removed from the array API spec.
+ def __len__(self, /) -> int:
"""
Performs the operation __len__.
"""
@@ -843,7 +844,7 @@ class Array:
return self.__class__._new(res)
@property
- def dtype(self):
+ def dtype(self) -> Dtype:
"""
Array API compatible wrapper for :py:meth:`np.ndaray.dtype <numpy.ndarray.dtype>`.
@@ -852,7 +853,7 @@ class Array:
return self._array.dtype
@property
- def device(self):
+ def device(self) -> Device:
"""
Array API compatible wrapper for :py:meth:`np.ndaray.device <numpy.ndarray.device>`.
@@ -862,7 +863,7 @@ class Array:
raise NotImplementedError("The device attribute is not yet implemented")
@property
- def ndim(self):
+ def ndim(self) -> int:
"""
Array API compatible wrapper for :py:meth:`np.ndaray.ndim <numpy.ndarray.ndim>`.
@@ -871,7 +872,7 @@ class Array:
return self._array.ndim
@property
- def shape(self):
+ def shape(self) -> Tuple[int, ...]:
"""
Array API compatible wrapper for :py:meth:`np.ndaray.shape <numpy.ndarray.shape>`.
@@ -880,7 +881,7 @@ class Array:
return self._array.shape
@property
- def size(self):
+ def size(self) -> int:
"""
Array API compatible wrapper for :py:meth:`np.ndaray.size <numpy.ndarray.size>`.
@@ -889,7 +890,7 @@ class Array:
return self._array.size
@property
- def T(self):
+ def T(self) -> Array:
"""
Array API compatible wrapper for :py:meth:`np.ndaray.T <numpy.ndarray.T>`.
diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py
index 9e9722a55..517c2bfdd 100644
--- a/numpy/_array_api/_creation_functions.py
+++ b/numpy/_array_api/_creation_functions.py
@@ -10,7 +10,7 @@ from ._dtypes import _all_dtypes
import numpy as np
-def asarray(obj: Union[float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array:
+def asarray(obj: Union[Array, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py
index 2f304bf49..693ceae84 100644
--- a/numpy/_array_api/_data_type_functions.py
+++ b/numpy/_array_api/_data_type_functions.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from ._array_object import Array
+from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ._types import List, Tuple, Union, Dtype
@@ -38,13 +39,44 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
from_ = from_._array
return np.can_cast(from_, to)
+# These are internal objects for the return types of finfo and iinfo, since
+# the NumPy versions contain extra data that isn't part of the spec.
+@dataclass
+class finfo_object:
+ bits: int
+ # Note: The types of the float data here are float, whereas in NumPy they
+ # are scalars of the corresponding float dtype.
+ eps: float
+ max: float
+ min: float
+ # Note: smallest_normal is part of the array API spec, but cannot be used
+ # until https://github.com/numpy/numpy/pull/18536 is merged.
+
+ # smallest_normal: float
+
+@dataclass
+class iinfo_object:
+ bits: int
+ max: int
+ min: int
+
def finfo(type: Union[Dtype, Array], /) -> finfo_object:
"""
Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`.
See its docstring for more information.
"""
- return np.finfo(type)
+ fi = np.finfo(type)
+ # Note: The types of the float data here are float, whereas in NumPy they
+ # are scalars of the corresponding float dtype.
+ return finfo_object(
+ fi.bits,
+ float(fi.eps),
+ float(fi.max),
+ float(fi.min),
+ # TODO: Uncomment this when #18536 is merged.
+ # float(fi.smallest_normal),
+ )
def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
"""
@@ -52,7 +84,8 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
See its docstring for more information.
"""
- return np.iinfo(type)
+ ii = np.iinfo(type)
+ return iinfo_object(ii.bits, ii.max, ii.min)
def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype:
"""
diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py
index fa0c08d7b..6308bfc26 100644
--- a/numpy/_array_api/_manipulation_functions.py
+++ b/numpy/_array_api/_manipulation_functions.py
@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union
import numpy as np
# Note: the function name is different here
-def concat(arrays: Tuple[Array, ...], /, *, axis: Optional[int] = 0) -> Array:
+def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
@@ -56,7 +56,7 @@ def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
"""
return Array._new(np.squeeze(x._array, axis=axis))
-def stack(arrays: Tuple[Array, ...], /, *, axis: int = 0) -> Array:
+def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.