net: keep waiting for valid DNS response until timeout

Prevents denial of service attacks from bogus UDP packets.

Fixes #13281.

Change-Id: Ifb51b17a1b0807bfd27b144d6037431701184e7b
Reviewed-on: https://go-review.googlesource.com/22126
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Matthew Dempsky 2016-04-15 19:19:58 -07:00
parent 9f1ccd647f
commit 3411d63219
4 changed files with 263 additions and 63 deletions

View file

@ -38,46 +38,67 @@ type dnsConn interface {
SetDeadline(time.Time) error
// readDNSResponse reads a DNS response message from the DNS
// transport endpoint and returns the received DNS response
// message.
readDNSResponse() (*dnsMsg, error)
// writeDNSQuery writes a DNS query message to the DNS
// connection endpoint.
writeDNSQuery(*dnsMsg) error
// dnsRoundTrip executes a single DNS transaction, returning a
// DNS response message for the provided DNS query message.
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
}
func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
b := make([]byte, 512) // see RFC 1035
n, err := c.Read(b)
if err != nil {
return nil, err
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
return msg, nil
func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
return dnsRoundTripUDP(c, query)
}
func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
b, ok := msg.Pack()
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
// such as a *UDPConn.
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return errors.New("cannot marshal DNS message")
return nil, errors.New("cannot marshal DNS message")
}
if _, err := c.Write(b); err != nil {
return err
return nil, err
}
b = make([]byte, 512) // see RFC 1035
for {
n, err := c.Read(b)
if err != nil {
return nil, err
}
resp := &dnsMsg{}
if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
// Ignore invalid responses as they may be malicious
// forgery attempts. Instead continue waiting until
// timeout. See golang.org/issue/13281.
continue
}
return resp, nil
}
return nil
}
func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
return dnsRoundTripTCP(c, out)
}
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
// such as a *TCPConn.
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
}
l := len(b)
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
return nil, err
}
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err
}
l := int(b[0])<<8 | int(b[1])
l = int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
@ -85,24 +106,14 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
if err != nil {
return nil, err
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
resp := &dnsMsg{}
if !resp.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
return msg, nil
}
func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
b, ok := msg.Pack()
if !ok {
return errors.New("cannot marshal DNS message")
if !resp.IsResponseTo(query) {
return nil, errors.New("invalid DNS response")
}
l := uint16(len(b))
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
return err
}
return nil
return resp, nil
}
func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg,
c.SetDeadline(d)
}
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
if err := c.writeDNSQuery(&out); err != nil {
return nil, mapErr(err)
}
in, err := c.readDNSResponse()
in, err := c.dnsRoundTrip(&out)
if err != nil {
return nil, mapErr(err)
}
if in.id != out.id {
return nil, errors.New("DNS message ID mismatch")
}
if in.truncated { // see RFC 5966
continue
}