diff options
-rw-r--r-- | msgpack/_packer.pyx | 31 | ||||
-rw-r--r-- | msgpack/fallback.py | 44 | ||||
-rw-r--r-- | test/test_stricttype.py | 15 |
3 files changed, 69 insertions, 21 deletions
diff --git a/msgpack/_packer.pyx b/msgpack/_packer.pyx index 6392655..c3ef1a4 100644 --- a/msgpack/_packer.pyx +++ b/msgpack/_packer.pyx @@ -63,6 +63,13 @@ cdef class Packer(object): :param bool use_bin_type: Use bin type introduced in msgpack spec 2.0 for bytes. It also enable str8 type for unicode. + :param bool strict_types: + If set to true, types will be checked to be exact. Derived classes + from serializeable types will not be serialized and will be + treated as unsupported type and forwarded to default. + Additionally tuples will not be serialized as lists. + This is useful when trying to implement accurate serialization + for python types. """ cdef msgpack_packer pk cdef object _default @@ -70,6 +77,7 @@ cdef class Packer(object): cdef object _berrors cdef char *encoding cdef char *unicode_errors + cdef bint strict_types cdef bool use_float cdef bint autoreset @@ -82,10 +90,12 @@ cdef class Packer(object): self.pk.length = 0 def __init__(self, default=None, encoding='utf-8', unicode_errors='strict', - use_single_float=False, bint autoreset=1, bint use_bin_type=0): + use_single_float=False, bint autoreset=1, bint use_bin_type=0, + bint strict_types=0): """ """ self.use_float = use_single_float + self.strict_types = strict_types self.autoreset = autoreset self.pk.use_bin_type = use_bin_type if default is not None: @@ -121,6 +131,7 @@ cdef class Packer(object): cdef dict d cdef size_t L cdef int default_used = 0 + cdef bint strict_types = self.strict_types if nest_limit < 0: raise PackValueError("recursion limit exceeded.") @@ -128,12 +139,12 @@ cdef class Packer(object): while True: if o is None: ret = msgpack_pack_nil(&self.pk) - elif isinstance(o, bool): + elif PyBool_Check(o) if strict_types else isinstance(o, bool): if o: ret = msgpack_pack_true(&self.pk) else: ret = msgpack_pack_false(&self.pk) - elif PyLong_Check(o): + elif PyLong_CheckExact(o) if strict_types else PyLong_Check(o): # PyInt_Check(long) is True for Python 3. # So we should test long before int. try: @@ -150,17 +161,17 @@ cdef class Packer(object): continue else: raise - elif PyInt_Check(o): + elif PyInt_CheckExact(o) if strict_types else PyInt_Check(o): longval = o ret = msgpack_pack_long(&self.pk, longval) - elif PyFloat_Check(o): + elif PyFloat_CheckExact(o) if strict_types else PyFloat_Check(o): if self.use_float: fval = o ret = msgpack_pack_float(&self.pk, fval) else: dval = o ret = msgpack_pack_double(&self.pk, dval) - elif PyBytes_Check(o): + elif PyBytes_CheckExact(o) if strict_types else PyBytes_Check(o): L = len(o) if L > (2**32)-1: raise ValueError("bytes is too large") @@ -168,7 +179,7 @@ cdef class Packer(object): ret = msgpack_pack_bin(&self.pk, L) if ret == 0: ret = msgpack_pack_raw_body(&self.pk, rawval, L) - elif PyUnicode_Check(o): + elif PyUnicode_CheckExact(o) if strict_types else PyUnicode_Check(o): if not self.encoding: raise TypeError("Can't encode unicode string: no encoding is specified") o = PyUnicode_AsEncodedString(o, self.encoding, self.unicode_errors) @@ -191,7 +202,7 @@ cdef class Packer(object): if ret != 0: break ret = self._pack(v, nest_limit-1) if ret != 0: break - elif PyDict_Check(o): + elif not strict_types and PyDict_Check(o): L = len(o) if L > (2**32)-1: raise ValueError("dict is too large") @@ -202,7 +213,7 @@ cdef class Packer(object): if ret != 0: break ret = self._pack(v, nest_limit-1) if ret != 0: break - elif isinstance(o, ExtType): + elif type(o) is ExtType if strict_types else isinstance(o, ExtType): # This should be before Tuple because ExtType is namedtuple. longval = o.code rawval = o.data @@ -211,7 +222,7 @@ cdef class Packer(object): raise ValueError("EXT data is too large") ret = msgpack_pack_ext(&self.pk, longval, L) ret = msgpack_pack_raw_body(&self.pk, rawval, L) - elif PyTuple_Check(o) or PyList_Check(o): + elif PyList_CheckExact(o) if strict_types else (PyTuple_Check(o) or PyList_Check(o)): L = len(o) if L > (2**32)-1: raise ValueError("list is too large") diff --git a/msgpack/fallback.py b/msgpack/fallback.py index f682611..40c54a8 100644 --- a/msgpack/fallback.py +++ b/msgpack/fallback.py @@ -69,6 +69,13 @@ TYPE_EXT = 5 DEFAULT_RECURSE_LIMIT = 511 +def _check_type_strict(obj, t, type=type, tuple=tuple): + if type(t) is tuple: + return type(obj) in t + else: + return type(obj) is t + + def unpack(stream, **kwargs): """ Unpack an object from `stream`. @@ -609,9 +616,18 @@ class Packer(object): :param bool use_bin_type: Use bin type introduced in msgpack spec 2.0 for bytes. It also enable str8 type for unicode. + :param bool strict_types: + If set to true, types will be checked to be exact. Derived classes + from serializeable types will not be serialized and will be + treated as unsupported type and forwarded to default. + Additionally tuples will not be serialized as lists. + This is useful when trying to implement accurate serialization + for python types. """ def __init__(self, default=None, encoding='utf-8', unicode_errors='strict', - use_single_float=False, autoreset=True, use_bin_type=False): + use_single_float=False, autoreset=True, use_bin_type=False, + strict_types=False): + self._strict_types = strict_types self._use_float = use_single_float self._autoreset = autoreset self._use_bin_type = use_bin_type @@ -623,18 +639,24 @@ class Packer(object): raise TypeError("default must be callable") self._default = default - def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT, isinstance=isinstance): + def _pack(self, obj, nest_limit=DEFAULT_RECURSE_LIMIT, + check=isinstance, check_type_strict=_check_type_strict): default_used = False + if self._strict_types: + check = check_type_strict + list_types = list + else: + list_types = (list, tuple) while True: if nest_limit < 0: raise PackValueError("recursion limit exceeded") if obj is None: return self._buffer.write(b"\xc0") - if isinstance(obj, bool): + if check(obj, bool): if obj: return self._buffer.write(b"\xc3") return self._buffer.write(b"\xc2") - if isinstance(obj, int_types): + if check(obj, int_types): if 0 <= obj < 0x80: return self._buffer.write(struct.pack("B", obj)) if -0x20 <= obj < 0: @@ -660,7 +682,7 @@ class Packer(object): default_used = True continue raise PackValueError("Integer value out of range") - if self._use_bin_type and isinstance(obj, bytes): + if self._use_bin_type and check(obj, bytes): n = len(obj) if n <= 0xff: self._buffer.write(struct.pack('>BB', 0xc4, n)) @@ -671,8 +693,8 @@ class Packer(object): else: raise PackValueError("Bytes is too large") return self._buffer.write(obj) - if isinstance(obj, (Unicode, bytes)): - if isinstance(obj, Unicode): + if check(obj, (Unicode, bytes)): + if check(obj, Unicode): if self._encoding is None: raise TypeError( "Can't encode unicode string: " @@ -690,11 +712,11 @@ class Packer(object): else: raise PackValueError("String is too large") return self._buffer.write(obj) - if isinstance(obj, float): + if check(obj, float): if self._use_float: return self._buffer.write(struct.pack(">Bf", 0xca, obj)) return self._buffer.write(struct.pack(">Bd", 0xcb, obj)) - if isinstance(obj, ExtType): + if check(obj, ExtType): code = obj.code data = obj.data assert isinstance(code, int) @@ -719,13 +741,13 @@ class Packer(object): self._buffer.write(struct.pack("b", code)) self._buffer.write(data) return - if isinstance(obj, (list, tuple)): + if check(obj, list_types): n = len(obj) self._fb_pack_array_header(n) for i in xrange(n): self._pack(obj[i], nest_limit - 1) return - if isinstance(obj, dict): + if check(obj, dict): return self._fb_pack_map_pairs(len(obj), dict_iteritems(obj), nest_limit - 1) if not default_used and self._default is not None: diff --git a/test/test_stricttype.py b/test/test_stricttype.py new file mode 100644 index 0000000..a20b5eb --- /dev/null +++ b/test/test_stricttype.py @@ -0,0 +1,15 @@ +# coding: utf-8 + +from collections import namedtuple +from msgpack import packb, unpackb + + +def test_namedtuple(): + T = namedtuple('T', "foo bar") + def default(o): + if isinstance(o, T): + return dict(o._asdict()) + raise TypeError('Unsupported type %s' % (type(o),)) + packed = packb(T(1, 42), strict_types=True, use_bin_type=True, default=default) + unpacked = unpackb(packed, encoding='utf-8') + assert unpacked == {'foo': 1, 'bar': 42} |