diff options
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/_private/utils.py | 17 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 1 |
2 files changed, 14 insertions, 4 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index c22348103..ee8eac9e8 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -20,7 +20,7 @@ from warnings import WarningMessage import pprint from numpy.core import( - float32, empty, arange, array_repr, ndarray, isnat, array) + intp, float32, empty, arange, array_repr, ndarray, isnat, array) from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: @@ -301,6 +301,11 @@ def assert_equal(actual, desired, err_msg='', verbose=True): check that all elements of these objects are equal. An exception is raised at the first conflicting values. + This function handles NaN comparisons as if NaN was a "normal" number. + That is, no assertion is raised if both objects have NaNs in the same + positions. This is in contrast to the IEEE standard on NaNs, which says + that NaN compared to anything must return False. + Parameters ---------- actual : array_like @@ -328,6 +333,11 @@ def assert_equal(actual, desired, err_msg='', verbose=True): ACTUAL: 5 DESIRED: 6 + The following comparison does not raise an exception. There are NaNs + in the inputs, but they are in the same positions. + + >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) + """ __tracebackhide__ = True # Hide traceback for py.test if isinstance(desired, dict): @@ -784,18 +794,17 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, if isinstance(val, bool): cond = val - reduced = [0] + reduced = array([val]) else: reduced = val.ravel() cond = reduced.all() - reduced = reduced.tolist() # The below comparison is a hack to ensure that fully masked # results, for which val.ravel().all() returns np.ma.masked, # do not trigger a failure (np.ma.masked != True evaluates as # np.ma.masked, which is falsy). if cond != True: - mismatch = 100.0 * reduced.count(0) / ox.size + mismatch = 100. * (reduced.size - reduced.sum(dtype=intp)) / ox.size remarks = ['Mismatch: {:.3g}%'.format(mismatch)] with errstate(invalid='ignore', divide='ignore'): diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 9081f3d6e..643d143ee 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -1498,6 +1498,7 @@ class TestAssertNoGcCycles(object): with assert_raises(AssertionError): assert_no_gc_cycles(make_cycle) + @pytest.mark.slow def test_fails(self): """ Test that in cases where the garbage cannot be collected, we raise an |