diff options
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r-- | numpy/testing/_private/utils.py | 39 |
1 files changed, 14 insertions, 25 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 7aa5ef033..8a31fcf15 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -21,7 +21,6 @@ import pprint from numpy.core import( intp, float32, empty, arange, array_repr, ndarray, isnat, array) -from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: from io import StringIO @@ -33,7 +32,7 @@ __all__ = [ 'assert_array_equal', 'assert_array_less', 'assert_string_equal', 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', - 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure', + 'raises', 'rundocs', 'runstring', 'verbose', 'measure', 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', @@ -154,22 +153,6 @@ def gisinf(x): return st -@deprecate(message="numpy.testing.rand is deprecated in numpy 1.11. " - "Use numpy.random.rand instead.") -def rand(*args): - """Returns an array of random numbers with the given shape. - - This only uses the standard library, so it is useful for testing purposes. - """ - import random - from numpy.core import zeros, float64 - results = zeros(args, float64) - f = results.flat - for i in range(len(f)): - f[i] = random.random() - return results - - if os.name == 'nt': # Code "stolen" from enthought/debug/memusage.py def GetPerformanceAttributes(object, counter, instance=None, @@ -703,7 +686,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', precision=6, equal_nan=True, equal_inf=True): __tracebackhide__ = True # Hide traceback for py.test - from numpy.core import array, array2string, isnan, inf, bool_, errstate + from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ x = array(x, copy=False, subok=True) y = array(y, copy=False, subok=True) @@ -805,17 +788,18 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, # np.ma.masked, which is falsy). if cond != True: n_mismatch = reduced.size - reduced.sum(dtype=intp) - percent_mismatch = 100 * n_mismatch / ox.size + n_elements = flagged.size if flagged.ndim != 0 else reduced.size + percent_mismatch = 100 * n_mismatch / n_elements remarks = [ 'Mismatched elements: {} / {} ({:.3g}%)'.format( - n_mismatch, ox.size, percent_mismatch)] + n_mismatch, n_elements, percent_mismatch)] with errstate(invalid='ignore', divide='ignore'): # ignore errors for non-numeric types with contextlib.suppress(TypeError): error = abs(x - y) - max_abs_error = error.max() - if error.dtype == 'object': + max_abs_error = max(error) + if getattr(error, 'dtype', object_) == object_: remarks.append('Max absolute difference: ' + str(max_abs_error)) else: @@ -824,8 +808,13 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, # note: this definition of relative error matches that one # used by assert_allclose (found in np.isclose) - max_rel_error = (error / abs(y)).max() - if error.dtype == 'object': + # Filter values where the divisor would be zero + nonzero = bool_(y != 0) + if all(~nonzero): + max_rel_error = array(inf) + else: + max_rel_error = max(error[nonzero] / abs(y[nonzero])) + if getattr(error, 'dtype', object_) == object_: remarks.append('Max relative difference: ' + str(max_rel_error)) else: |