diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index d8e1eaab65b..78c7ddeabe3 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -461,9 +461,11 @@ func (subst *subster) list(l []ir.Node) []ir.Node { } // tstruct substitutes type params in types of the fields of a structure type. For -// each field, if Nname is set, tstruct also translates the Nname using subst.vars, if -// Nname is in subst.vars. -func (subst *subster) tstruct(t *types.Type) *types.Type { +// each field, if Nname is set, tstruct also translates the Nname using +// subst.vars, if Nname is in subst.vars. To always force the creation of a new +// (top-level) struct, regardless of whether anything changed with the types or +// names of the struct's fields, set force to true. +func (subst *subster) tstruct(t *types.Type, force bool) *types.Type { if t.NumFields() == 0 { if t.HasTParam() { // For an empty struct, we need to return a new type, @@ -474,6 +476,9 @@ func (subst *subster) tstruct(t *types.Type) *types.Type { return t } var newfields []*types.Field + if force { + newfields = make([]*types.Field, t.NumFields()) + } for i, f := range t.Fields().Slice() { t2 := subst.typ(f.Type) if (t2 != f.Type || f.Nname != nil) && newfields == nil { @@ -650,20 +655,33 @@ func (subst *subster) typ(t *types.Type) *types.Type { } case types.TSTRUCT: - newt = subst.tstruct(t) + newt = subst.tstruct(t, false) if newt == t { newt = nil } case types.TFUNC: - newrecvs := subst.tstruct(t.Recvs()) - newparams := subst.tstruct(t.Params()) - newresults := subst.tstruct(t.Results()) + newrecvs := subst.tstruct(t.Recvs(), false) + newparams := subst.tstruct(t.Params(), false) + newresults := subst.tstruct(t.Results(), false) if newrecvs != t.Recvs() || newparams != t.Params() || newresults != t.Results() { + // If any types have changed, then the all the fields of + // of recv, params, and results must be copied, because they have + // offset fields that are dependent, and so must have an + // independent copy for each new signature. var newrecv *types.Field if newrecvs.NumFields() > 0 { + if newrecvs == t.Recvs() { + newrecvs = subst.tstruct(t.Recvs(), true) + } newrecv = newrecvs.Field(0) } + if newparams == t.Params() { + newparams = subst.tstruct(t.Params(), true) + } + if newresults == t.Results() { + newresults = subst.tstruct(t.Results(), true) + } newt = types.NewSignature(t.Pkg(), newrecv, t.TParams().FieldSlice(), newparams.FieldSlice(), newresults.FieldSlice()) } @@ -673,6 +691,13 @@ func (subst *subster) typ(t *types.Type) *types.Type { newt = nil } + case types.TMAP: + newkey := subst.typ(t.Key()) + newval := subst.typ(t.Elem()) + if newkey != t.Key() || newval != t.Elem() { + newt = types.NewMap(newkey, newval) + } + case types.TCHAN: elem := t.Elem() newelem := subst.typ(elem) @@ -684,8 +709,6 @@ func (subst *subster) typ(t *types.Type) *types.Type { types.CheckSize(newt) } } - - // TODO: case TMAP } if newt == nil { // Even though there were typeparams in the type, there may be no diff --git a/test/typeparam/maps.go b/test/typeparam/maps.go new file mode 100644 index 00000000000..d18dd59aed4 --- /dev/null +++ b/test/typeparam/maps.go @@ -0,0 +1,260 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "math" + "sort" +) + +// _Equal reports whether two slices are equal: the same length and all +// elements equal. All floating point NaNs are considered equal. +func _SliceEqual[Elem comparable](s1, s2 []Elem) bool { + if len(s1) != len(s2) { + return false + } + for i, v1 := range s1 { + v2 := s2[i] + if v1 != v2 { + isNaN := func(f Elem) bool { return f != f } + if !isNaN(v1) || !isNaN(v2) { + return false + } + } + } + return true +} + +// _Keys returns the keys of the map m. +// The keys will be an indeterminate order. +func _Keys[K comparable, V any](m map[K]V) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +// _Values returns the values of the map m. +// The values will be in an indeterminate order. +func _Values[K comparable, V any](m map[K]V) []V { + r := make([]V, 0, len(m)) + for _, v := range m { + r = append(r, v) + } + return r +} + +// _Equal reports whether two maps contain the same key/value pairs. +// _Values are compared using ==. +func _Equal[K, V comparable](m1, m2 map[K]V) bool { + if len(m1) != len(m2) { + return false + } + for k, v1 := range m1 { + if v2, ok := m2[k]; !ok || v1 != v2 { + return false + } + } + return true +} + +// _Copy returns a copy of m. +func _Copy[K comparable, V any](m map[K]V) map[K]V { + r := make(map[K]V, len(m)) + for k, v := range m { + r[k] = v + } + return r +} + +// _Add adds all key/value pairs in m2 to m1. _Keys in m2 that are already +// present in m1 will be overwritten with the value in m2. +func _Add[K comparable, V any](m1, m2 map[K]V) { + for k, v := range m2 { + m1[k] = v + } +} + +// _Sub removes all keys in m2 from m1. _Keys in m2 that are not present +// in m1 are ignored. The values in m2 are ignored. +func _Sub[K comparable, V any](m1, m2 map[K]V) { + for k := range m2 { + delete(m1, k) + } +} + +// _Intersect removes all keys from m1 that are not present in m2. +// _Keys in m2 that are not in m1 are ignored. The values in m2 are ignored. +func _Intersect[K comparable, V any](m1, m2 map[K]V) { + for k := range m1 { + if _, ok := m2[k]; !ok { + delete(m1, k) + } + } +} + +// _Filter deletes any key/value pairs from m for which f returns false. +func _Filter[K comparable, V any](m map[K]V, f func(K, V) bool) { + for k, v := range m { + if !f(k, v) { + delete(m, k) + } + } +} + +// _TransformValues applies f to each value in m. The keys remain unchanged. +func _TransformValues[K comparable, V any](m map[K]V, f func(V) V) { + for k, v := range m { + m[k] = f(v) + } +} + +var m1 = map[int]int{1: 2, 2: 4, 4: 8, 8: 16} +var m2 = map[int]string{1: "2", 2: "4", 4: "8", 8: "16"} + +func TestKeys() { + want := []int{1, 2, 4, 8} + + got1 := _Keys(m1) + sort.Ints(got1) + if !_SliceEqual(got1, want) { + panic(fmt.Sprintf("_Keys(%v) = %v, want %v", m1, got1, want)) + } + + got2 := _Keys(m2) + sort.Ints(got2) + if !_SliceEqual(got2, want) { + panic(fmt.Sprintf("_Keys(%v) = %v, want %v", m2, got2, want)) + } +} + +func TestValues() { + got1 := _Values(m1) + want1 := []int{2, 4, 8, 16} + sort.Ints(got1) + if !_SliceEqual(got1, want1) { + panic(fmt.Sprintf("_Values(%v) = %v, want %v", m1, got1, want1)) + } + + got2 := _Values(m2) + want2 := []string{"16", "2", "4", "8"} + sort.Strings(got2) + if !_SliceEqual(got2, want2) { + panic(fmt.Sprintf("_Values(%v) = %v, want %v", m2, got2, want2)) + } +} + +func TestEqual() { + if !_Equal(m1, m1) { + panic(fmt.Sprintf("_Equal(%v, %v) = false, want true", m1, m1)) + } + if _Equal(m1, nil) { + panic(fmt.Sprintf("_Equal(%v, nil) = true, want false", m1)) + } + if _Equal(nil, m1) { + panic(fmt.Sprintf("_Equal(nil, %v) = true, want false", m1)) + } + if !_Equal[int, int](nil, nil) { + panic("_Equal(nil, nil) = false, want true") + } + if ms := map[int]int{1: 2}; _Equal(m1, ms) { + panic(fmt.Sprintf("_Equal(%v, %v) = true, want false", m1, ms)) + } + + // Comparing NaN for equality is expected to fail. + mf := map[int]float64{1: 0, 2: math.NaN()} + if _Equal(mf, mf) { + panic(fmt.Sprintf("_Equal(%v, %v) = true, want false", mf, mf)) + } +} + +func TestCopy() { + m2 := _Copy(m1) + if !_Equal(m1, m2) { + panic(fmt.Sprintf("_Copy(%v) = %v, want %v", m1, m2, m1)) + } + m2[16] = 32 + if _Equal(m1, m2) { + panic(fmt.Sprintf("_Equal(%v, %v) = true, want false", m1, m2)) + } +} + +func TestAdd() { + mc := _Copy(m1) + _Add(mc, mc) + if !_Equal(mc, m1) { + panic(fmt.Sprintf("_Add(%v, %v) = %v, want %v", m1, m1, mc, m1)) + } + _Add(mc, map[int]int{16: 32}) + want := map[int]int{1: 2, 2: 4, 4: 8, 8: 16, 16: 32} + if !_Equal(mc, want) { + panic(fmt.Sprintf("_Add result = %v, want %v", mc, want)) + } +} + +func TestSub() { + mc := _Copy(m1) + _Sub(mc, mc) + if len(mc) > 0 { + panic(fmt.Sprintf("_Sub(%v, %v) = %v, want empty map", m1, m1, mc)) + } + mc = _Copy(m1) + _Sub(mc, map[int]int{1: 0}) + want := map[int]int{2: 4, 4: 8, 8: 16} + if !_Equal(mc, want) { + panic(fmt.Sprintf("_Sub result = %v, want %v", mc, want)) + } +} + +func TestIntersect() { + mc := _Copy(m1) + _Intersect(mc, mc) + if !_Equal(mc, m1) { + panic(fmt.Sprintf("_Intersect(%v, %v) = %v, want %v", m1, m1, mc, m1)) + } + _Intersect(mc, map[int]int{1: 0, 2: 0}) + want := map[int]int{1: 2, 2: 4} + if !_Equal(mc, want) { + panic(fmt.Sprintf("_Intersect result = %v, want %v", mc, want)) + } +} + +func TestFilter() { + mc := _Copy(m1) + _Filter(mc, func(int, int) bool { return true }) + if !_Equal(mc, m1) { + panic(fmt.Sprintf("_Filter(%v, true) = %v, want %v", m1, mc, m1)) + } + _Filter(mc, func(k, v int) bool { return k < 3 }) + want := map[int]int{1: 2, 2: 4} + if !_Equal(mc, want) { + panic(fmt.Sprintf("_Filter result = %v, want %v", mc, want)) + } +} + +func TestTransformValues() { + mc := _Copy(m1) + _TransformValues(mc, func(i int) int { return i / 2 }) + want := map[int]int{1: 1, 2: 2, 4: 4, 8: 8} + if !_Equal(mc, want) { + panic(fmt.Sprintf("_TransformValues result = %v, want %v", mc, want)) + } +} + +func main() { + TestKeys() + TestValues() + TestEqual() + TestCopy() + TestAdd() + TestSub() + TestIntersect() + TestFilter() + TestTransformValues() +} diff --git a/test/typeparam/sets.go b/test/typeparam/sets.go new file mode 100644 index 00000000000..258514489e1 --- /dev/null +++ b/test/typeparam/sets.go @@ -0,0 +1,280 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "sort" +) + +// _Equal reports whether two slices are equal: the same length and all +// elements equal. All floating point NaNs are considered equal. +func _SliceEqual[Elem comparable](s1, s2 []Elem) bool { + if len(s1) != len(s2) { + return false + } + for i, v1 := range s1 { + v2 := s2[i] + if v1 != v2 { + isNaN := func(f Elem) bool { return f != f } + if !isNaN(v1) || !isNaN(v2) { + return false + } + } + } + return true +} + +// A _Set is a set of elements of some type. +type _Set[Elem comparable] struct { + m map[Elem]struct{} +} + +// _Make makes a new set. +func _Make[Elem comparable]() _Set[Elem] { + return _Set[Elem]{m: make(map[Elem]struct{})} +} + +// Add adds an element to a set. +func (s _Set[Elem]) Add(v Elem) { + s.m[v] = struct{}{} +} + +// Delete removes an element from a set. If the element is not present +// in the set, this does nothing. +func (s _Set[Elem]) Delete(v Elem) { + delete(s.m, v) +} + +// Contains reports whether v is in the set. +func (s _Set[Elem]) Contains(v Elem) bool { + _, ok := s.m[v] + return ok +} + +// Len returns the number of elements in the set. +func (s _Set[Elem]) Len() int { + return len(s.m) +} + +// Values returns the values in the set. +// The values will be in an indeterminate order. +func (s _Set[Elem]) Values() []Elem { + r := make([]Elem, 0, len(s.m)) + for v := range s.m { + r = append(r, v) + } + return r +} + +// _Equal reports whether two sets contain the same elements. +func _Equal[Elem comparable](s1, s2 _Set[Elem]) bool { + if len(s1.m) != len(s2.m) { + return false + } + for v1 := range s1.m { + if !s2.Contains(v1) { + return false + } + } + return true +} + +// Copy returns a copy of s. +func (s _Set[Elem]) Copy() _Set[Elem] { + r := _Set[Elem]{m: make(map[Elem]struct{}, len(s.m))} + for v := range s.m { + r.m[v] = struct{}{} + } + return r +} + +// AddSet adds all the elements of s2 to s. +func (s _Set[Elem]) AddSet(s2 _Set[Elem]) { + for v := range s2.m { + s.m[v] = struct{}{} + } +} + +// SubSet removes all elements in s2 from s. +// Values in s2 that are not in s are ignored. +func (s _Set[Elem]) SubSet(s2 _Set[Elem]) { + for v := range s2.m { + delete(s.m, v) + } +} + +// Intersect removes all elements from s that are not present in s2. +// Values in s2 that are not in s are ignored. +func (s _Set[Elem]) Intersect(s2 _Set[Elem]) { + for v := range s.m { + if !s2.Contains(v) { + delete(s.m, v) + } + } +} + +// Iterate calls f on every element in the set. +func (s _Set[Elem]) Iterate(f func(Elem)) { + for v := range s.m { + f(v) + } +} + +// Filter deletes any elements from s for which f returns false. +func (s _Set[Elem]) Filter(f func(Elem) bool) { + for v := range s.m { + if !f(v) { + delete(s.m, v) + } + } +} + +func TestSet() { + s1 := _Make[int]() + if got := s1.Len(); got != 0 { + panic(fmt.Sprintf("Len of empty set = %d, want 0", got)) + } + s1.Add(1) + s1.Add(1) + s1.Add(1) + if got := s1.Len(); got != 1 { + panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got)) + } + s1.Add(2) + s1.Add(3) + s1.Add(4) + if got := s1.Len(); got != 4 { + panic(fmt.Sprintf("(%v).Len() == %d, want 4", s1, got)) + } + if !s1.Contains(1) { + panic(fmt.Sprintf("(%v).Contains(1) == false, want true", s1)) + } + if s1.Contains(5) { + panic(fmt.Sprintf("(%v).Contains(5) == true, want false", s1)) + } + vals := s1.Values() + sort.Ints(vals) + w1 := []int{1, 2, 3, 4} + if !_SliceEqual(vals, w1) { + panic(fmt.Sprintf("(%v).Values() == %v, want %v", s1, vals, w1)) + } +} + +func TestEqual() { + s1 := _Make[string]() + s2 := _Make[string]() + if !_Equal(s1, s2) { + panic(fmt.Sprintf("_Equal(%v, %v) = false, want true", s1, s2)) + } + s1.Add("hello") + s1.Add("world") + if got := s1.Len(); got != 2 { + panic(fmt.Sprintf("(%v).Len() == %d, want 2", s1, got)) + } + if _Equal(s1, s2) { + panic(fmt.Sprintf("_Equal(%v, %v) = true, want false", s1, s2)) + } +} + +func TestCopy() { + s1 := _Make[float64]() + s1.Add(0) + s2 := s1.Copy() + if !_Equal(s1, s2) { + panic(fmt.Sprintf("_Equal(%v, %v) = false, want true", s1, s2)) + } + s1.Add(1) + if _Equal(s1, s2) { + panic(fmt.Sprintf("_Equal(%v, %v) = true, want false", s1, s2)) + } +} + +func TestAddSet() { + s1 := _Make[int]() + s1.Add(1) + s1.Add(2) + s2 := _Make[int]() + s2.Add(2) + s2.Add(3) + s1.AddSet(s2) + if got := s1.Len(); got != 3 { + panic(fmt.Sprintf("(%v).Len() == %d, want 3", s1, got)) + } + s2.Add(1) + if !_Equal(s1, s2) { + panic(fmt.Sprintf("_Equal(%v, %v) = false, want true", s1, s2)) + } +} + +func TestSubSet() { + s1 := _Make[int]() + s1.Add(1) + s1.Add(2) + s2 := _Make[int]() + s2.Add(2) + s2.Add(3) + s1.SubSet(s2) + if got := s1.Len(); got != 1 { + panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got)) + } + if vals, want := s1.Values(), []int{1}; !_SliceEqual(vals, want) { + panic(fmt.Sprintf("after SubSet got %v, want %v", vals, want)) + } +} + +func TestIntersect() { + s1 := _Make[int]() + s1.Add(1) + s1.Add(2) + s2 := _Make[int]() + s2.Add(2) + s2.Add(3) + s1.Intersect(s2) + if got := s1.Len(); got != 1 { + panic(fmt.Sprintf("(%v).Len() == %d, want 1", s1, got)) + } + if vals, want := s1.Values(), []int{2}; !_SliceEqual(vals, want) { + panic(fmt.Sprintf("after Intersect got %v, want %v", vals, want)) + } +} + +func TestIterate() { + s1 := _Make[int]() + s1.Add(1) + s1.Add(2) + s1.Add(3) + s1.Add(4) + tot := 0 + s1.Iterate(func(i int) { tot += i }) + if tot != 10 { + panic(fmt.Sprintf("total of %v == %d, want 10", s1, tot)) + } +} + +func TestFilter() { + s1 := _Make[int]() + s1.Add(1) + s1.Add(2) + s1.Add(3) + s1.Filter(func(v int) bool { return v%2 == 0 }) + if vals, want := s1.Values(), []int{2}; !_SliceEqual(vals, want) { + panic(fmt.Sprintf("after Filter got %v, want %v", vals, want)) + } + +} + +func main() { + TestSet() + TestEqual() + TestCopy() + TestAddSet() + TestSubSet() + TestIntersect() + TestIterate() + TestFilter() +}