Add output parameter for SIV mode

This commit is contained in:
Helder Eijs 2018-10-27 22:30:18 +02:00
parent 5d1459ce55
commit 9276aa561c
4 changed files with 176 additions and 28 deletions

View file

@ -330,7 +330,7 @@ class GcmMode(object):
if len_cache > 0: if len_cache > 0:
self._update(b'\x00' * (16 - len_cache)) self._update(b'\x00' * (16 - len_cache))
def encrypt(self, plaintext): def encrypt(self, plaintext, output=None):
"""Encrypt data with the key and the parameters set at initialization. """Encrypt data with the key and the parameters set at initialization.
A cipher object is stateful: once you have encrypted a message A cipher object is stateful: once you have encrypted a message
@ -354,9 +354,13 @@ class GcmMode(object):
plaintext : bytes/bytearray/memoryview plaintext : bytes/bytearray/memoryview
The piece of data to encrypt. The piece of data to encrypt.
It can be of any length. It can be of any length.
:Keywords:
output : bytearray/memoryview
The location where the ciphertext must be written to.
If ``None``, the ciphertext is returned.
:Return: :Return:
the encrypted data, as a byte string. If ``output`` is ``None``, the ciphertext as ``bytes``.
It is as long as *plaintext*. Otherwise, ``None``.
""" """
if self.encrypt not in self._next: if self.encrypt not in self._next:
@ -364,13 +368,13 @@ class GcmMode(object):
" initialization or an update()") " initialization or an update()")
self._next = [self.encrypt, self.digest] self._next = [self.encrypt, self.digest]
ciphertext = self._cipher.encrypt(plaintext) ciphertext = self._cipher.encrypt(plaintext, output=output)
if self._status == MacStatus.PROCESSING_AUTH_DATA: if self._status == MacStatus.PROCESSING_AUTH_DATA:
self._pad_cache_and_update() self._pad_cache_and_update()
self._status = MacStatus.PROCESSING_CIPHERTEXT self._status = MacStatus.PROCESSING_CIPHERTEXT
self._update(ciphertext) self._update(ciphertext if output is None else output)
self._msg_len += len(plaintext) self._msg_len += len(plaintext)
# See NIST SP 800 38D, 5.2.1.1 # See NIST SP 800 38D, 5.2.1.1
@ -379,7 +383,7 @@ class GcmMode(object):
return ciphertext return ciphertext
def decrypt(self, ciphertext): def decrypt(self, ciphertext, output=None):
"""Decrypt data with the key and the parameters set at initialization. """Decrypt data with the key and the parameters set at initialization.
A cipher object is stateful: once you have decrypted a message A cipher object is stateful: once you have decrypted a message
@ -403,8 +407,13 @@ class GcmMode(object):
ciphertext : bytes/bytearray/memoryview ciphertext : bytes/bytearray/memoryview
The piece of data to decrypt. The piece of data to decrypt.
It can be of any length. It can be of any length.
:Keywords:
:Return: the decrypted data (byte string). output : bytearray/memoryview
The location where the plaintext must be written to.
If ``None``, the plaintext is returned.
:Return:
If ``output`` is ``None``, the plaintext as ``bytes``.
Otherwise, ``None``.
""" """
if self.decrypt not in self._next: if self.decrypt not in self._next:
@ -419,7 +428,7 @@ class GcmMode(object):
self._update(ciphertext) self._update(ciphertext)
self._msg_len += len(ciphertext) self._msg_len += len(ciphertext)
return self._cipher.decrypt(ciphertext) return self._cipher.decrypt(ciphertext, output=output)
def digest(self): def digest(self):
"""Compute the *binary* MAC tag in an AEAD mode. """Compute the *binary* MAC tag in an AEAD mode.
@ -512,22 +521,29 @@ class GcmMode(object):
self.verify(unhexlify(hex_mac_tag)) self.verify(unhexlify(hex_mac_tag))
def encrypt_and_digest(self, plaintext): def encrypt_and_digest(self, plaintext, output=None):
"""Perform encrypt() and digest() in one step. """Perform encrypt() and digest() in one step.
:Parameters: :Parameters:
plaintext : bytes/bytearray/memoryview plaintext : bytes/bytearray/memoryview
The piece of data to encrypt. The piece of data to encrypt.
:Keywords:
output : bytearray/memoryview
The location where the ciphertext must be written to.
If ``None``, the ciphertext is returned.
:Return: :Return:
a tuple with two byte strings: a tuple with two items:
- the encrypted data - the ciphertext, as ``bytes``
- the MAC - the MAC tag, as ``bytes``
The first item becomes ``None`` when the ``output`` parameter
specified a location for the result.
""" """
return self.encrypt(plaintext), self.digest() return self.encrypt(plaintext, output=output), self.digest()
def decrypt_and_verify(self, ciphertext, received_mac_tag): def decrypt_and_verify(self, ciphertext, received_mac_tag, output=None):
"""Perform decrypt() and verify() in one step. """Perform decrypt() and verify() in one step.
:Parameters: :Parameters:
@ -535,14 +551,18 @@ class GcmMode(object):
The piece of data to decrypt. The piece of data to decrypt.
received_mac_tag : byte string received_mac_tag : byte string
This is the *binary* MAC, as received from the sender. This is the *binary* MAC, as received from the sender.
:Keywords:
:Return: the decrypted data (byte string). output : bytearray/memoryview
The location where the plaintext must be written to.
If ``None``, the plaintext is returned.
:Return: the plaintext as ``bytes`` or ``None`` when the ``output``
parameter specified a location for the result.
:Raises ValueError: :Raises ValueError:
if the MAC does not match. The message has been tampered with if the MAC does not match. The message has been tampered with
or the key is incorrect. or the key is incorrect.
""" """
plaintext = self.decrypt(ciphertext) plaintext = self.decrypt(ciphertext, output=output)
self.verify(received_mac_tag) self.verify(received_mac_tag)
return plaintext return plaintext

View file

@ -269,17 +269,24 @@ class SivMode(object):
self.verify(unhexlify(hex_mac_tag)) self.verify(unhexlify(hex_mac_tag))
def encrypt_and_digest(self, plaintext): def encrypt_and_digest(self, plaintext, output=None):
"""Perform encrypt() and digest() in one step. """Perform encrypt() and digest() in one step.
:Parameters: :Parameters:
plaintext : bytes/bytearray/memoryview plaintext : bytes/bytearray/memoryview
The piece of data to encrypt. The piece of data to encrypt.
:Keywords:
output : bytearray/memoryview
The location where the ciphertext must be written to.
If ``None``, the ciphertext is returned.
:Return: :Return:
a tuple with two byte strings: a tuple with two items:
- the encrypted data - the ciphertext, as ``bytes``
- the MAC - the MAC tag, as ``bytes``
The first item becomes ``None`` when the ``output`` parameter
specified a location for the result.
""" """
if self.encrypt not in self._next: if self.encrypt not in self._next:
@ -296,9 +303,9 @@ class SivMode(object):
cipher = self._create_ctr_cipher(self._mac_tag) cipher = self._create_ctr_cipher(self._mac_tag)
return cipher.encrypt(plaintext), self._mac_tag return cipher.encrypt(plaintext, output=output), self._mac_tag
def decrypt_and_verify(self, ciphertext, mac_tag): def decrypt_and_verify(self, ciphertext, mac_tag, output=None):
"""Perform decryption and verification in one step. """Perform decryption and verification in one step.
A cipher object is stateful: once you have decrypted a message A cipher object is stateful: once you have decrypted a message
@ -316,8 +323,12 @@ class SivMode(object):
It can be of any length. It can be of any length.
mac_tag : bytes/bytearray/memoryview mac_tag : bytes/bytearray/memoryview
This is the *binary* MAC, as received from the sender. This is the *binary* MAC, as received from the sender.
:Keywords:
:Return: the decrypted data (byte string). output : bytearray/memoryview
The location where the plaintext must be written to.
If ``None``, the plaintext is returned.
:Return: the plaintext as ``bytes`` or ``None`` when the ``output``
parameter specified a location for the result.
:Raises ValueError: :Raises ValueError:
if the MAC does not match. The message has been tampered with if the MAC does not match. The message has been tampered with
or the key is incorrect. or the key is incorrect.
@ -331,11 +342,11 @@ class SivMode(object):
# Take the MAC and start the cipher for decryption # Take the MAC and start the cipher for decryption
self._cipher = self._create_ctr_cipher(mac_tag) self._cipher = self._create_ctr_cipher(mac_tag)
plaintext = self._cipher.decrypt(ciphertext) plaintext = self._cipher.decrypt(ciphertext, output=output)
if hasattr(self, 'nonce'): if hasattr(self, 'nonce'):
self._kdf.update(self.nonce) self._kdf.update(self.nonce)
self._kdf.update(plaintext) self._kdf.update(plaintext if output is None else output)
self.verify(mac_tag) self.verify(mac_tag)
return plaintext return plaintext

View file

@ -306,10 +306,74 @@ class GcmTests(unittest.TestCase):
pt_test = cipher4.decrypt_and_verify(memoryview(ct_test), memoryview(tag_test)) pt_test = cipher4.decrypt_and_verify(memoryview(ct_test), memoryview(tag_test))
self.assertEqual(self.data_128, pt_test) self.assertEqual(self.data_128, pt_test)
def test_output_param(self):
pt = b'5' * 16
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
ct = cipher.encrypt(pt)
tag = cipher.digest()
output = bytearray(16)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
res = cipher.encrypt(pt, output=output)
self.assertEqual(ct, output)
self.assertEqual(res, None)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
res = cipher.decrypt(ct, output=output)
self.assertEqual(pt, output)
self.assertEqual(res, None)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
res, tag_out = cipher.encrypt_and_digest(pt, output=output)
self.assertEqual(ct, output)
self.assertEqual(res, None)
self.assertEqual(tag, tag_out)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
res = cipher.decrypt_and_verify(ct, tag, output=output)
self.assertEqual(pt, output)
self.assertEqual(res, None)
def test_output_param_memoryview(self):
pt = b'5' * 16
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
ct = cipher.encrypt(pt)
output = memoryview(bytearray(16))
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
cipher.encrypt(pt, output=output)
self.assertEqual(ct, output)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
cipher.decrypt(ct, output=output)
self.assertEqual(pt, output)
def test_output_param_neg(self):
pt = b'5' * 16
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
ct = cipher.encrypt(pt)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
shorter_output = bytearray(15)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
import sys import sys
if sys.version[:3] == "2.6": if sys.version[:3] == "2.6":
del test_memoryview del test_memoryview
del test_output_param_memoryview
class GcmFSMTests(unittest.TestCase): class GcmFSMTests(unittest.TestCase):

View file

@ -240,10 +240,63 @@ class SivTests(unittest.TestCase):
pt_test = cipher3.decrypt_and_verify(ct_ba, tag_ba) pt_test = cipher3.decrypt_and_verify(ct_ba, tag_ba)
self.assertEqual(self.data_128, pt_test) self.assertEqual(self.data_128, pt_test)
def test_output_param(self):
pt = b'5' * 16
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
ct, tag = cipher.encrypt_and_digest(pt)
output = bytearray(16)
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
res, tag_out = cipher.encrypt_and_digest(pt, output=output)
self.assertEqual(ct, output)
self.assertEqual(res, None)
self.assertEqual(tag, tag_out)
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
res = cipher.decrypt_and_verify(ct, tag, output=output)
self.assertEqual(pt, output)
self.assertEqual(res, None)
def test_output_param_memoryview(self):
pt = b'5' * 16
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
ct, tag = cipher.encrypt_and_digest(pt)
output = memoryview(bytearray(16))
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
cipher.encrypt_and_digest(pt, output=output)
self.assertEqual(ct, output)
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
cipher.decrypt_and_verify(ct, tag, output=output)
self.assertEqual(pt, output)
def test_output_param_neg(self):
pt = b'5' * 16
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
ct, tag = cipher.encrypt_and_digest(pt)
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
self.assertRaises(TypeError, cipher.encrypt_and_digest, pt, output=b'0'*16)
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
self.assertRaises(TypeError, cipher.decrypt_and_verify, ct, tag, output=b'0'*16)
shorter_output = bytearray(15)
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
self.assertRaises(ValueError, cipher.encrypt_and_digest, pt, output=shorter_output)
cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
self.assertRaises(ValueError, cipher.decrypt_and_verify, ct, tag, output=shorter_output)
import sys import sys
if sys.version[:3] == "2.6": if sys.version[:3] == "2.6":
del test_memoryview del test_memoryview
del test_output_param_memoryview
class SivFSMTests(unittest.TestCase): class SivFSMTests(unittest.TestCase):