Support for bytearrays and memoryviews in CMAC

This commit is contained in:
Helder Eijs 2018-03-30 23:12:57 +02:00
parent 198bc33333
commit e2f83b71ad
2 changed files with 71 additions and 8 deletions

View file

@ -20,7 +20,9 @@
# SOFTWARE.
# ===================================================================
from Crypto.Util.py3compat import b, bchr, bord, tobytes
import sys
from Crypto.Util.py3compat import bord, tobytes
from binascii import unhexlify
@ -33,6 +35,12 @@ from Crypto.Random import get_random_bytes
digest_size = None
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
_memoryview = None
else:
_memoryview = memoryview
def _shift_bytes(bs, xor_lsb=0):
num = (bytes_to_long(bs) << 1) ^ xor_lsb
return long_to_bytes(num, len(bs))[-len(bs):]
@ -77,7 +85,7 @@ class CMAC(object):
self._mac_tag = None
# Compute sub-keys
zero_block = bchr(0) * ciphermod.block_size
zero_block = b'\x00' * ciphermod.block_size
cipher = ciphermod.new(key,
ciphermod.MODE_ECB,
**self._cipher_params)
@ -98,7 +106,7 @@ class CMAC(object):
**self._cipher_params)
# Cache for outstanding data to authenticate
self._cache = b("")
self._cache = b""
# Last two pieces of ciphertext produced
self._last_ct = self._last_pt = zero_block
@ -128,16 +136,19 @@ class CMAC(object):
msg = msg[filler:]
self._update(self._cache)
self._cache = b("")
self._cache = b""
update_len, remain = divmod(len(msg), self.digest_size)
update_len *= self.digest_size
if remain > 0:
self._update(msg[:update_len])
if isinstance(msg, _memoryview):
self._cache = msg[update_len:].tobytes()
else:
self._cache = msg[update_len:]
else:
self._update(msg)
self._cache = b("")
self._cache = b""
return self
def _update(self, data_block):
@ -201,8 +212,8 @@ class CMAC(object):
pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt)
else:
# Last block is partial (or message length is zero)
ext = self._cache + bchr(0x80) +\
bchr(0) * (self.digest_size - len(self._cache) - 1)
ext = self._cache + b'\x80' +\
b'\x00' * (self.digest_size - len(self._cache) - 1)
pt = strxor(strxor(self._last_ct, self._k2), ext)
cipher = self._factory.new(self._key,

View file

@ -33,6 +33,7 @@
"""Self-test suite for Crypto.Hash.CMAC"""
import sys
import unittest
from Crypto.Util.py3compat import tobytes
@ -266,6 +267,54 @@ class MultipleUpdates(unittest.TestCase):
self.assertEqual(ref_mac, mac.digest())
class ByteArrayTests(unittest.TestCase):
def runTest(self):
key = b"0" * 16
data = b"\x00\x01\x02"
key_ba = bytearray(key)
data_ba = bytearray(data)
# Data and key can be a bytearray (during initialization)
h1 = CMAC.new(key, data, ciphermod=AES)
h2 = CMAC.new(key_ba, data_ba, ciphermod=AES)
self.assertEqual(h1.digest(), h2.digest())
# Data can be a bytearray (during operation)
h1 = CMAC.new(key, ciphermod=AES)
h2 = CMAC.new(key, ciphermod=AES)
h1.update(data)
h2.update(data_ba)
self.assertEqual(h1.digest(), h2.digest())
class MemoryViewTests(unittest.TestCase):
def runTest(self):
key = b"0" * 16
data = b"\x00\x01\x02"
mv_ro = [ memoryview(x) for x in (key, data) ]
mv_rw = [ memoryview(bytearray(x)) for x in (key, data) ]
for key_mv, data_mv in (mv_ro, mv_rw):
# Data and key can be a memoryview (during initialization)
h1 = CMAC.new(key, data, ciphermod=AES)
h2 = CMAC.new(key_mv, data_mv, ciphermod=AES)
self.assertEqual(h1.digest(), h2.digest())
# Data can be a memoryview (during operation)
h1 = CMAC.new(key, ciphermod=AES)
h2 = CMAC.new(key, ciphermod=AES)
h1.update(data)
h2.update(data_mv)
self.assertEqual(h1.digest(), h2.digest())
def get_tests(config={}):
global test_data
from common import make_mac_tests
@ -279,6 +328,9 @@ def get_tests(config={}):
tests = make_mac_tests(CMAC, "CMAC", params_test_data)
tests.append(MultipleUpdates())
tests.append(ByteArrayTests())
if not (sys.version_info[0] == 2 and sys.version_info[1] < 7):
tests.append(MemoryViewTests())
return tests