cmd/compile: optimize comparisons with single bit difference

Optimize comparisons with constants that only differ by 1 bit (i.e.
a power of 2). For example:

    x == 4 || x == 6 -> x|2 == 6
    x != 1 && x != 5 -> x|4 != 5

Change-Id: Ic61719e5118446d21cf15652d9da22f7d95b2a15
Reviewed-on: https://go-review.googlesource.com/c/go/+/719420
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Auto-Submit: Keith Randall <khr@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
Reviewed-by: Keith Randall <khr@google.com>
This commit is contained in:
Michael Munday 2025-08-26 21:17:36 +01:00 committed by Gopher Robot
parent 1e5e6663e9
commit 0a569528ea
5 changed files with 530 additions and 1 deletions

View file

@ -337,6 +337,12 @@
(OrB ((Less|Leq)16U (Const16 [c]) x) (Leq16U x (Const16 [d]))) && uint16(c) >= uint16(d+1) && uint16(d+1) > uint16(d) => ((Less|Leq)16U (Const16 <x.Type> [c-d-1]) (Sub16 <x.Type> x (Const16 <x.Type> [d+1])))
(OrB ((Less|Leq)8U (Const8 [c]) x) (Leq8U x (Const8 [d]))) && uint8(c) >= uint8(d+1) && uint8(d+1) > uint8(d) => ((Less|Leq)8U (Const8 <x.Type> [c-d-1]) (Sub8 <x.Type> x (Const8 <x.Type> [d+1])))
// single bit difference: ( x != c && x != d ) -> ( x|(c^d) != c )
(AndB (Neq(64|32|16|8) x cv:(Const(64|32|16|8) [c])) (Neq(64|32|16|8) x (Const(64|32|16|8) [d]))) && c|d == c && oneBit(c^d) => (Neq(64|32|16|8) (Or(64|32|16|8) <x.Type> x (Const(64|32|16|8) <x.Type> [c^d])) cv)
// single bit difference: ( x == c || x == d ) -> ( x|(c^d) == c )
(OrB (Eq(64|32|16|8) x cv:(Const(64|32|16|8) [c])) (Eq(64|32|16|8) x (Const(64|32|16|8) [d]))) && c|d == c && oneBit(c^d) => (Eq(64|32|16|8) (Or(64|32|16|8) <x.Type> x (Const(64|32|16|8) <x.Type> [c^d])) cv)
// NaN check: ( x != x || x (>|>=|<|<=) c ) -> ( !(c (>=|>|<=|<) x) )
(OrB (Neq64F x x) ((Less|Leq)64F x y:(Const64F [c]))) => (Not ((Leq|Less)64F y x))
(OrB (Neq64F x x) ((Less|Leq)64F y:(Const64F [c]) x)) => (Not ((Leq|Less)64F x y))

View file

@ -10,7 +10,9 @@ import (
)
// fuseEarly runs fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck).
func fuseEarly(f *Func) { fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck) }
func fuseEarly(f *Func) {
fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeSingleBitDifference|fuseTypeNanCheck)
}
// fuseLate runs fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect).
func fuseLate(f *Func) { fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect) }
@ -21,6 +23,7 @@ const (
fuseTypePlain fuseType = 1 << iota
fuseTypeIf
fuseTypeIntInRange
fuseTypeSingleBitDifference
fuseTypeNanCheck
fuseTypeBranchRedirect
fuseTypeShortCircuit
@ -41,6 +44,9 @@ func fuse(f *Func, typ fuseType) {
if typ&fuseTypeIntInRange != 0 {
changed = fuseIntInRange(b) || changed
}
if typ&fuseTypeSingleBitDifference != 0 {
changed = fuseSingleBitDifference(b) || changed
}
if typ&fuseTypeNanCheck != 0 {
changed = fuseNanCheck(b) || changed
}

View file

@ -19,6 +19,14 @@ func fuseNanCheck(b *Block) bool {
return fuseComparisons(b, canOptNanCheck)
}
// fuseSingleBitDifference replaces the short-circuit operators between equality checks with
// constants that only differ by a single bit. For example, it would convert
// `if x == 4 || x == 6 { ... }` into `if (x == 4) | (x == 6) { ... }`. Rewrite rules can
// then optimize these using a bitwise operation, in this case generating `if x|2 == 6 { ... }`.
func fuseSingleBitDifference(b *Block) bool {
return fuseComparisons(b, canOptSingleBitDifference)
}
// fuseComparisons looks for control graphs that match this pattern:
//
// p - predecessor
@ -229,3 +237,40 @@ func canOptNanCheck(x, y *Value, op Op) bool {
}
return false
}
// canOptSingleBitDifference returns true if x op y matches either:
//
// v == c || v == d
// v != c && v != d
//
// Where c and d are constant values that differ by a single bit.
func canOptSingleBitDifference(x, y *Value, op Op) bool {
if x.Op != y.Op {
return false
}
switch x.Op {
case OpEq64, OpEq32, OpEq16, OpEq8:
if op != OpOrB {
return false
}
case OpNeq64, OpNeq32, OpNeq16, OpNeq8:
if op != OpAndB {
return false
}
default:
return false
}
xi := getConstIntArgIndex(x)
if xi < 0 {
return false
}
yi := getConstIntArgIndex(y)
if yi < 0 {
return false
}
if x.Args[xi^1] != y.Args[yi^1] {
return false
}
return oneBit(x.Args[xi].AuxInt ^ y.Args[yi].AuxInt)
}

View file

@ -5332,6 +5332,182 @@ func rewriteValuegeneric_OpAndB(v *Value) bool {
}
break
}
// match: (AndB (Neq64 x cv:(Const64 [c])) (Neq64 x (Const64 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Neq64 (Or64 <x.Type> x (Const64 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpNeq64 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst64 {
continue
}
c := auxIntToInt64(cv.AuxInt)
if v_1.Op != OpNeq64 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst64 {
continue
}
d := auxIntToInt64(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpNeq64)
v0 := b.NewValue0(v.Pos, OpOr64, x.Type)
v1 := b.NewValue0(v.Pos, OpConst64, x.Type)
v1.AuxInt = int64ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
// match: (AndB (Neq32 x cv:(Const32 [c])) (Neq32 x (Const32 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Neq32 (Or32 <x.Type> x (Const32 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpNeq32 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst32 {
continue
}
c := auxIntToInt32(cv.AuxInt)
if v_1.Op != OpNeq32 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst32 {
continue
}
d := auxIntToInt32(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpNeq32)
v0 := b.NewValue0(v.Pos, OpOr32, x.Type)
v1 := b.NewValue0(v.Pos, OpConst32, x.Type)
v1.AuxInt = int32ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
// match: (AndB (Neq16 x cv:(Const16 [c])) (Neq16 x (Const16 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Neq16 (Or16 <x.Type> x (Const16 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpNeq16 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst16 {
continue
}
c := auxIntToInt16(cv.AuxInt)
if v_1.Op != OpNeq16 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst16 {
continue
}
d := auxIntToInt16(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpNeq16)
v0 := b.NewValue0(v.Pos, OpOr16, x.Type)
v1 := b.NewValue0(v.Pos, OpConst16, x.Type)
v1.AuxInt = int16ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
// match: (AndB (Neq8 x cv:(Const8 [c])) (Neq8 x (Const8 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Neq8 (Or8 <x.Type> x (Const8 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpNeq8 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst8 {
continue
}
c := auxIntToInt8(cv.AuxInt)
if v_1.Op != OpNeq8 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst8 {
continue
}
d := auxIntToInt8(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpNeq8)
v0 := b.NewValue0(v.Pos, OpOr8, x.Type)
v1 := b.NewValue0(v.Pos, OpConst8, x.Type)
v1.AuxInt = int8ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
return false
}
func rewriteValuegeneric_OpArraySelect(v *Value) bool {
@ -23242,6 +23418,182 @@ func rewriteValuegeneric_OpOrB(v *Value) bool {
}
break
}
// match: (OrB (Eq64 x cv:(Const64 [c])) (Eq64 x (Const64 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Eq64 (Or64 <x.Type> x (Const64 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpEq64 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst64 {
continue
}
c := auxIntToInt64(cv.AuxInt)
if v_1.Op != OpEq64 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst64 {
continue
}
d := auxIntToInt64(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpEq64)
v0 := b.NewValue0(v.Pos, OpOr64, x.Type)
v1 := b.NewValue0(v.Pos, OpConst64, x.Type)
v1.AuxInt = int64ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
// match: (OrB (Eq32 x cv:(Const32 [c])) (Eq32 x (Const32 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Eq32 (Or32 <x.Type> x (Const32 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpEq32 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst32 {
continue
}
c := auxIntToInt32(cv.AuxInt)
if v_1.Op != OpEq32 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst32 {
continue
}
d := auxIntToInt32(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpEq32)
v0 := b.NewValue0(v.Pos, OpOr32, x.Type)
v1 := b.NewValue0(v.Pos, OpConst32, x.Type)
v1.AuxInt = int32ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
// match: (OrB (Eq16 x cv:(Const16 [c])) (Eq16 x (Const16 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Eq16 (Or16 <x.Type> x (Const16 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpEq16 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst16 {
continue
}
c := auxIntToInt16(cv.AuxInt)
if v_1.Op != OpEq16 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst16 {
continue
}
d := auxIntToInt16(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpEq16)
v0 := b.NewValue0(v.Pos, OpOr16, x.Type)
v1 := b.NewValue0(v.Pos, OpConst16, x.Type)
v1.AuxInt = int16ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
// match: (OrB (Eq8 x cv:(Const8 [c])) (Eq8 x (Const8 [d])))
// cond: c|d == c && oneBit(c^d)
// result: (Eq8 (Or8 <x.Type> x (Const8 <x.Type> [c^d])) cv)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpEq8 {
continue
}
_ = v_0.Args[1]
v_0_0 := v_0.Args[0]
v_0_1 := v_0.Args[1]
for _i1 := 0; _i1 <= 1; _i1, v_0_0, v_0_1 = _i1+1, v_0_1, v_0_0 {
x := v_0_0
cv := v_0_1
if cv.Op != OpConst8 {
continue
}
c := auxIntToInt8(cv.AuxInt)
if v_1.Op != OpEq8 {
continue
}
_ = v_1.Args[1]
v_1_0 := v_1.Args[0]
v_1_1 := v_1.Args[1]
for _i2 := 0; _i2 <= 1; _i2, v_1_0, v_1_1 = _i2+1, v_1_1, v_1_0 {
if x != v_1_0 || v_1_1.Op != OpConst8 {
continue
}
d := auxIntToInt8(v_1_1.AuxInt)
if !(c|d == c && oneBit(c^d)) {
continue
}
v.reset(OpEq8)
v0 := b.NewValue0(v.Pos, OpOr8, x.Type)
v1 := b.NewValue0(v.Pos, OpConst8, x.Type)
v1.AuxInt = int8ToAuxInt(c ^ d)
v0.AddArg2(x, v1)
v.AddArg2(v0, cv)
return true
}
}
}
break
}
// match: (OrB (Neq64F x x) (Less64F x y:(Const64F [c])))
// result: (Not (Leq64F y x))
for {

View file

@ -198,6 +198,126 @@ func ui4d(c <-chan uint8) {
}
}
// ------------------------------------ //
// single bit difference (conjunction) //
// ------------------------------------ //
func sisbc64(c <-chan int64) {
// amd64: "ORQ [$]2,"
// riscv64: "ORI [$]2,"
for x := <-c; x != 4 && x != 6; x = <-c {
}
}
func sisbc32(c <-chan int32) {
// amd64: "ORL [$]4,"
// riscv64: "ORI [$]4,"
for x := <-c; x != -1 && x != -5; x = <-c {
}
}
func sisbc16(c <-chan int16) {
// amd64: "ORL [$]32,"
// riscv64: "ORI [$]32,"
for x := <-c; x != 16 && x != 48; x = <-c {
}
}
func sisbc8(c <-chan int8) {
// amd64: "ORL [$]16,"
// riscv64: "ORI [$]16,"
for x := <-c; x != -15 && x != -31; x = <-c {
}
}
func uisbc64(c <-chan uint64) {
// amd64: "ORQ [$]4,"
// riscv64: "ORI [$]4,"
for x := <-c; x != 1 && x != 5; x = <-c {
}
}
func uisbc32(c <-chan uint32) {
// amd64: "ORL [$]4,"
// riscv64: "ORI [$]4,"
for x := <-c; x != 2 && x != 6; x = <-c {
}
}
func uisbc16(c <-chan uint16) {
// amd64: "ORL [$]32,"
// riscv64: "ORI [$]32,"
for x := <-c; x != 16 && x != 48; x = <-c {
}
}
func uisbc8(c <-chan uint8) {
// amd64: "ORL [$]64,"
// riscv64: "ORI [$]64,"
for x := <-c; x != 64 && x != 0; x = <-c {
}
}
// ------------------------------------ //
// single bit difference (disjunction) //
// ------------------------------------ //
func sisbd64(c <-chan int64) {
// amd64: "ORQ [$]2,"
// riscv64: "ORI [$]2,"
for x := <-c; x == 4 || x == 6; x = <-c {
}
}
func sisbd32(c <-chan int32) {
// amd64: "ORL [$]4,"
// riscv64: "ORI [$]4,"
for x := <-c; x == -1 || x == -5; x = <-c {
}
}
func sisbd16(c <-chan int16) {
// amd64: "ORL [$]32,"
// riscv64: "ORI [$]32,"
for x := <-c; x == 16 || x == 48; x = <-c {
}
}
func sisbd8(c <-chan int8) {
// amd64: "ORL [$]16,"
// riscv64: "ORI [$]16,"
for x := <-c; x == -15 || x == -31; x = <-c {
}
}
func uisbd64(c <-chan uint64) {
// amd64: "ORQ [$]4,"
// riscv64: "ORI [$]4,"
for x := <-c; x == 1 || x == 5; x = <-c {
}
}
func uisbd32(c <-chan uint32) {
// amd64: "ORL [$]4,"
// riscv64: "ORI [$]4,"
for x := <-c; x == 2 || x == 6; x = <-c {
}
}
func uisbd16(c <-chan uint16) {
// amd64: "ORL [$]32,"
// riscv64: "ORI [$]32,"
for x := <-c; x == 16 || x == 48; x = <-c {
}
}
func uisbd8(c <-chan uint8) {
// amd64: "ORL [$]64,"
// riscv64: "ORI [$]64,"
for x := <-c; x == 64 || x == 0; x = <-c {
}
}
// -------------------------------------//
// merge NaN checks //
// ------------------------------------ //