diff options
Diffstat (limited to 'numpy/testing/tests/test_utils.py')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 115 |
1 files changed, 63 insertions, 52 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 43afafaa8..c376a3852 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -327,24 +327,22 @@ class TestEqual(TestArrayEqual): self._test_not_equal(x, y) def test_error_message(self): - try: + with pytest.raises(AssertionError) as exc_info: self._assert_func(np.array([1, 2]), np.array([[1, 2]])) - except AssertionError as e: - msg = str(e) - msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)") - msg_reference = textwrap.dedent("""\ + msg = str(exc_info.value) + msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)") + msg_reference = textwrap.dedent("""\ - Arrays are not equal + Arrays are not equal - (shapes (2,), (1, 2) mismatch) - x: array([1, 2]) - y: array([[1, 2]])""") - try: - assert_equal(msg, msg_reference) - except AssertionError: - assert_equal(msg2, msg_reference) - else: - raise AssertionError("Did not raise") + (shapes (2,), (1, 2) mismatch) + x: array([1, 2]) + y: array([[1, 2]])""") + + try: + assert_equal(msg, msg_reference) + except AssertionError: + assert_equal(msg2, msg_reference) class TestArrayAlmostEqual(_GenericTest): @@ -509,38 +507,53 @@ class TestAlmostEqual(_GenericTest): x = np.array([1.00000000001, 2.00000000002, 3.00003]) y = np.array([1.00000000002, 2.00000000003, 3.00004]) - # test with a different amount of decimal digits - # note that we only check for the formatting of the arrays themselves - b = ('x: array([1.00000000001, 2.00000000002, 3.00003 ' - ' ])\n y: array([1.00000000002, 2.00000000003, 3.00004 ])') - try: + # Test with a different amount of decimal digits + with pytest.raises(AssertionError) as exc_info: self._assert_func(x, y, decimal=12) - except AssertionError as e: - # remove anything that's not the array string - assert_equal(str(e).split('%)\n ')[1], b) - - # with the default value of decimal digits, only the 3rd element differs - # note that we only check for the formatting of the arrays themselves - b = ('x: array([1. , 2. , 3.00003])\n y: array([1. , ' - '2. , 3.00004])') - try: + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 100%') + assert_equal(msgs[4], 'Max absolute difference: 1.e-05') + assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06') + assert_equal( + msgs[6], + ' x: array([1.00000000001, 2.00000000002, 3.00003 ])') + assert_equal( + msgs[7], + ' y: array([1.00000000002, 2.00000000003, 3.00004 ])') + + # With the default value of decimal digits, only the 3rd element + # differs. Note that we only check for the formatting of the arrays + # themselves. + with pytest.raises(AssertionError) as exc_info: self._assert_func(x, y) - except AssertionError as e: - # remove anything that's not the array string - assert_equal(str(e).split('%)\n ')[1], b) - - # Check the error message when input includes inf or nan + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 33.3%') + assert_equal(msgs[4], 'Max absolute difference: 1.e-05') + assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06') + assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])') + assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])') + + # Check the error message when input includes inf x = np.array([np.inf, 0]) y = np.array([np.inf, 1]) - try: + with pytest.raises(AssertionError) as exc_info: + self._assert_func(x, y) + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 50%') + assert_equal(msgs[4], 'Max absolute difference: 1.') + assert_equal(msgs[5], 'Max relative difference: 1.') + assert_equal(msgs[6], ' x: array([inf, 0.])') + assert_equal(msgs[7], ' y: array([inf, 1.])') + + # Check the error message when dividing by zero + x = np.array([1, 2]) + y = np.array([0, 0]) + with pytest.raises(AssertionError) as exc_info: self._assert_func(x, y) - except AssertionError as e: - msgs = str(e).split('\n') - # assert error percentage is 50% - assert_equal(msgs[3], '(mismatch 50.0%)') - # assert output array contains inf - assert_equal(msgs[4], ' x: array([inf, 0.])') - assert_equal(msgs[5], ' y: array([inf, 1.])') + msgs = str(exc_info.value).split('\n') + assert_equal(msgs[3], 'Mismatch: 100%') + assert_equal(msgs[4], 'Max absolute difference: 2') + assert_equal(msgs[5], 'Max relative difference: inf') def test_subclass_that_cannot_be_bool(self): # While we cannot guarantee testing functions will always work for @@ -829,12 +842,12 @@ class TestAssertAllclose(object): def test_report_fail_percentage(self): a = np.array([1, 1, 1, 1]) b = np.array([1, 1, 1, 2]) - try: + + with pytest.raises(AssertionError) as exc_info: assert_allclose(a, b) - msg = '' - except AssertionError as exc: - msg = exc.args[0] - assert_("mismatch 25.0%" in msg) + msg = str(exc_info.value) + assert_('Mismatch: 25%\nMax absolute difference: 1\n' + 'Max relative difference: 0.5' in msg) def test_equal_nan(self): a = np.array([np.nan]) @@ -1117,12 +1130,10 @@ class TestStringEqual(object): assert_string_equal("hello", "hello") assert_string_equal("hello\nmultiline", "hello\nmultiline") - try: + with pytest.raises(AssertionError) as exc_info: assert_string_equal("foo\nbar", "hello\nbar") - except AssertionError as exc: - assert_equal(str(exc), "Differences in strings:\n- foo\n+ hello") - else: - raise AssertionError("exception not raised") + msg = str(exc_info.value) + assert_equal(msg, "Differences in strings:\n- foo\n+ hello") assert_raises(AssertionError, lambda: assert_string_equal("foo", "hello")) |