net: use golang.org/x/net/dns/dnsmessage for DNS resolution

Vendors golang.org/x/net/dns/dnsmessage from x/net git rev
892bf7b0c6e2f93b51166bf3882e50277fa5afc6

Updates #16218
Updates #21160

Change-Id: Ic4e8f3c3d83c2936354ec14c5be93b0d2b42dd91
Reviewed-on: https://go-review.googlesource.com/37879
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Ian Gudger 2017-11-22 17:12:30 -08:00 committed by Brad Fitzpatrick
parent c830e05a20
commit 672729ebbd
11 changed files with 4167 additions and 1874 deletions

View file

@ -19,42 +19,59 @@ import (
"sync"
"testing"
"time"
"golang_org/x/net/dns/dnsmessage"
)
var goResolver = Resolver{PreferGo: true}
// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
const TestAddr uint32 = 0xc0000201
var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
func mustNewName(name string) dnsmessage.Name {
nn, err := dnsmessage.NewName(name)
if err != nil {
panic(fmt.Sprint("creating name: ", err))
}
return nn
}
func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
return dnsmessage.Question{
Name: mustNewName(name),
Type: qtype,
Class: class,
}
}
var dnsTransportFallbackTests = []struct {
server string
name string
qtype uint16
timeout int
rcode int
server string
question dnsmessage.Question
timeout int
rcode dnsmessage.RCode
}{
// Querying "com." with qtype=255 usually makes an answer
// which requires more than 512 bytes.
{"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
{"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
{"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
{"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
}
func TestDNSTransportFallback(t *testing.T) {
fake := fakeDNSServer{
rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
rcode: dnsRcodeSuccess,
rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.Header.ID,
Response: true,
RCode: dnsmessage.RCodeSuccess,
},
question: q.question,
Questions: q.Questions,
}
if n == "udp" {
r.truncated = true
r.Header.Truncated = true
}
return r, nil
},
@ -63,15 +80,13 @@ func TestDNSTransportFallback(t *testing.T) {
for _, tt := range dnsTransportFallbackTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
_, h, err := r.exchange(ctx, tt.server, tt.question, time.Second)
if err != nil {
t.Error(err)
continue
}
switch msg.rcode {
case tt.rcode:
default:
t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
if h.RCode != tt.rcode {
t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
continue
}
}
@ -80,39 +95,38 @@ func TestDNSTransportFallback(t *testing.T) {
// See RFC 6761 for further information about the reserved, pseudo
// domain names.
var specialDomainNameTests = []struct {
name string
qtype uint16
rcode int
question dnsmessage.Question
rcode dnsmessage.RCode
}{
// Name resolution APIs and libraries should not recognize the
// followings as special.
{"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
{"test.", dnsTypeALL, dnsRcodeNameError},
{"example.com.", dnsTypeALL, dnsRcodeSuccess},
{mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
{mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
{mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
// Name resolution APIs and libraries should recognize the
// followings as special and should not send any queries.
// Though, we test those names here for verifying negative
// answers at DNS query-response interaction level.
{"localhost.", dnsTypeALL, dnsRcodeNameError},
{"invalid.", dnsTypeALL, dnsRcodeNameError},
{mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
{mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
}
func TestSpecialDomainName(t *testing.T) {
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
},
question: q.question,
Questions: q.Questions,
}
switch q.question[0].Name {
switch q.Questions[0].Name.String() {
case "example.com.":
r.rcode = dnsRcodeSuccess
r.Header.RCode = dnsmessage.RCodeSuccess
default:
r.rcode = dnsRcodeNameError
r.Header.RCode = dnsmessage.RCodeNameError
}
return r, nil
@ -122,15 +136,13 @@ func TestSpecialDomainName(t *testing.T) {
for _, tt := range specialDomainNameTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
_, h, err := r.exchange(ctx, server, tt.question, 3*time.Second)
if err != nil {
t.Error(err)
continue
}
switch msg.rcode {
case tt.rcode, dnsRcodeServerFailure:
default:
t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
if h.RCode != tt.rcode {
t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
continue
}
}
@ -177,24 +189,26 @@ func TestAvoidDNSName(t *testing.T) {
}
}
var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
},
question: q.question,
Questions: q.Questions,
}
if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA {
r.answer = []dnsRR{
&dnsRR_A{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
r.Answers = []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: q.Questions[0].Name,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
},
A: TestAddr,
},
}
}
@ -459,54 +473,57 @@ var goLookupIPWithResolverConfigTests = []struct {
func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break
default:
time.Sleep(10 * time.Millisecond)
return nil, poll.ErrTimeout
return dnsmessage.Message{}, poll.ErrTimeout
}
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
},
question: q.question,
Questions: q.Questions,
}
for _, question := range q.question {
switch question.Qtype {
case dnsTypeA:
switch question.Name {
for _, question := range q.Questions {
switch question.Type {
case dnsmessage.TypeA:
switch question.Name.String() {
case "hostname.as112.net.":
break
case "ipv4.google.com.":
r.answer = append(r.answer, &dnsRR_A{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
r.Answers = append(r.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: q.Questions[0].Name,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
},
A: TestAddr,
})
default:
}
case dnsTypeAAAA:
switch question.Name {
case dnsmessage.TypeAAAA:
switch question.Name.String() {
case "hostname.as112.net.":
break
case "ipv6.google.com.":
r.answer = append(r.answer, &dnsRR_AAAA{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeAAAA,
Class: dnsClassINET,
Rdlength: 16,
r.Answers = append(r.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: q.Questions[0].Name,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
Length: 16,
},
Body: &dnsmessage.AAAAResource{
AAAA: TestAddr6,
},
AAAA: TestAddr6,
})
}
}
@ -554,13 +571,13 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
},
question: q.question,
Questions: q.Questions,
}
return r, nil
}}
@ -624,20 +641,20 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
t.Fatal(err)
}
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
},
question: q.question,
Questions: q.Questions,
}
switch q.question[0].Name {
switch q.Questions[0].Name.String() {
case fqdn + ".servfail.":
r.rcode = dnsRcodeServerFailure
r.Header.RCode = dnsmessage.RCodeServerFailure
default:
r.rcode = dnsRcodeNameError
r.Header.RCode = dnsmessage.RCodeNameError
}
return r, nil
@ -679,28 +696,30 @@ func TestIgnoreLameReferrals(t *testing.T) {
t.Fatal(err)
}
fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
},
question: q.question,
Questions: q.Questions,
}
if s == "192.0.2.2:53" {
r.recursion_available = true
if q.question[0].Qtype == dnsTypeA {
r.answer = []dnsRR{
&dnsRR_A{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
r.Header.RecursionAvailable = true
if q.Questions[0].Type == dnsmessage.TypeA {
r.Answers = []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: q.Questions[0].Name,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
},
A: TestAddr,
},
}
}
@ -766,20 +785,23 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
}
type fakeDNSServer struct {
rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
}
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
tcp := n == "tcp" || n == "tcp4" || n == "tcp6"
return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil
}
type fakeDNSConn struct {
Conn
tcp bool
server *fakeDNSServer
n string
s string
q *dnsMsg
q dnsmessage.Message
t time.Time
buf []byte
}
func (f *fakeDNSConn) Close() error {
@ -787,15 +809,32 @@ func (f *fakeDNSConn) Close() error {
}
func (f *fakeDNSConn) Read(b []byte) (int, error) {
if len(f.buf) > 0 {
n := copy(b, f.buf)
f.buf = f.buf[n:]
return n, nil
}
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
if err != nil {
return 0, err
}
bb, ok := resp.Pack()
if !ok {
return 0, errors.New("cannot marshal DNS message")
bb := make([]byte, 2, 514)
bb, err = resp.AppendPack(bb)
if err != nil {
return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
}
if f.tcp {
l := len(bb) - 2
bb[0] = byte(l >> 8)
bb[1] = byte(l)
f.buf = bb
return f.Read(b)
}
bb = bb[2:]
if len(b) < len(bb) {
return 0, errors.New("read would fragment DNS message")
}
@ -809,9 +848,11 @@ func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
}
func (f *fakeDNSConn) Write(b []byte) (int, error) {
f.q = new(dnsMsg)
if !f.q.Unpack(b) {
return 0, errors.New("cannot unmarshal DNS message")
if f.tcp && len(b) >= 2 {
b = b[2:]
}
if f.q.Unpack(b) != nil {
return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
}
return len(b), nil
}
@ -836,64 +877,75 @@ func TestIgnoreDNSForgeries(t *testing.T) {
return
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
t.Error("invalid DNS query")
var msg dnsmessage.Message
if msg.Unpack(b[:n]) != nil {
t.Error("invalid DNS query:", err)
return
}
s.Write([]byte("garbage DNS response packet"))
msg.response = true
msg.id++ // make invalid ID
b, ok := msg.Pack()
if !ok {
t.Error("failed to pack DNS response")
msg.Header.Response = true
msg.Header.ID++ // make invalid ID
if b, err = msg.Pack(); err != nil {
t.Error("failed to pack DNS response:", err)
return
}
s.Write(b)
msg.id-- // restore original ID
msg.answer = []dnsRR{
&dnsRR_A{
Hdr: dnsRR_Header{
Name: "www.example.com.",
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
msg.Header.ID-- // restore original ID
msg.Answers = []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: mustNewName("www.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
},
A: TestAddr,
},
}
b, ok = msg.Pack()
if !ok {
t.Error("failed to pack DNS response")
b, err = msg.Pack()
if err != nil {
t.Error("failed to pack DNS response:", err)
return
}
s.Write(b)
}()
msg := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: 42,
msg := dnsmessage.Message{
Header: dnsmessage.Header{
ID: 42,
},
question: []dnsQuestion{
Questions: []dnsmessage.Question{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
Name: mustNewName("www.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
}
dc := &dnsPacketConn{c}
resp, err := dc.dnsRoundTrip(msg)
b, err := msg.Pack()
if err != nil {
t.Fatalf("dnsRoundTripUDP failed: %v", err)
t.Fatal("Pack failed:", err)
}
if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
if err != nil {
t.Fatalf("dnsPacketRoundTrip failed: %v", err)
}
p.SkipAllQuestions()
as, err := p.AllAnswers()
if err != nil {
t.Fatal("AllAnswers failed:", err)
}
if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
t.Errorf("got address %v, want %v", got, TestAddr)
}
}
@ -918,7 +970,7 @@ func TestRetryTimeout(t *testing.T) {
var deadline0 time.Time
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q, deadline)
if deadline.IsZero() {
@ -928,7 +980,7 @@ func TestRetryTimeout(t *testing.T) {
if s == "192.0.2.1:53" {
deadline0 = deadline
time.Sleep(10 * time.Millisecond)
return nil, poll.ErrTimeout
return dnsmessage.Message{}, poll.ErrTimeout
}
if deadline.Equal(deadline0) {
@ -979,7 +1031,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
}
var usedServers []string
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
usedServers = append(usedServers, s)
return mockTXTResponse(q), nil
}}
@ -997,22 +1049,24 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
}
}
func mockTXTResponse(q *dnsMsg) *dnsMsg {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
recursion_available: true,
func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
RecursionAvailable: true,
},
question: q.question,
answer: []dnsRR{
&dnsRR_TXT{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeTXT,
Class: dnsClassINET,
Questions: q.Questions,
Answers: []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: q.Questions[0].Name,
Type: dnsmessage.TypeTXT,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.TXTResource{
TXT: []string{"ok"},
},
Txt: "ok",
},
},
}
@ -1080,22 +1134,22 @@ func TestStrictErrorsLookupIP(t *testing.T) {
cases := []struct {
desc string
resolveWhich func(quest *dnsQuestion) resolveWhichEnum
resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
wantStrictErr error
wantLaxErr error
wantIPs []string
}{
{
desc: "No errors",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
return resolveOK
},
wantIPs: []string{ip4, ip6},
},
{
desc: "searchX error fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
if quest.Name == searchX {
resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name.String() == searchX {
return resolveTimeout
}
return resolveOK
@ -1105,8 +1159,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchX IPv4-only timeout fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
if quest.Name == searchX && quest.Qtype == dnsTypeA {
resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
return resolveTimeout
}
return resolveOK
@ -1116,8 +1170,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchX IPv6-only servfail fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
if quest.Name == searchX && quest.Qtype == dnsTypeAAAA {
resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
return resolveServfail
}
return resolveOK
@ -1127,8 +1181,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchY error always fails",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
if quest.Name == searchY {
resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name.String() == searchY {
return resolveTimeout
}
return resolveOK
@ -1138,8 +1192,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchY IPv4-only socket error fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
if quest.Name == searchY && quest.Qtype == dnsTypeA {
resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
return resolveOpError
}
return resolveOK
@ -1149,8 +1203,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchY IPv6-only timeout fails in strict mode",
resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
if quest.Name == searchY && quest.Qtype == dnsTypeAAAA {
resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
return resolveTimeout
}
return resolveOK
@ -1161,80 +1215,84 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}
for i, tt := range cases {
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
switch tt.resolveWhich(&q.question[0]) {
switch tt.resolveWhich(q.Questions[0]) {
case resolveOK:
// Handle below.
case resolveOpError:
return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
case resolveServfail:
return &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
rcode: dnsRcodeServerFailure,
return dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
RCode: dnsmessage.RCodeServerFailure,
},
question: q.question,
Questions: q.Questions,
}, nil
case resolveTimeout:
return nil, poll.ErrTimeout
return dnsmessage.Message{}, poll.ErrTimeout
default:
t.Fatal("Impossible resolveWhich")
}
switch q.question[0].Name {
switch q.Questions[0].Name.String() {
case searchX, name + ".":
// Return NXDOMAIN to utilize the search list.
return &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
rcode: dnsRcodeNameError,
return dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
RCode: dnsmessage.RCodeNameError,
},
question: q.question,
Questions: q.Questions,
}, nil
case searchY:
// Return records below.
default:
return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
}
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: q.id,
response: true,
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
Response: true,
},
question: q.question,
Questions: q.Questions,
}
switch q.question[0].Qtype {
case dnsTypeA:
r.answer = []dnsRR{
&dnsRR_A{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
switch q.Questions[0].Type {
case dnsmessage.TypeA:
r.Answers = []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: q.Questions[0].Name,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
Length: 4,
},
Body: &dnsmessage.AResource{
A: TestAddr,
},
A: TestAddr,
},
}
case dnsTypeAAAA:
r.answer = []dnsRR{
&dnsRR_AAAA{
Hdr: dnsRR_Header{
Name: q.question[0].Name,
Rrtype: dnsTypeAAAA,
Class: dnsClassINET,
Rdlength: 16,
case dnsmessage.TypeAAAA:
r.Answers = []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: q.Questions[0].Name,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
Length: 16,
},
Body: &dnsmessage.AAAAResource{
AAAA: TestAddr6,
},
AAAA: TestAddr6,
},
}
default:
return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype)
return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
}
return r, nil
}}
@ -1295,22 +1353,22 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
const searchY = "test.y.golang.org."
const txt = "Hello World"
fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
switch q.question[0].Name {
switch q.Questions[0].Name.String() {
case searchX:
return nil, poll.ErrTimeout
return dnsmessage.Message{}, poll.ErrTimeout
case searchY:
return mockTXTResponse(q), nil
default:
return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
}
}}
for _, strict := range []bool{true, false} {
r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
_, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT)
p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT)
var wantErr error
var wantRRs int
if strict {
@ -1326,8 +1384,12 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
if !reflect.DeepEqual(err, wantErr) {
t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
}
if len(rrs) != wantRRs {
t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs)
a, err := p.AllAnswers()
if err != nil {
a = nil
}
if len(a) != wantRRs {
t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
}
}
}
@ -1337,9 +1399,9 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) {
fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
time.Sleep(10 * time.Microsecond)
return nil, poll.ErrTimeout
return dnsmessage.Message{}, poll.ErrTimeout
}}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
@ -1353,3 +1415,76 @@ func TestDNSGoroutineRace(t *testing.T) {
t.Fatal("fake DNS lookup unexpectedly succeeded")
}
}
// Issue 8434: verify that Temporary returns true on an error when rcode
// is SERVFAIL
func TestIssue8434(t *testing.T) {
msg := dnsmessage.Message{
Header: dnsmessage.Header{
RCode: dnsmessage.RCodeServerFailure,
},
}
b, err := msg.Pack()
if err != nil {
t.Fatal("Pack failed:", err)
}
var p dnsmessage.Parser
h, err := p.Start(b)
if err != nil {
t.Fatal("Start failed:", err)
}
if err := p.SkipAllQuestions(); err != nil {
t.Fatal("SkipAllQuestions failed:", err)
}
err = checkHeaders(&p, h, "golang.org", "foo:53")
if err == nil {
t.Fatal("expected an error")
}
if ne, ok := err.(Error); !ok {
t.Fatalf("err = %#v; wanted something supporting net.Error", err)
} else if !ne.Temporary() {
t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
}
if de, ok := err.(*DNSError); !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
} else if !de.IsTemporary {
t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
}
}
// Issue 12778: verify that NXDOMAIN without RA bit errors as
// "no such host" and not "server misbehaving"
func TestIssue12778(t *testing.T) {
msg := dnsmessage.Message{
Header: dnsmessage.Header{
RCode: dnsmessage.RCodeNameError,
RecursionAvailable: false,
},
}
b, err := msg.Pack()
if err != nil {
t.Fatal("Pack failed:", err)
}
var p dnsmessage.Parser
h, err := p.Start(b)
if err != nil {
t.Fatal("Start failed:", err)
}
if err := p.SkipAllQuestions(); err != nil {
t.Fatal("SkipAllQuestions failed:", err)
}
err = checkHeaders(&p, h, "golang.org", "foo:53")
if err == nil {
t.Fatal("expected an error")
}
de, ok := err.(*DNSError)
if !ok {
t.Fatalf("err = %#v; wanted a *net.DNSError", err)
}
if de.Err != errNoSuchHost.Error() {
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
}
}