diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-02-06 12:24:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-06 12:24:16 -0700 |
commit | cba30dbc4012f0968e25d4da3b48f3d9d745aa00 (patch) | |
tree | 4a8dcb99e6092709f3128d5cc20938b150317a06 | |
parent | 32b564ade7ef22439b5f2b9c11aa4c63f0ecd6fd (diff) | |
parent | b765975ca89457fea91a63b94075c9e66465e2ea (diff) | |
download | numpy-cba30dbc4012f0968e25d4da3b48f3d9d745aa00.tar.gz |
Merge pull request #18332 from seberg/issue-18325
BUG: Allow pickling all relevant DType types/classes
-rw-r--r-- | numpy/core/__init__.py | 22 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 16 |
2 files changed, 36 insertions, 2 deletions
diff --git a/numpy/core/__init__.py b/numpy/core/__init__.py index f22c86f59..dad9293e1 100644 --- a/numpy/core/__init__.py +++ b/numpy/core/__init__.py @@ -125,6 +125,7 @@ def _ufunc_reconstruct(module, name): mod = __import__(module, fromlist=[name]) return getattr(mod, name) + def _ufunc_reduce(func): # Report the `__name__`. pickle will try to find the module. Note that # pickle supports for this `__name__` to be a `__qualname__`. It may @@ -134,12 +135,31 @@ def _ufunc_reduce(func): return func.__name__ +def _DType_reconstruct(scalar_type): + # This is a work-around to pickle type(np.dtype(np.float64)), etc. + # and it should eventually be replaced with a better solution, e.g. when + # DTypes become HeapTypes. + return type(dtype(scalar_type)) + + +def _DType_reduce(DType): + # To pickle a DType without having to add top-level names, pickle the + # scalar type for now (and assume that reconstruction will be possible). + if DType is dtype: + return "dtype" # must pickle `np.dtype` as a singleton. + scalar_type = DType.type # pickle the scalar type for reconstruction + return _DType_reconstruct, (scalar_type,) + + import copyreg copyreg.pickle(ufunc, _ufunc_reduce) -# Unclutter namespace (must keep _ufunc_reconstruct for unpickling) +copyreg.pickle(type(dtype), _DType_reduce, _DType_reconstruct) + +# Unclutter namespace (must keep _*_reconstruct for unpickling) del copyreg del _ufunc_reduce +del _DType_reduce from numpy._pytesttester import PytestTester test = PytestTester(__name__) diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 03e0e172a..528486a05 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -1019,7 +1019,12 @@ class TestPickling: def check_pickling(self, dtype): for proto in range(pickle.HIGHEST_PROTOCOL + 1): - pickled = pickle.loads(pickle.dumps(dtype, proto)) + buf = pickle.dumps(dtype, proto) + # The dtype pickling itself pickles `np.dtype` if it is pickled + # as a singleton `dtype` should be stored in the buffer: + assert b"_DType_reconstruct" not in buf + assert b"dtype" in buf + pickled = pickle.loads(buf) assert_equal(pickled, dtype) assert_equal(pickled.descr, dtype.descr) if dtype.metadata is not None: @@ -1075,6 +1080,15 @@ class TestPickling: dt = np.dtype(int, metadata={'datum': 1}) self.check_pickling(dt) + @pytest.mark.parametrize("DType", + [type(np.dtype(t)) for t in np.typecodes['All']] + + [np.dtype(rational), np.dtype]) + def test_pickle_types(self, DType): + # Check that DTypes (the classes/types) roundtrip when pickling + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + roundtrip_DType = pickle.loads(pickle.dumps(DType, proto)) + assert roundtrip_DType is DType + def test_rational_dtype(): # test for bug gh-5719 |