gob: make recursive map and slice types work.

Before this fix, types such as
        type T map[string]T
caused infinite recursion in the gob implementation.
Now they just work.

Fixes #1518.

R=rsc
CC=golang-dev
https://golang.org/cl/4230045
This commit is contained in:
Rob Pike 2011-02-25 09:45:06 -08:00
parent 895631770a
commit c54b5d032f
5 changed files with 156 additions and 76 deletions

View file

@ -671,7 +671,7 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
}
// Index by Go types.
var decOpMap = []decOp{
var decOpTable = [...]decOp{
reflect.Bool: decBool,
reflect.Int8: decInt8,
reflect.Int16: decInt16,
@ -701,37 +701,43 @@ var decIgnoreOpMap = map[typeId]decOp{
// Return the decoding op for the base type under rt and
// the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) {
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
ut := userType(rt)
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base
indir := ut.indir
var op decOp
k := typ.Kind()
if int(k) < len(decOpMap) {
op = decOpMap[k]
if int(k) < len(decOpTable) {
op = decOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ.(type) {
case *reflect.ArrayType:
name = "element of " + name
elemId := dec.wireType[wireId].ArrayT.Elem
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
}
case *reflect.MapType:
name = "element of " + name
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
up := unsafe.Pointer(p)
state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
}
case *reflect.SliceType:
@ -746,10 +752,10 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
} else {
elemId = dec.wireType[wireId].SliceT.Elem
}
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
}
case *reflect.StructType:
@ -774,7 +780,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
if op == nil {
errorf("gob: decode can't handle type %s", rt.String())
}
return op, indir
return &op, indir
}
// Return the decoding op for a field that has no destination.
@ -838,11 +844,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
// Are these two gob Types compatible?
// Answers the question for basic types, arrays, and slices.
// Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
if rhs, ok := inProgress[fr]; ok {
return rhs == fw
}
inProgress[fr] = fw
fr = userType(fr).base
switch t := fr.(type) {
default:
// map, chan, etc: cannot handle.
// chan, etc: cannot handle.
return false
case *reflect.BoolType:
return fw == tBool
@ -864,14 +874,14 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
return false
}
array := wire.ArrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
case *reflect.MapType:
wire, ok := dec.wireType[fw]
if !ok || wire.MapT == nil {
return false
}
MapType := wire.MapT
return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem)
return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress)
case *reflect.SliceType:
// Is it an array of bytes?
if t.Elem().Kind() == reflect.Uint8 {
@ -885,7 +895,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
sw = dec.wireType[fw].SliceT
}
elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem)
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
case *reflect.StructType:
return true
}
@ -906,12 +916,12 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do
if !dec.compatibleType(rt, remoteId) {
if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) {
return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId))
}
op, indir := dec.decOpFor(remoteId, rt, name)
op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp))
ovfl := os.ErrorString(`value for "` + name + `" out of range`)
engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl}
engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl}
engine.numInstr = 1
return
}
@ -954,6 +964,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
}
engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.Field))
seen := make(map[reflect.Type]*decOp)
// Loop over the fields of the wire type.
for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ {
wireField := wireStruct.Field[fieldnum]
@ -969,11 +980,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl}
continue
}
if !dec.compatibleType(localField.Type, wireField.Id) {
if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) {
errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name)
}
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name)
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl}
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen)
engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl}
engine.numInstr++
}
return
@ -1070,8 +1081,8 @@ func init() {
default:
panic("gob: unknown size of int/uint")
}
decOpMap[reflect.Int] = iop
decOpMap[reflect.Uint] = uop
decOpTable[reflect.Int] = iop
decOpTable[reflect.Uint] = uop
// Finally uintptr
switch reflect.Typeof(uintptr(0)).Bits() {
@ -1082,5 +1093,5 @@ func init() {
default:
panic("gob: unknown size of uintptr")
}
decOpMap[reflect.Uintptr] = uop
decOpTable[reflect.Uintptr] = uop
}