mirror of
https://github.com/Legrandin/pycryptodome.git
synced 2025-12-08 05:19:46 +00:00
Support for bytearrays and memoryviews in CMAC
This commit is contained in:
parent
198bc33333
commit
e2f83b71ad
2 changed files with 71 additions and 8 deletions
|
|
@ -20,7 +20,9 @@
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
# ===================================================================
|
# ===================================================================
|
||||||
|
|
||||||
from Crypto.Util.py3compat import b, bchr, bord, tobytes
|
import sys
|
||||||
|
|
||||||
|
from Crypto.Util.py3compat import bord, tobytes
|
||||||
|
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
|
|
||||||
|
|
@ -33,6 +35,12 @@ from Crypto.Random import get_random_bytes
|
||||||
digest_size = None
|
digest_size = None
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
|
||||||
|
_memoryview = None
|
||||||
|
else:
|
||||||
|
_memoryview = memoryview
|
||||||
|
|
||||||
|
|
||||||
def _shift_bytes(bs, xor_lsb=0):
|
def _shift_bytes(bs, xor_lsb=0):
|
||||||
num = (bytes_to_long(bs) << 1) ^ xor_lsb
|
num = (bytes_to_long(bs) << 1) ^ xor_lsb
|
||||||
return long_to_bytes(num, len(bs))[-len(bs):]
|
return long_to_bytes(num, len(bs))[-len(bs):]
|
||||||
|
|
@ -77,7 +85,7 @@ class CMAC(object):
|
||||||
self._mac_tag = None
|
self._mac_tag = None
|
||||||
|
|
||||||
# Compute sub-keys
|
# Compute sub-keys
|
||||||
zero_block = bchr(0) * ciphermod.block_size
|
zero_block = b'\x00' * ciphermod.block_size
|
||||||
cipher = ciphermod.new(key,
|
cipher = ciphermod.new(key,
|
||||||
ciphermod.MODE_ECB,
|
ciphermod.MODE_ECB,
|
||||||
**self._cipher_params)
|
**self._cipher_params)
|
||||||
|
|
@ -98,7 +106,7 @@ class CMAC(object):
|
||||||
**self._cipher_params)
|
**self._cipher_params)
|
||||||
|
|
||||||
# Cache for outstanding data to authenticate
|
# Cache for outstanding data to authenticate
|
||||||
self._cache = b("")
|
self._cache = b""
|
||||||
|
|
||||||
# Last two pieces of ciphertext produced
|
# Last two pieces of ciphertext produced
|
||||||
self._last_ct = self._last_pt = zero_block
|
self._last_ct = self._last_pt = zero_block
|
||||||
|
|
@ -128,16 +136,19 @@ class CMAC(object):
|
||||||
|
|
||||||
msg = msg[filler:]
|
msg = msg[filler:]
|
||||||
self._update(self._cache)
|
self._update(self._cache)
|
||||||
self._cache = b("")
|
self._cache = b""
|
||||||
|
|
||||||
update_len, remain = divmod(len(msg), self.digest_size)
|
update_len, remain = divmod(len(msg), self.digest_size)
|
||||||
update_len *= self.digest_size
|
update_len *= self.digest_size
|
||||||
if remain > 0:
|
if remain > 0:
|
||||||
self._update(msg[:update_len])
|
self._update(msg[:update_len])
|
||||||
|
if isinstance(msg, _memoryview):
|
||||||
|
self._cache = msg[update_len:].tobytes()
|
||||||
|
else:
|
||||||
self._cache = msg[update_len:]
|
self._cache = msg[update_len:]
|
||||||
else:
|
else:
|
||||||
self._update(msg)
|
self._update(msg)
|
||||||
self._cache = b("")
|
self._cache = b""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _update(self, data_block):
|
def _update(self, data_block):
|
||||||
|
|
@ -201,8 +212,8 @@ class CMAC(object):
|
||||||
pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt)
|
pt = strxor(strxor(self._before_last_ct, self._k1), self._last_pt)
|
||||||
else:
|
else:
|
||||||
# Last block is partial (or message length is zero)
|
# Last block is partial (or message length is zero)
|
||||||
ext = self._cache + bchr(0x80) +\
|
ext = self._cache + b'\x80' +\
|
||||||
bchr(0) * (self.digest_size - len(self._cache) - 1)
|
b'\x00' * (self.digest_size - len(self._cache) - 1)
|
||||||
pt = strxor(strxor(self._last_ct, self._k2), ext)
|
pt = strxor(strxor(self._last_ct, self._k2), ext)
|
||||||
|
|
||||||
cipher = self._factory.new(self._key,
|
cipher = self._factory.new(self._key,
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@
|
||||||
|
|
||||||
"""Self-test suite for Crypto.Hash.CMAC"""
|
"""Self-test suite for Crypto.Hash.CMAC"""
|
||||||
|
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from Crypto.Util.py3compat import tobytes
|
from Crypto.Util.py3compat import tobytes
|
||||||
|
|
@ -266,6 +267,54 @@ class MultipleUpdates(unittest.TestCase):
|
||||||
self.assertEqual(ref_mac, mac.digest())
|
self.assertEqual(ref_mac, mac.digest())
|
||||||
|
|
||||||
|
|
||||||
|
class ByteArrayTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def runTest(self):
|
||||||
|
|
||||||
|
key = b"0" * 16
|
||||||
|
data = b"\x00\x01\x02"
|
||||||
|
|
||||||
|
key_ba = bytearray(key)
|
||||||
|
data_ba = bytearray(data)
|
||||||
|
|
||||||
|
# Data and key can be a bytearray (during initialization)
|
||||||
|
h1 = CMAC.new(key, data, ciphermod=AES)
|
||||||
|
h2 = CMAC.new(key_ba, data_ba, ciphermod=AES)
|
||||||
|
self.assertEqual(h1.digest(), h2.digest())
|
||||||
|
|
||||||
|
# Data can be a bytearray (during operation)
|
||||||
|
h1 = CMAC.new(key, ciphermod=AES)
|
||||||
|
h2 = CMAC.new(key, ciphermod=AES)
|
||||||
|
h1.update(data)
|
||||||
|
h2.update(data_ba)
|
||||||
|
self.assertEqual(h1.digest(), h2.digest())
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryViewTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def runTest(self):
|
||||||
|
|
||||||
|
key = b"0" * 16
|
||||||
|
data = b"\x00\x01\x02"
|
||||||
|
|
||||||
|
mv_ro = [ memoryview(x) for x in (key, data) ]
|
||||||
|
mv_rw = [ memoryview(bytearray(x)) for x in (key, data) ]
|
||||||
|
|
||||||
|
for key_mv, data_mv in (mv_ro, mv_rw):
|
||||||
|
|
||||||
|
# Data and key can be a memoryview (during initialization)
|
||||||
|
h1 = CMAC.new(key, data, ciphermod=AES)
|
||||||
|
h2 = CMAC.new(key_mv, data_mv, ciphermod=AES)
|
||||||
|
self.assertEqual(h1.digest(), h2.digest())
|
||||||
|
|
||||||
|
# Data can be a memoryview (during operation)
|
||||||
|
h1 = CMAC.new(key, ciphermod=AES)
|
||||||
|
h2 = CMAC.new(key, ciphermod=AES)
|
||||||
|
h1.update(data)
|
||||||
|
h2.update(data_mv)
|
||||||
|
self.assertEqual(h1.digest(), h2.digest())
|
||||||
|
|
||||||
|
|
||||||
def get_tests(config={}):
|
def get_tests(config={}):
|
||||||
global test_data
|
global test_data
|
||||||
from common import make_mac_tests
|
from common import make_mac_tests
|
||||||
|
|
@ -279,6 +328,9 @@ def get_tests(config={}):
|
||||||
|
|
||||||
tests = make_mac_tests(CMAC, "CMAC", params_test_data)
|
tests = make_mac_tests(CMAC, "CMAC", params_test_data)
|
||||||
tests.append(MultipleUpdates())
|
tests.append(MultipleUpdates())
|
||||||
|
tests.append(ByteArrayTests())
|
||||||
|
if not (sys.version_info[0] == 2 and sys.version_info[1] < 7):
|
||||||
|
tests.append(MemoryViewTests())
|
||||||
return tests
|
return tests
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue