[3.9] gh-102950: Implement PEP 706 – Filter for tarfile.extractall (GH-102953) (#104382)

Backport of c8c3956d90
This commit is contained in:
Petr Viktorin 2023-05-15 18:53:58 +02:00 committed by GitHub
parent 7cb3a44747
commit 98016f7c92
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 1761 additions and 78 deletions

View file

@ -46,6 +46,7 @@
import struct
import copy
import re
import warnings
try:
import pwd
@ -71,6 +72,7 @@
"ENCODING", "USTAR_FORMAT", "GNU_FORMAT", "PAX_FORMAT",
"DEFAULT_FORMAT", "open"]
#---------------------------------------------------------
# tar constants
#---------------------------------------------------------
@ -158,6 +160,8 @@
def stn(s, length, encoding, errors):
"""Convert a string to a null-terminated bytes object.
"""
if s is None:
raise ValueError("metadata cannot contain None")
s = s.encode(encoding, errors)
return s[:length] + (length - len(s)) * NUL
@ -708,9 +712,127 @@ def __init__(self, tarfile, tarinfo):
super().__init__(fileobj)
#class ExFileObject
#-----------------------------
# extraction filters (PEP 706)
#-----------------------------
class FilterError(TarError):
pass
class AbsolutePathError(FilterError):
def __init__(self, tarinfo):
self.tarinfo = tarinfo
super().__init__(f'member {tarinfo.name!r} has an absolute path')
class OutsideDestinationError(FilterError):
def __init__(self, tarinfo, path):
self.tarinfo = tarinfo
self._path = path
super().__init__(f'{tarinfo.name!r} would be extracted to {path!r}, '
+ 'which is outside the destination')
class SpecialFileError(FilterError):
def __init__(self, tarinfo):
self.tarinfo = tarinfo
super().__init__(f'{tarinfo.name!r} is a special file')
class AbsoluteLinkError(FilterError):
def __init__(self, tarinfo):
self.tarinfo = tarinfo
super().__init__(f'{tarinfo.name!r} is a symlink to an absolute path')
class LinkOutsideDestinationError(FilterError):
def __init__(self, tarinfo, path):
self.tarinfo = tarinfo
self._path = path
super().__init__(f'{tarinfo.name!r} would link to {path!r}, '
+ 'which is outside the destination')
def _get_filtered_attrs(member, dest_path, for_data=True):
new_attrs = {}
name = member.name
dest_path = os.path.realpath(dest_path)
# Strip leading / (tar's directory separator) from filenames.
# Include os.sep (target OS directory separator) as well.
if name.startswith(('/', os.sep)):
name = new_attrs['name'] = member.path.lstrip('/' + os.sep)
if os.path.isabs(name):
# Path is absolute even after stripping.
# For example, 'C:/foo' on Windows.
raise AbsolutePathError(member)
# Ensure we stay in the destination
target_path = os.path.realpath(os.path.join(dest_path, name))
if os.path.commonpath([target_path, dest_path]) != dest_path:
raise OutsideDestinationError(member, target_path)
# Limit permissions (no high bits, and go-w)
mode = member.mode
if mode is not None:
# Strip high bits & group/other write bits
mode = mode & 0o755
if for_data:
# For data, handle permissions & file types
if member.isreg() or member.islnk():
if not mode & 0o100:
# Clear executable bits if not executable by user
mode &= ~0o111
# Ensure owner can read & write
mode |= 0o600
elif member.isdir() or member.issym():
# Ignore mode for directories & symlinks
mode = None
else:
# Reject special files
raise SpecialFileError(member)
if mode != member.mode:
new_attrs['mode'] = mode
if for_data:
# Ignore ownership for 'data'
if member.uid is not None:
new_attrs['uid'] = None
if member.gid is not None:
new_attrs['gid'] = None
if member.uname is not None:
new_attrs['uname'] = None
if member.gname is not None:
new_attrs['gname'] = None
# Check link destination for 'data'
if member.islnk() or member.issym():
if os.path.isabs(member.linkname):
raise AbsoluteLinkError(member)
target_path = os.path.realpath(os.path.join(dest_path, member.linkname))
if os.path.commonpath([target_path, dest_path]) != dest_path:
raise LinkOutsideDestinationError(member, target_path)
return new_attrs
def fully_trusted_filter(member, dest_path):
return member
def tar_filter(member, dest_path):
new_attrs = _get_filtered_attrs(member, dest_path, False)
if new_attrs:
return member.replace(**new_attrs, deep=False)
return member
def data_filter(member, dest_path):
new_attrs = _get_filtered_attrs(member, dest_path, True)
if new_attrs:
return member.replace(**new_attrs, deep=False)
return member
_NAMED_FILTERS = {
"fully_trusted": fully_trusted_filter,
"tar": tar_filter,
"data": data_filter,
}
#------------------
# Exported Classes
#------------------
# Sentinel for replace() defaults, meaning "don't change the attribute"
_KEEP = object()
class TarInfo(object):
"""Informational class which holds the details about an
archive member given by a tar header block.
@ -791,12 +913,44 @@ def linkpath(self, linkname):
def __repr__(self):
return "<%s %r at %#x>" % (self.__class__.__name__,self.name,id(self))
def replace(self, *,
name=_KEEP, mtime=_KEEP, mode=_KEEP, linkname=_KEEP,
uid=_KEEP, gid=_KEEP, uname=_KEEP, gname=_KEEP,
deep=True, _KEEP=_KEEP):
"""Return a deep copy of self with the given attributes replaced.
"""
if deep:
result = copy.deepcopy(self)
else:
result = copy.copy(self)
if name is not _KEEP:
result.name = name
if mtime is not _KEEP:
result.mtime = mtime
if mode is not _KEEP:
result.mode = mode
if linkname is not _KEEP:
result.linkname = linkname
if uid is not _KEEP:
result.uid = uid
if gid is not _KEEP:
result.gid = gid
if uname is not _KEEP:
result.uname = uname
if gname is not _KEEP:
result.gname = gname
return result
def get_info(self):
"""Return the TarInfo's attributes as a dictionary.
"""
if self.mode is None:
mode = None
else:
mode = self.mode & 0o7777
info = {
"name": self.name,
"mode": self.mode & 0o7777,
"mode": mode,
"uid": self.uid,
"gid": self.gid,
"size": self.size,
@ -819,6 +973,9 @@ def tobuf(self, format=DEFAULT_FORMAT, encoding=ENCODING, errors="surrogateescap
"""Return a tar header as a string of 512 byte blocks.
"""
info = self.get_info()
for name, value in info.items():
if value is None:
raise ValueError("%s may not be None" % name)
if format == USTAR_FORMAT:
return self.create_ustar_header(info, encoding, errors)
@ -949,6 +1106,12 @@ def _create_header(info, format, encoding, errors):
devmajor = stn("", 8, encoding, errors)
devminor = stn("", 8, encoding, errors)
# None values in metadata should cause ValueError.
# itn()/stn() do this for all fields except type.
filetype = info.get("type", REGTYPE)
if filetype is None:
raise ValueError("TarInfo.type must not be None")
parts = [
stn(info.get("name", ""), 100, encoding, errors),
itn(info.get("mode", 0) & 0o7777, 8, format),
@ -957,7 +1120,7 @@ def _create_header(info, format, encoding, errors):
itn(info.get("size", 0), 12, format),
itn(info.get("mtime", 0), 12, format),
b" ", # checksum field
info.get("type", REGTYPE),
filetype,
stn(info.get("linkname", ""), 100, encoding, errors),
info.get("magic", POSIX_MAGIC),
stn(info.get("uname", ""), 32, encoding, errors),
@ -1457,6 +1620,8 @@ class TarFile(object):
fileobject = ExFileObject # The file-object for extractfile().
extraction_filter = None # The default filter for extraction.
def __init__(self, name=None, mode="r", fileobj=None, format=None,
tarinfo=None, dereference=None, ignore_zeros=None, encoding=None,
errors="surrogateescape", pax_headers=None, debug=None,
@ -1926,7 +2091,10 @@ def list(self, verbose=True, *, members=None):
members = self
for tarinfo in members:
if verbose:
_safe_print(stat.filemode(tarinfo.mode))
if tarinfo.mode is None:
_safe_print("??????????")
else:
_safe_print(stat.filemode(tarinfo.mode))
_safe_print("%s/%s" % (tarinfo.uname or tarinfo.uid,
tarinfo.gname or tarinfo.gid))
if tarinfo.ischr() or tarinfo.isblk():
@ -1934,8 +2102,11 @@ def list(self, verbose=True, *, members=None):
("%d,%d" % (tarinfo.devmajor, tarinfo.devminor)))
else:
_safe_print("%10d" % tarinfo.size)
_safe_print("%d-%02d-%02d %02d:%02d:%02d" \
% time.localtime(tarinfo.mtime)[:6])
if tarinfo.mtime is None:
_safe_print("????-??-?? ??:??:??")
else:
_safe_print("%d-%02d-%02d %02d:%02d:%02d" \
% time.localtime(tarinfo.mtime)[:6])
_safe_print(tarinfo.name + ("/" if tarinfo.isdir() else ""))
@ -2022,32 +2193,58 @@ def addfile(self, tarinfo, fileobj=None):
self.members.append(tarinfo)
def extractall(self, path=".", members=None, *, numeric_owner=False):
def _get_filter_function(self, filter):
if filter is None:
filter = self.extraction_filter
if filter is None:
return fully_trusted_filter
if isinstance(filter, str):
raise TypeError(
'String names are not supported for '
+ 'TarFile.extraction_filter. Use a function such as '
+ 'tarfile.data_filter directly.')
return filter
if callable(filter):
return filter
try:
return _NAMED_FILTERS[filter]
except KeyError:
raise ValueError(f"filter {filter!r} not found") from None
def extractall(self, path=".", members=None, *, numeric_owner=False,
filter=None):
"""Extract all members from the archive to the current working
directory and set owner, modification time and permissions on
directories afterwards. `path' specifies a different directory
to extract to. `members' is optional and must be a subset of the
list returned by getmembers(). If `numeric_owner` is True, only
the numbers for user/group names are used and not the names.
The `filter` function will be called on each member just
before extraction.
It can return a changed TarInfo or None to skip the member.
String names of common filters are accepted.
"""
directories = []
filter_function = self._get_filter_function(filter)
if members is None:
members = self
for tarinfo in members:
for member in members:
tarinfo = self._get_extract_tarinfo(member, filter_function, path)
if tarinfo is None:
continue
if tarinfo.isdir():
# Extract directories with a safe mode.
# For directories, delay setting attributes until later,
# since permissions can interfere with extraction and
# extracting contents can reset mtime.
directories.append(tarinfo)
tarinfo = copy.copy(tarinfo)
tarinfo.mode = 0o700
# Do not set_attrs directories, as we will do that further down
self.extract(tarinfo, path, set_attrs=not tarinfo.isdir(),
numeric_owner=numeric_owner)
self._extract_one(tarinfo, path, set_attrs=not tarinfo.isdir(),
numeric_owner=numeric_owner)
# Reverse sort directories.
directories.sort(key=lambda a: a.name)
directories.reverse()
directories.sort(key=lambda a: a.name, reverse=True)
# Set correct owner, mtime and filemode on directories.
for tarinfo in directories:
@ -2057,12 +2254,10 @@ def extractall(self, path=".", members=None, *, numeric_owner=False):
self.utime(tarinfo, dirpath)
self.chmod(tarinfo, dirpath)
except ExtractError as e:
if self.errorlevel > 1:
raise
else:
self._dbg(1, "tarfile: %s" % e)
self._handle_nonfatal_error(e)
def extract(self, member, path="", set_attrs=True, *, numeric_owner=False):
def extract(self, member, path="", set_attrs=True, *, numeric_owner=False,
filter=None):
"""Extract a member from the archive to the current working directory,
using its full name. Its file information is extracted as accurately
as possible. `member' may be a filename or a TarInfo object. You can
@ -2070,35 +2265,70 @@ def extract(self, member, path="", set_attrs=True, *, numeric_owner=False):
mtime, mode) are set unless `set_attrs' is False. If `numeric_owner`
is True, only the numbers for user/group names are used and not
the names.
"""
self._check("r")
The `filter` function will be called before extraction.
It can return a changed TarInfo or None to skip the member.
String names of common filters are accepted.
"""
filter_function = self._get_filter_function(filter)
tarinfo = self._get_extract_tarinfo(member, filter_function, path)
if tarinfo is not None:
self._extract_one(tarinfo, path, set_attrs, numeric_owner)
def _get_extract_tarinfo(self, member, filter_function, path):
"""Get filtered TarInfo (or None) from member, which might be a str"""
if isinstance(member, str):
tarinfo = self.getmember(member)
else:
tarinfo = member
unfiltered = tarinfo
try:
tarinfo = filter_function(tarinfo, path)
except (OSError, FilterError) as e:
self._handle_fatal_error(e)
except ExtractError as e:
self._handle_nonfatal_error(e)
if tarinfo is None:
self._dbg(2, "tarfile: Excluded %r" % unfiltered.name)
return None
# Prepare the link target for makelink().
if tarinfo.islnk():
tarinfo = copy.copy(tarinfo)
tarinfo._link_target = os.path.join(path, tarinfo.linkname)
return tarinfo
def _extract_one(self, tarinfo, path, set_attrs, numeric_owner):
"""Extract from filtered tarinfo to disk"""
self._check("r")
try:
self._extract_member(tarinfo, os.path.join(path, tarinfo.name),
set_attrs=set_attrs,
numeric_owner=numeric_owner)
except OSError as e:
if self.errorlevel > 0:
raise
else:
if e.filename is None:
self._dbg(1, "tarfile: %s" % e.strerror)
else:
self._dbg(1, "tarfile: %s %r" % (e.strerror, e.filename))
self._handle_fatal_error(e)
except ExtractError as e:
if self.errorlevel > 1:
raise
self._handle_nonfatal_error(e)
def _handle_nonfatal_error(self, e):
"""Handle non-fatal error (ExtractError) according to errorlevel"""
if self.errorlevel > 1:
raise
else:
self._dbg(1, "tarfile: %s" % e)
def _handle_fatal_error(self, e):
"""Handle "fatal" error according to self.errorlevel"""
if self.errorlevel > 0:
raise
elif isinstance(e, OSError):
if e.filename is None:
self._dbg(1, "tarfile: %s" % e.strerror)
else:
self._dbg(1, "tarfile: %s" % e)
self._dbg(1, "tarfile: %s %r" % (e.strerror, e.filename))
else:
self._dbg(1, "tarfile: %s %s" % (type(e).__name__, e))
def extractfile(self, member):
"""Extract a member from the archive as a file object. `member' may be
@ -2185,9 +2415,13 @@ def makedir(self, tarinfo, targetpath):
"""Make a directory called targetpath.
"""
try:
# Use a safe mode for the directory, the real mode is set
# later in _extract_member().
os.mkdir(targetpath, 0o700)
if tarinfo.mode is None:
# Use the system's default mode
os.mkdir(targetpath)
else:
# Use a safe mode for the directory, the real mode is set
# later in _extract_member().
os.mkdir(targetpath, 0o700)
except FileExistsError:
pass
@ -2230,6 +2464,9 @@ def makedev(self, tarinfo, targetpath):
raise ExtractError("special devices not supported by system")
mode = tarinfo.mode
if mode is None:
# Use mknod's default
mode = 0o600
if tarinfo.isblk():
mode |= stat.S_IFBLK
else:
@ -2251,7 +2488,6 @@ def makelink(self, tarinfo, targetpath):
os.unlink(targetpath)
os.symlink(tarinfo.linkname, targetpath)
else:
# See extract().
if os.path.exists(tarinfo._link_target):
os.link(tarinfo._link_target, targetpath)
else:
@ -2276,15 +2512,19 @@ def chown(self, tarinfo, targetpath, numeric_owner):
u = tarinfo.uid
if not numeric_owner:
try:
if grp:
if grp and tarinfo.gname:
g = grp.getgrnam(tarinfo.gname)[2]
except KeyError:
pass
try:
if pwd:
if pwd and tarinfo.uname:
u = pwd.getpwnam(tarinfo.uname)[2]
except KeyError:
pass
if g is None:
g = -1
if u is None:
u = -1
try:
if tarinfo.issym() and hasattr(os, "lchown"):
os.lchown(targetpath, u, g)
@ -2296,6 +2536,8 @@ def chown(self, tarinfo, targetpath, numeric_owner):
def chmod(self, tarinfo, targetpath):
"""Set file permissions of targetpath according to tarinfo.
"""
if tarinfo.mode is None:
return
try:
os.chmod(targetpath, tarinfo.mode)
except OSError:
@ -2304,10 +2546,13 @@ def chmod(self, tarinfo, targetpath):
def utime(self, tarinfo, targetpath):
"""Set modification time of targetpath according to tarinfo.
"""
mtime = tarinfo.mtime
if mtime is None:
return
if not hasattr(os, 'utime'):
return
try:
os.utime(targetpath, (tarinfo.mtime, tarinfo.mtime))
os.utime(targetpath, (mtime, mtime))
except OSError:
raise ExtractError("could not change modification time")
@ -2383,13 +2628,26 @@ def _getmember(self, name, tarinfo=None, normalize=False):
members = self.getmembers()
# Limit the member search list up to tarinfo.
skipping = False
if tarinfo is not None:
members = members[:members.index(tarinfo)]
try:
index = members.index(tarinfo)
except ValueError:
# The given starting point might be a (modified) copy.
# We'll later skip members until we find an equivalent.
skipping = True
else:
# Happy fast path
members = members[:index]
if normalize:
name = os.path.normpath(name)
for member in reversed(members):
if skipping:
if tarinfo.offset == member.offset:
skipping = False
continue
if normalize:
member_name = os.path.normpath(member.name)
else:
@ -2398,6 +2656,10 @@ def _getmember(self, name, tarinfo=None, normalize=False):
if name == member_name:
return member
if skipping:
# Starting point was not found
raise ValueError(tarinfo)
def _load(self):
"""Read through the entire archive file and look for readable
members.
@ -2490,6 +2752,7 @@ def __exit__(self, type, value, traceback):
#--------------------
# exported functions
#--------------------
def is_tarfile(name):
"""Return True if name points to a tar archive that we
are able to handle, else return False.
@ -2516,6 +2779,10 @@ def main():
parser = argparse.ArgumentParser(description=description)
parser.add_argument('-v', '--verbose', action='store_true', default=False,
help='Verbose output')
parser.add_argument('--filter', metavar='<filtername>',
choices=_NAMED_FILTERS,
help='Filter for extraction')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-l', '--list', metavar='<tarfile>',
help='Show listing of a tarfile')
@ -2527,8 +2794,12 @@ def main():
help='Create tarfile from sources')
group.add_argument('-t', '--test', metavar='<tarfile>',
help='Test if a tarfile is valid')
args = parser.parse_args()
if args.filter and args.extract is None:
parser.exit(1, '--filter is only valid for extraction\n')
if args.test is not None:
src = args.test
if is_tarfile(src):
@ -2559,7 +2830,7 @@ def main():
if is_tarfile(src):
with TarFile.open(src, 'r:*') as tf:
tf.extractall(path=curdir)
tf.extractall(path=curdir, filter=args.filter)
if args.verbose:
if curdir == '.':
msg = '{!r} file is extracted.'.format(src)