This commit is contained in:
Pea Soft 2026-01-03 06:15:21 +00:00 committed by GitHub
commit d765319b91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 81 additions and 63 deletions

View file

@ -1,5 +1,7 @@
# ruff: noqa: F401
# pyright: reportUnusedImport = none
import os
import typing as t
from .exceptions import * # noqa: F403
from .ext import ExtType, Timestamp
@ -8,7 +10,7 @@ version = (1, 1, 2)
__version__ = "1.1.2"
if os.environ.get("MSGPACK_PUREPYTHON"):
if os.environ.get("MSGPACK_PUREPYTHON") or t.TYPE_CHECKING:
from .fallback import Packer, Unpacker, unpackb
else:
try:
@ -17,26 +19,26 @@ else:
from .fallback import Packer, Unpacker, unpackb
def pack(o, stream, **kwargs):
def pack(o: t.Any, stream: t.BinaryIO, **kwargs: t.Any):
"""
Pack object `o` and write it to `stream`
See :class:`Packer` for options.
"""
packer = Packer(**kwargs)
stream.write(packer.pack(o))
packer = Packer(autoreset=True, **kwargs) # type: ignore
stream.write(t.cast(bytes, packer.pack(o)))
def packb(o, **kwargs):
def packb(o: t.Any, **kwargs: t.Any) -> bytes:
"""
Pack object `o` and return packed bytes
See :class:`Packer` for options.
"""
return Packer(**kwargs).pack(o)
return Packer(autoreset=True, **kwargs).pack(o) # type: ignore
def unpack(stream, **kwargs):
def unpack(stream: t.BinaryIO, **kwargs: t.Any):
"""
Unpack an object from `stream`.

View file

@ -1,19 +1,23 @@
import datetime
import struct
import typing as t
from collections import namedtuple
class ExtType(namedtuple("ExtType", "code data")):
"""ExtType represents ext type in msgpack."""
def __new__(cls, code, data):
code: int
data: bytes
def __new__(cls, code: int, data: bytes):
if not isinstance(code, int):
raise TypeError("code must be int")
if not isinstance(data, bytes):
raise TypeError("data must be bytes")
if not 0 <= code <= 127:
raise ValueError("code must be 0~127")
return super().__new__(cls, code, data)
return super().__new__(cls, code, data) # type: ignore
class Timestamp:
@ -28,7 +32,7 @@ class Timestamp:
__slots__ = ["seconds", "nanoseconds"]
def __init__(self, seconds, nanoseconds=0):
def __init__(self, seconds: int, nanoseconds=0):
"""Initialize a Timestamp object.
:param int seconds:
@ -54,13 +58,13 @@ class Timestamp:
"""String representation of Timestamp."""
return f"Timestamp(seconds={self.seconds}, nanoseconds={self.nanoseconds})"
def __eq__(self, other):
def __eq__(self, other: t.Any):
"""Check for equality with another Timestamp object"""
if type(other) is self.__class__:
return self.seconds == other.seconds and self.nanoseconds == other.nanoseconds
return False
def __ne__(self, other):
def __ne__(self, other: t.Any):
"""not-equals method (see :func:`__eq__()`)"""
return not self.__eq__(other)
@ -68,7 +72,7 @@ class Timestamp:
return hash((self.seconds, self.nanoseconds))
@staticmethod
def from_bytes(b):
def from_bytes(b: bytes):
"""Unpack bytes into a `Timestamp` object.
Used for pure-Python msgpack unpacking.
@ -116,7 +120,7 @@ class Timestamp:
return data
@staticmethod
def from_unix(unix_sec):
def from_unix(unix_sec: int | float):
"""Create a Timestamp from posix timestamp in seconds.
:param unix_float: Posix timestamp in seconds.
@ -135,7 +139,7 @@ class Timestamp:
return self.seconds + self.nanoseconds / 1e9
@staticmethod
def from_unix_nano(unix_ns):
def from_unix_nano(unix_ns: int):
"""Create a Timestamp from posix timestamp in nanoseconds.
:param int unix_ns: Posix timestamp in nanoseconds.
@ -162,7 +166,7 @@ class Timestamp:
)
@staticmethod
def from_datetime(dt):
def from_datetime(dt: datetime.datetime):
"""Create a Timestamp from datetime with tzinfo.
:rtype: Timestamp

View file

@ -2,9 +2,18 @@
import struct
import sys
import types
import typing as t
from collections.abc import Sequence
from datetime import datetime as _DateTime
if hasattr(sys, "pypy_version_info"):
_ClassInfo: t.TypeAlias = type | types.UnionType | tuple["_ClassInfo", ...]
_Pair = tuple[t.Any, t.Any]
_Pairs = t.Iterable[_Pair]
_SizeFmt = tuple[int, str]
_SizeFmtTyp = tuple[int, str, int]
if not t.TYPE_CHECKING and hasattr(sys, "pypy_version_info"):
from __pypy__ import newlist_hint
from __pypy__.builders import BytesBuilder
@ -33,7 +42,7 @@ else:
_USING_STRINGBUILDER = False
def newlist_hint(size):
def newlist_hint(size: int) -> list[t.Any]:
return []
@ -55,21 +64,21 @@ TYPE_EXT = 5
DEFAULT_RECURSE_LIMIT = 511
def _check_type_strict(obj, t, type=type, tuple=tuple):
def _check_type_strict(obj: object, t: _ClassInfo, type=type, tuple=tuple):
if type(t) is tuple:
return type(obj) in t
else:
return type(obj) is t
def _get_data_from_buffer(obj):
def _get_data_from_buffer(obj: bytes):
view = memoryview(obj)
if view.itemsize != 1:
raise ValueError("cannot unpack from multi-byte object")
return view
def unpackb(packed, **kwargs):
def unpackb(packed: bytes, **kwargs: t.Any):
"""
Unpack an object from `packed`.
@ -81,7 +90,7 @@ def unpackb(packed, **kwargs):
See :class:`Unpacker` for options.
"""
unpacker = Unpacker(None, max_buffer_size=len(packed), **kwargs)
unpacker = Unpacker(None, max_buffer_size=len(packed), **kwargs) # type: ignore
unpacker.feed(packed)
try:
ret = unpacker._unpack()
@ -95,7 +104,7 @@ def unpackb(packed, **kwargs):
_NO_FORMAT_USED = ""
_MSGPACK_HEADERS = {
_MSGPACK_HEADERS: dict[int, _SizeFmt | _SizeFmtTyp] = {
0xC4: (1, _NO_FORMAT_USED, TYPE_BIN),
0xC5: (2, ">H", TYPE_BIN),
0xC6: (4, ">I", TYPE_BIN),
@ -225,17 +234,17 @@ class Unpacker:
def __init__(
self,
file_like=None,
file_like: t.BinaryIO | None = None,
*,
read_size=0,
use_list=True,
raw=False,
timestamp=0,
strict_map_key=True,
object_hook=None,
object_pairs_hook=None,
list_hook=None,
unicode_errors=None,
object_hook: t.Callable[[dict[t.Any, t.Any]], object] | None = None,
object_pairs_hook: t.Callable[[_Pairs], dict[t.Any, t.Any]] | None = None,
list_hook: t.Callable[[list[t.Any]], list[t.Any]] | None = None,
unicode_errors: str | None = None,
max_buffer_size=100 * 1024 * 1024,
ext_hook=ExtType,
max_str_len=-1,
@ -315,7 +324,7 @@ class Unpacker:
if not callable(ext_hook):
raise TypeError("`ext_hook` is not callable")
def feed(self, next_bytes):
def feed(self, next_bytes: bytes):
assert self._feeding
view = _get_data_from_buffer(next_bytes)
if len(self._buffer) - self._buff_i + len(view) > self._max_buffer_size:
@ -342,12 +351,12 @@ class Unpacker:
def _get_extradata(self):
return self._buffer[self._buff_i :]
def read_bytes(self, n):
def read_bytes(self, n: int):
ret = self._read(n, raise_outofdata=False)
self._consume()
return ret
def _read(self, n, raise_outofdata=True):
def _read(self, n: int, raise_outofdata=True):
# (int) -> bytearray
self._reserve(n, raise_outofdata=raise_outofdata)
i = self._buff_i
@ -355,7 +364,7 @@ class Unpacker:
self._buff_i = i + len(ret)
return ret
def _reserve(self, n, raise_outofdata=True):
def _reserve(self, n: int, raise_outofdata=True):
remain_bytes = len(self._buffer) - self._buff_i - n
# Fast path: buffer has n bytes already
@ -423,7 +432,7 @@ class Unpacker:
elif b == 0xC3:
obj = True
elif 0xC4 <= b <= 0xC6:
size, fmt, typ = _MSGPACK_HEADERS[b]
size, fmt, typ = t.cast(_SizeFmtTyp, _MSGPACK_HEADERS[b])
self._reserve(size)
if len(fmt) > 0:
n = struct.unpack_from(fmt, self._buffer, self._buff_i)[0]
@ -434,7 +443,7 @@ class Unpacker:
raise ValueError(f"{n} exceeds max_bin_len({self._max_bin_len})")
obj = self._read(n)
elif 0xC7 <= b <= 0xC9:
size, fmt, typ = _MSGPACK_HEADERS[b]
size, fmt, typ = t.cast(_SizeFmtTyp, _MSGPACK_HEADERS[b])
self._reserve(size)
L, n = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size
@ -442,7 +451,7 @@ class Unpacker:
raise ValueError(f"{L} exceeds max_ext_len({self._max_ext_len})")
obj = self._read(L)
elif 0xCA <= b <= 0xD3:
size, fmt = _MSGPACK_HEADERS[b]
size, fmt = t.cast(_SizeFmt, _MSGPACK_HEADERS[b])
self._reserve(size)
if len(fmt) > 0:
obj = struct.unpack_from(fmt, self._buffer, self._buff_i)[0]
@ -450,14 +459,14 @@ class Unpacker:
obj = self._buffer[self._buff_i]
self._buff_i += size
elif 0xD4 <= b <= 0xD8:
size, fmt, typ = _MSGPACK_HEADERS[b]
size, fmt, typ = t.cast(_SizeFmtTyp, _MSGPACK_HEADERS[b])
if self._max_ext_len < size:
raise ValueError(f"{size} exceeds max_ext_len({self._max_ext_len})")
self._reserve(size + 1)
n, obj = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size + 1
elif 0xD9 <= b <= 0xDB:
size, fmt, typ = _MSGPACK_HEADERS[b]
size, fmt, typ = t.cast(_SizeFmtTyp, _MSGPACK_HEADERS[b])
self._reserve(size)
if len(fmt) > 0:
(n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
@ -468,14 +477,14 @@ class Unpacker:
raise ValueError(f"{n} exceeds max_str_len({self._max_str_len})")
obj = self._read(n)
elif 0xDC <= b <= 0xDD:
size, fmt, typ = _MSGPACK_HEADERS[b]
size, fmt, typ = t.cast(_SizeFmtTyp, _MSGPACK_HEADERS[b])
self._reserve(size)
(n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size
if n > self._max_array_len:
raise ValueError(f"{n} exceeds max_array_len({self._max_array_len})")
elif 0xDE <= b <= 0xDF:
size, fmt, typ = _MSGPACK_HEADERS[b]
size, fmt, typ = t.cast(_SizeFmtTyp, _MSGPACK_HEADERS[b])
self._reserve(size)
(n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size
@ -483,7 +492,7 @@ class Unpacker:
raise ValueError(f"{n} exceeds max_map_len({self._max_map_len})")
else:
raise FormatError("Unknown header: 0x%x" % b)
return typ, n, obj
return typ, t.cast(int, n), obj
def _unpack(self, execute=EX_CONSTRUCT):
typ, n, obj = self._read_header()
@ -519,31 +528,34 @@ class Unpacker:
return
if self._object_pairs_hook is not None:
ret = self._object_pairs_hook(
(self._unpack(EX_CONSTRUCT), self._unpack(EX_CONSTRUCT)) for _ in range(n)
(self._unpack(EX_CONSTRUCT), self._unpack(EX_CONSTRUCT))
for _ in range(n) # type: ignore
)
else:
ret = {}
for _ in range(n):
key = self._unpack(EX_CONSTRUCT)
if self._strict_map_key and type(key) not in (str, bytes):
raise ValueError("%s is not allowed for map key" % str(type(key)))
raise ValueError("%s is not allowed for map key" % str(type(key))) # type: ignore
if isinstance(key, str):
key = sys.intern(key)
ret[key] = self._unpack(EX_CONSTRUCT)
if self._object_hook is not None:
ret = self._object_hook(ret)
ret = self._object_hook(ret) # type: ignore
return ret
if execute == EX_SKIP:
return
if typ == TYPE_RAW:
obj = t.cast(bytearray, obj)
if self._raw:
obj = bytes(obj)
else:
obj = obj.decode("utf_8", self._unicode_errors)
return obj
if typ == TYPE_BIN:
return bytes(obj)
return bytes(t.cast(bytearray, obj))
if typ == TYPE_EXT:
obj = t.cast(bytearray, obj)
if n == -1: # timestamp
ts = Timestamp.from_bytes(bytes(obj))
if self._timestamp == 1:
@ -653,14 +665,14 @@ class Packer:
def __init__(
self,
*,
default=None,
default: t.Callable[[t.Any], t.Any] | None = None,
use_single_float=False,
autoreset=True,
use_bin_type=True,
strict_types=False,
datetime=False,
unicode_errors=None,
buf_size=None,
unicode_errors: str | None = None,
buf_size: int | None = None,
):
self._strict_types = strict_types
self._use_float = use_single_float
@ -675,10 +687,10 @@ class Packer:
def _pack(
self,
obj,
obj: t.Any,
nest_limit=DEFAULT_RECURSE_LIMIT,
check=isinstance,
check_type_strict=_check_type_strict,
check: t.Callable[[object, _ClassInfo], bool] = isinstance,
check_type_strict: t.Callable[[object, _ClassInfo], bool] = _check_type_strict,
):
default_used = False
if self._strict_types:
@ -747,10 +759,10 @@ class Packer:
if check(obj, (ExtType, Timestamp)):
if check(obj, Timestamp):
code = -1
data = obj.to_bytes()
data = t.cast(Timestamp, obj).to_bytes()
else:
code = obj.code
data = obj.data
code = t.cast(ExtType, obj).code
data = t.cast(ExtType, obj).data
assert isinstance(code, int)
assert isinstance(data, bytes)
L = len(data)
@ -797,7 +809,7 @@ class Packer:
raise TypeError(f"Cannot serialize {obj!r}")
def pack(self, obj):
def pack(self, obj: t.Any):
try:
self._pack(obj)
except:
@ -808,14 +820,14 @@ class Packer:
self._buffer = BytesIO()
return ret
def pack_map_pairs(self, pairs):
def pack_map_pairs(self, pairs: Sequence[_Pair]):
self._pack_map_pairs(len(pairs), pairs)
if self._autoreset:
ret = self._buffer.getvalue()
self._buffer = BytesIO()
return ret
def pack_array_header(self, n):
def pack_array_header(self, n: int):
if n >= 2**32:
raise ValueError
self._pack_array_header(n)
@ -824,7 +836,7 @@ class Packer:
self._buffer = BytesIO()
return ret
def pack_map_header(self, n):
def pack_map_header(self, n: int):
if n >= 2**32:
raise ValueError
self._pack_map_header(n)
@ -833,7 +845,7 @@ class Packer:
self._buffer = BytesIO()
return ret
def pack_ext_type(self, typecode, data):
def pack_ext_type(self, typecode: int, data: bytes):
if not isinstance(typecode, int):
raise TypeError("typecode must have int type.")
if not 0 <= typecode <= 127:
@ -862,7 +874,7 @@ class Packer:
self._buffer.write(struct.pack("B", typecode))
self._buffer.write(data)
def _pack_array_header(self, n):
def _pack_array_header(self, n: int):
if n <= 0x0F:
return self._buffer.write(struct.pack("B", 0x90 + n))
if n <= 0xFFFF:
@ -871,7 +883,7 @@ class Packer:
return self._buffer.write(struct.pack(">BI", 0xDD, n))
raise ValueError("Array is too large")
def _pack_map_header(self, n):
def _pack_map_header(self, n: int):
if n <= 0x0F:
return self._buffer.write(struct.pack("B", 0x80 + n))
if n <= 0xFFFF:
@ -880,13 +892,13 @@ class Packer:
return self._buffer.write(struct.pack(">BI", 0xDF, n))
raise ValueError("Dict is too large")
def _pack_map_pairs(self, n, pairs, nest_limit=DEFAULT_RECURSE_LIMIT):
def _pack_map_pairs(self, n: int, pairs: _Pairs, nest_limit=DEFAULT_RECURSE_LIMIT):
self._pack_map_header(n)
for k, v in pairs:
self._pack(k, nest_limit - 1)
self._pack(v, nest_limit - 1)
def _pack_raw_header(self, n):
def _pack_raw_header(self, n: int):
if n <= 0x1F:
self._buffer.write(struct.pack("B", 0xA0 + n))
elif self._use_bin_type and n <= 0xFF:
@ -898,7 +910,7 @@ class Packer:
else:
raise ValueError("Raw is too large")
def _pack_bin_header(self, n):
def _pack_bin_header(self, n: int):
if not self._use_bin_type:
return self._pack_raw_header(n)
elif n <= 0xFF: