crypto/tls: add Config.GetConfigForClient

GetConfigForClient allows the tls.Config to be updated on a per-client
basis.

Fixes #16066.
Fixes #15707.
Fixes #15699.

Change-Id: I2c675a443d557f969441226729f98502b38901ea
Reviewed-on: https://go-review.googlesource.com/30790
Run-TryBot: Adam Langley <agl@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Adam Langley 2016-10-10 15:27:34 -07:00 committed by Brad Fitzpatrick
parent 7e2bf952a9
commit cff3e7587a
4 changed files with 225 additions and 32 deletions

View file

@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) {
}
}
var getConfigForClientTests = []struct {
setup func(config *Config)
callback func(clientHello *ClientHelloInfo) (*Config, error)
errorSubstring string
verify func(config *Config) error
}{
{
nil,
func(clientHello *ClientHelloInfo) (*Config, error) {
return nil, nil
},
"",
nil,
},
{
nil,
func(clientHello *ClientHelloInfo) (*Config, error) {
return nil, errors.New("should bubble up")
},
"should bubble up",
nil,
},
{
nil,
func(clientHello *ClientHelloInfo) (*Config, error) {
config := testConfig.Clone()
// Setting a maximum version of TLS 1.1 should cause
// the handshake to fail.
config.MaxVersion = VersionTLS11
return config, nil
},
"version 301 when expecting version 302",
nil,
},
{
func(config *Config) {
for i := range config.SessionTicketKey {
config.SessionTicketKey[i] = byte(i)
}
config.sessionTicketKeys = nil
},
func(clientHello *ClientHelloInfo) (*Config, error) {
config := testConfig.Clone()
for i := range config.SessionTicketKey {
config.SessionTicketKey[i] = 0
}
config.sessionTicketKeys = nil
return config, nil
},
"",
func(config *Config) error {
// The value of SessionTicketKey should have been
// duplicated into the per-connection Config.
for i := range config.SessionTicketKey {
if b := config.SessionTicketKey[i]; b != byte(i) {
return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b)
}
}
return nil
},
},
{
func(config *Config) {
var dummyKey [32]byte
for i := range dummyKey {
dummyKey[i] = byte(i)
}
config.SetSessionTicketKeys([][32]byte{dummyKey})
},
func(clientHello *ClientHelloInfo) (*Config, error) {
config := testConfig.Clone()
config.sessionTicketKeys = nil
return config, nil
},
"",
func(config *Config) error {
// The session ticket keys should have been duplicated
// into the per-connection Config.
if l := len(config.sessionTicketKeys); l != 1 {
return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l)
}
return nil
},
},
}
func TestGetConfigForClient(t *testing.T) {
serverConfig := testConfig.Clone()
clientConfig := testConfig.Clone()
clientConfig.MinVersion = VersionTLS12
for i, test := range getConfigForClientTests {
if test.setup != nil {
test.setup(serverConfig)
}
var configReturned *Config
serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
config, err := test.callback(clientHello)
configReturned = config
return config, err
}
c, s := net.Pipe()
done := make(chan error)
go func() {
defer s.Close()
done <- Server(s, serverConfig).Handshake()
}()
clientErr := Client(c, clientConfig).Handshake()
c.Close()
serverErr := <-done
if len(test.errorSubstring) == 0 {
if serverErr != nil || clientErr != nil {
t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr)
}
if test.verify != nil {
if err := test.verify(configReturned); err != nil {
t.Errorf("#%d: verify returned error: %v", i, err)
}
}
} else {
if serverErr == nil {
t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring)
} else if !strings.Contains(serverErr.Error(), test.errorSubstring) {
t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr)
}
}
}
}
func bigFromString(s string) *big.Int {
ret := new(big.Int)
ret.SetString(s, 10)