diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-08-04 20:15:06 -0600 |
---|---|---|
committer | Aaron Meurer <asmeurer@gmail.com> | 2021-08-04 20:15:06 -0600 |
commit | 310929d12967cb0e8e6615466ff9b9f62fc899b6 (patch) | |
tree | aac3c8c08dbd97ef4bfc98120ae8d6114f1410fd /numpy/array_api/_manipulation_functions.py | |
parent | 6789a74312cda391b81ca803d38919555213a38f (diff) | |
download | numpy-310929d12967cb0e8e6615466ff9b9f62fc899b6.tar.gz |
Fix casting for the array API concat() and stack()
Diffstat (limited to 'numpy/array_api/_manipulation_functions.py')
-rw-r--r-- | numpy/array_api/_manipulation_functions.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py index fa6344beb..e68dc6fcf 100644 --- a/numpy/array_api/_manipulation_functions.py +++ b/numpy/array_api/_manipulation_functions.py @@ -14,10 +14,11 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i See its docstring for more information. """ + # Note: Casting rules here are different from the np.concatenate default + # (no for scalars with axis=None, no cross-kind casting) + dtype = result_type(*arrays) arrays = tuple(a._array for a in arrays) - # Call result type here just to raise on disallowed type combinations - result_type(*arrays) - return Array._new(np.concatenate(arrays, axis=axis)) + return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype)) def expand_dims(x: Array, /, *, axis: int) -> Array: """ @@ -65,7 +66,7 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> See its docstring for more information. """ - arrays = tuple(a._array for a in arrays) # Call result type here just to raise on disallowed type combinations result_type(*arrays) + arrays = tuple(a._array for a in arrays) return Array._new(np.stack(arrays, axis=axis)) |