pycryptodome/lib/Crypto/Protocol/HPKE.py
2025-04-27 19:36:13 +02:00

483 lines
17 KiB
Python

import struct
from enum import IntEnum
from types import ModuleType
from typing import Optional
from .KDF import _HKDF_extract, _HKDF_expand
from .DH import key_agreement, import_x25519_public_key, import_x448_public_key
from Crypto.Util.strxor import strxor
from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey
from Crypto.Hash import SHA256, SHA384, SHA512
from Crypto.Cipher import AES, ChaCha20_Poly1305
class MODE(IntEnum):
"""HPKE modes"""
BASE = 0x00
PSK = 0x01
AUTH = 0x02
AUTH_PSK = 0x03
class AEAD(IntEnum):
"""Authenticated Encryption with Associated Data (AEAD) Functions"""
AES128_GCM = 0x0001
AES256_GCM = 0x0002
CHACHA20_POLY1305 = 0x0003
class DeserializeError(ValueError):
pass
class MessageLimitReachedError(ValueError):
pass
# CURVE to (KEM ID, KDF ID, HASH)
_Curve_Config = {
"NIST P-256": (0x0010, 0x0001, SHA256),
"NIST P-384": (0x0011, 0x0002, SHA384),
"NIST P-521": (0x0012, 0x0003, SHA512),
"Curve25519": (0x0020, 0x0001, SHA256),
"Curve448": (0x0021, 0x0003, SHA512),
}
def _labeled_extract(salt: bytes,
label: bytes,
ikm: bytes,
suite_id: bytes,
hashmod: ModuleType):
labeled_ikm = b"HPKE-v1" + suite_id + label + ikm
return _HKDF_extract(salt, labeled_ikm, hashmod)
def _labeled_expand(prk: bytes,
label: bytes,
info: bytes,
L: int,
suite_id: bytes,
hashmod: ModuleType):
labeled_info = struct.pack('>H', L) + b"HPKE-v1" + suite_id + \
label + info
return _HKDF_expand(prk, labeled_info, L, hashmod)
def _extract_and_expand(dh: bytes,
kem_context: bytes,
suite_id: bytes,
hashmod: ModuleType):
Nsecret = hashmod.digest_size
eae_prk = _labeled_extract(b"",
b"eae_prk",
dh,
suite_id,
hashmod)
shared_secret = _labeled_expand(eae_prk,
b"shared_secret",
kem_context,
Nsecret,
suite_id,
hashmod)
return shared_secret
class HPKE_Cipher:
def __init__(self,
receiver_key: EccKey,
enc: Optional[bytes],
sender_key: Optional[EccKey],
psk_pair: tuple[bytes, bytes],
info: bytes,
aead_id: AEAD,
mode: MODE):
self.enc: bytes = b'' if enc is None else enc
"""The encapsulated session key."""
self._verify_psk_inputs(mode, psk_pair)
self._curve = receiver_key.curve
self._aead_id = aead_id
self._mode = mode
try:
self._kem_id, \
self._kdf_id, \
self._hashmod = _Curve_Config[self._curve]
except KeyError as ke:
raise ValueError("Curve {} is not supported by HPKE".format(self._curve)) from ke
self._Nk = 16 if self._aead_id == AEAD.AES128_GCM else 32
self._Nn = 12
self._Nt = 16
self._Nh = self._hashmod.digest_size
self._encrypt = not receiver_key.has_private()
if self._encrypt:
# SetupBaseS (encryption)
if enc is not None:
raise ValueError("Parameter 'enc' cannot be an input when sealing")
shared_secret, self.enc = self._encap(receiver_key,
self._kem_id,
self._hashmod,
sender_key)
else:
# SetupBaseR (decryption)
if enc is None:
raise ValueError("Parameter 'enc' required when unsealing")
shared_secret = self._decap(enc,
receiver_key,
self._kem_id,
self._hashmod,
sender_key)
self._sequence = 0
self._max_sequence = (1 << (8 * self._Nn)) - 1
self._key, \
self._base_nonce, \
self._export_secret = self._key_schedule(shared_secret,
info,
*psk_pair)
@staticmethod
def _encap(receiver_key: EccKey,
kem_id: int,
hashmod: ModuleType,
sender_key: Optional[EccKey] = None,
eph_key: Optional[EccKey] = None):
assert (sender_key is None) or sender_key.has_private()
assert (eph_key is None) or eph_key.has_private()
if eph_key is None:
eph_key = ECC.generate(curve=receiver_key.curve)
enc = eph_key.public_key().export_key(format='raw')
pkRm = receiver_key.public_key().export_key(format='raw')
kem_context = enc + pkRm
extra_param = {}
if sender_key:
kem_context += sender_key.public_key().export_key(format='raw')
extra_param = {'static_priv': sender_key}
suite_id = b"KEM" + struct.pack('>H', kem_id)
def kdf(dh,
kem_context=kem_context,
suite_id=suite_id,
hashmod=hashmod):
return _extract_and_expand(dh, kem_context, suite_id, hashmod)
shared_secret = key_agreement(eph_priv=eph_key,
static_pub=receiver_key,
kdf=kdf,
**extra_param)
return shared_secret, enc
@staticmethod
def _decap(enc: bytes,
receiver_key: EccKey,
kem_id: int,
hashmod: ModuleType,
sender_key: Optional[EccKey] = None):
assert receiver_key.has_private()
try:
if receiver_key.curve == 'Curve25519':
pkE = import_x25519_public_key(enc)
elif receiver_key.curve == 'Curve448':
pkE = import_x448_public_key(enc)
else:
pkE = ECC.import_key(enc, curve_name=receiver_key.curve)
except ValueError as ve:
raise DeserializeError("'enc' is not a valid encapsulated HPKE key") from ve
pkRm = receiver_key.public_key().export_key(format='raw')
kem_context = enc + pkRm
extra_param = {}
if sender_key:
kem_context += sender_key.public_key().export_key(format='raw')
extra_param = {'static_pub': sender_key}
suite_id = b"KEM" + struct.pack('>H', kem_id)
def kdf(dh,
kem_context=kem_context,
suite_id=suite_id,
hashmod=hashmod):
return _extract_and_expand(dh, kem_context, suite_id, hashmod)
shared_secret = key_agreement(eph_pub=pkE,
static_priv=receiver_key,
kdf=kdf,
**extra_param)
return shared_secret
@staticmethod
def _verify_psk_inputs(mode: MODE, psk_pair: tuple[bytes, bytes]):
psk_id, psk = psk_pair
if (psk == b'') ^ (psk_id == b''):
raise ValueError("Inconsistent PSK inputs")
if (psk == b''):
if mode in (MODE.PSK, MODE.AUTH_PSK):
raise ValueError(f"PSK is required with mode {mode.name}")
else:
if len(psk) < 32:
raise ValueError("PSK must be at least 32 byte long")
if mode in (MODE.BASE, MODE.AUTH):
raise ValueError("PSK is not compatible with this mode")
def _key_schedule(self,
shared_secret: bytes,
info: bytes,
psk_id: bytes,
psk: bytes):
suite_id = b"HPKE" + struct.pack('>HHH',
self._kem_id,
self._kdf_id,
self._aead_id)
psk_id_hash = _labeled_extract(b'',
b'psk_id_hash',
psk_id,
suite_id,
self._hashmod)
info_hash = _labeled_extract(b'',
b'info_hash',
info,
suite_id,
self._hashmod)
key_schedule_context = self._mode.to_bytes(1, 'big') + psk_id_hash + info_hash
secret = _labeled_extract(shared_secret,
b'secret',
psk,
suite_id,
self._hashmod)
key = _labeled_expand(secret,
b'key',
key_schedule_context,
self._Nk,
suite_id,
self._hashmod)
base_nonce = _labeled_expand(secret,
b'base_nonce',
key_schedule_context,
self._Nn,
suite_id,
self._hashmod)
exporter_secret = _labeled_expand(secret,
b'exp',
key_schedule_context,
self._Nh,
suite_id,
self._hashmod)
return key, base_nonce, exporter_secret
def _new_cipher(self):
nonce = strxor(self._base_nonce, self._sequence.to_bytes(self._Nn, 'big'))
if self._aead_id in (AEAD.AES128_GCM, AEAD.AES256_GCM):
cipher = AES.new(self._key, AES.MODE_GCM, nonce=nonce, mac_len=self._Nt)
elif self._aead_id == AEAD.CHACHA20_POLY1305:
cipher = ChaCha20_Poly1305.new(key=self._key, nonce=nonce)
else:
raise ValueError(f"Unknown AEAD cipher ID {self._aead_id:#x}")
if self._sequence >= self._max_sequence:
raise MessageLimitReachedError()
self._sequence += 1
return cipher
def seal(self, plaintext: bytes, auth_data: Optional[bytes] = None):
"""Encrypt and authenticate a message.
This method can be invoked multiple times
to seal an ordered sequence of messages.
Arguments:
plaintext: bytes
The message to seal.
auth_data: bytes
Optional. Additional Authenticated data (AAD) that is not encrypted
but that will be also covered by the authentication tag.
Returns:
The ciphertext concatenated with the authentication tag.
"""
if not self._encrypt:
raise ValueError("This cipher can only be used to seal")
cipher = self._new_cipher()
if auth_data:
cipher.update(auth_data)
ct, tag = cipher.encrypt_and_digest(plaintext)
return ct + tag
def unseal(self, ciphertext: bytes, auth_data: Optional[bytes] = None):
"""Decrypt a message and validate its authenticity.
This method can be invoked multiple times
to unseal an ordered sequence of messages.
Arguments:
cipertext: bytes
The message to unseal.
auth_data: bytes
Optional. Additional Authenticated data (AAD) that
was also covered by the authentication tag.
Returns:
The original plaintext.
Raises: ValueError
If the ciphertext (in combination with the AAD) is not valid.
But if it is the first time you call ``unseal()`` this
exception may also mean that any of the parameters or keys
used to establish the session is wrong or that one is missing.
"""
if self._encrypt:
raise ValueError("This cipher can only be used to unseal")
if len(ciphertext) < self._Nt:
raise ValueError("Ciphertext is too small")
cipher = self._new_cipher()
if auth_data:
cipher.update(auth_data)
try:
pt = cipher.decrypt_and_verify(ciphertext[:-self._Nt],
ciphertext[-self._Nt:])
except ValueError:
if self._sequence == 1:
raise ValueError("Incorrect HPKE keys/parameters or invalid message (wrong MAC tag)")
raise ValueError("Invalid message (wrong MAC tag)")
return pt
def new(*, receiver_key: EccKey,
aead_id: AEAD,
enc: Optional[bytes] = None,
sender_key: Optional[EccKey] = None,
psk: Optional[tuple[bytes, bytes]] = None,
info: Optional[bytes] = None) -> HPKE_Cipher:
"""Create an HPKE context which can be used:
- by the sender to seal (encrypt) a message or
- by the receiver to unseal (decrypt) it.
As a minimum, the two parties agree on the receiver's asymmetric key
(of which the sender will only know the public half).
Additionally, for authentication purposes, they may also agree on:
* the sender's asymmetric key (of which the receiver will only know the public half)
* a shared secret (e.g., a symmetric key derived from a password)
Args:
receiver_key:
The ECC key of the receiver.
It must be on one of the following curves: ``NIST P-256``,
``NIST P-384``, ``NIST P-521``, ``X25519`` or ``X448``.
If this is a **public** key, the HPKE context can only be used to
**seal** (**encrypt**).
If this is a **private** key, the HPKE context can only be used to
**unseal** (**decrypt**).
aead_id:
The HPKE identifier of the symmetric cipher.
The possible values are:
* ``HPKE.AEAD.AES128_GCM``
* ``HPKE.AEAD.AES256_GCM``
* ``HPKE.AEAD.CHACHA20_POLY1305``
enc:
The encapsulated session key (i.e., the KEM shared secret).
The receiver must always specify this parameter.
The sender must always omit this parameter.
sender_key:
The ECC key of the sender.
It must be on the same curve as the ``receiver_key``.
If the ``receiver_key`` is a public key, ``sender_key`` must be a
private key, and vice versa.
psk:
A Pre-Shared Key (PSK) as a 2-tuple of non-empty
byte strings: the identifier and the actual secret value.
Sender and receiver must use the same PSK (or none).
The secret value must be at least 32 bytes long,
but it must not be a low-entropy password
(use a KDF like PBKDF2 or scrypt to derive a secret
from a password).
info:
A non-secret parameter that contributes
to the generation of all session keys.
Sender and receive must use the same **info** parameter (or none).
Returns:
An object that can be used for
sealing (if ``receiver_key`` is a public key) or
unsealing (if ``receiver_key`` is a private key).
In the latter case,
correctness of all the keys and parameters will only
be assessed with the first call to ``unseal()``.
"""
if aead_id not in AEAD:
raise ValueError(f"Unknown AEAD cipher ID {aead_id:#x}")
curve = receiver_key.curve
if curve not in ('NIST P-256', 'NIST P-384', 'NIST P-521',
'Curve25519', 'Curve448'):
raise ValueError(f"Unsupported curve {curve}")
if sender_key:
count_private_keys = int(receiver_key.has_private()) + \
int(sender_key.has_private())
if count_private_keys != 1:
raise ValueError("Exactly 1 private key required")
if sender_key.curve != curve:
raise ValueError("Sender key uses {} but recipient key {}".
format(sender_key.curve, curve))
mode = MODE.AUTH if psk is None else MODE.AUTH_PSK
else:
mode = MODE.BASE if psk is None else MODE.PSK
if psk is None:
psk = b'', b''
if info is None:
info = b''
return HPKE_Cipher(receiver_key,
enc,
sender_key,
psk,
info,
aead_id,
mode)