summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-12-07 13:22:52 -0700
committerCharles Harris <charlesr.harris@gmail.com>2015-12-07 13:42:37 -0700
commit7b137ab0d52d95d5846f903a64207f5d6c0df17e (patch)
treee962afff0300b2f141fbee8c12da639247a2f407 /numpy/lib/tests/test_function_base.py
parentdb6300726ac9f73282ce8fda5588d958c0f327b4 (diff)
downloadnumpy-7b137ab0d52d95d5846f903a64207f5d6c0df17e.tar.gz
BUG: Quick and dirty fix for interp.
The original had incorrect comparisons involving <=, <, and also failed when the number of data points was 2. This fixes the use of the comparisons and uses linear search for fewer than 5 data points. The whole routine needs a simplified rewrite, but this takes care of the bug. Closes #6468.
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py40
1 files changed, 36 insertions, 4 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 88c932692..a5ac78e33 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1974,10 +1974,42 @@ class TestInterp(TestCase):
assert_almost_equal(np.interp(x0, x, y), x0)
def test_right_left_behavior(self):
- assert_equal(interp([-1, 0, 1], [0], [1]), [1, 1, 1])
- assert_equal(interp([-1, 0, 1], [0], [1], left=0), [0, 1, 1])
- assert_equal(interp([-1, 0, 1], [0], [1], right=0), [1, 1, 0])
- assert_equal(interp([-1, 0, 1], [0], [1], left=0, right=0), [0, 1, 0])
+ # Needs range of sizes to test different code paths.
+ # size ==1 is special cased, 1 < size < 5 is linear search, and
+ # size >= 5 goes through local search and possibly binary search.
+ for size in range(1, 10):
+ xp = np.arange(size, dtype=np.double)
+ yp = np.ones(size, dtype=np.double)
+ incpts = np.array([-1, 0, size - 1, size], dtype=np.double)
+ decpts = incpts[::-1]
+
+ incres = interp(incpts, xp, yp)
+ decres = interp(decpts, xp, yp)
+ inctgt = np.array([1, 1, 1, 1], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
+
+ incres = interp(incpts, xp, yp, left=0)
+ decres = interp(decpts, xp, yp, left=0)
+ inctgt = np.array([0, 1, 1, 1], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
+
+ incres = interp(incpts, xp, yp, right=2)
+ decres = interp(decpts, xp, yp, right=2)
+ inctgt = np.array([1, 1, 1, 2], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
+
+ incres = interp(incpts, xp, yp, left=0, right=2)
+ decres = interp(decpts, xp, yp, left=0, right=2)
+ inctgt = np.array([0, 1, 1, 2], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
def test_scalar_interpolation_point(self):
x = np.linspace(0, 1, 5)