GH#238: fix for incorrect CMAC after copying object

This commit is contained in:
Helder Eijs 2018-11-17 08:44:12 +01:00
parent 48b6a40be0
commit dfb0b5840e
2 changed files with 82 additions and 81 deletions

View file

@ -20,8 +20,6 @@
# SOFTWARE.
# ===================================================================
import sys
from Crypto.Util.py3compat import bord, tobytes, _copy_bytes
from binascii import unhexlify
@ -52,48 +50,35 @@ class CMAC(object):
def __init__(self, key, msg, ciphermod, cipher_params, mac_len):
if ciphermod is None:
raise TypeError("ciphermod must be specified (try AES)")
self.digest_size = mac_len
self._key = _copy_bytes(None, None, key)
self._factory = ciphermod
if cipher_params is None:
self._cipher_params = {}
else:
self._cipher_params = dict(cipher_params)
self._mac_len = mac_len or ciphermod.block_size
if self._mac_len < 4:
raise ValueError("MAC tag length must be at least 4 bytes long")
if self._mac_len > ciphermod.block_size:
raise ValueError("MAC tag length cannot be larger than a cipher block (%d) bytes" % ciphermod.block_size)
self._cipher_params = cipher_params
self._block_size = bs = ciphermod.block_size
self._mac_tag = None
# Section 5.3 of NIST SP 800 38B and Appendix B
if ciphermod.block_size == 8:
if bs == 8:
const_Rb = 0x1B
self._max_size = 8 * (2 ** 21)
elif ciphermod.block_size == 16:
elif bs == 16:
const_Rb = 0x87
self._max_size = 16 * (2 ** 48)
else:
raise TypeError("CMAC requires a cipher with a block size"
"of 8 or 16 bytes, not %d" %
(ciphermod.block_size,))
# Size of the final MAC tag, in bytes
self.digest_size = ciphermod.block_size
self._mac_tag = None
" of 8 or 16 bytes, not %d" % bs)
# Compute sub-keys
zero_block = b'\x00' * ciphermod.block_size
cipher = ciphermod.new(key,
ciphermod.MODE_ECB,
**self._cipher_params)
l = cipher.encrypt(zero_block)
if bord(l[0]) & 0x80:
self._k1 = _shift_bytes(l, const_Rb)
zero_block = b'\x00' * bs
self._ecb = ciphermod.new(key,
ciphermod.MODE_ECB,
**self._cipher_params)
L = self._ecb.encrypt(zero_block)
if bord(L[0]) & 0x80:
self._k1 = _shift_bytes(L, const_Rb)
else:
self._k1 = _shift_bytes(l)
self._k1 = _shift_bytes(L)
if bord(self._k1[0]) & 0x80:
self._k2 = _shift_bytes(self._k1, const_Rb)
else:
@ -106,11 +91,14 @@ class CMAC(object):
**self._cipher_params)
# Cache for outstanding data to authenticate
self._cache = b""
self._cache = bytearray(bs)
self._cache_n = 0
# Last two pieces of ciphertext produced
self._last_ct = self._last_pt = zero_block
self._before_last_ct = None
# Last piece of ciphertext produced
self._last_ct = zero_block
# Last block that was encrypted with AES
self._last_pt = None
# Counter for total message size
self._data_size = 0
@ -125,49 +113,48 @@ class CMAC(object):
data (byte string/byte array/memoryview): The next chunk of data
"""
# Mutable values must be copied if cached
self._data_size += len(msg)
bs = self._block_size
if len(self._cache) > 0:
filler = min(self.digest_size - len(self._cache), len(msg))
self._cache += msg[:filler]
if self._cache_n > 0:
filler = min(bs - self._cache_n, len(msg))
self._cache[self._cache_n:self._cache_n+filler] = msg[:filler]
self._cache_n += filler
if len(self._cache) < self.digest_size:
if self._cache_n < bs:
return self
msg = msg[filler:]
msg = memoryview(msg)[filler:]
self._update(self._cache)
self._cache = b""
self._cache_n = 0
update_len, remain = divmod(len(msg), self.digest_size)
update_len *= self.digest_size
remain = len(msg) % bs
if remain > 0:
self._update(msg[:update_len])
self._cache = _copy_bytes(update_len, None, msg)
self._update(msg[:-remain])
self._cache[:remain] = msg[-remain:]
else:
self._update(msg)
self._cache = b""
self._cache_n = remain
return self
def _update(self, data_block):
"""Update a block aligned to the block boundary"""
bs = self._block_size
assert len(data_block) % bs == 0
if len(data_block) == 0:
return
assert len(data_block) % self.digest_size == 0
ct = self._cbc.encrypt(data_block)
if len(data_block) == self.digest_size:
self._before_last_ct = self._last_ct
if len(data_block) == bs:
second_last = self._last_ct
assert len(second_last) == bs
else:
self._before_last_ct = ct[-self.digest_size * 2:-self.digest_size]
self._last_ct = ct[-self.digest_size:]
# data_block can mutable
self._last_pt = _copy_bytes(-self.digest_size, None, data_block)
second_last = ct[-bs*2:-bs]
assert len(second_last) == bs
self._last_ct = ct[-bs:]
self._last_pt = strxor(second_last, data_block[-bs:])
def copy(self):
"""Return a copy ("clone") of the CMAC object.
@ -180,19 +167,14 @@ class CMAC(object):
:return: An :class:`CMAC`
"""
obj = CMAC(self._key,
b"",
ciphermod=self._factory,
cipher_params=self._cipher_params,
mac_len=self._mac_len)
obj = self.__new__(CMAC)
obj.__dict__ = self.__dict__.copy()
obj._cbc = self._factory.new(self._key,
self._factory.MODE_CBC,
self._last_ct,
**self._cipher_params)
for m in ['_mac_tag', '_last_ct', '_before_last_ct', '_cache',
'_data_size', '_max_size']:
setattr(obj, m, getattr(self, m))
obj._cache = self._cache[:]
obj._last_ct = self._last_ct[:]
return obj
def digest(self):
@ -204,27 +186,25 @@ class CMAC(object):
:rtype: byte string
"""
bs = self._block_size
if self._mac_tag is not None:
return self._mac_tag[:self._mac_len]
return self._mac_tag
if self._data_size > self._max_size:
raise ValueError("MAC is unsafe for this message")
if len(self._cache) == 0 and self._before_last_ct is not None:
if self._cache_n == 0 and self._data_size > 0:
# Last block was full
pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt)
pt = strxor(self._last_pt, self._k1)
else:
# Last block is partial (or message length is zero)
ext = self._cache + b'\x80' +\
b'\x00' * (self.digest_size - len(self._cache) - 1)
pt = strxor(strxor(self._last_ct, self._k2), ext)
self._cache[self._cache_n:] = b'\x80' + b'\x00' * (bs - self._cache_n - 1)
pt = strxor(strxor(self._last_ct, self._cache), self._k2)
cipher = self._factory.new(self._key,
self._factory.MODE_ECB,
**self._cipher_params)
self._mac_tag = cipher.encrypt(pt)
self._mac_tag = self._ecb.encrypt(pt)[:self.digest_size]
return self._mac_tag[:self._mac_len]
return self._mac_tag
def hexdigest(self):
"""Return the **printable** MAC tag of the message authenticated so far.
@ -296,4 +276,18 @@ def new(key, msg=None, ciphermod=None, cipher_params=None, mac_len=None):
A :class:`CMAC` object
"""
if ciphermod is None:
raise TypeError("ciphermod must be specified (try AES)")
cipher_params = {} if cipher_params is None else dict(cipher_params)
if mac_len is None:
mac_len = ciphermod.block_size
if mac_len < 4:
raise ValueError("MAC tag length must be at least 4 bytes long")
if mac_len > ciphermod.block_size:
raise ValueError("MAC tag length cannot be larger than a cipher block (%d) bytes" % ciphermod.block_size)
return CMAC(key, msg, ciphermod, cipher_params, mac_len)