mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
encoding/xml: make sure Encoder.Encode reports Write errors.
Fixes #4112. R=remyoudompheng, daniel.morsing, dave, rsc CC=golang-dev https://golang.org/cl/7085053
This commit is contained in:
parent
8eb80914ca
commit
afde71cfbd
4 changed files with 61 additions and 14 deletions
|
|
@ -193,7 +193,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
|
||||||
if xmlns != "" {
|
if xmlns != "" {
|
||||||
p.WriteString(` xmlns="`)
|
p.WriteString(` xmlns="`)
|
||||||
// TODO: EscapeString, to avoid the allocation.
|
// TODO: EscapeString, to avoid the allocation.
|
||||||
Escape(p, []byte(xmlns))
|
if err := EscapeText(p, []byte(xmlns)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
p.WriteByte('"')
|
p.WriteByte('"')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -252,19 +254,22 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error {
|
||||||
p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()))
|
p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()))
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
// TODO: Add EscapeString.
|
// TODO: Add EscapeString.
|
||||||
Escape(p, []byte(val.String()))
|
EscapeText(p, []byte(val.String()))
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
p.WriteString(strconv.FormatBool(val.Bool()))
|
p.WriteString(strconv.FormatBool(val.Bool()))
|
||||||
case reflect.Array:
|
case reflect.Array:
|
||||||
// will be [...]byte
|
// will be [...]byte
|
||||||
bytes := make([]byte, val.Len())
|
var bytes []byte
|
||||||
for i := range bytes {
|
if val.CanAddr() {
|
||||||
bytes[i] = val.Index(i).Interface().(byte)
|
bytes = val.Slice(0, val.Len()).Bytes()
|
||||||
|
} else {
|
||||||
|
bytes = make([]byte, val.Len())
|
||||||
|
reflect.Copy(reflect.ValueOf(bytes), val)
|
||||||
}
|
}
|
||||||
Escape(p, bytes)
|
EscapeText(p, bytes)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
// will be []byte
|
// will be []byte
|
||||||
Escape(p, val.Bytes())
|
EscapeText(p, val.Bytes())
|
||||||
default:
|
default:
|
||||||
return &UnsupportedTypeError{typ}
|
return &UnsupportedTypeError{typ}
|
||||||
}
|
}
|
||||||
|
|
@ -298,10 +303,14 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
Escape(p, strconv.AppendBool(scratch[:0], vf.Bool()))
|
Escape(p, strconv.AppendBool(scratch[:0], vf.Bool()))
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
Escape(p, []byte(vf.String()))
|
if err := EscapeText(p, []byte(vf.String())); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if elem, ok := vf.Interface().([]byte); ok {
|
if elem, ok := vf.Interface().([]byte); ok {
|
||||||
Escape(p, elem)
|
if err := EscapeText(p, elem); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if vf.Type() == timeType {
|
if vf.Type() == timeType {
|
||||||
|
|
|
||||||
|
|
@ -965,6 +965,16 @@ func TestMarshalWriteErrors(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshalWriteIOErrors(t *testing.T) {
|
||||||
|
enc := NewEncoder(errWriter{})
|
||||||
|
|
||||||
|
expectErr := "unwritable"
|
||||||
|
err := enc.Encode(&Passenger{})
|
||||||
|
if err == nil || err.Error() != expectErr {
|
||||||
|
t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkMarshal(b *testing.B) {
|
func BenchmarkMarshal(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
Marshal(atomValue)
|
Marshal(atomValue)
|
||||||
|
|
|
||||||
|
|
@ -1720,9 +1720,9 @@ var (
|
||||||
esc_cr = []byte("
")
|
esc_cr = []byte("
")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Escape writes to w the properly escaped XML equivalent
|
// EscapeText writes to w the properly escaped XML equivalent
|
||||||
// of the plain text data s.
|
// of the plain text data s.
|
||||||
func Escape(w io.Writer, s []byte) {
|
func EscapeText(w io.Writer, s []byte) error {
|
||||||
var esc []byte
|
var esc []byte
|
||||||
last := 0
|
last := 0
|
||||||
for i, c := range s {
|
for i, c := range s {
|
||||||
|
|
@ -1746,11 +1746,25 @@ func Escape(w io.Writer, s []byte) {
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
w.Write(s[last:i])
|
if _, err := w.Write(s[last:i]); err != nil {
|
||||||
w.Write(esc)
|
return err
|
||||||
|
}
|
||||||
|
if _, err := w.Write(esc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
last = i + 1
|
last = i + 1
|
||||||
}
|
}
|
||||||
w.Write(s[last:])
|
if _, err := w.Write(s[last:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape is like EscapeText but omits the error return value.
|
||||||
|
// It is provided for backwards compatibility with Go 1.0.
|
||||||
|
// Code targeting Go 1.1 or later should use EscapeText.
|
||||||
|
func Escape(w io.Writer, s []byte) {
|
||||||
|
EscapeText(w, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// procInstEncoding parses the `encoding="..."` or `encoding='...'`
|
// procInstEncoding parses the `encoding="..."` or `encoding='...'`
|
||||||
|
|
|
||||||
|
|
@ -689,3 +689,17 @@ func TestDirectivesWithComments(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Writer whose Write method always returns an error.
|
||||||
|
type errWriter struct{}
|
||||||
|
|
||||||
|
func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") }
|
||||||
|
|
||||||
|
func TestEscapeTextIOErrors(t *testing.T) {
|
||||||
|
expectErr := "unwritable"
|
||||||
|
err := EscapeText(errWriter{}, []byte{'A'})
|
||||||
|
|
||||||
|
if err == nil || err.Error() != expectErr {
|
||||||
|
t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue