diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index b5b6ffb1c5..98be7a873d 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -125,7 +125,7 @@ func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, // Calling Dial here is scary -- we have to be sure not to // dial a name that will require a DNS lookup, or Dial will // call back here to translate it. The DNS config parser has - // already checked that all the cfg.servers[i] are IP + // already checked that all the cfg.servers are IP // addresses, which Dial will use without a DNS lookup. c, err := d.DialContext(ctx, network, server) if err != nil { @@ -182,13 +182,14 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { - if len(cfg.servers) == 0 { - return "", nil, &DNSError{Err: "no DNS servers", Name: name} - } - var lastErr error + serverOffset := cfg.serverOffset() + sLen := uint32(len(cfg.servers)) + for i := 0; i < cfg.attempts; i++ { - for _, server := range cfg.servers { + for j := uint32(0); j < sLen; j++ { + server := cfg.servers[(serverOffset+j)%sLen] + msg, err := exchange(ctx, server, name, qtype, cfg.timeout) if err != nil { lastErr = &DNSError{ diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index f185642feb..8ee64d407c 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -797,3 +797,73 @@ func TestRetryTimeout(t *testing.T) { t.Error("deadline0 still zero", deadline0) } } + +func TestRotate(t *testing.T) { + // without rotation, always uses the first server + testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"}) + + // with rotation, rotates through back to first + testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"}) +} + +func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { + origTestHookDNSDialer := testHookDNSDialer + defer func() { testHookDNSDialer = origTestHookDNSDialer }() + + conf, err := newResolvConfTest() + if err != nil { + t.Fatal(err) + } + defer conf.teardown() + + var confLines []string + for _, ns := range nameservers { + confLines = append(confLines, "nameserver "+ns) + } + if rotate { + confLines = append(confLines, "options rotate") + } + + if err := conf.writeAndUpdate(confLines); err != nil { + t.Fatal(err) + } + + d := &fakeDNSDialer{} + testHookDNSDialer = func() dnsDialer { return d } + + var usedServers []string + d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + usedServers = append(usedServers, s) + + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + recursion_available: true, + }, + question: q.question, + answer: []dnsRR{ + &dnsRR_CNAME{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeCNAME, + Class: dnsClassINET, + }, + Cname: "golang.org", + }, + }, + } + return r, nil + } + + // len(nameservers) + 1 to allow rotation to get back to start + for i := 0; i < len(nameservers)+1; i++ { + if _, err := goLookupCNAME(context.Background(), "www.golang.org"); err != nil { + t.Fatal(err) + } + } + + if !reflect.DeepEqual(usedServers, wantServers) { + t.Fatalf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers) + } +} diff --git a/src/net/dnsconfig_unix.go b/src/net/dnsconfig_unix.go index 683ae71812..9c8108d11c 100644 --- a/src/net/dnsconfig_unix.go +++ b/src/net/dnsconfig_unix.go @@ -10,6 +10,7 @@ package net import ( "os" + "sync/atomic" "time" ) @@ -29,6 +30,7 @@ type dnsConfig struct { lookup []string // OpenBSD top-level database "lookup" order err error // any error that occurs during open of resolv.conf mtime time.Time // time of resolv.conf modification + soffset uint32 // used by serverOffset } // See resolv.conf(5) on a Linux machine. @@ -136,6 +138,17 @@ func dnsReadConfig(filename string) *dnsConfig { return conf } +// serverOffset returns an offset that can be used to determine +// indices of servers in c.servers when making queries. +// When the rotate option is enabled, this offset increases. +// Otherwise it is always 0. +func (c *dnsConfig) serverOffset() uint32 { + if c.rotate { + return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start + } + return 0 +} + func dnsDefaultSearch() []string { hn, err := getHostname() if err != nil {