Support for bytearrays and memoryviews in CMAC

This commit is contained in:
Helder Eijs 2018-03-30 23:12:57 +02:00
parent 198bc33333
commit e2f83b71ad
2 changed files with 71 additions and 8 deletions

View file

@ -20,7 +20,9 @@
# SOFTWARE.
# ===================================================================
from Crypto.Util.py3compat import b, bchr, bord, tobytes
import sys
from Crypto.Util.py3compat import bord, tobytes
from binascii import unhexlify
@ -33,6 +35,12 @@ from Crypto.Random import get_random_bytes
digest_size = None
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
_memoryview = None
else:
_memoryview = memoryview
def _shift_bytes(bs, xor_lsb=0):
num = (bytes_to_long(bs) << 1) ^ xor_lsb
return long_to_bytes(num, len(bs))[-len(bs):]
@ -77,7 +85,7 @@ class CMAC(object):
self._mac_tag = None
# Compute sub-keys
zero_block = bchr(0) * ciphermod.block_size
zero_block = b'\x00' * ciphermod.block_size
cipher = ciphermod.new(key,
ciphermod.MODE_ECB,
**self._cipher_params)
@ -98,7 +106,7 @@ class CMAC(object):
**self._cipher_params)
# Cache for outstanding data to authenticate
self._cache = b("")
self._cache = b""
# Last two pieces of ciphertext produced
self._last_ct = self._last_pt = zero_block
@ -128,16 +136,19 @@ class CMAC(object):
msg = msg[filler:]
self._update(self._cache)
self._cache = b("")
self._cache = b""
update_len, remain = divmod(len(msg), self.digest_size)
update_len *= self.digest_size
if remain > 0:
self._update(msg[:update_len])
self._cache = msg[update_len:]
if isinstance(msg, _memoryview):
self._cache = msg[update_len:].tobytes()
else:
self._cache = msg[update_len:]
else:
self._update(msg)
self._cache = b("")
self._cache = b""
return self
def _update(self, data_block):
@ -201,8 +212,8 @@ class CMAC(object):
pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt)
else:
# Last block is partial (or message length is zero)
ext = self._cache + bchr(0x80) +\
bchr(0) * (self.digest_size - len(self._cache) - 1)
ext = self._cache + b'\x80' +\
b'\x00' * (self.digest_size - len(self._cache) - 1)
pt = strxor(strxor(self._last_ct, self._k2), ext)
cipher = self._factory.new(self._key,