mirror of
https://github.com/Legrandin/pycryptodome.git
synced 2025-12-08 05:19:46 +00:00
Support for memoryviews in CCM
This commit is contained in:
parent
1a56f87afe
commit
ec4084eaaf
7 changed files with 136 additions and 63 deletions
|
|
@ -112,7 +112,7 @@ class CbcMode(object):
|
|||
self.block_size = len(iv)
|
||||
"""The block size of the underlying cipher, in bytes."""
|
||||
|
||||
self.iv = _copy_bytes(0, iv)
|
||||
self.iv = _copy_bytes(None, None, iv)
|
||||
"""The Initialization Vector originally used to create the object.
|
||||
The value does not change."""
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,10 @@ Counter with CBC-MAC (CCM) mode.
|
|||
|
||||
__all__ = ['CcmMode']
|
||||
|
||||
from Crypto.Util.py3compat import byte_string, b, bchr, bord, unhexlify, bstr
|
||||
import struct
|
||||
|
||||
from Crypto.Util.py3compat import (byte_string, bord, unhexlify,
|
||||
_copy_bytes, _is_mutable)
|
||||
|
||||
from Crypto.Util.strxor import strxor
|
||||
from Crypto.Util.number import long_to_bytes
|
||||
|
|
@ -114,11 +117,11 @@ class CcmMode(object):
|
|||
self.block_size = factory.block_size
|
||||
"""The block size of the underlying cipher, in bytes."""
|
||||
|
||||
self.nonce = bstr(nonce)
|
||||
self.nonce = _copy_bytes(None, None, nonce)
|
||||
"""The nonce used for this cipher instance"""
|
||||
|
||||
self._factory = factory
|
||||
self._key = bstr(key)
|
||||
self._key = _copy_bytes(None, None, key)
|
||||
self._mac_len = mac_len
|
||||
self._msg_len = msg_len
|
||||
self._assoc_len = assoc_len
|
||||
|
|
@ -144,7 +147,7 @@ class CcmMode(object):
|
|||
# bytes worth of ciphertext)
|
||||
self._mac = self._factory.new(key,
|
||||
factory.MODE_CBC,
|
||||
iv=bchr(0) * 16,
|
||||
iv=b'\x00' * 16,
|
||||
**cipher_params)
|
||||
self._mac_status = MacStatus.NOT_STARTED
|
||||
self._t = None
|
||||
|
|
@ -158,19 +161,19 @@ class CcmMode(object):
|
|||
self._cumul_msg_len = 0
|
||||
|
||||
# Cache for unaligned associated data/plaintext.
|
||||
# This is a list, but when the MAC starts, it will become a binary
|
||||
# string no longer than the block size.
|
||||
# This is a list with byte strings, but when the MAC starts,
|
||||
# it will become a binary string no longer than the block size.
|
||||
self._cache = []
|
||||
|
||||
# Start CTR cipher, by formatting the counter (A.3)
|
||||
q = 15 - len(nonce) # length of Q, the encoded message length
|
||||
self._cipher = self._factory.new(key,
|
||||
self._factory.MODE_CTR,
|
||||
nonce=bchr(q - 1) + nonce,
|
||||
nonce=struct.pack("B", q - 1) + self.nonce,
|
||||
**cipher_params)
|
||||
|
||||
# S_0, step 6 in 6.1 for j=0
|
||||
self._s_0 = self._cipher.encrypt(bchr(0) * 16)
|
||||
self._s_0 = self._cipher.encrypt(b'\x00' * 16)
|
||||
|
||||
# Try to start the MAC
|
||||
if None not in (assoc_len, msg_len):
|
||||
|
|
@ -186,19 +189,19 @@ class CcmMode(object):
|
|||
q = 15 - len(self.nonce) # length of Q, the encoded message length
|
||||
flags = (64 * (self._assoc_len > 0) + 8 * ((self._mac_len - 2) // 2) +
|
||||
(q - 1))
|
||||
b_0 = bchr(flags) + bstr(self.nonce) + long_to_bytes(self._msg_len, q)
|
||||
b_0 = struct.pack("B", flags) + self.nonce + long_to_bytes(self._msg_len, q)
|
||||
|
||||
# Formatting associated data (A.2.2)
|
||||
# Encoded 'a' is concatenated with the associated data 'A'
|
||||
assoc_len_encoded = b('')
|
||||
assoc_len_encoded = b''
|
||||
if self._assoc_len > 0:
|
||||
if self._assoc_len < (2 ** 16 - 2 ** 8):
|
||||
enc_size = 2
|
||||
elif self._assoc_len < (2L ** 32):
|
||||
assoc_len_encoded = b('\xFF\xFE')
|
||||
assoc_len_encoded = b'\xFF\xFE'
|
||||
enc_size = 4
|
||||
else:
|
||||
assoc_len_encoded = b('\xFF\xFF')
|
||||
assoc_len_encoded = b'\xFF\xFF'
|
||||
enc_size = 8
|
||||
assoc_len_encoded += long_to_bytes(self._assoc_len, enc_size)
|
||||
|
||||
|
|
@ -207,8 +210,8 @@ class CcmMode(object):
|
|||
self._cache.insert(1, assoc_len_encoded)
|
||||
|
||||
# Process all the data cached so far
|
||||
first_data_to_mac = b("").join(self._cache)
|
||||
self._cache = b("")
|
||||
first_data_to_mac = b"".join(self._cache)
|
||||
self._cache = b""
|
||||
self._mac_status = MacStatus.PROCESSING_AUTH_DATA
|
||||
self._update(first_data_to_mac)
|
||||
|
||||
|
|
@ -222,7 +225,7 @@ class CcmMode(object):
|
|||
# the 16 byte boundary (A.2.3)
|
||||
len_cache = len(self._cache)
|
||||
if len_cache > 0:
|
||||
self._update(bchr(0) * (self.block_size - len_cache))
|
||||
self._update(b'\x00' * (self.block_size - len_cache))
|
||||
|
||||
def update(self, assoc_data):
|
||||
"""Protect associated data
|
||||
|
|
@ -262,12 +265,16 @@ class CcmMode(object):
|
|||
self._update(assoc_data)
|
||||
return self
|
||||
|
||||
def _update(self, assoc_data_pt=b("")):
|
||||
def _update(self, assoc_data_pt=b""):
|
||||
"""Update the MAC with associated data or plaintext
|
||||
(without FSM checks)"""
|
||||
|
||||
# If MAC has not started yet, we just park the data into a list.
|
||||
# If the data is mutable, we create a copy and store that instead.
|
||||
if self._mac_status == MacStatus.NOT_STARTED:
|
||||
self._cache.append(bstr(assoc_data_pt))
|
||||
if _is_mutable(assoc_data_pt):
|
||||
assoc_data_pt = _copy_bytes(None, None, assoc_data_pt)
|
||||
self._cache.append(assoc_data_pt)
|
||||
return
|
||||
|
||||
assert(len(self._cache) < self.block_size)
|
||||
|
|
@ -275,18 +282,18 @@ class CcmMode(object):
|
|||
if len(self._cache) > 0:
|
||||
filler = min(self.block_size - len(self._cache),
|
||||
len(assoc_data_pt))
|
||||
self._cache += assoc_data_pt[:filler]
|
||||
assoc_data_pt = assoc_data_pt[filler:]
|
||||
self._cache += _copy_bytes(None, filler, assoc_data_pt)
|
||||
assoc_data_pt = _copy_bytes(filler, None, assoc_data_pt)
|
||||
|
||||
if len(self._cache) < self.block_size:
|
||||
return
|
||||
|
||||
# The cache is exactly one block
|
||||
self._t = self._mac.encrypt(self._cache)
|
||||
self._cache = b("")
|
||||
self._cache = b""
|
||||
|
||||
update_len = len(assoc_data_pt) // self.block_size * self.block_size
|
||||
self._cache = assoc_data_pt[update_len:]
|
||||
self._cache = _copy_bytes(update_len, None, assoc_data_pt)
|
||||
if update_len > 0:
|
||||
self._t = self._mac.encrypt(assoc_data_pt[:update_len])[-16:]
|
||||
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class CfbMode(object):
|
|||
self.block_size = len(iv)
|
||||
"""The block size of the underlying cipher, in bytes."""
|
||||
|
||||
self.iv = _copy_bytes(0, iv)
|
||||
self.iv = _copy_bytes(None, None, iv)
|
||||
"""The Initialization Vector originally used to create the object.
|
||||
The value does not change."""
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class OfbMode(object):
|
|||
self.block_size = len(iv)
|
||||
"""The block size of the underlying cipher, in bytes."""
|
||||
|
||||
self.iv = _copy_bytes(0, iv)
|
||||
self.iv = _copy_bytes(None, None, iv)
|
||||
"""The Initialization Vector originally used to create the object.
|
||||
The value does not change."""
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
import sys
|
||||
|
||||
from Crypto.Util.py3compat import bord, tobytes, _memoryview
|
||||
from Crypto.Util.py3compat import bord, tobytes, _copy_bytes
|
||||
|
||||
from binascii import unhexlify
|
||||
|
||||
|
|
@ -40,16 +40,6 @@ def _shift_bytes(bs, xor_lsb=0):
|
|||
return long_to_bytes(num, len(bs))[-len(bs):]
|
||||
|
||||
|
||||
def _copy_bytes(start, seq):
|
||||
"""Return a copy of a sequence (byte string, byte array, memoryview)
|
||||
starting from a certain index"""
|
||||
|
||||
if isinstance(seq, _memoryview):
|
||||
return seq[start:].tobytes()
|
||||
else:
|
||||
return seq[start:]
|
||||
|
||||
|
||||
class CMAC(object):
|
||||
"""A CMAC hash object.
|
||||
Do not instantiate directly. Use the :func:`new` function.
|
||||
|
|
@ -65,7 +55,7 @@ class CMAC(object):
|
|||
if ciphermod is None:
|
||||
raise TypeError("ciphermod must be specified (try AES)")
|
||||
|
||||
self._key = _copy_bytes(0, key)
|
||||
self._key = _copy_bytes(None, None, key)
|
||||
self._factory = ciphermod
|
||||
if cipher_params is None:
|
||||
self._cipher_params = {}
|
||||
|
|
@ -148,7 +138,7 @@ class CMAC(object):
|
|||
update_len *= self.digest_size
|
||||
if remain > 0:
|
||||
self._update(msg[:update_len])
|
||||
self._cache = _copy_bytes(update_len, msg)
|
||||
self._cache = _copy_bytes(update_len, None, msg)
|
||||
else:
|
||||
self._update(msg)
|
||||
self._cache = b""
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@
|
|||
import unittest
|
||||
|
||||
from Crypto.SelfTest.st_common import list_test_cases
|
||||
from Crypto.Util.py3compat import unhexlify, tobytes, bchr, b
|
||||
from Crypto.Util.py3compat import unhexlify, tobytes, bchr, b, _memoryview
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Hash import SHAKE128
|
||||
|
||||
|
|
@ -275,40 +275,102 @@ class CcmTests(unittest.TestCase):
|
|||
self.assertEquals(cipher.digest(), ref_mac)
|
||||
|
||||
def test_bytearray(self):
|
||||
|
||||
# Encrypt
|
||||
key_ba = bytearray(self.key_128)
|
||||
nonce_ba = bytearray(self.nonce_96)
|
||||
header_ba = bytearray(self.data_128)
|
||||
data_ba = bytearray(self.data_128)
|
||||
|
||||
cipher1 = AES.new(self.key_128,
|
||||
AES.MODE_CCM,
|
||||
nonce=self.nonce_96)
|
||||
cipher1.update(self.data_128)
|
||||
ref1 = cipher1.encrypt(self.data_128)
|
||||
|
||||
cipher2 = AES.new(bytearray(self.key_128),
|
||||
AES.MODE_CCM,
|
||||
nonce=bytearray(self.nonce_96))
|
||||
cipher2.update(bytearray(self.data_128))
|
||||
ref2 = cipher2.encrypt(bytearray(self.data_128))
|
||||
|
||||
self.assertEqual(ref1, ref2)
|
||||
self.assertEqual(cipher1.nonce, cipher2.nonce)
|
||||
|
||||
ct = cipher1.encrypt(self.data_128)
|
||||
tag = cipher1.digest()
|
||||
|
||||
cipher2 = AES.new(key_ba,
|
||||
AES.MODE_CCM,
|
||||
nonce=nonce_ba)
|
||||
key_ba[:3] = b"\xFF\xFF\xFF"
|
||||
nonce_ba[:3] = b"\xFF\xFF\xFF"
|
||||
cipher2.update(header_ba)
|
||||
header_ba[:3] = b"\xFF\xFF\xFF"
|
||||
ct_test = cipher2.encrypt(data_ba)
|
||||
data_ba[:3] = b"\xFF\xFF\xFF"
|
||||
tag_test = cipher2.digest()
|
||||
|
||||
self.assertEqual(ct, ct_test)
|
||||
self.assertEqual(tag, tag_test)
|
||||
self.assertEqual(cipher1.nonce, cipher2.nonce)
|
||||
|
||||
# Decrypt
|
||||
cipher3 = AES.new(self.key_128,
|
||||
key_ba = bytearray(self.key_128)
|
||||
nonce_ba = bytearray(self.nonce_96)
|
||||
header_ba = bytearray(self.data_128)
|
||||
del data_ba
|
||||
|
||||
cipher4 = AES.new(key_ba,
|
||||
AES.MODE_CCM,
|
||||
nonce=nonce_ba)
|
||||
key_ba[:3] = b"\xFF\xFF\xFF"
|
||||
nonce_ba[:3] = b"\xFF\xFF\xFF"
|
||||
cipher4.update(header_ba)
|
||||
header_ba[:3] = b"\xFF\xFF\xFF"
|
||||
pt_test = cipher4.decrypt_and_verify(bytearray(ct_test), bytearray(tag_test))
|
||||
|
||||
self.assertEqual(self.data_128, pt_test)
|
||||
|
||||
def test_memoryview(self):
|
||||
|
||||
# Encrypt
|
||||
key_mv = memoryview(bytearray(self.key_128))
|
||||
nonce_mv = memoryview(bytearray(self.nonce_96))
|
||||
header_mv = memoryview(bytearray(self.data_128))
|
||||
data_mv = memoryview(bytearray(self.data_128))
|
||||
|
||||
cipher1 = AES.new(self.key_128,
|
||||
AES.MODE_CCM,
|
||||
nonce=self.nonce_96)
|
||||
cipher3.update(self.data_128)
|
||||
ref3 = cipher3.decrypt(ref1)
|
||||
cipher1.update(self.data_128)
|
||||
ct = cipher1.encrypt(self.data_128)
|
||||
tag = cipher1.digest()
|
||||
|
||||
cipher4 = AES.new(bytearray(self.key_128),
|
||||
cipher2 = AES.new(key_mv,
|
||||
AES.MODE_CCM,
|
||||
nonce=bytearray(self.nonce_96))
|
||||
cipher4.update(bytearray(self.data_128))
|
||||
ref4 = cipher4.decrypt(bytearray(ref1))
|
||||
nonce=nonce_mv)
|
||||
key_mv[:3] = b"\xFF\xFF\xFF"
|
||||
nonce_mv[:3] = b"\xFF\xFF\xFF"
|
||||
cipher2.update(header_mv)
|
||||
header_mv[:3] = b"\xFF\xFF\xFF"
|
||||
ct_test = cipher2.encrypt(data_mv)
|
||||
data_mv[:3] = b"\xFF\xFF\xFF"
|
||||
tag_test = cipher2.digest()
|
||||
|
||||
self.assertEqual(ref3, ref4)
|
||||
self.assertEqual(ct, ct_test)
|
||||
self.assertEqual(tag, tag_test)
|
||||
self.assertEqual(cipher1.nonce, cipher2.nonce)
|
||||
|
||||
cipher3.verify(bytearray(tag))
|
||||
# Decrypt
|
||||
key_mv = memoryview(bytearray(self.key_128))
|
||||
nonce_mv = memoryview(bytearray(self.nonce_96))
|
||||
header_mv = memoryview(bytearray(self.data_128))
|
||||
del data_mv
|
||||
|
||||
cipher4 = AES.new(key_mv,
|
||||
AES.MODE_CCM,
|
||||
nonce=nonce_mv)
|
||||
key_mv[:3] = b"\xFF\xFF\xFF"
|
||||
nonce_mv[:3] = b"\xFF\xFF\xFF"
|
||||
cipher4.update(header_mv)
|
||||
header_mv[:3] = b"\xFF\xFF\xFF"
|
||||
pt_test = cipher4.decrypt_and_verify(memoryview(ct_test), memoryview(tag_test))
|
||||
|
||||
self.assertEqual(self.data_128, pt_test)
|
||||
|
||||
import types
|
||||
if _memoryview is types.NoneType:
|
||||
del test_memoryview
|
||||
|
||||
|
||||
class CcmFSMTests(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -134,14 +134,28 @@ else:
|
|||
_memoryview = memoryview
|
||||
|
||||
|
||||
def _copy_bytes(start, seq):
|
||||
"""Return a copy of a sequence (byte string, byte array, memoryview)
|
||||
starting from a certain index"""
|
||||
def _copy_bytes(start, end, seq):
|
||||
"""Return an immutable copy of a sequence (byte string, byte array, memoryview)
|
||||
in a certain interval [start:seq]"""
|
||||
|
||||
if isinstance(seq, _memoryview):
|
||||
return seq[start:].tobytes()
|
||||
return seq[start:end].tobytes()
|
||||
elif isinstance(seq, bytearray):
|
||||
return bytes(seq[start:end])
|
||||
else:
|
||||
return seq[start:]
|
||||
return seq[start:end]
|
||||
|
||||
|
||||
def _is_immutable(data):
|
||||
if byte_string(data):
|
||||
return True
|
||||
elif isinstance(data, _memoryview) and data.readonly:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_mutable(data):
|
||||
return not _is_immutable(data)
|
||||
|
||||
|
||||
del sys
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue