From 74a3ee7a8b75bf6dc271c9a1a4b55d2ad9758420 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 Dec 2021 13:59:08 -0700 Subject: ENH: Add __array__ to the array_api Array object This is *NOT* part of the array API spec (so it should not be relied on for portable code). However, without this, np.asarray(np.array_api.Array) produces an object array instead of doing the conversion to a NumPy array as expected. This would work once np.asarray() implements dlpack support, but until then, it seems reasonable to make the conversion work. Note that the reverse, calling np.array_api.asarray(np.array), already works because np.array_api.asarray() is just a wrapper for np.asarray(). --- numpy/array_api/_array_object.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'numpy/array_api/_array_object.py') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index ead061882..d322e6ca6 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -108,6 +108,17 @@ class Array: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix + # This function is not required by the spec, but we implement it here for + # convenience so that np.asarray(np.array_api.Array) will work. + def __array__(self, dtype=None): + """ + Warning: this method is NOT part of the array API spec. Implementers + of other libraries need not include it, and users should not assume it + will be present in other implementations. + + """ + return np.asarray(self._array, dtype=dtype) + # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than # NumPy behavior -- cgit v1.2.1 From 5f21063cc317d92a866c7259a9509f5e5d6189c2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 Dec 2021 17:19:55 -0700 Subject: Add type hints to the numpy.array_api.Array.__array__ signature Thanks @BvB93 --- numpy/array_api/_array_object.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'numpy/array_api/_array_object.py') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index d322e6ca6..75baf34b0 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -33,6 +33,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Any if TYPE_CHECKING: from ._typing import Any, PyCapsule, Device, Dtype + import numpy.typing as npt import numpy as np @@ -110,7 +111,7 @@ class Array: # This function is not required by the spec, but we implement it here for # convenience so that np.asarray(np.array_api.Array) will work. - def __array__(self, dtype=None): + def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]: """ Warning: this method is NOT part of the array API spec. Implementers of other libraries need not include it, and users should not assume it -- cgit v1.2.1