Keep reference of anchor's value (#660)

* keep reference of anchor's value

* add test case
This commit is contained in:
Masaaki Goshima 2025-02-16 16:44:00 +09:00 committed by GitHub
parent 89a66008de
commit b46780d4c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 157 additions and 88 deletions

175
decode.go
View file

@ -29,7 +29,6 @@ type Decoder struct {
reader io.Reader
referenceReaders []io.Reader
anchorNodeMap map[string]ast.Node
aliasValueMap map[*ast.AliasNode]any
anchorValueMap map[string]reflect.Value
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
commentMaps []CommentMap
@ -54,7 +53,6 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
return &Decoder{
reader: r,
anchorNodeMap: map[string]ast.Node{},
aliasValueMap: make(map[*ast.AliasNode]any),
anchorValueMap: map[string]reflect.Value{},
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
opts: opts,
@ -117,8 +115,8 @@ func (d *Decoder) castToFloat(v interface{}) interface{} {
return 0
}
func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) (string, error) {
key, err := d.nodeToValue(node)
func (d *Decoder) mapKeyNodeToString(ctx context.Context, node ast.MapKeyNode) (string, error) {
key, err := d.nodeToValue(ctx, node)
if err != nil {
return "", err
}
@ -131,7 +129,7 @@ func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) (string, error) {
return fmt.Sprint(key), nil
}
func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error {
func (d *Decoder) setToMapValue(ctx context.Context, node ast.Node, m map[string]interface{}) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
@ -148,16 +146,16 @@ func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error {
}
iter := value.MapRange()
for iter.Next() {
if err := d.setToMapValue(iter.KeyValue(), m); err != nil {
if err := d.setToMapValue(ctx, iter.KeyValue(), m); err != nil {
return err
}
}
} else {
key, err := d.mapKeyNodeToString(n.Key)
key, err := d.mapKeyNodeToString(ctx, n.Key)
if err != nil {
return err
}
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return err
}
@ -165,7 +163,7 @@ func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error {
}
case *ast.MappingNode:
for _, value := range n.Values {
if err := d.setToMapValue(value, m); err != nil {
if err := d.setToMapValue(ctx, value, m); err != nil {
return err
}
}
@ -176,7 +174,7 @@ func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error {
return nil
}
func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) error {
func (d *Decoder) setToOrderedMapValue(ctx context.Context, node ast.Node, m *MapSlice) error {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
@ -193,16 +191,16 @@ func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) error {
}
iter := value.MapRange()
for iter.Next() {
if err := d.setToOrderedMapValue(iter.KeyValue(), m); err != nil {
if err := d.setToOrderedMapValue(ctx, iter.KeyValue(), m); err != nil {
return err
}
}
} else {
key, err := d.mapKeyNodeToString(n.Key)
key, err := d.mapKeyNodeToString(ctx, n.Key)
if err != nil {
return err
}
value, err := d.nodeToValue(n.Value)
value, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return err
}
@ -210,7 +208,7 @@ func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) error {
}
case *ast.MappingNode:
for _, value := range n.Values {
if err := d.setToOrderedMapValue(value, m); err != nil {
if err := d.setToOrderedMapValue(ctx, value, m); err != nil {
return err
}
}
@ -341,7 +339,7 @@ func (d *Decoder) addCommentToMap(path string, comment *Comment) {
})
}
func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
func (d *Decoder) nodeToValue(ctx context.Context, node ast.Node) (any, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
@ -366,7 +364,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
return n.GetValue(), nil
case *ast.TagNode:
if n.Directive != nil {
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
@ -377,17 +375,17 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
}
switch token.ReservedTagKeyword(n.Start.Value) {
case token.TimestampTag:
t, _ := d.castToTime(n.Value)
t, _ := d.castToTime(ctx, n.Value)
return t, nil
case token.IntegerTag:
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
i, _ := strconv.Atoi(fmt.Sprint(v))
return i, nil
case token.FloatTag:
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
@ -395,7 +393,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
case token.NullTag:
return nil, nil
case token.BinaryTag:
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
@ -409,7 +407,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
b, _ := base64.StdEncoding.DecodeString(str)
return b, nil
case token.BooleanTag:
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
@ -426,7 +424,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
}
return nil, errors.ErrSyntax(fmt.Sprintf("cannot convert %q to boolean", fmt.Sprint(v)), n.Value.GetToken())
case token.StringTag:
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
@ -435,45 +433,41 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
}
return fmt.Sprint(v), nil
case token.MappingTag:
return d.nodeToValue(n.Value)
return d.nodeToValue(ctx, n.Value)
default:
return d.nodeToValue(n.Value)
return d.nodeToValue(ctx, n.Value)
}
case *ast.AnchorNode:
anchorName := n.Name.GetToken().Value
// To handle the case where alias is processed recursively, the result of alias can be set to nil in advance.
d.anchorNodeMap[anchorName] = nil
anchorValue, err := d.nodeToValue(n.Value)
anchorValue, err := d.nodeToValue(withAnchor(ctx, anchorName), n.Value)
if err != nil {
delete(d.anchorNodeMap, anchorName)
return nil, err
}
d.anchorNodeMap[anchorName] = n.Value
d.anchorValueMap[anchorName] = reflect.ValueOf(anchorValue)
return anchorValue, nil
case *ast.AliasNode:
if v, exists := d.aliasValueMap[n]; exists {
return v, nil
text := n.Value.String()
if _, exists := getAnchorMap(ctx)[text]; exists {
// self recursion.
return nil, nil
}
if v, exists := d.anchorValueMap[text]; exists {
if !v.IsValid() {
return nil, nil
}
return v.Interface(), nil
}
// To handle the case where alias is processed recursively, the result of alias can be set to nil in advance.
d.aliasValueMap[n] = nil
aliasName := n.Value.GetToken().Value
node, exists := d.anchorNodeMap[aliasName]
if !exists {
return nil, errors.ErrSyntax(fmt.Sprintf("could not find alias %q", aliasName), n.Value.GetToken())
}
aliasValue, err := d.nodeToValue(node)
if err != nil {
return nil, err
}
// once the correct alias value is obtained, overwrite with that value.
d.aliasValueMap[n] = aliasValue
return aliasValue, nil
return nil, errors.ErrSyntax(fmt.Sprintf("could not find alias %q", aliasName), n.Value.GetToken())
case *ast.LiteralNode:
return n.Value.GetValue(), nil
case *ast.MappingKeyNode:
return d.nodeToValue(n.Value)
return d.nodeToValue(ctx, n.Value)
case *ast.MappingValueNode:
if n.Key.IsMergeKey() {
value, err := d.getMapNode(n.Value, true)
@ -484,7 +478,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
if d.useOrderedMap {
m := MapSlice{}
for iter.Next() {
if err := d.setToOrderedMapValue(iter.KeyValue(), &m); err != nil {
if err := d.setToOrderedMapValue(ctx, iter.KeyValue(), &m); err != nil {
return nil, err
}
}
@ -492,24 +486,24 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
}
m := make(map[string]any)
for iter.Next() {
if err := d.setToMapValue(iter.KeyValue(), m); err != nil {
if err := d.setToMapValue(ctx, iter.KeyValue(), m); err != nil {
return nil, err
}
}
return m, nil
}
key, err := d.mapKeyNodeToString(n.Key)
key, err := d.mapKeyNodeToString(ctx, n.Key)
if err != nil {
return nil, err
}
if d.useOrderedMap {
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
return MapSlice{{Key: key, Value: v}}, nil
}
v, err := d.nodeToValue(n.Value)
v, err := d.nodeToValue(ctx, n.Value)
if err != nil {
return nil, err
}
@ -518,7 +512,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
if d.useOrderedMap {
m := make(MapSlice, 0, len(n.Values))
for _, value := range n.Values {
if err := d.setToOrderedMapValue(value, &m); err != nil {
if err := d.setToOrderedMapValue(ctx, value, &m); err != nil {
return nil, err
}
}
@ -526,7 +520,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
}
m := make(map[string]interface{}, len(n.Values))
for _, value := range n.Values {
if err := d.setToMapValue(value, m); err != nil {
if err := d.setToMapValue(ctx, value, m); err != nil {
return nil, err
}
}
@ -534,7 +528,7 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
case *ast.SequenceNode:
v := make([]interface{}, 0, len(n.Values))
for _, value := range n.Values {
vv, err := d.nodeToValue(value)
vv, err := d.nodeToValue(ctx, value)
if err != nil {
return nil, err
}
@ -878,10 +872,13 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
}
if src.Type() == ast.AnchorType {
anchorName := src.(*ast.AnchorNode).Name.GetToken().Value
if _, exists := d.anchorValueMap[anchorName]; !exists {
d.anchorValueMap[anchorName] = dst
anchor, _ := src.(*ast.AnchorNode)
anchorName := anchor.Name.GetToken().Value
if err := d.decodeValue(withAnchor(ctx, anchorName), dst, anchor.Value); err != nil {
return err
}
d.anchorValueMap[anchorName] = dst
return nil
}
if d.canDecodeByUnmarshaler(dst) {
if err := d.decodeByUnmarshaler(ctx, dst, src); err != nil {
@ -914,7 +911,7 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
dst.Set(reflect.ValueOf(src))
return nil
}
srcVal, err := d.nodeToValue(src)
srcVal, err := d.nodeToValue(ctx, src)
if err != nil {
return err
}
@ -937,7 +934,7 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
}
return d.decodeStruct(ctx, dst, src)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, err := d.nodeToValue(src)
v, err := d.nodeToValue(ctx, src)
if err != nil {
return err
}
@ -971,7 +968,7 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
}
return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v, err := d.nodeToValue(src)
v, err := d.nodeToValue(ctx, src)
if err != nil {
return err
}
@ -1006,7 +1003,7 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
}
return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken())
}
srcVal, err := d.nodeToValue(src)
srcVal, err := d.nodeToValue(ctx, src)
if err != nil {
return err
}
@ -1046,6 +1043,9 @@ func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type
if value.Type().AssignableTo(target) {
break
}
if !value.CanAddr() {
break
}
value = value.Addr()
}
if !value.Type().AssignableTo(target) {
@ -1091,7 +1091,7 @@ func (d *Decoder) createDecodedNewValue(
return d.castToAssignableValue(newValue, typ, node)
}
func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) {
func (d *Decoder) keyToNodeMap(ctx context.Context, node ast.Node, ignoreMergeKey bool, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
@ -1111,7 +1111,7 @@ func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValue
if ignoreMergeKey {
continue
}
mergeMap, err := d.keyToNodeMap(mapIter.Value(), ignoreMergeKey, getKeyOrValueNode)
mergeMap, err := d.keyToNodeMap(ctx, mapIter.Value(), ignoreMergeKey, getKeyOrValueNode)
if err != nil {
return nil, err
}
@ -1122,7 +1122,7 @@ func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValue
keyToNodeMap[k] = v
}
} else {
keyVal, err := d.nodeToValue(keyNode)
keyVal, err := d.nodeToValue(ctx, keyNode)
if err != nil {
return nil, err
}
@ -1139,16 +1139,16 @@ func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValue
return keyToNodeMap, nil
}
func (d *Decoder) keyToKeyNodeMap(node ast.Node, ignoreMergeKey bool) (map[string]ast.Node, error) {
m, err := d.keyToNodeMap(node, ignoreMergeKey, func(nodeMap *ast.MapNodeIter) ast.Node { return nodeMap.Key() })
func (d *Decoder) keyToKeyNodeMap(ctx context.Context, node ast.Node, ignoreMergeKey bool) (map[string]ast.Node, error) {
m, err := d.keyToNodeMap(ctx, node, ignoreMergeKey, func(nodeMap *ast.MapNodeIter) ast.Node { return nodeMap.Key() })
if err != nil {
return nil, err
}
return m, nil
}
func (d *Decoder) keyToValueNodeMap(node ast.Node, ignoreMergeKey bool) (map[string]ast.Node, error) {
m, err := d.keyToNodeMap(node, ignoreMergeKey, func(nodeMap *ast.MapNodeIter) ast.Node { return nodeMap.Value() })
func (d *Decoder) keyToValueNodeMap(ctx context.Context, node ast.Node, ignoreMergeKey bool) (map[string]ast.Node, error) {
m, err := d.keyToNodeMap(ctx, node, ignoreMergeKey, func(nodeMap *ast.MapNodeIter) ast.Node { return nodeMap.Value() })
if err != nil {
return nil, err
}
@ -1194,11 +1194,11 @@ var allowedTimestampFormats = []string{
"2006-1-2", // date only
}
func (d *Decoder) castToTime(src ast.Node) (time.Time, error) {
func (d *Decoder) castToTime(ctx context.Context, src ast.Node) (time.Time, error) {
if src == nil {
return time.Time{}, nil
}
v, err := d.nodeToValue(src)
v, err := d.nodeToValue(ctx, src)
if err != nil {
return time.Time{}, err
}
@ -1221,7 +1221,7 @@ func (d *Decoder) castToTime(src ast.Node) (time.Time, error) {
}
func (d *Decoder) decodeTime(ctx context.Context, dst reflect.Value, src ast.Node) error {
t, err := d.castToTime(src)
t, err := d.castToTime(ctx, src)
if err != nil {
return err
}
@ -1229,11 +1229,11 @@ func (d *Decoder) decodeTime(ctx context.Context, dst reflect.Value, src ast.Nod
return nil
}
func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) {
func (d *Decoder) castToDuration(ctx context.Context, src ast.Node) (time.Duration, error) {
if src == nil {
return 0, nil
}
v, err := d.nodeToValue(src)
v, err := d.nodeToValue(ctx, src)
if err != nil {
return 0, err
}
@ -1252,7 +1252,7 @@ func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) {
}
func (d *Decoder) decodeDuration(ctx context.Context, dst reflect.Value, src ast.Node) error {
t, err := d.castToDuration(src)
t, err := d.castToDuration(ctx, src)
if err != nil {
return err
}
@ -1304,13 +1304,13 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
return err
}
ignoreMergeKey := structFieldMap.hasMergeProperty()
keyToNodeMap, err := d.keyToValueNodeMap(src, ignoreMergeKey)
keyToNodeMap, err := d.keyToValueNodeMap(ctx, src, ignoreMergeKey)
if err != nil {
return err
}
var unknownFields map[string]ast.Node
if d.disallowUnknownField {
unknownFields, err = d.keyToKeyNodeMap(src, ignoreMergeKey)
unknownFields, err = d.keyToKeyNodeMap(ctx, src, ignoreMergeKey)
if err != nil {
return err
}
@ -1566,11 +1566,11 @@ func (d *Decoder) decodeMapItem(ctx context.Context, dst *MapItem, src ast.Node)
}
return nil
}
k, err := d.nodeToValue(key)
k, err := d.nodeToValue(ctx, key)
if err != nil {
return err
}
v, err := d.nodeToValue(value)
v, err := d.nodeToValue(ctx, value)
if err != nil {
return err
}
@ -1622,14 +1622,14 @@ func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Nod
}
continue
}
k, err := d.nodeToValue(key)
k, err := d.nodeToValue(ctx, key)
if err != nil {
return err
}
if err := d.validateDuplicateKey(keyMap, k, key); err != nil {
return err
}
v, err := d.nodeToValue(value)
v, err := d.nodeToValue(ctx, value)
if err != nil {
return err
}
@ -1680,7 +1680,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
return err
}
} else {
keyVal, err := d.nodeToValue(key)
keyVal, err := d.nodeToValue(ctx, key)
if err != nil {
return err
}
@ -1783,7 +1783,7 @@ func (d *Decoder) readersUnderDirRecursive(dir string) ([]io.Reader, error) {
return readers, nil
}
func (d *Decoder) resolveReference() error {
func (d *Decoder) resolveReference(ctx context.Context) error {
for _, opt := range d.opts {
if err := opt(d); err != nil {
return err
@ -1818,7 +1818,7 @@ func (d *Decoder) resolveReference() error {
}
// assign new anchor definition to anchorMap
if _, err := d.parse(bytes); err != nil {
if _, err := d.parse(ctx, bytes); err != nil {
return err
}
}
@ -1826,7 +1826,7 @@ func (d *Decoder) resolveReference() error {
return nil
}
func (d *Decoder) parse(bytes []byte) (*ast.File, error) {
func (d *Decoder) parse(ctx context.Context, bytes []byte) (*ast.File, error) {
var parseMode parser.Mode
if d.toCommentMap != nil {
parseMode = parser.ParseComments
@ -1842,7 +1842,7 @@ func (d *Decoder) parse(bytes []byte) (*ast.File, error) {
normalizedFile := &ast.File{}
for _, doc := range f.Docs {
// try to decode ast.Node to value and map anchor value to anchorMap
v, err := d.nodeToValue(doc.Body)
v, err := d.nodeToValue(ctx, doc.Body)
if err != nil {
return nil, err
}
@ -1863,9 +1863,9 @@ func (d *Decoder) isInitialized() bool {
return d.parsedFile != nil
}
func (d *Decoder) decodeInit() error {
func (d *Decoder) decodeInit(ctx context.Context) error {
if !d.isResolvedReference {
if err := d.resolveReference(); err != nil {
if err := d.resolveReference(ctx); err != nil {
return err
}
}
@ -1873,7 +1873,7 @@ func (d *Decoder) decodeInit() error {
if _, err := io.Copy(&buf, d.reader); err != nil {
return err
}
file, err := d.parse(buf.Bytes())
file, err := d.parse(ctx, buf.Bytes())
if err != nil {
return err
}
@ -1883,6 +1883,7 @@ func (d *Decoder) decodeInit() error {
func (d *Decoder) decode(ctx context.Context, v reflect.Value) error {
d.decodeDepth = 0
d.anchorValueMap = make(map[string]reflect.Value)
if len(d.parsedFile.Docs) <= d.streamIndex {
return io.EOF
}
@ -1925,7 +1926,7 @@ func (d *Decoder) DecodeContext(ctx context.Context, v interface{}) error {
}
return nil
}
if err := d.decodeInit(); err != nil {
if err := d.decodeInit(ctx); err != nil {
return err
}
if err := d.decode(ctx, rv); err != nil {
@ -1949,12 +1950,12 @@ func (d *Decoder) DecodeFromNodeContext(ctx context.Context, node ast.Node, v in
return ErrDecodeRequiredPointerType
}
if !d.isInitialized() {
if err := d.decodeInit(); err != nil {
if err := d.decodeInit(ctx); err != nil {
return err
}
}
// resolve references to the anchor on the same file
if _, err := d.nodeToValue(node); err != nil {
if _, err := d.nodeToValue(ctx, node); err != nil {
return err
}
if err := d.decodeValue(ctx, rv.Elem(), node); err != nil {