crypto/mlkem: avoid a few unnecessary inverse NTT calls

We were mistakenly doing NTT⁻¹ inside the inner loop, on the components
of the inner product intead of the sum, leading to k² = 9 inverse NTT
calls instead of k = 3 inverse NTT.

Surprisingly large speedup as a result.

fips140: off
goos: darwin
goarch: arm64
pkg: crypto/mlkem
cpu: Apple M2
                  │ 4c285e0988  │          4c285e0988-dirty          │
                  │   sec/op    │   sec/op     vs base               │
KeyGen-2            28.95µ ± 3%   28.64µ ± 4%        ~ (p=0.699 n=6)
Encaps-2            43.13µ ± 3%   35.02µ ± 1%  -18.81% (p=0.002 n=6)
Decaps-2            43.80µ ± 1%   35.49µ ± 1%  -18.97% (p=0.002 n=6)
RoundTrip/Alice-2   77.27µ ± 7%   69.12µ ± 3%  -10.55% (p=0.002 n=6)
RoundTrip/Bob-2     43.08µ ± 2%   35.14µ ± 3%  -18.44% (p=0.002 n=6)
geomean             44.88µ        38.67µ       -13.84%

Change-Id: I6a6a69649c1378411c9aca75d473fd5b9984a609
Reviewed-on: https://go-review.googlesource.com/c/go/+/715381
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Reviewed-by: Mark Freeman <markfreeman@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
This commit is contained in:
Filippo Valsorda 2025-10-27 18:58:52 +01:00 committed by Gopher Robot
parent 590cf18daf
commit 592775ec7d
2 changed files with 6 additions and 4 deletions

View file

@ -369,11 +369,12 @@ func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[mes
u := make([]ringElement, k1024) // NTT⁻¹(AT ◦ r) + e1 u := make([]ringElement, k1024) // NTT⁻¹(AT ◦ r) + e1
for i := range u { for i := range u {
u[i] = e1[i] var uHat nttElement
for j := range r { for j := range r {
// Note that i and j are inverted, as we need the transposed of A. // Note that i and j are inverted, as we need the transposed of A.
u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j]))) uHat = polyAdd(uHat, nttMul(ex.a[j*k1024+i], r[j]))
} }
u[i] = polyAdd(e1[i], inverseNTT(uHat))
} }
μ := ringDecodeAndDecompress1(m) μ := ringDecodeAndDecompress1(m)

View file

@ -428,11 +428,12 @@ func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]
u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1 u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1
for i := range u { for i := range u {
u[i] = e1[i] var uHat nttElement
for j := range r { for j := range r {
// Note that i and j are inverted, as we need the transposed of A. // Note that i and j are inverted, as we need the transposed of A.
u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j]))) uHat = polyAdd(uHat, nttMul(ex.a[j*k+i], r[j]))
} }
u[i] = polyAdd(e1[i], inverseNTT(uHat))
} }
μ := ringDecodeAndDecompress1(m) μ := ringDecodeAndDecompress1(m)