from __future__ import annotations import pkgutil import sys import tokenize from io import StringIO from contextlib import contextmanager from dataclasses import dataclass from itertools import chain from tokenize import TokenInfo TYPE_CHECKING = False if TYPE_CHECKING: from typing import Any, Iterable, Iterator, Mapping def make_default_module_completer() -> ModuleCompleter: # Inside pyrepl, __package__ is set to '_pyrepl' return ModuleCompleter(namespace={'__package__': '_pyrepl'}) class ModuleCompleter: """A completer for Python import statements. Examples: - import - import foo - import foo. - import foo as bar, baz - from - from foo - from foo import - from foo import bar - from foo import (bar as baz, qux """ def __init__(self, namespace: Mapping[str, Any] | None = None) -> None: self.namespace = namespace or {} self._global_cache: list[pkgutil.ModuleInfo] = [] self._curr_sys_path: list[str] = sys.path[:] def get_completions(self, line: str) -> list[str]: """Return the next possible import completions for 'line'.""" result = ImportParser(line).parse() if not result: return [] try: return self.complete(*result) except Exception: # Some unexpected error occurred, make it look like # no completions are available return [] def complete(self, from_name: str | None, name: str | None) -> list[str]: if from_name is None: # import x.y.z assert name is not None path, prefix = self.get_path_and_prefix(name) modules = self.find_modules(path, prefix) return [self.format_completion(path, module) for module in modules] if name is None: # from x.y.z path, prefix = self.get_path_and_prefix(from_name) modules = self.find_modules(path, prefix) return [self.format_completion(path, module) for module in modules] # from x.y import z return self.find_modules(from_name, name) def find_modules(self, path: str, prefix: str) -> list[str]: """Find all modules under 'path' that start with 'prefix'.""" modules = self._find_modules(path, prefix) # Filter out invalid module names # (for example those containing dashes that cannot be imported with 'import') return [mod for mod in modules if mod.isidentifier()] def _find_modules(self, path: str, prefix: str) -> list[str]: if not path: # Top-level import (e.g. `import foo`` or `from foo`)` return [name for _, name, _ in self.global_cache if name.startswith(prefix)] if path.startswith('.'): # Convert relative path to absolute path package = self.namespace.get('__package__', '') path = self.resolve_relative_name(path, package) # type: ignore[assignment] if path is None: return [] modules: Iterable[pkgutil.ModuleInfo] = self.global_cache for segment in path.split('.'): modules = [mod_info for mod_info in modules if mod_info.ispkg and mod_info.name == segment] modules = self.iter_submodules(modules) return [module.name for module in modules if module.name.startswith(prefix)] def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]: """Iterate over all submodules of the given parent modules.""" specs = [info.module_finder.find_spec(info.name, None) for info in parent_modules if info.ispkg] search_locations = set(chain.from_iterable( getattr(spec, 'submodule_search_locations', []) for spec in specs if spec )) return pkgutil.iter_modules(search_locations) def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]: """ Split a dotted name into an import path and a final prefix that is to be completed. Examples: 'foo.bar' -> 'foo', 'bar' 'foo.' -> 'foo', '' '.foo' -> '.', 'foo' """ if '.' not in dotted_name: return '', dotted_name if dotted_name.startswith('.'): stripped = dotted_name.lstrip('.') dots = '.' * (len(dotted_name) - len(stripped)) if '.' not in stripped: return dots, stripped path, prefix = stripped.rsplit('.', 1) return dots + path, prefix path, prefix = dotted_name.rsplit('.', 1) return path, prefix def format_completion(self, path: str, module: str) -> str: if path == '' or path.endswith('.'): return f'{path}{module}' return f'{path}.{module}' def resolve_relative_name(self, name: str, package: str) -> str | None: """Resolve a relative module name to an absolute name. Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo' """ # taken from importlib._bootstrap level = 0 for character in name: if character != '.': break level += 1 bits = package.rsplit('.', level - 1) if len(bits) < level: return None base = bits[0] name = name[level:] return f'{base}.{name}' if name else base @property def global_cache(self) -> list[pkgutil.ModuleInfo]: """Global module cache""" if not self._global_cache or self._curr_sys_path != sys.path: self._curr_sys_path = sys.path[:] # print('getting packages') self._global_cache = list(pkgutil.iter_modules()) return self._global_cache class ImportParser: """ Parses incomplete import statements that are suitable for autocomplete suggestions. Examples: - import foo -> Result(from_name=None, name='foo') - import foo. -> Result(from_name=None, name='foo.') - from foo -> Result(from_name='foo', name=None) - from foo import bar -> Result(from_name='foo', name='bar') - from .foo import ( -> Result(from_name='.foo', name='') Note that the parser works in reverse order, starting from the last token in the input string. This makes the parser more robust when parsing multiple statements. """ _ignored_tokens = { tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT, tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER } _keywords = {'import', 'from', 'as'} def __init__(self, code: str) -> None: self.code = code tokens = [] try: for t in tokenize.generate_tokens(StringIO(code).readline): if t.type not in self._ignored_tokens: tokens.append(t) except tokenize.TokenError as e: if 'unexpected EOF' not in str(e): # unexpected EOF is fine, since we're parsing an # incomplete statement, but other errors are not # because we may not have all the tokens so it's # safer to bail out tokens = [] except SyntaxError: tokens = [] self.tokens = TokenQueue(tokens[::-1]) def parse(self) -> tuple[str | None, str | None] | None: if not (res := self._parse()): return None return res.from_name, res.name def _parse(self) -> Result | None: with self.tokens.save_state(): return self.parse_from_import() with self.tokens.save_state(): return self.parse_import() def parse_import(self) -> Result: if self.code.rstrip().endswith('import') and self.code.endswith(' '): return Result(name='') if self.tokens.peek_string(','): name = '' else: if self.code.endswith(' '): raise ParseError('parse_import') name = self.parse_dotted_name() if name.startswith('.'): raise ParseError('parse_import') while self.tokens.peek_string(','): self.tokens.pop() self.parse_dotted_as_name() if self.tokens.peek_string('import'): return Result(name=name) raise ParseError('parse_import') def parse_from_import(self) -> Result: stripped = self.code.rstrip() if stripped.endswith('import') and self.code.endswith(' '): return Result(from_name=self.parse_empty_from_import(), name='') if stripped.endswith('from') and self.code.endswith(' '): return Result(from_name='') if self.tokens.peek_string('(') or self.tokens.peek_string(','): return Result(from_name=self.parse_empty_from_import(), name='') if self.code.endswith(' '): raise ParseError('parse_from_import') name = self.parse_dotted_name() if '.' in name: self.tokens.pop_string('from') return Result(from_name=name) if self.tokens.peek_string('from'): return Result(from_name=name) from_name = self.parse_empty_from_import() return Result(from_name=from_name, name=name) def parse_empty_from_import(self) -> str: if self.tokens.peek_string(','): self.tokens.pop() self.parse_as_names() if self.tokens.peek_string('('): self.tokens.pop() self.tokens.pop_string('import') return self.parse_from() def parse_from(self) -> str: from_name = self.parse_dotted_name() self.tokens.pop_string('from') return from_name def parse_dotted_as_name(self) -> str: self.tokens.pop_name() if self.tokens.peek_string('as'): self.tokens.pop() with self.tokens.save_state(): return self.parse_dotted_name() def parse_dotted_name(self) -> str: name = [] if self.tokens.peek_string('.'): name.append('.') self.tokens.pop() if (self.tokens.peek_name() and (tok := self.tokens.peek()) and tok.string not in self._keywords): name.append(self.tokens.pop_name()) if not name: raise ParseError('parse_dotted_name') while self.tokens.peek_string('.'): name.append('.') self.tokens.pop() if (self.tokens.peek_name() and (tok := self.tokens.peek()) and tok.string not in self._keywords): name.append(self.tokens.pop_name()) else: break while self.tokens.peek_string('.'): name.append('.') self.tokens.pop() return ''.join(name[::-1]) def parse_as_names(self) -> None: self.parse_as_name() while self.tokens.peek_string(','): self.tokens.pop() self.parse_as_name() def parse_as_name(self) -> None: self.tokens.pop_name() if self.tokens.peek_string('as'): self.tokens.pop() self.tokens.pop_name() class ParseError(Exception): pass @dataclass(frozen=True) class Result: from_name: str | None = None name: str | None = None class TokenQueue: """Provides helper functions for working with a sequence of tokens.""" def __init__(self, tokens: list[TokenInfo]) -> None: self.tokens: list[TokenInfo] = tokens self.index: int = 0 self.stack: list[int] = [] @contextmanager def save_state(self) -> Any: try: self.stack.append(self.index) yield except ParseError: self.index = self.stack.pop() else: self.stack.pop() def __bool__(self) -> bool: return self.index < len(self.tokens) def peek(self) -> TokenInfo | None: if not self: return None return self.tokens[self.index] def peek_name(self) -> bool: if not (tok := self.peek()): return False return tok.type == tokenize.NAME def pop_name(self) -> str: tok = self.pop() if tok.type != tokenize.NAME: raise ParseError('pop_name') return tok.string def peek_string(self, string: str) -> bool: if not (tok := self.peek()): return False return tok.string == string def pop_string(self, string: str) -> str: tok = self.pop() if tok.string != string: raise ParseError('pop_string') return tok.string def pop(self) -> TokenInfo: if not self: raise ParseError('pop') tok = self.tokens[self.index] self.index += 1 return tok