2023-12-06 16:51:11 +01:00
|
|
|
// Copyright 2023 The Go Authors. All rights reserved.
|
|
|
|
|
// Use of this source code is governed by a BSD-style
|
|
|
|
|
// license that can be found in the LICENSE file.
|
|
|
|
|
|
2024-10-23 11:41:42 +02:00
|
|
|
package mlkem
|
2023-12-06 16:51:11 +01:00
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
2024-10-23 11:41:42 +02:00
|
|
|
"crypto/internal/fips/sha3"
|
2023-12-06 16:51:11 +01:00
|
|
|
"crypto/rand"
|
|
|
|
|
_ "embed"
|
|
|
|
|
"encoding/hex"
|
|
|
|
|
"flag"
|
|
|
|
|
"testing"
|
|
|
|
|
)
|
|
|
|
|
|
2024-10-23 11:36:56 +02:00
|
|
|
type encapsulationKey interface {
|
|
|
|
|
Bytes() []byte
|
|
|
|
|
Encapsulate() ([]byte, []byte)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type decapsulationKey[E encapsulationKey] interface {
|
|
|
|
|
Bytes() []byte
|
|
|
|
|
Decapsulate([]byte) ([]byte, error)
|
|
|
|
|
EncapsulationKey() E
|
|
|
|
|
}
|
|
|
|
|
|
2023-12-06 16:51:11 +01:00
|
|
|
func TestRoundTrip(t *testing.T) {
|
2024-10-23 11:36:56 +02:00
|
|
|
t.Run("768", func(t *testing.T) {
|
|
|
|
|
testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
|
|
|
|
|
})
|
|
|
|
|
t.Run("1024", func(t *testing.T) {
|
|
|
|
|
testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
|
|
|
|
|
t *testing.T, generateKey func() (D, error),
|
|
|
|
|
newEncapsulationKey func([]byte) (E, error),
|
|
|
|
|
newDecapsulationKey func([]byte) (D, error)) {
|
|
|
|
|
dk, err := generateKey()
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
2024-10-23 11:36:56 +02:00
|
|
|
ek := dk.EncapsulationKey()
|
|
|
|
|
c, Ke := ek.Encapsulate()
|
2024-10-16 14:31:44 +02:00
|
|
|
Kd, err := dk.Decapsulate(c)
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(Ke, Kd) {
|
|
|
|
|
t.Fail()
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-23 11:36:56 +02:00
|
|
|
ek1, err := newEncapsulationKey(ek.Bytes())
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(ek.Bytes(), ek1.Bytes()) {
|
|
|
|
|
t.Fail()
|
|
|
|
|
}
|
|
|
|
|
dk1, err := newDecapsulationKey(dk.Bytes())
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
2024-10-23 11:36:56 +02:00
|
|
|
if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Fail()
|
|
|
|
|
}
|
2024-10-23 11:36:56 +02:00
|
|
|
c1, Ke1 := ek1.Encapsulate()
|
|
|
|
|
Kd1, err := dk1.Decapsulate(c1)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(Ke1, Kd1) {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Fail()
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-23 11:36:56 +02:00
|
|
|
dk2, err := generateKey()
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) {
|
|
|
|
|
t.Fail()
|
|
|
|
|
}
|
|
|
|
|
if bytes.Equal(dk.Bytes(), dk2.Bytes()) {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Fail()
|
|
|
|
|
}
|
2024-10-23 11:36:56 +02:00
|
|
|
|
|
|
|
|
c2, Ke2 := dk.EncapsulationKey().Encapsulate()
|
|
|
|
|
if bytes.Equal(c, c2) {
|
|
|
|
|
t.Fail()
|
|
|
|
|
}
|
|
|
|
|
if bytes.Equal(Ke, Ke2) {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Fail()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestBadLengths(t *testing.T) {
|
2024-10-23 11:36:56 +02:00
|
|
|
t.Run("768", func(t *testing.T) {
|
|
|
|
|
testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
|
|
|
|
|
})
|
|
|
|
|
t.Run("1024", func(t *testing.T) {
|
|
|
|
|
testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
|
|
|
|
|
t *testing.T, generateKey func() (D, error),
|
|
|
|
|
newEncapsulationKey func([]byte) (E, error),
|
|
|
|
|
newDecapsulationKey func([]byte) (D, error)) {
|
|
|
|
|
dk, err := generateKey()
|
|
|
|
|
dkBytes := dk.Bytes()
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
2024-04-15 03:56:10 +02:00
|
|
|
ek := dk.EncapsulationKey()
|
2024-10-21 14:30:46 +02:00
|
|
|
ekBytes := dk.EncapsulationKey().Bytes()
|
|
|
|
|
c, _ := ek.Encapsulate()
|
2023-12-06 16:51:11 +01:00
|
|
|
|
2024-10-23 11:36:56 +02:00
|
|
|
for i := 0; i < len(dkBytes)-1; i++ {
|
|
|
|
|
if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
|
|
|
|
|
t.Errorf("expected error for dk length %d", i)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
dkLong := dkBytes
|
|
|
|
|
for i := 0; i < 100; i++ {
|
|
|
|
|
dkLong = append(dkLong, 0)
|
|
|
|
|
if _, err := newDecapsulationKey(dkLong); err == nil {
|
|
|
|
|
t.Errorf("expected error for dk length %d", len(dkLong))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-21 14:30:46 +02:00
|
|
|
for i := 0; i < len(ekBytes)-1; i++ {
|
2024-10-23 11:36:56 +02:00
|
|
|
if _, err := newEncapsulationKey(ekBytes[:i]); err == nil {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Errorf("expected error for ek length %d", i)
|
|
|
|
|
}
|
|
|
|
|
}
|
2024-10-21 14:30:46 +02:00
|
|
|
ekLong := ekBytes
|
2023-12-06 16:51:11 +01:00
|
|
|
for i := 0; i < 100; i++ {
|
|
|
|
|
ekLong = append(ekLong, 0)
|
2024-10-23 11:36:56 +02:00
|
|
|
if _, err := newEncapsulationKey(ekLong); err == nil {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Errorf("expected error for ek length %d", len(ekLong))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for i := 0; i < len(c)-1; i++ {
|
2024-10-16 14:31:44 +02:00
|
|
|
if _, err := dk.Decapsulate(c[:i]); err == nil {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Errorf("expected error for c length %d", i)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
cLong := c
|
|
|
|
|
for i := 0; i < 100; i++ {
|
|
|
|
|
cLong = append(cLong, 0)
|
2024-10-16 14:31:44 +02:00
|
|
|
if _, err := dk.Decapsulate(cLong); err == nil {
|
2023-12-06 16:51:11 +01:00
|
|
|
t.Errorf("expected error for c length %d", len(cLong))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var millionFlag = flag.Bool("million", false, "run the million vector test")
|
|
|
|
|
|
2024-10-12 20:22:44 +02:00
|
|
|
// TestAccumulated accumulates 10k (or 100, or 1M) random vectors and checks the
|
|
|
|
|
// hash of the result, to avoid checking in 150MB of test vectors.
|
|
|
|
|
func TestAccumulated(t *testing.T) {
|
2023-12-06 16:51:11 +01:00
|
|
|
n := 10000
|
2024-10-12 20:22:44 +02:00
|
|
|
expected := "8a518cc63da366322a8e7a818c7a0d63483cb3528d34a4cf42f35d5ad73f22fc"
|
2023-12-06 16:51:11 +01:00
|
|
|
if testing.Short() {
|
|
|
|
|
n = 100
|
2024-10-12 20:22:44 +02:00
|
|
|
expected = "1114b1b6699ed191734fa339376afa7e285c9e6acf6ff0177d346696ce564415"
|
2023-12-06 16:51:11 +01:00
|
|
|
}
|
|
|
|
|
if *millionFlag {
|
|
|
|
|
n = 1000000
|
2024-10-12 20:22:44 +02:00
|
|
|
expected = "424bf8f0e8ae99b78d788a6e2e8e9cdaf9773fc0c08a6f433507cb559edfd0f0"
|
2023-12-06 16:51:11 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s := sha3.NewShake128()
|
|
|
|
|
o := sha3.NewShake128()
|
2024-10-12 20:22:44 +02:00
|
|
|
seed := make([]byte, SeedSize)
|
|
|
|
|
var msg [messageSize]byte
|
2024-10-21 16:29:23 +02:00
|
|
|
ct1 := make([]byte, CiphertextSize768)
|
2023-12-06 16:51:11 +01:00
|
|
|
|
|
|
|
|
for i := 0; i < n; i++ {
|
2024-10-12 20:22:44 +02:00
|
|
|
s.Read(seed)
|
2024-10-21 16:29:23 +02:00
|
|
|
dk, err := NewDecapsulationKey768(seed)
|
2024-10-12 20:22:44 +02:00
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
ek := dk.EncapsulationKey()
|
2024-10-21 14:30:46 +02:00
|
|
|
o.Write(ek.Bytes())
|
2023-12-06 16:51:11 +01:00
|
|
|
|
2024-10-12 20:22:44 +02:00
|
|
|
s.Read(msg[:])
|
2024-10-21 14:30:46 +02:00
|
|
|
ct, k := kemEncaps(nil, ek, &msg)
|
2023-12-06 16:51:11 +01:00
|
|
|
o.Write(ct)
|
|
|
|
|
o.Write(k)
|
|
|
|
|
|
2024-10-16 14:31:44 +02:00
|
|
|
kk, err := dk.Decapsulate(ct)
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(kk, k) {
|
|
|
|
|
t.Errorf("k: got %x, expected %x", kk, k)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.Read(ct1)
|
2024-10-16 14:31:44 +02:00
|
|
|
k1, err := dk.Decapsulate(ct1)
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
o.Write(k1)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
got := hex.EncodeToString(o.Sum(nil))
|
|
|
|
|
if got != expected {
|
|
|
|
|
t.Errorf("got %s, expected %s", got, expected)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var sink byte
|
|
|
|
|
|
|
|
|
|
func BenchmarkKeyGen(b *testing.B) {
|
2024-10-21 16:29:23 +02:00
|
|
|
var dk DecapsulationKey768
|
2024-04-15 03:56:10 +02:00
|
|
|
var d, z [32]byte
|
|
|
|
|
rand.Read(d[:])
|
|
|
|
|
rand.Read(z[:])
|
2023-12-06 16:51:11 +01:00
|
|
|
b.ResetTimer()
|
|
|
|
|
for i := 0; i < b.N; i++ {
|
2024-11-10 15:22:00 +01:00
|
|
|
kemKeyGen(&dk, &d, &z)
|
2024-10-21 14:30:46 +02:00
|
|
|
sink ^= dk.EncapsulationKey().Bytes()[0]
|
2023-12-06 16:51:11 +01:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func BenchmarkEncaps(b *testing.B) {
|
2024-10-12 20:22:44 +02:00
|
|
|
seed := make([]byte, SeedSize)
|
|
|
|
|
rand.Read(seed)
|
2024-04-15 03:56:10 +02:00
|
|
|
var m [messageSize]byte
|
|
|
|
|
rand.Read(m[:])
|
2024-10-21 16:29:23 +02:00
|
|
|
dk, err := NewDecapsulationKey768(seed)
|
2024-10-12 20:22:44 +02:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
2024-10-21 14:30:46 +02:00
|
|
|
ekBytes := dk.EncapsulationKey().Bytes()
|
2024-10-21 16:29:23 +02:00
|
|
|
var c [CiphertextSize768]byte
|
2023-12-06 16:51:11 +01:00
|
|
|
b.ResetTimer()
|
|
|
|
|
for i := 0; i < b.N; i++ {
|
2024-10-21 16:29:23 +02:00
|
|
|
ek, err := NewEncapsulationKey768(ekBytes)
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
2024-10-21 14:30:46 +02:00
|
|
|
c, K := kemEncaps(&c, ek, &m)
|
2023-12-06 16:51:11 +01:00
|
|
|
sink ^= c[0] ^ K[0]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func BenchmarkDecaps(b *testing.B) {
|
2024-10-21 16:29:23 +02:00
|
|
|
dk, err := GenerateKey768()
|
2024-10-12 20:22:44 +02:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
ek := dk.EncapsulationKey()
|
2024-10-21 14:30:46 +02:00
|
|
|
c, _ := ek.Encapsulate()
|
2023-12-06 16:51:11 +01:00
|
|
|
b.ResetTimer()
|
|
|
|
|
for i := 0; i < b.N; i++ {
|
2024-10-21 16:29:23 +02:00
|
|
|
K := kemDecaps(dk, (*[CiphertextSize768]byte)(c))
|
2023-12-06 16:51:11 +01:00
|
|
|
sink ^= K[0]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func BenchmarkRoundTrip(b *testing.B) {
|
2024-10-21 16:29:23 +02:00
|
|
|
dk, err := GenerateKey768()
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
2024-04-15 03:56:10 +02:00
|
|
|
ek := dk.EncapsulationKey()
|
2024-10-21 14:30:46 +02:00
|
|
|
ekBytes := ek.Bytes()
|
|
|
|
|
c, _ := ek.Encapsulate()
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
b.Run("Alice", func(b *testing.B) {
|
|
|
|
|
for i := 0; i < b.N; i++ {
|
2024-10-21 16:29:23 +02:00
|
|
|
dkS, err := GenerateKey768()
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
2024-10-21 14:30:46 +02:00
|
|
|
ekS := dkS.EncapsulationKey().Bytes()
|
2024-04-15 03:56:10 +02:00
|
|
|
sink ^= ekS[0]
|
|
|
|
|
|
2024-10-16 14:31:44 +02:00
|
|
|
Ks, err := dk.Decapsulate(c)
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
2024-04-15 03:56:10 +02:00
|
|
|
sink ^= Ks[0]
|
2023-12-06 16:51:11 +01:00
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
b.Run("Bob", func(b *testing.B) {
|
|
|
|
|
for i := 0; i < b.N; i++ {
|
2024-10-21 16:29:23 +02:00
|
|
|
ek, err := NewEncapsulationKey768(ekBytes)
|
2024-10-21 14:30:46 +02:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
cS, Ks := ek.Encapsulate()
|
2023-12-06 16:51:11 +01:00
|
|
|
if err != nil {
|
|
|
|
|
b.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
sink ^= cS[0] ^ Ks[0]
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|