| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | import os | 
					
						
							|  |  |  | import pathlib | 
					
						
							|  |  |  | import tempfile | 
					
						
							|  |  |  | import functools | 
					
						
							|  |  |  | import contextlib | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  | import types | 
					
						
							|  |  |  | import importlib | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-18 21:28:49 -05:00
										 |  |  | from typing import Union, Optional | 
					
						
							| 
									
										
										
										
											2021-05-21 13:00:40 -04:00
										 |  |  | from .abc import ResourceReader, Traversable | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-04 13:43:00 -05:00
										 |  |  | from ._adapters import wrap_spec | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  | Package = Union[types.ModuleType, str] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def files(package): | 
					
						
							| 
									
										
										
										
											2021-05-21 13:00:40 -04:00
										 |  |  |     # type: (Package) -> Traversable | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  |     Get a Traversable resource from a package | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     return from_package(get_package(package)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  | def get_resource_reader(package): | 
					
						
							|  |  |  |     # type: (types.ModuleType) -> Optional[ResourceReader] | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  |     Return the package's loader if it's a ResourceReader. | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  |     # We can't use | 
					
						
							|  |  |  |     # a issubclass() check here because apparently abc.'s __subclasscheck__() | 
					
						
							|  |  |  |     # hook wants to create a weak reference to the object, but | 
					
						
							|  |  |  |     # zipimport.zipimporter does not support weak references, resulting in a | 
					
						
							|  |  |  |     # TypeError.  That seems terrible. | 
					
						
							|  |  |  |     spec = package.__spec__ | 
					
						
							| 
									
										
										
										
											2021-03-04 13:43:00 -05:00
										 |  |  |     reader = getattr(spec.loader, 'get_resource_reader', None)  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  |     if reader is None: | 
					
						
							|  |  |  |         return None | 
					
						
							| 
									
										
										
										
											2021-03-04 13:43:00 -05:00
										 |  |  |     return reader(spec.name)  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  | def resolve(cand): | 
					
						
							|  |  |  |     # type: (Package) -> types.ModuleType | 
					
						
							| 
									
										
										
										
											2021-03-04 13:43:00 -05:00
										 |  |  |     return cand if isinstance(cand, types.ModuleType) else importlib.import_module(cand) | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_package(package): | 
					
						
							|  |  |  |     # type: (Package) -> types.ModuleType | 
					
						
							|  |  |  |     """Take a package name or module object and return the module.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Raise an exception if the resolved module is not a package. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     resolved = resolve(package) | 
					
						
							| 
									
										
										
										
											2021-03-04 13:43:00 -05:00
										 |  |  |     if wrap_spec(resolved).submodule_search_locations is None: | 
					
						
							| 
									
										
										
										
											2021-05-26 16:16:11 -04:00
										 |  |  |         raise TypeError(f'{package!r} is not a package') | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  |     return resolved | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def from_package(package): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Return a Traversable object for the given package. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2021-03-04 13:43:00 -05:00
										 |  |  |     spec = wrap_spec(package) | 
					
						
							| 
									
										
										
										
											2020-06-07 21:00:51 -04:00
										 |  |  |     reader = spec.loader.get_resource_reader(spec.name) | 
					
						
							|  |  |  |     return reader.files() | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @contextlib.contextmanager | 
					
						
							| 
									
										
										
										
											2022-10-16 15:00:39 -04:00
										 |  |  | def _tempfile( | 
					
						
							|  |  |  |     reader, | 
					
						
							|  |  |  |     suffix='', | 
					
						
							|  |  |  |     # gh-93353: Keep a reference to call os.remove() in late Python | 
					
						
							|  |  |  |     # finalization. | 
					
						
							|  |  |  |     *, | 
					
						
							|  |  |  |     _os_remove=os.remove, | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  |     # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try' | 
					
						
							|  |  |  |     # blocks due to the need to close the temporary file to work on Windows | 
					
						
							|  |  |  |     # properly. | 
					
						
							|  |  |  |     fd, raw_path = tempfile.mkstemp(suffix=suffix) | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2021-07-30 20:37:09 -04:00
										 |  |  |         try: | 
					
						
							|  |  |  |             os.write(fd, reader()) | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             os.close(fd) | 
					
						
							| 
									
										
										
										
											2020-10-25 14:21:46 -04:00
										 |  |  |         del reader | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  |         yield pathlib.Path(raw_path) | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2022-06-13 19:24:00 +02:00
										 |  |  |             _os_remove(raw_path) | 
					
						
							| 
									
										
										
										
											2021-07-30 20:37:09 -04:00
										 |  |  |         except FileNotFoundError: | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  |             pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-16 15:00:39 -04:00
										 |  |  | def _temp_file(path): | 
					
						
							|  |  |  |     return _tempfile(path.read_bytes, suffix=path.name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _is_present_dir(path: Traversable) -> bool: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Some Traversables implement ``is_dir()`` to raise an | 
					
						
							|  |  |  |     exception (i.e. ``FileNotFoundError``) when the | 
					
						
							|  |  |  |     directory doesn't exist. This function wraps that call | 
					
						
							|  |  |  |     to always return a boolean and only return True | 
					
						
							|  |  |  |     if there's a dir and it exists. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     with contextlib.suppress(FileNotFoundError): | 
					
						
							|  |  |  |         return path.is_dir() | 
					
						
							|  |  |  |     return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | @functools.singledispatch | 
					
						
							|  |  |  | def as_file(path): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Given a Traversable object, return that object as a | 
					
						
							|  |  |  |     path on the local file system in a context manager. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2022-10-16 15:00:39 -04:00
										 |  |  |     return _temp_dir(path) if _is_present_dir(path) else _temp_file(path) | 
					
						
							| 
									
										
										
										
											2020-05-08 19:20:26 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @as_file.register(pathlib.Path) | 
					
						
							|  |  |  | @contextlib.contextmanager | 
					
						
							|  |  |  | def _(path): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Degenerate behavior for pathlib.Path objects. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     yield path | 
					
						
							| 
									
										
										
										
											2022-10-16 15:00:39 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @contextlib.contextmanager | 
					
						
							|  |  |  | def _temp_path(dir: tempfile.TemporaryDirectory): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Wrap tempfile.TemporyDirectory to return a pathlib object. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     with dir as result: | 
					
						
							|  |  |  |         yield pathlib.Path(result) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @contextlib.contextmanager | 
					
						
							|  |  |  | def _temp_dir(path): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Given a traversable dir, recursively replicate the whole tree | 
					
						
							|  |  |  |     to the file system in a context manager. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     assert path.is_dir() | 
					
						
							|  |  |  |     with _temp_path(tempfile.TemporaryDirectory()) as temp_dir: | 
					
						
							|  |  |  |         yield _write_contents(temp_dir, path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _write_contents(target, source): | 
					
						
							|  |  |  |     child = target.joinpath(source.name) | 
					
						
							|  |  |  |     if source.is_dir(): | 
					
						
							|  |  |  |         child.mkdir() | 
					
						
							|  |  |  |         for item in source.iterdir(): | 
					
						
							|  |  |  |             _write_contents(child, item) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         child.open('wb').write(source.read_bytes()) | 
					
						
							|  |  |  |     return child |