mirror of
https://github.com/Legrandin/pycryptodome.git
synced 2025-12-08 05:19:46 +00:00
GH#238: fix for incorrect CMAC after copying object
This commit is contained in:
parent
48b6a40be0
commit
dfb0b5840e
2 changed files with 82 additions and 81 deletions
|
|
@ -20,8 +20,6 @@
|
|||
# SOFTWARE.
|
||||
# ===================================================================
|
||||
|
||||
import sys
|
||||
|
||||
from Crypto.Util.py3compat import bord, tobytes, _copy_bytes
|
||||
|
||||
from binascii import unhexlify
|
||||
|
|
@ -52,48 +50,35 @@ class CMAC(object):
|
|||
|
||||
def __init__(self, key, msg, ciphermod, cipher_params, mac_len):
|
||||
|
||||
if ciphermod is None:
|
||||
raise TypeError("ciphermod must be specified (try AES)")
|
||||
self.digest_size = mac_len
|
||||
|
||||
self._key = _copy_bytes(None, None, key)
|
||||
self._factory = ciphermod
|
||||
if cipher_params is None:
|
||||
self._cipher_params = {}
|
||||
else:
|
||||
self._cipher_params = dict(cipher_params)
|
||||
self._mac_len = mac_len or ciphermod.block_size
|
||||
|
||||
if self._mac_len < 4:
|
||||
raise ValueError("MAC tag length must be at least 4 bytes long")
|
||||
if self._mac_len > ciphermod.block_size:
|
||||
raise ValueError("MAC tag length cannot be larger than a cipher block (%d) bytes" % ciphermod.block_size)
|
||||
self._cipher_params = cipher_params
|
||||
self._block_size = bs = ciphermod.block_size
|
||||
self._mac_tag = None
|
||||
|
||||
# Section 5.3 of NIST SP 800 38B and Appendix B
|
||||
if ciphermod.block_size == 8:
|
||||
if bs == 8:
|
||||
const_Rb = 0x1B
|
||||
self._max_size = 8 * (2 ** 21)
|
||||
elif ciphermod.block_size == 16:
|
||||
elif bs == 16:
|
||||
const_Rb = 0x87
|
||||
self._max_size = 16 * (2 ** 48)
|
||||
else:
|
||||
raise TypeError("CMAC requires a cipher with a block size"
|
||||
"of 8 or 16 bytes, not %d" %
|
||||
(ciphermod.block_size,))
|
||||
|
||||
# Size of the final MAC tag, in bytes
|
||||
self.digest_size = ciphermod.block_size
|
||||
self._mac_tag = None
|
||||
" of 8 or 16 bytes, not %d" % bs)
|
||||
|
||||
# Compute sub-keys
|
||||
zero_block = b'\x00' * ciphermod.block_size
|
||||
cipher = ciphermod.new(key,
|
||||
ciphermod.MODE_ECB,
|
||||
**self._cipher_params)
|
||||
l = cipher.encrypt(zero_block)
|
||||
if bord(l[0]) & 0x80:
|
||||
self._k1 = _shift_bytes(l, const_Rb)
|
||||
zero_block = b'\x00' * bs
|
||||
self._ecb = ciphermod.new(key,
|
||||
ciphermod.MODE_ECB,
|
||||
**self._cipher_params)
|
||||
L = self._ecb.encrypt(zero_block)
|
||||
if bord(L[0]) & 0x80:
|
||||
self._k1 = _shift_bytes(L, const_Rb)
|
||||
else:
|
||||
self._k1 = _shift_bytes(l)
|
||||
self._k1 = _shift_bytes(L)
|
||||
if bord(self._k1[0]) & 0x80:
|
||||
self._k2 = _shift_bytes(self._k1, const_Rb)
|
||||
else:
|
||||
|
|
@ -106,11 +91,14 @@ class CMAC(object):
|
|||
**self._cipher_params)
|
||||
|
||||
# Cache for outstanding data to authenticate
|
||||
self._cache = b""
|
||||
self._cache = bytearray(bs)
|
||||
self._cache_n = 0
|
||||
|
||||
# Last two pieces of ciphertext produced
|
||||
self._last_ct = self._last_pt = zero_block
|
||||
self._before_last_ct = None
|
||||
# Last piece of ciphertext produced
|
||||
self._last_ct = zero_block
|
||||
|
||||
# Last block that was encrypted with AES
|
||||
self._last_pt = None
|
||||
|
||||
# Counter for total message size
|
||||
self._data_size = 0
|
||||
|
|
@ -125,49 +113,48 @@ class CMAC(object):
|
|||
data (byte string/byte array/memoryview): The next chunk of data
|
||||
"""
|
||||
|
||||
# Mutable values must be copied if cached
|
||||
|
||||
self._data_size += len(msg)
|
||||
bs = self._block_size
|
||||
|
||||
if len(self._cache) > 0:
|
||||
filler = min(self.digest_size - len(self._cache), len(msg))
|
||||
self._cache += msg[:filler]
|
||||
if self._cache_n > 0:
|
||||
filler = min(bs - self._cache_n, len(msg))
|
||||
self._cache[self._cache_n:self._cache_n+filler] = msg[:filler]
|
||||
self._cache_n += filler
|
||||
|
||||
if len(self._cache) < self.digest_size:
|
||||
if self._cache_n < bs:
|
||||
return self
|
||||
|
||||
msg = msg[filler:]
|
||||
msg = memoryview(msg)[filler:]
|
||||
self._update(self._cache)
|
||||
self._cache = b""
|
||||
self._cache_n = 0
|
||||
|
||||
update_len, remain = divmod(len(msg), self.digest_size)
|
||||
update_len *= self.digest_size
|
||||
remain = len(msg) % bs
|
||||
if remain > 0:
|
||||
self._update(msg[:update_len])
|
||||
self._cache = _copy_bytes(update_len, None, msg)
|
||||
self._update(msg[:-remain])
|
||||
self._cache[:remain] = msg[-remain:]
|
||||
else:
|
||||
self._update(msg)
|
||||
self._cache = b""
|
||||
self._cache_n = remain
|
||||
return self
|
||||
|
||||
def _update(self, data_block):
|
||||
"""Update a block aligned to the block boundary"""
|
||||
|
||||
bs = self._block_size
|
||||
assert len(data_block) % bs == 0
|
||||
|
||||
if len(data_block) == 0:
|
||||
return
|
||||
|
||||
assert len(data_block) % self.digest_size == 0
|
||||
|
||||
ct = self._cbc.encrypt(data_block)
|
||||
|
||||
if len(data_block) == self.digest_size:
|
||||
self._before_last_ct = self._last_ct
|
||||
if len(data_block) == bs:
|
||||
second_last = self._last_ct
|
||||
assert len(second_last) == bs
|
||||
else:
|
||||
self._before_last_ct = ct[-self.digest_size * 2:-self.digest_size]
|
||||
self._last_ct = ct[-self.digest_size:]
|
||||
|
||||
# data_block can mutable
|
||||
self._last_pt = _copy_bytes(-self.digest_size, None, data_block)
|
||||
second_last = ct[-bs*2:-bs]
|
||||
assert len(second_last) == bs
|
||||
self._last_ct = ct[-bs:]
|
||||
self._last_pt = strxor(second_last, data_block[-bs:])
|
||||
|
||||
def copy(self):
|
||||
"""Return a copy ("clone") of the CMAC object.
|
||||
|
|
@ -180,19 +167,14 @@ class CMAC(object):
|
|||
:return: An :class:`CMAC`
|
||||
"""
|
||||
|
||||
obj = CMAC(self._key,
|
||||
b"",
|
||||
ciphermod=self._factory,
|
||||
cipher_params=self._cipher_params,
|
||||
mac_len=self._mac_len)
|
||||
|
||||
obj = self.__new__(CMAC)
|
||||
obj.__dict__ = self.__dict__.copy()
|
||||
obj._cbc = self._factory.new(self._key,
|
||||
self._factory.MODE_CBC,
|
||||
self._last_ct,
|
||||
**self._cipher_params)
|
||||
for m in ['_mac_tag', '_last_ct', '_before_last_ct', '_cache',
|
||||
'_data_size', '_max_size']:
|
||||
setattr(obj, m, getattr(self, m))
|
||||
obj._cache = self._cache[:]
|
||||
obj._last_ct = self._last_ct[:]
|
||||
return obj
|
||||
|
||||
def digest(self):
|
||||
|
|
@ -204,27 +186,25 @@ class CMAC(object):
|
|||
:rtype: byte string
|
||||
"""
|
||||
|
||||
bs = self._block_size
|
||||
|
||||
if self._mac_tag is not None:
|
||||
return self._mac_tag[:self._mac_len]
|
||||
return self._mac_tag
|
||||
|
||||
if self._data_size > self._max_size:
|
||||
raise ValueError("MAC is unsafe for this message")
|
||||
|
||||
if len(self._cache) == 0 and self._before_last_ct is not None:
|
||||
if self._cache_n == 0 and self._data_size > 0:
|
||||
# Last block was full
|
||||
pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt)
|
||||
pt = strxor(self._last_pt, self._k1)
|
||||
else:
|
||||
# Last block is partial (or message length is zero)
|
||||
ext = self._cache + b'\x80' +\
|
||||
b'\x00' * (self.digest_size - len(self._cache) - 1)
|
||||
pt = strxor(strxor(self._last_ct, self._k2), ext)
|
||||
self._cache[self._cache_n:] = b'\x80' + b'\x00' * (bs - self._cache_n - 1)
|
||||
pt = strxor(strxor(self._last_ct, self._cache), self._k2)
|
||||
|
||||
cipher = self._factory.new(self._key,
|
||||
self._factory.MODE_ECB,
|
||||
**self._cipher_params)
|
||||
self._mac_tag = cipher.encrypt(pt)
|
||||
self._mac_tag = self._ecb.encrypt(pt)[:self.digest_size]
|
||||
|
||||
return self._mac_tag[:self._mac_len]
|
||||
return self._mac_tag
|
||||
|
||||
def hexdigest(self):
|
||||
"""Return the **printable** MAC tag of the message authenticated so far.
|
||||
|
|
@ -296,4 +276,18 @@ def new(key, msg=None, ciphermod=None, cipher_params=None, mac_len=None):
|
|||
A :class:`CMAC` object
|
||||
"""
|
||||
|
||||
if ciphermod is None:
|
||||
raise TypeError("ciphermod must be specified (try AES)")
|
||||
|
||||
cipher_params = {} if cipher_params is None else dict(cipher_params)
|
||||
|
||||
if mac_len is None:
|
||||
mac_len = ciphermod.block_size
|
||||
|
||||
if mac_len < 4:
|
||||
raise ValueError("MAC tag length must be at least 4 bytes long")
|
||||
|
||||
if mac_len > ciphermod.block_size:
|
||||
raise ValueError("MAC tag length cannot be larger than a cipher block (%d) bytes" % ciphermod.block_size)
|
||||
|
||||
return CMAC(key, msg, ciphermod, cipher_params, mac_len)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue