mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
cmd/compile: try to rewrite loops to count down
Fixes #61629 This reduce the pressure on regalloc because then the loop only keep alive one value (the iterator) instead of the iterator and the upper bound since the comparison now acts against an immediate, often zero which can be skipped. This optimize things like: for i := 0; i < n; i++ { Or a range over a slice where the index is not used: for _, v := range someSlice { Or the new range over int from #61405: for range n { It is hit in 975 unique places while doing ./make.bash. Change-Id: I5facff8b267a0b60ea3c1b9a58c4d74cdb38f03f Reviewed-on: https://go-review.googlesource.com/c/go/+/512935 TryBot-Result: Gopher Robot <gobot@golang.org> Run-TryBot: Jorropo <jorropo.pgm@gmail.com> Reviewed-by: Keith Randall <khr@google.com> Reviewed-by: David Chase <drchase@google.com> Reviewed-by: Keith Randall <khr@golang.org> Auto-Submit: Keith Randall <khr@golang.org>
This commit is contained in:
parent
8613ef81e6
commit
bac4e2f241
4 changed files with 188 additions and 21 deletions
|
|
@ -798,6 +798,166 @@ func (ft *factsTable) cleanup(f *Func) {
|
|||
// its negation. If either leads to a contradiction, it can trim that
|
||||
// successor.
|
||||
func prove(f *Func) {
|
||||
// Find induction variables. Currently, findIndVars
|
||||
// is limited to one induction variable per block.
|
||||
var indVars map[*Block]indVar
|
||||
for _, v := range findIndVar(f) {
|
||||
ind := v.ind
|
||||
if len(ind.Args) != 2 {
|
||||
// the rewrite code assumes there is only ever two parents to loops
|
||||
panic("unexpected induction with too many parents")
|
||||
}
|
||||
|
||||
nxt := v.nxt
|
||||
if !(ind.Uses == 2 && // 2 used by comparison and next
|
||||
nxt.Uses == 1) { // 1 used by induction
|
||||
// ind or nxt is used inside the loop, add it for the facts table
|
||||
if indVars == nil {
|
||||
indVars = make(map[*Block]indVar)
|
||||
}
|
||||
indVars[v.entry] = v
|
||||
continue
|
||||
} else {
|
||||
// Since this induction variable is not used for anything but counting the iterations,
|
||||
// no point in putting it into the facts table.
|
||||
}
|
||||
|
||||
// try to rewrite to a downward counting loop checking against start if the
|
||||
// loop body does not depends on ind or nxt and end is known before the loop.
|
||||
// This reduce pressure on the register allocator because this do not need
|
||||
// to use end on each iteration anymore. We compare against the start constant instead.
|
||||
// That means this code:
|
||||
//
|
||||
// loop:
|
||||
// ind = (Phi (Const [x]) nxt),
|
||||
// if ind < end
|
||||
// then goto enter_loop
|
||||
// else goto exit_loop
|
||||
//
|
||||
// enter_loop:
|
||||
// do something without using ind nor nxt
|
||||
// nxt = inc + ind
|
||||
// goto loop
|
||||
//
|
||||
// exit_loop:
|
||||
//
|
||||
// is rewritten to:
|
||||
//
|
||||
// loop:
|
||||
// ind = (Phi end nxt)
|
||||
// if (Const [x]) < ind
|
||||
// then goto enter_loop
|
||||
// else goto exit_loop
|
||||
//
|
||||
// enter_loop:
|
||||
// do something without using ind nor nxt
|
||||
// nxt = ind - inc
|
||||
// goto loop
|
||||
//
|
||||
// exit_loop:
|
||||
//
|
||||
// this is better because it only require to keep ind then nxt alive while looping,
|
||||
// while the original form keeps ind then nxt and end alive
|
||||
start, end := v.min, v.max
|
||||
if v.flags&indVarCountDown != 0 {
|
||||
start, end = end, start
|
||||
}
|
||||
|
||||
if !(start.Op == OpConst8 || start.Op == OpConst16 || start.Op == OpConst32 || start.Op == OpConst64) {
|
||||
// if start is not a constant we would be winning nothing from inverting the loop
|
||||
continue
|
||||
}
|
||||
if end.Op == OpConst8 || end.Op == OpConst16 || end.Op == OpConst32 || end.Op == OpConst64 {
|
||||
// TODO: if both start and end are constants we should rewrite such that the comparison
|
||||
// is against zero and nxt is ++ or -- operation
|
||||
// That means:
|
||||
// for i := 2; i < 11; i += 2 {
|
||||
// should be rewritten to:
|
||||
// for i := 5; 0 < i; i-- {
|
||||
continue
|
||||
}
|
||||
|
||||
header := ind.Block
|
||||
check := header.Controls[0]
|
||||
if check == nil {
|
||||
// we don't know how to rewrite a loop that not simple comparison
|
||||
continue
|
||||
}
|
||||
switch check.Op {
|
||||
case OpLeq64, OpLeq32, OpLeq16, OpLeq8,
|
||||
OpLess64, OpLess32, OpLess16, OpLess8:
|
||||
default:
|
||||
// we don't know how to rewrite a loop that not simple comparison
|
||||
continue
|
||||
}
|
||||
if !((check.Args[0] == ind && check.Args[1] == end) ||
|
||||
(check.Args[1] == ind && check.Args[0] == end)) {
|
||||
// we don't know how to rewrite a loop that not simple comparison
|
||||
continue
|
||||
}
|
||||
if end.Block == ind.Block {
|
||||
// we can't rewrite loops where the condition depends on the loop body
|
||||
// this simple check is forced to work because if this is true a Phi in ind.Block must exists
|
||||
continue
|
||||
}
|
||||
|
||||
// invert the check
|
||||
check.Args[0], check.Args[1] = check.Args[1], check.Args[0]
|
||||
|
||||
// invert start and end in the loop
|
||||
for i, v := range check.Args {
|
||||
if v != end {
|
||||
continue
|
||||
}
|
||||
|
||||
check.SetArg(i, start)
|
||||
goto replacedEnd
|
||||
}
|
||||
panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
|
||||
replacedEnd:
|
||||
|
||||
for i, v := range ind.Args {
|
||||
if v != start {
|
||||
continue
|
||||
}
|
||||
|
||||
ind.SetArg(i, end)
|
||||
goto replacedStart
|
||||
}
|
||||
panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end))
|
||||
replacedStart:
|
||||
|
||||
if nxt.Args[0] != ind {
|
||||
// unlike additions subtractions are not commutative so be sure we get it right
|
||||
nxt.Args[0], nxt.Args[1] = nxt.Args[1], nxt.Args[0]
|
||||
}
|
||||
|
||||
switch nxt.Op {
|
||||
case OpAdd8:
|
||||
nxt.Op = OpSub8
|
||||
case OpAdd16:
|
||||
nxt.Op = OpSub16
|
||||
case OpAdd32:
|
||||
nxt.Op = OpSub32
|
||||
case OpAdd64:
|
||||
nxt.Op = OpSub64
|
||||
case OpSub8:
|
||||
nxt.Op = OpAdd8
|
||||
case OpSub16:
|
||||
nxt.Op = OpAdd16
|
||||
case OpSub32:
|
||||
nxt.Op = OpAdd32
|
||||
case OpSub64:
|
||||
nxt.Op = OpAdd64
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
if f.pass.debug > 0 {
|
||||
f.Warnl(ind.Pos, "Inverted loop iteration")
|
||||
}
|
||||
}
|
||||
|
||||
ft := newFactsTable(f)
|
||||
ft.checkpoint()
|
||||
|
||||
|
|
@ -933,15 +1093,6 @@ func prove(f *Func) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// Find induction variables. Currently, findIndVars
|
||||
// is limited to one induction variable per block.
|
||||
var indVars map[*Block]indVar
|
||||
for _, v := range findIndVar(f) {
|
||||
if indVars == nil {
|
||||
indVars = make(map[*Block]indVar)
|
||||
}
|
||||
indVars[v.entry] = v
|
||||
}
|
||||
|
||||
// current node state
|
||||
type walkState int
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue