gob: make recursive map and slice types work.

Before this fix, types such as
        type T map[string]T
caused infinite recursion in the gob implementation.
Now they just work.

Fixes #1518.

R=rsc
CC=golang-dev
https://golang.org/cl/4230045
This commit is contained in:
Rob Pike 2011-02-25 09:45:06 -08:00
parent 895631770a
commit c54b5d032f
5 changed files with 156 additions and 76 deletions

View file

@ -342,7 +342,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct { var data struct {
a int a int
} }
instr := &decInstr{decOpMap[reflect.Int], 6, 0, 0, ovfl} instr := &decInstr{decOpTable[reflect.Int], 6, 0, 0, ovfl}
state := newDecodeStateFromData(signedResult) state := newDecodeStateFromData(signedResult)
execDec("int", instr, state, t, unsafe.Pointer(&data)) execDec("int", instr, state, t, unsafe.Pointer(&data))
if data.a != 17 { if data.a != 17 {
@ -355,7 +355,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct { var data struct {
a uint a uint
} }
instr := &decInstr{decOpMap[reflect.Uint], 6, 0, 0, ovfl} instr := &decInstr{decOpTable[reflect.Uint], 6, 0, 0, ovfl}
state := newDecodeStateFromData(unsignedResult) state := newDecodeStateFromData(unsignedResult)
execDec("uint", instr, state, t, unsafe.Pointer(&data)) execDec("uint", instr, state, t, unsafe.Pointer(&data))
if data.a != 17 { if data.a != 17 {
@ -446,7 +446,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct { var data struct {
a uintptr a uintptr
} }
instr := &decInstr{decOpMap[reflect.Uintptr], 6, 0, 0, ovfl} instr := &decInstr{decOpTable[reflect.Uintptr], 6, 0, 0, ovfl}
state := newDecodeStateFromData(unsignedResult) state := newDecodeStateFromData(unsignedResult)
execDec("uintptr", instr, state, t, unsafe.Pointer(&data)) execDec("uintptr", instr, state, t, unsafe.Pointer(&data))
if data.a != 17 { if data.a != 17 {
@ -511,7 +511,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct { var data struct {
a complex64 a complex64
} }
instr := &decInstr{decOpMap[reflect.Complex64], 6, 0, 0, ovfl} instr := &decInstr{decOpTable[reflect.Complex64], 6, 0, 0, ovfl}
state := newDecodeStateFromData(complexResult) state := newDecodeStateFromData(complexResult)
execDec("complex", instr, state, t, unsafe.Pointer(&data)) execDec("complex", instr, state, t, unsafe.Pointer(&data))
if data.a != 17+19i { if data.a != 17+19i {
@ -524,7 +524,7 @@ func TestScalarDecInstructions(t *testing.T) {
var data struct { var data struct {
a complex128 a complex128
} }
instr := &decInstr{decOpMap[reflect.Complex128], 6, 0, 0, ovfl} instr := &decInstr{decOpTable[reflect.Complex128], 6, 0, 0, ovfl}
state := newDecodeStateFromData(complexResult) state := newDecodeStateFromData(complexResult)
execDec("complex", instr, state, t, unsafe.Pointer(&data)) execDec("complex", instr, state, t, unsafe.Pointer(&data))
if data.a != 17+19i { if data.a != 17+19i {

View file

@ -671,7 +671,7 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
} }
// Index by Go types. // Index by Go types.
var decOpMap = []decOp{ var decOpTable = [...]decOp{
reflect.Bool: decBool, reflect.Bool: decBool,
reflect.Int8: decInt8, reflect.Int8: decInt8,
reflect.Int16: decInt16, reflect.Int16: decInt16,
@ -701,37 +701,43 @@ var decIgnoreOpMap = map[typeId]decOp{
// Return the decoding op for the base type under rt and // Return the decoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) { func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
ut := userType(rt) ut := userType(rt)
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base typ := ut.base
indir := ut.indir indir := ut.indir
var op decOp var op decOp
k := typ.Kind() k := typ.Kind()
if int(k) < len(decOpMap) { if int(k) < len(decOpTable) {
op = decOpMap[k] op = decOpTable[k]
} }
if op == nil { if op == nil {
inProgress[rt] = &op
// Special cases // Special cases
switch t := typ.(type) { switch t := typ.(type) {
case *reflect.ArrayType: case *reflect.ArrayType:
name = "element of " + name name = "element of " + name
elemId := dec.wireType[wireId].ArrayT.Elem elemId := dec.wireType[wireId].ArrayT.Elem
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
} }
case *reflect.MapType: case *reflect.MapType:
name = "element of " + name name = "element of " + name
keyId := dec.wireType[wireId].MapT.Key keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem elemId := dec.wireType[wireId].MapT.Elem
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name) keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
up := unsafe.Pointer(p) up := unsafe.Pointer(p)
state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
} }
case *reflect.SliceType: case *reflect.SliceType:
@ -746,10 +752,10 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
} else { } else {
elemId = dec.wireType[wireId].SliceT.Elem elemId = dec.wireType[wireId].SliceT.Elem
} }
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name) ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
} }
case *reflect.StructType: case *reflect.StructType:
@ -774,7 +780,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
if op == nil { if op == nil {
errorf("gob: decode can't handle type %s", rt.String()) errorf("gob: decode can't handle type %s", rt.String())
} }
return op, indir return &op, indir
} }
// Return the decoding op for a field that has no destination. // Return the decoding op for a field that has no destination.
@ -838,11 +844,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
// Are these two gob Types compatible? // Are these two gob Types compatible?
// Answers the question for basic types, arrays, and slices. // Answers the question for basic types, arrays, and slices.
// Structs are considered ok; fields will be checked later. // Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
if rhs, ok := inProgress[fr]; ok {
return rhs == fw
}
inProgress[fr] = fw
fr = userType(fr).base fr = userType(fr).base
switch t := fr.(type) { switch t := fr.(type) {
default: default:
// map, chan, etc: cannot handle. // chan, etc: cannot handle.
return false return false
case *reflect.BoolType: case *reflect.BoolType:
return fw == tBool return fw == tBool
@ -864,14 +874,14 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
return false return false
} }
array := wire.ArrayT array := wire.ArrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem) return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
case *reflect.MapType: case *reflect.MapType:
wire, ok := dec.wireType[fw] wire, ok := dec.wireType[fw]
if !ok || wire.MapT == nil { if !ok || wire.MapT == nil {
return false return false
} }
MapType := wire.MapT MapType := wire.MapT
return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem) return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress)
case *reflect.SliceType: case *reflect.SliceType:
// Is it an array of bytes? // Is it an array of bytes?
if t.Elem().Kind() == reflect.Uint8 { if t.Elem().Kind() == reflect.Uint8 {
@ -885,7 +895,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
sw = dec.wireType[fw].SliceT sw = dec.wireType[fw].SliceT
} }
elem := userType(t.Elem()).base elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem) return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
case *reflect.StructType: case *reflect.StructType:
return true return true
} }
@ -906,12 +916,12 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do name := rt.String() // best we can do
if !dec.compatibleType(rt, remoteId) { if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) {
return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId)) return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId))
} }
op, indir := dec.decOpFor(remoteId, rt, name) op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp))
ovfl := os.ErrorString(`value for "` + name + `" out of range`) ovfl := os.ErrorString(`value for "` + name + `" out of range`)
engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl} engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl}
engine.numInstr = 1 engine.numInstr = 1
return return
} }
@ -954,6 +964,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
} }
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.Field)) engine.instr = make([]decInstr, len(wireStruct.Field))
seen := make(map[reflect.Type]*decOp)
// Loop over the fields of the wire type. // Loop over the fields of the wire type.
for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ { for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ {
wireField := wireStruct.Field[fieldnum] wireField := wireStruct.Field[fieldnum]
@ -969,11 +980,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl} engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl}
continue continue
} }
if !dec.compatibleType(localField.Type, wireField.Id) { if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) {
errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name) errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name)
} }
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name) op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen)
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl} engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl}
engine.numInstr++ engine.numInstr++
} }
return return
@ -1070,8 +1081,8 @@ func init() {
default: default:
panic("gob: unknown size of int/uint") panic("gob: unknown size of int/uint")
} }
decOpMap[reflect.Int] = iop decOpTable[reflect.Int] = iop
decOpMap[reflect.Uint] = uop decOpTable[reflect.Uint] = uop
// Finally uintptr // Finally uintptr
switch reflect.Typeof(uintptr(0)).Bits() { switch reflect.Typeof(uintptr(0)).Bits() {
@ -1082,5 +1093,5 @@ func init() {
default: default:
panic("gob: unknown size of uintptr") panic("gob: unknown size of uintptr")
} }
decOpMap[reflect.Uintptr] = uop decOpTable[reflect.Uintptr] = uop
} }

View file

@ -414,7 +414,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
} }
} }
var encOpMap = []encOp{ var encOpTable = [...]encOp{
reflect.Bool: encBool, reflect.Bool: encBool,
reflect.Int: encInt, reflect.Int: encInt,
reflect.Int8: encInt8, reflect.Int8: encInt8,
@ -434,18 +434,24 @@ var encOpMap = []encOp{
reflect.String: encString, reflect.String: encString,
} }
// Return the encoding op for the base type under rt and // Return (a pointer to) the encoding op for the base type under rt and
// the indirection count to reach it. // the indirection count to reach it.
func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) { func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
ut := userType(rt) ut := userType(rt)
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base typ := ut.base
indir := ut.indir indir := ut.indir
var op encOp
k := typ.Kind() k := typ.Kind()
if int(k) < len(encOpMap) { var op encOp
op = encOpMap[k] if int(k) < len(encOpTable) {
op = encOpTable[k]
} }
if op == nil { if op == nil {
inProgress[rt] = &op
// Special cases // Special cases
switch t := typ.(type) { switch t := typ.(type) {
case *reflect.SliceType: case *reflect.SliceType:
@ -454,25 +460,25 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
break break
} }
// Slices have a header; we decode it to find the underlying array. // Slices have a header; we decode it to find the underlying array.
elemOp, indir := enc.encOpFor(t.Elem()) elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
slice := (*reflect.SliceHeader)(p) slice := (*reflect.SliceHeader)(p)
if !state.sendZero && slice.Len == 0 { if !state.sendZero && slice.Len == 0 {
return return
} }
state.update(i) state.update(i)
state.enc.encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len)) state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len))
} }
case *reflect.ArrayType: case *reflect.ArrayType:
// True arrays have size in the type. // True arrays have size in the type.
elemOp, indir := enc.encOpFor(t.Elem()) elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i) state.update(i)
state.enc.encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len())
} }
case *reflect.MapType: case *reflect.MapType:
keyOp, keyIndir := enc.encOpFor(t.Key()) keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress)
elemOp, elemIndir := enc.encOpFor(t.Elem()) elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Maps cannot be accessed by moving addresses around the way // Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for // that slices etc. can. We must recover a full reflection value for
@ -483,7 +489,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
return return
} }
state.update(i) state.update(i)
state.enc.encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir) state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir)
} }
case *reflect.StructType: case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type. // Generate a closure that calls out to the engine for the nested type.
@ -511,21 +517,22 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
if op == nil { if op == nil {
errorf("gob enc: can't happen: encode type %s", rt.String()) errorf("gob enc: can't happen: encode type %s", rt.String())
} }
return op, indir return &op, indir
} }
// The local Type was compiled from the actual value, so we know it's compatible. // The local Type was compiled from the actual value, so we know it's compatible.
func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
srt, isStruct := rt.(*reflect.StructType) srt, isStruct := rt.(*reflect.StructType)
engine := new(encEngine) engine := new(encEngine)
seen := make(map[reflect.Type]*encOp)
if isStruct { if isStruct {
for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ { for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum) f := srt.Field(fieldNum)
if !isExported(f.Name) { if !isExported(f.Name) {
continue continue
} }
op, indir := enc.encOpFor(f.Type) op, indir := enc.encOpFor(f.Type, seen)
engine.instr = append(engine.instr, encInstr{op, fieldNum, indir, uintptr(f.Offset)}) engine.instr = append(engine.instr, encInstr{*op, fieldNum, indir, uintptr(f.Offset)})
} }
if srt.NumField() > 0 && len(engine.instr) == 0 { if srt.NumField() > 0 && len(engine.instr) == 0 {
errorf("type %s has no exported fields", rt) errorf("type %s has no exported fields", rt)
@ -533,8 +540,8 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0}) engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0})
} else { } else {
engine.instr = make([]encInstr, 1) engine.instr = make([]encInstr, 1)
op, indir := enc.encOpFor(rt) op, indir := enc.encOpFor(rt, seen)
engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero engine.instr[0] = encInstr{*op, singletonField, indir, 0} // offset is zero
} }
return engine return engine
} }

View file

@ -249,6 +249,24 @@ func TestArray(t *testing.T) {
} }
} }
func TestRecursiveMapType(t *testing.T) {
type recursiveMap map[string]recursiveMap
r1 := recursiveMap{"A": recursiveMap{"B": nil, "C": nil}, "D": nil}
r2 := make(recursiveMap)
if err := encAndDec(r1, &r2); err != nil {
t.Error(err)
}
}
func TestRecursiveSliceType(t *testing.T) {
type recursiveSlice []recursiveSlice
r1 := recursiveSlice{0: recursiveSlice{0: nil}, 1: nil}
r2 := make(recursiveSlice, 0)
if err := encAndDec(r1, &r2); err != nil {
t.Error(err)
}
}
// Regression test for bug: must send zero values inside arrays // Regression test for bug: must send zero values inside arrays
func TestDefaultsInArray(t *testing.T) { func TestDefaultsInArray(t *testing.T) {
type Type7 struct { type Type7 struct {

View file

@ -52,9 +52,6 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
// cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
// pp 539-540. As we step through indirections, run another type at // pp 539-540. As we step through indirections, run another type at
// half speed. If they meet up, there's a cycle. // half speed. If they meet up, there's a cycle.
// TODO: still need to deal with self-referential non-structs such
// as type T map[string]T but that is a larger undertaking - and can
// be useful, not always erroneous.
slowpoke := ut.base // walks half as fast as ut.base slowpoke := ut.base // walks half as fast as ut.base
for { for {
pt, ok := ut.base.(*reflect.PtrType) pt, ok := ut.base.(*reflect.PtrType)
@ -210,12 +207,18 @@ type arrayType struct {
Len int Len int
} }
func newArrayType(name string, elem gobType, length int) *arrayType { func newArrayType(name string) *arrayType {
a := &arrayType{CommonType{Name: name}, elem.id(), length} a := &arrayType{CommonType{Name: name}, 0, 0}
setTypeId(a)
return a return a
} }
func (a *arrayType) init(elem gobType, len int) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(a)
a.Elem = elem.id()
a.Len = len
}
func (a *arrayType) safeString(seen map[typeId]bool) string { func (a *arrayType) safeString(seen map[typeId]bool) string {
if seen[a.Id] { if seen[a.Id] {
return a.Name return a.Name
@ -233,12 +236,18 @@ type mapType struct {
Elem typeId Elem typeId
} }
func newMapType(name string, key, elem gobType) *mapType { func newMapType(name string) *mapType {
m := &mapType{CommonType{Name: name}, key.id(), elem.id()} m := &mapType{CommonType{Name: name}, 0, 0}
setTypeId(m)
return m return m
} }
func (m *mapType) init(key, elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(m)
m.Key = key.id()
m.Elem = elem.id()
}
func (m *mapType) safeString(seen map[typeId]bool) string { func (m *mapType) safeString(seen map[typeId]bool) string {
if seen[m.Id] { if seen[m.Id] {
return m.Name return m.Name
@ -257,12 +266,17 @@ type sliceType struct {
Elem typeId Elem typeId
} }
func newSliceType(name string, elem gobType) *sliceType { func newSliceType(name string) *sliceType {
s := &sliceType{CommonType{Name: name}, elem.id()} s := &sliceType{CommonType{Name: name}, 0}
setTypeId(s)
return s return s
} }
func (s *sliceType) init(elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(s)
s.Elem = elem.id()
}
func (s *sliceType) safeString(seen map[typeId]bool) string { func (s *sliceType) safeString(seen map[typeId]bool) string {
if seen[s.Id] { if seen[s.Id] {
return s.Name return s.Name
@ -304,11 +318,26 @@ func (s *structType) string() string { return s.safeString(make(map[typeId]bool)
func newStructType(name string) *structType { func newStructType(name string) *structType {
s := &structType{CommonType{Name: name}, nil} s := &structType{CommonType{Name: name}, nil}
// For historical reasons we set the id here rather than init.
// Se the comment in newTypeObject for details.
setTypeId(s) setTypeId(s)
return s return s
} }
func (s *structType) init(field []*fieldType) {
s.Field = field
}
func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
var err os.Error
var type0, type1 gobType
defer func() {
if err != nil {
types[rt] = nil, false
}
}()
// Install the top-level type before the subtypes (e.g. struct before
// fields) so recursive types can be constructed safely.
switch t := rt.(type) { switch t := rt.(type) {
// All basic types are easy: they are predefined. // All basic types are easy: they are predefined.
case *reflect.BoolType: case *reflect.BoolType:
@ -333,40 +362,55 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
return tInterface.gobType(), nil return tInterface.gobType(), nil
case *reflect.ArrayType: case *reflect.ArrayType:
gt, err := getType("", t.Elem()) at := newArrayType(name)
types[rt] = at
type0, err = getType("", t.Elem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newArrayType(name, gt, t.Len()), nil // Historical aside:
// For arrays, maps, and slices, we set the type id after the elements
// are constructed. This is to retain the order of type id allocation after
// a fix made to handle recursive types, which changed the order in
// which types are built. Delaying the setting in this way preserves
// type ids while allowing recursive types to be described. Structs,
// done below, were already handling recursion correctly so they
// assign the top-level id before those of the field.
at.init(type0, t.Len())
return at, nil
case *reflect.MapType: case *reflect.MapType:
kt, err := getType("", t.Key()) mt := newMapType(name)
types[rt] = mt
type0, err = getType("", t.Key())
if err != nil { if err != nil {
return nil, err return nil, err
} }
vt, err := getType("", t.Elem()) type1, err = getType("", t.Elem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newMapType(name, kt, vt), nil mt.init(type0, type1)
return mt, nil
case *reflect.SliceType: case *reflect.SliceType:
// []byte == []uint8 is a special case // []byte == []uint8 is a special case
if t.Elem().Kind() == reflect.Uint8 { if t.Elem().Kind() == reflect.Uint8 {
return tBytes.gobType(), nil return tBytes.gobType(), nil
} }
gt, err := getType(t.Elem().Name(), t.Elem()) st := newSliceType(name)
types[rt] = st
type0, err = getType(t.Elem().Name(), t.Elem())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newSliceType(name, gt), nil st.init(type0)
return st, nil
case *reflect.StructType: case *reflect.StructType:
// Install the struct type itself before the fields so recursive st := newStructType(name)
// structures can be constructed safely. types[rt] = st
strType := newStructType(name) idToType[st.id()] = st
types[rt] = strType
idToType[strType.id()] = strType
field := make([]*fieldType, t.NumField()) field := make([]*fieldType, t.NumField())
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
f := t.Field(i) f := t.Field(i)
@ -382,8 +426,8 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
} }
field[i] = &fieldType{f.Name, gt.id()} field[i] = &fieldType{f.Name, gt.id()}
} }
strType.Field = field st.init(field)
return strType, nil return st, nil
default: default:
return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String()) return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String())
@ -435,7 +479,7 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
// For bootstrapping purposes, we assume that the recipient knows how // For bootstrapping purposes, we assume that the recipient knows how
// to decode a wireType; it is exactly the wireType struct here, interpreted // to decode a wireType; it is exactly the wireType struct here, interpreted
// using the gob rules for sending a structure, except that we assume the // using the gob rules for sending a structure, except that we assume the
// ids for wireType and structType are known. The relevant pieces // ids for wireType and structType etc. are known. The relevant pieces
// are built in encode.go's init() function. // are built in encode.go's init() function.
// To maintain binary compatibility, if you extend this type, always put // To maintain binary compatibility, if you extend this type, always put
// the new fields last. // the new fields last.