cmd/compile: add generic rules to remove bool → int → bool roundtrips

Change-Id: I8b0a3b64c89fe167d304f901a5d38470f35400ab
Reviewed-on: https://go-review.googlesource.com/c/go/+/715200
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Keith Randall <khr@google.com>
Auto-Submit: Jorropo <jorropo.pgm@gmail.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
This commit is contained in:
Jorropo 2025-10-27 13:05:41 +01:00
parent 1662d55247
commit 73d7635fae
3 changed files with 592 additions and 0 deletions

View file

@ -2984,3 +2984,17 @@
// if b { x >>= 1 } => x >>= b // if b { x >>= 1 } => x >>= b
(CondSelect (Rsh(64|32|16|8)x64 x (Const64 [1])) x bool) => (Rsh(64|32|16|8)x8 [true] x (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)) (CondSelect (Rsh(64|32|16|8)x64 x (Const64 [1])) x bool) => (Rsh(64|32|16|8)x8 [true] x (CvtBoolToUint8 <types.Types[types.TUINT8]> bool))
(CondSelect (Rsh(64|32|16|8)Ux64 x (Const64 [1])) x bool) => (Rsh(64|32|16|8)Ux8 [true] x (CvtBoolToUint8 <types.Types[types.TUINT8]> bool)) (CondSelect (Rsh(64|32|16|8)Ux64 x (Const64 [1])) x bool) => (Rsh(64|32|16|8)Ux8 [true] x (CvtBoolToUint8 <types.Types[types.TUINT8]> bool))
// bool(int(x)) => x
(Neq8 (CvtBoolToUint8 x) (Const8 [0])) => x
(Neq8 (CvtBoolToUint8 x) (Const8 [1])) => (Not x)
(Eq8 (CvtBoolToUint8 x) (Const8 [1])) => x
(Eq8 (CvtBoolToUint8 x) (Const8 [0])) => (Not x)
(Neq(64|32|16) (ZeroExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [0])) => x
(Neq(64|32|16) (ZeroExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [1])) => (Not x)
(Eq(64|32|16) (ZeroExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [1])) => x
(Eq(64|32|16) (ZeroExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [0])) => (Not x)
(Neq(64|32|16) (SignExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [0])) => x
(Neq(64|32|16) (SignExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [1])) => (Not x)
(Eq(64|32|16) (SignExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [1])) => x
(Eq(64|32|16) (SignExt8to(64|32|16) (CvtBoolToUint8 x)) (Const(64|32|16) [0])) => (Not x)

View file

@ -8847,6 +8847,88 @@ func rewriteValuegeneric_OpEq16(v *Value) bool {
} }
break break
} }
// match: (Eq16 (ZeroExt8to16 (CvtBoolToUint8 x)) (Const16 [1]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 1 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Eq16 (ZeroExt8to16 (CvtBoolToUint8 x)) (Const16 [0]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 0 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
// match: (Eq16 (SignExt8to16 (CvtBoolToUint8 x)) (Const16 [1]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 1 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Eq16 (SignExt8to16 (CvtBoolToUint8 x)) (Const16 [0]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 0 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpEq32(v *Value) bool { func rewriteValuegeneric_OpEq32(v *Value) bool {
@ -9711,6 +9793,88 @@ func rewriteValuegeneric_OpEq32(v *Value) bool {
} }
break break
} }
// match: (Eq32 (ZeroExt8to32 (CvtBoolToUint8 x)) (Const32 [1]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 1 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Eq32 (ZeroExt8to32 (CvtBoolToUint8 x)) (Const32 [0]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 0 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
// match: (Eq32 (SignExt8to32 (CvtBoolToUint8 x)) (Const32 [1]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 1 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Eq32 (SignExt8to32 (CvtBoolToUint8 x)) (Const32 [0]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 0 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpEq32F(v *Value) bool { func rewriteValuegeneric_OpEq32F(v *Value) bool {
@ -10292,6 +10456,88 @@ func rewriteValuegeneric_OpEq64(v *Value) bool {
} }
break break
} }
// match: (Eq64 (ZeroExt8to64 (CvtBoolToUint8 x)) (Const64 [1]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 1 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Eq64 (ZeroExt8to64 (CvtBoolToUint8 x)) (Const64 [0]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 0 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
// match: (Eq64 (SignExt8to64 (CvtBoolToUint8 x)) (Const64 [1]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 1 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Eq64 (SignExt8to64 (CvtBoolToUint8 x)) (Const64 [0]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 0 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpEq64F(v *Value) bool { func rewriteValuegeneric_OpEq64F(v *Value) bool {
@ -10714,6 +10960,39 @@ func rewriteValuegeneric_OpEq8(v *Value) bool {
} }
break break
} }
// match: (Eq8 (CvtBoolToUint8 x) (Const8 [1]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0.Args[0]
if v_1.Op != OpConst8 || auxIntToInt8(v_1.AuxInt) != 1 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Eq8 (CvtBoolToUint8 x) (Const8 [0]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0.Args[0]
if v_1.Op != OpConst8 || auxIntToInt8(v_1.AuxInt) != 0 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpEqB(v *Value) bool { func rewriteValuegeneric_OpEqB(v *Value) bool {
@ -20214,6 +20493,88 @@ func rewriteValuegeneric_OpNeq16(v *Value) bool {
} }
break break
} }
// match: (Neq16 (ZeroExt8to16 (CvtBoolToUint8 x)) (Const16 [0]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 0 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Neq16 (ZeroExt8to16 (CvtBoolToUint8 x)) (Const16 [1]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 1 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
// match: (Neq16 (SignExt8to16 (CvtBoolToUint8 x)) (Const16 [0]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 0 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Neq16 (SignExt8to16 (CvtBoolToUint8 x)) (Const16 [1]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to16 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst16 || auxIntToInt16(v_1.AuxInt) != 1 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpNeq32(v *Value) bool { func rewriteValuegeneric_OpNeq32(v *Value) bool {
@ -20401,6 +20762,88 @@ func rewriteValuegeneric_OpNeq32(v *Value) bool {
} }
break break
} }
// match: (Neq32 (ZeroExt8to32 (CvtBoolToUint8 x)) (Const32 [0]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 0 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Neq32 (ZeroExt8to32 (CvtBoolToUint8 x)) (Const32 [1]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 1 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
// match: (Neq32 (SignExt8to32 (CvtBoolToUint8 x)) (Const32 [0]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 0 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Neq32 (SignExt8to32 (CvtBoolToUint8 x)) (Const32 [1]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to32 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst32 || auxIntToInt32(v_1.AuxInt) != 1 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpNeq32F(v *Value) bool { func rewriteValuegeneric_OpNeq32F(v *Value) bool {
@ -20611,6 +21054,88 @@ func rewriteValuegeneric_OpNeq64(v *Value) bool {
} }
break break
} }
// match: (Neq64 (ZeroExt8to64 (CvtBoolToUint8 x)) (Const64 [0]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 0 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Neq64 (ZeroExt8to64 (CvtBoolToUint8 x)) (Const64 [1]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpZeroExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 1 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
// match: (Neq64 (SignExt8to64 (CvtBoolToUint8 x)) (Const64 [0]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 0 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Neq64 (SignExt8to64 (CvtBoolToUint8 x)) (Const64 [1]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpSignExt8to64 {
continue
}
v_0_0 := v_0.Args[0]
if v_0_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0_0.Args[0]
if v_1.Op != OpConst64 || auxIntToInt64(v_1.AuxInt) != 1 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpNeq64F(v *Value) bool { func rewriteValuegeneric_OpNeq64F(v *Value) bool {
@ -20821,6 +21346,39 @@ func rewriteValuegeneric_OpNeq8(v *Value) bool {
} }
break break
} }
// match: (Neq8 (CvtBoolToUint8 x) (Const8 [0]))
// result: x
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0.Args[0]
if v_1.Op != OpConst8 || auxIntToInt8(v_1.AuxInt) != 0 {
continue
}
v.copyOf(x)
return true
}
break
}
// match: (Neq8 (CvtBoolToUint8 x) (Const8 [1]))
// result: (Not x)
for {
for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
if v_0.Op != OpCvtBoolToUint8 {
continue
}
x := v_0.Args[0]
if v_1.Op != OpConst8 || auxIntToInt8(v_1.AuxInt) != 1 {
continue
}
v.reset(OpNot)
v.AddArg(x)
return true
}
break
}
return false return false
} }
func rewriteValuegeneric_OpNeqB(v *Value) bool { func rewriteValuegeneric_OpNeqB(v *Value) bool {

View file

@ -507,3 +507,23 @@ func cmovmathhalveu(a uint, b bool) uint {
// wasm:"I64ShrU", -"Select" // wasm:"I64ShrU", -"Select"
return a return a
} }
func branchlessBoolToUint8(b bool) (r uint8) {
if b {
r = 1
}
return
}
func cmovFromMulFromFlags64(x uint64, b bool) uint64 {
// amd64:-"MOVB.ZX"
r := uint64(branchlessBoolToUint8(b))
// amd64:"CMOV",-"MOVB.ZX",-"MUL"
return x * r
}
func cmovFromMulFromFlags64sext(x int64, b bool) int64 {
// amd64:-"MOVB.ZX"
r := int64(int8(branchlessBoolToUint8(b)))
// amd64:"CMOV",-"MOVB.ZX",-"MUL"
return x * r
}