summaryrefslogtreecommitdiff
path: root/numpy/_array_api/_manipulation_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api/_manipulation_functions.py')
-rw-r--r--numpy/_array_api/_manipulation_functions.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py
index fa0c08d7b..6308bfc26 100644
--- a/numpy/_array_api/_manipulation_functions.py
+++ b/numpy/_array_api/_manipulation_functions.py
@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union
import numpy as np
# Note: the function name is different here
-def concat(arrays: Tuple[Array, ...], /, *, axis: Optional[int] = 0) -> Array:
+def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
@@ -56,7 +56,7 @@ def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
"""
return Array._new(np.squeeze(x._array, axis=axis))
-def stack(arrays: Tuple[Array, ...], /, *, axis: int = 0) -> Array:
+def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.