crypto/tls: extensions and Next Protocol Negotiation

Add support for TLS extensions in general and Next Protocol
Negotiation in particular.

R=rsc
CC=golang-dev
https://golang.org/cl/181045
This commit is contained in:
Adam Langley 2009-12-23 11:13:09 -08:00
parent 7c9111434a
commit 9ebb59634e
9 changed files with 379 additions and 28 deletions

View file

@ -4,6 +4,8 @@
package tls
import "strings"
type clientHelloMsg struct {
raw []byte
major, minor uint8
@ -11,6 +13,8 @@ type clientHelloMsg struct {
sessionId []byte
cipherSuites []uint16
compressionMethods []uint8
nextProtoNeg bool
serverName string
}
func (m *clientHelloMsg) marshal() []byte {
@ -19,6 +23,20 @@ func (m *clientHelloMsg) marshal() []byte {
}
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
numExtensions := 0
extensionsLength := 0
if m.nextProtoNeg {
numExtensions++
}
if len(m.serverName) > 0 {
extensionsLength += 5 + len(m.serverName)
numExtensions++
}
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
}
x := make([]byte, 4+length)
x[0] = typeClientHello
x[1] = uint8(length >> 16)
@ -39,6 +57,53 @@ func (m *clientHelloMsg) marshal() []byte {
z := y[2+len(m.cipherSuites)*2:]
z[0] = uint8(len(m.compressionMethods))
copy(z[1:], m.compressionMethods)
z = z[1+len(m.compressionMethods):]
if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength)
z = z[2:]
}
if m.nextProtoNeg {
z[0] = byte(extensionNextProtoNeg >> 8)
z[1] = byte(extensionNextProtoNeg)
// The length is always 0
z = z[4:]
}
if len(m.serverName) > 0 {
z[0] = byte(extensionServerName >> 8)
z[1] = byte(extensionServerName)
l := len(m.serverName) + 5
z[2] = byte(l >> 8)
z[3] = byte(l)
z = z[4:]
// RFC 3546, section 3.1
//
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
//
// enum {
// host_name(0), (255)
// } NameType;
//
// opaque HostName<1..2^16-1>;
//
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
z[1] = 1
z[3] = byte(len(m.serverName) >> 8)
z[4] = byte(len(m.serverName))
copy(z[5:], strings.Bytes(m.serverName))
z = z[l:]
}
m.raw = x
return x
@ -82,7 +147,68 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
}
m.compressionMethods = data[1 : 1+compressionMethodsLen]
// A ClientHello may be following by trailing data: RFC 4346 section 7.4.1.2
data = data[1+compressionMethodsLen:]
m.nextProtoNeg = false
m.serverName = ""
if len(data) == 0 {
// ClientHello is optionally followed by extension data
return true
}
if len(data) < 2 {
return false
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
return false
}
for len(data) != 0 {
if len(data) < 4 {
return false
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return false
}
switch extension {
case extensionServerName:
if length < 2 {
return false
}
numNames := int(data[0])<<8 | int(data[1])
d := data[2:]
for i := 0; i < numNames; i++ {
if len(d) < 3 {
return false
}
nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2])
d = d[3:]
if len(d) < nameLen {
return false
}
if nameType == 0 {
m.serverName = string(d[0:nameLen])
break
}
d = d[nameLen:]
}
case extensionNextProtoNeg:
if length > 0 {
return false
}
m.nextProtoNeg = true
}
data = data[length:]
}
return true
}
@ -93,6 +219,8 @@ type serverHelloMsg struct {
sessionId []byte
cipherSuite uint16
compressionMethod uint8
nextProtoNeg bool
nextProtos []string
}
func (m *serverHelloMsg) marshal() []byte {
@ -101,6 +229,23 @@ func (m *serverHelloMsg) marshal() []byte {
}
length := 38 + len(m.sessionId)
numExtensions := 0
extensionsLength := 0
nextProtoLen := 0
if m.nextProtoNeg {
numExtensions++
for _, v := range m.nextProtos {
nextProtoLen += len(v)
}
nextProtoLen += len(m.nextProtos)
extensionsLength += nextProtoLen
}
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
}
x := make([]byte, 4+length)
x[0] = typeServerHello
x[1] = uint8(length >> 16)
@ -115,11 +260,49 @@ func (m *serverHelloMsg) marshal() []byte {
z[0] = uint8(m.cipherSuite >> 8)
z[1] = uint8(m.cipherSuite)
z[2] = uint8(m.compressionMethod)
z = z[3:]
if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength)
z = z[2:]
}
if m.nextProtoNeg {
z[0] = byte(extensionNextProtoNeg >> 8)
z[1] = byte(extensionNextProtoNeg)
z[2] = byte(nextProtoLen >> 8)
z[3] = byte(nextProtoLen)
z = z[4:]
for _, v := range m.nextProtos {
l := len(v)
if l > 255 {
l = 255
}
z[0] = byte(l)
copy(z[1:], strings.Bytes(v[0:l]))
z = z[1+l:]
}
}
m.raw = x
return x
}
func append(slice []string, elem string) []string {
if len(slice) < cap(slice) {
slice = slice[0 : len(slice)+1]
slice[len(slice)-1] = elem
return slice
}
fresh := make([]string, len(slice)+1, cap(slice)*2+1)
copy(fresh, slice)
fresh[len(slice)] = elem
return fresh
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
if len(data) < 42 {
return false
@ -139,8 +322,53 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
}
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
m.compressionMethod = data[2]
data = data[3:]
m.nextProtoNeg = false
m.nextProtos = nil
if len(data) == 0 {
// ServerHello is optionally followed by extension data
return true
}
if len(data) < 2 {
return false
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if len(data) != extensionsLength {
return false
}
for len(data) != 0 {
if len(data) < 4 {
return false
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return false
}
switch extension {
case extensionNextProtoNeg:
m.nextProtoNeg = true
d := data
for len(d) > 0 {
l := int(d[0])
d = d[1:]
if l == 0 || l > len(d) {
return false
}
m.nextProtos = append(m.nextProtos, string(d[0:l]))
d = d[l:]
}
}
data = data[length:]
}
// Trailing data is allowed because extensions may be present.
return true
}
@ -295,3 +523,63 @@ func (m *finishedMsg) unmarshal(data []byte) bool {
m.verifyData = data[4:]
return true
}
type nextProtoMsg struct {
raw []byte
proto string
}
func (m *nextProtoMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
l := len(m.proto)
if l > 255 {
l = 255
}
padding := 32 - (l+2)%32
length := l + padding + 2
x := make([]byte, length+4)
x[0] = typeNextProtocol
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
y := x[4:]
y[0] = byte(l)
copy(y[1:], strings.Bytes(m.proto[0:l]))
y = y[1+l:]
y[0] = byte(padding)
m.raw = x
return x
}
func (m *nextProtoMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 5 {
return false
}
data = data[4:]
protoLen := int(data[0])
data = data[1:]
if len(data) < protoLen {
return false
}
m.proto = string(data[0:protoLen])
data = data[protoLen:]
if len(data) < 1 {
return false
}
paddingLen := int(data[0])
data = data[1:]
if len(data) != paddingLen {
return false
}
return true
}