diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2014-02-12 11:57:27 +0100 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2014-02-16 00:00:22 +0100 |
commit | ab04e1ae0e8eca717bc7e42f3b0a60c9ff764289 (patch) | |
tree | 49ea02f820c4ee3eb484578abd0078f543ef4898 /numpy | |
parent | 58e9e27c0c110f9be1558a53fb547dc1abc76fa4 (diff) | |
download | numpy-ab04e1ae0e8eca717bc7e42f3b0a60c9ff764289.tar.gz |
BUG: Force allclose logic to use inexact type
Casting y to an inexact type fixes problems such as
abs(MIN_INT) < 0, and generally makes sense since the allclose
logic is inherently for float types.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 12 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 7 | ||||
-rw-r--r-- | numpy/ma/core.py | 25 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 4 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 8 | ||||
-rw-r--r-- | numpy/testing/utils.py | 14 |
6 files changed, 47 insertions, 23 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 6b078ae31..3b8e52e71 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -2139,6 +2139,11 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8): x = array(a, copy=False, ndmin=1) y = array(b, copy=False, ndmin=1) + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = multiarray.result_type(y, 1.) + y = array(y, dtype=dtype, copy=False) + xinf = isinf(x) yinf = isinf(y) if any(xinf) or any(yinf): @@ -2154,12 +2159,7 @@ def allclose(a, b, rtol=1.e-5, atol=1.e-8): # ignore invalid fpe's with errstate(invalid='ignore'): - if not x.dtype.kind == 'b' and not y.dtype.kind == 'b': - diff = abs(x - y) - else: - diff = x ^ y - - r = all(less_equal(diff, atol + rtol * abs(y))) + r = all(less_equal(abs(x - y), atol + rtol * abs(y))) return r diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index ac341468c..12a39a522 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1420,6 +1420,13 @@ class TestAllclose(object): assert_array_equal(y, array([0, inf])) + def test_min_int(self): + # Could make problems because of abs(min_int) == min_int + min_int = np.iinfo(np.int_).min + a = np.array([min_int], dtype=np.int_) + assert_(allclose(a, a)) + + class TestIsclose(object): rtol = 1e-5 atol = 1e-8 diff --git a/numpy/ma/core.py b/numpy/ma/core.py index c62e55c45..16df1ea76 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -6916,6 +6916,13 @@ def allclose (a, b, masked_equal=True, rtol=1e-5, atol=1e-8): """ x = masked_array(a, copy=False) y = masked_array(b, copy=False) + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = np.result_type(y, 1.) + if y.dtype != dtype: + y = masked_array(y, dtype=dtype, copy=False) + m = mask_or(getmask(x), getmask(y)) xinf = np.isinf(masked_array(x, copy=False, mask=m)).filled(False) # If we have some infs, they should fall at the same place. @@ -6923,26 +6930,20 @@ def allclose (a, b, masked_equal=True, rtol=1e-5, atol=1e-8): return False # No infs at all if not np.any(xinf): - if not x.dtype.kind == 'b' and not y.dtype.kind == 'b': - diff = umath.absolute(x - y) - else: - diff = x ^ y - - d = filled(umath.less_equal(diff, atol + rtol * umath.absolute(y)), + d = filled(umath.less_equal(umath.absolute(x - y), + atol + rtol * umath.absolute(y)), masked_equal) return np.all(d) + if not np.all(filled(x[xinf] == y[xinf], masked_equal)): return False x = x[~xinf] y = y[~xinf] - if not x.dtype.kind == 'b' and not y.dtype.kind == 'b': - diff = umath.absolute(x - y) - else: - diff = x ^ y - - d = filled(umath.less_equal(diff, atol + rtol * umath.absolute(y)), + d = filled(umath.less_equal(umath.absolute(x - y), + atol + rtol * umath.absolute(y)), masked_equal) + return np.all(d) #.............................................................................. diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 8d8e1c947..19f13a8c4 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1995,6 +1995,10 @@ class TestMaskedArrayMethods(TestCase): a[0] = 0 self.assertTrue(allclose(a, 0, masked_equal=True)) + # Test that the function works for MIN_INT integer typed arrays + a = masked_array([np.iinfo(np.int_).min], dtype=np.int_) + self.assertTrue(allclose(a, a)) + def test_allany(self): # Checks the any/all methods/functions. x = np.array([[0.13, 0.26, 0.90], diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 94fc4d655..5956a4294 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -53,6 +53,9 @@ class _GenericTest(object): a = np.array([1, 1], dtype=np.object) self._test_equal(a, 1) + def test_array_likes(self): + self._test_equal([1, 2, 3], (1, 2, 3)) + class TestArrayEqual(_GenericTest, unittest.TestCase): def setUp(self): self._assert_func = assert_array_equal @@ -373,6 +376,11 @@ class TestAssertAllclose(unittest.TestCase): assert_allclose(6, 10, rtol=0.5) self.assertRaises(AssertionError, assert_allclose, 10, 6, rtol=0.5) + def test_min_int(self): + a = np.array([np.iinfo(np.int_).min], dtype=np.int_) + # Should not raise: + assert_allclose(a, a) + class TestArrayAlmostEqualNulp(unittest.TestCase): @dec.knownfailureif(True, "Github issue #347") diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 82aa1e39c..97908c7e8 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -793,7 +793,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): y: array([ 1. , 2.33333, 5. ]) """ - from numpy.core import around, number, float_ + from numpy.core import around, number, float_, result_type, array from numpy.core.numerictypes import issubdtype from numpy.core.fromnumeric import any as npany def compare(x, y): @@ -811,17 +811,21 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): except (TypeError, NotImplementedError): pass - if x.dtype.kind == 'b' and y.dtype.kind == 'b': - z = x ^ y - else: - z = abs(x-y) + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.) + y = array(y, dtype=dtype, copy=False) + z = abs(x-y) if not issubdtype(z.dtype, number): z = z.astype(float_) # handle object arrays + return around(z, decimal) <= 10.0**(-decimal) + assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, header=('Arrays are not almost equal to %d decimals' % decimal)) + def assert_array_less(x, y, err_msg='', verbose=True): """ Raise an assertion if two array_like objects are not ordered by less than. |