crypto/internal/hpke: separate KEM and PublicKey/PrivateKey interfaces

Updates #75300

Change-Id: I87ed26e8f57180d741408bdbda1696d46a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/719560
Reviewed-by: Mark Freeman <markfreeman@google.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Filippo Valsorda 2025-11-11 13:10:17 +01:00 committed by Gopher Robot
parent e15800c0ec
commit 7db2f0bb9a
6 changed files with 1743 additions and 792 deletions

View file

@ -132,12 +132,12 @@ func newContext(sharedSecret []byte, kemID uint16, kdf KDF, aead AEAD, info []by
// //
// The returned enc ciphertext can be used to instantiate a matching receiving // The returned enc ciphertext can be used to instantiate a matching receiving
// HPKE context with the corresponding KEM decapsulation key. // HPKE context with the corresponding KEM decapsulation key.
func NewSender(kem KEMSender, kdf KDF, aead AEAD, info []byte) (enc []byte, s *Sender, err error) { func NewSender(pk PublicKey, kdf KDF, aead AEAD, info []byte) (enc []byte, s *Sender, err error) {
sharedSecret, encapsulatedKey, err := kem.encap() sharedSecret, encapsulatedKey, err := pk.encap()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
context, err := newContext(sharedSecret, kem.ID(), kdf, aead, info) context, err := newContext(sharedSecret, pk.KEM().ID(), kdf, aead, info)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -151,12 +151,12 @@ func NewSender(kem KEMSender, kdf KDF, aead AEAD, info []byte) (enc []byte, s *S
// The enc parameter must have been produced by a matching sending HPKE context // The enc parameter must have been produced by a matching sending HPKE context
// with the corresponding KEM encapsulation key. The info parameter is // with the corresponding KEM encapsulation key. The info parameter is
// additional public information that must match between sender and recipient. // additional public information that must match between sender and recipient.
func NewRecipient(enc []byte, kem KEMRecipient, kdf KDF, aead AEAD, info []byte) (*Recipient, error) { func NewRecipient(enc []byte, k PrivateKey, kdf KDF, aead AEAD, info []byte) (*Recipient, error) {
sharedSecret, err := kem.decap(enc) sharedSecret, err := k.decap(enc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
context, err := newContext(sharedSecret, kem.ID(), kdf, aead, info) context, err := newContext(sharedSecret, k.KEM().ID(), kdf, aead, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -179,16 +179,17 @@ func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
// Seal instantiates a single-use HPKE sending HPKE context like [NewSender], // Seal instantiates a single-use HPKE sending HPKE context like [NewSender],
// and then encrypts the provided plaintext like [Sender.Seal] (with no aad). // and then encrypts the provided plaintext like [Sender.Seal] (with no aad).
func Seal(kem KEMSender, kdf KDF, aead AEAD, info, plaintext []byte) (enc, ct []byte, err error) { // Seal returns the concatenation of the encapsulated key and the ciphertext.
enc, s, err := NewSender(kem, kdf, aead, info) func Seal(pk PublicKey, kdf KDF, aead AEAD, info, plaintext []byte) ([]byte, error) {
enc, s, err := NewSender(pk, kdf, aead, info)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
ct, err = s.Seal(nil, plaintext) ct, err := s.Seal(nil, plaintext)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
return enc, ct, err return append(enc, ct...), nil
} }
// Export produces a secret value derived from the shared key between sender and // Export produces a secret value derived from the shared key between sender and
@ -219,8 +220,14 @@ func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) {
// Open instantiates a single-use HPKE receiving HPKE context like [NewRecipient], // Open instantiates a single-use HPKE receiving HPKE context like [NewRecipient],
// and then decrypts the provided ciphertext like [Recipient.Open] (with no aad). // and then decrypts the provided ciphertext like [Recipient.Open] (with no aad).
func Open(enc []byte, kem KEMRecipient, kdf KDF, aead AEAD, info, ciphertext []byte) ([]byte, error) { // ciphertext must be the concatenation of the encapsulated key and the actual ciphertext.
r, err := NewRecipient(enc, kem, kdf, aead, info) func Open(k PrivateKey, kdf KDF, aead AEAD, info, ciphertext []byte) ([]byte, error) {
encSize := k.KEM().encSize()
if len(ciphertext) < encSize {
return nil, errors.New("ciphertext too short")
}
enc, ciphertext := ciphertext[:encSize], ciphertext[encSize:]
r, err := NewRecipient(enc, k, kdf, aead, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -18,6 +18,57 @@ import (
"testing" "testing"
) )
func Example() {
// In this example, we use MLKEM768-X25519 as the KEM, HKDF-SHA256 as the
// KDF, and AES-256-GCM as the AEAD to encrypt a single message from a
// sender to a recipient using the one-shot API.
kem, kdf, aead := MLKEM768X25519(), HKDFSHA256(), AES256GCM()
// Recipient side
var (
recipientPrivateKey PrivateKey
publicKeyBytes []byte
)
{
k, err := kem.GenerateKey()
if err != nil {
panic(err)
}
recipientPrivateKey = k
publicKeyBytes = k.PublicKey().Bytes()
}
// Sender side
var ciphertext []byte
{
publicKey, err := kem.NewPublicKey(publicKeyBytes)
if err != nil {
panic(err)
}
message := []byte("|-()-|")
ct, err := Seal(publicKey, kdf, aead, []byte("example"), message)
if err != nil {
panic(err)
}
ciphertext = ct
}
// Recipient side
{
plaintext, err := Open(recipientPrivateKey, kdf, aead, []byte("example"), ciphertext)
if err != nil {
panic(err)
}
fmt.Printf("Decrypted message: %s\n", plaintext)
}
// Output:
// Decrypted message: |-()-|
}
func mustDecodeHex(t *testing.T, in string) []byte { func mustDecodeHex(t *testing.T, in string) []byte {
t.Helper() t.Helper()
b, err := hex.DecodeString(in) b, err := hex.DecodeString(in)
@ -83,6 +134,12 @@ func testVectors(t *testing.T, name string) {
if vector.KEM == 0x0021 { if vector.KEM == 0x0021 {
t.Skip("KEM 0x0021 (DHKEM(X448)) not supported") t.Skip("KEM 0x0021 (DHKEM(X448)) not supported")
} }
if vector.KEM == 0x0040 {
t.Skip("KEM 0x0040 (ML-KEM-512) not supported")
}
if vector.KDF == 0x0012 || vector.KDF == 0x0013 {
t.Skipf("TurboSHAKE KDF not supported")
}
kdf, err := NewKDF(vector.KDF) kdf, err := NewKDF(vector.KDF)
if err != nil { if err != nil {
@ -100,26 +157,37 @@ func testVectors(t *testing.T, name string) {
t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD) t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD)
} }
pubKeyBytes := mustDecodeHex(t, vector.PkRm) kem, err := NewKEM(vector.KEM)
kemSender, err := NewKEMSender(vector.KEM, pubKeyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if kemSender.ID() != vector.KEM { if kem.ID() != vector.KEM {
t.Errorf("unexpected KEM ID: got %04x, want %04x", kemSender.ID(), vector.KEM) t.Errorf("unexpected KEM ID: got %04x, want %04x", kem.ID(), vector.KEM)
}
pubKeyBytes := mustDecodeHex(t, vector.PkRm)
kemSender, err := kem.NewPublicKey(pubKeyBytes)
if err != nil {
t.Fatal(err)
}
if kemSender.KEM() != kem {
t.Errorf("unexpected KEM from sender: got %04x, want %04x", kemSender.KEM().ID(), kem.ID())
} }
if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) { if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes) t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes)
} }
ikmE := mustDecodeHex(t, vector.IkmE) ikmE := mustDecodeHex(t, vector.IkmE)
setupDerandomizedEncap(t, vector.KEM, ikmE, kemSender) setupDerandomizedEncap(t, ikmE, kemSender)
info := mustDecodeHex(t, vector.Info) info := mustDecodeHex(t, vector.Info)
encap, sender, err := NewSender(kemSender, kdf, aead, info) encap, sender, err := NewSender(kemSender, kdf, aead, info)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(encap) != kem.encSize() {
t.Errorf("unexpected encapsulated key size: got %d, want %d", len(encap), kem.encSize())
}
expectedEncap := mustDecodeHex(t, vector.Enc) expectedEncap := mustDecodeHex(t, vector.Enc)
if !bytes.Equal(encap, expectedEncap) { if !bytes.Equal(encap, expectedEncap) {
@ -127,23 +195,23 @@ func testVectors(t *testing.T, name string) {
} }
privKeyBytes := mustDecodeHex(t, vector.SkRm) privKeyBytes := mustDecodeHex(t, vector.SkRm)
kemRecipient, err := NewKEMRecipient(vector.KEM, privKeyBytes) kemRecipient, err := kem.NewPrivateKey(privKeyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if kemRecipient.ID() != vector.KEM { if kemRecipient.KEM() != kem {
t.Errorf("unexpected KEM ID: got %04x, want %04x", kemRecipient.ID(), vector.KEM) t.Errorf("unexpected KEM from recipient: got %04x, want %04x", kemRecipient.KEM().ID(), kem.ID())
} }
kemRecipientBytes, err := kemRecipient.Bytes() kemRecipientBytes, err := kemRecipient.Bytes()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// X25519 serialized keys must be clamped, so the bytes might not match. // X25519 serialized keys must be clamped, so the bytes might not match.
if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != dhkemX25519 { if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes) t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes)
} }
if vector.KEM == dhkemX25519 { if vector.KEM == DHKEM(ecdh.X25519()).ID() {
kem2, err := NewKEMRecipient(vector.KEM, kemRecipientBytes) kem2, err := kem.NewPrivateKey(kemRecipientBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -154,32 +222,28 @@ func testVectors(t *testing.T, name string) {
if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) { if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) {
t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes) t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes)
} }
if !bytes.Equal(kem2.KEMSender().Bytes(), pubKeyBytes) { if !bytes.Equal(kem2.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.KEMSender().Bytes(), pubKeyBytes) t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.PublicKey().Bytes(), pubKeyBytes)
} }
} }
if !bytes.Equal(kemRecipient.KEMSender().Bytes(), pubKeyBytes) { if !bytes.Equal(kemRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.KEMSender().Bytes(), pubKeyBytes) t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.PublicKey().Bytes(), pubKeyBytes)
} }
// NewKEMRecipientFromSeed is not implemented for the PQ KEMs yet. ikm := mustDecodeHex(t, vector.IkmR)
if vector.KEM != mlkem768 && vector.KEM != mlkem1024 && vector.KEM != mlkem768X25519 && derivRecipient, err := kem.DeriveKeyPair(ikm)
vector.KEM != mlkem768P256 && vector.KEM != mlkem1024P384 { if err != nil {
seed := mustDecodeHex(t, vector.IkmR) t.Fatal(err)
seedRecipient, err := NewKEMRecipientFromSeed(vector.KEM, seed) }
if err != nil { derivRecipientBytes, err := derivRecipient.Bytes()
t.Fatal(err) if err != nil {
} t.Fatal(err)
seedRecipientBytes, err := seedRecipient.Bytes() }
if err != nil { if !bytes.Equal(derivRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Fatal(err) t.Errorf("unexpected KEM bytes from seed: got %x, want %x", derivRecipientBytes, privKeyBytes)
} }
if !bytes.Equal(seedRecipientBytes, privKeyBytes) && vector.KEM != dhkemX25519 { if !bytes.Equal(derivRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM bytes from seed: got %x, want %x", seedRecipientBytes, privKeyBytes) t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", derivRecipient.PublicKey().Bytes(), pubKeyBytes)
}
if !bytes.Equal(seedRecipient.KEMSender().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", seedRecipient.KEMSender().Bytes(), pubKeyBytes)
}
} }
recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info) recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info)
@ -304,22 +368,22 @@ func drawRandomInput(t *testing.T, r io.Reader) []byte {
return b return b
} }
func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KEMSender) { func setupDerandomizedEncap(t *testing.T, randBytes []byte, pk PublicKey) {
t.Cleanup(func() { t.Cleanup(func() {
testingOnlyGenerateKey = nil testingOnlyGenerateKey = nil
testingOnlyEncapsulate = nil testingOnlyEncapsulate = nil
}) })
switch kemID { switch pk.KEM() {
case dhkemP256, dhkemP384, dhkemP521, dhkemX25519: case DHKEM(ecdh.P256()), DHKEM(ecdh.P384()), DHKEM(ecdh.P521()), DHKEM(ecdh.X25519()):
r, err := NewKEMRecipientFromSeed(kemID, randBytes) r, err := pk.KEM().DeriveKeyPair(randBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testingOnlyGenerateKey = func() *ecdh.PrivateKey { testingOnlyGenerateKey = func() *ecdh.PrivateKey {
return r.(*dhKEMRecipient).priv.(*ecdh.PrivateKey) return r.(*dhKEMPrivateKey).priv.(*ecdh.PrivateKey)
} }
case mlkem768: case mlkem768:
pq := kem.(*mlkemSender).pq.(*mlkem.EncapsulationKey768) pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey768)
testingOnlyEncapsulate = func() ([]byte, []byte) { testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate768(pq, randBytes) ss, ct, err := mlkemtest.Encapsulate768(pq, randBytes)
if err != nil { if err != nil {
@ -328,7 +392,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
return ss, ct return ss, ct
} }
case mlkem1024: case mlkem1024:
pq := kem.(*mlkemSender).pq.(*mlkem.EncapsulationKey1024) pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey1024)
testingOnlyEncapsulate = func() ([]byte, []byte) { testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate1024(pq, randBytes) ss, ct, err := mlkemtest.Encapsulate1024(pq, randBytes)
if err != nil { if err != nil {
@ -338,7 +402,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
} }
case mlkem768X25519: case mlkem768X25519:
pqRand, tRand := randBytes[:32], randBytes[32:] pqRand, tRand := randBytes[:32], randBytes[32:]
pq := kem.(*hybridSender).pq.(*mlkem.EncapsulationKey768) pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
k, err := ecdh.X25519().NewPrivateKey(tRand) k, err := ecdh.X25519().NewPrivateKey(tRand)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -357,7 +421,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
// The rest of randBytes are the following candidates for rejection // The rest of randBytes are the following candidates for rejection
// sampling, but they are never reached. // sampling, but they are never reached.
pqRand, tRand := randBytes[:32], randBytes[32:64] pqRand, tRand := randBytes[:32], randBytes[32:64]
pq := kem.(*hybridSender).pq.(*mlkem.EncapsulationKey768) pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
k, err := ecdh.P256().NewPrivateKey(tRand) k, err := ecdh.P256().NewPrivateKey(tRand)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -374,7 +438,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
} }
case mlkem1024P384: case mlkem1024P384:
pqRand, tRand := randBytes[:32], randBytes[32:] pqRand, tRand := randBytes[:32], randBytes[32:]
pq := kem.(*hybridSender).pq.(*mlkem.EncapsulationKey1024) pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey1024)
k, err := ecdh.P384().NewPrivateKey(tRand) k, err := ecdh.P384().NewPrivateKey(tRand)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -390,7 +454,7 @@ func setupDerandomizedEncap(t *testing.T, kemID uint16, randBytes []byte, kem KE
return ss, ct return ss, ct
} }
default: default:
t.Fatalf("unsupported KEM %04x", kemID) t.Fatalf("unsupported KEM %04x", pk.KEM().ID())
} }
} }
@ -416,4 +480,31 @@ func TestSingletons(t *testing.T) {
if ExportOnly() != ExportOnly() { if ExportOnly() != ExportOnly() {
t.Error("ExportOnly() != ExportOnly()") t.Error("ExportOnly() != ExportOnly()")
} }
if DHKEM(ecdh.P256()) != DHKEM(ecdh.P256()) {
t.Error("DHKEM(P-256) != DHKEM(P-256)")
}
if DHKEM(ecdh.P384()) != DHKEM(ecdh.P384()) {
t.Error("DHKEM(P-384) != DHKEM(P-384)")
}
if DHKEM(ecdh.P521()) != DHKEM(ecdh.P521()) {
t.Error("DHKEM(P-521) != DHKEM(P-521)")
}
if DHKEM(ecdh.X25519()) != DHKEM(ecdh.X25519()) {
t.Error("DHKEM(X25519) != DHKEM(X25519)")
}
if MLKEM768() != MLKEM768() {
t.Error("MLKEM768() != MLKEM768()")
}
if MLKEM1024() != MLKEM1024() {
t.Error("MLKEM1024() != MLKEM1024()")
}
if MLKEM768X25519() != MLKEM768X25519() {
t.Error("MLKEM768X25519() != MLKEM768X25519()")
}
if MLKEM768P256() != MLKEM768P256() {
t.Error("MLKEM768P256() != MLKEM768P256()")
}
if MLKEM1024P384() != MLKEM1024P384() {
t.Error("MLKEM1024P384() != MLKEM1024P384()")
}
} }

View file

@ -6,129 +6,90 @@ package hpke
import ( import (
"crypto/ecdh" "crypto/ecdh"
"crypto/mlkem"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
) )
const ( // A KEM is a Key Encapsulation Mechanism, one of the three components of an
dhkemP256 = 0x0010 // DHKEM(P-256, HKDF-SHA256) // HPKE ciphersuite.
dhkemP384 = 0x0011 // DHKEM(P-384, HKDF-SHA384) type KEM interface {
dhkemP521 = 0x0012 // DHKEM(P-521, HKDF-SHA512)
dhkemX25519 = 0x0020 // DHKEM(X25519, HKDF-SHA256)
)
// A KEMSender is an instantiation of a KEM (one of the three components of an
// HPKE ciphersuite) with an encapsulation key (i.e. the public key).
type KEMSender interface {
// ID returns the HPKE KEM identifier. // ID returns the HPKE KEM identifier.
ID() uint16 ID() uint16
// GenerateKey generates a new key pair.
GenerateKey() (PrivateKey, error)
// NewPublicKey deserializes a public key from bytes.
//
// It implements DeserializePublicKey, as defined in RFC 9180.
NewPublicKey([]byte) (PublicKey, error)
// NewPrivateKey deserializes a private key from bytes.
//
// It implements DeserializePrivateKey, as defined in RFC 9180.
NewPrivateKey([]byte) (PrivateKey, error)
// DeriveKeyPair derives a key pair from the given input keying material.
//
// It implements DeriveKeyPair, as defined in RFC 9180.
DeriveKeyPair(ikm []byte) (PrivateKey, error)
encSize() int
}
// NewKEM returns the KEM implementation for the given KEM ID.
//
// Applications are encouraged to use specific implementations like [DHKEM] or
// [MLKEM768X25519] instead, unless runtime agility is required.
func NewKEM(id uint16) (KEM, error) {
switch id {
case 0x0010: // DHKEM(P-256, HKDF-SHA256)
return DHKEM(ecdh.P256()), nil
case 0x0011: // DHKEM(P-384, HKDF-SHA384)
return DHKEM(ecdh.P384()), nil
case 0x0012: // DHKEM(P-521, HKDF-SHA512)
return DHKEM(ecdh.P521()), nil
case 0x0020: // DHKEM(X25519, HKDF-SHA256)
return DHKEM(ecdh.X25519()), nil
case 0x0041: // ML-KEM-768
return MLKEM768(), nil
case 0x0042: // ML-KEM-1024
return MLKEM1024(), nil
case 0x647a: // MLKEM768-X25519
return MLKEM768X25519(), nil
case 0x0050: // MLKEM768-P256
return MLKEM768P256(), nil
case 0x0051: // MLKEM1024-P384
return MLKEM1024P384(), nil
default:
return nil, errors.New("unsupported KEM")
}
}
// A PublicKey is an instantiation of a KEM (one of the three components of an
// HPKE ciphersuite) with an encapsulation key (i.e. the public key).
//
// A PublicKey is usually obtained from a method of the corresponding [KEM] or
// [PrivateKey], such as [KEM.NewPublicKey] or [PrivateKey.PublicKey].
type PublicKey interface {
// KEM returns the instantiated KEM.
KEM() KEM
// Bytes returns the public key as the output of SerializePublicKey. // Bytes returns the public key as the output of SerializePublicKey.
Bytes() []byte Bytes() []byte
encap() (sharedSecret, enc []byte, err error) encap() (sharedSecret, enc []byte, err error)
} }
// NewKEMSender implements DeserializePublicKey and returns a KEMSender // A PrivateKey is an instantiation of a KEM (one of the three components of
// for the given KEM ID and public key bytes.
//
// Applications are encouraged to use [ecdh.Curve.NewPublicKey] with
// [NewECDHSender] instead, unless runtime agility is required.
func NewKEMSender(id uint16, pub []byte) (KEMSender, error) {
switch id {
case dhkemP256:
k, err := ecdh.P256().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case dhkemP384:
k, err := ecdh.P384().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case dhkemP521:
k, err := ecdh.P521().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case dhkemX25519:
k, err := ecdh.X25519().NewPublicKey(pub)
if err != nil {
return nil, err
}
return NewECDHSender(k)
case mlkem768:
if len(pub) != mlkem.EncapsulationKeySize768 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey768(pub)
if err != nil {
return nil, err
}
return NewMLKEMSender(pq)
case mlkem1024:
if len(pub) != mlkem.EncapsulationKeySize1024 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey1024(pub)
if err != nil {
return nil, err
}
return NewMLKEMSender(pq)
case mlkem768X25519:
if len(pub) != mlkem.EncapsulationKeySize768+32 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey768(pub[:mlkem.EncapsulationKeySize768])
if err != nil {
return nil, err
}
k, err := ecdh.X25519().NewPublicKey(pub[mlkem.EncapsulationKeySize768:])
if err != nil {
return nil, err
}
return NewHybridSender(k, pq)
case mlkem768P256:
if len(pub) != mlkem.EncapsulationKeySize768+65 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey768(pub[:mlkem.EncapsulationKeySize768])
if err != nil {
return nil, err
}
k, err := ecdh.P256().NewPublicKey(pub[mlkem.EncapsulationKeySize768:])
if err != nil {
return nil, err
}
return NewHybridSender(k, pq)
case mlkem1024P384:
if len(pub) != mlkem.EncapsulationKeySize1024+97 {
return nil, errors.New("invalid public key size")
}
pq, err := mlkem.NewEncapsulationKey1024(pub[:mlkem.EncapsulationKeySize1024])
if err != nil {
return nil, err
}
k, err := ecdh.P384().NewPublicKey(pub[mlkem.EncapsulationKeySize1024:])
if err != nil {
return nil, err
}
return NewHybridSender(k, pq)
default:
return nil, errors.New("unsupported KEM")
}
}
// A KEMRecipient is an instantiation of a KEM (one of the three components of
// an HPKE ciphersuite) with a decapsulation key (i.e. the secret key). // an HPKE ciphersuite) with a decapsulation key (i.e. the secret key).
type KEMRecipient interface { //
// ID returns the HPKE KEM identifier. // A PrivateKey is usually obtained from a method of the corresponding [KEM],
ID() uint16 // such as [KEM.GenerateKey] or [KEM.NewPrivateKey].
type PrivateKey interface {
// KEM returns the instantiated KEM.
KEM() KEM
// Bytes returns the private key as the output of SerializePrivateKey, as // Bytes returns the private key as the output of SerializePrivateKey, as
// defined in RFC 9180. // defined in RFC 9180.
@ -137,117 +98,232 @@ type KEMRecipient interface {
// This is a requirement of RFC 9180, Section 7.1.2. // This is a requirement of RFC 9180, Section 7.1.2.
Bytes() ([]byte, error) Bytes() ([]byte, error)
// KEMSender returns the corresponding KEMSender for this recipient. // PublicKey returns the corresponding PublicKey.
KEMSender() KEMSender PublicKey() PublicKey
decap(enc []byte) (sharedSecret []byte, err error) decap(enc []byte) (sharedSecret []byte, err error)
} }
// NewKEMRecipient implements DeserializePrivateKey, as defined in RFC 9180, and type dhKEM struct {
// returns a KEMRecipient for the given KEM ID and private key bytes. kdf KDF
// id uint16
// Applications are encouraged to use [ecdh.Curve.NewPrivateKey] with curve ecdh.Curve
// [NewECDHRecipient] instead, unless runtime agility is required. Nsecret uint16
func NewKEMRecipient(id uint16, priv []byte) (KEMRecipient, error) { Nsk uint16
switch id { Nenc int
case dhkemP256:
k, err := ecdh.P256().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case dhkemP384:
k, err := ecdh.P384().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case dhkemP521:
k, err := ecdh.P521().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case dhkemX25519:
k, err := ecdh.X25519().NewPrivateKey(priv)
if err != nil {
return nil, err
}
return NewECDHRecipient(k)
case mlkem768:
pq, err := mlkem.NewDecapsulationKey768(priv)
if err != nil {
return nil, err
}
return NewMLKEMRecipient(pq)
case mlkem1024:
pq, err := mlkem.NewDecapsulationKey1024(priv)
if err != nil {
return nil, err
}
return NewMLKEMRecipient(pq)
case mlkem768X25519, mlkem768P256, mlkem1024P384:
return newHybridRecipientFromSeed(id, priv)
default:
return nil, errors.New("unsupported KEM")
}
} }
// NewKEMRecipientFromSeed implements DeriveKeyPair, as defined in RFC 9180, and func (kem *dhKEM) extractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
// returns a KEMRecipient for the given KEM ID and private key seed. suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
// eaePRK, err := kem.kdf.labeledExtract(suiteID, nil, "eae_prk", dhKey)
// Currently, it only supports the KEMs based on ECDH (DHKEM).
//
// TODO: rename to something about deriving.
func NewKEMRecipientFromSeed(id uint16, seed []byte) (KEMRecipient, error) {
// DeriveKeyPair from RFC 9180 Section 7.1.3.
var curve ecdh.Curve
var dh dhKEM
var Nsk uint16
switch id {
case dhkemP256:
curve = ecdh.P256()
dh, _ = dhKEMForCurve(curve)
Nsk = 32
case dhkemP384:
curve = ecdh.P384()
dh, _ = dhKEMForCurve(curve)
Nsk = 48
case dhkemP521:
curve = ecdh.P521()
dh, _ = dhKEMForCurve(curve)
Nsk = 66
case dhkemX25519:
curve = ecdh.X25519()
dh, _ = dhKEMForCurve(curve)
Nsk = 32
// Do not implement the PQ KEMs for now, as the seed input is not
// stable yet. See https://github.com/hpkewg/hpke-pq/issues/30.
default:
return nil, errors.New("unsupported KEM")
}
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), dh.id)
prk, err := dh.kdf.labeledExtract(suiteID, nil, "dkp_prk", seed)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if id == dhkemX25519 { return kem.kdf.labeledExpand(suiteID, eaePRK, "shared_secret", kemContext, kem.Nsecret)
s, err := dh.kdf.labeledExpand(suiteID, prk, "sk", nil, Nsk) }
func (kem *dhKEM) ID() uint16 {
return kem.id
}
func (kem *dhKEM) encSize() int {
return kem.Nenc
}
var dhKEMP256 = &dhKEM{HKDFSHA256(), 0x0010, ecdh.P256(), 32, 32, 65}
var dhKEMP384 = &dhKEM{HKDFSHA384(), 0x0011, ecdh.P384(), 48, 48, 97}
var dhKEMP521 = &dhKEM{HKDFSHA512(), 0x0012, ecdh.P521(), 64, 66, 133}
var dhKEMX25519 = &dhKEM{HKDFSHA256(), 0x0020, ecdh.X25519(), 32, 32, 32}
// DHKEM returns a KEM implementing one of
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on curve.
func DHKEM(curve ecdh.Curve) KEM {
switch curve {
case ecdh.P256():
return dhKEMP256
case ecdh.P384():
return dhKEMP384
case ecdh.P521():
return dhKEMP521
case ecdh.X25519():
return dhKEMX25519
default:
// The set of ecdh.Curve implementations is closed, because the
// interface has unexported methods. Therefore, this default case is
// only hit if a new curve is added that DHKEM doesn't support.
return unsupportedCurveKEM{}
}
}
type unsupportedCurveKEM struct{}
func (unsupportedCurveKEM) ID() uint16 {
return 0
}
func (unsupportedCurveKEM) GenerateKey() (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) NewPublicKey([]byte) (PublicKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) NewPrivateKey([]byte) (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) DeriveKeyPair([]byte) (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) encSize() int {
return 0
}
type dhKEMPublicKey struct {
kem *dhKEM
pub *ecdh.PublicKey
}
// NewDHKEMPublicKey returns a PublicKey implementing
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of pub ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of [DHKEM].
func NewDHKEMPublicKey(pub *ecdh.PublicKey) (PublicKey, error) {
kem, ok := DHKEM(pub.Curve()).(*dhKEM)
if !ok {
return nil, errors.New("unsupported curve")
}
return &dhKEMPublicKey{
kem: kem,
pub: pub,
}, nil
}
func (kem *dhKEM) NewPublicKey(data []byte) (PublicKey, error) {
pub, err := kem.curve.NewPublicKey(data)
if err != nil {
return nil, err
}
return NewDHKEMPublicKey(pub)
}
func (pk *dhKEMPublicKey) KEM() KEM {
return pk.kem
}
func (pk *dhKEMPublicKey) Bytes() []byte {
return pk.pub.Bytes()
}
// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() *ecdh.PrivateKey
func (pk *dhKEMPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
privEph, err := pk.pub.Curve().GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
if testingOnlyGenerateKey != nil {
privEph = testingOnlyGenerateKey()
}
dhVal, err := privEph.ECDH(pk.pub)
if err != nil {
return nil, nil, err
}
encPubEph := privEph.PublicKey().Bytes()
encPubRecip := pk.pub.Bytes()
kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = pk.kem.extractAndExpand(dhVal, kemContext)
if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
}
type dhKEMPrivateKey struct {
kem *dhKEM
priv ecdh.KeyExchanger
}
// NewDHKEMPrivateKey returns a PrivateKey implementing
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of priv ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh private key, or another implementation of a [ecdh.KeyExchanger]
// (e.g. a hardware key). Otherwise, applications should use the
// [KEM.NewPrivateKey] method of [DHKEM].
func NewDHKEMPrivateKey(priv ecdh.KeyExchanger) (PrivateKey, error) {
kem, ok := DHKEM(priv.Curve()).(*dhKEM)
if !ok {
return nil, errors.New("unsupported curve")
}
return &dhKEMPrivateKey{
kem: kem,
priv: priv,
}, nil
}
func (kem *dhKEM) GenerateKey() (PrivateKey, error) {
priv, err := kem.curve.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
return NewDHKEMPrivateKey(priv)
}
func (kem *dhKEM) NewPrivateKey(ikm []byte) (PrivateKey, error) {
priv, err := kem.curve.NewPrivateKey(ikm)
if err != nil {
return nil, err
}
return NewDHKEMPrivateKey(priv)
}
func (kem *dhKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
// DeriveKeyPair from RFC 9180 Section 7.1.3.
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
prk, err := kem.kdf.labeledExtract(suiteID, nil, "dkp_prk", ikm)
if err != nil {
return nil, err
}
if kem == dhKEMX25519 {
s, err := kem.kdf.labeledExpand(suiteID, prk, "sk", nil, kem.Nsk)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewKEMRecipient(id, s) return kem.NewPrivateKey(s)
} }
var counter uint8 var counter uint8
for counter < 4 { for counter < 4 {
s, err := dh.kdf.labeledExpand(suiteID, prk, "candidate", []byte{counter}, Nsk) s, err := kem.kdf.labeledExpand(suiteID, prk, "candidate", []byte{counter}, kem.Nsk)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if id == dhkemP521 { if kem == dhKEMP521 {
s[0] &= 0x01 s[0] &= 0x01
} }
r, err := NewKEMRecipient(id, s) r, err := kem.NewPrivateKey(s)
if err != nil { if err != nil {
counter++ counter++
continue continue
@ -257,136 +333,11 @@ func NewKEMRecipientFromSeed(id uint16, seed []byte) (KEMRecipient, error) {
panic("chance of four rejections is < 2^-128") panic("chance of four rejections is < 2^-128")
} }
type dhKEM struct { func (k *dhKEMPrivateKey) KEM() KEM {
kdf KDF return k.kem
id uint16
nSecret uint16
} }
func (dh *dhKEM) extractAndExpand(dhKey, kemContext []byte) ([]byte, error) { func (k *dhKEMPrivateKey) Bytes() ([]byte, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), dh.id)
eaePRK, err := dh.kdf.labeledExtract(suiteID, nil, "eae_prk", dhKey)
if err != nil {
return nil, err
}
return dh.kdf.labeledExpand(suiteID, eaePRK, "shared_secret", kemContext, dh.nSecret)
}
func (dh *dhKEM) ID() uint16 {
return dh.id
}
type dhKEMSender struct {
dhKEM
pub *ecdh.PublicKey
}
// NewECDHSender returns a KEMSender implementing one of
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of the provided public key.
func NewECDHSender(pub *ecdh.PublicKey) (KEMSender, error) {
dhKEM, err := dhKEMForCurve(pub.Curve())
if err != nil {
return nil, err
}
return &dhKEMSender{
pub: pub,
dhKEM: dhKEM,
}, nil
}
func dhKEMForCurve(curve ecdh.Curve) (dhKEM, error) {
switch curve {
case ecdh.P256():
return dhKEM{
kdf: HKDFSHA256(),
id: dhkemP256,
nSecret: 32,
}, nil
case ecdh.P384():
return dhKEM{
kdf: HKDFSHA384(),
id: dhkemP384,
nSecret: 48,
}, nil
case ecdh.P521():
return dhKEM{
kdf: HKDFSHA512(),
id: dhkemP521,
nSecret: 64,
}, nil
case ecdh.X25519():
return dhKEM{
kdf: HKDFSHA256(),
id: dhkemX25519,
nSecret: 32,
}, nil
default:
return dhKEM{}, errors.New("unsupported curve")
}
}
func (dh *dhKEMSender) Bytes() []byte {
return dh.pub.Bytes()
}
// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() *ecdh.PrivateKey
func (dh *dhKEMSender) encap() (sharedSecret []byte, encapPub []byte, err error) {
privEph, err := dh.pub.Curve().GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
if testingOnlyGenerateKey != nil {
privEph = testingOnlyGenerateKey()
}
dhVal, err := privEph.ECDH(dh.pub)
if err != nil {
return nil, nil, err
}
encPubEph := privEph.PublicKey().Bytes()
encPubRecip := dh.pub.Bytes()
kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = dh.extractAndExpand(dhVal, kemContext)
if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
}
type dhKEMRecipient struct {
dhKEM
priv ecdh.KeyExchanger
}
// NewECDHRecipient returns a KEMRecipient implementing one of
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of the provided private key.
func NewECDHRecipient(priv ecdh.KeyExchanger) (KEMRecipient, error) {
dhKEM, err := dhKEMForCurve(priv.Curve())
if err != nil {
return nil, err
}
return &dhKEMRecipient{
priv: priv,
dhKEM: dhKEM,
}, nil
}
func (dh *dhKEMRecipient) Bytes() ([]byte, error) {
// Bizarrely, RFC 9180, Section 7.1.2 says SerializePrivateKey MUST clamp // Bizarrely, RFC 9180, Section 7.1.2 says SerializePrivateKey MUST clamp
// the output, which I thought we all agreed to instead do as part of the DH // the output, which I thought we all agreed to instead do as part of the DH
// function, letting private keys be random bytes. // function, letting private keys be random bytes.
@ -396,11 +347,11 @@ func (dh *dhKEMRecipient) Bytes() ([]byte, error) {
// necessarily match the NewPrivateKey input. // necessarily match the NewPrivateKey input.
// //
// I'm sure this will not lead to any unexpected behavior or interop issue. // I'm sure this will not lead to any unexpected behavior or interop issue.
priv, ok := dh.priv.(*ecdh.PrivateKey) priv, ok := k.priv.(*ecdh.PrivateKey)
if !ok { if !ok {
return nil, errors.New("ecdh: private key does not support Bytes") return nil, errors.New("ecdh: private key does not support Bytes")
} }
if dh.id == dhkemX25519 { if k.kem == dhKEMX25519 {
b := priv.Bytes() b := priv.Bytes()
b[0] &= 248 b[0] &= 248
b[31] &= 127 b[31] &= 127
@ -410,22 +361,22 @@ func (dh *dhKEMRecipient) Bytes() ([]byte, error) {
return priv.Bytes(), nil return priv.Bytes(), nil
} }
func (dh *dhKEMRecipient) KEMSender() KEMSender { func (k *dhKEMPrivateKey) PublicKey() PublicKey {
return &dhKEMSender{ return &dhKEMPublicKey{
pub: dh.priv.PublicKey(), kem: k.kem,
dhKEM: dh.dhKEM, pub: k.priv.PublicKey(),
} }
} }
func (dh *dhKEMRecipient) decap(encPubEph []byte) ([]byte, error) { func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) {
pubEph, err := dh.priv.Curve().NewPublicKey(encPubEph) pubEph, err := k.priv.Curve().NewPublicKey(encPubEph)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dhVal, err := dh.priv.ECDH(pubEph) dhVal, err := k.priv.ECDH(pubEph)
if err != nil { if err != nil {
return nil, err return nil, err
} }
kemContext := append(encPubEph, dh.priv.PublicKey().Bytes()...) kemContext := append(encPubEph, k.priv.PublicKey().Bytes()...)
return dh.extractAndExpand(dhVal, kemContext) return k.kem.extractAndExpand(dhVal, kemContext)
} }

View file

@ -11,127 +11,221 @@ import (
"crypto/mlkem" "crypto/mlkem"
"crypto/rand" "crypto/rand"
"crypto/sha3" "crypto/sha3"
"encoding/binary"
"errors" "errors"
) )
const ( var mlkem768X25519 = &hybridKEM{
mlkem768 = 0x0041 // ML-KEM-768 id: 0x647a,
mlkem1024 = 0x0042 // ML-KEM-1024
mlkem768X25519 = 0x647a // MLKEM768-X25519
mlkem768P256 = 0x0050 // MLKEM768-P256
mlkem1024P384 = 0x0051 // MLKEM1024-P384
)
var mlkem768X25519Hybrid = hybrid{
id: mlkem768X25519,
label: /**/ `\./` + label: /**/ `\./` +
/* */ `/^\`, /* */ `/^\`,
curve: ecdh.X25519(),
curveSeedSize: 32,
curvePointSize: 32,
pqEncapsKeySize: mlkem.EncapsulationKeySize768,
pqCiphertextSize: mlkem.CiphertextSize768,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
} }
var mlkem768P256Hybrid = hybrid{ // MLKEM768X25519 returns a KEM implementing MLKEM768-X25519 (a.k.a. X-Wing)
id: mlkem768P256, // from draft-ietf-hpke-pq.
func MLKEM768X25519() KEM {
return mlkem768X25519
}
var mlkem768P256 = &hybridKEM{
id: 0x0050,
label: "MLKEM768-P256", label: "MLKEM768-P256",
curve: ecdh.P256(),
curveSeedSize: 32,
curvePointSize: 65,
pqEncapsKeySize: mlkem.EncapsulationKeySize768,
pqCiphertextSize: mlkem.CiphertextSize768,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
} }
var mlkem1024P384Hybrid = hybrid{ // MLKEM768P256 returns a KEM implementing MLKEM768-P256 from draft-ietf-hpke-pq.
id: mlkem1024P384, func MLKEM768P256() KEM {
return mlkem768P256
}
var mlkem1024P384 = &hybridKEM{
id: 0x0051,
label: "MLKEM1024-P384", label: "MLKEM1024-P384",
curve: ecdh.P384(),
curveSeedSize: 48,
curvePointSize: 97,
pqEncapsKeySize: mlkem.EncapsulationKeySize1024,
pqCiphertextSize: mlkem.CiphertextSize1024,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey1024(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey1024(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey1024()
},
} }
type hybrid struct { // MLKEM1024P384 returns a KEM implementing MLKEM1024-P384 from draft-ietf-hpke-pq.
func MLKEM1024P384() KEM {
return mlkem1024P384
}
type hybridKEM struct {
id uint16 id uint16
label string label string
curve ecdh.Curve
curveSeedSize int
curvePointSize int
pqEncapsKeySize int
pqCiphertextSize int
pqNewPublicKey func(data []byte) (crypto.Encapsulator, error)
pqNewPrivateKey func(data []byte) (crypto.Decapsulator, error)
pqGenerateKey func() (crypto.Decapsulator, error)
} }
func (x *hybrid) ID() uint16 { func (kem *hybridKEM) ID() uint16 {
return x.id return kem.id
} }
func (x *hybrid) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte { func (kem *hybridKEM) encSize() int {
return kem.pqCiphertextSize + kem.curvePointSize
}
func (kem *hybridKEM) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte {
h := sha3.New256() h := sha3.New256()
h.Write(ssPQ) h.Write(ssPQ)
h.Write(ssT) h.Write(ssT)
h.Write(ctT) h.Write(ctT)
h.Write(ekT) h.Write(ekT)
h.Write([]byte(x.label)) h.Write([]byte(kem.label))
return h.Sum(nil) return h.Sum(nil)
} }
type hybridSender struct { type hybridPublicKey struct {
hybrid kem *hybridKEM
t *ecdh.PublicKey t *ecdh.PublicKey
pq crypto.Encapsulator pq crypto.Encapsulator
} }
// NewHybridSender returns a KEMSender implementing one of // NewHybridPublicKey returns a PublicKey implementing one of
// //
// - MLKEM768-X25519 (a.k.a. X-Wing) // - MLKEM768-X25519 (a.k.a. X-Wing)
// - MLKEM768-P256 // - MLKEM768-P256
// - MLKEM1024-P384 // - MLKEM1024-P384
// //
// from draft-ietf-hpke-pq, depending on the underlying curve of t // from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq // ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq (either
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]). // *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
func NewHybridSender(t *ecdh.PublicKey, pq crypto.Encapsulator) (KEMSender, error) { //
// This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem public keys. Otherwise, applications should use
// the [KEM.NewPublicKey] method of e.g. [MLKEM768X25519].
func NewHybridPublicKey(pq crypto.Encapsulator, t *ecdh.PublicKey) (PublicKey, error) {
switch t.Curve() { switch t.Curve() {
case ecdh.X25519(): case ecdh.X25519():
if _, ok := pq.(*mlkem.EncapsulationKey768); !ok { if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for X25519 hybrid") return nil, errors.New("invalid PQ KEM for X25519 hybrid")
} }
return &hybridSender{mlkem768X25519Hybrid, t, pq}, nil return &hybridPublicKey{mlkem768X25519, t, pq}, nil
case ecdh.P256(): case ecdh.P256():
if _, ok := pq.(*mlkem.EncapsulationKey768); !ok { if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for P-256 hybrid") return nil, errors.New("invalid PQ KEM for P-256 hybrid")
} }
return &hybridSender{mlkem768P256Hybrid, t, pq}, nil return &hybridPublicKey{mlkem768P256, t, pq}, nil
case ecdh.P384(): case ecdh.P384():
if _, ok := pq.(*mlkem.EncapsulationKey1024); !ok { if _, ok := pq.(*mlkem.EncapsulationKey1024); !ok {
return nil, errors.New("invalid PQ KEM for P-384 hybrid") return nil, errors.New("invalid PQ KEM for P-384 hybrid")
} }
return &hybridSender{mlkem1024P384Hybrid, t, pq}, nil return &hybridPublicKey{mlkem1024P384, t, pq}, nil
default: default:
return nil, errors.New("unsupported curve") return nil, errors.New("unsupported curve")
} }
} }
func (s *hybridSender) Bytes() []byte { func (kem *hybridKEM) NewPublicKey(data []byte) (PublicKey, error) {
return append(s.pq.Bytes(), s.t.Bytes()...) if len(data) != kem.pqEncapsKeySize+kem.curvePointSize {
return nil, errors.New("invalid public key size")
}
pq, err := kem.pqNewPublicKey(data[:kem.pqEncapsKeySize])
if err != nil {
return nil, err
}
k, err := kem.curve.NewPublicKey(data[kem.pqEncapsKeySize:])
if err != nil {
return nil, err
}
return NewHybridPublicKey(pq, k)
}
func (pk *hybridPublicKey) KEM() KEM {
return pk.kem
}
func (pk *hybridPublicKey) Bytes() []byte {
return append(pk.pq.Bytes(), pk.t.Bytes()...)
} }
var testingOnlyEncapsulate func() (ss, ct []byte) var testingOnlyEncapsulate func() (ss, ct []byte)
func (s *hybridSender) encap() (sharedSecret []byte, encapPub []byte, err error) { func (pk *hybridPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
skE, err := s.t.Curve().GenerateKey(rand.Reader) skE, err := pk.t.Curve().GenerateKey(rand.Reader)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if testingOnlyGenerateKey != nil { if testingOnlyGenerateKey != nil {
skE = testingOnlyGenerateKey() skE = testingOnlyGenerateKey()
} }
ssT, err := skE.ECDH(s.t) ssT, err := skE.ECDH(pk.t)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
ctT := skE.PublicKey().Bytes() ctT := skE.PublicKey().Bytes()
ssPQ, ctPQ := s.pq.Encapsulate() ssPQ, ctPQ := pk.pq.Encapsulate()
if testingOnlyEncapsulate != nil { if testingOnlyEncapsulate != nil {
ssPQ, ctPQ = testingOnlyEncapsulate() ssPQ, ctPQ = testingOnlyEncapsulate()
} }
ss := s.sharedSecret(ssPQ, ssT, ctT, s.t.Bytes()) ss := pk.kem.sharedSecret(ssPQ, ssT, ctT, pk.t.Bytes())
ct := append(ctPQ, ctT...) ct := append(ctPQ, ctT...)
return ss, ct, nil return ss, ct, nil
} }
type hybridRecipient struct { type hybridPrivateKey struct {
hybrid kem *hybridKEM
seed []byte // can be nil seed []byte // can be nil
t ecdh.KeyExchanger t ecdh.KeyExchanger
pq crypto.Decapsulator pq crypto.Decapsulator
} }
// NewHybridRecipient returns a KEMRecipient implementing // NewHybridPrivateKey returns a PrivateKey implementing
// //
// - MLKEM768-X25519 (a.k.a. X-Wing) // - MLKEM768-X25519 (a.k.a. X-Wing)
// - MLKEM768-P256 // - MLKEM768-P256
@ -140,11 +234,23 @@ type hybridRecipient struct {
// from draft-ietf-hpke-pq, depending on the underlying curve of t // from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq.Encapsulator() // ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]). // (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
func NewHybridRecipient(t ecdh.KeyExchanger, pq crypto.Decapsulator) (KEMRecipient, error) { //
return newHybridRecipient(t, pq, nil) // This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem private keys, or another implementation of a
// [ecdh.KeyExchanger] and [crypto.Decapsulator] (e.g. a hardware key).
// Otherwise, applications should use the [KEM.NewPrivateKey] method of e.g.
// [MLKEM768X25519].
func NewHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger) (PrivateKey, error) {
return newHybridPrivateKey(pq, t, nil)
} }
func newHybridRecipientFromSeed(id uint16, priv []byte) (KEMRecipient, error) { func (kem *hybridKEM) GenerateKey() (PrivateKey, error) {
seed := make([]byte, 32)
rand.Read(seed)
return kem.NewPrivateKey(seed)
}
func (kem *hybridKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
if len(priv) != 32 { if len(priv) != 32 {
return nil, errors.New("hpke: invalid hybrid KEM secret length") return nil, errors.New("hpke: invalid hybrid KEM secret length")
} }
@ -154,200 +260,256 @@ func newHybridRecipientFromSeed(id uint16, priv []byte) (KEMRecipient, error) {
seedPQ := make([]byte, mlkem.SeedSize) seedPQ := make([]byte, mlkem.SeedSize)
s.Read(seedPQ) s.Read(seedPQ)
pq, err := kem.pqNewPrivateKey(seedPQ)
var pq crypto.Decapsulator if err != nil {
switch id { return nil, err
case mlkem768X25519, mlkem768P256:
sk, err := mlkem.NewDecapsulationKey768(seedPQ)
if err != nil {
return nil, err
}
pq = sk
case mlkem1024P384:
sk, err := mlkem.NewDecapsulationKey1024(seedPQ)
if err != nil {
return nil, err
}
pq = sk
default:
return nil, errors.New("hpke: invalid hybrid KEM ID")
}
var seedT []byte
var curve ecdh.Curve
switch id {
case mlkem768X25519:
seedT = make([]byte, 32)
curve = ecdh.X25519()
case mlkem768P256:
seedT = make([]byte, 32)
curve = ecdh.P256()
case mlkem1024P384:
seedT = make([]byte, 48)
curve = ecdh.P384()
default:
return nil, errors.New("hpke: invalid hybrid KEM ID")
} }
seedT := make([]byte, kem.curveSeedSize)
for { for {
s.Read(seedT) s.Read(seedT)
k, err := curve.NewPrivateKey(seedT) k, err := kem.curve.NewPrivateKey(seedT)
if err != nil { if err != nil {
continue continue
} }
return newHybridRecipient(k, pq, priv) return newHybridPrivateKey(pq, k, priv)
} }
} }
func newHybridRecipient(t ecdh.KeyExchanger, pq crypto.Decapsulator, seed []byte) (KEMRecipient, error) { func newHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger, seed []byte) (PrivateKey, error) {
switch t.Curve() { switch t.Curve() {
case ecdh.X25519(): case ecdh.X25519():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok { if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for X25519 hybrid") return nil, errors.New("invalid PQ KEM for X25519 hybrid")
} }
return &hybridRecipient{mlkem768X25519Hybrid, bytes.Clone(seed), t, pq}, nil return &hybridPrivateKey{mlkem768X25519, bytes.Clone(seed), t, pq}, nil
case ecdh.P256(): case ecdh.P256():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok { if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for P-256 hybrid") return nil, errors.New("invalid PQ KEM for P-256 hybrid")
} }
return &hybridRecipient{mlkem768P256Hybrid, bytes.Clone(seed), t, pq}, nil return &hybridPrivateKey{mlkem768P256, bytes.Clone(seed), t, pq}, nil
case ecdh.P384(): case ecdh.P384():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey1024); !ok { if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey1024); !ok {
return nil, errors.New("invalid PQ KEM for P-384 hybrid") return nil, errors.New("invalid PQ KEM for P-384 hybrid")
} }
return &hybridRecipient{mlkem1024P384Hybrid, bytes.Clone(seed), t, pq}, nil return &hybridPrivateKey{mlkem1024P384, bytes.Clone(seed), t, pq}, nil
default: default:
return nil, errors.New("unsupported curve") return nil, errors.New("unsupported curve")
} }
} }
func (r *hybridRecipient) Bytes() ([]byte, error) { func (kem *hybridKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
if r.seed == nil { suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 32)
if err != nil {
return nil, err
}
return kem.NewPrivateKey(dk)
}
func (k *hybridPrivateKey) KEM() KEM {
return k.kem
}
func (k *hybridPrivateKey) Bytes() ([]byte, error) {
if k.seed == nil {
return nil, errors.New("private key seed not available") return nil, errors.New("private key seed not available")
} }
return r.seed, nil return k.seed, nil
} }
func (r *hybridRecipient) KEMSender() KEMSender { func (k *hybridPrivateKey) PublicKey() PublicKey {
return &hybridSender{ return &hybridPublicKey{
hybrid: r.hybrid, kem: k.kem,
t: r.t.PublicKey(), t: k.t.PublicKey(),
pq: r.pq.Encapsulator(), pq: k.pq.Encapsulator(),
} }
} }
func (r *hybridRecipient) decap(enc []byte) ([]byte, error) { func (k *hybridPrivateKey) decap(enc []byte) ([]byte, error) {
var ctPQ, ctT []byte if len(enc) != k.kem.pqCiphertextSize+k.kem.curvePointSize {
switch r.id { return nil, errors.New("invalid encapsulated key size")
case mlkem768X25519:
if len(enc) != mlkem.CiphertextSize768+32 {
return nil, errors.New("invalid encapsulated key size")
}
ctPQ, ctT = enc[:mlkem.CiphertextSize768], enc[mlkem.CiphertextSize768:]
case mlkem768P256:
if len(enc) != mlkem.CiphertextSize768+65 {
return nil, errors.New("invalid encapsulated key size")
}
ctPQ, ctT = enc[:mlkem.CiphertextSize768], enc[mlkem.CiphertextSize768:]
case mlkem1024P384:
if len(enc) != mlkem.CiphertextSize1024+97 {
return nil, errors.New("invalid encapsulated key size")
}
ctPQ, ctT = enc[:mlkem.CiphertextSize1024], enc[mlkem.CiphertextSize1024:]
default:
return nil, errors.New("internal error: unsupported KEM")
} }
ssPQ, err := r.pq.Decapsulate(ctPQ) ctPQ, ctT := enc[:k.kem.pqCiphertextSize], enc[k.kem.pqCiphertextSize:]
ssPQ, err := k.pq.Decapsulate(ctPQ)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pub, err := r.t.Curve().NewPublicKey(ctT) pub, err := k.t.Curve().NewPublicKey(ctT)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ssT, err := r.t.ECDH(pub) ssT, err := k.t.ECDH(pub)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ss := r.sharedSecret(ssPQ, ssT, ctT, r.t.PublicKey().Bytes()) ss := k.kem.sharedSecret(ssPQ, ssT, ctT, k.t.PublicKey().Bytes())
return ss, nil return ss, nil
} }
type mlkemSender struct { var mlkem768 = &mlkemKEM{
id uint16 id: 0x0041,
pq interface { ciphertextSize: mlkem.CiphertextSize768,
Bytes() []byte newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
Encapsulate() (sharedKey []byte, ciphertext []byte) return mlkem.NewEncapsulationKey768(data)
} },
newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
generateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
} }
// NewMLKEMSender returns a KEMSender implementing ML-KEM-768 or ML-KEM-1024 from // MLKEM768 returns a KEM implementing ML-KEM-768 from draft-ietf-hpke-pq.
// draft-ietf-hpke-pq. pub must be either a *[mlkem.EncapsulationKey768] or a func MLKEM768() KEM {
// *[mlkem.EncapsulationKey1024]. return mlkem768
func NewMLKEMSender(pub crypto.Encapsulator) (KEMSender, error) { }
var mlkem1024 = &mlkemKEM{
id: 0x0042,
ciphertextSize: mlkem.CiphertextSize1024,
newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey1024(data)
},
newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey1024(data)
},
generateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey1024()
},
}
// MLKEM1024 returns a KEM implementing ML-KEM-1024 from draft-ietf-hpke-pq.
func MLKEM1024() KEM {
return mlkem1024
}
type mlkemKEM struct {
id uint16
ciphertextSize int
newPublicKey func(data []byte) (crypto.Encapsulator, error)
newPrivateKey func(data []byte) (crypto.Decapsulator, error)
generateKey func() (crypto.Decapsulator, error)
}
func (kem *mlkemKEM) ID() uint16 {
return kem.id
}
func (kem *mlkemKEM) encSize() int {
return kem.ciphertextSize
}
type mlkemPublicKey struct {
kem *mlkemKEM
pq crypto.Encapsulator
}
// NewMLKEMPublicKey returns a KEMPublicKey implementing
//
// - ML-KEM-768
// - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of pub
// (*[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of e.g. [MLKEM768].
func NewMLKEMPublicKey(pub crypto.Encapsulator) (PublicKey, error) {
switch pub.(type) { switch pub.(type) {
case *mlkem.EncapsulationKey768: case *mlkem.EncapsulationKey768:
return &mlkemSender{ return &mlkemPublicKey{mlkem768, pub}, nil
id: mlkem768,
pq: pub,
}, nil
case *mlkem.EncapsulationKey1024: case *mlkem.EncapsulationKey1024:
return &mlkemSender{ return &mlkemPublicKey{mlkem1024, pub}, nil
id: mlkem1024,
pq: pub,
}, nil
default: default:
return nil, errors.New("unsupported public key type") return nil, errors.New("unsupported public key type")
} }
} }
func (s *mlkemSender) ID() uint16 { func (kem *mlkemKEM) NewPublicKey(data []byte) (PublicKey, error) {
return s.id pq, err := kem.newPublicKey(data)
if err != nil {
return nil, err
}
return NewMLKEMPublicKey(pq)
} }
func (s *mlkemSender) Bytes() []byte { func (pk *mlkemPublicKey) KEM() KEM {
return s.pq.Bytes() return pk.kem
} }
func (s *mlkemSender) encap() (sharedSecret []byte, encapPub []byte, err error) { func (pk *mlkemPublicKey) Bytes() []byte {
ss, ct := s.pq.Encapsulate() return pk.pq.Bytes()
}
func (pk *mlkemPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
ss, ct := pk.pq.Encapsulate()
if testingOnlyEncapsulate != nil { if testingOnlyEncapsulate != nil {
ss, ct = testingOnlyEncapsulate() ss, ct = testingOnlyEncapsulate()
} }
return ss, ct, nil return ss, ct, nil
} }
type mlkemRecipient struct { type mlkemPrivateKey struct {
id uint16 kem *mlkemKEM
pq crypto.Decapsulator pq crypto.Decapsulator
} }
// NewMLKEMRecipient returns a KEMRecipient implementing ML-KEM-768 or ML-KEM-1024 // NewMLKEMPrivateKey returns a KEMPrivateKey implementing
// from draft-ietf-hpke-pq. priv.Encapsulator() must return either a //
// *[mlkem.EncapsulationKey768] or a *[mlkem.EncapsulationKey1024]. // - ML-KEM-768
func NewMLKEMRecipient(priv crypto.Decapsulator) (KEMRecipient, error) { // - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of priv.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem private key. Otherwise, applications should use the
// [KEM.NewPrivateKey] method of e.g. [MLKEM768].
func NewMLKEMPrivateKey(priv crypto.Decapsulator) (PrivateKey, error) {
switch priv.Encapsulator().(type) { switch priv.Encapsulator().(type) {
case *mlkem.EncapsulationKey768: case *mlkem.EncapsulationKey768:
return &mlkemRecipient{ return &mlkemPrivateKey{mlkem768, priv}, nil
id: mlkem768,
pq: priv,
}, nil
case *mlkem.EncapsulationKey1024: case *mlkem.EncapsulationKey1024:
return &mlkemRecipient{ return &mlkemPrivateKey{mlkem1024, priv}, nil
id: mlkem1024,
pq: priv,
}, nil
default: default:
return nil, errors.New("unsupported public key type") return nil, errors.New("unsupported public key type")
} }
} }
func (r *mlkemRecipient) ID() uint16 { func (kem *mlkemKEM) GenerateKey() (PrivateKey, error) {
return r.id pq, err := kem.generateKey()
if err != nil {
return nil, err
}
return NewMLKEMPrivateKey(pq)
} }
func (r *mlkemRecipient) Bytes() ([]byte, error) { func (kem *mlkemKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
pq, ok := r.pq.(interface { pq, err := kem.newPrivateKey(priv)
if err != nil {
return nil, err
}
return NewMLKEMPrivateKey(pq)
}
func (kem *mlkemKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 64)
if err != nil {
return nil, err
}
return kem.NewPrivateKey(dk)
}
func (k *mlkemPrivateKey) KEM() KEM {
return k.kem
}
func (k *mlkemPrivateKey) Bytes() ([]byte, error) {
pq, ok := k.pq.(interface {
Bytes() []byte Bytes() []byte
}) })
if !ok { if !ok {
@ -356,26 +518,13 @@ func (r *mlkemRecipient) Bytes() ([]byte, error) {
return pq.Bytes(), nil return pq.Bytes(), nil
} }
func (r *mlkemRecipient) KEMSender() KEMSender { func (k *mlkemPrivateKey) PublicKey() PublicKey {
s := &mlkemSender{ return &mlkemPublicKey{
id: r.id, kem: k.kem,
pq: r.pq.Encapsulator(), pq: k.pq.Encapsulator(),
} }
return s
} }
func (r *mlkemRecipient) decap(enc []byte) ([]byte, error) { func (k *mlkemPrivateKey) decap(enc []byte) ([]byte, error) {
switch r.id { return k.pq.Decapsulate(enc)
case mlkem768:
if len(enc) != mlkem.CiphertextSize768 {
return nil, errors.New("invalid encapsulated key size")
}
case mlkem1024:
if len(enc) != mlkem.CiphertextSize1024 {
return nil, errors.New("invalid encapsulated key size")
}
default:
return nil, errors.New("internal error: unsupported KEM")
}
return r.pq.Decapsulate(enc)
} }

File diff suppressed because it is too large Load diff

View file

@ -149,7 +149,7 @@ func parseECHConfigList(data []byte) ([]echConfig, error) {
return configs, nil return configs, nil
} }
func pickECHConfig(list []echConfig) (*echConfig, hpke.KEMSender, hpke.KDF, hpke.AEAD) { func pickECHConfig(list []echConfig) (*echConfig, hpke.PublicKey, hpke.KDF, hpke.AEAD) {
for _, ec := range list { for _, ec := range list {
if !validDNSName(string(ec.PublicName)) { if !validDNSName(string(ec.PublicName)) {
continue continue
@ -166,10 +166,16 @@ func pickECHConfig(list []echConfig) (*echConfig, hpke.KEMSender, hpke.KDF, hpke
if unsupportedExt { if unsupportedExt {
continue continue
} }
s, err := hpke.NewKEMSender(ec.KemID, ec.PublicKey) kem, err := hpke.NewKEM(ec.KemID)
if err != nil { if err != nil {
continue continue
} }
pub, err := kem.NewPublicKey(ec.PublicKey)
if err != nil {
// This is an error in the config, but killing the connection feels
// excessive.
continue
}
for _, cs := range ec.SymmetricCipherSuite { for _, cs := range ec.SymmetricCipherSuite {
// All of the supported AEADs and KDFs are fine, rather than // All of the supported AEADs and KDFs are fine, rather than
// imposing some sort of preference here, we just pick the first // imposing some sort of preference here, we just pick the first
@ -182,7 +188,7 @@ func pickECHConfig(list []echConfig) (*echConfig, hpke.KEMSender, hpke.KDF, hpke
if err != nil { if err != nil {
continue continue
} }
return &ec, s, kdf, aead return &ec, pub, kdf, aead
} }
} }
return nil, nil, nil, nil return nil, nil, nil, nil
@ -568,7 +574,12 @@ func (c *Conn) processECHClientHello(outer *clientHelloMsg, echKeys []EncryptedC
if skip { if skip {
continue continue
} }
echPriv, err := hpke.NewKEMRecipient(config.KemID, echKey.PrivateKey) kem, err := hpke.NewKEM(config.KemID)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config KEM: %s", err)
}
echPriv, err := kem.NewPrivateKey(echKey.PrivateKey)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey PrivateKey: %s", err) return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey PrivateKey: %s", err)