summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorseberg <sebastian@sipsolutions.net>2013-11-26 10:25:00 -0800
committerseberg <sebastian@sipsolutions.net>2013-11-26 10:25:00 -0800
commitf4749b7b2514db9f12978438a2131df981dc14d6 (patch)
tree847583e7f87558cec6a91090c9fd154103f92e7b /numpy/lib
parent78e29a323316642899f8ff85e538b785f0d5e31f (diff)
parent7a497ffdecceec2d8574674a2b8b04f7927f75d4 (diff)
downloadnumpy-f4749b7b2514db9f12978438a2131df981dc14d6.tar.gz
Merge pull request #4077 from ogrisel/streaming-numpy-save
Streaming numpy.save to arbitrary file objects
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/format.py17
-rw-r--r--numpy/lib/tests/test_io.py42
2 files changed, 49 insertions, 10 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py
index c411715d8..4ac1427b4 100644
--- a/numpy/lib/format.py
+++ b/numpy/lib/format.py
@@ -399,6 +399,10 @@ def write_array(fp, array, version=(1, 0)):
raise ValueError(msg % (version,))
fp.write(magic(*version))
write_array_header_1_0(fp, header_data_from_array_1_0(array))
+
+ # Set buffer size to 16 MiB to hide the Python loop overhead.
+ buffersize = max(16 * 1024 ** 2 // array.itemsize, 1)
+
if array.dtype.hasobject:
# We contain Python objects so we cannot write out the data directly.
# Instead, we will pickle it out with version 2 of the pickle protocol.
@@ -407,14 +411,19 @@ def write_array(fp, array, version=(1, 0)):
if isfileobj(fp):
array.T.tofile(fp)
else:
- fp.write(array.T.tostring('C'))
+ for chunk in numpy.nditer(
+ array, flags=['external_loop', 'buffered', 'zerosize_ok'],
+ buffersize=buffersize, order='F'):
+ fp.write(chunk.tostring('C'))
else:
if isfileobj(fp):
array.tofile(fp)
else:
- # XXX: We could probably chunk this using something like
- # arrayterator.
- fp.write(array.tostring('C'))
+ for chunk in numpy.nditer(
+ array, flags=['external_loop', 'buffered', 'zerosize_ok'],
+ buffersize=buffersize, order='C'):
+ fp.write(chunk.tostring('C'))
+
def read_array(fp):
"""
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 66310a509..e3ccb391c 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -104,18 +104,40 @@ class RoundtripTest(object):
self.arr = arr
self.arr_reloaded = arr_reloaded
+ def check_roundtrips(self, a):
+ self.roundtrip(a)
+ self.roundtrip(a, file_on_disk=True)
+ self.roundtrip(np.asfortranarray(a))
+ self.roundtrip(np.asfortranarray(a), file_on_disk=True)
+ if a.shape[0] > 1:
+ # neither C nor Fortran contiguous for 2D arrays or more
+ self.roundtrip(np.asfortranarray(a)[1:])
+ self.roundtrip(np.asfortranarray(a)[1:], file_on_disk=True)
+
def test_array(self):
+ a = np.array([], float)
+ self.check_roundtrips(a)
+
a = np.array([[1, 2], [3, 4]], float)
- self.roundtrip(a)
+ self.check_roundtrips(a)
a = np.array([[1, 2], [3, 4]], int)
- self.roundtrip(a)
+ self.check_roundtrips(a)
a = np.array([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]], dtype=np.csingle)
- self.roundtrip(a)
+ self.check_roundtrips(a)
a = np.array([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]], dtype=np.cdouble)
- self.roundtrip(a)
+ self.check_roundtrips(a)
+
+ def test_array_object(self):
+ if sys.version_info[:2] >= (2, 7):
+ a = np.array([], object)
+ self.check_roundtrips(a)
+
+ a = np.array([[1, 2], [3, 4]], object)
+ self.check_roundtrips(a)
+ # Fails with UnpicklingError: could not find MARK on Python 2.6
def test_1D(self):
a = np.array([1, 2, 3, 4], int)
@@ -126,22 +148,30 @@ class RoundtripTest(object):
a = np.array([[1, 2.5], [4, 7.3]])
self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'})
+ a = np.asfortranarray([[1, 2.5], [4, 7.3]])
+ self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'})
+
def test_record(self):
a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')])
- self.roundtrip(a)
+ self.check_roundtrips(a)
class TestSaveLoad(RoundtripTest, TestCase):
def roundtrip(self, *args, **kwargs):
RoundtripTest.roundtrip(self, np.save, *args, **kwargs)
assert_equal(self.arr[0], self.arr_reloaded)
+ assert_equal(self.arr[0].dtype, self.arr_reloaded.dtype)
+ assert_equal(self.arr[0].flags.fnc, self.arr_reloaded.flags.fnc)
class TestSavezLoad(RoundtripTest, TestCase):
def roundtrip(self, *args, **kwargs):
RoundtripTest.roundtrip(self, np.savez, *args, **kwargs)
for n, arr in enumerate(self.arr):
- assert_equal(arr, self.arr_reloaded['arr_%d' % n])
+ reloaded = self.arr_reloaded['arr_%d' % n]
+ assert_equal(arr, reloaded)
+ assert_equal(arr.dtype, reloaded.dtype)
+ assert_equal(arr.flags.fnc, reloaded.flags.fnc)
@np.testing.dec.skipif(not IS_64BIT, "Works only with 64bit systems")
@np.testing.dec.slow