pycryptodome/lib/Crypto/Protocol/HPKE.py
2025-03-09 08:15:38 +01:00

468 lines
16 KiB
Python

import struct
from enum import IntEnum
from types import ModuleType
from typing import Optional
from .KDF import _HKDF_extract, _HKDF_expand
from .DH import key_agreement, import_x25519_public_key, import_x448_public_key
from Crypto.Util.strxor import strxor
from Crypto.Util.number import long_to_bytes
from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey
from Crypto.Hash import SHA256, SHA384, SHA512
from Crypto.Cipher import AES, ChaCha20_Poly1305
class MODE(IntEnum):
BASE = 0x00
PSK = 0x01
AUTH = 0x02
AUTH_PSK = 0x03
class AEAD(IntEnum):
AES128_GCM = 0x0001
AES256_GCM = 0x0002
CHACHA20_POLY1305 = 0x0003
class DeserializeError(ValueError):
pass
# CURVE to (KEM ID, KDF ID, HASH)
_Curve_Config = {
"NIST P-256": (0x0010, 0x0001, SHA256),
"NIST P-384": (0x0011, 0x0002, SHA384),
"NIST P-521": (0x0012, 0x0003, SHA512),
"Curve25519": (0x0020, 0x0001, SHA256),
"Curve448": (0x0021, 0x0003, SHA512),
}
def labeled_extract(salt: bytes,
label: bytes,
ikm: bytes,
suite_id: bytes,
hashmod: ModuleType):
labeled_ikm = b"HPKE-v1" + suite_id + label + ikm
return _HKDF_extract(salt, labeled_ikm, hashmod)
def labeled_expand(prk: bytes,
label: bytes,
info: bytes,
L: int,
suite_id: bytes,
hashmod: ModuleType):
labeled_info = struct.pack('>H', L) + b"HPKE-v1" + suite_id + \
label + info
return _HKDF_expand(prk, labeled_info, L, hashmod)
def extract_and_expand(dh: bytes,
kem_context: bytes,
suite_id: bytes,
hashmod: ModuleType):
Nsecret = hashmod.digest_size
eae_prk = labeled_extract(b"",
b"eae_prk",
dh,
suite_id,
hashmod)
shared_secret = labeled_expand(eae_prk,
b"shared_secret",
kem_context,
Nsecret,
suite_id,
hashmod)
return shared_secret
class HPKE_Cipher:
def __init__(self,
receiver_key: EccKey,
enc: Optional[bytes],
sender_key: Optional[EccKey],
psk_pair: tuple[bytes, bytes],
info: bytes,
aead_id: AEAD,
mode: MODE):
self.enc = enc
self._verify_psk_inputs(mode, psk_pair)
self._curve = receiver_key.curve
self._aead_id = aead_id
self._mode = mode
try:
self._kem_id, \
self._kdf_id, \
self._hashmod = _Curve_Config[self._curve]
except KeyError as ke:
raise ValueError("Curve {} is not supported by HPKE".format(self._curve)) from ke
self._Nk = 16 if self._aead_id == AEAD.AES128_GCM else 32
self._Nn = 12
self._Nt = 16
self._Nh = self._hashmod.digest_size
self._encrypt = not receiver_key.has_private()
if self._encrypt:
# SetupBaseS (encryption)
if enc is not None:
raise ValueError("Parameter 'enc' cannot be an input when sealing")
shared_secret, self.enc = self._encap(receiver_key,
self._kem_id,
self._hashmod,
sender_key)
else:
# SetupBaseR (decryption)
if enc is None:
raise ValueError("Parameter 'enc' required when unsealing")
shared_secret = self._decap(enc,
receiver_key,
self._kem_id,
self._hashmod,
sender_key)
self._sequence = 0
self._key, \
self._base_nonce, \
self._export_secret = self._key_schedule(shared_secret,
info,
*psk_pair)
@staticmethod
def _encap(receiver_key: EccKey,
kem_id: int,
hashmod: ModuleType,
sender_key: Optional[EccKey] = None,
eph_key: Optional[EccKey] = None):
assert (sender_key is None) or sender_key.has_private()
assert (eph_key is None) or eph_key.has_private()
if eph_key is None:
eph_key = ECC.generate(curve=receiver_key.curve)
enc = eph_key.public_key().export_key(format='raw')
pkRm = receiver_key.public_key().export_key(format='raw')
kem_context = enc + pkRm
extra_param = {}
if sender_key:
kem_context += sender_key.public_key().export_key(format='raw')
extra_param = {'static_priv': sender_key}
suite_id = b"KEM" + struct.pack('>H', kem_id)
def kdf(dh,
kem_context=kem_context,
suite_id=suite_id,
hashmod=hashmod):
return extract_and_expand(dh, kem_context, suite_id, hashmod)
shared_secret = key_agreement(eph_priv=eph_key,
static_pub=receiver_key,
kdf=kdf,
**extra_param)
return shared_secret, enc
@staticmethod
def _decap(enc: bytes,
receiver_key: EccKey,
kem_id: int,
hashmod: ModuleType,
sender_key: Optional[EccKey] = None):
assert receiver_key.has_private()
try:
if receiver_key.curve == 'Curve25519':
pkE = import_x25519_public_key(enc)
elif receiver_key.curve == 'Curve448':
pkE = import_x448_public_key(enc)
else:
pkE = ECC.import_key(enc, curve_name=receiver_key.curve)
except ValueError as ve:
raise DeserializeError("'enc' is not a valid encapsulated HPKE key") from ve
pkRm = receiver_key.public_key().export_key(format='raw')
kem_context = enc + pkRm
extra_param = {}
if sender_key:
kem_context += sender_key.public_key().export_key(format='raw')
extra_param = {'static_pub': sender_key}
suite_id = b"KEM" + struct.pack('>H', kem_id)
def kdf(dh,
kem_context=kem_context,
suite_id=suite_id,
hashmod=hashmod):
return extract_and_expand(dh, kem_context, suite_id, hashmod)
shared_secret = key_agreement(eph_pub=pkE,
static_priv=receiver_key,
kdf=kdf,
**extra_param)
return shared_secret
@staticmethod
def _verify_psk_inputs(mode: MODE, psk_pair: tuple[bytes, bytes]):
psk_id, psk = psk_pair
if (psk == b'') ^ (psk_id == b''):
raise ValueError("Inconsistent PSK inputs")
if (psk == b''):
if mode in (MODE.PSK, MODE.AUTH_PSK):
raise ValueError(f"PSK is required with mode {mode.name}")
else:
if len(psk) < 32:
raise ValueError("PSK must be at least 32 byte long")
if mode in (MODE.BASE, MODE.AUTH):
raise ValueError("PSK is not compatible with this mode")
def _key_schedule(self,
shared_secret: bytes,
info: bytes,
psk_id: bytes,
psk: bytes):
suite_id = b"HPKE" + struct.pack('>HHH',
self._kem_id,
self._kdf_id,
self._aead_id)
psk_id_hash = labeled_extract(b'',
b'psk_id_hash',
psk_id,
suite_id,
self._hashmod)
info_hash = labeled_extract(b'',
b'info_hash',
info,
suite_id,
self._hashmod)
key_schedule_context = self._mode.to_bytes() + psk_id_hash + info_hash
secret = labeled_extract(shared_secret,
b'secret',
psk,
suite_id,
self._hashmod)
key = labeled_expand(secret,
b'key',
key_schedule_context,
self._Nk,
suite_id,
self._hashmod)
base_nonce = labeled_expand(secret,
b'base_nonce',
key_schedule_context,
self._Nn,
suite_id,
self._hashmod)
exporter_secret = labeled_expand(secret,
b'exp',
key_schedule_context,
self._Nh,
suite_id,
self._hashmod)
return key, base_nonce, exporter_secret
def _new_cipher(self):
nonce = strxor(self._base_nonce, long_to_bytes(self._sequence, self._Nn))
if self._aead_id in (AEAD.AES128_GCM, AEAD.AES256_GCM):
cipher = AES.new(self._key, AES.MODE_GCM, nonce=nonce, mac_len=self._Nt)
elif self._aead_id == AEAD.CHACHA20_POLY1305:
cipher = ChaCha20_Poly1305.new(key=self._key, nonce=nonce)
else:
raise ValueError(f"Unknown AEAD cipher ID {self._aead_id:#x}")
self._sequence += 1
return cipher
def seal(self, plaintext: bytes, auth_data: Optional[bytes] = None):
"""Encrypt and authenticate a message.
This method can be invoked multiple times
to seal an ordered sequence of messages.
Arguments:
plaintext: bytes
The message to seal.
auth_data: bytes
Optional. Additional Authenticated data (AAD) that is not encrypted
but that will be also covered by the authentication tag.
Returns:
The ciphertext concatenated with the authentication tag.
"""
if not self._encrypt:
raise ValueError("This cipher can only be used to seal")
cipher = self._new_cipher()
if auth_data:
cipher.update(auth_data)
ct, tag = cipher.encrypt_and_digest(plaintext)
return ct + tag
def unseal(self, ciphertext: bytes, auth_data: Optional[bytes] = None):
"""Decrypt a message and validate its authenticity.
This method can be invoked multiple times
to unseal an ordered sequence of messages.
Arguments:
cipertext: bytes
The message to unseal.
auth_data: bytes
Optional. Additional Authenticated data (AAD) that
was also covered by the authentication tag.
Returns:
The original plaintext.
Raises: ValueError
If the ciphertext (in combination with the AAD) is not valid.
But if it is the first time you call ``unseal()`` this
exception may also mean that any of the parameters or keys
used to establish the session is wrong or that one is missing.
"""
if self._encrypt:
raise ValueError("This cipher can only be used to unseal")
if len(ciphertext) < self._Nt:
raise ValueError("Ciphertext is too small")
cipher = self._new_cipher()
if auth_data:
cipher.update(auth_data)
try:
pt = cipher.decrypt_and_verify(ciphertext[:-self._Nt],
ciphertext[-self._Nt:])
except ValueError:
if self._sequence == 1:
raise ValueError("Incorrect HPKE keys/parameters or invalid message (wrong MAC tag)")
raise ValueError("Invalid message (wrong MAC tag)")
return pt
def new(*, receiver_key: EccKey,
aead_id: AEAD,
enc: Optional[bytes] = None,
sender_key: Optional[EccKey] = None,
psk: Optional[tuple[bytes, bytes]] = None,
info: bytes = b''):
"""Create an HPKE context which can be used
by the sender to seal (encrypt) a message
or by the receiver to unseal (decrypt) it.
As a minimum, the two parties agree on the receiver's asymmetric key
(the sender will only know the public half).
Additionally, for authentication purposes, they may also agree on:
* the sender's asymmetric key (the receiver will only know the public half)
* a shared secret (e.g., a symmetric key derived from a password)
Arguments:
receiver_key: EccKey or EccXKey
The ECC key of the receiver.
It must be on one of the following curves: ``NIST P-256``,
``NIST P-384``, ``NIST P-521``, ``X25519`` or ``X448``.
If this is a **public** key, the HPKE context can only be used to
**encrypt**.
If this is a **private** key, the HPKE context can only be used to
**decrypt**.
aead_id: int
The HPKE identifier of the symmetric cipher.
The possible values are:
* ``HPKE.AEAD.AES128``
* ``HPKE.AEAD.AES256``
* ``HPKE.AEAD.CHACHA20_POLY1305``
Keywords:
enc: byte string
The encapsulated session key (i.e., the KEM shared secret).
The receiver must always specify this parameter.
The sender must always omit this parameter.
sender_key: EccKey
Optional. The ECC key of the sender.
It must be on the same curve as the ``receiver_key``.
If the ``receiver_key`` is a public key, ``sender_key`` must be a
private key, and vice versa.
psk: byte string tuple
Optional. A Pre-Shared Key (PSK) as a 2-tuple of non-empty
byte strings: the identifier and the actual secret value.
Sender and receiver must use the same PSK (or none).
info: byte string
Optional. A non-secret parameter that contributes
to the generation of all session keys.
Sender and receive must use the same **info** parameter (or none).
Returns (HPKE_Cipher):
An object that can be used for sealing (if ``receiver_key`` is a public
key) or unsealing (if ``receiver_key`` is a private key).
When unsealing, correctness of all the keys and parameters will only
be assessed with the first call to ``unseal()``.
.. _HPKE: https://datatracker.ietf.org/doc/rfc9180/
"""
if aead_id not in AEAD:
raise ValueError(f"Unknown AEAD cipher ID {aead_id:#x}")
curve = receiver_key.curve
if curve not in ('NIST P-256', 'NIST P-384', 'NIST P-521',
'Curve25519', 'Curve448'):
raise ValueError(f"Unsupported curve {curve}")
if sender_key:
count_private_keys = int(receiver_key.has_private()) + \
int(sender_key.has_private())
if count_private_keys != 1:
raise ValueError("Exactly 1 private key required")
if sender_key.curve != curve:
raise ValueError("Sender key uses {} but recipient key {}".
format(sender_key.curve, curve))
mode = MODE.AUTH if psk is None else MODE.AUTH_PSK
else:
mode = MODE.BASE if psk is None else MODE.PSK
if psk is None:
psk = b'', b''
return HPKE_Cipher(receiver_key,
enc,
sender_key,
psk,
info,
aead_id,
mode)