mirror of
				https://github.com/python/cpython.git
				synced 2025-10-30 21:21:22 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			377 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			377 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| import pkgutil
 | |
| import sys
 | |
| import tokenize
 | |
| from io import StringIO
 | |
| from contextlib import contextmanager
 | |
| from dataclasses import dataclass
 | |
| from itertools import chain
 | |
| from tokenize import TokenInfo
 | |
| 
 | |
| TYPE_CHECKING = False
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from typing import Any, Iterable, Iterator, Mapping
 | |
| 
 | |
| 
 | |
| def make_default_module_completer() -> ModuleCompleter:
 | |
|     # Inside pyrepl, __package__ is set to '_pyrepl'
 | |
|     return ModuleCompleter(namespace={'__package__': '_pyrepl'})
 | |
| 
 | |
| 
 | |
| class ModuleCompleter:
 | |
|     """A completer for Python import statements.
 | |
| 
 | |
|     Examples:
 | |
|         - import <tab>
 | |
|         - import foo<tab>
 | |
|         - import foo.<tab>
 | |
|         - import foo as bar, baz<tab>
 | |
| 
 | |
|         - from <tab>
 | |
|         - from foo<tab>
 | |
|         - from foo import <tab>
 | |
|         - from foo import bar<tab>
 | |
|         - from foo import (bar as baz, qux<tab>
 | |
|     """
 | |
| 
 | |
|     def __init__(self, namespace: Mapping[str, Any] | None = None) -> None:
 | |
|         self.namespace = namespace or {}
 | |
|         self._global_cache: list[pkgutil.ModuleInfo] = []
 | |
|         self._curr_sys_path: list[str] = sys.path[:]
 | |
| 
 | |
|     def get_completions(self, line: str) -> list[str]:
 | |
|         """Return the next possible import completions for 'line'."""
 | |
|         result = ImportParser(line).parse()
 | |
|         if not result:
 | |
|             return []
 | |
|         try:
 | |
|             return self.complete(*result)
 | |
|         except Exception:
 | |
|             # Some unexpected error occurred, make it look like
 | |
|             # no completions are available
 | |
|             return []
 | |
| 
 | |
|     def complete(self, from_name: str | None, name: str | None) -> list[str]:
 | |
|         if from_name is None:
 | |
|             # import x.y.z<tab>
 | |
|             assert name is not None
 | |
|             path, prefix = self.get_path_and_prefix(name)
 | |
|             modules = self.find_modules(path, prefix)
 | |
|             return [self.format_completion(path, module) for module in modules]
 | |
| 
 | |
|         if name is None:
 | |
|             # from x.y.z<tab>
 | |
|             path, prefix = self.get_path_and_prefix(from_name)
 | |
|             modules = self.find_modules(path, prefix)
 | |
|             return [self.format_completion(path, module) for module in modules]
 | |
| 
 | |
|         # from x.y import z<tab>
 | |
|         return self.find_modules(from_name, name)
 | |
| 
 | |
|     def find_modules(self, path: str, prefix: str) -> list[str]:
 | |
|         """Find all modules under 'path' that start with 'prefix'."""
 | |
|         modules = self._find_modules(path, prefix)
 | |
|         # Filter out invalid module names
 | |
|         # (for example those containing dashes that cannot be imported with 'import')
 | |
|         return [mod for mod in modules if mod.isidentifier()]
 | |
| 
 | |
|     def _find_modules(self, path: str, prefix: str) -> list[str]:
 | |
|         if not path:
 | |
|             # Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)`
 | |
|             return [name for _, name, _ in self.global_cache
 | |
|                     if name.startswith(prefix)]
 | |
| 
 | |
|         if path.startswith('.'):
 | |
|             # Convert relative path to absolute path
 | |
|             package = self.namespace.get('__package__', '')
 | |
|             path = self.resolve_relative_name(path, package)  # type: ignore[assignment]
 | |
|             if path is None:
 | |
|                 return []
 | |
| 
 | |
|         modules: Iterable[pkgutil.ModuleInfo] = self.global_cache
 | |
|         for segment in path.split('.'):
 | |
|             modules = [mod_info for mod_info in modules
 | |
|                        if mod_info.ispkg and mod_info.name == segment]
 | |
|             modules = self.iter_submodules(modules)
 | |
|         return [module.name for module in modules
 | |
|                 if module.name.startswith(prefix)]
 | |
| 
 | |
|     def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]:
 | |
|         """Iterate over all submodules of the given parent modules."""
 | |
|         specs = [info.module_finder.find_spec(info.name, None)
 | |
|                  for info in parent_modules if info.ispkg]
 | |
|         search_locations = set(chain.from_iterable(
 | |
|             getattr(spec, 'submodule_search_locations', [])
 | |
|             for spec in specs if spec
 | |
|         ))
 | |
|         return pkgutil.iter_modules(search_locations)
 | |
| 
 | |
|     def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]:
 | |
|         """
 | |
|         Split a dotted name into an import path and a
 | |
|         final prefix that is to be completed.
 | |
| 
 | |
|         Examples:
 | |
|             'foo.bar' -> 'foo', 'bar'
 | |
|             'foo.' -> 'foo', ''
 | |
|             '.foo' -> '.', 'foo'
 | |
|         """
 | |
|         if '.' not in dotted_name:
 | |
|             return '', dotted_name
 | |
|         if dotted_name.startswith('.'):
 | |
|             stripped = dotted_name.lstrip('.')
 | |
|             dots = '.' * (len(dotted_name) - len(stripped))
 | |
|             if '.' not in stripped:
 | |
|                 return dots, stripped
 | |
|             path, prefix = stripped.rsplit('.', 1)
 | |
|             return dots + path, prefix
 | |
|         path, prefix = dotted_name.rsplit('.', 1)
 | |
|         return path, prefix
 | |
| 
 | |
|     def format_completion(self, path: str, module: str) -> str:
 | |
|         if path == '' or path.endswith('.'):
 | |
|             return f'{path}{module}'
 | |
|         return f'{path}.{module}'
 | |
| 
 | |
|     def resolve_relative_name(self, name: str, package: str) -> str | None:
 | |
|         """Resolve a relative module name to an absolute name.
 | |
| 
 | |
|         Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo'
 | |
|         """
 | |
|         # taken from importlib._bootstrap
 | |
|         level = 0
 | |
|         for character in name:
 | |
|             if character != '.':
 | |
|                 break
 | |
|             level += 1
 | |
|         bits = package.rsplit('.', level - 1)
 | |
|         if len(bits) < level:
 | |
|             return None
 | |
|         base = bits[0]
 | |
|         name = name[level:]
 | |
|         return f'{base}.{name}' if name else base
 | |
| 
 | |
|     @property
 | |
|     def global_cache(self) -> list[pkgutil.ModuleInfo]:
 | |
|         """Global module cache"""
 | |
|         if not self._global_cache or self._curr_sys_path != sys.path:
 | |
|             self._curr_sys_path = sys.path[:]
 | |
|             # print('getting packages')
 | |
|             self._global_cache = list(pkgutil.iter_modules())
 | |
|         return self._global_cache
 | |
| 
 | |
| 
 | |
| class ImportParser:
 | |
|     """
 | |
|     Parses incomplete import statements that are
 | |
|     suitable for autocomplete suggestions.
 | |
| 
 | |
|     Examples:
 | |
|         - import foo          -> Result(from_name=None, name='foo')
 | |
|         - import foo.         -> Result(from_name=None, name='foo.')
 | |
|         - from foo            -> Result(from_name='foo', name=None)
 | |
|         - from foo import bar -> Result(from_name='foo', name='bar')
 | |
|         - from .foo import (  -> Result(from_name='.foo', name='')
 | |
| 
 | |
|     Note that the parser works in reverse order, starting from the
 | |
|     last token in the input string. This makes the parser more robust
 | |
|     when parsing multiple statements.
 | |
|     """
 | |
|     _ignored_tokens = {
 | |
|         tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT,
 | |
|         tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER
 | |
|     }
 | |
|     _keywords = {'import', 'from', 'as'}
 | |
| 
 | |
|     def __init__(self, code: str) -> None:
 | |
|         self.code = code
 | |
|         tokens = []
 | |
|         try:
 | |
|             for t in tokenize.generate_tokens(StringIO(code).readline):
 | |
|                 if t.type not in self._ignored_tokens:
 | |
|                     tokens.append(t)
 | |
|         except tokenize.TokenError as e:
 | |
|             if 'unexpected EOF' not in str(e):
 | |
|                 # unexpected EOF is fine, since we're parsing an
 | |
|                 # incomplete statement, but other errors are not
 | |
|                 # because we may not have all the tokens so it's
 | |
|                 # safer to bail out
 | |
|                 tokens = []
 | |
|         except SyntaxError:
 | |
|             tokens = []
 | |
|         self.tokens = TokenQueue(tokens[::-1])
 | |
| 
 | |
|     def parse(self) -> tuple[str | None, str | None] | None:
 | |
|         if not (res := self._parse()):
 | |
|             return None
 | |
|         return res.from_name, res.name
 | |
| 
 | |
|     def _parse(self) -> Result | None:
 | |
|         with self.tokens.save_state():
 | |
|             return self.parse_from_import()
 | |
|         with self.tokens.save_state():
 | |
|             return self.parse_import()
 | |
| 
 | |
|     def parse_import(self) -> Result:
 | |
|         if self.code.rstrip().endswith('import') and self.code.endswith(' '):
 | |
|             return Result(name='')
 | |
|         if self.tokens.peek_string(','):
 | |
|             name = ''
 | |
|         else:
 | |
|             if self.code.endswith(' '):
 | |
|                 raise ParseError('parse_import')
 | |
|             name = self.parse_dotted_name()
 | |
|         if name.startswith('.'):
 | |
|             raise ParseError('parse_import')
 | |
|         while self.tokens.peek_string(','):
 | |
|             self.tokens.pop()
 | |
|             self.parse_dotted_as_name()
 | |
|         if self.tokens.peek_string('import'):
 | |
|             return Result(name=name)
 | |
|         raise ParseError('parse_import')
 | |
| 
 | |
|     def parse_from_import(self) -> Result:
 | |
|         stripped = self.code.rstrip()
 | |
|         if stripped.endswith('import') and self.code.endswith(' '):
 | |
|             return Result(from_name=self.parse_empty_from_import(), name='')
 | |
|         if stripped.endswith('from') and self.code.endswith(' '):
 | |
|             return Result(from_name='')
 | |
|         if self.tokens.peek_string('(') or self.tokens.peek_string(','):
 | |
|             return Result(from_name=self.parse_empty_from_import(), name='')
 | |
|         if self.code.endswith(' '):
 | |
|             raise ParseError('parse_from_import')
 | |
|         name = self.parse_dotted_name()
 | |
|         if '.' in name:
 | |
|             self.tokens.pop_string('from')
 | |
|             return Result(from_name=name)
 | |
|         if self.tokens.peek_string('from'):
 | |
|             return Result(from_name=name)
 | |
|         from_name = self.parse_empty_from_import()
 | |
|         return Result(from_name=from_name, name=name)
 | |
| 
 | |
|     def parse_empty_from_import(self) -> str:
 | |
|         if self.tokens.peek_string(','):
 | |
|             self.tokens.pop()
 | |
|             self.parse_as_names()
 | |
|         if self.tokens.peek_string('('):
 | |
|             self.tokens.pop()
 | |
|         self.tokens.pop_string('import')
 | |
|         return self.parse_from()
 | |
| 
 | |
|     def parse_from(self) -> str:
 | |
|         from_name = self.parse_dotted_name()
 | |
|         self.tokens.pop_string('from')
 | |
|         return from_name
 | |
| 
 | |
|     def parse_dotted_as_name(self) -> str:
 | |
|         self.tokens.pop_name()
 | |
|         if self.tokens.peek_string('as'):
 | |
|             self.tokens.pop()
 | |
|         with self.tokens.save_state():
 | |
|             return self.parse_dotted_name()
 | |
| 
 | |
|     def parse_dotted_name(self) -> str:
 | |
|         name = []
 | |
|         if self.tokens.peek_string('.'):
 | |
|             name.append('.')
 | |
|             self.tokens.pop()
 | |
|         if (self.tokens.peek_name()
 | |
|             and (tok := self.tokens.peek())
 | |
|             and tok.string not in self._keywords):
 | |
|             name.append(self.tokens.pop_name())
 | |
|         if not name:
 | |
|             raise ParseError('parse_dotted_name')
 | |
|         while self.tokens.peek_string('.'):
 | |
|             name.append('.')
 | |
|             self.tokens.pop()
 | |
|             if (self.tokens.peek_name()
 | |
|                 and (tok := self.tokens.peek())
 | |
|                 and tok.string not in self._keywords):
 | |
|                 name.append(self.tokens.pop_name())
 | |
|             else:
 | |
|                 break
 | |
| 
 | |
|         while self.tokens.peek_string('.'):
 | |
|             name.append('.')
 | |
|             self.tokens.pop()
 | |
|         return ''.join(name[::-1])
 | |
| 
 | |
|     def parse_as_names(self) -> None:
 | |
|         self.parse_as_name()
 | |
|         while self.tokens.peek_string(','):
 | |
|             self.tokens.pop()
 | |
|             self.parse_as_name()
 | |
| 
 | |
|     def parse_as_name(self) -> None:
 | |
|         self.tokens.pop_name()
 | |
|         if self.tokens.peek_string('as'):
 | |
|             self.tokens.pop()
 | |
|             self.tokens.pop_name()
 | |
| 
 | |
| 
 | |
| class ParseError(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True)
 | |
| class Result:
 | |
|     from_name: str | None = None
 | |
|     name: str | None = None
 | |
| 
 | |
| 
 | |
| class TokenQueue:
 | |
|     """Provides helper functions for working with a sequence of tokens."""
 | |
| 
 | |
|     def __init__(self, tokens: list[TokenInfo]) -> None:
 | |
|         self.tokens: list[TokenInfo] = tokens
 | |
|         self.index: int = 0
 | |
|         self.stack: list[int] = []
 | |
| 
 | |
|     @contextmanager
 | |
|     def save_state(self) -> Any:
 | |
|         try:
 | |
|             self.stack.append(self.index)
 | |
|             yield
 | |
|         except ParseError:
 | |
|             self.index = self.stack.pop()
 | |
|         else:
 | |
|             self.stack.pop()
 | |
| 
 | |
|     def __bool__(self) -> bool:
 | |
|         return self.index < len(self.tokens)
 | |
| 
 | |
|     def peek(self) -> TokenInfo | None:
 | |
|         if not self:
 | |
|             return None
 | |
|         return self.tokens[self.index]
 | |
| 
 | |
|     def peek_name(self) -> bool:
 | |
|         if not (tok := self.peek()):
 | |
|             return False
 | |
|         return tok.type == tokenize.NAME
 | |
| 
 | |
|     def pop_name(self) -> str:
 | |
|         tok = self.pop()
 | |
|         if tok.type != tokenize.NAME:
 | |
|             raise ParseError('pop_name')
 | |
|         return tok.string
 | |
| 
 | |
|     def peek_string(self, string: str) -> bool:
 | |
|         if not (tok := self.peek()):
 | |
|             return False
 | |
|         return tok.string == string
 | |
| 
 | |
|     def pop_string(self, string: str) -> str:
 | |
|         tok = self.pop()
 | |
|         if tok.string != string:
 | |
|             raise ParseError('pop_string')
 | |
|         return tok.string
 | |
| 
 | |
|     def pop(self) -> TokenInfo:
 | |
|         if not self:
 | |
|             raise ParseError('pop')
 | |
|         tok = self.tokens[self.index]
 | |
|         self.index += 1
 | |
|         return tok
 | 
