| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  | # Implementat marshal.loads() in pure Python | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ast | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-23 08:56:06 -08:00
										 |  |  | from typing import Any, Tuple | 
					
						
							| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Type: | 
					
						
							|  |  |  |     # Adapted from marshal.c | 
					
						
							|  |  |  |     NULL                = ord('0') | 
					
						
							|  |  |  |     NONE                = ord('N') | 
					
						
							|  |  |  |     FALSE               = ord('F') | 
					
						
							|  |  |  |     TRUE                = ord('T') | 
					
						
							|  |  |  |     STOPITER            = ord('S') | 
					
						
							|  |  |  |     ELLIPSIS            = ord('.') | 
					
						
							|  |  |  |     INT                 = ord('i') | 
					
						
							|  |  |  |     INT64               = ord('I') | 
					
						
							|  |  |  |     FLOAT               = ord('f') | 
					
						
							|  |  |  |     BINARY_FLOAT        = ord('g') | 
					
						
							|  |  |  |     COMPLEX             = ord('x') | 
					
						
							|  |  |  |     BINARY_COMPLEX      = ord('y') | 
					
						
							|  |  |  |     LONG                = ord('l') | 
					
						
							|  |  |  |     STRING              = ord('s') | 
					
						
							|  |  |  |     INTERNED            = ord('t') | 
					
						
							|  |  |  |     REF                 = ord('r') | 
					
						
							|  |  |  |     TUPLE               = ord('(') | 
					
						
							|  |  |  |     LIST                = ord('[') | 
					
						
							|  |  |  |     DICT                = ord('{') | 
					
						
							|  |  |  |     CODE                = ord('c') | 
					
						
							|  |  |  |     UNICODE             = ord('u') | 
					
						
							|  |  |  |     UNKNOWN             = ord('?') | 
					
						
							|  |  |  |     SET                 = ord('<') | 
					
						
							|  |  |  |     FROZENSET           = ord('>') | 
					
						
							|  |  |  |     ASCII               = ord('a') | 
					
						
							|  |  |  |     ASCII_INTERNED      = ord('A') | 
					
						
							|  |  |  |     SMALL_TUPLE         = ord(')') | 
					
						
							|  |  |  |     SHORT_ASCII         = ord('z') | 
					
						
							|  |  |  |     SHORT_ASCII_INTERNED = ord('Z') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | FLAG_REF = 0x80  # with a type, add obj to index | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | NULL = object()  # marker | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Cell kinds | 
					
						
							|  |  |  | CO_FAST_LOCAL = 0x20 | 
					
						
							|  |  |  | CO_FAST_CELL = 0x40 | 
					
						
							|  |  |  | CO_FAST_FREE = 0x80 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Code: | 
					
						
							|  |  |  |     def __init__(self, **kwds: Any): | 
					
						
							|  |  |  |         self.__dict__.update(kwds) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self) -> str: | 
					
						
							|  |  |  |         return f"Code(**{self.__dict__})" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-23 08:56:06 -08:00
										 |  |  |     co_localsplusnames: Tuple[str] | 
					
						
							|  |  |  |     co_localspluskinds: Tuple[int] | 
					
						
							| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-23 08:56:06 -08:00
										 |  |  |     def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]: | 
					
						
							| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  |         varnames: list[str] = [] | 
					
						
							|  |  |  |         for name, kind in zip(self.co_localsplusnames, | 
					
						
							|  |  |  |                               self.co_localspluskinds): | 
					
						
							|  |  |  |             if kind & select_kind: | 
					
						
							|  |  |  |                 varnames.append(name) | 
					
						
							|  |  |  |         return tuple(varnames) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							| 
									
										
										
										
											2021-11-23 08:56:06 -08:00
										 |  |  |     def co_varnames(self) -> Tuple[str, ...]: | 
					
						
							| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  |         return self.get_localsplus_names(CO_FAST_LOCAL) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							| 
									
										
										
										
											2021-11-23 08:56:06 -08:00
										 |  |  |     def co_cellvars(self) -> Tuple[str, ...]: | 
					
						
							| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  |         return self.get_localsplus_names(CO_FAST_CELL) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							| 
									
										
										
										
											2021-11-23 08:56:06 -08:00
										 |  |  |     def co_freevars(self) -> Tuple[str, ...]: | 
					
						
							| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  |         return self.get_localsplus_names(CO_FAST_FREE) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							|  |  |  |     def co_nlocals(self) -> int: | 
					
						
							|  |  |  |         return len(self.co_varnames) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Reader: | 
					
						
							|  |  |  |     # A fairly literal translation of the marshal reader. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, data: bytes): | 
					
						
							|  |  |  |         self.data: bytes = data | 
					
						
							|  |  |  |         self.end: int = len(self.data) | 
					
						
							|  |  |  |         self.pos: int = 0 | 
					
						
							|  |  |  |         self.refs: list[Any] = [] | 
					
						
							|  |  |  |         self.level: int = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_string(self, n: int) -> bytes: | 
					
						
							|  |  |  |         assert 0 <= n <= self.end - self.pos | 
					
						
							|  |  |  |         buf = self.data[self.pos : self.pos + n] | 
					
						
							|  |  |  |         self.pos += n | 
					
						
							|  |  |  |         return buf | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_byte(self) -> int: | 
					
						
							|  |  |  |         buf = self.r_string(1) | 
					
						
							|  |  |  |         return buf[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_short(self) -> int: | 
					
						
							|  |  |  |         buf = self.r_string(2) | 
					
						
							|  |  |  |         x = buf[0] | 
					
						
							|  |  |  |         x |= buf[1] << 8 | 
					
						
							|  |  |  |         x |= -(x & (1<<15))  # Sign-extend | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_long(self) -> int: | 
					
						
							|  |  |  |         buf = self.r_string(4) | 
					
						
							|  |  |  |         x = buf[0] | 
					
						
							|  |  |  |         x |= buf[1] << 8 | 
					
						
							|  |  |  |         x |= buf[2] << 16 | 
					
						
							|  |  |  |         x |= buf[3] << 24 | 
					
						
							|  |  |  |         x |= -(x & (1<<31))  # Sign-extend | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_long64(self) -> int: | 
					
						
							|  |  |  |         buf = self.r_string(8) | 
					
						
							|  |  |  |         x = buf[0] | 
					
						
							|  |  |  |         x |= buf[1] << 8 | 
					
						
							|  |  |  |         x |= buf[2] << 16 | 
					
						
							|  |  |  |         x |= buf[3] << 24 | 
					
						
							|  |  |  |         x |= buf[1] << 32 | 
					
						
							|  |  |  |         x |= buf[1] << 40 | 
					
						
							|  |  |  |         x |= buf[1] << 48 | 
					
						
							|  |  |  |         x |= buf[1] << 56 | 
					
						
							|  |  |  |         x |= -(x & (1<<63))  # Sign-extend | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_PyLong(self) -> int: | 
					
						
							|  |  |  |         n = self.r_long() | 
					
						
							|  |  |  |         size = abs(n) | 
					
						
							|  |  |  |         x = 0 | 
					
						
							|  |  |  |         # Pray this is right | 
					
						
							|  |  |  |         for i in range(size): | 
					
						
							|  |  |  |             x |= self.r_short() << i*15 | 
					
						
							|  |  |  |         if n < 0: | 
					
						
							|  |  |  |             x = -x | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_float_bin(self) -> float: | 
					
						
							|  |  |  |         buf = self.r_string(8) | 
					
						
							|  |  |  |         import struct  # Lazy import to avoid breaking UNIX build | 
					
						
							|  |  |  |         return struct.unpack("d", buf)[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_float_str(self) -> float: | 
					
						
							|  |  |  |         n = self.r_byte() | 
					
						
							|  |  |  |         buf = self.r_string(n) | 
					
						
							|  |  |  |         return ast.literal_eval(buf.decode("ascii")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_ref_reserve(self, flag: int) -> int: | 
					
						
							|  |  |  |         if flag: | 
					
						
							|  |  |  |             idx = len(self.refs) | 
					
						
							|  |  |  |             self.refs.append(None) | 
					
						
							|  |  |  |             return idx | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any: | 
					
						
							|  |  |  |         if flag: | 
					
						
							|  |  |  |             self.refs[idx] = obj | 
					
						
							|  |  |  |         return obj | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_ref(self, obj: Any, flag: int) -> Any: | 
					
						
							|  |  |  |         assert flag & FLAG_REF | 
					
						
							|  |  |  |         self.refs.append(obj) | 
					
						
							|  |  |  |         return obj | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def r_object(self) -> Any: | 
					
						
							|  |  |  |         old_level = self.level | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             return self._r_object() | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             self.level = old_level | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _r_object(self) -> Any: | 
					
						
							|  |  |  |         code = self.r_byte() | 
					
						
							|  |  |  |         flag = code & FLAG_REF | 
					
						
							|  |  |  |         type = code & ~FLAG_REF | 
					
						
							|  |  |  |         # print("  "*self.level + f"{code} {flag} {type} {chr(type)!r}") | 
					
						
							|  |  |  |         self.level += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def R_REF(obj: Any) -> Any: | 
					
						
							|  |  |  |             if flag: | 
					
						
							|  |  |  |                 obj = self.r_ref(obj, flag) | 
					
						
							|  |  |  |             return obj | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-23 08:56:06 -08:00
										 |  |  |         if type == Type.NULL: | 
					
						
							|  |  |  |             return NULL | 
					
						
							|  |  |  |         elif type == Type.NONE: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         elif type == Type.ELLIPSIS: | 
					
						
							|  |  |  |             return Ellipsis | 
					
						
							|  |  |  |         elif type == Type.FALSE: | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  |         elif type == Type.TRUE: | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |         elif type == Type.INT: | 
					
						
							|  |  |  |             return R_REF(self.r_long()) | 
					
						
							|  |  |  |         elif type == Type.INT64: | 
					
						
							|  |  |  |             return R_REF(self.r_long64()) | 
					
						
							|  |  |  |         elif type == Type.LONG: | 
					
						
							|  |  |  |             return R_REF(self.r_PyLong()) | 
					
						
							|  |  |  |         elif type == Type.FLOAT: | 
					
						
							|  |  |  |             return R_REF(self.r_float_str()) | 
					
						
							|  |  |  |         elif type == Type.BINARY_FLOAT: | 
					
						
							|  |  |  |             return R_REF(self.r_float_bin()) | 
					
						
							|  |  |  |         elif type == Type.COMPLEX: | 
					
						
							|  |  |  |             return R_REF(complex(self.r_float_str(), | 
					
						
							|  |  |  |                                     self.r_float_str())) | 
					
						
							|  |  |  |         elif type == Type.BINARY_COMPLEX: | 
					
						
							|  |  |  |             return R_REF(complex(self.r_float_bin(), | 
					
						
							|  |  |  |                                     self.r_float_bin())) | 
					
						
							|  |  |  |         elif type == Type.STRING: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             return R_REF(self.r_string(n)) | 
					
						
							|  |  |  |         elif type == Type.ASCII_INTERNED or type == Type.ASCII: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             return R_REF(self.r_string(n).decode("ascii")) | 
					
						
							|  |  |  |         elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII: | 
					
						
							|  |  |  |             n = self.r_byte() | 
					
						
							|  |  |  |             return R_REF(self.r_string(n).decode("ascii")) | 
					
						
							|  |  |  |         elif type == Type.INTERNED or type == Type.UNICODE: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             return R_REF(self.r_string(n).decode("utf8", "surrogatepass")) | 
					
						
							|  |  |  |         elif type == Type.SMALL_TUPLE: | 
					
						
							|  |  |  |             n = self.r_byte() | 
					
						
							|  |  |  |             idx = self.r_ref_reserve(flag) | 
					
						
							|  |  |  |             retval: Any = tuple(self.r_object() for _ in range(n)) | 
					
						
							|  |  |  |             self.r_ref_insert(retval, idx, flag) | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         elif type == Type.TUPLE: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             idx = self.r_ref_reserve(flag) | 
					
						
							|  |  |  |             retval = tuple(self.r_object() for _ in range(n)) | 
					
						
							|  |  |  |             self.r_ref_insert(retval, idx, flag) | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         elif type == Type.LIST: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             retval = R_REF([]) | 
					
						
							|  |  |  |             for _ in range(n): | 
					
						
							|  |  |  |                 retval.append(self.r_object()) | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         elif type == Type.DICT: | 
					
						
							|  |  |  |             retval = R_REF({}) | 
					
						
							|  |  |  |             while True: | 
					
						
							|  |  |  |                 key = self.r_object() | 
					
						
							|  |  |  |                 if key == NULL: | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |                 val = self.r_object() | 
					
						
							|  |  |  |                 retval[key] = val | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         elif type == Type.SET: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             retval = R_REF(set()) | 
					
						
							|  |  |  |             for _ in range(n): | 
					
						
							|  |  |  |                 v = self.r_object() | 
					
						
							|  |  |  |                 retval.add(v) | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         elif type == Type.FROZENSET: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             s: set[Any] = set() | 
					
						
							|  |  |  |             idx = self.r_ref_reserve(flag) | 
					
						
							|  |  |  |             for _ in range(n): | 
					
						
							|  |  |  |                 v = self.r_object() | 
					
						
							|  |  |  |                 s.add(v) | 
					
						
							|  |  |  |             retval = frozenset(s) | 
					
						
							|  |  |  |             self.r_ref_insert(retval, idx, flag) | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         elif type == Type.CODE: | 
					
						
							|  |  |  |             retval = R_REF(Code()) | 
					
						
							|  |  |  |             retval.co_argcount = self.r_long() | 
					
						
							|  |  |  |             retval.co_posonlyargcount = self.r_long() | 
					
						
							|  |  |  |             retval.co_kwonlyargcount = self.r_long() | 
					
						
							|  |  |  |             retval.co_stacksize = self.r_long() | 
					
						
							|  |  |  |             retval.co_flags = self.r_long() | 
					
						
							|  |  |  |             retval.co_code = self.r_object() | 
					
						
							|  |  |  |             retval.co_consts = self.r_object() | 
					
						
							|  |  |  |             retval.co_names = self.r_object() | 
					
						
							|  |  |  |             retval.co_localsplusnames = self.r_object() | 
					
						
							|  |  |  |             retval.co_localspluskinds = self.r_object() | 
					
						
							|  |  |  |             retval.co_filename = self.r_object() | 
					
						
							|  |  |  |             retval.co_name = self.r_object() | 
					
						
							|  |  |  |             retval.co_qualname = self.r_object() | 
					
						
							|  |  |  |             retval.co_firstlineno = self.r_long() | 
					
						
							|  |  |  |             retval.co_linetable = self.r_object() | 
					
						
							|  |  |  |             retval.co_endlinetable = self.r_object() | 
					
						
							|  |  |  |             retval.co_columntable = self.r_object() | 
					
						
							|  |  |  |             retval.co_exceptiontable = self.r_object() | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         elif type == Type.REF: | 
					
						
							|  |  |  |             n = self.r_long() | 
					
						
							|  |  |  |             retval = self.refs[n] | 
					
						
							|  |  |  |             assert retval is not None | 
					
						
							|  |  |  |             return retval | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             breakpoint() | 
					
						
							|  |  |  |             raise AssertionError(f"Unknown type {type} {chr(type)!r}") | 
					
						
							| 
									
										
										
										
											2021-11-22 10:09:48 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def loads(data: bytes) -> Any: | 
					
						
							|  |  |  |     assert isinstance(data, bytes) | 
					
						
							|  |  |  |     r = Reader(data) | 
					
						
							|  |  |  |     return r.r_object() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def main(): | 
					
						
							|  |  |  |     # Test | 
					
						
							|  |  |  |     import marshal, pprint | 
					
						
							|  |  |  |     sample = {'foo': {(42, "bar", 3.14)}} | 
					
						
							|  |  |  |     data = marshal.dumps(sample) | 
					
						
							|  |  |  |     retval = loads(data) | 
					
						
							|  |  |  |     assert retval == sample, retval | 
					
						
							|  |  |  |     sample = main.__code__ | 
					
						
							|  |  |  |     data = marshal.dumps(sample) | 
					
						
							|  |  |  |     retval = loads(data) | 
					
						
							|  |  |  |     assert isinstance(retval, Code), retval | 
					
						
							|  |  |  |     pprint.pprint(retval.__dict__) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     main() |