summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/tests/test_type_check.py18
-rw-r--r--numpy/lib/type_check.py21
2 files changed, 30 insertions, 9 deletions
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py
index 8945b61ea..ce8ef2f15 100644
--- a/numpy/lib/tests/test_type_check.py
+++ b/numpy/lib/tests/test_type_check.py
@@ -359,6 +359,7 @@ class TestNanToNum(object):
assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0]))
assert_(vals[1] == 0)
assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
+ assert_equal(type(vals), np.ndarray)
# perform the same test but in-place
with np.errstate(divide='ignore', invalid='ignore'):
@@ -369,16 +370,27 @@ class TestNanToNum(object):
assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0]))
assert_(vals[1] == 0)
assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
+ assert_equal(type(vals), np.ndarray)
+
+ def test_array(self):
+ vals = nan_to_num([1])
+ assert_array_equal(vals, np.array([1], int))
+ assert_equal(type(vals), np.ndarray)
def test_integer(self):
vals = nan_to_num(1)
assert_all(vals == 1)
- vals = nan_to_num([1])
- assert_array_equal(vals, np.array([1], int))
+ assert_equal(type(vals), np.int_)
+
+ def test_float(self):
+ vals = nan_to_num(1.0)
+ assert_all(vals == 1.0)
+ assert_equal(type(vals), np.float_)
def test_complex_good(self):
vals = nan_to_num(1+1j)
assert_all(vals == 1+1j)
+ assert_equal(type(vals), np.complex_)
def test_complex_bad(self):
with np.errstate(divide='ignore', invalid='ignore'):
@@ -387,6 +399,7 @@ class TestNanToNum(object):
vals = nan_to_num(v)
# !! This is actually (unexpectedly) zero
assert_all(np.isfinite(vals))
+ assert_equal(type(vals), np.complex_)
def test_complex_bad2(self):
with np.errstate(divide='ignore', invalid='ignore'):
@@ -394,6 +407,7 @@ class TestNanToNum(object):
v += np.array(-1+1.j)/0.
vals = nan_to_num(v)
assert_all(np.isfinite(vals))
+ assert_equal(type(vals), np.complex_)
# Fixme
#assert_all(vals.imag > 1e10) and assert_all(np.isfinite(vals))
# !! This is actually (unexpectedly) positive
diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py
index bfb5963f2..1664e6ebb 100644
--- a/numpy/lib/type_check.py
+++ b/numpy/lib/type_check.py
@@ -330,7 +330,7 @@ def _getmaxmin(t):
def nan_to_num(x, copy=True):
"""
- Replace nan with zero and inf with large finite numbers.
+ Replace NaN with zero and infinity with large finite numbers.
If `x` is inexact, NaN is replaced by zero, and infinity and -infinity
replaced by the respectively largest and most negative finite floating
@@ -343,7 +343,7 @@ def nan_to_num(x, copy=True):
Parameters
----------
- x : array_like
+ x : scalar or array_like
Input data.
copy : bool, optional
Whether to create a copy of `x` (True) or to replace values
@@ -374,6 +374,12 @@ def nan_to_num(x, copy=True):
Examples
--------
+ >>> np.nan_to_num(np.inf)
+ 1.7976931348623157e+308
+ >>> np.nan_to_num(-np.inf)
+ -1.7976931348623157e+308
+ >>> np.nan_to_num(np.nan)
+ 0.0
>>> x = np.array([np.inf, -np.inf, np.nan, -128, 128])
>>> np.nan_to_num(x)
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000,
@@ -386,20 +392,21 @@ def nan_to_num(x, copy=True):
"""
x = _nx.array(x, subok=True, copy=copy)
xtype = x.dtype.type
+
+ isscalar = (x.ndim == 0)
+
if not issubclass(xtype, _nx.inexact):
- return x
+ return x[()] if isscalar else x
iscomplex = issubclass(xtype, _nx.complexfloating)
- isscalar = (x.ndim == 0)
- x = x[None] if isscalar else x
dest = (x.real, x.imag) if iscomplex else (x,)
maxf, minf = _getmaxmin(x.real.dtype)
for d in dest:
_nx.copyto(d, 0.0, where=isnan(d))
_nx.copyto(d, maxf, where=isposinf(d))
_nx.copyto(d, minf, where=isneginf(d))
- return x[0] if isscalar else x
+ return x[()] if isscalar else x
#-----------------------------------------------------------------------------
@@ -579,7 +586,7 @@ def common_type(*arrays):
an integer array, the minimum precision type that is returned is a
64-bit floating point dtype.
- All input arrays except int64 and uint64 can be safely cast to the
+ All input arrays except int64 and uint64 can be safely cast to the
returned dtype without loss of information.
Parameters