summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorWarren Weckesser <warren.weckesser@gmail.com>2019-10-08 11:05:26 -0400
committerWarren Weckesser <warren.weckesser@gmail.com>2019-10-08 11:05:26 -0400
commit5e9b5ec6352bee4b96cd4bbedfa17413111462b3 (patch)
tree11c1c8308e4a337b696bda257fb7996a3f9b4721 /numpy
parentd0b0c609cc614f3bc82a7cfcb98e34e939a3e8de (diff)
downloadnumpy-5e9b5ec6352bee4b96cd4bbedfa17413111462b3.tar.gz
BUG: random: Use correct length when axis is given to shuffle.
When an axis argument was given, shuffle was using the original length of the array instead of the length of the given axis. This meant that, for example, if an array with shape (2, 10) was shuffled with axis=1, only the first two columns were shuffled.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/generator.pyx6
-rw-r--r--numpy/random/tests/test_generator_mt19937.py9
2 files changed, 12 insertions, 3 deletions
diff --git a/numpy/random/generator.pyx b/numpy/random/generator.pyx
index 37ac57c06..df7485a97 100644
--- a/numpy/random/generator.pyx
+++ b/numpy/random/generator.pyx
@@ -3786,7 +3786,7 @@ cdef class Generator:
# Shuffling and permutations:
def shuffle(self, object x, axis=0):
"""
- shuffle(x)
+ shuffle(x, axis=0)
Modify a sequence in-place by shuffling its contents.
@@ -3858,7 +3858,7 @@ cdef class Generator:
x = np.swapaxes(x, 0, axis)
buf = np.empty_like(x[0, ...])
with self.lock:
- for i in reversed(range(1, n)):
+ for i in reversed(range(1, len(x))):
j = random_interval(&self._bitgen, i)
if i == j:
# i == j is not needed and memcpy is undefined.
@@ -3928,7 +3928,7 @@ cdef class Generator:
def permutation(self, object x, axis=0):
"""
- permutation(x)
+ permutation(x, axis=0)
Randomly permute a sequence, or return a permuted range.
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 20bc10cd0..391c33c1a 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -746,6 +746,15 @@ class TestRandomDist(object):
random.shuffle(actual, axis=-1)
assert_array_equal(actual, desired)
+ def test_shuffle_axis_nonsquare(self):
+ y1 = np.arange(20).reshape(2, 10)
+ y2 = y1.copy()
+ random = Generator(MT19937(self.seed))
+ random.shuffle(y1, axis=1)
+ random = Generator(MT19937(self.seed))
+ random.shuffle(y2.T)
+ assert_array_equal(y1, y2)
+
def test_shuffle_masked(self):
# gh-3263
a = np.ma.masked_values(np.reshape(range(20), (5, 4)) % 3 - 1, -1)