summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/shape_base.py60
-rw-r--r--numpy/core/shape_base.pyi43
-rw-r--r--numpy/core/tests/test_shape_base.py62
3 files changed, 147 insertions, 18 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 1a4198c5f..c5e0ad475 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -215,12 +215,13 @@ def _arrays_for_stack_dispatcher(arrays, stacklevel=4):
return arrays
-def _vhstack_dispatcher(tup):
+def _vhstack_dispatcher(tup, *,
+ dtype=None, casting=None):
return _arrays_for_stack_dispatcher(tup)
@array_function_dispatch(_vhstack_dispatcher)
-def vstack(tup):
+def vstack(tup, *, dtype=None, casting="same_kind"):
"""
Stack arrays in sequence vertically (row wise).
@@ -239,6 +240,17 @@ def vstack(tup):
The arrays must have the same shape along all but the first axis.
1-D arrays must have the same length.
+ dtype : str or dtype
+ If provided, the destination array will have this dtype. Cannot be
+ provided together with `out`.
+
+ .. versionadded:: 1.24
+
+ casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
+ Controls what kind of data casting may occur. Defaults to 'same_kind'.
+
+ .. versionadded:: 1.24
+
Returns
-------
stacked : ndarray
@@ -279,11 +291,11 @@ def vstack(tup):
arrs = atleast_2d(*tup)
if not isinstance(arrs, list):
arrs = [arrs]
- return _nx.concatenate(arrs, 0)
+ return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
@array_function_dispatch(_vhstack_dispatcher)
-def hstack(tup):
+def hstack(tup, *, dtype=None, casting="same_kind"):
"""
Stack arrays in sequence horizontally (column wise).
@@ -302,6 +314,17 @@ def hstack(tup):
The arrays must have the same shape along all but the second axis,
except 1-D arrays which can be any length.
+ dtype : str or dtype
+ If provided, the destination array will have this dtype. Cannot be
+ provided together with `out`.
+
+ .. versionadded:: 1.24
+
+ casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
+ Controls what kind of data casting may occur. Defaults to 'same_kind'.
+
+ .. versionadded:: 1.24
+
Returns
-------
stacked : ndarray
@@ -340,12 +363,13 @@ def hstack(tup):
arrs = [arrs]
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
if arrs and arrs[0].ndim == 1:
- return _nx.concatenate(arrs, 0)
+ return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
else:
- return _nx.concatenate(arrs, 1)
+ return _nx.concatenate(arrs, 1, dtype=dtype, casting=casting)
-def _stack_dispatcher(arrays, axis=None, out=None):
+def _stack_dispatcher(arrays, axis=None, out=None, *,
+ dtype=None, casting=None):
arrays = _arrays_for_stack_dispatcher(arrays, stacklevel=6)
if out is not None:
# optimize for the typical case where only arrays is provided
@@ -355,7 +379,7 @@ def _stack_dispatcher(arrays, axis=None, out=None):
@array_function_dispatch(_stack_dispatcher)
-def stack(arrays, axis=0, out=None):
+def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
"""
Join a sequence of arrays along a new axis.
@@ -378,6 +402,18 @@ def stack(arrays, axis=0, out=None):
correct, matching that of what stack would have returned if no
out argument were specified.
+ dtype : str or dtype
+ If provided, the destination array will have this dtype. Cannot be
+ provided together with `out`.
+
+ .. versionadded:: 1.24
+
+ casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
+ Controls what kind of data casting may occur. Defaults to 'same_kind'.
+
+ .. versionadded:: 1.24
+
+
Returns
-------
stacked : ndarray
@@ -430,7 +466,8 @@ def stack(arrays, axis=0, out=None):
sl = (slice(None),) * axis + (_nx.newaxis,)
expanded_arrays = [arr[sl] for arr in arrays]
- return _nx.concatenate(expanded_arrays, axis=axis, out=out)
+ return _nx.concatenate(expanded_arrays, axis=axis, out=out,
+ dtype=dtype, casting=casting)
# Internal functions to eliminate the overhead of repeated dispatch in one of
@@ -438,7 +475,8 @@ def stack(arrays, axis=0, out=None):
# Use getattr to protect against __array_function__ being disabled.
_size = getattr(_from_nx.size, '__wrapped__', _from_nx.size)
_ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim)
-_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate)
+_concatenate = getattr(_from_nx.concatenate,
+ '__wrapped__', _from_nx.concatenate)
def _block_format_index(index):
@@ -539,7 +577,7 @@ def _concatenate_shapes(shapes, axis):
"""Given array shapes, return the resulting shape and slices prefixes.
These help in nested concatenation.
-
+
Returns
-------
shape: tuple of int
diff --git a/numpy/core/shape_base.pyi b/numpy/core/shape_base.pyi
index cea355d44..82541b55b 100644
--- a/numpy/core/shape_base.pyi
+++ b/numpy/core/shape_base.pyi
@@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import TypeVar, overload, Any, SupportsIndex
-from numpy import generic
-from numpy._typing import ArrayLike, NDArray, _ArrayLike
+from numpy import generic, _CastingKind
+from numpy._typing import ArrayLike, NDArray, _ArrayLike, DTypeLike
_SCT = TypeVar("_SCT", bound=generic)
_ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
@@ -31,32 +31,61 @@ def atleast_3d(arys: ArrayLike, /) -> NDArray[Any]: ...
def atleast_3d(*arys: ArrayLike) -> list[NDArray[Any]]: ...
@overload
-def vstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
+def vstack(
+ tup: Sequence[_ArrayLike[_SCT]],
+ *,
+ dtype: DTypeLike = ...,
+ casting: _CastingKind = ...
+) -> NDArray[_SCT]: ...
@overload
-def vstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
+def vstack(
+ tup: Sequence[ArrayLike],
+ *,
+ dtype: DTypeLike = ...,
+ casting: _CastingKind = ...
+) -> NDArray[Any]: ...
@overload
-def hstack(tup: Sequence[_ArrayLike[_SCT]]) -> NDArray[_SCT]: ...
+def hstack(
+ tup: Sequence[_ArrayLike[_SCT]],
+ *,
+ dtype: DTypeLike = ...,
+ casting: _CastingKind = ...
+) -> NDArray[_SCT]: ...
@overload
-def hstack(tup: Sequence[ArrayLike]) -> NDArray[Any]: ...
+def hstack(
+ tup: Sequence[ArrayLike],
+ *,
+ dtype: DTypeLike = ...,
+ casting: _CastingKind = ...
+) -> NDArray[Any]: ...
@overload
def stack(
arrays: Sequence[_ArrayLike[_SCT]],
axis: SupportsIndex = ...,
- out: None = ...,
+ out: None = ...,
+ *,
+ dtype: DTypeLike = ...,
+ casting: _CastingKind = ...
) -> NDArray[_SCT]: ...
@overload
def stack(
arrays: Sequence[ArrayLike],
axis: SupportsIndex = ...,
out: None = ...,
+ *,
+ dtype: DTypeLike = ...,
+ casting: _CastingKind = ...
) -> NDArray[Any]: ...
@overload
def stack(
arrays: Sequence[ArrayLike],
axis: SupportsIndex = ...,
out: _ArrayType = ...,
+ *,
+ dtype: DTypeLike = ...,
+ casting: _CastingKind = ...
) -> _ArrayType: ...
@overload
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index 679e3c036..c8dbb144a 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -157,6 +157,19 @@ class TestHstack:
with assert_warns(FutureWarning):
hstack(map(lambda x: x, np.ones((3, 2))))
+ def test_casting_and_dtype(self):
+ a = np.array([1, 2, 3])
+ b = np.array([2.5, 3.5, 4.5])
+ res = np.hstack((a, b), casting="unsafe", dtype=np.int64)
+ expected_res = np.array([1, 2, 3, 2, 3, 4])
+ assert_array_equal(res, expected_res)
+
+ def test_casting_and_dtype_type_error(self):
+ a = np.array([1, 2, 3])
+ b = np.array([2.5, 3.5, 4.5])
+ with pytest.raises(TypeError):
+ hstack((a, b), casting="safe", dtype=np.int64)
+
class TestVstack:
def test_non_iterable(self):
@@ -197,6 +210,20 @@ class TestVstack:
with assert_warns(FutureWarning):
vstack((np.arange(3) for _ in range(2)))
+ def test_casting_and_dtype(self):
+ a = np.array([1, 2, 3])
+ b = np.array([2.5, 3.5, 4.5])
+ res = np.vstack((a, b), casting="unsafe", dtype=np.int64)
+ expected_res = np.array([[1, 2, 3], [2, 3, 4]])
+ assert_array_equal(res, expected_res)
+
+ def test_casting_and_dtype_type_error(self):
+ a = np.array([1, 2, 3])
+ b = np.array([2.5, 3.5, 4.5])
+ with pytest.raises(TypeError):
+ vstack((a, b), casting="safe", dtype=np.int64)
+
+
class TestConcatenate:
def test_returns_copy(self):
@@ -449,6 +476,41 @@ def test_stack():
with assert_warns(FutureWarning):
result = stack((x for x in range(3)))
assert_array_equal(result, np.array([0, 1, 2]))
+ #casting and dtype test
+ a = np.array([1, 2, 3])
+ b = np.array([2.5, 3.5, 4.5])
+ res = np.stack((a, b), axis=1, casting="unsafe", dtype=np.int64)
+ expected_res = np.array([[1, 2], [2, 3], [3, 4]])
+ assert_array_equal(res, expected_res)
+ #casting and dtype with TypeError
+ with assert_raises(TypeError):
+ stack((a, b), dtype=np.int64, axis=1, casting="safe")
+
+
+@pytest.mark.parametrize("axis", [0])
+@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"])
+@pytest.mark.parametrize("casting",
+ ['no', 'equiv', 'safe', 'same_kind', 'unsafe'])
+def test_stack_out_and_dtype(axis, out_dtype, casting):
+ to_concat = (array([1, 2]), array([3, 4]))
+ res = array([[1, 2], [3, 4]])
+ out = np.zeros_like(res)
+
+ if not np.can_cast(to_concat[0], out_dtype, casting=casting):
+ with assert_raises(TypeError):
+ stack(to_concat, dtype=out_dtype,
+ axis=axis, casting=casting)
+ else:
+ res_out = stack(to_concat, out=out,
+ axis=axis, casting=casting)
+ res_dtype = stack(to_concat, dtype=out_dtype,
+ axis=axis, casting=casting)
+ assert res_out is out
+ assert_array_equal(out, res_dtype)
+ assert res_dtype.dtype == out_dtype
+
+ with assert_raises(TypeError):
+ stack(to_concat, out=out, dtype=out_dtype, axis=axis)
class TestBlock: