diff options
author | INADA Naoki <songofacandy@gmail.com> | 2010-10-26 01:49:00 +0900 |
---|---|---|
committer | INADA Naoki <songofacandy@gmail.com> | 2010-10-26 01:49:00 +0900 |
commit | 0076d42a0ddeb3601f1a0c806c7874da2110a986 (patch) | |
tree | d6fe47f2273a8c527796fe81c67f8610e71a72b2 | |
parent | 3980d381f7ad44b14483dc90c0f6f0a36804c290 (diff) | |
download | msgpack-python-0076d42a0ddeb3601f1a0c806c7874da2110a986.tar.gz |
Add check for recursion limit and default hook result.
-rw-r--r-- | msgpack/_msgpack.pyx | 24 | ||||
-rw-r--r-- | test/test_obj.py | 8 |
2 files changed, 17 insertions, 15 deletions
diff --git a/msgpack/_msgpack.pyx b/msgpack/_msgpack.pyx index 24e4f8b..0abdd51 100644 --- a/msgpack/_msgpack.pyx +++ b/msgpack/_msgpack.pyx @@ -80,7 +80,7 @@ cdef class Packer(object): def __dealloc__(self): free(self.pk.buf); - cdef int _pack(self, object o) except -1: + cdef int _pack(self, object o, int nest_limit=511, default=None) except -1: cdef long long llval cdef unsigned long long ullval cdef long longval @@ -89,6 +89,9 @@ cdef class Packer(object): cdef int ret cdef dict d + if nest_limit < 0: + raise ValueError("Too deep.") + if o is None: ret = msgpack_pack_nil(&self.pk) #elif PyBool_Check(o): @@ -126,33 +129,26 @@ cdef class Packer(object): ret = msgpack_pack_map(&self.pk, len(d)) if ret == 0: for k,v in d.items(): - ret = self._pack(k) + ret = self._pack(k, nest_limit-1, default) if ret != 0: break - ret = self._pack(v) + ret = self._pack(v, nest_limit-1, default) if ret != 0: break elif PySequence_Check(o): ret = msgpack_pack_array(&self.pk, len(o)) if ret == 0: for v in o: - ret = self._pack(v) + ret = self._pack(v, nest_limit-1, default) if ret != 0: break - elif self.default is not None: + elif default is not None: o = self.default(o) - d = o - ret = msgpack_pack_map(&self.pk, len(d)) - if ret == 0: - for k,v in d.items(): - ret = self._pack(k) - if ret != 0: break - ret = self._pack(v) - if ret != 0: break + ret = self._pack(o, nest_limit) else: raise TypeError("can't serialize %r" % (o,)) return ret def pack(self, object obj): cdef int ret - ret = self._pack(obj) + ret = self._pack(obj, self.default) if ret: raise TypeError buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) diff --git a/test/test_obj.py b/test/test_obj.py index 64a6390..28edacb 100644 --- a/test/test_obj.py +++ b/test/test_obj.py @@ -26,6 +26,12 @@ def test_decode_hook(): unpacked = unpacks(packed, object_hook=_decode_complex) eq_(unpacked[1], 1+2j) +@raises(TypeError) +def test_bad_hook(): + packed = packs([3, 1+2j], default=lambda o: o) + unpacked = unpacks(packed) + if __name__ == '__main__': - #main() test_decode_hook() + test_encode_hook() + test_bad_hook() |