diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules index 6fdea7cc7a..048d9958dc 100644 --- a/src/cmd/compile/internal/ssa/_gen/generic.rules +++ b/src/cmd/compile/internal/ssa/_gen/generic.rules @@ -347,6 +347,22 @@ (OrB ((Less|Leq)16U (Const16 [c]) x) (Leq16U x (Const16 [d]))) && uint16(c) >= uint16(d+1) && uint16(d+1) > uint16(d) => ((Less|Leq)16U (Const16 [c-d-1]) (Sub16 x (Const16 [d+1]))) (OrB ((Less|Leq)8U (Const8 [c]) x) (Leq8U x (Const8 [d]))) && uint8(c) >= uint8(d+1) && uint8(d+1) > uint8(d) => ((Less|Leq)8U (Const8 [c-d-1]) (Sub8 x (Const8 [d+1]))) +// NaN check: ( x != x || x (>|>=|<|<=) c ) -> ( !(c (>=|>|<=|<) x) ) +(OrB (Neq64F x x) ((Less|Leq)64F x y:(Const64F [c]))) => (Not ((Leq|Less)64F y x)) +(OrB (Neq64F x x) ((Less|Leq)64F y:(Const64F [c]) x)) => (Not ((Leq|Less)64F x y)) +(OrB (Neq32F x x) ((Less|Leq)32F x y:(Const32F [c]))) => (Not ((Leq|Less)32F y x)) +(OrB (Neq32F x x) ((Less|Leq)32F y:(Const32F [c]) x)) => (Not ((Leq|Less)32F x y)) + +// NaN check: ( x != x || Abs(x) (>|>=|<|<=) c ) -> ( !(c (>=|>|<=|<) Abs(x) ) +(OrB (Neq64F x x) ((Less|Leq)64F abs:(Abs x) y:(Const64F [c]))) => (Not ((Leq|Less)64F y abs)) +(OrB (Neq64F x x) ((Less|Leq)64F y:(Const64F [c]) abs:(Abs x))) => (Not ((Leq|Less)64F abs y)) + +// NaN check: ( x != x || -x (>|>=|<|<=) c ) -> ( !(c (>=|>|<=|<) -x) ) +(OrB (Neq64F x x) ((Less|Leq)64F neg:(Neg64F x) y:(Const64F [c]))) => (Not ((Leq|Less)64F y neg)) +(OrB (Neq64F x x) ((Less|Leq)64F y:(Const64F [c]) neg:(Neg64F x))) => (Not ((Leq|Less)64F neg y)) +(OrB (Neq32F x x) ((Less|Leq)32F neg:(Neg32F x) y:(Const32F [c]))) => (Not ((Leq|Less)32F y neg)) +(OrB (Neq32F x x) ((Less|Leq)32F y:(Const32F [c]) neg:(Neg32F x))) => (Not ((Leq|Less)32F neg y)) + // Canonicalize x-const to x+(-const) (Sub64 x (Const64 [c])) && x.Op != OpConst64 => (Add64 (Const64 [-c]) x) (Sub32 x (Const32 [c])) && x.Op != OpConst32 => (Add32 (Const32 [-c]) x) diff --git a/src/cmd/compile/internal/ssa/fuse.go b/src/cmd/compile/internal/ssa/fuse.go index 68defde7b4..0cee91b532 100644 --- a/src/cmd/compile/internal/ssa/fuse.go +++ b/src/cmd/compile/internal/ssa/fuse.go @@ -9,8 +9,8 @@ import ( "fmt" ) -// fuseEarly runs fuse(f, fuseTypePlain|fuseTypeIntInRange). -func fuseEarly(f *Func) { fuse(f, fuseTypePlain|fuseTypeIntInRange) } +// fuseEarly runs fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck). +func fuseEarly(f *Func) { fuse(f, fuseTypePlain|fuseTypeIntInRange|fuseTypeNanCheck) } // fuseLate runs fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect). func fuseLate(f *Func) { fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect) } @@ -21,6 +21,7 @@ const ( fuseTypePlain fuseType = 1 << iota fuseTypeIf fuseTypeIntInRange + fuseTypeNanCheck fuseTypeBranchRedirect fuseTypeShortCircuit ) @@ -38,7 +39,10 @@ func fuse(f *Func, typ fuseType) { changed = fuseBlockIf(b) || changed } if typ&fuseTypeIntInRange != 0 { - changed = fuseIntegerComparisons(b) || changed + changed = fuseIntInRange(b) || changed + } + if typ&fuseTypeNanCheck != 0 { + changed = fuseNanCheck(b) || changed } if typ&fuseTypePlain != 0 { changed = fuseBlockPlain(b) || changed diff --git a/src/cmd/compile/internal/ssa/fuse_comparisons.go b/src/cmd/compile/internal/ssa/fuse_comparisons.go index f5fb84b0d7..b6eb8fcb90 100644 --- a/src/cmd/compile/internal/ssa/fuse_comparisons.go +++ b/src/cmd/compile/internal/ssa/fuse_comparisons.go @@ -4,21 +4,36 @@ package ssa -// fuseIntegerComparisons optimizes inequalities such as '1 <= x && x < 5', -// which can be optimized to 'unsigned(x-1) < 4'. +// fuseIntInRange transforms integer range checks to remove the short-circuit operator. For example, +// it would convert `if 1 <= x && x < 5 { ... }` into `if (1 <= x) & (x < 5) { ... }`. Rewrite rules +// can then optimize these into unsigned range checks, `if unsigned(x-1) < 4 { ... }` in this case. +func fuseIntInRange(b *Block) bool { + return fuseComparisons(b, canOptIntInRange) +} + +// fuseNanCheck replaces the short-circuit operators between NaN checks and comparisons with +// constants. For example, it would transform `if x != x || x > 1.0 { ... }` into +// `if (x != x) | (x > 1.0) { ... }`. Rewrite rules can then merge the NaN check with the comparison, +// in this case generating `if !(x <= 1.0) { ... }`. +func fuseNanCheck(b *Block) bool { + return fuseComparisons(b, canOptNanCheck) +} + +// fuseComparisons looks for control graphs that match this pattern: // -// Look for branch structure like: -// -// p +// p - predecessor // |\ -// | b +// | b - block // |/ \ -// s0 s1 +// s0 s1 - successors // -// In our example, p has control '1 <= x', b has control 'x < 5', -// and s0 and s1 are the if and else results of the comparison. +// This pattern is typical for if statements such as `if x || y { ... }` and `if x && y { ... }`. // -// This will be optimized into: +// If canOptControls returns true when passed the control values for p and b then fuseComparisons +// will try to convert p into a plain block with only one successor (b) and modify b's control +// value to include p's control value (effectively causing b to be speculatively executed). +// +// This transformation results in a control graph that will now look like this: // // p // \ @@ -26,9 +41,12 @@ package ssa // / \ // s0 s1 // -// where b has the combined control value 'unsigned(x-1) < 4'. // Later passes will then fuse p and b. -func fuseIntegerComparisons(b *Block) bool { +// +// In other words `if x || y { ... }` will become `if x | y { ... }` and `if x && y { ... }` will +// become `if x & y { ... }`. This is a useful transformation because we can then use rewrite +// rules to optimize `x | y` and `x & y`. +func fuseComparisons(b *Block, canOptControls func(a, b *Value, op Op) bool) bool { if len(b.Preds) != 1 { return false } @@ -45,14 +63,6 @@ func fuseIntegerComparisons(b *Block) bool { return false } - // Check if the control values combine to make an integer inequality that - // can be further optimized later. - bc := b.Controls[0] - pc := p.Controls[0] - if !areMergeableInequalities(bc, pc) { - return false - } - // If the first (true) successors match then we have a disjunction (||). // If the second (false) successors match then we have a conjunction (&&). for i, op := range [2]Op{OpOrB, OpAndB} { @@ -60,6 +70,13 @@ func fuseIntegerComparisons(b *Block) bool { continue } + // Check if the control values can be usefully combined. + bc := b.Controls[0] + pc := p.Controls[0] + if !canOptControls(bc, pc, op) { + return false + } + // TODO(mundaym): should we also check the cost of executing b? // Currently we might speculatively execute b even if b contains // a lot of instructions. We could just check that len(b.Values) @@ -125,7 +142,7 @@ func isUnsignedInequality(v *Value) bool { return false } -func areMergeableInequalities(x, y *Value) bool { +func canOptIntInRange(x, y *Value, op Op) bool { // We need both inequalities to be either in the signed or unsigned domain. // TODO(mundaym): it would also be good to merge when we have an Eq op that // could be transformed into a Less/Leq. For example in the unsigned @@ -155,3 +172,60 @@ func areMergeableInequalities(x, y *Value) bool { } return false } + +// canOptNanCheck reports whether one of arguments is a NaN check and the other +// is a comparison with a constant that can be combined together. +// +// Examples (c must be a constant): +// +// v != v || v < c => !(c <= v) +// v != v || v <= c => !(c < v) +// v != v || c < v => !(v <= c) +// v != v || c <= v => !(v < c) +func canOptNanCheck(x, y *Value, op Op) bool { + if op != OpOrB { + return false + } + + for i := 0; i <= 1; i, x, y = i+1, y, x { + if len(x.Args) != 2 || x.Args[0] != x.Args[1] { + continue + } + v := x.Args[0] + switch x.Op { + case OpNeq64F: + if y.Op != OpLess64F && y.Op != OpLeq64F { + return false + } + for j := 0; j <= 1; j++ { + a, b := y.Args[j], y.Args[j^1] + if a.Op != OpConst64F { + continue + } + // Sign bit operations not affect NaN check results. This special case allows us + // to optimize statements like `if v != v || Abs(v) > c { ... }`. + if (b.Op == OpAbs || b.Op == OpNeg64F) && b.Args[0] == v { + return true + } + return b == v + } + case OpNeq32F: + if y.Op != OpLess32F && y.Op != OpLeq32F { + return false + } + for j := 0; j <= 1; j++ { + a, b := y.Args[j], y.Args[j^1] + if a.Op != OpConst32F { + continue + } + // Sign bit operations not affect NaN check results. This special case allows us + // to optimize statements like `if v != v || -v > c { ... }`. + if b.Op == OpNeg32F && b.Args[0] == v { + return true + } + return b == v + } + } + } + return false +} diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index 5720063f34..37ba324d86 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -23957,6 +23957,7 @@ func rewriteValuegeneric_OpOrB(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] b := v.Block + typ := &b.Func.Config.Types // match: (OrB (Less64 (Const64 [c]) x) (Less64 x (Const64 [d]))) // cond: c >= d // result: (Less64U (Const64 [c-d]) (Sub64 x (Const64 [d]))) @@ -25269,6 +25270,558 @@ func rewriteValuegeneric_OpOrB(v *Value) bool { } break } + // match: (OrB (Neq64F x x) (Less64F x y:(Const64F [c]))) + // result: (Not (Leq64F y x)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess64F { + continue + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst64F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq64F, typ.Bool) + v0.AddArg2(y, x) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Leq64F x y:(Const64F [c]))) + // result: (Not (Less64F y x)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq64F { + continue + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst64F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess64F, typ.Bool) + v0.AddArg2(y, x) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Less64F y:(Const64F [c]) x)) + // result: (Not (Leq64F x y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess64F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst64F { + continue + } + if x != v_1.Args[1] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq64F, typ.Bool) + v0.AddArg2(x, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Leq64F y:(Const64F [c]) x)) + // result: (Not (Less64F x y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq64F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst64F { + continue + } + if x != v_1.Args[1] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess64F, typ.Bool) + v0.AddArg2(x, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Less32F x y:(Const32F [c]))) + // result: (Not (Leq32F y x)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess32F { + continue + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst32F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq32F, typ.Bool) + v0.AddArg2(y, x) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Leq32F x y:(Const32F [c]))) + // result: (Not (Less32F y x)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq32F { + continue + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst32F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess32F, typ.Bool) + v0.AddArg2(y, x) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Less32F y:(Const32F [c]) x)) + // result: (Not (Leq32F x y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess32F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst32F { + continue + } + if x != v_1.Args[1] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq32F, typ.Bool) + v0.AddArg2(x, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Leq32F y:(Const32F [c]) x)) + // result: (Not (Less32F x y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq32F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst32F { + continue + } + if x != v_1.Args[1] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess32F, typ.Bool) + v0.AddArg2(x, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Less64F abs:(Abs x) y:(Const64F [c]))) + // result: (Not (Leq64F y abs)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess64F { + continue + } + _ = v_1.Args[1] + abs := v_1.Args[0] + if abs.Op != OpAbs || x != abs.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst64F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq64F, typ.Bool) + v0.AddArg2(y, abs) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Leq64F abs:(Abs x) y:(Const64F [c]))) + // result: (Not (Less64F y abs)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq64F { + continue + } + _ = v_1.Args[1] + abs := v_1.Args[0] + if abs.Op != OpAbs || x != abs.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst64F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess64F, typ.Bool) + v0.AddArg2(y, abs) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Less64F y:(Const64F [c]) abs:(Abs x))) + // result: (Not (Leq64F abs y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess64F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst64F { + continue + } + abs := v_1.Args[1] + if abs.Op != OpAbs || x != abs.Args[0] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq64F, typ.Bool) + v0.AddArg2(abs, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Leq64F y:(Const64F [c]) abs:(Abs x))) + // result: (Not (Less64F abs y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq64F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst64F { + continue + } + abs := v_1.Args[1] + if abs.Op != OpAbs || x != abs.Args[0] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess64F, typ.Bool) + v0.AddArg2(abs, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Less64F neg:(Neg64F x) y:(Const64F [c]))) + // result: (Not (Leq64F y neg)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess64F { + continue + } + _ = v_1.Args[1] + neg := v_1.Args[0] + if neg.Op != OpNeg64F || x != neg.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst64F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq64F, typ.Bool) + v0.AddArg2(y, neg) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Leq64F neg:(Neg64F x) y:(Const64F [c]))) + // result: (Not (Less64F y neg)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq64F { + continue + } + _ = v_1.Args[1] + neg := v_1.Args[0] + if neg.Op != OpNeg64F || x != neg.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst64F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess64F, typ.Bool) + v0.AddArg2(y, neg) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Less64F y:(Const64F [c]) neg:(Neg64F x))) + // result: (Not (Leq64F neg y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess64F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst64F { + continue + } + neg := v_1.Args[1] + if neg.Op != OpNeg64F || x != neg.Args[0] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq64F, typ.Bool) + v0.AddArg2(neg, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq64F x x) (Leq64F y:(Const64F [c]) neg:(Neg64F x))) + // result: (Not (Less64F neg y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq64F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq64F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst64F { + continue + } + neg := v_1.Args[1] + if neg.Op != OpNeg64F || x != neg.Args[0] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess64F, typ.Bool) + v0.AddArg2(neg, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Less32F neg:(Neg32F x) y:(Const32F [c]))) + // result: (Not (Leq32F y neg)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess32F { + continue + } + _ = v_1.Args[1] + neg := v_1.Args[0] + if neg.Op != OpNeg32F || x != neg.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst32F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq32F, typ.Bool) + v0.AddArg2(y, neg) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Leq32F neg:(Neg32F x) y:(Const32F [c]))) + // result: (Not (Less32F y neg)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq32F { + continue + } + _ = v_1.Args[1] + neg := v_1.Args[0] + if neg.Op != OpNeg32F || x != neg.Args[0] { + continue + } + y := v_1.Args[1] + if y.Op != OpConst32F { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess32F, typ.Bool) + v0.AddArg2(y, neg) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Less32F y:(Const32F [c]) neg:(Neg32F x))) + // result: (Not (Leq32F neg y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLess32F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst32F { + continue + } + neg := v_1.Args[1] + if neg.Op != OpNeg32F || x != neg.Args[0] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLeq32F, typ.Bool) + v0.AddArg2(neg, y) + v.AddArg(v0) + return true + } + break + } + // match: (OrB (Neq32F x x) (Leq32F y:(Const32F [c]) neg:(Neg32F x))) + // result: (Not (Less32F neg y)) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + if v_0.Op != OpNeq32F { + continue + } + x := v_0.Args[1] + if x != v_0.Args[0] || v_1.Op != OpLeq32F { + continue + } + _ = v_1.Args[1] + y := v_1.Args[0] + if y.Op != OpConst32F { + continue + } + neg := v_1.Args[1] + if neg.Op != OpNeg32F || x != neg.Args[0] { + continue + } + v.reset(OpNot) + v0 := b.NewValue0(v.Pos, OpLess32F, typ.Bool) + v0.AddArg2(neg, y) + v.AddArg(v0) + return true + } + break + } return false } func rewriteValuegeneric_OpPhi(v *Value) bool { diff --git a/src/cmd/compile/internal/test/float_test.go b/src/cmd/compile/internal/test/float_test.go index 9e61148c52..7a5e27870f 100644 --- a/src/cmd/compile/internal/test/float_test.go +++ b/src/cmd/compile/internal/test/float_test.go @@ -623,6 +623,110 @@ func TestInf(t *testing.T) { } } +//go:noinline +func isNaNOrGtZero64(x float64) bool { + return math.IsNaN(x) || x > 0 +} + +//go:noinline +func isNaNOrGteZero64(x float64) bool { + return x >= 0 || math.IsNaN(x) +} + +//go:noinline +func isNaNOrLtZero64(x float64) bool { + return x < 0 || math.IsNaN(x) +} + +//go:noinline +func isNaNOrLteZero64(x float64) bool { + return math.IsNaN(x) || x <= 0 +} + +func TestFusedNaNChecks64(t *testing.T) { + tests := []struct { + value float64 + isZero bool + isGreaterThanZero bool + isLessThanZero bool + isNaN bool + }{ + {value: 0.0, isZero: true}, + {value: math.Copysign(0, -1), isZero: true}, + {value: 1.0, isGreaterThanZero: true}, + {value: -1.0, isLessThanZero: true}, + {value: math.Inf(1), isGreaterThanZero: true}, + {value: math.Inf(-1), isLessThanZero: true}, + {value: math.NaN(), isNaN: true}, + } + + check := func(name string, f func(x float64) bool, value float64, want bool) { + got := f(value) + if got != want { + t.Errorf("%v(%g): want %v, got %v", name, value, want, got) + } + } + + for _, test := range tests { + check("isNaNOrGtZero64", isNaNOrGtZero64, test.value, test.isNaN || test.isGreaterThanZero) + check("isNaNOrGteZero64", isNaNOrGteZero64, test.value, test.isNaN || test.isGreaterThanZero || test.isZero) + check("isNaNOrLtZero64", isNaNOrLtZero64, test.value, test.isNaN || test.isLessThanZero) + check("isNaNOrLteZero64", isNaNOrLteZero64, test.value, test.isNaN || test.isLessThanZero || test.isZero) + } +} + +//go:noinline +func isNaNOrGtZero32(x float32) bool { + return x > 0 || x != x +} + +//go:noinline +func isNaNOrGteZero32(x float32) bool { + return x != x || x >= 0 +} + +//go:noinline +func isNaNOrLtZero32(x float32) bool { + return x != x || x < 0 +} + +//go:noinline +func isNaNOrLteZero32(x float32) bool { + return x <= 0 || x != x +} + +func TestFusedNaNChecks32(t *testing.T) { + tests := []struct { + value float32 + isZero bool + isGreaterThanZero bool + isLessThanZero bool + isNaN bool + }{ + {value: 0.0, isZero: true}, + {value: float32(math.Copysign(0, -1)), isZero: true}, + {value: 1.0, isGreaterThanZero: true}, + {value: -1.0, isLessThanZero: true}, + {value: float32(math.Inf(1)), isGreaterThanZero: true}, + {value: float32(math.Inf(-1)), isLessThanZero: true}, + {value: float32(math.NaN()), isNaN: true}, + } + + check := func(name string, f func(x float32) bool, value float32, want bool) { + got := f(value) + if got != want { + t.Errorf("%v(%g): want %v, got %v", name, value, want, got) + } + } + + for _, test := range tests { + check("isNaNOrGtZero32", isNaNOrGtZero32, test.value, test.isNaN || test.isGreaterThanZero) + check("isNaNOrGteZero32", isNaNOrGteZero32, test.value, test.isNaN || test.isGreaterThanZero || test.isZero) + check("isNaNOrLtZero32", isNaNOrLtZero32, test.value, test.isNaN || test.isLessThanZero) + check("isNaNOrLteZero32", isNaNOrLteZero32, test.value, test.isNaN || test.isLessThanZero || test.isZero) + } +} + var sinkFloat float64 func BenchmarkMul2(b *testing.B) { diff --git a/test/codegen/fuse.go b/test/codegen/fuse.go index 8d6ea3c5c7..561bac7224 100644 --- a/test/codegen/fuse.go +++ b/test/codegen/fuse.go @@ -6,6 +6,8 @@ package codegen +import "math" + // Notes: // - these examples use channels to provide a source of // unknown values that cannot be optimized away @@ -196,6 +198,84 @@ func ui4d(c <-chan uint8) { } } +// -------------------------------------// +// merge NaN checks // +// ------------------------------------ // + +func f64NaNOrPosInf(c <-chan float64) { + // This test assumes IsInf(x, 1) is implemented as x > MaxFloat rather than x == Inf(1). + + // amd64:"JCS",-"JNE",-"JPS",-"JPC" + // riscv64:"FCLASSD",-"FLED",-"FLTD",-"FNED",-"FEQD" + for x := <-c; math.IsNaN(x) || math.IsInf(x, 1); x = <-c { + } +} + +func f64NaNOrNegInf(c <-chan float64) { + // This test assumes IsInf(x, -1) is implemented as x < -MaxFloat rather than x == Inf(-1). + + // amd64:"JCS",-"JNE",-"JPS",-"JPC" + // riscv64:"FCLASSD",-"FLED",-"FLTD",-"FNED",-"FEQD" + for x := <-c; math.IsNaN(x) || math.IsInf(x, -1); x = <-c { + } +} + +func f64NaNOrLtOne(c <-chan float64) { + // amd64:"JCS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLED",-"FLTD",-"FNED",-"FEQD" + for x := <-c; math.IsNaN(x) || x < 1; x = <-c { + } +} + +func f64NaNOrLteOne(c <-chan float64) { + // amd64:"JLS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLTD",-"FLED",-"FNED",-"FEQD" + for x := <-c; x <= 1 || math.IsNaN(x); x = <-c { + } +} + +func f64NaNOrGtOne(c <-chan float64) { + // amd64:"JCS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLED",-"FLTD",-"FNED",-"FEQD" + for x := <-c; math.IsNaN(x) || x > 1; x = <-c { + } +} + +func f64NaNOrGteOne(c <-chan float64) { + // amd64:"JLS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLTD",-"FLED",-"FNED",-"FEQD" + for x := <-c; x >= 1 || math.IsNaN(x); x = <-c { + } +} + +func f32NaNOrLtOne(c <-chan float32) { + // amd64:"JCS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLES",-"FLTS",-"FNES",-"FEQS" + for x := <-c; x < 1 || x != x; x = <-c { + } +} + +func f32NaNOrLteOne(c <-chan float32) { + // amd64:"JLS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLTS",-"FLES",-"FNES",-"FEQS" + for x := <-c; x != x || x <= 1; x = <-c { + } +} + +func f32NaNOrGtOne(c <-chan float32) { + // amd64:"JCS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLES",-"FLTS",-"FNES",-"FEQS" + for x := <-c; x > 1 || x != x; x = <-c { + } +} + +func f32NaNOrGteOne(c <-chan float32) { + // amd64:"JLS",-"JNE",-"JPS",-"JPC" + // riscv64:"FLTS",-"FLES",-"FNES",-"FEQS" + for x := <-c; x != x || x >= 1; x = <-c { + } +} + // ------------------------------------ // // regressions // // ------------------------------------ //