diff --git a/lib/Crypto/Hash/CMAC.py b/lib/Crypto/Hash/CMAC.py index 7b87ffcb..152d88a0 100644 --- a/lib/Crypto/Hash/CMAC.py +++ b/lib/Crypto/Hash/CMAC.py @@ -20,7 +20,9 @@ # SOFTWARE. # =================================================================== -from Crypto.Util.py3compat import b, bchr, bord, tobytes +import sys + +from Crypto.Util.py3compat import bord, tobytes from binascii import unhexlify @@ -33,6 +35,12 @@ from Crypto.Random import get_random_bytes digest_size = None +if sys.version_info[0] == 2 and sys.version_info[1] < 7: + _memoryview = None +else: + _memoryview = memoryview + + def _shift_bytes(bs, xor_lsb=0): num = (bytes_to_long(bs) << 1) ^ xor_lsb return long_to_bytes(num, len(bs))[-len(bs):] @@ -77,7 +85,7 @@ class CMAC(object): self._mac_tag = None # Compute sub-keys - zero_block = bchr(0) * ciphermod.block_size + zero_block = b'\x00' * ciphermod.block_size cipher = ciphermod.new(key, ciphermod.MODE_ECB, **self._cipher_params) @@ -98,7 +106,7 @@ class CMAC(object): **self._cipher_params) # Cache for outstanding data to authenticate - self._cache = b("") + self._cache = b"" # Last two pieces of ciphertext produced self._last_ct = self._last_pt = zero_block @@ -128,16 +136,19 @@ class CMAC(object): msg = msg[filler:] self._update(self._cache) - self._cache = b("") + self._cache = b"" update_len, remain = divmod(len(msg), self.digest_size) update_len *= self.digest_size if remain > 0: self._update(msg[:update_len]) - self._cache = msg[update_len:] + if isinstance(msg, _memoryview): + self._cache = msg[update_len:].tobytes() + else: + self._cache = msg[update_len:] else: self._update(msg) - self._cache = b("") + self._cache = b"" return self def _update(self, data_block): @@ -201,8 +212,8 @@ class CMAC(object): pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt) else: # Last block is partial (or message length is zero) - ext = self._cache + bchr(0x80) +\ - bchr(0) * (self.digest_size - len(self._cache) - 1) + ext = self._cache + b'\x80' +\ + b'\x00' * (self.digest_size - len(self._cache) - 1) pt = strxor(strxor(self._last_ct, self._k2), ext) cipher = self._factory.new(self._key, diff --git a/lib/Crypto/SelfTest/Hash/test_CMAC.py b/lib/Crypto/SelfTest/Hash/test_CMAC.py index fe320759..a785b313 100644 --- a/lib/Crypto/SelfTest/Hash/test_CMAC.py +++ b/lib/Crypto/SelfTest/Hash/test_CMAC.py @@ -33,6 +33,7 @@ """Self-test suite for Crypto.Hash.CMAC""" +import sys import unittest from Crypto.Util.py3compat import tobytes @@ -266,6 +267,54 @@ class MultipleUpdates(unittest.TestCase): self.assertEqual(ref_mac, mac.digest()) +class ByteArrayTests(unittest.TestCase): + + def runTest(self): + + key = b"0" * 16 + data = b"\x00\x01\x02" + + key_ba = bytearray(key) + data_ba = bytearray(data) + + # Data and key can be a bytearray (during initialization) + h1 = CMAC.new(key, data, ciphermod=AES) + h2 = CMAC.new(key_ba, data_ba, ciphermod=AES) + self.assertEqual(h1.digest(), h2.digest()) + + # Data can be a bytearray (during operation) + h1 = CMAC.new(key, ciphermod=AES) + h2 = CMAC.new(key, ciphermod=AES) + h1.update(data) + h2.update(data_ba) + self.assertEqual(h1.digest(), h2.digest()) + + +class MemoryViewTests(unittest.TestCase): + + def runTest(self): + + key = b"0" * 16 + data = b"\x00\x01\x02" + + mv_ro = [ memoryview(x) for x in (key, data) ] + mv_rw = [ memoryview(bytearray(x)) for x in (key, data) ] + + for key_mv, data_mv in (mv_ro, mv_rw): + + # Data and key can be a memoryview (during initialization) + h1 = CMAC.new(key, data, ciphermod=AES) + h2 = CMAC.new(key_mv, data_mv, ciphermod=AES) + self.assertEqual(h1.digest(), h2.digest()) + + # Data can be a memoryview (during operation) + h1 = CMAC.new(key, ciphermod=AES) + h2 = CMAC.new(key, ciphermod=AES) + h1.update(data) + h2.update(data_mv) + self.assertEqual(h1.digest(), h2.digest()) + + def get_tests(config={}): global test_data from common import make_mac_tests @@ -279,6 +328,9 @@ def get_tests(config={}): tests = make_mac_tests(CMAC, "CMAC", params_test_data) tests.append(MultipleUpdates()) + tests.append(ByteArrayTests()) + if not (sys.version_info[0] == 2 and sys.version_info[1] < 7): + tests.append(MemoryViewTests()) return tests