Add strict_map_key option to unpacker

This commit is contained in:
Inada Naoki 2018-11-29 22:29:38 +09:00
parent 3c9c6edbc8
commit e9086a34e4
3 changed files with 28 additions and 5 deletions

View file

@ -27,6 +27,7 @@ cdef extern from "unpack.h":
bint use_list bint use_list
bint raw bint raw
bint has_pairs_hook # call object_hook with k-v pairs bint has_pairs_hook # call object_hook with k-v pairs
bint strict_map_key
PyObject* object_hook PyObject* object_hook
PyObject* list_hook PyObject* list_hook
PyObject* ext_hook PyObject* ext_hook
@ -56,7 +57,7 @@ cdef extern from "unpack.h":
cdef inline init_ctx(unpack_context *ctx, cdef inline init_ctx(unpack_context *ctx,
object object_hook, object object_pairs_hook, object object_hook, object object_pairs_hook,
object list_hook, object ext_hook, object list_hook, object ext_hook,
bint use_list, bint raw, bint use_list, bint raw, bint strict_map_key,
const char* encoding, const char* unicode_errors, const char* encoding, const char* unicode_errors,
Py_ssize_t max_str_len, Py_ssize_t max_bin_len, Py_ssize_t max_str_len, Py_ssize_t max_bin_len,
Py_ssize_t max_array_len, Py_ssize_t max_map_len, Py_ssize_t max_array_len, Py_ssize_t max_map_len,
@ -64,6 +65,7 @@ cdef inline init_ctx(unpack_context *ctx,
unpack_init(ctx) unpack_init(ctx)
ctx.user.use_list = use_list ctx.user.use_list = use_list
ctx.user.raw = raw ctx.user.raw = raw
ctx.user.strict_map_key = strict_map_key
ctx.user.object_hook = ctx.user.list_hook = <PyObject*>NULL ctx.user.object_hook = ctx.user.list_hook = <PyObject*>NULL
ctx.user.max_str_len = max_str_len ctx.user.max_str_len = max_str_len
ctx.user.max_bin_len = max_bin_len ctx.user.max_bin_len = max_bin_len
@ -140,7 +142,7 @@ cdef inline int get_data_from_buffer(object obj,
return 1 return 1
def unpackb(object packed, object object_hook=None, object list_hook=None, def unpackb(object packed, object object_hook=None, object list_hook=None,
bint use_list=True, bint raw=True, bint use_list=True, bint raw=True, bint strict_map_key=False,
encoding=None, unicode_errors=None, encoding=None, unicode_errors=None,
object_pairs_hook=None, ext_hook=ExtType, object_pairs_hook=None, ext_hook=ExtType,
Py_ssize_t max_str_len=1024*1024, Py_ssize_t max_str_len=1024*1024,
@ -180,7 +182,7 @@ def unpackb(object packed, object object_hook=None, object list_hook=None,
get_data_from_buffer(packed, &view, &buf, &buf_len, &new_protocol) get_data_from_buffer(packed, &view, &buf, &buf_len, &new_protocol)
try: try:
init_ctx(&ctx, object_hook, object_pairs_hook, list_hook, ext_hook, init_ctx(&ctx, object_hook, object_pairs_hook, list_hook, ext_hook,
use_list, raw, cenc, cerr, use_list, raw, strict_map_key, cenc, cerr,
max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len) max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len)
ret = unpack_construct(&ctx, buf, buf_len, &off) ret = unpack_construct(&ctx, buf, buf_len, &off)
finally: finally:
@ -236,6 +238,11 @@ cdef class Unpacker(object):
*encoding* option which is deprecated overrides this option. *encoding* option which is deprecated overrides this option.
:param bool strict_map_key:
If true, only str or bytes are accepted for map (dict) keys.
It's False by default for backward-compatibility.
But it will be True from msgpack 1.0.
:param callable object_hook: :param callable object_hook:
When specified, it should be callable. When specified, it should be callable.
Unpacker calls it with a dict argument after unpacking msgpack map. Unpacker calls it with a dict argument after unpacking msgpack map.
@ -318,7 +325,7 @@ cdef class Unpacker(object):
self.buf = NULL self.buf = NULL
def __init__(self, file_like=None, Py_ssize_t read_size=0, def __init__(self, file_like=None, Py_ssize_t read_size=0,
bint use_list=True, bint raw=True, bint use_list=True, bint raw=True, bint strict_map_key=False,
object object_hook=None, object object_pairs_hook=None, object list_hook=None, object object_hook=None, object object_pairs_hook=None, object list_hook=None,
encoding=None, unicode_errors=None, Py_ssize_t max_buffer_size=0, encoding=None, unicode_errors=None, Py_ssize_t max_buffer_size=0,
object ext_hook=ExtType, object ext_hook=ExtType,
@ -366,7 +373,7 @@ cdef class Unpacker(object):
cerr = unicode_errors cerr = unicode_errors
init_ctx(&self.ctx, object_hook, object_pairs_hook, list_hook, init_ctx(&self.ctx, object_hook, object_pairs_hook, list_hook,
ext_hook, use_list, raw, cenc, cerr, ext_hook, use_list, raw, strict_map_key, cenc, cerr,
max_str_len, max_bin_len, max_array_len, max_str_len, max_bin_len, max_array_len,
max_map_len, max_ext_len) max_map_len, max_ext_len)

View file

@ -23,6 +23,7 @@ typedef struct unpack_user {
bool use_list; bool use_list;
bool raw; bool raw;
bool has_pairs_hook; bool has_pairs_hook;
bool strict_map_key;
PyObject *object_hook; PyObject *object_hook;
PyObject *list_hook; PyObject *list_hook;
PyObject *ext_hook; PyObject *ext_hook;
@ -188,6 +189,10 @@ static inline int unpack_callback_map(unpack_user* u, unsigned int n, msgpack_un
static inline int unpack_callback_map_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object k, msgpack_unpack_object v) static inline int unpack_callback_map_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object k, msgpack_unpack_object v)
{ {
if (u->strict_map_key && !PyUnicode_CheckExact(k) && !PyBytes_CheckExact(k)) {
PyErr_Format(PyExc_ValueError, "%.100s is not allowed for map key", Py_TYPE(k)->tp_name);
return -1;
}
if (u->has_pairs_hook) { if (u->has_pairs_hook) {
msgpack_unpack_object item = PyTuple_Pack(2, k, v); msgpack_unpack_object item = PyTuple_Pack(2, k, v);
if (!item) if (!item)

View file

@ -50,3 +50,14 @@ def test_invalidvalue():
with raises(StackError): with raises(StackError):
unpackb(b"\x91" * 3000) # nested fixarray(len=1) unpackb(b"\x91" * 3000) # nested fixarray(len=1)
def test_strict_map_key():
valid = {u"unicode": 1, b"bytes": 2}
packed = packb(valid, use_bin_type=True)
assert valid == unpackb(packed, raw=True)
invalid = {42: 1}
packed = packb(invalid, use_bin_type=True)
with raises(ValueError):
unpackb(packed, raw=True)