summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOhad Ravid <ohad.rv@gmail.com>2021-03-24 15:45:16 +0200
committerGitHub <noreply@github.com>2021-03-24 08:45:16 -0500
commitbadbf70324274bdb4299d8c64d3d83a26be2d4c0 (patch)
treee83ec5e020d4d667cf1e95c9e9817c39705676bc
parent2832caab3bea702bdabe944ec21487bc93e94d59 (diff)
downloadnumpy-badbf70324274bdb4299d8c64d3d83a26be2d4c0.tar.gz
ENH: Improve performance of `np.save` for small arrays (gh-18657)
* ENH: Remove call to `_filter_header` from `_write_array_header` Improve performance of `np.save` by removing the call when writing the header, as it is known to be done in Python 3. * ENH: Only call `_filter_header` from `_read_array_header` for old vers Improve performance of `np.load` for arrays with version >= (3,0) by removing the call, as it is known to be done in Python 3. * ENH: Use a set of keys when checking `read_array` Improve performance of `np.load`. * DOC: Improve performance of `np.{save,load}` for small arrays
-rw-r--r--doc/release/upcoming_changes/18657.performance.rst10
-rw-r--r--numpy/lib/format.py16
2 files changed, 20 insertions, 6 deletions
diff --git a/doc/release/upcoming_changes/18657.performance.rst b/doc/release/upcoming_changes/18657.performance.rst
new file mode 100644
index 000000000..b9d436725
--- /dev/null
+++ b/doc/release/upcoming_changes/18657.performance.rst
@@ -0,0 +1,10 @@
+Improve performance of ``np.save`` and ``np.load`` for small arrays
+-------------------------------------------------------------------
+``np.save`` is now a lot faster for small arrays.
+
+``np.load`` is also faster for small arrays,
+but only when serializing with a version >= `(3, 0)`.
+
+Both are done by removing checks that are only relevant for Python 2,
+while still maintaining compatibility with arrays
+which might have been created by Python 2.
diff --git a/numpy/lib/format.py b/numpy/lib/format.py
index 904c32cc7..ead6a0420 100644
--- a/numpy/lib/format.py
+++ b/numpy/lib/format.py
@@ -173,6 +173,7 @@ from numpy.compat import (
__all__ = []
+EXPECTED_KEYS = {'descr', 'fortran_order', 'shape'}
MAGIC_PREFIX = b'\x93NUMPY'
MAGIC_LEN = len(MAGIC_PREFIX) + 2
ARRAY_ALIGN = 64 # plausible values are powers of 2 between 16 and 4096
@@ -432,7 +433,6 @@ def _write_array_header(fp, d, version=None):
header.append("'%s': %s, " % (key, repr(value)))
header.append("}")
header = "".join(header)
- header = _filter_header(header)
if version is None:
header = _wrap_header_guess_version(header)
else:
@@ -590,7 +590,10 @@ def _read_array_header(fp, version):
# "shape" : tuple of int
# "fortran_order" : bool
# "descr" : dtype.descr
- header = _filter_header(header)
+ # Versions (2, 0) and (1, 0) could have been created by a Python 2
+ # implementation before header filtering was implemented.
+ if version <= (2, 0):
+ header = _filter_header(header)
try:
d = safe_eval(header)
except SyntaxError as e:
@@ -599,14 +602,15 @@ def _read_array_header(fp, version):
if not isinstance(d, dict):
msg = "Header is not a dictionary: {!r}"
raise ValueError(msg.format(d))
- keys = sorted(d.keys())
- if keys != ['descr', 'fortran_order', 'shape']:
+
+ if EXPECTED_KEYS != d.keys():
+ keys = sorted(d.keys())
msg = "Header does not contain the correct keys: {!r}"
- raise ValueError(msg.format(keys))
+ raise ValueError(msg.format(d.keys()))
# Sanity-check the values.
if (not isinstance(d['shape'], tuple) or
- not numpy.all([isinstance(x, int) for x in d['shape']])):
+ not all(isinstance(x, int) for x in d['shape'])):
msg = "shape is not valid: {!r}"
raise ValueError(msg.format(d['shape']))
if not isinstance(d['fortran_order'], bool):