summaryrefslogtreecommitdiff
path: root/libgo/go/net/dnsclient_unix.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/dnsclient_unix.go')
-rw-r--r--libgo/go/net/dnsclient_unix.go288
1 files changed, 163 insertions, 125 deletions
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go
index 17188f0024..4dd4e16b0f 100644
--- a/libgo/go/net/dnsclient_unix.go
+++ b/libgo/go/net/dnsclient_unix.go
@@ -16,6 +16,7 @@
package net
import (
+ "context"
"errors"
"io"
"math/rand"
@@ -26,10 +27,10 @@ import (
// A dnsDialer provides dialing suitable for DNS queries.
type dnsDialer interface {
- dialDNS(string, string) (dnsConn, error)
+ dialDNS(ctx context.Context, network, addr string) (dnsConn, error)
}
-var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} }
+var testHookDNSDialer = func() dnsDialer { return &Dialer{} }
// A dnsConn represents a DNS transport endpoint.
type dnsConn interface {
@@ -37,46 +38,67 @@ type dnsConn interface {
SetDeadline(time.Time) error
- // readDNSResponse reads a DNS response message from the DNS
- // transport endpoint and returns the received DNS response
- // message.
- readDNSResponse() (*dnsMsg, error)
+ // dnsRoundTrip executes a single DNS transaction, returning a
+ // DNS response message for the provided DNS query message.
+ dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
+}
- // writeDNSQuery writes a DNS query message to the DNS
- // connection endpoint.
- writeDNSQuery(*dnsMsg) error
+func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
+ return dnsRoundTripUDP(c, query)
}
-func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
- b := make([]byte, 512) // see RFC 1035
- n, err := c.Read(b)
- if err != nil {
+// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
+// "UDP usage" transport mechanism. c should be a packet-oriented connection,
+// such as a *UDPConn.
+func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+ b, ok := query.Pack()
+ if !ok {
+ return nil, errors.New("cannot marshal DNS message")
+ }
+ if _, err := c.Write(b); err != nil {
return nil, err
}
- msg := &dnsMsg{}
- if !msg.Unpack(b[:n]) {
- return nil, errors.New("cannot unmarshal DNS message")
+
+ b = make([]byte, 512) // see RFC 1035
+ for {
+ n, err := c.Read(b)
+ if err != nil {
+ return nil, err
+ }
+ resp := &dnsMsg{}
+ if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
+ // Ignore invalid responses as they may be malicious
+ // forgery attempts. Instead continue waiting until
+ // timeout. See golang.org/issue/13281.
+ continue
+ }
+ return resp, nil
}
- return msg, nil
}
-func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
- b, ok := msg.Pack()
+func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
+ return dnsRoundTripTCP(c, out)
+}
+
+// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
+// "TCP usage" transport mechanism. c should be a stream-oriented connection,
+// such as a *TCPConn.
+func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
+ b, ok := query.Pack()
if !ok {
- return errors.New("cannot marshal DNS message")
+ return nil, errors.New("cannot marshal DNS message")
}
+ l := len(b)
+ b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
- return err
+ return nil, err
}
- return nil
-}
-func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
- b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
+ b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err
}
- l := int(b[0])<<8 | int(b[1])
+ l = int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
@@ -84,27 +106,17 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
if err != nil {
return nil, err
}
- msg := &dnsMsg{}
- if !msg.Unpack(b[:n]) {
+ resp := &dnsMsg{}
+ if !resp.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
- return msg, nil
-}
-
-func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
- b, ok := msg.Pack()
- if !ok {
- return errors.New("cannot marshal DNS message")
- }
- l := uint16(len(b))
- b = append([]byte{byte(l >> 8), byte(l)}, b...)
- if _, err := c.Write(b); err != nil {
- return err
+ if !resp.IsResponseTo(query) {
+ return nil, errors.New("invalid DNS response")
}
- return nil
+ return resp, nil
}
-func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
+func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
switch network {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
default:
@@ -113,11 +125,11 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
- // already checked that all the cfg.servers[i] are IP
+ // already checked that all the cfg.servers are IP
// addresses, which Dial will use without a DNS lookup.
- c, err := d.Dial(network, server)
+ c, err := d.DialContext(ctx, network, server)
if err != nil {
- return nil, err
+ return nil, mapErr(err)
}
switch network {
case "tcp", "tcp4", "tcp6":
@@ -129,8 +141,8 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
}
// exchange sends a query on the connection and hopes for a response.
-func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
- d := testHookDNSDialer(timeout)
+func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
+ d := testHookDNSDialer()
out := dnsMsg{
dnsMsgHdr: dnsMsgHdr{
recursion_desired: true,
@@ -140,24 +152,24 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg
},
}
for _, network := range []string{"udp", "tcp"} {
- c, err := d.dialDNS(network, server)
+ // TODO(mdempsky): Refactor so defers from UDP-based
+ // exchanges happen before TCP-based exchange.
+
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ c, err := d.dialDNS(ctx, network, server)
if err != nil {
return nil, err
}
defer c.Close()
- if timeout > 0 {
- c.SetDeadline(time.Now().Add(timeout))
+ if d, ok := ctx.Deadline(); ok && !d.IsZero() {
+ c.SetDeadline(d)
}
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
- if err := c.writeDNSQuery(&out); err != nil {
- return nil, err
- }
- in, err := c.readDNSResponse()
+ in, err := c.dnsRoundTrip(&out)
if err != nil {
- return nil, err
- }
- if in.id != out.id {
- return nil, errors.New("DNS message ID mismatch")
+ return nil, mapErr(err)
}
if in.truncated { // see RFC 5966
continue
@@ -169,16 +181,16 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
-func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
- if len(cfg.servers) == 0 {
- return "", nil, &DNSError{Err: "no DNS servers", Name: name}
- }
- timeout := time.Duration(cfg.timeout) * time.Second
+func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
var lastErr error
+ serverOffset := cfg.serverOffset()
+ sLen := uint32(len(cfg.servers))
+
for i := 0; i < cfg.attempts; i++ {
- for _, server := range cfg.servers {
- server = JoinHostPort(server, "53")
- msg, err := exchange(server, name, qtype, timeout)
+ for j := uint32(0); j < sLen; j++ {
+ server := cfg.servers[(serverOffset+j)%sLen]
+
+ msg, err := exchange(ctx, server, name, qtype, cfg.timeout)
if err != nil {
lastErr = &DNSError{
Err: err.Error(),
@@ -190,6 +202,12 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err
}
continue
}
+ // libresolv continues to the next server when it receives
+ // an invalid referral response. See golang.org/issue/15434.
+ if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 {
+ lastErr = &DNSError{Err: "lame referral", Name: name, Server: server}
+ continue
+ }
cname, rrs, err := answer(name, server, msg, qtype)
// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
// it means the response in msg was not useful and trying another
@@ -229,7 +247,6 @@ type resolverConfig struct {
// time to recheck resolv.conf.
ch chan struct{} // guards lastChecked and modTime
lastChecked time.Time // last time resolv.conf was checked
- modTime time.Time // time of resolv.conf modification
mu sync.RWMutex // protects dnsConfig
dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups
@@ -239,16 +256,12 @@ var resolvConf resolverConfig
// init initializes conf and is only called via conf.initOnce.
func (conf *resolverConfig) init() {
- // Set dnsConfig, modTime, and lastChecked so we don't parse
+ // Set dnsConfig and lastChecked so we don't parse
// resolv.conf twice the first time.
conf.dnsConfig = systemConf().resolv
if conf.dnsConfig == nil {
conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
}
-
- if fi, err := os.Stat("/etc/resolv.conf"); err == nil {
- conf.modTime = fi.ModTime()
- }
conf.lastChecked = time.Now()
// Prepare ch so that only one update of resolverConfig may
@@ -274,17 +287,12 @@ func (conf *resolverConfig) tryUpdate(name string) {
}
conf.lastChecked = now
+ var mtime time.Time
if fi, err := os.Stat(name); err == nil {
- if fi.ModTime().Equal(conf.modTime) {
- return
- }
- conf.modTime = fi.ModTime()
- } else {
- // If modTime wasn't set prior, assume nothing has changed.
- if conf.modTime.IsZero() {
- return
- }
- conf.modTime = time.Time{}
+ mtime = fi.ModTime()
+ }
+ if mtime.Equal(conf.dnsConfig.mtime) {
+ return
}
dnsConf := dnsReadConfig(name)
@@ -306,16 +314,21 @@ func (conf *resolverConfig) releaseSema() {
<-conf.ch
}
-func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
+func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
if !isDomainName(name) {
- return "", nil, &DNSError{Err: "invalid domain name", Name: name}
+ // We used to use "invalid domain name" as the error,
+ // but that is a detail of the specific lookup mechanism.
+ // Other lookups might allow broader name syntax
+ // (for example Multicast DNS allows UTF-8; see RFC 6762).
+ // For consistency with libc resolvers, report no such host.
+ return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
for _, fqdn := range conf.nameList(name) {
- cname, rrs, err = tryOneName(conf, fqdn, qtype)
+ cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype)
if err == nil {
break
}
@@ -329,30 +342,57 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
return
}
+// avoidDNS reports whether this is a hostname for which we should not
+// use DNS. Currently this includes only .onion, per RFC 7686. See
+// golang.org/issue/13705. Does not cover .local names (RFC 6762),
+// see golang.org/issue/16739.
+func avoidDNS(name string) bool {
+ if name == "" {
+ return true
+ }
+ if name[len(name)-1] == '.' {
+ name = name[:len(name)-1]
+ }
+ return stringsHasSuffixFold(name, ".onion")
+}
+
// nameList returns a list of names for sequential DNS queries.
func (conf *dnsConfig) nameList(name string) []string {
+ if avoidDNS(name) {
+ return nil
+ }
+
+ // Check name length (see isDomainName).
+ l := len(name)
+ rooted := l > 0 && name[l-1] == '.'
+ if l > 254 || l == 254 && rooted {
+ return nil
+ }
+
// If name is rooted (trailing dot), try only that name.
- rooted := len(name) > 0 && name[len(name)-1] == '.'
if rooted {
return []string{name}
}
+
+ hasNdots := count(name, '.') >= conf.ndots
+ name += "."
+ l++
+
// Build list of search choices.
names := make([]string, 0, 1+len(conf.search))
// If name has enough dots, try unsuffixed first.
- if count(name, '.') >= conf.ndots {
- names = append(names, name+".")
+ if hasNdots {
+ names = append(names, name)
}
- // Try suffixes.
+ // Try suffixes that are not too long (see isDomainName).
for _, suffix := range conf.search {
- suffixed := name + "." + suffix
- if suffixed[len(suffixed)-1] != '.' {
- suffixed += "."
+ if l+len(suffix) <= 254 {
+ names = append(names, name+suffix)
}
- names = append(names, suffixed)
}
// Try unsuffixed, if not tried first above.
- if count(name, '.') < conf.ndots {
- names = append(names, name+".")
+ if !hasNdots {
+ names = append(names, name)
}
return names
}
@@ -392,11 +432,11 @@ func (o hostLookupOrder) String() string {
// Normally we let cgo use the C library resolver instead of
// depending on our lookup code, so that Go and C get the same
// answers.
-func goLookupHost(name string) (addrs []string, err error) {
- return goLookupHostOrder(name, hostLookupFilesDNS)
+func goLookupHost(ctx context.Context, name string) (addrs []string, err error) {
+ return goLookupHostOrder(ctx, name, hostLookupFilesDNS)
}
-func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) {
+func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
// Use entries from /etc/hosts if they match.
addrs = lookupStaticHost(name)
@@ -404,7 +444,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err
return
}
}
- ips, err := goLookupIPOrder(name, order)
+ ips, _, err := goLookupIPCNAMEOrder(ctx, name, order)
if err != nil {
return
}
@@ -430,27 +470,30 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
// goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go.
-func goLookupIP(name string) (addrs []IPAddr, err error) {
- return goLookupIPOrder(name, hostLookupFilesDNS)
+func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+ order := systemConf().hostLookupOrder(host)
+ addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order)
+ return
}
-func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) {
+func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles {
- return addrs, nil
+ return addrs, name, nil
}
}
if !isDomainName(name) {
- return nil, &DNSError{Err: "invalid domain name", Name: name}
+ // See comment in func lookup above about use of errNoSuchHost.
+ return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
type racer struct {
- fqdn string
- rrs []dnsRR
+ cname string
+ rrs []dnsRR
error
}
lane := make(chan racer, 1)
@@ -459,20 +502,23 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
go func(qtype uint16) {
- _, rrs, err := tryOneName(conf, fqdn, qtype)
- lane <- racer{fqdn, rrs, err}
+ cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype)
+ lane <- racer{cname, rrs, err}
}(qtype)
}
for range qtypes {
racer := <-lane
if racer.error != nil {
// Prefer error for original name.
- if lastErr == nil || racer.fqdn == name+"." {
+ if lastErr == nil || fqdn == name+"." {
lastErr = racer.error
}
continue
}
addrs = append(addrs, addrRecordList(racer.rrs)...)
+ if cname == "" {
+ cname = racer.cname
+ }
}
if len(addrs) > 0 {
break
@@ -490,24 +536,16 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
addrs = goLookupIPFiles(name)
}
if len(addrs) == 0 && lastErr != nil {
- return nil, lastErr
+ return nil, "", lastErr
}
}
- return addrs, nil
+ return addrs, cname, nil
}
-// goLookupCNAME is the native Go implementation of LookupCNAME.
-// Used only if cgoLookupCNAME refuses to handle the request
-// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go).
-// Normally we let cgo use the C library resolver instead of
-// depending on our lookup code, so that Go and C get the same
-// answers.
-func goLookupCNAME(name string) (cname string, err error) {
- _, rrs, err := lookup(name, dnsTypeCNAME)
- if err != nil {
- return
- }
- cname = rrs[0].(*dnsRR_CNAME).Cname
+// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
+func goLookupCNAME(ctx context.Context, host string) (cname string, err error) {
+ order := systemConf().hostLookupOrder(host)
+ _, cname, err = goLookupIPCNAMEOrder(ctx, host, order)
return
}
@@ -516,7 +554,7 @@ func goLookupCNAME(name string) (cname string, err error) {
// only if cgoLookupPTR is the stub in cgo_stub.go).
// Normally we let cgo use the C library resolver instead of depending
// on our lookup code, so that Go and C get the same answers.
-func goLookupPTR(addr string) ([]string, error) {
+func goLookupPTR(ctx context.Context, addr string) ([]string, error) {
names := lookupStaticAddr(addr)
if len(names) > 0 {
return names, nil
@@ -525,7 +563,7 @@ func goLookupPTR(addr string) ([]string, error) {
if err != nil {
return nil, err
}
- _, rrs, err := lookup(arpa, dnsTypePTR)
+ _, rrs, err := lookup(ctx, arpa, dnsTypePTR)
if err != nil {
return nil, err
}