summaryrefslogtreecommitdiff
path: root/numpy/array_api/_manipulation_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-08-04 20:15:06 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-08-04 20:15:06 -0600
commit310929d12967cb0e8e6615466ff9b9f62fc899b6 (patch)
treeaac3c8c08dbd97ef4bfc98120ae8d6114f1410fd /numpy/array_api/_manipulation_functions.py
parent6789a74312cda391b81ca803d38919555213a38f (diff)
downloadnumpy-310929d12967cb0e8e6615466ff9b9f62fc899b6.tar.gz
Fix casting for the array API concat() and stack()
Diffstat (limited to 'numpy/array_api/_manipulation_functions.py')
-rw-r--r--numpy/array_api/_manipulation_functions.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py
index fa6344beb..e68dc6fcf 100644
--- a/numpy/array_api/_manipulation_functions.py
+++ b/numpy/array_api/_manipulation_functions.py
@@ -14,10 +14,11 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i
See its docstring for more information.
"""
+ # Note: Casting rules here are different from the np.concatenate default
+ # (no for scalars with axis=None, no cross-kind casting)
+ dtype = result_type(*arrays)
arrays = tuple(a._array for a in arrays)
- # Call result type here just to raise on disallowed type combinations
- result_type(*arrays)
- return Array._new(np.concatenate(arrays, axis=axis))
+ return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
def expand_dims(x: Array, /, *, axis: int) -> Array:
"""
@@ -65,7 +66,7 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
See its docstring for more information.
"""
- arrays = tuple(a._array for a in arrays)
# Call result type here just to raise on disallowed type combinations
result_type(*arrays)
+ arrays = tuple(a._array for a in arrays)
return Array._new(np.stack(arrays, axis=axis))