Add support for RawMessage, similar to json.RawMessage (#790)

* Add support for RawMessage, similar to json.RawMessage

This adds a type that users can use to refer to raw data in the yaml for
deferred decoding.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>

* test: add suggested cases from #668

---------

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Co-authored-by: Brian Goff <cpuguy83@gmail.com>
This commit is contained in:
Thane Thomson 2025-11-28 21:59:22 -05:00 committed by GitHub
parent 07c09c0287
commit a7b4bfbcf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 279 additions and 2 deletions

31
yaml.go
View file

@ -324,3 +324,34 @@ func RegisterCustomUnmarshalerContext[T any](unmarshaler func(context.Context, *
return unmarshaler(ctx, v.(*T), b)
}
}
// RawMessage is a raw encoded YAML value. It implements [BytesMarshaler] and
// [BytesUnmarshaler] and can be used to delay YAML decoding or precompute a YAML
// encoding.
// It also implements [json.Marshaler] and [json.Unmarshaler].
//
// This is similar to [json.RawMessage] in the stdlib.
type RawMessage []byte
func (m RawMessage) MarshalYAML() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
func (m *RawMessage) UnmarshalYAML(dt []byte) error {
if m == nil {
return errors.New("yaml.RawMessage: UnmarshalYAML on nil pointer")
}
*m = append((*m)[0:0], dt...)
return nil
}
func (m *RawMessage) UnmarshalJSON(b []byte) error {
return m.UnmarshalYAML(b)
}
func (m RawMessage) MarshalJSON() ([]byte, error) {
return YAMLToJSON(m)
}

View file

@ -1,6 +1,7 @@
package yaml_test
import (
"encoding/json"
"fmt"
"io"
"reflect"
@ -78,7 +79,7 @@ foo: bar # comment
}
func TestDecodeKeepAddress(t *testing.T) {
var data = `
data := `
a: &a [_]
b: &b [*a,*a]
c: &c [*b,*b]
@ -103,7 +104,7 @@ d: &d [*c,*c]
}
func TestSmartAnchor(t *testing.T) {
var data = `
data := `
a: &a [_,_,_,_,_,_,_,_,_,_,_,_,_,_,_]
b: &b [*a,*a,*a,*a,*a,*a,*a,*a,*a,*a]
c: &c [*b,*b,*b,*b,*b,*b,*b,*b,*b,*b]
@ -263,3 +264,248 @@ foo: 2
}
}
}
func checkRawValue[T any](t *testing.T, v yaml.RawMessage, expected T) {
t.Helper()
var actual T
if err := yaml.Unmarshal(v, &actual); err != nil {
t.Errorf("failed to unmarshal: %v", err)
return
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected %v, got %v", expected, actual)
}
}
func checkJSONRawValue[T any](t *testing.T, v json.RawMessage, expected T) {
t.Helper()
var actual T
if err := json.Unmarshal(v, &actual); err != nil {
t.Errorf("failed to unmarshal: %v", err)
return
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected %v, got %v", expected, actual)
}
checkRawValue(t, yaml.RawMessage(v), expected)
}
func TestRawMessage(t *testing.T) {
data := []byte(`
a: 1
b: "asdf"
c:
foo: bar
`)
var m map[string]yaml.RawMessage
if err := yaml.Unmarshal(data, &m); err != nil {
t.Fatal(err)
}
if len(m) != 3 {
t.Fatalf("failed to decode: %d", len(m))
}
checkRawValue(t, m["a"], 1)
checkRawValue(t, m["b"], "asdf")
checkRawValue(t, m["c"], map[string]string{"foo": "bar"})
dt, err := yaml.Marshal(m)
if err != nil {
t.Fatal(err)
}
var m2 map[string]yaml.RawMessage
if err := yaml.Unmarshal(dt, &m2); err != nil {
t.Fatal(err)
}
checkRawValue(t, m2["a"], 1)
checkRawValue(t, m2["b"], "asdf")
checkRawValue(t, m2["c"], map[string]string{"foo": "bar"})
dt, err = json.Marshal(m2)
if err != nil {
t.Fatal(err)
}
var m3 map[string]yaml.RawMessage
if err := yaml.Unmarshal(dt, &m3); err != nil {
t.Fatal(err)
}
checkRawValue(t, m3["a"], 1)
checkRawValue(t, m3["b"], "asdf")
checkRawValue(t, m3["c"], map[string]string{"foo": "bar"})
var m4 map[string]json.RawMessage
if err := json.Unmarshal(dt, &m4); err != nil {
t.Fatal(err)
}
checkJSONRawValue(t, m4["a"], 1)
checkJSONRawValue(t, m4["b"], "asdf")
checkJSONRawValue(t, m4["c"], map[string]string{"foo": "bar"})
}
type rawYAMLWrapper struct {
StaticField string `json:"staticField" yaml:"staticField"`
DynamicField yaml.RawMessage `json:"dynamicField" yaml:"dynamicField"`
}
type rawJSONWrapper struct {
StaticField string `json:"staticField" yaml:"staticField"`
DynamicField json.RawMessage `json:"dynamicField" yaml:"dynamicField"`
}
func (w rawJSONWrapper) Equals(o *rawJSONWrapper) bool {
if w.StaticField != o.StaticField {
return false
}
return reflect.DeepEqual(w.DynamicField, o.DynamicField)
}
type dynamicField struct {
A int `json:"a" yaml:"a"`
B string `json:"b" yaml:"b"`
C map[string]string `json:"c" yaml:"c"`
}
func (t dynamicField) Equals(o *dynamicField) bool {
if t.A != o.A {
return false
}
if t.B != o.B {
return false
}
if len(t.C) != len(o.C) {
return false
}
for k, v := range t.C {
ov, exists := o.C[k]
if !exists {
return false
}
if v != ov {
return false
}
}
return true
}
func TestRawMessageJSONCompatibility(t *testing.T) {
rawData := []byte(`staticField: value
dynamicField:
a: 1
b: abcd
c:
foo: bar
something: else
`)
expectedDynamicFieldValue := &dynamicField{
A: 1,
B: "abcd",
C: map[string]string{
"foo": "bar",
"something": "else",
},
}
t.Run("UseJSONUnmarshaler and json.RawMessage", func(t *testing.T) {
var wrapper rawJSONWrapper
if err := yaml.UnmarshalWithOptions(rawData, &wrapper, yaml.UseJSONUnmarshaler()); err != nil {
t.Fatal(err)
}
if wrapper.StaticField != "value" {
t.Fatalf("unexpected wrapper static field value: %s", wrapper.StaticField)
}
var dynamicFieldValue dynamicField
if err := yaml.Unmarshal(wrapper.DynamicField, &dynamicFieldValue); err != nil {
t.Fatal(err)
}
if !dynamicFieldValue.Equals(expectedDynamicFieldValue) {
t.Fatalf("unexpected dynamic field value: %v", dynamicFieldValue)
}
})
t.Run("UseJSONUnmarshaler and yaml.RawMessage", func(t *testing.T) {
var wrapper rawYAMLWrapper
if err := yaml.UnmarshalWithOptions(rawData, &wrapper, yaml.UseJSONUnmarshaler()); err != nil {
t.Fatal(err)
}
if wrapper.StaticField != "value" {
t.Fatalf("unexpected wrapper static field value: %s", wrapper.StaticField)
}
var dynamicFieldValue dynamicField
if err := yaml.Unmarshal(wrapper.DynamicField, &dynamicFieldValue); err != nil {
t.Fatal(err)
}
if !dynamicFieldValue.Equals(expectedDynamicFieldValue) {
t.Fatalf("unexpected dynamic field value: %v", dynamicFieldValue)
}
})
t.Run("UseJSONMarshaler and json.RawMessage", func(t *testing.T) {
dynamicFieldBytes, err := yaml.Marshal(expectedDynamicFieldValue)
if err != nil {
t.Fatal(err)
}
wrapper := rawJSONWrapper{
StaticField: "value",
DynamicField: json.RawMessage(dynamicFieldBytes),
}
wrapperBytes, err := yaml.MarshalWithOptions(&wrapper, yaml.UseJSONMarshaler())
if err != nil {
t.Fatal(err)
}
var unmarshaledWrapper rawJSONWrapper
if err := yaml.UnmarshalWithOptions(wrapperBytes, &unmarshaledWrapper, yaml.UseJSONUnmarshaler()); err != nil {
t.Fatal(err)
}
if unmarshaledWrapper.StaticField != wrapper.StaticField {
t.Fatalf("unexpected unmarshaled static field value: %s", unmarshaledWrapper.StaticField)
}
var unmarshaledDynamicFieldValue dynamicField
if err := yaml.UnmarshalWithOptions(unmarshaledWrapper.DynamicField, &unmarshaledDynamicFieldValue, yaml.UseJSONUnmarshaler()); err != nil {
t.Fatal(err)
}
if !unmarshaledDynamicFieldValue.Equals(expectedDynamicFieldValue) {
t.Fatalf("unexpected unmarshaled dynamic field value: %v", unmarshaledDynamicFieldValue)
}
})
t.Run("UseJSONMarshaler and yaml.RawMessage", func(t *testing.T) {
dynamicFieldBytes, err := yaml.Marshal(expectedDynamicFieldValue)
if err != nil {
t.Fatal(err)
}
wrapper := rawYAMLWrapper{
StaticField: "value",
DynamicField: yaml.RawMessage(dynamicFieldBytes),
}
wrapperBytes, err := yaml.MarshalWithOptions(&wrapper, yaml.UseJSONMarshaler())
if err != nil {
t.Fatal(err)
}
var unmarshaledWrapper rawYAMLWrapper
if err := yaml.UnmarshalWithOptions(wrapperBytes, &unmarshaledWrapper, yaml.UseJSONUnmarshaler()); err != nil {
t.Fatal(err)
}
if unmarshaledWrapper.StaticField != wrapper.StaticField {
t.Fatalf("unexpected unmarshaled static field value: %s", unmarshaledWrapper.StaticField)
}
var unmarshaledDynamicFieldValue dynamicField
if err := yaml.UnmarshalWithOptions(unmarshaledWrapper.DynamicField, &unmarshaledDynamicFieldValue, yaml.UseJSONUnmarshaler()); err != nil {
t.Fatal(err)
}
if !unmarshaledDynamicFieldValue.Equals(expectedDynamicFieldValue) {
t.Fatalf("unexpected unmarshaled dynamic field value: %v", unmarshaledDynamicFieldValue)
}
})
}