summaryrefslogtreecommitdiff
path: root/libgo/go/net/dial.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/dial.go')
-rw-r--r--libgo/go/net/dial.go398
1 files changed, 277 insertions, 121 deletions
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go
index 193776fe41..50bba5a49e 100644
--- a/libgo/go/net/dial.go
+++ b/libgo/go/net/dial.go
@@ -1,11 +1,12 @@
-// Copyright 2010 The Go Authors. All rights reserved.
+// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
- "errors"
+ "context"
+ "internal/nettrace"
"time"
)
@@ -58,24 +59,47 @@ type Dialer struct {
// that do not support keep-alives ignore this field.
KeepAlive time.Duration
+ // Resolver optionally specifies an alternate resolver to use.
+ Resolver *Resolver
+
// Cancel is an optional channel whose closure indicates that
// the dial should be canceled. Not all types of dials support
// cancelation.
+ //
+ // Deprecated: Use DialContext instead.
Cancel <-chan struct{}
}
-// Return either now+Timeout or Deadline, whichever comes first.
-// Or zero, if neither is set.
-func (d *Dialer) deadline(now time.Time) time.Time {
- if d.Timeout == 0 {
- return d.Deadline
+func minNonzeroTime(a, b time.Time) time.Time {
+ if a.IsZero() {
+ return b
}
- timeoutDeadline := now.Add(d.Timeout)
- if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) {
- return timeoutDeadline
- } else {
- return d.Deadline
+ if b.IsZero() || a.Before(b) {
+ return a
+ }
+ return b
+}
+
+// deadline returns the earliest of:
+// - now+Timeout
+// - d.Deadline
+// - the context's deadline
+// Or zero, if none of Timeout, Deadline, or context's deadline is set.
+func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
+ if d.Timeout != 0 { // including negative, for historical reasons
+ earliest = now.Add(d.Timeout)
+ }
+ if d, ok := ctx.Deadline(); ok {
+ earliest = minNonzeroTime(earliest, d)
+ }
+ return minNonzeroTime(earliest, d.Deadline)
+}
+
+func (d *Dialer) resolver() *Resolver {
+ if d.Resolver != nil {
+ return d.Resolver
}
+ return DefaultResolver
}
// partialDeadline returns the deadline to use for a single address,
@@ -110,7 +134,7 @@ func (d *Dialer) fallbackDelay() time.Duration {
}
}
-func parseNetwork(net string) (afnet string, proto int, err error) {
+func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) {
i := last(net, ':')
if i < 0 { // no colon
switch net {
@@ -127,9 +151,9 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
switch afnet {
case "ip", "ip4", "ip6":
protostr := net[i+1:]
- proto, i, ok := dtoi(protostr, 0)
+ proto, i, ok := dtoi(protostr)
if !ok || i != len(protostr) {
- proto, err = lookupProtocol(protostr)
+ proto, err = lookupProtocol(ctx, protostr)
if err != nil {
return "", 0, err
}
@@ -139,8 +163,11 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
return "", 0, UnknownNetworkError(net)
}
-func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) {
- afnet, _, err := parseNetwork(net)
+// resolveAddrList resolves addr using hint and returns a list of
+// addresses. The result contains at least one address when error is
+// nil.
+func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
+ afnet, _, err := parseNetwork(ctx, network)
if err != nil {
return nil, err
}
@@ -153,9 +180,59 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error)
if err != nil {
return nil, err
}
+ if op == "dial" && hint != nil && addr.Network() != hint.Network() {
+ return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+ }
return addrList{addr}, nil
}
- return internetAddrList(afnet, addr, deadline)
+ addrs, err := r.internetAddrList(ctx, afnet, addr)
+ if err != nil || op != "dial" || hint == nil {
+ return addrs, err
+ }
+ var (
+ tcp *TCPAddr
+ udp *UDPAddr
+ ip *IPAddr
+ wildcard bool
+ )
+ switch hint := hint.(type) {
+ case *TCPAddr:
+ tcp = hint
+ wildcard = tcp.isWildcard()
+ case *UDPAddr:
+ udp = hint
+ wildcard = udp.isWildcard()
+ case *IPAddr:
+ ip = hint
+ wildcard = ip.isWildcard()
+ }
+ naddrs := addrs[:0]
+ for _, addr := range addrs {
+ if addr.Network() != hint.Network() {
+ return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+ }
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
+ continue
+ }
+ naddrs = append(naddrs, addr)
+ case *UDPAddr:
+ if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
+ continue
+ }
+ naddrs = append(naddrs, addr)
+ case *IPAddr:
+ if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
+ continue
+ }
+ naddrs = append(naddrs, addr)
+ }
+ }
+ if len(naddrs) == 0 {
+ return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
+ }
+ return naddrs, nil
}
// Dial connects to the address on the named network.
@@ -173,8 +250,8 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error)
// If the host is empty, as in ":80", the local system is assumed.
//
// Examples:
-// Dial("tcp", "12.34.56.78:80")
-// Dial("tcp", "google.com:http")
+// Dial("tcp", "192.0.2.1:80")
+// Dial("tcp", "golang.org:http")
// Dial("tcp", "[2001:db8::1]:http")
// Dial("tcp", "[fe80::1%lo0]:80")
// Dial("tcp", ":80")
@@ -184,10 +261,13 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error)
// literal IP address.
//
// Examples:
-// Dial("ip4:1", "127.0.0.1")
-// Dial("ip6:ospf", "::1")
+// Dial("ip4:1", "192.0.2.1")
+// Dial("ip6:ipv6-icmp", "2001:db8::1")
//
// For Unix networks, the address must be a file system path.
+//
+// If the host is resolved to multiple addresses,
+// Dial will try each address in order until one succeeds.
func Dial(network, address string) (Conn, error) {
var d Dialer
return d.Dial(network, address)
@@ -200,11 +280,10 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
return d.Dial(network, address)
}
-// dialContext holds common state for all dial operations.
-type dialContext struct {
+// dialParam contains a Dial's parameters and configuration.
+type dialParam struct {
Dialer
network, address string
- finalDeadline time.Time
}
// Dial connects to the address on the named network.
@@ -212,17 +291,70 @@ type dialContext struct {
// See func Dial for a description of the network and address
// parameters.
func (d *Dialer) Dial(network, address string) (Conn, error) {
- finalDeadline := d.deadline(time.Now())
- addrs, err := resolveAddrList("dial", network, address, finalDeadline)
+ return d.DialContext(context.Background(), network, address)
+}
+
+// DialContext connects to the address on the named network using
+// the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// When using TCP, and the host in the address parameter resolves to multiple
+// network addresses, any dial timeout (from d.Timeout or ctx) is spread
+// over each consecutive dial, such that each is given an appropriate
+// fraction of the time to connect.
+// For example, if a host has 4 IP addresses and the timeout is 1 minute,
+// the connect to each single address will be given 15 seconds to complete
+// before trying the next one.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ deadline := d.deadline(ctx, time.Now())
+ if !deadline.IsZero() {
+ if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
+ subCtx, cancel := context.WithDeadline(ctx, deadline)
+ defer cancel()
+ ctx = subCtx
+ }
+ }
+ if oldCancel := d.Cancel; oldCancel != nil {
+ subCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ go func() {
+ select {
+ case <-oldCancel:
+ cancel()
+ case <-subCtx.Done():
+ }
+ }()
+ ctx = subCtx
+ }
+
+ // Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
+ resolveCtx := ctx
+ if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
+ shadow := *trace
+ shadow.ConnectStart = nil
+ shadow.ConnectDone = nil
+ resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
+ }
+
+ addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}
- ctx := &dialContext{
- Dialer: *d,
- network: network,
- address: address,
- finalDeadline: finalDeadline,
+ dp := &dialParam{
+ Dialer: *d,
+ network: network,
+ address: address,
}
var primaries, fallbacks addrList
@@ -233,116 +365,128 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
}
var c Conn
- if len(fallbacks) == 0 {
- // dialParallel can accept an empty fallbacks list,
- // but this shortcut avoids the goroutine/channel overhead.
- c, err = dialSerial(ctx, primaries, nil)
+ if len(fallbacks) > 0 {
+ c, err = dialParallel(ctx, dp, primaries, fallbacks)
} else {
- c, err = dialParallel(ctx, primaries, fallbacks)
+ c, err = dialSerial(ctx, dp, primaries)
+ }
+ if err != nil {
+ return nil, err
}
- if d.KeepAlive > 0 && err == nil {
- if tc, ok := c.(*TCPConn); ok {
- setKeepAlive(tc.fd, true)
- setKeepAlivePeriod(tc.fd, d.KeepAlive)
- testHookSetKeepAlive()
- }
+ if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 {
+ setKeepAlive(tc.fd, true)
+ setKeepAlivePeriod(tc.fd, d.KeepAlive)
+ testHookSetKeepAlive()
}
- return c, err
+ return c, nil
}
// dialParallel races two copies of dialSerial, giving the first a
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
// primary address.
-func dialParallel(ctx *dialContext, primaries, fallbacks addrList) (Conn, error) {
- results := make(chan dialResult) // unbuffered, so dialSerialAsync can detect race loss & cleanup
- cancel := make(chan struct{})
- defer close(cancel)
-
- // Spawn the primary racer.
- go dialSerialAsync(ctx, primaries, nil, cancel, results)
-
- // Spawn the fallback racer.
- fallbackTimer := time.NewTimer(ctx.fallbackDelay())
- go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results)
-
- var primaryErr error
- for nracers := 2; nracers > 0; nracers-- {
- res := <-results
- // If we're still waiting for a connection, then hasten the delay.
- // Otherwise, disable the Timer and let cancel take over.
- if fallbackTimer.Stop() && res.error != nil {
- fallbackTimer.Reset(0)
- }
- if res.error == nil {
- return res.Conn, nil
- }
- if res.primary {
- primaryErr = res.error
- }
+func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) {
+ if len(fallbacks) == 0 {
+ return dialSerial(ctx, dp, primaries)
}
- return nil, primaryErr
-}
-type dialResult struct {
- Conn
- error
- primary bool
-}
+ returned := make(chan struct{})
+ defer close(returned)
+
+ type dialResult struct {
+ Conn
+ error
+ primary bool
+ done bool
+ }
+ results := make(chan dialResult) // unbuffered
-// dialSerialAsync runs dialSerial after some delay, and returns the
-// resulting connection through a channel. When racing two connections,
-// the primary goroutine uses a nil timer to omit the delay.
-func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) {
- if timer != nil {
- // We're in the fallback goroutine; sleep before connecting.
+ startRacer := func(ctx context.Context, primary bool) {
+ ras := primaries
+ if !primary {
+ ras = fallbacks
+ }
+ c, err := dialSerial(ctx, dp, ras)
select {
- case <-timer.C:
- case <-cancel:
- return
+ case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
+ case <-returned:
+ if c != nil {
+ c.Close()
+ }
}
}
- c, err := dialSerial(ctx, ras, cancel)
- select {
- case results <- dialResult{c, err, timer == nil}:
- // We won the race.
- case <-cancel:
- // The other goroutine won the race.
- if c != nil {
- c.Close()
+
+ var primary, fallback dialResult
+
+ // Start the main racer.
+ primaryCtx, primaryCancel := context.WithCancel(ctx)
+ defer primaryCancel()
+ go startRacer(primaryCtx, true)
+
+ // Start the timer for the fallback racer.
+ fallbackTimer := time.NewTimer(dp.fallbackDelay())
+ defer fallbackTimer.Stop()
+
+ for {
+ select {
+ case <-fallbackTimer.C:
+ fallbackCtx, fallbackCancel := context.WithCancel(ctx)
+ defer fallbackCancel()
+ go startRacer(fallbackCtx, false)
+
+ case res := <-results:
+ if res.error == nil {
+ return res.Conn, nil
+ }
+ if res.primary {
+ primary = res
+ } else {
+ fallback = res
+ }
+ if primary.done && fallback.done {
+ return nil, primary.error
+ }
+ if res.primary && fallbackTimer.Stop() {
+ // If we were able to stop the timer, that means it
+ // was running (hadn't yet started the fallback), but
+ // we just got an error on the primary path, so start
+ // the fallback immediately (in 0 nanoseconds).
+ fallbackTimer.Reset(0)
+ }
}
}
}
// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
-func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) {
+func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) {
var firstErr error // The error from the first address is most relevant.
for i, ra := range ras {
select {
- case <-cancel:
- return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled}
+ case <-ctx.Done():
+ return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
default:
}
- partialDeadline, err := partialDeadline(time.Now(), ctx.finalDeadline, len(ras)-i)
+ deadline, _ := ctx.Deadline()
+ partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
if err != nil {
// Ran out of time.
if firstErr == nil {
- firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err}
+ firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err}
}
break
}
-
- // dialTCP does not support cancelation (see golang.org/issue/11225),
- // so if cancel fires, we'll continue trying to connect until the next
- // timeout, or return a spurious connection for the caller to close.
- dialer := func(d time.Time) (Conn, error) {
- return dialSingle(ctx, ra, d)
+ dialCtx := ctx
+ if partialDeadline.Before(deadline) {
+ var cancel context.CancelFunc
+ dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+ defer cancel()
}
- c, err := dial(ctx.network, ra, dialer, partialDeadline)
+
+ c, err := dialSingle(dialCtx, dp, ra)
if err == nil {
return c, nil
}
@@ -352,37 +496,43 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
}
if firstErr == nil {
- firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress}
+ firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress}
}
return nil, firstErr
}
// dialSingle attempts to establish and returns a single connection to
-// the destination address. This must be called through the OS-specific
-// dial function, because some OSes don't implement the deadline feature.
-func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err error) {
- la := ctx.LocalAddr
- if la != nil && la.Network() != ra.Network() {
- return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())}
+// the destination address.
+func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
+ trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
+ if trace != nil {
+ raStr := ra.String()
+ if trace.ConnectStart != nil {
+ trace.ConnectStart(dp.network, raStr)
+ }
+ if trace.ConnectDone != nil {
+ defer func() { trace.ConnectDone(dp.network, raStr, err) }()
+ }
}
+ la := dp.LocalAddr
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
- c, err = testHookDialTCP(ctx.network, la, ra, deadline, ctx.Cancel)
+ c, err = dialTCP(ctx, dp.network, la, ra)
case *UDPAddr:
la, _ := la.(*UDPAddr)
- c, err = dialUDP(ctx.network, la, ra, deadline)
+ c, err = dialUDP(ctx, dp.network, la, ra)
case *IPAddr:
la, _ := la.(*IPAddr)
- c, err = dialIP(ctx.network, la, ra, deadline)
+ c, err = dialIP(ctx, dp.network, la, ra)
case *UnixAddr:
la, _ := la.(*UnixAddr)
- c, err = dialUnix(ctx.network, la, ra, deadline)
+ c, err = dialUnix(ctx, dp.network, la, ra)
default:
- return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}}
+ return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}}
}
if err != nil {
- return nil, err // c is non-nil interface containing nil pointer
+ return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}
@@ -394,8 +544,11 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err erro
// If host is omitted, as in ":8080", Listen listens on all available interfaces
// instead of just the interface with the given host address.
// See Dial for more details about address syntax.
+//
+// Listening on a hostname is not recommended because this creates a socket
+// for at most one of its IP addresses.
func Listen(net, laddr string) (Listener, error) {
- addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
+ addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
@@ -421,8 +574,11 @@ func Listen(net, laddr string) (Listener, error) {
// If host is omitted, as in ":8080", ListenPacket listens on all available interfaces
// instead of just the interface with the given host address.
// See Dial for the syntax of laddr.
+//
+// Listening on a hostname is not recommended because this creates a socket
+// for at most one of its IP addresses.
func ListenPacket(net, laddr string) (PacketConn, error) {
- addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
+ addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}