better error checks (#607)

* check buffer exports
* add error messages
This commit is contained in:
Inada Naoki 2024-05-06 03:33:48 +09:00 committed by GitHub
parent e1068087e0
commit e0f0e145f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 57 additions and 15 deletions

View file

@ -106,6 +106,7 @@ cdef class Packer:
cdef object _default
cdef object _berrors
cdef const char *unicode_errors
cdef size_t exports # number of exported buffers
cdef bint strict_types
cdef bint use_float
cdef bint autoreset
@ -117,10 +118,16 @@ cdef class Packer:
raise MemoryError("Unable to allocate internal buffer.")
self.pk.buf_size = buf_size
self.pk.length = 0
self.exports = 0
def __dealloc__(self):
PyMem_Free(self.pk.buf)
self.pk.buf = NULL
assert self.exports == 0
cdef _check_exports(self):
if self.exports > 0:
raise BufferError("Existing exports of data: Packer cannot be changed")
def __init__(self, *, default=None,
bint use_single_float=False, bint autoreset=True, bint use_bin_type=True,
@ -149,8 +156,8 @@ cdef class Packer:
cdef unsigned long ulval
cdef const char* rawval
cdef Py_ssize_t L
cdef bool strict_types = self.strict_types
cdef Py_buffer view
cdef bint strict = self.strict_types
if o is None:
msgpack_pack_nil(&self.pk)
@ -158,7 +165,7 @@ cdef class Packer:
msgpack_pack_true(&self.pk)
elif o is False:
msgpack_pack_false(&self.pk)
elif PyLong_CheckExact(o) if strict_types else PyLong_Check(o):
elif PyLong_CheckExact(o) if strict else PyLong_Check(o):
try:
if o > 0:
ullval = o
@ -171,19 +178,19 @@ cdef class Packer:
return -2
else:
raise OverflowError("Integer value out of range")
elif PyFloat_CheckExact(o) if strict_types else PyFloat_Check(o):
elif PyFloat_CheckExact(o) if strict else PyFloat_Check(o):
if self.use_float:
msgpack_pack_float(&self.pk, <float>o)
else:
msgpack_pack_double(&self.pk, <double>o)
elif PyBytesLike_CheckExact(o) if strict_types else PyBytesLike_Check(o):
elif PyBytesLike_CheckExact(o) if strict else PyBytesLike_Check(o):
L = Py_SIZE(o)
if L > ITEM_LIMIT:
PyErr_Format(ValueError, b"%.200s object is too large", Py_TYPE(o).tp_name)
rawval = o
msgpack_pack_bin(&self.pk, L)
msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyUnicode_CheckExact(o) if strict_types else PyUnicode_Check(o):
elif PyUnicode_CheckExact(o) if strict else PyUnicode_Check(o):
if self.unicode_errors == NULL:
rawval = PyUnicode_AsUTF8AndSize(o, &L)
if L >ITEM_LIMIT:
@ -196,7 +203,7 @@ cdef class Packer:
rawval = o
msgpack_pack_raw(&self.pk, L)
msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyDict_CheckExact(o) if strict_types else PyDict_Check(o):
elif PyDict_CheckExact(o) if strict else PyDict_Check(o):
L = len(o)
if L > ITEM_LIMIT:
raise ValueError("dict is too large")
@ -204,7 +211,7 @@ cdef class Packer:
for k, v in o.items():
self._pack(k, nest_limit)
self._pack(v, nest_limit)
elif type(o) is ExtType if strict_types else isinstance(o, ExtType):
elif type(o) is ExtType if strict else isinstance(o, ExtType):
# This should be before Tuple because ExtType is namedtuple.
rawval = o.data
L = len(o.data)
@ -216,7 +223,7 @@ cdef class Packer:
llval = o.seconds
ulval = o.nanoseconds
msgpack_pack_timestamp(&self.pk, llval, ulval)
elif PyList_CheckExact(o) if strict_types else (PyTuple_Check(o) or PyList_Check(o)):
elif PyList_CheckExact(o) if strict else (PyTuple_Check(o) or PyList_Check(o)):
L = Py_SIZE(o)
if L > ITEM_LIMIT:
raise ValueError("list is too large")
@ -264,6 +271,7 @@ cdef class Packer:
def pack(self, object obj):
cdef int ret
self._check_exports()
try:
ret = self._pack(obj, DEFAULT_RECURSE_LIMIT)
except:
@ -277,12 +285,16 @@ cdef class Packer:
return buf
def pack_ext_type(self, typecode, data):
self._check_exports()
if len(data) > ITEM_LIMIT:
raise ValueError("ext data too large")
msgpack_pack_ext(&self.pk, typecode, len(data))
msgpack_pack_raw_body(&self.pk, data, len(data))
def pack_array_header(self, long long size):
self._check_exports()
if size > ITEM_LIMIT:
raise ValueError
raise ValueError("array too large")
msgpack_pack_array(&self.pk, size)
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
@ -290,8 +302,9 @@ cdef class Packer:
return buf
def pack_map_header(self, long long size):
self._check_exports()
if size > ITEM_LIMIT:
raise ValueError
raise ValueError("map too learge")
msgpack_pack_map(&self.pk, size)
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
@ -305,7 +318,11 @@ cdef class Packer:
*pairs* should be a sequence of pairs.
(`len(pairs)` and `for k, v in pairs:` should be supported.)
"""
msgpack_pack_map(&self.pk, len(pairs))
self._check_exports()
size = len(pairs)
if size > ITEM_LIMIT:
raise ValueError("map too large")
msgpack_pack_map(&self.pk, size)
for k, v in pairs:
self._pack(k)
self._pack(v)
@ -319,6 +336,7 @@ cdef class Packer:
This method is useful only when autoreset=False.
"""
self._check_exports()
self.pk.length = 0
def bytes(self):
@ -326,11 +344,15 @@ cdef class Packer:
return PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
def getbuffer(self):
"""Return view of internal buffer."""
"""Return memoryview of internal buffer.
Note: Packer now supports buffer protocol. You can use memoryview(packer).
"""
return memoryview(self)
def __getbuffer__(self, Py_buffer *buffer, int flags):
PyBuffer_FillInfo(buffer, self, self.pk.buf, self.pk.length, 1, flags)
self.exports += 1
def __releasebuffer__(self, Py_buffer *buffer):
pass
self.exports -= 1

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python
from pytest import raises
from msgpack import packb, unpackb
from msgpack import packb, unpackb, Packer
def test_unpack_buffer():
@ -27,3 +27,23 @@ def test_unpack_memoryview():
assert [b"foo", b"bar"] == obj
expected_type = bytes
assert all(type(s) == expected_type for s in obj)
def test_packer_getbuffer():
packer = Packer(autoreset=False)
packer.pack_array_header(2)
packer.pack(42)
packer.pack("hello")
buffer = packer.getbuffer()
assert isinstance(buffer, memoryview)
assert bytes(buffer) == b"\x92*\xa5hello"
if Packer.__module__ == "msgpack._cmsgpack": # only for Cython
# cython Packer supports buffer protocol directly
assert bytes(packer) == b"\x92*\xa5hello"
with raises(BufferError):
packer.pack(42)
buffer.release()
packer.pack(42)
assert bytes(packer) == b"\x92*\xa5hello*"