diff options
Diffstat (limited to 'numpy/_array_api')
-rw-r--r-- | numpy/_array_api/_array_object.py | 15 | ||||
-rw-r--r-- | numpy/_array_api/_creation_functions.py | 2 | ||||
-rw-r--r-- | numpy/_array_api/_data_type_functions.py | 37 | ||||
-rw-r--r-- | numpy/_array_api/_manipulation_functions.py | 4 |
4 files changed, 46 insertions, 12 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 9ea0eef18..43d8a8961 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -396,7 +396,8 @@ class Array: res = self._array.__le__(other._array) return self.__class__._new(res) - def __len__(self, /): + # Note: __len__ may end up being removed from the array API spec. + def __len__(self, /) -> int: """ Performs the operation __len__. """ @@ -843,7 +844,7 @@ class Array: return self.__class__._new(res) @property - def dtype(self): + def dtype(self) -> Dtype: """ Array API compatible wrapper for :py:meth:`np.ndaray.dtype <numpy.ndarray.dtype>`. @@ -852,7 +853,7 @@ class Array: return self._array.dtype @property - def device(self): + def device(self) -> Device: """ Array API compatible wrapper for :py:meth:`np.ndaray.device <numpy.ndarray.device>`. @@ -862,7 +863,7 @@ class Array: raise NotImplementedError("The device attribute is not yet implemented") @property - def ndim(self): + def ndim(self) -> int: """ Array API compatible wrapper for :py:meth:`np.ndaray.ndim <numpy.ndarray.ndim>`. @@ -871,7 +872,7 @@ class Array: return self._array.ndim @property - def shape(self): + def shape(self) -> Tuple[int, ...]: """ Array API compatible wrapper for :py:meth:`np.ndaray.shape <numpy.ndarray.shape>`. @@ -880,7 +881,7 @@ class Array: return self._array.shape @property - def size(self): + def size(self) -> int: """ Array API compatible wrapper for :py:meth:`np.ndaray.size <numpy.ndarray.size>`. @@ -889,7 +890,7 @@ class Array: return self._array.size @property - def T(self): + def T(self) -> Array: """ Array API compatible wrapper for :py:meth:`np.ndaray.T <numpy.ndarray.T>`. diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py index 9e9722a55..517c2bfdd 100644 --- a/numpy/_array_api/_creation_functions.py +++ b/numpy/_array_api/_creation_functions.py @@ -10,7 +10,7 @@ from ._dtypes import _all_dtypes import numpy as np -def asarray(obj: Union[float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array: +def asarray(obj: Union[Array, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`. diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py index 2f304bf49..693ceae84 100644 --- a/numpy/_array_api/_data_type_functions.py +++ b/numpy/_array_api/_data_type_functions.py @@ -2,6 +2,7 @@ from __future__ import annotations from ._array_object import Array +from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from ._types import List, Tuple, Union, Dtype @@ -38,13 +39,44 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: from_ = from_._array return np.can_cast(from_, to) +# These are internal objects for the return types of finfo and iinfo, since +# the NumPy versions contain extra data that isn't part of the spec. +@dataclass +class finfo_object: + bits: int + # Note: The types of the float data here are float, whereas in NumPy they + # are scalars of the corresponding float dtype. + eps: float + max: float + min: float + # Note: smallest_normal is part of the array API spec, but cannot be used + # until https://github.com/numpy/numpy/pull/18536 is merged. + + # smallest_normal: float + +@dataclass +class iinfo_object: + bits: int + max: int + min: int + def finfo(type: Union[Dtype, Array], /) -> finfo_object: """ Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`. See its docstring for more information. """ - return np.finfo(type) + fi = np.finfo(type) + # Note: The types of the float data here are float, whereas in NumPy they + # are scalars of the corresponding float dtype. + return finfo_object( + fi.bits, + float(fi.eps), + float(fi.max), + float(fi.min), + # TODO: Uncomment this when #18536 is merged. + # float(fi.smallest_normal), + ) def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: """ @@ -52,7 +84,8 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: See its docstring for more information. """ - return np.iinfo(type) + ii = np.iinfo(type) + return iinfo_object(ii.bits, ii.max, ii.min) def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: """ diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py index fa0c08d7b..6308bfc26 100644 --- a/numpy/_array_api/_manipulation_functions.py +++ b/numpy/_array_api/_manipulation_functions.py @@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union import numpy as np # Note: the function name is different here -def concat(arrays: Tuple[Array, ...], /, *, axis: Optional[int] = 0) -> Array: +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`. @@ -56,7 +56,7 @@ def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ return Array._new(np.squeeze(x._array, axis=axis)) -def stack(arrays: Tuple[Array, ...], /, *, axis: int = 0) -> Array: +def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`. |