summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-02-06 12:24:16 -0700
committerGitHub <noreply@github.com>2021-02-06 12:24:16 -0700
commitcba30dbc4012f0968e25d4da3b48f3d9d745aa00 (patch)
tree4a8dcb99e6092709f3128d5cc20938b150317a06
parent32b564ade7ef22439b5f2b9c11aa4c63f0ecd6fd (diff)
parentb765975ca89457fea91a63b94075c9e66465e2ea (diff)
downloadnumpy-cba30dbc4012f0968e25d4da3b48f3d9d745aa00.tar.gz
Merge pull request #18332 from seberg/issue-18325
BUG: Allow pickling all relevant DType types/classes
-rw-r--r--numpy/core/__init__.py22
-rw-r--r--numpy/core/tests/test_dtype.py16
2 files changed, 36 insertions, 2 deletions
diff --git a/numpy/core/__init__.py b/numpy/core/__init__.py
index f22c86f59..dad9293e1 100644
--- a/numpy/core/__init__.py
+++ b/numpy/core/__init__.py
@@ -125,6 +125,7 @@ def _ufunc_reconstruct(module, name):
mod = __import__(module, fromlist=[name])
return getattr(mod, name)
+
def _ufunc_reduce(func):
# Report the `__name__`. pickle will try to find the module. Note that
# pickle supports for this `__name__` to be a `__qualname__`. It may
@@ -134,12 +135,31 @@ def _ufunc_reduce(func):
return func.__name__
+def _DType_reconstruct(scalar_type):
+ # This is a work-around to pickle type(np.dtype(np.float64)), etc.
+ # and it should eventually be replaced with a better solution, e.g. when
+ # DTypes become HeapTypes.
+ return type(dtype(scalar_type))
+
+
+def _DType_reduce(DType):
+ # To pickle a DType without having to add top-level names, pickle the
+ # scalar type for now (and assume that reconstruction will be possible).
+ if DType is dtype:
+ return "dtype" # must pickle `np.dtype` as a singleton.
+ scalar_type = DType.type # pickle the scalar type for reconstruction
+ return _DType_reconstruct, (scalar_type,)
+
+
import copyreg
copyreg.pickle(ufunc, _ufunc_reduce)
-# Unclutter namespace (must keep _ufunc_reconstruct for unpickling)
+copyreg.pickle(type(dtype), _DType_reduce, _DType_reconstruct)
+
+# Unclutter namespace (must keep _*_reconstruct for unpickling)
del copyreg
del _ufunc_reduce
+del _DType_reduce
from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 03e0e172a..528486a05 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -1019,7 +1019,12 @@ class TestPickling:
def check_pickling(self, dtype):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- pickled = pickle.loads(pickle.dumps(dtype, proto))
+ buf = pickle.dumps(dtype, proto)
+ # The dtype pickling itself pickles `np.dtype` if it is pickled
+ # as a singleton `dtype` should be stored in the buffer:
+ assert b"_DType_reconstruct" not in buf
+ assert b"dtype" in buf
+ pickled = pickle.loads(buf)
assert_equal(pickled, dtype)
assert_equal(pickled.descr, dtype.descr)
if dtype.metadata is not None:
@@ -1075,6 +1080,15 @@ class TestPickling:
dt = np.dtype(int, metadata={'datum': 1})
self.check_pickling(dt)
+ @pytest.mark.parametrize("DType",
+ [type(np.dtype(t)) for t in np.typecodes['All']] +
+ [np.dtype(rational), np.dtype])
+ def test_pickle_types(self, DType):
+ # Check that DTypes (the classes/types) roundtrip when pickling
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ roundtrip_DType = pickle.loads(pickle.dumps(DType, proto))
+ assert roundtrip_DType is DType
+
def test_rational_dtype():
# test for bug gh-5719