Add Packer.pack_pairs.

This commit is contained in:
INADA Naoki 2012-12-10 21:26:41 +09:00
parent 0c7ab7c344
commit 537a2ab3f2
2 changed files with 44 additions and 6 deletions

View file

@ -186,18 +186,49 @@ cdef class Packer(object):
self.pk.length = 0 self.pk.length = 0
return buf return buf
cpdef pack_array_header(self, size_t size): def pack_array_header(self, size_t size):
msgpack_pack_array(&self.pk, size) cdef int ret = msgpack_pack_array(&self.pk, size)
if ret == -1:
raise MemoryError
elif ret: # should not happen
raise TypeError
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0 self.pk.length = 0
return buf return buf
cpdef pack_map_header(self, size_t size): def pack_map_header(self, size_t size):
msgpack_pack_map(&self.pk, size) cdef int ret = msgpack_pack_map(&self.pk, size)
if ret == -1:
raise MemoryError
elif ret: # should not happen
raise TypeError
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0 self.pk.length = 0
return buf return buf
def pack_map_pairs(self, object pairs):
"""
Pack *pairs* as msgpack map type.
*pairs* should sequence of pair.
(`len(pairs)` and `for k, v in *pairs*:` should be supported.)
"""
cdef int ret = msgpack_pack_map(&self.pk, len(pairs))
if ret == 0:
for k, v in pairs:
ret = self._pack(k)
if ret != 0: break
ret = self._pack(v)
if ret != 0: break
if ret == -1:
raise MemoryError
elif ret: # should not happen
raise TypeError
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def pack(object o, object stream, default=None, encoding='utf-8', unicode_errors='strict'): def pack(object o, object stream, default=None, encoding='utf-8', unicode_errors='strict'):
""" """
pack an object `o` and write it to stream).""" pack an object `o` and write it to stream)."""

View file

@ -118,8 +118,6 @@ def testMapSize(sizes=[0, 5, 50, 1000]):
assert unpacker.unpack() == dict((i, i * 2) for i in range(size)) assert unpacker.unpack() == dict((i, i * 2) for i in range(size))
class odict(dict): class odict(dict):
'''Reimplement OrderedDict to run test on Python 2.6''' '''Reimplement OrderedDict to run test on Python 2.6'''
def __init__(self, seq): def __init__(self, seq):
@ -144,5 +142,14 @@ def test_odict():
assert_equal(unpackb(packb(od), object_pairs_hook=pair_hook, use_list=1), seq) assert_equal(unpackb(packb(od), object_pairs_hook=pair_hook, use_list=1), seq)
def test_pairlist():
pairlist = [(b'a', 1), (2, b'b'), (b'foo', b'bar')]
packer = Packer()
packed = packer.pack_map_pairs(pairlist)
unpacked = unpackb(packed, object_pairs_hook=list)
assert pairlist == unpacked
if __name__ == '__main__': if __name__ == '__main__':
main() main()