cmd/compile: fold == != with a const and a bijective operation into the const

This extends a pattern we already match for Add* to
- Sub
- Sub (with swapped arguments)
- Xor
- Com
- Neg
- Mul

This more or less equates to constant folding and is particularly hard to
benchmark objectively for the same reasons.

It is 1 or 3 (for mul) cycles faster in a microbenchmark.

However it may require constants that are harder to materialize.

We currently do not consider these drawbacks in generic.rules.

I didn't originally thought the o.Uses == 1 was required however
certain arches like PPC64 are able to merge the CMP into the operation
in limited conditions which are broken by this CL.

Also if o.Uses == 1 we aren't removing a user, we could extand the
liveness of o's argument, without removing o increasing register pressure.

The latency gains should be invisible on branches, maybe not if used by
CondSelect or CvtBoolToUint8, but don't bother with theses unproven
dices.

Change-Id: I4fe6b5149576d2549e1157e5cc891af9edb79d55
Reviewed-on: https://go-review.googlesource.com/c/go/+/750181
Reviewed-by: Keith Randall <khr@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Jorropo <jorropo.pgm@gmail.com>
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
This commit is contained in:
Jorropo 2026-02-28 09:34:09 +01:00 committed by Gopher Robot
parent 75560e67c9
commit fabaedcbe8
5 changed files with 1691 additions and 82 deletions

View file

@ -289,20 +289,34 @@
(NeqB (ConstBool [true]) x) => (Not x)
(NeqB (Not x) y) => (EqB x y)
(Eq64 (Const64 <t> [c]) (Add64 (Const64 <t> [d]) x)) => (Eq64 (Const64 <t> [c-d]) x)
(Eq32 (Const32 <t> [c]) (Add32 (Const32 <t> [d]) x)) => (Eq32 (Const32 <t> [c-d]) x)
(Eq16 (Const16 <t> [c]) (Add16 (Const16 <t> [d]) x)) => (Eq16 (Const16 <t> [c-d]) x)
(Eq8 (Const8 <t> [c]) (Add8 (Const8 <t> [d]) x)) => (Eq8 (Const8 <t> [c-d]) x)
(Neq64 (Const64 <t> [c]) (Add64 (Const64 <t> [d]) x)) => (Neq64 (Const64 <t> [c-d]) x)
(Neq32 (Const32 <t> [c]) (Add32 (Const32 <t> [d]) x)) => (Neq32 (Const32 <t> [c-d]) x)
(Neq16 (Const16 <t> [c]) (Add16 (Const16 <t> [d]) x)) => (Neq16 (Const16 <t> [c-d]) x)
(Neq8 (Const8 <t> [c]) (Add8 (Const8 <t> [d]) x)) => (Neq8 (Const8 <t> [c-d]) x)
(CondSelect x _ (ConstBool [true ])) => x
(CondSelect _ y (ConstBool [false])) => y
(CondSelect x x _) => x
// fold eq / neq between a constant and a compile time bijective operation into the constant.
(Eq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Add(64|32|16|8) (Const(64|32|16|8) [d]) x)) && o.Uses == 1 => (Eq(64|32|16|8) (Const(64|32|16|8) <t> [c-d]) x)
(Neq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Add(64|32|16|8) (Const(64|32|16|8) [d]) x)) && o.Uses == 1 => (Neq(64|32|16|8) (Const(64|32|16|8) <t> [c-d]) x)
(Eq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Sub(64|32|16|8) x (Const(64|32|16|8) [d]))) && o.Uses == 1 => (Eq(64|32|16|8) (Const(64|32|16|8) <t> [c+d]) x)
(Neq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Sub(64|32|16|8) x (Const(64|32|16|8) [d]))) && o.Uses == 1 => (Neq(64|32|16|8) (Const(64|32|16|8) <t> [c+d]) x)
(Eq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Sub(64|32|16|8) (Const(64|32|16|8) [d]) x)) && o.Uses == 1 => (Eq(64|32|16|8) (Const(64|32|16|8) <t> [d-c]) x)
(Neq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Sub(64|32|16|8) (Const(64|32|16|8) [d]) x)) && o.Uses == 1 => (Neq(64|32|16|8) (Const(64|32|16|8) <t> [d-c]) x)
(Eq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Xor(64|32|16|8) (Const(64|32|16|8) [d]) x)) && o.Uses == 1 => (Eq(64|32|16|8) (Const(64|32|16|8) <t> [d^c]) x)
(Neq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Xor(64|32|16|8) (Const(64|32|16|8) [d]) x)) && o.Uses == 1 => (Neq(64|32|16|8) (Const(64|32|16|8) <t> [d^c]) x)
(Eq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Com(64|32|16|8) x)) && o.Uses == 1 => (Eq(64|32|16|8) (Const(64|32|16|8) <t> [^c]) x)
(Neq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Com(64|32|16|8) x)) && o.Uses == 1 => (Neq(64|32|16|8) (Const(64|32|16|8) <t> [^c]) x)
(Eq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Neg(64|32|16|8) x)) && o.Uses == 1 => (Eq(64|32|16|8) (Const(64|32|16|8) <t> [-c]) x)
(Neq(64|32|16|8) (Const(64|32|16|8) <t> [c]) o:(Neg(64|32|16|8) x)) && o.Uses == 1 => (Neq(64|32|16|8) (Const(64|32|16|8) <t> [-c]) x)
((Eq|Neq)64 (Const64 <t> [c]) o:(Mul64 (Const64 [d]) x)) && uint64(d)%2 == 1 && o.Uses == 1 => ((Eq|Neq)64 (Const64 <t> [int64(uint64(c) * modularMultiplicativeInverse(uint64(d))) ]) x)
((Eq|Neq)32 (Const32 <t> [c]) o:(Mul32 (Const32 [d]) x)) && uint32(d)%2 == 1 && o.Uses == 1 => ((Eq|Neq)32 (Const32 <t> [int32(uint32(c) * uint32(modularMultiplicativeInverse(uint64(d))))]) x)
((Eq|Neq)16 (Const16 <t> [c]) o:(Mul16 (Const16 [d]) x)) && uint16(d)%2 == 1 && o.Uses == 1 => ((Eq|Neq)16 (Const16 <t> [int16(uint16(c) * uint16(modularMultiplicativeInverse(uint64(d))))]) x)
((Eq|Neq)8 (Const8 <t> [c]) o:(Mul8 (Const8 [d]) x)) && uint8( d)%2 == 1 && o.Uses == 1 => ((Eq|Neq)8 (Const8 <t> [int8( uint8( c) * uint8( modularMultiplicativeInverse(uint64(d))))]) x)
// signed integer range: ( c <= x && x (<|<=) d ) -> ( unsigned(x-c) (<|<=) unsigned(d-c) )
(AndB (Leq64 (Const64 [c]) x) ((Less|Leq)64 x (Const64 [d]))) && d >= c => ((Less|Leq)64U (Sub64 <x.Type> x (Const64 <x.Type> [c])) (Const64 <x.Type> [d-c]))
(AndB (Leq32 (Const32 [c]) x) ((Less|Leq)32 x (Const32 [d]))) && d >= c => ((Less|Leq)32U (Sub32 <x.Type> x (Const32 <x.Type> [c])) (Const32 <x.Type> [d-c]))

View file

@ -2855,3 +2855,19 @@ func addToSub(op Op) Op {
panic(fmt.Sprintf("unexpected op %v", op))
}
}
func modularMultiplicativeInverse(x uint64) (y uint64) {
if x%2 != 1 {
panic("even numbers in a power-of-two modulus do not have a multiplicative inverse")
}
// we start with 3 bits of precision because each odd number is its own multiplicative inverse mod 8
y = x // 3 bits
// now use the Newton-Raphson method to double the number of correct bits in each iteration.
y *= 2 - x*y // 6 bits
y *= 2 - x*y // 12 bits
y *= 2 - x*y // 24 bits
y *= 2 - x*y // 48 bits
y *= 2 - x*y // 96 bits; good enough
return
}

View file

@ -359,3 +359,20 @@ func TestDisjointTypesRun(t *testing.T) {
t.Errorf("disjointTypes gives an incorrect answer that leads to an incorrect optimization.")
}
}
func TestModularMultiplicativeInverse(t *testing.T) {
t.Parallel()
// We've got 63 bits of phase space for the Multiplier
// Needless to say this is too much to bruteforce here.
// I've randomly picked a range of 1<<24 because it runs in 0.03s on my machine which isn't too slow.
// We test both sides of the wrapping point (0 and math.MaxUint64) since we need to test something and it's a usual place to have bugs.
const halfRange = 1 << 23
for i := -int64(halfRange) - 1; i < halfRange; i += 2 { // odd only, a bit after to a bit before the wrapping point
mmi := modularMultiplicativeInverse(uint64(i))
if uint64(i)*mmi != 1 {
t.Errorf("%d * modularMultiplicativeInverse(%d) != 1; modularMultiplicativeInverse(%d) == %d", i, i, i, mmi)
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -930,3 +930,45 @@ func cmpstring2(x, y string) int {
//amd64:-`MOVQ .*\(SP\)`
return cmp.Compare(x, y)
}
func bijectiveAdd(x uint) bool {
// amd64: -"ADD"
// arm64: -"ADD"
return x+1337 == 42
}
func bijectiveSub1(x uint) bool {
// amd64: -"SUB"
// arm64: -"SUB"
return x-1337 == 42
}
func bijectiveSub2(x uint) bool {
// amd64: -"SUB"
// arm64: -"SUB"
return 1337-x == 42
}
func bijectiveXor(x uint) bool {
// amd64: -"XOR"
// arm64: -"EOR"
return x^1337 == 42
}
func bijectiveCom(x uint) bool {
// amd64: -"NOT"
// arm64: -"MVN"
return ^x == 42
}
func bijectiveNeg(x int) bool {
// amd64: -"NEG"
// arm64: -"NEG"
return -x == 42
}
func bijectiveMul(x uint) bool {
// amd64: -"MUL"
// arm64: -"MUL"
return x*1337 == 42
}