summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikio Hara <mikioh.mikioh@gmail.com>2014-08-06 09:58:47 +0900
committerMikio Hara <mikioh.mikioh@gmail.com>2014-08-06 09:58:47 +0900
commit7846f72f54f04574afa96507a895c24d39257ff5 (patch)
treee4468a1da05efd347bdce77457e6e80c247aacf3
parent53f6f0a3f78b218b3aba77e39acb7c601c223957 (diff)
downloadgo-7846f72f54f04574afa96507a895c24d39257ff5.tar.gz
net: separate DNS transport from DNS query-response interaction
Before fixing issue 6579 this CL separates DNS transport from DNS message interaction to make it easier to add builtin DNS resolver control logic. Update issue 6579 LGTM=alex, kevlar R=golang-codereviews, alex, gobot, iant, minux, kevlar CC=golang-codereviews https://codereview.appspot.com/101220044
-rw-r--r--src/pkg/net/dnsclient_unix.go236
-rw-r--r--src/pkg/net/dnsclient_unix_test.go78
2 files changed, 222 insertions, 92 deletions
diff --git a/src/pkg/net/dnsclient_unix.go b/src/pkg/net/dnsclient_unix.go
index 3713efd0e..eb4b5900d 100644
--- a/src/pkg/net/dnsclient_unix.go
+++ b/src/pkg/net/dnsclient_unix.go
@@ -16,6 +16,7 @@
package net
import (
+ "errors"
"io"
"math/rand"
"os"
@@ -23,118 +24,187 @@ import (
"time"
)
-// Send a request on the connection and hope for a reply.
-// Up to cfg.attempts attempts.
-func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) {
- _, useTCP := c.(*TCPConn)
- if len(name) >= 256 {
- return nil, &DNSError{Err: "name too long", Name: name}
+// A dnsConn represents a DNS transport endpoint.
+type dnsConn interface {
+ Conn
+
+ // readDNSResponse reads a DNS response message from the DNS
+ // transport endpoint and returns the received DNS response
+ // message.
+ readDNSResponse() (*dnsMsg, error)
+
+ // writeDNSQuery writes a DNS query message to the DNS
+ // connection endpoint.
+ writeDNSQuery(*dnsMsg) error
+}
+
+func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
+ b := make([]byte, 512) // see RFC 1035
+ n, err := c.Read(b)
+ if err != nil {
+ return nil, err
+ }
+ msg := &dnsMsg{}
+ if !msg.Unpack(b[:n]) {
+ return nil, errors.New("cannot unmarshal DNS message")
+ }
+ return msg, nil
+}
+
+func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
+ b, ok := msg.Pack()
+ if !ok {
+ return errors.New("cannot marshal DNS message")
+ }
+ if _, err := c.Write(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
+ b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
+ if _, err := io.ReadFull(c, b[:2]); err != nil {
+ return nil, err
+ }
+ l := int(b[0])<<8 | int(b[1])
+ if l > len(b) {
+ b = make([]byte, l)
+ }
+ n, err := io.ReadFull(c, b[:l])
+ if err != nil {
+ return nil, err
}
- out := new(dnsMsg)
- out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
- out.question = []dnsQuestion{
- {name, qtype, dnsClassINET},
+ msg := &dnsMsg{}
+ if !msg.Unpack(b[:n]) {
+ return nil, errors.New("cannot unmarshal DNS message")
}
- out.recursion_desired = true
- msg, ok := out.Pack()
+ return msg, nil
+}
+
+func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
+ b, ok := msg.Pack()
if !ok {
- return nil, &DNSError{Err: "internal error - cannot pack message", Name: name}
+ return errors.New("cannot marshal DNS message")
+ }
+ l := uint16(len(b))
+ b = append([]byte{byte(l >> 8), byte(l)}, b...)
+ if _, err := c.Write(b); err != nil {
+ return err
}
- if useTCP {
- mlen := uint16(len(msg))
- msg = append([]byte{byte(mlen >> 8), byte(mlen)}, msg...)
+ return nil
+}
+
+func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
+ switch network {
+ case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+ default:
+ return nil, UnknownNetworkError(network)
}
- for attempt := 0; attempt < cfg.attempts; attempt++ {
- n, err := c.Write(msg)
+ // Calling Dial here is scary -- we have to be sure not to
+ // dial a name that will require a DNS lookup, or Dial will
+ // call back here to translate it. The DNS config parser has
+ // already checked that all the cfg.servers[i] are IP
+ // addresses, which Dial will use without a DNS lookup.
+ c, err := d.Dial(network, server)
+ if err != nil {
+ return nil, err
+ }
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ return c.(*TCPConn), nil
+ case "udp", "udp4", "udp6":
+ return c.(*UDPConn), nil
+ }
+ panic("unreachable")
+}
+
+// exchange sends a query on the connection and hopes for a response.
+func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
+ d := Dialer{Timeout: timeout}
+ out := dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ recursion_desired: true,
+ },
+ question: []dnsQuestion{
+ {name, qtype, dnsClassINET},
+ },
+ }
+ for _, network := range []string{"udp", "tcp"} {
+ c, err := d.dialDNS(network, server)
if err != nil {
return nil, err
}
-
- if cfg.timeout == 0 {
- c.SetReadDeadline(noDeadline)
- } else {
- c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second))
+ defer c.Close()
+ if timeout > 0 {
+ c.SetDeadline(time.Now().Add(timeout))
}
- buf := make([]byte, 2000)
- if useTCP {
- n, err = io.ReadFull(c, buf[:2])
- if err != nil {
- if e, ok := err.(Error); ok && e.Timeout() {
- continue
- }
- }
- mlen := int(buf[0])<<8 | int(buf[1])
- if mlen > len(buf) {
- buf = make([]byte, mlen)
- }
- n, err = io.ReadFull(c, buf[:mlen])
- } else {
- n, err = c.Read(buf)
+ out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
+ if err := c.writeDNSQuery(&out); err != nil {
+ return nil, err
}
+ in, err := c.readDNSResponse()
if err != nil {
- if e, ok := err.(Error); ok && e.Timeout() {
- continue
- }
return nil, err
}
- buf = buf[:n]
- in := new(dnsMsg)
- if !in.Unpack(buf) || in.id != out.id {
+ if in.id != out.id {
+ return nil, errors.New("DNS message ID mismatch")
+ }
+ if in.truncated { // see RFC 5966
continue
}
return in, nil
}
- var server string
- if a := c.RemoteAddr(); a != nil {
- server = a.String()
- }
- return nil, &DNSError{Err: "no answer from server", Name: name, Server: server, IsTimeout: true}
+ return nil, errors.New("no answer from DNS server")
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
-func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
+func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
if len(cfg.servers) == 0 {
return "", nil, &DNSError{Err: "no DNS servers", Name: name}
}
- for i := 0; i < len(cfg.servers); i++ {
- // Calling Dial here is scary -- we have to be sure
- // not to dial a name that will require a DNS lookup,
- // or Dial will call back here to translate it.
- // The DNS config parser has already checked that
- // all the cfg.servers[i] are IP addresses, which
- // Dial will use without a DNS lookup.
- server := cfg.servers[i] + ":53"
- c, cerr := Dial("udp", server)
- if cerr != nil {
- err = cerr
- continue
- }
- msg, merr := exchange(cfg, c, name, qtype)
- c.Close()
- if merr != nil {
- err = merr
- continue
+ if len(name) >= 256 {
+ return "", nil, &DNSError{Err: "DNS name too long", Name: name}
+ }
+ timeout := time.Duration(cfg.timeout) * time.Second
+ var lastErr error
+ for _, server := range cfg.servers {
+ server += ":53"
+ lastErr = &DNSError{
+ Err: "no answer from DNS server",
+ Name: name,
+ Server: server,
+ IsTimeout: true,
}
- if msg.truncated { // see RFC 5966
- c, cerr = Dial("tcp", server)
- if cerr != nil {
- err = cerr
- continue
+ for i := 0; i < cfg.attempts; i++ {
+ msg, err := exchange(server, name, qtype, timeout)
+ if err != nil {
+ if nerr, ok := err.(Error); ok && nerr.Timeout() {
+ lastErr = &DNSError{
+ Err: nerr.Error(),
+ Name: name,
+ Server: server,
+ IsTimeout: true,
+ }
+ continue
+
+ }
+ lastErr = &DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server,
+ }
+ break
}
- msg, merr = exchange(cfg, c, name, qtype)
- c.Close()
- if merr != nil {
- err = merr
- continue
+ cname, addrs, err := answer(name, server, msg, qtype)
+ if err == nil || err.(*DNSError).Err == noSuchHost {
+ return cname, addrs, err
}
- }
- cname, addrs, err = answer(name, server, msg, qtype)
- if err == nil || err.(*DNSError).Err == noSuchHost {
- break
+ lastErr = err
}
}
- return
+ return "", nil, lastErr
}
func convertRR_A(records []dnsRR) []IP {
diff --git a/src/pkg/net/dnsclient_unix_test.go b/src/pkg/net/dnsclient_unix_test.go
index 2350142d6..39d82d996 100644
--- a/src/pkg/net/dnsclient_unix_test.go
+++ b/src/pkg/net/dnsclient_unix_test.go
@@ -16,19 +16,79 @@ import (
"time"
)
-func TestTCPLookup(t *testing.T) {
+var dnsTransportFallbackTests = []struct {
+ server string
+ name string
+ qtype uint16
+ timeout int
+ rcode int
+}{
+ // Querying "com." with qtype=255 usually makes an answer
+ // which requires more than 512 bytes.
+ {"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
+ {"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
+}
+
+func TestDNSTransportFallback(t *testing.T) {
if testing.Short() || !*testExternal {
t.Skip("skipping test to avoid external network")
}
- c, err := Dial("tcp", "8.8.8.8:53")
- if err != nil {
- t.Fatalf("Dial failed: %v", err)
+
+ for _, tt := range dnsTransportFallbackTests {
+ timeout := time.Duration(tt.timeout) * time.Second
+ msg, err := exchange(tt.server, tt.name, tt.qtype, timeout)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ switch msg.rcode {
+ case tt.rcode, dnsRcodeServerFailure:
+ default:
+ t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
+ continue
+ }
}
- defer c.Close()
- cfg := &dnsConfig{timeout: 10, attempts: 3}
- _, err = exchange(cfg, c, "com.", dnsTypeALL)
- if err != nil {
- t.Fatalf("exchange failed: %v", err)
+}
+
+// See RFC 6761 for further information about the reserved, pseudo
+// domain names.
+var specialDomainNameTests = []struct {
+ name string
+ qtype uint16
+ rcode int
+}{
+ // Name resoltion APIs and libraries should not recongnize the
+ // followings as special.
+ {"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
+ {"test.", dnsTypeALL, dnsRcodeNameError},
+ {"example.com.", dnsTypeALL, dnsRcodeSuccess},
+
+ // Name resoltion APIs and libraries should recongnize the
+ // followings as special and should not send any queries.
+ // Though, we test those names here for verifying nagative
+ // answers at DNS query-response interaction level.
+ {"localhost.", dnsTypeALL, dnsRcodeNameError},
+ {"invalid.", dnsTypeALL, dnsRcodeNameError},
+}
+
+func TestSpecialDomainName(t *testing.T) {
+ if testing.Short() || !*testExternal {
+ t.Skip("skipping test to avoid external network")
+ }
+
+ server := "8.8.8.8:53"
+ for _, tt := range specialDomainNameTests {
+ msg, err := exchange(server, tt.name, tt.qtype, 0)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ switch msg.rcode {
+ case tt.rcode, dnsRcodeServerFailure:
+ default:
+ t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
+ continue
+ }
}
}