From 0040ab4161153c5b812ae51984d3967ad3dbd068 Mon Sep 17 00:00:00 2001 From: Shuhei Kitagawa Date: Sat, 29 Nov 2025 03:40:13 +0100 Subject: [PATCH] Skip directive in path operations (#758) --- path.go | 10 +++++- path_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/path.go b/path.go index 4d2cbba..568c4b4 100644 --- a/path.go +++ b/path.go @@ -258,6 +258,10 @@ func (p *Path) Filter(target, v interface{}) error { // FilterFile filter from ast.File by YAMLPath. func (p *Path) FilterFile(f *ast.File) (ast.Node, error) { for _, doc := range f.Docs { + // For simplicity, directives cannot be the target of operations + if doc.Body != nil && doc.Body.Type() == ast.DirectiveType { + continue + } node, err := p.FilterNode(doc.Body) if err != nil { return nil, err @@ -352,6 +356,10 @@ func (p *Path) ReplaceWithFile(dst *ast.File, src *ast.File) error { // ReplaceNode replace ast.File with ast.Node. func (p *Path) ReplaceWithNode(dst *ast.File, node ast.Node) error { for _, doc := range dst.Docs { + // For simplicity, directives cannot be the target of operations + if doc.Body != nil && doc.Body.Type() == ast.DirectiveType { + continue + } if node.Type() == ast.DocumentType { node = node.(*ast.DocumentNode).Body } @@ -364,7 +372,7 @@ func (p *Path) ReplaceWithNode(dst *ast.File, node ast.Node) error { // AnnotateSource add annotation to passed source ( see section 5.1 in README.md ). func (p *Path) AnnotateSource(source []byte, colored bool) ([]byte, error) { - file, err := parser.ParseBytes([]byte(source), 0) + file, err := parser.ParseBytes(source, 0) if err != nil { return nil, err } diff --git a/path_test.go b/path_test.go index efdee3d..0e29306 100644 --- a/path_test.go +++ b/path_test.go @@ -214,6 +214,82 @@ store: }) } +func TestPath_FilterFile(t *testing.T) { + tests := []struct { + name string + path string + src string + expected any + expectedErr string + }{ + { + name: "simple key", + path: "$.key", + src: `key: value`, + expected: "value", + }, + { + name: "with directive", + path: "$.key", + src: `%YAML 1.2 +--- +key: value`, + expected: "value", + }, + { + name: "multiple docs", + path: "$.key2", + src: `key1: value1 +--- +key2: value2 +`, + expected: "value2", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + path, err := yaml.PathString(test.path) + if err != nil { + t.Fatalf("unexpected error during path parsing: %+v", err) + } + + file, err := parser.ParseBytes([]byte(test.src), 0) + if err != nil { + t.Fatalf("failed to parse YAML: %+v", err) + } + + node, err := path.FilterFile(file) + if test.expectedErr != "" { + if err == nil { + t.Fatal("expected error but got none") + } + if !strings.Contains(err.Error(), test.expectedErr) { + t.Fatalf("expected error containing %q but got %q", test.expectedErr, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + + if node == nil { + t.Fatal("expected node but got nil") + } + + var actual any + if err := yaml.Unmarshal([]byte(node.String()), &actual); err != nil { + t.Fatalf("failed to unmarshal result: %+v", err) + } + + if !reflect.DeepEqual(test.expected, actual) { + t.Fatalf("expected %v(%T) but got %v(%T)", test.expected, test.expected, actual, actual) + } + }) + } +} + func TestPath_ReservedKeyword(t *testing.T) { tests := []struct { name string @@ -605,6 +681,22 @@ b: 2 expected: ` a: 3 b: 2 +`, + }, + { + path: "$.a", + dst: ` +%YAML 1.2 +--- +a: 1 +b: 2 +`, + src: `3`, + expected: ` +%YAML 1.2 +--- +a: 3 +b: 2 `, }, {