mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
crypto/tls: Client side support for TLS session resumption.
Adam (agl@) had already done an initial review of this CL in a branch. Added ClientSessionState to Config which now allows clients to keep state required to resume a TLS session with a server. A client handshake will try and use the SessionTicket/MasterSecret in this cached state if the server acknowledged resumption. We also added support to cache ClientSessionState object in Config that will be looked up by server remote address during the handshake. R=golang-codereviews, agl, rsc, agl, agl, bradfitz, mikioh.mikioh CC=golang-codereviews https://golang.org/cl/15680043
This commit is contained in:
parent
021c11683c
commit
988ffc0fe2
6 changed files with 435 additions and 70 deletions
|
|
@ -13,9 +13,20 @@ import (
|
|||
"encoding/asn1"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type clientHandshakeState struct {
|
||||
c *Conn
|
||||
serverHello *serverHelloMsg
|
||||
hello *clientHelloMsg
|
||||
suite *cipherSuite
|
||||
finishedHash finishedHash
|
||||
masterSecret []byte
|
||||
session *ClientSessionState
|
||||
}
|
||||
|
||||
func (c *Conn) clientHandshake() error {
|
||||
if c.config == nil {
|
||||
c.config = defaultConfig()
|
||||
|
|
@ -60,13 +71,58 @@ NextCipherSuite:
|
|||
_, err := io.ReadFull(c.config.rand(), hello.random[4:])
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("short read from Rand")
|
||||
return errors.New("tls: short read from Rand: " + err.Error())
|
||||
}
|
||||
|
||||
if hello.vers >= VersionTLS12 {
|
||||
hello.signatureAndHashes = supportedSKXSignatureAlgorithms
|
||||
}
|
||||
|
||||
var session *ClientSessionState
|
||||
var cacheKey string
|
||||
sessionCache := c.config.ClientSessionCache
|
||||
if c.config.SessionTicketsDisabled {
|
||||
sessionCache = nil
|
||||
}
|
||||
|
||||
if sessionCache != nil {
|
||||
hello.ticketSupported = true
|
||||
|
||||
// Try to resume a previously negotiated TLS session, if
|
||||
// available.
|
||||
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
||||
candidateSession, ok := sessionCache.Get(cacheKey)
|
||||
if ok {
|
||||
// Check that the ciphersuite/version used for the
|
||||
// previous session are still valid.
|
||||
cipherSuiteOk := false
|
||||
for _, id := range hello.cipherSuites {
|
||||
if id == candidateSession.cipherSuite {
|
||||
cipherSuiteOk = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
versOk := candidateSession.vers >= c.config.minVersion() &&
|
||||
candidateSession.vers <= c.config.maxVersion()
|
||||
if versOk && cipherSuiteOk {
|
||||
session = candidateSession
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
hello.sessionTicket = session.sessionTicket
|
||||
// A random session ID is used to detect when the
|
||||
// server accepted the ticket and is resuming a session
|
||||
// (see RFC 5077).
|
||||
hello.sessionId = make([]byte, 16)
|
||||
if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: short read from Rand: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
c.writeRecord(recordTypeHandshake, hello.marshal())
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
|
|
@ -86,25 +142,73 @@ NextCipherSuite:
|
|||
c.vers = vers
|
||||
c.haveVers = true
|
||||
|
||||
finishedHash := newFinishedHash(c.vers)
|
||||
finishedHash.Write(hello.marshal())
|
||||
finishedHash.Write(serverHello.marshal())
|
||||
|
||||
if serverHello.compressionMethod != compressionNone {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
if !hello.nextProtoNeg && serverHello.nextProtoNeg {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("server advertised unrequested NPN")
|
||||
}
|
||||
|
||||
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
|
||||
if suite == nil {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
hs := &clientHandshakeState{
|
||||
c: c,
|
||||
serverHello: serverHello,
|
||||
hello: hello,
|
||||
suite: suite,
|
||||
finishedHash: newFinishedHash(c.vers),
|
||||
session: session,
|
||||
}
|
||||
|
||||
hs.finishedHash.Write(hs.hello.marshal())
|
||||
hs.finishedHash.Write(hs.serverHello.marshal())
|
||||
|
||||
isResume, err := hs.processServerHello()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isResume {
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readSessionTicket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := hs.doFullHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readSessionTicket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if sessionCache != nil && hs.session != nil && session != hs.session {
|
||||
sessionCache.Put(cacheKey, hs.session)
|
||||
}
|
||||
|
||||
c.didResume = isResume
|
||||
c.handshakeComplete = true
|
||||
c.cipherSuite = suite.id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
c := hs.c
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -112,7 +216,7 @@ NextCipherSuite:
|
|||
if !ok || len(certMsg.certificates) == 0 {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(certMsg.marshal())
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
|
||||
certs := make([]*x509.Certificate, len(certMsg.certificates))
|
||||
for i, asn1Data := range certMsg.certificates {
|
||||
|
|
@ -154,7 +258,7 @@ NextCipherSuite:
|
|||
|
||||
c.peerCertificates = certs
|
||||
|
||||
if serverHello.ocspStapling {
|
||||
if hs.serverHello.ocspStapling {
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -163,7 +267,7 @@ NextCipherSuite:
|
|||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(cs.marshal())
|
||||
hs.finishedHash.Write(cs.marshal())
|
||||
|
||||
if cs.statusType == statusTypeOCSP {
|
||||
c.ocspResponse = cs.response
|
||||
|
|
@ -175,12 +279,12 @@ NextCipherSuite:
|
|||
return err
|
||||
}
|
||||
|
||||
keyAgreement := suite.ka(c.vers)
|
||||
keyAgreement := hs.suite.ka(c.vers)
|
||||
|
||||
skx, ok := msg.(*serverKeyExchangeMsg)
|
||||
if ok {
|
||||
finishedHash.Write(skx.marshal())
|
||||
err = keyAgreement.processServerKeyExchange(c.config, hello, serverHello, certs[0], skx)
|
||||
hs.finishedHash.Write(skx.marshal())
|
||||
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, certs[0], skx)
|
||||
if err != nil {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return err
|
||||
|
|
@ -209,7 +313,7 @@ NextCipherSuite:
|
|||
// ClientCertificateType, unless there is some external
|
||||
// arrangement to the contrary.
|
||||
|
||||
finishedHash.Write(certReq.marshal())
|
||||
hs.finishedHash.Write(certReq.marshal())
|
||||
|
||||
var rsaAvail, ecdsaAvail bool
|
||||
for _, certType := range certReq.certificateTypes {
|
||||
|
|
@ -274,7 +378,7 @@ NextCipherSuite:
|
|||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(shd.marshal())
|
||||
hs.finishedHash.Write(shd.marshal())
|
||||
|
||||
// If the server requested a certificate then we have to send a
|
||||
// Certificate message, even if it's empty because we don't have a
|
||||
|
|
@ -284,17 +388,17 @@ NextCipherSuite:
|
|||
if chainToSend != nil {
|
||||
certMsg.certificates = chainToSend.Certificate
|
||||
}
|
||||
finishedHash.Write(certMsg.marshal())
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certMsg.marshal())
|
||||
}
|
||||
|
||||
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hello, certs[0])
|
||||
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
if ckx != nil {
|
||||
finishedHash.Write(ckx.marshal())
|
||||
hs.finishedHash.Write(ckx.marshal())
|
||||
c.writeRecord(recordTypeHandshake, ckx.marshal())
|
||||
}
|
||||
|
||||
|
|
@ -306,7 +410,7 @@ NextCipherSuite:
|
|||
|
||||
switch key := c.config.Certificates[0].PrivateKey.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
digest, _, hashId := finishedHash.hashForClientCertificate(signatureECDSA)
|
||||
digest, _, hashId := hs.finishedHash.hashForClientCertificate(signatureECDSA)
|
||||
r, s, err := ecdsa.Sign(c.config.rand(), key, digest)
|
||||
if err == nil {
|
||||
signed, err = asn1.Marshal(ecdsaSignature{r, s})
|
||||
|
|
@ -314,7 +418,7 @@ NextCipherSuite:
|
|||
certVerify.signatureAndHash.signature = signatureECDSA
|
||||
certVerify.signatureAndHash.hash = hashId
|
||||
case *rsa.PrivateKey:
|
||||
digest, hashFunc, hashId := finishedHash.hashForClientCertificate(signatureRSA)
|
||||
digest, hashFunc, hashId := hs.finishedHash.hashForClientCertificate(signatureRSA)
|
||||
signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest)
|
||||
certVerify.signatureAndHash.signature = signatureRSA
|
||||
certVerify.signatureAndHash.hash = hashId
|
||||
|
|
@ -326,56 +430,73 @@ NextCipherSuite:
|
|||
}
|
||||
certVerify.signature = signed
|
||||
|
||||
finishedHash.Write(certVerify.marshal())
|
||||
hs.finishedHash.Write(certVerify.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certVerify.marshal())
|
||||
}
|
||||
|
||||
masterSecret := masterFromPreMasterSecret(c.vers, preMasterSecret, hello.random, serverHello.random)
|
||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, preMasterSecret, hs.hello.random, hs.serverHello.random)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) establishKeys() error {
|
||||
c := hs.c
|
||||
|
||||
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
||||
keysFromMasterSecret(c.vers, masterSecret, hello.random, serverHello.random, suite.macLen, suite.keyLen, suite.ivLen)
|
||||
|
||||
var clientCipher interface{}
|
||||
var clientHash macFunction
|
||||
if suite.cipher != nil {
|
||||
clientCipher = suite.cipher(clientKey, clientIV, false /* not for reading */)
|
||||
clientHash = suite.mac(c.vers, clientMAC)
|
||||
keysFromMasterSecret(c.vers, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
|
||||
var clientCipher, serverCipher interface{}
|
||||
var clientHash, serverHash macFunction
|
||||
if hs.suite.cipher != nil {
|
||||
clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
|
||||
clientHash = hs.suite.mac(c.vers, clientMAC)
|
||||
serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
|
||||
serverHash = hs.suite.mac(c.vers, serverMAC)
|
||||
} else {
|
||||
clientCipher = suite.aead(clientKey, clientIV)
|
||||
}
|
||||
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
|
||||
if serverHello.nextProtoNeg {
|
||||
nextProto := new(nextProtoMsg)
|
||||
proto, fallback := mutualProtocol(c.config.NextProtos, serverHello.nextProtos)
|
||||
nextProto.proto = proto
|
||||
c.clientProtocol = proto
|
||||
c.clientProtocolFallback = fallback
|
||||
|
||||
finishedHash.Write(nextProto.marshal())
|
||||
c.writeRecord(recordTypeHandshake, nextProto.marshal())
|
||||
clientCipher = hs.suite.aead(clientKey, clientIV)
|
||||
serverCipher = hs.suite.aead(serverKey, serverIV)
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = finishedHash.clientSum(masterSecret)
|
||||
finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
|
||||
var serverCipher interface{}
|
||||
var serverHash macFunction
|
||||
if suite.cipher != nil {
|
||||
serverCipher = suite.cipher(serverKey, serverIV, true /* for reading */)
|
||||
serverHash = suite.mac(c.vers, serverMAC)
|
||||
} else {
|
||||
serverCipher = suite.aead(serverKey, serverIV)
|
||||
}
|
||||
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
|
||||
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) serverResumedSession() bool {
|
||||
// If the server responded with the same sessionId then it means the
|
||||
// sessionTicket is being used to resume a TLS session.
|
||||
return hs.session != nil && hs.hello.sessionId != nil &&
|
||||
bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) processServerHello() (bool, error) {
|
||||
c := hs.c
|
||||
|
||||
if hs.serverHello.compressionMethod != compressionNone {
|
||||
return false, c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return false, errors.New("server advertised unrequested NPN")
|
||||
}
|
||||
|
||||
if hs.serverResumedSession() {
|
||||
// Restore masterSecret and peerCerts from previous state
|
||||
hs.masterSecret = hs.session.masterSecret
|
||||
c.peerCertificates = hs.session.serverCertificates
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) readFinished() error {
|
||||
c := hs.c
|
||||
|
||||
c.readRecord(recordTypeChangeCipherSpec)
|
||||
if err := c.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -384,17 +505,73 @@ NextCipherSuite:
|
|||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
verify := finishedHash.serverSum(masterSecret)
|
||||
verify := hs.finishedHash.serverSum(hs.masterSecret)
|
||||
if len(verify) != len(serverFinished.verifyData) ||
|
||||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
c.handshakeComplete = true
|
||||
c.cipherSuite = suite.id
|
||||
hs.finishedHash.Write(serverFinished.marshal())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) readSessionTicket() error {
|
||||
if !hs.serverHello.ticketSupported {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := hs.c
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
hs.finishedHash.Write(sessionTicketMsg.marshal())
|
||||
|
||||
hs.session = &ClientSessionState{
|
||||
sessionTicket: sessionTicketMsg.ticket,
|
||||
vers: c.vers,
|
||||
cipherSuite: hs.suite.id,
|
||||
masterSecret: hs.masterSecret,
|
||||
serverCertificates: c.peerCertificates,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) sendFinished() error {
|
||||
c := hs.c
|
||||
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
if hs.serverHello.nextProtoNeg {
|
||||
nextProto := new(nextProtoMsg)
|
||||
proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
|
||||
nextProto.proto = proto
|
||||
c.clientProtocol = proto
|
||||
c.clientProtocolFallback = fallback
|
||||
|
||||
hs.finishedHash.Write(nextProto.marshal())
|
||||
c.writeRecord(recordTypeHandshake, nextProto.marshal())
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
|
||||
hs.finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientSessionCacheKey returns a key used to cache sessionTickets that could
|
||||
// be used to resume previously negotiated TLS sessions with a server.
|
||||
func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
|
||||
if len(config.ServerName) > 0 {
|
||||
return config.ServerName
|
||||
}
|
||||
return serverAddr.String()
|
||||
}
|
||||
|
||||
// mutualProtocol finds the mutual Next Protocol Negotiation protocol given the
|
||||
// set of client and server supported protocols. The set of client supported
|
||||
// protocols must not be empty. It returns the resulting protocol and flag
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue