Support for memoryview in CBC/OFB/CFB mode

This commit is contained in:
Helder Eijs 2018-04-01 22:17:07 +02:00
parent 0914cc686d
commit 1a56f87afe
6 changed files with 167 additions and 25 deletions

View file

@ -34,7 +34,7 @@ Ciphertext Block Chaining (CBC) mode.
__all__ = ['CbcMode'] __all__ = ['CbcMode']
from Crypto.Util.py3compat import bstr from Crypto.Util.py3compat import _copy_bytes
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer, from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
create_string_buffer, get_raw_buffer, create_string_buffer, get_raw_buffer,
SmartPointer, c_size_t, c_uint8_ptr) SmartPointer, c_size_t, c_uint8_ptr)
@ -112,7 +112,7 @@ class CbcMode(object):
self.block_size = len(iv) self.block_size = len(iv)
"""The block size of the underlying cipher, in bytes.""" """The block size of the underlying cipher, in bytes."""
self.iv = bstr(iv) self.iv = _copy_bytes(0, iv)
"""The Initialization Vector originally used to create the object. """The Initialization Vector originally used to create the object.
The value does not change.""" The value does not change."""

View file

@ -26,7 +26,7 @@ Counter Feedback (CFB) mode.
__all__ = ['CfbMode'] __all__ = ['CfbMode']
from Crypto.Util.py3compat import bstr from Crypto.Util.py3compat import _copy_bytes
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer, from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
create_string_buffer, get_raw_buffer, create_string_buffer, get_raw_buffer,
SmartPointer, c_size_t, c_uint8_ptr) SmartPointer, c_size_t, c_uint8_ptr)
@ -111,7 +111,7 @@ class CfbMode(object):
self.block_size = len(iv) self.block_size = len(iv)
"""The block size of the underlying cipher, in bytes.""" """The block size of the underlying cipher, in bytes."""
self.iv = bstr(iv) self.iv = _copy_bytes(0, iv)
"""The Initialization Vector originally used to create the object. """The Initialization Vector originally used to create the object.
The value does not change.""" The value does not change."""

View file

@ -26,7 +26,7 @@ Output Feedback (CFB) mode.
__all__ = ['OfbMode'] __all__ = ['OfbMode']
from Crypto.Util.py3compat import bstr from Crypto.Util.py3compat import _copy_bytes
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer, from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
create_string_buffer, get_raw_buffer, create_string_buffer, get_raw_buffer,
SmartPointer, c_size_t, c_uint8_ptr) SmartPointer, c_size_t, c_uint8_ptr)
@ -108,7 +108,7 @@ class OfbMode(object):
self.block_size = len(iv) self.block_size = len(iv)
"""The block size of the underlying cipher, in bytes.""" """The block size of the underlying cipher, in bytes."""
self.iv = bstr(iv) self.iv = _copy_bytes(0, iv)
"""The Initialization Vector originally used to create the object. """The Initialization Vector originally used to create the object.
The value does not change.""" The value does not change."""

View file

@ -27,7 +27,7 @@
import unittest import unittest
from binascii import a2b_hex, b2a_hex, hexlify from binascii import a2b_hex, b2a_hex, hexlify
from Crypto.Util.py3compat import * from Crypto.Util.py3compat import b, _memoryview
from Crypto.Util.strxor import strxor_c from Crypto.Util.strxor import strxor_c
class _NoDefault: pass # sentinel object class _NoDefault: pass # sentinel object
@ -305,6 +305,86 @@ class ByteArrayTest(unittest.TestCase):
decipher.verify(bytearray(a2b_hex(self.mac))) decipher.verify(bytearray(a2b_hex(self.mac)))
class MemoryviewTest(unittest.TestCase):
"""Verify we can use memoryviews for encrypting and decrypting"""
def __init__(self, module, params):
unittest.TestCase.__init__(self)
self.module = module
# Extract the parameters
params = params.copy()
self.description = _extract(params, 'description')
self.key = b(_extract(params, 'key'))
self.plaintext = b(_extract(params, 'plaintext'))
self.ciphertext = b(_extract(params, 'ciphertext'))
self.module_name = _extract(params, 'module_name', None)
self.assoc_data = _extract(params, 'assoc_data', None)
self.mac = _extract(params, 'mac', None)
if self.assoc_data:
self.mac = b(self.mac)
mode = _extract(params, 'mode', None)
self.mode_name = str(mode)
if mode is not None:
# Block cipher
self.mode = getattr(self.module, "MODE_" + mode)
self.iv = _extract(params, 'iv', None)
if self.iv is None:
self.iv = _extract(params, 'nonce', None)
if self.iv is not None:
self.iv = b(self.iv)
else:
# Stream cipher
self.mode = None
self.iv = _extract(params, 'iv', None)
if self.iv is not None:
self.iv = b(self.iv)
self.extra_params = params
def _new(self):
params = self.extra_params.copy()
key = a2b_hex(self.key)
old_style = []
if self.mode is not None:
old_style = [ self.mode ]
if self.iv is not None:
old_style += [ a2b_hex(self.iv) ]
return self.module.new(key, *old_style, **params)
def runTest(self):
plaintext = a2b_hex(self.plaintext)
ciphertext = a2b_hex(self.ciphertext)
assoc_data = []
if self.assoc_data:
assoc_data = [ memoryview(a2b_hex(b(x))) for x in self.assoc_data]
cipher = self._new()
decipher = self._new()
# Only AEAD modes
for comp in assoc_data:
cipher.update(comp)
decipher.update(comp)
ct = b2a_hex(cipher.encrypt(memoryview(plaintext)))
pt = b2a_hex(decipher.decrypt(memoryview(ciphertext)))
self.assertEqual(self.ciphertext, ct) # encrypt
self.assertEqual(self.plaintext, pt) # decrypt
if self.mac:
mac = b2a_hex(cipher.digest())
self.assertEqual(self.mac, mac)
decipher.verify(memoryview(a2b_hex(self.mac)))
def make_block_tests(module, module_name, test_data, additional_params=dict()): def make_block_tests(module, module_name, test_data, additional_params=dict()):
tests = [] tests = []
extra_tests_added = False extra_tests_added = False
@ -406,6 +486,9 @@ def make_stream_tests(module, module_name, test_data):
tests += [ tests += [
ByteArrayTest(module, params), ByteArrayTest(module, params),
] ]
import types
if _memoryview is not types.NoneType:
tests.append(MemoryviewTest(module, params))
extra_tests_added = True extra_tests_added = True
# Add the test to the test suite # Add the test to the test suite

View file

@ -32,7 +32,7 @@ import unittest
from Crypto.SelfTest.loader import load_tests from Crypto.SelfTest.loader import load_tests
from Crypto.SelfTest.st_common import list_test_cases from Crypto.SelfTest.st_common import list_test_cases
from Crypto.Util.py3compat import tobytes, b, unhexlify from Crypto.Util.py3compat import tobytes, unhexlify, _memoryview
from Crypto.Cipher import AES, DES3, DES from Crypto.Cipher import AES, DES3, DES
from Crypto.Hash import SHAKE128 from Crypto.Hash import SHAKE128
@ -95,11 +95,11 @@ class BlockChainingTests(unittest.TestCase):
def test_iv_with_matching_length(self): def test_iv_with_matching_length(self):
self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode, self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
b("")) b"")
self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode, self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
self.iv_128[:15]) self.iv_128[:15])
self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode, self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
self.iv_128 + b("0")) self.iv_128 + b"0")
def test_block_size_128(self): def test_block_size_128(self):
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
@ -112,20 +112,20 @@ class BlockChainingTests(unittest.TestCase):
def test_unaligned_data_128(self): def test_unaligned_data_128(self):
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
for wrong_length in xrange(1,16): for wrong_length in xrange(1,16):
self.assertRaises(ValueError, cipher.encrypt, b("5") * wrong_length) self.assertRaises(ValueError, cipher.encrypt, b"5" * wrong_length)
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
for wrong_length in xrange(1,16): for wrong_length in xrange(1,16):
self.assertRaises(ValueError, cipher.decrypt, b("5") * wrong_length) self.assertRaises(ValueError, cipher.decrypt, b"5" * wrong_length)
def test_unaligned_data_64(self): def test_unaligned_data_64(self):
cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64) cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
for wrong_length in xrange(1,8): for wrong_length in xrange(1,8):
self.assertRaises(ValueError, cipher.encrypt, b("5") * wrong_length) self.assertRaises(ValueError, cipher.encrypt, b"5" * wrong_length)
cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64) cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
for wrong_length in xrange(1,8): for wrong_length in xrange(1,8):
self.assertRaises(ValueError, cipher.decrypt, b("5") * wrong_length) self.assertRaises(ValueError, cipher.decrypt, b"5" * wrong_length)
def test_IV_iv_attributes(self): def test_IV_iv_attributes(self):
data = get_tag_random("data", 16 * 100) data = get_tag_random("data", 16 * 100)
@ -146,17 +146,17 @@ class BlockChainingTests(unittest.TestCase):
def test_null_encryption_decryption(self): def test_null_encryption_decryption(self):
for func in "encrypt", "decrypt": for func in "encrypt", "decrypt":
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
result = getattr(cipher, func)(b("")) result = getattr(cipher, func)(b"")
self.assertEqual(result, b("")) self.assertEqual(result, b"")
def test_either_encrypt_or_decrypt(self): def test_either_encrypt_or_decrypt(self):
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
cipher.encrypt(b("")) cipher.encrypt(b"")
self.assertRaises(TypeError, cipher.decrypt, b("")) self.assertRaises(TypeError, cipher.decrypt, b"")
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
cipher.decrypt(b("")) cipher.decrypt(b"")
self.assertRaises(TypeError, cipher.encrypt, b("")) self.assertRaises(TypeError, cipher.encrypt, b"")
def test_data_must_be_bytes(self): def test_data_must_be_bytes(self):
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
@ -166,27 +166,75 @@ class BlockChainingTests(unittest.TestCase):
self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*') self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
def test_bytearray(self): def test_bytearray(self):
data = b("1") * 16 data = b"1" * 16
data_ba = bytearray(data)
# Encrypt # Encrypt
key_ba = bytearray(self.key_128)
iv_ba = bytearray(self.iv_128)
cipher1 = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher1 = AES.new(self.key_128, self.aes_mode, self.iv_128)
ref1 = cipher1.encrypt(data) ref1 = cipher1.encrypt(data)
cipher2 = AES.new(bytearray(self.key_128), self.aes_mode, bytearray(self.iv_128)) cipher2 = AES.new(key_ba, self.aes_mode, iv_ba)
ref2 = cipher2.encrypt(bytearray(data)) key_ba[:3] = b'\xFF\xFF\xFF'
iv_ba[:3] = b'\xFF\xFF\xFF'
ref2 = cipher2.encrypt(data_ba)
self.assertEqual(ref1, ref2) self.assertEqual(ref1, ref2)
self.assertEqual(cipher1.iv, cipher2.iv) self.assertEqual(cipher1.iv, cipher2.iv)
# Decrypt # Decrypt
key_ba = bytearray(self.key_128)
iv_ba = bytearray(self.iv_128)
cipher3 = AES.new(self.key_128, self.aes_mode, self.iv_128) cipher3 = AES.new(self.key_128, self.aes_mode, self.iv_128)
ref3 = cipher3.decrypt(data) ref3 = cipher3.decrypt(data)
cipher4 = AES.new(bytearray(self.key_128), self.aes_mode, bytearray(self.iv_128)) cipher4 = AES.new(key_ba, self.aes_mode, iv_ba)
ref4 = cipher4.decrypt(bytearray(data)) key_ba[:3] = b'\xFF\xFF\xFF'
iv_ba[:3] = b'\xFF\xFF\xFF'
ref4 = cipher4.decrypt(data_ba)
self.assertEqual(ref3, ref4) self.assertEqual(ref3, ref4)
def test_memoryview(self):
data = b"1" * 16
data_mv = memoryview(bytearray(data))
# Encrypt
key_mv = memoryview(bytearray(self.key_128))
iv_mv = memoryview(bytearray(self.iv_128))
cipher1 = AES.new(self.key_128, self.aes_mode, self.iv_128)
ref1 = cipher1.encrypt(data)
cipher2 = AES.new(key_mv, self.aes_mode, iv_mv)
key_mv[:3] = b'\xFF\xFF\xFF'
iv_mv[:3] = b'\xFF\xFF\xFF'
ref2 = cipher2.encrypt(data_mv)
self.assertEqual(ref1, ref2)
self.assertEqual(cipher1.iv, cipher2.iv)
# Decrypt
key_mv = memoryview(bytearray(self.key_128))
iv_mv = memoryview(bytearray(self.iv_128))
cipher3 = AES.new(self.key_128, self.aes_mode, self.iv_128)
ref3 = cipher3.decrypt(data)
cipher4 = AES.new(key_mv, self.aes_mode, iv_mv)
key_mv[:3] = b'\xFF\xFF\xFF'
iv_mv[:3] = b'\xFF\xFF\xFF'
ref4 = cipher4.decrypt(data_mv)
self.assertEqual(ref3, ref4)
import types
if _memoryview is types.NoneType:
del test_memoryview
class CbcTests(BlockChainingTests): class CbcTests(BlockChainingTests):
aes_mode = AES.MODE_CBC aes_mode = AES.MODE_CBC

View file

@ -133,4 +133,15 @@ else:
_memoryview = memoryview _memoryview = memoryview
def _copy_bytes(start, seq):
"""Return a copy of a sequence (byte string, byte array, memoryview)
starting from a certain index"""
if isinstance(seq, _memoryview):
return seq[start:].tobytes()
else:
return seq[start:]
del sys del sys