diff options
Diffstat (limited to 'libgo/go/net/dnsclient_unix.go')
-rw-r--r-- | libgo/go/net/dnsclient_unix.go | 288 |
1 files changed, 163 insertions, 125 deletions
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index 17188f0024..4dd4e16b0f 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -16,6 +16,7 @@ package net import ( + "context" "errors" "io" "math/rand" @@ -26,10 +27,10 @@ import ( // A dnsDialer provides dialing suitable for DNS queries. type dnsDialer interface { - dialDNS(string, string) (dnsConn, error) + dialDNS(ctx context.Context, network, addr string) (dnsConn, error) } -var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} } +var testHookDNSDialer = func() dnsDialer { return &Dialer{} } // A dnsConn represents a DNS transport endpoint. type dnsConn interface { @@ -37,46 +38,67 @@ type dnsConn interface { SetDeadline(time.Time) error - // readDNSResponse reads a DNS response message from the DNS - // transport endpoint and returns the received DNS response - // message. - readDNSResponse() (*dnsMsg, error) + // dnsRoundTrip executes a single DNS transaction, returning a + // DNS response message for the provided DNS query message. + dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) +} - // writeDNSQuery writes a DNS query message to the DNS - // connection endpoint. - writeDNSQuery(*dnsMsg) error +func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { + return dnsRoundTripUDP(c, query) } -func (c *UDPConn) readDNSResponse() (*dnsMsg, error) { - b := make([]byte, 512) // see RFC 1035 - n, err := c.Read(b) - if err != nil { +// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's +// "UDP usage" transport mechanism. c should be a packet-oriented connection, +// such as a *UDPConn. +func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { + b, ok := query.Pack() + if !ok { + return nil, errors.New("cannot marshal DNS message") + } + if _, err := c.Write(b); err != nil { return nil, err } - msg := &dnsMsg{} - if !msg.Unpack(b[:n]) { - return nil, errors.New("cannot unmarshal DNS message") + + b = make([]byte, 512) // see RFC 1035 + for { + n, err := c.Read(b) + if err != nil { + return nil, err + } + resp := &dnsMsg{} + if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) { + // Ignore invalid responses as they may be malicious + // forgery attempts. Instead continue waiting until + // timeout. See golang.org/issue/13281. + continue + } + return resp, nil } - return msg, nil } -func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error { - b, ok := msg.Pack() +func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) { + return dnsRoundTripTCP(c, out) +} + +// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's +// "TCP usage" transport mechanism. c should be a stream-oriented connection, +// such as a *TCPConn. +func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { + b, ok := query.Pack() if !ok { - return errors.New("cannot marshal DNS message") + return nil, errors.New("cannot marshal DNS message") } + l := len(b) + b = append([]byte{byte(l >> 8), byte(l)}, b...) if _, err := c.Write(b); err != nil { - return err + return nil, err } - return nil -} -func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { - b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 + b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 if _, err := io.ReadFull(c, b[:2]); err != nil { return nil, err } - l := int(b[0])<<8 | int(b[1]) + l = int(b[0])<<8 | int(b[1]) if l > len(b) { b = make([]byte, l) } @@ -84,27 +106,17 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { if err != nil { return nil, err } - msg := &dnsMsg{} - if !msg.Unpack(b[:n]) { + resp := &dnsMsg{} + if !resp.Unpack(b[:n]) { return nil, errors.New("cannot unmarshal DNS message") } - return msg, nil -} - -func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error { - b, ok := msg.Pack() - if !ok { - return errors.New("cannot marshal DNS message") - } - l := uint16(len(b)) - b = append([]byte{byte(l >> 8), byte(l)}, b...) - if _, err := c.Write(b); err != nil { - return err + if !resp.IsResponseTo(query) { + return nil, errors.New("invalid DNS response") } - return nil + return resp, nil } -func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { +func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) { switch network { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": default: @@ -113,11 +125,11 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { // Calling Dial here is scary -- we have to be sure not to // dial a name that will require a DNS lookup, or Dial will // call back here to translate it. The DNS config parser has - // already checked that all the cfg.servers[i] are IP + // already checked that all the cfg.servers are IP // addresses, which Dial will use without a DNS lookup. - c, err := d.Dial(network, server) + c, err := d.DialContext(ctx, network, server) if err != nil { - return nil, err + return nil, mapErr(err) } switch network { case "tcp", "tcp4", "tcp6": @@ -129,8 +141,8 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { } // exchange sends a query on the connection and hopes for a response. -func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { - d := testHookDNSDialer(timeout) +func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { + d := testHookDNSDialer() out := dnsMsg{ dnsMsgHdr: dnsMsgHdr{ recursion_desired: true, @@ -140,24 +152,24 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg }, } for _, network := range []string{"udp", "tcp"} { - c, err := d.dialDNS(network, server) + // TODO(mdempsky): Refactor so defers from UDP-based + // exchanges happen before TCP-based exchange. + + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + c, err := d.dialDNS(ctx, network, server) if err != nil { return nil, err } defer c.Close() - if timeout > 0 { - c.SetDeadline(time.Now().Add(timeout)) + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + c.SetDeadline(d) } out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) - if err := c.writeDNSQuery(&out); err != nil { - return nil, err - } - in, err := c.readDNSResponse() + in, err := c.dnsRoundTrip(&out) if err != nil { - return nil, err - } - if in.id != out.id { - return nil, errors.New("DNS message ID mismatch") + return nil, mapErr(err) } if in.truncated { // see RFC 5966 continue @@ -169,16 +181,16 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). -func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { - if len(cfg.servers) == 0 { - return "", nil, &DNSError{Err: "no DNS servers", Name: name} - } - timeout := time.Duration(cfg.timeout) * time.Second +func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { var lastErr error + serverOffset := cfg.serverOffset() + sLen := uint32(len(cfg.servers)) + for i := 0; i < cfg.attempts; i++ { - for _, server := range cfg.servers { - server = JoinHostPort(server, "53") - msg, err := exchange(server, name, qtype, timeout) + for j := uint32(0); j < sLen; j++ { + server := cfg.servers[(serverOffset+j)%sLen] + + msg, err := exchange(ctx, server, name, qtype, cfg.timeout) if err != nil { lastErr = &DNSError{ Err: err.Error(), @@ -190,6 +202,12 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err } continue } + // libresolv continues to the next server when it receives + // an invalid referral response. See golang.org/issue/15434. + if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 { + lastErr = &DNSError{Err: "lame referral", Name: name, Server: server} + continue + } cname, rrs, err := answer(name, server, msg, qtype) // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError, // it means the response in msg was not useful and trying another @@ -229,7 +247,6 @@ type resolverConfig struct { // time to recheck resolv.conf. ch chan struct{} // guards lastChecked and modTime lastChecked time.Time // last time resolv.conf was checked - modTime time.Time // time of resolv.conf modification mu sync.RWMutex // protects dnsConfig dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups @@ -239,16 +256,12 @@ var resolvConf resolverConfig // init initializes conf and is only called via conf.initOnce. func (conf *resolverConfig) init() { - // Set dnsConfig, modTime, and lastChecked so we don't parse + // Set dnsConfig and lastChecked so we don't parse // resolv.conf twice the first time. conf.dnsConfig = systemConf().resolv if conf.dnsConfig == nil { conf.dnsConfig = dnsReadConfig("/etc/resolv.conf") } - - if fi, err := os.Stat("/etc/resolv.conf"); err == nil { - conf.modTime = fi.ModTime() - } conf.lastChecked = time.Now() // Prepare ch so that only one update of resolverConfig may @@ -274,17 +287,12 @@ func (conf *resolverConfig) tryUpdate(name string) { } conf.lastChecked = now + var mtime time.Time if fi, err := os.Stat(name); err == nil { - if fi.ModTime().Equal(conf.modTime) { - return - } - conf.modTime = fi.ModTime() - } else { - // If modTime wasn't set prior, assume nothing has changed. - if conf.modTime.IsZero() { - return - } - conf.modTime = time.Time{} + mtime = fi.ModTime() + } + if mtime.Equal(conf.dnsConfig.mtime) { + return } dnsConf := dnsReadConfig(name) @@ -306,16 +314,21 @@ func (conf *resolverConfig) releaseSema() { <-conf.ch } -func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { +func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { if !isDomainName(name) { - return "", nil, &DNSError{Err: "invalid domain name", Name: name} + // We used to use "invalid domain name" as the error, + // but that is a detail of the specific lookup mechanism. + // Other lookups might allow broader name syntax + // (for example Multicast DNS allows UTF-8; see RFC 6762). + // For consistency with libc resolvers, report no such host. + return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() for _, fqdn := range conf.nameList(name) { - cname, rrs, err = tryOneName(conf, fqdn, qtype) + cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype) if err == nil { break } @@ -329,30 +342,57 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { return } +// avoidDNS reports whether this is a hostname for which we should not +// use DNS. Currently this includes only .onion, per RFC 7686. See +// golang.org/issue/13705. Does not cover .local names (RFC 6762), +// see golang.org/issue/16739. +func avoidDNS(name string) bool { + if name == "" { + return true + } + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return stringsHasSuffixFold(name, ".onion") +} + // nameList returns a list of names for sequential DNS queries. func (conf *dnsConfig) nameList(name string) []string { + if avoidDNS(name) { + return nil + } + + // Check name length (see isDomainName). + l := len(name) + rooted := l > 0 && name[l-1] == '.' + if l > 254 || l == 254 && rooted { + return nil + } + // If name is rooted (trailing dot), try only that name. - rooted := len(name) > 0 && name[len(name)-1] == '.' if rooted { return []string{name} } + + hasNdots := count(name, '.') >= conf.ndots + name += "." + l++ + // Build list of search choices. names := make([]string, 0, 1+len(conf.search)) // If name has enough dots, try unsuffixed first. - if count(name, '.') >= conf.ndots { - names = append(names, name+".") + if hasNdots { + names = append(names, name) } - // Try suffixes. + // Try suffixes that are not too long (see isDomainName). for _, suffix := range conf.search { - suffixed := name + "." + suffix - if suffixed[len(suffixed)-1] != '.' { - suffixed += "." + if l+len(suffix) <= 254 { + names = append(names, name+suffix) } - names = append(names, suffixed) } // Try unsuffixed, if not tried first above. - if count(name, '.') < conf.ndots { - names = append(names, name+".") + if !hasNdots { + names = append(names, name) } return names } @@ -392,11 +432,11 @@ func (o hostLookupOrder) String() string { // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupHost(name string) (addrs []string, err error) { - return goLookupHostOrder(name, hostLookupFilesDNS) +func goLookupHost(ctx context.Context, name string) (addrs []string, err error) { + return goLookupHostOrder(ctx, name, hostLookupFilesDNS) } -func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) { +func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) @@ -404,7 +444,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err return } } - ips, err := goLookupIPOrder(name, order) + ips, _, err := goLookupIPCNAMEOrder(ctx, name, order) if err != nil { return } @@ -430,27 +470,30 @@ func goLookupIPFiles(name string) (addrs []IPAddr) { // goLookupIP is the native Go implementation of LookupIP. // The libc versions are in cgo_*.go. -func goLookupIP(name string) (addrs []IPAddr, err error) { - return goLookupIPOrder(name, hostLookupFilesDNS) +func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { + order := systemConf().hostLookupOrder(host) + addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order) + return } -func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) { +func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { addrs = goLookupIPFiles(name) if len(addrs) > 0 || order == hostLookupFiles { - return addrs, nil + return addrs, name, nil } } if !isDomainName(name) { - return nil, &DNSError{Err: "invalid domain name", Name: name} + // See comment in func lookup above about use of errNoSuchHost. + return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() type racer struct { - fqdn string - rrs []dnsRR + cname string + rrs []dnsRR error } lane := make(chan racer, 1) @@ -459,20 +502,23 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { go func(qtype uint16) { - _, rrs, err := tryOneName(conf, fqdn, qtype) - lane <- racer{fqdn, rrs, err} + cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype) + lane <- racer{cname, rrs, err} }(qtype) } for range qtypes { racer := <-lane if racer.error != nil { // Prefer error for original name. - if lastErr == nil || racer.fqdn == name+"." { + if lastErr == nil || fqdn == name+"." { lastErr = racer.error } continue } addrs = append(addrs, addrRecordList(racer.rrs)...) + if cname == "" { + cname = racer.cname + } } if len(addrs) > 0 { break @@ -490,24 +536,16 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er addrs = goLookupIPFiles(name) } if len(addrs) == 0 && lastErr != nil { - return nil, lastErr + return nil, "", lastErr } } - return addrs, nil + return addrs, cname, nil } -// goLookupCNAME is the native Go implementation of LookupCNAME. -// Used only if cgoLookupCNAME refuses to handle the request -// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go). -// Normally we let cgo use the C library resolver instead of -// depending on our lookup code, so that Go and C get the same -// answers. -func goLookupCNAME(name string) (cname string, err error) { - _, rrs, err := lookup(name, dnsTypeCNAME) - if err != nil { - return - } - cname = rrs[0].(*dnsRR_CNAME).Cname +// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME. +func goLookupCNAME(ctx context.Context, host string) (cname string, err error) { + order := systemConf().hostLookupOrder(host) + _, cname, err = goLookupIPCNAMEOrder(ctx, host, order) return } @@ -516,7 +554,7 @@ func goLookupCNAME(name string) (cname string, err error) { // only if cgoLookupPTR is the stub in cgo_stub.go). // Normally we let cgo use the C library resolver instead of depending // on our lookup code, so that Go and C get the same answers. -func goLookupPTR(addr string) ([]string, error) { +func goLookupPTR(ctx context.Context, addr string) ([]string, error) { names := lookupStaticAddr(addr) if len(names) > 0 { return names, nil @@ -525,7 +563,7 @@ func goLookupPTR(addr string) ([]string, error) { if err != nil { return nil, err } - _, rrs, err := lookup(arpa, dnsTypePTR) + _, rrs, err := lookup(ctx, arpa, dnsTypePTR) if err != nil { return nil, err } |