diff options
author | seberg <sebastian@sipsolutions.net> | 2013-11-26 10:25:00 -0800 |
---|---|---|
committer | seberg <sebastian@sipsolutions.net> | 2013-11-26 10:25:00 -0800 |
commit | f4749b7b2514db9f12978438a2131df981dc14d6 (patch) | |
tree | 847583e7f87558cec6a91090c9fd154103f92e7b | |
parent | 78e29a323316642899f8ff85e538b785f0d5e31f (diff) | |
parent | 7a497ffdecceec2d8574674a2b8b04f7927f75d4 (diff) | |
download | numpy-f4749b7b2514db9f12978438a2131df981dc14d6.tar.gz |
Merge pull request #4077 from ogrisel/streaming-numpy-save
Streaming numpy.save to arbitrary file objects
-rw-r--r-- | numpy/lib/format.py | 17 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 42 |
2 files changed, 49 insertions, 10 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index c411715d8..4ac1427b4 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -399,6 +399,10 @@ def write_array(fp, array, version=(1, 0)): raise ValueError(msg % (version,)) fp.write(magic(*version)) write_array_header_1_0(fp, header_data_from_array_1_0(array)) + + # Set buffer size to 16 MiB to hide the Python loop overhead. + buffersize = max(16 * 1024 ** 2 // array.itemsize, 1) + if array.dtype.hasobject: # We contain Python objects so we cannot write out the data directly. # Instead, we will pickle it out with version 2 of the pickle protocol. @@ -407,14 +411,19 @@ def write_array(fp, array, version=(1, 0)): if isfileobj(fp): array.T.tofile(fp) else: - fp.write(array.T.tostring('C')) + for chunk in numpy.nditer( + array, flags=['external_loop', 'buffered', 'zerosize_ok'], + buffersize=buffersize, order='F'): + fp.write(chunk.tostring('C')) else: if isfileobj(fp): array.tofile(fp) else: - # XXX: We could probably chunk this using something like - # arrayterator. - fp.write(array.tostring('C')) + for chunk in numpy.nditer( + array, flags=['external_loop', 'buffered', 'zerosize_ok'], + buffersize=buffersize, order='C'): + fp.write(chunk.tostring('C')) + def read_array(fp): """ diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 66310a509..e3ccb391c 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -104,18 +104,40 @@ class RoundtripTest(object): self.arr = arr self.arr_reloaded = arr_reloaded + def check_roundtrips(self, a): + self.roundtrip(a) + self.roundtrip(a, file_on_disk=True) + self.roundtrip(np.asfortranarray(a)) + self.roundtrip(np.asfortranarray(a), file_on_disk=True) + if a.shape[0] > 1: + # neither C nor Fortran contiguous for 2D arrays or more + self.roundtrip(np.asfortranarray(a)[1:]) + self.roundtrip(np.asfortranarray(a)[1:], file_on_disk=True) + def test_array(self): + a = np.array([], float) + self.check_roundtrips(a) + a = np.array([[1, 2], [3, 4]], float) - self.roundtrip(a) + self.check_roundtrips(a) a = np.array([[1, 2], [3, 4]], int) - self.roundtrip(a) + self.check_roundtrips(a) a = np.array([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]], dtype=np.csingle) - self.roundtrip(a) + self.check_roundtrips(a) a = np.array([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]], dtype=np.cdouble) - self.roundtrip(a) + self.check_roundtrips(a) + + def test_array_object(self): + if sys.version_info[:2] >= (2, 7): + a = np.array([], object) + self.check_roundtrips(a) + + a = np.array([[1, 2], [3, 4]], object) + self.check_roundtrips(a) + # Fails with UnpicklingError: could not find MARK on Python 2.6 def test_1D(self): a = np.array([1, 2, 3, 4], int) @@ -126,22 +148,30 @@ class RoundtripTest(object): a = np.array([[1, 2.5], [4, 7.3]]) self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'}) + a = np.asfortranarray([[1, 2.5], [4, 7.3]]) + self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'}) + def test_record(self): a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')]) - self.roundtrip(a) + self.check_roundtrips(a) class TestSaveLoad(RoundtripTest, TestCase): def roundtrip(self, *args, **kwargs): RoundtripTest.roundtrip(self, np.save, *args, **kwargs) assert_equal(self.arr[0], self.arr_reloaded) + assert_equal(self.arr[0].dtype, self.arr_reloaded.dtype) + assert_equal(self.arr[0].flags.fnc, self.arr_reloaded.flags.fnc) class TestSavezLoad(RoundtripTest, TestCase): def roundtrip(self, *args, **kwargs): RoundtripTest.roundtrip(self, np.savez, *args, **kwargs) for n, arr in enumerate(self.arr): - assert_equal(arr, self.arr_reloaded['arr_%d' % n]) + reloaded = self.arr_reloaded['arr_%d' % n] + assert_equal(arr, reloaded) + assert_equal(arr.dtype, reloaded.dtype) + assert_equal(arr.flags.fnc, reloaded.flags.fnc) @np.testing.dec.skipif(not IS_64BIT, "Works only with 64bit systems") @np.testing.dec.slow |