crypto/tls: support crypto.MessageSigner private keys

Fixes #75656

Change-Id: I6bc71c80973765ef995d17b1450ea2026a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/724820
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Nicholas Husin <husin@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicholas Husin <nsh@golang.org>
This commit is contained in:
Filippo Valsorda 2025-11-26 21:11:35 +01:00 committed by Gopher Robot
parent 3fd9cb1895
commit 992ad55e3d
10 changed files with 214 additions and 117 deletions

View file

@ -0,0 +1,2 @@
If [Certificate.PrivateKey] implements [crypto.MessageSigner], its SignMessage
method is used instead of Sign in TLS 1.2 and later.

View file

@ -18,9 +18,13 @@ import (
"slices" "slices"
) )
// verifyHandshakeSignature verifies a signature against pre-hashed // verifyHandshakeSignature verifies a signature against unhashed handshake contents.
// (if required) handshake contents.
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
if hashFunc != directSigning {
h := hashFunc.New()
h.Write(signed)
signed = h.Sum(nil)
}
switch sigType { switch sigType {
case signatureECDSA: case signatureECDSA:
pubKey, ok := pubkey.(*ecdsa.PublicKey) pubKey, ok := pubkey.(*ecdsa.PublicKey)
@ -61,6 +65,32 @@ func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc c
return nil return nil
} }
// verifyLegacyHandshakeSignature verifies a TLS 1.0 and 1.1 signature against
// pre-hashed handshake contents.
func verifyLegacyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, hashed, sig []byte) error {
switch sigType {
case signatureECDSA:
pubKey, ok := pubkey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
}
if !ecdsa.VerifyASN1(pubKey, hashed, sig) {
return errors.New("ECDSA verification failure")
}
case signaturePKCS1v15:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, sig); err != nil {
return err
}
default:
return errors.New("internal error: unknown signature type")
}
return nil
}
const ( const (
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
@ -77,21 +107,15 @@ var signaturePadding = []byte{
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
} }
// signedMessage returns the pre-hashed (if necessary) message to be signed by // signedMessage returns the (unhashed) message to be signed by certificate keys
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. // in TLS 1.3. See RFC 8446, Section 4.4.3.
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { func signedMessage(context string, transcript hash.Hash) []byte {
if sigHash == directSigning { const maxSize = 64 /* signaturePadding */ + len(serverSignatureContext) + 512/8 /* SHA-512 */
b := &bytes.Buffer{} b := bytes.NewBuffer(make([]byte, 0, maxSize))
b.Write(signaturePadding) b.Write(signaturePadding)
io.WriteString(b, context) io.WriteString(b, context)
b.Write(transcript.Sum(nil)) b.Write(transcript.Sum(nil))
return b.Bytes() return b.Bytes()
}
h := sigHash.New()
h.Write(signaturePadding)
io.WriteString(h, context)
h.Write(transcript.Sum(nil))
return h.Sum(nil)
} }
// typeAndHashFromSignatureScheme returns the corresponding signature type and // typeAndHashFromSignatureScheme returns the corresponding signature type and

View file

@ -1597,9 +1597,14 @@ var writerMutex sync.Mutex
type Certificate struct { type Certificate struct {
Certificate [][]byte Certificate [][]byte
// PrivateKey contains the private key corresponding to the public key in // PrivateKey contains the private key corresponding to the public key in
// Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey. // Leaf. This must implement [crypto.Signer] with an RSA, ECDSA or Ed25519
// PublicKey.
//
// For a server up to TLS 1.2, it can also implement crypto.Decrypter with // For a server up to TLS 1.2, it can also implement crypto.Decrypter with
// an RSA PublicKey. // an RSA PublicKey.
//
// If it implements [crypto.MessageSigner], SignMessage will be used instead
// of Sign for TLS 1.2 and later.
PrivateKey crypto.PrivateKey PrivateKey crypto.PrivateKey
// SupportedSignatureAlgorithms is an optional list restricting what // SupportedSignatureAlgorithms is an optional list restricting what
// signature algorithms the PrivateKey can be used for. // signature algorithms the PrivateKey can be used for.

View file

@ -781,15 +781,13 @@ func (hs *clientHandshakeState) doFullHandshake() error {
return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey)
} }
var sigType uint8
var sigHash crypto.Hash
if c.vers >= VersionTLS12 { if c.vers >= VersionTLS12 {
signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return err return err
} }
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) sigType, sigHash, err := typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil { if err != nil {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
@ -799,23 +797,31 @@ func (hs *clientHandshakeState) doFullHandshake() error {
tlssha1.Value() // ensure godebug is initialized tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault() tlssha1.IncNonDefault()
} }
if hs.finishedHash.buffer == nil {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: did not keep handshake transcript for TLS 1.2")
}
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
certVerify.signature, err = crypto.SignMessage(key, c.config.rand(), hs.finishedHash.buffer, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
} else { } else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) sigType, sigHash, err := legacyTypeAndHashFromPublicKey(key.Public())
if err != nil { if err != nil {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return err return err
} }
} signed := hs.finishedHash.hashForClientCertificate(sigType)
certVerify.signature, err = key.Sign(c.config.rand(), signed, sigHash)
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) if err != nil {
signOpts := crypto.SignerOpts(sigHash) c.sendAlert(alertInternalError)
if sigType == signatureRSAPSS { return err
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} }
}
certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return err
} }
if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil {

View file

@ -664,7 +664,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) signed := signedMessage(serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil { sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError) c.sendAlert(alertDecryptError)
@ -783,12 +783,12 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) signed := signedMessage(clientSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash) signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS { if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
} }
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) sig, err := crypto.SignMessage(cert.PrivateKey.(crypto.Signer), c.config.rand(), signed, signOpts)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return errors.New("tls: failed to sign handshake: " + err.Error()) return errors.New("tls: failed to sign handshake: " + err.Error())

View file

@ -780,19 +780,27 @@ func (hs *serverHandshakeState) doFullHandshake() error {
tlssha1.Value() // ensure godebug is initialized tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault() tlssha1.IncNonDefault()
} }
if hs.finishedHash.buffer == nil {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: did not keep handshake transcript for TLS 1.2")
}
if err := verifyHandshakeSignature(sigType, pub, sigHash, hs.finishedHash.buffer, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
} else { } else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
if err != nil { if err != nil {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return err return err
} }
signed := hs.finishedHash.hashForClientCertificate(sigType)
if err := verifyLegacyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
} }
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash)
if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
c.peerSigAlg = certVerify.signatureAlgorithm c.peerSigAlg = certVerify.signatureAlgorithm
if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {

View file

@ -845,12 +845,12 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) signed := signedMessage(serverSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash) signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS { if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
} }
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) sig, err := crypto.SignMessage(hs.cert.PrivateKey.(crypto.Signer), c.config.rand(), signed, signOpts)
if err != nil { if err != nil {
public := hs.cert.PrivateKey.(crypto.Signer).Public() public := hs.cert.PrivateKey.(crypto.Signer).Public()
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
@ -1081,7 +1081,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
return c.sendAlert(alertInternalError) return c.sendAlert(alertInternalError)
} }
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) signed := signedMessage(clientSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil { sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError) c.sendAlert(alertDecryptError)

View file

@ -127,25 +127,8 @@ func md5SHA1Hash(slices [][]byte) []byte {
} }
// hashForServerKeyExchange hashes the given slices and returns their digest // hashForServerKeyExchange hashes the given slices and returns their digest
// using the given hash function (for TLS 1.2) or using a default based on // using a hash based on the sigType. It can only be used for TLS 1.0 and 1.1.
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't func hashForServerKeyExchange(sigType uint8, slices ...[]byte) []byte {
// do pre-hashing, it returns the concatenation of the slices.
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
if sigType == signatureEd25519 {
var signed []byte
for _, slice := range slices {
signed = append(signed, slice...)
}
return signed
}
if version >= VersionTLS12 {
h := hashFunc.New()
for _, slice := range slices {
h.Write(slice)
}
digest := h.Sum(nil)
return digest
}
if sigType == signatureECDSA { if sigType == signatureECDSA {
return sha1Hash(slices) return sha1Hash(slices)
} }
@ -207,14 +190,13 @@ func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Cer
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
} }
var sigType uint8 var sig []byte
var sigHash crypto.Hash
if ka.version >= VersionTLS12 { if ka.version >= VersionTLS12 {
ka.signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) ka.signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sigType, sigHash, err = typeAndHashFromSignatureScheme(ka.signatureAlgorithm) sigType, sigHash, err := typeAndHashFromSignatureScheme(ka.signatureAlgorithm)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -222,25 +204,31 @@ func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Cer
tlssha1.Value() // ensure godebug is initialized tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault() tlssha1.IncNonDefault()
} }
signed := slices.Concat(clientHello.random, hello.random, serverECDHEParams)
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
}
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err = crypto.SignMessage(priv, config.rand(), signed, signOpts)
if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
}
} else { } else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) sigType, sigHash, err := legacyTypeAndHashFromPublicKey(priv.Public())
if err != nil { if err != nil {
return nil, err return nil, err
} }
} signed := hashForServerKeyExchange(sigType, clientHello.random, hello.random, serverECDHEParams)
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { if (sigType == signaturePKCS1v15) != ka.isRSA {
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
} }
sig, err = priv.Sign(config.rand(), signed, sigHash)
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
signOpts := crypto.SignerOpts(sigHash) }
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := priv.Sign(config.rand(), signed, signOpts)
if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
} }
skx := new(serverKeyExchangeMsg) skx := new(serverKeyExchangeMsg)
@ -300,6 +288,18 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
if len(sig) < 2 { if len(sig) < 2 {
return errServerKeyExchange return errServerKeyExchange
} }
if ka.version >= VersionTLS12 {
ka.signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
}
sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) {
return errServerKeyExchange
}
sig = sig[2:]
if !slices.Contains(clientHello.supportedCurves, ka.curveID) { if !slices.Contains(clientHello.supportedCurves, ka.curveID) {
return errors.New("tls: server selected unoffered curve") return errors.New("tls: server selected unoffered curve")
@ -333,12 +333,6 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
var sigType uint8 var sigType uint8
var sigHash crypto.Hash var sigHash crypto.Hash
if ka.version >= VersionTLS12 { if ka.version >= VersionTLS12 {
ka.signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
if !isSupportedSignatureAlgorithm(ka.signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { if !isSupportedSignatureAlgorithm(ka.signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm") return errors.New("tls: certificate used with invalid signature algorithm")
} }
@ -350,26 +344,27 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
tlssha1.Value() // ensure godebug is initialized tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault() tlssha1.IncNonDefault()
} }
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return errServerKeyExchange
}
signed := slices.Concat(clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
} else { } else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
if err != nil { if err != nil {
return err return err
} }
} if (sigType == signaturePKCS1v15) != ka.isRSA {
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { return errServerKeyExchange
return errServerKeyExchange }
signed := hashForServerKeyExchange(sigType, clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyLegacyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
} }
sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) {
return errServerKeyExchange
}
sig = sig[2:]
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
return nil return nil
} }

View file

@ -221,23 +221,9 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte {
return h.prf(masterSecret, serverFinishedLabel, h.Sum(), finishedVerifyLength) return h.prf(masterSecret, serverFinishedLabel, h.Sum(), finishedVerifyLength)
} }
// hashForClientCertificate returns the handshake messages so far, pre-hashed if // hashForClientCertificate returns the handshake messages so far, pre-hashed,
// necessary, suitable for signing by a TLS client certificate. // suitable for signing by a TLS 1.0 and 1.1 client certificate.
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte { func (h finishedHash) hashForClientCertificate(sigType uint8) []byte {
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
}
if sigType == signatureEd25519 {
return h.buffer
}
if h.version >= VersionTLS12 {
hash := hashAlg.New()
hash.Write(h.buffer)
return hash.Sum(nil)
}
if sigType == signatureECDSA { if sigType == signatureECDSA {
return h.server.Sum(nil) return h.server.Sum(nil)
} }

View file

@ -2390,3 +2390,74 @@ func TestECH(t *testing.T) {
check() check()
} }
func TestMessageSigner(t *testing.T) {
t.Run("TLSv10", func(t *testing.T) { testMessageSigner(t, VersionTLS10) })
t.Run("TLSv12", func(t *testing.T) { testMessageSigner(t, VersionTLS12) })
t.Run("TLSv13", func(t *testing.T) { testMessageSigner(t, VersionTLS13) })
}
func testMessageSigner(t *testing.T, version uint16) {
clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
serverConfig.ClientAuth = RequireAnyClientCert
clientConfig.MinVersion = version
clientConfig.MaxVersion = version
serverConfig.MinVersion = version
serverConfig.MaxVersion = version
clientConfig.Certificates = []Certificate{{
Certificate: [][]byte{testRSACertificate},
PrivateKey: messageOnlySigner{testRSAPrivateKey},
}}
serverConfig.Certificates = []Certificate{{
Certificate: [][]byte{testRSACertificate},
PrivateKey: messageOnlySigner{testRSAPrivateKey},
}}
_, _, err := testHandshake(t, clientConfig, serverConfig)
if version < VersionTLS12 {
if err == nil {
t.Fatal("expected failure for TLS 1.0/1.1")
}
} else {
if err != nil {
t.Fatalf("unexpected failure: %s", err)
}
}
clientConfig.Certificates = []Certificate{{
Certificate: [][]byte{testECDSACertificate},
PrivateKey: messageOnlySigner{testECDSAPrivateKey},
}}
serverConfig.Certificates = []Certificate{{
Certificate: [][]byte{testECDSACertificate},
PrivateKey: messageOnlySigner{testECDSAPrivateKey},
}}
_, _, err = testHandshake(t, clientConfig, serverConfig)
if version < VersionTLS12 {
if err == nil {
t.Fatal("expected failure for TLS 1.0/1.1")
}
} else {
if err != nil {
t.Fatalf("unexpected failure: %s", err)
}
}
}
type messageOnlySigner struct{ crypto.Signer }
func (s messageOnlySigner) Public() crypto.PublicKey {
return s.Signer.Public()
}
func (s messageOnlySigner) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) (signature []byte, err error) {
return nil, errors.New("messageOnlySigner: Sign called")
}
func (s messageOnlySigner) SignMessage(rand io.Reader, msg []byte, opts crypto.SignerOpts) (signature []byte, err error) {
h := opts.HashFunc().New()
h.Write(msg)
digest := h.Sum(nil)
return s.Signer.Sign(rand, digest, opts)
}