crypto/internal/hpke: propagate hkdf error value

The hkdf operations done in hpke are not expected to fail given that
we control the inputs. However, propagating the error instead of
doesn't hurt and makes the code more robust to future changes.

Change-Id: I168854593a40f67e2cc275e0dedc3b24b8f1480e
Reviewed-on: https://go-review.googlesource.com/c/go/+/658475
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
qmuntal 2025-03-17 16:16:53 +01:00 committed by Quim Muntal
parent 44d1d2e5ad
commit eb7ab11aaf

View file

@ -26,31 +26,23 @@ type hkdfKDF struct {
hash crypto.Hash hash crypto.Hash
} }
func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte { func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) {
labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey)) labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
labeledIKM = append(labeledIKM, sid...) labeledIKM = append(labeledIKM, sid...)
labeledIKM = append(labeledIKM, label...) labeledIKM = append(labeledIKM, label...)
labeledIKM = append(labeledIKM, inputKey...) labeledIKM = append(labeledIKM, inputKey...)
prk, err := hkdf.Extract(kdf.hash.New, labeledIKM, salt) return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
if err != nil {
panic(err)
}
return prk
} }
func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte { func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) {
labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info)) labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
labeledInfo = byteorder.BEAppendUint16(labeledInfo, length) labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
labeledInfo = append(labeledInfo, []byte("HPKE-v1")...) labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
labeledInfo = append(labeledInfo, suiteID...) labeledInfo = append(labeledInfo, suiteID...)
labeledInfo = append(labeledInfo, label...) labeledInfo = append(labeledInfo, label...)
labeledInfo = append(labeledInfo, info...) labeledInfo = append(labeledInfo, info...)
key, err := hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length)) return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length))
if err != nil {
panic(err)
}
return key
} }
// dhKEM implements the KEM specified in RFC 9180, Section 4.1. // dhKEM implements the KEM specified in RFC 9180, Section 4.1.
@ -88,8 +80,11 @@ func newDHKem(kemID uint16) (*dhKEM, error) {
}, nil }, nil
} }
func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte { func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey) eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
if err != nil {
return nil, err
}
return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret) return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
} }
@ -111,8 +106,11 @@ func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encap
encPubRecip := pubRecipient.Bytes() encPubRecip := pubRecipient.Bytes()
kemContext := append(encPubEph, encPubRecip...) kemContext := append(encPubEph, encPubRecip...)
sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext)
return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil if err != nil {
return nil, nil, err
}
return sharedSecret, encPubEph, nil
} }
func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) { func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
@ -125,8 +123,7 @@ func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte,
return nil, err return nil, err
} }
kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...) kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
return dh.ExtractAndExpand(dhVal, kemContext)
return dh.ExtractAndExpand(dhVal, kemContext), nil
} }
type context struct { type context struct {
@ -201,16 +198,33 @@ func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (
return nil, errors.New("unsupported AEAD id") return nil, errors.New("unsupported AEAD id")
} }
pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil) pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info) if err != nil {
return nil, err
}
infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info)
if err != nil {
return nil, err
}
ksContext := append([]byte{0}, pskIDHash...) ksContext := append([]byte{0}, pskIDHash...)
ksContext = append(ksContext, infoHash...) ksContext = append(ksContext, infoHash...)
secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
if err != nil {
key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) return nil, err
baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) }
exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
if err != nil {
return nil, err
}
baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
if err != nil {
return nil, err
}
exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
if err != nil {
return nil, err
}
aead, err := aeadInfo.aead(key) aead, err := aeadInfo.aead(key)
if err != nil { if err != nil {