summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-07-04 18:54:17 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-07-04 19:00:59 +0200
commitd6c7a16af4de5fed3aba4dd0370d797afad1b8b1 (patch)
tree2e4a256633c2c9ff25ae1a1020be51a516e40f09 /numpy
parente8d13740980189a255c3ca31ee33b4e390c2ed75 (diff)
downloadnumpy-d6c7a16af4de5fed3aba4dd0370d797afad1b8b1.tar.gz
BUG: wrong selection for orders falling into equal ranges
when orders are selected where the kth element falls into an equal range the the last stored pivot was not the kth element, this leads to losing the ordering of smaller orders as following selection steps can start at index 0 again instead of the at the offset of the last selection. Closes gh-4836
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/npysort/selection.c.src10
-rw-r--r--numpy/core/tests/test_multiarray.py18
2 files changed, 25 insertions, 3 deletions
diff --git a/numpy/core/src/npysort/selection.c.src b/numpy/core/src/npysort/selection.c.src
index 920c07ec6..4167b2694 100644
--- a/numpy/core/src/npysort/selection.c.src
+++ b/numpy/core/src/npysort/selection.c.src
@@ -390,7 +390,10 @@ int
/* move pivot into position */
SWAP(SORTEE(low), SORTEE(hh));
- store_pivot(hh, kth, pivots, npiv);
+ /* kth pivot stored later */
+ if (hh != kth) {
+ store_pivot(hh, kth, pivots, npiv);
+ }
if (hh >= kth)
high = hh - 1;
@@ -400,10 +403,11 @@ int
/* two elements */
if (high == low + 1) {
- if (@TYPE@_LT(v[IDX(high)], v[IDX(low)]))
+ if (@TYPE@_LT(v[IDX(high)], v[IDX(low)])) {
SWAP(SORTEE(high), SORTEE(low))
- store_pivot(low, kth, pivots, npiv);
+ }
}
+ store_pivot(kth, kth, pivots, npiv);
return 0;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 2e40a2b7c..cb5c0095c 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1356,6 +1356,12 @@ class TestMethods(TestCase):
d[i:].partition(0, kind=k)
assert_array_equal(d, tgt)
+ d = np.array([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 9])
+ kth = [0, 3, 19, 20]
+ assert_equal(np.partition(d, kth, kind=k)[kth], (0, 3, 7, 7))
+ assert_equal(d[np.argpartition(d, kth, kind=k)][kth], (0, 3, 7, 7))
+
d = np.array([2, 1])
d.partition(0, kind=k)
assert_raises(ValueError, d.partition, 2)
@@ -1551,6 +1557,18 @@ class TestMethods(TestCase):
assert_raises(ValueError, d.partition, 2, kind=k)
assert_raises(ValueError, d.argpartition, 2, kind=k)
+ def test_partition_fuzz(self):
+ # a few rounds of random data testing
+ for j in range(10, 30):
+ for i in range(1, j - 2):
+ d = np.arange(j)
+ np.random.shuffle(d)
+ d = d % np.random.randint(2, 30)
+ idx = np.random.randint(d.size)
+ kth = [0, idx, i, i + 1]
+ tgt = np.sort(d)[kth]
+ assert_array_equal(np.partition(d, kth)[kth], tgt,
+ err_msg="data: %r\n kth: %r" % (d, kth))
def test_flatten(self):
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)