diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2007-08-20 13:46:55 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2007-08-20 13:46:55 +0000 |
commit | 30c6bcab17dc43b9683ff79bca99f1b37b0f70e1 (patch) | |
tree | 20b6e0649ad14983cce91fe4c6cb806f4dbbd9c5 | |
parent | 03064002eff18e65516e3f8885d886c45e65045b (diff) | |
download | numpy-30c6bcab17dc43b9683ff79bca99f1b37b0f70e1.tar.gz |
Fix allclose and add tests (based on a patch by Matthew Brett).
-rw-r--r-- | numpy/core/numeric.py | 21 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 57 |
2 files changed, 64 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b4709b392..4c3047720 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -835,21 +835,16 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8): """ x = array(a, copy=False) y = array(b, copy=False) - d1 = less_equal(absolute(x-y), atol + rtol * absolute(y)) xinf = isinf(x) - yinf = isinf(y) - if (not xinf.any() and not yinf.any()): - return d1.all() - d3 = (x[xinf] == y[yinf]) - d4 = (~xinf & ~yinf) - if d3.size < 2: - if d3.size==0: - return False - return d3 - if d3.all(): - return d1[d4].all() - else: + if not all(xinf == isinf(y)): + return False + if not any(xinf): + return all(less_equal(absolute(x-y), atol + rtol * absolute(y))) + if not all(x[xinf] == y[xinf]): return False + x = x[~xinf] + y = y[~xinf] + return all(less_equal(absolute(x-y), atol + rtol * absolute(y))) def array_equal(a1, a2): try: diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index f4c4431b6..3e06b4747 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -668,7 +668,62 @@ class test_clip(NumpyTestCase): self.clip(a, m, M, ac) assert_array_strict_equal(a, ac) - +class test_allclose_inf(ParametricTestCase): + rtol = 1e-5 + atol = 1e-8 + + def tst_allclose(self,x,y): + assert allclose(x,y), "%s and %s not close" % (x,y) + + def tst_not_allclose(self,x,y): + assert not allclose(x,y), "%s and %s shouldn't be close" % (x,y) + + def testip_allclose(self): + """Parametric test factory.""" + arr = array([100,1000]) + aran = arange(125).reshape((5,5,5)) + + atol = self.atol + rtol = self.rtol + + data = [([1,0], [1,0]), + ([atol], [0]), + ([1], [1+rtol+atol]), + (arr, arr + arr*rtol), + (arr, arr + arr*rtol + atol*2), + (aran, aran + aran*rtol),] + + for (x,y) in data: + yield (self.tst_allclose,x,y) + + def testip_not_allclose(self): + """Parametric test factory.""" + aran = arange(125).reshape((5,5,5)) + + atol = self.atol + rtol = self.rtol + + data = [([inf,0], [1,inf]), + ([inf,0], [1,0]), + ([inf,inf], [1,inf]), + ([inf,inf], [1,0]), + ([-inf, 0], [inf, 0]), + ([nan,0], [nan,0]), + ([atol*2], [0]), + ([1], [1+rtol+atol*2]), + (aran, aran + aran*atol + atol*2), + (array([inf,1]), array([0,inf]))] + + for (x,y) in data: + yield (self.tst_not_allclose,x,y) + + def test_no_parameter_modification(self): + x = array([inf,1]) + y = array([0,inf]) + allclose(x,y) + assert_array_equal(x,array([inf,1])) + assert_array_equal(y,array([0,inf])) + import sys if sys.version_info[:2] >= (2, 5): set_local_path() |