diff --git a/src/cmd/compile/internal/loopvar/loopvar.go b/src/cmd/compile/internal/loopvar/loopvar.go index 0ecb70570f7..c92b9d61ea3 100644 --- a/src/cmd/compile/internal/loopvar/loopvar.go +++ b/src/cmd/compile/internal/loopvar/loopvar.go @@ -50,6 +50,10 @@ func ForCapture(fn *ir.Func) []*ir.Name { // will be transformed. possiblyLeaked := make(map[*ir.Name]bool) + // these enable an optimization of "escape" under return statements + loopDepth := 0 + returnInLoopDepth := 0 + // noteMayLeak is called for candidate variables in for range/3-clause, and // adds them (mapped to false) to possiblyLeaked. noteMayLeak := func(x ir.Node) { @@ -95,6 +99,11 @@ func ForCapture(fn *ir.Func) []*ir.Name { scanChildrenThenTransform = func(n ir.Node) bool { switch x := n.(type) { case *ir.ClosureExpr: + if returnInLoopDepth >= loopDepth { + // This expression is a child of a return, which escapes all loops above + // the return, but not those between this expression and the return. + break + } for _, cv := range x.Func.ClosureVars { v := cv.Canonical() if _, ok := possiblyLeaked[v]; ok { @@ -103,6 +112,11 @@ func ForCapture(fn *ir.Func) []*ir.Name { } case *ir.AddrExpr: + if returnInLoopDepth >= loopDepth { + // This expression is a child of a return, which escapes all loops above + // the return, but not those between this expression and the return. + break + } // Explicitly note address-taken so that return-statements can be excluded y := ir.OuterValue(x.X) if y.Op() != ir.ONAME { @@ -119,6 +133,11 @@ func ForCapture(fn *ir.Func) []*ir.Name { } } + case *ir.ReturnStmt: + savedRILD := returnInLoopDepth + returnInLoopDepth = loopDepth + defer func() { returnInLoopDepth = savedRILD }() + case *ir.RangeStmt: if !(x.Def && x.DistinctVars) { // range loop must define its iteration variables AND have distinctVars. @@ -127,7 +146,9 @@ func ForCapture(fn *ir.Func) []*ir.Name { } noteMayLeak(x.Key) noteMayLeak(x.Value) + loopDepth++ ir.DoChildren(n, scanChildrenThenTransform) + loopDepth-- x.Key = maybeReplaceVar(x.Key, x) x.Value = maybeReplaceVar(x.Value, x) x.DistinctVars = false @@ -138,7 +159,9 @@ func ForCapture(fn *ir.Func) []*ir.Name { break } forAllDefInInit(x, noteMayLeak) + loopDepth++ ir.DoChildren(n, scanChildrenThenTransform) + loopDepth-- var leaked []*ir.Name // Collect the leaking variables for the much-more-complex transformation. forAllDefInInit(x, func(z ir.Node) { diff --git a/src/cmd/compile/internal/loopvar/loopvar_test.go b/src/cmd/compile/internal/loopvar/loopvar_test.go index 6f4e73bb271..729c240ef54 100644 --- a/src/cmd/compile/internal/loopvar/loopvar_test.go +++ b/src/cmd/compile/internal/loopvar/loopvar_test.go @@ -206,3 +206,41 @@ func TestLoopVarHashes(t *testing.T) { t.Errorf("Did not see expected value of m run") } } + +func TestLoopVarOpt(t *testing.T) { + switch runtime.GOOS { + case "linux", "darwin": + default: + t.Skipf("Slow test, usually avoid it, os=%s not linux or darwin", runtime.GOOS) + } + switch runtime.GOARCH { + case "amd64", "arm64": + default: + t.Skipf("Slow test, usually avoid it, arch=%s not amd64 or arm64", runtime.GOARCH) + } + + testenv.MustHaveGoBuild(t) + gocmd := testenv.GoToolPath(t) + + cmd := testenv.Command(t, gocmd, "run", "-gcflags=-d=loopvar=2", "opt.go") + cmd.Dir = filepath.Join("testdata") + + b, err := cmd.CombinedOutput() + m := string(b) + + t.Logf(m) + + yCount := strings.Count(m, "opt.go:16:6: transformed loop variable private escapes (loop inlined into ./opt.go:30)") + nCount := strings.Count(m, "shared") + + if yCount != 1 { + t.Errorf("yCount=%d != 1", yCount) + } + if nCount > 0 { + t.Errorf("nCount=%d > 0", nCount) + } + if err != nil { + t.Errorf("err=%v != nil", err) + } + +} diff --git a/src/cmd/compile/internal/loopvar/testdata/opt.go b/src/cmd/compile/internal/loopvar/testdata/opt.go new file mode 100644 index 00000000000..1bcd73614d8 --- /dev/null +++ b/src/cmd/compile/internal/loopvar/testdata/opt.go @@ -0,0 +1,42 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "os" +) + +var is []func() int + +func inline(j, k int) []*int { + var a []*int + for private := j; private < k; private++ { + a = append(a, &private) + } + return a + +} + +//go:noinline +func notinline(j, k int) ([]*int, *int) { + for shared := j; shared < k; shared++ { + if shared == k/2 { + // want the call inlined, want "private" in that inline to be transformed, + // (believe it ends up on init node of the return). + // but do not want "shared" transformed, + return inline(j, k), &shared + } + } + return nil, &j +} + +func main() { + a, p := notinline(2, 9) + fmt.Printf("a[0]=%d,*p=%d\n", *a[0], *p) + if *a[0] != 2 { + os.Exit(1) + } +}