mirror of
https://github.com/Legrandin/pycryptodome.git
synced 2025-12-08 05:19:46 +00:00
Support for memoryview in CBC/OFB/CFB mode
This commit is contained in:
parent
0914cc686d
commit
1a56f87afe
6 changed files with 167 additions and 25 deletions
|
|
@ -34,7 +34,7 @@ Ciphertext Block Chaining (CBC) mode.
|
|||
|
||||
__all__ = ['CbcMode']
|
||||
|
||||
from Crypto.Util.py3compat import bstr
|
||||
from Crypto.Util.py3compat import _copy_bytes
|
||||
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
|
||||
create_string_buffer, get_raw_buffer,
|
||||
SmartPointer, c_size_t, c_uint8_ptr)
|
||||
|
|
@ -112,7 +112,7 @@ class CbcMode(object):
|
|||
self.block_size = len(iv)
|
||||
"""The block size of the underlying cipher, in bytes."""
|
||||
|
||||
self.iv = bstr(iv)
|
||||
self.iv = _copy_bytes(0, iv)
|
||||
"""The Initialization Vector originally used to create the object.
|
||||
The value does not change."""
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ Counter Feedback (CFB) mode.
|
|||
|
||||
__all__ = ['CfbMode']
|
||||
|
||||
from Crypto.Util.py3compat import bstr
|
||||
from Crypto.Util.py3compat import _copy_bytes
|
||||
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
|
||||
create_string_buffer, get_raw_buffer,
|
||||
SmartPointer, c_size_t, c_uint8_ptr)
|
||||
|
|
@ -111,7 +111,7 @@ class CfbMode(object):
|
|||
self.block_size = len(iv)
|
||||
"""The block size of the underlying cipher, in bytes."""
|
||||
|
||||
self.iv = bstr(iv)
|
||||
self.iv = _copy_bytes(0, iv)
|
||||
"""The Initialization Vector originally used to create the object.
|
||||
The value does not change."""
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ Output Feedback (CFB) mode.
|
|||
|
||||
__all__ = ['OfbMode']
|
||||
|
||||
from Crypto.Util.py3compat import bstr
|
||||
from Crypto.Util.py3compat import _copy_bytes
|
||||
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
|
||||
create_string_buffer, get_raw_buffer,
|
||||
SmartPointer, c_size_t, c_uint8_ptr)
|
||||
|
|
@ -108,7 +108,7 @@ class OfbMode(object):
|
|||
self.block_size = len(iv)
|
||||
"""The block size of the underlying cipher, in bytes."""
|
||||
|
||||
self.iv = bstr(iv)
|
||||
self.iv = _copy_bytes(0, iv)
|
||||
"""The Initialization Vector originally used to create the object.
|
||||
The value does not change."""
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@
|
|||
import unittest
|
||||
from binascii import a2b_hex, b2a_hex, hexlify
|
||||
|
||||
from Crypto.Util.py3compat import *
|
||||
from Crypto.Util.py3compat import b, _memoryview
|
||||
from Crypto.Util.strxor import strxor_c
|
||||
|
||||
class _NoDefault: pass # sentinel object
|
||||
|
|
@ -305,6 +305,86 @@ class ByteArrayTest(unittest.TestCase):
|
|||
decipher.verify(bytearray(a2b_hex(self.mac)))
|
||||
|
||||
|
||||
class MemoryviewTest(unittest.TestCase):
|
||||
"""Verify we can use memoryviews for encrypting and decrypting"""
|
||||
|
||||
def __init__(self, module, params):
|
||||
unittest.TestCase.__init__(self)
|
||||
self.module = module
|
||||
|
||||
# Extract the parameters
|
||||
params = params.copy()
|
||||
self.description = _extract(params, 'description')
|
||||
self.key = b(_extract(params, 'key'))
|
||||
self.plaintext = b(_extract(params, 'plaintext'))
|
||||
self.ciphertext = b(_extract(params, 'ciphertext'))
|
||||
self.module_name = _extract(params, 'module_name', None)
|
||||
self.assoc_data = _extract(params, 'assoc_data', None)
|
||||
self.mac = _extract(params, 'mac', None)
|
||||
if self.assoc_data:
|
||||
self.mac = b(self.mac)
|
||||
|
||||
mode = _extract(params, 'mode', None)
|
||||
self.mode_name = str(mode)
|
||||
|
||||
if mode is not None:
|
||||
# Block cipher
|
||||
self.mode = getattr(self.module, "MODE_" + mode)
|
||||
|
||||
self.iv = _extract(params, 'iv', None)
|
||||
if self.iv is None:
|
||||
self.iv = _extract(params, 'nonce', None)
|
||||
if self.iv is not None:
|
||||
self.iv = b(self.iv)
|
||||
else:
|
||||
# Stream cipher
|
||||
self.mode = None
|
||||
self.iv = _extract(params, 'iv', None)
|
||||
if self.iv is not None:
|
||||
self.iv = b(self.iv)
|
||||
|
||||
self.extra_params = params
|
||||
|
||||
def _new(self):
|
||||
params = self.extra_params.copy()
|
||||
key = a2b_hex(self.key)
|
||||
|
||||
old_style = []
|
||||
if self.mode is not None:
|
||||
old_style = [ self.mode ]
|
||||
if self.iv is not None:
|
||||
old_style += [ a2b_hex(self.iv) ]
|
||||
|
||||
return self.module.new(key, *old_style, **params)
|
||||
|
||||
def runTest(self):
|
||||
|
||||
plaintext = a2b_hex(self.plaintext)
|
||||
ciphertext = a2b_hex(self.ciphertext)
|
||||
assoc_data = []
|
||||
if self.assoc_data:
|
||||
assoc_data = [ memoryview(a2b_hex(b(x))) for x in self.assoc_data]
|
||||
|
||||
cipher = self._new()
|
||||
decipher = self._new()
|
||||
|
||||
# Only AEAD modes
|
||||
for comp in assoc_data:
|
||||
cipher.update(comp)
|
||||
decipher.update(comp)
|
||||
|
||||
ct = b2a_hex(cipher.encrypt(memoryview(plaintext)))
|
||||
pt = b2a_hex(decipher.decrypt(memoryview(ciphertext)))
|
||||
|
||||
self.assertEqual(self.ciphertext, ct) # encrypt
|
||||
self.assertEqual(self.plaintext, pt) # decrypt
|
||||
|
||||
if self.mac:
|
||||
mac = b2a_hex(cipher.digest())
|
||||
self.assertEqual(self.mac, mac)
|
||||
decipher.verify(memoryview(a2b_hex(self.mac)))
|
||||
|
||||
|
||||
def make_block_tests(module, module_name, test_data, additional_params=dict()):
|
||||
tests = []
|
||||
extra_tests_added = False
|
||||
|
|
@ -406,6 +486,9 @@ def make_stream_tests(module, module_name, test_data):
|
|||
tests += [
|
||||
ByteArrayTest(module, params),
|
||||
]
|
||||
import types
|
||||
if _memoryview is not types.NoneType:
|
||||
tests.append(MemoryviewTest(module, params))
|
||||
extra_tests_added = True
|
||||
|
||||
# Add the test to the test suite
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import unittest
|
|||
|
||||
from Crypto.SelfTest.loader import load_tests
|
||||
from Crypto.SelfTest.st_common import list_test_cases
|
||||
from Crypto.Util.py3compat import tobytes, b, unhexlify
|
||||
from Crypto.Util.py3compat import tobytes, unhexlify, _memoryview
|
||||
from Crypto.Cipher import AES, DES3, DES
|
||||
from Crypto.Hash import SHAKE128
|
||||
|
||||
|
|
@ -95,11 +95,11 @@ class BlockChainingTests(unittest.TestCase):
|
|||
|
||||
def test_iv_with_matching_length(self):
|
||||
self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
|
||||
b(""))
|
||||
b"")
|
||||
self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
|
||||
self.iv_128[:15])
|
||||
self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
|
||||
self.iv_128 + b("0"))
|
||||
self.iv_128 + b"0")
|
||||
|
||||
def test_block_size_128(self):
|
||||
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
|
|
@ -112,20 +112,20 @@ class BlockChainingTests(unittest.TestCase):
|
|||
def test_unaligned_data_128(self):
|
||||
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
for wrong_length in xrange(1,16):
|
||||
self.assertRaises(ValueError, cipher.encrypt, b("5") * wrong_length)
|
||||
self.assertRaises(ValueError, cipher.encrypt, b"5" * wrong_length)
|
||||
|
||||
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
for wrong_length in xrange(1,16):
|
||||
self.assertRaises(ValueError, cipher.decrypt, b("5") * wrong_length)
|
||||
self.assertRaises(ValueError, cipher.decrypt, b"5" * wrong_length)
|
||||
|
||||
def test_unaligned_data_64(self):
|
||||
cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
|
||||
for wrong_length in xrange(1,8):
|
||||
self.assertRaises(ValueError, cipher.encrypt, b("5") * wrong_length)
|
||||
self.assertRaises(ValueError, cipher.encrypt, b"5" * wrong_length)
|
||||
|
||||
cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
|
||||
for wrong_length in xrange(1,8):
|
||||
self.assertRaises(ValueError, cipher.decrypt, b("5") * wrong_length)
|
||||
self.assertRaises(ValueError, cipher.decrypt, b"5" * wrong_length)
|
||||
|
||||
def test_IV_iv_attributes(self):
|
||||
data = get_tag_random("data", 16 * 100)
|
||||
|
|
@ -146,17 +146,17 @@ class BlockChainingTests(unittest.TestCase):
|
|||
def test_null_encryption_decryption(self):
|
||||
for func in "encrypt", "decrypt":
|
||||
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
result = getattr(cipher, func)(b(""))
|
||||
self.assertEqual(result, b(""))
|
||||
result = getattr(cipher, func)(b"")
|
||||
self.assertEqual(result, b"")
|
||||
|
||||
def test_either_encrypt_or_decrypt(self):
|
||||
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
cipher.encrypt(b(""))
|
||||
self.assertRaises(TypeError, cipher.decrypt, b(""))
|
||||
cipher.encrypt(b"")
|
||||
self.assertRaises(TypeError, cipher.decrypt, b"")
|
||||
|
||||
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
cipher.decrypt(b(""))
|
||||
self.assertRaises(TypeError, cipher.encrypt, b(""))
|
||||
cipher.decrypt(b"")
|
||||
self.assertRaises(TypeError, cipher.encrypt, b"")
|
||||
|
||||
def test_data_must_be_bytes(self):
|
||||
cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
|
|
@ -166,27 +166,75 @@ class BlockChainingTests(unittest.TestCase):
|
|||
self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
|
||||
|
||||
def test_bytearray(self):
|
||||
data = b("1") * 16
|
||||
data = b"1" * 16
|
||||
data_ba = bytearray(data)
|
||||
|
||||
# Encrypt
|
||||
key_ba = bytearray(self.key_128)
|
||||
iv_ba = bytearray(self.iv_128)
|
||||
|
||||
cipher1 = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
ref1 = cipher1.encrypt(data)
|
||||
|
||||
cipher2 = AES.new(bytearray(self.key_128), self.aes_mode, bytearray(self.iv_128))
|
||||
ref2 = cipher2.encrypt(bytearray(data))
|
||||
cipher2 = AES.new(key_ba, self.aes_mode, iv_ba)
|
||||
key_ba[:3] = b'\xFF\xFF\xFF'
|
||||
iv_ba[:3] = b'\xFF\xFF\xFF'
|
||||
ref2 = cipher2.encrypt(data_ba)
|
||||
|
||||
self.assertEqual(ref1, ref2)
|
||||
self.assertEqual(cipher1.iv, cipher2.iv)
|
||||
|
||||
# Decrypt
|
||||
key_ba = bytearray(self.key_128)
|
||||
iv_ba = bytearray(self.iv_128)
|
||||
|
||||
cipher3 = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
ref3 = cipher3.decrypt(data)
|
||||
|
||||
cipher4 = AES.new(bytearray(self.key_128), self.aes_mode, bytearray(self.iv_128))
|
||||
ref4 = cipher4.decrypt(bytearray(data))
|
||||
cipher4 = AES.new(key_ba, self.aes_mode, iv_ba)
|
||||
key_ba[:3] = b'\xFF\xFF\xFF'
|
||||
iv_ba[:3] = b'\xFF\xFF\xFF'
|
||||
ref4 = cipher4.decrypt(data_ba)
|
||||
|
||||
self.assertEqual(ref3, ref4)
|
||||
|
||||
def test_memoryview(self):
|
||||
data = b"1" * 16
|
||||
data_mv = memoryview(bytearray(data))
|
||||
|
||||
# Encrypt
|
||||
key_mv = memoryview(bytearray(self.key_128))
|
||||
iv_mv = memoryview(bytearray(self.iv_128))
|
||||
|
||||
cipher1 = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
ref1 = cipher1.encrypt(data)
|
||||
|
||||
cipher2 = AES.new(key_mv, self.aes_mode, iv_mv)
|
||||
key_mv[:3] = b'\xFF\xFF\xFF'
|
||||
iv_mv[:3] = b'\xFF\xFF\xFF'
|
||||
ref2 = cipher2.encrypt(data_mv)
|
||||
|
||||
self.assertEqual(ref1, ref2)
|
||||
self.assertEqual(cipher1.iv, cipher2.iv)
|
||||
|
||||
# Decrypt
|
||||
key_mv = memoryview(bytearray(self.key_128))
|
||||
iv_mv = memoryview(bytearray(self.iv_128))
|
||||
|
||||
cipher3 = AES.new(self.key_128, self.aes_mode, self.iv_128)
|
||||
ref3 = cipher3.decrypt(data)
|
||||
|
||||
cipher4 = AES.new(key_mv, self.aes_mode, iv_mv)
|
||||
key_mv[:3] = b'\xFF\xFF\xFF'
|
||||
iv_mv[:3] = b'\xFF\xFF\xFF'
|
||||
ref4 = cipher4.decrypt(data_mv)
|
||||
|
||||
self.assertEqual(ref3, ref4)
|
||||
|
||||
import types
|
||||
if _memoryview is types.NoneType:
|
||||
del test_memoryview
|
||||
|
||||
|
||||
class CbcTests(BlockChainingTests):
|
||||
aes_mode = AES.MODE_CBC
|
||||
|
|
|
|||
|
|
@ -133,4 +133,15 @@ else:
|
|||
|
||||
_memoryview = memoryview
|
||||
|
||||
|
||||
def _copy_bytes(start, seq):
|
||||
"""Return a copy of a sequence (byte string, byte array, memoryview)
|
||||
starting from a certain index"""
|
||||
|
||||
if isinstance(seq, _memoryview):
|
||||
return seq[start:].tobytes()
|
||||
else:
|
||||
return seq[start:]
|
||||
|
||||
|
||||
del sys
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue