From d225e247cc000901b7e5a523c827b7a6e192f74f Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Mon, 3 Feb 2025 12:54:17 +0900 Subject: [PATCH] fix comment map (#635) --- decode.go | 14 +++++++--- yaml_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 yaml_test.go diff --git a/decode.go b/decode.go index 779ecb5..0191a39 100644 --- a/decode.go +++ b/decode.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "fmt" "io" + "maps" "math" "os" "path/filepath" @@ -30,6 +31,7 @@ type Decoder struct { aliasValueMap map[*ast.AliasNode]any anchorValueMap map[string]reflect.Value customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error + commentMaps []CommentMap toCommentMap CommentMap opts []DecodeOption referenceFiles []string @@ -1957,6 +1959,12 @@ func (d *Decoder) parse(bytes []byte) (*ast.File, error) { if v != nil { normalizedFile.Docs = append(normalizedFile.Docs, doc) } + cm := CommentMap{} + maps.Copy(cm, d.toCommentMap) + d.commentMaps = append(d.commentMaps, cm) + for k := range d.toCommentMap { + delete(d.toCommentMap, k) + } } return normalizedFile, nil } @@ -1980,9 +1988,6 @@ func (d *Decoder) decodeInit() error { return err } d.parsedFile = file - for k := range d.toCommentMap { - delete(d.toCommentMap, k) - } return nil } @@ -1995,6 +2000,9 @@ func (d *Decoder) decode(ctx context.Context, v reflect.Value) error { if body == nil { return nil } + if len(d.commentMaps) > d.streamIndex { + maps.Copy(d.toCommentMap, d.commentMaps[d.streamIndex]) + } if err := d.decodeValue(ctx, v.Elem(), body); err != nil { return err } diff --git a/yaml_test.go b/yaml_test.go new file mode 100644 index 0000000..560a93c --- /dev/null +++ b/yaml_test.go @@ -0,0 +1,74 @@ +package yaml_test + +import ( + "io" + "reflect" + "strings" + "testing" + + "github.com/goccy/go-yaml" +) + +func TestRoundTripWithComment(t *testing.T) { + yml := ` +# head comment +key: value # line comment +` + var v struct { + Key string + } + comments := yaml.CommentMap{} + + if err := yaml.UnmarshalWithOptions([]byte(yml), &v, yaml.Strict(), yaml.CommentToMap(comments)); err != nil { + t.Fatal(err) + } + out, err := yaml.MarshalWithOptions(v, yaml.WithComment(comments)) + if err != nil { + t.Fatal(err) + } + got := "\n" + string(out) + if yml != got { + t.Fatalf("failed to get round tripped yaml: %s", got) + } +} + +func TestStreamDecodingWithComment(t *testing.T) { + yml := ` +a: + b: + c: # comment +--- +foo: bar # comment +--- +- a +- b +- c # comment +` + cm := yaml.CommentMap{} + dec := yaml.NewDecoder(strings.NewReader(yml), yaml.CommentToMap(cm)) + var commentPathsWithDocIndex [][]string + for { + var v any + if err := dec.Decode(&v); err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + paths := make([]string, 0, len(cm)) + for k := range cm { + paths = append(paths, k) + } + commentPathsWithDocIndex = append(commentPathsWithDocIndex, paths) + for k := range cm { + delete(cm, k) + } + } + if !reflect.DeepEqual(commentPathsWithDocIndex, [][]string{ + {"$.a.b.c"}, + {"$.foo"}, + {"$[2]"}, + }) { + t.Fatalf("failed to get comment: %v", commentPathsWithDocIndex) + } +}