mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
crypto/tls: add Config.GetConfigForClient
GetConfigForClient allows the tls.Config to be updated on a per-client basis. Fixes #16066. Fixes #15707. Fixes #15699. Change-Id: I2c675a443d557f969441226729f98502b38901ea Reviewed-on: https://go-review.googlesource.com/30790 Run-TryBot: Adam Langley <agl@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
7e2bf952a9
commit
cff3e7587a
4 changed files with 225 additions and 32 deletions
|
|
@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var getConfigForClientTests = []struct {
|
||||
setup func(config *Config)
|
||||
callback func(clientHello *ClientHelloInfo) (*Config, error)
|
||||
errorSubstring string
|
||||
verify func(config *Config) error
|
||||
}{
|
||||
{
|
||||
nil,
|
||||
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||
return nil, nil
|
||||
},
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
nil,
|
||||
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||
return nil, errors.New("should bubble up")
|
||||
},
|
||||
"should bubble up",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
nil,
|
||||
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||
config := testConfig.Clone()
|
||||
// Setting a maximum version of TLS 1.1 should cause
|
||||
// the handshake to fail.
|
||||
config.MaxVersion = VersionTLS11
|
||||
return config, nil
|
||||
},
|
||||
"version 301 when expecting version 302",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
func(config *Config) {
|
||||
for i := range config.SessionTicketKey {
|
||||
config.SessionTicketKey[i] = byte(i)
|
||||
}
|
||||
config.sessionTicketKeys = nil
|
||||
},
|
||||
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||
config := testConfig.Clone()
|
||||
for i := range config.SessionTicketKey {
|
||||
config.SessionTicketKey[i] = 0
|
||||
}
|
||||
config.sessionTicketKeys = nil
|
||||
return config, nil
|
||||
},
|
||||
"",
|
||||
func(config *Config) error {
|
||||
// The value of SessionTicketKey should have been
|
||||
// duplicated into the per-connection Config.
|
||||
for i := range config.SessionTicketKey {
|
||||
if b := config.SessionTicketKey[i]; b != byte(i) {
|
||||
return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
func(config *Config) {
|
||||
var dummyKey [32]byte
|
||||
for i := range dummyKey {
|
||||
dummyKey[i] = byte(i)
|
||||
}
|
||||
|
||||
config.SetSessionTicketKeys([][32]byte{dummyKey})
|
||||
},
|
||||
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||
config := testConfig.Clone()
|
||||
config.sessionTicketKeys = nil
|
||||
return config, nil
|
||||
},
|
||||
"",
|
||||
func(config *Config) error {
|
||||
// The session ticket keys should have been duplicated
|
||||
// into the per-connection Config.
|
||||
if l := len(config.sessionTicketKeys); l != 1 {
|
||||
return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestGetConfigForClient(t *testing.T) {
|
||||
serverConfig := testConfig.Clone()
|
||||
clientConfig := testConfig.Clone()
|
||||
clientConfig.MinVersion = VersionTLS12
|
||||
|
||||
for i, test := range getConfigForClientTests {
|
||||
if test.setup != nil {
|
||||
test.setup(serverConfig)
|
||||
}
|
||||
|
||||
var configReturned *Config
|
||||
serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||
config, err := test.callback(clientHello)
|
||||
configReturned = config
|
||||
return config, err
|
||||
}
|
||||
c, s := net.Pipe()
|
||||
done := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer s.Close()
|
||||
done <- Server(s, serverConfig).Handshake()
|
||||
}()
|
||||
|
||||
clientErr := Client(c, clientConfig).Handshake()
|
||||
c.Close()
|
||||
|
||||
serverErr := <-done
|
||||
|
||||
if len(test.errorSubstring) == 0 {
|
||||
if serverErr != nil || clientErr != nil {
|
||||
t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr)
|
||||
}
|
||||
if test.verify != nil {
|
||||
if err := test.verify(configReturned); err != nil {
|
||||
t.Errorf("#%d: verify returned error: %v", i, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if serverErr == nil {
|
||||
t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring)
|
||||
} else if !strings.Contains(serverErr.Error(), test.errorSubstring) {
|
||||
t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bigFromString(s string) *big.Int {
|
||||
ret := new(big.Int)
|
||||
ret.SetString(s, 10)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue