gob: make recursive map and slice types work.

Before this fix, types such as
        type T map[string]T
caused infinite recursion in the gob implementation.
Now they just work.

Fixes #1518.

R=rsc
CC=golang-dev
https://golang.org/cl/4230045
This commit is contained in:
Rob Pike 2011-02-25 09:45:06 -08:00
parent 895631770a
commit c54b5d032f
5 changed files with 156 additions and 76 deletions

View file

@ -414,7 +414,7 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
}
}
var encOpMap = []encOp{
var encOpTable = [...]encOp{
reflect.Bool: encBool,
reflect.Int: encInt,
reflect.Int8: encInt8,
@ -434,18 +434,24 @@ var encOpMap = []encOp{
reflect.String: encString,
}
// Return the encoding op for the base type under rt and
// Return (a pointer to) the encoding op for the base type under rt and
// the indirection count to reach it.
func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
ut := userType(rt)
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base
indir := ut.indir
var op encOp
k := typ.Kind()
if int(k) < len(encOpMap) {
op = encOpMap[k]
var op encOp
if int(k) < len(encOpTable) {
op = encOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ.(type) {
case *reflect.SliceType:
@ -454,25 +460,25 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
break
}
// Slices have a header; we decode it to find the underlying array.
elemOp, indir := enc.encOpFor(t.Elem())
elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
slice := (*reflect.SliceHeader)(p)
if !state.sendZero && slice.Len == 0 {
return
}
state.update(i)
state.enc.encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len))
state.enc.encodeArray(state.b, slice.Data, *elemOp, t.Elem().Size(), indir, int(slice.Len))
}
case *reflect.ArrayType:
// True arrays have size in the type.
elemOp, indir := enc.encOpFor(t.Elem())
elemOp, indir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i)
state.enc.encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len())
state.enc.encodeArray(state.b, uintptr(p), *elemOp, t.Elem().Size(), indir, t.Len())
}
case *reflect.MapType:
keyOp, keyIndir := enc.encOpFor(t.Key())
elemOp, elemIndir := enc.encOpFor(t.Elem())
keyOp, keyIndir := enc.encOpFor(t.Key(), inProgress)
elemOp, elemIndir := enc.encOpFor(t.Elem(), inProgress)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for
@ -483,7 +489,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
return
}
state.update(i)
state.enc.encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir)
state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir)
}
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
@ -511,21 +517,22 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int) {
if op == nil {
errorf("gob enc: can't happen: encode type %s", rt.String())
}
return op, indir
return &op, indir
}
// The local Type was compiled from the actual value, so we know it's compatible.
func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
srt, isStruct := rt.(*reflect.StructType)
engine := new(encEngine)
seen := make(map[reflect.Type]*encOp)
if isStruct {
for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum)
if !isExported(f.Name) {
continue
}
op, indir := enc.encOpFor(f.Type)
engine.instr = append(engine.instr, encInstr{op, fieldNum, indir, uintptr(f.Offset)})
op, indir := enc.encOpFor(f.Type, seen)
engine.instr = append(engine.instr, encInstr{*op, fieldNum, indir, uintptr(f.Offset)})
}
if srt.NumField() > 0 && len(engine.instr) == 0 {
errorf("type %s has no exported fields", rt)
@ -533,8 +540,8 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, 0, 0})
} else {
engine.instr = make([]encInstr, 1)
op, indir := enc.encOpFor(rt)
engine.instr[0] = encInstr{op, singletonField, indir, 0} // offset is zero
op, indir := enc.encOpFor(rt, seen)
engine.instr[0] = encInstr{*op, singletonField, indir, 0} // offset is zero
}
return engine
}