summaryrefslogtreecommitdiff
path: root/numpy/array_api/_manipulation_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_manipulation_functions.py')
-rw-r--r--numpy/array_api/_manipulation_functions.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py
index fa6344beb..e68dc6fcf 100644
--- a/numpy/array_api/_manipulation_functions.py
+++ b/numpy/array_api/_manipulation_functions.py
@@ -14,10 +14,11 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i
See its docstring for more information.
"""
+ # Note: Casting rules here are different from the np.concatenate default
+ # (no for scalars with axis=None, no cross-kind casting)
+ dtype = result_type(*arrays)
arrays = tuple(a._array for a in arrays)
- # Call result type here just to raise on disallowed type combinations
- result_type(*arrays)
- return Array._new(np.concatenate(arrays, axis=axis))
+ return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
def expand_dims(x: Array, /, *, axis: int) -> Array:
"""
@@ -65,7 +66,7 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
See its docstring for more information.
"""
- arrays = tuple(a._array for a in arrays)
# Call result type here just to raise on disallowed type combinations
result_type(*arrays)
+ arrays = tuple(a._array for a in arrays)
return Array._new(np.stack(arrays, axis=axis))