mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
gob: make recursive map and slice types work.
Before this fix, types such as
type T map[string]T
caused infinite recursion in the gob implementation.
Now they just work.
Fixes #1518.
R=rsc
CC=golang-dev
https://golang.org/cl/4230045
This commit is contained in:
parent
895631770a
commit
c54b5d032f
5 changed files with 156 additions and 76 deletions
|
|
@ -671,7 +671,7 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
|
|||
}
|
||||
|
||||
// Index by Go types.
|
||||
var decOpMap = []decOp{
|
||||
var decOpTable = [...]decOp{
|
||||
reflect.Bool: decBool,
|
||||
reflect.Int8: decInt8,
|
||||
reflect.Int16: decInt16,
|
||||
|
|
@ -701,37 +701,43 @@ var decIgnoreOpMap = map[typeId]decOp{
|
|||
|
||||
// Return the decoding op for the base type under rt and
|
||||
// the indirection count to reach it.
|
||||
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int) {
|
||||
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
|
||||
ut := userType(rt)
|
||||
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
|
||||
// Return the pointer to the op we're already building.
|
||||
if opPtr := inProgress[rt]; opPtr != nil {
|
||||
return opPtr, ut.indir
|
||||
}
|
||||
typ := ut.base
|
||||
indir := ut.indir
|
||||
var op decOp
|
||||
k := typ.Kind()
|
||||
if int(k) < len(decOpMap) {
|
||||
op = decOpMap[k]
|
||||
if int(k) < len(decOpTable) {
|
||||
op = decOpTable[k]
|
||||
}
|
||||
if op == nil {
|
||||
inProgress[rt] = &op
|
||||
// Special cases
|
||||
switch t := typ.(type) {
|
||||
case *reflect.ArrayType:
|
||||
name = "element of " + name
|
||||
elemId := dec.wireType[wireId].ArrayT.Elem
|
||||
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
|
||||
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
|
||||
ovfl := overflow(name)
|
||||
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
|
||||
state.dec.decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
|
||||
state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
|
||||
}
|
||||
|
||||
case *reflect.MapType:
|
||||
name = "element of " + name
|
||||
keyId := dec.wireType[wireId].MapT.Key
|
||||
elemId := dec.wireType[wireId].MapT.Elem
|
||||
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name)
|
||||
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
|
||||
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
|
||||
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
|
||||
ovfl := overflow(name)
|
||||
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
|
||||
up := unsafe.Pointer(p)
|
||||
state.dec.decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl)
|
||||
state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
|
||||
}
|
||||
|
||||
case *reflect.SliceType:
|
||||
|
|
@ -746,10 +752,10 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
|
|||
} else {
|
||||
elemId = dec.wireType[wireId].SliceT.Elem
|
||||
}
|
||||
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name)
|
||||
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
|
||||
ovfl := overflow(name)
|
||||
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
|
||||
state.dec.decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
|
||||
state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
|
||||
}
|
||||
|
||||
case *reflect.StructType:
|
||||
|
|
@ -774,7 +780,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp
|
|||
if op == nil {
|
||||
errorf("gob: decode can't handle type %s", rt.String())
|
||||
}
|
||||
return op, indir
|
||||
return &op, indir
|
||||
}
|
||||
|
||||
// Return the decoding op for a field that has no destination.
|
||||
|
|
@ -838,11 +844,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
|
|||
// Are these two gob Types compatible?
|
||||
// Answers the question for basic types, arrays, and slices.
|
||||
// Structs are considered ok; fields will be checked later.
|
||||
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
|
||||
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
|
||||
if rhs, ok := inProgress[fr]; ok {
|
||||
return rhs == fw
|
||||
}
|
||||
inProgress[fr] = fw
|
||||
fr = userType(fr).base
|
||||
switch t := fr.(type) {
|
||||
default:
|
||||
// map, chan, etc: cannot handle.
|
||||
// chan, etc: cannot handle.
|
||||
return false
|
||||
case *reflect.BoolType:
|
||||
return fw == tBool
|
||||
|
|
@ -864,14 +874,14 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
|
|||
return false
|
||||
}
|
||||
array := wire.ArrayT
|
||||
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem)
|
||||
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
|
||||
case *reflect.MapType:
|
||||
wire, ok := dec.wireType[fw]
|
||||
if !ok || wire.MapT == nil {
|
||||
return false
|
||||
}
|
||||
MapType := wire.MapT
|
||||
return dec.compatibleType(t.Key(), MapType.Key) && dec.compatibleType(t.Elem(), MapType.Elem)
|
||||
return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress)
|
||||
case *reflect.SliceType:
|
||||
// Is it an array of bytes?
|
||||
if t.Elem().Kind() == reflect.Uint8 {
|
||||
|
|
@ -885,7 +895,7 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
|
|||
sw = dec.wireType[fw].SliceT
|
||||
}
|
||||
elem := userType(t.Elem()).base
|
||||
return sw != nil && dec.compatibleType(elem, sw.Elem)
|
||||
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
|
||||
case *reflect.StructType:
|
||||
return true
|
||||
}
|
||||
|
|
@ -906,12 +916,12 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
|
|||
engine = new(decEngine)
|
||||
engine.instr = make([]decInstr, 1) // one item
|
||||
name := rt.String() // best we can do
|
||||
if !dec.compatibleType(rt, remoteId) {
|
||||
if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) {
|
||||
return nil, os.ErrorString("gob: wrong type received for local value " + name + ": " + dec.typeString(remoteId))
|
||||
}
|
||||
op, indir := dec.decOpFor(remoteId, rt, name)
|
||||
op, indir := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp))
|
||||
ovfl := os.ErrorString(`value for "` + name + `" out of range`)
|
||||
engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl}
|
||||
engine.instr[singletonField] = decInstr{*op, singletonField, indir, 0, ovfl}
|
||||
engine.numInstr = 1
|
||||
return
|
||||
}
|
||||
|
|
@ -954,6 +964,7 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
|
|||
}
|
||||
engine = new(decEngine)
|
||||
engine.instr = make([]decInstr, len(wireStruct.Field))
|
||||
seen := make(map[reflect.Type]*decOp)
|
||||
// Loop over the fields of the wire type.
|
||||
for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ {
|
||||
wireField := wireStruct.Field[fieldnum]
|
||||
|
|
@ -969,11 +980,11 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
|
|||
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl}
|
||||
continue
|
||||
}
|
||||
if !dec.compatibleType(localField.Type, wireField.Id) {
|
||||
if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) {
|
||||
errorf("gob: wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name)
|
||||
}
|
||||
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name)
|
||||
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl}
|
||||
op, indir := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen)
|
||||
engine.instr[fieldnum] = decInstr{*op, fieldnum, indir, uintptr(localField.Offset), ovfl}
|
||||
engine.numInstr++
|
||||
}
|
||||
return
|
||||
|
|
@ -1070,8 +1081,8 @@ func init() {
|
|||
default:
|
||||
panic("gob: unknown size of int/uint")
|
||||
}
|
||||
decOpMap[reflect.Int] = iop
|
||||
decOpMap[reflect.Uint] = uop
|
||||
decOpTable[reflect.Int] = iop
|
||||
decOpTable[reflect.Uint] = uop
|
||||
|
||||
// Finally uintptr
|
||||
switch reflect.Typeof(uintptr(0)).Bits() {
|
||||
|
|
@ -1082,5 +1093,5 @@ func init() {
|
|||
default:
|
||||
panic("gob: unknown size of uintptr")
|
||||
}
|
||||
decOpMap[reflect.Uintptr] = uop
|
||||
decOpTable[reflect.Uintptr] = uop
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue