net: use golang.org/x/net/dns/dnsmessage for DNS resolution

Vendors golang.org/x/net/dns/dnsmessage from x/net git rev
892bf7b0c6e2f93b51166bf3882e50277fa5afc6

Updates #16218
Updates #21160

Change-Id: Ic4e8f3c3d83c2936354ec14c5be93b0d2b42dd91
Reviewed-on: https://go-review.googlesource.com/37879
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Ian Gudger 2017-11-22 17:12:30 -08:00 committed by Brad Fitzpatrick
parent c830e05a20
commit 672729ebbd
11 changed files with 4167 additions and 1874 deletions

View file

@ -23,142 +23,231 @@ import (
"os"
"sync"
"time"
"golang_org/x/net/dns/dnsmessage"
)
// A dnsConn represents a DNS transport endpoint.
type dnsConn interface {
io.Closer
SetDeadline(time.Time) error
// dnsRoundTrip executes a single DNS transaction, returning a
// DNS response message for the provided DNS query message.
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
}
// dnsPacketConn implements the dnsConn interface for RFC 1035's
// "UDP usage" transport mechanism. Conn is a packet-oriented connection,
// such as a *UDPConn.
type dnsPacketConn struct {
Conn
}
func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
return 0, nil, nil, err
}
if err := b.Question(q); err != nil {
return 0, nil, nil, err
}
tcpReq, err = b.Finish()
udpReq = tcpReq[2:]
l := len(tcpReq) - 2
tcpReq[0] = byte(l >> 8)
tcpReq[1] = byte(l)
return id, udpReq, tcpReq, err
}
func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
if !respHdr.Response {
return false
}
if reqID != respHdr.ID {
return false
}
if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
return false
}
return true
}
func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
return nil, err
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 512) // see RFC 1035
for {
n, err := c.Read(b)
if err != nil {
return nil, err
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
resp := &dnsMsg{}
if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
// Ignore invalid responses as they may be malicious
// forgery attempts. Instead continue waiting until
// timeout. See golang.org/issue/13281.
var p dnsmessage.Parser
// Ignore invalid responses as they may be malicious
// forgery attempts. Instead continue waiting until
// timeout. See golang.org/issue/13281.
h, err := p.Start(b[:n])
if err != nil {
continue
}
return resp, nil
q, err := p.Question()
if err != nil || !checkResponse(id, query, h, q) {
continue
}
return p, h, nil
}
}
// dnsStreamConn implements the dnsConn interface for RFC 1035's
// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
// such as a *TCPConn.
type dnsStreamConn struct {
Conn
}
func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
}
l := len(b)
b = append([]byte{byte(l >> 8), byte(l)}, b...)
func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
return nil, err
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
l = int(b[0])<<8 | int(b[1])
l := int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
n, err := io.ReadFull(c, b[:l])
if err != nil {
return nil, err
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
resp := &dnsMsg{}
if !resp.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
var p dnsmessage.Parser
h, err := p.Start(b[:n])
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
}
if !resp.IsResponseTo(query) {
return nil, errors.New("invalid DNS response")
q, err := p.Question()
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
}
return resp, nil
if !checkResponse(id, query, h, q) {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
}
return p, h, nil
}
// exchange sends a query on the connection and hopes for a response.
func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
out := dnsMsg{
dnsMsgHdr: dnsMsgHdr{
recursion_desired: true,
},
question: []dnsQuestion{
{name, qtype, dnsClassINET},
},
func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot marshal DNS message")
}
for _, network := range []string{"udp", "tcp"} {
// TODO(mdempsky): Refactor so defers from UDP-based
// exchanges happen before TCP-based exchange.
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
c, err := r.dial(ctx, network, server)
if err != nil {
return nil, err
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
defer c.Close()
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
c.SetDeadline(d)
}
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
in, err := c.dnsRoundTrip(&out)
if err != nil {
return nil, mapErr(err)
var p dnsmessage.Parser
var h dnsmessage.Header
if network == "tcp" {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
} else {
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
}
if in.truncated { // see RFC 5966
c.Close()
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
}
if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
}
if h.Truncated { // see RFC 5966
continue
}
return in, nil
return p, h, nil
}
return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("no answer from DNS server")
}
func checkHeaders(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) error {
_, err := p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
return &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434.
if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
return &DNSError{Err: "lame referral", Name: name, Server: server}
}
// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
// it means the response in msg was not useful and trying another
// server probably won't help. Return now in those cases.
// TODO: indicate this in a more obvious way, such as a field on DNSError?
if h.RCode == dnsmessage.RCodeNameError {
return &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
}
if h.RCode != dnsmessage.RCodeSuccess {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly or
// having temporary trouble.
err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
if h.RCode == dnsmessage.RCodeServerFailure {
err.IsTemporary = true
}
return err
}
return nil
}
func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type, name, server string) error {
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
return &DNSError{
Err: errNoSuchHost.Error(),
Name: name,
Server: server,
}
}
if err != nil {
return &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type == qtype {
return nil
}
if err := p.SkipAnswer(); err != nil {
return &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
}
return nil, errors.New("no answer from DNS server")
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
var lastErr error
serverOffset := cfg.serverOffset()
sLen := uint32(len(cfg.servers))
n, err := dnsmessage.NewName(name)
if err != nil {
return dnsmessage.Parser{}, "", errors.New("cannot marshal DNS message")
}
q := dnsmessage.Question{
Name: n,
Type: qtype,
Class: dnsmessage.ClassINET,
}
for i := 0; i < cfg.attempts; i++ {
for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen]
msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout)
p, h, err := r.exchange(ctx, server, q, cfg.timeout)
if err != nil {
lastErr = &DNSError{
Err: err.Error(),
@ -175,41 +264,19 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
}
continue
}
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434.
if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 {
lastErr = &DNSError{Err: "lame referral", Name: name, Server: server}
lastErr = checkHeaders(&p, h, name, server)
if lastErr != nil {
continue
}
cname, rrs, err := answer(name, server, msg, qtype)
// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
// it means the response in msg was not useful and trying another
// server probably won't help. Return now in those cases.
// TODO: indicate this in a more obvious way, such as a field on DNSError?
if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
return cname, rrs, err
}
lastErr = err
}
}
return "", nil, lastErr
}
// addrRecordList converts and returns a list of IP addresses from DNS
// address records (both A and AAAA). Other record types are ignored.
func addrRecordList(rrs []dnsRR) []IPAddr {
addrs := make([]IPAddr, 0, 4)
for _, rr := range rrs {
switch rr := rr.(type) {
case *dnsRR_A:
addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
case *dnsRR_AAAA:
ip := make(IP, IPv6len)
copy(ip, rr.AAAA[:])
addrs = append(addrs, IPAddr{IP: ip})
lastErr = skipToAnswer(&p, qtype, name, server)
if lastErr == nil {
return p, server, nil
}
}
}
return addrs
return dnsmessage.Parser{}, "", lastErr
}
// A resolverConfig represents a DNS stub resolver configuration.
@ -287,21 +354,26 @@ func (conf *resolverConfig) releaseSema() {
<-conf.ch
}
func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
if !isDomainName(name) {
// We used to use "invalid domain name" as the error,
// but that is a detail of the specific lookup mechanism.
// Other lookups might allow broader name syntax
// (for example Multicast DNS allows UTF-8; see RFC 6762).
// For consistency with libc resolvers, report no such host.
return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name}
return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
var (
p dnsmessage.Parser
server string
err error
)
for _, fqdn := range conf.nameList(name) {
cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype)
p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
if err == nil {
break
}
@ -311,13 +383,16 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname
break
}
}
if err == nil {
return p, server, nil
}
if err, ok := err.(*DNSError); ok {
// Show original name passed to lookup, not suffixed one.
// In general we might have tried many suffixes; showing
// just one is misleading. See also golang.org/issue/6324.
err.Name = name
}
return
return dnsmessage.Parser{}, "", err
}
// avoidDNS reports whether this is a hostname for which we should not
@ -454,36 +529,36 @@ func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr,
return
}
func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) {
func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles {
return addrs, name, nil
return addrs, dnsmessage.Name{}, nil
}
}
if !isDomainName(name) {
// See comment in func lookup above about use of errNoSuchHost.
return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
type racer struct {
cname string
rrs []dnsRR
p dnsmessage.Parser
server string
error
}
lane := make(chan racer, 1)
qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA}
qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
var lastErr error
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
dnsWaitGroup.Add(1)
go func(qtype uint16) {
defer dnsWaitGroup.Done()
cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
lane <- racer{cname, rrs, err}
go func(qtype dnsmessage.Type) {
p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
lane <- racer{p, server, err}
dnsWaitGroup.Done()
}(qtype)
}
hitStrictError := false
@ -500,9 +575,74 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
}
continue
}
addrs = append(addrs, addrRecordList(racer.rrs)...)
if cname == "" {
cname = racer.cname
// Presotto says it's okay to assume that servers listed in
// /etc/resolv.conf are recursive resolvers.
//
// We asked for recursion, so it should have included all the
// answers we need in this one packet.
//
// Further, RFC 1035 section 4.3.1 says that "the recursive
// response to a query will be... The answer to the query,
// possibly preface by one or more CNAME RRs that specify
// aliases encountered on the way to an answer."
//
// Therefore, we should be able to assume that we can ignore
// CNAMEs and that the A and AAAA records we requested are
// for the canonical name.
loop:
for {
h, err := racer.p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
}
if err != nil {
break
}
switch h.Type {
case dnsmessage.TypeA:
a, err := racer.p.AResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
case dnsmessage.TypeAAAA:
aaaa, err := racer.p.AAAAResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
default:
if err := racer.p.SkipAnswer(); err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: racer.server,
}
break loop
}
continue
}
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
}
}
if hitStrictError {
@ -528,17 +668,17 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
addrs = goLookupIPFiles(name)
}
if len(addrs) == 0 && lastErr != nil {
return nil, "", lastErr
return nil, dnsmessage.Name{}, lastErr
}
}
return addrs, cname, nil
}
// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (cname string, err error) {
func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) {
order := systemConf().hostLookupOrder(host)
_, cname, err = r.goLookupIPCNAMEOrder(ctx, host, order)
return
_, cname, err := r.goLookupIPCNAMEOrder(ctx, host, order)
return cname.String(), err
}
// goLookupPTR is the native Go implementation of LookupAddr.
@ -555,13 +695,36 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, erro
if err != nil {
return nil, err
}
_, rrs, err := r.lookup(ctx, arpa, dnsTypePTR)
p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR)
if err != nil {
return nil, err
}
ptrs := make([]string, len(rrs))
for i, rr := range rrs {
ptrs[i] = rr.(*dnsRR_PTR).Ptr
var ptrs []string
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot marshal DNS message",
Name: addr,
Server: server,
}
}
if h.Type != dnsmessage.TypePTR {
continue
}
ptr, err := p.PTRResource()
if err != nil {
return nil, &DNSError{
Err: "cannot marshal DNS message",
Name: addr,
Server: server,
}
}
ptrs = append(ptrs, ptr.PTR.String())
}
return ptrs, nil
}