summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorINADA Naoki <songofacandy@gmail.com>2010-10-26 01:26:06 +0900
committerINADA Naoki <songofacandy@gmail.com>2010-10-26 01:26:06 +0900
commitfa157082ac8db71e3312ca97fe1ceb7f56546fcb (patch)
treeee9cfaa2f1470e2c6229cb5bf7bfcfaf7404f96f
parent367f15c247bf37beb37e2a81d1d707bc8ef2e085 (diff)
downloadmsgpack-python-fa157082ac8db71e3312ca97fe1ceb7f56546fcb.tar.gz
Add `object_hook` option to unpack and `default` option to pack.
(see simplejson for how to use).
-rw-r--r--msgpack/_msgpack.pyx56
-rw-r--r--msgpack/unpack.h14
-rw-r--r--msgpack/unpack_template.h1
-rw-r--r--test/test_obj.py31
4 files changed, 92 insertions, 10 deletions
diff --git a/msgpack/_msgpack.pyx b/msgpack/_msgpack.pyx
index 66869c8..fb7f0c1 100644
--- a/msgpack/_msgpack.pyx
+++ b/msgpack/_msgpack.pyx
@@ -20,6 +20,9 @@ cdef extern from "Python.h":
cdef bint PyFloat_Check(object o)
cdef bint PyBytes_Check(object o)
cdef bint PyUnicode_Check(object o)
+ cdef bint PyCallable_Check(object o)
+ cdef void Py_INCREF(object o)
+ cdef void Py_DECREF(object o)
cdef extern from "stdlib.h":
void* malloc(size_t)
@@ -60,6 +63,7 @@ cdef class Packer(object):
astream.write(packer.pack(b))
"""
cdef msgpack_packer pk
+ cdef object default
def __cinit__(self):
cdef int buf_size = 1024*1024
@@ -67,6 +71,12 @@ cdef class Packer(object):
self.pk.buf_size = buf_size
self.pk.length = 0
+ def __init__(self, default=None):
+ if default is not None:
+ if not PyCallable_Check(default):
+ raise TypeError("default must be a callable.")
+ self.default = default
+
def __dealloc__(self):
free(self.pk.buf);
@@ -126,9 +136,18 @@ cdef class Packer(object):
for v in o:
ret = self._pack(v)
if ret != 0: break
+ elif self.default is not None:
+ o = self.default(o)
+ d = o
+ ret = msgpack_pack_map(&self.pk, len(d))
+ if ret == 0:
+ for k,v in d.items():
+ ret = self._pack(k)
+ if ret != 0: break
+ ret = self._pack(v)
+ if ret != 0: break
else:
- # TODO: Serialize with defalt() like simplejson.
- raise TypeError, "can't serialize %r" % (o,)
+ raise TypeError("can't serialize %r" % (o,))
return ret
def pack(self, object obj):
@@ -141,14 +160,14 @@ cdef class Packer(object):
return buf
-def pack(object o, object stream):
+def pack(object o, object stream, default=None):
"""pack an object `o` and write it to stream)."""
- packer = Packer()
+ packer = Packer(default)
stream.write(packer.pack(o))
-def packb(object o):
+def packb(object o, default=None):
"""pack o and return packed bytes."""
- packer = Packer()
+ packer = Packer(default=default)
return packer.pack(o)
packs = packb
@@ -156,6 +175,7 @@ packs = packb
cdef extern from "unpack.h":
ctypedef struct msgpack_user:
int use_list
+ PyObject* object_hook
ctypedef struct template_context:
msgpack_user user
@@ -170,7 +190,7 @@ cdef extern from "unpack.h":
object template_data(template_context* ctx)
-def unpackb(bytes packed_bytes):
+def unpackb(bytes packed_bytes, object object_hook=None):
"""Unpack packed_bytes to object. Returns an unpacked object."""
cdef const_char_ptr p = packed_bytes
cdef template_context ctx
@@ -178,7 +198,16 @@ def unpackb(bytes packed_bytes):
cdef int ret
template_init(&ctx)
ctx.user.use_list = 0
+ ctx.user.object_hook = NULL
+ if object_hook is not None:
+ if not PyCallable_Check(object_hook):
+ raise TypeError("object_hook must be a callable.")
+ Py_INCREF(object_hook)
+ ctx.user.object_hook = <PyObject*>object_hook
ret = template_execute(&ctx, p, len(packed_bytes), &off)
+ if object_hook is not None:
+ pass
+ #Py_DECREF(object_hook)
if ret == 1:
return template_data(&ctx)
else:
@@ -186,10 +215,10 @@ def unpackb(bytes packed_bytes):
unpacks = unpackb
-def unpack(object stream):
+def unpack(object stream, object object_hook=None):
"""unpack an object from stream."""
packed = stream.read()
- return unpackb(packed)
+ return unpackb(packed, object_hook=object_hook)
cdef class UnpackIterator(object):
cdef object unpacker
@@ -234,6 +263,7 @@ cdef class Unpacker(object):
cdef int read_size
cdef object waiting_bytes
cdef bint use_list
+ cdef object object_hook
def __cinit__(self):
self.buf = NULL
@@ -242,7 +272,8 @@ cdef class Unpacker(object):
if self.buf:
free(self.buf);
- def __init__(self, file_like=None, int read_size=0, bint use_list=0):
+ def __init__(self, file_like=None, int read_size=0, bint use_list=0,
+ object object_hook=None):
if read_size == 0:
read_size = 1024*1024
self.use_list = use_list
@@ -255,6 +286,11 @@ cdef class Unpacker(object):
self.buf_tail = 0
template_init(&self.ctx)
self.ctx.user.use_list = use_list
+ self.ctx.user.object_hook = <PyObject*>NULL
+ if object_hook is not None:
+ if not PyCallable_Check(object_hook):
+ raise TypeError("object_hook must be a callable.")
+ self.ctx.user.object_hook = <PyObject*>object_hook
def feed(self, bytes next_bytes):
self.waiting_bytes.append(next_bytes)
diff --git a/msgpack/unpack.h b/msgpack/unpack.h
index 9eb8ce7..e4c03bd 100644
--- a/msgpack/unpack.h
+++ b/msgpack/unpack.h
@@ -21,6 +21,7 @@
typedef struct unpack_user {
int use_list;
+ PyObject *object_hook;
} unpack_user;
@@ -172,6 +173,19 @@ static inline int template_callback_map_item(unpack_user* u, msgpack_unpack_obje
return -1;
}
+//static inline int template_callback_map_end(unpack_user* u, msgpack_unpack_object* c)
+int template_callback_map_end(unpack_user* u, msgpack_unpack_object* c)
+{
+ if (u->object_hook) {
+ PyObject *arglist = Py_BuildValue("(O)", *c);
+ Py_INCREF(*c);
+ *c = PyEval_CallObject(u->object_hook, arglist);
+ Py_DECREF(arglist);
+ return 0;
+ }
+ return -1;
+}
+
static inline int template_callback_raw(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o)
{
PyObject *py;
diff --git a/msgpack/unpack_template.h b/msgpack/unpack_template.h
index ca6e1f3..1fdedd7 100644
--- a/msgpack/unpack_template.h
+++ b/msgpack/unpack_template.h
@@ -317,6 +317,7 @@ _push:
case CT_MAP_VALUE:
if(msgpack_unpack_callback(_map_item)(user, &c->obj, c->map_key, obj) < 0) { goto _failed; }
if(--c->count == 0) {
+ msgpack_unpack_callback(_map_end)(user, &c->obj);
obj = c->obj;
--top;
/*printf("stack pop %d\n", top);*/
diff --git a/test/test_obj.py b/test/test_obj.py
new file mode 100644
index 0000000..64a6390
--- /dev/null
+++ b/test/test_obj.py
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+from nose import main
+from nose.tools import *
+
+from msgpack import packs, unpacks
+
+def _decode_complex(obj):
+ if '__complex__' in obj:
+ return complex(obj['real'], obj['imag'])
+ return obj
+
+def _encode_complex(obj):
+ if isinstance(obj, complex):
+ return {'__complex__': True, 'real': 1, 'imag': 2}
+ return obj
+
+def test_encode_hook():
+ packed = packs([3, 1+2j], default=_encode_complex)
+ unpacked = unpacks(packed)
+ eq_(unpacked[1], {'__complex__': True, 'real': 1, 'imag': 2})
+
+def test_decode_hook():
+ packed = packs([3, {'__complex__': True, 'real': 1, 'imag': 2}])
+ unpacked = unpacks(packed, object_hook=_decode_complex)
+ eq_(unpacked[1], 1+2j)
+
+if __name__ == '__main__':
+ #main()
+ test_decode_hook()