crypto/mlkem: swap order of return values of Encapsulate

Per FIPS 203 (https://csrc.nist.gov/pubs/fips/203/final), the order of return values should be sharedKey, ciphertext. This commit simply swaps those return values and updates any consumers of the Encapsulate() method to respect the new order.

Fixes #70950

Change-Id: I2a0d605e3baf7fe69510d60d3d35bbac18f883c9
Reviewed-on: https://go-review.googlesource.com/c/go/+/638376
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Austin Clements <austin@google.com>
Auto-Submit: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
This commit is contained in:
Alec Bakholdin 2024-12-22 20:36:59 -05:00 committed by Gopher Robot
parent 772f024c61
commit cce75da30b
7 changed files with 25 additions and 25 deletions

View file

@ -40,7 +40,7 @@ func init() {
dk := &DecapsulationKey768{} dk := &DecapsulationKey768{}
kemKeyGen(dk, d, z) kemKeyGen(dk, d, z)
ek := dk.EncapsulationKey() ek := dk.EncapsulationKey()
c, Ke := ek.EncapsulateInternal(m) Ke, c := ek.EncapsulateInternal(m)
Kd, err := dk.Decapsulate(c) Kd, err := dk.Decapsulate(c)
if err != nil { if err != nil {
return err return err

View file

@ -189,7 +189,7 @@ func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
// the first operational use (if not exported before the first use)." // the first operational use (if not exported before the first use)."
func kemPCT1024(dk *DecapsulationKey1024) error { func kemPCT1024(dk *DecapsulationKey1024) error {
ek := dk.EncapsulationKey() ek := dk.EncapsulationKey()
c, K := ek.Encapsulate() K, c := ek.Encapsulate()
K1, err := dk.Decapsulate(c) K1, err := dk.Decapsulate(c)
if err != nil { if err != nil {
return err return err
@ -204,13 +204,13 @@ func kemPCT1024(dk *DecapsulationKey1024) error {
// encapsulation key, drawing random bytes from a DRBG. // encapsulation key, drawing random bytes from a DRBG.
// //
// The shared key must be kept secret. // The shared key must be kept secret.
func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
// The actual logic is in a separate function to outline this allocation. // The actual logic is in a separate function to outline this allocation.
var cc [CiphertextSize1024]byte var cc [CiphertextSize1024]byte
return ek.encapsulate(&cc) return ek.encapsulate(&cc)
} }
func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (sharedKey, ciphertext []byte) {
var m [messageSize]byte var m [messageSize]byte
drbg.Read(m[:]) drbg.Read(m[:])
// Note that the modulus check (step 2 of the encapsulation key check from // Note that the modulus check (step 2 of the encapsulation key check from
@ -221,7 +221,7 @@ func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (ciphe
// EncapsulateInternal is a derandomized version of Encapsulate, exclusively for // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
// use in tests. // use in tests.
func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
cc := &[CiphertextSize1024]byte{} cc := &[CiphertextSize1024]byte{}
return kemEncaps1024(cc, ek, m) return kemEncaps1024(cc, ek, m)
} }
@ -229,14 +229,14 @@ func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (ciphertext, sh
// kemEncaps1024 generates a shared key and an associated ciphertext. // kemEncaps1024 generates a shared key and an associated ciphertext.
// //
// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17. // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (c, K []byte) { func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (K, c []byte) {
g := sha3.New512() g := sha3.New512()
g.Write(m[:]) g.Write(m[:])
g.Write(ek.h[:]) g.Write(ek.h[:])
G := g.Sum(nil) G := g.Sum(nil)
K, r := G[:SharedKeySize], G[SharedKeySize:] K, r := G[:SharedKeySize], G[SharedKeySize:]
c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r) c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
return c, K return K, c
} }
// NewEncapsulationKey1024 parses an encapsulation key from its encoded form. // NewEncapsulationKey1024 parses an encapsulation key from its encoded form.

View file

@ -246,7 +246,7 @@ func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
// the first operational use (if not exported before the first use)." // the first operational use (if not exported before the first use)."
func kemPCT(dk *DecapsulationKey768) error { func kemPCT(dk *DecapsulationKey768) error {
ek := dk.EncapsulationKey() ek := dk.EncapsulationKey()
c, K := ek.Encapsulate() K, c := ek.Encapsulate()
K1, err := dk.Decapsulate(c) K1, err := dk.Decapsulate(c)
if err != nil { if err != nil {
return err return err
@ -261,13 +261,13 @@ func kemPCT(dk *DecapsulationKey768) error {
// encapsulation key, drawing random bytes from a DRBG. // encapsulation key, drawing random bytes from a DRBG.
// //
// The shared key must be kept secret. // The shared key must be kept secret.
func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
// The actual logic is in a separate function to outline this allocation. // The actual logic is in a separate function to outline this allocation.
var cc [CiphertextSize768]byte var cc [CiphertextSize768]byte
return ek.encapsulate(&cc) return ek.encapsulate(&cc)
} }
func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (sharedKey, ciphertext []byte) {
var m [messageSize]byte var m [messageSize]byte
drbg.Read(m[:]) drbg.Read(m[:])
// Note that the modulus check (step 2 of the encapsulation key check from // Note that the modulus check (step 2 of the encapsulation key check from
@ -278,7 +278,7 @@ func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (ciphert
// EncapsulateInternal is a derandomized version of Encapsulate, exclusively for // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
// use in tests. // use in tests.
func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
cc := &[CiphertextSize768]byte{} cc := &[CiphertextSize768]byte{}
return kemEncaps(cc, ek, m) return kemEncaps(cc, ek, m)
} }
@ -286,14 +286,14 @@ func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (ciphertext, sha
// kemEncaps generates a shared key and an associated ciphertext. // kemEncaps generates a shared key and an associated ciphertext.
// //
// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17. // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (c, K []byte) { func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (K, c []byte) {
g := sha3.New512() g := sha3.New512()
g.Write(m[:]) g.Write(m[:])
g.Write(ek.h[:]) g.Write(ek.h[:])
G := g.Sum(nil) G := g.Sum(nil)
K, r := G[:SharedKeySize], G[SharedKeySize:] K, r := G[:SharedKeySize], G[SharedKeySize:]
c = pkeEncrypt(cc, &ek.encryptionKey, m, r) c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
return c, K return K, c
} }
// NewEncapsulationKey768 parses an encapsulation key from its encoded form. // NewEncapsulationKey768 parses an encapsulation key from its encoded form.

View file

@ -91,6 +91,6 @@ func (ek *EncapsulationKey1024) Bytes() []byte {
// encapsulation key, drawing random bytes from crypto/rand. // encapsulation key, drawing random bytes from crypto/rand.
// //
// The shared key must be kept secret. // The shared key must be kept secret.
func (ek *EncapsulationKey1024) Encapsulate() (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
return ek.key.Encapsulate() return ek.key.Encapsulate()
} }

View file

@ -101,6 +101,6 @@ func (ek *EncapsulationKey768) Bytes() []byte {
// encapsulation key, drawing random bytes from crypto/rand. // encapsulation key, drawing random bytes from crypto/rand.
// //
// The shared key must be kept secret. // The shared key must be kept secret.
func (ek *EncapsulationKey768) Encapsulate() (ciphertext, sharedKey []byte) { func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
return ek.key.Encapsulate() return ek.key.Encapsulate()
} }

View file

@ -43,7 +43,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
t.Fatal(err) t.Fatal(err)
} }
ek := dk.EncapsulationKey() ek := dk.EncapsulationKey()
c, Ke := ek.Encapsulate() Ke, c := ek.Encapsulate()
Kd, err := dk.Decapsulate(c) Kd, err := dk.Decapsulate(c)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -66,7 +66,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
if !bytes.Equal(dk.Bytes(), dk1.Bytes()) { if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
t.Fail() t.Fail()
} }
c1, Ke1 := ek1.Encapsulate() Ke1, c1 := ek1.Encapsulate()
Kd1, err := dk1.Decapsulate(c1) Kd1, err := dk1.Decapsulate(c1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -86,7 +86,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
t.Fail() t.Fail()
} }
c2, Ke2 := dk.EncapsulationKey().Encapsulate() Ke2, c2 := dk.EncapsulationKey().Encapsulate()
if bytes.Equal(c, c2) { if bytes.Equal(c, c2) {
t.Fail() t.Fail()
} }
@ -115,7 +115,7 @@ func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
} }
ek := dk.EncapsulationKey() ek := dk.EncapsulationKey()
ekBytes := dk.EncapsulationKey().Bytes() ekBytes := dk.EncapsulationKey().Bytes()
c, _ := ek.Encapsulate() _, c := ek.Encapsulate()
for i := 0; i < len(dkBytes)-1; i++ { for i := 0; i < len(dkBytes)-1; i++ {
if _, err := newDecapsulationKey(dkBytes[:i]); err == nil { if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
@ -189,7 +189,7 @@ func TestAccumulated(t *testing.T) {
o.Write(ek.Bytes()) o.Write(ek.Bytes())
s.Read(msg[:]) s.Read(msg[:])
ct, k := ek.key.EncapsulateInternal(&msg) k, ct := ek.key.EncapsulateInternal(&msg)
o.Write(ct) o.Write(ct)
o.Write(k) o.Write(k)
@ -244,7 +244,7 @@ func BenchmarkEncaps(b *testing.B) {
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
c, K := ek.key.EncapsulateInternal(&m) K, c := ek.key.EncapsulateInternal(&m)
sink ^= c[0] ^ K[0] sink ^= c[0] ^ K[0]
} }
} }
@ -255,7 +255,7 @@ func BenchmarkDecaps(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
ek := dk.EncapsulationKey() ek := dk.EncapsulationKey()
c, _ := ek.Encapsulate() _, c := ek.Encapsulate()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
K, _ := dk.Decapsulate(c) K, _ := dk.Decapsulate(c)
@ -270,7 +270,7 @@ func BenchmarkRoundTrip(b *testing.B) {
} }
ek := dk.EncapsulationKey() ek := dk.EncapsulationKey()
ekBytes := ek.Bytes() ekBytes := ek.Bytes()
c, _ := ek.Encapsulate() _, c := ek.Encapsulate()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -296,7 +296,7 @@ func BenchmarkRoundTrip(b *testing.B) {
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
cS, Ks := ek.Encapsulate() Ks, cS := ek.Encapsulate()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View file

@ -280,7 +280,7 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid X25519MLKEM768 client key share") return errors.New("tls: invalid X25519MLKEM768 client key share")
} }
ciphertext, mlkemSharedSecret := k.Encapsulate() mlkemSharedSecret, ciphertext := k.Encapsulate()
// draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.3: "For // draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.3: "For
// X25519MLKEM768, the shared secret is the concatenation of the ML-KEM // X25519MLKEM768, the shared secret is the concatenation of the ML-KEM
// shared secret and the X25519 shared secret. The shared secret is 64 // shared secret and the X25519 shared secret. The shared secret is 64