mirror of
				https://github.com/python/cpython.git
				synced 2025-10-26 19:24:34 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			612 lines
		
	
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			612 lines
		
	
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import collections.abc
 | |
| import contextlib
 | |
| import errno
 | |
| import os
 | |
| import re
 | |
| import stat
 | |
| import sys
 | |
| import time
 | |
| import unittest
 | |
| import warnings
 | |
| 
 | |
| 
 | |
| # Filename used for testing
 | |
| if os.name == 'java':
 | |
|     # Jython disallows @ in module names
 | |
|     TESTFN_ASCII = '$test'
 | |
| else:
 | |
|     TESTFN_ASCII = '@test'
 | |
| 
 | |
| # Disambiguate TESTFN for parallel testing, while letting it remain a valid
 | |
| # module name.
 | |
| TESTFN_ASCII = "{}_{}_tmp".format(TESTFN_ASCII, os.getpid())
 | |
| 
 | |
| # TESTFN_UNICODE is a non-ascii filename
 | |
| TESTFN_UNICODE = TESTFN_ASCII + "-\xe0\xf2\u0258\u0141\u011f"
 | |
| if sys.platform == 'darwin':
 | |
|     # In Mac OS X's VFS API file names are, by definition, canonically
 | |
|     # decomposed Unicode, encoded using UTF-8. See QA1173:
 | |
|     # http://developer.apple.com/mac/library/qa/qa2001/qa1173.html
 | |
|     import unicodedata
 | |
|     TESTFN_UNICODE = unicodedata.normalize('NFD', TESTFN_UNICODE)
 | |
| 
 | |
| # TESTFN_UNENCODABLE is a filename (str type) that should *not* be able to be
 | |
| # encoded by the filesystem encoding (in strict mode). It can be None if we
 | |
| # cannot generate such filename.
 | |
| TESTFN_UNENCODABLE = None
 | |
| if os.name == 'nt':
 | |
|     # skip win32s (0) or Windows 9x/ME (1)
 | |
|     if sys.getwindowsversion().platform >= 2:
 | |
|         # Different kinds of characters from various languages to minimize the
 | |
|         # probability that the whole name is encodable to MBCS (issue #9819)
 | |
|         TESTFN_UNENCODABLE = TESTFN_ASCII + "-\u5171\u0141\u2661\u0363\uDC80"
 | |
|         try:
 | |
|             TESTFN_UNENCODABLE.encode(sys.getfilesystemencoding())
 | |
|         except UnicodeEncodeError:
 | |
|             pass
 | |
|         else:
 | |
|             print('WARNING: The filename %r CAN be encoded by the filesystem '
 | |
|                   'encoding (%s). Unicode filename tests may not be effective'
 | |
|                   % (TESTFN_UNENCODABLE, sys.getfilesystemencoding()))
 | |
|             TESTFN_UNENCODABLE = None
 | |
| # Mac OS X denies unencodable filenames (invalid utf-8)
 | |
| elif sys.platform != 'darwin':
 | |
|     try:
 | |
|         # ascii and utf-8 cannot encode the byte 0xff
 | |
|         b'\xff'.decode(sys.getfilesystemencoding())
 | |
|     except UnicodeDecodeError:
 | |
|         # 0xff will be encoded using the surrogate character u+DCFF
 | |
|         TESTFN_UNENCODABLE = TESTFN_ASCII \
 | |
|             + b'-\xff'.decode(sys.getfilesystemencoding(), 'surrogateescape')
 | |
|     else:
 | |
|         # File system encoding (eg. ISO-8859-* encodings) can encode
 | |
|         # the byte 0xff. Skip some unicode filename tests.
 | |
|         pass
 | |
| 
 | |
| # FS_NONASCII: non-ASCII character encodable by os.fsencode(),
 | |
| # or an empty string if there is no such character.
 | |
| FS_NONASCII = ''
 | |
| for character in (
 | |
|     # First try printable and common characters to have a readable filename.
 | |
|     # For each character, the encoding list are just example of encodings able
 | |
|     # to encode the character (the list is not exhaustive).
 | |
| 
 | |
|     # U+00E6 (Latin Small Letter Ae): cp1252, iso-8859-1
 | |
|     '\u00E6',
 | |
|     # U+0130 (Latin Capital Letter I With Dot Above): cp1254, iso8859_3
 | |
|     '\u0130',
 | |
|     # U+0141 (Latin Capital Letter L With Stroke): cp1250, cp1257
 | |
|     '\u0141',
 | |
|     # U+03C6 (Greek Small Letter Phi): cp1253
 | |
|     '\u03C6',
 | |
|     # U+041A (Cyrillic Capital Letter Ka): cp1251
 | |
|     '\u041A',
 | |
|     # U+05D0 (Hebrew Letter Alef): Encodable to cp424
 | |
|     '\u05D0',
 | |
|     # U+060C (Arabic Comma): cp864, cp1006, iso8859_6, mac_arabic
 | |
|     '\u060C',
 | |
|     # U+062A (Arabic Letter Teh): cp720
 | |
|     '\u062A',
 | |
|     # U+0E01 (Thai Character Ko Kai): cp874
 | |
|     '\u0E01',
 | |
| 
 | |
|     # Then try more "special" characters. "special" because they may be
 | |
|     # interpreted or displayed differently depending on the exact locale
 | |
|     # encoding and the font.
 | |
| 
 | |
|     # U+00A0 (No-Break Space)
 | |
|     '\u00A0',
 | |
|     # U+20AC (Euro Sign)
 | |
|     '\u20AC',
 | |
| ):
 | |
|     try:
 | |
|         # If Python is set up to use the legacy 'mbcs' in Windows,
 | |
|         # 'replace' error mode is used, and encode() returns b'?'
 | |
|         # for characters missing in the ANSI codepage
 | |
|         if os.fsdecode(os.fsencode(character)) != character:
 | |
|             raise UnicodeError
 | |
|     except UnicodeError:
 | |
|         pass
 | |
|     else:
 | |
|         FS_NONASCII = character
 | |
|         break
 | |
| 
 | |
| # Save the initial cwd
 | |
| SAVEDCWD = os.getcwd()
 | |
| 
 | |
| # TESTFN_UNDECODABLE is a filename (bytes type) that should *not* be able to be
 | |
| # decoded from the filesystem encoding (in strict mode). It can be None if we
 | |
| # cannot generate such filename (ex: the latin1 encoding can decode any byte
 | |
| # sequence). On UNIX, TESTFN_UNDECODABLE can be decoded by os.fsdecode() thanks
 | |
| # to the surrogateescape error handler (PEP 383), but not from the filesystem
 | |
| # encoding in strict mode.
 | |
| TESTFN_UNDECODABLE = None
 | |
| for name in (
 | |
|     # b'\xff' is not decodable by os.fsdecode() with code page 932. Windows
 | |
|     # accepts it to create a file or a directory, or don't accept to enter to
 | |
|     # such directory (when the bytes name is used). So test b'\xe7' first:
 | |
|     # it is not decodable from cp932.
 | |
|     b'\xe7w\xf0',
 | |
|     # undecodable from ASCII, UTF-8
 | |
|     b'\xff',
 | |
|     # undecodable from iso8859-3, iso8859-6, iso8859-7, cp424, iso8859-8, cp856
 | |
|     # and cp857
 | |
|     b'\xae\xd5'
 | |
|     # undecodable from UTF-8 (UNIX and Mac OS X)
 | |
|     b'\xed\xb2\x80', b'\xed\xb4\x80',
 | |
|     # undecodable from shift_jis, cp869, cp874, cp932, cp1250, cp1251, cp1252,
 | |
|     # cp1253, cp1254, cp1255, cp1257, cp1258
 | |
|     b'\x81\x98',
 | |
| ):
 | |
|     try:
 | |
|         name.decode(sys.getfilesystemencoding())
 | |
|     except UnicodeDecodeError:
 | |
|         TESTFN_UNDECODABLE = os.fsencode(TESTFN_ASCII) + name
 | |
|         break
 | |
| 
 | |
| if FS_NONASCII:
 | |
|     TESTFN_NONASCII = TESTFN_ASCII + FS_NONASCII
 | |
| else:
 | |
|     TESTFN_NONASCII = None
 | |
| TESTFN = TESTFN_NONASCII or TESTFN_ASCII
 | |
| 
 | |
| 
 | |
| def make_bad_fd():
 | |
|     """
 | |
|     Create an invalid file descriptor by opening and closing a file and return
 | |
|     its fd.
 | |
|     """
 | |
|     file = open(TESTFN, "wb")
 | |
|     try:
 | |
|         return file.fileno()
 | |
|     finally:
 | |
|         file.close()
 | |
|         unlink(TESTFN)
 | |
| 
 | |
| 
 | |
| _can_symlink = None
 | |
| 
 | |
| 
 | |
| def can_symlink():
 | |
|     global _can_symlink
 | |
|     if _can_symlink is not None:
 | |
|         return _can_symlink
 | |
|     symlink_path = TESTFN + "can_symlink"
 | |
|     try:
 | |
|         os.symlink(TESTFN, symlink_path)
 | |
|         can = True
 | |
|     except (OSError, NotImplementedError, AttributeError):
 | |
|         can = False
 | |
|     else:
 | |
|         os.remove(symlink_path)
 | |
|     _can_symlink = can
 | |
|     return can
 | |
| 
 | |
| 
 | |
| def skip_unless_symlink(test):
 | |
|     """Skip decorator for tests that require functional symlink"""
 | |
|     ok = can_symlink()
 | |
|     msg = "Requires functional symlink implementation"
 | |
|     return test if ok else unittest.skip(msg)(test)
 | |
| 
 | |
| 
 | |
| _can_xattr = None
 | |
| 
 | |
| 
 | |
| def can_xattr():
 | |
|     import tempfile
 | |
|     global _can_xattr
 | |
|     if _can_xattr is not None:
 | |
|         return _can_xattr
 | |
|     if not hasattr(os, "setxattr"):
 | |
|         can = False
 | |
|     else:
 | |
|         import platform
 | |
|         tmp_dir = tempfile.mkdtemp()
 | |
|         tmp_fp, tmp_name = tempfile.mkstemp(dir=tmp_dir)
 | |
|         try:
 | |
|             with open(TESTFN, "wb") as fp:
 | |
|                 try:
 | |
|                     # TESTFN & tempfile may use different file systems with
 | |
|                     # different capabilities
 | |
|                     os.setxattr(tmp_fp, b"user.test", b"")
 | |
|                     os.setxattr(tmp_name, b"trusted.foo", b"42")
 | |
|                     os.setxattr(fp.fileno(), b"user.test", b"")
 | |
|                     # Kernels < 2.6.39 don't respect setxattr flags.
 | |
|                     kernel_version = platform.release()
 | |
|                     m = re.match(r"2.6.(\d{1,2})", kernel_version)
 | |
|                     can = m is None or int(m.group(1)) >= 39
 | |
|                 except OSError:
 | |
|                     can = False
 | |
|         finally:
 | |
|             unlink(TESTFN)
 | |
|             unlink(tmp_name)
 | |
|             rmdir(tmp_dir)
 | |
|     _can_xattr = can
 | |
|     return can
 | |
| 
 | |
| 
 | |
| def skip_unless_xattr(test):
 | |
|     """Skip decorator for tests that require functional extended attributes"""
 | |
|     ok = can_xattr()
 | |
|     msg = "no non-broken extended attribute support"
 | |
|     return test if ok else unittest.skip(msg)(test)
 | |
| 
 | |
| 
 | |
| def unlink(filename):
 | |
|     try:
 | |
|         _unlink(filename)
 | |
|     except (FileNotFoundError, NotADirectoryError):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| if sys.platform.startswith("win"):
 | |
|     def _waitfor(func, pathname, waitall=False):
 | |
|         # Perform the operation
 | |
|         func(pathname)
 | |
|         # Now setup the wait loop
 | |
|         if waitall:
 | |
|             dirname = pathname
 | |
|         else:
 | |
|             dirname, name = os.path.split(pathname)
 | |
|             dirname = dirname or '.'
 | |
|         # Check for `pathname` to be removed from the filesystem.
 | |
|         # The exponential backoff of the timeout amounts to a total
 | |
|         # of ~1 second after which the deletion is probably an error
 | |
|         # anyway.
 | |
|         # Testing on an i7@4.3GHz shows that usually only 1 iteration is
 | |
|         # required when contention occurs.
 | |
|         timeout = 0.001
 | |
|         while timeout < 1.0:
 | |
|             # Note we are only testing for the existence of the file(s) in
 | |
|             # the contents of the directory regardless of any security or
 | |
|             # access rights.  If we have made it this far, we have sufficient
 | |
|             # permissions to do that much using Python's equivalent of the
 | |
|             # Windows API FindFirstFile.
 | |
|             # Other Windows APIs can fail or give incorrect results when
 | |
|             # dealing with files that are pending deletion.
 | |
|             L = os.listdir(dirname)
 | |
|             if not (L if waitall else name in L):
 | |
|                 return
 | |
|             # Increase the timeout and try again
 | |
|             time.sleep(timeout)
 | |
|             timeout *= 2
 | |
|         warnings.warn('tests may fail, delete still pending for ' + pathname,
 | |
|                       RuntimeWarning, stacklevel=4)
 | |
| 
 | |
|     def _unlink(filename):
 | |
|         _waitfor(os.unlink, filename)
 | |
| 
 | |
|     def _rmdir(dirname):
 | |
|         _waitfor(os.rmdir, dirname)
 | |
| 
 | |
|     def _rmtree(path):
 | |
|         from test.support import _force_run
 | |
| 
 | |
|         def _rmtree_inner(path):
 | |
|             for name in _force_run(path, os.listdir, path):
 | |
|                 fullname = os.path.join(path, name)
 | |
|                 try:
 | |
|                     mode = os.lstat(fullname).st_mode
 | |
|                 except OSError as exc:
 | |
|                     print("support.rmtree(): os.lstat(%r) failed with %s"
 | |
|                           % (fullname, exc),
 | |
|                           file=sys.__stderr__)
 | |
|                     mode = 0
 | |
|                 if stat.S_ISDIR(mode):
 | |
|                     _waitfor(_rmtree_inner, fullname, waitall=True)
 | |
|                     _force_run(fullname, os.rmdir, fullname)
 | |
|                 else:
 | |
|                     _force_run(fullname, os.unlink, fullname)
 | |
|         _waitfor(_rmtree_inner, path, waitall=True)
 | |
|         _waitfor(lambda p: _force_run(p, os.rmdir, p), path)
 | |
| 
 | |
|     def _longpath(path):
 | |
|         try:
 | |
|             import ctypes
 | |
|         except ImportError:
 | |
|             # No ctypes means we can't expands paths.
 | |
|             pass
 | |
|         else:
 | |
|             buffer = ctypes.create_unicode_buffer(len(path) * 2)
 | |
|             length = ctypes.windll.kernel32.GetLongPathNameW(path, buffer,
 | |
|                                                              len(buffer))
 | |
|             if length:
 | |
|                 return buffer[:length]
 | |
|         return path
 | |
| else:
 | |
|     _unlink = os.unlink
 | |
|     _rmdir = os.rmdir
 | |
| 
 | |
|     def _rmtree(path):
 | |
|         import shutil
 | |
|         try:
 | |
|             shutil.rmtree(path)
 | |
|             return
 | |
|         except OSError:
 | |
|             pass
 | |
| 
 | |
|         def _rmtree_inner(path):
 | |
|             from test.support import _force_run
 | |
|             for name in _force_run(path, os.listdir, path):
 | |
|                 fullname = os.path.join(path, name)
 | |
|                 try:
 | |
|                     mode = os.lstat(fullname).st_mode
 | |
|                 except OSError:
 | |
|                     mode = 0
 | |
|                 if stat.S_ISDIR(mode):
 | |
|                     _rmtree_inner(fullname)
 | |
|                     _force_run(path, os.rmdir, fullname)
 | |
|                 else:
 | |
|                     _force_run(path, os.unlink, fullname)
 | |
|         _rmtree_inner(path)
 | |
|         os.rmdir(path)
 | |
| 
 | |
|     def _longpath(path):
 | |
|         return path
 | |
| 
 | |
| 
 | |
| def rmdir(dirname):
 | |
|     try:
 | |
|         _rmdir(dirname)
 | |
|     except FileNotFoundError:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| def rmtree(path):
 | |
|     try:
 | |
|         _rmtree(path)
 | |
|     except FileNotFoundError:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| @contextlib.contextmanager
 | |
| def temp_dir(path=None, quiet=False):
 | |
|     """Return a context manager that creates a temporary directory.
 | |
| 
 | |
|     Arguments:
 | |
| 
 | |
|       path: the directory to create temporarily.  If omitted or None,
 | |
|         defaults to creating a temporary directory using tempfile.mkdtemp.
 | |
| 
 | |
|       quiet: if False (the default), the context manager raises an exception
 | |
|         on error.  Otherwise, if the path is specified and cannot be
 | |
|         created, only a warning is issued.
 | |
| 
 | |
|     """
 | |
|     import tempfile
 | |
|     dir_created = False
 | |
|     if path is None:
 | |
|         path = tempfile.mkdtemp()
 | |
|         dir_created = True
 | |
|         path = os.path.realpath(path)
 | |
|     else:
 | |
|         try:
 | |
|             os.mkdir(path)
 | |
|             dir_created = True
 | |
|         except OSError as exc:
 | |
|             if not quiet:
 | |
|                 raise
 | |
|             warnings.warn(f'tests may fail, unable to create '
 | |
|                           f'temporary directory {path!r}: {exc}',
 | |
|                           RuntimeWarning, stacklevel=3)
 | |
|     if dir_created:
 | |
|         pid = os.getpid()
 | |
|     try:
 | |
|         yield path
 | |
|     finally:
 | |
|         # In case the process forks, let only the parent remove the
 | |
|         # directory. The child has a different process id. (bpo-30028)
 | |
|         if dir_created and pid == os.getpid():
 | |
|             rmtree(path)
 | |
| 
 | |
| 
 | |
| @contextlib.contextmanager
 | |
| def change_cwd(path, quiet=False):
 | |
|     """Return a context manager that changes the current working directory.
 | |
| 
 | |
|     Arguments:
 | |
| 
 | |
|       path: the directory to use as the temporary current working directory.
 | |
| 
 | |
|       quiet: if False (the default), the context manager raises an exception
 | |
|         on error.  Otherwise, it issues only a warning and keeps the current
 | |
|         working directory the same.
 | |
| 
 | |
|     """
 | |
|     saved_dir = os.getcwd()
 | |
|     try:
 | |
|         os.chdir(os.path.realpath(path))
 | |
|     except OSError as exc:
 | |
|         if not quiet:
 | |
|             raise
 | |
|         warnings.warn(f'tests may fail, unable to change the current working '
 | |
|                       f'directory to {path!r}: {exc}',
 | |
|                       RuntimeWarning, stacklevel=3)
 | |
|     try:
 | |
|         yield os.getcwd()
 | |
|     finally:
 | |
|         os.chdir(saved_dir)
 | |
| 
 | |
| 
 | |
| @contextlib.contextmanager
 | |
| def temp_cwd(name='tempcwd', quiet=False):
 | |
|     """
 | |
|     Context manager that temporarily creates and changes the CWD.
 | |
| 
 | |
|     The function temporarily changes the current working directory
 | |
|     after creating a temporary directory in the current directory with
 | |
|     name *name*.  If *name* is None, the temporary directory is
 | |
|     created using tempfile.mkdtemp.
 | |
| 
 | |
|     If *quiet* is False (default) and it is not possible to
 | |
|     create or change the CWD, an error is raised.  If *quiet* is True,
 | |
|     only a warning is raised and the original CWD is used.
 | |
| 
 | |
|     """
 | |
|     with temp_dir(path=name, quiet=quiet) as temp_path:
 | |
|         with change_cwd(temp_path, quiet=quiet) as cwd_dir:
 | |
|             yield cwd_dir
 | |
| 
 | |
| 
 | |
| def create_empty_file(filename):
 | |
|     """Create an empty file. If the file already exists, truncate it."""
 | |
|     fd = os.open(filename, os.O_WRONLY | os.O_CREAT | os.O_TRUNC)
 | |
|     os.close(fd)
 | |
| 
 | |
| 
 | |
| def fs_is_case_insensitive(directory):
 | |
|     """Detects if the file system for the specified directory
 | |
|     is case-insensitive."""
 | |
|     import tempfile
 | |
|     with tempfile.NamedTemporaryFile(dir=directory) as base:
 | |
|         base_path = base.name
 | |
|         case_path = base_path.upper()
 | |
|         if case_path == base_path:
 | |
|             case_path = base_path.lower()
 | |
|         try:
 | |
|             return os.path.samefile(base_path, case_path)
 | |
|         except FileNotFoundError:
 | |
|             return False
 | |
| 
 | |
| 
 | |
| class FakePath:
 | |
|     """Simple implementing of the path protocol.
 | |
|     """
 | |
|     def __init__(self, path):
 | |
|         self.path = path
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return f'<FakePath {self.path!r}>'
 | |
| 
 | |
|     def __fspath__(self):
 | |
|         if (isinstance(self.path, BaseException) or
 | |
|             isinstance(self.path, type) and
 | |
|                 issubclass(self.path, BaseException)):
 | |
|             raise self.path
 | |
|         else:
 | |
|             return self.path
 | |
| 
 | |
| 
 | |
| def fd_count():
 | |
|     """Count the number of open file descriptors.
 | |
|     """
 | |
|     if sys.platform.startswith(('linux', 'freebsd')):
 | |
|         try:
 | |
|             names = os.listdir("/proc/self/fd")
 | |
|             # Subtract one because listdir() internally opens a file
 | |
|             # descriptor to list the content of the /proc/self/fd/ directory.
 | |
|             return len(names) - 1
 | |
|         except FileNotFoundError:
 | |
|             pass
 | |
| 
 | |
|     MAXFD = 256
 | |
|     if hasattr(os, 'sysconf'):
 | |
|         try:
 | |
|             MAXFD = os.sysconf("SC_OPEN_MAX")
 | |
|         except OSError:
 | |
|             pass
 | |
| 
 | |
|     old_modes = None
 | |
|     if sys.platform == 'win32':
 | |
|         # bpo-25306, bpo-31009: Call CrtSetReportMode() to not kill the process
 | |
|         # on invalid file descriptor if Python is compiled in debug mode
 | |
|         try:
 | |
|             import msvcrt
 | |
|             msvcrt.CrtSetReportMode
 | |
|         except (AttributeError, ImportError):
 | |
|             # no msvcrt or a release build
 | |
|             pass
 | |
|         else:
 | |
|             old_modes = {}
 | |
|             for report_type in (msvcrt.CRT_WARN,
 | |
|                                 msvcrt.CRT_ERROR,
 | |
|                                 msvcrt.CRT_ASSERT):
 | |
|                 old_modes[report_type] = msvcrt.CrtSetReportMode(report_type,
 | |
|                                                                  0)
 | |
| 
 | |
|     try:
 | |
|         count = 0
 | |
|         for fd in range(MAXFD):
 | |
|             try:
 | |
|                 # Prefer dup() over fstat(). fstat() can require input/output
 | |
|                 # whereas dup() doesn't.
 | |
|                 fd2 = os.dup(fd)
 | |
|             except OSError as e:
 | |
|                 if e.errno != errno.EBADF:
 | |
|                     raise
 | |
|             else:
 | |
|                 os.close(fd2)
 | |
|                 count += 1
 | |
|     finally:
 | |
|         if old_modes is not None:
 | |
|             for report_type in (msvcrt.CRT_WARN,
 | |
|                                 msvcrt.CRT_ERROR,
 | |
|                                 msvcrt.CRT_ASSERT):
 | |
|                 msvcrt.CrtSetReportMode(report_type, old_modes[report_type])
 | |
| 
 | |
|     return count
 | |
| 
 | |
| 
 | |
| if hasattr(os, "umask"):
 | |
|     @contextlib.contextmanager
 | |
|     def temp_umask(umask):
 | |
|         """Context manager that temporarily sets the process umask."""
 | |
|         oldmask = os.umask(umask)
 | |
|         try:
 | |
|             yield
 | |
|         finally:
 | |
|             os.umask(oldmask)
 | |
| 
 | |
| 
 | |
| class EnvironmentVarGuard(collections.abc.MutableMapping):
 | |
| 
 | |
|     """Class to help protect the environment variable properly.  Can be used as
 | |
|     a context manager."""
 | |
| 
 | |
|     def __init__(self):
 | |
|         self._environ = os.environ
 | |
|         self._changed = {}
 | |
| 
 | |
|     def __getitem__(self, envvar):
 | |
|         return self._environ[envvar]
 | |
| 
 | |
|     def __setitem__(self, envvar, value):
 | |
|         # Remember the initial value on the first access
 | |
|         if envvar not in self._changed:
 | |
|             self._changed[envvar] = self._environ.get(envvar)
 | |
|         self._environ[envvar] = value
 | |
| 
 | |
|     def __delitem__(self, envvar):
 | |
|         # Remember the initial value on the first access
 | |
|         if envvar not in self._changed:
 | |
|             self._changed[envvar] = self._environ.get(envvar)
 | |
|         if envvar in self._environ:
 | |
|             del self._environ[envvar]
 | |
| 
 | |
|     def keys(self):
 | |
|         return self._environ.keys()
 | |
| 
 | |
|     def __iter__(self):
 | |
|         return iter(self._environ)
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self._environ)
 | |
| 
 | |
|     def set(self, envvar, value):
 | |
|         self[envvar] = value
 | |
| 
 | |
|     def unset(self, envvar):
 | |
|         del self[envvar]
 | |
| 
 | |
|     def __enter__(self):
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, *ignore_exc):
 | |
|         for (k, v) in self._changed.items():
 | |
|             if v is None:
 | |
|                 if k in self._environ:
 | |
|                     del self._environ[k]
 | |
|             else:
 | |
|                 self._environ[k] = v
 | |
|         os.environ = self._environ
 | 
