crypto/internal/hpke: separate KEM and PublicKey/PrivateKey interfaces

Updates #75300

Change-Id: I87ed26e8f57180d741408bdbda1696d46a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/719560
Reviewed-by: Mark Freeman <markfreeman@google.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Filippo Valsorda 2025-11-11 13:10:17 +01:00 committed by Gopher Robot
parent e15800c0ec
commit 7db2f0bb9a
6 changed files with 1743 additions and 792 deletions

View file

@ -132,12 +132,12 @@ func newContext(sharedSecret []byte, kemID uint16, kdf KDF, aead AEAD, info []by
//
// The returned enc ciphertext can be used to instantiate a matching receiving
// HPKE context with the corresponding KEM decapsulation key.
func NewSender(kem KEMSender, kdf KDF, aead AEAD, info []byte) (enc []byte, s *Sender, err error) {
sharedSecret, encapsulatedKey, err := kem.encap()
func NewSender(pk PublicKey, kdf KDF, aead AEAD, info []byte) (enc []byte, s *Sender, err error) {
sharedSecret, encapsulatedKey, err := pk.encap()
if err != nil {
return nil, nil, err
}
context, err := newContext(sharedSecret, kem.ID(), kdf, aead, info)
context, err := newContext(sharedSecret, pk.KEM().ID(), kdf, aead, info)
if err != nil {
return nil, nil, err
}
@ -151,12 +151,12 @@ func NewSender(kem KEMSender, kdf KDF, aead AEAD, info []byte) (enc []byte, s *S
// The enc parameter must have been produced by a matching sending HPKE context
// with the corresponding KEM encapsulation key. The info parameter is
// additional public information that must match between sender and recipient.
func NewRecipient(enc []byte, kem KEMRecipient, kdf KDF, aead AEAD, info []byte) (*Recipient, error) {
sharedSecret, err := kem.decap(enc)
func NewRecipient(enc []byte, k PrivateKey, kdf KDF, aead AEAD, info []byte) (*Recipient, error) {
sharedSecret, err := k.decap(enc)
if err != nil {
return nil, err
}
context, err := newContext(sharedSecret, kem.ID(), kdf, aead, info)
context, err := newContext(sharedSecret, k.KEM().ID(), kdf, aead, info)
if err != nil {
return nil, err
}
@ -179,16 +179,17 @@ func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
// Seal instantiates a single-use HPKE sending HPKE context like [NewSender],
// and then encrypts the provided plaintext like [Sender.Seal] (with no aad).
func Seal(kem KEMSender, kdf KDF, aead AEAD, info, plaintext []byte) (enc, ct []byte, err error) {
enc, s, err := NewSender(kem, kdf, aead, info)
// Seal returns the concatenation of the encapsulated key and the ciphertext.
func Seal(pk PublicKey, kdf KDF, aead AEAD, info, plaintext []byte) ([]byte, error) {
enc, s, err := NewSender(pk, kdf, aead, info)
if err != nil {
return nil, nil, err
return nil, err
}
ct, err = s.Seal(nil, plaintext)
ct, err := s.Seal(nil, plaintext)
if err != nil {
return nil, nil, err
return nil, err
}
return enc, ct, err
return append(enc, ct...), nil
}
// Export produces a secret value derived from the shared key between sender and
@ -219,8 +220,14 @@ func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) {
// Open instantiates a single-use HPKE receiving HPKE context like [NewRecipient],
// and then decrypts the provided ciphertext like [Recipient.Open] (with no aad).
func Open(enc []byte, kem KEMRecipient, kdf KDF, aead AEAD, info, ciphertext []byte) ([]byte, error) {
r, err := NewRecipient(enc, kem, kdf, aead, info)
// ciphertext must be the concatenation of the encapsulated key and the actual ciphertext.
func Open(k PrivateKey, kdf KDF, aead AEAD, info, ciphertext []byte) ([]byte, error) {
encSize := k.KEM().encSize()
if len(ciphertext) < encSize {
return nil, errors.New("ciphertext too short")
}
enc, ciphertext := ciphertext[:encSize], ciphertext[encSize:]
r, err := NewRecipient(enc, k, kdf, aead, info)
if err != nil {
return nil, err
}

View file

@ -18,6 +18,57 @@ import (
"testing"
)
func Example() {
// In this example, we use MLKEM768-X25519 as the KEM, HKDF-SHA256 as the
// KDF, and AES-256-GCM as the AEAD to encrypt a single message from a
// sender to a recipient using the one-shot API.
kem, kdf, aead := MLKEM768X25519(), HKDFSHA256(), AES256GCM()
// Recipient side
var (
recipientPrivateKey PrivateKey
publicKeyBytes []byte
)
{
k, err := kem.GenerateKey()
if err != nil {
panic(err)
}
recipientPrivateKey = k
publicKeyBytes = k.PublicKey().Bytes()
}
// Sender side
var ciphertext []byte
{
publicKey, err := kem.NewPublicKey(publicKeyBytes)
if err != nil {
panic(err)
}
message := []byte("|-()-|")
ct, err := Seal(publicKey, kdf, aead, []byte("example"), message)
if err != nil {
panic(err)
}
ciphertext = ct
}
// Recipient side
{
plaintext, err := Open(recipientPrivateKey, kdf, aead, []byte("example"), ciphertext)
if err != nil {
panic(err)
}
fmt.Printf("Decrypted message: %s\n", plaintext)
}
// Output:
// Decrypted message: |-()-|
}
func mustDecodeHex(t *testing.T, in string) []byte {
t.Helper()
b, err := hex.DecodeString(in)
@ -83,6 +134,12 @@ func testVectors(t *testing.T, name string) {
if vector.KEM == 0x0021 {
t.Skip("KEM 0x0021 (DHKEM(X448)) not supported")
}
if vector.KEM == 0x0040 {
t.Skip("KEM 0x0040 (ML-KEM-512) not supported")
}
if vector.KDF == 0x0012 || vector.KDF == 0x0013 {
t.Skipf("TurboSHAKE KDF not supported")
}
kdf, err := NewKDF(vector.KDF)
if err != nil {
@ -100,26 +157,37 @@ func testVectors(t *testing.T, name string) {
t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD)
}
pubKeyBytes := mustDecodeHex(t, vector.PkRm)
kemSender, err := NewKEMSender(vector.KEM, pubKeyBytes)
kem, err := NewKEM(vector.KEM)
if err != nil {
t.Fatal(err)
}
if kemSender.ID() != vector.KEM {
t.Errorf("unexpected KEM ID: got %04x, want %04x", kemSender.ID(), vector.KEM)
if kem.ID() != vector.KEM {
t.Errorf("unexpected KEM ID: got %04x, want %04x", kem.ID(), vector.KEM)
}
pubKeyBytes := mustDecodeHex(t, vector.PkRm)
kemSender, err := kem.NewPublicKey(pubKeyBytes)
if err != nil {
t.Fatal(err)
}
if kemSender.KEM() != kem {
t.Errorf("unexpected KEM from sender: got %04x, want %04x", kemSender.KEM().ID(), kem.ID())
}
if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes)
}
ikmE := mustDecodeHex(t, vector.IkmE)
setupDerandomizedEncap(t, vector.KEM, ikmE, kemSender)
setupDerandomizedEncap(t, ikmE, kemSender)
info := mustDecodeHex(t, vector.Info)
encap, sender, err := NewSender(kemSender, kdf, aead, info)
if err != nil {
t.Fatal(err)
}
if len(encap) != kem.encSize() {
t.Errorf("unexpected encapsulated key size: got %d, want %d", len(encap), kem.encSize())
}
expectedEncap := mustDecodeHex(t, vector.Enc)
if !bytes.Equal(encap, expectedEncap) {
@ -127,23 +195,23 @@ func testVectors(t *testing.T, name string) {
}
privKeyBytes := mustDecodeHex(t, vector.SkRm)
kemRecipient, err := NewKEMRecipient(vector.KEM, privKeyBytes)
kemRecipient, err := kem.NewPrivateKey(privKeyBytes)
if err != nil {
t.Fatal(err)
}
if kemRecipient.ID() != vector.KEM {
t.Errorf("unexpected KEM ID: got %04x, want %04x", kemRecipient.ID(), vector.KEM)
if kemRecipient.KEM() != kem {
t.Errorf("unexpected KEM from recipient: got %04x, want %04x", kemRecipient.KEM().ID(), kem.ID())
}
kemRecipientBytes, err := kemRecipient.Bytes()
if err != nil {
t.Fatal(err)
}
// X25519 serialized keys must be clamped, so the bytes might not match.
if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != dhkemX25519 {
if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes)
}
if vector.KEM == dhkemX25519 {
kem2, err := NewKEMRecipient(vector.KEM, kemRecipientBytes)
if vector.KEM == DHKEM(ecdh.X25519()).ID() {
kem2, err := kem.NewPrivateKey(kemRecipientBytes)
if err != nil {
t.Fatal(err)
}
@ -154,32 +222,28 @@ func testVectors(t *testing.T, name string) {
if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) {
t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes)
}
if !bytes.Equal(kem2.KEMSender().Bytes(), pubKeyBytes) {
t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.KEMSender().Bytes(), pubKeyBytes)
if !bytes.Equal(kem2.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.PublicKey().Bytes(), pubKeyBytes)
}
}
if !bytes.Equal(kemRecipient.KEMSender().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.KEMSender().Bytes(), pubKeyBytes)
if !bytes.Equal(kemRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.PublicKey().Bytes(), pubKeyBytes)
}
// NewKEMRecipientFromSeed is not implemented for the PQ KEMs yet.
if vector.KEM != mlkem768 && vector.KEM != mlkem1024 && vector.KEM != mlkem768X25519 &&
vector.KEM != mlkem768P256 && vector.KEM != mlkem1024P384 {
seed := mustDecodeHex(t, vector.IkmR)
seedRecipient, err := NewKEMRecipientFromSeed(vector.KEM, seed)
if err != nil {
t.Fatal(err)
}
seedRecipientBytes, err := seedRecipient.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(seedRecipientBytes, privKeyBytes) && vector.KEM != dhkemX25519 {
t.Errorf("unexpected KEM bytes from seed: got %x, want %x", seedRecipientBytes, privKeyBytes)
}
if !bytes.Equal(seedRecipient.KEMSender().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", seedRecipient.KEMSender().Bytes(), pubKeyBytes)
}
ikm := mustDecodeHex(t, vector.IkmR)
derivRecipient, err := kem.DeriveKeyPair(ikm)
if err != nil {
t.Fatal(err)
}
derivRecipientBytes, err := derivRecipient.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(derivRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Errorf("unexpected KEM bytes from seed: got %x, want %x", derivRecipientBytes, privKeyBytes)
}
if !bytes.Equal(derivRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", derivRecipient.PublicKey().Bytes(), pubKeyBytes)
}
recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info)
@ -304,22 +368,22 @@ func drawRandomInput(t *testing.T, r io.Reader) []byte {
return b
}
func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KEMSender) {
func setupDerandomizedEncap(t *testing.T, randBytes []byte, pk PublicKey) {
t.Cleanup(func() {
testingOnlyGenerateKey = nil
testingOnlyEncapsulate = nil
})
switch kemID {
case dhkemP256, dhkemP384, dhkemP521, dhkemX25519:
r, err := NewKEMRecipientFromSeed(kemID, randBytes)
switch pk.KEM() {
case DHKEM(ecdh.P256()), DHKEM(ecdh.P384()), DHKEM(ecdh.P521()), DHKEM(ecdh.X25519()):
r, err := pk.KEM().DeriveKeyPair(randBytes)
if err != nil {
t.Fatal(err)
}
testingOnlyGenerateKey = func() *ecdh.PrivateKey {
return r.(*dhKEMRecipient).priv.(*ecdh.PrivateKey)
return r.(*dhKEMPrivateKey).priv.(*ecdh.PrivateKey)
}
case mlkem768:
pq := kem.(*mlkemSender).pq.(*mlkem.EncapsulationKey768)
pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey768)
testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate768(pq, randBytes)
if err != nil {
@ -328,7 +392,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
return ss, ct
}
case mlkem1024:
pq := kem.(*mlkemSender).pq.(*mlkem.EncapsulationKey1024)
pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey1024)
testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate1024(pq, randBytes)
if err != nil {
@ -338,7 +402,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
}
case mlkem768X25519:
pqRand, tRand := randBytes[:32], randBytes[32:]
pq := kem.(*hybridSender).pq.(*mlkem.EncapsulationKey768)
pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
k, err := ecdh.X25519().NewPrivateKey(tRand)
if err != nil {
t.Fatal(err)
@ -357,7 +421,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
// The rest of randBytes are the following candidates for rejection
// sampling, but they are never reached.
pqRand, tRand := randBytes[:32], randBytes[32:64]
pq := kem.(*hybridSender).pq.(*mlkem.EncapsulationKey768)
pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
k, err := ecdh.P256().NewPrivateKey(tRand)
if err != nil {
t.Fatal(err)
@ -374,7 +438,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
}
case mlkem1024P384:
pqRand, tRand := randBytes[:32], randBytes[32:]
pq := kem.(*hybridSender).pq.(*mlkem.EncapsulationKey1024)
pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey1024)
k, err := ecdh.P384().NewPrivateKey(tRand)
if err != nil {
t.Fatal(err)
@ -390,7 +454,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
return ss, ct
}
default:
t.Fatalf("unsupported KEM %04x", kemID)
t.Fatalf("unsupported KEM %04x", pk.KEM().ID())
}
}
@ -416,4 +480,31 @@ func TestSingletons(t *testing.T) {
if ExportOnly() != ExportOnly() {
t.Error("ExportOnly() != ExportOnly()")
}
if DHKEM(ecdh.P256()) != DHKEM(ecdh.P256()) {
t.Error("DHKEM(P-256) != DHKEM(P-256)")
}
if DHKEM(ecdh.P384()) != DHKEM(ecdh.P384()) {
t.Error("DHKEM(P-384) != DHKEM(P-384)")
}
if DHKEM(ecdh.P521()) != DHKEM(ecdh.P521()) {
t.Error("DHKEM(P-521) != DHKEM(P-521)")
}
if DHKEM(ecdh.X25519()) != DHKEM(ecdh.X25519()) {
t.Error("DHKEM(X25519) != DHKEM(X25519)")
}
if MLKEM768() != MLKEM768() {
t.Error("MLKEM768() != MLKEM768()")
}
if MLKEM1024() != MLKEM1024() {
t.Error("MLKEM1024() != MLKEM1024()")
}
if MLKEM768X25519() != MLKEM768X25519() {
t.Error("MLKEM768X25519() != MLKEM768X25519()")
}
if MLKEM768P256() != MLKEM768P256() {
t.Error("MLKEM768P256() != MLKEM768P256()")
}
if MLKEM1024P384() != MLKEM1024P384() {
t.Error("MLKEM1024P384() != MLKEM1024P384()")
}
}

View file

@ -6,129 +6,90 @@ package hpke
import (
"crypto/ecdh"
"crypto/mlkem"
"crypto/rand"
"encoding/binary"
"errors"
)
const (
dhkemP256 = 0x0010 // DHKEM(P-256, HKDF-SHA256)
dhkemP384 = 0x0011 // DHKEM(P-384, HKDF-SHA384)
dhkemP521 = 0x0012 // DHKEM(P-521, HKDF-SHA512)
dhkemX25519 = 0x0020 // DHKEM(X25519, HKDF-SHA256)
)
// A KEMSender is an instantiation of a KEM (one of the three components of an
// HPKE ciphersuite) with an encapsulation key (i.e. the public key).
type KEMSender interface {
// A KEM is a Key Encapsulation Mechanism, one of the three components of an
// HPKE ciphersuite.
type KEM interface {
// ID returns the HPKE KEM identifier.
ID() uint16
// GenerateKey generates a new key pair.
GenerateKey() (PrivateKey, error)
// NewPublicKey deserializes a public key from bytes.
//
// It implements DeserializePublicKey, as defined in RFC 9180.
NewPublicKey([]byte) (PublicKey, error)
// NewPrivateKey deserializes a private key from bytes.
//
// It implements DeserializePrivateKey, as defined in RFC 9180.
NewPrivateKey([]byte) (PrivateKey, error)
// DeriveKeyPair derives a key pair from the given input keying material.
//
// It implements DeriveKeyPair, as defined in RFC 9180.
DeriveKeyPair(ikm []byte) (PrivateKey, error)
encSize() int
}
// NewKEM returns the KEM implementation for the given KEM ID.
//
// Applications are encouraged to use specific implementations like [DHKEM] or
// [MLKEM768X25519] instead, unless runtime agility is required.
func NewKEM(id uint16) (KEM, error) {
switch id {
case 0x0010: // DHKEM(P-256, HKDF-SHA256)
return DHKEM(ecdh.P256()), nil
case 0x0011: // DHKEM(P-384, HKDF-SHA384)
return DHKEM(ecdh.P384()), nil
case 0x0012: // DHKEM(P-521, HKDF-SHA512)
return DHKEM(ecdh.P521()), nil
case 0x0020: // DHKEM(X25519, HKDF-SHA256)
return DHKEM(ecdh.X25519()), nil
case 0x0041: // ML-KEM-768
return MLKEM768(), nil
case 0x0042: // ML-KEM-1024
return MLKEM1024(), nil
case 0x647a: // MLKEM768-X25519
return MLKEM768X25519(), nil
case 0x0050: // MLKEM768-P256
return MLKEM768P256(), nil
case 0x0051: // MLKEM1024-P384
return MLKEM1024P384(), nil
default:
return nil, errors.New("unsupported KEM")
}
}
// A PublicKey is an instantiation of a KEM (one of the three components of an
// HPKE ciphersuite) with an encapsulation key (i.e. the public key).
//
// A PublicKey is usually obtained from a method of the corresponding [KEM] or
// [PrivateKey], such as [KEM.NewPublicKey] or [PrivateKey.PublicKey].
type PublicKey interface {
// KEM returns the instantiated KEM.
KEM() KEM
// Bytes returns the public key as the output of SerializePublicKey.
Bytes() []byte
encap() (sharedSecret, enc []byte, err error)
}
// NewKEMSender implements DeserializePublicKey and returns a KEMSender
// for the given KEM ID and public key bytes.
//
// Applications are encouraged to use [ecdh.Curve.NewPublicKey] with
// [NewECDHSender] instead, unless runtime agility is required.
func NewKEMSender(id uint16, pub []byte) (KEMSender, error) {
switch id {
case dhkemP256:
k, err := ecdh.P256().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case dhkemP384:
k, err := ecdh.P384().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case dhkemP521:
k, err := ecdh.P521().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case dhkemX25519:
k, err := ecdh.X25519().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case mlkem768:
if len(pub) != mlkem.EncapsulationKeySize768 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey768(pub)
if err != nil {
return nil, err
}
return NewMLKEMSender(pq)
case mlkem1024:
if len(pub) != mlkem.EncapsulationKeySize1024 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey1024(pub)
if err != nil {
return nil, err
}
return NewMLKEMSender(pq)
case mlkem768X25519:
if len(pub) != mlkem.EncapsulationKeySize768+32 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey768(pub[:mlkem.EncapsulationKeySize768])
if err != nil {
return nil, err
}
k, err := ecdh.X25519().NewPublicKey(pub[mlkem.EncapsulationKeySize768:])
if err != nil {
return nil, err
}
return NewHybridSender(k, pq)
case mlkem768P256:
if len(pub) != mlkem.EncapsulationKeySize768+65 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey768(pub[:mlkem.EncapsulationKeySize768])
if err != nil {
return nil, err
}
k, err := ecdh.P256().NewPublicKey(pub[mlkem.EncapsulationKeySize768:])
if err != nil {
return nil, err
}
return NewHybridSender(k, pq)
case mlkem1024P384:
if len(pub) != mlkem.EncapsulationKeySize1024+97 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey1024(pub[:mlkem.EncapsulationKeySize1024])
if err != nil {
return nil, err
}
k, err := ecdh.P384().NewPublicKey(pub[mlkem.EncapsulationKeySize1024:])
if err != nil {
return nil, err
}
return NewHybridSender(k, pq)
default:
return nil, errors.New("unsupported KEM")
}
}
// A KEMRecipient is an instantiation of a KEM (one of the three components of
// A PrivateKey is an instantiation of a KEM (one of the three components of
// an HPKE ciphersuite) with a decapsulation key (i.e. the secret key).
type KEMRecipient interface {
// ID returns the HPKE KEM identifier.
ID() uint16
//
// A PrivateKey is usually obtained from a method of the corresponding [KEM],
// such as [KEM.GenerateKey] or [KEM.NewPrivateKey].
type PrivateKey interface {
// KEM returns the instantiated KEM.
KEM() KEM
// Bytes returns the private key as the output of SerializePrivateKey, as
// defined in RFC 9180.
@ -137,117 +98,232 @@ type KEMRecipient interface {
// This is a requirement of RFC 9180, Section 7.1.2.
Bytes() ([]byte, error)
// KEMSender returns the corresponding KEMSender for this recipient.
KEMSender() KEMSender
// PublicKey returns the corresponding PublicKey.
PublicKey() PublicKey
decap(enc []byte) (sharedSecret []byte, err error)
}
// NewKEMRecipient implements DeserializePrivateKey, as defined in RFC 9180, and
// returns a KEMRecipient for the given KEM ID and private key bytes.
//
// Applications are encouraged to use [ecdh.Curve.NewPrivateKey] with
// [NewECDHRecipient] instead, unless runtime agility is required.
func NewKEMRecipient(id uint16, priv []byte) (KEMRecipient, error) {
switch id {
case dhkemP256:
k, err := ecdh.P256().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case dhkemP384:
k, err := ecdh.P384().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case dhkemP521:
k, err := ecdh.P521().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case dhkemX25519:
k, err := ecdh.X25519().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case mlkem768:
pq, err := mlkem.NewDecapsulationKey768(priv)
if err != nil {
return nil, err
}
return NewMLKEMRecipient(pq)
case mlkem1024:
pq, err := mlkem.NewDecapsulationKey1024(priv)
if err != nil {
return nil, err
}
return NewMLKEMRecipient(pq)
case mlkem768X25519, mlkem768P256, mlkem1024P384:
return newHybridRecipientFromSeed(id, priv)
default:
return nil, errors.New("unsupported KEM")
}
type dhKEM struct {
kdf KDF
id uint16
curve ecdh.Curve
Nsecret uint16
Nsk uint16
Nenc int
}
// NewKEMRecipientFromSeed implements DeriveKeyPair, as defined in RFC 9180, and
// returns a KEMRecipient for the given KEM ID and private key seed.
//
// Currently, it only supports the KEMs based on ECDH (DHKEM).
//
// TODO: rename to something about deriving.
func NewKEMRecipientFromSeed(id uint16, seed []byte) (KEMRecipient, error) {
// DeriveKeyPair from RFC 9180 Section 7.1.3.
var curve ecdh.Curve
var dh dhKEM
var Nsk uint16
switch id {
case dhkemP256:
curve = ecdh.P256()
dh, _ = dhKEMForCurve(curve)
Nsk = 32
case dhkemP384:
curve = ecdh.P384()
dh, _ = dhKEMForCurve(curve)
Nsk = 48
case dhkemP521:
curve = ecdh.P521()
dh, _ = dhKEMForCurve(curve)
Nsk = 66
case dhkemX25519:
curve = ecdh.X25519()
dh, _ = dhKEMForCurve(curve)
Nsk = 32
// Do not implement the PQ KEMs for now, as the seed input is not
// stable yet. See https://github.com/hpkewg/hpke-pq/issues/30.
default:
return nil, errors.New("unsupported KEM")
}
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), dh.id)
prk, err := dh.kdf.labeledExtract(suiteID, nil, "dkp_prk", seed)
func (kem *dhKEM) extractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
eaePRK, err := kem.kdf.labeledExtract(suiteID, nil, "eae_prk", dhKey)
if err != nil {
return nil, err
}
if id == dhkemX25519 {
s, err := dh.kdf.labeledExpand(suiteID, prk, "sk", nil, Nsk)
return kem.kdf.labeledExpand(suiteID, eaePRK, "shared_secret", kemContext, kem.Nsecret)
}
func (kem *dhKEM) ID() uint16 {
return kem.id
}
func (kem *dhKEM) encSize() int {
return kem.Nenc
}
var dhKEMP256 = &dhKEM{HKDFSHA256(), 0x0010, ecdh.P256(), 32, 32, 65}
var dhKEMP384 = &dhKEM{HKDFSHA384(), 0x0011, ecdh.P384(), 48, 48, 97}
var dhKEMP521 = &dhKEM{HKDFSHA512(), 0x0012, ecdh.P521(), 64, 66, 133}
var dhKEMX25519 = &dhKEM{HKDFSHA256(), 0x0020, ecdh.X25519(), 32, 32, 32}
// DHKEM returns a KEM implementing one of
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on curve.
func DHKEM(curve ecdh.Curve) KEM {
switch curve {
case ecdh.P256():
return dhKEMP256
case ecdh.P384():
return dhKEMP384
case ecdh.P521():
return dhKEMP521
case ecdh.X25519():
return dhKEMX25519
default:
// The set of ecdh.Curve implementations is closed, because the
// interface has unexported methods. Therefore, this default case is
// only hit if a new curve is added that DHKEM doesn't support.
return unsupportedCurveKEM{}
}
}
type unsupportedCurveKEM struct{}
func (unsupportedCurveKEM) ID() uint16 {
return 0
}
func (unsupportedCurveKEM) GenerateKey() (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) NewPublicKey([]byte) (PublicKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) NewPrivateKey([]byte) (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) DeriveKeyPair([]byte) (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) encSize() int {
return 0
}
type dhKEMPublicKey struct {
kem *dhKEM
pub *ecdh.PublicKey
}
// NewDHKEMPublicKey returns a PublicKey implementing
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of pub ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of [DHKEM].
func NewDHKEMPublicKey(pub *ecdh.PublicKey) (PublicKey, error) {
kem, ok := DHKEM(pub.Curve()).(*dhKEM)
if !ok {
return nil, errors.New("unsupported curve")
}
return &dhKEMPublicKey{
kem: kem,
pub: pub,
}, nil
}
func (kem *dhKEM) NewPublicKey(data []byte) (PublicKey, error) {
pub, err := kem.curve.NewPublicKey(data)
if err != nil {
return nil, err
}
return NewDHKEMPublicKey(pub)
}
func (pk *dhKEMPublicKey) KEM() KEM {
return pk.kem
}
func (pk *dhKEMPublicKey) Bytes() []byte {
return pk.pub.Bytes()
}
// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() *ecdh.PrivateKey
func (pk *dhKEMPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
privEph, err := pk.pub.Curve().GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
if testingOnlyGenerateKey != nil {
privEph = testingOnlyGenerateKey()
}
dhVal, err := privEph.ECDH(pk.pub)
if err != nil {
return nil, nil, err
}
encPubEph := privEph.PublicKey().Bytes()
encPubRecip := pk.pub.Bytes()
kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = pk.kem.extractAndExpand(dhVal, kemContext)
if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
}
type dhKEMPrivateKey struct {
kem *dhKEM
priv ecdh.KeyExchanger
}
// NewDHKEMPrivateKey returns a PrivateKey implementing
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of priv ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh private key, or another implementation of a [ecdh.KeyExchanger]
// (e.g. a hardware key). Otherwise, applications should use the
// [KEM.NewPrivateKey] method of [DHKEM].
func NewDHKEMPrivateKey(priv ecdh.KeyExchanger) (PrivateKey, error) {
kem, ok := DHKEM(priv.Curve()).(*dhKEM)
if !ok {
return nil, errors.New("unsupported curve")
}
return &dhKEMPrivateKey{
kem: kem,
priv: priv,
}, nil
}
func (kem *dhKEM) GenerateKey() (PrivateKey, error) {
priv, err := kem.curve.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
return NewDHKEMPrivateKey(priv)
}
func (kem *dhKEM) NewPrivateKey(ikm []byte) (PrivateKey, error) {
priv, err := kem.curve.NewPrivateKey(ikm)
if err != nil {
return nil, err
}
return NewDHKEMPrivateKey(priv)
}
func (kem *dhKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
// DeriveKeyPair from RFC 9180 Section 7.1.3.
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
prk, err := kem.kdf.labeledExtract(suiteID, nil, "dkp_prk", ikm)
if err != nil {
return nil, err
}
if kem == dhKEMX25519 {
s, err := kem.kdf.labeledExpand(suiteID, prk, "sk", nil, kem.Nsk)
if err != nil {
return nil, err
}
return NewKEMRecipient(id, s)
return kem.NewPrivateKey(s)
}
var counter uint8
for counter < 4 {
s, err := dh.kdf.labeledExpand(suiteID, prk, "candidate", []byte{counter}, Nsk)
s, err := kem.kdf.labeledExpand(suiteID, prk, "candidate", []byte{counter}, kem.Nsk)
if err != nil {
return nil, err
}
if id == dhkemP521 {
if kem == dhKEMP521 {
s[0] &= 0x01
}
r, err := NewKEMRecipient(id, s)
r, err := kem.NewPrivateKey(s)
if err != nil {
counter++
continue
@ -257,136 +333,11 @@ func NewKEMRecipientFromSeed(id uint16, seed []byte) (KEMRecipient, error) {
panic("chance of four rejections is < 2^-128")
}
type dhKEM struct {
kdf KDF
id uint16
nSecret uint16
func (k *dhKEMPrivateKey) KEM() KEM {
return k.kem
}
func (dh *dhKEM) extractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), dh.id)
eaePRK, err := dh.kdf.labeledExtract(suiteID, nil, "eae_prk", dhKey)
if err != nil {
return nil, err
}
return dh.kdf.labeledExpand(suiteID, eaePRK, "shared_secret", kemContext, dh.nSecret)
}
func (dh *dhKEM) ID() uint16 {
return dh.id
}
type dhKEMSender struct {
dhKEM
pub *ecdh.PublicKey
}
// NewECDHSender returns a KEMSender implementing one of
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of the provided public key.
func NewECDHSender(pub *ecdh.PublicKey) (KEMSender, error) {
dhKEM, err := dhKEMForCurve(pub.Curve())
if err != nil {
return nil, err
}
return &dhKEMSender{
pub: pub,
dhKEM: dhKEM,
}, nil
}
func dhKEMForCurve(curve ecdh.Curve) (dhKEM, error) {
switch curve {
case ecdh.P256():
return dhKEM{
kdf: HKDFSHA256(),
id: dhkemP256,
nSecret: 32,
}, nil
case ecdh.P384():
return dhKEM{
kdf: HKDFSHA384(),
id: dhkemP384,
nSecret: 48,
}, nil
case ecdh.P521():
return dhKEM{
kdf: HKDFSHA512(),
id: dhkemP521,
nSecret: 64,
}, nil
case ecdh.X25519():
return dhKEM{
kdf: HKDFSHA256(),
id: dhkemX25519,
nSecret: 32,
}, nil
default:
return dhKEM{}, errors.New("unsupported curve")
}
}
func (dh *dhKEMSender) Bytes() []byte {
return dh.pub.Bytes()
}
// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() *ecdh.PrivateKey
func (dh *dhKEMSender) encap() (sharedSecret []byte, encapPub []byte, err error) {
privEph, err := dh.pub.Curve().GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
if testingOnlyGenerateKey != nil {
privEph = testingOnlyGenerateKey()
}
dhVal, err := privEph.ECDH(dh.pub)
if err != nil {
return nil, nil, err
}
encPubEph := privEph.PublicKey().Bytes()
encPubRecip := dh.pub.Bytes()
kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = dh.extractAndExpand(dhVal, kemContext)
if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
}
type dhKEMRecipient struct {
dhKEM
priv ecdh.KeyExchanger
}
// NewECDHRecipient returns a KEMRecipient implementing one of
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of the provided private key.
func NewECDHRecipient(priv ecdh.KeyExchanger) (KEMRecipient, error) {
dhKEM, err := dhKEMForCurve(priv.Curve())
if err != nil {
return nil, err
}
return &dhKEMRecipient{
priv: priv,
dhKEM: dhKEM,
}, nil
}
func (dh *dhKEMRecipient) Bytes() ([]byte, error) {
func (k *dhKEMPrivateKey) Bytes() ([]byte, error) {
// Bizarrely, RFC 9180, Section 7.1.2 says SerializePrivateKey MUST clamp
// the output, which I thought we all agreed to instead do as part of the DH
// function, letting private keys be random bytes.
@ -396,11 +347,11 @@ func (dh *dhKEMRecipient) Bytes() ([]byte, error) {
// necessarily match the NewPrivateKey input.
//
// I'm sure this will not lead to any unexpected behavior or interop issue.
priv, ok := dh.priv.(*ecdh.PrivateKey)
priv, ok := k.priv.(*ecdh.PrivateKey)
if !ok {
return nil, errors.New("ecdh: private key does not support Bytes")
}
if dh.id == dhkemX25519 {
if k.kem == dhKEMX25519 {
b := priv.Bytes()
b[0] &= 248
b[31] &= 127
@ -410,22 +361,22 @@ func (dh *dhKEMRecipient) Bytes() ([]byte, error) {
return priv.Bytes(), nil
}
func (dh *dhKEMRecipient) KEMSender() KEMSender {
return &dhKEMSender{
pub: dh.priv.PublicKey(),
dhKEM: dh.dhKEM,
func (k *dhKEMPrivateKey) PublicKey() PublicKey {
return &dhKEMPublicKey{
kem: k.kem,
pub: k.priv.PublicKey(),
}
}
func (dh *dhKEMRecipient) decap(encPubEph []byte) ([]byte, error) {
pubEph, err := dh.priv.Curve().NewPublicKey(encPubEph)
func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) {
pubEph, err := k.priv.Curve().NewPublicKey(encPubEph)
if err != nil {
return nil, err
}
dhVal, err := dh.priv.ECDH(pubEph)
dhVal, err := k.priv.ECDH(pubEph)
if err != nil {
return nil, err
}
kemContext := append(encPubEph, dh.priv.PublicKey().Bytes()...)
return dh.extractAndExpand(dhVal, kemContext)
kemContext := append(encPubEph, k.priv.PublicKey().Bytes()...)
return k.kem.extractAndExpand(dhVal, kemContext)
}

View file

@ -11,127 +11,221 @@ import (
"crypto/mlkem"
"crypto/rand"
"crypto/sha3"
"encoding/binary"
"errors"
)
const (
mlkem768 = 0x0041 // ML-KEM-768
mlkem1024 = 0x0042 // ML-KEM-1024
mlkem768X25519 = 0x647a // MLKEM768-X25519
mlkem768P256 = 0x0050 // MLKEM768-P256
mlkem1024P384 = 0x0051 // MLKEM1024-P384
)
var mlkem768X25519Hybrid = hybrid{
id: mlkem768X25519,
var mlkem768X25519 = &hybridKEM{
id: 0x647a,
label: /**/ `\./` +
/* */ `/^\`,
curve: ecdh.X25519(),
curveSeedSize: 32,
curvePointSize: 32,
pqEncapsKeySize: mlkem.EncapsulationKeySize768,
pqCiphertextSize: mlkem.CiphertextSize768,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
}
var mlkem768P256Hybrid = hybrid{
id: mlkem768P256,
// MLKEM768X25519 returns a KEM implementing MLKEM768-X25519 (a.k.a. X-Wing)
// from draft-ietf-hpke-pq.
func MLKEM768X25519() KEM {
return mlkem768X25519
}
var mlkem768P256 = &hybridKEM{
id: 0x0050,
label: "MLKEM768-P256",
curve: ecdh.P256(),
curveSeedSize: 32,
curvePointSize: 65,
pqEncapsKeySize: mlkem.EncapsulationKeySize768,
pqCiphertextSize: mlkem.CiphertextSize768,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
}
var mlkem1024P384Hybrid = hybrid{
id: mlkem1024P384,
// MLKEM768P256 returns a KEM implementing MLKEM768-P256 from draft-ietf-hpke-pq.
func MLKEM768P256() KEM {
return mlkem768P256
}
var mlkem1024P384 = &hybridKEM{
id: 0x0051,
label: "MLKEM1024-P384",
curve: ecdh.P384(),
curveSeedSize: 48,
curvePointSize: 97,
pqEncapsKeySize: mlkem.EncapsulationKeySize1024,
pqCiphertextSize: mlkem.CiphertextSize1024,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey1024(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey1024(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey1024()
},
}
type hybrid struct {
// MLKEM1024P384 returns a KEM implementing MLKEM1024-P384 from draft-ietf-hpke-pq.
func MLKEM1024P384() KEM {
return mlkem1024P384
}
type hybridKEM struct {
id uint16
label string
curve ecdh.Curve
curveSeedSize int
curvePointSize int
pqEncapsKeySize int
pqCiphertextSize int
pqNewPublicKey func(data []byte) (crypto.Encapsulator, error)
pqNewPrivateKey func(data []byte) (crypto.Decapsulator, error)
pqGenerateKey func() (crypto.Decapsulator, error)
}
func (x *hybrid) ID() uint16 {
return x.id
func (kem *hybridKEM) ID() uint16 {
return kem.id
}
func (x *hybrid) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte {
func (kem *hybridKEM) encSize() int {
return kem.pqCiphertextSize + kem.curvePointSize
}
func (kem *hybridKEM) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte {
h := sha3.New256()
h.Write(ssPQ)
h.Write(ssT)
h.Write(ctT)
h.Write(ekT)
h.Write([]byte(x.label))
h.Write([]byte(kem.label))
return h.Sum(nil)
}
type hybridSender struct {
hybrid
t *ecdh.PublicKey
pq crypto.Encapsulator
type hybridPublicKey struct {
kem *hybridKEM
t *ecdh.PublicKey
pq crypto.Encapsulator
}
// NewHybridSender returns a KEMSender implementing one of
// NewHybridPublicKey returns a PublicKey implementing one of
//
// - MLKEM768-X25519 (a.k.a. X-Wing)
// - MLKEM768-P256
// - MLKEM1024-P384
//
// from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
func NewHybridSender(t *ecdh.PublicKey, pq crypto.Encapsulator) (KEMSender, error) {
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq (either
// *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem public keys. Otherwise, applications should use
// the [KEM.NewPublicKey] method of e.g. [MLKEM768X25519].
func NewHybridPublicKey(pq crypto.Encapsulator, t *ecdh.PublicKey) (PublicKey, error) {
switch t.Curve() {
case ecdh.X25519():
if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for X25519 hybrid")
}
return &hybridSender{mlkem768X25519Hybrid, t, pq}, nil
return &hybridPublicKey{mlkem768X25519, t, pq}, nil
case ecdh.P256():
if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for P-256 hybrid")
}
return &hybridSender{mlkem768P256Hybrid, t, pq}, nil
return &hybridPublicKey{mlkem768P256, t, pq}, nil
case ecdh.P384():
if _, ok := pq.(*mlkem.EncapsulationKey1024); !ok {
return nil, errors.New("invalid PQ KEM for P-384 hybrid")
}
return &hybridSender{mlkem1024P384Hybrid, t, pq}, nil
return &hybridPublicKey{mlkem1024P384, t, pq}, nil
default:
return nil, errors.New("unsupported curve")
}
}
func (s *hybridSender) Bytes() []byte {
return append(s.pq.Bytes(), s.t.Bytes()...)
func (kem *hybridKEM) NewPublicKey(data []byte) (PublicKey, error) {
if len(data) != kem.pqEncapsKeySize+kem.curvePointSize {
return nil, errors.New("invalid public key size")
}
pq, err := kem.pqNewPublicKey(data[:kem.pqEncapsKeySize])
if err != nil {
return nil, err
}
k, err := kem.curve.NewPublicKey(data[kem.pqEncapsKeySize:])
if err != nil {
return nil, err
}
return NewHybridPublicKey(pq, k)
}
func (pk *hybridPublicKey) KEM() KEM {
return pk.kem
}
func (pk *hybridPublicKey) Bytes() []byte {
return append(pk.pq.Bytes(), pk.t.Bytes()...)
}
var testingOnlyEncapsulate func() (ss, ct []byte)
func (s *hybridSender) encap() (sharedSecret []byte, encapPub []byte, err error) {
skE, err := s.t.Curve().GenerateKey(rand.Reader)
func (pk *hybridPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
skE, err := pk.t.Curve().GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
if testingOnlyGenerateKey != nil {
skE = testingOnlyGenerateKey()
}
ssT, err := skE.ECDH(s.t)
ssT, err := skE.ECDH(pk.t)
if err != nil {
return nil, nil, err
}
ctT := skE.PublicKey().Bytes()
ssPQ, ctPQ := s.pq.Encapsulate()
ssPQ, ctPQ := pk.pq.Encapsulate()
if testingOnlyEncapsulate != nil {
ssPQ, ctPQ = testingOnlyEncapsulate()
}
ss := s.sharedSecret(ssPQ, ssT, ctT, s.t.Bytes())
ss := pk.kem.sharedSecret(ssPQ, ssT, ctT, pk.t.Bytes())
ct := append(ctPQ, ctT...)
return ss, ct, nil
}
type hybridRecipient struct {
hybrid
type hybridPrivateKey struct {
kem *hybridKEM
seed []byte // can be nil
t ecdh.KeyExchanger
pq crypto.Decapsulator
}
// NewHybridRecipient returns a KEMRecipient implementing
// NewHybridPrivateKey returns a PrivateKey implementing
//
// - MLKEM768-X25519 (a.k.a. X-Wing)
// - MLKEM768-P256
@ -140,11 +234,23 @@ type hybridRecipient struct {
// from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
func NewHybridRecipient(t ecdh.KeyExchanger, pq crypto.Decapsulator) (KEMRecipient, error) {
return newHybridRecipient(t, pq, nil)
//
// This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem private keys, or another implementation of a
// [ecdh.KeyExchanger] and [crypto.Decapsulator] (e.g. a hardware key).
// Otherwise, applications should use the [KEM.NewPrivateKey] method of e.g.
// [MLKEM768X25519].
func NewHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger) (PrivateKey, error) {
return newHybridPrivateKey(pq, t, nil)
}
func newHybridRecipientFromSeed(id uint16, priv []byte) (KEMRecipient, error) {
func (kem *hybridKEM) GenerateKey() (PrivateKey, error) {
seed := make([]byte, 32)
rand.Read(seed)
return kem.NewPrivateKey(seed)
}
func (kem *hybridKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
if len(priv) != 32 {
return nil, errors.New("hpke: invalid hybrid KEM secret length")
}
@ -154,200 +260,256 @@ func newHybridRecipientFromSeed(id uint16, priv []byte) (KEMRecipient, error) {
seedPQ := make([]byte, mlkem.SeedSize)
s.Read(seedPQ)
var pq crypto.Decapsulator
switch id {
case mlkem768X25519, mlkem768P256:
sk, err := mlkem.NewDecapsulationKey768(seedPQ)
if err != nil {
return nil, err
}
pq = sk
case mlkem1024P384:
sk, err := mlkem.NewDecapsulationKey1024(seedPQ)
if err != nil {
return nil, err
}
pq = sk
default:
return nil, errors.New("hpke: invalid hybrid KEM ID")
}
var seedT []byte
var curve ecdh.Curve
switch id {
case mlkem768X25519:
seedT = make([]byte, 32)
curve = ecdh.X25519()
case mlkem768P256:
seedT = make([]byte, 32)
curve = ecdh.P256()
case mlkem1024P384:
seedT = make([]byte, 48)
curve = ecdh.P384()
default:
return nil, errors.New("hpke: invalid hybrid KEM ID")
pq, err := kem.pqNewPrivateKey(seedPQ)
if err != nil {
return nil, err
}
seedT := make([]byte, kem.curveSeedSize)
for {
s.Read(seedT)
k, err := curve.NewPrivateKey(seedT)
k, err := kem.curve.NewPrivateKey(seedT)
if err != nil {
continue
}
return newHybridRecipient(k, pq, priv)
return newHybridPrivateKey(pq, k, priv)
}
}
func newHybridRecipient(t ecdh.KeyExchanger, pq crypto.Decapsulator, seed []byte) (KEMRecipient, error) {
func newHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger, seed []byte) (PrivateKey, error) {
switch t.Curve() {
case ecdh.X25519():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for X25519 hybrid")
}
return &hybridRecipient{mlkem768X25519Hybrid, bytes.Clone(seed), t, pq}, nil
return &hybridPrivateKey{mlkem768X25519, bytes.Clone(seed), t, pq}, nil
case ecdh.P256():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for P-256 hybrid")
}
return &hybridRecipient{mlkem768P256Hybrid, bytes.Clone(seed), t, pq}, nil
return &hybridPrivateKey{mlkem768P256, bytes.Clone(seed), t, pq}, nil
case ecdh.P384():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey1024); !ok {
return nil, errors.New("invalid PQ KEM for P-384 hybrid")
}
return &hybridRecipient{mlkem1024P384Hybrid, bytes.Clone(seed), t, pq}, nil
return &hybridPrivateKey{mlkem1024P384, bytes.Clone(seed), t, pq}, nil
default:
return nil, errors.New("unsupported curve")
}
}
func (r *hybridRecipient) Bytes() ([]byte, error) {
if r.seed == nil {
func (kem *hybridKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 32)
if err != nil {
return nil, err
}
return kem.NewPrivateKey(dk)
}
func (k *hybridPrivateKey) KEM() KEM {
return k.kem
}
func (k *hybridPrivateKey) Bytes() ([]byte, error) {
if k.seed == nil {
return nil, errors.New("private key seed not available")
}
return r.seed, nil
return k.seed, nil
}
func (r *hybridRecipient) KEMSender() KEMSender {
return &hybridSender{
hybrid: r.hybrid,
t: r.t.PublicKey(),
pq: r.pq.Encapsulator(),
func (k *hybridPrivateKey) PublicKey() PublicKey {
return &hybridPublicKey{
kem: k.kem,
t: k.t.PublicKey(),
pq: k.pq.Encapsulator(),
}
}
func (r *hybridRecipient) decap(enc []byte) ([]byte, error) {
var ctPQ, ctT []byte
switch r.id {
case mlkem768X25519:
if len(enc) != mlkem.CiphertextSize768+32 {
return nil, errors.New("invalid encapsulated key size")
}
ctPQ, ctT = enc[:mlkem.CiphertextSize768], enc[mlkem.CiphertextSize768:]
case mlkem768P256:
if len(enc) != mlkem.CiphertextSize768+65 {
return nil, errors.New("invalid encapsulated key size")
}
ctPQ, ctT = enc[:mlkem.CiphertextSize768], enc[mlkem.CiphertextSize768:]
case mlkem1024P384:
if len(enc) != mlkem.CiphertextSize1024+97 {
return nil, errors.New("invalid encapsulated key size")
}
ctPQ, ctT = enc[:mlkem.CiphertextSize1024], enc[mlkem.CiphertextSize1024:]
default:
return nil, errors.New("internal error: unsupported KEM")
func (k *hybridPrivateKey) decap(enc []byte) ([]byte, error) {
if len(enc) != k.kem.pqCiphertextSize+k.kem.curvePointSize {
return nil, errors.New("invalid encapsulated key size")
}
ssPQ, err := r.pq.Decapsulate(ctPQ)
ctPQ, ctT := enc[:k.kem.pqCiphertextSize], enc[k.kem.pqCiphertextSize:]
ssPQ, err := k.pq.Decapsulate(ctPQ)
if err != nil {
return nil, err
}
pub, err := r.t.Curve().NewPublicKey(ctT)
pub, err := k.t.Curve().NewPublicKey(ctT)
if err != nil {
return nil, err
}
ssT, err := r.t.ECDH(pub)
ssT, err := k.t.ECDH(pub)
if err != nil {
return nil, err
}
ss := r.sharedSecret(ssPQ, ssT, ctT, r.t.PublicKey().Bytes())
ss := k.kem.sharedSecret(ssPQ, ssT, ctT, k.t.PublicKey().Bytes())
return ss, nil
}
type mlkemSender struct {
id uint16
pq interface {
Bytes() []byte
Encapsulate() (sharedKey []byte, ciphertext []byte)
}
var mlkem768 = &mlkemKEM{
id: 0x0041,
ciphertextSize: mlkem.CiphertextSize768,
newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
generateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
}
// NewMLKEMSender returns a KEMSender implementing ML-KEM-768 or ML-KEM-1024 from
// draft-ietf-hpke-pq. pub must be either a *[mlkem.EncapsulationKey768] or a
// *[mlkem.EncapsulationKey1024].
func NewMLKEMSender(pub crypto.Encapsulator) (KEMSender, error) {
// MLKEM768 returns a KEM implementing ML-KEM-768 from draft-ietf-hpke-pq.
func MLKEM768() KEM {
return mlkem768
}
var mlkem1024 = &mlkemKEM{
id: 0x0042,
ciphertextSize: mlkem.CiphertextSize1024,
newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey1024(data)
},
newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey1024(data)
},
generateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey1024()
},
}
// MLKEM1024 returns a KEM implementing ML-KEM-1024 from draft-ietf-hpke-pq.
func MLKEM1024() KEM {
return mlkem1024
}
type mlkemKEM struct {
id uint16
ciphertextSize int
newPublicKey func(data []byte) (crypto.Encapsulator, error)
newPrivateKey func(data []byte) (crypto.Decapsulator, error)
generateKey func() (crypto.Decapsulator, error)
}
func (kem *mlkemKEM) ID() uint16 {
return kem.id
}
func (kem *mlkemKEM) encSize() int {
return kem.ciphertextSize
}
type mlkemPublicKey struct {
kem *mlkemKEM
pq crypto.Encapsulator
}
// NewMLKEMPublicKey returns a KEMPublicKey implementing
//
// - ML-KEM-768
// - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of pub
// (*[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of e.g. [MLKEM768].
func NewMLKEMPublicKey(pub crypto.Encapsulator) (PublicKey, error) {
switch pub.(type) {
case *mlkem.EncapsulationKey768:
return &mlkemSender{
id: mlkem768,
pq: pub,
}, nil
return &mlkemPublicKey{mlkem768, pub}, nil
case *mlkem.EncapsulationKey1024:
return &mlkemSender{
id: mlkem1024,
pq: pub,
}, nil
return &mlkemPublicKey{mlkem1024, pub}, nil
default:
return nil, errors.New("unsupported public key type")
}
}
func (s *mlkemSender) ID() uint16 {
return s.id
func (kem *mlkemKEM) NewPublicKey(data []byte) (PublicKey, error) {
pq, err := kem.newPublicKey(data)
if err != nil {
return nil, err
}
return NewMLKEMPublicKey(pq)
}
func (s *mlkemSender) Bytes() []byte {
return s.pq.Bytes()
func (pk *mlkemPublicKey) KEM() KEM {
return pk.kem
}
func (s *mlkemSender) encap() (sharedSecret []byte, encapPub []byte, err error) {
ss, ct := s.pq.Encapsulate()
func (pk *mlkemPublicKey) Bytes() []byte {
return pk.pq.Bytes()
}
func (pk *mlkemPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
ss, ct := pk.pq.Encapsulate()
if testingOnlyEncapsulate != nil {
ss, ct = testingOnlyEncapsulate()
}
return ss, ct, nil
}
type mlkemRecipient struct {
id uint16
pq crypto.Decapsulator
type mlkemPrivateKey struct {
kem *mlkemKEM
pq crypto.Decapsulator
}
// NewMLKEMRecipient returns a KEMRecipient implementing ML-KEM-768 or ML-KEM-1024
// from draft-ietf-hpke-pq. priv.Encapsulator() must return either a
// *[mlkem.EncapsulationKey768] or a *[mlkem.EncapsulationKey1024].
func NewMLKEMRecipient(priv crypto.Decapsulator) (KEMRecipient, error) {
// NewMLKEMPrivateKey returns a KEMPrivateKey implementing
//
// - ML-KEM-768
// - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of priv.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem private key. Otherwise, applications should use the
// [KEM.NewPrivateKey] method of e.g. [MLKEM768].
func NewMLKEMPrivateKey(priv crypto.Decapsulator) (PrivateKey, error) {
switch priv.Encapsulator().(type) {
case *mlkem.EncapsulationKey768:
return &mlkemRecipient{
id: mlkem768,
pq: priv,
}, nil
return &mlkemPrivateKey{mlkem768, priv}, nil
case *mlkem.EncapsulationKey1024:
return &mlkemRecipient{
id: mlkem1024,
pq: priv,
}, nil
return &mlkemPrivateKey{mlkem1024, priv}, nil
default:
return nil, errors.New("unsupported public key type")
}
}
func (r *mlkemRecipient) ID() uint16 {
return r.id
func (kem *mlkemKEM) GenerateKey() (PrivateKey, error) {
pq, err := kem.generateKey()
if err != nil {
return nil, err
}
return NewMLKEMPrivateKey(pq)
}
func (r *mlkemRecipient) Bytes() ([]byte, error) {
pq, ok := r.pq.(interface {
func (kem *mlkemKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
pq, err := kem.newPrivateKey(priv)
if err != nil {
return nil, err
}
return NewMLKEMPrivateKey(pq)
}
func (kem *mlkemKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 64)
if err != nil {
return nil, err
}
return kem.NewPrivateKey(dk)
}
func (k *mlkemPrivateKey) KEM() KEM {
return k.kem
}
func (k *mlkemPrivateKey) Bytes() ([]byte, error) {
pq, ok := k.pq.(interface {
Bytes() []byte
})
if !ok {
@ -356,26 +518,13 @@ func (r *mlkemRecipient) Bytes() ([]byte, error) {
return pq.Bytes(), nil
}
func (r *mlkemRecipient) KEMSender() KEMSender {
s := &mlkemSender{
id: r.id,
pq: r.pq.Encapsulator(),
func (k *mlkemPrivateKey) PublicKey() PublicKey {
return &mlkemPublicKey{
kem: k.kem,
pq: k.pq.Encapsulator(),
}
return s
}
func (r *mlkemRecipient) decap(enc []byte) ([]byte, error) {
switch r.id {
case mlkem768:
if len(enc) != mlkem.CiphertextSize768 {
return nil, errors.New("invalid encapsulated key size")
}
case mlkem1024:
if len(enc) != mlkem.CiphertextSize1024 {
return nil, errors.New("invalid encapsulated key size")
}
default:
return nil, errors.New("internal error: unsupported KEM")
}
return r.pq.Decapsulate(enc)
func (k *mlkemPrivateKey) decap(enc []byte) ([]byte, error) {
return k.pq.Decapsulate(enc)
}

File diff suppressed because it is too large Load diff

View file

@ -149,7 +149,7 @@ func parseECHConfigList(data []byte) ([]echConfig, error) {
return configs, nil
}
func pickECHConfig(list []echConfig) (*echConfig, hpke.KEMSender, hpke.KDF, hpke.AEAD) {
func pickECHConfig(list []echConfig) (*echConfig, hpke.PublicKey, hpke.KDF, hpke.AEAD) {
for _, ec := range list {
if !validDNSName(string(ec.PublicName)) {
continue
@ -166,10 +166,16 @@ func pickECHConfig(list []echConfig) (*echConfig, hpke.KEMSender, hpke.KDF, hpke
if unsupportedExt {
continue
}
s, err := hpke.NewKEMSender(ec.KemID, ec.PublicKey)
kem, err := hpke.NewKEM(ec.KemID)
if err != nil {
continue
}
pub, err := kem.NewPublicKey(ec.PublicKey)
if err != nil {
// This is an error in the config, but killing the connection feels
// excessive.
continue
}
for _, cs := range ec.SymmetricCipherSuite {
// All of the supported AEADs and KDFs are fine, rather than
// imposing some sort of preference here, we just pick the first
@ -182,7 +188,7 @@ func pickECHConfig(list []echConfig) (*echConfig, hpke.KEMSender, hpke.KDF, hpke
if err != nil {
continue
}
return &ec, s, kdf, aead
return &ec, pub, kdf, aead
}
}
return nil, nil, nil, nil
@ -568,7 +574,12 @@ func (c *Conn) processECHClientHello(outer *clientHelloMsg, echKeys []EncryptedC
if skip {
continue
}
echPriv, err := hpke.NewKEMRecipient(config.KemID, echKey.PrivateKey)
kem, err := hpke.NewKEM(config.KemID)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config KEM: %s", err)
}
echPriv, err := kem.NewPrivateKey(echKey.PrivateKey)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey PrivateKey: %s", err)