go-yaml/decode.go

831 lines
22 KiB
Go
Raw Normal View History

2019-10-18 00:00:58 +09:00
package yaml
import (
"encoding/base64"
"fmt"
2019-10-18 00:00:58 +09:00
"io"
"io/ioutil"
2019-10-29 19:46:55 +09:00
"math"
"os"
"path/filepath"
2019-10-18 00:00:58 +09:00
"reflect"
2019-10-31 17:38:53 +09:00
"strconv"
2019-11-01 14:09:33 +09:00
"time"
2019-10-18 00:00:58 +09:00
"github.com/goccy/go-yaml/ast"
2019-10-24 11:07:40 +09:00
"github.com/goccy/go-yaml/internal/errors"
2019-10-18 00:00:58 +09:00
"github.com/goccy/go-yaml/parser"
"github.com/goccy/go-yaml/token"
"golang.org/x/xerrors"
2019-10-18 00:00:58 +09:00
)
2019-10-21 12:53:30 +09:00
// Decoder reads and decodes YAML values from an input stream.
2019-10-18 00:00:58 +09:00
type Decoder struct {
2019-12-10 18:41:23 +09:00
reader io.Reader
referenceReaders []io.Reader
anchorMap map[string]ast.Node
opts []DecodeOption
referenceFiles []string
referenceDirs []string
isRecursiveDir bool
isResolvedReference bool
validator StructValidator
disallowUnknownField bool
2019-10-18 00:00:58 +09:00
}
2019-10-21 12:53:30 +09:00
// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
2019-10-18 14:27:49 +09:00
return &Decoder{
2019-12-10 18:41:23 +09:00
reader: r,
anchorMap: map[string]ast.Node{},
opts: opts,
referenceReaders: []io.Reader{},
referenceFiles: []string{},
referenceDirs: []string{},
isRecursiveDir: false,
isResolvedReference: false,
disallowUnknownField: false,
2019-10-18 14:27:49 +09:00
}
2019-10-18 00:00:58 +09:00
}
2019-10-31 17:38:53 +09:00
func (d *Decoder) castToFloat(v interface{}) interface{} {
switch vv := v.(type) {
case int:
return float64(vv)
case int8:
return float64(vv)
case int16:
return float64(vv)
case int32:
return float64(vv)
case int64:
return float64(vv)
case uint:
return float64(vv)
case uint8:
return float64(vv)
case uint16:
return float64(vv)
case uint32:
return float64(vv)
case uint64:
return float64(vv)
2019-11-01 11:26:36 +09:00
case float32:
return float64(vv)
case float64:
return vv
2019-10-31 17:38:53 +09:00
case string:
2019-11-01 11:26:36 +09:00
// if error occurred, return zero value
2019-10-31 17:38:53 +09:00
f, _ := strconv.ParseFloat(vv, 64)
return f
}
return 0
}
2019-10-18 00:00:58 +09:00
func (d *Decoder) nodeToValue(node ast.Node) interface{} {
switch n := node.(type) {
case *ast.NullNode:
return nil
case *ast.StringNode:
return n.GetValue()
case *ast.IntegerNode:
return n.GetValue()
case *ast.FloatNode:
return n.GetValue()
case *ast.BoolNode:
return n.GetValue()
case *ast.InfinityNode:
return n.GetValue()
case *ast.NanNode:
return n.GetValue()
case *ast.TagNode:
switch n.Start.Value {
2019-11-01 14:22:06 +09:00
case token.TimestampTag:
t, _ := d.castToTime(n.Value)
return t
2019-10-31 17:38:53 +09:00
case token.FloatTag:
return d.castToFloat(d.nodeToValue(n.Value))
case token.NullTag:
return nil
2019-10-18 00:00:58 +09:00
case token.BinaryTag:
b, _ := base64.StdEncoding.DecodeString(d.nodeToValue(n.Value).(string))
return b
}
2019-10-18 14:27:49 +09:00
case *ast.AnchorNode:
anchorName := n.Name.GetToken().Value
anchorValue := d.nodeToValue(n.Value)
d.anchorMap[anchorName] = n.Value
2019-10-18 14:27:49 +09:00
return anchorValue
case *ast.AliasNode:
aliasName := n.Value.GetToken().Value
return d.nodeToValue(d.anchorMap[aliasName])
2019-10-18 00:00:58 +09:00
case *ast.LiteralNode:
return n.Value.GetValue()
case *ast.MappingValueNode:
m := map[string]interface{}{}
2019-10-21 14:54:26 +09:00
if n.Key.Type() == ast.MergeKeyType {
mapValue := d.nodeToValue(n.Value).(map[string]interface{})
for k, v := range mapValue {
m[k] = v
}
} else {
key := n.Key.GetToken().Value
m[key] = d.nodeToValue(n.Value)
}
2019-10-18 00:00:58 +09:00
return m
case *ast.MappingNode:
2019-10-18 00:00:58 +09:00
m := map[string]interface{}{}
for _, value := range n.Values {
subMap := d.nodeToValue(value).(map[string]interface{})
for k, v := range subMap {
m[k] = v
}
}
return m
case *ast.SequenceNode:
v := []interface{}{}
for _, value := range n.Values {
v = append(v, d.nodeToValue(value))
}
return v
}
return nil
}
func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
2019-11-01 12:42:41 +09:00
if _, ok := node.(*ast.NullNode); ok {
return nil, nil
}
if anchor, ok := node.(*ast.AnchorNode); ok {
mapNode, ok := anchor.Value.(ast.MapNode)
if ok {
return mapNode, nil
}
return nil, xerrors.Errorf("%s node doesn't MapNode", anchor.Value.Type())
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
anchorNode := d.anchorMap[aliasName]
mapNode, ok := anchorNode.(ast.MapNode)
if ok {
return mapNode, nil
}
return nil, xerrors.Errorf("%s node doesn't MapNode", anchorNode.Type())
}
mapNode, ok := node.(ast.MapNode)
if !ok {
return nil, xerrors.Errorf("%s node doesn't MapNode", node.Type())
}
return mapNode, nil
}
func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) {
2019-11-01 12:42:41 +09:00
if _, ok := node.(*ast.NullNode); ok {
return nil, nil
}
if anchor, ok := node.(*ast.AnchorNode); ok {
arrayNode, ok := anchor.Value.(ast.ArrayNode)
if ok {
return arrayNode, nil
}
return nil, xerrors.Errorf("%s node doesn't ArrayNode", anchor.Value.Type())
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
anchorNode := d.anchorMap[aliasName]
arrayNode, ok := anchorNode.(ast.ArrayNode)
if ok {
return arrayNode, nil
}
return nil, xerrors.Errorf("%s node doesn't ArrayNode", anchorNode.Type())
}
arrayNode, ok := node.(ast.ArrayNode)
if !ok {
return nil, xerrors.Errorf("%s node doesn't ArrayNode", node.Type())
}
return arrayNode, nil
}
2019-11-05 17:02:55 +09:00
func (d *Decoder) fileToNode(f *ast.File) ast.Node {
for _, doc := range f.Docs {
if v := d.nodeToValue(doc.Body); v != nil {
return doc.Body
2019-10-18 13:05:36 +09:00
}
2019-10-18 00:00:58 +09:00
}
return nil
}
func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type) reflect.Value {
if typ.Kind() != reflect.String {
return v.Convert(typ)
}
// cast value to string
switch v.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.ValueOf(fmt.Sprint(v.Int()))
case reflect.Float32, reflect.Float64:
return reflect.ValueOf(fmt.Sprint(v.Float()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return reflect.ValueOf(fmt.Sprint(v.Uint()))
case reflect.Bool:
return reflect.ValueOf(fmt.Sprint(v.Bool()))
}
return v.Convert(typ)
}
2019-10-29 19:46:55 +09:00
var (
errOverflowNumber = xerrors.New("overflow number")
2019-11-01 12:43:36 +09:00
errTypeMismatch = xerrors.New("type mismatch")
2019-10-29 19:46:55 +09:00
)
func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error {
2019-10-24 17:18:03 +09:00
valueType := dst.Type()
2019-10-26 07:19:18 +09:00
if unmarshaler, ok := dst.Addr().Interface().(BytesUnmarshaler); ok {
b := fmt.Sprintf("%v", src)
if err := unmarshaler.UnmarshalYAML([]byte(b)); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
2019-10-26 07:19:18 +09:00
} else if unmarshaler, ok := dst.Addr().Interface().(InterfaceUnmarshaler); ok {
2019-10-24 17:18:03 +09:00
if err := unmarshaler.UnmarshalYAML(func(v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Type().Kind() != reflect.Ptr {
return errors.ErrDecodeRequiredPointerType
}
if err := d.decodeValue(rv.Elem(), src); err != nil {
return errors.Wrapf(err, "failed to decode value")
}
return nil
}); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
2019-10-18 00:00:58 +09:00
switch valueType.Kind() {
case reflect.Ptr:
2019-10-24 17:18:03 +09:00
if dst.IsNil() {
return nil
2019-10-21 14:54:26 +09:00
}
2019-10-30 16:57:59 +09:00
if src.Type() == ast.NullType {
// set nil value to pointer
dst.Set(reflect.Zero(valueType))
return nil
}
2019-10-24 17:18:03 +09:00
v := d.createDecodableValue(dst.Type())
if err := d.decodeValue(v, src); err != nil {
return errors.Wrapf(err, "failed to decode ptr value")
}
dst.Set(d.castToAssignableValue(v, dst.Type()))
2019-10-18 00:00:58 +09:00
case reflect.Interface:
v := reflect.ValueOf(d.nodeToValue(src))
2019-10-24 17:18:03 +09:00
if v.IsValid() {
dst.Set(v)
}
2019-10-18 00:00:58 +09:00
case reflect.Map:
2019-10-24 17:18:03 +09:00
return d.decodeMap(dst, src)
2019-11-01 12:43:57 +09:00
case reflect.Array:
return d.decodeArray(dst, src)
case reflect.Slice:
2019-10-24 17:18:03 +09:00
return d.decodeSlice(dst, src)
2019-10-18 00:00:58 +09:00
case reflect.Struct:
2019-11-01 14:09:33 +09:00
if _, ok := dst.Addr().Interface().(*time.Time); ok {
return d.decodeTime(dst, src)
}
2019-10-24 17:18:03 +09:00
return d.decodeStruct(dst, src)
2019-10-29 19:46:55 +09:00
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v := d.nodeToValue(src)
switch vv := v.(type) {
case int64:
if !dst.OverflowInt(vv) {
dst.SetInt(vv)
return nil
}
2019-10-29 20:15:27 +09:00
case uint64:
if vv <= math.MaxInt64 && !dst.OverflowInt(int64(vv)) {
dst.SetInt(int64(vv))
return nil
}
2019-10-29 19:46:55 +09:00
case float64:
if vv <= math.MaxInt64 && !dst.OverflowInt(int64(vv)) {
dst.SetInt(int64(vv))
return nil
}
2019-11-01 12:43:36 +09:00
default:
return errTypeMismatch
2019-10-29 19:46:55 +09:00
}
return errOverflowNumber
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v := d.nodeToValue(src)
switch vv := v.(type) {
case int64:
2019-10-29 20:15:27 +09:00
if 0 <= vv && !dst.OverflowUint(uint64(vv)) {
2019-10-29 19:46:55 +09:00
dst.SetUint(uint64(vv))
return nil
}
2019-10-29 20:15:27 +09:00
case uint64:
if !dst.OverflowUint(vv) {
dst.SetUint(vv)
return nil
}
2019-10-29 19:46:55 +09:00
case float64:
if 0 <= vv && vv <= math.MaxUint64 && !dst.OverflowUint(uint64(vv)) {
dst.SetUint(uint64(vv))
return nil
}
2019-11-01 12:43:36 +09:00
default:
return errTypeMismatch
2019-10-29 19:46:55 +09:00
}
return errOverflowNumber
2019-10-18 00:00:58 +09:00
}
v := reflect.ValueOf(d.nodeToValue(src))
2019-10-24 17:18:03 +09:00
if v.IsValid() {
dst.Set(d.convertValue(v, dst.Type()))
2019-10-24 17:18:03 +09:00
}
return nil
2019-10-18 00:00:58 +09:00
}
2019-10-24 17:18:03 +09:00
func (d *Decoder) createDecodableValue(typ reflect.Type) reflect.Value {
for {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
continue
}
break
}
return reflect.New(typ).Elem()
}
func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type) reflect.Value {
if target.Kind() != reflect.Ptr {
return value
}
maxTryCount := 5
tryCount := 0
for {
if tryCount > maxTryCount {
return value
}
if value.Type().AssignableTo(target) {
break
}
value = value.Addr()
tryCount++
}
return value
}
2019-12-10 18:41:23 +09:00
func (d *Decoder) keyToNodeMap(node ast.Node, filter func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) {
mapNode, err := d.getMapNode(node)
if err != nil {
return nil, errors.Wrapf(err, "failed to get map node")
}
keyToNodeMap := map[string]ast.Node{}
2019-11-01 12:42:41 +09:00
if mapNode == nil {
return keyToNodeMap, nil
}
mapIter := mapNode.MapRange()
for mapIter.Next() {
keyNode := mapIter.Key()
if keyNode.Type() == ast.MergeKeyType {
2019-12-10 18:41:23 +09:00
mergeMap, err := d.keyToNodeMap(mapIter.Value(), filter)
if err != nil {
return nil, errors.Wrapf(err, "failed to get keyToNodeMap by MergeKey node")
}
for k, v := range mergeMap {
keyToNodeMap[k] = v
}
} else {
key, ok := d.nodeToValue(keyNode).(string)
if !ok {
return nil, errors.Wrapf(err, "failed to decode map key")
}
2019-12-10 18:41:23 +09:00
keyToNodeMap[key] = filter(mapIter)
}
}
return keyToNodeMap, nil
}
func (d *Decoder) setDefaultValueIfConflicted(v reflect.Value, fieldMap StructFieldMap) error {
typ := v.Type()
if typ.Kind() != reflect.Struct {
return nil
}
embeddedStructFieldMap, err := structFieldMap(typ)
if err != nil {
return errors.Wrapf(err, "failed to get struct field map by embedded type")
}
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
if isIgnoredStructField(field) {
continue
}
structField := embeddedStructFieldMap[field.Name]
if !fieldMap.isIncludedRenderName(structField.RenderName) {
continue
}
// if declared same key name, set default value
fieldValue := v.Field(i)
if fieldValue.CanSet() {
fieldValue.Set(reflect.Zero(fieldValue.Type()))
}
}
return nil
}
2019-11-01 14:09:33 +09:00
// This is a subset of the formats allowed by the regular expression
// defined at http://yaml.org/type/timestamp.html.
var allowedTimestampFormats = []string{
"2006-1-2T15:4:5.999999999Z07:00", // RCF3339Nano with short date fields.
"2006-1-2t15:4:5.999999999Z07:00", // RFC3339Nano with short date fields and lower-case "t".
"2006-1-2 15:4:5.999999999", // space separated with no time zone
"2006-1-2", // date only
}
2019-11-01 14:22:06 +09:00
func (d *Decoder) castToTime(src ast.Node) (time.Time, error) {
2019-11-01 14:09:33 +09:00
if src == nil {
2019-11-01 14:22:06 +09:00
return time.Time{}, nil
2019-11-01 14:09:33 +09:00
}
v := d.nodeToValue(src)
2019-11-01 14:22:06 +09:00
if t, ok := v.(time.Time); ok {
return t, nil
}
2019-11-01 14:09:33 +09:00
s, ok := v.(string)
if !ok {
2019-11-01 14:22:06 +09:00
return time.Time{}, errTypeMismatch
2019-11-01 14:09:33 +09:00
}
for _, format := range allowedTimestampFormats {
t, err := time.Parse(format, s)
if err != nil {
// invalid format
continue
}
2019-11-01 14:22:06 +09:00
return t, nil
}
return time.Time{}, nil
}
func (d *Decoder) decodeTime(dst reflect.Value, src ast.Node) error {
t, err := d.castToTime(src)
if err != nil {
return err
2019-11-01 14:09:33 +09:00
}
2019-11-01 14:22:06 +09:00
dst.Set(reflect.ValueOf(t))
2019-11-01 14:09:33 +09:00
return nil
}
func (d *Decoder) decodeStruct(dst reflect.Value, src ast.Node) error {
2019-10-24 17:18:03 +09:00
if src == nil {
return nil
}
structType := dst.Type()
2019-10-18 12:55:38 +09:00
structValue := reflect.New(structType)
2019-10-19 18:28:36 +09:00
structFieldMap, err := structFieldMap(structType)
2019-10-18 12:55:38 +09:00
if err != nil {
2019-10-24 17:18:03 +09:00
return errors.Wrapf(err, "failed to create struct field map")
2019-10-18 12:55:38 +09:00
}
2019-12-10 18:41:23 +09:00
keyToNodeMap, err := d.keyToNodeMap(src, func(iter *ast.MapNodeIter) ast.Node {
return iter.Value()
})
if err != nil {
return errors.Wrapf(err, "failed to get keyToNodeMap")
2019-10-18 12:55:38 +09:00
}
2019-12-10 18:41:23 +09:00
var uncalledKeys map[string]ast.Node
if d.disallowUnknownField {
uncalledKeys, err = d.keyToNodeMap(src, func(iter *ast.MapNodeIter) ast.Node {
return iter.Key()
})
if err != nil {
return errors.Wrapf(err, "failed to get keyToKeyNodeMap")
}
}
2019-10-18 12:55:38 +09:00
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
2019-10-19 18:28:36 +09:00
if isIgnoredStructField(field) {
2019-10-18 12:55:38 +09:00
continue
}
structField := structFieldMap[field.Name]
2019-10-28 22:12:23 +09:00
if structField.IsInline {
fieldValue := structValue.Elem().FieldByName(field.Name)
if !fieldValue.CanSet() {
return xerrors.Errorf("cannot set embedded type as unexported field %s.%s", field.PkgPath, field.Name)
}
2019-10-30 16:57:59 +09:00
if fieldValue.Type().Kind() == reflect.Ptr && src.Type() == ast.NullType {
// set nil value to pointer
fieldValue.Set(reflect.Zero(fieldValue.Type()))
continue
}
2019-10-28 22:12:23 +09:00
newFieldValue := d.createDecodableValue(fieldValue.Type())
if err := d.decodeValue(newFieldValue, src); err != nil {
2019-11-01 12:43:36 +09:00
if xerrors.Is(err, errTypeMismatch) || xerrors.Is(err, errOverflowNumber) {
// skip decoding if an error occurs
2019-10-29 19:46:55 +09:00
continue
}
2019-10-28 22:12:23 +09:00
return errors.Wrapf(err, "failed to decode value")
}
d.setDefaultValueIfConflicted(newFieldValue, structFieldMap)
2019-10-28 22:12:23 +09:00
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
continue
}
v, exists := keyToNodeMap[structField.RenderName]
2019-10-18 12:55:38 +09:00
if !exists {
continue
}
2019-12-10 18:41:23 +09:00
delete(uncalledKeys, structField.RenderName)
2019-10-18 12:55:38 +09:00
fieldValue := structValue.Elem().FieldByName(field.Name)
2019-10-30 16:57:59 +09:00
if fieldValue.Type().Kind() == reflect.Ptr && src.Type() == ast.NullType {
// set nil value to pointer
fieldValue.Set(reflect.Zero(fieldValue.Type()))
continue
}
2019-10-24 17:18:03 +09:00
newFieldValue := d.createDecodableValue(fieldValue.Type())
if err := d.decodeValue(newFieldValue, v); err != nil {
2019-11-01 12:43:36 +09:00
if xerrors.Is(err, errTypeMismatch) || xerrors.Is(err, errOverflowNumber) {
// skip decoding if an error occurs
2019-10-29 19:46:55 +09:00
continue
}
2019-10-24 17:18:03 +09:00
return errors.Wrapf(err, "failed to decode value")
2019-10-18 12:55:38 +09:00
}
2019-10-24 17:18:03 +09:00
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
2019-10-18 12:55:38 +09:00
}
if d.validator != nil {
if err := d.validator.Struct(structValue.Interface()); err != nil {
ev := reflect.ValueOf(err)
if ev.Type().Kind() == reflect.Slice {
for i := 0; i < ev.Len(); i++ {
fieldErr, ok := ev.Index(i).Interface().(FieldError)
if !ok {
continue
}
fieldName := fieldErr.StructField()
structField := structFieldMap[fieldName]
node, exists := keyToNodeMap[structField.RenderName]
if exists {
// TODO: to make FieldError message cutomizable
return errors.ErrSyntax(fmt.Sprintf("%s", err), node.GetToken())
}
}
}
}
}
2019-12-10 18:41:23 +09:00
if len(uncalledKeys) != 0 && d.disallowUnknownField {
for key, node := range uncalledKeys {
return errors.ErrSyntax(fmt.Sprintf("unknown field \"%s\"", key), node.GetToken())
}
}
2019-10-24 17:18:03 +09:00
dst.Set(structValue.Elem())
return nil
2019-10-18 12:55:38 +09:00
}
2019-11-01 12:43:57 +09:00
func (d *Decoder) decodeArray(dst reflect.Value, src ast.Node) error {
arrayNode, err := d.getArrayNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get array node")
}
if arrayNode == nil {
return nil
}
iter := arrayNode.ArrayRange()
arrayValue := reflect.New(dst.Type()).Elem()
arrayType := dst.Type()
elemType := arrayType.Elem()
idx := 0
for iter.Next() {
v := iter.Value()
if elemType.Kind() == reflect.Ptr && v.Type() == ast.NullType {
// set nil value to pointer
arrayValue.Index(idx).Set(reflect.Zero(elemType))
} else {
dstValue := d.createDecodableValue(elemType)
if err := d.decodeValue(dstValue, v); err != nil {
if xerrors.Is(err, errTypeMismatch) || xerrors.Is(err, errOverflowNumber) {
// skip decoding if an error occurs
}
return errors.Wrapf(err, "failed to decode value")
} else {
arrayValue.Index(idx).Set(d.castToAssignableValue(dstValue, elemType))
}
}
idx++
}
dst.Set(arrayValue)
return nil
}
func (d *Decoder) decodeSlice(dst reflect.Value, src ast.Node) error {
arrayNode, err := d.getArrayNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get array node")
}
2019-11-01 12:42:41 +09:00
if arrayNode == nil {
return nil
}
iter := arrayNode.ArrayRange()
2019-10-24 17:18:03 +09:00
sliceType := dst.Type()
sliceValue := reflect.MakeSlice(sliceType, 0, iter.Len())
2019-10-24 17:18:03 +09:00
elemType := sliceType.Elem()
for iter.Next() {
v := iter.Value()
2019-10-30 16:57:59 +09:00
if elemType.Kind() == reflect.Ptr && v.Type() == ast.NullType {
// set nil value to pointer
sliceValue = reflect.Append(sliceValue, reflect.Zero(elemType))
continue
}
2019-10-24 17:18:03 +09:00
dstValue := d.createDecodableValue(elemType)
if err := d.decodeValue(dstValue, v); err != nil {
2019-11-01 12:43:36 +09:00
if xerrors.Is(err, errTypeMismatch) || xerrors.Is(err, errOverflowNumber) {
// skip decoding if an error occurs
2019-10-31 17:56:04 +09:00
continue
}
2019-10-24 17:18:03 +09:00
return errors.Wrapf(err, "failed to decode value")
2019-10-18 12:55:38 +09:00
}
2019-10-24 17:18:03 +09:00
sliceValue = reflect.Append(sliceValue, d.castToAssignableValue(dstValue, elemType))
2019-10-18 00:00:58 +09:00
}
2019-10-24 17:18:03 +09:00
dst.Set(sliceValue)
return nil
2019-10-18 00:00:58 +09:00
}
func (d *Decoder) decodeMap(dst reflect.Value, src ast.Node) error {
mapNode, err := d.getMapNode(src)
if err != nil {
return errors.Wrapf(err, "failed to get map node")
}
2019-11-01 12:42:41 +09:00
if mapNode == nil {
return nil
}
2019-10-24 17:18:03 +09:00
mapType := dst.Type()
2019-10-18 00:00:58 +09:00
mapValue := reflect.MakeMap(mapType)
keyType := mapValue.Type().Key()
valueType := mapValue.Type().Elem()
mapIter := mapNode.MapRange()
for mapIter.Next() {
key := mapIter.Key()
value := mapIter.Value()
2019-10-30 16:57:59 +09:00
k := reflect.ValueOf(d.nodeToValue(key))
if k.IsValid() && k.Type().ConvertibleTo(keyType) {
k = k.Convert(keyType)
}
if valueType.Kind() == reflect.Ptr && value.Type() == ast.NullType {
// set nil value to pointer
mapValue.SetMapIndex(k, reflect.Zero(valueType))
continue
}
2019-10-24 17:18:03 +09:00
dstValue := d.createDecodableValue(valueType)
if err := d.decodeValue(dstValue, value); err != nil {
2019-11-01 12:43:36 +09:00
if xerrors.Is(err, errTypeMismatch) || xerrors.Is(err, errOverflowNumber) {
// skip decoding if an error occurs
2019-10-29 19:46:55 +09:00
continue
}
2019-10-24 17:18:03 +09:00
return errors.Wrapf(err, "failed to decode value")
2019-10-18 12:55:38 +09:00
}
2019-10-30 16:57:59 +09:00
if !k.IsValid() {
// expect nil key
mapValue.SetMapIndex(d.createDecodableValue(keyType), d.castToAssignableValue(dstValue, valueType))
continue
}
mapValue.SetMapIndex(k, d.castToAssignableValue(dstValue, valueType))
2019-10-18 00:00:58 +09:00
}
2019-10-24 17:18:03 +09:00
dst.Set(mapValue)
return nil
2019-10-18 00:00:58 +09:00
}
func (d *Decoder) fileToReader(file string) (io.Reader, error) {
reader, err := os.Open(file)
if err != nil {
return nil, errors.Wrapf(err, "failed to open file")
2019-10-18 00:00:58 +09:00
}
return reader, nil
}
func (d *Decoder) isYAMLFile(file string) bool {
ext := filepath.Ext(file)
if ext == ".yml" {
return true
}
if ext == ".yaml" {
return true
}
return false
}
func (d *Decoder) readersUnderDir(dir string) ([]io.Reader, error) {
pattern := fmt.Sprintf("%s/*", dir)
matches, err := filepath.Glob(pattern)
2019-10-18 00:00:58 +09:00
if err != nil {
return nil, errors.Wrapf(err, "failed to get files by %s", pattern)
}
readers := []io.Reader{}
for _, match := range matches {
if !d.isYAMLFile(match) {
continue
}
reader, err := d.fileToReader(match)
if err != nil {
return nil, errors.Wrapf(err, "failed to get reader")
}
readers = append(readers, reader)
}
return readers, nil
}
func (d *Decoder) readersUnderDirRecursive(dir string) ([]io.Reader, error) {
readers := []io.Reader{}
if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if !d.isYAMLFile(path) {
return nil
}
reader, err := d.fileToReader(path)
if err != nil {
return errors.Wrapf(err, "failed to get reader")
}
readers = append(readers, reader)
return nil
}); err != nil {
return nil, errors.Wrapf(err, "interrupt walk in %s", dir)
}
return readers, nil
}
func (d *Decoder) resolveReference() error {
for _, opt := range d.opts {
if err := opt(d); err != nil {
return errors.Wrapf(err, "failed to exec option")
}
}
for _, file := range d.referenceFiles {
reader, err := d.fileToReader(file)
if err != nil {
return errors.Wrapf(err, "failed to get reader")
}
d.referenceReaders = append(d.referenceReaders, reader)
}
for _, dir := range d.referenceDirs {
if !d.isRecursiveDir {
readers, err := d.readersUnderDir(dir)
if err != nil {
return errors.Wrapf(err, "failed to get readers from under the %s", dir)
}
d.referenceReaders = append(d.referenceReaders, readers...)
} else {
readers, err := d.readersUnderDirRecursive(dir)
if err != nil {
return errors.Wrapf(err, "failed to get readers from under the %s", dir)
}
d.referenceReaders = append(d.referenceReaders, readers...)
}
2019-10-18 00:00:58 +09:00
}
for _, reader := range d.referenceReaders {
bytes, err := ioutil.ReadAll(reader)
if err != nil {
return errors.Wrapf(err, "failed to read buffer")
}
// assign new anchor definition to anchorMap
if _, err := d.decode(bytes); err != nil {
return errors.Wrapf(err, "failed to decode")
}
}
d.isResolvedReference = true
return nil
}
func (d *Decoder) decode(bytes []byte) (ast.Node, error) {
2019-11-05 17:02:55 +09:00
f, err := parser.ParseBytes(bytes, 0)
2019-10-18 00:00:58 +09:00
if err != nil {
return nil, errors.Wrapf(err, "failed to parse yaml")
}
2019-11-05 17:02:55 +09:00
return d.fileToNode(f), nil
}
2019-10-21 12:53:30 +09:00
// Decode reads the next YAML-encoded value from its input
// and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about the
// conversion of YAML into a Go value.
func (d *Decoder) Decode(v interface{}) error {
if !d.isResolvedReference {
if err := d.resolveReference(); err != nil {
return errors.Wrapf(err, "failed to resolve reference")
}
}
rv := reflect.ValueOf(v)
if rv.Type().Kind() != reflect.Ptr {
return errors.ErrDecodeRequiredPointerType
}
bytes, err := ioutil.ReadAll(d.reader)
if err != nil {
return errors.Wrapf(err, "failed to read buffer")
}
node, err := d.decode(bytes)
if err != nil {
return errors.Wrapf(err, "failed to decode")
2019-10-18 00:00:58 +09:00
}
if node == nil {
2019-10-18 00:00:58 +09:00
return nil
}
if err := d.decodeValue(rv.Elem(), node); err != nil {
return errors.Wrapf(err, "failed to decode value")
2019-10-18 12:55:38 +09:00
}
2019-10-18 00:00:58 +09:00
return nil
}