mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 23:21:29 +00:00 
			
		
		
		
	
		
			
	
	
		
			174 lines
		
	
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			174 lines
		
	
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import contextlib
							 | 
						||
| 
								 | 
							
								import imp
							 | 
						||
| 
								 | 
							
								import importlib
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								import unittest
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@contextlib.contextmanager
							 | 
						||
| 
								 | 
							
								def uncache(*names):
							 | 
						||
| 
								 | 
							
								    """Uncache a module from sys.modules.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    A basic sanity check is performed to prevent uncaching modules that either
							 | 
						||
| 
								 | 
							
								    cannot/shouldn't be uncached.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    for name in names:
							 | 
						||
| 
								 | 
							
								        if name in ('sys', 'marshal', 'imp'):
							 | 
						||
| 
								 | 
							
								            raise ValueError(
							 | 
						||
| 
								 | 
							
								                "cannot uncache {0} as it will break _importlib".format(name))
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            del sys.modules[name]
							 | 
						||
| 
								 | 
							
								        except KeyError:
							 | 
						||
| 
								 | 
							
								            pass
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        yield
							 | 
						||
| 
								 | 
							
								    finally:
							 | 
						||
| 
								 | 
							
								        for name in names:
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                del sys.modules[name]
							 | 
						||
| 
								 | 
							
								            except KeyError:
							 | 
						||
| 
								 | 
							
								                pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@contextlib.contextmanager
							 | 
						||
| 
								 | 
							
								def import_state(**kwargs):
							 | 
						||
| 
								 | 
							
								    """Context manager to manage the various importers and stored state in the
							 | 
						||
| 
								 | 
							
								    sys module.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    The 'modules' attribute is not supported as the interpreter state stores a
							 | 
						||
| 
								 | 
							
								    pointer to the dict that the interpreter uses internally;
							 | 
						||
| 
								 | 
							
								    reassigning to sys.modules does not have the desired effect.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    originals = {}
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        for attr, default in (('meta_path', []), ('path', []),
							 | 
						||
| 
								 | 
							
								                              ('path_hooks', []),
							 | 
						||
| 
								 | 
							
								                              ('path_importer_cache', {})):
							 | 
						||
| 
								 | 
							
								            originals[attr] = getattr(sys, attr)
							 | 
						||
| 
								 | 
							
								            if attr in kwargs:
							 | 
						||
| 
								 | 
							
								                new_value = kwargs[attr]
							 | 
						||
| 
								 | 
							
								                del kwargs[attr]
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                new_value = default
							 | 
						||
| 
								 | 
							
								            setattr(sys, attr, new_value)
							 | 
						||
| 
								 | 
							
								        if len(kwargs):
							 | 
						||
| 
								 | 
							
								            raise ValueError(
							 | 
						||
| 
								 | 
							
								                    'unrecognized arguments: {0}'.format(kwargs.keys()))
							 | 
						||
| 
								 | 
							
								        yield
							 | 
						||
| 
								 | 
							
								    finally:
							 | 
						||
| 
								 | 
							
								        for attr, value in originals.items():
							 | 
						||
| 
								 | 
							
								            setattr(sys, attr, value)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class mock_modules(object):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """A mock importer/loader."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, *names):
							 | 
						||
| 
								 | 
							
								        self.modules = {}
							 | 
						||
| 
								 | 
							
								        for name in names:
							 | 
						||
| 
								 | 
							
								            if not name.endswith('.__init__'):
							 | 
						||
| 
								 | 
							
								                import_name = name
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                import_name = name[:-len('.__init__')]
							 | 
						||
| 
								 | 
							
								            if '.' not in name:
							 | 
						||
| 
								 | 
							
								                package = None
							 | 
						||
| 
								 | 
							
								            elif import_name == name:
							 | 
						||
| 
								 | 
							
								                package = name.rsplit('.', 1)[0]
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                package = import_name
							 | 
						||
| 
								 | 
							
								            module = imp.new_module(import_name)
							 | 
						||
| 
								 | 
							
								            module.__loader__ = self
							 | 
						||
| 
								 | 
							
								            module.__file__ = '<mock __file__>'
							 | 
						||
| 
								 | 
							
								            module.__package__ = package
							 | 
						||
| 
								 | 
							
								            module.attr = name
							 | 
						||
| 
								 | 
							
								            if import_name != name:
							 | 
						||
| 
								 | 
							
								                module.__path__ = ['<mock __path__>']
							 | 
						||
| 
								 | 
							
								            self.modules[import_name] = module
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __getitem__(self, name):
							 | 
						||
| 
								 | 
							
								        return self.modules[name]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def find_module(self, fullname, path=None):
							 | 
						||
| 
								 | 
							
								        if fullname not in self.modules:
							 | 
						||
| 
								 | 
							
								            return None
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            return self
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def load_module(self, fullname):
							 | 
						||
| 
								 | 
							
								        if fullname not in self.modules:
							 | 
						||
| 
								 | 
							
								            raise ImportError
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            sys.modules[fullname] = self.modules[fullname]
							 | 
						||
| 
								 | 
							
								            return self.modules[fullname]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __enter__(self):
							 | 
						||
| 
								 | 
							
								        self._uncache = uncache(*self.modules.keys())
							 | 
						||
| 
								 | 
							
								        self._uncache.__enter__()
							 | 
						||
| 
								 | 
							
								        return self
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __exit__(self, *exc_info):
							 | 
						||
| 
								 | 
							
								        self._uncache.__exit__(None, None, None)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ImportModuleTests(unittest.TestCase):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """Test importlib.import_module."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_module_import(self):
							 | 
						||
| 
								 | 
							
								        # Test importing a top-level module.
							 | 
						||
| 
								 | 
							
								        with mock_modules('top_level') as mock:
							 | 
						||
| 
								 | 
							
								            with import_state(meta_path=[mock]):
							 | 
						||
| 
								 | 
							
								                module = importlib.import_module('top_level')
							 | 
						||
| 
								 | 
							
								                self.assertEqual(module.__name__, 'top_level')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_absolute_package_import(self):
							 | 
						||
| 
								 | 
							
								        # Test importing a module from a package with an absolute name.
							 | 
						||
| 
								 | 
							
								        pkg_name = 'pkg'
							 | 
						||
| 
								 | 
							
								        pkg_long_name = '{0}.__init__'.format(pkg_name)
							 | 
						||
| 
								 | 
							
								        name = '{0}.mod'.format(pkg_name)
							 | 
						||
| 
								 | 
							
								        with mock_modules(pkg_long_name, name) as mock:
							 | 
						||
| 
								 | 
							
								            with import_state(meta_path=[mock]):
							 | 
						||
| 
								 | 
							
								                module = importlib.import_module(name)
							 | 
						||
| 
								 | 
							
								                self.assertEqual(module.__name__, name)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_relative_package_import(self):
							 | 
						||
| 
								 | 
							
								        # Test importing a module from a package through a relatve import.
							 | 
						||
| 
								 | 
							
								        pkg_name = 'pkg'
							 | 
						||
| 
								 | 
							
								        pkg_long_name = '{0}.__init__'.format(pkg_name)
							 | 
						||
| 
								 | 
							
								        module_name = 'mod'
							 | 
						||
| 
								 | 
							
								        absolute_name = '{0}.{1}'.format(pkg_name, module_name)
							 | 
						||
| 
								 | 
							
								        relative_name = '.{0}'.format(module_name)
							 | 
						||
| 
								 | 
							
								        with mock_modules(pkg_long_name, absolute_name) as mock:
							 | 
						||
| 
								 | 
							
								            with import_state(meta_path=[mock]):
							 | 
						||
| 
								 | 
							
								                module = importlib.import_module(relative_name, pkg_name)
							 | 
						||
| 
								 | 
							
								                self.assertEqual(module.__name__, absolute_name)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_absolute_import_with_package(self):
							 | 
						||
| 
								 | 
							
								        # Test importing a module from a package with an absolute name with
							 | 
						||
| 
								 | 
							
								        # the 'package' argument given.
							 | 
						||
| 
								 | 
							
								        pkg_name = 'pkg'
							 | 
						||
| 
								 | 
							
								        pkg_long_name = '{0}.__init__'.format(pkg_name)
							 | 
						||
| 
								 | 
							
								        name = '{0}.mod'.format(pkg_name)
							 | 
						||
| 
								 | 
							
								        with mock_modules(pkg_long_name, name) as mock:
							 | 
						||
| 
								 | 
							
								            with import_state(meta_path=[mock]):
							 | 
						||
| 
								 | 
							
								                module = importlib.import_module(name, pkg_name)
							 | 
						||
| 
								 | 
							
								                self.assertEqual(module.__name__, name)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_relative_import_wo_package(self):
							 | 
						||
| 
								 | 
							
								        # Relative imports cannot happen without the 'package' argument being
							 | 
						||
| 
								 | 
							
								        # set.
							 | 
						||
| 
								 | 
							
								        self.assertRaises(TypeError, importlib.import_module, '.support')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def test_main():
							 | 
						||
| 
								 | 
							
								    from test.test_support import run_unittest
							 | 
						||
| 
								 | 
							
								    run_unittest(ImportModuleTests)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if __name__ == '__main__':
							 | 
						||
| 
								 | 
							
								    test_main()
							 |