Support for memoryviews in CCM

This commit is contained in:
Helder Eijs 2018-04-02 15:35:12 +02:00
parent 1a56f87afe
commit ec4084eaaf
7 changed files with 136 additions and 63 deletions

View file

@ -112,7 +112,7 @@ class CbcMode(object):
self.block_size = len(iv)
"""The block size of the underlying cipher, in bytes."""
self.iv = _copy_bytes(0, iv)
self.iv = _copy_bytes(None, None, iv)
"""The Initialization Vector originally used to create the object.
The value does not change."""

View file

@ -34,7 +34,10 @@ Counter with CBC-MAC (CCM) mode.
__all__ = ['CcmMode']
from Crypto.Util.py3compat import byte_string, b, bchr, bord, unhexlify, bstr
import struct
from Crypto.Util.py3compat import (byte_string, bord, unhexlify,
_copy_bytes, _is_mutable)
from Crypto.Util.strxor import strxor
from Crypto.Util.number import long_to_bytes
@ -114,11 +117,11 @@ class CcmMode(object):
self.block_size = factory.block_size
"""The block size of the underlying cipher, in bytes."""
self.nonce = bstr(nonce)
self.nonce = _copy_bytes(None, None, nonce)
"""The nonce used for this cipher instance"""
self._factory = factory
self._key = bstr(key)
self._key = _copy_bytes(None, None, key)
self._mac_len = mac_len
self._msg_len = msg_len
self._assoc_len = assoc_len
@ -144,7 +147,7 @@ class CcmMode(object):
# bytes worth of ciphertext)
self._mac = self._factory.new(key,
factory.MODE_CBC,
iv=bchr(0) * 16,
iv=b'\x00' * 16,
**cipher_params)
self._mac_status = MacStatus.NOT_STARTED
self._t = None
@ -158,19 +161,19 @@ class CcmMode(object):
self._cumul_msg_len = 0
# Cache for unaligned associated data/plaintext.
# This is a list, but when the MAC starts, it will become a binary
# string no longer than the block size.
# This is a list with byte strings, but when the MAC starts,
# it will become a binary string no longer than the block size.
self._cache = []
# Start CTR cipher, by formatting the counter (A.3)
q = 15 - len(nonce) # length of Q, the encoded message length
self._cipher = self._factory.new(key,
self._factory.MODE_CTR,
nonce=bchr(q - 1) + nonce,
nonce=struct.pack("B", q - 1) + self.nonce,
**cipher_params)
# S_0, step 6 in 6.1 for j=0
self._s_0 = self._cipher.encrypt(bchr(0) * 16)
self._s_0 = self._cipher.encrypt(b'\x00' * 16)
# Try to start the MAC
if None not in (assoc_len, msg_len):
@ -186,19 +189,19 @@ class CcmMode(object):
q = 15 - len(self.nonce) # length of Q, the encoded message length
flags = (64 * (self._assoc_len > 0) + 8 * ((self._mac_len - 2) // 2) +
(q - 1))
b_0 = bchr(flags) + bstr(self.nonce) + long_to_bytes(self._msg_len, q)
b_0 = struct.pack("B", flags) + self.nonce + long_to_bytes(self._msg_len, q)
# Formatting associated data (A.2.2)
# Encoded 'a' is concatenated with the associated data 'A'
assoc_len_encoded = b('')
assoc_len_encoded = b''
if self._assoc_len > 0:
if self._assoc_len < (2 ** 16 - 2 ** 8):
enc_size = 2
elif self._assoc_len < (2L ** 32):
assoc_len_encoded = b('\xFF\xFE')
assoc_len_encoded = b'\xFF\xFE'
enc_size = 4
else:
assoc_len_encoded = b('\xFF\xFF')
assoc_len_encoded = b'\xFF\xFF'
enc_size = 8
assoc_len_encoded += long_to_bytes(self._assoc_len, enc_size)
@ -207,8 +210,8 @@ class CcmMode(object):
self._cache.insert(1, assoc_len_encoded)
# Process all the data cached so far
first_data_to_mac = b("").join(self._cache)
self._cache = b("")
first_data_to_mac = b"".join(self._cache)
self._cache = b""
self._mac_status = MacStatus.PROCESSING_AUTH_DATA
self._update(first_data_to_mac)
@ -222,7 +225,7 @@ class CcmMode(object):
# the 16 byte boundary (A.2.3)
len_cache = len(self._cache)
if len_cache > 0:
self._update(bchr(0) * (self.block_size - len_cache))
self._update(b'\x00' * (self.block_size - len_cache))
def update(self, assoc_data):
"""Protect associated data
@ -262,12 +265,16 @@ class CcmMode(object):
self._update(assoc_data)
return self
def _update(self, assoc_data_pt=b("")):
def _update(self, assoc_data_pt=b""):
"""Update the MAC with associated data or plaintext
(without FSM checks)"""
# If MAC has not started yet, we just park the data into a list.
# If the data is mutable, we create a copy and store that instead.
if self._mac_status == MacStatus.NOT_STARTED:
self._cache.append(bstr(assoc_data_pt))
if _is_mutable(assoc_data_pt):
assoc_data_pt = _copy_bytes(None, None, assoc_data_pt)
self._cache.append(assoc_data_pt)
return
assert(len(self._cache) < self.block_size)
@ -275,18 +282,18 @@ class CcmMode(object):
if len(self._cache) > 0:
filler = min(self.block_size - len(self._cache),
len(assoc_data_pt))
self._cache += assoc_data_pt[:filler]
assoc_data_pt = assoc_data_pt[filler:]
self._cache += _copy_bytes(None, filler, assoc_data_pt)
assoc_data_pt = _copy_bytes(filler, None, assoc_data_pt)
if len(self._cache) < self.block_size:
return
# The cache is exactly one block
self._t = self._mac.encrypt(self._cache)
self._cache = b("")
self._cache = b""
update_len = len(assoc_data_pt) // self.block_size * self.block_size
self._cache = assoc_data_pt[update_len:]
self._cache = _copy_bytes(update_len, None, assoc_data_pt)
if update_len > 0:
self._t = self._mac.encrypt(assoc_data_pt[:update_len])[-16:]

View file

@ -111,7 +111,7 @@ class CfbMode(object):
self.block_size = len(iv)
"""The block size of the underlying cipher, in bytes."""
self.iv = _copy_bytes(0, iv)
self.iv = _copy_bytes(None, None, iv)
"""The Initialization Vector originally used to create the object.
The value does not change."""

View file

@ -108,7 +108,7 @@ class OfbMode(object):
self.block_size = len(iv)
"""The block size of the underlying cipher, in bytes."""
self.iv = _copy_bytes(0, iv)
self.iv = _copy_bytes(None, None, iv)
"""The Initialization Vector originally used to create the object.
The value does not change."""

View file

@ -22,7 +22,7 @@
import sys
from Crypto.Util.py3compat import bord, tobytes, _memoryview
from Crypto.Util.py3compat import bord, tobytes, _copy_bytes
from binascii import unhexlify
@ -40,16 +40,6 @@ def _shift_bytes(bs, xor_lsb=0):
return long_to_bytes(num, len(bs))[-len(bs):]
def _copy_bytes(start, seq):
"""Return a copy of a sequence (byte string, byte array, memoryview)
starting from a certain index"""
if isinstance(seq, _memoryview):
return seq[start:].tobytes()
else:
return seq[start:]
class CMAC(object):
"""A CMAC hash object.
Do not instantiate directly. Use the :func:`new` function.
@ -65,7 +55,7 @@ class CMAC(object):
if ciphermod is None:
raise TypeError("ciphermod must be specified (try AES)")
self._key = _copy_bytes(0, key)
self._key = _copy_bytes(None, None, key)
self._factory = ciphermod
if cipher_params is None:
self._cipher_params = {}
@ -148,7 +138,7 @@ class CMAC(object):
update_len *= self.digest_size
if remain > 0:
self._update(msg[:update_len])
self._cache = _copy_bytes(update_len, msg)
self._cache = _copy_bytes(update_len, None, msg)
else:
self._update(msg)
self._cache = b""

View file

@ -31,7 +31,7 @@
import unittest
from Crypto.SelfTest.st_common import list_test_cases
from Crypto.Util.py3compat import unhexlify, tobytes, bchr, b
from Crypto.Util.py3compat import unhexlify, tobytes, bchr, b, _memoryview
from Crypto.Cipher import AES
from Crypto.Hash import SHAKE128
@ -275,40 +275,102 @@ class CcmTests(unittest.TestCase):
self.assertEquals(cipher.digest(), ref_mac)
def test_bytearray(self):
# Encrypt
key_ba = bytearray(self.key_128)
nonce_ba = bytearray(self.nonce_96)
header_ba = bytearray(self.data_128)
data_ba = bytearray(self.data_128)
cipher1 = AES.new(self.key_128,
AES.MODE_CCM,
nonce=self.nonce_96)
cipher1.update(self.data_128)
ref1 = cipher1.encrypt(self.data_128)
cipher2 = AES.new(bytearray(self.key_128),
AES.MODE_CCM,
nonce=bytearray(self.nonce_96))
cipher2.update(bytearray(self.data_128))
ref2 = cipher2.encrypt(bytearray(self.data_128))
self.assertEqual(ref1, ref2)
self.assertEqual(cipher1.nonce, cipher2.nonce)
ct = cipher1.encrypt(self.data_128)
tag = cipher1.digest()
cipher2 = AES.new(key_ba,
AES.MODE_CCM,
nonce=nonce_ba)
key_ba[:3] = b"\xFF\xFF\xFF"
nonce_ba[:3] = b"\xFF\xFF\xFF"
cipher2.update(header_ba)
header_ba[:3] = b"\xFF\xFF\xFF"
ct_test = cipher2.encrypt(data_ba)
data_ba[:3] = b"\xFF\xFF\xFF"
tag_test = cipher2.digest()
self.assertEqual(ct, ct_test)
self.assertEqual(tag, tag_test)
self.assertEqual(cipher1.nonce, cipher2.nonce)
# Decrypt
cipher3 = AES.new(self.key_128,
key_ba = bytearray(self.key_128)
nonce_ba = bytearray(self.nonce_96)
header_ba = bytearray(self.data_128)
del data_ba
cipher4 = AES.new(key_ba,
AES.MODE_CCM,
nonce=nonce_ba)
key_ba[:3] = b"\xFF\xFF\xFF"
nonce_ba[:3] = b"\xFF\xFF\xFF"
cipher4.update(header_ba)
header_ba[:3] = b"\xFF\xFF\xFF"
pt_test = cipher4.decrypt_and_verify(bytearray(ct_test), bytearray(tag_test))
self.assertEqual(self.data_128, pt_test)
def test_memoryview(self):
# Encrypt
key_mv = memoryview(bytearray(self.key_128))
nonce_mv = memoryview(bytearray(self.nonce_96))
header_mv = memoryview(bytearray(self.data_128))
data_mv = memoryview(bytearray(self.data_128))
cipher1 = AES.new(self.key_128,
AES.MODE_CCM,
nonce=self.nonce_96)
cipher3.update(self.data_128)
ref3 = cipher3.decrypt(ref1)
cipher1.update(self.data_128)
ct = cipher1.encrypt(self.data_128)
tag = cipher1.digest()
cipher4 = AES.new(bytearray(self.key_128),
cipher2 = AES.new(key_mv,
AES.MODE_CCM,
nonce=bytearray(self.nonce_96))
cipher4.update(bytearray(self.data_128))
ref4 = cipher4.decrypt(bytearray(ref1))
nonce=nonce_mv)
key_mv[:3] = b"\xFF\xFF\xFF"
nonce_mv[:3] = b"\xFF\xFF\xFF"
cipher2.update(header_mv)
header_mv[:3] = b"\xFF\xFF\xFF"
ct_test = cipher2.encrypt(data_mv)
data_mv[:3] = b"\xFF\xFF\xFF"
tag_test = cipher2.digest()
self.assertEqual(ref3, ref4)
self.assertEqual(ct, ct_test)
self.assertEqual(tag, tag_test)
self.assertEqual(cipher1.nonce, cipher2.nonce)
cipher3.verify(bytearray(tag))
# Decrypt
key_mv = memoryview(bytearray(self.key_128))
nonce_mv = memoryview(bytearray(self.nonce_96))
header_mv = memoryview(bytearray(self.data_128))
del data_mv
cipher4 = AES.new(key_mv,
AES.MODE_CCM,
nonce=nonce_mv)
key_mv[:3] = b"\xFF\xFF\xFF"
nonce_mv[:3] = b"\xFF\xFF\xFF"
cipher4.update(header_mv)
header_mv[:3] = b"\xFF\xFF\xFF"
pt_test = cipher4.decrypt_and_verify(memoryview(ct_test), memoryview(tag_test))
self.assertEqual(self.data_128, pt_test)
import types
if _memoryview is types.NoneType:
del test_memoryview
class CcmFSMTests(unittest.TestCase):

View file

@ -134,14 +134,28 @@ else:
_memoryview = memoryview
def _copy_bytes(start, seq):
"""Return a copy of a sequence (byte string, byte array, memoryview)
starting from a certain index"""
def _copy_bytes(start, end, seq):
"""Return an immutable copy of a sequence (byte string, byte array, memoryview)
in a certain interval [start:seq]"""
if isinstance(seq, _memoryview):
return seq[start:].tobytes()
return seq[start:end].tobytes()
elif isinstance(seq, bytearray):
return bytes(seq[start:end])
else:
return seq[start:]
return seq[start:end]
def _is_immutable(data):
if byte_string(data):
return True
elif isinstance(data, _memoryview) and data.readonly:
return True
return False
def _is_mutable(data):
return not _is_immutable(data)
del sys