summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorJoshua Loyal <joshua.d.loyal@gmail.com>2017-02-19 12:04:27 -0500
committerJoshua Loyal <joshua.d.loyal@gmail.com>2017-02-20 11:51:39 -0500
commit89944e80e6fc9bc47d1666b9a7572827138f90e3 (patch)
treef96157cea8d789a13b65e7a3f8e31e03b3b5d555 /numpy/lib
parenteda7009cf14a9b8e9b03ddd5a8ec369646c8525d (diff)
downloadnumpy-89944e80e6fc9bc47d1666b9a7572827138f90e3.tar.gz
ENH: Allow for an in-place nan_to_num conversion. Fixes #8634
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/tests/test_type_check.py10
-rw-r--r--numpy/lib/type_check.py11
2 files changed, 19 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py
index 4523e3f24..473b558be 100644
--- a/numpy/lib/tests/test_type_check.py
+++ b/numpy/lib/tests/test_type_check.py
@@ -320,6 +320,16 @@ class TestNanToNum(TestCase):
assert_(vals[1] == 0)
assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
+ # perform the same test but in-place
+ with np.errstate(divide='ignore', invalid='ignore'):
+ vals = np.array((-1., 0, 1))/0.
+ result = nan_to_num(vals, copy=False)
+
+ assert_(result is vals)
+ assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0]))
+ assert_(vals[1] == 0)
+ assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
+
def test_integer(self):
vals = nan_to_num(1)
assert_all(vals == 1)
diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py
index 3bbee0258..a59fe3cc4 100644
--- a/numpy/lib/type_check.py
+++ b/numpy/lib/type_check.py
@@ -315,7 +315,7 @@ def _getmaxmin(t):
f = getlimits.finfo(t)
return f.max, f.min
-def nan_to_num(x):
+def nan_to_num(x, copy=True):
"""
Replace nan with zero and inf with finite numbers.
@@ -327,6 +327,13 @@ def nan_to_num(x):
----------
x : array_like
Input data.
+ copy : bool, optional
+ Whether to create a copy of `x` (True) or to replace values
+ in-place (False). The in-place operation only occurs if
+ casting to an array does not require a copy.
+ Default is True.
+
+ .. versionadded:: 1.13
Returns
-------
@@ -361,7 +368,7 @@ def nan_to_num(x):
-1.28000000e+002, 1.28000000e+002])
"""
- x = _nx.array(x, subok=True)
+ x = _nx.array(x, subok=True, copy=copy)
xtype = x.dtype.type
if not issubclass(xtype, _nx.inexact):
return x