cmd/compile: use binsearch-not-table for simd non-constant immediates when retpoline

The simd api translates non-constant inputs to instructions that
require and immediate operand into switch tables.  This occurs
later in the pipeline than the retpoline guarded code that replaces
table switches with binary search.

Attempting to introduce Go simd code into the runtime triggered a test
that checked for retpoline correctness and found this problem.

The fix is to replace the table switch with a binary search where
the intrinsic is inserted, if retpoline is active.

This CL is adapted from part of a pending dev.simd CL for WASM support
that also needs binary-search-switches.

Change-Id: Ia9d38d09695307fa3955fe3d9348637935022278
Reviewed-on: https://go-review.googlesource.com/c/go/+/773900
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
David Chase 2026-05-04 13:57:57 -04:00
parent 07840ceeed
commit d5ebe8100d

View file

@ -1971,6 +1971,11 @@ func opLen4_31(op ssa.Op, t *types.Type) func(s *state, n *ir.CallExpr, args []*
}
func immJumpTable(s *state, idx *ssa.Value, intrinsicCall *ir.CallExpr, genOp func(*state, int)) *ssa.Value {
if base.Ctxt.Retpoline {
// Note spectre=all implies retpoline which requires binary search instead of table switch.
return branchTableImm8(s, idx, intrinsicCall, genOp)
}
// Make blocks we'll need.
bEnd := s.f.NewBlock(ssa.BlockPlain)
@ -1985,10 +1990,7 @@ func immJumpTable(s *state, idx *ssa.Value, intrinsicCall *ir.CallExpr, genOp fu
b := s.curBlock
b.Kind = ssa.BlockJumpTable
b.Pos = intrinsicCall.Pos()
if base.Flag.Cfg.SpectreIndex {
// Potential Spectre vulnerability hardening?
idx = s.newValue2(ssa.OpSpectreSliceIndex, t, idx, s.uintptrConstant(255))
}
b.SetControl(idx)
targets := [256]*ssa.Block{}
for i := range 256 {
@ -2012,6 +2014,77 @@ func immJumpTable(s *state, idx *ssa.Value, intrinsicCall *ir.CallExpr, genOp fu
return ret
}
func branchTableImm8(s *state, idx *ssa.Value, intrinsicCall *ir.CallExpr, genOp func(*state, int)) *ssa.Value {
return branchTableN(s, idx, intrinsicCall, genOp, 256, true)
}
func branchTableN(s *state, idx *ssa.Value, intrinsicCall *ir.CallExpr, genOp func(*state, int), immLimit uint64, preChecked bool) *ssa.Value {
// Make blocks we'll need.
bEnd := s.f.NewBlock(ssa.BlockPlain)
bPanic := s.f.NewBlock(ssa.BlockPlain)
jt := s.f.NewBlock(ssa.BlockPlain)
t := types.Types[types.TUINTPTR]
idx = s.conv(nil, idx, idx.Type, t)
if !preChecked {
// Begin with a bounds check
width := s.uintptrConstant(immLimit)
cmp := s.newValue2(s.ssaOp(ir.OLT, t), types.Types[types.TBOOL], idx, width)
bb := s.endBlock()
bb.Kind = ssa.BlockIf
bb.SetControl(cmp)
bb.AddEdgeTo(jt) // in range - use jump table
bb.AddEdgeTo(bPanic) // out of range - panic
bb.Likely = ssa.BranchLikely // panic is unlikely
s.startBlock(bPanic)
s.rtcall(ir.Syms.PanicSimdImm, false, nil)
}
s.startBlock(jt)
jt.Kind = ssa.BlockPlain
jt.Pos = intrinsicCall.Pos()
branchTableNInner(s, idx, 0, immLimit, genOp, bEnd)
s.startBlock(bEnd)
ret := s.variable(intrinsicCall, intrinsicCall.Type())
return ret
}
func branchTableNInner(s *state, idx *ssa.Value, lowInclusive, len uint64, genOp func(*state, int), bEnd *ssa.Block) {
t := types.Types[types.TUINTPTR]
if len == 0 {
panic("empty branch table")
}
if len == 1 {
genOp(s, int(lowInclusive+len-1))
if s.curBlock != nil { // if genOp was "panic" then curBlock is already ended and nil
if s.curBlock.Kind != ssa.BlockExit {
s.curBlock.AddEdgeTo(bEnd)
}
s.endBlock()
}
return
}
s.curBlock.Kind = ssa.BlockIf
cmp := s.newValue2(s.ssaOp(ir.OLT, t), types.Types[types.TBOOL], idx, s.uintptrConstant(lowInclusive+len/2))
bb := s.endBlock()
bb.Kind = ssa.BlockIf
bb.SetControl(cmp)
bMatch := s.f.NewBlock(ssa.BlockPlain)
bNext := s.f.NewBlock(ssa.BlockPlain)
bb.AddEdgeTo(bMatch)
bb.AddEdgeTo(bNext)
s.startBlock(bMatch)
branchTableNInner(s, idx, lowInclusive, len/2, genOp, bEnd)
s.startBlock(bNext)
branchTableNInner(s, idx, lowInclusive+len/2, len-len/2, genOp, bEnd)
}
func opLen1Imm8(op ssa.Op, t *types.Type, offset int) func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
return func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value {
if args[1].Op == ssa.OpConst8 {