diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/tests/test_type_check.py | 18 | ||||
-rw-r--r-- | numpy/lib/type_check.py | 21 |
2 files changed, 30 insertions, 9 deletions
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py index 8945b61ea..ce8ef2f15 100644 --- a/numpy/lib/tests/test_type_check.py +++ b/numpy/lib/tests/test_type_check.py @@ -359,6 +359,7 @@ class TestNanToNum(object): assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0])) assert_(vals[1] == 0) assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) + assert_equal(type(vals), np.ndarray) # perform the same test but in-place with np.errstate(divide='ignore', invalid='ignore'): @@ -369,16 +370,27 @@ class TestNanToNum(object): assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0])) assert_(vals[1] == 0) assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) + assert_equal(type(vals), np.ndarray) + + def test_array(self): + vals = nan_to_num([1]) + assert_array_equal(vals, np.array([1], int)) + assert_equal(type(vals), np.ndarray) def test_integer(self): vals = nan_to_num(1) assert_all(vals == 1) - vals = nan_to_num([1]) - assert_array_equal(vals, np.array([1], int)) + assert_equal(type(vals), np.int_) + + def test_float(self): + vals = nan_to_num(1.0) + assert_all(vals == 1.0) + assert_equal(type(vals), np.float_) def test_complex_good(self): vals = nan_to_num(1+1j) assert_all(vals == 1+1j) + assert_equal(type(vals), np.complex_) def test_complex_bad(self): with np.errstate(divide='ignore', invalid='ignore'): @@ -387,6 +399,7 @@ class TestNanToNum(object): vals = nan_to_num(v) # !! This is actually (unexpectedly) zero assert_all(np.isfinite(vals)) + assert_equal(type(vals), np.complex_) def test_complex_bad2(self): with np.errstate(divide='ignore', invalid='ignore'): @@ -394,6 +407,7 @@ class TestNanToNum(object): v += np.array(-1+1.j)/0. vals = nan_to_num(v) assert_all(np.isfinite(vals)) + assert_equal(type(vals), np.complex_) # Fixme #assert_all(vals.imag > 1e10) and assert_all(np.isfinite(vals)) # !! This is actually (unexpectedly) positive diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index bfb5963f2..1664e6ebb 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -330,7 +330,7 @@ def _getmaxmin(t): def nan_to_num(x, copy=True): """ - Replace nan with zero and inf with large finite numbers. + Replace NaN with zero and infinity with large finite numbers. If `x` is inexact, NaN is replaced by zero, and infinity and -infinity replaced by the respectively largest and most negative finite floating @@ -343,7 +343,7 @@ def nan_to_num(x, copy=True): Parameters ---------- - x : array_like + x : scalar or array_like Input data. copy : bool, optional Whether to create a copy of `x` (True) or to replace values @@ -374,6 +374,12 @@ def nan_to_num(x, copy=True): Examples -------- + >>> np.nan_to_num(np.inf) + 1.7976931348623157e+308 + >>> np.nan_to_num(-np.inf) + -1.7976931348623157e+308 + >>> np.nan_to_num(np.nan) + 0.0 >>> x = np.array([np.inf, -np.inf, np.nan, -128, 128]) >>> np.nan_to_num(x) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, @@ -386,20 +392,21 @@ def nan_to_num(x, copy=True): """ x = _nx.array(x, subok=True, copy=copy) xtype = x.dtype.type + + isscalar = (x.ndim == 0) + if not issubclass(xtype, _nx.inexact): - return x + return x[()] if isscalar else x iscomplex = issubclass(xtype, _nx.complexfloating) - isscalar = (x.ndim == 0) - x = x[None] if isscalar else x dest = (x.real, x.imag) if iscomplex else (x,) maxf, minf = _getmaxmin(x.real.dtype) for d in dest: _nx.copyto(d, 0.0, where=isnan(d)) _nx.copyto(d, maxf, where=isposinf(d)) _nx.copyto(d, minf, where=isneginf(d)) - return x[0] if isscalar else x + return x[()] if isscalar else x #----------------------------------------------------------------------------- @@ -579,7 +586,7 @@ def common_type(*arrays): an integer array, the minimum precision type that is returned is a 64-bit floating point dtype. - All input arrays except int64 and uint64 can be safely cast to the + All input arrays except int64 and uint64 can be safely cast to the returned dtype without loss of information. Parameters |