[dev.regabi] cmd/compile: separate CommStmt from CaseStmt

Like go/ast and cmd/compile/internal/syntax before it, package ir now
has separate concrete representations for switch-case clauses and
select-communication clauses.

Passes toolstash -cmp.

Change-Id: I32667cbae251fe7881be0f434388478433b2414f
Reviewed-on: https://go-review.googlesource.com/c/go/+/280443
Trust: Matthew Dempsky <mdempsky@google.com>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
This commit is contained in:
Matthew Dempsky 2020-12-26 22:42:17 -08:00
parent ed9772e130
commit 2ecf52b841
7 changed files with 109 additions and 26 deletions

View file

@ -38,6 +38,7 @@ func main() {
ntypeType := lookup("Ntype") ntypeType := lookup("Ntype")
nodesType := lookup("Nodes") nodesType := lookup("Nodes")
slicePtrCaseStmtType := types.NewSlice(types.NewPointer(lookup("CaseStmt"))) slicePtrCaseStmtType := types.NewSlice(types.NewPointer(lookup("CaseStmt")))
slicePtrCommStmtType := types.NewSlice(types.NewPointer(lookup("CommStmt")))
ptrFieldType := types.NewPointer(lookup("Field")) ptrFieldType := types.NewPointer(lookup("Field"))
slicePtrFieldType := types.NewSlice(ptrFieldType) slicePtrFieldType := types.NewSlice(ptrFieldType)
ptrIdentType := types.NewPointer(lookup("Ident")) ptrIdentType := types.NewPointer(lookup("Ident"))
@ -79,6 +80,8 @@ func main() {
fmt.Fprintf(&buf, "c.%s = c.%s.Copy()\n", name, name) fmt.Fprintf(&buf, "c.%s = c.%s.Copy()\n", name, name)
case is(slicePtrCaseStmtType): case is(slicePtrCaseStmtType):
fmt.Fprintf(&buf, "c.%s = copyCases(c.%s)\n", name, name) fmt.Fprintf(&buf, "c.%s = copyCases(c.%s)\n", name, name)
case is(slicePtrCommStmtType):
fmt.Fprintf(&buf, "c.%s = copyComms(c.%s)\n", name, name)
case is(ptrFieldType): case is(ptrFieldType):
fmt.Fprintf(&buf, "if c.%s != nil { c.%s = c.%s.copy() }\n", name, name, name) fmt.Fprintf(&buf, "if c.%s != nil { c.%s = c.%s.copy() }\n", name, name, name)
case is(slicePtrFieldType): case is(slicePtrFieldType):
@ -99,6 +102,8 @@ func main() {
fmt.Fprintf(&buf, "err = maybeDoList(n.%s, err, do)\n", name) fmt.Fprintf(&buf, "err = maybeDoList(n.%s, err, do)\n", name)
case is(slicePtrCaseStmtType): case is(slicePtrCaseStmtType):
fmt.Fprintf(&buf, "err = maybeDoCases(n.%s, err, do)\n", name) fmt.Fprintf(&buf, "err = maybeDoCases(n.%s, err, do)\n", name)
case is(slicePtrCommStmtType):
fmt.Fprintf(&buf, "err = maybeDoComms(n.%s, err, do)\n", name)
case is(ptrFieldType): case is(ptrFieldType):
fmt.Fprintf(&buf, "err = maybeDoField(n.%s, err, do)\n", name) fmt.Fprintf(&buf, "err = maybeDoField(n.%s, err, do)\n", name)
case is(slicePtrFieldType): case is(slicePtrFieldType):
@ -120,6 +125,8 @@ func main() {
fmt.Fprintf(&buf, "editList(n.%s, edit)\n", name) fmt.Fprintf(&buf, "editList(n.%s, edit)\n", name)
case is(slicePtrCaseStmtType): case is(slicePtrCaseStmtType):
fmt.Fprintf(&buf, "editCases(n.%s, edit)\n", name) fmt.Fprintf(&buf, "editCases(n.%s, edit)\n", name)
case is(slicePtrCommStmtType):
fmt.Fprintf(&buf, "editComms(n.%s, edit)\n", name)
case is(ptrFieldType): case is(ptrFieldType):
fmt.Fprintf(&buf, "editField(n.%s, edit)\n", name) fmt.Fprintf(&buf, "editField(n.%s, edit)\n", name)
case is(slicePtrFieldType): case is(slicePtrFieldType):

View file

@ -239,7 +239,6 @@ func (n *CaseStmt) doChildren(do func(Node) error) error {
err = maybeDoList(n.init, err, do) err = maybeDoList(n.init, err, do)
err = maybeDo(n.Var, err, do) err = maybeDo(n.Var, err, do)
err = maybeDoList(n.List, err, do) err = maybeDoList(n.List, err, do)
err = maybeDo(n.Comm, err, do)
err = maybeDoList(n.Body, err, do) err = maybeDoList(n.Body, err, do)
return err return err
} }
@ -247,7 +246,6 @@ func (n *CaseStmt) editChildren(edit func(Node) Node) {
editList(n.init, edit) editList(n.init, edit)
n.Var = maybeEdit(n.Var, edit) n.Var = maybeEdit(n.Var, edit)
editList(n.List, edit) editList(n.List, edit)
n.Comm = maybeEdit(n.Comm, edit)
editList(n.Body, edit) editList(n.Body, edit)
} }
@ -295,6 +293,29 @@ func (n *ClosureReadExpr) editChildren(edit func(Node) Node) {
editList(n.init, edit) editList(n.init, edit)
} }
func (n *CommStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
func (n *CommStmt) copy() Node {
c := *n
c.init = c.init.Copy()
c.List = c.List.Copy()
c.Body = c.Body.Copy()
return &c
}
func (n *CommStmt) doChildren(do func(Node) error) error {
var err error
err = maybeDoList(n.init, err, do)
err = maybeDoList(n.List, err, do)
err = maybeDo(n.Comm, err, do)
err = maybeDoList(n.Body, err, do)
return err
}
func (n *CommStmt) editChildren(edit func(Node) Node) {
editList(n.init, edit)
editList(n.List, edit)
n.Comm = maybeEdit(n.Comm, edit)
editList(n.Body, edit)
}
func (n *CompLitExpr) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) } func (n *CompLitExpr) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
func (n *CompLitExpr) copy() Node { func (n *CompLitExpr) copy() Node {
c := *n c := *n
@ -781,20 +802,20 @@ func (n *SelectStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) }
func (n *SelectStmt) copy() Node { func (n *SelectStmt) copy() Node {
c := *n c := *n
c.init = c.init.Copy() c.init = c.init.Copy()
c.Cases = copyCases(c.Cases) c.Cases = copyComms(c.Cases)
c.Compiled = c.Compiled.Copy() c.Compiled = c.Compiled.Copy()
return &c return &c
} }
func (n *SelectStmt) doChildren(do func(Node) error) error { func (n *SelectStmt) doChildren(do func(Node) error) error {
var err error var err error
err = maybeDoList(n.init, err, do) err = maybeDoList(n.init, err, do)
err = maybeDoCases(n.Cases, err, do) err = maybeDoComms(n.Cases, err, do)
err = maybeDoList(n.Compiled, err, do) err = maybeDoList(n.Compiled, err, do)
return err return err
} }
func (n *SelectStmt) editChildren(edit func(Node) Node) { func (n *SelectStmt) editChildren(edit func(Node) Node) {
editList(n.init, edit) editList(n.init, edit)
editCases(n.Cases, edit) editComms(n.Cases, edit)
editList(n.Compiled, edit) editList(n.Compiled, edit)
} }

View file

@ -178,19 +178,17 @@ type CaseStmt struct {
miniStmt miniStmt
Var Node // declared variable for this case in type switch Var Node // declared variable for this case in type switch
List Nodes // list of expressions for switch, early select List Nodes // list of expressions for switch, early select
Comm Node // communication case (Exprs[0]) after select is type-checked
Body Nodes Body Nodes
} }
func NewCaseStmt(pos src.XPos, list, body []Node) *CaseStmt { func NewCaseStmt(pos src.XPos, list, body []Node) *CaseStmt {
n := &CaseStmt{} n := &CaseStmt{List: list, Body: body}
n.pos = pos n.pos = pos
n.op = OCASE n.op = OCASE
n.List.Set(list)
n.Body.Set(body)
return n return n
} }
// TODO(mdempsky): Generate these with mknode.go.
func copyCases(list []*CaseStmt) []*CaseStmt { func copyCases(list []*CaseStmt) []*CaseStmt {
if list == nil { if list == nil {
return nil return nil
@ -199,7 +197,6 @@ func copyCases(list []*CaseStmt) []*CaseStmt {
copy(c, list) copy(c, list)
return c return c
} }
func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error { func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error {
if err != nil { if err != nil {
return err return err
@ -213,7 +210,6 @@ func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error {
} }
return nil return nil
} }
func editCases(list []*CaseStmt, edit func(Node) Node) { func editCases(list []*CaseStmt, edit func(Node) Node) {
for i, x := range list { for i, x := range list {
if x != nil { if x != nil {
@ -222,6 +218,50 @@ func editCases(list []*CaseStmt, edit func(Node) Node) {
} }
} }
type CommStmt struct {
miniStmt
List Nodes // list of expressions for switch, early select
Comm Node // communication case (Exprs[0]) after select is type-checked
Body Nodes
}
func NewCommStmt(pos src.XPos, list, body []Node) *CommStmt {
n := &CommStmt{List: list, Body: body}
n.pos = pos
n.op = OCASE
return n
}
// TODO(mdempsky): Generate these with mknode.go.
func copyComms(list []*CommStmt) []*CommStmt {
if list == nil {
return nil
}
c := make([]*CommStmt, len(list))
copy(c, list)
return c
}
func maybeDoComms(list []*CommStmt, err error, do func(Node) error) error {
if err != nil {
return err
}
for _, x := range list {
if x != nil {
if err := do(x); err != nil {
return err
}
}
}
return nil
}
func editComms(list []*CommStmt, edit func(Node) Node) {
for i, x := range list {
if x != nil {
list[i] = edit(x).(*CommStmt)
}
}
}
// A ForStmt is a non-range for loop: for Init; Cond; Post { Body } // A ForStmt is a non-range for loop: for Init; Cond; Post { Body }
// Op can be OFOR or OFORUNTIL (!Cond). // Op can be OFOR or OFORUNTIL (!Cond).
type ForStmt struct { type ForStmt struct {
@ -365,18 +405,17 @@ func (n *ReturnStmt) SetOrig(x Node) { n.orig = x }
type SelectStmt struct { type SelectStmt struct {
miniStmt miniStmt
Label *types.Sym Label *types.Sym
Cases []*CaseStmt Cases []*CommStmt
HasBreak bool HasBreak bool
// TODO(rsc): Instead of recording here, replace with a block? // TODO(rsc): Instead of recording here, replace with a block?
Compiled Nodes // compiled form, after walkswitch Compiled Nodes // compiled form, after walkswitch
} }
func NewSelectStmt(pos src.XPos, cases []*CaseStmt) *SelectStmt { func NewSelectStmt(pos src.XPos, cases []*CommStmt) *SelectStmt {
n := &SelectStmt{} n := &SelectStmt{Cases: cases}
n.pos = pos n.pos = pos
n.op = OSELECT n.op = OSELECT
n.Cases = cases
return n return n
} }
@ -407,10 +446,9 @@ type SwitchStmt struct {
} }
func NewSwitchStmt(pos src.XPos, tag Node, cases []*CaseStmt) *SwitchStmt { func NewSwitchStmt(pos src.XPos, tag Node, cases []*CaseStmt) *SwitchStmt {
n := &SwitchStmt{Tag: tag} n := &SwitchStmt{Tag: tag, Cases: cases}
n.pos = pos n.pos = pos
n.op = OSWITCH n.op = OSWITCH
n.Cases = cases
return n return n
} }

View file

@ -1266,8 +1266,8 @@ func (p *noder) simpleStmt(stmt syntax.SimpleStmt) []ir.Node {
return []ir.Node{p.stmt(stmt)} return []ir.Node{p.stmt(stmt)}
} }
func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CaseStmt { func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CommStmt {
nodes := make([]*ir.CaseStmt, len(clauses)) nodes := make([]*ir.CommStmt, len(clauses))
for i, clause := range clauses { for i, clause := range clauses {
p.setlineno(clause) p.setlineno(clause)
if i > 0 { if i > 0 {
@ -1275,7 +1275,7 @@ func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*
} }
p.openScope(clause.Pos()) p.openScope(clause.Pos())
nodes[i] = ir.NewCaseStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body)) nodes[i] = ir.NewCommStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body))
} }
if len(clauses) > 0 { if len(clauses) > 0 {
p.closeScope(rbrace) p.closeScope(rbrace)

View file

@ -1144,7 +1144,7 @@ func (w *exportWriter) stmt(n ir.Node) {
w.op(n.Op()) w.op(n.Op())
w.pos(n.Pos()) w.pos(n.Pos())
w.stmtList(n.Init()) w.stmtList(n.Init())
w.caseList(n.Cases, false) w.commList(n.Cases)
case ir.OSWITCH: case ir.OSWITCH:
n := n.(*ir.SwitchStmt) n := n.(*ir.SwitchStmt)
@ -1193,6 +1193,15 @@ func (w *exportWriter) caseList(cases []*ir.CaseStmt, namedTypeSwitch bool) {
} }
} }
func (w *exportWriter) commList(cases []*ir.CommStmt) {
w.uint64(uint64(len(cases)))
for _, cas := range cases {
w.pos(cas.Pos())
w.stmtList(cas.List)
w.stmtList(cas.Body)
}
}
func (w *exportWriter) exprList(list ir.Nodes) { func (w *exportWriter) exprList(list ir.Nodes) {
for _, n := range list { for _, n := range list {
w.expr(n) w.expr(n)

View file

@ -789,6 +789,14 @@ func (r *importReader) caseList(switchExpr ir.Node) []*ir.CaseStmt {
return cases return cases
} }
func (r *importReader) commList() []*ir.CommStmt {
cases := make([]*ir.CommStmt, r.uint64())
for i := range cases {
cases[i] = ir.NewCommStmt(r.pos(), r.stmtList(), r.stmtList())
}
return cases
}
func (r *importReader) exprList() []ir.Node { func (r *importReader) exprList() []ir.Node {
var list []ir.Node var list []ir.Node
for { for {
@ -1035,7 +1043,7 @@ func (r *importReader) node() ir.Node {
case ir.OSELECT: case ir.OSELECT:
pos := r.pos() pos := r.pos()
init := r.stmtList() init := r.stmtList()
n := ir.NewSelectStmt(pos, r.caseList(nil)) n := ir.NewSelectStmt(pos, r.commList())
n.PtrInit().Set(init) n.PtrInit().Set(init)
return n return n

View file

@ -29,7 +29,7 @@ func walkSelect(sel *ir.SelectStmt) {
base.Pos = lno base.Pos = lno
} }
func walkSelectCases(cases []*ir.CaseStmt) []ir.Node { func walkSelectCases(cases []*ir.CommStmt) []ir.Node {
ncas := len(cases) ncas := len(cases)
sellineno := base.Pos sellineno := base.Pos
@ -73,7 +73,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
// convert case value arguments to addresses. // convert case value arguments to addresses.
// this rewrite is used by both the general code and the next optimization. // this rewrite is used by both the general code and the next optimization.
var dflt *ir.CaseStmt var dflt *ir.CommStmt
for _, cas := range cases { for _, cas := range cases {
ir.SetPos(cas) ir.SetPos(cas)
n := cas.Comm n := cas.Comm
@ -146,7 +146,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
if dflt != nil { if dflt != nil {
ncas-- ncas--
} }
casorder := make([]*ir.CaseStmt, ncas) casorder := make([]*ir.CommStmt, ncas)
nsends, nrecvs := 0, 0 nsends, nrecvs := 0, 0
var init []ir.Node var init []ir.Node
@ -242,7 +242,7 @@ func walkSelectCases(cases []*ir.CaseStmt) []ir.Node {
} }
// dispatch cases // dispatch cases
dispatch := func(cond ir.Node, cas *ir.CaseStmt) { dispatch := func(cond ir.Node, cas *ir.CommStmt) {
cond = typecheck.Expr(cond) cond = typecheck.Expr(cond)
cond = typecheck.DefaultLit(cond, nil) cond = typecheck.DefaultLit(cond, nil)