crypto/internal/fips140/mldsa: unroll NTT and inverseNTT

fips140: off
goos: darwin
goarch: arm64
pkg: crypto/internal/fips140test
cpu: Apple M2
                      │ bade4ade59  │          bade4ade59-dirty          │
                      │   sec/op    │   sec/op     vs base               │
MLDSASign/ML-DSA-44-8   264.8µ ± 0%   244.5µ ± 0%  -7.68% (p=0.000 n=20)

fips140: off
goos: linux
goarch: amd64
pkg: crypto/internal/fips140test
cpu: AMD EPYC 7443P 24-Core Processor
                       │ bade4ade59  │          bade4ade59-dirty          │
                       │   sec/op    │   sec/op     vs base               │
MLDSASign/ML-DSA-44-48   408.7µ ± 3%   386.5µ ± 1%  -5.41% (p=0.000 n=20)

Change-Id: I04d38a48d5105cbcd625cba9398711b26a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/723020
Reviewed-by: Junyang Shao <shaojunyang@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Mark Freeman <markfreeman@google.com>
This commit is contained in:
Filippo Valsorda 2025-11-21 19:24:37 +01:00 committed by Gopher Robot
parent f821fc46c5
commit 9962d95fed

View file

@ -146,19 +146,69 @@ var zetas = [256]fieldElement{4193792, 25847, 5771523, 7861508, 237124, 7602457,
// It implements NTT, according to FIPS 203, Algorithm 9.
func ntt(f ringElement) nttElement {
var m uint8
for len := 128; len >= 1; len /= 2 {
for len := 128; len >= 8; len /= 2 {
for start := 0; start < 256; start += 2 * len {
m++
zeta := zetas[m]
// Bounds check elimination hint.
f, flen := f[start:start+len], f[start+len:start+len+len]
for j := 0; j < len; j++ {
for j := 0; j < len; j += 2 {
t := fieldMontgomeryMul(zeta, flen[j])
flen[j] = fieldSub(f[j], t)
f[j] = fieldAdd(f[j], t)
// Unroll by 2 for performance.
t = fieldMontgomeryMul(zeta, flen[j+1])
flen[j+1] = fieldSub(f[j+1], t)
f[j+1] = fieldAdd(f[j+1], t)
}
}
}
// Unroll len = 4, 2, and 1.
for start := 0; start < 256; start += 8 {
m++
zeta := zetas[m]
t := fieldMontgomeryMul(zeta, f[start+4])
f[start+4] = fieldSub(f[start], t)
f[start] = fieldAdd(f[start], t)
t = fieldMontgomeryMul(zeta, f[start+5])
f[start+5] = fieldSub(f[start+1], t)
f[start+1] = fieldAdd(f[start+1], t)
t = fieldMontgomeryMul(zeta, f[start+6])
f[start+6] = fieldSub(f[start+2], t)
f[start+2] = fieldAdd(f[start+2], t)
t = fieldMontgomeryMul(zeta, f[start+7])
f[start+7] = fieldSub(f[start+3], t)
f[start+3] = fieldAdd(f[start+3], t)
}
for start := 0; start < 256; start += 4 {
m++
zeta := zetas[m]
t := fieldMontgomeryMul(zeta, f[start+2])
f[start+2] = fieldSub(f[start], t)
f[start] = fieldAdd(f[start], t)
t = fieldMontgomeryMul(zeta, f[start+3])
f[start+3] = fieldSub(f[start+1], t)
f[start+1] = fieldAdd(f[start+1], t)
}
for start := 0; start < 256; start += 2 {
m++
zeta := zetas[m]
t := fieldMontgomeryMul(zeta, f[start+1])
f[start+1] = fieldSub(f[start], t)
f[start] = fieldAdd(f[start], t)
}
return nttElement(f)
}
@ -167,20 +217,70 @@ func ntt(f ringElement) nttElement {
// It implements NTT⁻¹, according to FIPS 203, Algorithm 10.
func inverseNTT(f nttElement) ringElement {
var m uint8 = 255
for len := 1; len < 256; len *= 2 {
// Unroll len = 1, 2, and 4.
for start := 0; start < 256; start += 2 {
zeta := zetas[m]
m--
t := f[start]
f[start] = fieldAdd(t, f[start+1])
f[start+1] = fieldMontgomeryMulSub(zeta, f[start+1], t)
}
for start := 0; start < 256; start += 4 {
zeta := zetas[m]
m--
t := f[start]
f[start] = fieldAdd(t, f[start+2])
f[start+2] = fieldMontgomeryMulSub(zeta, f[start+2], t)
t = f[start+1]
f[start+1] = fieldAdd(t, f[start+3])
f[start+3] = fieldMontgomeryMulSub(zeta, f[start+3], t)
}
for start := 0; start < 256; start += 8 {
zeta := zetas[m]
m--
t := f[start]
f[start] = fieldAdd(t, f[start+4])
f[start+4] = fieldMontgomeryMulSub(zeta, f[start+4], t)
t = f[start+1]
f[start+1] = fieldAdd(t, f[start+5])
f[start+5] = fieldMontgomeryMulSub(zeta, f[start+5], t)
t = f[start+2]
f[start+2] = fieldAdd(t, f[start+6])
f[start+6] = fieldMontgomeryMulSub(zeta, f[start+6], t)
t = f[start+3]
f[start+3] = fieldAdd(t, f[start+7])
f[start+7] = fieldMontgomeryMulSub(zeta, f[start+7], t)
}
for len := 8; len < 256; len *= 2 {
for start := 0; start < 256; start += 2 * len {
zeta := zetas[m]
m--
// Bounds check elimination hint.
f, flen := f[start:start+len], f[start+len:start+len+len]
for j := 0; j < len; j++ {
for j := 0; j < len; j += 2 {
t := f[j]
f[j] = fieldAdd(t, flen[j])
// -z * (t - flen[j]) = z * (flen[j] - t)
flen[j] = fieldMontgomeryMulSub(zeta, flen[j], t)
// Unroll by 2 for performance.
t = f[j+1]
f[j+1] = fieldAdd(t, flen[j+1])
flen[j+1] = fieldMontgomeryMulSub(zeta, flen[j+1], t)
}
}
}
for i := range f {
f[i] = fieldMontgomeryMul(f[i], 16382) // 16382 = 256⁻¹ * R mod q
}