summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_creation_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-02-26 18:24:36 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-02-26 18:24:36 -0700
commitb933ebbe1aee58af38f05a341dc3952fc761d777 (patch)
treef42e743a774097920cc06eebdae655881cdc9b17 /numpy/_array_api/_creation_functions.py
parent2ff635c7cbc8804a3956ddbf8165f536dffc2df5 (diff)
downloadnumpy-b933ebbe1aee58af38f05a341dc3952fc761d777.tar.gz
Allow dimension 0 arrays in the array API namespace full() and full_like()
Diffstat (limited to 'numpy/_array_api/_creation_functions.py')
-rw-r--r--numpy/_array_api/_creation_functions.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py
index 4be482199..197960211 100644
--- a/numpy/_array_api/_creation_functions.py
+++ b/numpy/_array_api/_creation_functions.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from ._types import (Optional, SupportsDLPack, SupportsBufferProtocol, Tuple,
Union, array, device, dtype)
+from ._dtypes import _all_dtypes
import numpy as np
@@ -88,7 +89,14 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], /, *
if device is not None:
# Note: Device support is not yet implemented on ndarray
raise NotImplementedError("Device support is not yet implemented")
- return ndarray._new(np.full(shape, fill_value, dtype=dtype))
+ if isinstance(fill_value, ndarray) and fill_value.ndim == 0:
+ fill_value = fill_value._array[...]
+ res = np.full(shape, fill_value, dtype=dtype)
+ if res.dtype not in _all_dtypes:
+ # This will happen if the fill value is not something that NumPy
+ # coerces to one of the acceptable dtypes.
+ raise TypeError("Invalid input to full")
+ return ndarray._new(res)
def full_like(x: array, fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
@@ -100,7 +108,12 @@ def full_like(x: array, fill_value: Union[int, float], /, *, dtype: Optional[dty
if device is not None:
# Note: Device support is not yet implemented on ndarray
raise NotImplementedError("Device support is not yet implemented")
- return ndarray._new(np.full_like._implementation(x._array, fill_value, dtype=dtype))
+ res = np.full_like._implementation(x._array, fill_value, dtype=dtype)
+ if res.dtype not in _all_dtypes:
+ # This will happen if the fill value is not something that NumPy
+ # coerces to one of the acceptable dtypes.
+ raise TypeError("Invalid input to full_like")
+ return ndarray._new(res)
def linspace(start: Union[int, float], stop: Union[int, float], num: int, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: bool = True) -> array:
"""