[dev.simd] all: merge master (8dd5b13) into dev.simd

Merge List:

+ 2025-11-24 8dd5b13abc cmd/compile: relax stmtline_test on amd64
+ 2025-11-23 feae743bdb cmd/compile: use 32x32->64 multiplies on loong64
+ 2025-11-23 e88be8a128 runtime: fix stale comment for mheap/malloc
+ 2025-11-23 a318843a2a cmd/internal/obj/loong64: optimize duplicate optab entries
+ 2025-11-23 a18294bb6a cmd/internal/obj/arm64, image/gif, runtime, sort: use math/bits to calculate log2
+ 2025-11-23 437323ef7b slices: fix incorrect comment in slices.Insert function documentation
+ 2025-11-23 1993dca400 doc/next: pre-announce end of support for macOS 12 in Go 1.27
+ 2025-11-22 337f7b1f5d cmd/go: update default go directive in mod or work init
+ 2025-11-21 3c26aef8fb cmd/internal/obj/riscv: improve large branch/call/jump tests
+ 2025-11-21 31aa9f800b crypto/tls: use inner hello for earlyData when using QUIC and ECH
+ 2025-11-21 d68aec8db1 runtime: replace trace seqlock with write flag
+ 2025-11-21 8d9906cd34 runtime/trace: add Log benchmark
+ 2025-11-21 6aeacdff38 cmd/go: support sha1 repos when git default is sha256
+ 2025-11-21 9570036ca5 crypto/sha3: make the zero value of SHAKE useable
+ 2025-11-21 155efbbeeb crypto/sha3: make the zero value of SHA3 useable
+ 2025-11-21 6f16669e34 database/sql: don't ignore ColumnConverter for unknown input count
+ 2025-11-21 121bc3e464 runtime/pprof: remove hard-coded sleep in CPU profile reader
+ 2025-11-21 b604148c4e runtime: fix double wakeup in CPU profile buffer
+ 2025-11-21 22f24f90b5 cmd/compile: change testing.B.Loop keep alive semantic
+ 2025-11-21 cfb9d2eb73 net: remove unused linknames
+ 2025-11-21 65ef314f89 net/http: remove unused linknames
+ 2025-11-21 0f32fbc631 net/http: populate Response.Request when using NewFileTransport
+ 2025-11-21 3e0a8e7867 net/http: preserve original path encoding in redirects
+ 2025-11-21 831af61120 net/http: use HTTP 307 redirects in ServeMux
+ 2025-11-21 87269224cb net/http: update Response.Request.URL after redirects on GOOS=js
+ 2025-11-21 7aa9ca729f net/http/cookiejar: treat localhost as secure origin
+ 2025-11-21 f870a1d398 net/url: warn that JoinPath arguments should be escaped
+ 2025-11-21 9962d95fed crypto/internal/fips140/mldsa: unroll NTT and inverseNTT
+ 2025-11-21 f821fc46c5 crypto/internal/fisp140test: update acvptool, test data
+ 2025-11-21 b59efc38a0 crypto/internal/fips140/mldsa: new package
+ 2025-11-21 62741480b8 runtime: remove linkname for gopanic
+ 2025-11-21 7db2f0bb9a crypto/internal/hpke: separate KEM and PublicKey/PrivateKey interfaces
+ 2025-11-21 e15800c0ec crypto/internal/hpke: add ML-KEM and hybrid KEMs, and SHAKE KDFs
+ 2025-11-21 7c985a2df4 crypto/internal/hpke: modularize API and support more ciphersuites
+ 2025-11-21 e7d47ac33d cmd/compile: simplify negative on multiplication
+ 2025-11-21 35d2712b32 net/http: fix typo in Transport docs
+ 2025-11-21 90c970cd0f net: remove unnecessary loop variable copies in tests
+ 2025-11-21 9772d3a690 cmd/cgo: strip top-level const qualifier from argument frame struct
+ 2025-11-21 1903782ade errors: add examples for custom Is/As matching
+ 2025-11-21 ec92bc6d63 cmd/compile: rewrite Rsh to RshU if arguments are proved positive
+ 2025-11-21 3820f94c1d cmd/compile: propagate unsigned relations for Rsh if arguments are positive
+ 2025-11-21 d474f1fd21 cmd/compile: make dse track multiple shadowed ranges
+ 2025-11-21 d0d0a72980 cmd/compile/internal/ssa: correct type of ARM64 conditional instructions
+ 2025-11-21 a9704f89ea internal/runtime/gc/scan: add AVX512 impl of filterNil.
+ 2025-11-21 ccd389036a cmd/internal/objabi: remove -V=goexperiment internal special case
+ 2025-11-21 e7787b9eca runtime: go fmt
+ 2025-11-21 17b3b98796 internal/strconv: go fmt
+ 2025-11-21 c851827c68 internal/trace: go fmt
+ 2025-11-21 f87aaec53d cmd/compile: fix integer overflow in prove pass
+ 2025-11-21 dbd2ab9992 cmd/compile/internal: fix typos
+ 2025-11-21 b9d86baae3 cmd/compile/internal/devirtualize: fix typos
+ 2025-11-20 4b0e3cc1d6 cmd/link: support loading R_LARCH_PCREL20_S2 and R_LARCH_CALL36 relocs
+ 2025-11-20 cdba82c7d6 cmd/internal/obj/loong64: add {,X}VSLT.{B/H/W/V}{,U} instructions support
+ 2025-11-20 bd2b117c2c crypto/tls: add QUICErrorEvent
+ 2025-11-20 3ad2e113fc net/http/httputil: wrap ReverseProxy's outbound request body so Close is a noop
+ 2025-11-20 d58b733646 runtime: track goroutine location until actual STW
+ 2025-11-20 1bc54868d4 cmd/vendor: update to x/tools@68724af
+ 2025-11-20 8c3195973b runtime: disable stack allocation tests on sanitizers
+ 2025-11-20 ff654ea100 net/url: permit colons in the host of postgresql:// URLs
+ 2025-11-20 a662badab9 encoding/json: remove linknames
+ 2025-11-20 5afe237d65 mime: add missing path for mime types in godoc
+ 2025-11-20 c1b7112af8 os/signal: make NotifyContext cancel the context with a cause

Change-Id: Ib93ef643be610dfbdd83ff45095a7b1ca2537b8b
This commit is contained in:
Cherry Mui 2025-11-24 11:03:06 -05:00
commit 220d73cc44
161 changed files with 11170 additions and 1379 deletions

3
api/next/75108.txt Normal file
View file

@ -0,0 +1,3 @@
pkg crypto/tls, const QUICErrorEvent = 10 #75108
pkg crypto/tls, const QUICErrorEvent QUICEventKind #75108
pkg crypto/tls, type QUICEvent struct, Err error #75108

View file

@ -0,0 +1,2 @@
The [QUICConn] type used by QUIC implementations includes new event
for reporting TLS handshake errors.

View file

@ -0,0 +1,2 @@
[NotifyContext] now cancels the returned context with [context.CancelCauseFunc]
and an error indicating which signal was received.

View file

@ -1,6 +1,13 @@
## Ports {#ports} ## Ports {#ports}
### Darwin
<!-- go.dev/issue/75836 -->
Go 1.26 is the last release that will run on macOS 12 Monterey. Go 1.27 will require macOS 13 Ventura or later.
### Windows ### Windows
<!-- go.dev/issue/71671 --> <!-- go.dev/issue/71671 -->
As [announced](/doc/go1.25#windows) in the Go 1.25 release notes, the [broken](/doc/go1.24#windows) 32-bit windows/arm port (`GOOS=windows` `GOARCH=arm`) is removed. As [announced](/doc/go1.25#windows) in the Go 1.25 release notes, the [broken](/doc/go1.24#windows) 32-bit windows/arm port (`GOOS=windows` `GOARCH=arm`) is removed.

View file

@ -597,6 +597,42 @@ lable2:
XVSEQV $15, X2, X4 // 44bc8176 XVSEQV $15, X2, X4 // 44bc8176
XVSEQV $-15, X2, X4 // 44c48176 XVSEQV $-15, X2, X4 // 44c48176
// VSLTB{B,H,W,V}, XVSLTB{B,H,W,V} instruction
VSLTB V1, V2, V3 // 43040670
VSLTH V1, V2, V3 // 43840670
VSLTW V1, V2, V3 // 43040770
VSLTV V1, V2, V3 // 43840770
XVSLTB X1, X2, X3 // 43040674
XVSLTH X1, X2, X3 // 43840674
XVSLTW X1, X2, X3 // 43040774
XVSLTV X1, X2, X3 // 43840774
VSLTB $1, V2, V3 // 43048672
VSLTH $16, V2, V3 // 43c08672
VSLTW $-16, V2, V3 // 43408772
VSLTV $-15, V2, V3 // 43c48772
XVSLTB $1, X2, X3 // 43048676
XVSLTH $16, X2, X3 // 43c08676
XVSLTW $-16, X2, X3 // 43408776
XVSLTV $-16, X2, X3 // 43c08776
// VSLTB{B,H,W,V}U, XVSLTB{B,H,W,V}U instruction
VSLTBU V1, V2, V3 // 43040870
VSLTHU V1, V2, V3 // 43840870
VSLTWU V1, V2, V3 // 43040970
VSLTVU V1, V2, V3 // 43840970
XVSLTBU X1, X2, X3 // 43040874
XVSLTHU X1, X2, X3 // 43840874
XVSLTWU X1, X2, X3 // 43040974
XVSLTVU X1, X2, X3 // 43840974
VSLTBU $0, V2, V3 // 43008872
VSLTHU $31, V2, V3 // 43fc8872
VSLTWU $16, V2, V3 // 43408972
VSLTVU $1, V2, V3 // 43848972
XVSLTBU $0, X2, X3 // 43008876
XVSLTHU $31, X2, X3 // 43fc8876
XVSLTWU $8, X2, X3 // 43208976
XVSLTVU $0, X2, X3 // 43808976
// VPCNT{B,H,W,V}, XVPCNT{B,H,W,V} instruction // VPCNT{B,H,W,V}, XVPCNT{B,H,W,V} instruction
VPCNTB V1, V2 // 22209c72 VPCNTB V1, V2 // 22209c72
VPCNTH V1, V2 // 22249c72 VPCNTH V1, V2 // 22249c72

View file

@ -953,6 +953,12 @@ typedef struct {
} issue69086struct; } issue69086struct;
static int issue690861(issue69086struct* p) { p->b = 1234; return p->c; } static int issue690861(issue69086struct* p) { p->b = 1234; return p->c; }
static int issue690862(unsigned long ul1, unsigned long ul2, unsigned int u, issue69086struct s) { return (int)(s.b); } static int issue690862(unsigned long ul1, unsigned long ul2, unsigned int u, issue69086struct s) { return (int)(s.b); }
char issue75751v = 1;
char * const issue75751p = &issue75751v;
#define issue75751m issue75751p
char * const volatile issue75751p2 = &issue75751v;
#define issue75751m2 issue75751p2
*/ */
import "C" import "C"
@ -2396,3 +2402,8 @@ func test69086(t *testing.T) {
t.Errorf("call: got %d, want 1234", got) t.Errorf("call: got %d, want 1234", got)
} }
} }
// Issue 75751: no runtime test, just make sure it compiles.
func test75751() int {
return int(*C.issue75751m) + int(*C.issue75751m2)
}

View file

@ -457,6 +457,36 @@ func checkImportSymName(s string) {
// Also assumes that gc convention is to word-align the // Also assumes that gc convention is to word-align the
// input and output parameters. // input and output parameters.
func (p *Package) structType(n *Name) (string, int64) { func (p *Package) structType(n *Name) (string, int64) {
// It's possible for us to see a type with a top-level const here,
// which will give us an unusable struct type. See #75751.
// The top-level const will always appear as a final qualifier,
// constructed by typeConv.loadType in the dwarf.QualType case.
// The top-level const is meaningless here and can simply be removed.
stripConst := func(s string) string {
i := strings.LastIndex(s, "const")
if i == -1 {
return s
}
// A top-level const can only be followed by other qualifiers.
if r, ok := strings.CutSuffix(s, "const"); ok {
return strings.TrimSpace(r)
}
var nonConst []string
for _, f := range strings.Fields(s[i:]) {
switch f {
case "const":
case "restrict", "volatile":
nonConst = append(nonConst, f)
default:
return s
}
}
return strings.TrimSpace(s[:i]) + " " + strings.Join(nonConst, " ")
}
var buf strings.Builder var buf strings.Builder
fmt.Fprint(&buf, "struct {\n") fmt.Fprint(&buf, "struct {\n")
off := int64(0) off := int64(0)
@ -468,7 +498,7 @@ func (p *Package) structType(n *Name) (string, int64) {
} }
c := t.Typedef c := t.Typedef
if c == "" { if c == "" {
c = t.C.String() c = stripConst(t.C.String())
} }
fmt.Fprintf(&buf, "\t\t%s p%d;\n", c, i) fmt.Fprintf(&buf, "\t\t%s p%d;\n", c, i)
off += t.Size off += t.Size
@ -484,7 +514,7 @@ func (p *Package) structType(n *Name) (string, int64) {
fmt.Fprintf(&buf, "\t\tchar __pad%d[%d];\n", off, pad) fmt.Fprintf(&buf, "\t\tchar __pad%d[%d];\n", off, pad)
off += pad off += pad
} }
fmt.Fprintf(&buf, "\t\t%s r;\n", t.C) fmt.Fprintf(&buf, "\t\t%s r;\n", stripConst(t.C.String()))
off += t.Size off += t.Size
} }
if off%p.PtrSize != 0 { if off%p.PtrSize != 0 {

View file

@ -0,0 +1,313 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bloop
// This file contains support routines for keeping
// statements alive
// in such loops (example):
//
// for b.Loop() {
// var a, b int
// a = 5
// b = 6
// f(a, b)
// }
//
// The results of a, b and f(a, b) will be kept alive.
//
// Formally, the lhs (if they are [ir.Name]-s) of
// [ir.AssignStmt], [ir.AssignListStmt],
// [ir.AssignOpStmt], and the results of [ir.CallExpr]
// or its args if it doesn't return a value will be kept
// alive.
//
// The keep alive logic is implemented with as wrapping a
// runtime.KeepAlive around the Name.
//
// TODO: currently this is implemented with KeepAlive
// because it will prevent DSE and DCE which is probably
// what we want right now. And KeepAlive takes an ssa
// value instead of a symbol, which is easier to manage.
// But since KeepAlive's context was mainly in the runtime
// and GC, should we implement a new intrinsic that lowers
// to OpVarLive? Peeling out the symbols is a bit tricky
// and also VarLive seems to assume that there exists a
// VarDef on the same symbol that dominates it.
import (
"cmd/compile/internal/base"
"cmd/compile/internal/ir"
"cmd/compile/internal/reflectdata"
"cmd/compile/internal/typecheck"
"cmd/compile/internal/types"
"fmt"
)
// getNameFromNode tries to iteratively peel down the node to
// get the name.
func getNameFromNode(n ir.Node) *ir.Name {
var ret *ir.Name
if n.Op() == ir.ONAME {
ret = n.(*ir.Name)
} else {
// avoid infinite recursion on circular referencing nodes.
seen := map[ir.Node]bool{n: true}
var findName func(ir.Node) bool
findName = func(a ir.Node) bool {
if a.Op() == ir.ONAME {
ret = a.(*ir.Name)
return true
}
if !seen[a] {
seen[a] = true
return ir.DoChildren(a, findName)
}
return false
}
ir.DoChildren(n, findName)
}
return ret
}
// keepAliveAt returns a statement that is either curNode, or a
// block containing curNode followed by a call to runtime.keepAlive for each
// ONAME in ns. These calls ensure that names in ns will be live until
// after curNode's execution.
func keepAliveAt(ns []*ir.Name, curNode ir.Node) ir.Node {
if len(ns) == 0 {
return curNode
}
pos := curNode.Pos()
calls := []ir.Node{curNode}
for _, n := range ns {
if n == nil {
continue
}
if n.Sym() == nil {
continue
}
if n.Sym().IsBlank() {
continue
}
arg := ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], n)
if !n.Type().IsInterface() {
srcRType0 := reflectdata.TypePtrAt(pos, n.Type())
arg.TypeWord = srcRType0
arg.SrcRType = srcRType0
}
callExpr := typecheck.Call(pos,
typecheck.LookupRuntime("KeepAlive"),
[]ir.Node{arg}, false).(*ir.CallExpr)
callExpr.IsCompilerVarLive = true
callExpr.NoInline = true
calls = append(calls, callExpr)
}
return ir.NewBlockStmt(pos, calls)
}
func debugName(name *ir.Name, line string) {
if base.Flag.LowerM > 0 {
if name.Linksym() != nil {
fmt.Printf("%v: %s will be kept alive\n", line, name.Linksym().Name)
} else {
fmt.Printf("%v: expr will be kept alive\n", line)
}
}
}
// preserveStmt transforms stmt so that any names defined/assigned within it
// are used after stmt's execution, preventing their dead code elimination
// and dead store elimination. The return value is the transformed statement.
func preserveStmt(curFn *ir.Func, stmt ir.Node) (ret ir.Node) {
ret = stmt
switch n := stmt.(type) {
case *ir.AssignStmt:
// Peel down struct and slice indexing to get the names
name := getNameFromNode(n.X)
if name != nil {
debugName(name, ir.Line(stmt))
ret = keepAliveAt([]*ir.Name{name}, n)
}
case *ir.AssignListStmt:
names := []*ir.Name{}
for _, lhs := range n.Lhs {
name := getNameFromNode(lhs)
if name != nil {
debugName(name, ir.Line(stmt))
names = append(names, name)
}
}
ret = keepAliveAt(names, n)
case *ir.AssignOpStmt:
name := getNameFromNode(n.X)
if name != nil {
debugName(name, ir.Line(stmt))
ret = keepAliveAt([]*ir.Name{name}, n)
}
case *ir.CallExpr:
names := []*ir.Name{}
curNode := stmt
if n.Fun != nil && n.Fun.Type() != nil && n.Fun.Type().NumResults() != 0 {
// This function's results are not assigned, assign them to
// auto tmps and then keepAliveAt these autos.
// Note: markStmt assumes the context that it's called - this CallExpr is
// not within another OAS2, which is guaranteed by the case above.
results := n.Fun.Type().Results()
lhs := make([]ir.Node, len(results))
for i, res := range results {
tmp := typecheck.TempAt(n.Pos(), curFn, res.Type)
lhs[i] = tmp
names = append(names, tmp)
}
// Create an assignment statement.
assign := typecheck.AssignExpr(
ir.NewAssignListStmt(n.Pos(), ir.OAS2, lhs,
[]ir.Node{n})).(*ir.AssignListStmt)
assign.Def = true
curNode = assign
plural := ""
if len(results) > 1 {
plural = "s"
}
if base.Flag.LowerM > 0 {
fmt.Printf("%v: function result%s will be kept alive\n", ir.Line(stmt), plural)
}
} else {
// This function probably doesn't return anything, keep its args alive.
argTmps := []ir.Node{}
for i, a := range n.Args {
if name := getNameFromNode(a); name != nil {
// If they are name, keep them alive directly.
debugName(name, ir.Line(stmt))
names = append(names, name)
} else if a.Op() == ir.OSLICELIT {
// variadic args are encoded as slice literal.
s := a.(*ir.CompLitExpr)
ns := []*ir.Name{}
for i, n := range s.List {
if name := getNameFromNode(n); name != nil {
debugName(name, ir.Line(a))
ns = append(ns, name)
} else {
// We need a temporary to save this arg.
tmp := typecheck.TempAt(n.Pos(), curFn, n.Type())
argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, n)))
names = append(names, tmp)
s.List[i] = tmp
if base.Flag.LowerM > 0 {
fmt.Printf("%v: function arg will be kept alive\n", ir.Line(n))
}
}
}
names = append(names, ns...)
} else {
// expressions, we need to assign them to temps and change the original arg to reference
// them.
tmp := typecheck.TempAt(n.Pos(), curFn, a.Type())
argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, a)))
names = append(names, tmp)
n.Args[i] = tmp
if base.Flag.LowerM > 0 {
fmt.Printf("%v: function arg will be kept alive\n", ir.Line(stmt))
}
}
}
if len(argTmps) > 0 {
argTmps = append(argTmps, n)
curNode = ir.NewBlockStmt(n.Pos(), argTmps)
}
}
ret = keepAliveAt(names, curNode)
}
return
}
func preserveStmts(curFn *ir.Func, list ir.Nodes) {
for i := range list {
list[i] = preserveStmt(curFn, list[i])
}
}
// isTestingBLoop returns true if it matches the node as a
// testing.(*B).Loop. See issue #61515.
func isTestingBLoop(t ir.Node) bool {
if t.Op() != ir.OFOR {
return false
}
nFor, ok := t.(*ir.ForStmt)
if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
return false
}
n, ok := nFor.Cond.(*ir.CallExpr)
if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
return false
}
name := ir.MethodExprName(n.Fun)
if name == nil {
return false
}
if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
// Attempting to match a function call to testing.(*B).Loop
return true
}
return false
}
type editor struct {
inBloop bool
curFn *ir.Func
}
func (e editor) edit(n ir.Node) ir.Node {
e.inBloop = isTestingBLoop(n) || e.inBloop
// It's in bloop, mark the stmts with bodies.
ir.EditChildren(n, e.edit)
if e.inBloop {
switch n := n.(type) {
case *ir.ForStmt:
preserveStmts(e.curFn, n.Body)
case *ir.IfStmt:
preserveStmts(e.curFn, n.Body)
preserveStmts(e.curFn, n.Else)
case *ir.BlockStmt:
preserveStmts(e.curFn, n.List)
case *ir.CaseClause:
preserveStmts(e.curFn, n.List)
preserveStmts(e.curFn, n.Body)
case *ir.CommClause:
preserveStmts(e.curFn, n.Body)
}
}
return n
}
// BloopWalk performs a walk on all functions in the package
// if it imports testing and wrap the results of all qualified
// statements in a runtime.KeepAlive intrinsic call. See package
// doc for more details.
//
// for b.Loop() {...}
//
// loop's body.
func BloopWalk(pkg *ir.Package) {
hasTesting := false
for _, i := range pkg.Imports {
if i.Path == "testing" {
hasTesting = true
break
}
}
if !hasTesting {
return
}
for _, fn := range pkg.Funcs {
e := editor{false, fn}
ir.EditChildren(fn, e.edit)
}
}

View file

@ -131,7 +131,7 @@ func StaticCall(s *State, call *ir.CallExpr) {
// type assertion that we make here would also have failed, but with a different // type assertion that we make here would also have failed, but with a different
// panic "pkg.Iface is nil, not *pkg.Impl", where previously we would get a nil panic. // panic "pkg.Iface is nil, not *pkg.Impl", where previously we would get a nil panic.
// We fix this, by introducing an additional nilcheck on the itab. // We fix this, by introducing an additional nilcheck on the itab.
// Calling a method on an nil interface (in most cases) is a bug in a program, so it is fine // Calling a method on a nil interface (in most cases) is a bug in a program, so it is fine
// to devirtualize and further (possibly) inline them, even though we would never reach // to devirtualize and further (possibly) inline them, even though we would never reach
// the called function. // the called function.
dt.UseNilPanic = true dt.UseNilPanic = true
@ -197,7 +197,7 @@ var noType types.Type
// concreteType1 analyzes the node n and returns its concrete type if it is statically known. // concreteType1 analyzes the node n and returns its concrete type if it is statically known.
// Otherwise, it returns a nil Type, indicating that a concrete type was not determined. // Otherwise, it returns a nil Type, indicating that a concrete type was not determined.
// When n is known to be statically nil or a self-assignment is detected, in returns a sentinel [noType] type instead. // When n is known to be statically nil or a self-assignment is detected, it returns a sentinel [noType] type instead.
func concreteType1(s *State, n ir.Node, seen map[*ir.Name]struct{}) (outT *types.Type) { func concreteType1(s *State, n ir.Node, seen map[*ir.Name]struct{}) (outT *types.Type) {
nn := n // for debug messages nn := n // for debug messages
@ -310,7 +310,7 @@ func concreteType1(s *State, n ir.Node, seen map[*ir.Name]struct{}) (outT *types
// assignment can be one of: // assignment can be one of:
// - nil - assignment from an interface type. // - nil - assignment from an interface type.
// - *types.Type - assignment from a concrete type (non-interface). // - *types.Type - assignment from a concrete type (non-interface).
// - ir.Node - assignment from a ir.Node. // - ir.Node - assignment from an ir.Node.
// //
// In most cases assignment should be an [ir.Node], but in cases where we // In most cases assignment should be an [ir.Node], but in cases where we
// do not follow the data-flow, we return either a concrete type (*types.Type) or a nil. // do not follow the data-flow, we return either a concrete type (*types.Type) or a nil.
@ -560,8 +560,8 @@ func (s *State) analyze(nodes ir.Nodes) {
assign(n.Key, nil) assign(n.Key, nil)
assign(n.Value, nil) assign(n.Value, nil)
} else { } else {
// We will not reach here in case of an range-over-func, as it is // We will not reach here in case of a range-over-func, as it is
// rewrtten to function calls in the noder package. // rewritten to function calls in the noder package.
base.FatalfAt(n.Pos(), "range over unexpected type %v", n.X.Type()) base.FatalfAt(n.Pos(), "range over unexpected type %v", n.X.Type())
} }
case ir.OSWITCH: case ir.OSWITCH:

View file

@ -45,6 +45,20 @@ func (e *escape) call(ks []hole, call ir.Node) {
fn = ir.StaticCalleeName(v) fn = ir.StaticCalleeName(v)
} }
// argumentParam handles escape analysis of assigning a call
// argument to its corresponding parameter.
argumentParam := func(param *types.Field, arg ir.Node) {
e.rewriteArgument(arg, call, fn)
argument(e.tagHole(ks, fn, param), arg)
}
if call.IsCompilerVarLive {
// Don't escape compiler-inserted KeepAlive.
argumentParam = func(param *types.Field, arg ir.Node) {
argument(e.discardHole(), arg)
}
}
fntype := call.Fun.Type() fntype := call.Fun.Type()
if fn != nil { if fn != nil {
fntype = fn.Type() fntype = fn.Type()
@ -77,13 +91,6 @@ func (e *escape) call(ks []hole, call ir.Node) {
recvArg = call.Fun.(*ir.SelectorExpr).X recvArg = call.Fun.(*ir.SelectorExpr).X
} }
// argumentParam handles escape analysis of assigning a call
// argument to its corresponding parameter.
argumentParam := func(param *types.Field, arg ir.Node) {
e.rewriteArgument(arg, call, fn)
argument(e.tagHole(ks, fn, param), arg)
}
// internal/abi.EscapeNonString forces its argument to be on // internal/abi.EscapeNonString forces its argument to be on
// the heap, if it contains a non-string pointer. // the heap, if it contains a non-string pointer.
// This is used in hash/maphash.Comparable, where we cannot // This is used in hash/maphash.Comparable, where we cannot

View file

@ -8,6 +8,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"cmd/compile/internal/base" "cmd/compile/internal/base"
"cmd/compile/internal/bloop"
"cmd/compile/internal/coverage" "cmd/compile/internal/coverage"
"cmd/compile/internal/deadlocals" "cmd/compile/internal/deadlocals"
"cmd/compile/internal/dwarfgen" "cmd/compile/internal/dwarfgen"
@ -239,6 +240,9 @@ func Main(archInit func(*ssagen.ArchInfo)) {
} }
} }
// Apply bloop markings.
bloop.BloopWalk(typecheck.Target)
// Interleaved devirtualization and inlining. // Interleaved devirtualization and inlining.
base.Timer.Start("fe", "devirtualize-and-inline") base.Timer.Start("fe", "devirtualize-and-inline")
interleaved.DevirtualizeAndInlinePackage(typecheck.Target, profile) interleaved.DevirtualizeAndInlinePackage(typecheck.Target, profile)

View file

@ -254,28 +254,6 @@ func (s *inlClosureState) mark(n ir.Node) ir.Node {
return n // already visited n.X before wrapping return n // already visited n.X before wrapping
} }
if isTestingBLoop(n) {
// No inlining nor devirtualization performed on b.Loop body
if base.Flag.LowerM > 0 {
fmt.Printf("%v: skip inlining within testing.B.loop for %v\n", ir.Line(n), n)
}
// We still want to explore inlining opportunities in other parts of ForStmt.
nFor, _ := n.(*ir.ForStmt)
nForInit := nFor.Init()
for i, x := range nForInit {
if x != nil {
nForInit[i] = s.mark(x)
}
}
if nFor.Cond != nil {
nFor.Cond = s.mark(nFor.Cond)
}
if nFor.Post != nil {
nFor.Post = s.mark(nFor.Post)
}
return n
}
if p != nil { if p != nil {
n = p.X // in this case p was copied in from a (marked) inlined function, this is a new unvisited node. n = p.X // in this case p was copied in from a (marked) inlined function, this is a new unvisited node.
} }
@ -371,29 +349,3 @@ func match(n ir.Node) bool {
} }
return false return false
} }
// isTestingBLoop returns true if it matches the node as a
// testing.(*B).Loop. See issue #61515.
func isTestingBLoop(t ir.Node) bool {
if t.Op() != ir.OFOR {
return false
}
nFor, ok := t.(*ir.ForStmt)
if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
return false
}
n, ok := nFor.Cond.(*ir.CallExpr)
if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
return false
}
name := ir.MethodExprName(n.Fun)
if name == nil {
return false
}
if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
// Attempting to match a function call to testing.(*B).Loop
return true
}
return false
}

View file

@ -193,6 +193,9 @@ type CallExpr struct {
GoDefer bool // whether this call is part of a go or defer statement GoDefer bool // whether this call is part of a go or defer statement
NoInline bool // whether this call must not be inlined NoInline bool // whether this call must not be inlined
UseBuf bool // use stack buffer for backing store (OAPPEND only) UseBuf bool // use stack buffer for backing store (OAPPEND only)
// whether it's a runtime.KeepAlive call the compiler generates to
// keep a variable alive. See #73137.
IsCompilerVarLive bool
} }
func NewCallExpr(pos src.XPos, op Op, fun Node, args []Node) *CallExpr { func NewCallExpr(pos src.XPos, op Op, fun Node, args []Node) *CallExpr {
@ -681,7 +684,7 @@ type TypeAssertExpr struct {
// When set to true, if this assert would panic, then use a nil pointer panic // When set to true, if this assert would panic, then use a nil pointer panic
// instead of an interface conversion panic. // instead of an interface conversion panic.
// It must not be set for type asserts using the commaok form. // It must not be set for type assertions using the commaok form.
UseNilPanic bool UseNilPanic bool
} }

View file

@ -187,6 +187,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
ssa.OpLOONG64DIVD, ssa.OpLOONG64DIVD,
ssa.OpLOONG64MULV, ssa.OpLOONG64MULHV, ssa.OpLOONG64MULHVU, ssa.OpLOONG64MULH, ssa.OpLOONG64MULHU, ssa.OpLOONG64MULV, ssa.OpLOONG64MULHV, ssa.OpLOONG64MULHVU, ssa.OpLOONG64MULH, ssa.OpLOONG64MULHU,
ssa.OpLOONG64DIVV, ssa.OpLOONG64REMV, ssa.OpLOONG64DIVVU, ssa.OpLOONG64REMVU, ssa.OpLOONG64DIVV, ssa.OpLOONG64REMV, ssa.OpLOONG64DIVVU, ssa.OpLOONG64REMVU,
ssa.OpLOONG64MULWVW, ssa.OpLOONG64MULWVWU,
ssa.OpLOONG64FCOPYSGD: ssa.OpLOONG64FCOPYSGD:
p := s.Prog(v.Op.Asm()) p := s.Prog(v.Op.Asm())
p.From.Type = obj.TYPE_REG p.From.Type = obj.TYPE_REG

View file

@ -517,15 +517,15 @@ func init() {
// If the condition 'Cond' evaluates to true against current flags, // If the condition 'Cond' evaluates to true against current flags,
// flags are set to the result of the comparison operation. // flags are set to the result of the comparison operation.
// Otherwise, flags are set to the fallback value 'Nzcv'. // Otherwise, flags are set to the fallback value 'Nzcv'.
{name: "CCMP", argLength: 3, reg: gp2flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CMP arg0 arg1 else flags = Nzcv {name: "CCMP", argLength: 3, reg: gp2flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CMP arg0 arg1 else flags = Nzcv
{name: "CCMN", argLength: 3, reg: gp2flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CMN arg0 arg1 else flags = Nzcv {name: "CCMN", argLength: 3, reg: gp2flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CMN arg0 arg1 else flags = Nzcv
{name: "CCMPconst", argLength: 2, reg: gp1flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CMPconst [ConstValue] arg0 else flags = Nzcv {name: "CCMPconst", argLength: 2, reg: gp1flagsflags, asm: "CCMP", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CMPconst [ConstValue] arg0 else flags = Nzcv
{name: "CCMNconst", argLength: 2, reg: gp1flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CMNconst [ConstValue] arg0 else flags = Nzcv {name: "CCMNconst", argLength: 2, reg: gp1flagsflags, asm: "CCMN", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CMNconst [ConstValue] arg0 else flags = Nzcv
{name: "CCMPW", argLength: 3, reg: gp2flagsflags, asm: "CCMPW", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CMPW arg0 arg1 else flags = Nzcv {name: "CCMPW", argLength: 3, reg: gp2flagsflags, asm: "CCMPW", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CMPW arg0 arg1 else flags = Nzcv
{name: "CCMNW", argLength: 3, reg: gp2flagsflags, asm: "CCMNW", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CMNW arg0 arg1 else flags = Nzcv {name: "CCMNW", argLength: 3, reg: gp2flagsflags, asm: "CCMNW", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CMNW arg0 arg1 else flags = Nzcv
{name: "CCMPWconst", argLength: 2, reg: gp1flagsflags, asm: "CCMPW", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CCMPWconst [ConstValue] arg0 else flags = Nzcv {name: "CCMPWconst", argLength: 2, reg: gp1flagsflags, asm: "CCMPW", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CCMPWconst [ConstValue] arg0 else flags = Nzcv
{name: "CCMNWconst", argLength: 2, reg: gp1flagsflags, asm: "CCMNW", aux: "ARM64ConditionalParams", typ: "Flag"}, // If Cond then flags = CCMNWconst [ConstValue] arg0 else flags = Nzcv {name: "CCMNWconst", argLength: 2, reg: gp1flagsflags, asm: "CCMNW", aux: "ARM64ConditionalParams", typ: "Flags"}, // If Cond then flags = CCMNWconst [ConstValue] arg0 else flags = Nzcv
// function calls // function calls
{name: "CALLstatic", argLength: -1, reg: regInfo{clobbers: callerSave}, aux: "CallOff", clobberFlags: true, call: true}, // call static function aux.(*obj.LSym). last arg=mem, auxint=argsize, returns mem {name: "CALLstatic", argLength: -1, reg: regInfo{clobbers: callerSave}, aux: "CallOff", clobberFlags: true, call: true}, // call static function aux.(*obj.LSym). last arg=mem, auxint=argsize, returns mem

View file

@ -15,6 +15,10 @@
(Select0 (Mul64uover x y)) => (MULV x y) (Select0 (Mul64uover x y)) => (MULV x y)
(Select1 (Mul64uover x y)) => (SGTU <typ.Bool> (MULHVU x y) (MOVVconst <typ.UInt64> [0])) (Select1 (Mul64uover x y)) => (SGTU <typ.Bool> (MULHVU x y) (MOVVconst <typ.UInt64> [0]))
// 32 mul 32 -> 64
(MULV r:(MOVWUreg x) s:(MOVWUreg y)) && r.Uses == 1 && s.Uses == 1 => (MULWVWU x y)
(MULV r:(MOVWreg x) s:(MOVWreg y)) && r.Uses == 1 && s.Uses == 1 => (MULWVW x y)
(Hmul64 ...) => (MULHV ...) (Hmul64 ...) => (MULHV ...)
(Hmul64u ...) => (MULHVU ...) (Hmul64u ...) => (MULHVU ...)
(Hmul32 ...) => (MULH ...) (Hmul32 ...) => (MULH ...)

View file

@ -205,6 +205,8 @@ func init() {
{name: "DIVVU", argLength: 2, reg: gp21, asm: "DIVVU", typ: "UInt64"}, // arg0 / arg1, unsigned {name: "DIVVU", argLength: 2, reg: gp21, asm: "DIVVU", typ: "UInt64"}, // arg0 / arg1, unsigned
{name: "REMV", argLength: 2, reg: gp21, asm: "REMV", typ: "Int64"}, // arg0 / arg1, signed {name: "REMV", argLength: 2, reg: gp21, asm: "REMV", typ: "Int64"}, // arg0 / arg1, signed
{name: "REMVU", argLength: 2, reg: gp21, asm: "REMVU", typ: "UInt64"}, // arg0 / arg1, unsigned {name: "REMVU", argLength: 2, reg: gp21, asm: "REMVU", typ: "UInt64"}, // arg0 / arg1, unsigned
{name: "MULWVW", argLength: 2, reg: gp21, asm: "MULWVW", commutative: true}, // arg0 * arg1, signed, 32-bit mult results in 64-bit
{name: "MULWVWU", argLength: 2, reg: gp21, asm: "MULWVWU", commutative: true}, // arg0 * arg1, unsigned, 32-bit mult results in 64-bit
{name: "ADDF", argLength: 2, reg: fp21, asm: "ADDF", commutative: true}, // arg0 + arg1 {name: "ADDF", argLength: 2, reg: fp21, asm: "ADDF", commutative: true}, // arg0 + arg1
{name: "ADDD", argLength: 2, reg: fp21, asm: "ADDD", commutative: true}, // arg0 + arg1 {name: "ADDD", argLength: 2, reg: fp21, asm: "ADDD", commutative: true}, // arg0 + arg1

View file

@ -200,6 +200,10 @@
(Mul(8|16|32|64) (Neg(8|16|32|64) x) (Neg(8|16|32|64) y)) => (Mul(8|16|32|64) x y) (Mul(8|16|32|64) (Neg(8|16|32|64) x) (Neg(8|16|32|64) y)) => (Mul(8|16|32|64) x y)
// simplify negative on mul if possible
(Neg(8|16|32|64) (Mul(8|16|32|64) x (Const(8|16|32|64) <t> [c]))) => (Mul(8|16|32|64) x (Const(8|16|32|64) <t> [-c]))
(Neg(8|16|32|64) (Mul(8|16|32|64) x (Neg(8|16|32|64) y))) => (Mul(8|16|32|64) x y)
// DeMorgan's Laws // DeMorgan's Laws
(And(8|16|32|64) <t> (Com(8|16|32|64) x) (Com(8|16|32|64) y)) => (Com(8|16|32|64) (Or(8|16|32|64) <t> x y)) (And(8|16|32|64) <t> (Com(8|16|32|64) x) (Com(8|16|32|64) y)) => (Com(8|16|32|64) (Or(8|16|32|64) <t> x y))
(Or(8|16|32|64) <t> (Com(8|16|32|64) x) (Com(8|16|32|64) y)) => (Com(8|16|32|64) (And(8|16|32|64) <t> x y)) (Or(8|16|32|64) <t> (Com(8|16|32|64) x) (Com(8|16|32|64) y)) => (Com(8|16|32|64) (And(8|16|32|64) <t> x y))

View file

@ -10,6 +10,10 @@ import (
"cmd/internal/obj" "cmd/internal/obj"
) )
// maxShadowRanges bounds the number of disjoint byte intervals
// we track per pointer to avoid quadratic behaviour.
const maxShadowRanges = 64
// dse does dead-store elimination on the Function. // dse does dead-store elimination on the Function.
// Dead stores are those which are unconditionally followed by // Dead stores are those which are unconditionally followed by
// another store to the same location, with no intervening load. // another store to the same location, with no intervening load.
@ -24,6 +28,10 @@ func dse(f *Func) {
defer f.retSparseMap(shadowed) defer f.retSparseMap(shadowed)
// localAddrs maps from a local variable (the Aux field of a LocalAddr value) to an instance of a LocalAddr value for that variable in the current block. // localAddrs maps from a local variable (the Aux field of a LocalAddr value) to an instance of a LocalAddr value for that variable in the current block.
localAddrs := map[any]*Value{} localAddrs := map[any]*Value{}
// shadowedRanges stores the actual range data. The 'shadowed' sparseMap stores a 1-based index into this slice.
var shadowedRanges []*shadowRanges
for _, b := range f.Blocks { for _, b := range f.Blocks {
// Find all the stores in this block. Categorize their uses: // Find all the stores in this block. Categorize their uses:
// loadUse contains stores which are used by a subsequent load. // loadUse contains stores which are used by a subsequent load.
@ -89,10 +97,11 @@ func dse(f *Func) {
// Walk backwards looking for dead stores. Keep track of shadowed addresses. // Walk backwards looking for dead stores. Keep track of shadowed addresses.
// A "shadowed address" is a pointer, offset, and size describing a memory region that // A "shadowed address" is a pointer, offset, and size describing a memory region that
// is known to be written. We keep track of shadowed addresses in the shadowed map, // is known to be written. We keep track of shadowed addresses in the shadowed map,
// mapping the ID of the address to a shadowRange where future writes will happen. // mapping the ID of the address to a shadowRanges where future writes will happen.
// Since we're walking backwards, writes to a shadowed region are useless, // Since we're walking backwards, writes to a shadowed region are useless,
// as they will be immediately overwritten. // as they will be immediately overwritten.
shadowed.clear() shadowed.clear()
shadowedRanges = shadowedRanges[:0]
v := last v := last
walkloop: walkloop:
@ -100,6 +109,7 @@ func dse(f *Func) {
// Someone might be reading this memory state. // Someone might be reading this memory state.
// Clear all shadowed addresses. // Clear all shadowed addresses.
shadowed.clear() shadowed.clear()
shadowedRanges = shadowedRanges[:0]
} }
if v.Op == OpStore || v.Op == OpZero { if v.Op == OpStore || v.Op == OpZero {
ptr := v.Args[0] ptr := v.Args[0]
@ -119,9 +129,14 @@ func dse(f *Func) {
ptr = la ptr = la
} }
} }
srNum, _ := shadowed.get(ptr.ID) var si *shadowRanges
sr := shadowRange(srNum) idx, ok := shadowed.get(ptr.ID)
if sr.contains(off, off+sz) { if ok {
// The sparseMap stores a 1-based index, so we subtract 1.
si = shadowedRanges[idx-1]
}
if si != nil && si.contains(off, off+sz) {
// Modify the store/zero into a copy of the memory state, // Modify the store/zero into a copy of the memory state,
// effectively eliding the store operation. // effectively eliding the store operation.
if v.Op == OpStore { if v.Op == OpStore {
@ -136,7 +151,13 @@ func dse(f *Func) {
v.Op = OpCopy v.Op = OpCopy
} else { } else {
// Extend shadowed region. // Extend shadowed region.
shadowed.set(ptr.ID, int32(sr.merge(off, off+sz))) if si == nil {
si = &shadowRanges{}
shadowedRanges = append(shadowedRanges, si)
// Store a 1-based index in the sparseMap.
shadowed.set(ptr.ID, int32(len(shadowedRanges)))
}
si.add(off, off+sz)
} }
} }
// walk to previous store // walk to previous store
@ -156,46 +177,51 @@ func dse(f *Func) {
} }
} }
// A shadowRange encodes a set of byte offsets [lo():hi()] from // shadowRange represents a single byte range [lo,hi] that will be written.
// a given pointer that will be written to later in the block. type shadowRange struct {
// A zero shadowRange encodes an empty shadowed range. lo, hi uint16
type shadowRange int32
func (sr shadowRange) lo() int64 {
return int64(sr & 0xffff)
} }
func (sr shadowRange) hi() int64 { // shadowRanges stores an unordered collection of disjoint byte ranges.
return int64((sr >> 16) & 0xffff) type shadowRanges struct {
ranges []shadowRange
} }
// contains reports whether [lo:hi] is completely within sr. // contains reports whether [lo:hi] is completely within sr.
func (sr shadowRange) contains(lo, hi int64) bool { func (sr *shadowRanges) contains(lo, hi int64) bool {
return lo >= sr.lo() && hi <= sr.hi() for _, r := range sr.ranges {
if lo >= int64(r.lo) && hi <= int64(r.hi) {
return true
}
}
return false
} }
// merge returns the union of sr and [lo:hi]. func (sr *shadowRanges) add(lo, hi int64) {
// merge is allowed to return something smaller than the union. // Ignore the store if:
func (sr shadowRange) merge(lo, hi int64) shadowRange { // - the range doesn't fit in 16 bits, or
if lo < 0 || hi > 0xffff { // - we already track maxShadowRanges intervals.
// Ignore offsets that are too large or small. // The cap prevents a theoretical O(n^2) blow-up.
return sr if lo < 0 || hi > 0xffff || len(sr.ranges) >= maxShadowRanges {
return
} }
if sr.lo() == sr.hi() { nlo := lo
// Old range is empty - use new one. nhi := hi
return shadowRange(lo + hi<<16) out := sr.ranges[:0]
for _, r := range sr.ranges {
if nhi < int64(r.lo) || nlo > int64(r.hi) {
out = append(out, r)
continue
} }
if hi < sr.lo() || lo > sr.hi() { if int64(r.lo) < nlo {
// The two regions don't overlap or abut, so we would nlo = int64(r.lo)
// have to keep track of multiple disjoint ranges.
// Because we can only keep one, keep the larger one.
if sr.hi()-sr.lo() >= hi-lo {
return sr
} }
return shadowRange(lo + hi<<16) if int64(r.hi) > nhi {
nhi = int64(r.hi)
} }
// Regions overlap or abut - compute the union. }
return shadowRange(min(lo, sr.lo()) + max(hi, sr.hi())<<16) sr.ranges = append(out, shadowRange{uint16(nlo), uint16(nhi)})
} }
// elimDeadAutosGeneric deletes autos that are never accessed. To achieve this // elimDeadAutosGeneric deletes autos that are never accessed. To achieve this

View file

@ -7,6 +7,8 @@ package ssa
import ( import (
"cmd/compile/internal/types" "cmd/compile/internal/types"
"cmd/internal/src" "cmd/internal/src"
"fmt"
"sort"
"testing" "testing"
) )
@ -172,3 +174,335 @@ func TestDeadStoreSmallStructInit(t *testing.T) {
t.Errorf("dead store not removed") t.Errorf("dead store not removed")
} }
} }
func TestDeadStoreArrayGap(t *testing.T) {
c := testConfig(t)
ptr := c.config.Types.BytePtr
i64 := c.config.Types.Int64
typ := types.NewArray(i64, 5)
tmp := c.Temp(typ)
fun := c.Fun("entry",
Bloc("entry",
Valu("start", OpInitMem, types.TypeMem, 0, nil),
Valu("sp", OpSP, c.config.Types.Uintptr, 0, nil),
Valu("base", OpLocalAddr, ptr, 0, tmp, "sp", "start"),
Valu("p0", OpOffPtr, ptr, 0, nil, "base"),
Valu("p1", OpOffPtr, ptr, 8, nil, "base"),
Valu("p2", OpOffPtr, ptr, 16, nil, "base"),
Valu("p3", OpOffPtr, ptr, 24, nil, "base"),
Valu("p4", OpOffPtr, ptr, 32, nil, "base"),
Valu("one", OpConst64, i64, 1, nil),
Valu("seven", OpConst64, i64, 7, nil),
Valu("zero", OpConst64, i64, 0, nil),
Valu("mem0", OpZero, types.TypeMem, 40, typ, "base", "start"),
Valu("s0", OpStore, types.TypeMem, 0, i64, "p0", "one", "mem0"),
Valu("s1", OpStore, types.TypeMem, 0, i64, "p1", "seven", "s0"),
Valu("s2", OpStore, types.TypeMem, 0, i64, "p3", "one", "s1"),
Valu("s3", OpStore, types.TypeMem, 0, i64, "p4", "one", "s2"),
Valu("s4", OpStore, types.TypeMem, 0, i64, "p2", "zero", "s3"),
Goto("exit")),
Bloc("exit",
Exit("s4")))
CheckFunc(fun.f)
dse(fun.f)
CheckFunc(fun.f)
if op := fun.values["mem0"].Op; op != OpCopy {
t.Fatalf("dead Zero not removed: got %s, want OpCopy", op)
}
}
func TestShadowRanges(t *testing.T) {
t.Run("simple insert & contains", func(t *testing.T) {
var sr shadowRanges
sr.add(10, 20)
wantRanges(t, sr.ranges, [][2]uint16{{10, 20}})
if !sr.contains(12, 18) || !sr.contains(10, 20) {
t.Fatalf("contains failed after simple add")
}
if sr.contains(9, 11) || sr.contains(11, 21) {
t.Fatalf("contains erroneously true for non-contained range")
}
})
t.Run("merge overlapping", func(t *testing.T) {
var sr shadowRanges
sr.add(10, 20)
sr.add(15, 25)
wantRanges(t, sr.ranges, [][2]uint16{{10, 25}})
if !sr.contains(13, 24) {
t.Fatalf("contains should be true after merge")
}
})
t.Run("merge touching boundary", func(t *testing.T) {
var sr shadowRanges
sr.add(100, 150)
// touches at 150 - should coalesce
sr.add(150, 180)
wantRanges(t, sr.ranges, [][2]uint16{{100, 180}})
})
t.Run("union across several ranges", func(t *testing.T) {
var sr shadowRanges
sr.add(10, 20)
sr.add(30, 40)
// bridges second, not first
sr.add(25, 35)
wantRanges(t, sr.ranges, [][2]uint16{{10, 20}, {25, 40}})
// envelops everything
sr.add(5, 50)
wantRanges(t, sr.ranges, [][2]uint16{{5, 50}})
})
t.Run("disjoint intervals stay separate", func(t *testing.T) {
var sr shadowRanges
sr.add(10, 20)
sr.add(22, 30)
wantRanges(t, sr.ranges, [][2]uint16{{10, 20}, {22, 30}})
// spans both
if sr.contains(15, 25) {
t.Fatalf("contains across two disjoint ranges should be false")
}
})
t.Run("large uint16 offsets still work", func(t *testing.T) {
var sr shadowRanges
sr.add(40000, 45000)
if !sr.contains(42000, 43000) {
t.Fatalf("contains failed for large uint16 values")
}
})
t.Run("out-of-bounds inserts ignored", func(t *testing.T) {
var sr shadowRanges
sr.add(10, 20)
sr.add(-5, 5)
sr.add(70000, 70010)
wantRanges(t, sr.ranges, [][2]uint16{{10, 20}})
})
}
// canonicalise order for comparisons
func sortRanges(r []shadowRange) {
sort.Slice(r, func(i, j int) bool { return r[i].lo < r[j].lo })
}
// compare actual slice with expected pairs
func wantRanges(t *testing.T, got []shadowRange, want [][2]uint16) {
t.Helper()
sortRanges(got)
if len(got) != len(want) {
t.Fatalf("len(ranges)=%d, want %d (got=%v)", len(got), len(want), got)
}
for i, w := range want {
if got[i].lo != w[0] || got[i].hi != w[1] {
t.Fatalf("range %d = [%d,%d], want [%d,%d] (full=%v)",
i, got[i].lo, got[i].hi, w[0], w[1], got)
}
}
}
func BenchmarkDeadStore(b *testing.B) {
cfg := testConfig(b)
ptr := cfg.config.Types.BytePtr
f := cfg.Fun("entry",
Bloc("entry",
Valu("start", OpInitMem, types.TypeMem, 0, nil),
Valu("sb", OpSB, cfg.config.Types.Uintptr, 0, nil),
Valu("v", OpConstBool, cfg.config.Types.Bool, 1, nil),
Valu("a1", OpAddr, ptr, 0, nil, "sb"),
Valu("a2", OpAddr, ptr, 0, nil, "sb"),
Valu("a3", OpAddr, ptr, 0, nil, "sb"),
Valu("z1", OpZero, types.TypeMem, 1, cfg.config.Types.Bool, "a3", "start"),
Valu("s1", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "a1", "v", "z1"),
Valu("s2", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "a2", "v", "s1"),
Valu("s3", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "a1", "v", "s2"),
Valu("s4", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "a3", "v", "s3"),
Goto("exit")),
Bloc("exit",
Exit("s3")))
runBench(b, func() {
dse(f.f)
})
}
func BenchmarkDeadStorePhi(b *testing.B) {
cfg := testConfig(b)
ptr := cfg.config.Types.BytePtr
f := cfg.Fun("entry",
Bloc("entry",
Valu("start", OpInitMem, types.TypeMem, 0, nil),
Valu("sb", OpSB, cfg.config.Types.Uintptr, 0, nil),
Valu("v", OpConstBool, cfg.config.Types.Bool, 1, nil),
Valu("addr", OpAddr, ptr, 0, nil, "sb"),
Goto("loop")),
Bloc("loop",
Valu("phi", OpPhi, types.TypeMem, 0, nil, "start", "store"),
Valu("store", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "addr", "v", "phi"),
If("v", "loop", "exit")),
Bloc("exit",
Exit("store")))
runBench(b, func() {
dse(f.f)
})
}
func BenchmarkDeadStoreTypes(b *testing.B) {
cfg := testConfig(b)
t1 := cfg.config.Types.UInt64.PtrTo()
t2 := cfg.config.Types.UInt32.PtrTo()
f := cfg.Fun("entry",
Bloc("entry",
Valu("start", OpInitMem, types.TypeMem, 0, nil),
Valu("sb", OpSB, cfg.config.Types.Uintptr, 0, nil),
Valu("v", OpConstBool, cfg.config.Types.Bool, 1, nil),
Valu("a1", OpAddr, t1, 0, nil, "sb"),
Valu("a2", OpAddr, t2, 0, nil, "sb"),
Valu("s1", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "a1", "v", "start"),
Valu("s2", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "a2", "v", "s1"),
Goto("exit")),
Bloc("exit",
Exit("s2")))
cse(f.f)
runBench(b, func() {
dse(f.f)
})
}
func BenchmarkDeadStoreUnsafe(b *testing.B) {
cfg := testConfig(b)
ptr := cfg.config.Types.UInt64.PtrTo()
f := cfg.Fun("entry",
Bloc("entry",
Valu("start", OpInitMem, types.TypeMem, 0, nil),
Valu("sb", OpSB, cfg.config.Types.Uintptr, 0, nil),
Valu("v", OpConstBool, cfg.config.Types.Bool, 1, nil),
Valu("a1", OpAddr, ptr, 0, nil, "sb"),
Valu("s1", OpStore, types.TypeMem, 0, cfg.config.Types.Int64, "a1", "v", "start"),
Valu("s2", OpStore, types.TypeMem, 0, cfg.config.Types.Bool, "a1", "v", "s1"),
Goto("exit")),
Bloc("exit",
Exit("s2")))
cse(f.f)
runBench(b, func() {
dse(f.f)
})
}
func BenchmarkDeadStoreSmallStructInit(b *testing.B) {
cfg := testConfig(b)
ptr := cfg.config.Types.BytePtr
typ := types.NewStruct([]*types.Field{
types.NewField(src.NoXPos, &types.Sym{Name: "A"}, cfg.config.Types.Int),
types.NewField(src.NoXPos, &types.Sym{Name: "B"}, cfg.config.Types.Int),
})
tmp := cfg.Temp(typ)
f := cfg.Fun("entry",
Bloc("entry",
Valu("start", OpInitMem, types.TypeMem, 0, nil),
Valu("sp", OpSP, cfg.config.Types.Uintptr, 0, nil),
Valu("zero", OpConst64, cfg.config.Types.Int, 0, nil),
Valu("v6", OpLocalAddr, ptr, 0, tmp, "sp", "start"),
Valu("v3", OpOffPtr, ptr, 8, nil, "v6"),
Valu("v22", OpOffPtr, ptr, 0, nil, "v6"),
Valu("s1", OpStore, types.TypeMem, 0, cfg.config.Types.Int, "v22", "zero", "start"),
Valu("s2", OpStore, types.TypeMem, 0, cfg.config.Types.Int, "v3", "zero", "s1"),
Valu("v8", OpLocalAddr, ptr, 0, tmp, "sp", "s2"),
Valu("v23", OpOffPtr, ptr, 8, nil, "v8"),
Valu("v25", OpOffPtr, ptr, 0, nil, "v8"),
Valu("s3", OpStore, types.TypeMem, 0, cfg.config.Types.Int, "v25", "zero", "s2"),
Valu("s4", OpStore, types.TypeMem, 0, cfg.config.Types.Int, "v23", "zero", "s3"),
Goto("exit")),
Bloc("exit",
Exit("s4")))
cse(f.f)
runBench(b, func() {
dse(f.f)
})
}
func BenchmarkDeadStoreLargeBlock(b *testing.B) {
// create a very large block with many shadowed stores
const (
addrCount = 128
// first 7 are dead
storesPerAddr = 8
)
cfg := testConfig(b)
ptrType := cfg.config.Types.BytePtr
boolType := cfg.config.Types.Bool
items := []interface{}{
Valu("start", OpInitMem, types.TypeMem, 0, nil),
Valu("sb", OpSB, cfg.config.Types.Uintptr, 0, nil),
Valu("v", OpConstBool, boolType, 1, nil),
}
for i := 0; i < addrCount; i++ {
items = append(items,
Valu(fmt.Sprintf("addr%d", i), OpAddr, ptrType, 0, nil, "sb"),
)
}
prev := "start"
for round := 0; round < storesPerAddr; round++ {
for i := 0; i < addrCount; i++ {
store := fmt.Sprintf("s_%03d_%d", i, round)
addr := fmt.Sprintf("addr%d", i)
items = append(items,
Valu(store, OpStore, types.TypeMem, 0, boolType, addr, "v", prev),
)
prev = store
}
}
items = append(items, Goto("exit"))
entryBlk := Bloc("entry", items...)
exitBlk := Bloc("exit", Exit(prev))
f := cfg.Fun("stress", entryBlk, exitBlk)
runBench(b, func() {
dse(f.f)
})
}
func runBench(b *testing.B, build func()) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
build()
}
}

View file

@ -4306,6 +4306,8 @@ const (
OpLOONG64DIVVU OpLOONG64DIVVU
OpLOONG64REMV OpLOONG64REMV
OpLOONG64REMVU OpLOONG64REMVU
OpLOONG64MULWVW
OpLOONG64MULWVWU
OpLOONG64ADDF OpLOONG64ADDF
OpLOONG64ADDD OpLOONG64ADDD
OpLOONG64SUBF OpLOONG64SUBF
@ -66871,6 +66873,36 @@ var opcodeTable = [...]opInfo{
}, },
}, },
}, },
{
name: "MULWVW",
argLen: 2,
commutative: true,
asm: loong64.AMULWVW,
reg: regInfo{
inputs: []inputInfo{
{0, 1073741816}, // R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R18 R19 R20 R21 g R23 R24 R25 R26 R27 R28 R29 R31
{1, 1073741817}, // ZERO R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R18 R19 R20 R21 g R23 R24 R25 R26 R27 R28 R29 R31
},
outputs: []outputInfo{
{0, 1071644664}, // R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R18 R19 R20 R21 R23 R24 R25 R26 R27 R28 R29 R31
},
},
},
{
name: "MULWVWU",
argLen: 2,
commutative: true,
asm: loong64.AMULWVWU,
reg: regInfo{
inputs: []inputInfo{
{0, 1073741816}, // R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R18 R19 R20 R21 g R23 R24 R25 R26 R27 R28 R29 R31
{1, 1073741817}, // ZERO R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R18 R19 R20 R21 g R23 R24 R25 R26 R27 R28 R29 R31
},
outputs: []outputInfo{
{0, 1071644664}, // R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R18 R19 R20 R21 R23 R24 R25 R26 R27 R28 R29 R31
},
},
},
{ {
name: "ADDF", name: "ADDF",
argLen: 2, argLen: 2,

View file

@ -2051,8 +2051,11 @@ func (ft *factsTable) detectSliceLenRelation(v *Value) {
return return
} }
slice := v.Args[0].Args[0]
index := v.Args[1] index := v.Args[1]
if !ft.isNonNegative(index) {
return
}
slice := v.Args[0].Args[0]
for o := ft.orderings[index.ID]; o != nil; o = o.next { for o := ft.orderings[index.ID]; o != nil; o = o.next {
if o.d != signed { if o.d != signed {
@ -2471,9 +2474,18 @@ func addLocalFacts(ft *factsTable, b *Block) {
//ft.update(b, v, v.Args[0], unsigned, gt|eq) //ft.update(b, v, v.Args[0], unsigned, gt|eq)
//ft.update(b, v, v.Args[1], unsigned, gt|eq) //ft.update(b, v, v.Args[1], unsigned, gt|eq)
case OpDiv64, OpDiv32, OpDiv16, OpDiv8: case OpDiv64, OpDiv32, OpDiv16, OpDiv8:
if ft.isNonNegative(v.Args[0]) && ft.isNonNegative(v.Args[1]) { if !ft.isNonNegative(v.Args[1]) {
ft.update(b, v, v.Args[0], unsigned, lt|eq) break
} }
fallthrough
case OpRsh8x64, OpRsh8x32, OpRsh8x16, OpRsh8x8,
OpRsh16x64, OpRsh16x32, OpRsh16x16, OpRsh16x8,
OpRsh32x64, OpRsh32x32, OpRsh32x16, OpRsh32x8,
OpRsh64x64, OpRsh64x32, OpRsh64x16, OpRsh64x8:
if !ft.isNonNegative(v.Args[0]) {
break
}
fallthrough
case OpDiv64u, OpDiv32u, OpDiv16u, OpDiv8u, case OpDiv64u, OpDiv32u, OpDiv16u, OpDiv8u,
OpRsh8Ux64, OpRsh8Ux32, OpRsh8Ux16, OpRsh8Ux8, OpRsh8Ux64, OpRsh8Ux32, OpRsh8Ux16, OpRsh8Ux8,
OpRsh16Ux64, OpRsh16Ux32, OpRsh16Ux16, OpRsh16Ux8, OpRsh16Ux64, OpRsh16Ux32, OpRsh16Ux16, OpRsh16Ux8,
@ -2488,12 +2500,17 @@ func addLocalFacts(ft *factsTable, b *Block) {
zl := ft.limits[z.ID] zl := ft.limits[z.ID]
var uminDivisor uint64 var uminDivisor uint64
switch v.Op { switch v.Op {
case OpDiv64u, OpDiv32u, OpDiv16u, OpDiv8u: case OpDiv64u, OpDiv32u, OpDiv16u, OpDiv8u,
OpDiv64, OpDiv32, OpDiv16, OpDiv8:
uminDivisor = zl.umin uminDivisor = zl.umin
case OpRsh8Ux64, OpRsh8Ux32, OpRsh8Ux16, OpRsh8Ux8, case OpRsh8Ux64, OpRsh8Ux32, OpRsh8Ux16, OpRsh8Ux8,
OpRsh16Ux64, OpRsh16Ux32, OpRsh16Ux16, OpRsh16Ux8, OpRsh16Ux64, OpRsh16Ux32, OpRsh16Ux16, OpRsh16Ux8,
OpRsh32Ux64, OpRsh32Ux32, OpRsh32Ux16, OpRsh32Ux8, OpRsh32Ux64, OpRsh32Ux32, OpRsh32Ux16, OpRsh32Ux8,
OpRsh64Ux64, OpRsh64Ux32, OpRsh64Ux16, OpRsh64Ux8: OpRsh64Ux64, OpRsh64Ux32, OpRsh64Ux16, OpRsh64Ux8,
OpRsh8x64, OpRsh8x32, OpRsh8x16, OpRsh8x8,
OpRsh16x64, OpRsh16x32, OpRsh16x16, OpRsh16x8,
OpRsh32x64, OpRsh32x32, OpRsh32x16, OpRsh32x8,
OpRsh64x64, OpRsh64x32, OpRsh64x16, OpRsh64x8:
uminDivisor = 1 << zl.umin uminDivisor = 1 << zl.umin
default: default:
panic("unreachable") panic("unreachable")
@ -2655,6 +2672,22 @@ var unsignedOp = map[Op]Op{
OpMod16: OpMod16u, OpMod16: OpMod16u,
OpMod32: OpMod32u, OpMod32: OpMod32u,
OpMod64: OpMod64u, OpMod64: OpMod64u,
OpRsh8x8: OpRsh8Ux8,
OpRsh8x16: OpRsh8Ux16,
OpRsh8x32: OpRsh8Ux32,
OpRsh8x64: OpRsh8Ux64,
OpRsh16x8: OpRsh16Ux8,
OpRsh16x16: OpRsh16Ux16,
OpRsh16x32: OpRsh16Ux32,
OpRsh16x64: OpRsh16Ux64,
OpRsh32x8: OpRsh32Ux8,
OpRsh32x16: OpRsh32Ux16,
OpRsh32x32: OpRsh32Ux32,
OpRsh32x64: OpRsh32Ux64,
OpRsh64x8: OpRsh64Ux8,
OpRsh64x16: OpRsh64Ux16,
OpRsh64x32: OpRsh64Ux32,
OpRsh64x64: OpRsh64Ux64,
} }
var bytesizeToConst = [...]Op{ var bytesizeToConst = [...]Op{
@ -2741,8 +2774,15 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
case OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64, case OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64,
OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64, OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64,
OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64, OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64,
OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64, OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64:
OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64, if ft.isNonNegative(v.Args[0]) {
if b.Func.pass.debug > 0 {
b.Func.Warnl(v.Pos, "Proved %v is unsigned", v.Op)
}
v.Op = unsignedOp[v.Op]
}
fallthrough
case OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64,
OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64, OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64,
OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64, OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64,
OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64, OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64,

View file

@ -5866,6 +5866,54 @@ func rewriteValueLOONG64_OpLOONG64MULV(v *Value) bool {
v_0 := v.Args[0] v_0 := v.Args[0]
b := v.Block b := v.Block
config := b.Func.Config config := b.Func.Config
// match: (MULV r:(MOVWUreg x) s:(MOVWUreg y))
// cond: r.Uses == 1 && s.Uses == 1
// result: (MULWVWU x y)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
r := v_0
if r.Op != OpLOONG64MOVWUreg {
continue
}
x := r.Args[0]
s := v_1
if s.Op != OpLOONG64MOVWUreg {
continue
}
y := s.Args[0]
if !(r.Uses == 1 && s.Uses == 1) {
continue
}
v.reset(OpLOONG64MULWVWU)
v.AddArg2(x, y)
return true
}
break
}
// match: (MULV r:(MOVWreg x) s:(MOVWreg y))
// cond: r.Uses == 1 && s.Uses == 1
// result: (MULWVW x y)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
r := v_0
if r.Op != OpLOONG64MOVWreg {
continue
}
x := r.Args[0]
s := v_1
if s.Op != OpLOONG64MOVWreg {
continue
}
y := s.Args[0]
if !(r.Uses == 1 && s.Uses == 1) {
continue
}
v.reset(OpLOONG64MULWVW)
v.AddArg2(x, y)
return true
}
break
}
// match: (MULV _ (MOVVconst [0])) // match: (MULV _ (MOVVconst [0]))
// result: (MOVVconst [0]) // result: (MOVVconst [0])
for { for {

View file

@ -18069,6 +18069,51 @@ func rewriteValuegeneric_OpNeg16(v *Value) bool {
v.AuxInt = int16ToAuxInt(-c) v.AuxInt = int16ToAuxInt(-c)
return true return true
} }
// match: (Neg16 (Mul16 x (Const16 <t> [c])))
// result: (Mul16 x (Const16 <t> [-c]))
for {
if v_0.Op != OpMul16 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpConst16 {
continue
}
t := v_0_1.Type
c := auxIntToInt16(v_0_1.AuxInt)
v.reset(OpMul16)
v0 := b.NewValue0(v.Pos, OpConst16, t)
v0.AuxInt = int16ToAuxInt(-c)
v.AddArg2(x, v0)
return true
}
break
}
// match: (Neg16 (Mul16 x (Neg16 y)))
// result: (Mul16 x y)
for {
if v_0.Op != OpMul16 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpNeg16 {
continue
}
y := v_0_1.Args[0]
v.reset(OpMul16)
v.AddArg2(x, y)
return true
}
break
}
// match: (Neg16 (Sub16 x y)) // match: (Neg16 (Sub16 x y))
// result: (Sub16 y x) // result: (Sub16 y x)
for { for {
@ -18121,6 +18166,51 @@ func rewriteValuegeneric_OpNeg32(v *Value) bool {
v.AuxInt = int32ToAuxInt(-c) v.AuxInt = int32ToAuxInt(-c)
return true return true
} }
// match: (Neg32 (Mul32 x (Const32 <t> [c])))
// result: (Mul32 x (Const32 <t> [-c]))
for {
if v_0.Op != OpMul32 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpConst32 {
continue
}
t := v_0_1.Type
c := auxIntToInt32(v_0_1.AuxInt)
v.reset(OpMul32)
v0 := b.NewValue0(v.Pos, OpConst32, t)
v0.AuxInt = int32ToAuxInt(-c)
v.AddArg2(x, v0)
return true
}
break
}
// match: (Neg32 (Mul32 x (Neg32 y)))
// result: (Mul32 x y)
for {
if v_0.Op != OpMul32 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpNeg32 {
continue
}
y := v_0_1.Args[0]
v.reset(OpMul32)
v.AddArg2(x, y)
return true
}
break
}
// match: (Neg32 (Sub32 x y)) // match: (Neg32 (Sub32 x y))
// result: (Sub32 y x) // result: (Sub32 y x)
for { for {
@ -18192,6 +18282,51 @@ func rewriteValuegeneric_OpNeg64(v *Value) bool {
v.AuxInt = int64ToAuxInt(-c) v.AuxInt = int64ToAuxInt(-c)
return true return true
} }
// match: (Neg64 (Mul64 x (Const64 <t> [c])))
// result: (Mul64 x (Const64 <t> [-c]))
for {
if v_0.Op != OpMul64 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpConst64 {
continue
}
t := v_0_1.Type
c := auxIntToInt64(v_0_1.AuxInt)
v.reset(OpMul64)
v0 := b.NewValue0(v.Pos, OpConst64, t)
v0.AuxInt = int64ToAuxInt(-c)
v.AddArg2(x, v0)
return true
}
break
}
// match: (Neg64 (Mul64 x (Neg64 y)))
// result: (Mul64 x y)
for {
if v_0.Op != OpMul64 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpNeg64 {
continue
}
y := v_0_1.Args[0]
v.reset(OpMul64)
v.AddArg2(x, y)
return true
}
break
}
// match: (Neg64 (Sub64 x y)) // match: (Neg64 (Sub64 x y))
// result: (Sub64 y x) // result: (Sub64 y x)
for { for {
@ -18263,6 +18398,51 @@ func rewriteValuegeneric_OpNeg8(v *Value) bool {
v.AuxInt = int8ToAuxInt(-c) v.AuxInt = int8ToAuxInt(-c)
return true return true
} }
// match: (Neg8 (Mul8 x (Const8 <t> [c])))
// result: (Mul8 x (Const8 <t> [-c]))
for {
if v_0.Op != OpMul8 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpConst8 {
continue
}
t := v_0_1.Type
c := auxIntToInt8(v_0_1.AuxInt)
v.reset(OpMul8)
v0 := b.NewValue0(v.Pos, OpConst8, t)
v0.AuxInt = int8ToAuxInt(-c)
v.AddArg2(x, v0)
return true
}
break
}
// match: (Neg8 (Mul8 x (Neg8 y)))
// result: (Mul8 x y)
for {
if v_0.Op != OpMul8 {
break
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i0 := 0; _i0 <= 1; _i0, v_0_0, v_0_1 = _i0+1, v_0_1, v_0_0 {
x := v_0_0
if v_0_1.Op != OpNeg8 {
continue
}
y := v_0_1.Args[0]
v.reset(OpMul8)
v.AddArg2(x, y)
return true
}
break
}
// match: (Neg8 (Sub8 x y)) // match: (Neg8 (Sub8 x y))
// result: (Sub8 y x) // result: (Sub8 y x)
for { for {

View file

@ -140,7 +140,7 @@ func TestStmtLines(t *testing.T) {
var m float64 var m float64
switch runtime.GOARCH { switch runtime.GOARCH {
case "amd64": case "amd64":
m = 0.0111 // > 98.89% obtained on amd64, no backsliding m = 0.015 // > 98.5% obtained on amd64, there has been minor backsliding
case "riscv64": case "riscv64":
m = 0.03 // XXX temporary update threshold to 97% for regabi m = 0.03 // XXX temporary update threshold to 97% for regabi
default: default:

View file

@ -6172,8 +6172,8 @@ func (s *state) dottype(n *ir.TypeAssertExpr, commaok bool) (res, resok *ssa.Val
base.Fatalf("unexpected *ir.TypeAssertExpr with UseNilPanic == true && commaok == true") base.Fatalf("unexpected *ir.TypeAssertExpr with UseNilPanic == true && commaok == true")
} }
if n.Type().IsInterface() { if n.Type().IsInterface() {
// Currently we do not expect the compiler to emit type asserts with UseNilPanic, that assert to an interface type. // Currently we do not expect the compiler to emit type assertions with UseNilPanic, that asserts to an interface type.
// If needed, this can be relaxed in the future, but for now we can assert that. // If needed, this can be relaxed in the future, but for now we can't assert that.
base.Fatalf("unexpected *ir.TypeAssertExpr with UseNilPanic == true && Type().IsInterface() == true") base.Fatalf("unexpected *ir.TypeAssertExpr with UseNilPanic == true && Type().IsInterface() == true")
} }
typs := s.f.Config.Types typs := s.f.Config.Types

View file

@ -145,3 +145,32 @@ func BenchmarkMul2Neg(b *testing.B) {
globl = s globl = s
} }
} }
func BenchmarkSimplifyNegMul(b *testing.B) {
x := make([]int64, 1024)
y := make([]int64, 1024)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var s int64
for i := range x {
s = -(-x[i] * y[i])
}
globl = s
}
}
func BenchmarkSimplifyNegDiv(b *testing.B) {
x := make([]int64, 1024)
y := make([]int64, 1024)
for i := range y {
y[i] = 42
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var s int64
for i := range x {
s = -(-x[i] / y[i])
}
globl = s
}
}

View file

@ -304,3 +304,6 @@ var loong64HasLSX bool
var riscv64HasZbb bool var riscv64HasZbb bool
func asanregisterglobals(unsafe.Pointer, uintptr) func asanregisterglobals(unsafe.Pointer, uintptr)
// used by testing.B.Loop
func KeepAlive(interface{})

View file

@ -250,6 +250,7 @@ var runtimeDecls = [...]struct {
{"loong64HasLSX", varTag, 6}, {"loong64HasLSX", varTag, 6},
{"riscv64HasZbb", varTag, 6}, {"riscv64HasZbb", varTag, 6},
{"asanregisterglobals", funcTag, 136}, {"asanregisterglobals", funcTag, 136},
{"KeepAlive", funcTag, 11},
} }
func runtimeTypes() []*types.Type { func runtimeTypes() []*types.Type {

View file

@ -11,7 +11,7 @@ require (
golang.org/x/sys v0.38.0 golang.org/x/sys v0.38.0
golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54
golang.org/x/term v0.34.0 golang.org/x/term v0.34.0
golang.org/x/tools v0.39.1-0.20251114194111-59ff18ce4883 golang.org/x/tools v0.39.1-0.20251120214200-68724afed209
) )
require ( require (

View file

@ -22,7 +22,7 @@ golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/tools v0.39.1-0.20251114194111-59ff18ce4883 h1:aeO0AW8d+a+5+hNQx9f4J5egD89zftrY2x42KGQjLzI= golang.org/x/tools v0.39.1-0.20251120214200-68724afed209 h1:BGuEUnbWU1H+VhF4Z52lwCvzRT8Q/Z7kJC3okSME58w=
golang.org/x/tools v0.39.1-0.20251114194111-59ff18ce4883/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/tools v0.39.1-0.20251120214200-68724afed209/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
rsc.io/markdown v0.0.0-20240306144322-0bf8f97ee8ef h1:mqLYrXCXYEZOop9/Dbo6RPX11539nwiCNBb1icVPmw8= rsc.io/markdown v0.0.0-20240306144322-0bf8f97ee8ef h1:mqLYrXCXYEZOop9/Dbo6RPX11539nwiCNBb1icVPmw8=
rsc.io/markdown v0.0.0-20240306144322-0bf8f97ee8ef/go.mod h1:8xcPgWmwlZONN1D9bjxtHEjrUtSEa3fakVF8iaewYKQ= rsc.io/markdown v0.0.0-20240306144322-0bf8f97ee8ef/go.mod h1:8xcPgWmwlZONN1D9bjxtHEjrUtSEa3fakVF8iaewYKQ=

View file

@ -17,6 +17,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"slices" "slices"
"sort" "sort"
@ -91,9 +92,21 @@ func newGitRepo(ctx context.Context, remote string, local bool) (Repo, error) {
break break
} }
} }
gitSupportsSHA256, gitVersErr := gitSupportsSHA256()
if gitVersErr != nil {
return nil, fmt.Errorf("unable to resolve git version: %w", gitVersErr)
}
objFormatFlag := []string{} objFormatFlag := []string{}
// If git is sufficiently recent to support sha256,
// always initialize with an explicit object-format.
if repoSha256Hash { if repoSha256Hash {
// We always set --object-format=sha256 if the repo
// we're cloning uses sha256 hashes because if the git
// version is too old, it'll fail either way, so we
// might as well give it one last chance.
objFormatFlag = []string{"--object-format=sha256"} objFormatFlag = []string{"--object-format=sha256"}
} else if gitSupportsSHA256 {
objFormatFlag = []string{"--object-format=sha1"}
} }
if _, err := Run(ctx, r.dir, "git", "init", "--bare", objFormatFlag); err != nil { if _, err := Run(ctx, r.dir, "git", "init", "--bare", objFormatFlag); err != nil {
os.RemoveAll(r.dir) os.RemoveAll(r.dir)
@ -389,7 +402,7 @@ func (r *gitRepo) Latest(ctx context.Context) (*RevInfo, error) {
func (r *gitRepo) checkConfigSHA256(ctx context.Context) bool { func (r *gitRepo) checkConfigSHA256(ctx context.Context) bool {
if hashType, sha256CfgErr := r.runGit(ctx, "git", "config", "extensions.objectformat"); sha256CfgErr == nil { if hashType, sha256CfgErr := r.runGit(ctx, "git", "config", "extensions.objectformat"); sha256CfgErr == nil {
return "sha256" == strings.TrimSpace(string(hashType)) return strings.TrimSpace(string(hashType)) == "sha256"
} }
return false return false
} }
@ -972,3 +985,36 @@ func (r *gitRepo) runGit(ctx context.Context, cmdline ...any) ([]byte, error) {
} }
return RunWithArgs(ctx, args) return RunWithArgs(ctx, args)
} }
// Capture the major, minor and (optionally) patch version, but ignore anything later
var gitVersLineExtract = regexp.MustCompile(`git version\s+(\d+\.\d+(?:\.\d+)?)`)
func gitVersion() (string, error) {
gitOut, runErr := exec.Command("git", "version").CombinedOutput()
if runErr != nil {
return "v0", fmt.Errorf("failed to execute git version: %w", runErr)
}
return extractGitVersion(gitOut)
}
func extractGitVersion(gitOut []byte) (string, error) {
matches := gitVersLineExtract.FindSubmatch(gitOut)
if len(matches) < 2 {
return "v0", fmt.Errorf("git version extraction regexp did not match version line: %q", gitOut)
}
return "v" + string(matches[1]), nil
}
func hasAtLeastGitVersion(minVers string) (bool, error) {
gitVers, gitVersErr := gitVersion()
if gitVersErr != nil {
return false, gitVersErr
}
return semver.Compare(minVers, gitVers) <= 0, nil
}
const minGitSHA256Vers = "v2.29"
func gitSupportsSHA256() (bool, error) {
return hasAtLeastGitVersion(minGitSHA256Vers)
}

View file

@ -16,11 +16,9 @@ import (
"io/fs" "io/fs"
"log" "log"
"os" "os"
"os/exec"
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -196,42 +194,56 @@ func testRepo(ctx context.Context, t *testing.T, remote string) (Repo, error) {
return NewRepo(ctx, vcsName, remote, false) return NewRepo(ctx, vcsName, remote, false)
} }
var gitVersLineExtract = regexp.MustCompile(`git version\s+([\d.]+)`) func TestExtractGitVersion(t *testing.T) {
t.Parallel()
func gitVersion(t testing.TB) string { for _, tbl := range []struct {
gitOut, runErr := exec.Command("git", "version").CombinedOutput() in, exp string
if runErr != nil { }{
t.Logf("failed to execute git version: %s", runErr) {in: "git version 2.52.0.rc2", exp: "v2.52.0"},
return "v0" {in: "git version 2.52.0.38.g5e6e4854e0", exp: "v2.52.0"},
{in: "git version 2.51.2", exp: "v2.51.2"},
{in: "git version 1.5.0.5.GIT", exp: "v1.5.0"},
{in: "git version 1.5.1-rc3.GIT", exp: "v1.5.1"},
{in: "git version 1.5.2.GIT", exp: "v1.5.2"},
{in: "git version 2.43.0.rc2.23.gc3cc3e1da7", exp: "v2.43.0"},
} {
t.Run(tbl.exp, func(t *testing.T) {
out, extrErr := extractGitVersion([]byte(tbl.in))
if extrErr != nil {
t.Errorf("failed to extract git version from %q: %s", tbl.in, extrErr)
} }
matches := gitVersLineExtract.FindSubmatch(gitOut) if out != tbl.exp {
if len(matches) < 2 { t.Errorf("unexpected git version extractGitVersion(%q) = %q; want %q", tbl.in, out, tbl.exp)
t.Logf("git version extraction regexp did not match version line: %q", gitOut) }
return "v0" })
} }
return "v" + string(matches[1])
} }
const minGitSHA256Vers = "v2.29"
func TestTags(t *testing.T) { func TestTags(t *testing.T) {
t.Parallel()
gitVers := gitVersion(t) gitVers, gitVersErr := gitVersion()
if gitVersErr != nil {
t.Logf("git version check failed: %s", gitVersErr)
}
type tagsTest struct { type tagsTest struct {
repo string repo string
prefix string prefix string
tags []Tag tags []Tag
// Override the git default hash for a few cases to make sure
// we handle all 3 reasonable states.
defGitHash string
} }
runTest := func(tt tagsTest) func(*testing.T) { runTest := func(tt tagsTest) func(*testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
t.Parallel()
if tt.repo == gitsha256repo && semver.Compare(gitVers, minGitSHA256Vers) < 0 { if tt.repo == gitsha256repo && semver.Compare(gitVers, minGitSHA256Vers) < 0 {
t.Skipf("git version is too old (%+v); skipping git sha256 test", gitVers) t.Skipf("git version is too old (%+v); skipping git sha256 test", gitVers)
} }
ctx := testContext(t) ctx := testContext(t)
if tt.defGitHash != "" {
t.Setenv("GIT_DEFAULT_HASH", tt.defGitHash)
}
r, err := testRepo(ctx, t, tt.repo) r, err := testRepo(ctx, t, tt.repo)
if err != nil { if err != nil {
@ -248,27 +260,37 @@ func TestTags(t *testing.T) {
} }
for _, tt := range []tagsTest{ for _, tt := range []tagsTest{
{gitrepo1, "xxx", []Tag{}}, {gitrepo1, "xxx", []Tag{}, ""},
{gitrepo1, "xxx", []Tag{}, "sha256"},
{gitrepo1, "xxx", []Tag{}, "sha1"},
{gitrepo1, "", []Tag{ {gitrepo1, "", []Tag{
{"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"}, {"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"},
{"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"}, {"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"},
{"v2.0.1", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"}, {"v2.0.1", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"},
{"v2.0.2", "9d02800338b8a55be062c838d1f02e0c5780b9eb"}, {"v2.0.2", "9d02800338b8a55be062c838d1f02e0c5780b9eb"},
{"v2.3", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"}, {"v2.3", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"},
}}, }, ""},
{gitrepo1, "v", []Tag{ {gitrepo1, "v", []Tag{
{"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"}, {"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"},
{"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"}, {"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"},
{"v2.0.1", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"}, {"v2.0.1", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"},
{"v2.0.2", "9d02800338b8a55be062c838d1f02e0c5780b9eb"}, {"v2.0.2", "9d02800338b8a55be062c838d1f02e0c5780b9eb"},
{"v2.3", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"}, {"v2.3", "76a00fb249b7f93091bc2c89a789dab1fc1bc26f"},
}}, }, ""},
{gitrepo1, "v1", []Tag{ {gitrepo1, "v1", []Tag{
{"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"}, {"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"},
{"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"}, {"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"},
}}, }, ""},
{gitrepo1, "2", []Tag{}}, {gitrepo1, "v1", []Tag{
{gitsha256repo, "xxx", []Tag{}}, {"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"},
{"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"},
}, "sha256"},
{gitrepo1, "v1", []Tag{
{"v1.2.3", "ede458df7cd0fdca520df19a33158086a8a68e81"},
{"v1.2.4-annotated", "ede458df7cd0fdca520df19a33158086a8a68e81"},
}, "sha1"},
{gitrepo1, "2", []Tag{}, ""},
{gitsha256repo, "xxx", []Tag{}, ""},
{gitsha256repo, "", []Tag{ {gitsha256repo, "", []Tag{
{"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"}, {"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"}, {"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
@ -276,7 +298,7 @@ func TestTags(t *testing.T) {
{"v2.0.1", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"}, {"v2.0.1", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"},
{"v2.0.2", "1401e4e1fdb4169b51d44a1ff62af63ccc708bf5c12d15051268b51bbb6cbd82"}, {"v2.0.2", "1401e4e1fdb4169b51d44a1ff62af63ccc708bf5c12d15051268b51bbb6cbd82"},
{"v2.3", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"}, {"v2.3", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"},
}}, }, ""},
{gitsha256repo, "v", []Tag{ {gitsha256repo, "v", []Tag{
{"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"}, {"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"}, {"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
@ -284,13 +306,23 @@ func TestTags(t *testing.T) {
{"v2.0.1", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"}, {"v2.0.1", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"},
{"v2.0.2", "1401e4e1fdb4169b51d44a1ff62af63ccc708bf5c12d15051268b51bbb6cbd82"}, {"v2.0.2", "1401e4e1fdb4169b51d44a1ff62af63ccc708bf5c12d15051268b51bbb6cbd82"},
{"v2.3", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"}, {"v2.3", "b7550fd9d2129c724c39ae0536e8b2fae4364d8c82bb8b0880c9b71f67295d09"},
}}, }, ""},
{gitsha256repo, "v1", []Tag{ {gitsha256repo, "v1", []Tag{
{"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"}, {"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"}, {"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.3.0", "a9157cad2aa6dc2f78aa31fced5887f04e758afa8703f04d0178702ebf04ee17"}, {"v1.3.0", "a9157cad2aa6dc2f78aa31fced5887f04e758afa8703f04d0178702ebf04ee17"},
}}, }, ""},
{gitsha256repo, "2", []Tag{}}, {gitsha256repo, "v1", []Tag{
{"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.3.0", "a9157cad2aa6dc2f78aa31fced5887f04e758afa8703f04d0178702ebf04ee17"},
}, "sha1"},
{gitsha256repo, "v1", []Tag{
{"v1.2.3", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.2.4-annotated", "47b8b51b2a2d9d5caa3d460096c4e01f05700ce3a9390deb54400bd508214c5c"},
{"v1.3.0", "a9157cad2aa6dc2f78aa31fced5887f04e758afa8703f04d0178702ebf04ee17"},
}, "sha256"},
{gitsha256repo, "2", []Tag{}, ""},
} { } {
t.Run(path.Base(tt.repo)+"/"+tt.prefix, runTest(tt)) t.Run(path.Base(tt.repo)+"/"+tt.prefix, runTest(tt))
if tt.repo == gitrepo1 { if tt.repo == gitrepo1 {
@ -315,7 +347,10 @@ func TestTags(t *testing.T) {
func TestLatest(t *testing.T) { func TestLatest(t *testing.T) {
t.Parallel() t.Parallel()
gitVers := gitVersion(t) gitVers, gitVersErr := gitVersion()
if gitVersErr != nil {
t.Logf("git version check failed: %s", gitVersErr)
}
type latestTest struct { type latestTest struct {
repo string repo string
@ -409,7 +444,10 @@ func TestLatest(t *testing.T) {
func TestReadFile(t *testing.T) { func TestReadFile(t *testing.T) {
t.Parallel() t.Parallel()
gitVers := gitVersion(t) gitVers, gitVersErr := gitVersion()
if gitVersErr != nil {
t.Logf("git version check failed: %s", gitVersErr)
}
type readFileTest struct { type readFileTest struct {
repo string repo string
@ -508,7 +546,10 @@ type zipFile struct {
func TestReadZip(t *testing.T) { func TestReadZip(t *testing.T) {
t.Parallel() t.Parallel()
gitVers := gitVersion(t) gitVers, gitVersErr := gitVersion()
if gitVersErr != nil {
t.Logf("git version check failed: %s", gitVersErr)
}
type readZipTest struct { type readZipTest struct {
repo string repo string
@ -798,7 +839,10 @@ var hgmap = map[string]string{
func TestStat(t *testing.T) { func TestStat(t *testing.T) {
t.Parallel() t.Parallel()
gitVers := gitVersion(t) gitVers, gitVersErr := gitVersion()
if gitVersErr != nil {
t.Logf("git version check failed: %s", gitVersErr)
}
type statTest struct { type statTest struct {
repo string repo string

View file

@ -29,6 +29,7 @@ import (
"cmd/go/internal/lockedfile" "cmd/go/internal/lockedfile"
"cmd/go/internal/modfetch" "cmd/go/internal/modfetch"
"cmd/go/internal/search" "cmd/go/internal/search"
igover "internal/gover"
"golang.org/x/mod/modfile" "golang.org/x/mod/modfile"
"golang.org/x/mod/module" "golang.org/x/mod/module"
@ -826,7 +827,7 @@ func WriteWorkFile(path string, wf *modfile.WorkFile) error {
wf.Cleanup() wf.Cleanup()
out := modfile.Format(wf.Syntax) out := modfile.Format(wf.Syntax)
return os.WriteFile(path, out, 0666) return os.WriteFile(path, out, 0o666)
} }
// UpdateWorkGoVersion updates the go line in wf to be at least goVers, // UpdateWorkGoVersion updates the go line in wf to be at least goVers,
@ -1200,7 +1201,7 @@ func CreateModFile(loaderstate *State, ctx context.Context, modPath string) {
modFile := new(modfile.File) modFile := new(modfile.File)
modFile.AddModuleStmt(modPath) modFile.AddModuleStmt(modPath)
loaderstate.MainModules = makeMainModules(loaderstate, []module.Version{modFile.Module.Mod}, []string{modRoot}, []*modfile.File{modFile}, []*modFileIndex{nil}, nil) loaderstate.MainModules = makeMainModules(loaderstate, []module.Version{modFile.Module.Mod}, []string{modRoot}, []*modfile.File{modFile}, []*modFileIndex{nil}, nil)
addGoStmt(modFile, modFile.Module.Mod, gover.Local()) // Add the go directive before converted module requirements. addGoStmt(modFile, modFile.Module.Mod, DefaultModInitGoVersion()) // Add the go directive before converted module requirements.
rs := requirementsFromModFiles(loaderstate, ctx, nil, []*modfile.File{modFile}, nil) rs := requirementsFromModFiles(loaderstate, ctx, nil, []*modfile.File{modFile}, nil)
rs, err := updateRoots(loaderstate, ctx, rs.direct, rs, nil, nil, false) rs, err := updateRoots(loaderstate, ctx, rs.direct, rs, nil, nil, false)
@ -1811,9 +1812,7 @@ Run 'go help mod init' for more information.
return "", fmt.Errorf(msg, dir, reason) return "", fmt.Errorf(msg, dir, reason)
} }
var ( var importCommentRE = lazyregexp.New(`(?m)^package[ \t]+[^ \t\r\n/]+[ \t]+//[ \t]+import[ \t]+(\"[^"]+\")[ \t]*\r?\n`)
importCommentRE = lazyregexp.New(`(?m)^package[ \t]+[^ \t\r\n/]+[ \t]+//[ \t]+import[ \t]+(\"[^"]+\")[ \t]*\r?\n`)
)
func findImportComment(file string) string { func findImportComment(file string) string {
data, err := os.ReadFile(file) data, err := os.ReadFile(file)
@ -2252,3 +2251,29 @@ func CheckGodebug(verb, k, v string) error {
} }
return fmt.Errorf("unknown %s %q", verb, k) return fmt.Errorf("unknown %s %q", verb, k)
} }
// DefaultModInitGoVersion returns the appropriate go version to include in a
// newly initialized module or work file.
//
// If the current toolchain version is a stable version of Go 1.N.M, default to
// go 1.(N-1).0
//
// If the current toolchain version is a pre-release version of Go 1.N (Release
// Candidate M) or a development version of Go 1.N, default to go 1.(N-2).0
func DefaultModInitGoVersion() string {
v := gover.Local()
if isPrereleaseOrDevelVersion(v) {
v = gover.Prev(gover.Prev(v))
} else {
v = gover.Prev(v)
}
if strings.Count(v, ".") < 2 {
v += ".0"
}
return v
}
func isPrereleaseOrDevelVersion(s string) bool {
v := igover.Parse(s)
return v.Kind != "" || v.Patch == ""
}

View file

@ -398,7 +398,8 @@ func hasWorkingBzr() bool {
return err == nil return err == nil
} }
var gitVersLineExtract = regexp.MustCompile(`git version\s+([\d.]+)`) // Capture the major, minor and (optionally) patch version, but ignore anything later
var gitVersLineExtract = regexp.MustCompile(`git version\s+(\d+\.\d+(?:\.\d+)?)`)
func gitVersion() (string, error) { func gitVersion() (string, error) {
gitOut, runErr := exec.Command("git", "version").CombinedOutput() gitOut, runErr := exec.Command("git", "version").CombinedOutput()

View file

@ -12,7 +12,6 @@ import (
"cmd/go/internal/base" "cmd/go/internal/base"
"cmd/go/internal/fsys" "cmd/go/internal/fsys"
"cmd/go/internal/gover"
"cmd/go/internal/modload" "cmd/go/internal/modload"
"golang.org/x/mod/modfile" "golang.org/x/mod/modfile"
@ -58,10 +57,9 @@ func runInit(ctx context.Context, cmd *base.Command, args []string) {
base.Fatalf("go: %s already exists", gowork) base.Fatalf("go: %s already exists", gowork)
} }
goV := gover.Local() // Use current Go version by default
wf := new(modfile.WorkFile) wf := new(modfile.WorkFile)
wf.Syntax = new(modfile.FileSyntax) wf.Syntax = new(modfile.FileSyntax)
wf.AddGoStmt(goV) wf.AddGoStmt(modload.DefaultModInitGoVersion())
workUse(ctx, moduleLoaderState, gowork, wf, args) workUse(ctx, moduleLoaderState, gowork, wf, args)
modload.WriteWorkFile(gowork, wf) modload.WriteWorkFile(gowork, wf)
} }

View file

@ -157,7 +157,8 @@ func hasWorkingGit() bool {
return err == nil return err == nil
} }
var gitVersLineExtract = regexp.MustCompile(`git version\s+([\d.]+)`) // Capture the major, minor and (optionally) patch version, but ignore anything later
var gitVersLineExtract = regexp.MustCompile(`git version\s+(\d+\.\d+(?:\.\d+)?)`)
func gitVersion() (string, error) { func gitVersion() (string, error) {
gitOut, runErr := exec.Command("git", "version").CombinedOutput() gitOut, runErr := exec.Command("git", "version").CombinedOutput()

View file

@ -1,5 +1,8 @@
env GO111MODULE=on env GO111MODULE=on
# Set go version so that we can test produced mod files for equality.
env TESTGO_VERSION=go1.26.0
# Test that go mod edits and related mod flags work. # Test that go mod edits and related mod flags work.
# Also test that they can use a dummy name that isn't resolvable. golang.org/issue/24100 # Also test that they can use a dummy name that isn't resolvable. golang.org/issue/24100
@ -10,16 +13,16 @@ stderr 'cannot determine module path'
go mod init x.x/y/z go mod init x.x/y/z
stderr 'creating new go.mod: module x.x/y/z' stderr 'creating new go.mod: module x.x/y/z'
cmpenv go.mod $WORK/go.mod.init cmp go.mod $WORK/go.mod.init
! go mod init ! go mod init
cmpenv go.mod $WORK/go.mod.init cmp go.mod $WORK/go.mod.init
# go mod edits # go mod edits
go mod edit -droprequire=x.1 -require=x.1@v1.0.0 -require=x.2@v1.1.0 -droprequire=x.2 -exclude='x.1 @ v1.2.0' -exclude=x.1@v1.2.1 -exclude=x.1@v2.0.0+incompatible -replace=x.1@v1.3.0=y.1@v1.4.0 -replace='x.1@v1.4.0 = ../z' -retract=v1.6.0 -retract=[v1.1.0,v1.2.0] -retract=[v1.3.0,v1.4.0] -retract=v1.0.0 go mod edit -droprequire=x.1 -require=x.1@v1.0.0 -require=x.2@v1.1.0 -droprequire=x.2 -exclude='x.1 @ v1.2.0' -exclude=x.1@v1.2.1 -exclude=x.1@v2.0.0+incompatible -replace=x.1@v1.3.0=y.1@v1.4.0 -replace='x.1@v1.4.0 = ../z' -retract=v1.6.0 -retract=[v1.1.0,v1.2.0] -retract=[v1.3.0,v1.4.0] -retract=v1.0.0
cmpenv go.mod $WORK/go.mod.edit1 cmp go.mod $WORK/go.mod.edit1
go mod edit -droprequire=x.1 -dropexclude=x.1@v1.2.1 -dropexclude=x.1@v2.0.0+incompatible -dropreplace=x.1@v1.3.0 -require=x.3@v1.99.0 -dropretract=v1.0.0 -dropretract=[v1.1.0,v1.2.0] go mod edit -droprequire=x.1 -dropexclude=x.1@v1.2.1 -dropexclude=x.1@v2.0.0+incompatible -dropreplace=x.1@v1.3.0 -require=x.3@v1.99.0 -dropretract=v1.0.0 -dropretract=[v1.1.0,v1.2.0]
cmpenv go.mod $WORK/go.mod.edit2 cmp go.mod $WORK/go.mod.edit2
# -exclude and -retract reject invalid versions. # -exclude and -retract reject invalid versions.
! go mod edit -exclude=example.com/m@bad ! go mod edit -exclude=example.com/m@bad
@ -36,11 +39,11 @@ stderr '^go: -exclude=example.com/m/v2@v1\.0\.0: version "v1\.0\.0" invalid: sho
! go mod edit -exclude=gopkg.in/example.v1@v2.0.0 ! go mod edit -exclude=gopkg.in/example.v1@v2.0.0
stderr '^go: -exclude=gopkg\.in/example\.v1@v2\.0\.0: version "v2\.0\.0" invalid: should be v1, not v2$' stderr '^go: -exclude=gopkg\.in/example\.v1@v2\.0\.0: version "v2\.0\.0" invalid: should be v1, not v2$'
cmpenv go.mod $WORK/go.mod.edit2 cmp go.mod $WORK/go.mod.edit2
# go mod edit -json # go mod edit -json
go mod edit -json go mod edit -json
cmpenv stdout $WORK/go.mod.json cmp stdout $WORK/go.mod.json
# go mod edit -json (retractions with rationales) # go mod edit -json (retractions with rationales)
go mod edit -json $WORK/go.mod.retractrationale go mod edit -json $WORK/go.mod.retractrationale
@ -56,66 +59,66 @@ cmp stdout $WORK/go.mod.empty.json
# go mod edit -replace # go mod edit -replace
go mod edit -replace=x.1@v1.3.0=y.1/v2@v2.3.5 -replace=x.1@v1.4.0=y.1/v2@v2.3.5 go mod edit -replace=x.1@v1.3.0=y.1/v2@v2.3.5 -replace=x.1@v1.4.0=y.1/v2@v2.3.5
cmpenv go.mod $WORK/go.mod.edit3 cmp go.mod $WORK/go.mod.edit3
go mod edit -replace=x.1=y.1/v2@v2.3.6 go mod edit -replace=x.1=y.1/v2@v2.3.6
cmpenv go.mod $WORK/go.mod.edit4 cmp go.mod $WORK/go.mod.edit4
go mod edit -dropreplace=x.1 go mod edit -dropreplace=x.1
cmpenv go.mod $WORK/go.mod.edit5 cmp go.mod $WORK/go.mod.edit5
go mod edit -replace=x.1=../y.1/@v2 go mod edit -replace=x.1=../y.1/@v2
cmpenv go.mod $WORK/go.mod.edit6 cmp go.mod $WORK/go.mod.edit6
! go mod edit -replace=x.1=y.1/@v2 ! go mod edit -replace=x.1=y.1/@v2
stderr '^go: -replace=x.1=y.1/@v2: invalid new path: malformed import path "y.1/": trailing slash$' stderr '^go: -replace=x.1=y.1/@v2: invalid new path: malformed import path "y.1/": trailing slash$'
# go mod edit -fmt # go mod edit -fmt
cp $WORK/go.mod.badfmt go.mod cp $WORK/go.mod.badfmt go.mod
go mod edit -fmt -print # -print should avoid writing file go mod edit -fmt -print # -print should avoid writing file
cmpenv stdout $WORK/go.mod.goodfmt cmp stdout $WORK/go.mod.goodfmt
cmp go.mod $WORK/go.mod.badfmt cmp go.mod $WORK/go.mod.badfmt
go mod edit -fmt # without -print, should write file (and nothing to stdout) go mod edit -fmt # without -print, should write file (and nothing to stdout)
! stdout . ! stdout .
cmpenv go.mod $WORK/go.mod.goodfmt cmp go.mod $WORK/go.mod.goodfmt
# go mod edit -module # go mod edit -module
cd $WORK/m cd $WORK/m
go mod init a.a/b/c go mod init a.a/b/c
go mod edit -module x.x/y/z go mod edit -module x.x/y/z
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
# golang.org/issue/30513: don't require go-gettable module paths. # golang.org/issue/30513: don't require go-gettable module paths.
cd $WORK/local cd $WORK/local
go mod init foo go mod init foo
go mod edit -module local-only -require=other-local@v1.0.0 -replace other-local@v1.0.0=./other go mod edit -module local-only -require=other-local@v1.0.0 -replace other-local@v1.0.0=./other
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
# go mod edit -godebug # go mod edit -godebug
cd $WORK/g cd $WORK/g
cp go.mod.start go.mod cp go.mod.start go.mod
go mod edit -godebug key=value go mod edit -godebug key=value
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
go mod edit -dropgodebug key2 go mod edit -dropgodebug key2
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
go mod edit -dropgodebug key go mod edit -dropgodebug key
cmpenv go.mod go.mod.start cmp go.mod go.mod.start
# go mod edit -tool # go mod edit -tool
cd $WORK/h cd $WORK/h
cp go.mod.start go.mod cp go.mod.start go.mod
go mod edit -tool example.com/tool go mod edit -tool example.com/tool
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
go mod edit -droptool example.com/tool2 go mod edit -droptool example.com/tool2
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
go mod edit -droptool example.com/tool go mod edit -droptool example.com/tool
cmpenv go.mod go.mod.start cmp go.mod go.mod.start
# go mod edit -ignore # go mod edit -ignore
cd $WORK/i cd $WORK/i
cp go.mod.start go.mod cp go.mod.start go.mod
go mod edit -ignore example.com/ignore go mod edit -ignore example.com/ignore
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
go mod edit -dropignore example.com/ignore2 go mod edit -dropignore example.com/ignore2
cmpenv go.mod go.mod.edit cmp go.mod go.mod.edit
go mod edit -dropignore example.com/ignore go mod edit -dropignore example.com/ignore
cmpenv go.mod go.mod.start cmp go.mod go.mod.start
-- x.go -- -- x.go --
package x package x
@ -126,11 +129,11 @@ package w
-- $WORK/go.mod.init -- -- $WORK/go.mod.init --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
-- $WORK/go.mod.edit1 -- -- $WORK/go.mod.edit1 --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
require x.1 v1.0.0 require x.1 v1.0.0
@ -154,7 +157,7 @@ retract (
-- $WORK/go.mod.edit2 -- -- $WORK/go.mod.edit2 --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
exclude x.1 v1.2.0 exclude x.1 v1.2.0
@ -171,7 +174,7 @@ require x.3 v1.99.0
"Module": { "Module": {
"Path": "x.x/y/z" "Path": "x.x/y/z"
}, },
"Go": "$goversion", "Go": "1.25.0",
"Require": [ "Require": [
{ {
"Path": "x.3", "Path": "x.3",
@ -211,7 +214,7 @@ require x.3 v1.99.0
-- $WORK/go.mod.edit3 -- -- $WORK/go.mod.edit3 --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
exclude x.1 v1.2.0 exclude x.1 v1.2.0
@ -229,7 +232,7 @@ require x.3 v1.99.0
-- $WORK/go.mod.edit4 -- -- $WORK/go.mod.edit4 --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
exclude x.1 v1.2.0 exclude x.1 v1.2.0
@ -244,7 +247,7 @@ require x.3 v1.99.0
-- $WORK/go.mod.edit5 -- -- $WORK/go.mod.edit5 --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
exclude x.1 v1.2.0 exclude x.1 v1.2.0
@ -257,7 +260,7 @@ require x.3 v1.99.0
-- $WORK/go.mod.edit6 -- -- $WORK/go.mod.edit6 --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
exclude x.1 v1.2.0 exclude x.1 v1.2.0
@ -272,7 +275,7 @@ replace x.1 => ../y.1/@v2
-- $WORK/local/go.mod.edit -- -- $WORK/local/go.mod.edit --
module local-only module local-only
go $goversion go 1.25.0
require other-local v1.0.0 require other-local v1.0.0
@ -304,7 +307,7 @@ retract [v1.8.1, v1.8.2]
-- $WORK/m/go.mod.edit -- -- $WORK/m/go.mod.edit --
module x.x/y/z module x.x/y/z
go $goversion go 1.25.0
-- $WORK/go.mod.retractrationale -- -- $WORK/go.mod.retractrationale --
module x.x/y/z module x.x/y/z

View file

@ -0,0 +1,47 @@
env TESTGO_VERSION=go1.28-devel
go mod init example.com
cmp go.mod go.mod.want-1.26.0
rm go.mod
env TESTGO_VERSION=go1.26.0
go mod init example.com
cmp go.mod go.mod.want-1.25.0
rm go.mod
env TESTGO_VERSION=go1.22.2
go mod init example.com
cmp go.mod go.mod.want-1.21.0
rm go.mod
env TESTGO_VERSION=go1.25.0-xyzzy
go mod init example.com
cmp go.mod go.mod.want-1.24.0
rm go.mod
env TESTGO_VERSION=go1.23rc3
go mod init example.com
cmp go.mod go.mod.want-1.21.0
rm go.mod
env TESTGO_VERSION=go1.18beta2
go mod init example.com
cmp go.mod go.mod.want-1.16.0
-- go.mod.want-1.26.0 --
module example.com
go 1.26.0
-- go.mod.want-1.25.0 --
module example.com
go 1.25.0
-- go.mod.want-1.24.0 --
module example.com
go 1.24.0
-- go.mod.want-1.22.0 --
module example.com
go 1.22.0
-- go.mod.want-1.21.0 --
module example.com
go 1.21.0
-- go.mod.want-1.16.0 --
module example.com
go 1.16.0

View file

@ -1,4 +1,5 @@
[short] skip 'runs go run' [short] skip 'runs go run'
env TESTGO_VERSION=go1.26.0
! go work init doesnotexist ! go work init doesnotexist
stderr 'go: directory doesnotexist does not exist' stderr 'go: directory doesnotexist does not exist'
@ -74,7 +75,7 @@ use (
../src/a ../src/a
) )
-- go.work.want -- -- go.work.want --
go $goversion go 1.25.0
use ( use (
./a ./a

View file

@ -1,4 +1,5 @@
# Test editing go.work files. # Test editing go.work files.
env TESTGO_VERSION=go1.26.0
go work init m go work init m
cmpenv go.work go.work.want_initial cmpenv go.work go.work.want_initial
@ -54,11 +55,11 @@ module m
go 1.18 go 1.18
-- go.work.want_initial -- -- go.work.want_initial --
go $goversion go 1.25.0
use ./m use ./m
-- go.work.want_use_n -- -- go.work.want_use_n --
go $goversion go 1.25.0
use ( use (
./m ./m

View file

@ -2,7 +2,7 @@
# 'go work init . .. foo/bar' should produce a go.work file # 'go work init . .. foo/bar' should produce a go.work file
# with the same paths as 'go work init; go work use -r ..', # with the same paths as 'go work init; go work use -r ..',
# and it should have 'use .' rather than 'use ./.' inside. # and it should have 'use .' rather than 'use ./.' inside.
env TESTGO_VERSION=go1.23
cd dir cd dir
go work init . .. foo/bar go work init . .. foo/bar
@ -12,19 +12,19 @@ go work init
go work use -r .. go work use -r ..
cmp go.work go.work.init cmp go.work go.work.init
cmpenv go.work $WORK/go.work.want cmp go.work $WORK/go.work.want
-- go.mod -- -- go.mod --
module example module example
go 1.18 go 1.18
-- dir/go.mod -- -- dir/go.mod --
module example module example
go 1.18 go 1.21.0
-- dir/foo/bar/go.mod -- -- dir/foo/bar/go.mod --
module example module example
go 1.18 go 1.21.0
-- $WORK/go.work.want -- -- $WORK/go.work.want --
go $goversion go 1.21.0
use ( use (
. .

View file

@ -8,13 +8,13 @@ go mod edit -C m1_22_0 -go=1.22.0 -toolchain=go1.99.0
# work init writes the current Go version to the go line # work init writes the current Go version to the go line
go work init go work init
grep '^go 1.50$' go.work grep '^go 1.48.0$' go.work
! grep toolchain go.work ! grep toolchain go.work
# work init with older modules should leave go 1.50 in the go.work. # work init with older modules should leave go 1.48.0 in the go.work.
rm go.work rm go.work
go work init ./m1_22_0 go work init ./m1_22_0
grep '^go 1.50$' go.work grep '^go 1.48.0$' go.work
! grep toolchain go.work ! grep toolchain go.work
# work init with newer modules should bump go, # work init with newer modules should bump go,
@ -31,5 +31,5 @@ env GOTOOLCHAIN=auto
go work init ./m1_22_0 go work init ./m1_22_0
stderr '^go: m1_22_0'${/}'go.mod requires go >= 1.22.0; switching to go1.22.9$' stderr '^go: m1_22_0'${/}'go.mod requires go >= 1.22.0; switching to go1.22.9$'
cat go.work cat go.work
grep '^go 1.22.9$' go.work grep '^go 1.22.0$' go.work
! grep toolchain go.work ! grep toolchain go.work

View file

@ -0,0 +1,35 @@
env TESTGO_VERSION=go1.28-devel
go work init
cmp go.work go.work.want-1.26.0
rm go.work
env TESTGO_VERSION=go1.26.0
go work init
cmp go.work go.work.want-1.25.0
rm go.work
env TESTGO_VERSION=go1.22.2
go work init
cmp go.work go.work.want-1.21.0
rm go.work
env TESTGO_VERSION=go1.25.0-xyzzy
go work init
cmp go.work go.work.want-1.24.0
rm go.work
env TESTGO_VERSION=go1.24rc3
go work init
cmp go.work go.work.want-1.22.0
rm go.work
env TESTGO_VERSION=go1.18beta2
go work init
cmp go.work go.work.want-1.16.0
-- go.work.want-1.26.0 --
go 1.26.0
-- go.work.want-1.25.0 --
go 1.25.0
-- go.work.want-1.24.0 --
go 1.24.0
-- go.work.want-1.22.0 --
go 1.22.0
-- go.work.want-1.21.0 --
go 1.21.0
-- go.work.want-1.16.0 --
go 1.16.0

View file

@ -11,13 +11,14 @@ go mod init -C m1_24_rc0
go mod edit -C m1_24_rc0 -go=1.24rc0 -toolchain=go1.99.2 go mod edit -C m1_24_rc0 -go=1.24rc0 -toolchain=go1.99.2
go work init ./m1_22_0 ./m1_22_1 go work init ./m1_22_0 ./m1_22_1
grep '^go 1.50$' go.work cat go.work
grep '^go 1.48.0$' go.work
! grep toolchain go.work ! grep toolchain go.work
# work sync with older modules should leave go 1.50 in the go.work. # work sync with older modules should leave go 1.48.0 in the go.work.
go work sync go work sync
cat go.work cat go.work
grep '^go 1.50$' go.work grep '^go 1.48.0$' go.work
! grep toolchain go.work ! grep toolchain go.work
# work sync with newer modules should update go 1.21 -> 1.22.1 and toolchain -> go1.22.9 in go.work # work sync with newer modules should update go 1.21 -> 1.22.1 and toolchain -> go1.22.9 in go.work

View file

@ -11,12 +11,12 @@ go mod init -C m1_24_rc0
go mod edit -C m1_24_rc0 -go=1.24rc0 -toolchain=go1.99.2 go mod edit -C m1_24_rc0 -go=1.24rc0 -toolchain=go1.99.2
go work init go work init
grep '^go 1.50$' go.work grep '^go 1.48.0$' go.work
! grep toolchain go.work ! grep toolchain go.work
# work use with older modules should leave go 1.50 in the go.work. # work use with older modules should leave go 1.48.0 in the go.work.
go work use ./m1_22_0 go work use ./m1_22_0
grep '^go 1.50$' go.work grep '^go 1.48.0$' go.work
! grep toolchain go.work ! grep toolchain go.work
# work use with newer modules should bump go and toolchain, # work use with newer modules should bump go and toolchain,

View file

@ -1674,32 +1674,7 @@ func log2(x uint64) uint32 {
if x == 0 { if x == 0 {
panic("log2 of 0") panic("log2 of 0")
} }
n := uint32(0) return uint32(bits.Len64(x) - 1)
if x >= 1<<32 {
x >>= 32
n += 32
}
if x >= 1<<16 {
x >>= 16
n += 16
}
if x >= 1<<8 {
x >>= 8
n += 8
}
if x >= 1<<4 {
x >>= 4
n += 4
}
if x >= 1<<2 {
x >>= 2
n += 2
}
if x >= 1<<1 {
x >>= 1
n += 1
}
return n
} }
func autoclass(l int64) int { func autoclass(l int64) int {

View file

@ -356,6 +356,8 @@ func (s *LSym) checkFIPSReloc(ctxt *Link, rel Reloc) {
objabi.R_GOTPCREL, objabi.R_GOTPCREL,
objabi.R_LOONG64_ADDR_LO, // used with PC-relative load objabi.R_LOONG64_ADDR_LO, // used with PC-relative load
objabi.R_LOONG64_ADDR_HI, // used with PC-relative load objabi.R_LOONG64_ADDR_HI, // used with PC-relative load
objabi.R_LOONG64_ADDR_PCREL20_S2, // used with PC-relative load
objabi.R_LOONG64_CALL36,
objabi.R_LOONG64_TLS_LE_HI, objabi.R_LOONG64_TLS_LE_HI,
objabi.R_LOONG64_TLS_LE_LO, objabi.R_LOONG64_TLS_LE_LO,
objabi.R_LOONG64_TLS_IE_HI, objabi.R_LOONG64_TLS_IE_HI,

View file

@ -912,6 +912,23 @@ const (
AVSEQV AVSEQV
AXVSEQV AXVSEQV
AVSLTB
AVSLTH
AVSLTW
AVSLTV
AVSLTBU
AVSLTHU
AVSLTWU
AVSLTVU
AXVSLTB
AXVSLTH
AXVSLTW
AXVSLTV
AXVSLTBU
AXVSLTHU
AXVSLTWU
AXVSLTVU
// LSX and LASX integer div and mod instructions // LSX and LASX integer div and mod instructions
AVDIVB AVDIVB
AVDIVH AVDIVH

View file

@ -400,6 +400,22 @@ var Anames = []string{
"XVSEQW", "XVSEQW",
"VSEQV", "VSEQV",
"XVSEQV", "XVSEQV",
"VSLTB",
"VSLTH",
"VSLTW",
"VSLTV",
"VSLTBU",
"VSLTHU",
"VSLTWU",
"VSLTVU",
"XVSLTB",
"XVSLTH",
"XVSLTW",
"XVSLTV",
"XVSLTBU",
"XVSLTHU",
"XVSLTWU",
"XVSLTVU",
"VDIVB", "VDIVB",
"VDIVH", "VDIVH",
"VDIVW", "VDIVW",

View file

@ -65,26 +65,19 @@ var optab = []Optab{
{AMOVWU, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 12, 4, 0, 0}, {AMOVWU, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 12, 4, 0, 0},
{ASUB, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {ASUB, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ASUBV, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AADD, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {AADD, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AADDV, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {AADDV, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AAND, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {AAND, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ASUB, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {ASUB, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ASUBV, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AADD, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {AADD, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AADDV, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {AADDV, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AAND, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {AAND, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ANEGW, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {ANEGW, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ANEGV, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AMASKEQZ, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {AMASKEQZ, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ASLL, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {ASLL, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ASLL, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {ASLL, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ASLLV, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {ASLLV, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{ASLLV, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0}, {ASLLV, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AMUL, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AMUL, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AMULV, C_REG, C_NONE, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AMULV, C_REG, C_REG, C_NONE, C_REG, C_NONE, 2, 4, 0, 0},
{AADDF, C_FREG, C_NONE, C_NONE, C_FREG, C_NONE, 2, 4, 0, 0}, {AADDF, C_FREG, C_NONE, C_NONE, C_FREG, C_NONE, 2, 4, 0, 0},
{AADDF, C_FREG, C_FREG, C_NONE, C_FREG, C_NONE, 2, 4, 0, 0}, {AADDF, C_FREG, C_FREG, C_NONE, C_FREG, C_NONE, 2, 4, 0, 0},
{ACMPEQF, C_FREG, C_FREG, C_NONE, C_FCCREG, C_NONE, 2, 4, 0, 0}, {ACMPEQF, C_FREG, C_FREG, C_NONE, C_FCCREG, C_NONE, 2, 4, 0, 0},
@ -93,6 +86,14 @@ var optab = []Optab{
{AXVSEQB, C_XREG, C_XREG, C_NONE, C_XREG, C_NONE, 2, 4, 0, 0}, {AXVSEQB, C_XREG, C_XREG, C_NONE, C_XREG, C_NONE, 2, 4, 0, 0},
{AVSEQB, C_S5CON, C_VREG, C_NONE, C_VREG, C_NONE, 22, 4, 0, 0}, {AVSEQB, C_S5CON, C_VREG, C_NONE, C_VREG, C_NONE, 22, 4, 0, 0},
{AXVSEQB, C_S5CON, C_XREG, C_NONE, C_XREG, C_NONE, 22, 4, 0, 0}, {AXVSEQB, C_S5CON, C_XREG, C_NONE, C_XREG, C_NONE, 22, 4, 0, 0},
{AVSLTB, C_VREG, C_VREG, C_NONE, C_VREG, C_NONE, 2, 4, 0, 0},
{AXVSLTB, C_XREG, C_XREG, C_NONE, C_XREG, C_NONE, 2, 4, 0, 0},
{AVSLTB, C_S5CON, C_VREG, C_NONE, C_VREG, C_NONE, 22, 4, 0, 0},
{AXVSLTB, C_S5CON, C_XREG, C_NONE, C_XREG, C_NONE, 22, 4, 0, 0},
{AVSLTB, C_U5CON, C_VREG, C_NONE, C_VREG, C_NONE, 31, 4, 0, 0},
{AXVSLTB, C_U5CON, C_XREG, C_NONE, C_XREG, C_NONE, 31, 4, 0, 0},
{AVANDV, C_VREG, C_VREG, C_NONE, C_VREG, C_NONE, 2, 4, 0, 0}, {AVANDV, C_VREG, C_VREG, C_NONE, C_VREG, C_NONE, 2, 4, 0, 0},
{AVANDV, C_VREG, C_NONE, C_NONE, C_VREG, C_NONE, 2, 4, 0, 0}, {AVANDV, C_VREG, C_NONE, C_NONE, C_VREG, C_NONE, 2, 4, 0, 0},
{AXVANDV, C_XREG, C_XREG, C_NONE, C_XREG, C_NONE, 2, 4, 0, 0}, {AXVANDV, C_XREG, C_XREG, C_NONE, C_XREG, C_NONE, 2, 4, 0, 0},
@ -1495,26 +1496,6 @@ func buildop(ctxt *obj.Link) {
opset(ALL, r0) opset(ALL, r0)
opset(ALLV, r0) opset(ALLV, r0)
case AMUL:
opset(AMULU, r0)
opset(AMULH, r0)
opset(AMULHU, r0)
opset(AREM, r0)
opset(AREMU, r0)
opset(ADIV, r0)
opset(ADIVU, r0)
opset(AMULWVW, r0)
opset(AMULWVWU, r0)
case AMULV:
opset(AMULVU, r0)
opset(AMULHV, r0)
opset(AMULHVU, r0)
opset(AREMV, r0)
opset(AREMVU, r0)
opset(ADIVV, r0)
opset(ADIVVU, r0)
case ASLL: case ASLL:
opset(ASRL, r0) opset(ASRL, r0)
opset(ASRA, r0) opset(ASRA, r0)
@ -1533,9 +1514,26 @@ func buildop(ctxt *obj.Link) {
case ASUB: case ASUB:
opset(ASUBU, r0) opset(ASUBU, r0)
opset(ANOR, r0) opset(ANOR, r0)
opset(ASUBV, r0)
case ASUBV:
opset(ASUBVU, r0) opset(ASUBVU, r0)
opset(AMUL, r0)
opset(AMULU, r0)
opset(AMULH, r0)
opset(AMULHU, r0)
opset(AREM, r0)
opset(AREMU, r0)
opset(ADIV, r0)
opset(ADIVU, r0)
opset(AMULV, r0)
opset(AMULVU, r0)
opset(AMULHV, r0)
opset(AMULHVU, r0)
opset(AREMV, r0)
opset(AREMVU, r0)
opset(ADIVV, r0)
opset(ADIVVU, r0)
opset(AMULWVW, r0)
opset(AMULWVWU, r0)
case ASYSCALL: case ASYSCALL:
opset(ADBAR, r0) opset(ADBAR, r0)
@ -1555,6 +1553,9 @@ func buildop(ctxt *obj.Link) {
opset(AALSLW, r0) opset(AALSLW, r0)
opset(AALSLWU, r0) opset(AALSLWU, r0)
case ANEGW:
opset(ANEGV, r0)
case AMOVW, case AMOVW,
AMOVD, AMOVD,
AMOVF, AMOVF,
@ -1567,8 +1568,6 @@ func buildop(ctxt *obj.Link) {
AXVMOVQ, AXVMOVQ,
AVSHUFB, AVSHUFB,
AXVSHUFB, AXVSHUFB,
ANEGW,
ANEGV,
AWORD, AWORD,
APRELD, APRELD,
APRELDX, APRELDX,
@ -1784,6 +1783,24 @@ func buildop(ctxt *obj.Link) {
opset(AXVSHUFW, r0) opset(AXVSHUFW, r0)
opset(AXVSHUFV, r0) opset(AXVSHUFV, r0)
case AVSLTB:
opset(AVSLTH, r0)
opset(AVSLTW, r0)
opset(AVSLTV, r0)
opset(AVSLTBU, r0)
opset(AVSLTHU, r0)
opset(AVSLTWU, r0)
opset(AVSLTVU, r0)
case AXVSLTB:
opset(AXVSLTH, r0)
opset(AXVSLTW, r0)
opset(AXVSLTV, r0)
opset(AXVSLTBU, r0)
opset(AXVSLTHU, r0)
opset(AXVSLTWU, r0)
opset(AXVSLTVU, r0)
case AVANDB: case AVANDB:
opset(AVORB, r0) opset(AVORB, r0)
opset(AVXORB, r0) opset(AVXORB, r0)
@ -3379,6 +3396,38 @@ func (c *ctxt0) oprrr(a obj.As) uint32 {
return 0x0e003 << 15 // vseq.d return 0x0e003 << 15 // vseq.d
case AXVSEQV: case AXVSEQV:
return 0x0e803 << 15 // xvseq.d return 0x0e803 << 15 // xvseq.d
case AVSLTB:
return 0x0E00C << 15 // vslt.b
case AVSLTH:
return 0x0E00D << 15 // vslt.h
case AVSLTW:
return 0x0E00E << 15 // vslt.w
case AVSLTV:
return 0x0E00F << 15 // vslt.d
case AVSLTBU:
return 0x0E010 << 15 // vslt.bu
case AVSLTHU:
return 0x0E011 << 15 // vslt.hu
case AVSLTWU:
return 0x0E012 << 15 // vslt.wu
case AVSLTVU:
return 0x0E013 << 15 // vslt.du
case AXVSLTB:
return 0x0E80C << 15 // xvslt.b
case AXVSLTH:
return 0x0E80D << 15 // xvslt.h
case AXVSLTW:
return 0x0E80E << 15 // xvslt.w
case AXVSLTV:
return 0x0E80F << 15 // xvslt.d
case AXVSLTBU:
return 0x0E810 << 15 // xvslt.bu
case AXVSLTHU:
return 0x0E811 << 15 // xvslt.hu
case AXVSLTWU:
return 0x0E812 << 15 // xvslt.wu
case AXVSLTVU:
return 0x0E813 << 15 // xvslt.du
case AVANDV: case AVANDV:
return 0x0E24C << 15 // vand.v return 0x0E24C << 15 // vand.v
case AVORV: case AVORV:
@ -4399,6 +4448,38 @@ func (c *ctxt0) opirr(a obj.As) uint32 {
return 0x0ED02 << 15 // xvseqi.w return 0x0ED02 << 15 // xvseqi.w
case AXVSEQV: case AXVSEQV:
return 0x0ED03 << 15 // xvseqi.d return 0x0ED03 << 15 // xvseqi.d
case AVSLTB:
return 0x0E50C << 15 // vslti.b
case AVSLTH:
return 0x0E50D << 15 // vslti.h
case AVSLTW:
return 0x0E50E << 15 // vslti.w
case AVSLTV:
return 0x0E50F << 15 // vslti.d
case AVSLTBU:
return 0x0E510 << 15 // vslti.bu
case AVSLTHU:
return 0x0E511 << 15 // vslti.hu
case AVSLTWU:
return 0x0E512 << 15 // vslti.wu
case AVSLTVU:
return 0x0E513 << 15 // vslti.du
case AXVSLTB:
return 0x0ED0C << 15 // xvslti.b
case AXVSLTH:
return 0x0ED0D << 15 // xvslti.h
case AXVSLTW:
return 0x0ED0E << 15 // xvslti.w
case AXVSLTV:
return 0x0ED0F << 15 // xvslti.d
case AXVSLTBU:
return 0x0ED10 << 15 // xvslti.bu
case AXVSLTHU:
return 0x0ED11 << 15 // xvslti.hu
case AXVSLTWU:
return 0x0ED12 << 15 // xvslti.wu
case AXVSLTVU:
return 0x0ED13 << 15 // xvslti.du
case AVROTRB: case AVROTRB:
return 0x1ca8<<18 | 0x1<<13 // vrotri.b return 0x1ca8<<18 | 0x1<<13 // vrotri.b
case AVROTRH: case AVROTRH:

View file

@ -17,7 +17,8 @@ import (
) )
// TestLargeBranch generates a large function with a very far conditional // TestLargeBranch generates a large function with a very far conditional
// branch, in order to ensure that it assembles successfully. // branch, in order to ensure that it assembles correctly. This requires
// inverting the branch and using a jump to reach the target.
func TestLargeBranch(t *testing.T) { func TestLargeBranch(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping test in short mode") t.Skip("Skipping test in short mode")
@ -26,6 +27,23 @@ func TestLargeBranch(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module largecall"), 0644); err != nil {
t.Fatalf("Failed to write file: %v\n", err)
}
main := `package main
import "fmt"
func main() {
fmt.Print(x())
}
func x() uint64
`
if err := os.WriteFile(filepath.Join(dir, "x.go"), []byte(main), 0644); err != nil {
t.Fatalf("failed to write main: %v\n", err)
}
// Generate a very large function. // Generate a very large function.
buf := bytes.NewBuffer(make([]byte, 0, 7000000)) buf := bytes.NewBuffer(make([]byte, 0, 7000000))
genLargeBranch(buf) genLargeBranch(buf)
@ -36,27 +54,62 @@ func TestLargeBranch(t *testing.T) {
} }
// Assemble generated file. // Assemble generated file.
cmd := testenv.Command(t, testenv.GoToolPath(t), "tool", "asm", "-o", filepath.Join(dir, "x.o"), tmpfile) cmd := exec.Command(testenv.GoToolPath(t), "tool", "asm", "-o", filepath.Join(dir, "x.o"), "-S", tmpfile)
cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux") cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("Failed to assemble: %v\n%s", err, out)
}
// The expected instruction sequence for the long branch is:
// BNEZ
// AUIPC $..., X31
// JALR X0, $..., X31
want := regexp.MustCompile(`\sBNEZ\s.*\s.*\n.*\n.*AUIPC\s\$\d+, X31.*\n.*JALR\sX0, \$\d+, ?X31`)
if !want.Match(out) {
t.Error("Missing assembly instructions")
}
// Build generated files.
cmd = testenv.Command(t, testenv.GoToolPath(t), "build", "-o", "x.exe", "-ldflags=-linkmode=internal")
cmd.Dir = dir
cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err = cmd.CombinedOutput()
if err != nil { if err != nil {
t.Errorf("Build failed: %v, output: %s", err, out) t.Errorf("Build failed: %v, output: %s", err, out)
} }
if runtime.GOARCH == "riscv64" && runtime.GOOS == "linux" {
cmd = testenv.Command(t, filepath.Join(dir, "x.exe"))
out, err = cmd.CombinedOutput()
if err != nil {
t.Errorf("Failed to run test binary: %v", err)
}
if string(out) != "1" {
t.Errorf(`Got test output %q, want "2"`, string(out))
}
}
} }
func genLargeBranch(buf *bytes.Buffer) { func genLargeBranch(buf *bytes.Buffer) {
fmt.Fprintln(buf, "TEXT f(SB),0,$0-0") fmt.Fprintln(buf, "TEXT ·x(SB),0,$0-8")
fmt.Fprintln(buf, "BEQ X0, X0, label") fmt.Fprintln(buf, "MOV X0, X10")
for i := 0; i < 1<<19; i++ { fmt.Fprintln(buf, "BEQZ X10, label")
for i := 0; i < 1<<18; i++ {
// Use a non-compressable instruction.
fmt.Fprintln(buf, "ADD $0, X5, X0") fmt.Fprintln(buf, "ADD $0, X5, X0")
} }
fmt.Fprintln(buf, "ADD $1, X10, X10")
fmt.Fprintln(buf, "label:") fmt.Fprintln(buf, "label:")
fmt.Fprintln(buf, "ADD $0, X5, X0") fmt.Fprintln(buf, "ADD $1, X10, X10")
fmt.Fprintln(buf, "MOV X10, r+0(FP)")
fmt.Fprintln(buf, "RET")
} }
// TestLargeCall generates a large function (>1MB of text) with a call to // TestLargeCall generates a large function (>1MB of text) with a call to
// a following function, in order to ensure that it assembles and links // a following function, in order to ensure that it assembles and links
// correctly. // correctly. This requires the use of AUIPC+JALR instruction sequences,
// which are fixed up by the linker.
func TestLargeCall(t *testing.T) { func TestLargeCall(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping test in short mode") t.Skip("Skipping test in short mode")
@ -69,12 +122,15 @@ func TestLargeCall(t *testing.T) {
t.Fatalf("Failed to write file: %v\n", err) t.Fatalf("Failed to write file: %v\n", err)
} }
main := `package main main := `package main
import "fmt"
func main() { func main() {
x() fmt.Print(x())
} }
func x() func x() uint64
func y() func y() uint64
` `
if err := os.WriteFile(filepath.Join(dir, "x.go"), []byte(main), 0644); err != nil { if err := os.WriteFile(filepath.Join(dir, "x.go"), []byte(main), 0644); err != nil {
t.Fatalf("failed to write main: %v\n", err) t.Fatalf("failed to write main: %v\n", err)
@ -84,12 +140,49 @@ func y()
buf := bytes.NewBuffer(make([]byte, 0, 7000000)) buf := bytes.NewBuffer(make([]byte, 0, 7000000))
genLargeCall(buf) genLargeCall(buf)
if err := os.WriteFile(filepath.Join(dir, "x.s"), buf.Bytes(), 0644); err != nil { tmpfile := filepath.Join(dir, "x.s")
if err := os.WriteFile(tmpfile, buf.Bytes(), 0644); err != nil {
t.Fatalf("Failed to write file: %v\n", err) t.Fatalf("Failed to write file: %v\n", err)
} }
// Assemble generated file.
cmd := exec.Command(testenv.GoToolPath(t), "tool", "asm", "-o", filepath.Join(dir, "x.o"), "-S", tmpfile)
cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("Failed to assemble: %v\n%s", err, out)
}
// The expected instruction sequence for the long call is:
// AUIPC $0, $0, X31
// JALR X.., X31
want := regexp.MustCompile(`\sAUIPC\s\$0, \$0, X31.*\n.*\sJALR\sX.*, X31`)
if !want.Match(out) {
t.Error("Missing assembly instructions")
}
// Build generated files. // Build generated files.
cmd := testenv.Command(t, testenv.GoToolPath(t), "build", "-ldflags=-linkmode=internal") cmd = testenv.Command(t, testenv.GoToolPath(t), "build", "-o", "x.exe", "-ldflags=-linkmode=internal")
cmd.Dir = dir
cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err = cmd.CombinedOutput()
if err != nil {
t.Errorf("Build failed: %v, output: %s", err, out)
}
if runtime.GOARCH == "riscv64" && runtime.GOOS == "linux" {
cmd = testenv.Command(t, filepath.Join(dir, "x.exe"))
out, err = cmd.CombinedOutput()
if err != nil {
t.Errorf("Failed to run test binary: %v", err)
}
if string(out) != "2" {
t.Errorf(`Got test output %q, want "2"`, string(out))
}
}
if runtime.GOARCH == "riscv64" && testenv.HasCGO() {
cmd := testenv.Command(t, testenv.GoToolPath(t), "build", "-o", "x.exe", "-ldflags=-linkmode=external")
cmd.Dir = dir cmd.Dir = dir
cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux") cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
@ -97,38 +190,44 @@ func y()
t.Errorf("Build failed: %v, output: %s", err, out) t.Errorf("Build failed: %v, output: %s", err, out)
} }
if runtime.GOARCH == "riscv64" && testenv.HasCGO() { if runtime.GOARCH == "riscv64" && runtime.GOOS == "linux" {
cmd := testenv.Command(t, testenv.GoToolPath(t), "build", "-ldflags=-linkmode=external") cmd = testenv.Command(t, filepath.Join(dir, "x.exe"))
cmd.Dir = dir out, err = cmd.CombinedOutput()
cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err := cmd.CombinedOutput()
if err != nil { if err != nil {
t.Errorf("Build failed: %v, output: %s", err, out) t.Errorf("Failed to run test binary: %v", err)
}
if string(out) != "2" {
t.Errorf(`Got test output %q, want "2"`, string(out))
}
} }
} }
} }
func genLargeCall(buf *bytes.Buffer) { func genLargeCall(buf *bytes.Buffer) {
fmt.Fprintln(buf, "TEXT ·x(SB),0,$0-0") fmt.Fprintln(buf, "TEXT ·x(SB),0,$0-8")
fmt.Fprintln(buf, "MOV X0, X10")
fmt.Fprintln(buf, "CALL ·y(SB)") fmt.Fprintln(buf, "CALL ·y(SB)")
for i := 0; i < 1<<19; i++ { fmt.Fprintln(buf, "ADD $1, X10, X10")
fmt.Fprintln(buf, "MOV X10, r+0(FP)")
fmt.Fprintln(buf, "RET")
for i := 0; i < 1<<18; i++ {
// Use a non-compressable instruction.
fmt.Fprintln(buf, "ADD $0, X5, X0") fmt.Fprintln(buf, "ADD $0, X5, X0")
} }
fmt.Fprintln(buf, "ADD $1, X10, X10")
fmt.Fprintln(buf, "RET") fmt.Fprintln(buf, "RET")
fmt.Fprintln(buf, "TEXT ·y(SB),0,$0-0") fmt.Fprintln(buf, "TEXT ·y(SB),0,$0-0")
fmt.Fprintln(buf, "ADD $0, X5, X0") fmt.Fprintln(buf, "ADD $1, X10, X10")
fmt.Fprintln(buf, "RET") fmt.Fprintln(buf, "RET")
} }
// TestLargeJump generates a large jump (>1MB of text) with a JMP to the // TestLargeJump generates a large jump (>1MB of text) with a JMP to the
// end of the function, in order to ensure that it assembles correctly. // end of the function, in order to ensure that it assembles correctly.
// This requires the use of AUIPC+JALR instruction sequences.
func TestLargeJump(t *testing.T) { func TestLargeJump(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping test in short mode") t.Skip("Skipping test in short mode")
} }
if runtime.GOARCH != "riscv64" {
t.Skip("Require riscv64 to run")
}
testenv.MustHaveGoBuild(t) testenv.MustHaveGoBuild(t)
dir := t.TempDir() dir := t.TempDir()
@ -154,32 +253,58 @@ func x() uint64
buf := bytes.NewBuffer(make([]byte, 0, 7000000)) buf := bytes.NewBuffer(make([]byte, 0, 7000000))
genLargeJump(buf) genLargeJump(buf)
if err := os.WriteFile(filepath.Join(dir, "x.s"), buf.Bytes(), 0644); err != nil { tmpfile := filepath.Join(dir, "x.s")
if err := os.WriteFile(tmpfile, buf.Bytes(), 0644); err != nil {
t.Fatalf("Failed to write file: %v\n", err) t.Fatalf("Failed to write file: %v\n", err)
} }
// Build generated files. // Assemble generated file.
cmd := testenv.Command(t, testenv.GoToolPath(t), "build", "-o", "x.exe") cmd := exec.Command(testenv.GoToolPath(t), "tool", "asm", "-o", filepath.Join(dir, "x.o"), "-S", tmpfile)
cmd.Dir = dir cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("Failed to assemble: %v\n%s", err, out)
}
// The expected instruction sequence for the long call is:
// AUIPC $..., X31
// JALR X0, $.., X31
want := regexp.MustCompile(`\sAUIPC\s\$\d+, X31.*\n.*\sJALR\sX0, \$\d+, ?X31`)
if !want.Match(out) {
t.Error("Missing assembly instructions")
t.Errorf("%s", out)
}
// Build generated files.
cmd = testenv.Command(t, testenv.GoToolPath(t), "build", "-o", "x.exe")
cmd.Dir = dir
cmd.Env = append(os.Environ(), "GOARCH=riscv64", "GOOS=linux")
out, err = cmd.CombinedOutput()
if err != nil { if err != nil {
t.Errorf("Build failed: %v, output: %s", err, out) t.Errorf("Build failed: %v, output: %s", err, out)
} }
if runtime.GOARCH == "riscv64" && runtime.GOOS == "linux" {
cmd = testenv.Command(t, filepath.Join(dir, "x.exe")) cmd = testenv.Command(t, filepath.Join(dir, "x.exe"))
out, err = cmd.CombinedOutput() out, err = cmd.CombinedOutput()
if err != nil {
t.Errorf("Failed to run test binary: %v", err)
}
if string(out) != "1" { if string(out) != "1" {
t.Errorf(`Got test output %q, want "1"`, string(out)) t.Errorf(`Got test output %q, want "1"`, string(out))
} }
} }
}
func genLargeJump(buf *bytes.Buffer) { func genLargeJump(buf *bytes.Buffer) {
fmt.Fprintln(buf, "TEXT ·x(SB),0,$0-8") fmt.Fprintln(buf, "TEXT ·x(SB),0,$0-8")
fmt.Fprintln(buf, "MOV X0, X10") fmt.Fprintln(buf, "MOV X0, X10")
fmt.Fprintln(buf, "JMP end") fmt.Fprintln(buf, "JMP end")
for i := 0; i < 1<<18; i++ { for i := 0; i < 1<<18; i++ {
fmt.Fprintln(buf, "ADD $1, X10, X10") // Use a non-compressable instruction.
fmt.Fprintln(buf, "ADD $0, X5, X0")
} }
fmt.Fprintln(buf, "ADD $1, X10, X10")
fmt.Fprintln(buf, "end:") fmt.Fprintln(buf, "end:")
fmt.Fprintln(buf, "ADD $1, X10, X10") fmt.Fprintln(buf, "ADD $1, X10, X10")
fmt.Fprintln(buf, "MOV X10, r+0(FP)") fmt.Fprintln(buf, "MOV X10, r+0(FP)")

View file

@ -95,17 +95,11 @@ func (versionFlag) Set(s string) error {
p := "" p := ""
if s == "goexperiment" {
// test/run.go uses this to discover the full set of
// experiment tags. Report everything.
p = " X:" + strings.Join(buildcfg.Experiment.All(), ",")
} else {
// If the enabled experiments differ from the baseline, // If the enabled experiments differ from the baseline,
// include that difference. // include that difference.
if goexperiment := buildcfg.Experiment.String(); goexperiment != "" { if goexperiment := buildcfg.Experiment.String(); goexperiment != "" {
p = " X:" + goexperiment p = " X:" + goexperiment
} }
}
// The go command invokes -V=full to get a unique identifier // The go command invokes -V=full to get a unique identifier
// for this tool. It is assumed that the release version is sufficient // for this tool. It is assumed that the release version is sufficient

View file

@ -334,6 +334,10 @@ const (
R_LOONG64_ADDR_HI R_LOONG64_ADDR_HI
R_LOONG64_ADDR_LO R_LOONG64_ADDR_LO
// R_LOONG64_ADDR_PCREL20_S2 resolves to the 22-bit, 4-byte aligned offset of an
// external address, by encoding it into a PCADDI instruction.
R_LOONG64_ADDR_PCREL20_S2
// R_LOONG64_TLS_LE_HI resolves to the high 20 bits of a TLS address (offset from // R_LOONG64_TLS_LE_HI resolves to the high 20 bits of a TLS address (offset from
// thread pointer), by encoding it into the instruction. // thread pointer), by encoding it into the instruction.
// R_LOONG64_TLS_LE_LO resolves to the low 12 bits of a TLS address (offset from // R_LOONG64_TLS_LE_LO resolves to the low 12 bits of a TLS address (offset from
@ -341,10 +345,14 @@ const (
R_LOONG64_TLS_LE_HI R_LOONG64_TLS_LE_HI
R_LOONG64_TLS_LE_LO R_LOONG64_TLS_LE_LO
// R_CALLLOONG64 resolves to non-PC-relative target address of a CALL (BL/JIRL) // R_CALLLOONG64 resolves to the 28-bit 4-byte aligned PC-relative target
// instruction, by encoding the address into the instruction. // address of a BL instruction, by encoding it into the instruction.
R_CALLLOONG64 R_CALLLOONG64
// R_LOONG64_CALL36 resolves to the 38-bit 4-byte aligned PC-relative target
// address of a PCADDU18I + JIRL pair, by encoding it into the instructions.
R_LOONG64_CALL36
// R_LOONG64_TLS_IE_HI and R_LOONG64_TLS_IE_LO relocates a pcalau12i, ld.d // R_LOONG64_TLS_IE_HI and R_LOONG64_TLS_IE_LO relocates a pcalau12i, ld.d
// pair to compute the address of the GOT slot of the tls symbol. // pair to compute the address of the GOT slot of the tls symbol.
R_LOONG64_TLS_IE_HI R_LOONG64_TLS_IE_HI
@ -360,14 +368,17 @@ const (
// 64-bit in-place subtraction. // 64-bit in-place subtraction.
R_LOONG64_SUB64 R_LOONG64_SUB64
// R_JMP16LOONG64 resolves to 18-bit PC-relative target address of a JMP instructions. // R_JMP16LOONG64 resolves to the 18-bit 4-byte aligned PC-relative target
// address of a BEQ/BNE/BLT/BGE/BLTU/BGEU instruction, by encoding it into
// the instruction.
R_JMP16LOONG64 R_JMP16LOONG64
// R_JMP21LOONG64 resolves to 23-bit PC-relative target address of a JMP instructions. // R_JMP21LOONG64 resolves to the 23-bit 4-byte aligned PC-relative target
// address of a BEQZ/BNEZ instruction, by encoding it into the instruction.
R_JMP21LOONG64 R_JMP21LOONG64
// R_JMPLOONG64 resolves to non-PC-relative target address of a JMP instruction, // R_JMPLOONG64 resolves to the 28-bit 4-byte aligned PC-relative target
// by encoding the address into the instruction. // address of a B instruction, by encoding it into the instruction.
R_JMPLOONG64 R_JMPLOONG64
// R_ADDRMIPSU (only used on mips/mips64) resolves to the sign-adjusted "upper" 16 // R_ADDRMIPSU (only used on mips/mips64) resolves to the sign-adjusted "upper" 16

View file

@ -86,34 +86,36 @@ func _() {
_ = x[R_PCRELDBL-76] _ = x[R_PCRELDBL-76]
_ = x[R_LOONG64_ADDR_HI-77] _ = x[R_LOONG64_ADDR_HI-77]
_ = x[R_LOONG64_ADDR_LO-78] _ = x[R_LOONG64_ADDR_LO-78]
_ = x[R_LOONG64_TLS_LE_HI-79] _ = x[R_LOONG64_ADDR_PCREL20_S2-79]
_ = x[R_LOONG64_TLS_LE_LO-80] _ = x[R_LOONG64_TLS_LE_HI-80]
_ = x[R_CALLLOONG64-81] _ = x[R_LOONG64_TLS_LE_LO-81]
_ = x[R_LOONG64_TLS_IE_HI-82] _ = x[R_CALLLOONG64-82]
_ = x[R_LOONG64_TLS_IE_LO-83] _ = x[R_LOONG64_CALL36-83]
_ = x[R_LOONG64_GOT_HI-84] _ = x[R_LOONG64_TLS_IE_HI-84]
_ = x[R_LOONG64_GOT_LO-85] _ = x[R_LOONG64_TLS_IE_LO-85]
_ = x[R_LOONG64_ADD64-86] _ = x[R_LOONG64_GOT_HI-86]
_ = x[R_LOONG64_SUB64-87] _ = x[R_LOONG64_GOT_LO-87]
_ = x[R_JMP16LOONG64-88] _ = x[R_LOONG64_ADD64-88]
_ = x[R_JMP21LOONG64-89] _ = x[R_LOONG64_SUB64-89]
_ = x[R_JMPLOONG64-90] _ = x[R_JMP16LOONG64-90]
_ = x[R_ADDRMIPSU-91] _ = x[R_JMP21LOONG64-91]
_ = x[R_ADDRMIPSTLS-92] _ = x[R_JMPLOONG64-92]
_ = x[R_ADDRCUOFF-93] _ = x[R_ADDRMIPSU-93]
_ = x[R_WASMIMPORT-94] _ = x[R_ADDRMIPSTLS-94]
_ = x[R_XCOFFREF-95] _ = x[R_ADDRCUOFF-95]
_ = x[R_PEIMAGEOFF-96] _ = x[R_WASMIMPORT-96]
_ = x[R_INITORDER-97] _ = x[R_XCOFFREF-97]
_ = x[R_DWTXTADDR_U1-98] _ = x[R_PEIMAGEOFF-98]
_ = x[R_DWTXTADDR_U2-99] _ = x[R_INITORDER-99]
_ = x[R_DWTXTADDR_U3-100] _ = x[R_DWTXTADDR_U1-100]
_ = x[R_DWTXTADDR_U4-101] _ = x[R_DWTXTADDR_U2-101]
_ = x[R_DWTXTADDR_U3-102]
_ = x[R_DWTXTADDR_U4-103]
} }
const _RelocType_name = "R_ADDRR_ADDRPOWERR_ADDRARM64R_ADDRMIPSR_ADDROFFR_SIZER_CALLR_CALLARMR_CALLARM64R_CALLINDR_CALLPOWERR_CALLMIPSR_CONSTR_PCRELR_TLS_LER_TLS_IER_GOTOFFR_PLT0R_PLT1R_PLT2R_USEFIELDR_USETYPER_USEIFACER_USEIFACEMETHODR_USENAMEDMETHODR_METHODOFFR_KEEPR_POWER_TOCR_GOTPCRELR_JMPMIPSR_DWARFSECREFR_ARM64_TLS_LER_ARM64_TLS_IER_ARM64_GOTPCRELR_ARM64_GOTR_ARM64_PCRELR_ARM64_PCREL_LDST8R_ARM64_PCREL_LDST16R_ARM64_PCREL_LDST32R_ARM64_PCREL_LDST64R_ARM64_LDST8R_ARM64_LDST16R_ARM64_LDST32R_ARM64_LDST64R_ARM64_LDST128R_POWER_TLS_LER_POWER_TLS_IER_POWER_TLSR_POWER_TLS_IE_PCREL34R_POWER_TLS_LE_TPREL34R_ADDRPOWER_DSR_ADDRPOWER_GOTR_ADDRPOWER_GOT_PCREL34R_ADDRPOWER_PCRELR_ADDRPOWER_TOCRELR_ADDRPOWER_TOCREL_DSR_ADDRPOWER_D34R_ADDRPOWER_PCREL34R_RISCV_JALR_RISCV_JAL_TRAMPR_RISCV_CALLR_RISCV_PCREL_ITYPER_RISCV_PCREL_STYPER_RISCV_TLS_IER_RISCV_TLS_LER_RISCV_GOT_HI20R_RISCV_GOT_PCREL_ITYPER_RISCV_PCREL_HI20R_RISCV_PCREL_LO12_IR_RISCV_PCREL_LO12_SR_RISCV_BRANCHR_RISCV_ADD32R_RISCV_SUB32R_RISCV_RVC_BRANCHR_RISCV_RVC_JUMPR_PCRELDBLR_LOONG64_ADDR_HIR_LOONG64_ADDR_LOR_LOONG64_TLS_LE_HIR_LOONG64_TLS_LE_LOR_CALLLOONG64R_LOONG64_TLS_IE_HIR_LOONG64_TLS_IE_LOR_LOONG64_GOT_HIR_LOONG64_GOT_LOR_LOONG64_ADD64R_LOONG64_SUB64R_JMP16LOONG64R_JMP21LOONG64R_JMPLOONG64R_ADDRMIPSUR_ADDRMIPSTLSR_ADDRCUOFFR_WASMIMPORTR_XCOFFREFR_PEIMAGEOFFR_INITORDERR_DWTXTADDR_U1R_DWTXTADDR_U2R_DWTXTADDR_U3R_DWTXTADDR_U4" const _RelocType_name = "R_ADDRR_ADDRPOWERR_ADDRARM64R_ADDRMIPSR_ADDROFFR_SIZER_CALLR_CALLARMR_CALLARM64R_CALLINDR_CALLPOWERR_CALLMIPSR_CONSTR_PCRELR_TLS_LER_TLS_IER_GOTOFFR_PLT0R_PLT1R_PLT2R_USEFIELDR_USETYPER_USEIFACER_USEIFACEMETHODR_USENAMEDMETHODR_METHODOFFR_KEEPR_POWER_TOCR_GOTPCRELR_JMPMIPSR_DWARFSECREFR_ARM64_TLS_LER_ARM64_TLS_IER_ARM64_GOTPCRELR_ARM64_GOTR_ARM64_PCRELR_ARM64_PCREL_LDST8R_ARM64_PCREL_LDST16R_ARM64_PCREL_LDST32R_ARM64_PCREL_LDST64R_ARM64_LDST8R_ARM64_LDST16R_ARM64_LDST32R_ARM64_LDST64R_ARM64_LDST128R_POWER_TLS_LER_POWER_TLS_IER_POWER_TLSR_POWER_TLS_IE_PCREL34R_POWER_TLS_LE_TPREL34R_ADDRPOWER_DSR_ADDRPOWER_GOTR_ADDRPOWER_GOT_PCREL34R_ADDRPOWER_PCRELR_ADDRPOWER_TOCRELR_ADDRPOWER_TOCREL_DSR_ADDRPOWER_D34R_ADDRPOWER_PCREL34R_RISCV_JALR_RISCV_JAL_TRAMPR_RISCV_CALLR_RISCV_PCREL_ITYPER_RISCV_PCREL_STYPER_RISCV_TLS_IER_RISCV_TLS_LER_RISCV_GOT_HI20R_RISCV_GOT_PCREL_ITYPER_RISCV_PCREL_HI20R_RISCV_PCREL_LO12_IR_RISCV_PCREL_LO12_SR_RISCV_BRANCHR_RISCV_ADD32R_RISCV_SUB32R_RISCV_RVC_BRANCHR_RISCV_RVC_JUMPR_PCRELDBLR_LOONG64_ADDR_HIR_LOONG64_ADDR_LOR_LOONG64_ADDR_PCREL20_S2R_LOONG64_TLS_LE_HIR_LOONG64_TLS_LE_LOR_CALLLOONG64R_LOONG64_CALL36R_LOONG64_TLS_IE_HIR_LOONG64_TLS_IE_LOR_LOONG64_GOT_HIR_LOONG64_GOT_LOR_LOONG64_ADD64R_LOONG64_SUB64R_JMP16LOONG64R_JMP21LOONG64R_JMPLOONG64R_ADDRMIPSUR_ADDRMIPSTLSR_ADDRCUOFFR_WASMIMPORTR_XCOFFREFR_PEIMAGEOFFR_INITORDERR_DWTXTADDR_U1R_DWTXTADDR_U2R_DWTXTADDR_U3R_DWTXTADDR_U4"
var _RelocType_index = [...]uint16{0, 6, 17, 28, 38, 47, 53, 59, 68, 79, 88, 99, 109, 116, 123, 131, 139, 147, 153, 159, 165, 175, 184, 194, 210, 226, 237, 243, 254, 264, 273, 286, 300, 314, 330, 341, 354, 373, 393, 413, 433, 446, 460, 474, 488, 503, 517, 531, 542, 564, 586, 600, 615, 638, 655, 673, 694, 709, 728, 739, 756, 768, 787, 806, 820, 834, 850, 873, 891, 911, 931, 945, 958, 971, 989, 1005, 1015, 1032, 1049, 1068, 1087, 1100, 1119, 1138, 1154, 1170, 1185, 1200, 1214, 1228, 1240, 1251, 1264, 1275, 1287, 1297, 1309, 1320, 1334, 1348, 1362, 1376} var _RelocType_index = [...]uint16{0, 6, 17, 28, 38, 47, 53, 59, 68, 79, 88, 99, 109, 116, 123, 131, 139, 147, 153, 159, 165, 175, 184, 194, 210, 226, 237, 243, 254, 264, 273, 286, 300, 314, 330, 341, 354, 373, 393, 413, 433, 446, 460, 474, 488, 503, 517, 531, 542, 564, 586, 600, 615, 638, 655, 673, 694, 709, 728, 739, 756, 768, 787, 806, 820, 834, 850, 873, 891, 911, 931, 945, 958, 971, 989, 1005, 1015, 1032, 1049, 1074, 1093, 1112, 1125, 1141, 1160, 1179, 1195, 1211, 1226, 1241, 1255, 1269, 1281, 1292, 1305, 1316, 1328, 1338, 1350, 1361, 1375, 1389, 1403, 1417}
func (i RelocType) String() string { func (i RelocType) String() string {
i -= 1 i -= 1

View file

@ -1065,13 +1065,15 @@ func relSize(arch *sys.Arch, pn string, elftype uint32) (uint8, uint8, error) {
LOONG64 | uint32(elf.R_LARCH_PCALA_LO12)<<16, LOONG64 | uint32(elf.R_LARCH_PCALA_LO12)<<16,
LOONG64 | uint32(elf.R_LARCH_GOT_PC_HI20)<<16, LOONG64 | uint32(elf.R_LARCH_GOT_PC_HI20)<<16,
LOONG64 | uint32(elf.R_LARCH_GOT_PC_LO12)<<16, LOONG64 | uint32(elf.R_LARCH_GOT_PC_LO12)<<16,
LOONG64 | uint32(elf.R_LARCH_32_PCREL)<<16: LOONG64 | uint32(elf.R_LARCH_32_PCREL)<<16,
LOONG64 | uint32(elf.R_LARCH_PCREL20_S2)<<16:
return 4, 4, nil return 4, 4, nil
case LOONG64 | uint32(elf.R_LARCH_64)<<16, case LOONG64 | uint32(elf.R_LARCH_64)<<16,
LOONG64 | uint32(elf.R_LARCH_ADD64)<<16, LOONG64 | uint32(elf.R_LARCH_ADD64)<<16,
LOONG64 | uint32(elf.R_LARCH_SUB64)<<16, LOONG64 | uint32(elf.R_LARCH_SUB64)<<16,
LOONG64 | uint32(elf.R_LARCH_64_PCREL)<<16: LOONG64 | uint32(elf.R_LARCH_64_PCREL)<<16,
LOONG64 | uint32(elf.R_LARCH_CALL36)<<16:
return 8, 8, nil return 8, 8, nil
case S390X | uint32(elf.R_390_8)<<16: case S390X | uint32(elf.R_390_8)<<16:

View file

@ -85,7 +85,8 @@ func adddynrel(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, s loade
} }
return true return true
case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_B26): case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_B26),
objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_CALL36):
if targType == sym.SDYNIMPORT { if targType == sym.SDYNIMPORT {
addpltsym(target, ldr, syms, targ) addpltsym(target, ldr, syms, targ)
su := ldr.MakeSymbolUpdater(s) su := ldr.MakeSymbolUpdater(s)
@ -95,8 +96,12 @@ func adddynrel(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, s loade
if targType == 0 || targType == sym.SXREF { if targType == 0 || targType == sym.SXREF {
ldr.Errorf(s, "unknown symbol %s in callloong64", ldr.SymName(targ)) ldr.Errorf(s, "unknown symbol %s in callloong64", ldr.SymName(targ))
} }
relocType := objabi.R_CALLLOONG64
if r.Type() == objabi.ElfRelocOffset+objabi.RelocType(elf.R_LARCH_CALL36) {
relocType = objabi.R_LOONG64_CALL36
}
su := ldr.MakeSymbolUpdater(s) su := ldr.MakeSymbolUpdater(s)
su.SetRelocType(rIdx, objabi.R_CALLLOONG64) su.SetRelocType(rIdx, relocType)
return true return true
case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_GOT_PC_HI20), case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_GOT_PC_HI20),
@ -117,7 +122,8 @@ func adddynrel(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, s loade
return true return true
case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_PCALA_HI20), case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_PCALA_HI20),
objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_PCALA_LO12): objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_PCALA_LO12),
objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_PCREL20_S2):
if targType == sym.SDYNIMPORT { if targType == sym.SDYNIMPORT {
ldr.Errorf(s, "unexpected relocation for dynamic symbol %s", ldr.SymName(targ)) ldr.Errorf(s, "unexpected relocation for dynamic symbol %s", ldr.SymName(targ))
} }
@ -125,12 +131,17 @@ func adddynrel(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, s loade
ldr.Errorf(s, "unknown symbol %s", ldr.SymName(targ)) ldr.Errorf(s, "unknown symbol %s", ldr.SymName(targ))
} }
su := ldr.MakeSymbolUpdater(s) var relocType objabi.RelocType
if r.Type() == objabi.ElfRelocOffset+objabi.RelocType(elf.R_LARCH_PCALA_HI20) { switch r.Type() - objabi.ElfRelocOffset {
su.SetRelocType(rIdx, objabi.R_LOONG64_ADDR_HI) case objabi.RelocType(elf.R_LARCH_PCALA_HI20):
} else { relocType = objabi.R_LOONG64_ADDR_HI
su.SetRelocType(rIdx, objabi.R_LOONG64_ADDR_LO) case objabi.RelocType(elf.R_LARCH_PCALA_LO12):
relocType = objabi.R_LOONG64_ADDR_LO
case objabi.RelocType(elf.R_LARCH_PCREL20_S2):
relocType = objabi.R_LOONG64_ADDR_PCREL20_S2
} }
su := ldr.MakeSymbolUpdater(s)
su.SetRelocType(rIdx, relocType)
return true return true
case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_ADD64), case objabi.ElfRelocOffset + objabi.RelocType(elf.R_LARCH_ADD64),
@ -418,6 +429,11 @@ func elfreloc1(ctxt *ld.Link, out *ld.OutBuf, ldr *loader.Loader, s loader.Sym,
out.Write64(uint64(elf.R_LARCH_B26) | uint64(elfsym)<<32) out.Write64(uint64(elf.R_LARCH_B26) | uint64(elfsym)<<32)
out.Write64(uint64(r.Xadd)) out.Write64(uint64(r.Xadd))
case objabi.R_LOONG64_CALL36:
out.Write64(uint64(sectoff))
out.Write64(uint64(elf.R_LARCH_CALL36) | uint64(elfsym)<<32)
out.Write64(uint64(r.Xadd))
case objabi.R_LOONG64_TLS_IE_HI: case objabi.R_LOONG64_TLS_IE_HI:
out.Write64(uint64(sectoff)) out.Write64(uint64(sectoff))
out.Write64(uint64(elf.R_LARCH_TLS_IE_PC_HI20) | uint64(elfsym)<<32) out.Write64(uint64(elf.R_LARCH_TLS_IE_PC_HI20) | uint64(elfsym)<<32)
@ -438,6 +454,11 @@ func elfreloc1(ctxt *ld.Link, out *ld.OutBuf, ldr *loader.Loader, s loader.Sym,
out.Write64(uint64(elf.R_LARCH_PCALA_HI20) | uint64(elfsym)<<32) out.Write64(uint64(elf.R_LARCH_PCALA_HI20) | uint64(elfsym)<<32)
out.Write64(uint64(r.Xadd)) out.Write64(uint64(r.Xadd))
case objabi.R_LOONG64_ADDR_PCREL20_S2:
out.Write64(uint64(sectoff))
out.Write64(uint64(elf.R_LARCH_PCREL20_S2) | uint64(elfsym)<<32)
out.Write64(uint64(r.Xadd))
case objabi.R_LOONG64_GOT_HI: case objabi.R_LOONG64_GOT_HI:
out.Write64(uint64(sectoff)) out.Write64(uint64(sectoff))
out.Write64(uint64(elf.R_LARCH_GOT_PC_HI20) | uint64(elfsym)<<32) out.Write64(uint64(elf.R_LARCH_GOT_PC_HI20) | uint64(elfsym)<<32)
@ -463,7 +484,8 @@ func archreloc(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, r loade
default: default:
return val, 0, false return val, 0, false
case objabi.R_LOONG64_ADDR_HI, case objabi.R_LOONG64_ADDR_HI,
objabi.R_LOONG64_ADDR_LO: objabi.R_LOONG64_ADDR_LO,
objabi.R_LOONG64_ADDR_PCREL20_S2:
// set up addend for eventual relocation via outer symbol. // set up addend for eventual relocation via outer symbol.
rs, _ := ld.FoldSubSymbolOffset(ldr, rs) rs, _ := ld.FoldSubSymbolOffset(ldr, rs)
rst := ldr.SymType(rs) rst := ldr.SymType(rs)
@ -474,6 +496,7 @@ func archreloc(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, r loade
case objabi.R_LOONG64_TLS_LE_HI, case objabi.R_LOONG64_TLS_LE_HI,
objabi.R_LOONG64_TLS_LE_LO, objabi.R_LOONG64_TLS_LE_LO,
objabi.R_CALLLOONG64, objabi.R_CALLLOONG64,
objabi.R_LOONG64_CALL36,
objabi.R_JMPLOONG64, objabi.R_JMPLOONG64,
objabi.R_LOONG64_TLS_IE_HI, objabi.R_LOONG64_TLS_IE_HI,
objabi.R_LOONG64_TLS_IE_LO, objabi.R_LOONG64_TLS_IE_LO,
@ -499,6 +522,10 @@ func archreloc(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, r loade
return val&0xffc003ff | (t << 10), noExtReloc, isOk return val&0xffc003ff | (t << 10), noExtReloc, isOk
} }
return val&0xfe00001f | (t << 5), noExtReloc, isOk return val&0xfe00001f | (t << 5), noExtReloc, isOk
case objabi.R_LOONG64_ADDR_PCREL20_S2:
pc := ldr.SymValue(s) + int64(r.Off())
t := (ldr.SymAddr(rs) + r.Add() - pc) >> 2
return val&0xfe00001f | ((t & 0xfffff) << 5), noExtReloc, isOk
case objabi.R_LOONG64_TLS_LE_HI, case objabi.R_LOONG64_TLS_LE_HI,
objabi.R_LOONG64_TLS_LE_LO: objabi.R_LOONG64_TLS_LE_LO:
t := ldr.SymAddr(rs) + r.Add() t := ldr.SymAddr(rs) + r.Add()
@ -512,6 +539,14 @@ func archreloc(target *ld.Target, ldr *loader.Loader, syms *ld.ArchSyms, r loade
t := ldr.SymAddr(rs) + r.Add() - pc t := ldr.SymAddr(rs) + r.Add() - pc
return val&0xfc000000 | (((t >> 2) & 0xffff) << 10) | (((t >> 2) & 0x3ff0000) >> 16), noExtReloc, isOk return val&0xfc000000 | (((t >> 2) & 0xffff) << 10) | (((t >> 2) & 0x3ff0000) >> 16), noExtReloc, isOk
case objabi.R_LOONG64_CALL36:
pc := ldr.SymValue(s) + int64(r.Off())
t := (ldr.SymAddr(rs) + r.Add() - pc) >> 2
// val is pcaddu18i (lower half) + jirl (upper half)
pcaddu18i := (val & 0xfe00001f) | (((t + 0x8000) >> 16) << 5)
jirl := ((val >> 32) & 0xfc0003ff) | ((t & 0xffff) << 10)
return pcaddu18i | (jirl << 32), noExtReloc, isOk
case objabi.R_JMP16LOONG64, case objabi.R_JMP16LOONG64,
objabi.R_JMP21LOONG64: objabi.R_JMP21LOONG64:
pc := ldr.SymValue(s) + int64(r.Off()) pc := ldr.SymValue(s) + int64(r.Off())

View file

@ -21,6 +21,7 @@ import (
"golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/internal/analysis/analyzerutil" "golang.org/x/tools/internal/analysis/analyzerutil"
typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
"golang.org/x/tools/internal/astutil" "golang.org/x/tools/internal/astutil"
"golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/diff"
"golang.org/x/tools/internal/moreiters" "golang.org/x/tools/internal/moreiters"
@ -28,6 +29,7 @@ import (
"golang.org/x/tools/internal/refactor" "golang.org/x/tools/internal/refactor"
"golang.org/x/tools/internal/refactor/inline" "golang.org/x/tools/internal/refactor/inline"
"golang.org/x/tools/internal/typesinternal" "golang.org/x/tools/internal/typesinternal"
"golang.org/x/tools/internal/typesinternal/typeindex"
) )
//go:embed doc.go //go:embed doc.go
@ -43,20 +45,29 @@ var Analyzer = &analysis.Analyzer{
(*goFixInlineConstFact)(nil), (*goFixInlineConstFact)(nil),
(*goFixInlineAliasFact)(nil), (*goFixInlineAliasFact)(nil),
}, },
Requires: []*analysis.Analyzer{inspect.Analyzer}, Requires: []*analysis.Analyzer{
inspect.Analyzer,
typeindexanalyzer.Analyzer,
},
} }
var allowBindingDecl bool var (
allowBindingDecl bool
lazyEdits bool
)
func init() { func init() {
Analyzer.Flags.BoolVar(&allowBindingDecl, "allow_binding_decl", false, Analyzer.Flags.BoolVar(&allowBindingDecl, "allow_binding_decl", false,
"permit inlinings that require a 'var params = args' declaration") "permit inlinings that require a 'var params = args' declaration")
Analyzer.Flags.BoolVar(&lazyEdits, "lazy_edits", false,
"compute edits lazily (only meaningful to gopls driver)")
} }
// analyzer holds the state for this analysis. // analyzer holds the state for this analysis.
type analyzer struct { type analyzer struct {
pass *analysis.Pass pass *analysis.Pass
root inspector.Cursor root inspector.Cursor
index *typeindex.Index
// memoization of repeated calls for same file. // memoization of repeated calls for same file.
fileContent map[string][]byte fileContent map[string][]byte
// memoization of fact imports (nil => no fact) // memoization of fact imports (nil => no fact)
@ -69,6 +80,7 @@ func run(pass *analysis.Pass) (any, error) {
a := &analyzer{ a := &analyzer{
pass: pass, pass: pass,
root: pass.ResultOf[inspect.Analyzer].(*inspector.Inspector).Root(), root: pass.ResultOf[inspect.Analyzer].(*inspector.Inspector).Root(),
index: pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index),
fileContent: make(map[string][]byte), fileContent: make(map[string][]byte),
inlinableFuncs: make(map[*types.Func]*inline.Callee), inlinableFuncs: make(map[*types.Func]*inline.Callee),
inlinableConsts: make(map[*types.Const]*goFixInlineConstFact), inlinableConsts: make(map[*types.Const]*goFixInlineConstFact),
@ -170,6 +182,27 @@ func (a *analyzer) inlineCall(call *ast.CallExpr, cur inspector.Cursor) {
return // don't inline a function from within its own test return // don't inline a function from within its own test
} }
// Compute the edits.
//
// Ordinarily the analyzer reports a fix containing
// edits. However, the algorithm is somewhat expensive
// (unnecessarily so: see go.dev/issue/75773) so
// to reduce costs in gopls, we omit the edits,
// meaning that gopls must compute them on demand
// (based on the Diagnostic.Category) when they are
// requested via a code action.
//
// This does mean that the following categories of
// caller-dependent obstacles to inlining will be
// reported when the gopls user requests the fix,
// rather than by quietly suppressing the diagnostic:
// - shadowing problems
// - callee imports inaccessible "internal" packages
// - callee refers to nonexported symbols
// - callee uses too-new Go features
// - inlining call from a cgo file
var edits []analysis.TextEdit
if !lazyEdits {
// Inline the call. // Inline the call.
content, err := a.readFile(call) content, err := a.readFile(call)
if err != nil { if err != nil {
@ -184,6 +217,9 @@ func (a *analyzer) inlineCall(call *ast.CallExpr, cur inspector.Cursor) {
File: curFile, File: curFile,
Call: call, Call: call,
Content: content, Content: content,
CountUses: func(pkgname *types.PkgName) int {
return moreiters.Len(a.index.Uses(pkgname))
},
} }
res, err := inline.Inline(caller, callee, &inline.Options{Logf: discard}) res, err := inline.Inline(caller, callee, &inline.Options{Logf: discard})
if err != nil { if err != nil {
@ -211,22 +247,23 @@ func (a *analyzer) inlineCall(call *ast.CallExpr, cur inspector.Cursor) {
} }
got := res.Content got := res.Content
// Suggest the "fix".
var textEdits []analysis.TextEdit
for _, edit := range diff.Bytes(content, got) { for _, edit := range diff.Bytes(content, got) {
textEdits = append(textEdits, analysis.TextEdit{ edits = append(edits, analysis.TextEdit{
Pos: curFile.FileStart + token.Pos(edit.Start), Pos: curFile.FileStart + token.Pos(edit.Start),
End: curFile.FileStart + token.Pos(edit.End), End: curFile.FileStart + token.Pos(edit.End),
NewText: []byte(edit.New), NewText: []byte(edit.New),
}) })
} }
}
a.pass.Report(analysis.Diagnostic{ a.pass.Report(analysis.Diagnostic{
Pos: call.Pos(), Pos: call.Pos(),
End: call.End(), End: call.End(),
Message: fmt.Sprintf("Call of %v should be inlined", callee), Message: fmt.Sprintf("Call of %v should be inlined", callee),
Category: "inline_call", // keep consistent with gopls/internal/golang.fixInlineCall
SuggestedFixes: []analysis.SuggestedFix{{ SuggestedFixes: []analysis.SuggestedFix{{
Message: fmt.Sprintf("Inline call of %v", callee), Message: fmt.Sprintf("Inline call of %v", callee),
TextEdits: textEdits, TextEdits: edits, // within gopls, this is nil => compute fix's edits lazily
}}, }},
}) })
} }

View file

@ -233,17 +233,20 @@ func mapsloop(pass *analysis.Pass) (any, error) {
assign := rng.Body.List[0].(*ast.AssignStmt) assign := rng.Body.List[0].(*ast.AssignStmt)
if index, ok := assign.Lhs[0].(*ast.IndexExpr); ok && if index, ok := assign.Lhs[0].(*ast.IndexExpr); ok &&
astutil.EqualSyntax(rng.Key, index.Index) && astutil.EqualSyntax(rng.Key, index.Index) &&
astutil.EqualSyntax(rng.Value, assign.Rhs[0]) && astutil.EqualSyntax(rng.Value, assign.Rhs[0]) {
is[*types.Map](typeparams.CoreType(info.TypeOf(index.X))) && if tmap, ok := typeparams.CoreType(info.TypeOf(index.X)).(*types.Map); ok &&
types.Identical(info.TypeOf(index), info.TypeOf(rng.Value)) { // m[k], v types.Identical(info.TypeOf(index), info.TypeOf(rng.Value)) && // m[k], v
types.Identical(tmap.Key(), info.TypeOf(rng.Key)) {
// Have: for k, v := range x { m[k] = v } // Have: for k, v := range x { m[k] = v }
// where there is no implicit conversion. // where there is no implicit conversion
// of either key or value.
check(file, curRange, assign, index.X, rng.X) check(file, curRange, assign, index.X, rng.X)
} }
} }
} }
} }
}
return nil, nil return nil, nil
} }

View file

@ -489,7 +489,7 @@ func isNegativeConst(info *types.Info, expr ast.Expr) bool {
return false return false
} }
// isNoneNegativeConst returns true if the expr is a const int with value >= zero. // isNonNegativeConst returns true if the expr is a const int with value >= zero.
func isNonNegativeConst(info *types.Info, expr ast.Expr) bool { func isNonNegativeConst(info *types.Info, expr ast.Expr) bool {
if tv, ok := info.Types[expr]; ok && tv.Value != nil && tv.Value.Kind() == constant.Int { if tv, ok := info.Types[expr]; ok && tv.Value != nil && tv.Value.Kind() == constant.Int {
if v, ok := constant.Int64Val(tv.Value); ok { if v, ok := constant.Int64Val(tv.Value); ok {

View file

@ -92,6 +92,36 @@
// } // }
// logf("%s", 123) // logf format %s has arg 123 of wrong type int // logf("%s", 123) // logf format %s has arg 123 of wrong type int
// //
// Interface methods may also be analyzed as printf wrappers, if
// within the interface's package there is an assignment from a
// implementation type whose corresponding method is a printf wrapper.
//
// For example, the var declaration below causes a *myLoggerImpl value
// to be assigned to a Logger variable:
//
// type Logger interface {
// Logf(format string, args ...any)
// }
//
// type myLoggerImpl struct{ ... }
//
// var _ Logger = (*myLoggerImpl)(nil)
//
// func (*myLoggerImpl) Logf(format string, args ...any) {
// println(fmt.Sprintf(format, args...))
// }
//
// Since myLoggerImpl's Logf method is a printf wrapper, this
// establishes that Logger.Logf is a printf wrapper too, causing
// dynamic calls through the interface to be checked:
//
// func f(log Logger) {
// log.Logf("%s", 123) // Logger.Logf format %s has arg 123 of wrong type int
// }
//
// This feature applies only to interface methods declared in files
// using at least Go 1.26.
//
// # Specifying printf wrappers by flag // # Specifying printf wrappers by flag
// //
// The -funcs flag specifies a comma-separated list of names of // The -funcs flag specifies a comma-separated list of names of

View file

@ -27,6 +27,7 @@ import (
"golang.org/x/tools/internal/typeparams" "golang.org/x/tools/internal/typeparams"
"golang.org/x/tools/internal/typesinternal" "golang.org/x/tools/internal/typesinternal"
"golang.org/x/tools/internal/versions" "golang.org/x/tools/internal/versions"
"golang.org/x/tools/refactor/satisfy"
) )
func init() { func init() {
@ -65,7 +66,7 @@ func (kind Kind) String() string {
case KindErrorf: case KindErrorf:
return "errorf" return "errorf"
} }
return "" return "(none)"
} }
// Result is the printf analyzer's result type. Clients may query the result // Result is the printf analyzer's result type. Clients may query the result
@ -138,7 +139,7 @@ type wrapper struct {
type printfCaller struct { type printfCaller struct {
w *wrapper w *wrapper
call *ast.CallExpr call *ast.CallExpr // forwarding call (nil for implicit interface method -> impl calls)
} }
// formatArgsParams returns the "format string" and "args ...any" // formatArgsParams returns the "format string" and "args ...any"
@ -183,60 +184,12 @@ func findPrintLike(pass *analysis.Pass, res *Result) {
wrappers []*wrapper wrappers []*wrapper
byObj = make(map[types.Object]*wrapper) byObj = make(map[types.Object]*wrapper)
) )
for cur := range inspect.Root().Preorder((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)) { for cur := range inspect.Root().Preorder((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil), (*ast.InterfaceType)(nil)) {
var (
curBody inspector.Cursor // for *ast.BlockStmt
sig *types.Signature
obj types.Object
)
switch f := cur.Node().(type) {
case *ast.FuncDecl:
// named function or method:
//
// func wrapf(format string, args ...any) {...}
if f.Body != nil {
curBody = cur.ChildAt(edge.FuncDecl_Body, -1)
obj = info.Defs[f.Name]
sig = obj.Type().(*types.Signature)
}
case *ast.FuncLit: // addWrapper records that a func (or var representing
// anonymous function directly assigned to a variable: // a FuncLit) is a potential print{,f} wrapper.
// // curBody is its *ast.BlockStmt, if any.
// var wrapf = func(format string, args ...any) {...} addWrapper := func(obj types.Object, sig *types.Signature, curBody inspector.Cursor) *wrapper {
// wrapf := func(format string, args ...any) {...}
// wrapf = func(format string, args ...any) {...}
//
// The LHS may also be a struct field x.wrapf or
// an imported var pkg.Wrapf.
//
sig = info.TypeOf(f).(*types.Signature)
curBody = cur.ChildAt(edge.FuncLit_Body, -1)
var lhs ast.Expr
switch ek, idx := cur.ParentEdge(); ek {
case edge.ValueSpec_Values:
curName := cur.Parent().ChildAt(edge.ValueSpec_Names, idx)
lhs = curName.Node().(*ast.Ident)
case edge.AssignStmt_Rhs:
curLhs := cur.Parent().ChildAt(edge.AssignStmt_Lhs, idx)
lhs = curLhs.Node().(ast.Expr)
}
switch lhs := lhs.(type) {
case *ast.Ident:
// variable: wrapf = func(...)
obj = info.ObjectOf(lhs).(*types.Var)
case *ast.SelectorExpr:
if sel, ok := info.Selections[lhs]; ok {
// struct field: x.wrapf = func(...)
obj = sel.Obj().(*types.Var)
} else {
// imported var: pkg.Wrapf = func(...)
obj = info.Uses[lhs.Sel].(*types.Var)
}
}
}
if obj != nil {
format, args := formatArgsParams(sig) format, args := formatArgsParams(sig)
if args != nil { if args != nil {
// obj (the symbol for a function/method, or variable // obj (the symbol for a function/method, or variable
@ -254,17 +207,124 @@ func findPrintLike(pass *analysis.Pass, res *Result) {
} }
byObj[w.obj] = w byObj[w.obj] = w
wrappers = append(wrappers, w) wrappers = append(wrappers, w)
return w
}
return nil
}
switch f := cur.Node().(type) {
case *ast.FuncDecl:
// named function or method:
//
// func wrapf(format string, args ...any) {...}
if f.Body != nil {
fn := info.Defs[f.Name].(*types.Func)
addWrapper(fn, fn.Signature(), cur.ChildAt(edge.FuncDecl_Body, -1))
}
case *ast.FuncLit:
// anonymous function directly assigned to a variable:
//
// var wrapf = func(format string, args ...any) {...}
// wrapf := func(format string, args ...any) {...}
// wrapf = func(format string, args ...any) {...}
//
// The LHS may also be a struct field x.wrapf or
// an imported var pkg.Wrapf.
//
var lhs ast.Expr
switch ek, idx := cur.ParentEdge(); ek {
case edge.ValueSpec_Values:
curName := cur.Parent().ChildAt(edge.ValueSpec_Names, idx)
lhs = curName.Node().(*ast.Ident)
case edge.AssignStmt_Rhs:
curLhs := cur.Parent().ChildAt(edge.AssignStmt_Lhs, idx)
lhs = curLhs.Node().(ast.Expr)
}
var v *types.Var
switch lhs := lhs.(type) {
case *ast.Ident:
// variable: wrapf = func(...)
v = info.ObjectOf(lhs).(*types.Var)
case *ast.SelectorExpr:
if sel, ok := info.Selections[lhs]; ok {
// struct field: x.wrapf = func(...)
v = sel.Obj().(*types.Var)
} else {
// imported var: pkg.Wrapf = func(...)
v = info.Uses[lhs.Sel].(*types.Var)
}
}
if v != nil {
sig := info.TypeOf(f).(*types.Signature)
curBody := cur.ChildAt(edge.FuncLit_Body, -1)
addWrapper(v, sig, curBody)
}
case *ast.InterfaceType:
// Induction through interface methods is gated as
// if it were a go1.26 language feature, to avoid
// surprises when go test's vet suite gets stricter.
if analyzerutil.FileUsesGoVersion(pass, astutil.EnclosingFile(cur), versions.Go1_26) {
for imeth := range info.TypeOf(f).(*types.Interface).Methods() {
addWrapper(imeth, imeth.Signature(), inspector.Cursor{})
}
} }
} }
} }
// impls maps abstract methods to implementations.
//
// Interface methods are modelled as if they have a body
// that calls each implementing method.
//
// In the code below, impls maps Logger.Logf to
// [myLogger.Logf], and if myLogger.Logf is discovered to be
// printf-like, then so will be Logger.Logf.
//
// type Logger interface {
// Logf(format string, args ...any)
// }
// type myLogger struct{ ... }
// func (myLogger) Logf(format string, args ...any) {...}
// var _ Logger = myLogger{}
impls := methodImplementations(pass)
// Pass 2: scan the body of each wrapper function // Pass 2: scan the body of each wrapper function
// for calls to other printf-like functions. // for calls to other printf-like functions.
//
// Also, reject tricky cases where the parameters
// are potentially mutated by AssignStmt or UnaryExpr.
// TODO: Relax these checks; issue 26555.
for _, w := range wrappers { for _, w := range wrappers {
// doCall records a call from one wrapper to another.
doCall := func(callee types.Object, call *ast.CallExpr) {
// Call from one wrapper candidate to another?
// Record the edge so that if callee is found to be
// a true wrapper, w will be too.
if w2, ok := byObj[callee]; ok {
w2.callers = append(w2.callers, printfCaller{w, call})
}
// Is the candidate a true wrapper, because it calls
// a known print{,f}-like function from the allowlist
// or an imported fact, or another wrapper found
// to be a true wrapper?
// If so, convert all w's callers to kind.
kind := callKind(pass, callee, res)
if kind != KindNone {
checkForward(pass, w, call, kind, res)
}
}
// An interface method has no body, but acts
// like an implicit call to each implementing method.
if w.curBody.Inspector() == nil {
for impl := range impls[w.obj.(*types.Func)] {
doCall(impl, nil)
}
continue // (no body)
}
// Process all calls in the wrapper function's body.
scan: scan:
for cur := range w.curBody.Preorder( for cur := range w.curBody.Preorder(
(*ast.AssignStmt)(nil), (*ast.AssignStmt)(nil),
@ -272,6 +332,12 @@ func findPrintLike(pass *analysis.Pass, res *Result) {
(*ast.CallExpr)(nil), (*ast.CallExpr)(nil),
) { ) {
switch n := cur.Node().(type) { switch n := cur.Node().(type) {
// Reject tricky cases where the parameters
// are potentially mutated by AssignStmt or UnaryExpr.
// (This logic checks for mutation only before the call.)
// TODO: Relax these checks; issue 26555.
case *ast.AssignStmt: case *ast.AssignStmt:
// If the wrapper updates format or args // If the wrapper updates format or args
// it is not a simple wrapper. // it is not a simple wrapper.
@ -294,28 +360,53 @@ func findPrintLike(pass *analysis.Pass, res *Result) {
case *ast.CallExpr: case *ast.CallExpr:
if len(n.Args) > 0 && match(info, n.Args[len(n.Args)-1], w.args) { if len(n.Args) > 0 && match(info, n.Args[len(n.Args)-1], w.args) {
if callee := typeutil.Callee(pass.TypesInfo, n); callee != nil { if callee := typeutil.Callee(pass.TypesInfo, n); callee != nil {
doCall(callee, n)
// Call from one wrapper candidate to another? }
// Record the edge so that if callee is found to be }
// a true wrapper, w will be too. }
if w2, ok := byObj[callee]; ok { }
w2.callers = append(w2.callers, printfCaller{w, n}) }
} }
// Is the candidate a true wrapper, because it calls // methodImplementations returns the mapping from interface methods
// a known print{,f}-like function from the allowlist // declared in this package to their corresponding implementing
// or an imported fact, or another wrapper found // methods (which may also be interface methods), according to the set
// to be a true wrapper? // of assignments to interface types that appear within this package.
// If so, convert all w's callers to kind. func methodImplementations(pass *analysis.Pass) map[*types.Func]map[*types.Func]bool {
kind := callKind(pass, callee, res) impls := make(map[*types.Func]map[*types.Func]bool)
if kind != KindNone {
checkForward(pass, w, n, kind, res) // To find interface/implementation relations,
} // we use the 'satisfy' pass, but proposal #70638
} // provides a better way.
} //
} // This pass over the syntax could be factored out as
// a separate analysis pass if it is needed by other
// analyzers.
var f satisfy.Finder
f.Find(pass.TypesInfo, pass.Files)
for assign := range f.Result {
// Have: LHS = RHS, where LHS is an interface type.
for imeth := range assign.LHS.Underlying().(*types.Interface).Methods() {
// Limit to interface methods of current package.
if imeth.Pkg() != pass.Pkg {
continue
}
if _, args := formatArgsParams(imeth.Signature()); args == nil {
continue // not print{,f}-like
}
// Add implementing method to the set.
impl, _, _ := types.LookupFieldOrMethod(assign.RHS, false, pass.Pkg, imeth.Name()) // can't fail
set, ok := impls[imeth]
if !ok {
set = make(map[*types.Func]bool)
impls[imeth] = set
}
set[impl.(*types.Func)] = true
} }
} }
return impls
} }
func match(info *types.Info, arg ast.Expr, param *types.Var) bool { func match(info *types.Info, arg ast.Expr, param *types.Var) bool {
@ -323,9 +414,16 @@ func match(info *types.Info, arg ast.Expr, param *types.Var) bool {
return ok && info.ObjectOf(id) == param return ok && info.ObjectOf(id) == param
} }
// checkForward checks that a forwarding wrapper is forwarding correctly. // checkForward checks whether a forwarding wrapper is forwarding correctly.
// It diagnoses writing fmt.Printf(format, args) instead of fmt.Printf(format, args...). // If so, it propagates changes in wrapper kind information backwards
// through through the wrapper.callers graph of forwarding calls.
//
// If not, it reports a diagnostic that the user wrote
// fmt.Printf(format, args) instead of fmt.Printf(format, args...).
func checkForward(pass *analysis.Pass, w *wrapper, call *ast.CallExpr, kind Kind, res *Result) { func checkForward(pass *analysis.Pass, w *wrapper, call *ast.CallExpr, kind Kind, res *Result) {
// Check correct call forwarding.
// (Interface methods forward correctly by construction.)
if call != nil {
matched := kind == KindPrint || matched := kind == KindPrint ||
kind != KindNone && len(call.Args) >= 2 && match(pass.TypesInfo, call.Args[len(call.Args)-2], w.format) kind != KindNone && len(call.Args) >= 2 && match(pass.TypesInfo, call.Args[len(call.Args)-2], w.format)
if !matched { if !matched {
@ -354,6 +452,7 @@ func checkForward(pass *analysis.Pass, w *wrapper, call *ast.CallExpr, kind Kind
pass.ReportRangef(call, "missing ... in args forwarded to %s-like function", desc) pass.ReportRangef(call, "missing ... in args forwarded to %s-like function", desc)
return return
} }
}
// If the candidate's print{,f} status becomes known, // If the candidate's print{,f} status becomes known,
// propagate it back to all its so-far known callers. // propagate it back to all its so-far known callers.
@ -444,8 +543,6 @@ var isPrint = stringSet{
"(*testing.common).Logf": true, "(*testing.common).Logf": true,
"(*testing.common).Skip": true, "(*testing.common).Skip": true,
"(*testing.common).Skipf": true, "(*testing.common).Skipf": true,
// *testing.T and B are detected by induction, but testing.TB is
// an interface and the inference can't follow dynamic calls.
"(testing.TB).Error": true, "(testing.TB).Error": true,
"(testing.TB).Errorf": true, "(testing.TB).Errorf": true,
"(testing.TB).Fatal": true, "(testing.TB).Fatal": true,

View file

@ -64,7 +64,7 @@ func NodeContains(n ast.Node, rng Range) bool {
return NodeRange(n).Contains(rng) return NodeRange(n).Contains(rng)
} }
// NodeContainPos reports whether the Pos/End range of node n encloses // NodeContainsPos reports whether the Pos/End range of node n encloses
// the given pos. // the given pos.
// //
// Like [NodeRange], it treats the range of an [ast.File] as the // Like [NodeRange], it treats the range of an [ast.File] as the

View file

@ -45,3 +45,11 @@ func Any[T any](seq iter.Seq[T], pred func(T) bool) bool {
} }
return false return false
} }
// Len returns the number of elements in the sequence (by iterating).
func Len[T any](seq iter.Seq[T]) (n int) {
for range seq {
n++
}
return
}

View file

@ -40,7 +40,11 @@ type Caller struct {
Info *types.Info Info *types.Info
File *ast.File File *ast.File
Call *ast.CallExpr Call *ast.CallExpr
Content []byte // source of file containing Content []byte // source of file containing (TODO(adonovan): see comment at Result.Content)
// CountUses is an optional optimized computation of
// the number of times pkgname appears in Info.Uses.
CountUses func(pkgname *types.PkgName) int
path []ast.Node // path from call to root of file syntax tree path []ast.Node // path from call to root of file syntax tree
enclosingFunc *ast.FuncDecl // top-level function/method enclosing the call, if any enclosingFunc *ast.FuncDecl // top-level function/method enclosing the call, if any
@ -57,6 +61,18 @@ type Options struct {
// Result holds the result of code transformation. // Result holds the result of code transformation.
type Result struct { type Result struct {
// TODO(adonovan): the only textual results that should be
// needed are (1) an edit in the vicinity of the call (either
// to the CallExpr or one of its ancestors), and optionally
// (2) an edit to the import declaration.
// Change the inliner API to return a list of edits,
// and not to accept a Caller.Content, as it is only
// temptation to use such algorithmically expensive
// operations as reformatting the entire file, which is
// a significant source of non-linear dynamic behavior;
// see https://go.dev/issue/75773.
// This will require a sequence of changes to the tests
// and the inliner algorithm itself.
Content []byte // formatted, transformed content of caller file Content []byte // formatted, transformed content of caller file
Literalized bool // chosen strategy replaced callee() with func(){...}() Literalized bool // chosen strategy replaced callee() with func(){...}()
BindingDecl bool // transformation added "var params = args" declaration BindingDecl bool // transformation added "var params = args" declaration
@ -432,27 +448,19 @@ func newImportState(logf func(string, ...any), caller *Caller, callee *gobCallee
importMap: make(map[string][]string), importMap: make(map[string][]string),
} }
// Build an index of used-once PkgNames. // Provide an inefficient default implementation of CountUses.
type pkgNameUse struct { // (Ideally clients amortize this for the entire package.)
count int countUses := caller.CountUses
id *ast.Ident // an arbitrary use if countUses == nil {
} uses := make(map[*types.PkgName]int)
pkgNameUses := make(map[*types.PkgName]pkgNameUse) for _, obj := range caller.Info.Uses {
for id, obj := range caller.Info.Uses {
if pkgname, ok := obj.(*types.PkgName); ok { if pkgname, ok := obj.(*types.PkgName); ok {
u := pkgNameUses[pkgname] uses[pkgname]++
u.id = id
u.count++
pkgNameUses[pkgname] = u
} }
} }
// soleUse returns the ident that refers to pkgname, if there is exactly one. countUses = func(pkgname *types.PkgName) int {
soleUse := func(pkgname *types.PkgName) *ast.Ident { return uses[pkgname]
u := pkgNameUses[pkgname]
if u.count == 1 {
return u.id
} }
return nil
} }
for _, imp := range caller.File.Imports { for _, imp := range caller.File.Imports {
@ -472,8 +480,10 @@ func newImportState(logf func(string, ...any), caller *Caller, callee *gobCallee
// If that is the case, proactively check if any of the callee FreeObjs // If that is the case, proactively check if any of the callee FreeObjs
// need this import. Doing so eagerly simplifies the resulting logic. // need this import. Doing so eagerly simplifies the resulting logic.
needed := true needed := true
sel, ok := ast.Unparen(caller.Call.Fun).(*ast.SelectorExpr) if sel, ok := ast.Unparen(caller.Call.Fun).(*ast.SelectorExpr); ok &&
if ok && soleUse(pkgName) == sel.X { is[*ast.Ident](sel.X) &&
caller.Info.Uses[sel.X.(*ast.Ident)] == pkgName &&
countUses(pkgName) == 1 {
needed = false // no longer needed by caller needed = false // no longer needed by caller
// Check to see if any of the inlined free objects need this package. // Check to see if any of the inlined free objects need this package.
for _, obj := range callee.FreeObjs { for _, obj := range callee.FreeObjs {

View file

@ -0,0 +1,727 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package satisfy inspects the type-checked ASTs of Go packages and
// reports the set of discovered type constraints of the form (lhs, rhs
// Type) where lhs is a non-trivial interface, rhs satisfies this
// interface, and this fact is necessary for the package to be
// well-typed.
//
// THIS PACKAGE IS EXPERIMENTAL AND MAY CHANGE AT ANY TIME.
//
// It is provided only for the gopls tool. It requires well-typed inputs.
package satisfy // import "golang.org/x/tools/refactor/satisfy"
// NOTES:
//
// We don't care about numeric conversions, so we don't descend into
// types or constant expressions. This is unsound because
// constant expressions can contain arbitrary statements, e.g.
// const x = len([1]func(){func() {
// ...
// }})
//
// Assignability conversions are possible in the following places:
// - in assignments y = x, y := x, var y = x.
// - from call argument types to formal parameter types
// - in append and delete calls
// - from return operands to result parameter types
// - in composite literal T{k:v}, from k and v to T's field/element/key type
// - in map[key] from key to the map's key type
// - in comparisons x==y and switch x { case y: }.
// - in explicit conversions T(x)
// - in sends ch <- x, from x to the channel element type
// - in type assertions x.(T) and switch x.(type) { case T: }
//
// The results of this pass provide information equivalent to the
// ssa.MakeInterface and ssa.ChangeInterface instructions.
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/internal/typeparams"
)
// A Constraint records the fact that the RHS type does and must
// satisfy the LHS type, which is an interface.
// The names are suggestive of an assignment statement LHS = RHS.
//
// The constraint is implicitly universally quantified over any type
// parameters appearing within the two types.
type Constraint struct {
LHS, RHS types.Type
}
// A Finder inspects the type-checked ASTs of Go packages and
// accumulates the set of type constraints (x, y) such that x is
// assignable to y, y is an interface, and both x and y have methods.
//
// In other words, it returns the subset of the "implements" relation
// that is checked during compilation of a package. Refactoring tools
// will need to preserve at least this part of the relation to ensure
// continued compilation.
type Finder struct {
Result map[Constraint]bool
msetcache typeutil.MethodSetCache
// per-Find state
info *types.Info
sig *types.Signature
}
// Find inspects a single package, populating Result with its pairs of
// constrained types.
//
// The result is non-canonical and thus may contain duplicates (but this
// tends to preserves names of interface types better).
//
// The package must be free of type errors, and
// info.{Defs,Uses,Selections,Types} must have been populated by the
// type-checker.
func (f *Finder) Find(info *types.Info, files []*ast.File) {
if info.Defs == nil || info.Uses == nil || info.Selections == nil || info.Types == nil {
panic("Finder.Find: one of info.{Defs,Uses,Selections.Types} is not populated")
}
if f.Result == nil {
f.Result = make(map[Constraint]bool)
}
f.info = info
for _, file := range files {
for _, d := range file.Decls {
switch d := d.(type) {
case *ast.GenDecl:
if d.Tok == token.VAR { // ignore consts
for _, spec := range d.Specs {
f.valueSpec(spec.(*ast.ValueSpec))
}
}
case *ast.FuncDecl:
if d.Body != nil {
f.sig = f.info.Defs[d.Name].Type().(*types.Signature)
f.stmt(d.Body)
f.sig = nil
}
}
}
}
f.info = nil
}
var (
tInvalid = types.Typ[types.Invalid]
tUntypedBool = types.Typ[types.UntypedBool]
tUntypedNil = types.Typ[types.UntypedNil]
)
// exprN visits an expression in a multi-value context.
func (f *Finder) exprN(e ast.Expr) types.Type {
typ := f.info.Types[e].Type.(*types.Tuple)
switch e := e.(type) {
case *ast.ParenExpr:
return f.exprN(e.X)
case *ast.CallExpr:
// x, err := f(args)
sig := typeparams.CoreType(f.expr(e.Fun)).(*types.Signature)
f.call(sig, e.Args)
case *ast.IndexExpr:
// y, ok := x[i]
x := f.expr(e.X)
f.assign(f.expr(e.Index), typeparams.CoreType(x).(*types.Map).Key())
case *ast.TypeAssertExpr:
// y, ok := x.(T)
f.typeAssert(f.expr(e.X), typ.At(0).Type())
case *ast.UnaryExpr: // must be receive <-
// y, ok := <-x
f.expr(e.X)
default:
panic(e)
}
return typ
}
func (f *Finder) call(sig *types.Signature, args []ast.Expr) {
if len(args) == 0 {
return
}
// Ellipsis call? e.g. f(x, y, z...)
if _, ok := args[len(args)-1].(*ast.Ellipsis); ok {
for i, arg := range args {
// The final arg is a slice, and so is the final param.
f.assign(sig.Params().At(i).Type(), f.expr(arg))
}
return
}
var argtypes []types.Type
// Gather the effective actual parameter types.
if tuple, ok := f.info.Types[args[0]].Type.(*types.Tuple); ok {
// f(g()) call where g has multiple results?
f.expr(args[0])
// unpack the tuple
for v := range tuple.Variables() {
argtypes = append(argtypes, v.Type())
}
} else {
for _, arg := range args {
argtypes = append(argtypes, f.expr(arg))
}
}
// Assign the actuals to the formals.
if !sig.Variadic() {
for i, argtype := range argtypes {
f.assign(sig.Params().At(i).Type(), argtype)
}
} else {
// The first n-1 parameters are assigned normally.
nnormals := sig.Params().Len() - 1
for i, argtype := range argtypes[:nnormals] {
f.assign(sig.Params().At(i).Type(), argtype)
}
// Remaining args are assigned to elements of varargs slice.
tElem := sig.Params().At(nnormals).Type().(*types.Slice).Elem()
for i := nnormals; i < len(argtypes); i++ {
f.assign(tElem, argtypes[i])
}
}
}
// builtin visits the arguments of a builtin type with signature sig.
func (f *Finder) builtin(obj *types.Builtin, sig *types.Signature, args []ast.Expr) {
switch obj.Name() {
case "make", "new":
for i, arg := range args {
if i == 0 && f.info.Types[arg].IsType() {
continue // skip the type operand
}
f.expr(arg)
}
case "append":
s := f.expr(args[0])
if _, ok := args[len(args)-1].(*ast.Ellipsis); ok && len(args) == 2 {
// append(x, y...) including append([]byte, "foo"...)
f.expr(args[1])
} else {
// append(x, y, z)
tElem := typeparams.CoreType(s).(*types.Slice).Elem()
for _, arg := range args[1:] {
f.assign(tElem, f.expr(arg))
}
}
case "delete":
m := f.expr(args[0])
k := f.expr(args[1])
f.assign(typeparams.CoreType(m).(*types.Map).Key(), k)
default:
// ordinary call
f.call(sig, args)
}
}
func (f *Finder) extract(tuple types.Type, i int) types.Type {
if tuple, ok := tuple.(*types.Tuple); ok && i < tuple.Len() {
return tuple.At(i).Type()
}
return tInvalid
}
func (f *Finder) valueSpec(spec *ast.ValueSpec) {
var T types.Type
if spec.Type != nil {
T = f.info.Types[spec.Type].Type
}
switch len(spec.Values) {
case len(spec.Names): // e.g. var x, y = f(), g()
for _, value := range spec.Values {
v := f.expr(value)
if T != nil {
f.assign(T, v)
}
}
case 1: // e.g. var x, y = f()
tuple := f.exprN(spec.Values[0])
for i := range spec.Names {
if T != nil {
f.assign(T, f.extract(tuple, i))
}
}
}
}
// assign records pairs of distinct types that are related by
// assignability, where the left-hand side is an interface and both
// sides have methods.
//
// It should be called for all assignability checks, type assertions,
// explicit conversions and comparisons between two types, unless the
// types are uninteresting (e.g. lhs is a concrete type, or the empty
// interface; rhs has no methods).
func (f *Finder) assign(lhs, rhs types.Type) {
if types.Identical(lhs, rhs) {
return
}
if !types.IsInterface(lhs) {
return
}
if f.msetcache.MethodSet(lhs).Len() == 0 {
return
}
if f.msetcache.MethodSet(rhs).Len() == 0 {
return
}
// record the pair
f.Result[Constraint{lhs, rhs}] = true
}
// typeAssert must be called for each type assertion x.(T) where x has
// interface type I.
func (f *Finder) typeAssert(I, T types.Type) {
// Type assertions are slightly subtle, because they are allowed
// to be "impossible", e.g.
//
// var x interface{f()}
// _ = x.(interface{f()int}) // legal
//
// (In hindsight, the language spec should probably not have
// allowed this, but it's too late to fix now.)
//
// This means that a type assert from I to T isn't exactly a
// constraint that T is assignable to I, but for a refactoring
// tool it is a conditional constraint that, if T is assignable
// to I before a refactoring, it should remain so after.
if types.AssignableTo(T, I) {
f.assign(I, T)
}
}
// compare must be called for each comparison x==y.
func (f *Finder) compare(x, y types.Type) {
if types.AssignableTo(x, y) {
f.assign(y, x)
} else if types.AssignableTo(y, x) {
f.assign(x, y)
}
}
// expr visits a true expression (not a type or defining ident)
// and returns its type.
func (f *Finder) expr(e ast.Expr) types.Type {
tv := f.info.Types[e]
if tv.Value != nil {
return tv.Type // prune the descent for constants
}
// tv.Type may be nil for an ast.Ident.
switch e := e.(type) {
case *ast.BadExpr, *ast.BasicLit:
// no-op
case *ast.Ident:
// (referring idents only)
if obj, ok := f.info.Uses[e]; ok {
return obj.Type()
}
if e.Name == "_" { // e.g. "for _ = range x"
return tInvalid
}
panic("undefined ident: " + e.Name)
case *ast.Ellipsis:
if e.Elt != nil {
f.expr(e.Elt)
}
case *ast.FuncLit:
saved := f.sig
f.sig = tv.Type.(*types.Signature)
f.stmt(e.Body)
f.sig = saved
case *ast.CompositeLit:
switch T := typeparams.CoreType(typeparams.Deref(tv.Type)).(type) {
case *types.Struct:
for i, elem := range e.Elts {
if kv, ok := elem.(*ast.KeyValueExpr); ok {
f.assign(f.info.Uses[kv.Key.(*ast.Ident)].Type(), f.expr(kv.Value))
} else {
f.assign(T.Field(i).Type(), f.expr(elem))
}
}
case *types.Map:
for _, elem := range e.Elts {
elem := elem.(*ast.KeyValueExpr)
f.assign(T.Key(), f.expr(elem.Key))
f.assign(T.Elem(), f.expr(elem.Value))
}
case *types.Array, *types.Slice:
tElem := T.(interface {
Elem() types.Type
}).Elem()
for _, elem := range e.Elts {
if kv, ok := elem.(*ast.KeyValueExpr); ok {
// ignore the key
f.assign(tElem, f.expr(kv.Value))
} else {
f.assign(tElem, f.expr(elem))
}
}
default:
panic(fmt.Sprintf("unexpected composite literal type %T: %v", tv.Type, tv.Type.String()))
}
case *ast.ParenExpr:
f.expr(e.X)
case *ast.SelectorExpr:
if _, ok := f.info.Selections[e]; ok {
f.expr(e.X) // selection
} else {
return f.info.Uses[e.Sel].Type() // qualified identifier
}
case *ast.IndexExpr:
if instance(f.info, e.X) {
// f[T] or C[T] -- generic instantiation
} else {
// x[i] or m[k] -- index or lookup operation
x := f.expr(e.X)
i := f.expr(e.Index)
if ux, ok := typeparams.CoreType(x).(*types.Map); ok {
f.assign(ux.Key(), i)
}
}
case *ast.IndexListExpr:
// f[X, Y] -- generic instantiation
case *ast.SliceExpr:
f.expr(e.X)
if e.Low != nil {
f.expr(e.Low)
}
if e.High != nil {
f.expr(e.High)
}
if e.Max != nil {
f.expr(e.Max)
}
case *ast.TypeAssertExpr:
x := f.expr(e.X)
f.typeAssert(x, f.info.Types[e.Type].Type)
case *ast.CallExpr:
if tvFun := f.info.Types[e.Fun]; tvFun.IsType() {
// conversion
arg0 := f.expr(e.Args[0])
f.assign(tvFun.Type, arg0)
} else {
// function call
// unsafe call. Treat calls to functions in unsafe like ordinary calls,
// except that their signature cannot be determined by their func obj.
// Without this special handling, f.expr(e.Fun) would fail below.
if s, ok := ast.Unparen(e.Fun).(*ast.SelectorExpr); ok {
if obj, ok := f.info.Uses[s.Sel].(*types.Builtin); ok && obj.Pkg().Path() == "unsafe" {
sig := f.info.Types[e.Fun].Type.(*types.Signature)
f.call(sig, e.Args)
return tv.Type
}
}
// builtin call
if id, ok := ast.Unparen(e.Fun).(*ast.Ident); ok {
if obj, ok := f.info.Uses[id].(*types.Builtin); ok {
sig := f.info.Types[id].Type.(*types.Signature)
f.builtin(obj, sig, e.Args)
return tv.Type
}
}
// ordinary call
f.call(typeparams.CoreType(f.expr(e.Fun)).(*types.Signature), e.Args)
}
case *ast.StarExpr:
f.expr(e.X)
case *ast.UnaryExpr:
f.expr(e.X)
case *ast.BinaryExpr:
x := f.expr(e.X)
y := f.expr(e.Y)
if e.Op == token.EQL || e.Op == token.NEQ {
f.compare(x, y)
}
case *ast.KeyValueExpr:
f.expr(e.Key)
f.expr(e.Value)
case *ast.ArrayType,
*ast.StructType,
*ast.FuncType,
*ast.InterfaceType,
*ast.MapType,
*ast.ChanType:
panic(e)
}
if tv.Type == nil {
panic(fmt.Sprintf("no type for %T", e))
}
return tv.Type
}
func (f *Finder) stmt(s ast.Stmt) {
switch s := s.(type) {
case *ast.BadStmt,
*ast.EmptyStmt,
*ast.BranchStmt:
// no-op
case *ast.DeclStmt:
d := s.Decl.(*ast.GenDecl)
if d.Tok == token.VAR { // ignore consts
for _, spec := range d.Specs {
f.valueSpec(spec.(*ast.ValueSpec))
}
}
case *ast.LabeledStmt:
f.stmt(s.Stmt)
case *ast.ExprStmt:
f.expr(s.X)
case *ast.SendStmt:
ch := f.expr(s.Chan)
val := f.expr(s.Value)
f.assign(typeparams.CoreType(ch).(*types.Chan).Elem(), val)
case *ast.IncDecStmt:
f.expr(s.X)
case *ast.AssignStmt:
switch s.Tok {
case token.ASSIGN, token.DEFINE:
// y := x or y = x
var rhsTuple types.Type
if len(s.Lhs) != len(s.Rhs) {
rhsTuple = f.exprN(s.Rhs[0])
}
for i := range s.Lhs {
var lhs, rhs types.Type
if rhsTuple == nil {
rhs = f.expr(s.Rhs[i]) // 1:1 assignment
} else {
rhs = f.extract(rhsTuple, i) // n:1 assignment
}
if id, ok := s.Lhs[i].(*ast.Ident); ok {
if id.Name != "_" {
if obj, ok := f.info.Defs[id]; ok {
lhs = obj.Type() // definition
}
}
}
if lhs == nil {
lhs = f.expr(s.Lhs[i]) // assignment
}
f.assign(lhs, rhs)
}
default:
// y op= x
f.expr(s.Lhs[0])
f.expr(s.Rhs[0])
}
case *ast.GoStmt:
f.expr(s.Call)
case *ast.DeferStmt:
f.expr(s.Call)
case *ast.ReturnStmt:
formals := f.sig.Results()
switch len(s.Results) {
case formals.Len(): // 1:1
for i, result := range s.Results {
f.assign(formals.At(i).Type(), f.expr(result))
}
case 1: // n:1
tuple := f.exprN(s.Results[0])
for i := 0; i < formals.Len(); i++ {
f.assign(formals.At(i).Type(), f.extract(tuple, i))
}
}
case *ast.SelectStmt:
f.stmt(s.Body)
case *ast.BlockStmt:
for _, s := range s.List {
f.stmt(s)
}
case *ast.IfStmt:
if s.Init != nil {
f.stmt(s.Init)
}
f.expr(s.Cond)
f.stmt(s.Body)
if s.Else != nil {
f.stmt(s.Else)
}
case *ast.SwitchStmt:
if s.Init != nil {
f.stmt(s.Init)
}
var tag types.Type = tUntypedBool
if s.Tag != nil {
tag = f.expr(s.Tag)
}
for _, cc := range s.Body.List {
cc := cc.(*ast.CaseClause)
for _, cond := range cc.List {
f.compare(tag, f.info.Types[cond].Type)
}
for _, s := range cc.Body {
f.stmt(s)
}
}
case *ast.TypeSwitchStmt:
if s.Init != nil {
f.stmt(s.Init)
}
var I types.Type
switch ass := s.Assign.(type) {
case *ast.ExprStmt: // x.(type)
I = f.expr(ast.Unparen(ass.X).(*ast.TypeAssertExpr).X)
case *ast.AssignStmt: // y := x.(type)
I = f.expr(ast.Unparen(ass.Rhs[0]).(*ast.TypeAssertExpr).X)
}
for _, cc := range s.Body.List {
cc := cc.(*ast.CaseClause)
for _, cond := range cc.List {
tCase := f.info.Types[cond].Type
if tCase != tUntypedNil {
f.typeAssert(I, tCase)
}
}
for _, s := range cc.Body {
f.stmt(s)
}
}
case *ast.CommClause:
if s.Comm != nil {
f.stmt(s.Comm)
}
for _, s := range s.Body {
f.stmt(s)
}
case *ast.ForStmt:
if s.Init != nil {
f.stmt(s.Init)
}
if s.Cond != nil {
f.expr(s.Cond)
}
if s.Post != nil {
f.stmt(s.Post)
}
f.stmt(s.Body)
case *ast.RangeStmt:
x := f.expr(s.X)
// No conversions are involved when Tok==DEFINE.
if s.Tok == token.ASSIGN {
if s.Key != nil {
k := f.expr(s.Key)
var xelem types.Type
// Keys of array, *array, slice, string aren't interesting
// since the RHS key type is just an int.
switch ux := typeparams.CoreType(x).(type) {
case *types.Chan:
xelem = ux.Elem()
case *types.Map:
xelem = ux.Key()
}
if xelem != nil {
f.assign(k, xelem)
}
}
if s.Value != nil {
val := f.expr(s.Value)
var xelem types.Type
// Values of type strings aren't interesting because
// the RHS value type is just a rune.
switch ux := typeparams.CoreType(x).(type) {
case *types.Array:
xelem = ux.Elem()
case *types.Map:
xelem = ux.Elem()
case *types.Pointer: // *array
xelem = typeparams.CoreType(typeparams.Deref(ux)).(*types.Array).Elem()
case *types.Slice:
xelem = ux.Elem()
}
if xelem != nil {
f.assign(val, xelem)
}
}
}
f.stmt(s.Body)
default:
panic(s)
}
}
// -- Plundered from golang.org/x/tools/go/ssa -----------------
func instance(info *types.Info, expr ast.Expr) bool {
var id *ast.Ident
switch x := expr.(type) {
case *ast.Ident:
id = x
case *ast.SelectorExpr:
id = x.Sel
default:
return false
}
_, ok := info.Instances[id]
return ok
}

View file

@ -73,7 +73,7 @@ golang.org/x/text/internal/tag
golang.org/x/text/language golang.org/x/text/language
golang.org/x/text/transform golang.org/x/text/transform
golang.org/x/text/unicode/norm golang.org/x/text/unicode/norm
# golang.org/x/tools v0.39.1-0.20251114194111-59ff18ce4883 # golang.org/x/tools v0.39.1-0.20251120214200-68724afed209
## explicit; go 1.24.0 ## explicit; go 1.24.0
golang.org/x/tools/cmd/bisect golang.org/x/tools/cmd/bisect
golang.org/x/tools/cover golang.org/x/tools/cover
@ -149,6 +149,7 @@ golang.org/x/tools/internal/typeparams
golang.org/x/tools/internal/typesinternal golang.org/x/tools/internal/typesinternal
golang.org/x/tools/internal/typesinternal/typeindex golang.org/x/tools/internal/typesinternal/typeindex
golang.org/x/tools/internal/versions golang.org/x/tools/internal/versions
golang.org/x/tools/refactor/satisfy
# rsc.io/markdown v0.0.0-20240306144322-0bf8f97ee8ef # rsc.io/markdown v0.0.0-20240306144322-0bf8f97ee8ef
## explicit; go 1.20 ## explicit; go 1.20
rsc.io/markdown rsc.io/markdown

View file

@ -0,0 +1,82 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"bytes"
"crypto/internal/fips140"
_ "crypto/internal/fips140/check"
"crypto/internal/fips140/sha256"
"errors"
"sync"
)
func fipsPCT(priv *PrivateKey) {
fips140.PCT("ML-DSA sign and verify PCT", func() error {
μ := make([]byte, 64)
sig, err := SignExternalMuDeterministic(priv, μ)
if err != nil {
return err
}
return VerifyExternalMu(priv.PublicKey(), μ, sig)
})
}
var fipsSelfTest = sync.OnceFunc(func() {
fips140.CAST("ML-DSA-44", fips140CAST)
})
// fips140CAST covers all rejection sampling paths, as recommended by IG 10.3.A,
// and as tested by TestCASTRejectionPaths. It tests only one parameter set as
// allowed by Note26. It tests the modified version of Algorithm 7 and 8 with a
// fixed mu/μ, as allowed by IG 10.3.A, Resolution 15. It compares sk and not
// pk, because H(pk) is part of sk, as allowed by the same Resolution. It
// compares the results with hashes instead of values, to avoid embedding several
// kilobytes of test vectors in every binary, as allowed by GeneralNote7.
func fips140CAST() error {
// From https://pages.nist.gov/ACVP/draft-celi-acvp-ml-dsa.html#table-1.
var seed = &[32]byte{
0x5c, 0x62, 0x4f, 0xcc, 0x18, 0x62, 0x45, 0x24,
0x52, 0xd0, 0xc6, 0x65, 0x84, 0x0d, 0x82, 0x37,
0xf4, 0x31, 0x08, 0xe5, 0x49, 0x9e, 0xdc, 0xdc,
0x10, 0x8f, 0xbc, 0x49, 0xd5, 0x96, 0xe4, 0xb7,
}
var μ = &[64]byte{
0x2a, 0xd1, 0xc7, 0x2b, 0xb0, 0xfc, 0xbe, 0x28,
0x09, 0x9c, 0xe8, 0xbd, 0x2e, 0xd8, 0x36, 0xdf,
0xeb, 0xe5, 0x20, 0xaa, 0xd3, 0x8f, 0xba, 0xc6,
0x6e, 0xf7, 0x85, 0xa3, 0xcf, 0xb1, 0x0f, 0xb4,
0x19, 0x32, 0x7f, 0xa5, 0x78, 0x18, 0xee, 0x4e,
0x37, 0x18, 0xda, 0x4b, 0xe4, 0x8d, 0x24, 0xb5,
0x9a, 0x20, 0x8f, 0x88, 0x07, 0x27, 0x1f, 0xdb,
0x7e, 0xda, 0x6e, 0x60, 0x14, 0x1b, 0xd2, 0x63,
}
var skHash = []byte{
0x29, 0x37, 0x49, 0x51, 0xcb, 0x2b, 0xc3, 0xcd,
0xa7, 0x31, 0x5c, 0xe7, 0xf0, 0xab, 0x99, 0xc7,
0xd2, 0xd6, 0x52, 0x92, 0xe6, 0xc5, 0x15, 0x6e,
0x8a, 0xa6, 0x2a, 0xc1, 0x4b, 0x14, 0x12, 0xaf,
}
var sigHash = []byte{
0xdc, 0xc7, 0x1a, 0x42, 0x1b, 0xc6, 0xff, 0xaf,
0xb7, 0xdf, 0x0c, 0x7f, 0x6d, 0x01, 0x8a, 0x19,
0xad, 0xa1, 0x54, 0xd1, 0xe2, 0xee, 0x36, 0x0e,
0xd5, 0x33, 0xce, 0xcd, 0x5d, 0xc9, 0x80, 0xad,
}
priv := newPrivateKey(seed, params44)
H := sha256.New()
H.Write(TestingOnlyPrivateKeySemiExpandedBytes(priv))
if !bytes.Equal(H.Sum(nil), skHash) {
return errors.New("unexpected private key hash")
}
var random [32]byte
sig := signInternal(priv, μ, &random)
H.Reset()
H.Write(sig)
if !bytes.Equal(H.Sum(nil), sigHash) {
return errors.New("unexpected signature hash")
}
return verifyInternal(priv.PublicKey(), μ, sig)
}

View file

@ -0,0 +1,781 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"crypto/internal/constanttime"
"crypto/internal/fips140/sha3"
"errors"
"math/bits"
)
const (
q = 8380417 // 2²³ - 2¹³ + 1
R = 4294967296 // 2³²
RR = 2365951 // R² mod q, aka R in the Montgomery domain
qNegInv = 4236238847 // -q⁻¹ mod R (q * qNegInv ≡ -1 mod R)
one = 4193792 // R mod q, aka 1 in the Montgomery domain
minusOne = 4186625 // (q - 1) * R mod q, aka -1 in the Montgomery domain
)
// fieldElement is an element n of _q in the Montgomery domain, represented as
// an integer x in [0, q) such that x ≡ n * R (mod q) where R = 2³².
type fieldElement uint32
var errUnreducedFieldElement = errors.New("mldsa: unreduced field element")
// fieldToMontgomery checks that a value a is < q, and converts it to
// Montgomery form.
func fieldToMontgomery(a uint32) (fieldElement, error) {
if a >= q {
return 0, errUnreducedFieldElement
}
// a * R² * R⁻¹ ≡ a * R (mod q)
return fieldMontgomeryMul(fieldElement(a), RR), nil
}
// fieldSubToMontgomery converts a difference a - b to Montgomery form.
// a and b must be < q. (This bound can probably be relaxed.)
func fieldSubToMontgomery(a, b uint32) fieldElement {
x := a - b + q
return fieldMontgomeryMul(fieldElement(x), RR)
}
// fieldFromMontgomery converts a value a in Montgomery form back to
// standard representation.
func fieldFromMontgomery(a fieldElement) uint32 {
// (a * R) * 1 * R⁻¹ ≡ a (mod q)
return uint32(fieldMontgomeryReduce(uint64(a)))
}
// fieldCenteredMod returns r mod± q, the value r reduced to the range
// [(q1)/2, (q1)/2].
func fieldCenteredMod(r fieldElement) int32 {
x := int32(fieldFromMontgomery(r))
// x <= q / 2 ? x : x - q
return constantTimeSelectLessOrEqual(x, q/2, x, x-q)
}
// fieldInfinityNorm returns the infinity norm ||r||∞ of r, or the absolute
// value of r centered around 0.
func fieldInfinityNorm(r fieldElement) uint32 {
x := int32(fieldFromMontgomery(r))
// x <= q / 2 ? x : |x - q|
// |x - q| = -(x - q) = q - x because x < q => x - q < 0
return uint32(constantTimeSelectLessOrEqual(x, q/2, x, q-x))
}
// fieldReduceOnce reduces a value a < 2q.
func fieldReduceOnce(a uint32) fieldElement {
x, b := bits.Sub64(uint64(a), uint64(q), 0)
return fieldElement(x + b*q)
}
// fieldAdd returns a + b mod q.
func fieldAdd(a, b fieldElement) fieldElement {
x := uint32(a + b)
return fieldReduceOnce(x)
}
// fieldSub returns a - b mod q.
func fieldSub(a, b fieldElement) fieldElement {
x := uint32(a - b + q)
return fieldReduceOnce(x)
}
// fieldMontgomeryMul returns a * b * R⁻¹ mod q.
func fieldMontgomeryMul(a, b fieldElement) fieldElement {
x := uint64(a) * uint64(b)
return fieldMontgomeryReduce(x)
}
// fieldMontgomeryReduce returns x * R⁻¹ mod q for x < q * R.
func fieldMontgomeryReduce(x uint64) fieldElement {
t := uint32(x) * qNegInv
u := (x + uint64(t)*q) >> 32
return fieldReduceOnce(uint32(u))
}
// fieldMontgomeryMulSub returns a * (b - c). This operation is fused to save a
// fieldReduceOnce after the subtraction.
func fieldMontgomeryMulSub(a, b, c fieldElement) fieldElement {
x := uint64(a) * uint64(b-c+q)
return fieldMontgomeryReduce(x)
}
// fieldMontgomeryAddMul returns a * b + c * d. This operation is fused to save
// a fieldReduceOnce and a fieldReduce.
func fieldMontgomeryAddMul(a, b, c, d fieldElement) fieldElement {
x := uint64(a) * uint64(b)
x += uint64(c) * uint64(d)
return fieldMontgomeryReduce(x)
}
const n = 256
// ringElement is a polynomial, an element of R_q.
type ringElement [n]fieldElement
// polyAdd adds two ringElements or nttElements.
func polyAdd[T ~[n]fieldElement](a, b T) (s T) {
for i := range s {
s[i] = fieldAdd(a[i], b[i])
}
return s
}
// polySub subtracts two ringElements or nttElements.
func polySub[T ~[n]fieldElement](a, b T) (s T) {
for i := range s {
s[i] = fieldSub(a[i], b[i])
}
return s
}
// nttElement is an NTT representation, an element of T_q.
type nttElement [n]fieldElement
// zetas are the values ζ^BitRev₈(k) mod q for each index k, converted to the
// Montgomery domain.
var zetas = [256]fieldElement{4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103, 2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868, 6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005, 2706023, 95776, 3077325, 3530437, 6718724, 4788269, 5842901, 3915439, 4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118, 6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596, 811944, 531354, 954230, 3881043, 3900724, 5823537, 2071892, 5582638, 4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196, 7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922, 3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370, 7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987, 5037034, 264944, 508951, 3097992, 44288, 7280319, 904516, 3958618, 4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969, 7063561, 189548, 4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330, 1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961, 2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955, 266997, 2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039, 900702, 1859098, 909542, 819034, 495491, 6767243, 8337157, 7857917, 7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579, 342297, 286988, 5942594, 4108315, 3437287, 5038140, 1735879, 203044, 2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974, 4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447, 7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775, 7100756, 1917081, 5834105, 7005614, 1500165, 777191, 2235880, 3406031, 7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136, 4603424, 6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531, 7173032, 5196991, 162844, 1616392, 3014001, 810149, 1652634, 4686184, 6581310, 5341501, 3523897, 3866901, 269760, 2213111, 7404533, 1717735, 472078, 7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524, 5441381, 6144432, 7959518, 6094090, 183443, 7403526, 1612842, 4834730, 7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782}
// ntt maps a ringElement to its nttElement representation.
//
// It implements NTT, according to FIPS 203, Algorithm 9.
func ntt(f ringElement) nttElement {
var m uint8
for len := 128; len >= 8; len /= 2 {
for start := 0; start < 256; start += 2 * len {
m++
zeta := zetas[m]
// Bounds check elimination hint.
f, flen := f[start:start+len], f[start+len:start+len+len]
for j := 0; j < len; j += 2 {
t := fieldMontgomeryMul(zeta, flen[j])
flen[j] = fieldSub(f[j], t)
f[j] = fieldAdd(f[j], t)
// Unroll by 2 for performance.
t = fieldMontgomeryMul(zeta, flen[j+1])
flen[j+1] = fieldSub(f[j+1], t)
f[j+1] = fieldAdd(f[j+1], t)
}
}
}
// Unroll len = 4, 2, and 1.
for start := 0; start < 256; start += 8 {
m++
zeta := zetas[m]
t := fieldMontgomeryMul(zeta, f[start+4])
f[start+4] = fieldSub(f[start], t)
f[start] = fieldAdd(f[start], t)
t = fieldMontgomeryMul(zeta, f[start+5])
f[start+5] = fieldSub(f[start+1], t)
f[start+1] = fieldAdd(f[start+1], t)
t = fieldMontgomeryMul(zeta, f[start+6])
f[start+6] = fieldSub(f[start+2], t)
f[start+2] = fieldAdd(f[start+2], t)
t = fieldMontgomeryMul(zeta, f[start+7])
f[start+7] = fieldSub(f[start+3], t)
f[start+3] = fieldAdd(f[start+3], t)
}
for start := 0; start < 256; start += 4 {
m++
zeta := zetas[m]
t := fieldMontgomeryMul(zeta, f[start+2])
f[start+2] = fieldSub(f[start], t)
f[start] = fieldAdd(f[start], t)
t = fieldMontgomeryMul(zeta, f[start+3])
f[start+3] = fieldSub(f[start+1], t)
f[start+1] = fieldAdd(f[start+1], t)
}
for start := 0; start < 256; start += 2 {
m++
zeta := zetas[m]
t := fieldMontgomeryMul(zeta, f[start+1])
f[start+1] = fieldSub(f[start], t)
f[start] = fieldAdd(f[start], t)
}
return nttElement(f)
}
// inverseNTT maps a nttElement back to the ringElement it represents.
//
// It implements NTT⁻¹, according to FIPS 203, Algorithm 10.
func inverseNTT(f nttElement) ringElement {
var m uint8 = 255
// Unroll len = 1, 2, and 4.
for start := 0; start < 256; start += 2 {
zeta := zetas[m]
m--
t := f[start]
f[start] = fieldAdd(t, f[start+1])
f[start+1] = fieldMontgomeryMulSub(zeta, f[start+1], t)
}
for start := 0; start < 256; start += 4 {
zeta := zetas[m]
m--
t := f[start]
f[start] = fieldAdd(t, f[start+2])
f[start+2] = fieldMontgomeryMulSub(zeta, f[start+2], t)
t = f[start+1]
f[start+1] = fieldAdd(t, f[start+3])
f[start+3] = fieldMontgomeryMulSub(zeta, f[start+3], t)
}
for start := 0; start < 256; start += 8 {
zeta := zetas[m]
m--
t := f[start]
f[start] = fieldAdd(t, f[start+4])
f[start+4] = fieldMontgomeryMulSub(zeta, f[start+4], t)
t = f[start+1]
f[start+1] = fieldAdd(t, f[start+5])
f[start+5] = fieldMontgomeryMulSub(zeta, f[start+5], t)
t = f[start+2]
f[start+2] = fieldAdd(t, f[start+6])
f[start+6] = fieldMontgomeryMulSub(zeta, f[start+6], t)
t = f[start+3]
f[start+3] = fieldAdd(t, f[start+7])
f[start+7] = fieldMontgomeryMulSub(zeta, f[start+7], t)
}
for len := 8; len < 256; len *= 2 {
for start := 0; start < 256; start += 2 * len {
zeta := zetas[m]
m--
// Bounds check elimination hint.
f, flen := f[start:start+len], f[start+len:start+len+len]
for j := 0; j < len; j += 2 {
t := f[j]
f[j] = fieldAdd(t, flen[j])
// -z * (t - flen[j]) = z * (flen[j] - t)
flen[j] = fieldMontgomeryMulSub(zeta, flen[j], t)
// Unroll by 2 for performance.
t = f[j+1]
f[j+1] = fieldAdd(t, flen[j+1])
flen[j+1] = fieldMontgomeryMulSub(zeta, flen[j+1], t)
}
}
}
for i := range f {
f[i] = fieldMontgomeryMul(f[i], 16382) // 16382 = 256⁻¹ * R mod q
}
return ringElement(f)
}
// nttMul multiplies two nttElements.
func nttMul(a, b nttElement) (p nttElement) {
for i := range p {
p[i] = fieldMontgomeryMul(a[i], b[i])
}
return p
}
// sampleNTT samples an nttElement uniformly at random from the seed rho and the
// indices s and r. It implements Step 3 of ExpandA, RejNTTPoly, and
// CoeffFromThreeBytes from FIPS 204, passing in ρ, s, and r instead of ρ'.
func sampleNTT(rho []byte, s, r byte) nttElement {
G := sha3.NewShake128()
G.Write(rho)
G.Write([]byte{s, r})
var a nttElement
var j int // index into a
var buf [168]byte // buffered reads from B, matching the rate of SHAKE-128
off := len(buf) // index into buf, starts in a "buffer fully consumed" state
for j < n {
if off >= len(buf) {
G.Read(buf[:])
off = 0
}
v := uint32(buf[off]) | uint32(buf[off+1])<<8 | uint32(buf[off+2])<<16
off += 3
f, err := fieldToMontgomery(v & 0b01111111_11111111_11111111) // 23 bits
if err != nil {
continue
}
a[j] = f
j++
}
return a
}
// sampleBoundedPoly samples a ringElement with coefficients in [−η, η] from the
// seed rho and the index r. It implements RejBoundedPoly and CoeffFromHalfByte
// from FIPS 204, passing in ρ and r separately from ExpandS.
func sampleBoundedPoly(rho []byte, r byte, p parameters) ringElement {
H := sha3.NewShake256()
H.Write(rho)
H.Write([]byte{r, 0}) // IntegerToBytes(r, 2)
var a ringElement
var j int
var buf [136]byte // buffered reads from H, matching the rate of SHAKE-256
off := len(buf) // index into buf, starts in a "buffer fully consumed" state
for {
if off >= len(buf) {
H.Read(buf[:])
off = 0
}
z0 := buf[off] & 0x0F
z1 := buf[off] >> 4
off++
coeff, ok := coeffFromHalfByte(z0, p)
if ok {
a[j] = coeff
j++
}
if j >= len(a) {
break
}
coeff, ok = coeffFromHalfByte(z1, p)
if ok {
a[j] = coeff
j++
}
if j >= len(a) {
break
}
}
return a
}
// sampleInBall samples a ringElement with coefficients in {1, 0, 1}, and τ
// non-zero coefficients. It is not constant-time.
func sampleInBall(rho []byte, p parameters) ringElement {
H := sha3.NewShake256()
H.Write(rho)
s := make([]byte, 8)
H.Read(s)
var c ringElement
for i := 256 - p.τ; i < 256; i++ {
j := make([]byte, 1)
H.Read(j)
for j[0] > byte(i) {
H.Read(j)
}
c[i] = c[j[0]]
// c[j] = (1) ^ h[i+τ256], where h are the bits in s in little-endian.
// That is, -1⁰ = 1 if the bit is 0, -1¹ = -1 if it is 1.
bitIdx := i + p.τ - 256
bit := (s[bitIdx/8] >> (bitIdx % 8)) & 1
if bit == 0 {
c[j[0]] = one
} else {
c[j[0]] = minusOne
}
}
return c
}
// coeffFromHalfByte implements CoeffFromHalfByte from FIPS 204.
//
// It maps a value in [0, 15] to a coefficient in [−η, η]
func coeffFromHalfByte(b byte, p parameters) (fieldElement, bool) {
if b > 15 {
panic("internal error: half-byte out of range")
}
switch p.η {
case 2:
// Return z = 2 (b mod 5), which maps from
//
// b = ( 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 )
//
// to
//
// b%5 = ( 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0 )
//
// to
//
// z = ( -2, -1, 0, 1, 2, -2, -1, 0, 1, 2, -2, -1, 0, 1, 2 )
//
if b > 14 {
return 0, false
}
// Calculate b % 5 with Barrett reduction, to avoid a potentially
// variable-time division.
const barrettMultiplier = 0x3334 // ⌈2¹⁶ / 5⌉
const barrettShift = 16 // log₂(2¹⁶)
quotient := (uint32(b) * barrettMultiplier) >> barrettShift
remainder := uint32(b) - quotient*5
return fieldSubToMontgomery(2, remainder), true
case 4:
// Return z = 4 b, which maps from
//
// b = ( 8, 7, 6, 5, 4, 3, 2, 1, 0 )
//
// to
//
// z = ( 4, -3, -2, -1, 0, 1, 2, 3, 4 )
//
if b > 8 {
return 0, false
}
return fieldSubToMontgomery(4, uint32(b)), true
default:
panic("internal error: unsupported η")
}
}
// power2Round implements Power2Round from FIPS 204.
//
// It separates the bottom d = 13 bits of each 23-bit coefficient, rounding the
// high part based on the low part, and correcting the low part accordingly.
func power2Round(r fieldElement) (hi uint16, lo fieldElement) {
rr := fieldFromMontgomery(r)
// Add 2¹² - 1 to round up r1 by one if r0 > 2¹².
// r is at most 2²³ - 2¹³ + 1, so rr + (2¹² - 1) won't overflow 23 bits.
r1 := rr + 1<<12 - 1
r1 >>= 13
// r1 <= 2¹⁰ - 1
// r1 * 2¹³ <= (2¹⁰ - 1) * 2¹³ = 2²³ - 2¹³ < q
r0 := fieldSubToMontgomery(rr, r1<<13)
return uint16(r1), r0
}
// highBits implements HighBits from FIPS 204.
func highBits(r ringElement, p parameters) [n]byte {
var w [n]byte
switch p.γ2 {
case 32:
for i := range n {
w[i] = highBits32(fieldFromMontgomery(r[i]))
}
case 88:
for i := range n {
w[i] = highBits88(fieldFromMontgomery(r[i]))
}
default:
panic("mldsa: internal error: unsupported γ2")
}
return w
}
// useHint implements UseHint from FIPS 204.
//
// It is not constant-time.
func useHint(r ringElement, h [n]byte, p parameters) [n]byte {
var w [n]byte
switch p.γ2 {
case 32:
for i := range n {
w[i] = useHint32(r[i], h[i])
}
case 88:
for i := range n {
w[i] = useHint88(r[i], h[i])
}
default:
panic("mldsa: internal error: unsupported γ2")
}
return w
}
// makeHint implements MakeHint from FIPS 204.
func makeHint(ct0, w, cs2 ringElement, p parameters) (h [n]byte, count1s int) {
switch p.γ2 {
case 32:
for i := range n {
h[i] = makeHint32(ct0[i], w[i], cs2[i])
count1s += int(h[i])
}
case 88:
for i := range n {
h[i] = makeHint88(ct0[i], w[i], cs2[i])
count1s += int(h[i])
}
default:
panic("mldsa: internal error: unsupported γ2")
}
return h, count1s
}
// highBits32 implements HighBits from FIPS 204 for γ2 = (q - 1) / 32.
func highBits32(x uint32) byte {
// The implementation is based on the reference implementation and on
// BoringSSL. There are exhaustive tests in TestDecompose that compare it to
// a straightforward implementation of Decompose from the spec, so for our
// purposes it only has to work and be constant-time.
r1 := (x + 127) >> 7
r1 = (r1*1025 + (1 << 21)) >> 22
r1 &= 0b1111
return byte(r1)
}
// decompose32 implements Decompose from FIPS 204 for γ2 = (q - 1) / 32.
//
// r1 is in [0, 15].
func decompose32(r fieldElement) (r1 byte, r0 int32) {
x := fieldFromMontgomery(r)
r1 = highBits32(x)
// r - r1 * (2 * γ2) mod± q
r0 = int32(x) - int32(r1)*2*(q-1)/32
r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0)
return r1, r0
}
// useHint32 implements UseHint from FIPS 204 for γ2 = (q - 1) / 32.
func useHint32(r fieldElement, hint byte) byte {
const m = 16 // (q 1) / (2 * γ2)
r1, r0 := decompose32(r)
if hint == 1 {
if r0 > 0 {
r1 = (r1 + 1) % m
} else {
// Underflow is safe, because it operates modulo 256 (since the type
// is byte), which is a multiple of m.
r1 = (r1 - 1) % m
}
}
return r1
}
// makeHint32 implements MakeHint from FIPS 204 for γ2 = (q - 1) / 32.
func makeHint32(ct0, w, cs2 fieldElement) byte {
// v1 = HighBits(r + z) = HighBits(w - cs2 + ct0 - ct0) = HighBits(w - cs2)
rPlusZ := fieldSub(w, cs2)
v1 := highBits32(fieldFromMontgomery(rPlusZ))
// r1 = HighBits(r) = HighBits(w - cs2 + ct0)
r1 := highBits32(fieldFromMontgomery(fieldAdd(rPlusZ, ct0)))
return byte(constanttime.ByteEq(v1, r1) ^ 1)
}
// highBits88 implements HighBits from FIPS 204 for γ2 = (q - 1) / 88.
func highBits88(x uint32) byte {
// Like highBits32, this is exhaustively tested in TestDecompose.
r1 := (x + 127) >> 7
r1 = (r1*11275 + (1 << 23)) >> 24
r1 = constantTimeSelectEqual(r1, 44, 0, r1)
return byte(r1)
}
// decompose88 implements Decompose from FIPS 204 for γ2 = (q - 1) / 88.
//
// r1 is in [0, 43].
func decompose88(r fieldElement) (r1 byte, r0 int32) {
x := fieldFromMontgomery(r)
r1 = highBits88(x)
// r - r1 * (2 * γ2) mod± q
r0 = int32(x) - int32(r1)*2*(q-1)/88
r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0)
return r1, r0
}
// useHint88 implements UseHint from FIPS 204 for γ2 = (q - 1) / 88.
func useHint88(r fieldElement, hint byte) byte {
const m = 44 // (q 1) / (2 * γ2)
r1, r0 := decompose88(r)
if hint == 1 {
if r0 > 0 {
// (r1 + 1) mod m, for r1 in [0, m-1]
if r1 == m-1 {
r1 = 0
} else {
r1++
}
} else {
// (r1 - 1) % m, for r1 in [0, m-1]
if r1 == 0 {
r1 = m - 1
} else {
r1--
}
}
}
return r1
}
// makeHint88 implements MakeHint from FIPS 204 for γ2 = (q - 1) / 88.
func makeHint88(ct0, w, cs2 fieldElement) byte {
// Same as makeHint32 above.
rPlusZ := fieldSub(w, cs2)
v1 := highBits88(fieldFromMontgomery(rPlusZ))
r1 := highBits88(fieldFromMontgomery(fieldAdd(rPlusZ, ct0)))
return byte(constanttime.ByteEq(v1, r1) ^ 1)
}
// bitPack implements BitPack(r mod± q, γ₁-1, γ₁), which packs the centered
// coefficients of r into little-endian γ1+1-bit chunks. It appends to buf.
//
// It must only be applied to r with coefficients in [−γ₁+1, γ₁], as
// guaranteed by the rejection conditions in Sign.
func bitPack(b []byte, r ringElement, p parameters) []byte {
switch p.γ1 {
case 17:
return bitPack18(b, r)
case 19:
return bitPack20(b, r)
default:
panic("mldsa: internal error: unsupported γ1")
}
}
// bitPack18 implements BitPack(r mod± q, 2¹⁷-1, 2¹⁷), which packs the centered
// coefficients of r into little-endian 18-bit chunks. It appends to buf.
//
// It must only be applied to r with coefficients in [2¹⁷+1, 2¹⁷], as
// guaranteed by the rejection conditions in Sign.
func bitPack18(buf []byte, r ringElement) []byte {
out, v := sliceForAppend(buf, 18*n/8)
const b = 1 << 17
for i := 0; i < n; i += 4 {
// b - [2¹⁷+1, 2¹⁷] = [0, 2²⁸-1]
w0 := b - fieldCenteredMod(r[i])
v[0] = byte(w0 << 0)
v[1] = byte(w0 >> 8)
v[2] = byte(w0 >> 16)
w1 := b - fieldCenteredMod(r[i+1])
v[2] |= byte(w1 << 2)
v[3] = byte(w1 >> 6)
v[4] = byte(w1 >> 14)
w2 := b - fieldCenteredMod(r[i+2])
v[4] |= byte(w2 << 4)
v[5] = byte(w2 >> 4)
v[6] = byte(w2 >> 12)
w3 := b - fieldCenteredMod(r[i+3])
v[6] |= byte(w3 << 6)
v[7] = byte(w3 >> 2)
v[8] = byte(w3 >> 10)
v = v[4*18/8:]
}
return out
}
// bitPack20 implements BitPack(r mod± q, 2¹⁹-1, 2¹⁹), which packs the centered
// coefficients of r into little-endian 20-bit chunks. It appends to buf.
//
// It must only be applied to r with coefficients in [2¹⁹+1, 2¹⁹], as
// guaranteed by the rejection conditions in Sign.
func bitPack20(buf []byte, r ringElement) []byte {
out, v := sliceForAppend(buf, 20*n/8)
const b = 1 << 19
for i := 0; i < n; i += 2 {
// b - [2¹⁹+1, 2¹⁹] = [0, 2²⁰-1]
w0 := b - fieldCenteredMod(r[i])
v[0] = byte(w0 << 0)
v[1] = byte(w0 >> 8)
v[2] = byte(w0 >> 16)
w1 := b - fieldCenteredMod(r[i+1])
v[2] |= byte(w1 << 4)
v[3] = byte(w1 >> 4)
v[4] = byte(w1 >> 12)
v = v[2*20/8:]
}
return out
}
// bitUnpack implements BitUnpack(v, 2^γ1-1, 2^γ1), which unpacks each γ1+1 bits
// in little-endian into a coefficient in [-2^γ1+1, 2^γ1].
func bitUnpack(v []byte, p parameters) ringElement {
switch p.γ1 {
case 17:
return bitUnpack18(v)
case 19:
return bitUnpack20(v)
default:
panic("mldsa: internal error: unsupported γ1")
}
}
// bitUnpack18 implements BitUnpack(v, 2¹⁷-1, 2¹⁷), which unpacks each 18 bits
// in little-endian into a coefficient in [-2¹⁷+1, 2¹⁷].
func bitUnpack18(v []byte) ringElement {
if len(v) != 18*n/8 {
panic("mldsa: internal error: invalid bitUnpack18 input length")
}
const b = 1 << 17
const mask18 = 1<<18 - 1
var r ringElement
for i := 0; i < n; i += 4 {
w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16
r[i+0] = fieldSubToMontgomery(b, w0&mask18)
w1 := uint32(v[2])>>2 | uint32(v[3])<<6 | uint32(v[4])<<14
r[i+1] = fieldSubToMontgomery(b, w1&mask18)
w2 := uint32(v[4])>>4 | uint32(v[5])<<4 | uint32(v[6])<<12
r[i+2] = fieldSubToMontgomery(b, w2&mask18)
w3 := uint32(v[6])>>6 | uint32(v[7])<<2 | uint32(v[8])<<10
r[i+3] = fieldSubToMontgomery(b, w3&mask18)
v = v[4*18/8:]
}
return r
}
// bitUnpack20 implements BitUnpack(v, 2¹⁹-1, 2¹⁹), which unpacks each 20 bits
// in little-endian into a coefficient in [-2¹⁹+1, 2¹⁹].
func bitUnpack20(v []byte) ringElement {
if len(v) != 20*n/8 {
panic("mldsa: internal error: invalid bitUnpack20 input length")
}
const b = 1 << 19
const mask20 = 1<<20 - 1
var r ringElement
for i := 0; i < n; i += 2 {
w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16
r[i+0] = fieldSubToMontgomery(b, w0&mask20)
w1 := uint32(v[2])>>4 | uint32(v[3])<<4 | uint32(v[4])<<12
r[i+1] = fieldSubToMontgomery(b, w1&mask20)
v = v[2*20/8:]
}
return r
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// constantTimeSelectLessOrEqual returns yes if a <= b, no otherwise, in constant time.
func constantTimeSelectLessOrEqual(a, b, yes, no int32) int32 {
return int32(constanttime.Select(constanttime.LessOrEq(int(a), int(b)), int(yes), int(no)))
}
// constantTimeSelectEqual returns yes if a == b, no otherwise, in constant time.
func constantTimeSelectEqual(a, b, yes, no uint32) uint32 {
return uint32(constanttime.Select(constanttime.Eq(int32(a), int32(b)), int(yes), int(no)))
}
// constantTimeAbs returns the absolute value of x in constant time.
func constantTimeAbs(x int32) uint32 {
return uint32(constantTimeSelectLessOrEqual(0, x, x, -x))
}

View file

@ -0,0 +1,370 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"math/big"
"testing"
)
type interestingValue struct {
v uint32
m fieldElement
}
// q is large enough that we can't exhaustively test all q × q inputs, so when
// we have two inputs we test [0, q) on one side and a set of interesting
// values on the other side.
func interestingValues() []interestingValue {
if testing.Short() {
return []interestingValue{{v: q - 1, m: minusOne}}
}
var values []interestingValue
for _, v := range []uint32{
0,
1,
2,
3,
q - 3,
q - 2,
q - 1,
q / 2,
(q + 1) / 2,
} {
m, _ := fieldToMontgomery(v)
values = append(values, interestingValue{v: v, m: m})
// Also test values that have an interesting Montgomery representation.
values = append(values, interestingValue{
v: fieldFromMontgomery(fieldElement(v)), m: fieldElement(v)})
}
return values
}
func TestToFromMontgomery(t *testing.T) {
for a := range uint32(q) {
m, err := fieldToMontgomery(a)
if err != nil {
t.Fatalf("fieldToMontgomery(%d) returned error: %v", a, err)
}
exp := fieldElement((uint64(a) * R) % q)
if m != exp {
t.Fatalf("fieldToMontgomery(%d) = %d, expected %d", a, m, exp)
}
got := fieldFromMontgomery(m)
if got != a {
t.Fatalf("fieldFromMontgomery(fieldToMontgomery(%d)) = %d, expected %d", a, got, a)
}
}
}
func TestFieldAdd(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range fieldElement(q) {
got := fieldAdd(a.m, b)
exp := (a.m + b) % q
if got != exp {
t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
}
}
}
}
func TestFieldSub(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range fieldElement(q) {
got := fieldSub(a.m, b)
exp := (a.m + q - b) % q
if got != exp {
t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
}
}
}
}
func TestFieldSubToMontgomery(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range uint32(q) {
got := fieldSubToMontgomery(a.v, b)
diff := (a.v + q - b) % q
exp := fieldElement((uint64(diff) * R) % q)
if got != exp {
t.Fatalf("fieldSubToMontgomery(%d, %d) = %d, expected %d", a.v, b, got, exp)
}
}
}
}
func TestFieldReduceOnce(t *testing.T) {
t.Parallel()
for a := range uint32(2 * q) {
got := fieldReduceOnce(a)
var exp uint32
if a < q {
exp = a
} else {
exp = a - q
}
if uint32(got) != exp {
t.Fatalf("fieldReduceOnce(%d) = %d, expected %d", a, got, exp)
}
}
}
func TestFieldMul(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range fieldElement(q) {
got := fieldFromMontgomery(fieldMontgomeryMul(a.m, b))
exp := uint32((uint64(a.v) * uint64(fieldFromMontgomery(b))) % q)
if got != exp {
t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
}
}
}
}
func TestFieldToMontgomeryOverflow(t *testing.T) {
// fieldToMontgomery should reject inputs ≥ q.
inputs := []uint32{
q,
q + 1,
q + 2,
1<<23 - 1,
1 << 23,
q + 1<<23,
q + 1<<31,
^uint32(0),
}
for _, in := range inputs {
if _, err := fieldToMontgomery(in); err == nil {
t.Fatalf("fieldToMontgomery(%d) did not return an error", in)
}
}
}
func TestFieldMulSub(t *testing.T) {
for _, a := range interestingValues() {
for _, b := range interestingValues() {
for _, c := range interestingValues() {
got := fieldFromMontgomery(fieldMontgomeryMulSub(a.m, b.m, c.m))
exp := uint32((uint64(a.v) * (uint64(b.v) + q - uint64(c.v))) % q)
if got != exp {
t.Fatalf("%d * (%d - %d) = %d, expected %d", a.v, b.v, c.v, got, exp)
}
}
}
}
}
func TestFieldAddMul(t *testing.T) {
for _, a := range interestingValues() {
for _, b := range interestingValues() {
for _, c := range interestingValues() {
for _, d := range interestingValues() {
got := fieldFromMontgomery(fieldMontgomeryAddMul(a.m, b.m, c.m, d.m))
exp := uint32((uint64(a.v)*uint64(b.v) + uint64(c.v)*uint64(d.v)) % q)
if got != exp {
t.Fatalf("%d + %d * %d = %d, expected %d", a.v, b.v, c.v, got, exp)
}
}
}
}
}
}
func BitRev8(n uint8) uint8 {
var r uint8
r |= n >> 7 & 0b0000_0001
r |= n >> 5 & 0b0000_0010
r |= n >> 3 & 0b0000_0100
r |= n >> 1 & 0b0000_1000
r |= n << 1 & 0b0001_0000
r |= n << 3 & 0b0010_0000
r |= n << 5 & 0b0100_0000
r |= n << 7 & 0b1000_0000
return r
}
func CenteredMod(x, m uint32) int32 {
x = x % m
if x > m/2 {
return int32(x) - int32(m)
}
return int32(x)
}
func reduceModQ(x int32) uint32 {
x %= q
if x < 0 {
return uint32(x + q)
}
return uint32(x)
}
func TestCenteredMod(t *testing.T) {
for x := range uint32(q * 2) {
got := CenteredMod(uint32(x), q)
if reduceModQ(got) != (x % q) {
t.Fatalf("CenteredMod(%d) = %d, which is not congruent to %d mod %d", x, got, x, q)
}
}
for x := range uint32(q) {
r, _ := fieldToMontgomery(x)
got := fieldCenteredMod(r)
exp := CenteredMod(x, q)
if got != exp {
t.Fatalf("fieldCenteredMod(%d) = %d, expected %d", x, got, exp)
}
}
}
func TestInfinityNorm(t *testing.T) {
for x := range uint32(q) {
r, _ := fieldToMontgomery(x)
got := fieldInfinityNorm(r)
exp := CenteredMod(x, q)
if exp < 0 {
exp = -exp
}
if got != uint32(exp) {
t.Fatalf("fieldInfinityNorm(%d) = %d, expected %d", x, got, exp)
}
}
}
func TestConstants(t *testing.T) {
if fieldFromMontgomery(one) != 1 {
t.Errorf("one constant incorrect")
}
if fieldFromMontgomery(minusOne) != q-1 {
t.Errorf("minusOne constant incorrect")
}
if fieldInfinityNorm(one) != 1 {
t.Errorf("one infinity norm incorrect")
}
if fieldInfinityNorm(minusOne) != 1 {
t.Errorf("minusOne infinity norm incorrect")
}
if PublicKeySize44 != pubKeySize(params44) {
t.Errorf("PublicKeySize44 constant incorrect")
}
if PublicKeySize65 != pubKeySize(params65) {
t.Errorf("PublicKeySize65 constant incorrect")
}
if PublicKeySize87 != pubKeySize(params87) {
t.Errorf("PublicKeySize87 constant incorrect")
}
if SignatureSize44 != sigSize(params44) {
t.Errorf("SignatureSize44 constant incorrect")
}
if SignatureSize65 != sigSize(params65) {
t.Errorf("SignatureSize65 constant incorrect")
}
if SignatureSize87 != sigSize(params87) {
t.Errorf("SignatureSize87 constant incorrect")
}
}
func TestPower2Round(t *testing.T) {
t.Parallel()
for x := range uint32(q) {
rr, _ := fieldToMontgomery(x)
t1, t0 := power2Round(rr)
hi, err := fieldToMontgomery(uint32(t1) << 13)
if err != nil {
t.Fatalf("power2Round(%d): failed to convert high part to Montgomery: %v", x, err)
}
if r := fieldFromMontgomery(fieldAdd(hi, t0)); r != x {
t.Fatalf("power2Round(%d) = (%d, %d), which reconstructs to %d, expected %d", x, t1, t0, r, x)
}
}
}
func SpecDecompose(rr fieldElement, p parameters) (R1 uint32, R0 int32) {
r := fieldFromMontgomery(rr)
if (q-1)%p.γ2 != 0 {
panic("mldsa: internal error: unsupported denγ2")
}
γ2 := (q - 1) / uint32(p.γ2)
r0 := CenteredMod(r, 2*γ2)
diff := int32(r) - r0
if diff == q-1 {
r0 = r0 - 1
return 0, r0
} else {
if diff < 0 || uint32(diff)%γ2 != 0 {
panic("mldsa: internal error: invalid decomposition")
}
r1 := uint32(diff) / (2 * γ2)
return r1, r0
}
}
func TestDecompose(t *testing.T) {
t.Run("ML-DSA-44", func(t *testing.T) {
testDecompose(t, params44)
})
t.Run("ML-DSA-65,87", func(t *testing.T) {
testDecompose(t, params65)
})
}
func testDecompose(t *testing.T, p parameters) {
t.Parallel()
for x := range uint32(q) {
rr, _ := fieldToMontgomery(x)
r1, r0 := SpecDecompose(rr, p)
// Check that SpecDecompose is correct.
// r ≡ r1 * (2 * γ2) + r0 mod q
γ2 := (q - 1) / uint32(p.γ2)
reconstructed := reduceModQ(int32(r1*2*γ2) + r0)
if reconstructed != x {
t.Fatalf("SpecDecompose(%d) = (%d, %d), which reconstructs to %d, expected %d", x, r1, r0, reconstructed, x)
}
var gotR1 byte
var gotR0 int32
switch p.γ2 {
case 88:
gotR1, gotR0 = decompose88(rr)
if gotR1 > 43 {
t.Fatalf("decompose88(%d) returned r1 = %d, which is out of range", x, gotR1)
}
case 32:
gotR1, gotR0 = decompose32(rr)
if gotR1 > 15 {
t.Fatalf("decompose32(%d) returned r1 = %d, which is out of range", x, gotR1)
}
default:
t.Fatalf("unsupported denγ2: %d", p.γ2)
}
if uint32(gotR1) != r1 {
t.Fatalf("highBits(%d) = %d, expected %d", x, gotR1, r1)
}
if gotR0 != r0 {
t.Fatalf("lowBits(%d) = %d, expected %d", x, gotR0, r0)
}
}
}
func TestZetas(t *testing.T) {
ζ := big.NewInt(1753)
q := big.NewInt(q)
for k, zeta := range zetas {
// ζ^BitRev₈(k) mod q
exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev8(uint8(k)))), q)
got := fieldFromMontgomery(zeta)
if big.NewInt(int64(got)).Cmp(exp) != 0 {
t.Errorf("zetas[%d] = %v, expected %v", k, got, exp)
}
}
}

View file

@ -0,0 +1,782 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"bytes"
"crypto/internal/fips140"
"crypto/internal/fips140/drbg"
"crypto/internal/fips140/sha3"
"crypto/internal/fips140/subtle"
"crypto/internal/fips140deps/byteorder"
"errors"
)
type parameters struct {
k, l int // dimensions of A
η int // bound for secret coefficients
γ1 int // log₂(γ₁), where [-γ₁+1, γ₁] is the bound of y
γ2 int // denominator of γ₂ = (q - 1) / γ2
λ int // collison strength
τ int // number of non-zero coefficients in challenge
ω int // max number of hints in MakeHint
}
var (
params44 = parameters{k: 4, l: 4, η: 2, γ1: 17, γ2: 88, λ: 128, τ: 39, ω: 80}
params65 = parameters{k: 6, l: 5, η: 4, γ1: 19, γ2: 32, λ: 192, τ: 49, ω: 55}
params87 = parameters{k: 8, l: 7, η: 2, γ1: 19, γ2: 32, λ: 256, τ: 60, ω: 75}
)
func pubKeySize(p parameters) int {
// ρ + k × n × 10-bit coefficients of t₁
return 32 + p.k*n*10/8
}
func sigSize(p parameters) int {
// challenge + l × n × (γ₁+1)-bit coefficients of z + hint
return (p.λ / 4) + p.l*n*(p.γ1+1)/8 + p.ω + p.k
}
const (
PrivateKeySize = 32
PublicKeySize44 = 32 + 4*n*10/8
PublicKeySize65 = 32 + 6*n*10/8
PublicKeySize87 = 32 + 8*n*10/8
SignatureSize44 = 128/4 + 4*n*(17+1)/8 + 80 + 4
SignatureSize65 = 192/4 + 5*n*(19+1)/8 + 55 + 6
SignatureSize87 = 256/4 + 7*n*(19+1)/8 + 75 + 8
)
const maxK, maxL, maxλ, maxγ1 = 8, 7, 256, 19
const maxPubKeySize = PublicKeySize87
type PrivateKey struct {
seed [32]byte
pub PublicKey
s1 [maxL]nttElement
s2 [maxK]nttElement
t0 [maxK]nttElement
k [32]byte
}
func (priv *PrivateKey) Equal(x *PrivateKey) bool {
return priv.pub.p == x.pub.p && subtle.ConstantTimeCompare(priv.seed[:], x.seed[:]) == 1
}
func (priv *PrivateKey) Bytes() []byte {
seed := priv.seed
return seed[:]
}
func (priv *PrivateKey) PublicKey() *PublicKey {
// Note that this is likely to keep the entire PrivateKey reachable for
// the lifetime of the PublicKey, which may be undesirable.
return &priv.pub
}
type PublicKey struct {
raw [maxPubKeySize]byte
p parameters
a [maxK * maxL]nttElement
t1 [maxK]nttElement // NTT(t₁ ⋅ 2ᵈ)
tr [64]byte // public key hash
}
func (pub *PublicKey) Equal(x *PublicKey) bool {
size := pubKeySize(pub.p)
return pub.p == x.p && subtle.ConstantTimeCompare(pub.raw[:size], x.raw[:size]) == 1
}
func (pub *PublicKey) Bytes() []byte {
size := pubKeySize(pub.p)
return bytes.Clone(pub.raw[:size])
}
func (pub *PublicKey) Parameters() string {
switch pub.p {
case params44:
return "ML-DSA-44"
case params65:
return "ML-DSA-65"
case params87:
return "ML-DSA-87"
default:
panic("mldsa: internal error: unknown parameters")
}
}
func GenerateKey44() *PrivateKey {
fipsSelfTest()
fips140.RecordApproved()
var seed [32]byte
drbg.Read(seed[:])
priv := newPrivateKey(&seed, params44)
fipsPCT(priv)
return priv
}
func GenerateKey65() *PrivateKey {
fipsSelfTest()
fips140.RecordApproved()
var seed [32]byte
drbg.Read(seed[:])
priv := newPrivateKey(&seed, params65)
fipsPCT(priv)
return priv
}
func GenerateKey87() *PrivateKey {
fipsSelfTest()
fips140.RecordApproved()
var seed [32]byte
drbg.Read(seed[:])
priv := newPrivateKey(&seed, params87)
fipsPCT(priv)
return priv
}
var errInvalidSeedLength = errors.New("mldsa: invalid seed length")
func NewPrivateKey44(seed []byte) (*PrivateKey, error) {
fipsSelfTest()
fips140.RecordApproved()
if len(seed) != 32 {
return nil, errInvalidSeedLength
}
return newPrivateKey((*[32]byte)(seed), params44), nil
}
func NewPrivateKey65(seed []byte) (*PrivateKey, error) {
fipsSelfTest()
fips140.RecordApproved()
if len(seed) != 32 {
return nil, errInvalidSeedLength
}
return newPrivateKey((*[32]byte)(seed), params65), nil
}
func NewPrivateKey87(seed []byte) (*PrivateKey, error) {
fipsSelfTest()
fips140.RecordApproved()
if len(seed) != 32 {
return nil, errInvalidSeedLength
}
return newPrivateKey((*[32]byte)(seed), params87), nil
}
func newPrivateKey(seed *[32]byte, p parameters) *PrivateKey {
k, l := p.k, p.l
priv := &PrivateKey{pub: PublicKey{p: p}}
priv.seed = *seed
ξ := sha3.NewShake256()
ξ.Write(seed[:])
ξ.Write([]byte{byte(k), byte(l)})
ρ, ρs := make([]byte, 32), make([]byte, 64)
ξ.Read(ρ)
ξ.Read(ρs)
ξ.Read(priv.k[:])
A := priv.pub.a[:k*l]
computeMatrixA(A, ρ, p)
s1 := priv.s1[:l]
for r := range l {
s1[r] = ntt(sampleBoundedPoly(ρs, byte(r), p))
}
s2 := priv.s2[:k]
for r := range k {
s2[r] = ntt(sampleBoundedPoly(ρs, byte(l+r), p))
}
// ˆt = Â ∘ ŝ₁ + ŝ₂
tHat := make([]nttElement, k, maxK)
for i := range tHat {
tHat[i] = s2[i]
for j := range s1 {
tHat[i] = polyAdd(tHat[i], nttMul(A[i*l+j], s1[j]))
}
}
// t = NTT⁻¹(ˆt)
t := make([]ringElement, k, maxK)
for i := range tHat {
t[i] = inverseNTT(tHat[i])
}
// (t₁, _) = Power2Round(t)
// (_, ˆt₀) = NTT(Power2Round(t))
t1, t0 := make([][n]uint16, k, maxK), priv.t0[:k]
for i := range t {
var w ringElement
for j := range t[i] {
t1[i][j], w[j] = power2Round(t[i][j])
}
t0[i] = ntt(w)
}
// The computations below (and their storage in the PrivateKey struct) are
// not strictly necessary and could be deferred to PrivateKey.PublicKey().
// That would require keeping or re-deriving ρ and t/t1, though.
pk := pkEncode(priv.pub.raw[:0], ρ, t1, p)
priv.pub.tr = computePublicKeyHash(pk)
computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ)
return priv
}
func computeMatrixA(A []nttElement, ρ []byte, p parameters) {
k, l := p.k, p.l
for r := range k {
for s := range l {
A[r*l+s] = sampleNTT(ρ, byte(s), byte(r))
}
}
}
func computePublicKeyHash(pk []byte) [64]byte {
H := sha3.NewShake256()
H.Write(pk)
var tr [64]byte
H.Read(tr[:])
return tr
}
func computeT1Hat(t1Hat []nttElement, t1 [][n]uint16) {
for i := range t1 {
var w ringElement
for j := range t1[i] {
// t₁ <= 2¹⁰ - 1
// t₁ ⋅ 2ᵈ <= 2ᵈ(2¹⁰ - 1) = 2²³ - 2¹³ < q = 2²³ - 2¹³ + 1
z, _ := fieldToMontgomery(uint32(t1[i][j]) << 13)
w[j] = z
}
t1Hat[i] = ntt(w)
}
}
func pkEncode(buf []byte, ρ []byte, t1 [][n]uint16, p parameters) []byte {
pk := append(buf, ρ...)
for _, w := range t1[:p.k] {
// Encode four at a time into 4 * 10 bits = 5 bytes.
for i := 0; i < n; i += 4 {
c0 := w[i]
c1 := w[i+1]
c2 := w[i+2]
c3 := w[i+3]
b0 := byte(c0 >> 0)
b1 := byte((c0 >> 8) | (c1 << 2))
b2 := byte((c1 >> 6) | (c2 << 4))
b3 := byte((c2 >> 4) | (c3 << 6))
b4 := byte(c3 >> 2)
pk = append(pk, b0, b1, b2, b3, b4)
}
}
return pk
}
func pkDecode(pk []byte, t1 [][n]uint16, p parameters) (ρ []byte, err error) {
if len(pk) != pubKeySize(p) {
return nil, errInvalidPublicKeyLength
}
ρ, pk = pk[:32], pk[32:]
for r := range t1 {
// Decode four at a time from 4 * 10 bits = 5 bytes.
for i := 0; i < n; i += 4 {
b0, b1, b2, b3, b4 := pk[0], pk[1], pk[2], pk[3], pk[4]
t1[r][i+0] = uint16(b0>>0) | uint16(b1&0b0000_0011)<<8
t1[r][i+1] = uint16(b1>>2) | uint16(b2&0b0000_1111)<<6
t1[r][i+2] = uint16(b2>>4) | uint16(b3&0b0011_1111)<<4
t1[r][i+3] = uint16(b3>>6) | uint16(b4&0b1111_1111)<<2
pk = pk[5:]
}
}
return ρ, nil
}
var errInvalidPublicKeyLength = errors.New("mldsa: invalid public key length")
func NewPublicKey44(pk []byte) (*PublicKey, error) {
return newPublicKey(pk, params44)
}
func NewPublicKey65(pk []byte) (*PublicKey, error) {
return newPublicKey(pk, params65)
}
func NewPublicKey87(pk []byte) (*PublicKey, error) {
return newPublicKey(pk, params87)
}
func newPublicKey(pk []byte, p parameters) (*PublicKey, error) {
k, l := p.k, p.l
t1 := make([][n]uint16, k, maxK)
ρ, err := pkDecode(pk, t1, p)
if err != nil {
return nil, err
}
pub := &PublicKey{p: p}
copy(pub.raw[:], pk)
computeMatrixA(pub.a[:k*l], ρ, p)
pub.tr = computePublicKeyHash(pk)
computeT1Hat(pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ)
return pub, nil
}
var (
errContextTooLong = errors.New("mldsa: context too long")
errMessageHashLength = errors.New("mldsa: invalid message hash length")
errRandomLength = errors.New("mldsa: invalid random length")
)
func Sign(priv *PrivateKey, msg []byte, context string) ([]byte, error) {
fipsSelfTest()
fips140.RecordApproved()
var random [32]byte
drbg.Read(random[:])
μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
if err != nil {
return nil, err
}
return signInternal(priv, &μ, &random), nil
}
func SignDeterministic(priv *PrivateKey, msg []byte, context string) ([]byte, error) {
fipsSelfTest()
fips140.RecordApproved()
var random [32]byte
μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
if err != nil {
return nil, err
}
return signInternal(priv, &μ, &random), nil
}
func TestingOnlySignWithRandom(priv *PrivateKey, msg []byte, context string, random []byte) ([]byte, error) {
fipsSelfTest()
fips140.RecordApproved()
μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
if err != nil {
return nil, err
}
if len(random) != 32 {
return nil, errRandomLength
}
return signInternal(priv, &μ, (*[32]byte)(random)), nil
}
func SignExternalMu(priv *PrivateKey, μ []byte) ([]byte, error) {
fipsSelfTest()
fips140.RecordApproved()
var random [32]byte
drbg.Read(random[:])
if len(μ) != 64 {
return nil, errMessageHashLength
}
return signInternal(priv, (*[64]byte)(μ), &random), nil
}
func SignExternalMuDeterministic(priv *PrivateKey, μ []byte) ([]byte, error) {
fipsSelfTest()
fips140.RecordApproved()
var random [32]byte
if len(μ) != 64 {
return nil, errMessageHashLength
}
return signInternal(priv, (*[64]byte)(μ), &random), nil
}
func TestingOnlySignExternalMuWithRandom(priv *PrivateKey, μ []byte, random []byte) ([]byte, error) {
fipsSelfTest()
fips140.RecordApproved()
if len(μ) != 64 {
return nil, errMessageHashLength
}
if len(random) != 32 {
return nil, errRandomLength
}
return signInternal(priv, (*[64]byte)(μ), (*[32]byte)(random)), nil
}
func computeMessageHash(tr []byte, msg []byte, context string) ([64]byte, error) {
if len(context) > 255 {
return [64]byte{}, errContextTooLong
}
H := sha3.NewShake256()
H.Write(tr)
H.Write([]byte{0}) // ML-DSA / HashML-DSA domain separator
H.Write([]byte{byte(len(context))})
H.Write([]byte(context))
H.Write(msg)
var μ [64]byte
H.Read(μ[:])
return μ, nil
}
func signInternal(priv *PrivateKey, μ *[64]byte, random *[32]byte) []byte {
p, k, l := priv.pub.p, priv.pub.p.k, priv.pub.p.l
A, s1, s2, t0 := priv.pub.a[:k*l], priv.s1[:l], priv.s2[:k], priv.t0[:k]
β := p.τ * p.η
γ1 := uint32(1 << p.γ1)
γ := γ1 - uint32(β)
γ2 := (q - 1) / uint32(p.γ2)
γ := γ2 - uint32(β)
H := sha3.NewShake256()
H.Write(priv.k[:])
H.Write(random[:])
H.Write(μ[:])
nonce := make([]byte, 64)
H.Read(nonce)
κ := 0
sign:
for {
// Main rejection sampling loop. Note that leaking rejected signatures
// leaks information about the private key. However, as explained in
// https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
// Section 5.5, we are free to leak rejected ch values, as well as which
// check causes the rejection and which coefficient failed the check
// (but not the value or sign of the coefficient).
y := make([]ringElement, l, maxL)
for r := range y {
counter := make([]byte, 2)
byteorder.LEPutUint16(counter, uint16(κ))
κ++
H.Reset()
H.Write(nonce)
H.Write(counter)
v := make([]byte, (p.γ1+1)*n/8, (maxγ1+1)*n/8)
H.Read(v)
y[r] = bitUnpack(v, p)
}
// w = NTT⁻¹(Â ∘ NTT(y))
yHat := make([]nttElement, l, maxL)
for i := range y {
yHat[i] = ntt(y[i])
}
w := make([]ringElement, k, maxK)
for i := range w {
var wHat nttElement
for j := range l {
wHat = polyAdd(wHat, nttMul(A[i*l+j], yHat[j]))
}
w[i] = inverseNTT(wHat)
}
H.Reset()
H.Write(μ[:])
for i := range w {
w1Encode(H, highBits(w[i], p), p)
}
ch := make([]byte, p.λ/4, maxλ/4)
H.Read(ch)
// sampleInBall is not constant time, but see comment above about
// leaking rejected ch values being acceptable.
c := ntt(sampleInBall(ch, p))
cs1 := make([]ringElement, l, maxL)
for i := range cs1 {
cs1[i] = inverseNTT(nttMul(c, s1[i]))
}
cs2 := make([]ringElement, k, maxK)
for i := range cs2 {
cs2[i] = inverseNTT(nttMul(c, s2[i]))
}
z := make([]ringElement, l, maxL)
for i := range y {
z[i] = polyAdd(y[i], cs1[i])
// Reject if ||z||∞ ≥ γ1 β
if coefficientsExceedBound(z[i], γ) {
if testingOnlyRejectionReason != nil {
testingOnlyRejectionReason("z")
}
continue sign
}
}
for i := range w {
r0 := polySub(w[i], cs2[i])
// Reject if ||LowBits(r0)||∞ ≥ γ2 β
if lowBitsExceedBound(r0, γ, p) {
if testingOnlyRejectionReason != nil {
testingOnlyRejectionReason("r0")
}
continue sign
}
}
ct0 := make([]ringElement, k, maxK)
for i := range ct0 {
ct0[i] = inverseNTT(nttMul(c, t0[i]))
// Reject if ||ct0||∞ ≥ γ2
if coefficientsExceedBound(ct0[i], γ2) {
if testingOnlyRejectionReason != nil {
testingOnlyRejectionReason("ct0")
}
continue sign
}
}
count1s := 0
h := make([][n]byte, k, maxK)
for i := range w {
var count int
h[i], count = makeHint(ct0[i], w[i], cs2[i], p)
count1s += count
}
// Reject if number of hints > ω
if count1s > p.ω {
if testingOnlyRejectionReason != nil {
testingOnlyRejectionReason("h")
}
continue sign
}
return sigEncode(ch, z, h, p)
}
}
// testingOnlyRejectionReason is set in tests, to ensure that all rejection
// paths are covered. If not nil, it is called with a string describing the
// reason for rejection: "z", "r0", "ct0", or "h".
var testingOnlyRejectionReason func(reason string)
// w1Encode implements w1Encode from FIPS 204, writing directly into H.
func w1Encode(H *sha3.SHAKE, w [n]byte, p parameters) {
switch p.γ2 {
case 32:
// Coefficients are <= (q 1)/(2γ2) 1 = 15, four bits each.
buf := make([]byte, 4*n/8)
for i := 0; i < n; i += 2 {
b0 := w[i]
b1 := w[i+1]
buf[i/2] = b0 | b1<<4
}
H.Write(buf)
case 88:
// Coefficients are <= (q 1)/(2γ2) 1 = 43, six bits each.
buf := make([]byte, 6*n/8)
for i := 0; i < n; i += 4 {
b0 := w[i]
b1 := w[i+1]
b2 := w[i+2]
b3 := w[i+3]
buf[3*i/4+0] = (b0 >> 0) | (b1 << 6)
buf[3*i/4+1] = (b1 >> 2) | (b2 << 4)
buf[3*i/4+2] = (b2 >> 4) | (b3 << 2)
}
H.Write(buf)
default:
panic("mldsa: internal error: unsupported γ2")
}
}
func coefficientsExceedBound(w ringElement, bound uint32) bool {
// If this function appears in profiles, it might be possible to deduplicate
// the work of fieldFromMontgomery inside fieldInfinityNorm with the
// subsequent encoding of w.
for i := range w {
if fieldInfinityNorm(w[i]) >= bound {
return true
}
}
return false
}
func lowBitsExceedBound(w ringElement, bound uint32, p parameters) bool {
switch p.γ2 {
case 32:
for i := range w {
_, r0 := decompose32(w[i])
if constantTimeAbs(r0) >= bound {
return true
}
}
case 88:
for i := range w {
_, r0 := decompose88(w[i])
if constantTimeAbs(r0) >= bound {
return true
}
}
default:
panic("mldsa: internal error: unsupported γ2")
}
return false
}
var (
errInvalidSignatureLength = errors.New("mldsa: invalid signature length")
errInvalidSignatureCoeffBounds = errors.New("mldsa: invalid signature")
errInvalidSignatureChallenge = errors.New("mldsa: invalid signature")
errInvalidSignatureHintLimits = errors.New("mldsa: invalid signature encoding")
errInvalidSignatureHintIndexOrder = errors.New("mldsa: invalid signature encoding")
errInvalidSignatureHintExtraIndices = errors.New("mldsa: invalid signature encoding")
)
func Verify(pub *PublicKey, msg, sig []byte, context string) error {
fipsSelfTest()
fips140.RecordApproved()
μ, err := computeMessageHash(pub.tr[:], msg, context)
if err != nil {
return err
}
return verifyInternal(pub, &μ, sig)
}
func VerifyExternalMu(pub *PublicKey, μ []byte, sig []byte) error {
fipsSelfTest()
fips140.RecordApproved()
if len(μ) != 64 {
return errMessageHashLength
}
return verifyInternal(pub, (*[64]byte)(μ), sig)
}
func verifyInternal(pub *PublicKey, μ *[64]byte, sig []byte) error {
p, k, l := pub.p, pub.p.k, pub.p.l
t1, A := pub.t1[:k], pub.a[:k*l]
β := p.τ * p.η
γ1 := uint32(1 << p.γ1)
γ := γ1 - uint32(β)
z := make([]ringElement, l, maxL)
h := make([][n]byte, k, maxK)
ch, err := sigDecode(sig, z, h, p)
if err != nil {
return err
}
c := ntt(sampleInBall(ch, p))
// w = Â ∘ NTT(z) NTT(c) ∘ NTT(t₁ ⋅ 2ᵈ)
zHat := make([]nttElement, l, maxL)
for i := range zHat {
zHat[i] = ntt(z[i])
}
w := make([]ringElement, k, maxK)
for i := range w {
var wHat nttElement
for j := range l {
wHat = polyAdd(wHat, nttMul(A[i*l+j], zHat[j]))
}
wHat = polySub(wHat, nttMul(c, t1[i]))
w[i] = inverseNTT(wHat)
}
// Use hints h to compute w₁ from w(approx).
w1 := make([][n]byte, k, maxK)
for i := range w {
w1[i] = useHint(w[i], h[i], p)
}
H := sha3.NewShake256()
H.Write(μ[:])
for i := range w {
w1Encode(H, w1[i], p)
}
computedCH := make([]byte, p.λ/4, maxλ/4)
H.Read(computedCH)
for i := range z {
if coefficientsExceedBound(z[i], γ) {
return errInvalidSignatureCoeffBounds
}
}
if !bytes.Equal(ch, computedCH) {
return errInvalidSignatureChallenge
}
return nil
}
func sigEncode(ch []byte, z []ringElement, h [][n]byte, p parameters) []byte {
sig := make([]byte, 0, sigSize(p))
sig = append(sig, ch...)
for i := range z {
sig = bitPack(sig, z[i], p)
}
sig = hintEncode(sig, h, p)
return sig
}
func sigDecode(sig []byte, z []ringElement, h [][n]byte, p parameters) (ch []byte, err error) {
if len(sig) != sigSize(p) {
return nil, errInvalidSignatureLength
}
ch, sig = sig[:p.λ/4], sig[p.λ/4:]
for i := range z {
length := (p.γ1 + 1) * n / 8
z[i] = bitUnpack(sig[:length], p)
sig = sig[length:]
}
if err := hintDecode(sig, h, p); err != nil {
return nil, err
}
return ch, nil
}
func hintEncode(buf []byte, h [][n]byte, p parameters) []byte {
ω, k := p.ω, p.k
out, y := sliceForAppend(buf, ω+k)
var idx byte
for i := range k {
for j := range n {
if h[i][j] != 0 {
y[idx] = byte(j)
idx++
}
}
y[ω+i] = idx
}
return out
}
func hintDecode(y []byte, h [][n]byte, p parameters) error {
ω, k := p.ω, p.k
if len(y) != ω+k {
return errors.New("mldsa: internal error: invalid signature hint length")
}
var idx byte
for i := range k {
limit := y[ω+i]
if limit < idx || limit > byte(ω) {
return errInvalidSignatureHintLimits
}
first := idx
for idx < limit {
if idx > first && y[idx-1] >= y[idx] {
return errInvalidSignatureHintIndexOrder
}
h[i][y[idx]] = 1
idx++
}
}
for i := idx; i < byte(ω); i++ {
if y[i] != 0 {
return errInvalidSignatureHintExtraIndices
}
}
return nil
}

View file

@ -0,0 +1,431 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"bytes"
"crypto/internal/fips140/sha3"
"crypto/sha256"
"encoding/hex"
"strings"
"testing"
)
// Most tests are in crypto/internal/fips140test/mldsa_test.go, so they can
// apply to all FIPS 140-3 module versions. This file contains only tests that
// need access to the unexported symbol testingOnlyRejectionReason.
func TestACVPRejectionKATs(t *testing.T) {
testCases := []struct {
name string
seed string // input to ML-DSA.KeyGen_internal
keyHash string // SHA2-256(pk || sk)
msg string // M' input to ML-DSA.Sign_internal
sigHash string // SHA2-256(sig)
newPrivateKey func([]byte) (*PrivateKey, error)
newPublicKey func([]byte) (*PublicKey, error)
}{
// https://pages.nist.gov/ACVP/draft-celi-acvp-ml-dsa.html#table-1
// ML-DSA Algorithm 7 ML-DSA.Sign_internal() Known Answer Tests for Rejection Cases
{
"Path/ML-DSA-44/1",
"5C624FCC1862452452D0C665840D8237F43108E5499EDCDC108FBC49D596E4B7",
"AC825C59D8A4C453A2C4EFEA8395741CA404F3000E28D56B25D03BB402E5CB2F",
"951FDF5473A4CBA6D9E5B5DB7E79FB8173921BA5B13E9271401B8F907B8B7D5B",
"DCC71A421BC6FFAFB7DF0C7F6D018A19ADA154D1E2EE360ED533CECD5DC980AD",
NewPrivateKey44, NewPublicKey44,
},
{
"Path/ML-DSA-44/2",
"836EABEDB4D2CD9BE6A4D957CF5EE6BF489304136864C55C2C5F01DA5047D18B",
"E1FF40D96E3552FAB531D1715084B7E38CCDBACC0A8AF94C30959FB4C7F5A445",
"199A0AB735E9004163DD02D319A61CFE81638E3BF47BB1E90E90D6E3EA545247",
"A2608BC27E60541D27B6A14F460D54A48C0298DCC3F45999F29047A3135C4941",
NewPrivateKey44, NewPublicKey44,
},
{
"Path/ML-DSA-44/3",
"CA5A01E1EA6552CB5C9803462B94C2F1DC9D13BB17A6ACE510D157056A2C6114",
"A4652DC4A271095268DD84A5B0744DFDBE2E642E4D41FBC4329C2FBA534C0E13",
"8C8CACA88FFF52B9330510537B3701B3993F3726136A650F48F8604551550832",
"B4B142209137397DAD504CAED01D390ADAF49973D8D2414FC3457FB7AF775189",
NewPrivateKey44, NewPublicKey44,
},
{
"Path/ML-DSA-44/4",
"9C005F1550B4F31855C6B92F978736733F37791CB39DD182D7BA5732BDC2483E",
"2485AA99345F1B334D4D94B610FBFFCCB626CBFD4E9FF0E1F6FC35093C423544",
"B744343F30F7FEE088998BA574E799F1BF3939C06C29BF9AC10F3588A57E21E2",
"5B80A60BAA480B9D0C7D2C05B50928C4BF6808DDA693642058A3EB77EAA768FC",
NewPrivateKey44, NewPublicKey44,
},
{
"Path/ML-DSA-44/5",
"4FAB5485B009399E8AE6FC3D3EEFBFE8E09796E4477AABD5EB1CC908FA734DE3",
"CB56909A7CF3008A662DC635EDCB79DC151CA7ACBAE17B544384ABD91BBBC1E9",
"7CAB0FDCF4BEA5F039137478AA45C9C48EF96D906FC49F6E2F138111BF1B4A4E",
"6CC38D73D639682ABC556DC6DCF436DE24033091F34004F410FABC6887F77AB0",
NewPrivateKey44, NewPublicKey44,
},
{
"Path/ML-DSA-65/1",
"464756A985E5DF03739D95DD309C1ED9C5B04254CC294E7E7EB9B9365EE15117",
"AE95EA0DAA80199E7B4A74EB5A1B1DC6C3805BD01D2FA78D7C4FBA8C255AA13D",
"491101BBA044DE6E44A63796C33CDA051BB05A60725B87AF4BA9DB940C03AC09",
"8E08EA0C8DB941685B9905A73B0B57BAD3500B1F73490480B24375B41230CC04",
NewPrivateKey65, NewPublicKey65,
},
{
"Path/ML-DSA-65/2",
"235A48DB4CA7916B884F424A8586EFD517E87C64AECEC0FCE9A3CC212BA1522E",
"1AC58A909DB4D7BC2473AB5E24AF768279C76F86A82D448258E24EEA4EA6B713",
"F8CE85CB2EC474FFBF5A3FFAE029CE6F4526B8D597655067F97F438B81071E9B",
"AE9531A01738615B6D33C77B3FF618A86E101FDC4C8504681F0EDFA64511AD63",
NewPrivateKey65, NewPublicKey65,
},
{
"Path/ML-DSA-65/3",
"E13131B705A760305FEFFEBFE99082E2691A444BBEFCC3EDF67D909886200207",
"B422093F95CC489C52F4FA2B8973A2FDDD44426D1D04D1AAEEFC8715D417181F",
"CD365512C7E61BBAA130800B37F3BB46AAF1BEEF3742EA8A9010A6DD4576ED0B",
"3C55E604DECA7B89A99305D7A391C35F66A17C1923F467675EC951C0948D21C9",
NewPrivateKey65, NewPublicKey65,
},
{
"Path/ML-DSA-65/4",
"0A4793E040A4BC0D0F37643D12C1EA1F10648724609936C76E0EC83E37209E92",
"622D26D536D4D66CD94956B33A74E2E830ED265D25C34FF7C3E5243403146ADF",
"6D9C7A795E48D80A892CBF4D4558429787277E3806EB5D0BCE1640EEBBBF9AEC",
"3B141110B9F56540B2D49AACDE6399974A4EAC40621E367E68D4504F294DB21B",
NewPrivateKey65, NewPublicKey65,
},
{
"Path/ML-DSA-65/5",
"F865B889E5022D54BABC81CA67E7EB39F1AC42F92CF5295C3DA5C9667DB1B924",
"45BC8EDD1A620C46E973E346844270721824D97888BC174281852D98B7E8F4A3",
"047AFAADBE020ED2D766DA85317DEDE80BE550545F0B21E3F555A990F8004258",
"56308A3578360C41356BA9C97D3240E01767FA76BBBA9FD0CC6CFA9ADD088DB9",
NewPrivateKey65, NewPublicKey65,
},
{
"Path/ML-DSA-87/1",
"0D58219132746BE077DFE821E9F8FD87857B28AB91D6A567E312A73E2636032C",
"4D261270341A7AC6B66900DDC2B8AB34AB483C897410DDF3B2C072BDDA416434",
"3AA49EF72D010AEC19383BA1E83EC2DD3DCC207A96FFCEB9FFA269E3E3D66400",
"5049DC39045618B903C71595B3A3E07A731F95D37304623ACC98BCEF4258B4CA",
NewPrivateKey87, NewPublicKey87,
},
{
"Path/ML-DSA-87/2",
"146C47AB9F88408EB76A813294D533B29D7E0FDA75DA5A4E7C69EB61EFEEBB78",
"05194438AF855B79DB8CCCCB647D6BA5C7AAF901BBD09D3B29395F0EA431D164",
"82C44F998A8D24F056084D0E80ECFD8434493385A284C69974923C270D397782",
"CFFC5988A351E14A3EE1282F042A143679C4503814296B27993949A7FF966F57",
NewPrivateKey87, NewPublicKey87,
},
{
"Path/ML-DSA-87/3",
"049D9B0B646A2AC7F50B63CE5E4BFE44C9B87634F4FF6C14C513E388B8A1F808",
"AC8FE6B2FE26591B129EA536A9A001C785D8ACBDD9489F6E51469A156E9E635D",
"FEBC9F8AE159002BE1A11D395959DD7FC20718135690CDAA2BCFB5801C02AB89",
"FF4006089BDF7337E868F86DDF48F239D2A52EA1D0F686E0103BF19C3B571DB1",
NewPrivateKey87, NewPublicKey87,
},
{
"Path/ML-DSA-87/4",
"9823DDDE446A8EA883DAD3AC6477F79839FDC2D2DEF2416BE0A8B71CFBC3F5C6",
"525010E307C4EA7667D54EE27007C219B01F4CF88DC3AB2DE8E9AAA59440A884",
"F7592C97C1A96A2F4053588F5CDAD4C50BF7C3752709854FA27779B445DD2BA2",
"FD7757602B83B0A67A314CD5BCC880E7AE47ACDF4D6AF98269028EFB486838F7",
NewPrivateKey87, NewPublicKey87,
},
{
"Path/ML-DSA-87/5",
"AE213FE8589B414F53780D8B9B6837179967E13CB474C5AD365C043778D2BC90",
"D4988E91064E5DF6D867434D1DED16DCD8533E39E420DC2B4EB9E40A84146F7D",
"19C1913BA76FF04596BB7CC80FD825A5AEDEF5D5AD61CEDB5203E6D7EDB18877",
"23FE743EDD101970D499E7EB57A7AA245BAF417E851B260C55DD525A445F08DA",
NewPrivateKey87, NewPublicKey87,
},
// https://pages.nist.gov/ACVP/draft-celi-acvp-ml-dsa.html#table-2
// ML-DSA Algorithm 7 ML-DSA.Sign_internal() Known Answer Tests for Number of Rejection Cases
{
"Count/ML-DSA-44/77",
"090D97C1F4166EB32CA67C5FB564ACBE0735DB4AF4B8DB3A7C2CE7402357CA44",
"26D79E4068040E996BC9EB5034C20489C0AD38DC2FEC1918D0760C8621872408",
"E3838364B37F47EDFCA2B577B20B80C3CB51B9F56E0E4CDB7DF002C874039252",
"CD91150C610FF02DE1DD7049C309EFE800CE5C1BC2E5A32D752AB62C5BF5E16F",
NewPrivateKey44, NewPublicKey44,
},
{
"Count/ML-DSA-44/100",
"CFC73D07A883543A804F770070861825143A62F2F97D05FCE00FD8B25D29A43F",
"89142AB26D6EB6C01FA3F189A9C877597740D685983F29BBDD3596648266AE0E",
"0960C13E9BA467A938450120CC96FF6F04B7E557C99A838619A48F9A38738AB8",
"B6296FFF0C1F23DE4906D58144B00A2DB13AD25E49B4B8573A62EFEECB544DD7",
NewPrivateKey44, NewPublicKey44,
},
{
"Count/ML-DSA-65/64",
"26B605C78AC762FA1634C6F91DD117C4FBFF7F3A7E7781F0CC83B6281F04AD7F",
"5DA13E571DF80867A8F27E0FF81BE7252A1ABF89B3D6A03D4036AF643EFBB04B",
"C9B07E7DDC0274468F312F5C692A54AC73D1E34D8638E20A2CD3C788F27D4355",
"12A4637E3A833A5A2A46F6A991399E544B62A230B7AA82F7366840FF6A88DE61",
NewPrivateKey65, NewPublicKey65,
},
{
"Count/ML-DSA-65/73",
"9191CF381BEE17475C011986EFB6AFB1EFA6997442FD33427353F1DA1AA39FC0",
"7930D4E52BA03B61DAA57743B39E291D824DC156356C6B1A8232574D5C8BDD08",
"E616E36E81AA1EC39262109421AE0DDDA5E3B5A8F4A252BCA27AE882538DF618",
"3D758ACE312433D780403B3D4273171FB93D008B395352142C6DC5173E517310",
NewPrivateKey65, NewPublicKey65,
},
{
"Count/ML-DSA-65/66",
"516912C7B90A3DBE009B7478DBCAF0F5C5C9ED9699A20D0CA56CC516E5A444CD",
"0FD15951B93A4D19446B48D47D32D2CA2253FF43BB8CCCB34C07E5F1A3181B7A",
"9247CA75F9456226A0C783DABCC33FF5B4B489575ADED543E74B29B45F9C8EF2",
"E5CE267800EDF33588451050F9B4A5BF97030D045132A7E3ED9210E74028D23B",
NewPrivateKey65, NewPublicKey65,
},
{
"Count/ML-DSA-65/65",
"D4B841F882D50AB9E590066BAFABA0F0D04D32641C0B978E54CCAA69A6E8D2C4",
"0039C128DDE6923EA08FF14F5C5C66DCB282B471FD1917DBEBE07C8C45B73F8A",
"175231657B0F3C7065947999467C342064F29BFAEB553E97561407D5560E3AEB",
"8830EA254AF2854BF67C2B907E2321C94FD6EFB2FDAA77669FC3A5C4426C57C9",
NewPrivateKey65, NewPublicKey65,
},
{
"Count/ML-DSA-65/64",
"5492EB8D811072C030A30CC66B23A173059EBA0D4868CCB92FBE2510B4A5915F",
"573DCD99C86DAE81F6F80CB00AF40846028EA8F9FE63102FE4A78238BC7B660E",
"33D2753ED87D0003B44C1AF5F72EB931F559C6B4931AF7E249F65D3FA7613295",
"84D4AF50933D6E13D4332B86AF0692A66F5030AB01C2EAC4131A5EEBF78CE9E5",
NewPrivateKey65, NewPublicKey65,
},
{
"Count/ML-DSA-87/64",
"B5C07ECEFE9E7C3B885FDEF032BDF9F807B4011E2DFE6806C088D2081631C8EB",
"5D22F4C40F6EEB96BB891DB15884ED4B0009EA02A24D9D1E9ADFC81C7A42EA7F",
"D1D5C2D167D6E62906790A5FEDF5A0A754CFAF47E6A11AEB93FB8C41934C31F8",
"54F0A9CB26F98B394A35918ECA6760EBD10753FC5CDBA8BE508873AD83538131",
NewPrivateKey87, NewPublicKey87,
},
{
"Count/ML-DSA-87/65",
"E8FC3C9FAD711DDA2946334FBBD331468D6E9AB48EB86DCD03F300A17AEBC5E5",
"B6C4DC9B20CE5D0F445931EE316CF0676E806D1A6A98868881D060EA27CEB139",
"3B435F7A2CE431C7AB8EAE0991C5DAC610827C99D27803046FBC6C567D6B71F2",
"E337495F08773F14FB26A3E229B9B26D086644C7FDC300267F9DCDD5D78DB849",
NewPrivateKey87, NewPublicKey87,
},
{
"Count/ML-DSA-87/64",
"151F80886D6CE8C3B428964FE02C40CA0C8EFFA100EE089E54D785344FCCF719",
"127972C33323FEFBF6B69C19E0C86F41558D9AB2B1A8AD6F39BD0A0245DC8D7E",
"C628CE94D2AA99AA50CF15B147D4F9A9C62A3D4612152DE0A502C377F472D614",
"99B552B21432544248BFF47AC8F24CB78DBB25C9683F3ADCB75614BED58A0358",
NewPrivateKey87, NewPublicKey87,
},
{
"Count/ML-DSA-87/64",
"48BEFFB4C97E59E474E1906F39888BE5AE62F6A011C05EF6A6B8D1E54F2171B7",
"72DA77CF563CBB530129F60129AF989CA4036BA1058267BFBA34A2C70BE803C4",
"D2756A8FB4E47F796AF704ED0FC8C6E573D42DFAB443B329F00F8DB2FF12C465",
"E643914B8556D05360C65EB3E7A06BE7C398B82D49973EEFDC711E65B11EB5E8",
NewPrivateKey87, NewPublicKey87,
},
{
"Count/ML-DSA-87/69",
"FE2DA9DD93A077FCB6452AC88D0A5762EB896BAAAC6CE7D01CB1370BA8322390",
"7422DBE3F476FFE41A4EFB33F3DDFD8B328029BA3050603866C36CFBC2EE4B87",
"A86B29ADF2300D2636E21D4A350CD18E55A254379C3659A7A95D8734CEC1F005",
"8D25818DD972FFF5B9E9B4CC534A95100A1340C1C81D1486A68939D340E0A58B",
NewPrivateKey87, NewPublicKey87,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
seed := fromHex(tc.seed)
priv, err := tc.newPrivateKey(seed)
if err != nil {
t.Fatalf("NewPrivateKey: %v", err)
}
if strings.Contains(t.Name(), "/Path/") {
// For path coverage tests, check that we hit all rejection paths.
reached := map[string]bool{"z": false, "r0": false, "ct0": false, "h": false}
// The ct0 rejection is only reachable for ML-DSA-44.
if priv.PublicKey().Parameters() != "ML-DSA-44" {
delete(reached, "ct0")
}
testingOnlyRejectionReason = func(reason string) {
t.Log(reason, "rejection")
reached[reason] = true
}
t.Cleanup(func() {
testingOnlyRejectionReason = nil
})
defer func() {
for reason, hit := range reached {
if !hit {
t.Errorf("Rejection path %q not hit", reason)
}
}
}()
}
pk := priv.PublicKey().Bytes()
sk := TestingOnlyPrivateKeySemiExpandedBytes(priv)
keyHashGot := sha256.Sum256(append(pk, sk...))
keyHashWant := fromHex(tc.keyHash)
if !bytes.Equal(keyHashGot[:], keyHashWant) {
t.Errorf("Key hash mismatch:\n got: %X\n want: %X", keyHashGot, keyHashWant)
}
pub, err := tc.newPublicKey(pk)
if err != nil {
t.Fatalf("NewPublicKey: %v", err)
}
if !pub.Equal(priv.PublicKey()) {
t.Errorf("Parsed public key not equal to original")
}
if *pub != *priv.PublicKey() {
t.Errorf("Parsed public key not identical to original")
}
// The table provides a Sign_internal input (not actually formatted
// like one), which is part of the pre-image of μ.
M := fromHex(tc.msg)
H := sha3.NewShake256()
tr := computePublicKeyHash(pk)
H.Write(tr[:])
H.Write(M)
μ := make([]byte, 64)
H.Read(μ)
t.Logf("Computed μ: %x", μ)
sig, err := SignExternalMuDeterministic(priv, μ)
if err != nil {
t.Fatalf("SignExternalMuDeterministic: %v", err)
}
sigHashGot := sha256.Sum256(sig)
sigHashWant := fromHex(tc.sigHash)
if !bytes.Equal(sigHashGot[:], sigHashWant) {
t.Errorf("Signature hash mismatch:\n got: %X\n want: %X", sigHashGot, sigHashWant)
}
if err := VerifyExternalMu(priv.PublicKey(), μ, sig); err != nil {
t.Errorf("Verify: %v", err)
}
wrong := make([]byte, len(μ))
if err := VerifyExternalMu(priv.PublicKey(), wrong, sig); err == nil {
t.Errorf("Verify passed on wrong message")
}
})
}
}
func TestCASTRejectionPaths(t *testing.T) {
reached := map[string]bool{"z": false, "r0": false, "ct0": false, "h": false}
testingOnlyRejectionReason = func(reason string) {
t.Log(reason, "rejection")
reached[reason] = true
}
t.Cleanup(func() {
testingOnlyRejectionReason = nil
})
fips140CAST()
for reason, hit := range reached {
if !hit {
t.Errorf("Rejection path %q not hit", reason)
}
}
}
func BenchmarkCAST(b *testing.B) {
// IG 10.3.A says "ML-DSA digital signature generation CASTs should cover
// all applicable rejection sampling loop paths". For ML-DSA-44, there are
// four paths. For ML-DSA-65 and ML-DSA-87, only three. This benchmark helps
// us figure out which is faster: four rejections of ML-DSA-44, or three of
// ML-DSA-65. (It's the former, but only barely.)
b.Run("ML-DSA-44", func(b *testing.B) {
// Same as TestACVPRejectionKATs/Test/Path/ML-DSA-44/1.
seed := fromHex("5C624FCC1862452452D0C665840D8237F43108E5499EDCDC108FBC49D596E4B7")
μ := fromHex("2ad1c72bb0fcbe28099ce8bd2ed836dfebe520aad38fbac66ef785a3cfb10fb4" +
"19327fa57818ee4e3718da4be48d24b59a208f8807271fdb7eda6e60141bd263")
skHash := fromHex("29374951cb2bc3cda7315ce7f0ab99c7d2d65292e6c5156e8aa62ac14b1412af")
sigHash := fromHex("dcc71a421bc6ffafb7df0c7f6d018a19ada154d1e2ee360ed533cecd5dc980ad")
for b.Loop() {
priv, err := NewPrivateKey44(seed)
if err != nil {
b.Fatalf("NewPrivateKey: %v", err)
}
sk := TestingOnlyPrivateKeySemiExpandedBytes(priv)
if sha256.Sum256(sk) != ([32]byte)(skHash) {
b.Fatalf("sk hash mismatch, got %x", sha256.Sum256(sk))
}
sig, err := SignExternalMuDeterministic(priv, μ)
if err != nil {
b.Fatalf("SignExternalMuDeterministic: %v", err)
}
if sha256.Sum256(sig) != ([32]byte)(sigHash) {
b.Fatalf("sig hash mismatch, got %x", sha256.Sum256(sig))
}
if err := VerifyExternalMu(priv.PublicKey(), μ, sig); err != nil {
b.Fatalf("Verify: %v", err)
}
}
})
b.Run("ML-DSA-65", func(b *testing.B) {
// Same as TestACVPRejectionKATs/Path/ML-DSA-65/4, which is the only one
// actually covering all three rejection paths, despite IG 10.3.A
// pointing explicitly at these vectors for this check. See
// https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/6U34L4ISYzk/m/hel75x07AQAJ
seed := fromHex("F215BA2280D86F142012FC05FFC04F2C7D22FF5DD7D69AA0EFB081E3A53E9318")
μ := fromHex("35cdb7dddbed44af4641bac659f46598ed769ea9693fd4ed2152b84c45811d2e" +
"66eded1eb20cde1c1f4b82642a330d8e86ac432a2aefaa56cd9b2b5f4affd450")
skHash := fromHex("2e6f5ff659310b8ca1457a65d8b448b297a905dc08e06c1246a97daad0af6f7d")
sigHash := fromHex("c027d21b21fa75abe7f35cd84a54e2e83bd352140bc8c49eab2c45004e7268a7")
for b.Loop() {
priv, err := NewPrivateKey65(seed)
if err != nil {
b.Fatalf("NewPrivateKey: %v", err)
}
sk := TestingOnlyPrivateKeySemiExpandedBytes(priv)
if sha256.Sum256(sk) != ([32]byte)(skHash) {
b.Fatalf("sk hash mismatch, got %x", sha256.Sum256(sk))
}
sig, err := SignExternalMuDeterministic(priv, μ)
if err != nil {
b.Fatalf("SignExternalMuDeterministic: %v", err)
}
if sha256.Sum256(sig) != ([32]byte)(sigHash) {
b.Fatalf("sig hash mismatch, got %x", sha256.Sum256(sig))
}
if err := VerifyExternalMu(priv.PublicKey(), μ, sig); err != nil {
b.Fatalf("Verify: %v", err)
}
}
})
}
func fromHex(s string) []byte {
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

View file

@ -0,0 +1,244 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"crypto/internal/fips140/drbg"
"errors"
"math/bits"
)
// FIPS 204 defines a needless semi-expanded format for private keys. This is
// not a good format for key storage and exchange, because it is large and
// requires careful parsing to reject malformed keys. Seeds instead are just 32
// bytes, are always valid, and always expand to valid keys in memory. It is
// *also* a poor in-memory format, because it defers computing the NTT of s1,
// s2, and t0 and the expansion of A until signing time, which is inefficient.
// For a hot second, it looked like we could have all agreed to only use seeds,
// but unfortunately OpenSSL and BouncyCastle lobbied hard against that during
// the WGLC of the LAMPS IETF working group. Also, ACVP tests provide and expect
// semi-expanded keys, so we implement them here for testing purposes.
func semiExpandedPrivKeySize(p parameters) int {
k, l := p.k, p.l
ηBitlen := bits.Len(uint(p.η)) + 1
// ρ + K + tr + l × n × η-bit coefficients of s₁ +
// k × n × η-bit coefficients of s₂ + k × n × 13-bit coefficients of t₀
return 32 + 32 + 64 + l*n*ηBitlen/8 + k*n*ηBitlen/8 + k*n*13/8
}
// TestingOnlyNewPrivateKeyFromSemiExpanded creates a PrivateKey from a
// semi-expanded private key encoding, for testing purposes. It rejects
// inconsistent keys.
//
// [PrivateKey.Bytes] must NOT be called on the resulting key, as it will
// produce a random value.
func TestingOnlyNewPrivateKeyFromSemiExpanded(sk []byte) (*PrivateKey, error) {
var p parameters
switch len(sk) {
case semiExpandedPrivKeySize(params44):
p = params44
case semiExpandedPrivKeySize(params65):
p = params65
case semiExpandedPrivKeySize(params87):
p = params87
default:
return nil, errors.New("mldsa: invalid semi-expanded private key size")
}
k, l := p.k, p.l
ρ, K, tr, s1, s2, t0, err := skDecode(sk, p)
if err != nil {
return nil, err
}
priv := &PrivateKey{pub: PublicKey{p: p}}
priv.k = K
priv.pub.tr = tr
A := priv.pub.a[:k*l]
computeMatrixA(A, ρ[:], p)
for r := range l {
priv.s1[r] = ntt(s1[r])
}
for r := range k {
priv.s2[r] = ntt(s2[r])
}
for r := range k {
priv.t0[r] = ntt(t0[r])
}
// We need to put something in priv.seed, and putting random bytes feels
// safer than putting anything predictable.
drbg.Read(priv.seed[:])
// Making this format *even more* annoying, we need to recompute t1 from ρ,
// s1, and s2 if we want to generate the public key. This is essentially as
// much work as regenerating everything from seed.
//
// You might also notice that the semi-expanded format also stores t0 and a
// hash of the public key, though. How are we supposed to check they are
// consistent without regenerating the public key? Do we even need to check?
// Who knows! FIPS 204 says
//
// > Note that there exist malformed inputs that can cause skDecode to
// > return values that are not in the correct range. Hence, skDecode
// > should only be run on inputs that come from trusted sources.
//
// so it sounds like it doesn't even want us to check the coefficients are
// within bounds, but especially if using this format for key exchange, that
// sounds like a bad idea. So we check everything.
t1 := make([][n]uint16, k, maxK)
for i := range k {
tHat := priv.s2[i]
for j := range l {
tHat = polyAdd(tHat, nttMul(A[i*l+j], priv.s1[j]))
}
t := inverseNTT(tHat)
for j := range n {
r1, r0 := power2Round(t[j])
t1[i][j] = r1
if r0 != t0[i][j] {
return nil, errors.New("mldsa: semi-expanded private key inconsistent with t0")
}
}
}
pk := pkEncode(priv.pub.raw[:0], ρ[:], t1, p)
if computePublicKeyHash(pk) != tr {
return nil, errors.New("mldsa: semi-expanded private key inconsistent with public key hash")
}
computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ)
return priv, nil
}
func TestingOnlyPrivateKeySemiExpandedBytes(priv *PrivateKey) []byte {
k, l, η := priv.pub.p.k, priv.pub.p.l, priv.pub.p.η
sk := make([]byte, 0, semiExpandedPrivKeySize(priv.pub.p))
sk = append(sk, priv.pub.raw[:32]...) // ρ
sk = append(sk, priv.k[:]...) // K
sk = append(sk, priv.pub.tr[:]...) // tr
for i := range l {
sk = bitPackSlow(sk, inverseNTT(priv.s1[i]), η, η)
}
for i := range k {
sk = bitPackSlow(sk, inverseNTT(priv.s2[i]), η, η)
}
const bound = 1 << (13 - 1) // 2^(d-1)
for i := range k {
sk = bitPackSlow(sk, inverseNTT(priv.t0[i]), bound-1, bound)
}
return sk
}
func skDecode(sk []byte, p parameters) (ρ, K [32]byte, tr [64]byte, s1, s2, t0 []ringElement, err error) {
k, l, η := p.k, p.l, p.η
if len(sk) != semiExpandedPrivKeySize(p) {
err = errors.New("mldsa: invalid semi-expanded private key size")
return
}
copy(ρ[:], sk[:32])
sk = sk[32:]
copy(K[:], sk[:32])
sk = sk[32:]
copy(tr[:], sk[:64])
sk = sk[64:]
s1 = make([]ringElement, l)
for i := range l {
length := n * bits.Len(uint(η)*2) / 8
s1[i], err = bitUnpackSlow(sk[:length], η, η)
if err != nil {
return
}
sk = sk[length:]
}
s2 = make([]ringElement, k)
for i := range k {
length := n * bits.Len(uint(η)*2) / 8
s2[i], err = bitUnpackSlow(sk[:length], η, η)
if err != nil {
return
}
sk = sk[length:]
}
const bound = 1 << (13 - 1) // 2^(d-1)
t0 = make([]ringElement, k)
for i := range k {
length := n * 13 / 8
t0[i], err = bitUnpackSlow(sk[:length], bound-1, bound)
if err != nil {
return
}
sk = sk[length:]
}
return
}
func bitPackSlow(buf []byte, r ringElement, a, b int) []byte {
bitlen := bits.Len(uint(a + b))
if bitlen <= 0 || bitlen > 16 {
panic("mldsa: internal error: invalid bitlen")
}
out, v := sliceForAppend(buf, n*bitlen/8)
var acc uint32
var accBits uint
for i := range r {
w := int32(b) - fieldCenteredMod(r[i])
acc |= uint32(w) << accBits
accBits += uint(bitlen)
for accBits >= 8 {
v[0] = byte(acc)
v = v[1:]
acc >>= 8
accBits -= 8
}
}
if accBits > 0 {
v[0] = byte(acc)
}
return out
}
func bitUnpackSlow(v []byte, a, b int) (ringElement, error) {
bitlen := bits.Len(uint(a + b))
if bitlen <= 0 || bitlen > 16 {
panic("mldsa: internal error: invalid bitlen")
}
if len(v) != n*bitlen/8 {
return ringElement{}, errors.New("mldsa: invalid input length for bitUnpackSlow")
}
mask := uint32((1 << bitlen) - 1)
maxValue := uint32(a + b)
var r ringElement
var acc uint32
var accBits uint
vIdx := 0
for i := range r {
for accBits < uint(bitlen) {
if vIdx < len(v) {
acc |= uint32(v[vIdx]) << accBits
vIdx++
accBits += 8
}
}
w := acc & mask
if w > maxValue {
return ringElement{}, errors.New("mldsa: coefficient out of range")
}
r[i] = fieldSubToMontgomery(uint32(b), w)
acc >>= bitlen
accBits -= uint(bitlen)
}
return r, nil
}

View file

@ -4,9 +4,7 @@
package byteorder package byteorder
import ( import "internal/byteorder"
"internal/byteorder"
)
func LEUint16(b []byte) uint16 { func LEUint16(b []byte) uint16 {
return byteorder.LEUint16(b) return byteorder.LEUint16(b)
@ -36,6 +34,10 @@ func BEPutUint64(b []byte, v uint64) {
byteorder.BEPutUint64(b, v) byteorder.BEPutUint64(b, v)
} }
func LEPutUint16(b []byte, v uint16) {
byteorder.LEPutUint16(b, v)
}
func LEPutUint64(b []byte, v uint64) { func LEPutUint64(b []byte, v uint64) {
byteorder.LEPutUint64(b, v) byteorder.LEPutUint64(b, v)
} }

View file

@ -16,17 +16,17 @@
{"algorithm":"cSHAKE-128","hexCustomization":false,"outputLen":[{"min":16,"max":65536,"increment":8}],"msgLen":[{"min":0,"max":65536,"increment":8}],"revision":"1.0"}, {"algorithm":"cSHAKE-128","hexCustomization":false,"outputLen":[{"min":16,"max":65536,"increment":8}],"msgLen":[{"min":0,"max":65536,"increment":8}],"revision":"1.0"},
{"algorithm":"cSHAKE-256","hexCustomization":false,"outputLen":[{"min":16,"max":65536,"increment":8}],"msgLen":[{"min":0,"max":65536,"increment":8}],"revision":"1.0"}, {"algorithm":"cSHAKE-256","hexCustomization":false,"outputLen":[{"min":16,"max":65536,"increment":8}],"msgLen":[{"min":0,"max":65536,"increment":8}],"revision":"1.0"},
{"algorithm":"HMAC-SHA2-224","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":224,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA2-224","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[224],"revision":"1.0"},
{"algorithm":"HMAC-SHA2-256","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":256,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA2-256","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[256],"revision":"1.0"},
{"algorithm":"HMAC-SHA2-384","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":384,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA2-384","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[384],"revision":"1.0"},
{"algorithm":"HMAC-SHA2-512","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":512,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA2-512","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[512],"revision":"1.0"},
{"algorithm":"HMAC-SHA2-512/224","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":224,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA2-512/224","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[224],"revision":"1.0"},
{"algorithm":"HMAC-SHA2-512/256","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":256,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA2-512/256","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[256],"revision":"1.0"},
{"algorithm":"HMAC-SHA3-224","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":224,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA3-224","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[224],"revision":"1.0"},
{"algorithm":"HMAC-SHA3-256","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":256,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA3-256","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[256],"revision":"1.0"},
{"algorithm":"HMAC-SHA3-384","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":384,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA3-384","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[384],"revision":"1.0"},
{"algorithm":"HMAC-SHA3-512","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[{"increment":8,"max":512,"min":32}],"revision":"1.0"}, {"algorithm":"HMAC-SHA3-512","keyLen":[{"increment":8,"max":524288,"min":8}],"macLen":[512],"revision":"1.0"},
{"algorithm":"KDA","mode":"HKDF","revision":"Sp800-56Cr1","fixedInfoPattern":"uPartyInfo||vPartyInfo","encoding":["concatenation"],"hmacAlg":["SHA2-224","SHA2-256","SHA2-384","SHA2-512","SHA2-512/224","SHA2-512/256","SHA3-224","SHA3-256","SHA3-384","SHA3-512"],"macSaltMethods":["default","random"],"l":2048,"z":[{"min":224,"max":65336,"increment":8}]}, {"algorithm":"KDA","mode":"HKDF","revision":"Sp800-56Cr1","fixedInfoPattern":"uPartyInfo||vPartyInfo","encoding":["concatenation"],"hmacAlg":["SHA2-224","SHA2-256","SHA2-384","SHA2-512","SHA2-512/224","SHA2-512/256","SHA3-224","SHA3-256","SHA3-384","SHA3-512"],"macSaltMethods":["default","random"],"l":2048,"z":[{"min":224,"max":65336,"increment":8}]},
{"algorithm":"KDA","mode":"OneStepNoCounter","revision":"Sp800-56Cr2","auxFunctions":[{"auxFunctionName":"HMAC-SHA2-224","l":224,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-256","l":256,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-384","l":384,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-512","l":512,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-512/224","l":224,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-512/256","l":256,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-224","l":224,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-256","l":256,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-384","l":384,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-512","l":512,"macSaltMethods":["default","random"]}],"fixedInfoPattern":"uPartyInfo||vPartyInfo","encoding":["concatenation"],"z":[{"min":224,"max":65336,"increment":8}]}, {"algorithm":"KDA","mode":"OneStepNoCounter","revision":"Sp800-56Cr2","auxFunctions":[{"auxFunctionName":"HMAC-SHA2-224","l":224,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-256","l":256,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-384","l":384,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-512","l":512,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-512/224","l":224,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA2-512/256","l":256,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-224","l":224,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-256","l":256,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-384","l":384,"macSaltMethods":["default","random"]},{"auxFunctionName":"HMAC-SHA3-512","l":512,"macSaltMethods":["default","random"]}],"fixedInfoPattern":"uPartyInfo||vPartyInfo","encoding":["concatenation"],"z":[{"min":224,"max":65336,"increment":8}]},
@ -64,7 +64,7 @@
{"algorithm":"ACVP-AES-CTR","direction":["encrypt","decrypt"],"keyLen":[128,192,256],"payloadLen":[{"min":8,"max":128,"increment":8}],"incrementalCounter":true,"overflowCounter":true,"performCounterTests":true,"revision":"1.0"}, {"algorithm":"ACVP-AES-CTR","direction":["encrypt","decrypt"],"keyLen":[128,192,256],"payloadLen":[{"min":8,"max":128,"increment":8}],"incrementalCounter":true,"overflowCounter":true,"performCounterTests":true,"revision":"1.0"},
{"algorithm":"ACVP-AES-GCM","direction":["encrypt","decrypt"],"keyLen":[128,192,256],"payloadLen":[{"min":0,"max":65536,"increment":8}],"aadLen":[{"min":0,"max":65536,"increment":8}],"tagLen":[96,104,112,120,128],"ivLen":[96],"ivGen":"external","revision":"1.0"}, {"algorithm":"ACVP-AES-GCM","direction":["encrypt","decrypt"],"keyLen":[128,192,256],"payloadLen":[{"min":0,"max":65536,"increment":8}],"aadLen":[{"min":0,"max":65536,"increment":8}],"tagLen":[96,104,112,120,128],"ivLen":[96],"ivGen":"external","revision":"1.0"},
{"algorithm":"ACVP-AES-GCM","direction":["encrypt","decrypt"],"keyLen":[128,192,256],"payloadLen":[{"min":0,"max":65536,"increment":8}],"aadLen":[{"min":0,"max":65536,"increment":8}],"tagLen":[128],"ivLen":[96],"ivGen":"internal","ivGenMode":"8.2.2","revision":"1.0"}, {"algorithm":"ACVP-AES-GCM","direction":["encrypt","decrypt"],"keyLen":[128,192,256],"payloadLen":[{"min":0,"max":65536,"increment":8}],"aadLen":[{"min":0,"max":65536,"increment":8}],"tagLen":[128],"ivLen":[96],"ivGen":"internal","ivGenMode":"8.2.2","revision":"1.0"},
{"algorithm":"CMAC-AES","capabilities":[{"direction":["gen","ver"],"msgLen":[{"min":0,"max":524288,"increment":8}],"keyLen":[128,256],"macLen":[{"min":8,"max":128,"increment":8}]}],"revision":"1.0"}, {"algorithm":"CMAC-AES","capabilities":[{"direction":["gen","ver"],"msgLen":[{"min":0,"max":524288,"increment":8}],"keyLen":[128,256],"macLen":[128]}],"revision":"1.0"},
{"algorithm":"TLS-v1.2","mode":"KDF","revision":"RFC7627","hashAlg":["SHA2-256","SHA2-384","SHA2-512"]}, {"algorithm":"TLS-v1.2","mode":"KDF","revision":"RFC7627","hashAlg":["SHA2-256","SHA2-384","SHA2-512"]},
{"algorithm":"TLS-v1.3","mode":"KDF","revision":"RFC8446","hmacAlg":["SHA2-256","SHA2-384"],"runningMode":["DHE","PSK","PSK-DHE"]}, {"algorithm":"TLS-v1.3","mode":"KDF","revision":"RFC8446","hmacAlg":["SHA2-256","SHA2-384"],"runningMode":["DHE","PSK","PSK-DHE"]},

View file

@ -2147,9 +2147,9 @@ func TestACVP(t *testing.T) {
const ( const (
bsslModule = "boringssl.googlesource.com/boringssl.git" bsslModule = "boringssl.googlesource.com/boringssl.git"
bsslVersion = "v0.0.0-20250207174145-0bb19f6126cb" bsslVersion = "v0.0.0-20251111011041-baaf868e6e8f"
goAcvpModule = "github.com/cpu/go-acvp" goAcvpModule = "github.com/cpu/go-acvp"
goAcvpVersion = "v0.0.0-20250126154732-de1ba727a0be" goAcvpVersion = "v0.0.0-20251111204335-5c8bf7f5cac1"
) )
// In crypto/tls/bogo_shim_test.go the test is skipped if run on a builder with runtime.GOOS == "windows" // In crypto/tls/bogo_shim_test.go the test is skipped if run on a builder with runtime.GOOS == "windows"

View file

@ -0,0 +1,9 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build fips140v1.0
package fipstest
func fips140v2Conditionals() {}

View file

@ -0,0 +1,16 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !fips140v1.0
package fipstest
import "crypto/internal/fips140/mldsa"
func fips140v2Conditionals() {
// ML-DSA sign and verify PCT
kMLDSA := mldsa.GenerateKey44()
// ML-DSA-44
mldsa.SignDeterministic(kMLDSA, make([]byte, 32), "")
}

View file

@ -6,6 +6,7 @@ package fipstest
import ( import (
"crypto" "crypto"
"crypto/internal/fips140"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"internal/testenv" "internal/testenv"
@ -48,6 +49,8 @@ var allCASTs = []string{
"HKDF-SHA2-256", "HKDF-SHA2-256",
"HMAC-SHA2-256", "HMAC-SHA2-256",
"KAS-ECC-SSC P-256", "KAS-ECC-SSC P-256",
"ML-DSA sign and verify PCT",
"ML-DSA-44",
"ML-KEM PCT", // -768 "ML-KEM PCT", // -768
"ML-KEM PCT", // -1024 "ML-KEM PCT", // -1024
"ML-KEM-768", "ML-KEM-768",
@ -61,6 +64,14 @@ var allCASTs = []string{
"cSHAKE128", "cSHAKE128",
} }
func init() {
if fips140.Version() == "v1.0.0" {
allCASTs = slices.DeleteFunc(allCASTs, func(s string) bool {
return strings.HasPrefix(s, "ML-DSA")
})
}
}
func TestAllCASTs(t *testing.T) { func TestAllCASTs(t *testing.T) {
testenv.MustHaveSource(t) testenv.MustHaveSource(t)
@ -104,6 +115,7 @@ func TestAllCASTs(t *testing.T) {
// TestConditionals causes the conditional CASTs and PCTs to be invoked. // TestConditionals causes the conditional CASTs and PCTs to be invoked.
func TestConditionals(t *testing.T) { func TestConditionals(t *testing.T) {
fips140v2Conditionals()
// ML-KEM PCT // ML-KEM PCT
kMLKEM, err := mlkem.GenerateKey768() kMLKEM, err := mlkem.GenerateKey768()
if err != nil { if err != nil {

View file

@ -0,0 +1,728 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !fips140v1.0
package fipstest
import (
"crypto/internal/cryptotest"
"crypto/internal/fips140"
. "crypto/internal/fips140/mldsa"
"crypto/internal/fips140/sha3"
"encoding/hex"
"flag"
"math/rand"
"testing"
)
var sixtyMillionFlag = flag.Bool("60million", false, "run 60M-iterations accumulated test")
// TestMLDSAAccumulated accumulates 10k (or 100, or 60M) random vectors and checks
// the hash of the result, to avoid checking in megabytes of test vectors.
//
// 60M in particular is enough to give a 99.9% chance of hitting every value in
// the base field.
//
// 1-((q-1)/q)^60000000 ~= 0.9992
//
// If setting -60million, remember to also set -timeout 0.
func TestMLDSAAccumulated(t *testing.T) {
t.Run("ML-DSA-44/100", func(t *testing.T) {
testMLDSAAccumulated(t, NewPrivateKey44, NewPublicKey44, 100,
"d51148e1f9f4fa1a723a6cf42e25f2a99eb5c1b378b3d2dbbd561b1203beeae4")
})
t.Run("ML-DSA-65/100", func(t *testing.T) {
testMLDSAAccumulated(t, NewPrivateKey65, NewPublicKey65, 100,
"8358a1843220194417cadbc2651295cd8fc65125b5a5c1a239a16dc8b57ca199")
})
t.Run("ML-DSA-87/100", func(t *testing.T) {
testMLDSAAccumulated(t, NewPrivateKey87, NewPublicKey87, 100,
"8c3ad714777622b8f21ce31bb35f71394f23bc0fcf3c78ace5d608990f3b061b")
})
if !testing.Short() {
t.Run("ML-DSA-44/10k", func(t *testing.T) {
t.Parallel()
testMLDSAAccumulated(t, NewPrivateKey44, NewPublicKey44, 10000,
"e7fd21f6a59bcba60d65adc44404bb29a7c00e5d8d3ec06a732c00a306a7d143")
})
t.Run("ML-DSA-65/10k", func(t *testing.T) {
t.Parallel()
testMLDSAAccumulated(t, NewPrivateKey65, NewPublicKey65, 10000,
"5ff5e196f0b830c3b10a9eb5358e7c98a3a20136cb677f3ae3b90175c3ace329")
})
t.Run("ML-DSA-87/10k", func(t *testing.T) {
t.Parallel()
testMLDSAAccumulated(t, NewPrivateKey87, NewPublicKey87, 10000,
"80a8cf39317f7d0be0e24972c51ac152bd2a3e09bc0c32ce29dd82c4e7385e60")
})
}
if *sixtyMillionFlag {
t.Run("ML-DSA-44/60M", func(t *testing.T) {
t.Parallel()
testMLDSAAccumulated(t, NewPrivateKey44, NewPublicKey44, 60000000,
"080b48049257f5cd30dee17d6aa393d6c42fe52a29099df84a460ebaf4b02330")
})
t.Run("ML-DSA-65/60M", func(t *testing.T) {
t.Parallel()
testMLDSAAccumulated(t, NewPrivateKey65, NewPublicKey65, 60000000,
"0af0165db2b180f7a83dbecad1ccb758b9c2d834b7f801fc49dd572a9d4b1e83")
})
t.Run("ML-DSA-87/60M", func(t *testing.T) {
t.Parallel()
testMLDSAAccumulated(t, NewPrivateKey87, NewPublicKey87, 60000000,
"011166e9d5032c9bdc5c9bbb5dbb6c86df1c3d9bf3570b65ebae942dd9830057")
})
}
}
func testMLDSAAccumulated(t *testing.T, newPrivateKey func([]byte) (*PrivateKey, error), newPublicKey func([]byte) (*PublicKey, error), n int, expected string) {
s := sha3.NewShake128()
o := sha3.NewShake128()
seed := make([]byte, PrivateKeySize)
msg := make([]byte, 0)
for i := 0; i < n; i++ {
s.Read(seed)
dk, err := newPrivateKey(seed)
if err != nil {
t.Fatalf("NewPrivateKey: %v", err)
}
pk := dk.PublicKey().Bytes()
o.Write(pk)
sig, err := SignDeterministic(dk, msg, "")
if err != nil {
t.Fatalf("SignDeterministic: %v", err)
}
o.Write(sig)
pub, err := newPublicKey(pk)
if err != nil {
t.Fatalf("NewPublicKey: %v", err)
}
if *pub != *dk.PublicKey() {
t.Fatalf("public key mismatch")
}
if err := Verify(dk.PublicKey(), msg, sig, ""); err != nil {
t.Fatalf("Verify: %v", err)
}
}
got := hex.EncodeToString(o.Sum(nil))
if got != expected {
t.Errorf("got %s, expected %s", got, expected)
}
}
func TestMLDSAGenerateKey(t *testing.T) {
t.Run("ML-DSA-44", func(t *testing.T) {
testMLDSAGenerateKey(t, GenerateKey44, NewPrivateKey44)
})
t.Run("ML-DSA-65", func(t *testing.T) {
testMLDSAGenerateKey(t, GenerateKey65, NewPrivateKey65)
})
t.Run("ML-DSA-87", func(t *testing.T) {
testMLDSAGenerateKey(t, GenerateKey87, NewPrivateKey87)
})
}
func testMLDSAGenerateKey(t *testing.T, generateKey func() *PrivateKey, newPrivateKey func([]byte) (*PrivateKey, error)) {
k1 := generateKey()
k2 := generateKey()
if k1.Equal(k2) {
t.Errorf("two generated keys are equal")
}
k1x, err := newPrivateKey(k1.Bytes())
if err != nil {
t.Fatalf("NewPrivateKey: %v", err)
}
if !k1.Equal(k1x) {
t.Errorf("generated key and re-parsed key are not equal")
}
}
func TestMLDSAAllocations(t *testing.T) {
// We allocate the PrivateKey (k and kk) and PublicKey (pk) structs and the
// public key (pkBytes) and signature (sig) byte slices on the heap. They
// are all large and for the byte slices variable-length. Still, check we
// are not slipping more allocations in.
var expected float64 = 5
if fips140.Enabled {
// The PCT does a sign/verify cycle, which allocates a signature slice.
expected += 1
}
cryptotest.SkipTestAllocations(t)
if allocs := testing.AllocsPerRun(100, func() {
k := GenerateKey44()
seed := k.Bytes()
kk, err := NewPrivateKey44(seed)
if err != nil {
t.Fatalf("NewPrivateKey44: %v", err)
}
if !k.Equal(kk) {
t.Fatalf("keys not equal")
}
pkBytes := k.PublicKey().Bytes()
pk, err := NewPublicKey44(pkBytes)
if err != nil {
t.Fatalf("NewPublicKey44: %v", err)
}
message := []byte("Hello, world!")
context := "test"
sig, err := Sign(k, message, context)
if err != nil {
t.Fatalf("Sign: %v", err)
}
if err := Verify(pk, message, sig, context); err != nil {
t.Fatalf("Verify: %v", err)
}
}); allocs > expected {
t.Errorf("expected %0.0f allocations, got %0.1f", expected, allocs)
}
}
func BenchmarkMLDSASign(b *testing.B) {
// Signing works by rejection sampling, which introduces massive variance in
// individual signing times. To get stable but correct results, we benchmark
// a series of representative operations, engineered to have the same
// distribution of rejection counts and reasons as the average case. See also
// https://words.filippo.io/rsa-keygen-bench/ for a similar approach.
b.Run("ML-DSA-44", func(b *testing.B) {
benchmarkMLDSASign(b, NewPrivateKey44, benchmarkMessagesMLDSA44)
})
b.Run("ML-DSA-65", func(b *testing.B) {
benchmarkMLDSASign(b, NewPrivateKey65, benchmarkMessagesMLDSA65)
})
b.Run("ML-DSA-87", func(b *testing.B) {
benchmarkMLDSASign(b, NewPrivateKey87, benchmarkMessagesMLDSA87)
})
}
func benchmarkMLDSASign(b *testing.B, newPrivateKey func([]byte) (*PrivateKey, error), messages []string) {
seed := make([]byte, 32)
priv, err := newPrivateKey(seed)
if err != nil {
b.Fatalf("NewPrivateKey: %v", err)
}
rand.Shuffle(len(messages), func(i, j int) {
messages[i], messages[j] = messages[j], messages[i]
})
i := 0
for b.Loop() {
msg := messages[i]
if i++; i >= len(messages) {
i = 0
}
SignDeterministic(priv, []byte(msg), "")
}
}
// BenchmarkMLDSAVerify runs both public key parsing and signature verification,
// since pre-computation can be easily moved between the two, but in practice
// most uses of verification are for fresh public keys (unlike signing).
func BenchmarkMLDSAVerify(b *testing.B) {
b.Run("ML-DSA-44", func(b *testing.B) {
benchmarkMLDSAVerify(b, GenerateKey44, NewPublicKey44)
})
b.Run("ML-DSA-65", func(b *testing.B) {
benchmarkMLDSAVerify(b, GenerateKey65, NewPublicKey65)
})
b.Run("ML-DSA-87", func(b *testing.B) {
benchmarkMLDSAVerify(b, GenerateKey87, NewPublicKey87)
})
}
func benchmarkMLDSAVerify(b *testing.B, generateKey func() *PrivateKey, newPublicKey func([]byte) (*PublicKey, error)) {
priv := generateKey()
msg := make([]byte, 128)
sig, err := SignDeterministic(priv, msg, "context")
if err != nil {
b.Fatalf("SignDeterministic: %v", err)
}
pub := priv.PublicKey().Bytes()
for b.Loop() {
pk, err := newPublicKey(pub)
if err != nil {
b.Fatalf("NewPublicKey: %v", err)
}
if err := Verify(pk, msg, sig, "context"); err != nil {
b.Fatalf("Verify: %v", err)
}
}
}
func BenchmarkMLDSAKeygen(b *testing.B) {
b.Run("ML-DSA-44", func(b *testing.B) {
for b.Loop() {
NewPrivateKey44(make([]byte, 32))
}
})
b.Run("ML-DSA-65", func(b *testing.B) {
for b.Loop() {
NewPrivateKey65(make([]byte, 32))
}
})
b.Run("ML-DSA-87", func(b *testing.B) {
for b.Loop() {
NewPrivateKey87(make([]byte, 32))
}
})
}
var benchmarkMessagesMLDSA44 = []string{
"BUS7IAZWYOZ4JHJQYDWRTJL4V7",
"MK5HFFNP4TB5S6FM4KUFZSIXPD",
"DBFETUV4O56J57FXTXTIVCDIAR",
"I4FCMZ7UNLYAE2VVPKTE5ETXKL",
"56U76XRPOVFX3AU7MB2JHAP6JX",
"3ER6UPKIIDGCXLGLPU7KI3ODTN",
"JPQDX2IL3W5CYAFRZ4XUJOHQ3G",
"6AJOEI33Z3MLEBVC2Q67AYWK5L",
"WE3U36HYOPJ72RN3C74F6IOTTJ",
"NMPF5I3B2BKQG5RK26LMPQECCX",
"JRGAN2FA6IY7ESFGZ7PVI2RGWA",
"UIKLF6KNSIUHIIVNRKNUFRNR4W",
"HA252APFYUWHSZZFKP7CWGIBRY",
"JFY774TXRITQ6CIR56P2ZOTOL6",
"ZASYLW5Y3RAOC5NDZ2NCH5A4UY",
"42X4JXNPXMFRCFAE5AKR7XTFO7",
"YAHQUWUH534MUI2TYEKQR7VR3A",
"HBP7FGEXGSOZ5HNOVRGXZJU2KG",
"HG4O7DCRMYMQXASFLMYQ6NMIXK",
"2KPQMDZKS65CLJU4DHTMVV5WI3",
"G6YSUTEX4HHL44ISK2JVVK45BV",
"PUJGPEQUBQM3IK2EXDQFJ2WGBG",
"PNS6HMQAWA3RORSMSNEUAINMIR",
"L35MZS4XYIJK453OFXCZG4WHIK",
"CRY54YZMFRF6JTB3FPNNBWPUOG",
"Y25TSZBWGU4HJCRMWZHAWXQ2DN",
"23W64TW3AKZPKCM4HMKEHFI6VQ",
"PWQAOZ24B4VLNEQR4XKN7LZHDI",
"YINPDR3ZSAKPPXP6J6VAXHIPYO",
"JDBB52ZRAB3PYBPNE7P4COY5PJ",
"4DYU52LQLVG3LTREOTLBCJK3XC",
"AB45MV6RKUGPCW4EUK7DX23MJX",
"HEJSITE5K7J6YJ74OEATVTCERV",
"ZKI5QCFCGM26UK7F5KYTENXKD2",
"VH5G3ZLF5XC22QAEJ6JDGOBE5Y",
"HYGXFHH3JW5SENG26MXLL54IGV",
"MJUCRL36JZ757UYHBFPCJBPZRH",
"IBH3T6NAVLCJQBYSVHAQFUITYA",
"VMWCS7JMIMFQB6TPRAMOUXIKWD",
"SXRPGPNNW2MMBKQS3HJURIQ3XV",
"YPPYMJZW6WYXPSCZIPI57NTP5L",
"N3SH6DUH6UOPU7YMQ6BJJEQSPI",
"Q243DGA6VC6CW66FFUAB5V3VLB",
"OUUBXEU4NJBRN5XZJ7YQUPIZLA",
"H5TWHVGC7FXG6MCKJQURD3RNWG",
"OONG2ZZ7H3P5BREEEURNJHBBQG",
"HWROSSRTBCQOAIQAY5S4EQG4FX",
"AJW6PW62JQNU72VKGIQMPBX64C",
"OXECVUVAWBBBXGGQGQBTYVEP4S",
"M5XN6V2LQJDEIN3G4Z6WJO6AVT",
"NHGJUX3WGRTEIRPFWC2I467ST4",
"SEOADTJDKAYYLDSC4VAES2CRDJ",
"J5AT674S577ZFGEURNIAGYOHKW",
"VJQVNMGHG4ITFX2XSPSDEWVZWD",
"ZWY3KJPXTAVWWVHNAJDUXZ52TG",
"HY46PBUGP4EMH34C6Q56MO7CJP",
"MQTUO7CF6R6CRJPVV6F673M6VW",
"35Z2Z5KV2RBJPQ7OZ24ZJE6BKR",
"OVUEVXBLCU2BBY25QP5WJACDIX",
"LNJX7PCLYL35WYJBW6CTXENPUU",
"IH7E766LCENOQ5ZKZVCMLEPACU",
"T2HZFGDDSFQ6YADB52NIFLBFEV",
"RHQUJMN4MB5SYY4FP4ARZH52QJ",
"W7GZC5ZM63UF2EJ7OC4WJM3OTH",
"T2NHNFVOMICY33AQZSR53HXFQ6",
"7ZVB4Y4K4Y2VAM5NC7HHAJNZIB",
"UX2I4VF62XJGP2XTNN6LDKXTOH",
"HJAMJR5RQTQW7JMW7ZLPRBZE7E",
"HKWSKX7MB5346PHYNWNBAYDSYK",
"BVWSB75HFLLE45MWA6EPHPTCFR",
"YDH2J6NMM7UINHGUOPIUI7PSSR",
"SYQPZLK52HMUAQFMVHGRJYKBEY",
"7AA6UQFGSPBGNUDPLWXSGNKKPP",
"AYXRJGRWZ5S3QOEDVWYHHCICHV",
"KFJYAWO7IATSBCSTDUAA5EPFAN",
"3JABTLB6T2ICHGVT3HXZZ3OAIT",
"WCM3IBOCQJ36WSG627CCNK3QA7",
"5FB5H3BZN2J4RGR2DUW7M37NKZ",
"VKDDAD3BVOMPSNEDGIRHKX5S6R",
"LFH5HVUR726OSFD3YVYM3ZHEIH",
"Y4ETQB2KZVFB4M7SALLCTHX2FB",
"E6SAU3C25MO2WBBVBKCKP2N4ZE",
"3JA54Q3NEKURB5EAPL2FOFIESD",
"FZPBW7BIQIW3FTKQD4TLKNWLMD",
"LY5W6XFA2ZRI53FTUJYGWZ5RX6",
"QID236JY3ICR55O5YRED33O7YT",
"HDRU3L6MFEBCBQFNLF5IRPMOAL",
"232ANKJBDBG4TSKQ7GJMWTHT23",
"CDWE3CELZM5AOJGYEFHMUNSP5O",
"7LNJRBOKN6W7RXUU34MDJ2SNKL",
"S3IZOADTW2A6E5IGRO5WKX7FVH",
"ZAISTLXC55EBMTN6KZ6QX5S7OS",
"4Z5ZIVCMFR2PY2PY4Z47T4YPYA",
"NE36L53Z6AMYQU7Q5REFUF76MK",
"WND5UP5M6KWPBRFP5WIWTOWV3I",
"7OC54DLFWMADJEMKEJ3Y2FMMZS",
"BWJVZHGEN43ULNIOZCPZOB64HG",
"VDFPQSR7RE54A75GT4JDZY5JK2",
"HFCD5EPBZBSVMXIDA47DZ6MRD6",
"RNBVFIUUJUM7EHRE3VNWSTORGO",
"VO5NLQJBR22CRRYUETGTU6JLMR",
"RZOMNFHBTL6HMGWH4PEEDASK7U",
"QL73UBTOLK5O2TW43YWAIKS6T3",
"NE3QVSMWS5G3W5C3BMKTJNMI2L",
"YHI6EYQ4GZMB2QPGHPUG2ZUOEL",
"6MBATW7MFNRUQBFD3GM35B7YPM",
"AIYRY6P5T4XU44CGVPEV6W43FR",
"MIAQ2FHXMAPY5NXSS45VRDPRMG",
"2SNLHQYKK2K6NSWOF6KPGZ3CPC",
"RVBHIQO5LH77ZWEAO3SVL72M2V",
"XXTGJCJNRSNLE7ARAH2UU6LVKR",
"DQMGILY5IDMWN5OYQYYXH26ZGR",
"627VTXXMM455KMTFNUUTKNFXPY",
"HC7IBFGLZCWGUR4K7REPMPW6W4",
"CHL6JRQUS7D4NML3PFT37PPZAA",
"Y767HXJAGJ75KE3JLO4DTLQIXC",
"NTIODXI5I7TF2KXXWXOAYGT7G4",
"PKZYEK2WAI4D4HEYYZH6H5IOMP",
"FG6J6G7HZDEDF4JQBQOTC7RQGZ",
"3VHM2VZU77Y25E3UUYZJLB2QLA",
"WRZQJQW7ARH4DXYHVLCJ4HRTTB",
"LQXKV5HD2AZHENSJ2VFLJ5YU5L",
"MF6Q4OA2EN6TG6BUDK7RWCQNPU",
"3USKYKPC5CB3EC4ZRMZVE3R2UO",
"3WICO2GVS3IRBFUHNDLNKWVP7N",
"P6ZR2UZZOVUZKT4KUS5WICW5XE",
"PYPZUU76RYVOUZGUUX33HLDKYA",
"2FTSURHV34VYTVIUU7W6V5C3NK",
"YABDYMGXS2MD2CYF3S4ALG4FLG",
"MHIBDH25RRPWV3P4VAWT6SAX3I",
"OINSMWJQ2UTOOKZ3X6ICXXBQR7",
"PFTQS7JNU2Q3Q6L4CGBXVLOYNE",
"A4MZ7CCVYQUDJ2AFHNXBBQ3D24",
"CPUB5R3ORTCMSMCLUQURE6AN5O",
"NF5E7U3DFTXWFFXXHUXTEP4VZQ",
"AWB5WDFERWSSJG53YGJMDORQKR",
"U5JQUILKD6SEL6LXAMNFZP6VSW",
"M45NLOAFLO74EJKG5EXNET6J5Y",
"P2KTEUMZ5DZZMYSPOHDR2WJXAN",
"KVO7AXZNFBUBPYLOTZQQ42TFNS",
"WGJJ7SAEV6SBBWWYS4BTLD63WM",
"Y6GURVDV4ESRBPWSTV25T4PE4K",
"ESK7MPFPUZ5ZAQ52RP4SQIYCCC",
"623M3CIABZ3RANERQ2IREXAVYO",
"OQ4CQCFO42RS4BMMSGSDLUTOQO",
"AMFHRDVGM6G2TIR3TKIFGFSDVM",
"7VVSGGCVC53PLOYG7YHPFUJM5X",
"Z3HMESVL7EZUSZNZ33WXEBHA2N",
"AWWVRQD5W7IBSQPS26XOJVDV5H",
"OQBZ5ZST3U3NZYHSIWRNROIG6L",
"II573BW7DJLBYJSPSYIABQWDZD",
"MOKXOQFOCUCLQQH4UKH2DPE7VN",
"XR54NGUOU6BBUUTINNWBPJ35HX",
"DNK36COZGFXI6DY7WLCNUETIRT",
"R5M2PV7E3EHEM3TLGRCL3HSFMC",
"ITKENZQYDQMZFCUPOT7VF3BMU7",
"5GDCB74PPPHEP5N5G3DVRCYT7R",
"ZMKXVRPLI5PY5BDVEPOA3NQZGN",
"GBLIALWTHTUDTOMDERQFVB77CS",
"VKRTTXUTFOK4PJAQQZCCT7TV3T",
"ZJBUJJ4SW62BXOID3XO2W2M2PF",
"SKWT5T6QJTCD3FCINIK22KMVBJ",
"EHINNU6L33HRLOOJ3A2XFJSYQL",
"N4HRQJEFPAT5SU3YPO74WSMQIR",
"TGPTZ3ENMFWB5CZKJFR5WHIRI4",
"O4HNFTAUJJ2LZPQXPXRAXOVABA",
"4JVB5STP2YG5GYOXDWIF4KCKFB",
"MY554X3YZHBECLHNNZ7A3SPJTU",
"ASCJMAH7VCQAD2QJSWXPSVSM3H",
"NBNGL5DZ623KCG2JNZFGZMZ7KD",
"KGMZSW35AEQOJ6FA7IR7BHZI52",
"Q7QUHHS4OJFMJ4I3FY6TDKSMZQ",
"MZAE7TOEXAS76T7KIC73FEYRU4",
"2BVESR3REAWADCGYOYM7T646RG",
"EK3L2ORP4LT3HU3EMXDSQWFOKJ",
"3X4A6VMGMIDLVK72FZSDHSERWY",
"I3UHWI6M6HQFRBSQ6W2SABUNUP",
"REKPXW4DIB4MTKMPHN3RBVHVME",
"W37FNFZE35NX65Z7CVQ7L5U4L5",
"4AGYK6U2KP6RAOADCBUDDCBECV",
"IXM4SFQUDW2NOTXZIPWTNGET3F",
"6YE4G3VELF27MN3Z5B4VIQ3XYK",
"LPOZCPZAG3MD47MIWGR4FIOCDH",
"WGREKUL2LD7C7SYGKH7APIY2A6",
"WWW277FKTKUXQMP4BECSRHLWJI",
"UYE4IQPMSTXVQG7EJALKWWEGDN",
"TIV2L5Z6K7SNGNUVWSNKTAF4UE",
"I3FQOAW3PINUK26P62HCX657FO",
}
var benchmarkMessagesMLDSA65 = []string{
"NDGEUBUDWGRJJ3A4UNZZQOEKNL",
"ACGYQUXN4POOFUENCLNCIPHFAZ",
"Z3XETEYKROVJH7SIHOIAYCTO42",
"DXWCVCEFULV7XHRWHJWSEXWES7",
"BCR2D5PNLGFYX6B3QFQFV23JZP",
"2DVP5HNG54ES64QK4D37PWUYTJ",
"UJM4ADPJLURAIQH4XA6QYUGNJ6",
"B5WRCIPK5IVZW52R6TJOKNPKZH",
"7QNL6JTSP62IGX6RCM2NHRMTKK",
"EJSZQYLM7G7AJCGIEVBV2UW7NN",
"UFNA2NKJ3QFWNHHL5CXZ4R5H46",
"QZAXRTT3E4DOGVTJCOTBG3WXQV",
"KH2ETOYZO5UHIHIKATWJMUVG27",
"V5HVVQTOWRXZ2PB4XWXSEKXUN5",
"5LA7NAFI2LESMH533XY45QVCQW",
"SMF4TWPTMJA2Z4F4OVETTLVRAY",
"FWZ5OJAFMLTQRREPYF4VDRPPGI",
"OK3QMNO3OZSKSR6Q4BFVOVRWTH",
"NQOVN6F6AOBOEGMJTVMF67KTIJ",
"CCLC4Y6YT3AQ3HGT2QNSYAUGNV",
"CAZJHCHBUYQ6OKZ7DMWMDDLIZQ",
"LVW5XDTHPKOW5D452SYD7AFO6Q",
"EYA6O6FTYPC6TRKZPRPX5N2KQ4",
"Z6SGAEZ2SAAZHPQO7GL7CUMBAG",
"FKUCKW6JQVF4WQYXUSXYZQMAVY",
"LN2KDF4DANPE4SC4GKJ4BES3IZ",
"AVCRTWB6ALOQHY34XI7NTMP2JH",
"A5WHIS6CBWPCYIEC6N2MBAOEZ6",
"JC2BH476BXUQFIDA6UCR5V4G4F",
"NU6XH6VLSSFHVSRZCYXPFYKYCD",
"GSUXVZBDDYSZYFGXNP6AZW3PTC",
"XJPRNJ26XP4MIYH2Q7M7MPZ73M",
"INUTUP3IRFWIIT23DNFTIYKCFY",
"T4KH7HKLEYGXHBIRFGFCRUZCC4",
"GGQX4JFVWZHE5Y73YTLMSSOXNS",
"BUA4Q3TQZGLVHMMJU62GQOSHLV",
"WXW3SJXLSZO2MYF4YFIMXL2IQP",
"Q32XBVVGFQTSXAIDJE6XSEPRZG",
"6TEXT6SA7INRCTDSCSVZJEQ2YG",
"ZBN4UL43C3SJIG4HYR236PXCVS",
"TVWPLLC7NROBREWOM75VA3XCR3",
"CCDGL2FURLBABQ4IJBYCB75JFR",
"XBZGCOVTZHCPAARBTMAKPIE6GJ",
"TPRAENJ7I54XRIVH6LL6FDIA3I",
"RKOM3PHFILPIIQZL4ILQWGRYWI",
"CEEZIZ2WUXHQQFATYYGQ3ZDBTI",
"SLKOVAP6WLIVJBVU7VZG3ZGEOW",
"TWMCLJJSWEEQQPQGGDKEJ5SU2R",
"IFMUXXCD2LC7IGQLZ2QEK5UOQ2",
"C7IWFEBHW2CXN4XBJS7VLWH3VK",
"7KJYUEW3F264727TM4LE6RMGDO",
"BPG2XAPBMBTA4VMPUM7IZVZPK3",
"Y5X577BWRZNPLNUHJVSKGMUXYB",
"ZCKMKM23E4IUPTNQDFN2LTLZVX",
"4RKK223JNBDAP4G5DOAHHZ3VNO",
"5UZ3TQZHZT22ISTB4WJEVO6MC4",
"YMVS4HFSJ32CRZRL23PXZUEJFJ",
"UQEUJUTPSZLZARNBXWMCTMHPFF",
"CZAAZ5WK7EIPMW7NA3EZNNBF45",
"227PBHH23WM7F2QLEZSPFYXVW4",
"YUYS2J5CRFXZ4J4KJT2ZKIZVW3",
"MFLHZJOZV44SN4AH6OJ3QZWM2O",
"H2B3CRBCXYN7QWDGYUPHQZP23A",
"T4L6YWQUQ3CTACENAJ5WUXZWFH",
"N723H6MUGPZSRZ72C635OD4BP7",
"NI4TUMVA6LQPQV2TXPN4QOIGBZ",
"CQI3S4LSTQASSJJVZXEFPOVW7K",
"ANPY4HJ64LLSB3GK2R4C6WDBS3",
"RGWQCZKQLMT5FZRDE4B3VMASVK",
"Q3WCCF2HA3CA4WWRJBMGBW7WI7",
"2AKJRXFHXLUQPOXPTLSZN5PW4A",
"IJWOOTI4N7RWXJIHAPXN6KEWEN",
"4D53T6N6ATOVTD4LKSTAAWBJMU",
"B4G5HDD6RITG6NIH6FXCRZDYZM",
"TJCDFKMRUY2OG6KRSMNVCGQFUP",
"PB33IHQKALAY6H6GVBVLI6ZRXK",
"SCCWGW2J5S4WL4FTTMQ435F6DB",
"ZVJH2HSMTLHGXMGPMXLJCKCLLE",
"62LG37U6JXR77YRZQQCDSBHVCS",
"BU4CBWOXQ352TEOKIXO245ID4O",
"UEZOH7KEIODSEVRUF6GMWGA2RB",
"IPJWROME4GM66CGLUWP5BJ4SX6",
"355GDC7TG64AZJ7IJX6K62KZCZ",
"AHTFKX3V7XUB3EWOMQVCGZYGUE",
"N4RV2GKXJ4SPHHJ52Z7K5EGLER",
"ZY7V7NE5F66XHDHWM6YNFEWZA6",
"DIKFO5KAVT4WAP7BOEFM56ZUSR",
"4TDFOFKDAPIOM3MU5GD7NPXNWQ",
"AD7YZO756HDK6YWFILAKW3JWA7",
"NUA53JS2ZK2BGHH3A7BJTJZYW7",
"QLCNC3AQNKLRMSYR62WQSQP5VI",
"SJ7OBS7ZYXSGXOYXPE5KW2XKN6",
"44HBMOGMIMJS63CEXQU7FCXE2E",
"KCK3J7ZL6QF4SLHHSWTJURK7PG",
"HLH4CLUGBSOOBSS3BPO62N5MC3",
"3FNS4GITO6OEUBAVDDXK4WOBTD",
"IAC3K3I4AQGY3G6UHG7PL2N6TE",
"KUKLNH74POJI5DYAEWUD7RABTQ",
"ETM6N7VU3GBSQ7P5MCD6UF3E3S",
"IZITM5NYBGJZLSI3BI4VEMW43U",
"46OPQU4LL6N3Z2U7KYPKUMBAGI",
"EV7YZ5DMAV7VKYJQUFSRD37GPP",
"AV7W2PGYDJIAKLFVEBL6BXQSGC",
"M2FOX5QZEZKV4QXKPI5XUZDHEM",
"R4IFPLVMOVYCHRTR6LXAUGP3LL",
"JGH6XJUMP4DRVAM27P2JNOKXVO",
"D2XN3ZLLU6VFPMDYM7NBHSQEOI",
"2PO3BYENOMQK6SHQDCFSRPJQI3",
"IBVQ7U3QEUC6PQRE4PV53JTZTK",
"ZBCOX4P7NG2IXXFB2R43MG2SLV",
"5NJDPQVVDO7ADNZ2CV7L6QBNGZ",
"V7ASFIIYUMXFGW4B7ZM6LOGUTE",
"PX5IJZ7W2LUPKM6YN4PMZ43ZLM",
"AYK7SZ23DHC7Q56MWAJXBG76LB",
"UYCAPXJM4HNGKLIDSZ4NCEDJLN",
"UWMDZ3C2ODLACKGJPGETNQ3TA4",
"Q6OI6R3WYYJ4CCZCDJBQMCRCZR",
"LCMJHLP7354APCEGPKE7HHWTWB",
"N7T7ZKOYPAMEYTTDOWZNCN6PRD",
"UZADPU4UNHAF7L7LQDMTKA2EQH",
"DC2OEPQDECVLRVNNCS6BMH4CRA",
"37IZ427XHUMZ66EJ62U2YEZDAC",
"6BCZDQZDPZLS5OGESKNUBPSSFV",
"ST2LEMJ4OLQ32TJTLH2WCWT4WA",
"GA2TL4SFLEW4G2B5PQMIKJT5XG",
"L7PPBIET26EH7LQTLEFC4I4EIA",
"6YSM7MC2W4DEV6ULAHMX27LH56",
"QL26Z5KZ4YRRG2BXXGDRRLV357",
"677TWRAJ5NSNHCE243POQPEG7K",
"66MEBQJLGAGVXDX3KZ2YFTTVJM",
"6D4VUWAQD6R65ICSDLFAATC67V",
"7GXLD5CNU3TDUQSSW42SHL7B5D",
"RQETUMEBG2ZM2NF2EZAQHGHWWE",
"DCRX5ANWDMXZFIDVAXYLQZYMRN",
"5SDWT7YAF7L4WWANAGYINZAYXH",
"PZILRV7I2S6WKUSHKYRLA2JQY3",
"2G66TK2PZ5MOTAZDN7BFS3LAIH",
"QOLJ3WGJ6JS3FMMXBNTNAIKXVK",
"FMAL67YTHDCCYVZ5CRMN2XJPDN",
"UOTZDXTJKQ3YAIRKHTYNX6G55P",
"X3DLNPJ3V62LRHGEY4DTT35H3R",
"DKU7CHNXPB5QRZVGIQZW46XCKC",
"RAKBD4LQKEDTVDSK3DVTRWG23B",
"INTRA7BWHLVQMBRKBJNUSMF7MU",
"AUYRBNVCOYYHOHUYOOFIZ2FWMD",
"22EJVDEQ7PASLBAMTVKXOQP5RJ",
"3S6NATWA57SFTZEW7UZUOUYAEU",
}
var benchmarkMessagesMLDSA87 = []string{
"LQQPGPNUME6QDNDTQTS4BA7I7M",
"PTYEEJ7RMI6MXNN6PZH222Y6QI",
"R6DTHAADKNMEADDK5ECPNOTOAT",
"S2QM7VDC6UKRQNRETZMNAZ6SJT",
"EYULPTSJORQJCNYNYVHDFN4N3F",
"YETZNHZ75SXFU672VQ5WXYEPV2",
"KTSND3JGA4AN3PCMG4455JEXGR",
"JGE6HK37O6XMWZQZCHFUPNUEXP",
"CRYB2FZD2BYNANBFFO2HRZEHGZ",
"7MLNDZJ7OIEPBJZOMULOMQH2BA",
"4WQCNTIFVSX2DNALMWUKZRA6CI",
"Y5NK4OBDSDWC5WLL27CEEXYYOT",
"C4SSWSPBVCDAWJXH2CDMXR36LH",
"THDBKXRTKWJUGJMAAYTWTFMX7Z",
"NWXPUD4DAA6QOREW4AFFYQYQNG",
"3RQIJXMO7WYHBEBL3G6EOLNZNQ",
"R7JEOHFP2C7O4AVPRPRELXWOMM",
"LU6MWR7SZXVIKS54BY62X67NPA",
"FG2FFM4F2ECKHCSJ75KXK632JP",
"BF76ZDSVVUSYS5KK4FFD22YPS7",
"HCLBWZRLHEMYZLFWHLAN2BKCZ7",
"HGFVS4QC7AWXYPVRSWAK77KTQF",
"LUZ3C53PUUHBWCDJ7WAHK2UT3K",
"Y3WR6SMDUBW34N3MUT7EQYIJCV",
"F2X35AQTXVZBMPXTWNAAH4ZX2W",
"6MKFFDYWD6ZAKS3C6GRCRLZLRF",
"AFMZYYFRHKMQRNKU5UTSKQ74H6",
"TDTN7J3O367OVPWLESRNPLN4M2",
"WYMLD2X6N4CZ2RDOKF5CFTSYTG",
"UNPTSBLJ6HZRNR72T2VEEHCFX2",
"SNCM4R2P27AJOXBS67RMCARS3U",
"OU7QBE5QOXO7CIYTBJR3KOW2WK",
"2NNQOBQKZ2OD4ZAXI3SNEURYUP",
"YQTUPOYBT67XPCHIGKSGSKC3BZ",
"HGB4ZM3G76IXYWWCMVT3HONRIS",
"WZC6QUKRZZ2TOVA277JYKQITEW",
"XO2WT46A5HYL6CUJF7SGJ6YWOG",
"4QJA35PMYQIDRZ7ZHG7RLZJVGF",
"BMJZELWZ4I2UWXESU3NR6ATC4M",
"XWLFB7FN6D5PRY6YUXC5JUIBFM",
"WRAFFF27AVTIOYIBYA2IPTXI3R",
"VOXUTYTN2XZ362OJFO2R53UCUF",
"UHN73ARJ737WUJ6QYEI7U46OPO",
"3Y3K5E2A4ML3VYVNAFWEEIXTSN",
"QMU4322NKPRLE7JBGYFGS36H2S",
"NJAQTNCXPVDICTDVUKTPRCD2AX",
"OC373ZFBNV2H46T6OY3XRPSUHG",
"UBLAS6CDWE3A662MLKP7QDEOCC",
"BKFDLAL2RTPMERYVW3B7UJ5W3H",
"QFKFGXKGW5SAKLBAWQXUWW77OS",
"EJNUQHTLLOVB4ARETOGLY4WUTJ",
"N243OCMVLLAO6I2XLCYOIMQYGY",
"YRRFLWK7ZASUKYX7ZLQMW2PJ6X",
"3DGVPBWD2BIK6KQE65K72DNJNM",
"TJRYMNOAIW33VIHKLJG4GXAVUK",
"6DSRINAYXL34U54U355U7IVFGS",
"6CHA4MX7LVS77XKRWG7IYC3XVL",
"GM2CEGBEPBOHAPIOBUWJ4MJNTG",
"VJKHGBY33VUIJFEQLX3JVUNQBD",
"DTOHAD5M2KL46IZHE4TPLJWHTI",
"IYFG3UDN7ROOY2ZFSLM2BU2LMQ",
"A5OGJHPOE4PW6QSZYHZ5TKPGIC",
"FX4BCN67AEGCLUTLFPNDL3SQU5",
"MWIZQVOZOHTTBUXC3BEX62MNI5",
"BYHVJHBLK4O6LFSKEIQ3CAAKU7",
"QJU7P6KWSSKAA5GVA6RH4OV7MX",
"I3T3XM5Z5TAJHAYDQHFA2ZV7PU",
"L46MQCHV3TJ6FYIQQ2FCJXES74",
"QXZRQIYAJMXYR6PU3VDYGCIT5W",
"MFS53RR2XEYS22NYOJLGTHVTTM",
"FRWIWJRP4AQMXWX4WJ4WYVKM3E",
"X6GK6IGVLJWYSHLKHGXSW3TJDP",
"L5LPJ2HIWA4UY6G6FMZXGDEDAM",
"GD6FYOYUGDHXEQ5S2KLJEGNSN7",
"ODAL7ZRKXSPAAN5DVRBWJQCFQX",
"CV3QFBDXBPT3SCPJGUYSMDN6ZS",
"IGSLSACRZ6XID466KQIB4YNGYO",
"WZ2EACBN26RAML2S52YXRYP2OF",
"LB76VEVNOBYFMKFZ7SDFCBCHQE",
"TLFA7EU3JJFAP6EMUKNV2ZXRBM",
"SIIJF6OXAKRP25CBUYFBRCDDVP",
"TEPNI7TJ7HASJWIQMBS4VFLRQC",
"VK2JINYWEDV7IQFWH4OTAD4W5O",
"GILUH5AMVE4TM7EKPXJBZGT6EJ",
"DV7ALFRAW3TI4WMQQLDTO6RNHN",
"CAIB5G3NXC5ASPLFIWAFPVHS5B",
"MLFJXZUOAGN7EGPMXOOVTB2CL4",
"6MZYT3ANWHBOS67WGHZI3QPEAP",
"LVJDQB52C2PERSSQJRMRCJ4UBF",
"QY4VKAZAYQIZOX2L2VO2QHAQVC",
"UAA5SST2XA76JPKM3XOZ5RUHFI",
"VLZWF53JSQ6SCRUFDKVPXWAS4L",
"NX2DZIKMJIYXUNSAHFP23FHTBU",
"F5OAKDDDA34A2RPIKDPM5CYPMZ",
"E5PEP3ANIK2L4VLOST4NIYNKBD",
"IPBGFLHSMP4UFXF6XJX42T6CAL",
"XHPU7DBFTZB2TX5K34AD6DJTK3",
"2ZU7EJN2DG2UMT6HX5KGS2RFT6",
"SD5S7U34WSE4GBPKVDUDZLBIEH",
"WZFFL3BTQAV4VQMSAGCS45SGG3",
"QE7ZT2LI4CA5DLSVMHV6CP3E3V",
"YIWMS6AS72Z5N2ALZNFGCYC5QL",
"A4QJ5FNY54THAKBOB65K2JBIV7",
"6LORQGA3QO7TNADHEIINQZEE26",
"5V45M6RAKOZDMONYY4DIH3ZBL2",
"SVP7UYIZ5RTLWRKFLCWHAQV3Y2",
"C2UYQL2BBE4VLUJ3IFNFMHAN7O",
"P4DS44LGP2ERZB3OB7JISQKBXA",
"A6B4O5MWALOEHLILSVDOIXHQ4Z",
"DKQJTW5QF7KDZA3IR4X5R5F3CG",
"H6QFQX2C2QTH3YKEOO57SQS23J",
"DIF373ML2RWZMEOIVUHFXKUG7O",
"Z5PPIA3GJ74QXFFCOSUAQMN5YN",
"PM6XIDECSS5S77UXMB55VZHZSE",
}

View file

@ -0,0 +1,130 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hpke
import (
"crypto/aes"
"crypto/cipher"
"errors"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
// The AEAD is one of the three components of an HPKE ciphersuite, implementing
// symmetric encryption.
type AEAD interface {
ID() uint16
keySize() int
nonceSize() int
aead(key []byte) (cipher.AEAD, error)
}
// NewAEAD returns the AEAD implementation for the given AEAD ID.
//
// Applications are encouraged to use specific implementations like [AES128GCM]
// or [ChaCha20Poly1305] instead, unless runtime agility is required.
func NewAEAD(id uint16) (AEAD, error) {
switch id {
case 0x0001: // AES-128-GCM
return AES128GCM(), nil
case 0x0002: // AES-256-GCM
return AES256GCM(), nil
case 0x0003: // ChaCha20Poly1305
return ChaCha20Poly1305(), nil
case 0xFFFF: // Export-only
return ExportOnly(), nil
default:
return nil, fmt.Errorf("unsupported AEAD %04x", id)
}
}
// AES128GCM returns an AES-128-GCM AEAD implementation.
func AES128GCM() AEAD { return aes128GCM }
// AES256GCM returns an AES-256-GCM AEAD implementation.
func AES256GCM() AEAD { return aes256GCM }
// ChaCha20Poly1305 returns a ChaCha20Poly1305 AEAD implementation.
func ChaCha20Poly1305() AEAD { return chacha20poly1305AEAD }
// ExportOnly returns a placeholder AEAD implementation that cannot encrypt or
// decrypt, but only export secrets with [Sender.Export] or [Recipient.Export].
//
// When this is used, [Sender.Seal] and [Recipient.Open] return errors.
func ExportOnly() AEAD { return exportOnlyAEAD{} }
type aead struct {
nK int
nN int
new func([]byte) (cipher.AEAD, error)
id uint16
}
var aes128GCM = &aead{
nK: 128 / 8,
nN: 96 / 8,
new: newAESGCM,
id: 0x0001,
}
var aes256GCM = &aead{
nK: 256 / 8,
nN: 96 / 8,
new: newAESGCM,
id: 0x0002,
}
var chacha20poly1305AEAD = &aead{
nK: chacha20poly1305.KeySize,
nN: chacha20poly1305.NonceSize,
new: chacha20poly1305.New,
id: 0x0003,
}
func newAESGCM(key []byte) (cipher.AEAD, error) {
b, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(b)
}
func (a *aead) ID() uint16 {
return a.id
}
func (a *aead) aead(key []byte) (cipher.AEAD, error) {
if len(key) != a.nK {
return nil, errors.New("invalid key size")
}
return a.new(key)
}
func (a *aead) keySize() int {
return a.nK
}
func (a *aead) nonceSize() int {
return a.nN
}
type exportOnlyAEAD struct{}
func (exportOnlyAEAD) ID() uint16 {
return 0xFFFF
}
func (exportOnlyAEAD) aead(key []byte) (cipher.AEAD, error) {
return nil, nil
}
func (exportOnlyAEAD) keySize() int {
return 0
}
func (exportOnlyAEAD) nonceSize() int {
return 0
}

View file

@ -2,354 +2,261 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package hpke implements Hybrid Public Key Encryption (HPKE) as defined in
// [RFC 9180].
//
// [RFC 9180]: https://www.rfc-editor.org/rfc/rfc9180.html
package hpke package hpke
import ( import (
"crypto"
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/ecdh" "encoding/binary"
"crypto/hkdf"
"crypto/rand"
"errors" "errors"
"internal/byteorder"
"math/bits"
"golang.org/x/crypto/chacha20poly1305"
) )
// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
type hkdfKDF struct {
hash crypto.Hash
}
func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) {
labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
labeledIKM = append(labeledIKM, sid...)
labeledIKM = append(labeledIKM, label...)
labeledIKM = append(labeledIKM, inputKey...)
return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
}
func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) {
labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
labeledInfo = append(labeledInfo, suiteID...)
labeledInfo = append(labeledInfo, label...)
labeledInfo = append(labeledInfo, info...)
return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length))
}
// dhKEM implements the KEM specified in RFC 9180, Section 4.1.
type dhKEM struct {
dh ecdh.Curve
kdf hkdfKDF
suiteID []byte
nSecret uint16
}
type KemID uint16
const DHKEM_X25519_HKDF_SHA256 = 0x0020
var SupportedKEMs = map[uint16]struct {
curve ecdh.Curve
hash crypto.Hash
nSecret uint16
}{
// RFC 9180 Section 7.1
DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
}
func newDHKem(kemID uint16) (*dhKEM, error) {
suite, ok := SupportedKEMs[kemID]
if !ok {
return nil, errors.New("unsupported suite ID")
}
return &dhKEM{
dh: suite.curve,
kdf: hkdfKDF{suite.hash},
suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID),
nSecret: suite.nSecret,
}, nil
}
func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
if err != nil {
return nil, err
}
return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
}
func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
var privEph *ecdh.PrivateKey
if testingOnlyGenerateKey != nil {
privEph, err = testingOnlyGenerateKey()
} else {
privEph, err = dh.dh.GenerateKey(rand.Reader)
}
if err != nil {
return nil, nil, err
}
dhVal, err := privEph.ECDH(pubRecipient)
if err != nil {
return nil, nil, err
}
encPubEph := privEph.PublicKey().Bytes()
encPubRecip := pubRecipient.Bytes()
kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext)
if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
}
func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
pubEph, err := dh.dh.NewPublicKey(encPubEph)
if err != nil {
return nil, err
}
dhVal, err := secRecipient.ECDH(pubEph)
if err != nil {
return nil, err
}
kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
return dh.ExtractAndExpand(dhVal, kemContext)
}
type context struct { type context struct {
aead cipher.AEAD
sharedSecret []byte
suiteID []byte suiteID []byte
key []byte export func(string, uint16) ([]byte, error)
baseNonce []byte
exporterSecret []byte
seqNum uint128 aead cipher.AEAD
baseNonce []byte
// seqNum starts at zero and is incremented for each Seal/Open call.
// 64 bits are enough not to overflow for 500 years at 1ns per operation.
seqNum uint64
} }
// Sender is a sending HPKE context. It is instantiated with a specific KEM
// encapsulation key (i.e. the public key), and it is stateful, incrementing the
// nonce counter for each [Sender.Seal] call.
type Sender struct { type Sender struct {
*context *context
} }
// Recipient is a receiving HPKE context. It is instantiated with a specific KEM
// decapsulation key (i.e. the secret key), and it is stateful, incrementing the
// nonce counter for each successful [Recipient.Open] call.
type Recipient struct { type Recipient struct {
*context *context
} }
var aesGCMNew = func(key []byte) (cipher.AEAD, error) { func newContext(sharedSecret []byte, kemID uint16, kdf KDF, aead AEAD, info []byte) (*context, error) {
block, err := aes.NewCipher(key) sid := suiteID(kemID, kdf.ID(), aead.ID())
if kdf.oneStage() {
secrets := make([]byte, 0, 2+2+len(sharedSecret))
secrets = binary.BigEndian.AppendUint16(secrets, 0) // empty psk
secrets = binary.BigEndian.AppendUint16(secrets, uint16(len(sharedSecret)))
secrets = append(secrets, sharedSecret...)
ksContext := make([]byte, 0, 1+2+2+len(info))
ksContext = append(ksContext, 0) // mode 0
ksContext = binary.BigEndian.AppendUint16(ksContext, 0) // empty psk_id
ksContext = binary.BigEndian.AppendUint16(ksContext, uint16(len(info)))
ksContext = append(ksContext, info...)
secret, err := kdf.labeledDerive(sid, secrets, "secret", ksContext,
uint16(aead.keySize()+aead.nonceSize()+kdf.size()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cipher.NewGCM(block) key := secret[:aead.keySize()]
} baseNonce := secret[aead.keySize() : aead.keySize()+aead.nonceSize()]
expSecret := secret[aead.keySize()+aead.nonceSize():]
type AEADID uint16 a, err := aead.aead(key)
const (
AEAD_AES_128_GCM = 0x0001
AEAD_AES_256_GCM = 0x0002
AEAD_ChaCha20Poly1305 = 0x0003
)
var SupportedAEADs = map[uint16]struct {
keySize int
nonceSize int
aead func([]byte) (cipher.AEAD, error)
}{
// RFC 9180, Section 7.3
AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
}
type KDFID uint16
const KDF_HKDF_SHA256 = 0x0001
var SupportedKDFs = map[uint16]func() *hkdfKDF{
// RFC 9180, Section 7.2
KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
}
func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
sid := suiteID(kemID, kdfID, aeadID)
kdfInit, ok := SupportedKDFs[kdfID]
if !ok {
return nil, errors.New("unsupported KDF id")
}
kdf := kdfInit()
aeadInfo, ok := SupportedAEADs[aeadID]
if !ok {
return nil, errors.New("unsupported AEAD id")
}
pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info) export := func(exporterContext string, length uint16) ([]byte, error) {
return kdf.labeledDerive(sid, expSecret, "sec", []byte(exporterContext), length)
}
return &context{
aead: a,
suiteID: sid,
export: export,
baseNonce: baseNonce,
}, nil
}
pskIDHash, err := kdf.labeledExtract(sid, nil, "psk_id_hash", nil)
if err != nil {
return nil, err
}
infoHash, err := kdf.labeledExtract(sid, nil, "info_hash", info)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ksContext := append([]byte{0}, pskIDHash...) ksContext := append([]byte{0}, pskIDHash...)
ksContext = append(ksContext, infoHash...) ksContext = append(ksContext, infoHash...)
secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) secret, err := kdf.labeledExtract(sid, sharedSecret, "secret", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) key, err := kdf.labeledExpand(sid, secret, "key", ksContext, uint16(aead.keySize()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) a, err := aead.aead(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) baseNonce, err := kdf.labeledExpand(sid, secret, "base_nonce", ksContext, uint16(aead.nonceSize()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
expSecret, err := kdf.labeledExpand(sid, secret, "exp", ksContext, uint16(kdf.size()))
aead, err := aeadInfo.aead(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
export := func(exporterContext string, length uint16) ([]byte, error) {
return kdf.labeledExpand(sid, expSecret, "sec", []byte(exporterContext), length)
}
return &context{ return &context{
aead: aead, aead: a,
sharedSecret: sharedSecret,
suiteID: sid, suiteID: sid,
key: key, export: export,
baseNonce: baseNonce, baseNonce: baseNonce,
exporterSecret: exporterSecret,
}, nil }, nil
} }
func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) { // NewSender returns a sending HPKE context for the provided KEM encapsulation
kem, err := newDHKem(kemID) // key (i.e. the public key), and using the ciphersuite defined by the
// combination of KEM, KDF, and AEAD.
//
// The info parameter is additional public information that must match between
// sender and recipient.
//
// The returned enc ciphertext can be used to instantiate a matching receiving
// HPKE context with the corresponding KEM decapsulation key.
func NewSender(pk PublicKey, kdf KDF, aead AEAD, info []byte) (enc []byte, s *Sender, err error) {
sharedSecret, encapsulatedKey, err := pk.encap()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
sharedSecret, encapsulatedKey, err := kem.Encap(pub) context, err := newContext(sharedSecret, pk.KEM().ID(), kdf, aead, info)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
if err != nil {
return nil, nil, err
}
return encapsulatedKey, &Sender{context}, nil return encapsulatedKey, &Sender{context}, nil
} }
func SetupRecipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Recipient, error) { // NewRecipient returns a receiving HPKE context for the provided KEM
kem, err := newDHKem(kemID) // decapsulation key (i.e. the secret key), and using the ciphersuite defined by
// the combination of KEM, KDF, and AEAD.
//
// The enc parameter must have been produced by a matching sending HPKE context
// with the corresponding KEM encapsulation key. The info parameter is
// additional public information that must match between sender and recipient.
func NewRecipient(enc []byte, k PrivateKey, kdf KDF, aead AEAD, info []byte) (*Recipient, error) {
sharedSecret, err := k.decap(enc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sharedSecret, err := kem.Decap(encPubEph, priv) context, err := newContext(sharedSecret, k.KEM().ID(), kdf, aead, info)
if err != nil { if err != nil {
return nil, err return nil, err
} }
context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
if err != nil {
return nil, err
}
return &Recipient{context}, nil return &Recipient{context}, nil
} }
// Seal encrypts the provided plaintext, optionally binding to the additional
// public data aad.
//
// Seal uses incrementing counters for each call, and Open on the receiving side
// must be called in the same order as Seal.
func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
if s.aead == nil {
return nil, errors.New("export-only instantiation")
}
ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
s.seqNum++
return ciphertext, nil
}
// Seal instantiates a single-use HPKE sending HPKE context like [NewSender],
// and then encrypts the provided plaintext like [Sender.Seal] (with no aad).
// Seal returns the concatenation of the encapsulated key and the ciphertext.
func Seal(pk PublicKey, kdf KDF, aead AEAD, info, plaintext []byte) ([]byte, error) {
enc, s, err := NewSender(pk, kdf, aead, info)
if err != nil {
return nil, err
}
ct, err := s.Seal(nil, plaintext)
if err != nil {
return nil, err
}
return append(enc, ct...), nil
}
// Export produces a secret value derived from the shared key between sender and
// recipient. length must be at most 65,535.
func (s *Sender) Export(exporterContext string, length int) ([]byte, error) {
if length < 0 || length > 0xFFFF {
return nil, errors.New("invalid length")
}
return s.export(exporterContext, uint16(length))
}
// Open decrypts the provided ciphertext, optionally binding to the additional
// public data aad, or returns an error if decryption fails.
//
// Open uses incrementing counters for each successful call, and must be called
// in the same order as Seal on the sending side.
func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) {
if r.aead == nil {
return nil, errors.New("export-only instantiation")
}
plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
if err != nil {
return nil, err
}
r.seqNum++
return plaintext, nil
}
// Open instantiates a single-use HPKE receiving HPKE context like [NewRecipient],
// and then decrypts the provided ciphertext like [Recipient.Open] (with no aad).
// ciphertext must be the concatenation of the encapsulated key and the actual ciphertext.
func Open(k PrivateKey, kdf KDF, aead AEAD, info, ciphertext []byte) ([]byte, error) {
encSize := k.KEM().encSize()
if len(ciphertext) < encSize {
return nil, errors.New("ciphertext too short")
}
enc, ciphertext := ciphertext[:encSize], ciphertext[encSize:]
r, err := NewRecipient(enc, k, kdf, aead, info)
if err != nil {
return nil, err
}
return r.Open(nil, ciphertext)
}
// Export produces a secret value derived from the shared key between sender and
// recipient. length must be at most 65,535.
func (r *Recipient) Export(exporterContext string, length int) ([]byte, error) {
if length < 0 || length > 0xFFFF {
return nil, errors.New("invalid length")
}
return r.export(exporterContext, uint16(length))
}
func (ctx *context) nextNonce() []byte { func (ctx *context) nextNonce() []byte {
nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():] nonce := make([]byte, ctx.aead.NonceSize())
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], ctx.seqNum)
for i := range ctx.baseNonce { for i := range ctx.baseNonce {
nonce[i] ^= ctx.baseNonce[i] nonce[i] ^= ctx.baseNonce[i]
} }
return nonce return nonce
} }
func (ctx *context) incrementNonce() {
// Message limit is, according to the RFC, 2^95+1, which
// is somewhat confusing, but we do as we're told.
if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
panic("message limit reached")
}
ctx.seqNum = ctx.seqNum.addOne()
}
func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
s.incrementNonce()
return ciphertext, nil
}
func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) {
plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
if err != nil {
return nil, err
}
r.incrementNonce()
return plaintext, nil
}
func suiteID(kemID, kdfID, aeadID uint16) []byte { func suiteID(kemID, kdfID, aeadID uint16) []byte {
suiteID := make([]byte, 0, 4+2+2+2) suiteID := make([]byte, 0, 4+2+2+2)
suiteID = append(suiteID, []byte("HPKE")...) suiteID = append(suiteID, []byte("HPKE")...)
suiteID = byteorder.BEAppendUint16(suiteID, kemID) suiteID = binary.BigEndian.AppendUint16(suiteID, kemID)
suiteID = byteorder.BEAppendUint16(suiteID, kdfID) suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID)
suiteID = byteorder.BEAppendUint16(suiteID, aeadID) suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID)
return suiteID return suiteID
} }
func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
kemInfo, ok := SupportedKEMs[kemID]
if !ok {
return nil, errors.New("unsupported KEM id")
}
return kemInfo.curve.NewPublicKey(bytes)
}
func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
kemInfo, ok := SupportedKEMs[kemID]
if !ok {
return nil, errors.New("unsupported KEM id")
}
return kemInfo.curve.NewPrivateKey(bytes)
}
type uint128 struct {
hi, lo uint64
}
func (u uint128) addOne() uint128 {
lo, carry := bits.Add64(u.lo, 1, 0)
return uint128{u.hi + carry, lo}
}
func (u uint128) bitLen() int {
return bits.Len64(u.hi) + bits.Len64(u.lo)
}
func (u uint128) bytes() []byte {
b := make([]byte, 16)
byteorder.BEPutUint64(b[0:], u.hi)
byteorder.BEPutUint64(b[8:], u.lo)
return b
}

View file

@ -6,18 +6,69 @@ package hpke
import ( import (
"bytes" "bytes"
"crypto/ecdh"
"crypto/mlkem"
"crypto/mlkem/mlkemtest"
"crypto/sha3"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt"
"io"
"os" "os"
"strconv"
"strings"
"testing" "testing"
"crypto/ecdh"
_ "crypto/sha256"
_ "crypto/sha512"
) )
func Example() {
// In this example, we use MLKEM768-X25519 as the KEM, HKDF-SHA256 as the
// KDF, and AES-256-GCM as the AEAD to encrypt a single message from a
// sender to a recipient using the one-shot API.
kem, kdf, aead := MLKEM768X25519(), HKDFSHA256(), AES256GCM()
// Recipient side
var (
recipientPrivateKey PrivateKey
publicKeyBytes []byte
)
{
k, err := kem.GenerateKey()
if err != nil {
panic(err)
}
recipientPrivateKey = k
publicKeyBytes = k.PublicKey().Bytes()
}
// Sender side
var ciphertext []byte
{
publicKey, err := kem.NewPublicKey(publicKeyBytes)
if err != nil {
panic(err)
}
message := []byte("|-()-|")
ct, err := Seal(publicKey, kdf, aead, []byte("example"), message)
if err != nil {
panic(err)
}
ciphertext = ct
}
// Recipient side
{
plaintext, err := Open(recipientPrivateKey, kdf, aead, []byte("example"), ciphertext)
if err != nil {
panic(err)
}
fmt.Printf("Decrypted message: %s\n", plaintext)
}
// Output:
// Decrypted message: |-()-|
}
func mustDecodeHex(t *testing.T, in string) []byte { func mustDecodeHex(t *testing.T, in string) []byte {
t.Helper() t.Helper()
b, err := hex.DecodeString(in) b, err := hex.DecodeString(in)
@ -27,169 +78,433 @@ func mustDecodeHex(t *testing.T, in string) []byte {
return b return b
} }
func parseVectorSetup(vector string) map[string]string { func TestVectors(t *testing.T) {
vals := map[string]string{} t.Run("rfc9180", func(t *testing.T) {
for _, l := range strings.Split(vector, "\n") { testVectors(t, "rfc9180")
fields := strings.Split(l, ": ") })
vals[fields[0]] = fields[1] t.Run("hpke-pq", func(t *testing.T) {
} testVectors(t, "hpke-pq")
return vals })
} }
func parseVectorEncryptions(vector string) []map[string]string { func testVectors(t *testing.T, name string) {
vals := []map[string]string{} vectorsJSON, err := os.ReadFile("testdata/" + name + ".json")
for _, section := range strings.Split(vector, "\n\n") {
e := map[string]string{}
for _, l := range strings.Split(section, "\n") {
fields := strings.Split(l, ": ")
e[fields[0]] = fields[1]
}
vals = append(vals, e)
}
return vals
}
func TestRFC9180Vectors(t *testing.T) {
vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var vectors []struct { var vectors []struct {
Name string Mode uint16 `json:"mode"`
Setup string KEM uint16 `json:"kem_id"`
Encryptions string KDF uint16 `json:"kdf_id"`
AEAD uint16 `json:"aead_id"`
Info string `json:"info"`
IkmE string `json:"ikmE"`
IkmR string `json:"ikmR"`
SkRm string `json:"skRm"`
PkRm string `json:"pkRm"`
Enc string `json:"enc"`
Encryptions []struct {
Aad string `json:"aad"`
Ct string `json:"ct"`
Nonce string `json:"nonce"`
Pt string `json:"pt"`
} `json:"encryptions"`
Exports []struct {
Context string `json:"exporter_context"`
L int `json:"L"`
Value string `json:"exported_value"`
} `json:"exports"`
// Instead of checking in a very large rfc9180.json, we computed
// alternative accumulated values.
AccEncryptions string `json:"encryptions_accumulated"`
AccExports string `json:"exports_accumulated"`
} }
if err := json.Unmarshal(vectorsJSON, &vectors); err != nil { if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
t.Fatal(err) t.Fatal(err)
} }
for _, vector := range vectors { for _, vector := range vectors {
t.Run(vector.Name, func(t *testing.T) { name := fmt.Sprintf("mode %04x kem %04x kdf %04x aead %04x",
setup := parseVectorSetup(vector.Setup) vector.Mode, vector.KEM, vector.KDF, vector.AEAD)
t.Run(name, func(t *testing.T) {
if vector.Mode != 0 {
t.Skip("only mode 0 (base) is supported")
}
if vector.KEM == 0x0021 {
t.Skip("KEM 0x0021 (DHKEM(X448)) not supported")
}
if vector.KEM == 0x0040 {
t.Skip("KEM 0x0040 (ML-KEM-512) not supported")
}
if vector.KDF == 0x0012 || vector.KDF == 0x0013 {
t.Skipf("TurboSHAKE KDF not supported")
}
kemID, err := strconv.Atoi(setup["kem_id"]) kdf, err := NewKDF(vector.KDF)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, ok := SupportedKEMs[uint16(kemID)]; !ok { if kdf.ID() != vector.KDF {
t.Skip("unsupported KEM") t.Errorf("unexpected KDF ID: got %04x, want %04x", kdf.ID(), vector.KDF)
} }
kdfID, err := strconv.Atoi(setup["kdf_id"])
aead, err := NewAEAD(vector.AEAD)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, ok := SupportedKDFs[uint16(kdfID)]; !ok { if aead.ID() != vector.AEAD {
t.Skip("unsupported KDF") t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD)
} }
aeadID, err := strconv.Atoi(setup["aead_id"])
kem, err := NewKEM(vector.KEM)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, ok := SupportedAEADs[uint16(aeadID)]; !ok { if kem.ID() != vector.KEM {
t.Skip("unsupported AEAD") t.Errorf("unexpected KEM ID: got %04x, want %04x", kem.ID(), vector.KEM)
} }
info := mustDecodeHex(t, setup["info"]) pubKeyBytes := mustDecodeHex(t, vector.PkRm)
pubKeyBytes := mustDecodeHex(t, setup["pkRm"]) kemSender, err := kem.NewPublicKey(pubKeyBytes)
pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if kemSender.KEM() != kem {
ephemeralPrivKey := mustDecodeHex(t, setup["skEm"]) t.Errorf("unexpected KEM from sender: got %04x, want %04x", kemSender.KEM().ID(), kem.ID())
}
testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) { if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) {
return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey) t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes)
} }
t.Cleanup(func() { testingOnlyGenerateKey = nil })
encap, sender, err := SetupSender( ikmE := mustDecodeHex(t, vector.IkmE)
uint16(kemID), setupDerandomizedEncap(t, ikmE, kemSender)
uint16(kdfID),
uint16(aeadID), info := mustDecodeHex(t, vector.Info)
pub, encap, sender, err := NewSender(kemSender, kdf, aead, info)
info,
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(encap) != kem.encSize() {
t.Errorf("unexpected encapsulated key size: got %d, want %d", len(encap), kem.encSize())
}
expectedEncap := mustDecodeHex(t, setup["enc"]) expectedEncap := mustDecodeHex(t, vector.Enc)
if !bytes.Equal(encap, expectedEncap) { if !bytes.Equal(encap, expectedEncap) {
t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap) t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
} }
privKeyBytes := mustDecodeHex(t, setup["skRm"]) privKeyBytes := mustDecodeHex(t, vector.SkRm)
priv, err := ParseHPKEPrivateKey(uint16(kemID), privKeyBytes) kemRecipient, err := kem.NewPrivateKey(privKeyBytes)
if err != nil {
t.Fatal(err)
}
if kemRecipient.KEM() != kem {
t.Errorf("unexpected KEM from recipient: got %04x, want %04x", kemRecipient.KEM().ID(), kem.ID())
}
kemRecipientBytes, err := kemRecipient.Bytes()
if err != nil {
t.Fatal(err)
}
// X25519 serialized keys must be clamped, so the bytes might not match.
if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes)
}
if vector.KEM == DHKEM(ecdh.X25519()).ID() {
kem2, err := kem.NewPrivateKey(kemRecipientBytes)
if err != nil {
t.Fatal(err)
}
kemRecipientBytes2, err := kem2.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) {
t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes)
}
if !bytes.Equal(kem2.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.PublicKey().Bytes(), pubKeyBytes)
}
}
if !bytes.Equal(kemRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.PublicKey().Bytes(), pubKeyBytes)
}
ikm := mustDecodeHex(t, vector.IkmR)
derivRecipient, err := kem.DeriveKeyPair(ikm)
if err != nil {
t.Fatal(err)
}
derivRecipientBytes, err := derivRecipient.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(derivRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Errorf("unexpected KEM bytes from seed: got %x, want %x", derivRecipientBytes, privKeyBytes)
}
if !bytes.Equal(derivRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", derivRecipient.PublicKey().Bytes(), pubKeyBytes)
}
recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
recipient, err := SetupRecipient( if aead != ExportOnly() && len(vector.AccEncryptions) != 0 {
uint16(kemID), source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
uint16(kdfID), for range 1000 {
uint16(aeadID), aad, plaintext := drawRandomInput(t, source), drawRandomInput(t, source)
priv, ciphertext, err := sender.Seal(aad, plaintext)
info,
encap,
)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
sink.Write(ciphertext)
for _, ctx := range []*context{sender.context, recipient.context} { got, err := recipient.Open(aad, ciphertext)
expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
if !bytes.Equal(ctx.sharedSecret, expectedSharedSecret) {
t.Errorf("unexpected shared secret, got: %x, want %x", ctx.sharedSecret, expectedSharedSecret)
}
expectedKey := mustDecodeHex(t, setup["key"])
if !bytes.Equal(ctx.key, expectedKey) {
t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey)
}
expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) {
t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce)
}
expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
if !bytes.Equal(ctx.exporterSecret, expectedExporterSecret) {
t.Errorf("unexpected exporter secret, got: %x, want %x", ctx.exporterSecret, expectedExporterSecret)
}
}
for _, enc := range parseVectorEncryptions(vector.Encryptions) {
t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
seqNum, err := strconv.Atoi(enc["sequence number"])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
sender.seqNum = uint128{lo: uint64(seqNum)} if !bytes.Equal(got, plaintext) {
recipient.seqNum = uint128{lo: uint64(seqNum)} t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
expectedNonce := mustDecodeHex(t, enc["nonce"])
computedNonce := sender.nextNonce()
if !bytes.Equal(computedNonce, expectedNonce) {
t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
} }
}
encryptions := make([]byte, 16)
sink.Read(encryptions)
expectedEncryptions := mustDecodeHex(t, vector.AccEncryptions)
if !bytes.Equal(encryptions, expectedEncryptions) {
t.Errorf("unexpected accumulated encryptions, got: %x, want %x", encryptions, expectedEncryptions)
}
} else if aead != ExportOnly() {
for _, enc := range vector.Encryptions {
aad := mustDecodeHex(t, enc.Aad)
plaintext := mustDecodeHex(t, enc.Pt)
expectedCiphertext := mustDecodeHex(t, enc.Ct)
expectedCiphertext := mustDecodeHex(t, enc["ct"]) ciphertext, err := sender.Seal(aad, plaintext)
ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(ciphertext, expectedCiphertext) { if !bytes.Equal(ciphertext, expectedCiphertext) {
t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext) t.Errorf("unexpected ciphertext, got: %x, want %x", ciphertext, expectedCiphertext)
} }
expectedPlaintext := mustDecodeHex(t, enc["pt"]) got, err := recipient.Open(aad, ciphertext)
plaintext, err := recipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"]))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(plaintext, expectedPlaintext) { if !bytes.Equal(got, plaintext) {
t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext) t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
}
}
} else {
if _, err := sender.Seal(nil, nil); err == nil {
t.Error("expected error from Seal with export-only AEAD")
}
if _, err := recipient.Open(nil, nil); err == nil {
t.Error("expected error from Open with export-only AEAD")
}
}
if len(vector.AccExports) != 0 {
source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
for l := range 1000 {
context := string(drawRandomInput(t, source))
value, err := sender.Export(context, l)
if err != nil {
t.Fatal(err)
}
sink.Write(value)
got, err := recipient.Export(context, l)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, value) {
t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
}
}
exports := make([]byte, 16)
sink.Read(exports)
expectedExports := mustDecodeHex(t, vector.AccExports)
if !bytes.Equal(exports, expectedExports) {
t.Errorf("unexpected accumulated exports, got: %x, want %x", exports, expectedExports)
}
} else {
for _, exp := range vector.Exports {
context := string(mustDecodeHex(t, exp.Context))
expectedValue := mustDecodeHex(t, exp.Value)
value, err := sender.Export(context, exp.L)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(value, expectedValue) {
t.Errorf("unexpected exported value, got: %x, want %x", value, expectedValue)
}
got, err := recipient.Export(context, exp.L)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, value) {
t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
}
}
} }
}) })
} }
}
func drawRandomInput(t *testing.T, r io.Reader) []byte {
t.Helper()
l := make([]byte, 1)
if _, err := r.Read(l); err != nil {
t.Fatal(err)
}
n := int(l[0])
b := make([]byte, n)
if _, err := r.Read(b); err != nil {
t.Fatal(err)
}
return b
}
func setupDerandomizedEncap(t *testing.T, randBytes []byte, pk PublicKey) {
t.Cleanup(func() {
testingOnlyGenerateKey = nil
testingOnlyEncapsulate = nil
}) })
switch pk.KEM() {
case DHKEM(ecdh.P256()), DHKEM(ecdh.P384()), DHKEM(ecdh.P521()), DHKEM(ecdh.X25519()):
r, err := pk.KEM().DeriveKeyPair(randBytes)
if err != nil {
t.Fatal(err)
}
testingOnlyGenerateKey = func() *ecdh.PrivateKey {
return r.(*dhKEMPrivateKey).priv.(*ecdh.PrivateKey)
}
case mlkem768:
pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey768)
testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate768(pq, randBytes)
if err != nil {
t.Fatal(err)
}
return ss, ct
}
case mlkem1024:
pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey1024)
testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate1024(pq, randBytes)
if err != nil {
t.Fatal(err)
}
return ss, ct
}
case mlkem768X25519:
pqRand, tRand := randBytes[:32], randBytes[32:]
pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
k, err := ecdh.X25519().NewPrivateKey(tRand)
if err != nil {
t.Fatal(err)
}
testingOnlyGenerateKey = func() *ecdh.PrivateKey {
return k
}
testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
if err != nil {
t.Fatal(err)
}
return ss, ct
}
case mlkem768P256:
// The rest of randBytes are the following candidates for rejection
// sampling, but they are never reached.
pqRand, tRand := randBytes[:32], randBytes[32:64]
pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
k, err := ecdh.P256().NewPrivateKey(tRand)
if err != nil {
t.Fatal(err)
}
testingOnlyGenerateKey = func() *ecdh.PrivateKey {
return k
}
testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
if err != nil {
t.Fatal(err)
}
return ss, ct
}
case mlkem1024P384:
pqRand, tRand := randBytes[:32], randBytes[32:]
pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey1024)
k, err := ecdh.P384().NewPrivateKey(tRand)
if err != nil {
t.Fatal(err)
}
testingOnlyGenerateKey = func() *ecdh.PrivateKey {
return k
}
testingOnlyEncapsulate = func() ([]byte, []byte) {
ss, ct, err := mlkemtest.Encapsulate1024(pq, pqRand)
if err != nil {
t.Fatal(err)
}
return ss, ct
}
default:
t.Fatalf("unsupported KEM %04x", pk.KEM().ID())
}
}
func TestSingletons(t *testing.T) {
if HKDFSHA256() != HKDFSHA256() {
t.Error("HKDFSHA256() != HKDFSHA256()")
}
if HKDFSHA384() != HKDFSHA384() {
t.Error("HKDFSHA384() != HKDFSHA384()")
}
if HKDFSHA512() != HKDFSHA512() {
t.Error("HKDFSHA512() != HKDFSHA512()")
}
if AES128GCM() != AES128GCM() {
t.Error("AES128GCM() != AES128GCM()")
}
if AES256GCM() != AES256GCM() {
t.Error("AES256GCM() != AES256GCM()")
}
if ChaCha20Poly1305() != ChaCha20Poly1305() {
t.Error("ChaCha20Poly1305() != ChaCha20Poly1305()")
}
if ExportOnly() != ExportOnly() {
t.Error("ExportOnly() != ExportOnly()")
}
if DHKEM(ecdh.P256()) != DHKEM(ecdh.P256()) {
t.Error("DHKEM(P-256) != DHKEM(P-256)")
}
if DHKEM(ecdh.P384()) != DHKEM(ecdh.P384()) {
t.Error("DHKEM(P-384) != DHKEM(P-384)")
}
if DHKEM(ecdh.P521()) != DHKEM(ecdh.P521()) {
t.Error("DHKEM(P-521) != DHKEM(P-521)")
}
if DHKEM(ecdh.X25519()) != DHKEM(ecdh.X25519()) {
t.Error("DHKEM(X25519) != DHKEM(X25519)")
}
if MLKEM768() != MLKEM768() {
t.Error("MLKEM768() != MLKEM768()")
}
if MLKEM1024() != MLKEM1024() {
t.Error("MLKEM1024() != MLKEM1024()")
}
if MLKEM768X25519() != MLKEM768X25519() {
t.Error("MLKEM768X25519() != MLKEM768X25519()")
}
if MLKEM768P256() != MLKEM768P256() {
t.Error("MLKEM768P256() != MLKEM768P256()")
}
if MLKEM1024P384() != MLKEM1024P384() {
t.Error("MLKEM1024P384() != MLKEM1024P384()")
} }
} }

View file

@ -0,0 +1,155 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hpke
import (
"crypto/hkdf"
"crypto/sha256"
"crypto/sha3"
"crypto/sha512"
"encoding/binary"
"errors"
"fmt"
"hash"
)
// The KDF is one of the three components of an HPKE ciphersuite, implementing
// key derivation.
type KDF interface {
ID() uint16
oneStage() bool
size() int // Nh
labeledDerive(suiteID, inputKey []byte, label string, context []byte, length uint16) ([]byte, error)
labeledExtract(suiteID, salt []byte, label string, inputKey []byte) ([]byte, error)
labeledExpand(suiteID, randomKey []byte, label string, info []byte, length uint16) ([]byte, error)
}
// NewKDF returns the KDF implementation for the given KDF ID.
//
// Applications are encouraged to use specific implementations like [HKDFSHA256]
// instead, unless runtime agility is required.
func NewKDF(id uint16) (KDF, error) {
switch id {
case 0x0001: // HKDF-SHA256
return HKDFSHA256(), nil
case 0x0002: // HKDF-SHA384
return HKDFSHA384(), nil
case 0x0003: // HKDF-SHA512
return HKDFSHA512(), nil
case 0x0010: // SHAKE128
return SHAKE128(), nil
case 0x0011: // SHAKE256
return SHAKE256(), nil
default:
return nil, fmt.Errorf("unsupported KDF %04x", id)
}
}
// HKDFSHA256 returns an HKDF-SHA256 KDF implementation.
func HKDFSHA256() KDF { return hkdfSHA256 }
// HKDFSHA384 returns an HKDF-SHA384 KDF implementation.
func HKDFSHA384() KDF { return hkdfSHA384 }
// HKDFSHA512 returns an HKDF-SHA512 KDF implementation.
func HKDFSHA512() KDF { return hkdfSHA512 }
type hkdfKDF struct {
hash func() hash.Hash
id uint16
nH int
}
var hkdfSHA256 = &hkdfKDF{hash: sha256.New, id: 0x0001, nH: sha256.Size}
var hkdfSHA384 = &hkdfKDF{hash: sha512.New384, id: 0x0002, nH: sha512.Size384}
var hkdfSHA512 = &hkdfKDF{hash: sha512.New, id: 0x0003, nH: sha512.Size}
func (kdf *hkdfKDF) ID() uint16 {
return kdf.id
}
func (kdf *hkdfKDF) size() int {
return kdf.nH
}
func (kdf *hkdfKDF) oneStage() bool {
return false
}
func (kdf *hkdfKDF) labeledDerive(_, _ []byte, _ string, _ []byte, _ uint16) ([]byte, error) {
return nil, errors.New("hpke: internal error: labeledDerive called on two-stage KDF")
}
func (kdf *hkdfKDF) labeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) ([]byte, error) {
labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey))
labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
labeledIKM = append(labeledIKM, suiteID...)
labeledIKM = append(labeledIKM, label...)
labeledIKM = append(labeledIKM, inputKey...)
return hkdf.Extract(kdf.hash, labeledIKM, salt)
}
func (kdf *hkdfKDF) labeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) {
labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length)
labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
labeledInfo = append(labeledInfo, suiteID...)
labeledInfo = append(labeledInfo, label...)
labeledInfo = append(labeledInfo, info...)
return hkdf.Expand(kdf.hash, randomKey, string(labeledInfo), int(length))
}
// SHAKE128 returns a SHAKE128 KDF implementation.
func SHAKE128() KDF {
return shake128KDF
}
// SHAKE256 returns a SHAKE256 KDF implementation.
func SHAKE256() KDF {
return shake256KDF
}
type shakeKDF struct {
hash func() *sha3.SHAKE
id uint16
nH int
}
var shake128KDF = &shakeKDF{hash: sha3.NewSHAKE128, id: 0x0010, nH: 32}
var shake256KDF = &shakeKDF{hash: sha3.NewSHAKE256, id: 0x0011, nH: 64}
func (kdf *shakeKDF) ID() uint16 {
return kdf.id
}
func (kdf *shakeKDF) size() int {
return kdf.nH
}
func (kdf *shakeKDF) oneStage() bool {
return true
}
func (kdf *shakeKDF) labeledDerive(suiteID, inputKey []byte, label string, context []byte, length uint16) ([]byte, error) {
H := kdf.hash()
H.Write(inputKey)
H.Write([]byte("HPKE-v1"))
H.Write(suiteID)
H.Write([]byte{byte(len(label) >> 8), byte(len(label))})
H.Write([]byte(label))
H.Write([]byte{byte(length >> 8), byte(length)})
H.Write(context)
out := make([]byte, length)
H.Read(out)
return out, nil
}
func (kdf *shakeKDF) labeledExtract(_, _ []byte, _ string, _ []byte) ([]byte, error) {
return nil, errors.New("hpke: internal error: labeledExtract called on one-stage KDF")
}
func (kdf *shakeKDF) labeledExpand(_, _ []byte, _ string, _ []byte, _ uint16) ([]byte, error) {
return nil, errors.New("hpke: internal error: labeledExpand called on one-stage KDF")
}

View file

@ -0,0 +1,382 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hpke
import (
"crypto/ecdh"
"crypto/rand"
"encoding/binary"
"errors"
)
// A KEM is a Key Encapsulation Mechanism, one of the three components of an
// HPKE ciphersuite.
type KEM interface {
// ID returns the HPKE KEM identifier.
ID() uint16
// GenerateKey generates a new key pair.
GenerateKey() (PrivateKey, error)
// NewPublicKey deserializes a public key from bytes.
//
// It implements DeserializePublicKey, as defined in RFC 9180.
NewPublicKey([]byte) (PublicKey, error)
// NewPrivateKey deserializes a private key from bytes.
//
// It implements DeserializePrivateKey, as defined in RFC 9180.
NewPrivateKey([]byte) (PrivateKey, error)
// DeriveKeyPair derives a key pair from the given input keying material.
//
// It implements DeriveKeyPair, as defined in RFC 9180.
DeriveKeyPair(ikm []byte) (PrivateKey, error)
encSize() int
}
// NewKEM returns the KEM implementation for the given KEM ID.
//
// Applications are encouraged to use specific implementations like [DHKEM] or
// [MLKEM768X25519] instead, unless runtime agility is required.
func NewKEM(id uint16) (KEM, error) {
switch id {
case 0x0010: // DHKEM(P-256, HKDF-SHA256)
return DHKEM(ecdh.P256()), nil
case 0x0011: // DHKEM(P-384, HKDF-SHA384)
return DHKEM(ecdh.P384()), nil
case 0x0012: // DHKEM(P-521, HKDF-SHA512)
return DHKEM(ecdh.P521()), nil
case 0x0020: // DHKEM(X25519, HKDF-SHA256)
return DHKEM(ecdh.X25519()), nil
case 0x0041: // ML-KEM-768
return MLKEM768(), nil
case 0x0042: // ML-KEM-1024
return MLKEM1024(), nil
case 0x647a: // MLKEM768-X25519
return MLKEM768X25519(), nil
case 0x0050: // MLKEM768-P256
return MLKEM768P256(), nil
case 0x0051: // MLKEM1024-P384
return MLKEM1024P384(), nil
default:
return nil, errors.New("unsupported KEM")
}
}
// A PublicKey is an instantiation of a KEM (one of the three components of an
// HPKE ciphersuite) with an encapsulation key (i.e. the public key).
//
// A PublicKey is usually obtained from a method of the corresponding [KEM] or
// [PrivateKey], such as [KEM.NewPublicKey] or [PrivateKey.PublicKey].
type PublicKey interface {
// KEM returns the instantiated KEM.
KEM() KEM
// Bytes returns the public key as the output of SerializePublicKey.
Bytes() []byte
encap() (sharedSecret, enc []byte, err error)
}
// A PrivateKey is an instantiation of a KEM (one of the three components of
// an HPKE ciphersuite) with a decapsulation key (i.e. the secret key).
//
// A PrivateKey is usually obtained from a method of the corresponding [KEM],
// such as [KEM.GenerateKey] or [KEM.NewPrivateKey].
type PrivateKey interface {
// KEM returns the instantiated KEM.
KEM() KEM
// Bytes returns the private key as the output of SerializePrivateKey, as
// defined in RFC 9180.
//
// Note that for X25519 this might not match the input to NewPrivateKey.
// This is a requirement of RFC 9180, Section 7.1.2.
Bytes() ([]byte, error)
// PublicKey returns the corresponding PublicKey.
PublicKey() PublicKey
decap(enc []byte) (sharedSecret []byte, err error)
}
type dhKEM struct {
kdf KDF
id uint16
curve ecdh.Curve
Nsecret uint16
Nsk uint16
Nenc int
}
func (kem *dhKEM) extractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
eaePRK, err := kem.kdf.labeledExtract(suiteID, nil, "eae_prk", dhKey)
if err != nil {
return nil, err
}
return kem.kdf.labeledExpand(suiteID, eaePRK, "shared_secret", kemContext, kem.Nsecret)
}
func (kem *dhKEM) ID() uint16 {
return kem.id
}
func (kem *dhKEM) encSize() int {
return kem.Nenc
}
var dhKEMP256 = &dhKEM{HKDFSHA256(), 0x0010, ecdh.P256(), 32, 32, 65}
var dhKEMP384 = &dhKEM{HKDFSHA384(), 0x0011, ecdh.P384(), 48, 48, 97}
var dhKEMP521 = &dhKEM{HKDFSHA512(), 0x0012, ecdh.P521(), 64, 66, 133}
var dhKEMX25519 = &dhKEM{HKDFSHA256(), 0x0020, ecdh.X25519(), 32, 32, 32}
// DHKEM returns a KEM implementing one of
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on curve.
func DHKEM(curve ecdh.Curve) KEM {
switch curve {
case ecdh.P256():
return dhKEMP256
case ecdh.P384():
return dhKEMP384
case ecdh.P521():
return dhKEMP521
case ecdh.X25519():
return dhKEMX25519
default:
// The set of ecdh.Curve implementations is closed, because the
// interface has unexported methods. Therefore, this default case is
// only hit if a new curve is added that DHKEM doesn't support.
return unsupportedCurveKEM{}
}
}
type unsupportedCurveKEM struct{}
func (unsupportedCurveKEM) ID() uint16 {
return 0
}
func (unsupportedCurveKEM) GenerateKey() (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) NewPublicKey([]byte) (PublicKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) NewPrivateKey([]byte) (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) DeriveKeyPair([]byte) (PrivateKey, error) {
return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) encSize() int {
return 0
}
type dhKEMPublicKey struct {
kem *dhKEM
pub *ecdh.PublicKey
}
// NewDHKEMPublicKey returns a PublicKey implementing
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of pub ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of [DHKEM].
func NewDHKEMPublicKey(pub *ecdh.PublicKey) (PublicKey, error) {
kem, ok := DHKEM(pub.Curve()).(*dhKEM)
if !ok {
return nil, errors.New("unsupported curve")
}
return &dhKEMPublicKey{
kem: kem,
pub: pub,
}, nil
}
func (kem *dhKEM) NewPublicKey(data []byte) (PublicKey, error) {
pub, err := kem.curve.NewPublicKey(data)
if err != nil {
return nil, err
}
return NewDHKEMPublicKey(pub)
}
func (pk *dhKEMPublicKey) KEM() KEM {
return pk.kem
}
func (pk *dhKEMPublicKey) Bytes() []byte {
return pk.pub.Bytes()
}
// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() *ecdh.PrivateKey
func (pk *dhKEMPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
privEph, err := pk.pub.Curve().GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
if testingOnlyGenerateKey != nil {
privEph = testingOnlyGenerateKey()
}
dhVal, err := privEph.ECDH(pk.pub)
if err != nil {
return nil, nil, err
}
encPubEph := privEph.PublicKey().Bytes()
encPubRecip := pk.pub.Bytes()
kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = pk.kem.extractAndExpand(dhVal, kemContext)
if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
}
type dhKEMPrivateKey struct {
kem *dhKEM
priv ecdh.KeyExchanger
}
// NewDHKEMPrivateKey returns a PrivateKey implementing
//
// - DHKEM(P-256, HKDF-SHA256)
// - DHKEM(P-384, HKDF-SHA384)
// - DHKEM(P-521, HKDF-SHA512)
// - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of priv ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh private key, or another implementation of a [ecdh.KeyExchanger]
// (e.g. a hardware key). Otherwise, applications should use the
// [KEM.NewPrivateKey] method of [DHKEM].
func NewDHKEMPrivateKey(priv ecdh.KeyExchanger) (PrivateKey, error) {
kem, ok := DHKEM(priv.Curve()).(*dhKEM)
if !ok {
return nil, errors.New("unsupported curve")
}
return &dhKEMPrivateKey{
kem: kem,
priv: priv,
}, nil
}
func (kem *dhKEM) GenerateKey() (PrivateKey, error) {
priv, err := kem.curve.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
return NewDHKEMPrivateKey(priv)
}
func (kem *dhKEM) NewPrivateKey(ikm []byte) (PrivateKey, error) {
priv, err := kem.curve.NewPrivateKey(ikm)
if err != nil {
return nil, err
}
return NewDHKEMPrivateKey(priv)
}
func (kem *dhKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
// DeriveKeyPair from RFC 9180 Section 7.1.3.
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
prk, err := kem.kdf.labeledExtract(suiteID, nil, "dkp_prk", ikm)
if err != nil {
return nil, err
}
if kem == dhKEMX25519 {
s, err := kem.kdf.labeledExpand(suiteID, prk, "sk", nil, kem.Nsk)
if err != nil {
return nil, err
}
return kem.NewPrivateKey(s)
}
var counter uint8
for counter < 4 {
s, err := kem.kdf.labeledExpand(suiteID, prk, "candidate", []byte{counter}, kem.Nsk)
if err != nil {
return nil, err
}
if kem == dhKEMP521 {
s[0] &= 0x01
}
r, err := kem.NewPrivateKey(s)
if err != nil {
counter++
continue
}
return r, nil
}
panic("chance of four rejections is < 2^-128")
}
func (k *dhKEMPrivateKey) KEM() KEM {
return k.kem
}
func (k *dhKEMPrivateKey) Bytes() ([]byte, error) {
// Bizarrely, RFC 9180, Section 7.1.2 says SerializePrivateKey MUST clamp
// the output, which I thought we all agreed to instead do as part of the DH
// function, letting private keys be random bytes.
//
// At the same time, it says DeserializePrivateKey MUST also clamp, implying
// that the input doesn't have to be clamped, so Bytes by spec doesn't
// necessarily match the NewPrivateKey input.
//
// I'm sure this will not lead to any unexpected behavior or interop issue.
priv, ok := k.priv.(*ecdh.PrivateKey)
if !ok {
return nil, errors.New("ecdh: private key does not support Bytes")
}
if k.kem == dhKEMX25519 {
b := priv.Bytes()
b[0] &= 248
b[31] &= 127
b[31] |= 64
return b, nil
}
return priv.Bytes(), nil
}
func (k *dhKEMPrivateKey) PublicKey() PublicKey {
return &dhKEMPublicKey{
kem: k.kem,
pub: k.priv.PublicKey(),
}
}
func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) {
pubEph, err := k.priv.Curve().NewPublicKey(encPubEph)
if err != nil {
return nil, err
}
dhVal, err := k.priv.ECDH(pubEph)
if err != nil {
return nil, err
}
kemContext := append(encPubEph, k.priv.PublicKey().Bytes()...)
return k.kem.extractAndExpand(dhVal, kemContext)
}

View file

@ -0,0 +1,530 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hpke
import (
"bytes"
"crypto"
"crypto/ecdh"
"crypto/mlkem"
"crypto/rand"
"crypto/sha3"
"encoding/binary"
"errors"
)
var mlkem768X25519 = &hybridKEM{
id: 0x647a,
label: /**/ `\./` +
/* */ `/^\`,
curve: ecdh.X25519(),
curveSeedSize: 32,
curvePointSize: 32,
pqEncapsKeySize: mlkem.EncapsulationKeySize768,
pqCiphertextSize: mlkem.CiphertextSize768,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
}
// MLKEM768X25519 returns a KEM implementing MLKEM768-X25519 (a.k.a. X-Wing)
// from draft-ietf-hpke-pq.
func MLKEM768X25519() KEM {
return mlkem768X25519
}
var mlkem768P256 = &hybridKEM{
id: 0x0050,
label: "MLKEM768-P256",
curve: ecdh.P256(),
curveSeedSize: 32,
curvePointSize: 65,
pqEncapsKeySize: mlkem.EncapsulationKeySize768,
pqCiphertextSize: mlkem.CiphertextSize768,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
}
// MLKEM768P256 returns a KEM implementing MLKEM768-P256 from draft-ietf-hpke-pq.
func MLKEM768P256() KEM {
return mlkem768P256
}
var mlkem1024P384 = &hybridKEM{
id: 0x0051,
label: "MLKEM1024-P384",
curve: ecdh.P384(),
curveSeedSize: 48,
curvePointSize: 97,
pqEncapsKeySize: mlkem.EncapsulationKeySize1024,
pqCiphertextSize: mlkem.CiphertextSize1024,
pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey1024(data)
},
pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey1024(data)
},
pqGenerateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey1024()
},
}
// MLKEM1024P384 returns a KEM implementing MLKEM1024-P384 from draft-ietf-hpke-pq.
func MLKEM1024P384() KEM {
return mlkem1024P384
}
type hybridKEM struct {
id uint16
label string
curve ecdh.Curve
curveSeedSize int
curvePointSize int
pqEncapsKeySize int
pqCiphertextSize int
pqNewPublicKey func(data []byte) (crypto.Encapsulator, error)
pqNewPrivateKey func(data []byte) (crypto.Decapsulator, error)
pqGenerateKey func() (crypto.Decapsulator, error)
}
func (kem *hybridKEM) ID() uint16 {
return kem.id
}
func (kem *hybridKEM) encSize() int {
return kem.pqCiphertextSize + kem.curvePointSize
}
func (kem *hybridKEM) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte {
h := sha3.New256()
h.Write(ssPQ)
h.Write(ssT)
h.Write(ctT)
h.Write(ekT)
h.Write([]byte(kem.label))
return h.Sum(nil)
}
type hybridPublicKey struct {
kem *hybridKEM
t *ecdh.PublicKey
pq crypto.Encapsulator
}
// NewHybridPublicKey returns a PublicKey implementing one of
//
// - MLKEM768-X25519 (a.k.a. X-Wing)
// - MLKEM768-P256
// - MLKEM1024-P384
//
// from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq (either
// *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem public keys. Otherwise, applications should use
// the [KEM.NewPublicKey] method of e.g. [MLKEM768X25519].
func NewHybridPublicKey(pq crypto.Encapsulator, t *ecdh.PublicKey) (PublicKey, error) {
switch t.Curve() {
case ecdh.X25519():
if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for X25519 hybrid")
}
return &hybridPublicKey{mlkem768X25519, t, pq}, nil
case ecdh.P256():
if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for P-256 hybrid")
}
return &hybridPublicKey{mlkem768P256, t, pq}, nil
case ecdh.P384():
if _, ok := pq.(*mlkem.EncapsulationKey1024); !ok {
return nil, errors.New("invalid PQ KEM for P-384 hybrid")
}
return &hybridPublicKey{mlkem1024P384, t, pq}, nil
default:
return nil, errors.New("unsupported curve")
}
}
func (kem *hybridKEM) NewPublicKey(data []byte) (PublicKey, error) {
if len(data) != kem.pqEncapsKeySize+kem.curvePointSize {
return nil, errors.New("invalid public key size")
}
pq, err := kem.pqNewPublicKey(data[:kem.pqEncapsKeySize])
if err != nil {
return nil, err
}
k, err := kem.curve.NewPublicKey(data[kem.pqEncapsKeySize:])
if err != nil {
return nil, err
}
return NewHybridPublicKey(pq, k)
}
func (pk *hybridPublicKey) KEM() KEM {
return pk.kem
}
func (pk *hybridPublicKey) Bytes() []byte {
return append(pk.pq.Bytes(), pk.t.Bytes()...)
}
var testingOnlyEncapsulate func() (ss, ct []byte)
func (pk *hybridPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
skE, err := pk.t.Curve().GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}
if testingOnlyGenerateKey != nil {
skE = testingOnlyGenerateKey()
}
ssT, err := skE.ECDH(pk.t)
if err != nil {
return nil, nil, err
}
ctT := skE.PublicKey().Bytes()
ssPQ, ctPQ := pk.pq.Encapsulate()
if testingOnlyEncapsulate != nil {
ssPQ, ctPQ = testingOnlyEncapsulate()
}
ss := pk.kem.sharedSecret(ssPQ, ssT, ctT, pk.t.Bytes())
ct := append(ctPQ, ctT...)
return ss, ct, nil
}
type hybridPrivateKey struct {
kem *hybridKEM
seed []byte // can be nil
t ecdh.KeyExchanger
pq crypto.Decapsulator
}
// NewHybridPrivateKey returns a PrivateKey implementing
//
// - MLKEM768-X25519 (a.k.a. X-Wing)
// - MLKEM768-P256
// - MLKEM1024-P384
//
// from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem private keys, or another implementation of a
// [ecdh.KeyExchanger] and [crypto.Decapsulator] (e.g. a hardware key).
// Otherwise, applications should use the [KEM.NewPrivateKey] method of e.g.
// [MLKEM768X25519].
func NewHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger) (PrivateKey, error) {
return newHybridPrivateKey(pq, t, nil)
}
func (kem *hybridKEM) GenerateKey() (PrivateKey, error) {
seed := make([]byte, 32)
rand.Read(seed)
return kem.NewPrivateKey(seed)
}
func (kem *hybridKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
if len(priv) != 32 {
return nil, errors.New("hpke: invalid hybrid KEM secret length")
}
s := sha3.NewSHAKE256()
s.Write(priv)
seedPQ := make([]byte, mlkem.SeedSize)
s.Read(seedPQ)
pq, err := kem.pqNewPrivateKey(seedPQ)
if err != nil {
return nil, err
}
seedT := make([]byte, kem.curveSeedSize)
for {
s.Read(seedT)
k, err := kem.curve.NewPrivateKey(seedT)
if err != nil {
continue
}
return newHybridPrivateKey(pq, k, priv)
}
}
func newHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger, seed []byte) (PrivateKey, error) {
switch t.Curve() {
case ecdh.X25519():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for X25519 hybrid")
}
return &hybridPrivateKey{mlkem768X25519, bytes.Clone(seed), t, pq}, nil
case ecdh.P256():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
return nil, errors.New("invalid PQ KEM for P-256 hybrid")
}
return &hybridPrivateKey{mlkem768P256, bytes.Clone(seed), t, pq}, nil
case ecdh.P384():
if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey1024); !ok {
return nil, errors.New("invalid PQ KEM for P-384 hybrid")
}
return &hybridPrivateKey{mlkem1024P384, bytes.Clone(seed), t, pq}, nil
default:
return nil, errors.New("unsupported curve")
}
}
func (kem *hybridKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 32)
if err != nil {
return nil, err
}
return kem.NewPrivateKey(dk)
}
func (k *hybridPrivateKey) KEM() KEM {
return k.kem
}
func (k *hybridPrivateKey) Bytes() ([]byte, error) {
if k.seed == nil {
return nil, errors.New("private key seed not available")
}
return k.seed, nil
}
func (k *hybridPrivateKey) PublicKey() PublicKey {
return &hybridPublicKey{
kem: k.kem,
t: k.t.PublicKey(),
pq: k.pq.Encapsulator(),
}
}
func (k *hybridPrivateKey) decap(enc []byte) ([]byte, error) {
if len(enc) != k.kem.pqCiphertextSize+k.kem.curvePointSize {
return nil, errors.New("invalid encapsulated key size")
}
ctPQ, ctT := enc[:k.kem.pqCiphertextSize], enc[k.kem.pqCiphertextSize:]
ssPQ, err := k.pq.Decapsulate(ctPQ)
if err != nil {
return nil, err
}
pub, err := k.t.Curve().NewPublicKey(ctT)
if err != nil {
return nil, err
}
ssT, err := k.t.ECDH(pub)
if err != nil {
return nil, err
}
ss := k.kem.sharedSecret(ssPQ, ssT, ctT, k.t.PublicKey().Bytes())
return ss, nil
}
var mlkem768 = &mlkemKEM{
id: 0x0041,
ciphertextSize: mlkem.CiphertextSize768,
newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey768(data)
},
newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey768(data)
},
generateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey768()
},
}
// MLKEM768 returns a KEM implementing ML-KEM-768 from draft-ietf-hpke-pq.
func MLKEM768() KEM {
return mlkem768
}
var mlkem1024 = &mlkemKEM{
id: 0x0042,
ciphertextSize: mlkem.CiphertextSize1024,
newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
return mlkem.NewEncapsulationKey1024(data)
},
newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
return mlkem.NewDecapsulationKey1024(data)
},
generateKey: func() (crypto.Decapsulator, error) {
return mlkem.GenerateKey1024()
},
}
// MLKEM1024 returns a KEM implementing ML-KEM-1024 from draft-ietf-hpke-pq.
func MLKEM1024() KEM {
return mlkem1024
}
type mlkemKEM struct {
id uint16
ciphertextSize int
newPublicKey func(data []byte) (crypto.Encapsulator, error)
newPrivateKey func(data []byte) (crypto.Decapsulator, error)
generateKey func() (crypto.Decapsulator, error)
}
func (kem *mlkemKEM) ID() uint16 {
return kem.id
}
func (kem *mlkemKEM) encSize() int {
return kem.ciphertextSize
}
type mlkemPublicKey struct {
kem *mlkemKEM
pq crypto.Encapsulator
}
// NewMLKEMPublicKey returns a KEMPublicKey implementing
//
// - ML-KEM-768
// - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of pub
// (*[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of e.g. [MLKEM768].
func NewMLKEMPublicKey(pub crypto.Encapsulator) (PublicKey, error) {
switch pub.(type) {
case *mlkem.EncapsulationKey768:
return &mlkemPublicKey{mlkem768, pub}, nil
case *mlkem.EncapsulationKey1024:
return &mlkemPublicKey{mlkem1024, pub}, nil
default:
return nil, errors.New("unsupported public key type")
}
}
func (kem *mlkemKEM) NewPublicKey(data []byte) (PublicKey, error) {
pq, err := kem.newPublicKey(data)
if err != nil {
return nil, err
}
return NewMLKEMPublicKey(pq)
}
func (pk *mlkemPublicKey) KEM() KEM {
return pk.kem
}
func (pk *mlkemPublicKey) Bytes() []byte {
return pk.pq.Bytes()
}
func (pk *mlkemPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
ss, ct := pk.pq.Encapsulate()
if testingOnlyEncapsulate != nil {
ss, ct = testingOnlyEncapsulate()
}
return ss, ct, nil
}
type mlkemPrivateKey struct {
kem *mlkemKEM
pq crypto.Decapsulator
}
// NewMLKEMPrivateKey returns a KEMPrivateKey implementing
//
// - ML-KEM-768
// - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of priv.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem private key. Otherwise, applications should use the
// [KEM.NewPrivateKey] method of e.g. [MLKEM768].
func NewMLKEMPrivateKey(priv crypto.Decapsulator) (PrivateKey, error) {
switch priv.Encapsulator().(type) {
case *mlkem.EncapsulationKey768:
return &mlkemPrivateKey{mlkem768, priv}, nil
case *mlkem.EncapsulationKey1024:
return &mlkemPrivateKey{mlkem1024, priv}, nil
default:
return nil, errors.New("unsupported public key type")
}
}
func (kem *mlkemKEM) GenerateKey() (PrivateKey, error) {
pq, err := kem.generateKey()
if err != nil {
return nil, err
}
return NewMLKEMPrivateKey(pq)
}
func (kem *mlkemKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
pq, err := kem.newPrivateKey(priv)
if err != nil {
return nil, err
}
return NewMLKEMPrivateKey(pq)
}
func (kem *mlkemKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), kem.id)
dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 64)
if err != nil {
return nil, err
}
return kem.NewPrivateKey(dk)
}
func (k *mlkemPrivateKey) KEM() KEM {
return k.kem
}
func (k *mlkemPrivateKey) Bytes() ([]byte, error) {
pq, ok := k.pq.(interface {
Bytes() []byte
})
if !ok {
return nil, errors.New("private key seed not available")
}
return pq.Bytes(), nil
}
func (k *mlkemPrivateKey) PublicKey() PublicKey {
return &mlkemPublicKey{
kem: k.kem,
pq: k.pq.Encapsulator(),
}
}
func (k *mlkemPrivateKey) decap(enc []byte) ([]byte, error) {
return k.pq.Decapsulate(enc)
}

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,332 @@
[
{
"mode": 0,
"kem_id": 32,
"kdf_id": 1,
"aead_id": 1,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "7268600d403fce431561aef583ee1613527cff655c1343f29812e66706df3234",
"ikmR": "6db9df30aa07dd42ee5e8181afdb977e538f5e1fec8a06223f33f7013e525037",
"skRm": "4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8",
"pkRm": "3948cfe0ad1ddb695d780e59077195da6c56506b027329794ab02bca80815c4d",
"enc": "37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431",
"encryptions_accumulated": "dcabb32ad8e8acea785275323395abd0",
"exports_accumulated": "45db490fc51c86ba46cca1217f66a75e"
},
{
"mode": 0,
"kem_id": 32,
"kdf_id": 1,
"aead_id": 2,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "2cd7c601cefb3d42a62b04b7a9041494c06c7843818e0ce28a8f704ae7ab20f9",
"ikmR": "dac33b0e9db1b59dbbea58d59a14e7b5896e9bdf98fad6891e99d1686492b9ee",
"skRm": "497b4502664cfea5d5af0b39934dac72242a74f8480451e1aee7d6a53320333d",
"pkRm": "430f4b9859665145a6b1ba274024487bd66f03a2dd577d7753c68d7d7d00c00c",
"enc": "6c93e09869df3402d7bf231bf540fadd35cd56be14f97178f0954db94b7fc256",
"encryptions_accumulated": "1702e73e1e71705faa8241022af1deea",
"exports_accumulated": "5cb678bf1c52afbd9afb58b8f7c1ced3"
},
{
"mode": 0,
"kem_id": 32,
"kdf_id": 1,
"aead_id": 3,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "909a9b35d3dc4713a5e72a4da274b55d3d3821a37e5d099e74a647db583a904b",
"ikmR": "1ac01f181fdf9f352797655161c58b75c656a6cc2716dcb66372da835542e1df",
"skRm": "8057991eef8f1f1af18f4a9491d16a1ce333f695d4db8e38da75975c4478e0fb",
"pkRm": "4310ee97d88cc1f088a5576c77ab0cf5c3ac797f3d95139c6c84b5429c59662a",
"enc": "1afa08d3dec047a643885163f1180476fa7ddb54c6a8029ea33f95796bf2ac4a",
"encryptions_accumulated": "225fb3d35da3bb25e4371bcee4273502",
"exports_accumulated": "54e2189c04100b583c84452f94eb9a4a"
},
{
"mode": 0,
"kem_id": 32,
"kdf_id": 1,
"aead_id": 65535,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "55bc245ee4efda25d38f2d54d5bb6665291b99f8108a8c4b686c2b14893ea5d9",
"ikmR": "683ae0da1d22181e74ed2e503ebf82840deb1d5e872cade20f4b458d99783e31",
"skRm": "33d196c830a12f9ac65d6e565a590d80f04ee9b19c83c87f2c170d972a812848",
"pkRm": "194141ca6c3c3beb4792cd97ba0ea1faff09d98435012345766ee33aae2d7664",
"enc": "e5e8f9bfff6c2f29791fc351d2c25ce1299aa5eaca78a757c0b4fb4bcd830918",
"exports_accumulated": "3fe376e3f9c349bc5eae67bbce867a16"
},
{
"mode": 0,
"kem_id": 32,
"kdf_id": 3,
"aead_id": 1,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "895221ae20f39cbf46871d6ea162d44b84dd7ba9cc7a3c80f16d6ea4242cd6d4",
"ikmR": "59a9b44375a297d452fc18e5bba1a64dec709f23109486fce2d3a5428ed2000a",
"skRm": "ddfbb71d7ea8ebd98fa9cc211aa7b535d258fe9ab4a08bc9896af270e35aad35",
"pkRm": "adf16c696b87995879b27d470d37212f38a58bfe7f84e6d50db638b8f2c22340",
"enc": "8998da4c3d6ade83c53e861a022c046db909f1c31107196ab4c2f4dd37e1a949",
"encryptions_accumulated": "19a0d0fb001f83e7606948507842f913",
"exports_accumulated": "e5d853af841b92602804e7a40c1f2487"
},
{
"mode": 0,
"kem_id": 32,
"kdf_id": 3,
"aead_id": 2,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "e72b39232ee9ef9f6537a72afe28f551dbe632006aa1b300a00518883a3f2dc1",
"ikmR": "a0484936abc95d587acf7034156229f9970e9dfa76773754e40fb30e53c9de16",
"skRm": "bdd8943c1e60191f3ea4e69fc4f322aa1086db9650f1f952fdce88395a4bd1af",
"pkRm": "aa7bddcf5ca0b2c0cf760b5dffc62740a8e761ec572032a809bebc87aaf7575e",
"enc": "c12ba9fb91d7ebb03057d8bea4398688dcc1d1d1ff3b97f09b96b9bf89bd1e4a",
"encryptions_accumulated": "20402e520fdbfee76b2b0af73d810deb",
"exports_accumulated": "80b7f603f0966ca059dd5e8a7cede735"
},
{
"mode": 0,
"kem_id": 32,
"kdf_id": 3,
"aead_id": 3,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "636d1237a5ae674c24caa0c32a980d3218d84f916ba31e16699892d27103a2a9",
"ikmR": "969bb169aa9c24a501ee9d962e96c310226d427fb6eb3fc579d9882dbc708315",
"skRm": "fad15f488c09c167bd18d8f48f282e30d944d624c5676742ad820119de44ea91",
"pkRm": "06aa193a5612d89a1935c33f1fda3109fcdf4b867da4c4507879f184340b0e0e",
"enc": "1d38fc578d4209ea0ef3ee5f1128ac4876a9549d74dc2d2f46e75942a6188244",
"encryptions_accumulated": "c03e64ef58b22065f04be776d77e160c",
"exports_accumulated": "fa84b4458d580b5069a1be60b4785eac"
},
{
"mode": 0,
"kem_id": 32,
"kdf_id": 3,
"aead_id": 65535,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "3cfbc97dece2c497126df8909efbdd3d56b3bbe97ddf6555c99a04ff4402474c",
"ikmR": "dff9a966e02b161472f167c0d4252d400069449e62384beb78111cb596220921",
"skRm": "7596739457c72bbd6758c7021cfcb4d2fcd677d1232896b8f00da223c5519c36",
"pkRm": "9a83674c1bc12909fd59635ba1445592b82a7c01d4dad3ffc8f3975e76c43732",
"enc": "444fbbf83d64fef654dfb2a17997d82ca37cd8aeb8094371da33afb95e0c5b0e",
"exports_accumulated": "7557bdf93eadf06e3682fce3d765277f"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 1,
"aead_id": 1,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "4270e54ffd08d79d5928020af4686d8f6b7d35dbe470265f1f5aa22816ce860e",
"ikmR": "668b37171f1072f3cf12ea8a236a45df23fc13b82af3609ad1e354f6ef817550",
"skRm": "f3ce7fdae57e1a310d87f1ebbde6f328be0a99cdbcadf4d6589cf29de4b8ffd2",
"pkRm": "04fe8c19ce0905191ebc298a9245792531f26f0cece2460639e8bc39cb7f706a826a779b4cf969b8a0e539c7f62fb3d30ad6aa8f80e30f1d128aafd68a2ce72ea0",
"enc": "04a92719c6195d5085104f469a8b9814d5838ff72b60501e2c4466e5e67b325ac98536d7b61a1af4b78e5b7f951c0900be863c403ce65c9bfcb9382657222d18c4",
"encryptions_accumulated": "fcb852ae6a1e19e874fbd18a199df3e4",
"exports_accumulated": "655be1f8b189a6b103528ac6d28d3109"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 1,
"aead_id": 2,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "a90d3417c3da9cb6c6ae19b4b5dd6cc9529a4cc24efb7ae0ace1f31887a8cd6c",
"ikmR": "a0ce15d49e28bd47a18a97e147582d814b08cbe00109fed5ec27d1b4e9f6f5e3",
"skRm": "317f915db7bc629c48fe765587897e01e282d3e8445f79f27f65d031a88082b2",
"pkRm": "04abc7e49a4c6b3566d77d0304addc6ed0e98512ffccf505e6a8e3eb25c685136f853148544876de76c0f2ef99cdc3a05ccf5ded7860c7c021238f9e2073d2356c",
"enc": "04c06b4f6bebc7bb495cb797ab753f911aff80aefb86fd8b6fcc35525f3ab5f03e0b21bd31a86c6048af3cb2d98e0d3bf01da5cc4c39ff5370d331a4f1f7d5a4e0",
"encryptions_accumulated": "8d3263541fc1695b6e88ff3a1208577c",
"exports_accumulated": "038af0baa5ce3c4c5f371c3823b15217"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 1,
"aead_id": 3,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "f1f1a3bc95416871539ecb51c3a8f0cf608afb40fbbe305c0a72819d35c33f1f",
"ikmR": "61092f3f56994dd424405899154a9918353e3e008171517ad576b900ddb275e7",
"skRm": "a4d1c55836aa30f9b3fbb6ac98d338c877c2867dd3a77396d13f68d3ab150d3b",
"pkRm": "04a697bffde9405c992883c5c439d6cc358170b51af72812333b015621dc0f40bad9bb726f68a5c013806a790ec716ab8669f84f6b694596c2987cf35baba2a006",
"enc": "04c07836a0206e04e31d8ae99bfd549380b072a1b1b82e563c935c095827824fc1559eac6fb9e3c70cd3193968994e7fe9781aa103f5b50e934b5b2f387e381291",
"encryptions_accumulated": "702cdecae9ba5c571c8b00ad1f313dbf",
"exports_accumulated": "2e0951156f1e7718a81be3004d606800"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 1,
"aead_id": 65535,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "3800bb050bb4882791fc6b2361d7adc2543e4e0abbac367cf00a0c4251844350",
"ikmR": "c6638d8079a235ea4054885355a7caefee67151c6ff2a04f4ba26d099c3a8b02",
"skRm": "62c3868357a464f8461d03aa0182c7cebcde841036aea7230ddc7339f1088346",
"pkRm": "046c6bb9e1976402c692fef72552f4aaeedd83a5e5079de3d7ae732da0f397b15921fb9c52c9866affc8e29c0271a35937023a9245982ec18bab1eb157cf16fc33",
"enc": "04d804370b7e24b94749eb1dc8df6d4d4a5d75f9effad01739ebcad5c54a40d57aaa8b4190fc124dbde2e4f1e1d1b012a3bc4038157dc29b55533a932306d8d38d",
"exports_accumulated": "a6d39296bc2704db6194b7d6180ede8a"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 3,
"aead_id": 1,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "4ab11a9dd78c39668f7038f921ffc0993b368171d3ddde8031501ee1e08c4c9a",
"ikmR": "ea9ff7cc5b2705b188841c7ace169290ff312a9cb31467784ca92d7a2e6e1be8",
"skRm": "3ac8530ad1b01885960fab38cf3cdc4f7aef121eaa239f222623614b4079fb38",
"pkRm": "04085aa5b665dc3826f9650ccbcc471be268c8ada866422f739e2d531d4a8818a9466bc6b449357096232919ec4fe9070ccbac4aac30f4a1a53efcf7af90610edd",
"enc": "0493ed86735bdfb978cc055c98b45695ad7ce61ce748f4dd63c525a3b8d53a15565c6897888070070c1579db1f86aaa56deb8297e64db7e8924e72866f9a472580",
"encryptions_accumulated": "3d670fc7760ce5b208454bb678fbc1dd",
"exports_accumulated": "0a3e30b572dafc58b998cd51959924be"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 3,
"aead_id": 2,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "0c4b7c8090d9995e298d6fd61c7a0a66bb765a12219af1aacfaac99b4deaf8ad",
"ikmR": "a2f6e7c4d9e108e03be268a64fe73e11a320963c85375a30bfc9ec4a214c6a55",
"skRm": "9648e8711e9b6cb12dc19abf9da350cf61c3669c017b1db17bb36913b54a051d",
"pkRm": "0400f209b1bf3b35b405d750ef577d0b2dc81784005d1c67ff4f6d2860d7640ca379e22ac7fa105d94bc195758f4dfc0b82252098a8350c1bfeda8275ce4dd4262",
"enc": "0404dc39344526dbfa728afba96986d575811b5af199c11f821a0e603a4d191b25544a402f25364964b2c129cb417b3c1dab4dfc0854f3084e843f731654392726",
"encryptions_accumulated": "9da1683aade69d882aa094aa57201481",
"exports_accumulated": "80ab8f941a71d59f566e5032c6e2c675"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 3,
"aead_id": 3,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "02bd2bdbb430c0300cea89b37ada706206a9a74e488162671d1ff68b24deeb5f",
"ikmR": "8d283ea65b27585a331687855ab0836a01191d92ab689374f3f8d655e702d82f",
"skRm": "ebedc3ca088ad03dfbbfcd43f438c4bb5486376b8ccaea0dc25fc64b2f7fc0da",
"pkRm": "048fed808e948d46d95f778bd45236ce0c464567a1dc6f148ba71dc5aeff2ad52a43c71851b99a2cdbf1dad68d00baad45007e0af443ff80ad1b55322c658b7372",
"enc": "044415d6537c2e9dd4c8b73f2868b5b9e7e8e3d836990dc2fd5b466d1324c88f2df8436bac7aa2e6ebbfd13bd09eaaa7c57c7495643bacba2121dca2f2040e1c5f",
"encryptions_accumulated": "f025dca38d668cee68e7c434e1b98f9f",
"exports_accumulated": "2efbb7ade3f87133810f507fdd73f874"
},
{
"mode": 0,
"kem_id": 16,
"kdf_id": 3,
"aead_id": 65535,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "497efeca99592461588394f7e9496129ed89e62b58204e076d1b7141e999abda",
"ikmR": "49b7cbfc1756e8ae010dc80330108f5be91268b3636f3e547dbc714d6bcd3d16",
"skRm": "9d34abe85f6da91b286fbbcfbd12c64402de3d7f63819e6c613037746b4eae6b",
"pkRm": "0453a4d1a4333b291e32d50a77ac9157bbc946059941cf9ed5784c15adbc7ad8fe6bf34a504ed81fd9bc1b6bb066a037da30fccd6c0b42d72bf37b9fef43c8e498",
"enc": "04f910248e120076be2a4c93428ac0c8a6b89621cfef19f0f9e113d835cf39d5feabbf6d26444ebbb49c991ec22338ade3a5edff35a929be67c4e5f33dcff96706",
"exports_accumulated": "6df17307eeb20a9180cff75ea183dd60"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 1,
"aead_id": 1,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "5040af7a10269b11f78bb884812ad20041866db8bbd749a6a69e3f33e54da7164598f005bce09a9fe190e29c2f42df9e9e3aad040fccc625ddbd7aa99063fc594f40",
"ikmR": "39a28dc317c3e48b908948f99d608059f882d3d09c0541824bc25f94e6dee7aa0df1c644296b06fbb76e84aef5008f8a908e08fbabadf70658538d74753a85f8856a",
"skRm": "009227b4b91cf1eb6eecb6c0c0bae93a272d24e11c63bd4c34a581c49f9c3ca01c16bbd32a0a1fac22784f2ae985c85f183baad103b2d02aee787179dfc1a94fea11",
"pkRm": "0400b81073b1612cf7fdb6db07b35cf4bc17bda5854f3d270ecd9ea99f6c07b46795b8014b66c523ceed6f4829c18bc3886c891b63fa902500ce3ddeb1fbec7e608ac70050b76a0a7fc081dbf1cb30b005981113e635eb501a973aba662d7f16fcc12897dd752d657d37774bb16197c0d9724eecc1ed65349fb6ac1f280749e7669766f8cd",
"enc": "0400bec215e31718cd2eff5ba61d55d062d723527ec2029d7679a9c867d5c68219c9b217a9d7f78562dc0af3242fef35d1d6f4a28ee75f0d4b31bc918937b559b70762004c4fd6ad7373db7e31da8735fbd6171bbdcfa770211420682c760a40a482cc24f4125edbea9cb31fe71d5d796cfe788dc408857697a52fef711fb921fa7c385218",
"encryptions_accumulated": "94209973d36203eef2e56d155ef241d5",
"exports_accumulated": "31f25ea5e192561bce5f2c2822a9432c"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 1,
"aead_id": 2,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "9953fbd633be69d984fc4fffc4d7749f007dbf97102d36a647a8108b0bb7c609e826b026aec1cd47b93fc5acb7518fa455ed38d0c29e900c56990635612fd3d220d2",
"ikmR": "17320bc93d9bc1d422ba0c705bf693e9a51a855d6e09c11bddea5687adc1a1122ec81384dc7e47959cae01c420a69e8e39337d9ebf9a9b2f3905cb76a35b0693ac34",
"skRm": "01a27e65890d64a121cfe59b41484b63fd1213c989c00e05a049ac4ede1f5caeec52bf43a59bdc36731cb6f8a0b7d7724b047ff52803c421ee99d61d4ea2e569c825",
"pkRm": "0400eb4010ca82412c044b52bdc218625c4ea797e061236206843e318882b3c1642e7e14e7cc1b4b171a433075ac0c8563043829eee51059a8b68197c8a7f6922465650075f40b6f440fdf525e2512b0c2023709294d912d8c68f94140390bff228097ce2d5f89b2b21f50d4c0892cfb955c380293962d5fe72060913870b61adc8b111953",
"enc": "0401c1cf49cafa9e26e24a9e20d7fa44a50a4e88d27236ef17358e79f3615a97f825899a985b3edb5195cad24a4fb64828701e81fbfd9a7ef673efde508e789509bd7c00fd5bfe053377bbee22e40ae5d64aa6fb47b314b5ab7d71b652db9259962dce742317d54084f0cf62a4b7e3f3caa9e6afb8efd6bf1eb8a2e13a7e73ec9213070d68",
"encryptions_accumulated": "69d16fa7c814cd8be9aa2122fda8768f",
"exports_accumulated": "d295fad3aef8be1f89d785800f83a30b"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 1,
"aead_id": 3,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "566568b6cbfd1c6c06d1b0a2dc22d4e4965858bf3d54bf6cba5c018be0fad7a5cd9237937800f3cb57f10fa5691faeecab1685aa6da9b667469224a0989ff82b822b",
"ikmR": "f9f594556282cfe3eb30958ca2ef90ecd2a6ffd2661d41eb39ba184f3dae9f914aad297dd80cc763cb6525437a61ceae448aeeb304de137dc0f28dd007f0d592e137",
"skRm": "0168c8bf969b30bd949e154bf2db1964535e3f230f6604545bc9a33e9cd80fb17f4002170a9c91d55d7dd21db48e687cea83083498768cc008c6adf1e0ca08a309bd",
"pkRm": "040086b1a785a52af34a9a830332999896e99c5df0007a2ec3243ee3676ba040e60fde21bacf8e5f8db26b5acd42a2c81160286d54a2f124ca8816ac697993727431e50002aa5f5ebe70d88ff56445ade400fb979b466c9046123bbf5be72db9d90d1cde0bb7c217cff8ea0484445150eaf60170b039f54a5f6baeb7288bc62b1dedb59a1b",
"enc": "0401f828650ec526a647386324a31dadf75b54550b06707ae3e1fb83874b2633c935bb862bc4f07791ccfafbb08a1f00e18c531a34fec76f2cf3d581e7915fa40bbc3b010ab7c3d9162ea69928e71640ecff08b97f4fa9e8c66dfe563a13bf561cee7635563f91d387e2a38ee674ea28b24c633a988d1a08968b455e96307c64bda3f094b7",
"encryptions_accumulated": "586d5a92612828afbd7fdcea96006892",
"exports_accumulated": "a70389af65de4452a3f3147b66bd5c73"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 1,
"aead_id": 65535,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "5dfb76f8b4708970acb4a6efa35ec4f2cebd61a3276a711c2fa42ef0bc9c191ea9dac7c0ac907336d830cea4a8394ab69e9171f344c4817309f93170cb34914987a5",
"ikmR": "9fd2aad24a653787f53df4a0d514c6d19610ca803298d7812bc0460b76c21da99315ebfec2343b4848d34ce526f0d39ce5a8dfddd9544e1c4d4b9a62f4191d096b42",
"skRm": "01ca47cf2f6f36fef46a01a46b393c30672224dd566aa3dd07a229519c49632c83d800e66149c3a7a07b840060549accd0d480ec5c71d2a975f88f6aa2fc0810b393",
"pkRm": "040143b7db23907d3ae1c43ef4882a6cdb142ca05a21c2475985c199807dd143e898136c65faf1ca1b6c6c2e8a92d67a0ab9c24f8c5cff7610cb942a73eb2ec4217c26018d67621cc78a60ec4bd1e23f90eb772adba2cf5a566020ee651f017b280a155c016679bd7e7ebad49e28e7ab679f66765f4ef34eae6b38a99f31bc73ea0f0d694d",
"enc": "040073dda7343ce32926c028c3be28508cccb751e2d4c6187bcc4e9b1de82d3d70c5702c6c866a920d9d9a574f5a4d4a0102db76207d5b3b77da16bb57486c5cc2a95f006b5d2e15efb24e297bdf8f2b6d7b25bf226d1b6efca47627b484d2942c14df6fe018d82ab9fb7306370c248864ea48fe5ca94934993517aacaa3b6bca8f92efc84",
"exports_accumulated": "d8fa94ac5e6829caf5ab4cdd1e05f5e1"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 3,
"aead_id": 1,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "018b6bb1b8bbcefbd91e66db4e1300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"ikmR": "7bf9fd92611f2ff4e6c2ab4dd636a320e0397d6a93d014277b025a7533684c3255a02aa1f2a142be5391eebfc60a6a9c729b79c2428b8d78fa36497b1e89e446d402",
"skRm": "019db24a3e8b1f383436cd06997dd864eb091418ff561e3876cee2e4762a0cc0b69688af9a7a4963c90d394b2be579144af97d4933c0e6c2c2d13e7505ea51a06b0d",
"pkRm": "0401e06b350786c48a60dfc50eed324b58ecafc4efba26242c46c14274bd97f0989487a6fae0626188fea971ae1cb53f5d0e87188c1c62af92254f17138bbcebf5acd0018e574ee1d695813ce9dc45b404d2cf9c04f27627c4c55da1f936d813fd39435d0713d4a3cdc5409954a1180eb2672bdfc4e0e79c04eda89f857f625e058742a1c8",
"enc": "0400ac8d1611948105f23cf5e6842b07bd39b352d9d1e7bff2c93ac063731d6372e2661eff2afce604d4a679b49195f15e4fa228432aed971f2d46c1beb51fb3e5812501fe199c3d94c1b199393642500443dd82ce1c01701a1279cc3d74e29773030e26a70d3512f761e1eb0d7882209599eb9acd295f5939311c55e737f11c19988878d6",
"encryptions_accumulated": "207972885962115e69daaa3bc5015151",
"exports_accumulated": "8e9c577501320d86ee84407840188f5f"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 3,
"aead_id": 2,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "7f06ab8215105fc46aceeb2e3dc5028b44364f960426eb0d8e4026c2f8b5d7e7a986688f1591abf5ab753c357a5d6f0440414b4ed4ede71317772ac98d9239f70904",
"ikmR": "2ad954bbe39b7122529f7dde780bff626cd97f850d0784a432784e69d86eccaade43b6c10a8ffdb94bf943c6da479db137914ec835a7e715e36e45e29b587bab3bf1",
"skRm": "01462680369ae375e4b3791070a7458ed527842f6a98a79ff5e0d4cbde83c27196a3916956655523a6a2556a7af62c5cadabe2ef9da3760bb21e005202f7b2462847",
"pkRm": "0401b45498c1714e2dce167d3caf162e45e0642afc7ed435df7902ccae0e84ba0f7d373f646b7738bbbdca11ed91bdeae3cdcba3301f2457be452f271fa6837580e661012af49583a62e48d44bed350c7118c0d8dc861c238c72a2bda17f64704f464b57338e7f40b60959480c0e58e6559b190d81663ed816e523b6b6a418f66d2451ec64",
"enc": "040138b385ca16bb0d5fa0c0665fbbd7e69e3ee29f63991d3e9b5fa740aab8900aaeed46ed73a49055758425a0ce36507c54b29cc5b85a5cee6bae0cf1c21f2731ece2013dc3fb7c8d21654bb161b463962ca19e8c654ff24c94dd2898de12051f1ed0692237fb02b2f8d1dc1c73e9b366b529eb436e98a996ee522aef863dd5739d2f29b0",
"encryptions_accumulated": "31769e36bcca13288177eb1c92f616ae",
"exports_accumulated": "fbffd93db9f000f51cf8ab4c1127fbda"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 3,
"aead_id": 3,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "f9d540fde009bb1e5e71617c122a079862306b97144c8c4dca45ef6605c2ec9c43527c150800f5608a7e4cff771226579e7c776fb3def4e22e68e9fdc92340e94b6e",
"ikmR": "5273f7762dea7a2408333dbf8db9f6ef2ac4c475ad9e81a3b0b8c8805304adf5c876105d8703b42117ad8ee350df881e3d52926aafcb5c90f649faf94be81952c78a",
"skRm": "015b59f17366a1d4442e5b92d883a8f35fe8d88fea0e5bac6dfac7153c78fd0c6248c618b083899a7d62ba6e00e8a22cdde628dd5399b9a3377bb898792ff6f54ab9",
"pkRm": "040084698a47358f06a92926ee826a6784341285ee45f4b8269de271a8c6f03d5e8e24f628de13f5c37377b7cabfbd67bc98f9e8e758dfbee128b2fe752cd32f0f3ccd0061baec1ed7c6b52b7558bc120f783e5999c8952242d9a20baf421ccfc2a2b87c42d7b5b806fea6d518d5e9cd7bfd6c85beb5adeb72da41ac3d4f27bba83cff24d7",
"enc": "0400edc201c9b32988897a7f7b19104ebb54fc749faa41a67e9931e87ec30677194898074afb9a5f40a97df2972368a0c594e5b60e90d1ff83e9e35f8ff3ad200fd6d70028b5645debe9f1f335dbc1225c066218e85cf82a05fbe361fa477740b906cb3083076e4d17232513d102627597d38e354762cf05b3bd0f33dc4d0fb78531afd3fd",
"encryptions_accumulated": "aa69356025f552372770ef126fa2e59a",
"exports_accumulated": "1fcffb5d8bc1d825daf904a0c6f4a4d3"
},
{
"mode": 0,
"kem_id": 18,
"kdf_id": 3,
"aead_id": 65535,
"info": "4f6465206f6e2061204772656369616e2055726e",
"ikmE": "3018d74c67d0c61b5e4075190621fc192996e928b8859f45b3ad2399af8599df69c34b7a3eefeda7ee49ae73d4579300b85dde1654c0dfc3a3f78143d239a628cf72",
"ikmR": "a243eff510b99140034c72587e9f131809b9bce03a9da3da458771297f535cede0f48167200bf49ac123b52adfd789cf0adfd5cded6be2f146aeb00c34d4e6d234fc",
"skRm": "0045fe00b1d55eb64182d334e301e9ac553d6dbafbf69935e65f5bf89c761b9188c0e4d50a0167de6b98af7bebd05b2627f45f5fca84690cd86a61ba5a612870cf53",
"pkRm": "0401635b3074ad37b752696d5ca311da9cc790a899116030e4c71b83edd06ced92fdd238f6c921132852f20e6a2cbcf2659739232f4a69390f2b14d80667bcf9b71983000a919d29366554f53107a6c4cc7f8b24fa2de97b42433610cbd236d5a2c668e991ff4c4383e9fe0a9e7858fc39064e31fca1964e809a2f898c32fba46ce33575b8",
"enc": "0400932d9ff83ca4b799968bda0dd9dac4d02c9232cdcf133db7c53cfbf3d80a299fd99bc42da38bb78f57976bdb69988819b6e2924fadacdad8c05052997cf50b29110139f000af5b2c599b05fc63537d60a8384ca984821f8cd12621577a974ebadaf98bfdad6d1643dd4316062d7c0bda5ba0f0a2719992e993af615568abf19a256993",
"exports_accumulated": "29c0f6150908f6e0d979172f23f1d57b"
}
]

View file

@ -97,6 +97,7 @@ func sumSHAKE256(out, data []byte, length int) []byte {
} }
// SHA3 is an instance of a SHA-3 hash. It implements [hash.Hash]. // SHA3 is an instance of a SHA-3 hash. It implements [hash.Hash].
// The zero value is a usable SHA3-256 hash.
type SHA3 struct { type SHA3 struct {
s sha3.Digest s sha3.Digest
} }
@ -126,43 +127,57 @@ func New512() *SHA3 {
return &SHA3{*sha3.New512()} return &SHA3{*sha3.New512()}
} }
func (s *SHA3) init() {
if s.s.Size() == 0 {
*s = *New256()
}
}
// Write absorbs more data into the hash's state. // Write absorbs more data into the hash's state.
func (s *SHA3) Write(p []byte) (n int, err error) { func (s *SHA3) Write(p []byte) (n int, err error) {
s.init()
return s.s.Write(p) return s.s.Write(p)
} }
// Sum appends the current hash to b and returns the resulting slice. // Sum appends the current hash to b and returns the resulting slice.
func (s *SHA3) Sum(b []byte) []byte { func (s *SHA3) Sum(b []byte) []byte {
s.init()
return s.s.Sum(b) return s.s.Sum(b)
} }
// Reset resets the hash to its initial state. // Reset resets the hash to its initial state.
func (s *SHA3) Reset() { func (s *SHA3) Reset() {
s.init()
s.s.Reset() s.s.Reset()
} }
// Size returns the number of bytes Sum will produce. // Size returns the number of bytes Sum will produce.
func (s *SHA3) Size() int { func (s *SHA3) Size() int {
s.init()
return s.s.Size() return s.s.Size()
} }
// BlockSize returns the hash's rate. // BlockSize returns the hash's rate.
func (s *SHA3) BlockSize() int { func (s *SHA3) BlockSize() int {
s.init()
return s.s.BlockSize() return s.s.BlockSize()
} }
// MarshalBinary implements [encoding.BinaryMarshaler]. // MarshalBinary implements [encoding.BinaryMarshaler].
func (s *SHA3) MarshalBinary() ([]byte, error) { func (s *SHA3) MarshalBinary() ([]byte, error) {
s.init()
return s.s.MarshalBinary() return s.s.MarshalBinary()
} }
// AppendBinary implements [encoding.BinaryAppender]. // AppendBinary implements [encoding.BinaryAppender].
func (s *SHA3) AppendBinary(p []byte) ([]byte, error) { func (s *SHA3) AppendBinary(p []byte) ([]byte, error) {
s.init()
return s.s.AppendBinary(p) return s.s.AppendBinary(p)
} }
// UnmarshalBinary implements [encoding.BinaryUnmarshaler]. // UnmarshalBinary implements [encoding.BinaryUnmarshaler].
func (s *SHA3) UnmarshalBinary(data []byte) error { func (s *SHA3) UnmarshalBinary(data []byte) error {
s.init()
return s.s.UnmarshalBinary(data) return s.s.UnmarshalBinary(data)
} }
@ -173,10 +188,17 @@ func (d *SHA3) Clone() (hash.Cloner, error) {
} }
// SHAKE is an instance of a SHAKE extendable output function. // SHAKE is an instance of a SHAKE extendable output function.
// The zero value is a usable SHAKE256 hash.
type SHAKE struct { type SHAKE struct {
s sha3.SHAKE s sha3.SHAKE
} }
func (s *SHAKE) init() {
if s.s.Size() == 0 {
*s = *NewSHAKE256()
}
}
// NewSHAKE128 creates a new SHAKE128 XOF. // NewSHAKE128 creates a new SHAKE128 XOF.
func NewSHAKE128() *SHAKE { func NewSHAKE128() *SHAKE {
return &SHAKE{*sha3.NewShake128()} return &SHAKE{*sha3.NewShake128()}
@ -209,6 +231,7 @@ func NewCSHAKE256(N, S []byte) *SHAKE {
// //
// It panics if any output has already been read. // It panics if any output has already been read.
func (s *SHAKE) Write(p []byte) (n int, err error) { func (s *SHAKE) Write(p []byte) (n int, err error) {
s.init()
return s.s.Write(p) return s.s.Write(p)
} }
@ -216,30 +239,36 @@ func (s *SHAKE) Write(p []byte) (n int, err error) {
// //
// Any call to Write after a call to Read will panic. // Any call to Write after a call to Read will panic.
func (s *SHAKE) Read(p []byte) (n int, err error) { func (s *SHAKE) Read(p []byte) (n int, err error) {
s.init()
return s.s.Read(p) return s.s.Read(p)
} }
// Reset resets the XOF to its initial state. // Reset resets the XOF to its initial state.
func (s *SHAKE) Reset() { func (s *SHAKE) Reset() {
s.init()
s.s.Reset() s.s.Reset()
} }
// BlockSize returns the rate of the XOF. // BlockSize returns the rate of the XOF.
func (s *SHAKE) BlockSize() int { func (s *SHAKE) BlockSize() int {
s.init()
return s.s.BlockSize() return s.s.BlockSize()
} }
// MarshalBinary implements [encoding.BinaryMarshaler]. // MarshalBinary implements [encoding.BinaryMarshaler].
func (s *SHAKE) MarshalBinary() ([]byte, error) { func (s *SHAKE) MarshalBinary() ([]byte, error) {
s.init()
return s.s.MarshalBinary() return s.s.MarshalBinary()
} }
// AppendBinary implements [encoding.BinaryAppender]. // AppendBinary implements [encoding.BinaryAppender].
func (s *SHAKE) AppendBinary(p []byte) ([]byte, error) { func (s *SHAKE) AppendBinary(p []byte) ([]byte, error) {
s.init()
return s.s.AppendBinary(p) return s.s.AppendBinary(p)
} }
// UnmarshalBinary implements [encoding.BinaryUnmarshaler]. // UnmarshalBinary implements [encoding.BinaryUnmarshaler].
func (s *SHAKE) UnmarshalBinary(data []byte) error { func (s *SHAKE) UnmarshalBinary(data []byte) error {
s.init()
return s.s.UnmarshalBinary(data) return s.s.UnmarshalBinary(data)
} }

View file

@ -26,6 +26,7 @@ var testDigests = map[string]func() *SHA3{
"SHA3-256": New256, "SHA3-256": New256,
"SHA3-384": New384, "SHA3-384": New384,
"SHA3-512": New512, "SHA3-512": New512,
"SHA3-Zero": func() *SHA3 { return &SHA3{} },
} }
// testShakes contains functions that return *sha3.SHAKE instances for // testShakes contains functions that return *sha3.SHAKE instances for
@ -40,6 +41,7 @@ var testShakes = map[string]struct {
"SHAKE256": {NewCSHAKE256, "", ""}, "SHAKE256": {NewCSHAKE256, "", ""},
"cSHAKE128": {NewCSHAKE128, "CSHAKE128", "CustomString"}, "cSHAKE128": {NewCSHAKE128, "CSHAKE128", "CustomString"},
"cSHAKE256": {NewCSHAKE256, "CSHAKE256", "CustomString"}, "cSHAKE256": {NewCSHAKE256, "CSHAKE256", "CustomString"},
"SHAKE-Zero": {func(N []byte, S []byte) *SHAKE { return &SHAKE{} }, "", ""},
} }
func TestSHA3Hash(t *testing.T) { func TestSHA3Hash(t *testing.T) {

View file

@ -889,13 +889,29 @@ type Config struct {
// with a specific ECH config known to a client. // with a specific ECH config known to a client.
type EncryptedClientHelloKey struct { type EncryptedClientHelloKey struct {
// Config should be a marshalled ECHConfig associated with PrivateKey. This // Config should be a marshalled ECHConfig associated with PrivateKey. This
// must match the config provided to clients byte-for-byte. The config // must match the config provided to clients byte-for-byte. The config must
// should only specify the DHKEM(X25519, HKDF-SHA256) KEM ID (0x0020), the // use as KEM one of
// HKDF-SHA256 KDF ID (0x0001), and a subset of the following AEAD IDs: //
// AES-128-GCM (0x0001), AES-256-GCM (0x0002), ChaCha20Poly1305 (0x0003). // - DHKEM(P-256, HKDF-SHA256) (0x0010)
// - DHKEM(P-384, HKDF-SHA384) (0x0011)
// - DHKEM(P-521, HKDF-SHA512) (0x0012)
// - DHKEM(X25519, HKDF-SHA256) (0x0020)
//
// and as KDF one of
//
// - HKDF-SHA256 (0x0001)
// - HKDF-SHA384 (0x0002)
// - HKDF-SHA512 (0x0003)
//
// and as AEAD one of
//
// - AES-128-GCM (0x0001)
// - AES-256-GCM (0x0002)
// - ChaCha20Poly1305 (0x0003)
//
Config []byte Config []byte
// PrivateKey should be a marshalled private key. Currently, we expect // PrivateKey should be a marshalled private key, in the format expected by
// this to be the output of [ecdh.PrivateKey.Bytes]. // HPKE's DeserializePrivateKey (see RFC 9180), for the KEM used in Config.
PrivateKey []byte PrivateKey []byte
// SendAsRetry indicates if Config should be sent as part of the list of // SendAsRetry indicates if Config should be sent as part of the list of
// retry configs when ECH is requested by the client but rejected by the // retry configs when ECH is requested by the client but rejected by the

View file

@ -9,24 +9,11 @@ import (
"crypto/internal/hpke" "crypto/internal/hpke"
"errors" "errors"
"fmt" "fmt"
"slices"
"strings" "strings"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
) )
// sortedSupportedAEADs is just a sorted version of hpke.SupportedAEADS.
// We need this so that when we insert them into ECHConfigs the ordering
// is stable.
var sortedSupportedAEADs []uint16
func init() {
for aeadID := range hpke.SupportedAEADs {
sortedSupportedAEADs = append(sortedSupportedAEADs, aeadID)
}
slices.Sort(sortedSupportedAEADs)
}
type echCipher struct { type echCipher struct {
KDFID uint16 KDFID uint16
AEADID uint16 AEADID uint16
@ -162,25 +149,8 @@ func parseECHConfigList(data []byte) ([]echConfig, error) {
return configs, nil return configs, nil
} }
func pickECHConfig(list []echConfig) *echConfig { func pickECHConfig(list []echConfig) (*echConfig, hpke.PublicKey, hpke.KDF, hpke.AEAD) {
for _, ec := range list { for _, ec := range list {
if _, ok := hpke.SupportedKEMs[ec.KemID]; !ok {
continue
}
var validSCS bool
for _, cs := range ec.SymmetricCipherSuite {
if _, ok := hpke.SupportedAEADs[cs.AEADID]; !ok {
continue
}
if _, ok := hpke.SupportedKDFs[cs.KDFID]; !ok {
continue
}
validSCS = true
break
}
if !validSCS {
continue
}
if !validDNSName(string(ec.PublicName)) { if !validDNSName(string(ec.PublicName)) {
continue continue
} }
@ -196,25 +166,32 @@ func pickECHConfig(list []echConfig) *echConfig {
if unsupportedExt { if unsupportedExt {
continue continue
} }
return &ec kem, err := hpke.NewKEM(ec.KemID)
} if err != nil {
return nil
}
func pickECHCipherSuite(suites []echCipher) (echCipher, error) {
for _, s := range suites {
// NOTE: all of the supported AEADs and KDFs are fine, rather than
// imposing some sort of preference here, we just pick the first valid
// suite.
if _, ok := hpke.SupportedAEADs[s.AEADID]; !ok {
continue continue
} }
if _, ok := hpke.SupportedKDFs[s.KDFID]; !ok { pub, err := kem.NewPublicKey(ec.PublicKey)
if err != nil {
// This is an error in the config, but killing the connection feels
// excessive.
continue continue
} }
return s, nil for _, cs := range ec.SymmetricCipherSuite {
// All of the supported AEADs and KDFs are fine, rather than
// imposing some sort of preference here, we just pick the first
// valid suite.
kdf, err := hpke.NewKDF(cs.KDFID)
if err != nil {
continue
} }
return echCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH") aead, err := hpke.NewAEAD(cs.AEADID)
if err != nil {
continue
}
return &ec, pub, kdf, aead
}
}
return nil, nil, nil, nil
} }
func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) { func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) {
@ -592,18 +569,33 @@ func (c *Conn) processECHClientHello(outer *clientHelloMsg, echKeys []EncryptedC
skip, config, err := parseECHConfig(echKey.Config) skip, config, err := parseECHConfig(echKey.Config)
if err != nil || skip { if err != nil || skip {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys Config: %s", err) return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config: %s", err)
} }
if skip { if skip {
continue continue
} }
echPriv, err := hpke.ParseHPKEPrivateKey(config.KemID, echKey.PrivateKey) kem, err := hpke.NewKEM(config.KemID)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys PrivateKey: %s", err) return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config KEM: %s", err)
}
echPriv, err := kem.NewPrivateKey(echKey.PrivateKey)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey PrivateKey: %s", err)
}
kdf, err := hpke.NewKDF(echCiphersuite.KDFID)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config KDF: %s", err)
}
aead, err := hpke.NewAEAD(echCiphersuite.AEADID)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config AEAD: %s", err)
} }
info := append([]byte("tls ech\x00"), echKey.Config...) info := append([]byte("tls ech\x00"), echKey.Config...)
hpkeContext, err := hpke.SetupRecipient(hpke.DHKEM_X25519_HKDF_SHA256, echCiphersuite.KDFID, echCiphersuite.AEADID, echPriv, info, encap) hpkeContext, err := hpke.NewRecipient(encap, echPriv, kdf, aead, info)
if err != nil { if err != nil {
// attempt next trial decryption // attempt next trial decryption
continue continue

View file

@ -41,7 +41,7 @@ func TestSkipBadConfigs(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
config := pickECHConfig(configs) config, _, _, _ := pickECHConfig(configs)
if config != nil { if config != nil {
t.Fatal("pickECHConfig picked an invalid config") t.Fatal("pickECHConfig picked an invalid config")
} }

View file

@ -205,11 +205,11 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCli
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
echConfig := pickECHConfig(echConfigs) echConfig, echPK, kdf, aead := pickECHConfig(echConfigs)
if echConfig == nil { if echConfig == nil {
return nil, nil, nil, errors.New("tls: EncryptedClientHelloConfigList contains no valid configs") return nil, nil, nil, errors.New("tls: EncryptedClientHelloConfigList contains no valid configs")
} }
ech = &echClientContext{config: echConfig} ech = &echClientContext{config: echConfig, kdfID: kdf.ID(), aeadID: aead.ID()}
hello.encryptedClientHello = []byte{1} // indicate inner hello hello.encryptedClientHello = []byte{1} // indicate inner hello
// We need to explicitly set these 1.2 fields to nil, as we do not // We need to explicitly set these 1.2 fields to nil, as we do not
// marshal them when encoding the inner hello, otherwise transcripts // marshal them when encoding the inner hello, otherwise transcripts
@ -219,17 +219,8 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCli
hello.secureRenegotiationSupported = false hello.secureRenegotiationSupported = false
hello.extendedMasterSecret = false hello.extendedMasterSecret = false
echPK, err := hpke.ParseHPKEPublicKey(ech.config.KemID, ech.config.PublicKey)
if err != nil {
return nil, nil, nil, err
}
suite, err := pickECHCipherSuite(ech.config.SymmetricCipherSuite)
if err != nil {
return nil, nil, nil, err
}
ech.kdfID, ech.aeadID = suite.KDFID, suite.AEADID
info := append([]byte("tls ech\x00"), ech.config.raw...) info := append([]byte("tls ech\x00"), ech.config.raw...)
ech.encapsulatedKey, ech.hpkeContext, err = hpke.SetupSender(ech.config.KemID, suite.KDFID, suite.AEADID, echPK, info) ech.encapsulatedKey, ech.hpkeContext, err = hpke.NewSender(echPK, kdf, aead, info)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -317,7 +308,11 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if hello.earlyData { if hello.earlyData {
suite := cipherSuiteTLS13ByID(session.cipherSuite) suite := cipherSuiteTLS13ByID(session.cipherSuite)
transcript := suite.hash.New() transcript := suite.hash.New()
if err := transcriptMsg(hello, transcript); err != nil { transcriptHello := hello
if ech != nil {
transcriptHello = ech.innerHello
}
if err := transcriptMsg(transcriptHello, transcript); err != nil {
return err return err
} }
earlyTrafficSecret := earlySecret.ClientEarlyTrafficSecret(transcript) earlyTrafficSecret := earlySecret.ClientEarlyTrafficSecret(transcript)

View file

@ -117,6 +117,11 @@ const (
// The application may modify the [SessionState] before storing it. // The application may modify the [SessionState] before storing it.
// This event only occurs on client connections. // This event only occurs on client connections.
QUICStoreSession QUICStoreSession
// QUICErrorEvent indicates that a fatal error has occurred.
// The handshake cannot proceed and the connection must be closed.
// QUICEvent.Err is set.
QUICErrorEvent
) )
// A QUICEvent is an event occurring on a QUIC connection. // A QUICEvent is an event occurring on a QUIC connection.
@ -138,6 +143,10 @@ type QUICEvent struct {
// Set for QUICResumeSession and QUICStoreSession. // Set for QUICResumeSession and QUICStoreSession.
SessionState *SessionState SessionState *SessionState
// Set for QUICErrorEvent.
// The error will wrap AlertError.
Err error
} }
type quicState struct { type quicState struct {
@ -157,6 +166,7 @@ type quicState struct {
cancel context.CancelFunc cancel context.CancelFunc
waitingForDrain bool waitingForDrain bool
errorReturned bool
// readbuf is shared between HandleData and the handshake goroutine. // readbuf is shared between HandleData and the handshake goroutine.
// HandshakeCryptoData passes ownership to the handshake goroutine by // HandshakeCryptoData passes ownership to the handshake goroutine by
@ -229,6 +239,15 @@ func (q *QUICConn) NextEvent() QUICEvent {
<-qs.signalc <-qs.signalc
<-qs.blockedc <-qs.blockedc
} }
if err := q.conn.handshakeErr; err != nil {
if qs.errorReturned {
return QUICEvent{Kind: QUICNoEvent}
}
qs.errorReturned = true
qs.events = nil
qs.nextEvent = 0
return QUICEvent{Kind: QUICErrorEvent, Err: q.conn.handshakeErr}
}
if qs.nextEvent >= len(qs.events) { if qs.nextEvent >= len(qs.events) {
qs.events = qs.events[:0] qs.events = qs.events[:0]
qs.nextEvent = 0 qs.nextEvent = 0

View file

@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -21,6 +22,7 @@ type testQUICConn struct {
ticketOpts QUICSessionTicketOptions ticketOpts QUICSessionTicketOptions
onResumeSession func(*SessionState) onResumeSession func(*SessionState)
gotParams []byte gotParams []byte
gotError error
earlyDataRejected bool earlyDataRejected bool
complete bool complete bool
} }
@ -109,6 +111,9 @@ func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent
if onEvent != nil && onEvent(e, a, b) { if onEvent != nil && onEvent(e, a, b) {
continue continue
} }
if a.gotError != nil && e.Kind != QUICNoEvent {
return fmt.Errorf("unexpected event %v after QUICErrorEvent", e.Kind)
}
switch e.Kind { switch e.Kind {
case QUICNoEvent: case QUICNoEvent:
idleCount++ idleCount++
@ -152,6 +157,11 @@ func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent
} }
case QUICRejectedEarlyData: case QUICRejectedEarlyData:
a.earlyDataRejected = true a.earlyDataRejected = true
case QUICErrorEvent:
if e.Err == nil {
return errors.New("unexpected QUICErrorEvent with no Err")
}
a.gotError = e.Err
} }
if e.Kind != QUICNoEvent { if e.Kind != QUICNoEvent {
idleCount = 0 idleCount = 0
@ -371,6 +381,45 @@ func TestQUICHandshakeError(t *testing.T) {
if _, ok := errors.AsType[*CertificateVerificationError](err); !ok { if _, ok := errors.AsType[*CertificateVerificationError](err); !ok {
t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err) t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
} }
ev := cli.conn.NextEvent()
if ev.Kind != QUICErrorEvent {
t.Errorf("client.NextEvent: no QUICErrorEvent, want one")
}
if ev.Err != err {
t.Errorf("client.NextEvent: want same error returned by Start, got %v", ev.Err)
}
}
// Test that we can report an error produced by the GetEncryptedClientHelloKeys function.
func TestQUICECHKeyError(t *testing.T) {
getECHKeysError := errors.New("error returned by GetEncryptedClientHelloKeys")
config := &QUICConfig{TLSConfig: testConfig.Clone()}
config.TLSConfig.MinVersion = VersionTLS13
config.TLSConfig.NextProtos = []string{"h3"}
config.TLSConfig.GetEncryptedClientHelloKeys = func(*ClientHelloInfo) ([]EncryptedClientHelloKey, error) {
return nil, getECHKeysError
}
cli := newTestQUICClient(t, config)
cli.conn.SetTransportParameters(nil)
srv := newTestQUICServer(t, config)
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
}
srv.conn.SetTransportParameters(nil)
if err := runTestQUICConnection(context.Background(), cli, srv, nil); err == nil {
t.Fatalf("handshake with GetEncryptedClientHelloKeys errors: nil, want error")
}
if srv.gotError == nil {
t.Fatalf("after GetEncryptedClientHelloKeys error, server did not see QUICErrorEvent")
}
if _, ok := errors.AsType[AlertError](srv.gotError); !ok {
t.Errorf("connection handshake terminated with error %T, want AlertError", srv.gotError)
}
if !errors.Is(srv.gotError, getECHKeysError) {
t.Errorf("connection handshake terminated with error %v, want error returned by GetEncryptedClientHelloKeys", srv.gotError)
}
} }
// Test that QUICConn.ConnectionState can be used during the handshake, // Test that QUICConn.ConnectionState can be used during the handshake,

View file

@ -11,7 +11,6 @@ import (
"crypto/ecdh" "crypto/ecdh"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/internal/hpke"
"crypto/rand" "crypto/rand"
"crypto/tls/internal/fips140tls" "crypto/tls/internal/fips140tls"
"crypto/x509" "crypto/x509"
@ -2249,15 +2248,13 @@ func TestECH(t *testing.T) {
builder.AddUint16(extensionEncryptedClientHello) builder.AddUint16(extensionEncryptedClientHello)
builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
builder.AddUint8(id) builder.AddUint8(id)
builder.AddUint16(hpke.DHKEM_X25519_HKDF_SHA256) // The only DHKEM we support builder.AddUint16(0x0020 /* DHKEM(X25519, HKDF-SHA256) */)
builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
builder.AddBytes(pubKey) builder.AddBytes(pubKey)
}) })
builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
for _, aeadID := range sortedSupportedAEADs { builder.AddUint16(0x0001 /* HKDF-SHA256 */)
builder.AddUint16(hpke.KDF_HKDF_SHA256) // The only KDF we support builder.AddUint16(0x0001 /* AES-128-GCM */)
builder.AddUint16(aeadID)
}
}) })
builder.AddUint8(maxNameLen) builder.AddUint8(maxNameLen)
builder.AddUint8LengthPrefixed(func(builder *cryptobyte.Builder) { builder.AddUint8LengthPrefixed(func(builder *cryptobyte.Builder) {

View file

@ -55,7 +55,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
// it isn't expecting. The final error will be thrown // it isn't expecting. The final error will be thrown
// in the argument converter loop. // in the argument converter loop.
index := nv.Ordinal - 1 index := nv.Ordinal - 1
if c.want <= index { if c.want >= 0 && c.want <= index {
return nil return nil
} }

View file

@ -5135,3 +5135,50 @@ func TestIssue69728(t *testing.T) {
t.Errorf("not equal; v1 = %v, v2 = %v", v1, v2) t.Errorf("not equal; v1 = %v, v2 = %v", v1, v2)
} }
} }
func TestColumnConverterWithUnknownInputCount(t *testing.T) {
db := OpenDB(&unknownInputsConnector{})
stmt, err := db.Prepare("SELECT ?")
if err != nil {
t.Fatal(err)
}
_, err = stmt.Exec(1)
if err != nil {
t.Fatal(err)
}
}
type unknownInputsConnector struct{}
func (unknownInputsConnector) Connect(context.Context) (driver.Conn, error) {
return unknownInputsConn{}, nil
}
func (unknownInputsConnector) Driver() driver.Driver { return nil }
type unknownInputsConn struct{}
func (unknownInputsConn) Prepare(string) (driver.Stmt, error) { return unknownInputsStmt{}, nil }
func (unknownInputsConn) Close() error { return nil }
func (unknownInputsConn) Begin() (driver.Tx, error) { return nil, nil }
type unknownInputsStmt struct{}
func (unknownInputsStmt) Close() error { return nil }
func (unknownInputsStmt) NumInput() int { return -1 }
func (unknownInputsStmt) Exec(args []driver.Value) (driver.Result, error) {
if _, ok := args[0].(string); !ok {
return nil, fmt.Errorf("Expected string, got %T", args[0])
}
return nil, nil
}
func (unknownInputsStmt) Query([]driver.Value) (driver.Rows, error) { return nil, nil }
func (unknownInputsStmt) ColumnConverter(idx int) driver.ValueConverter {
return unknownInputsValueConverter{}
}
type unknownInputsValueConverter struct{}
func (unknownInputsValueConverter) ConvertValue(v any) (driver.Value, error) {
return "string", nil
}

Some files were not shown because too many files have changed in this diff Show more