summaryrefslogtreecommitdiff
path: root/numpy/lib/npyio.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@gmail.com>2018-10-22 17:40:08 -0700
committerGitHub <noreply@github.com>2018-10-22 17:40:08 -0700
commit73151451437fa6ce0d8b5f033c1e005885f63cf8 (patch)
tree558bbcfb4ef9caef87edca554ebbf4f090b6cc17 /numpy/lib/npyio.py
parent2bdb732eb7d04a6016f5e18efe02365f47a5c6b1 (diff)
downloadnumpy-73151451437fa6ce0d8b5f033c1e005885f63cf8.tar.gz
ENH: __array_function__ support for np.lib, part 2/2 (#12119)
* ENH: __array_function__ support for np.lib, part 2 xref GH12028 np.lib.npyio through np.lib.ufunclike * Fix failures in numpy/core/tests/test_overrides.py * CLN: handle depreaction in dispatchers for np.lib.ufunclike * CLN: fewer dispatchers in lib.twodim_base * CLN: fewer dispatchers in lib.shape_base * CLN: more dispatcher consolidation * BUG: fix test failure * Use all method instead of function in assert_equal * DOC: indicate n is array_like in scimath.logn * MAINT: updates per review * MAINT: more conservative changes in assert_array_equal * MAINT: add back in comment * MAINT: casting tweaks in assert_array_equal * MAINT: fixes and tests for assert_array_equal on subclasses
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r--numpy/lib/npyio.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 62fc9c5b3..733795671 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -12,6 +12,7 @@ import numpy as np
from . import format
from ._datasource import DataSource
from numpy.core.multiarray import packbits, unpackbits
+from numpy.core.overrides import array_function_dispatch
from numpy.core._internal import recursive
from ._iotools import (
LineSplitter, NameValidator, StringConverter, ConverterError,
@@ -447,6 +448,11 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True,
fid.close()
+def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None):
+ return (arr,)
+
+
+@array_function_dispatch(_save_dispatcher)
def save(file, arr, allow_pickle=True, fix_imports=True):
"""
Save an array to a binary file in NumPy ``.npy`` format.
@@ -525,6 +531,14 @@ def save(file, arr, allow_pickle=True, fix_imports=True):
fid.close()
+def _savez_dispatcher(file, *args, **kwds):
+ for a in args:
+ yield a
+ for v in kwds.values():
+ yield v
+
+
+@array_function_dispatch(_savez_dispatcher)
def savez(file, *args, **kwds):
"""
Save several arrays into a single file in uncompressed ``.npz`` format.
@@ -604,6 +618,14 @@ def savez(file, *args, **kwds):
_savez(file, args, kwds, False)
+def _savez_compressed_dispatcher(file, *args, **kwds):
+ for a in args:
+ yield a
+ for v in kwds.values():
+ yield v
+
+
+@array_function_dispatch(_savez_compressed_dispatcher)
def savez_compressed(file, *args, **kwds):
"""
Save several arrays into a single file in compressed ``.npz`` format.
@@ -1154,6 +1176,13 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None,
return X
+def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None,
+ header=None, footer=None, comments=None,
+ encoding=None):
+ return (X,)
+
+
+@array_function_dispatch(_savetxt_dispatcher)
def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
footer='', comments='# ', encoding=None):
"""