- clean up code creating keys for type maps

- derive int, uint, float, uintptr decoders based on their size
- add overflow checks in decode

R=rsc
DELTA=407  (281 added, 44 deleted, 82 changed)
OCL=32286
CL=32290
This commit is contained in:
Rob Pike 2009-07-28 12:59:39 -07:00
parent 08b5b4843b
commit 483e4fc409
5 changed files with 362 additions and 125 deletions

View file

@ -29,6 +29,10 @@ type decodeState struct {
fieldnum int; // the last field number read.
}
func overflow(name string) os.ErrorString {
return os.ErrorString(`value for "` + name + `" out of range`);
}
// decodeUintReader reads an encoded unsigned integer from an io.Reader.
// Used only by the Decoder to read the message length.
func decodeUintReader(r io.Reader, oneByte []byte) (x uint64, err os.Error) {
@ -50,6 +54,7 @@ func decodeUintReader(r io.Reader, oneByte []byte) (x uint64, err os.Error) {
// decodeUint reads an encoded unsigned integer from state.r.
// Sets state.err. If state.err is already non-nil, it does nothing.
// Does not check for overflow.
func decodeUint(state *decodeState) (x uint64) {
if state.err != nil {
return
@ -71,6 +76,7 @@ func decodeUint(state *decodeState) (x uint64) {
// decodeInt reads an encoded signed integer from state.r.
// Sets state.err. If state.err is already non-nil, it does nothing.
// Does not check for overflow.
func decodeInt(state *decodeState) int64 {
x := decodeUint(state);
if state.err != nil {
@ -91,6 +97,7 @@ type decInstr struct {
field int; // field number of the wire type
indir int; // how many pointer indirections to reach the value in the struct
offset uintptr; // offset in the structure of the field to encode
ovfl os.ErrorString; // error message for overflow/underflow (for arrays, of the elements)
}
// Since the encoder writes no zeros, if we arrive at a decoder we have
@ -126,26 +133,6 @@ func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*bool)(p) = decodeInt(state) != 0;
}
func decInt(i *decInstr, state *decodeState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int));
}
p = *(*unsafe.Pointer)(p);
}
*(*int)(p) = int(decodeInt(state));
}
func decUint(i *decInstr, state *decodeState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint));
}
p = *(*unsafe.Pointer)(p);
}
*(*uint)(p) = uint(decodeUint(state));
}
func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
@ -153,7 +140,12 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
p = *(*unsafe.Pointer)(p);
}
*(*int8)(p) = int8(decodeInt(state));
v := decodeInt(state);
if v < math.MinInt8 || math.MaxInt8 < v {
state.err = i.ovfl
} else {
*(*int8)(p) = int8(v)
}
}
func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
@ -163,7 +155,12 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
p = *(*unsafe.Pointer)(p);
}
*(*uint8)(p) = uint8(decodeUint(state));
v := decodeUint(state);
if math.MaxUint8 < v {
state.err = i.ovfl
} else {
*(*uint8)(p) = uint8(v)
}
}
func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
@ -173,7 +170,12 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
p = *(*unsafe.Pointer)(p);
}
*(*int16)(p) = int16(decodeInt(state));
v := decodeInt(state);
if v < math.MinInt16 || math.MaxInt16 < v {
state.err = i.ovfl
} else {
*(*int16)(p) = int16(v)
}
}
func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
@ -183,7 +185,12 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
p = *(*unsafe.Pointer)(p);
}
*(*uint16)(p) = uint16(decodeUint(state));
v := decodeUint(state);
if math.MaxUint16 < v {
state.err = i.ovfl
} else {
*(*uint16)(p) = uint16(v)
}
}
func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
@ -193,7 +200,12 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
p = *(*unsafe.Pointer)(p);
}
*(*int32)(p) = int32(decodeInt(state));
v := decodeInt(state);
if v < math.MinInt32 || math.MaxInt32 < v {
state.err = i.ovfl
} else {
*(*int32)(p) = int32(v)
}
}
func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
@ -203,7 +215,12 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
p = *(*unsafe.Pointer)(p);
}
*(*uint32)(p) = uint32(decodeUint(state));
v := decodeUint(state);
if math.MaxUint32 < v {
state.err = i.ovfl
} else {
*(*uint32)(p) = uint32(v)
}
}
func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) {
@ -226,16 +243,6 @@ func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*uint64)(p) = uint64(decodeUint(state));
}
func decUintptr(i *decInstr, state *decodeState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uintptr));
}
p = *(*unsafe.Pointer)(p);
}
*(*uintptr)(p) = uintptr(decodeUint(state));
}
// Floating-point numbers are transmitted as uint64s holding the bits
// of the underlying representation. They are sent byte-reversed, with
// the exponent end coming out first, so integer floating point numbers
@ -251,16 +258,6 @@ func floatFromBits(u uint64) float64 {
return math.Float64frombits(v);
}
func decFloat(i *decInstr, state *decodeState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(float));
}
p = *(*unsafe.Pointer)(p);
}
*(*float)(p) = float(floatFromBits(uint64(decodeUint(state))));
}
func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
@ -268,7 +265,16 @@ func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
p = *(*unsafe.Pointer)(p);
}
*(*float32)(p) = float32(floatFromBits(uint64(decodeUint(state))));
v := floatFromBits(decodeUint(state));
av := v;
if av < 0 {
av = -av
}
if math.MaxFloat32 < av { // underflow is OK
state.err = i.ovfl
} else {
*(*float32)(p) = float32(v)
}
}
func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
@ -386,8 +392,8 @@ func ignoreStruct(engine *decEngine, b *bytes.Buffer) os.Error {
return state.err
}
func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int) os.Error {
instr := &decInstr{elemOp, 0, elemIndir, 0};
func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) os.Error {
instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl};
for i := 0; i < length && state.err == nil; i++ {
up := unsafe.Pointer(p);
if elemIndir > 1 {
@ -399,7 +405,7 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint
return state.err
}
func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int) os.Error {
func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error {
if indir > 0 {
up := unsafe.Pointer(p);
if *(*unsafe.Pointer)(up) == nil {
@ -413,11 +419,11 @@ func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp
if n := decodeUint(state); n != uint64(length) {
return os.ErrorString("gob: length mismatch in decodeArray");
}
return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir);
return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl);
}
func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error {
instr := &decInstr{elemOp, 0, 0, 0};
instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")};
for i := 0; i < length && state.err == nil; i++ {
elemOp(instr, state, nil);
}
@ -431,7 +437,7 @@ func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error {
return ignoreArrayHelper(state, elemOp, length);
}
func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int) os.Error {
func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error {
length := uintptr(decodeUint(state));
if indir > 0 {
up := unsafe.Pointer(p);
@ -448,7 +454,7 @@ func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp
hdrp.Data = uintptr(unsafe.Pointer(&data[0]));
hdrp.Len = int(length);
hdrp.Cap = int(length);
return decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, int(length), elemIndir);
return decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, int(length), elemIndir, ovfl);
}
func ignoreSlice(state *decodeState, elemOp decOp) os.Error {
@ -456,22 +462,18 @@ func ignoreSlice(state *decodeState, elemOp decOp) os.Error {
}
var decOpMap = map[reflect.Type] decOp {
reflect.Typeof((*reflect.BoolType)(nil)): decBool,
reflect.Typeof((*reflect.IntType)(nil)): decInt,
reflect.Typeof((*reflect.Int8Type)(nil)): decInt8,
reflect.Typeof((*reflect.Int16Type)(nil)): decInt16,
reflect.Typeof((*reflect.Int32Type)(nil)): decInt32,
reflect.Typeof((*reflect.Int64Type)(nil)): decInt64,
reflect.Typeof((*reflect.UintType)(nil)): decUint,
reflect.Typeof((*reflect.Uint8Type)(nil)): decUint8,
reflect.Typeof((*reflect.Uint16Type)(nil)): decUint16,
reflect.Typeof((*reflect.Uint32Type)(nil)): decUint32,
reflect.Typeof((*reflect.Uint64Type)(nil)): decUint64,
reflect.Typeof((*reflect.UintptrType)(nil)): decUintptr,
reflect.Typeof((*reflect.FloatType)(nil)): decFloat,
reflect.Typeof((*reflect.Float32Type)(nil)): decFloat32,
reflect.Typeof((*reflect.Float64Type)(nil)): decFloat64,
reflect.Typeof((*reflect.StringType)(nil)): decString,
valueKind(false): decBool,
valueKind(int8(0)): decInt8,
valueKind(int16(0)): decInt16,
valueKind(int32(0)): decInt32,
valueKind(int64(0)): decInt64,
valueKind(uint8(0)): decUint8,
valueKind(uint16(0)): decUint16,
valueKind(uint32(0)): decUint32,
valueKind(uint64(0)): decUint64,
valueKind(float32(0)): decFloat32,
valueKind(float64(0)): decFloat64,
valueKind("x"): decString,
}
var decIgnoreOpMap = map[typeId] decOp {
@ -488,34 +490,38 @@ func getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error)
// Return the decoding op for the base type under rt and
// the indirection count to reach it.
func decOpFor(wireId typeId, rt reflect.Type) (decOp, int, os.Error) {
func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
typ, indir := indirect(rt);
op, ok := decOpMap[reflect.Typeof(typ)];
if !ok {
// Special cases
switch t := typ.(type) {
case *reflect.SliceType:
name = "element of " + name;
if _, ok := t.Elem().(*reflect.Uint8Type); ok {
op = decUint8Array;
break;
}
elemId := wireId.gobType().(*sliceType).Elem;
elemOp, elemIndir, err := decOpFor(elemId, t.Elem());
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
if err != nil {
return nil, 0, err
}
ovfl := overflow(name);
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir);
state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl);
};
case *reflect.ArrayType:
name = "element of " + name;
elemId := wireId.gobType().(*arrayType).Elem;
elemOp, elemIndir, err := decOpFor(elemId, t.Elem());
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
if err != nil {
return nil, 0, err
}
ovfl := overflow(name);
op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir);
state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl);
};
case *reflect.StructType:
@ -658,23 +664,25 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
wireField := wireStruct.field[fieldnum];
// Find the field of the local type with the same name.
localField, present := srt.FieldByName(wireField.name);
ovfl := overflow(wireField.name);
// TODO(r): anonymous names
if !present || localField.Anonymous {
op, err := decIgnoreOpFor(wireField.id);
if err != nil {
return nil, err
}
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0};
engine.instr[fieldnum] = decInstr{op, fieldnum, 0, 0, ovfl};
continue;
}
if !compatibleType(localField.Type, wireField.id) {
return nil, os.ErrorString("gob: wrong type for field " + wireField.name + " in type " + wireId.Name());
details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + wireId.Name();
return nil, os.ErrorString("gob: wrong type for field " + wireField.name + details);
}
op, indir, err := decOpFor(wireField.id, localField.Type);
op, indir, err := decOpFor(wireField.id, localField.Type, localField.Name);
if err != nil {
return nil, err
}
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset)};
engine.instr[fieldnum] = decInstr{op, fieldnum, indir, uintptr(localField.Offset), ovfl};
engine.numInstr++;
}
return;
@ -746,3 +754,44 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
}
return decodeStruct(engine, rt.(*reflect.StructType), b, uintptr(v.Addr()), 0);
}
func init() {
// We assume that the size of float is sufficient to tell us whether it is
// equivalent to float32 or to float64. This is very unlikely to be wrong.
var op decOp;
switch unsafe.Sizeof(float(0)) {
case unsafe.Sizeof(float32(0)):
op = decFloat32;
case unsafe.Sizeof(float64(0)):
op = decFloat64;
default:
panic("gob: unknown size of float", unsafe.Sizeof(float(0)));
}
decOpMap[valueKind(float(0))] = op;
// A similar assumption about int and uint. Also assume int and uint have the same size.
var uop decOp;
switch unsafe.Sizeof(int(0)) {
case unsafe.Sizeof(int32(0)):
op = decInt32;
uop = decUint32;
case unsafe.Sizeof(int64(0)):
op = decInt64;
uop = decUint64;
default:
panic("gob: unknown size of int/uint", unsafe.Sizeof(int(0)));
}
decOpMap[valueKind(int(0))] = op;
decOpMap[valueKind(uint(0))] = uop;
// Finally uintptr
switch unsafe.Sizeof(uintptr(0)) {
case unsafe.Sizeof(uint32(0)):
uop = decUint32;
case unsafe.Sizeof(uint64(0)):
uop = decUint64;
default:
panic("gob: unknown size of uintptr", unsafe.Sizeof(uintptr(0)));
}
decOpMap[valueKind(uintptr(0))] = uop;
}