summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_creation_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-04-26 16:55:00 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-04-26 16:55:00 -0600
commitb0b2539208a650ef5651fdfb9c16d57c8412d1c7 (patch)
treedf4bd67dc8bcc7d9071ec713e9d48ffd0de587ac /numpy/_array_api/_creation_functions.py
parent6c196f540429aa0869e8fb66917a5e76447d2c02 (diff)
downloadnumpy-b0b2539208a650ef5651fdfb9c16d57c8412d1c7.tar.gz
Add meshgrid(), broadcast_arrays(), broadcast_to(), and can_cast() to the array API namespace
Diffstat (limited to 'numpy/_array_api/_creation_functions.py')
-rw-r--r--numpy/_array_api/_creation_functions.py15
1 files changed, 13 insertions, 2 deletions
diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py
index 003b10afb..c6db3cb7b 100644
--- a/numpy/_array_api/_creation_functions.py
+++ b/numpy/_array_api/_creation_functions.py
@@ -3,8 +3,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from ._types import (Optional, SupportsDLPack, SupportsBufferProtocol, Tuple,
- Union, array, device, dtype)
+ from ._types import (List, Optional, SupportsDLPack,
+ SupportsBufferProtocol, Tuple, Union, array, device,
+ dtype)
+ from collections.abc import Sequence
from ._dtypes import _all_dtypes
import numpy as np
@@ -135,6 +137,15 @@ def linspace(start: Union[int, float], stop: Union[int, float], num: int, /, *,
raise NotImplementedError("Device support is not yet implemented")
return ndarray._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
+def meshgrid(*arrays: Sequence[array], indexing: str = 'xy') -> List[array, ...]:
+ """
+ Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import ndarray
+ return [ndarray._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)]
+
def ones(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
"""
Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.