crypto/internal/fips140/mldsa: add accumulated field functions test

Change-Id: I283841df8be3ab94ca5e4a867673bd336a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/762940
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Reviewed-by: Michael Pratt <mpratt@google.com>
This commit is contained in:
Filippo Valsorda 2026-04-05 23:59:49 +02:00 committed by Gopher Robot
parent e22e20a1e5
commit d876fda088

View file

@ -5,6 +5,9 @@
package mldsa
import (
"crypto/internal/fips140/sha3"
"encoding/hex"
"fmt"
"math/big"
"testing"
)
@ -368,3 +371,74 @@ func TestZetas(t *testing.T) {
}
}
}
// TestAccumulated computes the hash of the following 12 values, as ASCII
// decimals with an optional leading - sign and separated by newlines, for all
// elements r in q from 0 to q-1:
//
// - r mod± q
// - ‖r‖∞ = |r mod± q|
// - r1, r0 = Power2Round(r)
//
// For ML-DSA-44 (γ₂ = (q - 1) / 88):
// - HighBits(r) = UseHint(0, r)
// - UseHint(1, r)
// - LowBits(r)
// - ‖LowBits(r)‖∞ = |LowBits(r)|
//
// For ML-DSA-65 and ML-DSA-87 (γ₂ = (q - 1) / 32):
// - HighBits(r) = UseHint(0, r)
// - UseHint(1, r)
// - LowBits(r)
// - ‖LowBits(r)‖∞ = |LowBits(r)|
//
// Note that HighBits(r), LowBits(r) = Decompose(r).
func TestAccumulated(t *testing.T) {
if testing.Short() {
t.Skip("skipping accumulated test in short mode")
}
o := sha3.NewShake128()
for x := range uint32(q) {
r, _ := fieldToMontgomery(x)
fmt.Fprintf(o, "%d\n", fieldCenteredMod(r))
fmt.Fprintf(o, "%d\n", fieldInfinityNorm(r))
hi, lo := power2Round(r)
fmt.Fprintf(o, "%d\n", hi)
fmt.Fprintf(o, "%d\n", fieldFromMontgomery(lo))
r1, r0 := decompose88(r)
if r1x := highBits88(fieldFromMontgomery(r)); r1x != r1 {
t.Fatalf("highBits88(%d) = %d, expected %d", x, r1x, r1)
}
if r1h0 := useHint88(r, 0); r1h0 != r1 {
t.Fatalf("useHint88(%d, 0) = %d, expected %d", x, r1h0, r1)
}
fmt.Fprintf(o, "%d\n", r1)
fmt.Fprintf(o, "%d\n", useHint88(r, 1))
fmt.Fprintf(o, "%d\n", r0)
fmt.Fprintf(o, "%d\n", constantTimeAbs(r0))
r1, r0 = decompose32(r)
if r1x := highBits32(fieldFromMontgomery(r)); r1x != r1 {
t.Fatalf("highBits32(%d) = %d, expected %d", x, r1x, r1)
}
if r1h0 := useHint32(r, 0); r1h0 != r1 {
t.Fatalf("useHint32(%d, 0) = %d, expected %d", x, r1h0, r1)
}
fmt.Fprintf(o, "%d\n", r1)
fmt.Fprintf(o, "%d\n", useHint32(r, 1))
fmt.Fprintf(o, "%d\n", r0)
fmt.Fprintf(o, "%d\n", constantTimeAbs(r0))
}
// The expected value is documented at https://c2sp.org/CCTV/ML-DSA, and
// tested against https://github.com/FiloSottile/mldsa-py.
expected := "f930663417278156ab05d940294a77210a809c924d8ab63ec72f4526247602c7"
if got := hex.EncodeToString(o.Sum(nil)); got != expected {
t.Errorf("got %s, expected %s", got, expected)
}
}