summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorWarren Weckesser <warren.weckesser@gmail.com>2020-09-02 13:36:37 -0400
committerGitHub <noreply@github.com>2020-09-02 20:36:37 +0300
commitf1d0378d3735aa97a86af3b87b8a9d9ee4575af7 (patch)
tree4c5b6fb284fb9e26ac05202ccf67da6bb6e7dea7 /numpy/random/tests
parentf7cb42d52f121656b0a96ad5a9bc5c44506d2cce (diff)
downloadnumpy-f1d0378d3735aa97a86af3b87b8a9d9ee4575af7.tar.gz
ENH: random: Add the method `permuted` to Generator. (#15121)
* ENH: random: Make _shuffle_raw and _shuffle_int standalone functions. * ENH: random: Add the method `permuted` to Generator. The method permuted(x, axis=None, out=None) shuffles an array. Unlike the existing shuffle method, it shuffles the slices along the given axis independently. Closes gh-5173.
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index bb6d25ef1..6be7d852b 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -1039,6 +1039,56 @@ class TestRandomDist:
assert_raises(np.AxisError, random.permutation, arr, 3)
assert_raises(TypeError, random.permutation, arr, slice(1, 2, None))
+ @pytest.mark.parametrize("dtype", [int, object])
+ @pytest.mark.parametrize("axis, expected",
+ [(None, np.array([[3, 7, 0, 9, 10, 11],
+ [8, 4, 2, 5, 1, 6]])),
+ (0, np.array([[6, 1, 2, 9, 10, 11],
+ [0, 7, 8, 3, 4, 5]])),
+ (1, np.array([[ 5, 3, 4, 0, 2, 1],
+ [11, 9, 10, 6, 8, 7]]))])
+ def test_permuted(self, dtype, axis, expected):
+ random = Generator(MT19937(self.seed))
+ x = np.arange(12).reshape(2, 6).astype(dtype)
+ random.permuted(x, axis=axis, out=x)
+ assert_array_equal(x, expected)
+
+ random = Generator(MT19937(self.seed))
+ x = np.arange(12).reshape(2, 6).astype(dtype)
+ y = random.permuted(x, axis=axis)
+ assert y.dtype == dtype
+ assert_array_equal(y, expected)
+
+ def test_permuted_with_strides(self):
+ random = Generator(MT19937(self.seed))
+ x0 = np.arange(22).reshape(2, 11)
+ x1 = x0.copy()
+ x = x0[:, ::3]
+ y = random.permuted(x, axis=1, out=x)
+ expected = np.array([[0, 9, 3, 6],
+ [14, 20, 11, 17]])
+ assert_array_equal(y, expected)
+ x1[:, ::3] = expected
+ # Verify that the original x0 was modified in-place as expected.
+ assert_array_equal(x1, x0)
+
+ def test_permuted_empty(self):
+ y = random.permuted([])
+ assert_array_equal(y, [])
+
+ @pytest.mark.parametrize('outshape', [(2, 3), 5])
+ def test_permuted_out_with_wrong_shape(self, outshape):
+ a = np.array([1, 2, 3])
+ out = np.zeros(outshape, dtype=a.dtype)
+ with pytest.raises(ValueError, match='same shape'):
+ random.permuted(a, out=out)
+
+ def test_permuted_out_with_wrong_type(self):
+ out = np.zeros((3, 5), dtype=np.int32)
+ x = np.ones((3, 5))
+ with pytest.raises(TypeError, match='Cannot cast'):
+ random.permuted(x, axis=1, out=out)
+
def test_beta(self):
random = Generator(MT19937(self.seed))
actual = random.beta(.1, .9, size=(3, 2))