diff --git a/yaml.go b/yaml.go index f9a9f83..e1b5fbd 100644 --- a/yaml.go +++ b/yaml.go @@ -324,3 +324,34 @@ func RegisterCustomUnmarshalerContext[T any](unmarshaler func(context.Context, * return unmarshaler(ctx, v.(*T), b) } } + +// RawMessage is a raw encoded YAML value. It implements [BytesMarshaler] and +// [BytesUnmarshaler] and can be used to delay YAML decoding or precompute a YAML +// encoding. +// It also implements [json.Marshaler] and [json.Unmarshaler]. +// +// This is similar to [json.RawMessage] in the stdlib. +type RawMessage []byte + +func (m RawMessage) MarshalYAML() ([]byte, error) { + if m == nil { + return []byte("null"), nil + } + return m, nil +} + +func (m *RawMessage) UnmarshalYAML(dt []byte) error { + if m == nil { + return errors.New("yaml.RawMessage: UnmarshalYAML on nil pointer") + } + *m = append((*m)[0:0], dt...) + return nil +} + +func (m *RawMessage) UnmarshalJSON(b []byte) error { + return m.UnmarshalYAML(b) +} + +func (m RawMessage) MarshalJSON() ([]byte, error) { + return YAMLToJSON(m) +} diff --git a/yaml_test.go b/yaml_test.go index f9698e6..d5546b4 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -1,6 +1,7 @@ package yaml_test import ( + "encoding/json" "fmt" "io" "reflect" @@ -78,7 +79,7 @@ foo: bar # comment } func TestDecodeKeepAddress(t *testing.T) { - var data = ` + data := ` a: &a [_] b: &b [*a,*a] c: &c [*b,*b] @@ -103,7 +104,7 @@ d: &d [*c,*c] } func TestSmartAnchor(t *testing.T) { - var data = ` + data := ` a: &a [_,_,_,_,_,_,_,_,_,_,_,_,_,_,_] b: &b [*a,*a,*a,*a,*a,*a,*a,*a,*a,*a] c: &c [*b,*b,*b,*b,*b,*b,*b,*b,*b,*b] @@ -263,3 +264,248 @@ foo: 2 } } } + +func checkRawValue[T any](t *testing.T, v yaml.RawMessage, expected T) { + t.Helper() + + var actual T + + if err := yaml.Unmarshal(v, &actual); err != nil { + t.Errorf("failed to unmarshal: %v", err) + return + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected %v, got %v", expected, actual) + } +} + +func checkJSONRawValue[T any](t *testing.T, v json.RawMessage, expected T) { + t.Helper() + + var actual T + + if err := json.Unmarshal(v, &actual); err != nil { + t.Errorf("failed to unmarshal: %v", err) + return + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected %v, got %v", expected, actual) + } + + checkRawValue(t, yaml.RawMessage(v), expected) +} + +func TestRawMessage(t *testing.T) { + data := []byte(` +a: 1 +b: "asdf" +c: + foo: bar +`) + + var m map[string]yaml.RawMessage + if err := yaml.Unmarshal(data, &m); err != nil { + t.Fatal(err) + } + + if len(m) != 3 { + t.Fatalf("failed to decode: %d", len(m)) + } + + checkRawValue(t, m["a"], 1) + checkRawValue(t, m["b"], "asdf") + checkRawValue(t, m["c"], map[string]string{"foo": "bar"}) + + dt, err := yaml.Marshal(m) + if err != nil { + t.Fatal(err) + } + var m2 map[string]yaml.RawMessage + if err := yaml.Unmarshal(dt, &m2); err != nil { + t.Fatal(err) + } + + checkRawValue(t, m2["a"], 1) + checkRawValue(t, m2["b"], "asdf") + checkRawValue(t, m2["c"], map[string]string{"foo": "bar"}) + + dt, err = json.Marshal(m2) + if err != nil { + t.Fatal(err) + } + + var m3 map[string]yaml.RawMessage + if err := yaml.Unmarshal(dt, &m3); err != nil { + t.Fatal(err) + } + checkRawValue(t, m3["a"], 1) + checkRawValue(t, m3["b"], "asdf") + checkRawValue(t, m3["c"], map[string]string{"foo": "bar"}) + + var m4 map[string]json.RawMessage + if err := json.Unmarshal(dt, &m4); err != nil { + t.Fatal(err) + } + checkJSONRawValue(t, m4["a"], 1) + checkJSONRawValue(t, m4["b"], "asdf") + checkJSONRawValue(t, m4["c"], map[string]string{"foo": "bar"}) +} + +type rawYAMLWrapper struct { + StaticField string `json:"staticField" yaml:"staticField"` + DynamicField yaml.RawMessage `json:"dynamicField" yaml:"dynamicField"` +} + +type rawJSONWrapper struct { + StaticField string `json:"staticField" yaml:"staticField"` + DynamicField json.RawMessage `json:"dynamicField" yaml:"dynamicField"` +} + +func (w rawJSONWrapper) Equals(o *rawJSONWrapper) bool { + if w.StaticField != o.StaticField { + return false + } + return reflect.DeepEqual(w.DynamicField, o.DynamicField) +} + +type dynamicField struct { + A int `json:"a" yaml:"a"` + B string `json:"b" yaml:"b"` + C map[string]string `json:"c" yaml:"c"` +} + +func (t dynamicField) Equals(o *dynamicField) bool { + if t.A != o.A { + return false + } + if t.B != o.B { + return false + } + if len(t.C) != len(o.C) { + return false + } + for k, v := range t.C { + ov, exists := o.C[k] + if !exists { + return false + } + if v != ov { + return false + } + } + return true +} + +func TestRawMessageJSONCompatibility(t *testing.T) { + rawData := []byte(`staticField: value +dynamicField: + a: 1 + b: abcd + c: + foo: bar + something: else +`) + + expectedDynamicFieldValue := &dynamicField{ + A: 1, + B: "abcd", + C: map[string]string{ + "foo": "bar", + "something": "else", + }, + } + + t.Run("UseJSONUnmarshaler and json.RawMessage", func(t *testing.T) { + var wrapper rawJSONWrapper + if err := yaml.UnmarshalWithOptions(rawData, &wrapper, yaml.UseJSONUnmarshaler()); err != nil { + t.Fatal(err) + } + if wrapper.StaticField != "value" { + t.Fatalf("unexpected wrapper static field value: %s", wrapper.StaticField) + } + var dynamicFieldValue dynamicField + if err := yaml.Unmarshal(wrapper.DynamicField, &dynamicFieldValue); err != nil { + t.Fatal(err) + } + if !dynamicFieldValue.Equals(expectedDynamicFieldValue) { + t.Fatalf("unexpected dynamic field value: %v", dynamicFieldValue) + } + }) + + t.Run("UseJSONUnmarshaler and yaml.RawMessage", func(t *testing.T) { + var wrapper rawYAMLWrapper + if err := yaml.UnmarshalWithOptions(rawData, &wrapper, yaml.UseJSONUnmarshaler()); err != nil { + t.Fatal(err) + } + if wrapper.StaticField != "value" { + t.Fatalf("unexpected wrapper static field value: %s", wrapper.StaticField) + } + var dynamicFieldValue dynamicField + if err := yaml.Unmarshal(wrapper.DynamicField, &dynamicFieldValue); err != nil { + t.Fatal(err) + } + if !dynamicFieldValue.Equals(expectedDynamicFieldValue) { + t.Fatalf("unexpected dynamic field value: %v", dynamicFieldValue) + } + }) + + t.Run("UseJSONMarshaler and json.RawMessage", func(t *testing.T) { + dynamicFieldBytes, err := yaml.Marshal(expectedDynamicFieldValue) + if err != nil { + t.Fatal(err) + } + wrapper := rawJSONWrapper{ + StaticField: "value", + DynamicField: json.RawMessage(dynamicFieldBytes), + } + wrapperBytes, err := yaml.MarshalWithOptions(&wrapper, yaml.UseJSONMarshaler()) + if err != nil { + t.Fatal(err) + } + var unmarshaledWrapper rawJSONWrapper + if err := yaml.UnmarshalWithOptions(wrapperBytes, &unmarshaledWrapper, yaml.UseJSONUnmarshaler()); err != nil { + t.Fatal(err) + } + if unmarshaledWrapper.StaticField != wrapper.StaticField { + t.Fatalf("unexpected unmarshaled static field value: %s", unmarshaledWrapper.StaticField) + } + var unmarshaledDynamicFieldValue dynamicField + if err := yaml.UnmarshalWithOptions(unmarshaledWrapper.DynamicField, &unmarshaledDynamicFieldValue, yaml.UseJSONUnmarshaler()); err != nil { + t.Fatal(err) + } + if !unmarshaledDynamicFieldValue.Equals(expectedDynamicFieldValue) { + t.Fatalf("unexpected unmarshaled dynamic field value: %v", unmarshaledDynamicFieldValue) + } + }) + + t.Run("UseJSONMarshaler and yaml.RawMessage", func(t *testing.T) { + dynamicFieldBytes, err := yaml.Marshal(expectedDynamicFieldValue) + if err != nil { + t.Fatal(err) + } + wrapper := rawYAMLWrapper{ + StaticField: "value", + DynamicField: yaml.RawMessage(dynamicFieldBytes), + } + wrapperBytes, err := yaml.MarshalWithOptions(&wrapper, yaml.UseJSONMarshaler()) + if err != nil { + t.Fatal(err) + } + var unmarshaledWrapper rawYAMLWrapper + if err := yaml.UnmarshalWithOptions(wrapperBytes, &unmarshaledWrapper, yaml.UseJSONUnmarshaler()); err != nil { + t.Fatal(err) + } + if unmarshaledWrapper.StaticField != wrapper.StaticField { + t.Fatalf("unexpected unmarshaled static field value: %s", unmarshaledWrapper.StaticField) + } + var unmarshaledDynamicFieldValue dynamicField + if err := yaml.UnmarshalWithOptions(unmarshaledWrapper.DynamicField, &unmarshaledDynamicFieldValue, yaml.UseJSONUnmarshaler()); err != nil { + t.Fatal(err) + } + if !unmarshaledDynamicFieldValue.Equals(expectedDynamicFieldValue) { + t.Fatalf("unexpected unmarshaled dynamic field value: %v", unmarshaledDynamicFieldValue) + } + }) +}