cmd/compile: make computeKnownBitsForShift iteration faster

I didn't even tried benchmarking this, I doubt it can be measured.
I needed this to remove O(n²) behavior for code that calls
computeKnownBitsForShift as part of computing known bits for
add and sub.

Change-Id: I6bab20cd6b65fb389e345e5745d17c364fb3d233
Reviewed-on: https://go-review.googlesource.com/c/go/+/773840
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Jorropo <jorropo.pgm@gmail.com>
Reviewed-by: Keith Randall <khr@golang.org>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
This commit is contained in:
Jorropo 2026-05-04 18:52:28 +02:00 committed by Gopher Robot
parent fabaedcbe8
commit 880ef11ecf
2 changed files with 100 additions and 6 deletions

View file

@ -289,13 +289,11 @@ func (kb *knownBitsState) computeKnownBitsForShift(v *Value, doShiftByAConst fun
value, known = doShiftByAConst(x, xk, xSize, 64)
set = true
}
yk &= xSize - 1
for i := range xSize {
if i&yk != y {
continue
}
a, k := doShiftByAConst(x, xk, xSize, int64(i))
yk |= ^(xSize - 1)
for i := range allPossibleValues(y, yk) {
a, k := doShiftByAConst(x, xk, xSize, i)
if !set {
value, known = a, k
set = true
@ -310,3 +308,33 @@ func (kb *knownBitsState) computeKnownBitsForShift(v *Value, doShiftByAConst fun
return value & known, known
}
// allPossibleValues iterates over all values that could exist.
// It scales exponentially with the number of unknown bits,
// the exact number of iterations will be uint128(1)<<bits.OnesCount64(^known)
// thus be careful with what values you pass to it.
func allPossibleValues(value, known int64) func(yield func(v int64) bool) {
unknown := ^known
return func(yield func(v int64) bool) {
// This finds the next valid value for the variable bits.
// It is equivalent to (s|known + 1) & unknown.
// The s|known step creates blocks of 1s in all the known bits.
// +1 finds the next possible value, the blocks of 1s set in the previous step allows it to skip over blocks of known bits.
// & unknown clears garbage generated by the blocks of ones and overflow.
//
// You can transform (s|known + 1) & unknown into (s - unknown) & unknown through:
// (s + known + 1) & unknown: s | known → s + known (since s & known == 0)
// (s + ^unknown + 1) & unknown: known → ^unknown (definition of unknown)
// (s + -unknown) & unknown: ^unknown + 1 → -unknown (two's complement negation)
// (s - unknown) & unknown: s + -unknown → s - unknown (arithmetic)
for s := int64(0); ; s = (s - unknown) & unknown {
// fixed bits | current variable bits gives the current iteration
if !yield(value | s) {
return
}
if s == unknown {
break
}
}
}
}

View file

@ -0,0 +1,66 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssa
import (
"fmt"
"iter"
"math/bits"
"testing"
)
// allPossibleValuesRejection has identical behavior to allPossibleValues
// but it is implemented with an obviously correct rejection based algorithm.
// We use it to test that allPossibleValues.
func allPossibleValuesRejection(value, known, max int64) func(yield func(v int64) bool) {
return func(yield func(v int64) bool) {
for i := int64(0); i <= max; {
if i&known == value {
if !yield(i) {
return
}
}
next, overflow := bits.Add64(uint64(i), 1, 0)
if overflow != 0 {
// exit condition in case the 64th bit is unknown.
break
}
i = int64(next)
}
}
}
func TestAllPossibleValues(t *testing.T) {
// We can't test too much since it scales exponentially with the number of unknown bits.
const tryMask = int64(0b0111_1111)
for i := int64(0); uint64(i) <= uint64(tryMask); i++ {
unknown := ^i
known := i | ^tryMask
for value := range allPossibleValuesRejection(0, unknown, tryMask) { // don't use allPossibleValues since it's what we are about to test.
t.Run(fmt.Sprintf("known=%b,value=%b", uint64(known), uint64(value)), func(t *testing.T) {
truth, truthStop := iter.Pull(allPossibleValuesRejection(value, known, tryMask))
defer truthStop()
dut, dutStop := iter.Pull(allPossibleValues(value, known))
defer dutStop()
for i := int64(0); ; i++ {
want, wantOk := truth()
got, gotOk := dut()
if wantOk != gotOk {
t.Fatalf("unexpected ok at iteration %d: got %v, want %v", i, gotOk, wantOk)
}
if !gotOk {
break
}
if got != want {
t.Errorf("unexpected value at iteration %d: got %b, want %b", i, uint64(got), uint64(want))
}
}
})
}
}
}