cmd/compile: try to rewrite loops to count down

Fixes #61629

This reduce the pressure on regalloc because then the loop only keep alive
one value (the iterator) instead of the iterator and the upper bound since
the comparison now acts against an immediate, often zero which can be skipped.

This optimize things like:
  for i := 0; i < n; i++ {
Or a range over a slice where the index is not used:
  for _, v := range someSlice {
Or the new range over int from #61405:
  for range n {

It is hit in 975 unique places while doing ./make.bash.

Change-Id: I5facff8b267a0b60ea3c1b9a58c4d74cdb38f03f
Reviewed-on: https://go-review.googlesource.com/c/go/+/512935
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Jorropo <jorropo.pgm@gmail.com>
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Auto-Submit: Keith Randall <khr@golang.org>
This commit is contained in:
Jorropo 2023-07-25 16:19:10 +02:00 committed by Gopher Robot
parent 8613ef81e6
commit bac4e2f241
4 changed files with 188 additions and 21 deletions

View file

@ -798,6 +798,166 @@ func (ft *factsTable) cleanup(f *Func) {
// its negation. If either leads to a contradiction, it can trim that
// successor.
func prove(f *Func) {
// Find induction variables. Currently, findIndVars
// is limited to one induction variable per block.
var indVars map[*Block]indVar
for _, v := range findIndVar(f) {
ind := v.ind
if len(ind.Args) != 2 {
// the rewrite code assumes there is only ever two parents to loops
panic("unexpected induction with too many parents")
}
nxt := v.nxt
if !(ind.Uses == 2 && // 2 used by comparison and next
nxt.Uses == 1) { // 1 used by induction
// ind or nxt is used inside the loop, add it for the facts table
if indVars == nil {
indVars = make(map[*Block]indVar)
}
indVars[v.entry] = v
continue
} else {
// Since this induction variable is not used for anything but counting the iterations,
// no point in putting it into the facts table.
}
// try to rewrite to a downward counting loop checking against start if the
// loop body does not depends on ind or nxt and end is known before the loop.
// This reduce pressure on the register allocator because this do not need
// to use end on each iteration anymore. We compare against the start constant instead.
// That means this code:
//
// loop:
// ind = (Phi (Const [x]) nxt),
// if ind < end
// then goto enter_loop
// else goto exit_loop
//
// enter_loop:
// do something without using ind nor nxt
// nxt = inc + ind
// goto loop
//
// exit_loop:
//
// is rewritten to:
//
// loop:
// ind = (Phi end nxt)
// if (Const [x]) < ind
// then goto enter_loop
// else goto exit_loop
//
// enter_loop:
// do something without using ind nor nxt
// nxt = ind - inc
// goto loop
//
// exit_loop:
//
// this is better because it only require to keep ind then nxt alive while looping,
// while the original form keeps ind then nxt and end alive
start, end := v.min, v.max
if v.flags&indVarCountDown != 0 {
start, end = end, start
}
if !(start.Op == OpConst8 || start.Op == OpConst16 || start.Op == OpConst32 || start.Op == OpConst64) {
// if start is not a constant we would be winning nothing from inverting the loop
continue
}
if end.Op == OpConst8 || end.Op == OpConst16 || end.Op == OpConst32 || end.Op == OpConst64 {
// TODO: if both start and end are constants we should rewrite such that the comparison
// is against zero and nxt is ++ or -- operation
// That means:
// for i := 2; i < 11; i += 2 {
// should be rewritten to:
// for i := 5; 0 < i; i-- {
continue
}
header := ind.Block
check := header.Controls[0]
if check == nil {
// we don't know how to rewrite a loop that not simple comparison
continue
}
switch check.Op {
case OpLeq64, OpLeq32, OpLeq16, OpLeq8,
OpLess64, OpLess32, OpLess16, OpLess8:
default:
// we don't know how to rewrite a loop that not simple comparison
continue
}
if !((check.Args[0] == ind && check.Args[1] == end) ||
(check.Args[1] == ind && check.Args[0] == end)) {
// we don't know how to rewrite a loop that not simple comparison
continue
}
if end.Block == ind.Block {
// we can't rewrite loops where the condition depends on the loop body
// this simple check is forced to work because if this is true a Phi in ind.Block must exists
continue
}
// invert the check
check.Args[0], check.Args[1] = check.Args[1], check.Args[0]
// invert start and end in the loop
for i, v := range check.Args {
if v != end {
continue
}
check.SetArg(i, start)
goto replacedEnd
}
panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
replacedEnd:
for i, v := range ind.Args {
if v != start {
continue
}
ind.SetArg(i, end)
goto replacedStart
}
panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
replacedStart:
if nxt.Args[0] != ind {
// unlike additions subtractions are not commutative so be sure we get it right
nxt.Args[0], nxt.Args[1] = nxt.Args[1], nxt.Args[0]
}
switch nxt.Op {
case OpAdd8:
nxt.Op = OpSub8
case OpAdd16:
nxt.Op = OpSub16
case OpAdd32:
nxt.Op = OpSub32
case OpAdd64:
nxt.Op = OpSub64
case OpSub8:
nxt.Op = OpAdd8
case OpSub16:
nxt.Op = OpAdd16
case OpSub32:
nxt.Op = OpAdd32
case OpSub64:
nxt.Op = OpAdd64
default:
panic("unreachable")
}
if f.pass.debug > 0 {
f.Warnl(ind.Pos, "Inverted loop iteration")
}
}
ft := newFactsTable(f)
ft.checkpoint()
@ -933,15 +1093,6 @@ func prove(f *Func) {
}
}
}
// Find induction variables. Currently, findIndVars
// is limited to one induction variable per block.
var indVars map[*Block]indVar
for _, v := range findIndVar(f) {
if indVars == nil {
indVars = make(map[*Block]indVar)
}
indVars[v.entry] = v
}
// current node state
type walkState int