Completed str/unicode unification.

All tests pass, but maybe some tests have become unnecessary now.
Removed PaxUnicodeTest, added MiscTest.

TarFile.extractfile() returns a binary file object which can be used
with a TextIOWrapper for text I/O.
This commit is contained in:
Lars Gustäbel 2007-08-07 18:36:16 +00:00
parent cd869d8d41
commit b506dc32c1
2 changed files with 171 additions and 212 deletions

View file

@ -72,33 +72,33 @@
#--------------------------------------------------------- #---------------------------------------------------------
# tar constants # tar constants
#--------------------------------------------------------- #---------------------------------------------------------
NUL = "\0" # the null character NUL = b"\0" # the null character
BLOCKSIZE = 512 # length of processing blocks BLOCKSIZE = 512 # length of processing blocks
RECORDSIZE = BLOCKSIZE * 20 # length of records RECORDSIZE = BLOCKSIZE * 20 # length of records
GNU_MAGIC = "ustar \0" # magic gnu tar string GNU_MAGIC = b"ustar \0" # magic gnu tar string
POSIX_MAGIC = "ustar\x0000" # magic posix tar string POSIX_MAGIC = b"ustar\x0000" # magic posix tar string
LENGTH_NAME = 100 # maximum length of a filename LENGTH_NAME = 100 # maximum length of a filename
LENGTH_LINK = 100 # maximum length of a linkname LENGTH_LINK = 100 # maximum length of a linkname
LENGTH_PREFIX = 155 # maximum length of the prefix field LENGTH_PREFIX = 155 # maximum length of the prefix field
REGTYPE = "0" # regular file REGTYPE = b"0" # regular file
AREGTYPE = "\0" # regular file AREGTYPE = b"\0" # regular file
LNKTYPE = "1" # link (inside tarfile) LNKTYPE = b"1" # link (inside tarfile)
SYMTYPE = "2" # symbolic link SYMTYPE = b"2" # symbolic link
CHRTYPE = "3" # character special device CHRTYPE = b"3" # character special device
BLKTYPE = "4" # block special device BLKTYPE = b"4" # block special device
DIRTYPE = "5" # directory DIRTYPE = b"5" # directory
FIFOTYPE = "6" # fifo special device FIFOTYPE = b"6" # fifo special device
CONTTYPE = "7" # contiguous file CONTTYPE = b"7" # contiguous file
GNUTYPE_LONGNAME = "L" # GNU tar longname GNUTYPE_LONGNAME = b"L" # GNU tar longname
GNUTYPE_LONGLINK = "K" # GNU tar longlink GNUTYPE_LONGLINK = b"K" # GNU tar longlink
GNUTYPE_SPARSE = "S" # GNU tar sparse file GNUTYPE_SPARSE = b"S" # GNU tar sparse file
XHDTYPE = "x" # POSIX.1-2001 extended header XHDTYPE = b"x" # POSIX.1-2001 extended header
XGLTYPE = "g" # POSIX.1-2001 global header XGLTYPE = b"g" # POSIX.1-2001 global header
SOLARIS_XHDTYPE = "X" # Solaris extended header SOLARIS_XHDTYPE = b"X" # Solaris extended header
USTAR_FORMAT = 0 # POSIX.1-1988 (ustar) format USTAR_FORMAT = 0 # POSIX.1-1988 (ustar) format
GNU_FORMAT = 1 # GNU tar format GNU_FORMAT = 1 # GNU tar format
@ -173,19 +173,19 @@
# Some useful functions # Some useful functions
#--------------------------------------------------------- #---------------------------------------------------------
def stn(s, length): def stn(s, length, encoding, errors):
"""Convert a python string to a null-terminated string buffer. """Convert a string to a null-terminated bytes object.
""" """
s = s.encode(encoding, errors)
return s[:length] + (length - len(s)) * NUL return s[:length] + (length - len(s)) * NUL
def nts(s): def nts(s, encoding, errors):
"""Convert a null-terminated string field to a python string. """Convert a null-terminated bytes object to a string.
""" """
# Use the string up to the first null char. p = s.find(b"\0")
p = s.find("\0") if p != -1:
if p == -1: s = s[:p]
return s return s.decode(encoding, errors)
return s[:p]
def nti(s): def nti(s):
"""Convert a number field to a python number. """Convert a number field to a python number.
@ -194,7 +194,7 @@ def nti(s):
# itn() below. # itn() below.
if s[0] != chr(0o200): if s[0] != chr(0o200):
try: try:
n = int(nts(s) or "0", 8) n = int(nts(s, "ascii", "strict") or "0", 8)
except ValueError: except ValueError:
raise HeaderError("invalid header") raise HeaderError("invalid header")
else: else:
@ -214,7 +214,7 @@ def itn(n, digits=8, format=DEFAULT_FORMAT):
# encoding, the following digits-1 bytes are a big-endian # encoding, the following digits-1 bytes are a big-endian
# representation. This allows values up to (256**(digits-1))-1. # representation. This allows values up to (256**(digits-1))-1.
if 0 <= n < 8 ** (digits - 1): if 0 <= n < 8 ** (digits - 1):
s = "%0*o" % (digits - 1, n) + NUL s = bytes("%0*o" % (digits - 1, n)) + NUL
else: else:
if format != GNU_FORMAT or n >= 256 ** (digits - 1): if format != GNU_FORMAT or n >= 256 ** (digits - 1):
raise ValueError("overflow in number field") raise ValueError("overflow in number field")
@ -224,33 +224,13 @@ def itn(n, digits=8, format=DEFAULT_FORMAT):
# this could raise OverflowError. # this could raise OverflowError.
n = struct.unpack("L", struct.pack("l", n))[0] n = struct.unpack("L", struct.pack("l", n))[0]
s = "" s = b""
for i in range(digits - 1): for i in range(digits - 1):
s = chr(n & 0o377) + s s.insert(0, n & 0o377)
n >>= 8 n >>= 8
s = chr(0o200) + s s.insert(0, 0o200)
return s return s
def uts(s, encoding, errors):
"""Convert a unicode object to a string.
"""
if errors == "utf-8":
# An extra error handler similar to the -o invalid=UTF-8 option
# in POSIX.1-2001. Replace untranslatable characters with their
# UTF-8 representation.
try:
return s.encode(encoding, "strict")
except UnicodeEncodeError:
x = []
for c in s:
try:
x.append(c.encode(encoding, "strict"))
except UnicodeEncodeError:
x.append(c.encode("utf8"))
return "".join(x)
else:
return s.encode(encoding, errors)
def calc_chksums(buf): def calc_chksums(buf):
"""Calculate the checksum for a member's header by summing up all """Calculate the checksum for a member's header by summing up all
characters except for the chksum field which is treated as if characters except for the chksum field which is treated as if
@ -412,7 +392,7 @@ def __init__(self, name, mode, comptype, fileobj, bufsize):
self.comptype = comptype self.comptype = comptype
self.fileobj = fileobj self.fileobj = fileobj
self.bufsize = bufsize self.bufsize = bufsize
self.buf = "" self.buf = b""
self.pos = 0 self.pos = 0
self.closed = False self.closed = False
@ -434,7 +414,7 @@ def __init__(self, name, mode, comptype, fileobj, bufsize):
except ImportError: except ImportError:
raise CompressionError("bz2 module is not available") raise CompressionError("bz2 module is not available")
if mode == "r": if mode == "r":
self.dbuf = "" self.dbuf = b""
self.cmp = bz2.BZ2Decompressor() self.cmp = bz2.BZ2Decompressor()
else: else:
self.cmp = bz2.BZ2Compressor() self.cmp = bz2.BZ2Compressor()
@ -451,10 +431,11 @@ def _init_write_gz(self):
self.zlib.DEF_MEM_LEVEL, self.zlib.DEF_MEM_LEVEL,
0) 0)
timestamp = struct.pack("<L", int(time.time())) timestamp = struct.pack("<L", int(time.time()))
self.__write("\037\213\010\010%s\002\377" % timestamp) self.__write(b"\037\213\010\010" + timestamp + b"\002\377")
if self.name.endswith(".gz"): if self.name.endswith(".gz"):
self.name = self.name[:-3] self.name = self.name[:-3]
self.__write(self.name + NUL) # RFC1952 says we must use ISO-8859-1 for the FNAME field.
self.__write(self.name.encode("iso-8859-1", "replace") + NUL)
def write(self, s): def write(self, s):
"""Write string s to the stream. """Write string s to the stream.
@ -487,7 +468,7 @@ def close(self):
if self.mode == "w" and self.buf: if self.mode == "w" and self.buf:
self.fileobj.write(self.buf) self.fileobj.write(self.buf)
self.buf = "" self.buf = b""
if self.comptype == "gz": if self.comptype == "gz":
# The native zlib crc is an unsigned 32-bit integer, but # The native zlib crc is an unsigned 32-bit integer, but
# the Python wrapper implicitly casts that to a signed C # the Python wrapper implicitly casts that to a signed C
@ -507,12 +488,12 @@ def _init_read_gz(self):
"""Initialize for reading a gzip compressed fileobj. """Initialize for reading a gzip compressed fileobj.
""" """
self.cmp = self.zlib.decompressobj(-self.zlib.MAX_WBITS) self.cmp = self.zlib.decompressobj(-self.zlib.MAX_WBITS)
self.dbuf = "" self.dbuf = b""
# taken from gzip.GzipFile with some alterations # taken from gzip.GzipFile with some alterations
if self.__read(2) != "\037\213": if self.__read(2) != b"\037\213":
raise ReadError("not a gzip file") raise ReadError("not a gzip file")
if self.__read(1) != "\010": if self.__read(1) != b"\010":
raise CompressionError("unsupported compression method") raise CompressionError("unsupported compression method")
flag = ord(self.__read(1)) flag = ord(self.__read(1))
@ -577,7 +558,6 @@ def _read(self, size):
return self.__read(size) return self.__read(size)
c = len(self.dbuf) c = len(self.dbuf)
t = [self.dbuf]
while c < size: while c < size:
buf = self.__read(self.bufsize) buf = self.__read(self.bufsize)
if not buf: if not buf:
@ -586,27 +566,26 @@ def _read(self, size):
buf = self.cmp.decompress(buf) buf = self.cmp.decompress(buf)
except IOError: except IOError:
raise ReadError("invalid compressed data") raise ReadError("invalid compressed data")
t.append(buf) self.dbuf += buf
c += len(buf) c += len(buf)
t = "".join(t) buf = self.dbuf[:size]
self.dbuf = t[size:] self.dbuf = self.dbuf[size:]
return t[:size] return buf
def __read(self, size): def __read(self, size):
"""Return size bytes from stream. If internal buffer is empty, """Return size bytes from stream. If internal buffer is empty,
read another block from the stream. read another block from the stream.
""" """
c = len(self.buf) c = len(self.buf)
t = [self.buf]
while c < size: while c < size:
buf = self.fileobj.read(self.bufsize) buf = self.fileobj.read(self.bufsize)
if not buf: if not buf:
break break
t.append(buf) self.buf += buf
c += len(buf) c += len(buf)
t = "".join(t) buf = self.buf[:size]
self.buf = t[size:] self.buf = self.buf[size:]
return t[:size] return buf
# class _Stream # class _Stream
class _StreamProxy(object): class _StreamProxy(object):
@ -623,7 +602,7 @@ def read(self, size):
return self.buf return self.buf
def getcomptype(self): def getcomptype(self):
if self.buf.startswith("\037\213\010"): if self.buf.startswith(b"\037\213\010"):
return "gz" return "gz"
if self.buf.startswith("BZh91"): if self.buf.startswith("BZh91"):
return "bz2" return "bz2"
@ -655,22 +634,20 @@ def init(self):
if self.mode == "r": if self.mode == "r":
self.bz2obj = bz2.BZ2Decompressor() self.bz2obj = bz2.BZ2Decompressor()
self.fileobj.seek(0) self.fileobj.seek(0)
self.buf = "" self.buf = b""
else: else:
self.bz2obj = bz2.BZ2Compressor() self.bz2obj = bz2.BZ2Compressor()
def read(self, size): def read(self, size):
b = [self.buf]
x = len(self.buf) x = len(self.buf)
while x < size: while x < size:
try: try:
raw = self.fileobj.read(self.blocksize) raw = self.fileobj.read(self.blocksize)
data = self.bz2obj.decompress(raw) data = self.bz2obj.decompress(raw)
b.append(data) self.buf += data
except EOFError: except EOFError:
break break
x += len(data) x += len(data)
self.buf = "".join(b)
buf = self.buf[:size] buf = self.buf[:size]
self.buf = self.buf[size:] self.buf = self.buf[size:]
@ -713,6 +690,12 @@ def __init__(self, fileobj, offset, size, sparse=None):
self.sparse = sparse self.sparse = sparse
self.position = 0 self.position = 0
def seekable(self):
if not hasattr(self.fileobj, "seekable"):
# XXX gzip.GzipFile and bz2.BZ2File
return True
return self.fileobj.seekable()
def tell(self): def tell(self):
"""Return the current file position. """Return the current file position.
""" """
@ -746,14 +729,14 @@ def readnormal(self, size):
def readsparse(self, size): def readsparse(self, size):
"""Read operation for sparse files. """Read operation for sparse files.
""" """
data = [] data = b""
while size > 0: while size > 0:
buf = self.readsparsesection(size) buf = self.readsparsesection(size)
if not buf: if not buf:
break break
size -= len(buf) size -= len(buf)
data.append(buf) data += buf
return "".join(data) return data
def readsparsesection(self, size): def readsparsesection(self, size):
"""Read a single section of a sparse file. """Read a single section of a sparse file.
@ -761,7 +744,7 @@ def readsparsesection(self, size):
section = self.sparse.find(self.position) section = self.sparse.find(self.position)
if section is None: if section is None:
return "" return b""
size = min(size, section.offset + section.size - self.position) size = min(size, section.offset + section.size - self.position)
@ -793,7 +776,16 @@ def __init__(self, tarfile, tarinfo):
self.size = tarinfo.size self.size = tarinfo.size
self.position = 0 self.position = 0
self.buffer = "" self.buffer = b""
def readable(self):
return True
def writable(self):
return False
def seekable(self):
return self.fileobj.seekable()
def read(self, size=None): def read(self, size=None):
"""Read at most size bytes from the file. If size is not """Read at most size bytes from the file. If size is not
@ -802,11 +794,11 @@ def read(self, size=None):
if self.closed: if self.closed:
raise ValueError("I/O operation on closed file") raise ValueError("I/O operation on closed file")
buf = "" buf = b""
if self.buffer: if self.buffer:
if size is None: if size is None:
buf = self.buffer buf = self.buffer
self.buffer = "" self.buffer = b""
else: else:
buf = self.buffer[:size] buf = self.buffer[:size]
self.buffer = self.buffer[size:] self.buffer = self.buffer[size:]
@ -819,6 +811,9 @@ def read(self, size=None):
self.position += len(buf) self.position += len(buf)
return buf return buf
# XXX TextIOWrapper uses the read1() method.
read1 = read
def readline(self, size=-1): def readline(self, size=-1):
"""Read one entire line from the file. If size is present """Read one entire line from the file. If size is present
and non-negative, return a string with at most that and non-negative, return a string with at most that
@ -827,16 +822,14 @@ def readline(self, size=-1):
if self.closed: if self.closed:
raise ValueError("I/O operation on closed file") raise ValueError("I/O operation on closed file")
if "\n" in self.buffer: pos = self.buffer.find(b"\n") + 1
pos = self.buffer.find("\n") + 1 if pos == 0:
else: # no newline found.
buffers = [self.buffer]
while True: while True:
buf = self.fileobj.read(self.blocksize) buf = self.fileobj.read(self.blocksize)
buffers.append(buf) self.buffer += buf
if not buf or "\n" in buf: if not buf or b"\n" in buf:
self.buffer = "".join(buffers) pos = self.buffer.find(b"\n") + 1
pos = self.buffer.find("\n") + 1
if pos == 0: if pos == 0:
# no newline found. # no newline found.
pos = len(self.buffer) pos = len(self.buffer)
@ -886,7 +879,7 @@ def seek(self, pos, whence=os.SEEK_SET):
else: else:
raise ValueError("Invalid argument") raise ValueError("Invalid argument")
self.buffer = "" self.buffer = b""
self.fileobj.seek(self.position) self.fileobj.seek(self.position)
def close(self): def close(self):
@ -955,7 +948,7 @@ def _setlinkpath(self, linkname):
def __repr__(self): def __repr__(self):
return "<%s %r at %#x>" % (self.__class__.__name__,self.name,id(self)) return "<%s %r at %#x>" % (self.__class__.__name__,self.name,id(self))
def get_info(self, encoding, errors): def get_info(self):
"""Return the TarInfo's attributes as a dictionary. """Return the TarInfo's attributes as a dictionary.
""" """
info = { info = {
@ -977,27 +970,23 @@ def get_info(self, encoding, errors):
if info["type"] == DIRTYPE and not info["name"].endswith("/"): if info["type"] == DIRTYPE and not info["name"].endswith("/"):
info["name"] += "/" info["name"] += "/"
for key in ("name", "linkname", "uname", "gname"):
if isinstance(info[key], str):
info[key] = info[key].encode(encoding, errors)
return info return info
def tobuf(self, format=DEFAULT_FORMAT, encoding=ENCODING, errors="strict"): def tobuf(self, format=DEFAULT_FORMAT, encoding=ENCODING, errors="strict"):
"""Return a tar header as a string of 512 byte blocks. """Return a tar header as a string of 512 byte blocks.
""" """
info = self.get_info(encoding, errors) info = self.get_info()
if format == USTAR_FORMAT: if format == USTAR_FORMAT:
return self.create_ustar_header(info) return self.create_ustar_header(info, encoding, errors)
elif format == GNU_FORMAT: elif format == GNU_FORMAT:
return self.create_gnu_header(info) return self.create_gnu_header(info, encoding, errors)
elif format == PAX_FORMAT: elif format == PAX_FORMAT:
return self.create_pax_header(info, encoding, errors) return self.create_pax_header(info, encoding, errors)
else: else:
raise ValueError("invalid format") raise ValueError("invalid format")
def create_ustar_header(self, info): def create_ustar_header(self, info, encoding, errors):
"""Return the object as a ustar header block. """Return the object as a ustar header block.
""" """
info["magic"] = POSIX_MAGIC info["magic"] = POSIX_MAGIC
@ -1008,21 +997,21 @@ def create_ustar_header(self, info):
if len(info["name"]) > LENGTH_NAME: if len(info["name"]) > LENGTH_NAME:
info["prefix"], info["name"] = self._posix_split_name(info["name"]) info["prefix"], info["name"] = self._posix_split_name(info["name"])
return self._create_header(info, USTAR_FORMAT) return self._create_header(info, USTAR_FORMAT, encoding, errors)
def create_gnu_header(self, info): def create_gnu_header(self, info, encoding, errors):
"""Return the object as a GNU header block sequence. """Return the object as a GNU header block sequence.
""" """
info["magic"] = GNU_MAGIC info["magic"] = GNU_MAGIC
buf = "" buf = b""
if len(info["linkname"]) > LENGTH_LINK: if len(info["linkname"]) > LENGTH_LINK:
buf += self._create_gnu_long_header(info["linkname"], GNUTYPE_LONGLINK) buf += self._create_gnu_long_header(info["linkname"], GNUTYPE_LONGLINK, encoding, errors)
if len(info["name"]) > LENGTH_NAME: if len(info["name"]) > LENGTH_NAME:
buf += self._create_gnu_long_header(info["name"], GNUTYPE_LONGNAME) buf += self._create_gnu_long_header(info["name"], GNUTYPE_LONGNAME, encoding, errors)
return buf + self._create_header(info, GNU_FORMAT) return buf + self._create_header(info, GNU_FORMAT, encoding, errors)
def create_pax_header(self, info, encoding, errors): def create_pax_header(self, info, encoding, errors):
"""Return the object as a ustar header block. If it cannot be """Return the object as a ustar header block. If it cannot be
@ -1042,17 +1031,15 @@ def create_pax_header(self, info, encoding, errors):
# The pax header has priority. # The pax header has priority.
continue continue
val = info[name].decode(encoding, errors)
# Try to encode the string as ASCII. # Try to encode the string as ASCII.
try: try:
val.encode("ascii") info[name].encode("ascii", "strict")
except UnicodeEncodeError: except UnicodeEncodeError:
pax_headers[hname] = val pax_headers[hname] = info[name]
continue continue
if len(info[name]) > length: if len(info[name]) > length:
pax_headers[hname] = val pax_headers[hname] = info[name]
# Test number fields for values that exceed the field limit or values # Test number fields for values that exceed the field limit or values
# that like to be stored as float. # that like to be stored as float.
@ -1069,17 +1056,17 @@ def create_pax_header(self, info, encoding, errors):
# Create a pax extended header if necessary. # Create a pax extended header if necessary.
if pax_headers: if pax_headers:
buf = self._create_pax_generic_header(pax_headers) buf = self._create_pax_generic_header(pax_headers, XHDTYPE, encoding, errors)
else: else:
buf = "" buf = b""
return buf + self._create_header(info, USTAR_FORMAT) return buf + self._create_header(info, USTAR_FORMAT, encoding, errors)
@classmethod @classmethod
def create_pax_global_header(cls, pax_headers): def create_pax_global_header(cls, pax_headers, encoding, errors):
"""Return the object as a pax global header block sequence. """Return the object as a pax global header block sequence.
""" """
return cls._create_pax_generic_header(pax_headers, type=XGLTYPE) return cls._create_pax_generic_header(pax_headers, XGLTYPE, encoding, errors)
def _posix_split_name(self, name): def _posix_split_name(self, name):
"""Split a name longer than 100 chars into a prefix """Split a name longer than 100 chars into a prefix
@ -1097,31 +1084,31 @@ def _posix_split_name(self, name):
return prefix, name return prefix, name
@staticmethod @staticmethod
def _create_header(info, format): def _create_header(info, format, encoding, errors):
"""Return a header block. info is a dictionary with file """Return a header block. info is a dictionary with file
information, format must be one of the *_FORMAT constants. information, format must be one of the *_FORMAT constants.
""" """
parts = [ parts = [
stn(info.get("name", ""), 100), stn(info.get("name", ""), 100, encoding, errors),
itn(info.get("mode", 0) & 0o7777, 8, format), itn(info.get("mode", 0) & 0o7777, 8, format),
itn(info.get("uid", 0), 8, format), itn(info.get("uid", 0), 8, format),
itn(info.get("gid", 0), 8, format), itn(info.get("gid", 0), 8, format),
itn(info.get("size", 0), 12, format), itn(info.get("size", 0), 12, format),
itn(info.get("mtime", 0), 12, format), itn(info.get("mtime", 0), 12, format),
" ", # checksum field b" ", # checksum field
info.get("type", REGTYPE), info.get("type", REGTYPE),
stn(info.get("linkname", ""), 100), stn(info.get("linkname", ""), 100, encoding, errors),
stn(info.get("magic", POSIX_MAGIC), 8), info.get("magic", POSIX_MAGIC),
stn(info.get("uname", "root"), 32), stn(info.get("uname", "root"), 32, encoding, errors),
stn(info.get("gname", "root"), 32), stn(info.get("gname", "root"), 32, encoding, errors),
itn(info.get("devmajor", 0), 8, format), itn(info.get("devmajor", 0), 8, format),
itn(info.get("devminor", 0), 8, format), itn(info.get("devminor", 0), 8, format),
stn(info.get("prefix", ""), 155) stn(info.get("prefix", ""), 155, encoding, errors)
] ]
buf = struct.pack("%ds" % BLOCKSIZE, "".join(parts)) buf = struct.pack("%ds" % BLOCKSIZE, b"".join(parts))
chksum = calc_chksums(buf[-BLOCKSIZE:])[0] chksum = calc_chksums(buf[-BLOCKSIZE:])[0]
buf = buf[:-364] + "%06o\0" % chksum + buf[-357:] buf = buf[:-364] + bytes("%06o\0" % chksum) + buf[-357:]
return buf return buf
@staticmethod @staticmethod
@ -1135,11 +1122,11 @@ def _create_payload(payload):
return payload return payload
@classmethod @classmethod
def _create_gnu_long_header(cls, name, type): def _create_gnu_long_header(cls, name, type, encoding, errors):
"""Return a GNUTYPE_LONGNAME or GNUTYPE_LONGLINK sequence """Return a GNUTYPE_LONGNAME or GNUTYPE_LONGLINK sequence
for name. for name.
""" """
name += NUL name = name.encode(encoding, errors) + NUL
info = {} info = {}
info["name"] = "././@LongLink" info["name"] = "././@LongLink"
@ -1148,16 +1135,16 @@ def _create_gnu_long_header(cls, name, type):
info["magic"] = GNU_MAGIC info["magic"] = GNU_MAGIC
# create extended header + name blocks. # create extended header + name blocks.
return cls._create_header(info, USTAR_FORMAT) + \ return cls._create_header(info, USTAR_FORMAT, encoding, errors) + \
cls._create_payload(name) cls._create_payload(name)
@classmethod @classmethod
def _create_pax_generic_header(cls, pax_headers, type=XHDTYPE): def _create_pax_generic_header(cls, pax_headers, type, encoding, errors):
"""Return a POSIX.1-2001 extended or global header sequence """Return a POSIX.1-2001 extended or global header sequence
that contains a list of keyword, value pairs. The values that contains a list of keyword, value pairs. The values
must be unicode objects. must be strings.
""" """
records = [] records = b""
for keyword, value in pax_headers.items(): for keyword, value in pax_headers.items():
keyword = keyword.encode("utf8") keyword = keyword.encode("utf8")
value = value.encode("utf8") value = value.encode("utf8")
@ -1168,8 +1155,7 @@ def _create_pax_generic_header(cls, pax_headers, type=XHDTYPE):
if n == p: if n == p:
break break
p = n p = n
records.append("%d %s=%s\n" % (p, keyword, value)) records += bytes(str(p)) + b" " + keyword + b"=" + value + b"\n"
records = "".join(records)
# We use a hardcoded "././@PaxHeader" name like star does # We use a hardcoded "././@PaxHeader" name like star does
# instead of the one that POSIX recommends. # instead of the one that POSIX recommends.
@ -1180,12 +1166,12 @@ def _create_pax_generic_header(cls, pax_headers, type=XHDTYPE):
info["magic"] = POSIX_MAGIC info["magic"] = POSIX_MAGIC
# Create pax header + record blocks. # Create pax header + record blocks.
return cls._create_header(info, USTAR_FORMAT) + \ return cls._create_header(info, USTAR_FORMAT, encoding, errors) + \
cls._create_payload(records) cls._create_payload(records)
@classmethod @classmethod
def frombuf(cls, buf): def frombuf(cls, buf, encoding, errors):
"""Construct a TarInfo object from a 512 byte string buffer. """Construct a TarInfo object from a 512 byte bytes object.
""" """
if len(buf) != BLOCKSIZE: if len(buf) != BLOCKSIZE:
raise HeaderError("truncated header") raise HeaderError("truncated header")
@ -1198,7 +1184,7 @@ def frombuf(cls, buf):
obj = cls() obj = cls()
obj.buf = buf obj.buf = buf
obj.name = nts(buf[0:100]) obj.name = nts(buf[0:100], encoding, errors)
obj.mode = nti(buf[100:108]) obj.mode = nti(buf[100:108])
obj.uid = nti(buf[108:116]) obj.uid = nti(buf[108:116])
obj.gid = nti(buf[116:124]) obj.gid = nti(buf[116:124])
@ -1206,12 +1192,12 @@ def frombuf(cls, buf):
obj.mtime = nti(buf[136:148]) obj.mtime = nti(buf[136:148])
obj.chksum = chksum obj.chksum = chksum
obj.type = buf[156:157] obj.type = buf[156:157]
obj.linkname = nts(buf[157:257]) obj.linkname = nts(buf[157:257], encoding, errors)
obj.uname = nts(buf[265:297]) obj.uname = nts(buf[265:297], encoding, errors)
obj.gname = nts(buf[297:329]) obj.gname = nts(buf[297:329], encoding, errors)
obj.devmajor = nti(buf[329:337]) obj.devmajor = nti(buf[329:337])
obj.devminor = nti(buf[337:345]) obj.devminor = nti(buf[337:345])
prefix = nts(buf[345:500]) prefix = nts(buf[345:500], encoding, errors)
# Old V7 tar format represents a directory as a regular # Old V7 tar format represents a directory as a regular
# file with a trailing slash. # file with a trailing slash.
@ -1235,7 +1221,7 @@ def fromtarfile(cls, tarfile):
buf = tarfile.fileobj.read(BLOCKSIZE) buf = tarfile.fileobj.read(BLOCKSIZE)
if not buf: if not buf:
return return
obj = cls.frombuf(buf) obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors)
obj.offset = tarfile.fileobj.tell() - BLOCKSIZE obj.offset = tarfile.fileobj.tell() - BLOCKSIZE
return obj._proc_member(tarfile) return obj._proc_member(tarfile)
@ -1295,9 +1281,9 @@ def _proc_gnulong(self, tarfile):
# the longname information. # the longname information.
next.offset = self.offset next.offset = self.offset
if self.type == GNUTYPE_LONGNAME: if self.type == GNUTYPE_LONGNAME:
next.name = nts(buf) next.name = nts(buf, tarfile.encoding, tarfile.errors)
elif self.type == GNUTYPE_LONGLINK: elif self.type == GNUTYPE_LONGLINK:
next.linkname = nts(buf) next.linkname = nts(buf, tarfile.encoding, tarfile.errors)
return next return next
@ -1324,12 +1310,12 @@ def _proc_sparse(self, tarfile):
lastpos = offset + numbytes lastpos = offset + numbytes
pos += 24 pos += 24
isextended = ord(buf[482]) isextended = bool(buf[482])
origsize = nti(buf[483:495]) origsize = nti(buf[483:495])
# If the isextended flag is given, # If the isextended flag is given,
# there are extra headers to process. # there are extra headers to process.
while isextended == 1: while isextended:
buf = tarfile.fileobj.read(BLOCKSIZE) buf = tarfile.fileobj.read(BLOCKSIZE)
pos = 0 pos = 0
for i in range(21): for i in range(21):
@ -1344,7 +1330,7 @@ def _proc_sparse(self, tarfile):
realpos += numbytes realpos += numbytes
lastpos = offset + numbytes lastpos = offset + numbytes
pos += 24 pos += 24
isextended = ord(buf[504]) isextended = bool(buf[504])
if lastpos < origsize: if lastpos < origsize:
sp.append(_hole(lastpos, origsize - lastpos)) sp.append(_hole(lastpos, origsize - lastpos))
@ -1431,8 +1417,6 @@ def _apply_pax_info(self, pax_headers, encoding, errors):
value = PAX_NUMBER_FIELDS[keyword](value) value = PAX_NUMBER_FIELDS[keyword](value)
except ValueError: except ValueError:
value = 0 value = 0
else:
value = uts(value, encoding, errors)
setattr(self, keyword, value) setattr(self, keyword, value)
@ -1542,7 +1526,7 @@ def __init__(self, name=None, mode="r", fileobj=None, format=None,
if errors is not None: if errors is not None:
self.errors = errors self.errors = errors
elif mode == "r": elif mode == "r":
self.errors = "utf-8" self.errors = "replace"
else: else:
self.errors = "strict" self.errors = "strict"
@ -1575,14 +1559,15 @@ def __init__(self, name=None, mode="r", fileobj=None, format=None,
while True: while True:
if self.next() is None: if self.next() is None:
if self.offset > 0: if self.offset > 0:
self.fileobj.seek(- BLOCKSIZE, 1) self.fileobj.seek(self.fileobj.tell() - BLOCKSIZE)
break break
if self.mode in "aw": if self.mode in "aw":
self._loaded = True self._loaded = True
if self.pax_headers: if self.pax_headers:
buf = self.tarinfo.create_pax_global_header(self.pax_headers.copy()) buf = self.tarinfo.create_pax_global_header(
self.pax_headers.copy(), self.encoding, self.errors)
self.fileobj.write(buf) self.fileobj.write(buf)
self.offset += len(buf) self.offset += len(buf)

View file

@ -2,6 +2,7 @@
import sys import sys
import os import os
import io
import shutil import shutil
import tempfile import tempfile
import StringIO import StringIO
@ -64,8 +65,8 @@ def test_fileobj_regular_file(self):
def test_fileobj_readlines(self): def test_fileobj_readlines(self):
self.tar.extract("ustar/regtype", TEMPDIR) self.tar.extract("ustar/regtype", TEMPDIR)
tarinfo = self.tar.getmember("ustar/regtype") tarinfo = self.tar.getmember("ustar/regtype")
fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU") fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "r")
fobj2 = self.tar.extractfile(tarinfo) fobj2 = io.TextIOWrapper(self.tar.extractfile(tarinfo))
lines1 = fobj1.readlines() lines1 = fobj1.readlines()
lines2 = fobj2.readlines() lines2 = fobj2.readlines()
@ -83,7 +84,7 @@ def test_fileobj_iter(self):
fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU") fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
fobj2 = self.tar.extractfile(tarinfo) fobj2 = self.tar.extractfile(tarinfo)
lines1 = fobj1.readlines() lines1 = fobj1.readlines()
lines2 = [line for line in fobj2] lines2 = list(io.TextIOWrapper(fobj2))
self.assert_(lines1 == lines2, self.assert_(lines1 == lines2,
"fileobj.__iter__() failed") "fileobj.__iter__() failed")
@ -115,11 +116,11 @@ def test_fileobj_seek(self):
fobj.seek(0, 2) fobj.seek(0, 2)
self.assertEqual(tarinfo.size, fobj.tell(), self.assertEqual(tarinfo.size, fobj.tell(),
"seek() to file's end failed") "seek() to file's end failed")
self.assert_(fobj.read() == "", self.assert_(fobj.read() == b"",
"read() at file's end did not return empty string") "read() at file's end did not return empty string")
fobj.seek(-tarinfo.size, 2) fobj.seek(-tarinfo.size, 2)
self.assertEqual(0, fobj.tell(), self.assertEqual(0, fobj.tell(),
"relative seek() to file's start failed") "relative seek() to file's end failed")
fobj.seek(512) fobj.seek(512)
s1 = fobj.readlines() s1 = fobj.readlines()
fobj.seek(512) fobj.seek(512)
@ -245,13 +246,13 @@ class DetectReadTest(unittest.TestCase):
def _testfunc_file(self, name, mode): def _testfunc_file(self, name, mode):
try: try:
tarfile.open(name, mode) tarfile.open(name, mode)
except tarfile.ReadError: except tarfile.ReadError as e:
self.fail() self.fail()
def _testfunc_fileobj(self, name, mode): def _testfunc_fileobj(self, name, mode):
try: try:
tarfile.open(name, mode, fileobj=open(name, "rb")) tarfile.open(name, mode, fileobj=open(name, "rb"))
except tarfile.ReadError: except tarfile.ReadError as e:
self.fail() self.fail()
def _test_modes(self, testfunc): def _test_modes(self, testfunc):
@ -393,7 +394,7 @@ def test_truncated_longname(self):
tarinfo = self.tar.getmember(longname) tarinfo = self.tar.getmember(longname)
offset = tarinfo.offset offset = tarinfo.offset
self.tar.fileobj.seek(offset) self.tar.fileobj.seek(offset)
fobj = StringIO.StringIO(self.tar.fileobj.read(3 * 512)) fobj = io.BytesIO(self.tar.fileobj.read(3 * 512))
self.assertRaises(tarfile.ReadError, tarfile.open, name="foo.tar", fileobj=fobj) self.assertRaises(tarfile.ReadError, tarfile.open, name="foo.tar", fileobj=fobj)
def test_header_offset(self): def test_header_offset(self):
@ -401,9 +402,9 @@ def test_header_offset(self):
# the preceding extended header. # the preceding extended header.
longname = self.subdir + "/" + "123/" * 125 + "longname" longname = self.subdir + "/" + "123/" * 125 + "longname"
offset = self.tar.getmember(longname).offset offset = self.tar.getmember(longname).offset
fobj = open(tarname) fobj = open(tarname, "rb")
fobj.seek(offset) fobj.seek(offset)
tarinfo = tarfile.TarInfo.frombuf(fobj.read(512)) tarinfo = tarfile.TarInfo.frombuf(fobj.read(512), "iso8859-1", "strict")
self.assertEqual(tarinfo.type, self.longnametype) self.assertEqual(tarinfo.type, self.longnametype)
@ -764,10 +765,10 @@ def test_pax_global_header(self):
self.assertEqual(tar.pax_headers, pax_headers) self.assertEqual(tar.pax_headers, pax_headers)
self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers) self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers)
# Test if all the fields are unicode. # Test if all the fields are strings.
for key, val in tar.pax_headers.items(): for key, val in tar.pax_headers.items():
self.assert_(type(key) is unicode) self.assert_(type(key) is not bytes)
self.assert_(type(val) is unicode) self.assert_(type(val) is not bytes)
if key in tarfile.PAX_NUMBER_FIELDS: if key in tarfile.PAX_NUMBER_FIELDS:
try: try:
tarfile.PAX_NUMBER_FIELDS[key](val) tarfile.PAX_NUMBER_FIELDS[key](val)
@ -815,20 +816,14 @@ def _test_unicode_filename(self, encoding):
tar.close() tar.close()
tar = tarfile.open(tmpname, encoding=encoding) tar = tarfile.open(tmpname, encoding=encoding)
self.assert_(type(tar.getnames()[0]) is not unicode) self.assert_(type(tar.getnames()[0]) is not bytes)
self.assertEqual(tar.getmembers()[0].name, name.encode(encoding)) self.assertEqual(tar.getmembers()[0].name, name)
tar.close() tar.close()
def test_unicode_filename_error(self): def test_unicode_filename_error(self):
tar = tarfile.open(tmpname, "w", format=self.format, encoding="ascii", errors="strict") tar = tarfile.open(tmpname, "w", format=self.format, encoding="ascii", errors="strict")
tarinfo = tarfile.TarInfo() tarinfo = tarfile.TarInfo()
tarinfo.name = "äöü"
if self.format == tarfile.PAX_FORMAT:
self.assertRaises(UnicodeError, tar.addfile, tarinfo)
else:
tar.addfile(tarinfo)
tarinfo.name = "äöü" tarinfo.name = "äöü"
self.assertRaises(UnicodeError, tar.addfile, tarinfo) self.assertRaises(UnicodeError, tar.addfile, tarinfo)
@ -851,7 +846,7 @@ def test_uname_unicode(self):
t.uname = name t.uname = name
t.gname = name t.gname = name
fobj = StringIO.StringIO() fobj = io.BytesIO()
tar = tarfile.open("foo.tar", mode="w", fileobj=fobj, format=self.format, encoding="iso8859-1") tar = tarfile.open("foo.tar", mode="w", fileobj=fobj, format=self.format, encoding="iso8859-1")
tar.addfile(t) tar.addfile(t)
tar.close() tar.close()
@ -862,46 +857,12 @@ def test_uname_unicode(self):
self.assertEqual(t.uname, "äöü") self.assertEqual(t.uname, "äöü")
self.assertEqual(t.gname, "äöü") self.assertEqual(t.gname, "äöü")
class GNUUnicodeTest(UstarUnicodeTest): class GNUUnicodeTest(UstarUnicodeTest):
format = tarfile.GNU_FORMAT format = tarfile.GNU_FORMAT
class PaxUnicodeTest(UstarUnicodeTest):
format = tarfile.PAX_FORMAT
def _create_unicode_name(self, name):
tar = tarfile.open(tmpname, "w", format=self.format)
t = tarfile.TarInfo()
t.pax_headers["path"] = name
tar.addfile(t)
tar.close()
def test_error_handlers(self):
# Test if the unicode error handlers work correctly for characters
# that cannot be expressed in a given encoding.
self._create_unicode_name("äöü")
for handler, name in (("utf-8", "äöü".encode("utf8")),
("replace", "???"), ("ignore", "")):
tar = tarfile.open(tmpname, format=self.format, encoding="ascii",
errors=handler)
self.assertEqual(tar.getnames()[0], name)
self.assertRaises(UnicodeError, tarfile.open, tmpname,
encoding="ascii", errors="strict")
def test_error_handler_utf8(self):
# Create a pathname that has one component representable using
# iso8859-1 and the other only in iso8859-15.
self._create_unicode_name("äöü/¤")
tar = tarfile.open(tmpname, format=self.format, encoding="iso8859-1",
errors="utf-8")
self.assertEqual(tar.getnames()[0], "äöü/" + "¤".encode("utf8"))
class AppendTest(unittest.TestCase): class AppendTest(unittest.TestCase):
# Test append mode (cp. patch #1652681). # Test append mode (cp. patch #1652681).
@ -1028,6 +989,19 @@ def test_pax_limits(self):
tarinfo.tobuf(tarfile.PAX_FORMAT) tarinfo.tobuf(tarfile.PAX_FORMAT)
class MiscTest(unittest.TestCase):
def test_char_fields(self):
self.assertEqual(tarfile.stn("foo", 8, "ascii", "strict"), b"foo\0\0\0\0\0")
self.assertEqual(tarfile.stn("foobar", 3, "ascii", "strict"), b"foo")
self.assertEqual(tarfile.nts(b"foo\0\0\0\0\0", "ascii", "strict"), "foo")
self.assertEqual(tarfile.nts(b"foo\0bar\0", "ascii", "strict"), "foo")
def test_number_fields(self):
self.assertEqual(tarfile.itn(1), b"0000001\x00")
self.assertEqual(tarfile.itn(0xffffffff), b"\x80\x00\x00\x00\xff\xff\xff\xff")
class GzipMiscReadTest(MiscReadTest): class GzipMiscReadTest(MiscReadTest):
tarname = gzipname tarname = gzipname
mode = "r:gz" mode = "r:gz"
@ -1075,9 +1049,9 @@ def test_main():
PaxWriteTest, PaxWriteTest,
UstarUnicodeTest, UstarUnicodeTest,
GNUUnicodeTest, GNUUnicodeTest,
PaxUnicodeTest,
AppendTest, AppendTest,
LimitsTest, LimitsTest,
MiscTest,
] ]
if hasattr(os, "link"): if hasattr(os, "link"):