summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorINADA Naoki <songofacandy@gmail.com>2010-10-26 01:49:00 +0900
committerINADA Naoki <songofacandy@gmail.com>2010-10-26 01:49:00 +0900
commit0076d42a0ddeb3601f1a0c806c7874da2110a986 (patch)
treed6fe47f2273a8c527796fe81c67f8610e71a72b2
parent3980d381f7ad44b14483dc90c0f6f0a36804c290 (diff)
downloadmsgpack-python-0076d42a0ddeb3601f1a0c806c7874da2110a986.tar.gz
Add check for recursion limit and default hook result.
-rw-r--r--msgpack/_msgpack.pyx24
-rw-r--r--test/test_obj.py8
2 files changed, 17 insertions, 15 deletions
diff --git a/msgpack/_msgpack.pyx b/msgpack/_msgpack.pyx
index 24e4f8b..0abdd51 100644
--- a/msgpack/_msgpack.pyx
+++ b/msgpack/_msgpack.pyx
@@ -80,7 +80,7 @@ cdef class Packer(object):
def __dealloc__(self):
free(self.pk.buf);
- cdef int _pack(self, object o) except -1:
+ cdef int _pack(self, object o, int nest_limit=511, default=None) except -1:
cdef long long llval
cdef unsigned long long ullval
cdef long longval
@@ -89,6 +89,9 @@ cdef class Packer(object):
cdef int ret
cdef dict d
+ if nest_limit < 0:
+ raise ValueError("Too deep.")
+
if o is None:
ret = msgpack_pack_nil(&self.pk)
#elif PyBool_Check(o):
@@ -126,33 +129,26 @@ cdef class Packer(object):
ret = msgpack_pack_map(&self.pk, len(d))
if ret == 0:
for k,v in d.items():
- ret = self._pack(k)
+ ret = self._pack(k, nest_limit-1, default)
if ret != 0: break
- ret = self._pack(v)
+ ret = self._pack(v, nest_limit-1, default)
if ret != 0: break
elif PySequence_Check(o):
ret = msgpack_pack_array(&self.pk, len(o))
if ret == 0:
for v in o:
- ret = self._pack(v)
+ ret = self._pack(v, nest_limit-1, default)
if ret != 0: break
- elif self.default is not None:
+ elif default is not None:
o = self.default(o)
- d = o
- ret = msgpack_pack_map(&self.pk, len(d))
- if ret == 0:
- for k,v in d.items():
- ret = self._pack(k)
- if ret != 0: break
- ret = self._pack(v)
- if ret != 0: break
+ ret = self._pack(o, nest_limit)
else:
raise TypeError("can't serialize %r" % (o,))
return ret
def pack(self, object obj):
cdef int ret
- ret = self._pack(obj)
+ ret = self._pack(obj, self.default)
if ret:
raise TypeError
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
diff --git a/test/test_obj.py b/test/test_obj.py
index 64a6390..28edacb 100644
--- a/test/test_obj.py
+++ b/test/test_obj.py
@@ -26,6 +26,12 @@ def test_decode_hook():
unpacked = unpacks(packed, object_hook=_decode_complex)
eq_(unpacked[1], 1+2j)
+@raises(TypeError)
+def test_bad_hook():
+ packed = packs([3, 1+2j], default=lambda o: o)
+ unpacked = unpacks(packed)
+
if __name__ == '__main__':
- #main()
test_decode_hook()
+ test_encode_hook()
+ test_bad_hook()