From b933ebbe1aee58af38f05a341dc3952fc761d777 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Feb 2021 18:24:36 -0700 Subject: Allow dimension 0 arrays in the array API namespace full() and full_like() --- numpy/_array_api/_creation_functions.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) (limited to 'numpy/_array_api/_creation_functions.py') diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py index 4be482199..197960211 100644 --- a/numpy/_array_api/_creation_functions.py +++ b/numpy/_array_api/_creation_functions.py @@ -2,6 +2,7 @@ from __future__ import annotations from ._types import (Optional, SupportsDLPack, SupportsBufferProtocol, Tuple, Union, array, device, dtype) +from ._dtypes import _all_dtypes import numpy as np @@ -88,7 +89,14 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], /, * if device is not None: # Note: Device support is not yet implemented on ndarray raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.full(shape, fill_value, dtype=dtype)) + if isinstance(fill_value, ndarray) and fill_value.ndim == 0: + fill_value = fill_value._array[...] + res = np.full(shape, fill_value, dtype=dtype) + if res.dtype not in _all_dtypes: + # This will happen if the fill value is not something that NumPy + # coerces to one of the acceptable dtypes. + raise TypeError("Invalid input to full") + return ndarray._new(res) def full_like(x: array, fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array: """ @@ -100,7 +108,12 @@ def full_like(x: array, fill_value: Union[int, float], /, *, dtype: Optional[dty if device is not None: # Note: Device support is not yet implemented on ndarray raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.full_like._implementation(x._array, fill_value, dtype=dtype)) + res = np.full_like._implementation(x._array, fill_value, dtype=dtype) + if res.dtype not in _all_dtypes: + # This will happen if the fill value is not something that NumPy + # coerces to one of the acceptable dtypes. + raise TypeError("Invalid input to full_like") + return ndarray._new(res) def linspace(start: Union[int, float], stop: Union[int, float], num: int, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: bool = True) -> array: """ -- cgit v1.2.1