diff options
author | Warren Weckesser <warren.weckesser@gmail.com> | 2019-10-08 11:05:26 -0400 |
---|---|---|
committer | Warren Weckesser <warren.weckesser@gmail.com> | 2019-10-08 11:05:26 -0400 |
commit | 5e9b5ec6352bee4b96cd4bbedfa17413111462b3 (patch) | |
tree | 11c1c8308e4a337b696bda257fb7996a3f9b4721 /numpy | |
parent | d0b0c609cc614f3bc82a7cfcb98e34e939a3e8de (diff) | |
download | numpy-5e9b5ec6352bee4b96cd4bbedfa17413111462b3.tar.gz |
BUG: random: Use correct length when axis is given to shuffle.
When an axis argument was given, shuffle was using the original length of
the array instead of the length of the given axis. This meant that, for
example, if an array with shape (2, 10) was shuffled with axis=1, only the
first two columns were shuffled.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/random/generator.pyx | 6 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 9 |
2 files changed, 12 insertions, 3 deletions
diff --git a/numpy/random/generator.pyx b/numpy/random/generator.pyx index 37ac57c06..df7485a97 100644 --- a/numpy/random/generator.pyx +++ b/numpy/random/generator.pyx @@ -3786,7 +3786,7 @@ cdef class Generator: # Shuffling and permutations: def shuffle(self, object x, axis=0): """ - shuffle(x) + shuffle(x, axis=0) Modify a sequence in-place by shuffling its contents. @@ -3858,7 +3858,7 @@ cdef class Generator: x = np.swapaxes(x, 0, axis) buf = np.empty_like(x[0, ...]) with self.lock: - for i in reversed(range(1, n)): + for i in reversed(range(1, len(x))): j = random_interval(&self._bitgen, i) if i == j: # i == j is not needed and memcpy is undefined. @@ -3928,7 +3928,7 @@ cdef class Generator: def permutation(self, object x, axis=0): """ - permutation(x) + permutation(x, axis=0) Randomly permute a sequence, or return a permuted range. diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 20bc10cd0..391c33c1a 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -746,6 +746,15 @@ class TestRandomDist(object): random.shuffle(actual, axis=-1) assert_array_equal(actual, desired) + def test_shuffle_axis_nonsquare(self): + y1 = np.arange(20).reshape(2, 10) + y2 = y1.copy() + random = Generator(MT19937(self.seed)) + random.shuffle(y1, axis=1) + random = Generator(MT19937(self.seed)) + random.shuffle(y2.T) + assert_array_equal(y1, y2) + def test_shuffle_masked(self): # gh-3263 a = np.ma.masked_values(np.reshape(range(20), (5, 4)) % 3 - 1, -1) |