mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	gh-132775: Add _PyPickle_GetXIData() (gh-133107)
There's some extra complexity due to making sure we we get things right when handling functions and classes defined in the __main__ module. This is also reflected in the tests, including the addition of extra functions in test.support.import_helper.
This commit is contained in:
		
							parent
							
								
									6c522debc2
								
							
						
					
					
						commit
						cb35c11d82
					
				
					 5 changed files with 1056 additions and 55 deletions
				
			
		|  | @ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped( | |||
|         xid_newobjfunc, | ||||
|         _PyXIData_t *); | ||||
| 
 | ||||
| // _PyObject_GetXIData() for pickle
 | ||||
| PyAPI_DATA(PyObject *) _PyPickle_LoadFromXIData(_PyXIData_t *); | ||||
| PyAPI_FUNC(int) _PyPickle_GetXIData( | ||||
|         PyThreadState *, | ||||
|         PyObject *, | ||||
|         _PyXIData_t *); | ||||
| 
 | ||||
| // _PyObject_GetXIData() for marshal
 | ||||
| PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *); | ||||
| PyAPI_FUNC(int) _PyMarshal_GetXIData( | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| import contextlib | ||||
| import _imp | ||||
| import importlib | ||||
| import importlib.machinery | ||||
| import importlib.util | ||||
| import os | ||||
| import shutil | ||||
|  | @ -332,3 +333,110 @@ def ensure_lazy_imports(imported_module, modules_to_block): | |||
|     ) | ||||
|     from .script_helper import assert_python_ok | ||||
|     assert_python_ok("-S", "-c", script) | ||||
| 
 | ||||
| 
 | ||||
| @contextlib.contextmanager | ||||
| def module_restored(name): | ||||
|     """A context manager that restores a module to the original state.""" | ||||
|     missing = object() | ||||
|     orig = sys.modules.get(name, missing) | ||||
|     if orig is None: | ||||
|         mod = importlib.import_module(name) | ||||
|     else: | ||||
|         mod = type(sys)(name) | ||||
|         mod.__dict__.update(orig.__dict__) | ||||
|         sys.modules[name] = mod | ||||
|     try: | ||||
|         yield mod | ||||
|     finally: | ||||
|         if orig is missing: | ||||
|             sys.modules.pop(name, None) | ||||
|         else: | ||||
|             sys.modules[name] = orig | ||||
| 
 | ||||
| 
 | ||||
| def create_module(name, loader=None, *, ispkg=False): | ||||
|     """Return a new, empty module.""" | ||||
|     spec = importlib.machinery.ModuleSpec( | ||||
|         name, | ||||
|         loader, | ||||
|         origin='<import_helper>', | ||||
|         is_package=ispkg, | ||||
|     ) | ||||
|     return importlib.util.module_from_spec(spec) | ||||
| 
 | ||||
| 
 | ||||
| def _ensure_module(name, ispkg, addparent, clearnone): | ||||
|     try: | ||||
|         mod = orig = sys.modules[name] | ||||
|     except KeyError: | ||||
|         mod = orig = None | ||||
|         missing = True | ||||
|     else: | ||||
|         missing = False | ||||
|         if mod is not None: | ||||
|             # It was already imported. | ||||
|             return mod, orig, missing | ||||
|         # Otherwise, None means it was explicitly disabled. | ||||
| 
 | ||||
|     assert name != '__main__' | ||||
|     if not missing: | ||||
|         assert orig is None, (name, sys.modules[name]) | ||||
|         if not clearnone: | ||||
|             raise ModuleNotFoundError(name) | ||||
|         del sys.modules[name] | ||||
|     # Try normal import, then fall back to adding the module. | ||||
|     try: | ||||
|         mod = importlib.import_module(name) | ||||
|     except ModuleNotFoundError: | ||||
|         if addparent and not clearnone: | ||||
|             addparent = None | ||||
|         mod = _add_module(name, ispkg, addparent) | ||||
|     return mod, orig, missing | ||||
| 
 | ||||
| 
 | ||||
| def _add_module(spec, ispkg, addparent): | ||||
|     if isinstance(spec, str): | ||||
|         name = spec | ||||
|         mod = create_module(name, ispkg=ispkg) | ||||
|         spec = mod.__spec__ | ||||
|     else: | ||||
|         name = spec.name | ||||
|         mod = importlib.util.module_from_spec(spec) | ||||
|     sys.modules[name] = mod | ||||
|     if addparent is not False and spec.parent: | ||||
|         _ensure_module(spec.parent, True, addparent, bool(addparent)) | ||||
|     return mod | ||||
| 
 | ||||
| 
 | ||||
| def add_module(spec, *, parents=True): | ||||
|     """Return the module after creating it and adding it to sys.modules. | ||||
| 
 | ||||
|     If parents is True then also create any missing parents. | ||||
|     """ | ||||
|     return _add_module(spec, False, parents) | ||||
| 
 | ||||
| 
 | ||||
| def add_package(spec, *, parents=True): | ||||
|     """Return the module after creating it and adding it to sys.modules. | ||||
| 
 | ||||
|     If parents is True then also create any missing parents. | ||||
|     """ | ||||
|     return _add_module(spec, True, parents) | ||||
| 
 | ||||
| 
 | ||||
| def ensure_module_imported(name, *, clearnone=True): | ||||
|     """Return the corresponding module. | ||||
| 
 | ||||
|     If it was already imported then return that.  Otherwise, try | ||||
|     importing it (optionally clear it first if None).  If that fails | ||||
|     then create a new empty module. | ||||
| 
 | ||||
|     It can be helpful to combine this with ready_to_import() and/or | ||||
|     isolated_modules(). | ||||
|     """ | ||||
|     if sys.modules.get(name) is not None: | ||||
|         mod = sys.modules[name] | ||||
|     else: | ||||
|         mod, _, _ = _force_import(name, False, True, clearnone) | ||||
|     return mod | ||||
|  |  | |||
|  | @ -1,3 +1,6 @@ | |||
| import contextlib | ||||
| import importlib | ||||
| import importlib.util | ||||
| import itertools | ||||
| import sys | ||||
| import types | ||||
|  | @ -9,7 +12,7 @@ | |||
| _interpreters = import_helper.import_module('_interpreters') | ||||
| from _interpreters import NotShareableError | ||||
| 
 | ||||
| 
 | ||||
| from test import _code_definitions as code_defs | ||||
| from test import _crossinterp_definitions as defs | ||||
| 
 | ||||
| 
 | ||||
|  | @ -21,6 +24,88 @@ | |||
|                if (isinstance(o, type) and | ||||
|                   n not in ('DynamicClassAttribute', '_GeneratorWrapper'))] | ||||
| 
 | ||||
| DEFS = defs | ||||
| with open(code_defs.__file__) as infile: | ||||
|     _code_defs_text = infile.read() | ||||
| with open(DEFS.__file__) as infile: | ||||
|     _defs_text = infile.read() | ||||
|     _defs_text = _defs_text.replace('from ', '# from ') | ||||
| DEFS_TEXT = f""" | ||||
| ####################################### | ||||
| # from {code_defs.__file__} | ||||
| 
 | ||||
| {_code_defs_text} | ||||
| 
 | ||||
| ####################################### | ||||
| # from {defs.__file__} | ||||
| 
 | ||||
| {_defs_text} | ||||
| """ | ||||
| del infile, _code_defs_text, _defs_text | ||||
| 
 | ||||
| 
 | ||||
| def load_defs(module=None): | ||||
|     """Return a new copy of the test._crossinterp_definitions module. | ||||
| 
 | ||||
|     The module's __name__ matches the "module" arg, which is either | ||||
|     a str or a module. | ||||
| 
 | ||||
|     If the "module" arg is a module then the just-loaded defs are also | ||||
|     copied into that module. | ||||
| 
 | ||||
|     Note that the new module is not added to sys.modules. | ||||
|     """ | ||||
|     if module is None: | ||||
|         modname = DEFS.__name__ | ||||
|     elif isinstance(module, str): | ||||
|         modname = module | ||||
|         module = None | ||||
|     else: | ||||
|         modname = module.__name__ | ||||
|     # Create the new module and populate it. | ||||
|     defs = import_helper.create_module(modname) | ||||
|     defs.__file__ = DEFS.__file__ | ||||
|     exec(DEFS_TEXT, defs.__dict__) | ||||
|     # Copy the defs into the module arg, if any. | ||||
|     if module is not None: | ||||
|         for name, value in defs.__dict__.items(): | ||||
|             if name.startswith('_'): | ||||
|                 continue | ||||
|             assert not hasattr(module, name), (name, getattr(module, name)) | ||||
|             setattr(module, name, value) | ||||
|     return defs | ||||
| 
 | ||||
| 
 | ||||
| @contextlib.contextmanager | ||||
| def using___main__(): | ||||
|     """Make sure __main__ module exists (and clean up after).""" | ||||
|     modname = '__main__' | ||||
|     if modname not in sys.modules: | ||||
|         with import_helper.isolated_modules(): | ||||
|             yield import_helper.add_module(modname) | ||||
|     else: | ||||
|         with import_helper.module_restored(modname) as mod: | ||||
|             yield mod | ||||
| 
 | ||||
| 
 | ||||
| @contextlib.contextmanager | ||||
| def temp_module(modname): | ||||
|     """Create the module and add to sys.modules, then remove it after.""" | ||||
|     assert modname not in sys.modules, (modname,) | ||||
|     with import_helper.isolated_modules(): | ||||
|         yield import_helper.add_module(modname) | ||||
| 
 | ||||
| 
 | ||||
| @contextlib.contextmanager | ||||
| def missing_defs_module(modname, *, prep=False): | ||||
|     assert modname not in sys.modules, (modname,) | ||||
|     if prep: | ||||
|         with import_helper.ready_to_import(modname, DEFS_TEXT): | ||||
|             yield modname | ||||
|     else: | ||||
|         with import_helper.isolated_modules(): | ||||
|             yield modname | ||||
| 
 | ||||
| 
 | ||||
| class _GetXIDataTests(unittest.TestCase): | ||||
| 
 | ||||
|  | @ -32,52 +117,49 @@ def get_xidata(self, obj, *, mode=None): | |||
| 
 | ||||
|     def get_roundtrip(self, obj, *, mode=None): | ||||
|         mode = self._resolve_mode(mode) | ||||
|         xid =_testinternalcapi.get_crossinterp_data(obj, mode) | ||||
|         return self._get_roundtrip(obj, mode) | ||||
| 
 | ||||
|     def _get_roundtrip(self, obj, mode): | ||||
|         xid = _testinternalcapi.get_crossinterp_data(obj, mode) | ||||
|         return _testinternalcapi.restore_crossinterp_data(xid) | ||||
| 
 | ||||
|     def iter_roundtrip_values(self, values, *, mode=None): | ||||
|     def assert_roundtrip_identical(self, values, *, mode=None): | ||||
|         mode = self._resolve_mode(mode) | ||||
|         for obj in values: | ||||
|             with self.subTest(obj): | ||||
|                 xid = _testinternalcapi.get_crossinterp_data(obj, mode) | ||||
|                 got = _testinternalcapi.restore_crossinterp_data(xid) | ||||
|                 yield obj, got | ||||
| 
 | ||||
|     def assert_roundtrip_identical(self, values, *, mode=None): | ||||
|         for obj, got in self.iter_roundtrip_values(values, mode=mode): | ||||
|             # XXX What about between interpreters? | ||||
|             self.assertIs(got, obj) | ||||
|                 got = self._get_roundtrip(obj, mode) | ||||
|                 self.assertIs(got, obj) | ||||
| 
 | ||||
|     def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None): | ||||
|         for obj, got in self.iter_roundtrip_values(values, mode=mode): | ||||
|             self.assertEqual(got, obj) | ||||
|             self.assertIs(type(got), | ||||
|                           type(obj) if expecttype is None else expecttype) | ||||
|         mode = self._resolve_mode(mode) | ||||
|         for obj in values: | ||||
|             with self.subTest(obj): | ||||
|                 got = self._get_roundtrip(obj, mode) | ||||
|                 self.assertEqual(got, obj) | ||||
|                 self.assertIs(type(got), | ||||
|                               type(obj) if expecttype is None else expecttype) | ||||
| 
 | ||||
| #    def assert_roundtrip_equal_not_identical(self, values, *, | ||||
| #                                            mode=None, expecttype=None): | ||||
| #        mode = self._resolve_mode(mode) | ||||
| #        for obj in values: | ||||
| #            cls = type(obj) | ||||
| #            with self.subTest(obj): | ||||
| #                got = self._get_roundtrip(obj, mode) | ||||
| #                self.assertIsNot(got, obj) | ||||
| #                self.assertIs(type(got), type(obj)) | ||||
| #                self.assertEqual(got, obj) | ||||
| #                self.assertIs(type(got), | ||||
| #                              cls if expecttype is None else expecttype) | ||||
| # | ||||
| #    def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None): | ||||
| #        mode = self._resolve_mode(mode) | ||||
| #        for obj in values: | ||||
| #            cls = type(obj) | ||||
| #            with self.subTest(obj): | ||||
| #                got = self._get_roundtrip(obj, mode) | ||||
| #                self.assertIsNot(got, obj) | ||||
| #                self.assertIs(type(got), type(obj)) | ||||
| #                self.assertNotEqual(got, obj) | ||||
| #                self.assertIs(type(got), | ||||
| #                              cls if expecttype is None else expecttype) | ||||
|     def assert_roundtrip_equal_not_identical(self, values, *, | ||||
|                                              mode=None, expecttype=None): | ||||
|         mode = self._resolve_mode(mode) | ||||
|         for obj in values: | ||||
|             with self.subTest(obj): | ||||
|                 got = self._get_roundtrip(obj, mode) | ||||
|                 self.assertIsNot(got, obj) | ||||
|                 self.assertIs(type(got), | ||||
|                               type(obj) if expecttype is None else expecttype) | ||||
|                 self.assertEqual(got, obj) | ||||
| 
 | ||||
|     def assert_roundtrip_not_equal(self, values, *, | ||||
|                                    mode=None, expecttype=None): | ||||
|         mode = self._resolve_mode(mode) | ||||
|         for obj in values: | ||||
|             with self.subTest(obj): | ||||
|                 got = self._get_roundtrip(obj, mode) | ||||
|                 self.assertIsNot(got, obj) | ||||
|                 self.assertIs(type(got), | ||||
|                               type(obj) if expecttype is None else expecttype) | ||||
|                 self.assertNotEqual(got, obj) | ||||
| 
 | ||||
|     def assert_not_shareable(self, values, exctype=None, *, mode=None): | ||||
|         mode = self._resolve_mode(mode) | ||||
|  | @ -95,6 +177,363 @@ def _resolve_mode(self, mode): | |||
|         return mode | ||||
| 
 | ||||
| 
 | ||||
| class PickleTests(_GetXIDataTests): | ||||
| 
 | ||||
|     MODE = 'pickle' | ||||
| 
 | ||||
|     def test_shareable(self): | ||||
|         self.assert_roundtrip_equal([ | ||||
|             # singletons | ||||
|             None, | ||||
|             True, | ||||
|             False, | ||||
|             # bytes | ||||
|             *(i.to_bytes(2, 'little', signed=True) | ||||
|               for i in range(-1, 258)), | ||||
|             # str | ||||
|             'hello world', | ||||
|             '你好世界', | ||||
|             '', | ||||
|             # int | ||||
|             sys.maxsize, | ||||
|             -sys.maxsize - 1, | ||||
|             *range(-1, 258), | ||||
|             # float | ||||
|             0.0, | ||||
|             1.1, | ||||
|             -1.0, | ||||
|             0.12345678, | ||||
|             -0.12345678, | ||||
|             # tuple | ||||
|             (), | ||||
|             (1,), | ||||
|             ("hello", "world", ), | ||||
|             (1, True, "hello"), | ||||
|             ((1,),), | ||||
|             ((1, 2), (3, 4)), | ||||
|             ((1, 2), (3, 4), (5, 6)), | ||||
|         ]) | ||||
|         # not shareable using xidata | ||||
|         self.assert_roundtrip_equal([ | ||||
|             # int | ||||
|             sys.maxsize + 1, | ||||
|             -sys.maxsize - 2, | ||||
|             2**1000, | ||||
|             # tuple | ||||
|             (0, 1.0, []), | ||||
|             (0, 1.0, {}), | ||||
|             (0, 1.0, ([],)), | ||||
|             (0, 1.0, ({},)), | ||||
|         ]) | ||||
| 
 | ||||
|     def test_list(self): | ||||
|         self.assert_roundtrip_equal_not_identical([ | ||||
|             [], | ||||
|             [1, 2, 3], | ||||
|             [[1], (2,), {3: 4}], | ||||
|         ]) | ||||
| 
 | ||||
|     def test_dict(self): | ||||
|         self.assert_roundtrip_equal_not_identical([ | ||||
|             {}, | ||||
|             {1: 7, 2: 8, 3: 9}, | ||||
|             {1: [1], 2: (2,), 3: {3: 4}}, | ||||
|         ]) | ||||
| 
 | ||||
|     def test_set(self): | ||||
|         self.assert_roundtrip_equal_not_identical([ | ||||
|             set(), | ||||
|             {1, 2, 3}, | ||||
|             {frozenset({1}), (2,)}, | ||||
|         ]) | ||||
| 
 | ||||
|     # classes | ||||
| 
 | ||||
|     def assert_class_defs_same(self, defs): | ||||
|         # Unpickle relative to the unchanged original module. | ||||
|         self.assert_roundtrip_identical(defs.TOP_CLASSES) | ||||
| 
 | ||||
|         instances = [] | ||||
|         for cls, args in defs.TOP_CLASSES.items(): | ||||
|             if cls in defs.CLASSES_WITHOUT_EQUALITY: | ||||
|                 continue | ||||
|             instances.append(cls(*args)) | ||||
|         self.assert_roundtrip_equal_not_identical(instances) | ||||
| 
 | ||||
|         # these don't compare equal | ||||
|         instances = [] | ||||
|         for cls, args in defs.TOP_CLASSES.items(): | ||||
|             if cls not in defs.CLASSES_WITHOUT_EQUALITY: | ||||
|                 continue | ||||
|             instances.append(cls(*args)) | ||||
|         self.assert_roundtrip_not_equal(instances) | ||||
| 
 | ||||
|     def assert_class_defs_other_pickle(self, defs, mod): | ||||
|         # Pickle relative to a different module than the original. | ||||
|         for cls in defs.TOP_CLASSES: | ||||
|             assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) | ||||
|         self.assert_not_shareable(defs.TOP_CLASSES) | ||||
| 
 | ||||
|         instances = [] | ||||
|         for cls, args in defs.TOP_CLASSES.items(): | ||||
|             instances.append(cls(*args)) | ||||
|         self.assert_not_shareable(instances) | ||||
| 
 | ||||
|     def assert_class_defs_other_unpickle(self, defs, mod, *, fail=False): | ||||
|         # Unpickle relative to a different module than the original. | ||||
|         for cls in defs.TOP_CLASSES: | ||||
|             assert not hasattr(mod, cls.__name__), (cls, getattr(mod, cls.__name__)) | ||||
| 
 | ||||
|         instances = [] | ||||
|         for cls, args in defs.TOP_CLASSES.items(): | ||||
|             with self.subTest(cls): | ||||
|                 setattr(mod, cls.__name__, cls) | ||||
|                 xid = self.get_xidata(cls) | ||||
|                 inst = cls(*args) | ||||
|                 instxid = self.get_xidata(inst) | ||||
|                 instances.append( | ||||
|                         (cls, xid, inst, instxid)) | ||||
| 
 | ||||
|         for cls, xid, inst, instxid in instances: | ||||
|             with self.subTest(cls): | ||||
|                 delattr(mod, cls.__name__) | ||||
|                 if fail: | ||||
|                     with self.assertRaises(NotShareableError): | ||||
|                         _testinternalcapi.restore_crossinterp_data(xid) | ||||
|                     continue | ||||
|                 got = _testinternalcapi.restore_crossinterp_data(xid) | ||||
|                 self.assertIsNot(got, cls) | ||||
|                 self.assertNotEqual(got, cls) | ||||
| 
 | ||||
|                 gotcls = got | ||||
|                 got = _testinternalcapi.restore_crossinterp_data(instxid) | ||||
|                 self.assertIsNot(got, inst) | ||||
|                 self.assertIs(type(got), gotcls) | ||||
|                 if cls in defs.CLASSES_WITHOUT_EQUALITY: | ||||
|                     self.assertNotEqual(got, inst) | ||||
|                 elif cls in defs.BUILTIN_SUBCLASSES: | ||||
|                     self.assertEqual(got, inst) | ||||
|                 else: | ||||
|                     self.assertNotEqual(got, inst) | ||||
| 
 | ||||
|     def assert_class_defs_not_shareable(self, defs): | ||||
|         self.assert_not_shareable(defs.TOP_CLASSES) | ||||
| 
 | ||||
|         instances = [] | ||||
|         for cls, args in defs.TOP_CLASSES.items(): | ||||
|             instances.append(cls(*args)) | ||||
|         self.assert_not_shareable(instances) | ||||
| 
 | ||||
|     def test_user_class_normal(self): | ||||
|         self.assert_class_defs_same(defs) | ||||
| 
 | ||||
|     def test_user_class_in___main__(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs(mod) | ||||
|             self.assert_class_defs_same(defs) | ||||
| 
 | ||||
|     def test_user_class_not_in___main___with_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             assert defs.__file__ | ||||
|             mod.__file__ = defs.__file__ | ||||
|             self.assert_class_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_class_not_in___main___without_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             defs.__file__ = None | ||||
|             mod.__file__ = None | ||||
|             self.assert_class_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_class_not_in___main___unpickle_with_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             assert defs.__file__ | ||||
|             mod.__file__ = defs.__file__ | ||||
|             self.assert_class_defs_other_unpickle(defs, mod) | ||||
| 
 | ||||
|     def test_user_class_not_in___main___unpickle_without_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             defs.__file__ = None | ||||
|             mod.__file__ = None | ||||
|             self.assert_class_defs_other_unpickle(defs, mod, fail=True) | ||||
| 
 | ||||
|     def test_user_class_in_module(self): | ||||
|         with temp_module('__spam__') as mod: | ||||
|             defs = load_defs(mod) | ||||
|             self.assert_class_defs_same(defs) | ||||
| 
 | ||||
|     def test_user_class_not_in_module_with_filename(self): | ||||
|         with temp_module('__spam__') as mod: | ||||
|             defs = load_defs(mod.__name__) | ||||
|             assert defs.__file__ | ||||
|             # For now, we only address this case for __main__. | ||||
|             self.assert_class_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_class_not_in_module_without_filename(self): | ||||
|         with temp_module('__spam__') as mod: | ||||
|             defs = load_defs(mod.__name__) | ||||
|             defs.__file__ = None | ||||
|             self.assert_class_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_class_module_missing_then_imported(self): | ||||
|         with missing_defs_module('__spam__', prep=True) as modname: | ||||
|             defs = load_defs(modname) | ||||
|             # For now, we only address this case for __main__. | ||||
|             self.assert_class_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_class_module_missing_not_available(self): | ||||
|         with missing_defs_module('__spam__') as modname: | ||||
|             defs = load_defs(modname) | ||||
|             self.assert_class_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_nested_class(self): | ||||
|         eggs = defs.EggsNested() | ||||
|         with self.assertRaises(NotShareableError): | ||||
|             self.get_roundtrip(eggs) | ||||
| 
 | ||||
|     # functions | ||||
| 
 | ||||
|     def assert_func_defs_same(self, defs): | ||||
|         # Unpickle relative to the unchanged original module. | ||||
|         self.assert_roundtrip_identical(defs.TOP_FUNCTIONS) | ||||
| 
 | ||||
|     def assert_func_defs_other_pickle(self, defs, mod): | ||||
|         # Pickle relative to a different module than the original. | ||||
|         for func in defs.TOP_FUNCTIONS: | ||||
|             assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) | ||||
|         self.assert_not_shareable(defs.TOP_FUNCTIONS) | ||||
| 
 | ||||
|     def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False): | ||||
|         # Unpickle relative to a different module than the original. | ||||
|         for func in defs.TOP_FUNCTIONS: | ||||
|             assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__)) | ||||
| 
 | ||||
|         captured = [] | ||||
|         for func in defs.TOP_FUNCTIONS: | ||||
|             with self.subTest(func): | ||||
|                 setattr(mod, func.__name__, func) | ||||
|                 xid = self.get_xidata(func) | ||||
|                 captured.append( | ||||
|                         (func, xid)) | ||||
| 
 | ||||
|         for func, xid in captured: | ||||
|             with self.subTest(func): | ||||
|                 delattr(mod, func.__name__) | ||||
|                 if fail: | ||||
|                     with self.assertRaises(NotShareableError): | ||||
|                         _testinternalcapi.restore_crossinterp_data(xid) | ||||
|                     continue | ||||
|                 got = _testinternalcapi.restore_crossinterp_data(xid) | ||||
|                 self.assertIsNot(got, func) | ||||
|                 self.assertNotEqual(got, func) | ||||
| 
 | ||||
|     def assert_func_defs_not_shareable(self, defs): | ||||
|         self.assert_not_shareable(defs.TOP_FUNCTIONS) | ||||
| 
 | ||||
|     def test_user_function_normal(self): | ||||
| #        self.assert_roundtrip_equal(defs.TOP_FUNCTIONS) | ||||
|         self.assert_func_defs_same(defs) | ||||
| 
 | ||||
|     def test_user_func_in___main__(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs(mod) | ||||
|             self.assert_func_defs_same(defs) | ||||
| 
 | ||||
|     def test_user_func_not_in___main___with_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             assert defs.__file__ | ||||
|             mod.__file__ = defs.__file__ | ||||
|             self.assert_func_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_func_not_in___main___without_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             defs.__file__ = None | ||||
|             mod.__file__ = None | ||||
|             self.assert_func_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_func_not_in___main___unpickle_with_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             assert defs.__file__ | ||||
|             mod.__file__ = defs.__file__ | ||||
|             self.assert_func_defs_other_unpickle(defs, mod) | ||||
| 
 | ||||
|     def test_user_func_not_in___main___unpickle_without_filename(self): | ||||
|         with using___main__() as mod: | ||||
|             defs = load_defs('__main__') | ||||
|             defs.__file__ = None | ||||
|             mod.__file__ = None | ||||
|             self.assert_func_defs_other_unpickle(defs, mod, fail=True) | ||||
| 
 | ||||
|     def test_user_func_in_module(self): | ||||
|         with temp_module('__spam__') as mod: | ||||
|             defs = load_defs(mod) | ||||
|             self.assert_func_defs_same(defs) | ||||
| 
 | ||||
|     def test_user_func_not_in_module_with_filename(self): | ||||
|         with temp_module('__spam__') as mod: | ||||
|             defs = load_defs(mod.__name__) | ||||
|             assert defs.__file__ | ||||
|             # For now, we only address this case for __main__. | ||||
|             self.assert_func_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_func_not_in_module_without_filename(self): | ||||
|         with temp_module('__spam__') as mod: | ||||
|             defs = load_defs(mod.__name__) | ||||
|             defs.__file__ = None | ||||
|             self.assert_func_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_func_module_missing_then_imported(self): | ||||
|         with missing_defs_module('__spam__', prep=True) as modname: | ||||
|             defs = load_defs(modname) | ||||
|             # For now, we only address this case for __main__. | ||||
|             self.assert_func_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_user_func_module_missing_not_available(self): | ||||
|         with missing_defs_module('__spam__') as modname: | ||||
|             defs = load_defs(modname) | ||||
|             self.assert_func_defs_not_shareable(defs) | ||||
| 
 | ||||
|     def test_nested_function(self): | ||||
|         self.assert_not_shareable(defs.NESTED_FUNCTIONS) | ||||
| 
 | ||||
|     # exceptions | ||||
| 
 | ||||
|     def test_user_exception_normal(self): | ||||
|         self.assert_roundtrip_not_equal([ | ||||
|             defs.MimimalError('error!'), | ||||
|         ]) | ||||
|         self.assert_roundtrip_equal_not_identical([ | ||||
|             defs.RichError('error!', 42), | ||||
|         ]) | ||||
| 
 | ||||
|     def test_builtin_exception(self): | ||||
|         msg = 'error!' | ||||
|         try: | ||||
|             raise Exception | ||||
|         except Exception as exc: | ||||
|             caught = exc | ||||
|         special = { | ||||
|             BaseExceptionGroup: (msg, [caught]), | ||||
|             ExceptionGroup: (msg, [caught]), | ||||
| #            UnicodeError: (None, msg, None, None, None), | ||||
|             UnicodeEncodeError: ('utf-8', '', 1, 3, msg), | ||||
|             UnicodeDecodeError: ('utf-8', b'', 1, 3, msg), | ||||
|             UnicodeTranslateError: ('', 1, 3, msg), | ||||
|         } | ||||
|         exceptions = [] | ||||
|         for cls in EXCEPTION_TYPES: | ||||
|             args = special.get(cls) or (msg,) | ||||
|             exceptions.append(cls(*args)) | ||||
| 
 | ||||
|         self.assert_roundtrip_not_equal(exceptions) | ||||
| 
 | ||||
| 
 | ||||
| class MarshalTests(_GetXIDataTests): | ||||
| 
 | ||||
|     MODE = 'marshal' | ||||
|  | @ -444,22 +883,12 @@ def test_module(self): | |||
|         ]) | ||||
| 
 | ||||
|     def test_class(self): | ||||
|         self.assert_not_shareable([ | ||||
|             defs.Spam, | ||||
|             defs.SpamOkay, | ||||
|             defs.SpamFull, | ||||
|             defs.SubSpamFull, | ||||
|             defs.SubTuple, | ||||
|             defs.EggsNested, | ||||
|         ]) | ||||
|         self.assert_not_shareable([ | ||||
|             defs.Spam(), | ||||
|             defs.SpamOkay(), | ||||
|             defs.SpamFull(1, 2, 3), | ||||
|             defs.SubSpamFull(1, 2, 3), | ||||
|             defs.SubTuple([1, 2, 3]), | ||||
|             defs.EggsNested(), | ||||
|         ]) | ||||
|         self.assert_not_shareable(defs.CLASSES) | ||||
| 
 | ||||
|         instances = [] | ||||
|         for cls, args in defs.CLASSES.items(): | ||||
|             instances.append(cls(*args)) | ||||
|         self.assert_not_shareable(instances) | ||||
| 
 | ||||
|     def test_builtin_type(self): | ||||
|         self.assert_not_shareable([ | ||||
|  |  | |||
|  | @ -1939,6 +1939,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs) | |||
|             goto error; | ||||
|         } | ||||
|     } | ||||
|     else if (strcmp(mode, "pickle") == 0) { | ||||
|         if (_PyPickle_GetXIData(tstate, obj, xidata) != 0) { | ||||
|             goto error; | ||||
|         } | ||||
|     } | ||||
|     else if (strcmp(mode, "marshal") == 0) { | ||||
|         if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) { | ||||
|             goto error; | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ | |||
| 
 | ||||
| #include "Python.h" | ||||
| #include "marshal.h"              // PyMarshal_WriteObjectToString() | ||||
| #include "osdefs.h"               // MAXPATHLEN | ||||
| #include "pycore_ceval.h"         // _Py_simple_func | ||||
| #include "pycore_crossinterp.h"   // _PyXIData_t | ||||
| #include "pycore_initconfig.h"    // _PyStatus_OK() | ||||
|  | @ -10,6 +11,155 @@ | |||
| #include "pycore_typeobject.h"    // _PyStaticType_InitBuiltin() | ||||
| 
 | ||||
| 
 | ||||
| static Py_ssize_t | ||||
| _Py_GetMainfile(char *buffer, size_t maxlen) | ||||
| { | ||||
|     // We don't expect subinterpreters to have the __main__ module's
 | ||||
|     // __name__ set, but proceed just in case.
 | ||||
|     PyThreadState *tstate = _PyThreadState_GET(); | ||||
|     PyObject *module = _Py_GetMainModule(tstate); | ||||
|     if (_Py_CheckMainModule(module) < 0) { | ||||
|         return -1; | ||||
|     } | ||||
|     Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen); | ||||
|     Py_DECREF(module); | ||||
|     return size; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| static PyObject * | ||||
| import_get_module(PyThreadState *tstate, const char *modname) | ||||
| { | ||||
|     PyObject *module = NULL; | ||||
|     if (strcmp(modname, "__main__") == 0) { | ||||
|         module = _Py_GetMainModule(tstate); | ||||
|         if (_Py_CheckMainModule(module) < 0) { | ||||
|             assert(_PyErr_Occurred(tstate)); | ||||
|             return NULL; | ||||
|         } | ||||
|     } | ||||
|     else { | ||||
|         module = PyImport_ImportModule(modname); | ||||
|         if (module == NULL) { | ||||
|             return NULL; | ||||
|         } | ||||
|     } | ||||
|     return module; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| static PyObject * | ||||
| runpy_run_path(const char *filename, const char *modname) | ||||
| { | ||||
|     PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path"); | ||||
|     if (run_path == NULL) { | ||||
|         return NULL; | ||||
|     } | ||||
|     PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname); | ||||
|     if (args == NULL) { | ||||
|         Py_DECREF(run_path); | ||||
|         return NULL; | ||||
|     } | ||||
|     PyObject *ns = PyObject_Call(run_path, args, NULL); | ||||
|     Py_DECREF(run_path); | ||||
|     Py_DECREF(args); | ||||
|     return ns; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| static PyObject * | ||||
| pyerr_get_message(PyObject *exc) | ||||
| { | ||||
|     assert(!PyErr_Occurred()); | ||||
|     PyObject *args = PyException_GetArgs(exc); | ||||
|     if (args == NULL || args == Py_None || PyObject_Size(args) < 1) { | ||||
|         return NULL; | ||||
|     } | ||||
|     if (PyUnicode_Check(args)) { | ||||
|         return args; | ||||
|     } | ||||
|     PyObject *msg = PySequence_GetItem(args, 0); | ||||
|     Py_DECREF(args); | ||||
|     if (msg == NULL) { | ||||
|         PyErr_Clear(); | ||||
|         return NULL; | ||||
|     } | ||||
|     if (!PyUnicode_Check(msg)) { | ||||
|         Py_DECREF(msg); | ||||
|         return NULL; | ||||
|     } | ||||
|     return msg; | ||||
| } | ||||
| 
 | ||||
| #define MAX_MODNAME (255) | ||||
| #define MAX_ATTRNAME (255) | ||||
| 
 | ||||
| struct attributeerror_info { | ||||
|     char modname[MAX_MODNAME+1]; | ||||
|     char attrname[MAX_ATTRNAME+1]; | ||||
| }; | ||||
| 
 | ||||
| static int | ||||
| _parse_attributeerror(PyObject *exc, struct attributeerror_info *info) | ||||
| { | ||||
|     assert(exc != NULL); | ||||
|     assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); | ||||
|     int res = -1; | ||||
| 
 | ||||
|     PyObject *msgobj = pyerr_get_message(exc); | ||||
|     if (msgobj == NULL) { | ||||
|         return -1; | ||||
|     } | ||||
|     const char *err = PyUnicode_AsUTF8(msgobj); | ||||
| 
 | ||||
|     if (strncmp(err, "module '", 8) != 0) { | ||||
|         goto finally; | ||||
|     } | ||||
|     err += 8; | ||||
| 
 | ||||
|     const char *matched = strchr(err, '\''); | ||||
|     if (matched == NULL) { | ||||
|         goto finally; | ||||
|     } | ||||
|     Py_ssize_t len = matched - err; | ||||
|     if (len > MAX_MODNAME) { | ||||
|         goto finally; | ||||
|     } | ||||
|     (void)strncpy(info->modname, err, len); | ||||
|     info->modname[len] = '\0'; | ||||
|     err = matched; | ||||
| 
 | ||||
|     if (strncmp(err, "' has no attribute '", 20) != 0) { | ||||
|         goto finally; | ||||
|     } | ||||
|     err += 20; | ||||
| 
 | ||||
|     matched = strchr(err, '\''); | ||||
|     if (matched == NULL) { | ||||
|         goto finally; | ||||
|     } | ||||
|     len = matched - err; | ||||
|     if (len > MAX_ATTRNAME) { | ||||
|         goto finally; | ||||
|     } | ||||
|     (void)strncpy(info->attrname, err, len); | ||||
|     info->attrname[len] = '\0'; | ||||
|     err = matched + 1; | ||||
| 
 | ||||
|     if (strlen(err) > 0) { | ||||
|         goto finally; | ||||
|     } | ||||
|     res = 0; | ||||
| 
 | ||||
| finally: | ||||
|     Py_DECREF(msgobj); | ||||
|     return res; | ||||
| } | ||||
| 
 | ||||
| #undef MAX_MODNAME | ||||
| #undef MAX_ATTRNAME | ||||
| 
 | ||||
| 
 | ||||
| /**************/ | ||||
| /* exceptions */ | ||||
| /**************/ | ||||
|  | @ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate, | |||
| } | ||||
| 
 | ||||
| 
 | ||||
| /* pickle C-API */ | ||||
| 
 | ||||
| struct _pickle_context { | ||||
|     PyThreadState *tstate; | ||||
| }; | ||||
| 
 | ||||
| static PyObject * | ||||
| _PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj) | ||||
| { | ||||
|     PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps"); | ||||
|     if (dumps == NULL) { | ||||
|         return NULL; | ||||
|     } | ||||
|     PyObject *bytes = PyObject_CallOneArg(dumps, obj); | ||||
|     Py_DECREF(dumps); | ||||
|     return bytes; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| struct sync_module_result { | ||||
|     PyObject *module; | ||||
|     PyObject *loaded; | ||||
|     PyObject *failed; | ||||
| }; | ||||
| 
 | ||||
| struct sync_module { | ||||
|     const char *filename; | ||||
|     char _filename[MAXPATHLEN+1]; | ||||
|     struct sync_module_result cached; | ||||
| }; | ||||
| 
 | ||||
| static void | ||||
| sync_module_clear(struct sync_module *data) | ||||
| { | ||||
|     data->filename = NULL; | ||||
|     Py_CLEAR(data->cached.module); | ||||
|     Py_CLEAR(data->cached.loaded); | ||||
|     Py_CLEAR(data->cached.failed); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| struct _unpickle_context { | ||||
|     PyThreadState *tstate; | ||||
|     // We only special-case the __main__ module,
 | ||||
|     // since other modules behave consistently.
 | ||||
|     struct sync_module main; | ||||
| }; | ||||
| 
 | ||||
| static void | ||||
| _unpickle_context_clear(struct _unpickle_context *ctx) | ||||
| { | ||||
|     sync_module_clear(&ctx->main); | ||||
| } | ||||
| 
 | ||||
| static struct sync_module_result | ||||
| _unpickle_context_get_module(struct _unpickle_context *ctx, | ||||
|                              const char *modname) | ||||
| { | ||||
|     if (strcmp(modname, "__main__") == 0) { | ||||
|         return ctx->main.cached; | ||||
|     } | ||||
|     else { | ||||
|         return (struct sync_module_result){ | ||||
|             .failed = PyExc_NotImplementedError, | ||||
|         }; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static struct sync_module_result | ||||
| _unpickle_context_set_module(struct _unpickle_context *ctx, | ||||
|                              const char *modname) | ||||
| { | ||||
|     struct sync_module_result res = {0}; | ||||
|     struct sync_module_result *cached = NULL; | ||||
|     const char *filename = NULL; | ||||
|     if (strcmp(modname, "__main__") == 0) { | ||||
|         cached = &ctx->main.cached; | ||||
|         filename = ctx->main.filename; | ||||
|     } | ||||
|     else { | ||||
|         res.failed = PyExc_NotImplementedError; | ||||
|         goto finally; | ||||
|     } | ||||
| 
 | ||||
|     res.module = import_get_module(ctx->tstate, modname); | ||||
|     if (res.module == NULL) { | ||||
|         res.failed = _PyErr_GetRaisedException(ctx->tstate); | ||||
|         assert(res.failed != NULL); | ||||
|         goto finally; | ||||
|     } | ||||
| 
 | ||||
|     if (filename == NULL) { | ||||
|         Py_CLEAR(res.module); | ||||
|         res.failed = PyExc_NotImplementedError; | ||||
|         goto finally; | ||||
|     } | ||||
|     res.loaded = runpy_run_path(filename, modname); | ||||
|     if (res.loaded == NULL) { | ||||
|         Py_CLEAR(res.module); | ||||
|         res.failed = _PyErr_GetRaisedException(ctx->tstate); | ||||
|         assert(res.failed != NULL); | ||||
|         goto finally; | ||||
|     } | ||||
| 
 | ||||
| finally: | ||||
|     if (cached != NULL) { | ||||
|         assert(cached->module == NULL); | ||||
|         assert(cached->loaded == NULL); | ||||
|         assert(cached->failed == NULL); | ||||
|         *cached = res; | ||||
|     } | ||||
|     return res; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| static int | ||||
| _handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc) | ||||
| { | ||||
|     // The caller must check if an exception is set or not when -1 is returned.
 | ||||
|     assert(!_PyErr_Occurred(ctx->tstate)); | ||||
|     assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); | ||||
|     struct attributeerror_info info; | ||||
|     if (_parse_attributeerror(exc, &info) < 0) { | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     // Get the module.
 | ||||
|     struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname); | ||||
|     if (mod.failed != NULL) { | ||||
|         // It must have failed previously.
 | ||||
|         return -1; | ||||
|     } | ||||
|     if (mod.module == NULL) { | ||||
|         mod = _unpickle_context_set_module(ctx, info.modname); | ||||
|         if (mod.failed != NULL) { | ||||
|             return -1; | ||||
|         } | ||||
|         assert(mod.module != NULL); | ||||
|     } | ||||
| 
 | ||||
|     // Bail out if it is unexpectedly set already.
 | ||||
|     if (PyObject_HasAttrString(mod.module, info.attrname)) { | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     // Try setting the attribute.
 | ||||
|     PyObject *value = NULL; | ||||
|     if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) { | ||||
|         return -1; | ||||
|     } | ||||
|     assert(value != NULL); | ||||
|     int res = PyObject_SetAttrString(mod.module, info.attrname, value); | ||||
|     Py_DECREF(value); | ||||
|     if (res < 0) { | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| static PyObject * | ||||
| _PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled) | ||||
| { | ||||
|     PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads"); | ||||
|     if (loads == NULL) { | ||||
|         return NULL; | ||||
|     } | ||||
|     PyObject *obj = PyObject_CallOneArg(loads, pickled); | ||||
|     if (ctx != NULL) { | ||||
|         while (obj == NULL) { | ||||
|             assert(_PyErr_Occurred(ctx->tstate)); | ||||
|             if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { | ||||
|                 // We leave other failures unhandled.
 | ||||
|                 break; | ||||
|             } | ||||
|             // Try setting the attr if not set.
 | ||||
|             PyObject *exc = _PyErr_GetRaisedException(ctx->tstate); | ||||
|             if (_handle_unpickle_missing_attr(ctx, exc) < 0) { | ||||
|                 // Any resulting exceptions are ignored
 | ||||
|                 // in favor of the original.
 | ||||
|                 _PyErr_SetRaisedException(ctx->tstate, exc); | ||||
|                 break; | ||||
|             } | ||||
|             Py_CLEAR(exc); | ||||
|             // Retry with the attribute set.
 | ||||
|             obj = PyObject_CallOneArg(loads, pickled); | ||||
|         } | ||||
|     } | ||||
|     Py_DECREF(loads); | ||||
|     return obj; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| /* pickle wrapper */ | ||||
| 
 | ||||
| struct _pickle_xid_context { | ||||
|     // __main__.__file__
 | ||||
|     struct { | ||||
|         const char *utf8; | ||||
|         size_t len; | ||||
|         char _utf8[MAXPATHLEN+1]; | ||||
|     } mainfile; | ||||
| }; | ||||
| 
 | ||||
| static int | ||||
| _set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx) | ||||
| { | ||||
|     // Set mainfile if possible.
 | ||||
|     Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN); | ||||
|     if (len < 0) { | ||||
|         // For now we ignore any exceptions.
 | ||||
|         PyErr_Clear(); | ||||
|     } | ||||
|     else if (len > 0) { | ||||
|         ctx->mainfile.utf8 = ctx->mainfile._utf8; | ||||
|         ctx->mainfile.len = (size_t)len; | ||||
|     } | ||||
| 
 | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| struct _shared_pickle_data { | ||||
|     _PyBytes_data_t pickled;  // Must be first if we use _PyBytes_FromXIData().
 | ||||
|     struct _pickle_xid_context ctx; | ||||
| }; | ||||
| 
 | ||||
| PyObject * | ||||
| _PyPickle_LoadFromXIData(_PyXIData_t *xidata) | ||||
| { | ||||
|     PyThreadState *tstate = _PyThreadState_GET(); | ||||
|     struct _shared_pickle_data *shared = | ||||
|                             (struct _shared_pickle_data *)xidata->data; | ||||
|     // We avoid copying the pickled data by wrapping it in a memoryview.
 | ||||
|     // The alternative is to get a bytes object using _PyBytes_FromXIData().
 | ||||
|     PyObject *pickled = PyMemoryView_FromMemory( | ||||
|             (char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ); | ||||
|     if (pickled == NULL) { | ||||
|         return NULL; | ||||
|     } | ||||
| 
 | ||||
|     // Unpickle the object.
 | ||||
|     struct _unpickle_context ctx = { | ||||
|         .tstate = tstate, | ||||
|         .main = { | ||||
|             .filename = shared->ctx.mainfile.utf8, | ||||
|         }, | ||||
|     }; | ||||
|     PyObject *obj = _PyPickle_Loads(&ctx, pickled); | ||||
|     Py_DECREF(pickled); | ||||
|     _unpickle_context_clear(&ctx); | ||||
|     if (obj == NULL) { | ||||
|         PyObject *cause = _PyErr_GetRaisedException(tstate); | ||||
|         assert(cause != NULL); | ||||
|         _set_xid_lookup_failure( | ||||
|                     tstate, NULL, "object could not be unpickled", cause); | ||||
|         Py_DECREF(cause); | ||||
|     } | ||||
|     return obj; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| int | ||||
| _PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata) | ||||
| { | ||||
|     // Pickle the object.
 | ||||
|     struct _pickle_context ctx = { | ||||
|         .tstate = tstate, | ||||
|     }; | ||||
|     PyObject *bytes = _PyPickle_Dumps(&ctx, obj); | ||||
|     if (bytes == NULL) { | ||||
|         PyObject *cause = _PyErr_GetRaisedException(tstate); | ||||
|         assert(cause != NULL); | ||||
|         _set_xid_lookup_failure( | ||||
|                     tstate, NULL, "object could not be pickled", cause); | ||||
|         Py_DECREF(cause); | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     // If we had an "unwrapper" mechnanism, we could call
 | ||||
|     // _PyObject_GetXIData() on the bytes object directly and add
 | ||||
|     // a simple unwrapper to call pickle.loads() on the bytes.
 | ||||
|     size_t size = sizeof(struct _shared_pickle_data); | ||||
|     struct _shared_pickle_data *shared = | ||||
|             (struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped( | ||||
|                     tstate, bytes, size, _PyPickle_LoadFromXIData, xidata); | ||||
|     Py_DECREF(bytes); | ||||
|     if (shared == NULL) { | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     // If it mattered, we could skip getting __main__.__file__
 | ||||
|     // when "__main__" doesn't show up in the pickle bytes.
 | ||||
|     if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) { | ||||
|         _xidata_clear(xidata); | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| /* marshal wrapper */ | ||||
| 
 | ||||
| PyObject * | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Eric Snow
						Eric Snow