mirror of
https://github.com/Legrandin/pycryptodome.git
synced 2025-12-08 05:19:46 +00:00
Support for bytearrays and memoryviews in CMAC
This commit is contained in:
parent
198bc33333
commit
e2f83b71ad
2 changed files with 71 additions and 8 deletions
|
|
@ -20,7 +20,9 @@
|
|||
# SOFTWARE.
|
||||
# ===================================================================
|
||||
|
||||
from Crypto.Util.py3compat import b, bchr, bord, tobytes
|
||||
import sys
|
||||
|
||||
from Crypto.Util.py3compat import bord, tobytes
|
||||
|
||||
from binascii import unhexlify
|
||||
|
||||
|
|
@ -33,6 +35,12 @@ from Crypto.Random import get_random_bytes
|
|||
digest_size = None
|
||||
|
||||
|
||||
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
|
||||
_memoryview = None
|
||||
else:
|
||||
_memoryview = memoryview
|
||||
|
||||
|
||||
def _shift_bytes(bs, xor_lsb=0):
|
||||
num = (bytes_to_long(bs) << 1) ^ xor_lsb
|
||||
return long_to_bytes(num, len(bs))[-len(bs):]
|
||||
|
|
@ -77,7 +85,7 @@ class CMAC(object):
|
|||
self._mac_tag = None
|
||||
|
||||
# Compute sub-keys
|
||||
zero_block = bchr(0) * ciphermod.block_size
|
||||
zero_block = b'\x00' * ciphermod.block_size
|
||||
cipher = ciphermod.new(key,
|
||||
ciphermod.MODE_ECB,
|
||||
**self._cipher_params)
|
||||
|
|
@ -98,7 +106,7 @@ class CMAC(object):
|
|||
**self._cipher_params)
|
||||
|
||||
# Cache for outstanding data to authenticate
|
||||
self._cache = b("")
|
||||
self._cache = b""
|
||||
|
||||
# Last two pieces of ciphertext produced
|
||||
self._last_ct = self._last_pt = zero_block
|
||||
|
|
@ -128,16 +136,19 @@ class CMAC(object):
|
|||
|
||||
msg = msg[filler:]
|
||||
self._update(self._cache)
|
||||
self._cache = b("")
|
||||
self._cache = b""
|
||||
|
||||
update_len, remain = divmod(len(msg), self.digest_size)
|
||||
update_len *= self.digest_size
|
||||
if remain > 0:
|
||||
self._update(msg[:update_len])
|
||||
if isinstance(msg, _memoryview):
|
||||
self._cache = msg[update_len:].tobytes()
|
||||
else:
|
||||
self._cache = msg[update_len:]
|
||||
else:
|
||||
self._update(msg)
|
||||
self._cache = b("")
|
||||
self._cache = b""
|
||||
return self
|
||||
|
||||
def _update(self, data_block):
|
||||
|
|
@ -201,8 +212,8 @@ class CMAC(object):
|
|||
pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt)
|
||||
else:
|
||||
# Last block is partial (or message length is zero)
|
||||
ext = self._cache + bchr(0x80) +\
|
||||
bchr(0) * (self.digest_size - len(self._cache) - 1)
|
||||
ext = self._cache + b'\x80' +\
|
||||
b'\x00' * (self.digest_size - len(self._cache) - 1)
|
||||
pt = strxor(strxor(self._last_ct, self._k2), ext)
|
||||
|
||||
cipher = self._factory.new(self._key,
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@
|
|||
|
||||
"""Self-test suite for Crypto.Hash.CMAC"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from Crypto.Util.py3compat import tobytes
|
||||
|
|
@ -266,6 +267,54 @@ class MultipleUpdates(unittest.TestCase):
|
|||
self.assertEqual(ref_mac, mac.digest())
|
||||
|
||||
|
||||
class ByteArrayTests(unittest.TestCase):
|
||||
|
||||
def runTest(self):
|
||||
|
||||
key = b"0" * 16
|
||||
data = b"\x00\x01\x02"
|
||||
|
||||
key_ba = bytearray(key)
|
||||
data_ba = bytearray(data)
|
||||
|
||||
# Data and key can be a bytearray (during initialization)
|
||||
h1 = CMAC.new(key, data, ciphermod=AES)
|
||||
h2 = CMAC.new(key_ba, data_ba, ciphermod=AES)
|
||||
self.assertEqual(h1.digest(), h2.digest())
|
||||
|
||||
# Data can be a bytearray (during operation)
|
||||
h1 = CMAC.new(key, ciphermod=AES)
|
||||
h2 = CMAC.new(key, ciphermod=AES)
|
||||
h1.update(data)
|
||||
h2.update(data_ba)
|
||||
self.assertEqual(h1.digest(), h2.digest())
|
||||
|
||||
|
||||
class MemoryViewTests(unittest.TestCase):
|
||||
|
||||
def runTest(self):
|
||||
|
||||
key = b"0" * 16
|
||||
data = b"\x00\x01\x02"
|
||||
|
||||
mv_ro = [ memoryview(x) for x in (key, data) ]
|
||||
mv_rw = [ memoryview(bytearray(x)) for x in (key, data) ]
|
||||
|
||||
for key_mv, data_mv in (mv_ro, mv_rw):
|
||||
|
||||
# Data and key can be a memoryview (during initialization)
|
||||
h1 = CMAC.new(key, data, ciphermod=AES)
|
||||
h2 = CMAC.new(key_mv, data_mv, ciphermod=AES)
|
||||
self.assertEqual(h1.digest(), h2.digest())
|
||||
|
||||
# Data can be a memoryview (during operation)
|
||||
h1 = CMAC.new(key, ciphermod=AES)
|
||||
h2 = CMAC.new(key, ciphermod=AES)
|
||||
h1.update(data)
|
||||
h2.update(data_mv)
|
||||
self.assertEqual(h1.digest(), h2.digest())
|
||||
|
||||
|
||||
def get_tests(config={}):
|
||||
global test_data
|
||||
from common import make_mac_tests
|
||||
|
|
@ -279,6 +328,9 @@ def get_tests(config={}):
|
|||
|
||||
tests = make_mac_tests(CMAC, "CMAC", params_test_data)
|
||||
tests.append(MultipleUpdates())
|
||||
tests.append(ByteArrayTests())
|
||||
if not (sys.version_info[0] == 2 and sys.version_info[1] < 7):
|
||||
tests.append(MemoryViewTests())
|
||||
return tests
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue