diff options
author | Joshua Loyal <joshua.d.loyal@gmail.com> | 2017-02-19 12:04:27 -0500 |
---|---|---|
committer | Joshua Loyal <joshua.d.loyal@gmail.com> | 2017-02-20 11:51:39 -0500 |
commit | 89944e80e6fc9bc47d1666b9a7572827138f90e3 (patch) | |
tree | f96157cea8d789a13b65e7a3f8e31e03b3b5d555 /numpy/lib | |
parent | eda7009cf14a9b8e9b03ddd5a8ec369646c8525d (diff) | |
download | numpy-89944e80e6fc9bc47d1666b9a7572827138f90e3.tar.gz |
ENH: Allow for an in-place nan_to_num conversion. Fixes #8634
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/tests/test_type_check.py | 10 | ||||
-rw-r--r-- | numpy/lib/type_check.py | 11 |
2 files changed, 19 insertions, 2 deletions
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py index 4523e3f24..473b558be 100644 --- a/numpy/lib/tests/test_type_check.py +++ b/numpy/lib/tests/test_type_check.py @@ -320,6 +320,16 @@ class TestNanToNum(TestCase): assert_(vals[1] == 0) assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) + # perform the same test but in-place + with np.errstate(divide='ignore', invalid='ignore'): + vals = np.array((-1., 0, 1))/0. + result = nan_to_num(vals, copy=False) + + assert_(result is vals) + assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0])) + assert_(vals[1] == 0) + assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) + def test_integer(self): vals = nan_to_num(1) assert_all(vals == 1) diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index 3bbee0258..a59fe3cc4 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -315,7 +315,7 @@ def _getmaxmin(t): f = getlimits.finfo(t) return f.max, f.min -def nan_to_num(x): +def nan_to_num(x, copy=True): """ Replace nan with zero and inf with finite numbers. @@ -327,6 +327,13 @@ def nan_to_num(x): ---------- x : array_like Input data. + copy : bool, optional + Whether to create a copy of `x` (True) or to replace values + in-place (False). The in-place operation only occurs if + casting to an array does not require a copy. + Default is True. + + .. versionadded:: 1.13 Returns ------- @@ -361,7 +368,7 @@ def nan_to_num(x): -1.28000000e+002, 1.28000000e+002]) """ - x = _nx.array(x, subok=True) + x = _nx.array(x, subok=True, copy=copy) xtype = x.dtype.type if not issubclass(xtype, _nx.inexact): return x |