bpo-44439: BZ2File.write() / LZMAFile.write() handle buffer protocol correctly (GH-26764) (GH-26845)

No longer use len() to get the length of the input data. For some buffer protocol objects,
the length obtained by using len() is wrong.
(cherry picked from commit bc6c12c72a)

Co-authored-by: Ma Lin <animalize@users.noreply.github.com>
This commit is contained in:
Miss Islington (bot) 2021-06-22 06:59:53 -07:00 committed by GitHub
parent cf739332bd
commit 01858fbe31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 9 deletions

View file

@ -219,14 +219,22 @@ def write(self, data):
"""Write a byte string to the file. """Write a byte string to the file.
Returns the number of uncompressed bytes written, which is Returns the number of uncompressed bytes written, which is
always len(data). Note that due to buffering, the file on disk always the length of data in bytes. Note that due to buffering,
may not reflect the data written until close() is called. the file on disk may not reflect the data written until close()
is called.
""" """
self._check_can_write() self._check_can_write()
if isinstance(data, (bytes, bytearray)):
length = len(data)
else:
# accept any data that supports the buffer protocol
data = memoryview(data)
length = data.nbytes
compressed = self._compressor.compress(data) compressed = self._compressor.compress(data)
self._fp.write(compressed) self._fp.write(compressed)
self._pos += len(data) self._pos += length
return len(data) return length
def writelines(self, seq): def writelines(self, seq):
"""Write a sequence of byte strings to the file. """Write a sequence of byte strings to the file.

View file

@ -278,7 +278,7 @@ def write(self,data):
if self.fileobj is None: if self.fileobj is None:
raise ValueError("write() on closed GzipFile object") raise ValueError("write() on closed GzipFile object")
if isinstance(data, bytes): if isinstance(data, (bytes, bytearray)):
length = len(data) length = len(data)
else: else:
# accept any data that supports the buffer protocol # accept any data that supports the buffer protocol

View file

@ -229,14 +229,22 @@ def write(self, data):
"""Write a bytes object to the file. """Write a bytes object to the file.
Returns the number of uncompressed bytes written, which is Returns the number of uncompressed bytes written, which is
always len(data). Note that due to buffering, the file on disk always the length of data in bytes. Note that due to buffering,
may not reflect the data written until close() is called. the file on disk may not reflect the data written until close()
is called.
""" """
self._check_can_write() self._check_can_write()
if isinstance(data, (bytes, bytearray)):
length = len(data)
else:
# accept any data that supports the buffer protocol
data = memoryview(data)
length = data.nbytes
compressed = self._compressor.compress(data) compressed = self._compressor.compress(data)
self._fp.write(compressed) self._fp.write(compressed)
self._pos += len(data) self._pos += length
return len(data) return length
def seek(self, offset, whence=io.SEEK_SET): def seek(self, offset, whence=io.SEEK_SET):
"""Change the file position. """Change the file position.

View file

@ -1,6 +1,7 @@
from test import support from test import support
from test.support import bigmemtest, _4G from test.support import bigmemtest, _4G
import array
import unittest import unittest
from io import BytesIO, DEFAULT_BUFFER_SIZE from io import BytesIO, DEFAULT_BUFFER_SIZE
import os import os
@ -620,6 +621,14 @@ def test_read_truncated(self):
with BZ2File(BytesIO(truncated[:i])) as f: with BZ2File(BytesIO(truncated[:i])) as f:
self.assertRaises(EOFError, f.read, 1) self.assertRaises(EOFError, f.read, 1)
def test_issue44439(self):
q = array.array('Q', [1, 2, 3, 4, 5])
LENGTH = len(q) * q.itemsize
with BZ2File(BytesIO(), 'w') as f:
self.assertEqual(f.write(q), LENGTH)
self.assertEqual(f.tell(), LENGTH)
class BZ2CompressorTest(BaseTest): class BZ2CompressorTest(BaseTest):
def testCompress(self): def testCompress(self):

View file

@ -592,6 +592,15 @@ def test_prepend_error(self):
with gzip.open(self.filename, "rb") as f: with gzip.open(self.filename, "rb") as f:
f._buffer.raw._fp.prepend() f._buffer.raw._fp.prepend()
def test_issue44439(self):
q = array.array('Q', [1, 2, 3, 4, 5])
LENGTH = len(q) * q.itemsize
with gzip.GzipFile(fileobj=io.BytesIO(), mode='w') as f:
self.assertEqual(f.write(q), LENGTH)
self.assertEqual(f.tell(), LENGTH)
class TestOpen(BaseTest): class TestOpen(BaseTest):
def test_binary_modes(self): def test_binary_modes(self):
uncompressed = data1 * 50 uncompressed = data1 * 50

View file

@ -1,4 +1,5 @@
import _compression import _compression
import array
from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE
import os import os
import pathlib import pathlib
@ -1231,6 +1232,14 @@ def test_issue21872(self):
self.assertTrue(d2.eof) self.assertTrue(d2.eof)
self.assertEqual(out1 + out2, entire) self.assertEqual(out1 + out2, entire)
def test_issue44439(self):
q = array.array('Q', [1, 2, 3, 4, 5])
LENGTH = len(q) * q.itemsize
with LZMAFile(BytesIO(), 'w') as f:
self.assertEqual(f.write(q), LENGTH)
self.assertEqual(f.tell(), LENGTH)
class OpenTestCase(unittest.TestCase): class OpenTestCase(unittest.TestCase):

View file

@ -0,0 +1,3 @@
Fix in :meth:`bz2.BZ2File.write` / :meth:`lzma.LZMAFile.write` methods, when
the input data is an object that supports the buffer protocol, the file length
may be wrong.