diff options
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r-- | numpy/core/shape_base.py | 60 |
1 files changed, 49 insertions, 11 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 1a4198c5f..c5e0ad475 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -215,12 +215,13 @@ def _arrays_for_stack_dispatcher(arrays, stacklevel=4): return arrays -def _vhstack_dispatcher(tup): +def _vhstack_dispatcher(tup, *, + dtype=None, casting=None): return _arrays_for_stack_dispatcher(tup) @array_function_dispatch(_vhstack_dispatcher) -def vstack(tup): +def vstack(tup, *, dtype=None, casting="same_kind"): """ Stack arrays in sequence vertically (row wise). @@ -239,6 +240,17 @@ def vstack(tup): The arrays must have the same shape along all but the first axis. 1-D arrays must have the same length. + dtype : str or dtype + If provided, the destination array will have this dtype. Cannot be + provided together with `out`. + + .. versionadded:: 1.24 + + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Defaults to 'same_kind'. + + .. versionadded:: 1.24 + Returns ------- stacked : ndarray @@ -279,11 +291,11 @@ def vstack(tup): arrs = atleast_2d(*tup) if not isinstance(arrs, list): arrs = [arrs] - return _nx.concatenate(arrs, 0) + return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting) @array_function_dispatch(_vhstack_dispatcher) -def hstack(tup): +def hstack(tup, *, dtype=None, casting="same_kind"): """ Stack arrays in sequence horizontally (column wise). @@ -302,6 +314,17 @@ def hstack(tup): The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length. + dtype : str or dtype + If provided, the destination array will have this dtype. Cannot be + provided together with `out`. + + .. versionadded:: 1.24 + + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Defaults to 'same_kind'. + + .. versionadded:: 1.24 + Returns ------- stacked : ndarray @@ -340,12 +363,13 @@ def hstack(tup): arrs = [arrs] # As a special case, dimension 0 of 1-dimensional arrays is "horizontal" if arrs and arrs[0].ndim == 1: - return _nx.concatenate(arrs, 0) + return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting) else: - return _nx.concatenate(arrs, 1) + return _nx.concatenate(arrs, 1, dtype=dtype, casting=casting) -def _stack_dispatcher(arrays, axis=None, out=None): +def _stack_dispatcher(arrays, axis=None, out=None, *, + dtype=None, casting=None): arrays = _arrays_for_stack_dispatcher(arrays, stacklevel=6) if out is not None: # optimize for the typical case where only arrays is provided @@ -355,7 +379,7 @@ def _stack_dispatcher(arrays, axis=None, out=None): @array_function_dispatch(_stack_dispatcher) -def stack(arrays, axis=0, out=None): +def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"): """ Join a sequence of arrays along a new axis. @@ -378,6 +402,18 @@ def stack(arrays, axis=0, out=None): correct, matching that of what stack would have returned if no out argument were specified. + dtype : str or dtype + If provided, the destination array will have this dtype. Cannot be + provided together with `out`. + + .. versionadded:: 1.24 + + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Defaults to 'same_kind'. + + .. versionadded:: 1.24 + + Returns ------- stacked : ndarray @@ -430,7 +466,8 @@ def stack(arrays, axis=0, out=None): sl = (slice(None),) * axis + (_nx.newaxis,) expanded_arrays = [arr[sl] for arr in arrays] - return _nx.concatenate(expanded_arrays, axis=axis, out=out) + return _nx.concatenate(expanded_arrays, axis=axis, out=out, + dtype=dtype, casting=casting) # Internal functions to eliminate the overhead of repeated dispatch in one of @@ -438,7 +475,8 @@ def stack(arrays, axis=0, out=None): # Use getattr to protect against __array_function__ being disabled. _size = getattr(_from_nx.size, '__wrapped__', _from_nx.size) _ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim) -_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate) +_concatenate = getattr(_from_nx.concatenate, + '__wrapped__', _from_nx.concatenate) def _block_format_index(index): @@ -539,7 +577,7 @@ def _concatenate_shapes(shapes, axis): """Given array shapes, return the resulting shape and slices prefixes. These help in nested concatenation. - + Returns ------- shape: tuple of int |