diff options
Diffstat (limited to 'libgo/go/net')
130 files changed, 4985 insertions, 2425 deletions
diff --git a/libgo/go/net/cgo_bsd.go b/libgo/go/net/cgo_bsd.go index 27c3e9acb94..3852fc22987 100644 --- a/libgo/go/net/cgo_bsd.go +++ b/libgo/go/net/cgo_bsd.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd +// +build !netgo +// +build darwin dragonfly freebsd package net diff --git a/libgo/go/net/cgo_linux.go b/libgo/go/net/cgo_linux.go index 650575cce6c..77522f9141b 100644 --- a/libgo/go/net/cgo_linux.go +++ b/libgo/go/net/cgo_linux.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build cgo,!netgo + package net /* diff --git a/libgo/go/net/cgo_netbsd.go b/libgo/go/net/cgo_netbsd.go index 27334af641a..3c13103831f 100644 --- a/libgo/go/net/cgo_netbsd.go +++ b/libgo/go/net/cgo_netbsd.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build cgo,!netgo + package net /* diff --git a/libgo/go/net/cgo_openbsd.go b/libgo/go/net/cgo_openbsd.go index aeaf8e568ad..09c5ad2d9fd 100644 --- a/libgo/go/net/cgo_openbsd.go +++ b/libgo/go/net/cgo_openbsd.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// +build cgo,!netgo + package net /* diff --git a/libgo/go/net/cgo_stub.go b/libgo/go/net/cgo_stub.go index 52e57d7400e..f533c14212f 100644 --- a/libgo/go/net/cgo_stub.go +++ b/libgo/go/net/cgo_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !cgo +// +build !cgo netgo // Stub cgo routines for systems that do not use cgo to do network lookups. diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go index ce54d827c8e..0abf43410e1 100644 --- a/libgo/go/net/cgo_unix.go +++ b/libgo/go/net/cgo_unix.go @@ -2,7 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build !netgo +// +build darwin dragonfly freebsd linux netbsd openbsd package net @@ -50,6 +51,9 @@ func cgoLookupHost(name string) (addrs []string, err error, completed bool) { } func cgoLookupPort(net, service string) (port int, err error, completed bool) { + acquireThread() + defer releaseThread() + var res *syscall.Addrinfo var hints syscall.Addrinfo @@ -99,6 +103,9 @@ func cgoLookupPort(net, service string) (port int, err error, completed bool) { } func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, completed bool) { + acquireThread() + defer releaseThread() + var res *syscall.Addrinfo var hints syscall.Addrinfo @@ -114,7 +121,18 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, complet if gerrno == syscall.EAI_NONAME { str = noSuchHost } else if gerrno == syscall.EAI_SYSTEM { - str = syscall.GetErrno().Error() + errno := syscall.GetErrno() + if errno == 0 { + // err should not be nil, but sometimes getaddrinfo returns + // gerrno == C.EAI_SYSTEM with err == nil on Linux. + // The report claims that it happens when we have too many + // open files, so use syscall.EMFILE (too many open files in system). + // Most system calls would return ENFILE (too many open files), + // so at the least EMFILE should be easy to recognize if this + // comes up again. golang.org/issue/6232. + errno = syscall.EMFILE + } + str = errno.Error() } else { str = bytePtrToString(libc_gai_strerror(gerrno)) } @@ -160,6 +178,9 @@ func cgoLookupCNAME(name string) (cname string, err error, completed bool) { } func copyIP(x IP) IP { + if len(x) < 16 { + return x.To16() + } y := make(IP, len(x)) copy(y, x) return y diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index b18d283626c..6304818bf14 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -37,6 +37,13 @@ type Dialer struct { // network being dialed. // If nil, a local address is automatically chosen. LocalAddr Addr + + // DualStack allows a single dial to attempt to establish + // multiple IPv4 and IPv6 connections and to return the first + // established connection when the network is "tcp" and the + // destination is a host name that has multiple address family + // DNS records. + DualStack bool } // Return either now+Timeout or Deadline, whichever comes first. @@ -82,13 +89,13 @@ func parseNetwork(net string) (afnet string, proto int, err error) { return "", 0, UnknownNetworkError(net) } -func resolveAddr(op, net, addr string, deadline time.Time) (Addr, error) { +func resolveAddr(op, net, addr string, deadline time.Time) (netaddr, error) { afnet, _, err := parseNetwork(net) if err != nil { - return nil, &OpError{op, net, nil, err} + return nil, err } if op == "dial" && addr == "" { - return nil, &OpError{op, net, nil, errMissingAddress} + return nil, errMissingAddress } switch afnet { case "unix", "unixgram", "unixpacket": @@ -143,12 +150,74 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { // See func Dial for a description of the network and address // parameters. func (d *Dialer) Dial(network, address string) (Conn, error) { - return resolveAndDial(network, address, d.LocalAddr, d.deadline()) + ra, err := resolveAddr("dial", network, address, d.deadline()) + if err != nil { + return nil, &OpError{Op: "dial", Net: network, Addr: nil, Err: err} + } + dialer := func(deadline time.Time) (Conn, error) { + return dialSingle(network, address, d.LocalAddr, ra.toAddr(), deadline) + } + if ras, ok := ra.(addrList); ok && d.DualStack && network == "tcp" { + dialer = func(deadline time.Time) (Conn, error) { + return dialMulti(network, address, d.LocalAddr, ras, deadline) + } + } + return dial(network, ra.toAddr(), dialer, d.deadline()) } -func dial(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) { +// dialMulti attempts to establish connections to each destination of +// the list of addresses. It will return the first established +// connection and close the other connections. Otherwise it returns +// error on the last attempt. +func dialMulti(net, addr string, la Addr, ras addrList, deadline time.Time) (Conn, error) { + type racer struct { + Conn + Addr + error + } + // Sig controls the flow of dial results on lane. It passes a + // token to the next racer and also indicates the end of flow + // by using closed channel. + sig := make(chan bool, 1) + lane := make(chan racer, 1) + for _, ra := range ras { + go func(ra Addr) { + c, err := dialSingle(net, addr, la, ra, deadline) + if _, ok := <-sig; ok { + lane <- racer{c, ra, err} + } else if err == nil { + // We have to return the resources + // that belong to the other + // connections here for avoiding + // unnecessary resource starvation. + c.Close() + } + }(ra.toAddr()) + } + defer close(sig) + var failAddr Addr + lastErr := errTimeout + nracers := len(ras) + for nracers > 0 { + sig <- true + select { + case racer := <-lane: + if racer.error == nil { + return racer.Conn, nil + } + failAddr = racer.Addr + lastErr = racer.error + nracers-- + } + } + return nil, &OpError{Op: "dial", Net: net, Addr: failAddr, Err: lastErr} +} + +// dialSingle attempts to establish and returns a single connection to +// the destination address. +func dialSingle(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) { if la != nil && la.Network() != ra.Network() { - return nil, &OpError{"dial", net, ra, errors.New("mismatched local addr type " + la.Network())} + return nil, &OpError{Op: "dial", Net: net, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())} } switch ra := ra.(type) { case *TCPAddr: @@ -164,21 +233,14 @@ func dial(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) la, _ := la.(*UnixAddr) c, err = dialUnix(net, la, ra, deadline) default: - err = &OpError{"dial", net + " " + addr, ra, UnknownNetworkError(net)} + return nil, &OpError{Op: "dial", Net: net, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: addr}} } if err != nil { - return nil, err + return nil, err // c is non-nil interface containing nil pointer } - return -} - -type stringAddr struct { - net, addr string + return c, nil } -func (a stringAddr) Network() string { return a.net } -func (a stringAddr) String() string { return a.addr } - // Listen announces on the local network address laddr. // The network net must be a stream-oriented network: "tcp", "tcp4", // "tcp6", "unix" or "unixpacket". @@ -186,15 +248,21 @@ func (a stringAddr) String() string { return a.addr } func Listen(net, laddr string) (Listener, error) { la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { - return nil, err + return nil, &OpError{Op: "listen", Net: net, Addr: nil, Err: err} } - switch la := la.(type) { + var l Listener + switch la := la.toAddr().(type) { case *TCPAddr: - return ListenTCP(net, la) + l, err = ListenTCP(net, la) case *UnixAddr: - return ListenUnix(net, la) + l, err = ListenUnix(net, la) + default: + return nil, &OpError{Op: "listen", Net: net, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: laddr}} } - return nil, UnknownNetworkError(net) + if err != nil { + return nil, err // l is non-nil interface containing nil pointer + } + return l, nil } // ListenPacket announces on the local network address laddr. @@ -204,15 +272,21 @@ func Listen(net, laddr string) (Listener, error) { func ListenPacket(net, laddr string) (PacketConn, error) { la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { - return nil, err + return nil, &OpError{Op: "listen", Net: net, Addr: nil, Err: err} } - switch la := la.(type) { + var l PacketConn + switch la := la.toAddr().(type) { case *UDPAddr: - return ListenUDP(net, la) + l, err = ListenUDP(net, la) case *IPAddr: - return ListenIP(net, la) + l, err = ListenIP(net, la) case *UnixAddr: - return ListenUnixgram(net, la) + l, err = ListenUnixgram(net, la) + default: + return nil, &OpError{Op: "listen", Net: net, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: laddr}} + } + if err != nil { + return nil, err // l is non-nil interface containing nil pointer } - return nil, UnknownNetworkError(net) + return l, nil } diff --git a/libgo/go/net/dial_gen.go b/libgo/go/net/dial_gen.go index 19f86816821..ada6233003f 100644 --- a/libgo/go/net/dial_gen.go +++ b/libgo/go/net/dial_gen.go @@ -12,62 +12,35 @@ import ( var testingIssue5349 bool // used during tests -// resolveAndDialChannel is the simple pure-Go implementation of -// resolveAndDial, still used on operating systems where the deadline -// hasn't been pushed down into the pollserver. (Plan 9 and some old -// versions of Windows) -func resolveAndDialChannel(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { +// dialChannel is the simple pure-Go implementation of dial, still +// used on operating systems where the deadline hasn't been pushed +// down into the pollserver. (Plan 9 and some old versions of Windows) +func dialChannel(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { var timeout time.Duration if !deadline.IsZero() { timeout = deadline.Sub(time.Now()) } if timeout <= 0 { - ra, err := resolveAddr("dial", net, addr, noDeadline) - if err != nil { - return nil, err - } - return dial(net, addr, localAddr, ra, noDeadline) + return dialer(noDeadline) } t := time.NewTimer(timeout) defer t.Stop() - type pair struct { + type racer struct { Conn error } - ch := make(chan pair, 1) - resolvedAddr := make(chan Addr, 1) + ch := make(chan racer, 1) go func() { if testingIssue5349 { time.Sleep(time.Millisecond) } - ra, err := resolveAddr("dial", net, addr, noDeadline) - if err != nil { - ch <- pair{nil, err} - return - } - resolvedAddr <- ra // in case we need it for OpError - c, err := dial(net, addr, localAddr, ra, noDeadline) - ch <- pair{c, err} + c, err := dialer(noDeadline) + ch <- racer{c, err} }() select { case <-t.C: - // Try to use the real Addr in our OpError, if we resolved it - // before the timeout. Otherwise we just use stringAddr. - var ra Addr - select { - case a := <-resolvedAddr: - ra = a - default: - ra = &stringAddr{net, addr} - } - err := &OpError{ - Op: "dial", - Net: net, - Addr: ra, - Err: &timeoutError{}, - } - return nil, err - case p := <-ch: - return p.Conn, p.error + return nil, &OpError{Op: "dial", Net: net, Addr: ra, Err: errTimeout} + case racer := <-ch: + return racer.Conn, racer.error } } diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index 03a0bad7a5b..c7ffdd3d9c8 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -5,13 +5,17 @@ package net import ( + "bytes" "flag" "fmt" "io" "os" + "os/exec" "reflect" "regexp" "runtime" + "strconv" + "sync" "testing" "time" ) @@ -137,7 +141,7 @@ func TestSelfConnect(t *testing.T) { n = 1000 } switch runtime.GOOS { - case "darwin", "freebsd", "netbsd", "openbsd", "plan9", "solaris", "windows": + case "darwin", "dragonfly", "freebsd", "netbsd", "openbsd", "plan9", "solaris", "windows": // Non-Linux systems take a long time to figure // out that there is nothing listening on localhost. n = 100 @@ -314,6 +318,96 @@ func TestDialTimeoutFDLeak(t *testing.T) { } } +func numTCP() (ntcp, nopen, nclose int, err error) { + lsof, err := exec.Command("lsof", "-n", "-p", strconv.Itoa(os.Getpid())).Output() + if err != nil { + return 0, 0, 0, err + } + ntcp += bytes.Count(lsof, []byte("TCP")) + for _, state := range []string{"LISTEN", "SYN_SENT", "SYN_RECEIVED", "ESTABLISHED"} { + nopen += bytes.Count(lsof, []byte(state)) + } + for _, state := range []string{"CLOSED", "CLOSE_WAIT", "LAST_ACK", "FIN_WAIT_1", "FIN_WAIT_2", "CLOSING", "TIME_WAIT"} { + nclose += bytes.Count(lsof, []byte(state)) + } + return ntcp, nopen, nclose, nil +} + +func TestDialMultiFDLeak(t *testing.T) { + if !supportsIPv4 || !supportsIPv6 { + t.Skip("neither ipv4 nor ipv6 is supported") + } + + halfDeadServer := func(dss *dualStackServer, ln Listener) { + for { + if c, err := ln.Accept(); err != nil { + return + } else { + // It just keeps established + // connections like a half-dead server + // does. + dss.putConn(c) + } + } + } + dss, err := newDualStackServer([]streamListener{ + {net: "tcp4", addr: "127.0.0.1"}, + {net: "tcp6", addr: "[::1]"}, + }) + if err != nil { + t.Fatalf("newDualStackServer failed: %v", err) + } + defer dss.teardown() + if err := dss.buildup(halfDeadServer); err != nil { + t.Fatalf("dualStackServer.buildup failed: %v", err) + } + + _, before, _, err := numTCP() + if err != nil { + t.Skipf("skipping test; error finding or running lsof: %v", err) + } + + var wg sync.WaitGroup + portnum, _, _ := dtoi(dss.port, 0) + ras := addrList{ + // Losers that will fail to connect, see RFC 6890. + &TCPAddr{IP: IPv4(198, 18, 0, 254), Port: portnum}, + &TCPAddr{IP: ParseIP("2001:2::254"), Port: portnum}, + + // Winner candidates of this race. + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: portnum}, + &TCPAddr{IP: IPv6loopback, Port: portnum}, + + // Losers that will have established connections. + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: portnum}, + &TCPAddr{IP: IPv6loopback, Port: portnum}, + } + const T1 = 10 * time.Millisecond + const T2 = 2 * T1 + const N = 10 + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if c, err := dialMulti("tcp", "fast failover test", nil, ras, time.Now().Add(T1)); err == nil { + c.Close() + } + }() + } + wg.Wait() + time.Sleep(T2) + + ntcp, after, nclose, err := numTCP() + if err != nil { + t.Skipf("skipping test; error finding or running lsof: %v", err) + } + t.Logf("tcp sessions: %v, open sessions: %v, closing sessions: %v", ntcp, after, nclose) + + if after != before { + t.Fatalf("got %v open sessions; expected %v", after, before) + } +} + func numFD() int { if runtime.GOOS == "linux" { f, err := os.Open("/proc/self/fd") @@ -331,17 +425,22 @@ func numFD() int { panic("numFDs not implemented on " + runtime.GOOS) } -var testPoller = flag.Bool("poller", false, "platform supports runtime-integrated poller") - // Assert that a failed Dial attempt does not leak // runtime.PollDesc structures func TestDialFailPDLeak(t *testing.T) { - if !*testPoller { - t.Skip("test disabled; use -poller to enable") + if testing.Short() { + t.Skip("skipping test in short mode") + } + if runtime.GOOS == "windows" && runtime.GOARCH == "386" { + // Just skip the test because it takes too long. + t.Skipf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH) } - const loops = 10 - const count = 20000 + maxprocs := runtime.GOMAXPROCS(0) + loops := 10 + maxprocs + // 500 is enough to turn over the chunk of pollcache. + // See allocPollDesc in runtime/netpoll.goc. + const count = 500 var old runtime.MemStats // used by sysdelta runtime.ReadMemStats(&old) sysdelta := func() uint64 { @@ -354,19 +453,26 @@ func TestDialFailPDLeak(t *testing.T) { d := &Dialer{Timeout: time.Nanosecond} // don't bother TCP with handshaking failcount := 0 for i := 0; i < loops; i++ { + var wg sync.WaitGroup for i := 0; i < count; i++ { - conn, err := d.Dial("tcp", "127.0.0.1:1") - if err == nil { - t.Error("dial should not succeed") - conn.Close() - t.FailNow() - } + wg.Add(1) + go func() { + defer wg.Done() + if c, err := d.Dial("tcp", "127.0.0.1:1"); err == nil { + t.Error("dial should not succeed") + c.Close() + } + }() + } + wg.Wait() + if t.Failed() { + t.FailNow() } if delta := sysdelta(); delta > 0 { failcount++ } // there are always some allocations on the first loop - if failcount > 3 { + if failcount > maxprocs+2 { t.Error("detected possible memory leak in runtime") t.FailNow() } @@ -381,7 +487,6 @@ func TestDialer(t *testing.T) { defer ln.Close() ch := make(chan error, 1) go func() { - var err error c, err := ln.Accept() if err != nil { ch <- fmt.Errorf("Accept failed: %v", err) @@ -407,3 +512,46 @@ func TestDialer(t *testing.T) { t.Error(err) } } + +func TestDialDualStackLocalhost(t *testing.T) { + if ips, err := LookupIP("localhost"); err != nil { + t.Fatalf("LookupIP failed: %v", err) + } else if len(ips) < 2 || !supportsIPv4 || !supportsIPv6 { + t.Skip("localhost doesn't have a pair of different address family IP addresses") + } + + touchAndByeServer := func(dss *dualStackServer, ln Listener) { + for { + if c, err := ln.Accept(); err != nil { + return + } else { + c.Close() + } + } + } + dss, err := newDualStackServer([]streamListener{ + {net: "tcp4", addr: "127.0.0.1"}, + {net: "tcp6", addr: "[::1]"}, + }) + if err != nil { + t.Fatalf("newDualStackServer failed: %v", err) + } + defer dss.teardown() + if err := dss.buildup(touchAndByeServer); err != nil { + t.Fatalf("dualStackServer.buildup failed: %v", err) + } + + d := &Dialer{DualStack: true} + for _ = range dss.lns { + if c, err := d.Dial("tcp", "localhost:"+dss.port); err != nil { + t.Errorf("Dial failed: %v", err) + } else { + if addr := c.LocalAddr().(*TCPAddr); addr.IP.To4() != nil { + dss.teardownNetwork("tcp4") + } else if addr.IP.To16() != nil && addr.IP.To4() == nil { + dss.teardownNetwork("tcp6") + } + c.Close() + } + } +} diff --git a/libgo/go/net/dialgoogle_test.go b/libgo/go/net/dialgoogle_test.go index 73a94f5bf1c..b4ebad0e0dc 100644 --- a/libgo/go/net/dialgoogle_test.go +++ b/libgo/go/net/dialgoogle_test.go @@ -16,6 +16,59 @@ import ( // If an IPv6 tunnel is running, we can try dialing a real IPv6 address. var testIPv6 = flag.Bool("ipv6", false, "assume ipv6 tunnel is present") +func TestResolveGoogle(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + for _, network := range []string{"tcp", "tcp4", "tcp6"} { + addr, err := ResolveTCPAddr(network, "www.google.com:http") + if err != nil { + if (network == "tcp" || network == "tcp4") && !supportsIPv4 { + t.Logf("ipv4 is not supported: %v", err) + } else if network == "tcp6" && !supportsIPv6 { + t.Logf("ipv6 is not supported: %v", err) + } else { + t.Errorf("ResolveTCPAddr failed: %v", err) + } + continue + } + if (network == "tcp" || network == "tcp4") && addr.IP.To4() == nil { + t.Errorf("got %v; expected an IPv4 address on %v", addr, network) + } else if network == "tcp6" && (addr.IP.To16() == nil || addr.IP.To4() != nil) { + t.Errorf("got %v; expected an IPv6 address on %v", addr, network) + } + } +} + +func TestDialGoogle(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + d := &Dialer{DualStack: true} + for _, network := range []string{"tcp", "tcp4", "tcp6"} { + if network == "tcp" && !supportsIPv4 && !supportsIPv6 { + t.Logf("skipping test; both ipv4 and ipv6 are not supported") + continue + } else if network == "tcp4" && !supportsIPv4 { + t.Logf("skipping test; ipv4 is not supported") + continue + } else if network == "tcp6" && !supportsIPv6 { + t.Logf("skipping test; ipv6 is not supported") + continue + } else if network == "tcp6" && !*testIPv6 { + t.Logf("test disabled; use -ipv6 to enable") + continue + } + if c, err := d.Dial(network, "www.google.com:http"); err != nil { + t.Errorf("Dial failed: %v", err) + } else { + c.Close() + } + } +} + // fd is already connected to the destination, port 80. // Run an HTTP request to fetch the appropriate page. func fetchGoogle(t *testing.T, fd Conn, network, addr string) { @@ -54,6 +107,30 @@ var googleaddrsipv4 = []string{ "[0:0:0:0:0:ffff::%d.%d.%d.%d]:80", } +func TestDNSThreadLimit(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + const N = 10000 + c := make(chan int, N) + for i := 0; i < N; i++ { + go func(i int) { + LookupIP(fmt.Sprintf("%d.net-test.golang.org", i)) + c <- 1 + }(i) + } + // Don't bother waiting for the stragglers; stop at 0.9 N. + for i := 0; i < N*9/10; i++ { + if i%100 == 0 { + //println("TestDNSThreadLimit:", i) + } + <-c + } + + // If we're still here, it worked. +} + func TestDialGoogleIPv4(t *testing.T) { if testing.Short() || !*testExternal { t.Skip("skipping test to avoid external network") diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go index 76b192645aa..01db4372945 100644 --- a/libgo/go/net/dnsclient.go +++ b/libgo/go/net/dnsclient.go @@ -122,12 +122,9 @@ func isDomainName(s string) bool { if len(s) > 255 { return false } - if s[len(s)-1] != '.' { // simplify checking loop: make name end in dot - s += "." - } last := byte('.') - ok := false // ok once we've seen a letter + ok := false // Ok once we've seen a letter. partlen := 0 for i := 0; i < len(s); i++ { c := s[i] @@ -141,13 +138,13 @@ func isDomainName(s string) bool { // fine partlen++ case c == '-': - // byte before dash cannot be dot + // Byte before dash cannot be dot. if last == '.' { return false } partlen++ case c == '.': - // byte before dot cannot be dot, dash + // Byte before dot cannot be dot, dash. if last == '.' || last == '-' { return false } @@ -158,6 +155,9 @@ func isDomainName(s string) bool { } last = c } + if last == '-' || partlen > 63 { + return false + } return ok } diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index 9e21bb4a0f6..16cf420dcdb 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd // DNS client: see RFC 1035. // Has to be linked into package net for Dial. @@ -17,6 +17,7 @@ package net import ( + "io" "math/rand" "sync" "time" @@ -25,6 +26,7 @@ import ( // Send a request on the connection and hope for a reply. // Up to cfg.attempts attempts. func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) { + _, useTCP := c.(*TCPConn) if len(name) >= 256 { return nil, &DNSError{Err: "name too long", Name: name} } @@ -38,7 +40,10 @@ func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error if !ok { return nil, &DNSError{Err: "internal error - cannot pack message", Name: name} } - + if useTCP { + mlen := uint16(len(msg)) + msg = append([]byte{byte(mlen >> 8), byte(mlen)}, msg...) + } for attempt := 0; attempt < cfg.attempts; attempt++ { n, err := c.Write(msg) if err != nil { @@ -46,20 +51,33 @@ func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error } if cfg.timeout == 0 { - c.SetReadDeadline(time.Time{}) + c.SetReadDeadline(noDeadline) } else { c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second)) } - - buf := make([]byte, 2000) // More than enough. - n, err = c.Read(buf) + buf := make([]byte, 2000) + if useTCP { + n, err = io.ReadFull(c, buf[:2]) + if err != nil { + if e, ok := err.(Error); ok && e.Timeout() { + continue + } + } + mlen := int(buf[0])<<8 | int(buf[1]) + if mlen > len(buf) { + buf = make([]byte, mlen) + } + n, err = io.ReadFull(c, buf[:mlen]) + } else { + n, err = c.Read(buf) + } if err != nil { if e, ok := err.(Error); ok && e.Timeout() { continue } return nil, err } - buf = buf[0:n] + buf = buf[:n] in := new(dnsMsg) if !in.Unpack(buf) || in.id != out.id { continue @@ -98,6 +116,19 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs err = merr continue } + if msg.truncated { // see RFC 5966 + c, cerr = Dial("tcp", server) + if cerr != nil { + err = cerr + continue + } + msg, merr = exchange(cfg, c, name, qtype) + c.Close() + if merr != nil { + err = merr + continue + } + } cname, addrs, err = answer(name, server, msg, qtype) if err == nil || err.(*DNSError).Err == noSuchHost { break @@ -180,6 +211,12 @@ func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err error) if err == nil { return } + if e, ok := err.(*DNSError); ok { + // Show original name passed to lookup, not suffixed one. + // In general we might have tried many suffixes; showing + // just one is misleading. See also golang.org/issue/6324. + e.Name = name + } return } diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go new file mode 100644 index 00000000000..47dcb563bc5 --- /dev/null +++ b/libgo/go/net/dnsclient_unix_test.go @@ -0,0 +1,27 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package net + +import ( + "testing" +) + +func TestTCPLookup(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + c, err := Dial("tcp", "8.8.8.8:53") + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + cfg := &dnsConfig{timeout: 10, attempts: 3} + _, err = exchange(cfg, c, "com.", dnsTypeALL) + if err != nil { + t.Fatalf("exchange failed: %v", err) + } +} diff --git a/libgo/go/net/dnsconfig_unix.go b/libgo/go/net/dnsconfig_unix.go index bb46cc9007c..2f0f6c031f1 100644 --- a/libgo/go/net/dnsconfig_unix.go +++ b/libgo/go/net/dnsconfig_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd // Read system DNS config from /etc/resolv.conf diff --git a/libgo/go/net/dnsname_test.go b/libgo/go/net/dnsname_test.go index 70df693f789..57dd25fe4c6 100644 --- a/libgo/go/net/dnsname_test.go +++ b/libgo/go/net/dnsname_test.go @@ -5,6 +5,7 @@ package net import ( + "strings" "testing" ) @@ -16,7 +17,6 @@ type testCase struct { var tests = []testCase{ // RFC2181, section 11. {"_xmpp-server._tcp.google.com", true}, - {"_xmpp-server._tcp.google.com", true}, {"foo.com", true}, {"1foo.com", true}, {"26.0.0.73.com", true}, @@ -24,6 +24,10 @@ var tests = []testCase{ {"fo1o.com", true}, {"foo1.com", true}, {"a.b..com", false}, + {"a.b-.com", false}, + {"a.b.com-", false}, + {"a.b..", false}, + {"b.com.", true}, } func getTestCases(ch chan<- testCase) { @@ -63,3 +67,17 @@ func TestDNSNames(t *testing.T) { } } } + +func BenchmarkDNSNames(b *testing.B) { + benchmarks := append(tests, []testCase{ + {strings.Repeat("a", 63), true}, + {strings.Repeat("a", 64), false}, + }...) + for n := 0; n < b.N; n++ { + for _, tc := range benchmarks { + if isDomainName(tc.name) != tc.result { + b.Errorf("isDomainName(%q) = %v; want %v", tc.name, !tc.result, tc.result) + } + } + } +} diff --git a/libgo/go/net/fd_bsd.go b/libgo/go/net/fd_bsd.go deleted file mode 100644 index 8bb1ae53847..00000000000 --- a/libgo/go/net/fd_bsd.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build freebsd netbsd openbsd - -// Waiting for FDs via kqueue/kevent. - -package net - -import ( - "os" - "syscall" -) - -type pollster struct { - kq int - eventbuf [10]syscall.Kevent_t - events []syscall.Kevent_t - - // An event buffer for AddFD/DelFD. - // Must hold pollServer lock. - kbuf [1]syscall.Kevent_t -} - -func newpollster() (p *pollster, err error) { - p = new(pollster) - if p.kq, err = syscall.Kqueue(); err != nil { - return nil, os.NewSyscallError("kqueue", err) - } - syscall.CloseOnExec(p.kq) - p.events = p.eventbuf[0:0] - return p, nil -} - -// First return value is whether the pollServer should be woken up. -// This version always returns false. -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_ADD - add event to kqueue list - // EV_ONESHOT - delete the event the first time it triggers - flags := syscall.EV_ADD - if !repeat { - flags |= syscall.EV_ONESHOT - } - syscall.SetKevent(ev, fd, kmode, flags) - - n, err := syscall.Kevent(p.kq, p.kbuf[:], nil, nil) - if err != nil { - return false, os.NewSyscallError("kevent", err) - } - if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode { - return false, os.NewSyscallError("kqueue phase error", err) - } - if ev.Data != 0 { - return false, syscall.Errno(int(ev.Data)) - } - return false, nil -} - -// Return value is whether the pollServer should be woken up. -// This version always returns false. -func (p *pollster) DelFD(fd int, mode int) bool { - // pollServer is locked. - - var kmode int - if mode == 'r' { - kmode = syscall.EVFILT_READ - } else { - kmode = syscall.EVFILT_WRITE - } - ev := &p.kbuf[0] - // EV_DELETE - delete event from kqueue list - syscall.SetKevent(ev, fd, kmode, syscall.EV_DELETE) - syscall.Kevent(p.kq, p.kbuf[:], nil, nil) - return false -} - -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { - var t *syscall.Timespec - for len(p.events) == 0 { - if nsec > 0 { - if t == nil { - t = new(syscall.Timespec) - } - *t = syscall.NsecToTimespec(nsec) - } - - s.Unlock() - n, err := syscall.Kevent(p.kq, nil, p.eventbuf[:], t) - s.Lock() - - if err != nil { - if err == syscall.EINTR { - continue - } - return -1, 0, os.NewSyscallError("kevent", err) - } - if n == 0 { - return -1, 0, nil - } - p.events = p.eventbuf[:n] - } - ev := &p.events[0] - p.events = p.events[1:] - fd = int(ev.Ident) - if ev.Filter == syscall.EVFILT_READ { - mode = 'r' - } else { - mode = 'w' - } - return fd, mode, nil -} - -func (p *pollster) Close() error { return os.NewSyscallError("close", syscall.Close(p.kq)) } diff --git a/libgo/go/net/fd_mutex.go b/libgo/go/net/fd_mutex.go new file mode 100644 index 00000000000..6d5509d7f2a --- /dev/null +++ b/libgo/go/net/fd_mutex.go @@ -0,0 +1,184 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "sync/atomic" + +// fdMutex is a specialized synchronization primitive +// that manages lifetime of an fd and serializes access +// to Read and Write methods on netFD. +type fdMutex struct { + state uint64 + rsema uint32 + wsema uint32 +} + +// fdMutex.state is organized as follows: +// 1 bit - whether netFD is closed, if set all subsequent lock operations will fail. +// 1 bit - lock for read operations. +// 1 bit - lock for write operations. +// 20 bits - total number of references (read+write+misc). +// 20 bits - number of outstanding read waiters. +// 20 bits - number of outstanding write waiters. +const ( + mutexClosed = 1 << 0 + mutexRLock = 1 << 1 + mutexWLock = 1 << 2 + mutexRef = 1 << 3 + mutexRefMask = (1<<20 - 1) << 3 + mutexRWait = 1 << 23 + mutexRMask = (1<<20 - 1) << 23 + mutexWWait = 1 << 43 + mutexWMask = (1<<20 - 1) << 43 +) + +// Read operations must do RWLock(true)/RWUnlock(true). +// Write operations must do RWLock(false)/RWUnlock(false). +// Misc operations must do Incref/Decref. Misc operations include functions like +// setsockopt and setDeadline. They need to use Incref/Decref to ensure that +// they operate on the correct fd in presence of a concurrent Close call +// (otherwise fd can be closed under their feet). +// Close operation must do IncrefAndClose/Decref. + +// RWLock/Incref return whether fd is open. +// RWUnlock/Decref return whether fd is closed and there are no remaining references. + +func (mu *fdMutex) Incref() bool { + for { + old := atomic.LoadUint64(&mu.state) + if old&mutexClosed != 0 { + return false + } + new := old + mutexRef + if new&mutexRefMask == 0 { + panic("net: inconsistent fdMutex") + } + if atomic.CompareAndSwapUint64(&mu.state, old, new) { + return true + } + } +} + +func (mu *fdMutex) IncrefAndClose() bool { + for { + old := atomic.LoadUint64(&mu.state) + if old&mutexClosed != 0 { + return false + } + // Mark as closed and acquire a reference. + new := (old | mutexClosed) + mutexRef + if new&mutexRefMask == 0 { + panic("net: inconsistent fdMutex") + } + // Remove all read and write waiters. + new &^= mutexRMask | mutexWMask + if atomic.CompareAndSwapUint64(&mu.state, old, new) { + // Wake all read and write waiters, + // they will observe closed flag after wakeup. + for old&mutexRMask != 0 { + old -= mutexRWait + runtime_Semrelease(&mu.rsema) + } + for old&mutexWMask != 0 { + old -= mutexWWait + runtime_Semrelease(&mu.wsema) + } + return true + } + } +} + +func (mu *fdMutex) Decref() bool { + for { + old := atomic.LoadUint64(&mu.state) + if old&mutexRefMask == 0 { + panic("net: inconsistent fdMutex") + } + new := old - mutexRef + if atomic.CompareAndSwapUint64(&mu.state, old, new) { + return new&(mutexClosed|mutexRefMask) == mutexClosed + } + } +} + +func (mu *fdMutex) RWLock(read bool) bool { + var mutexBit, mutexWait, mutexMask uint64 + var mutexSema *uint32 + if read { + mutexBit = mutexRLock + mutexWait = mutexRWait + mutexMask = mutexRMask + mutexSema = &mu.rsema + } else { + mutexBit = mutexWLock + mutexWait = mutexWWait + mutexMask = mutexWMask + mutexSema = &mu.wsema + } + for { + old := atomic.LoadUint64(&mu.state) + if old&mutexClosed != 0 { + return false + } + var new uint64 + if old&mutexBit == 0 { + // Lock is free, acquire it. + new = (old | mutexBit) + mutexRef + if new&mutexRefMask == 0 { + panic("net: inconsistent fdMutex") + } + } else { + // Wait for lock. + new = old + mutexWait + if new&mutexMask == 0 { + panic("net: inconsistent fdMutex") + } + } + if atomic.CompareAndSwapUint64(&mu.state, old, new) { + if old&mutexBit == 0 { + return true + } + runtime_Semacquire(mutexSema) + // The signaller has subtracted mutexWait. + } + } +} + +func (mu *fdMutex) RWUnlock(read bool) bool { + var mutexBit, mutexWait, mutexMask uint64 + var mutexSema *uint32 + if read { + mutexBit = mutexRLock + mutexWait = mutexRWait + mutexMask = mutexRMask + mutexSema = &mu.rsema + } else { + mutexBit = mutexWLock + mutexWait = mutexWWait + mutexMask = mutexWMask + mutexSema = &mu.wsema + } + for { + old := atomic.LoadUint64(&mu.state) + if old&mutexBit == 0 || old&mutexRefMask == 0 { + panic("net: inconsistent fdMutex") + } + // Drop lock, drop reference and wake read waiter if present. + new := (old &^ mutexBit) - mutexRef + if old&mutexMask != 0 { + new -= mutexWait + } + if atomic.CompareAndSwapUint64(&mu.state, old, new) { + if old&mutexMask != 0 { + runtime_Semrelease(mutexSema) + } + return new&(mutexClosed|mutexRefMask) == mutexClosed + } + } +} + +// Implemented in runtime package. +func runtime_Semacquire(sema *uint32) +func runtime_Semrelease(sema *uint32) diff --git a/libgo/go/net/fd_mutex_test.go b/libgo/go/net/fd_mutex_test.go new file mode 100644 index 00000000000..8383084b7a2 --- /dev/null +++ b/libgo/go/net/fd_mutex_test.go @@ -0,0 +1,186 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "math/rand" + "runtime" + "testing" + "time" +) + +func TestMutexLock(t *testing.T) { + var mu fdMutex + + if !mu.Incref() { + t.Fatal("broken") + } + if mu.Decref() { + t.Fatal("broken") + } + + if !mu.RWLock(true) { + t.Fatal("broken") + } + if mu.RWUnlock(true) { + t.Fatal("broken") + } + + if !mu.RWLock(false) { + t.Fatal("broken") + } + if mu.RWUnlock(false) { + t.Fatal("broken") + } +} + +func TestMutexClose(t *testing.T) { + var mu fdMutex + if !mu.IncrefAndClose() { + t.Fatal("broken") + } + + if mu.Incref() { + t.Fatal("broken") + } + if mu.RWLock(true) { + t.Fatal("broken") + } + if mu.RWLock(false) { + t.Fatal("broken") + } + if mu.IncrefAndClose() { + t.Fatal("broken") + } +} + +func TestMutexCloseUnblock(t *testing.T) { + c := make(chan bool) + var mu fdMutex + mu.RWLock(true) + for i := 0; i < 4; i++ { + go func() { + if mu.RWLock(true) { + t.Fatal("broken") + } + c <- true + }() + } + // Concurrent goroutines must not be able to read lock the mutex. + time.Sleep(time.Millisecond) + select { + case <-c: + t.Fatal("broken") + default: + } + mu.IncrefAndClose() // Must unblock the readers. + for i := 0; i < 4; i++ { + select { + case <-c: + case <-time.After(10 * time.Second): + t.Fatal("broken") + } + } + if mu.Decref() { + t.Fatal("broken") + } + if !mu.RWUnlock(true) { + t.Fatal("broken") + } +} + +func TestMutexPanic(t *testing.T) { + ensurePanics := func(f func()) { + defer func() { + if recover() == nil { + t.Fatal("does not panic") + } + }() + f() + } + + var mu fdMutex + ensurePanics(func() { mu.Decref() }) + ensurePanics(func() { mu.RWUnlock(true) }) + ensurePanics(func() { mu.RWUnlock(false) }) + + ensurePanics(func() { mu.Incref(); mu.Decref(); mu.Decref() }) + ensurePanics(func() { mu.RWLock(true); mu.RWUnlock(true); mu.RWUnlock(true) }) + ensurePanics(func() { mu.RWLock(false); mu.RWUnlock(false); mu.RWUnlock(false) }) + + // ensure that it's still not broken + mu.Incref() + mu.Decref() + mu.RWLock(true) + mu.RWUnlock(true) + mu.RWLock(false) + mu.RWUnlock(false) +} + +func TestMutexStress(t *testing.T) { + P := 8 + N := int(1e6) + if testing.Short() { + P = 4 + N = 1e4 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(P)) + done := make(chan bool) + var mu fdMutex + var readState [2]uint64 + var writeState [2]uint64 + for p := 0; p < P; p++ { + go func() { + r := rand.New(rand.NewSource(rand.Int63())) + for i := 0; i < N; i++ { + switch r.Intn(3) { + case 0: + if !mu.Incref() { + t.Fatal("broken") + } + if mu.Decref() { + t.Fatal("broken") + } + case 1: + if !mu.RWLock(true) { + t.Fatal("broken") + } + // Ensure that it provides mutual exclusion for readers. + if readState[0] != readState[1] { + t.Fatal("broken") + } + readState[0]++ + readState[1]++ + if mu.RWUnlock(true) { + t.Fatal("broken") + } + case 2: + if !mu.RWLock(false) { + t.Fatal("broken") + } + // Ensure that it provides mutual exclusion for writers. + if writeState[0] != writeState[1] { + t.Fatal("broken") + } + writeState[0]++ + writeState[1]++ + if mu.RWUnlock(false) { + t.Fatal("broken") + } + } + } + done <- true + }() + } + for p := 0; p < P; p++ { + <-done + } + if !mu.IncrefAndClose() { + t.Fatal("broken") + } + if !mu.Decref() { + t.Fatal("broken") + } +} diff --git a/libgo/go/net/fd_plan9.go b/libgo/go/net/fd_plan9.go index e9527a3743b..acc82940217 100644 --- a/libgo/go/net/fd_plan9.go +++ b/libgo/go/net/fd_plan9.go @@ -18,15 +18,13 @@ type netFD struct { laddr, raddr Addr } -var canCancelIO = true // used for testing current package - func sysInit() { } -func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { +func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { // On plan9, use the relatively inefficient // goroutine-racing implementation. - return resolveAndDialChannel(net, addr, localAddr, deadline) + return dialChannel(net, ra, dialer, deadline) } func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) *netFD { @@ -108,15 +106,15 @@ func (fd *netFD) file(f *os.File, s string) (*os.File, error) { return os.NewFile(uintptr(dfd), s), nil } -func setDeadline(fd *netFD, t time.Time) error { +func (fd *netFD) setDeadline(t time.Time) error { return syscall.EPLAN9 } -func setReadDeadline(fd *netFD, t time.Time) error { +func (fd *netFD) setReadDeadline(t time.Time) error { return syscall.EPLAN9 } -func setWriteDeadline(fd *netFD, t time.Time) error { +func (fd *netFD) setWriteDeadline(t time.Time) error { return syscall.EPLAN9 } @@ -127,3 +125,7 @@ func setReadBuffer(fd *netFD, bytes int) error { func setWriteBuffer(fd *netFD, bytes int) error { return syscall.EPLAN9 } + +func skipRawSocketTests() (skip bool, skipmsg string, err error) { + return true, "skipping test on plan9", nil +} diff --git a/libgo/go/net/fd_poll_runtime.go b/libgo/go/net/fd_poll_runtime.go index e3b4f7e4648..e2b2768864a 100644 --- a/libgo/go/net/fd_poll_runtime.go +++ b/libgo/go/net/fd_poll_runtime.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin linux +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net @@ -13,27 +13,23 @@ import ( ) func runtime_pollServerInit() -func runtime_pollOpen(fd int) (uintptr, int) +func runtime_pollOpen(fd uintptr) (uintptr, int) func runtime_pollClose(ctx uintptr) func runtime_pollWait(ctx uintptr, mode int) int +func runtime_pollWaitCanceled(ctx uintptr, mode int) int func runtime_pollReset(ctx uintptr, mode int) int func runtime_pollSetDeadline(ctx uintptr, d int64, mode int) func runtime_pollUnblock(ctx uintptr) -var canCancelIO = true // used for testing current package - type pollDesc struct { runtimeCtx uintptr } var serverInit sync.Once -func sysInit() { -} - func (pd *pollDesc) Init(fd *netFD) error { serverInit.Do(runtime_pollServerInit) - ctx, errno := runtime_pollOpen(fd.sysfd) + ctx, errno := runtime_pollOpen(uintptr(fd.sysfd)) if errno != 0 { return syscall.Errno(errno) } @@ -42,7 +38,11 @@ func (pd *pollDesc) Init(fd *netFD) error { } func (pd *pollDesc) Close() { + if pd.runtimeCtx == 0 { + return + } runtime_pollClose(pd.runtimeCtx) + pd.runtimeCtx = 0 } func (pd *pollDesc) Lock() { @@ -57,28 +57,49 @@ func (pd *pollDesc) Wakeup() { // Evict evicts fd from the pending list, unblocking any I/O running on fd. // Return value is whether the pollServer should be woken up. func (pd *pollDesc) Evict() bool { + if pd.runtimeCtx == 0 { + return false + } runtime_pollUnblock(pd.runtimeCtx) return false } -func (pd *pollDesc) PrepareRead() error { - res := runtime_pollReset(pd.runtimeCtx, 'r') +func (pd *pollDesc) Prepare(mode int) error { + res := runtime_pollReset(pd.runtimeCtx, mode) return convertErr(res) } +func (pd *pollDesc) PrepareRead() error { + return pd.Prepare('r') +} + func (pd *pollDesc) PrepareWrite() error { - res := runtime_pollReset(pd.runtimeCtx, 'w') + return pd.Prepare('w') +} + +func (pd *pollDesc) Wait(mode int) error { + res := runtime_pollWait(pd.runtimeCtx, mode) return convertErr(res) } func (pd *pollDesc) WaitRead() error { - res := runtime_pollWait(pd.runtimeCtx, 'r') - return convertErr(res) + return pd.Wait('r') } func (pd *pollDesc) WaitWrite() error { - res := runtime_pollWait(pd.runtimeCtx, 'w') - return convertErr(res) + return pd.Wait('w') +} + +func (pd *pollDesc) WaitCanceled(mode int) { + runtime_pollWaitCanceled(pd.runtimeCtx, mode) +} + +func (pd *pollDesc) WaitCanceledRead() { + pd.WaitCanceled('r') +} + +func (pd *pollDesc) WaitCanceledWrite() { + pd.WaitCanceled('w') } func convertErr(res int) error { @@ -90,19 +111,20 @@ func convertErr(res int) error { case 2: return errTimeout } + println("unreachable: ", res) panic("unreachable") } -func setReadDeadline(fd *netFD, t time.Time) error { - return setDeadlineImpl(fd, t, 'r') +func (fd *netFD) setDeadline(t time.Time) error { + return setDeadlineImpl(fd, t, 'r'+'w') } -func setWriteDeadline(fd *netFD, t time.Time) error { - return setDeadlineImpl(fd, t, 'w') +func (fd *netFD) setReadDeadline(t time.Time) error { + return setDeadlineImpl(fd, t, 'r') } -func setDeadline(fd *netFD, t time.Time) error { - return setDeadlineImpl(fd, t, 'r'+'w') +func (fd *netFD) setWriteDeadline(t time.Time) error { + return setDeadlineImpl(fd, t, 'w') } func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { @@ -110,7 +132,7 @@ func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { if t.IsZero() { d = 0 } - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } runtime_pollSetDeadline(fd.pd.runtimeCtx, d, mode) diff --git a/libgo/go/net/fd_poll_unix.go b/libgo/go/net/fd_poll_unix.go deleted file mode 100644 index 307e577e999..00000000000 --- a/libgo/go/net/fd_poll_unix.go +++ /dev/null @@ -1,360 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build freebsd netbsd openbsd - -package net - -import ( - "os" - "runtime" - "sync" - "syscall" - "time" -) - -// A pollServer helps FDs determine when to retry a non-blocking -// read or write after they get EAGAIN. When an FD needs to wait, -// call s.WaitRead() or s.WaitWrite() to pass the request to the poll server. -// When the pollServer finds that i/o on FD should be possible -// again, it will send on fd.cr/fd.cw to wake any waiting goroutines. -// -// To avoid races in closing, all fd operations are locked and -// refcounted. when netFD.Close() is called, it calls syscall.Shutdown -// and sets a closing flag. Only when the last reference is removed -// will the fd be closed. - -type pollServer struct { - pr, pw *os.File - poll *pollster // low-level OS hooks - sync.Mutex // controls pending and deadline - pending map[int]*pollDesc - deadline int64 // next deadline (nsec since 1970) -} - -// A pollDesc contains netFD state related to pollServer. -type pollDesc struct { - // immutable after Init() - pollServer *pollServer - sysfd int - cr, cw chan error - - // mutable, protected by pollServer mutex - closing bool - ncr, ncw int - - // mutable, safe for concurrent access - rdeadline, wdeadline deadline -} - -func newPollServer() (s *pollServer, err error) { - s = new(pollServer) - if s.pr, s.pw, err = os.Pipe(); err != nil { - return nil, err - } - if err = syscall.SetNonblock(int(s.pr.Fd()), true); err != nil { - goto Errno - } - if err = syscall.SetNonblock(int(s.pw.Fd()), true); err != nil { - goto Errno - } - if s.poll, err = newpollster(); err != nil { - goto Error - } - if _, err = s.poll.AddFD(int(s.pr.Fd()), 'r', true); err != nil { - s.poll.Close() - goto Error - } - s.pending = make(map[int]*pollDesc) - go s.Run() - return s, nil - -Errno: - err = &os.PathError{ - Op: "setnonblock", - Path: s.pr.Name(), - Err: err, - } -Error: - s.pr.Close() - s.pw.Close() - return nil, err -} - -func (s *pollServer) AddFD(pd *pollDesc, mode int) error { - s.Lock() - intfd := pd.sysfd - if intfd < 0 || pd.closing { - // fd closed underfoot - s.Unlock() - return errClosing - } - - var t int64 - key := intfd << 1 - if mode == 'r' { - pd.ncr++ - t = pd.rdeadline.value() - } else { - pd.ncw++ - key++ - t = pd.wdeadline.value() - } - s.pending[key] = pd - doWakeup := false - if t > 0 && (s.deadline == 0 || t < s.deadline) { - s.deadline = t - doWakeup = true - } - - wake, err := s.poll.AddFD(intfd, mode, false) - s.Unlock() - if err != nil { - return err - } - if wake || doWakeup { - s.Wakeup() - } - return nil -} - -// Evict evicts pd from the pending list, unblocking -// any I/O running on pd. The caller must have locked -// pollserver. -// Return value is whether the pollServer should be woken up. -func (s *pollServer) Evict(pd *pollDesc) bool { - pd.closing = true - doWakeup := false - if s.pending[pd.sysfd<<1] == pd { - s.WakeFD(pd, 'r', errClosing) - if s.poll.DelFD(pd.sysfd, 'r') { - doWakeup = true - } - delete(s.pending, pd.sysfd<<1) - } - if s.pending[pd.sysfd<<1|1] == pd { - s.WakeFD(pd, 'w', errClosing) - if s.poll.DelFD(pd.sysfd, 'w') { - doWakeup = true - } - delete(s.pending, pd.sysfd<<1|1) - } - return doWakeup -} - -var wakeupbuf [1]byte - -func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } - -func (s *pollServer) LookupFD(fd int, mode int) *pollDesc { - key := fd << 1 - if mode == 'w' { - key++ - } - netfd, ok := s.pending[key] - if !ok { - return nil - } - delete(s.pending, key) - return netfd -} - -func (s *pollServer) WakeFD(pd *pollDesc, mode int, err error) { - if mode == 'r' { - for pd.ncr > 0 { - pd.ncr-- - pd.cr <- err - } - } else { - for pd.ncw > 0 { - pd.ncw-- - pd.cw <- err - } - } -} - -func (s *pollServer) CheckDeadlines() { - now := time.Now().UnixNano() - // TODO(rsc): This will need to be handled more efficiently, - // probably with a heap indexed by wakeup time. - - var nextDeadline int64 - for key, pd := range s.pending { - var t int64 - var mode int - if key&1 == 0 { - mode = 'r' - } else { - mode = 'w' - } - if mode == 'r' { - t = pd.rdeadline.value() - } else { - t = pd.wdeadline.value() - } - if t > 0 { - if t <= now { - delete(s.pending, key) - s.poll.DelFD(pd.sysfd, mode) - s.WakeFD(pd, mode, errTimeout) - } else if nextDeadline == 0 || t < nextDeadline { - nextDeadline = t - } - } - } - s.deadline = nextDeadline -} - -func (s *pollServer) Run() { - var scratch [100]byte - s.Lock() - defer s.Unlock() - for { - var timeout int64 // nsec to wait for or 0 for none - if s.deadline > 0 { - timeout = s.deadline - time.Now().UnixNano() - if timeout <= 0 { - s.CheckDeadlines() - continue - } - } - fd, mode, err := s.poll.WaitFD(s, timeout) - if err != nil { - print("pollServer WaitFD: ", err.Error(), "\n") - return - } - if fd < 0 { - // Timeout happened. - s.CheckDeadlines() - continue - } - if fd == int(s.pr.Fd()) { - // Drain our wakeup pipe (we could loop here, - // but it's unlikely that there are more than - // len(scratch) wakeup calls). - s.pr.Read(scratch[0:]) - s.CheckDeadlines() - } else { - pd := s.LookupFD(fd, mode) - if pd == nil { - // This can happen because the WaitFD runs without - // holding s's lock, so there might be a pending wakeup - // for an fd that has been evicted. No harm done. - continue - } - s.WakeFD(pd, mode, nil) - } - } -} - -func (pd *pollDesc) Close() { -} - -func (pd *pollDesc) Lock() { - pd.pollServer.Lock() -} - -func (pd *pollDesc) Unlock() { - pd.pollServer.Unlock() -} - -func (pd *pollDesc) Wakeup() { - pd.pollServer.Wakeup() -} - -func (pd *pollDesc) PrepareRead() error { - if pd.rdeadline.expired() { - return errTimeout - } - return nil -} - -func (pd *pollDesc) PrepareWrite() error { - if pd.wdeadline.expired() { - return errTimeout - } - return nil -} - -func (pd *pollDesc) WaitRead() error { - err := pd.pollServer.AddFD(pd, 'r') - if err == nil { - err = <-pd.cr - } - return err -} - -func (pd *pollDesc) WaitWrite() error { - err := pd.pollServer.AddFD(pd, 'w') - if err == nil { - err = <-pd.cw - } - return err -} - -func (pd *pollDesc) Evict() bool { - return pd.pollServer.Evict(pd) -} - -// Spread network FDs over several pollServers. - -var pollMaxN int -var pollservers []*pollServer -var startServersOnce []func() - -var canCancelIO = true // used for testing current package - -func sysInit() { - pollMaxN = runtime.NumCPU() - if pollMaxN > 8 { - pollMaxN = 8 // No improvement then. - } - pollservers = make([]*pollServer, pollMaxN) - startServersOnce = make([]func(), pollMaxN) - for i := 0; i < pollMaxN; i++ { - k := i - once := new(sync.Once) - startServersOnce[i] = func() { once.Do(func() { startServer(k) }) } - } -} - -func startServer(k int) { - p, err := newPollServer() - if err != nil { - panic(err) - } - pollservers[k] = p -} - -func (pd *pollDesc) Init(fd *netFD) error { - pollN := runtime.GOMAXPROCS(0) - if pollN > pollMaxN { - pollN = pollMaxN - } - k := fd.sysfd % pollN - startServersOnce[k]() - pd.sysfd = fd.sysfd - pd.pollServer = pollservers[k] - pd.cr = make(chan error, 1) - pd.cw = make(chan error, 1) - return nil -} - -// TODO(dfc) these unused error returns could be removed - -func setReadDeadline(fd *netFD, t time.Time) error { - fd.pd.rdeadline.setTime(t) - return nil -} - -func setWriteDeadline(fd *netFD, t time.Time) error { - fd.pd.wdeadline.setTime(t) - return nil -} - -func setDeadline(fd *netFD, t time.Time) error { - setReadDeadline(fd, t) - setWriteDeadline(fd, t) - return nil -} diff --git a/libgo/go/net/fd_posix_test.go b/libgo/go/net/fd_posix_test.go deleted file mode 100644 index 8be0335d61c..00000000000 --- a/libgo/go/net/fd_posix_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin freebsd linux netbsd openbsd windows - -package net - -import ( - "testing" - "time" -) - -var deadlineSetTimeTests = []struct { - input time.Time - expected int64 -}{ - {time.Time{}, 0}, - {time.Date(2009, 11, 10, 23, 00, 00, 00, time.UTC), 1257894000000000000}, // 2009-11-10 23:00:00 +0000 UTC -} - -func TestDeadlineSetTime(t *testing.T) { - for _, tt := range deadlineSetTimeTests { - var d deadline - d.setTime(tt.input) - actual := d.value() - expected := int64(0) - if !tt.input.IsZero() { - expected = tt.input.UnixNano() - } - if actual != expected { - t.Errorf("set/value failed: expected %v, actual %v", expected, actual) - } - } -} - -var deadlineExpiredTests = []struct { - deadline time.Time - expired bool -}{ - // note, times are relative to the start of the test run, not - // the start of TestDeadlineExpired - {time.Now().Add(5 * time.Minute), false}, - {time.Now().Add(-5 * time.Minute), true}, - {time.Time{}, false}, // no deadline set -} - -func TestDeadlineExpired(t *testing.T) { - for _, tt := range deadlineExpiredTests { - var d deadline - d.set(tt.deadline.UnixNano()) - expired := d.expired() - if expired != tt.expired { - t.Errorf("expire failed: expected %v, actual %v", tt.expired, expired) - } - } -} diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go index 8c59bff989c..9ed4f753649 100644 --- a/libgo/go/net/fd_unix.go +++ b/libgo/go/net/fd_unix.go @@ -2,70 +2,59 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd package net import ( "io" "os" - "sync" + "runtime" + "sync/atomic" "syscall" "time" ) // Network file descriptor. type netFD struct { - // locking/lifetime of sysfd - sysmu sync.Mutex - sysref int - - // must lock both sysmu and pollDesc to write - // can lock either to read - closing bool + // locking/lifetime of sysfd + serialize access to Read and Write methods + fdmu fdMutex // immutable until Close sysfd int family int sotype int isConnected bool - sysfile *os.File net string laddr Addr raddr Addr - // serialize access to Read and Write methods - rio, wio sync.Mutex - // wait server pd pollDesc } -func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { - ra, err := resolveAddr("dial", net, addr, deadline) - if err != nil { - return nil, err - } - return dial(net, addr, localAddr, ra, deadline) +func sysInit() { } -func newFD(fd, family, sotype int, net string) (*netFD, error) { - netfd := &netFD{ - sysfd: fd, - family: family, - sotype: sotype, - net: net, - } - if err := netfd.pd.Init(netfd); err != nil { - return nil, err +func dial(network string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { + return dialer(deadline) +} + +func newFD(sysfd, family, sotype int, net string) (*netFD, error) { + return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil +} + +func (fd *netFD) init() error { + if err := fd.pd.Init(fd); err != nil { + return err } - return netfd, nil + return nil } func (fd *netFD) setAddr(laddr, raddr Addr) { fd.laddr = laddr fd.raddr = raddr - fd.sysfile = os.NewFile(uintptr(fd.sysfd), fd.net) + runtime.SetFinalizer(fd, (*netFD).Close) } func (fd *netFD) name() string { @@ -80,8 +69,9 @@ func (fd *netFD) name() string { } func (fd *netFD) connect(la, ra syscall.Sockaddr) error { - fd.wio.Lock() - defer fd.wio.Unlock() + // Do not need to call fd.writeLock here, + // because fd is not yet accessible to user, + // so no concurrent operations are possible. if err := fd.pd.PrepareWrite(); err != nil { return err } @@ -100,48 +90,69 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr) error { return nil } +func (fd *netFD) destroy() { + // Poller may want to unregister fd in readiness notification mechanism, + // so this must be executed before closesocket. + fd.pd.Close() + closesocket(fd.sysfd) + fd.sysfd = -1 + runtime.SetFinalizer(fd, nil) +} + // Add a reference to this fd. -// If closing==true, pollDesc must be locked; mark the fd as closing. // Returns an error if the fd cannot be used. -func (fd *netFD) incref(closing bool) error { - fd.sysmu.Lock() - if fd.closing { - fd.sysmu.Unlock() +func (fd *netFD) incref() error { + if !fd.fdmu.Incref() { return errClosing } - fd.sysref++ - if closing { - fd.closing = true - } - fd.sysmu.Unlock() return nil } -// Remove a reference to this FD and close if we've been asked to do so (and -// there are no references left. +// Remove a reference to this FD and close if we've been asked to do so +// (and there are no references left). func (fd *netFD) decref() { - fd.sysmu.Lock() - fd.sysref-- - if fd.closing && fd.sysref == 0 { - // Poller may want to unregister fd in readiness notification mechanism, - // so this must be executed before sysfile.Close(). - fd.pd.Close() - if fd.sysfile != nil { - fd.sysfile.Close() - fd.sysfile = nil - } else { - closesocket(fd.sysfd) - } - fd.sysfd = -1 + if fd.fdmu.Decref() { + fd.destroy() + } +} + +// Add a reference to this fd and lock for reading. +// Returns an error if the fd cannot be used. +func (fd *netFD) readLock() error { + if !fd.fdmu.RWLock(true) { + return errClosing + } + return nil +} + +// Unlock for reading and remove a reference to this FD. +func (fd *netFD) readUnlock() { + if fd.fdmu.RWUnlock(true) { + fd.destroy() + } +} + +// Add a reference to this fd and lock for writing. +// Returns an error if the fd cannot be used. +func (fd *netFD) writeLock() error { + if !fd.fdmu.RWLock(false) { + return errClosing + } + return nil +} + +// Unlock for writing and remove a reference to this FD. +func (fd *netFD) writeUnlock() { + if fd.fdmu.RWUnlock(false) { + fd.destroy() } - fd.sysmu.Unlock() } func (fd *netFD) Close() error { fd.pd.Lock() // needed for both fd.incref(true) and pollDesc.Evict - if err := fd.incref(true); err != nil { + if !fd.fdmu.IncrefAndClose() { fd.pd.Unlock() - return err + return errClosing } // Unblock any I/O. Once it all unblocks and returns, // so that it cannot be referring to fd.sysfd anymore, @@ -158,7 +169,7 @@ func (fd *netFD) Close() error { } func (fd *netFD) shutdown(how int) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() @@ -178,12 +189,10 @@ func (fd *netFD) CloseWrite() error { } func (fd *netFD) Read(p []byte) (n int, err error) { - fd.rio.Lock() - defer fd.rio.Unlock() - if err := fd.incref(false); err != nil { + if err := fd.readLock(); err != nil { return 0, err } - defer fd.decref() + defer fd.readUnlock() if err := fd.pd.PrepareRead(); err != nil { return 0, &OpError{"read", fd.net, fd.raddr, err} } @@ -207,12 +216,10 @@ func (fd *netFD) Read(p []byte) (n int, err error) { } func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { - fd.rio.Lock() - defer fd.rio.Unlock() - if err := fd.incref(false); err != nil { + if err := fd.readLock(); err != nil { return 0, nil, err } - defer fd.decref() + defer fd.readUnlock() if err := fd.pd.PrepareRead(); err != nil { return 0, nil, &OpError{"read", fd.net, fd.laddr, err} } @@ -236,12 +243,10 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { } func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { - fd.rio.Lock() - defer fd.rio.Unlock() - if err := fd.incref(false); err != nil { + if err := fd.readLock(); err != nil { return 0, 0, 0, nil, err } - defer fd.decref() + defer fd.readUnlock() if err := fd.pd.PrepareRead(); err != nil { return 0, 0, 0, nil, &OpError{"read", fd.net, fd.laddr, err} } @@ -272,12 +277,10 @@ func chkReadErr(n int, err error, fd *netFD) error { } func (fd *netFD) Write(p []byte) (nn int, err error) { - fd.wio.Lock() - defer fd.wio.Unlock() - if err := fd.incref(false); err != nil { + if err := fd.writeLock(); err != nil { return 0, err } - defer fd.decref() + defer fd.writeUnlock() if err := fd.pd.PrepareWrite(); err != nil { return 0, &OpError{"write", fd.net, fd.raddr, err} } @@ -311,12 +314,10 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { } func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { - fd.wio.Lock() - defer fd.wio.Unlock() - if err := fd.incref(false); err != nil { + if err := fd.writeLock(); err != nil { return 0, err } - defer fd.decref() + defer fd.writeUnlock() if err := fd.pd.PrepareWrite(); err != nil { return 0, &OpError{"write", fd.net, fd.raddr, err} } @@ -338,12 +339,10 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { } func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { - fd.wio.Lock() - defer fd.wio.Unlock() - if err := fd.incref(false); err != nil { + if err := fd.writeLock(); err != nil { return 0, 0, err } - defer fd.decref() + defer fd.writeUnlock() if err := fd.pd.PrepareWrite(); err != nil { return 0, 0, &OpError{"write", fd.net, fd.raddr, err} } @@ -366,12 +365,10 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob } func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) { - fd.rio.Lock() - defer fd.rio.Unlock() - if err := fd.incref(false); err != nil { + if err := fd.readLock(); err != nil { return nil, err } - defer fd.decref() + defer fd.readUnlock() var s int var rsa syscall.Sockaddr @@ -399,20 +396,68 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e closesocket(s) return nil, err } + if err = netfd.init(); err != nil { + fd.Close() + return nil, err + } lsa, _ := syscall.Getsockname(netfd.sysfd) netfd.setAddr(toAddr(lsa), toAddr(rsa)) return netfd, nil } -func (fd *netFD) dup() (f *os.File, err error) { +// tryDupCloexec indicates whether F_DUPFD_CLOEXEC should be used. +// If the kernel doesn't support it, this is set to 0. +var tryDupCloexec = int32(1) + +func dupCloseOnExec(fd int) (newfd int, err error) { + if atomic.LoadInt32(&tryDupCloexec) == 1 { + r0, _, e1 := syscall.Syscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_DUPFD_CLOEXEC, 0) + if runtime.GOOS == "darwin" && e1 == syscall.EBADF { + // On OS X 10.6 and below (but we only support + // >= 10.6), F_DUPFD_CLOEXEC is unsupported + // and fcntl there falls back (undocumented) + // to doing an ioctl instead, returning EBADF + // in this case because fd is not of the + // expected device fd type. Treat it as + // EINVAL instead, so we fall back to the + // normal dup path. + // TODO: only do this on 10.6 if we can detect 10.6 + // cheaply. + e1 = syscall.EINVAL + } + switch e1 { + case 0: + return int(r0), nil + case syscall.EINVAL: + // Old kernel. Fall back to the portable way + // from now on. + atomic.StoreInt32(&tryDupCloexec, 0) + default: + return -1, e1 + } + } + return dupCloseOnExecOld(fd) +} + +// dupCloseOnExecUnixOld is the traditional way to dup an fd and +// set its O_CLOEXEC bit, using two system calls. +func dupCloseOnExecOld(fd int) (newfd int, err error) { syscall.ForkLock.RLock() - ns, err := syscall.Dup(fd.sysfd) + defer syscall.ForkLock.RUnlock() + newfd, err = syscall.Dup(fd) + if err != nil { + return -1, err + } + syscall.CloseOnExec(newfd) + return +} + +func (fd *netFD) dup() (f *os.File, err error) { + ns, err := dupCloseOnExec(fd.sysfd) if err != nil { syscall.ForkLock.RUnlock() return nil, &OpError{"dup", fd.net, fd.laddr, err} } - syscall.CloseOnExec(ns) - syscall.ForkLock.RUnlock() // We want blocking mode for the new fd, hence the double negative. // This also puts the old fd into blocking mode, meaning that @@ -428,3 +473,10 @@ func (fd *netFD) dup() (f *os.File, err error) { func closesocket(s int) error { return syscall.Close(s) } + +func skipRawSocketTests() (skip bool, skipmsg string, err error) { + if os.Getuid() != 0 { + return true, "skipping test; must be root", nil + } + return false, "", nil +} diff --git a/libgo/go/net/fd_unix_test.go b/libgo/go/net/fd_unix_test.go index 664ef1bf19d..65d3e69a764 100644 --- a/libgo/go/net/fd_unix_test.go +++ b/libgo/go/net/fd_unix_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd package net diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go index fefd174bafa..64d56c73e06 100644 --- a/libgo/go/net/fd_windows.go +++ b/libgo/go/net/fd_windows.go @@ -15,7 +15,10 @@ import ( "unsafe" ) -var initErr error +var ( + initErr error + ioSync uint64 +) // CancelIo Windows API cancels all outstanding IO for a particular // socket on current thread. To overcome that limitation, we run @@ -27,7 +30,11 @@ var initErr error // package uses CancelIoEx API, if present, otherwise it fallback // to CancelIo. -var canCancelIO bool // determines if CancelIoEx API is present +var ( + canCancelIO bool // determines if CancelIoEx API is present + skipSyncNotif bool + hasLoadSetFileCompletionNotificationModes bool +) func sysInit() { var d syscall.WSAData @@ -40,6 +47,27 @@ func sysInit() { lookupPort = newLookupPort lookupIP = newLookupIP } + + hasLoadSetFileCompletionNotificationModes = syscall.LoadSetFileCompletionNotificationModes() == nil + if hasLoadSetFileCompletionNotificationModes { + // It's not safe to use FILE_SKIP_COMPLETION_PORT_ON_SUCCESS if non IFS providers are installed: + // http://support.microsoft.com/kb/2568167 + skipSyncNotif = true + protos := [2]int32{syscall.IPPROTO_TCP, 0} + var buf [32]syscall.WSAProtocolInfo + len := uint32(unsafe.Sizeof(buf)) + n, err := syscall.WSAEnumProtocols(&protos[0], &buf[0], &len) + if err != nil { + skipSyncNotif = false + } else { + for i := int32(0); i < n; i++ { + if buf[i].ServiceFlags1&syscall.XP1_IFS_HANDLES == 0 { + skipSyncNotif = false + break + } + } + } + } } func closesocket(s syscall.Handle) error { @@ -47,128 +75,62 @@ func closesocket(s syscall.Handle) error { } func canUseConnectEx(net string) bool { - if net == "udp" || net == "udp4" || net == "udp6" { + switch net { + case "udp", "udp4", "udp6", "ip", "ip4", "ip6": // ConnectEx windows API does not support connectionless sockets. return false } return syscall.LoadConnectEx() == nil } -func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { +func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { if !canUseConnectEx(net) { // Use the relatively inefficient goroutine-racing // implementation of DialTimeout. - return resolveAndDialChannel(net, addr, localAddr, deadline) - } - ra, err := resolveAddr("dial", net, addr, deadline) - if err != nil { - return nil, err + return dialChannel(net, ra, dialer, deadline) } - return dial(net, addr, localAddr, ra, deadline) + return dialer(deadline) } -// Interface for all IO operations. -type anOpIface interface { - Op() *anOp - Name() string - Submit() error -} - -// IO completion result parameters. -type ioResult struct { - qty uint32 - err error -} - -// anOp implements functionality common to all IO operations. -type anOp struct { +// operation contains superset of data necessary to perform all async IO. +type operation struct { // Used by IOCP interface, it must be first field // of the struct, as our code rely on it. o syscall.Overlapped - resultc chan ioResult - errnoc chan error - fd *netFD -} + // fields used by runtime.netpoll + runtimeCtx uintptr + mode int32 + errno int32 + qty uint32 -func (o *anOp) Init(fd *netFD, mode int) { - o.fd = fd - var i int - if mode == 'r' { - i = 0 - } else { - i = 1 - } - if fd.resultc[i] == nil { - fd.resultc[i] = make(chan ioResult, 1) - } - o.resultc = fd.resultc[i] - if fd.errnoc[i] == nil { - fd.errnoc[i] = make(chan error) - } - o.errnoc = fd.errnoc[i] + // fields used only by net package + fd *netFD + errc chan error + buf syscall.WSABuf + sa syscall.Sockaddr + rsa *syscall.RawSockaddrAny + rsan int32 + handle syscall.Handle + flags uint32 } -func (o *anOp) Op() *anOp { - return o -} - -// bufOp is used by IO operations that read / write -// data from / to client buffer. -type bufOp struct { - anOp - buf syscall.WSABuf -} - -func (o *bufOp) Init(fd *netFD, buf []byte, mode int) { - o.anOp.Init(fd, mode) +func (o *operation) InitBuf(buf []byte) { o.buf.Len = uint32(len(buf)) - if len(buf) == 0 { - o.buf.Buf = nil - } else { + o.buf.Buf = nil + if len(buf) != 0 { o.buf.Buf = (*byte)(unsafe.Pointer(&buf[0])) } } -// resultSrv will retrieve all IO completion results from -// iocp and send them to the correspondent waiting client -// goroutine via channel supplied in the request. -type resultSrv struct { - iocp syscall.Handle -} - -func runtime_blockingSyscallHint() - -func (s *resultSrv) Run() { - var o *syscall.Overlapped - var key uint32 - var r ioResult - for { - r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, 0) - if r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil { - runtime_blockingSyscallHint() - r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) - } - switch { - case r.err == nil: - // Dequeued successfully completed IO packet. - case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil: - // Wait has timed out (should not happen now, but might be used in the future). - panic("GetQueuedCompletionStatus timed out") - case o == nil: - // Failed to dequeue anything -> report the error. - panic("GetQueuedCompletionStatus failed " + r.err.Error()) - default: - // Dequeued failed IO packet. - } - (*anOp)(unsafe.Pointer(o)).resultc <- r - } -} - // ioSrv executes net IO requests. type ioSrv struct { - submchan chan anOpIface // submit IO requests - canchan chan anOpIface // cancel IO requests + req chan ioSrvReq +} + +type ioSrvReq struct { + o *operation + submit func(o *operation) error // if nil, cancel the operation } // ProcessRemoteIO will execute submit IO requests on behalf @@ -179,192 +141,182 @@ type ioSrv struct { func (s *ioSrv) ProcessRemoteIO() { runtime.LockOSThread() defer runtime.UnlockOSThread() - for { - select { - case o := <-s.submchan: - o.Op().errnoc <- o.Submit() - case o := <-s.canchan: - o.Op().errnoc <- syscall.CancelIo(syscall.Handle(o.Op().fd.sysfd)) + for r := range s.req { + if r.submit != nil { + r.o.errc <- r.submit(r.o) + } else { + r.o.errc <- syscall.CancelIo(r.o.fd.sysfd) } } } -// ExecIO executes a single IO operation oi. It submits and cancels +// ExecIO executes a single IO operation o. It submits and cancels // IO in the current thread for systems where Windows CancelIoEx API // is available. Alternatively, it passes the request onto -// a special goroutine and waits for completion or cancels request. -// deadline is unix nanos. -func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) { - var err error - o := oi.Op() - // Calculate timeout delta. - var delta int64 - if deadline != 0 { - delta = deadline - time.Now().UnixNano() - if delta <= 0 { - return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, errTimeout} - } +// runtime netpoll and waits for completion or cancels request. +func (s *ioSrv) ExecIO(o *operation, name string, submit func(o *operation) error) (int, error) { + fd := o.fd + // Notify runtime netpoll about starting IO. + err := fd.pd.Prepare(int(o.mode)) + if err != nil { + return 0, &OpError{name, fd.net, fd.laddr, err} } // Start IO. if canCancelIO { - err = oi.Submit() + err = submit(o) } else { // Send request to a special dedicated thread, // so it can stop the IO with CancelIO later. - s.submchan <- oi - err = <-o.errnoc + s.req <- ioSrvReq{o, submit} + err = <-o.errc } switch err { case nil: - // IO completed immediately, but we need to get our completion message anyway. + // IO completed immediately + if o.fd.skipSyncNotif { + // No completion message will follow, so return immediately. + return int(o.qty), nil + } + // Need to get our completion message anyway. case syscall.ERROR_IO_PENDING: // IO started, and we have to wait for its completion. err = nil default: - return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, err} - } - // Setup timer, if deadline is given. - var timer <-chan time.Time - if delta > 0 { - t := time.NewTimer(time.Duration(delta) * time.Nanosecond) - defer t.Stop() - timer = t.C + return 0, &OpError{name, fd.net, fd.laddr, err} } // Wait for our request to complete. - var r ioResult - var cancelled, timeout bool - select { - case r = <-o.resultc: - case <-timer: - cancelled = true - timeout = true - case <-o.fd.closec: - cancelled = true - } - if cancelled { - // Cancel it. - if canCancelIO { - err := syscall.CancelIoEx(syscall.Handle(o.Op().fd.sysfd), &o.o) - // Assuming ERROR_NOT_FOUND is returned, if IO is completed. - if err != nil && err != syscall.ERROR_NOT_FOUND { - // TODO(brainman): maybe do something else, but panic. - panic(err) - } - } else { - s.canchan <- oi - <-o.errnoc - } - // Wait for IO to be canceled or complete successfully. - r = <-o.resultc - if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled - if timeout { - r.err = errTimeout - } else { - r.err = errClosing - } + err = fd.pd.Wait(int(o.mode)) + if err == nil { + // All is good. Extract our IO results and return. + if o.errno != 0 { + err = syscall.Errno(o.errno) + return 0, &OpError{name, fd.net, fd.laddr, err} } + return int(o.qty), nil + } + // IO is interrupted by "close" or "timeout" + netpollErr := err + switch netpollErr { + case errClosing, errTimeout: + // will deal with those. + default: + panic("net: unexpected runtime.netpoll error: " + netpollErr.Error()) } - if r.err != nil { - err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err} + // Cancel our request. + if canCancelIO { + err := syscall.CancelIoEx(fd.sysfd, &o.o) + // Assuming ERROR_NOT_FOUND is returned, if IO is completed. + if err != nil && err != syscall.ERROR_NOT_FOUND { + // TODO(brainman): maybe do something else, but panic. + panic(err) + } + } else { + s.req <- ioSrvReq{o, nil} + <-o.errc + } + // Wait for cancellation to complete. + fd.pd.WaitCanceled(int(o.mode)) + if o.errno != 0 { + err = syscall.Errno(o.errno) + if err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled + err = netpollErr + } + return 0, &OpError{name, fd.net, fd.laddr, err} } - return int(r.qty), err + // We issued cancellation request. But, it seems, IO operation succeeded + // before cancellation request run. We need to treat IO operation as + // succeeded (the bytes are actually sent/recv from network). + return int(o.qty), nil } // Start helper goroutines. -var resultsrv *resultSrv -var iosrv *ioSrv +var rsrv, wsrv *ioSrv var onceStartServer sync.Once func startServer() { - resultsrv = new(resultSrv) - var err error - resultsrv.iocp, err = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1) - if err != nil { - panic("CreateIoCompletionPort: " + err.Error()) - } - go resultsrv.Run() - - iosrv = new(ioSrv) + rsrv = new(ioSrv) + wsrv = new(ioSrv) if !canCancelIO { - // Only CancelIo API is available. Lets start special goroutine - // locked to an OS thread, that both starts and cancels IO. - iosrv.submchan = make(chan anOpIface) - iosrv.canchan = make(chan anOpIface) - go iosrv.ProcessRemoteIO() + // Only CancelIo API is available. Lets start two special goroutines + // locked to an OS thread, that both starts and cancels IO. One will + // process read requests, while other will do writes. + rsrv.req = make(chan ioSrvReq) + go rsrv.ProcessRemoteIO() + wsrv.req = make(chan ioSrvReq) + go wsrv.ProcessRemoteIO() } } // Network file descriptor. type netFD struct { - // locking/lifetime of sysfd - sysmu sync.Mutex - sysref int - closing bool + // locking/lifetime of sysfd + serialize access to Read and Write methods + fdmu fdMutex // immutable until Close - sysfd syscall.Handle - family int - sotype int - isConnected bool - net string - laddr Addr - raddr Addr - resultc [2]chan ioResult // read/write completion results - errnoc [2]chan error // read/write submit or cancel operation errors - closec chan bool // used by Close to cancel pending IO + sysfd syscall.Handle + family int + sotype int + isConnected bool + skipSyncNotif bool + net string + laddr Addr + raddr Addr - // serialize access to Read and Write methods - rio, wio sync.Mutex + rop operation // read operation + wop operation // write operation - // read and write deadlines - rdeadline, wdeadline deadline + // wait server + pd pollDesc } -func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD { - netfd := &netFD{ - sysfd: fd, - family: family, - sotype: sotype, - net: net, - closec: make(chan bool), - } - return netfd -} - -func newFD(fd syscall.Handle, family, proto int, net string) (*netFD, error) { +func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) { if initErr != nil { return nil, initErr } onceStartServer.Do(startServer) - // Associate our socket with resultsrv.iocp. - if _, err := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); err != nil { - return nil, err + return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil +} + +func (fd *netFD) init() error { + if err := fd.pd.Init(fd); err != nil { + return err + } + if hasLoadSetFileCompletionNotificationModes { + // We do not use events, so we can skip them always. + flags := uint8(syscall.FILE_SKIP_SET_EVENT_ON_HANDLE) + // It's not safe to skip completion notifications for UDP: + // http://blogs.technet.com/b/winserverperformance/archive/2008/06/26/designing-applications-for-high-performance-part-iii.aspx + if skipSyncNotif && fd.net == "tcp" { + flags |= syscall.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS + } + err := syscall.SetFileCompletionNotificationModes(fd.sysfd, flags) + if err == nil && flags&syscall.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS != 0 { + fd.skipSyncNotif = true + } } - return allocFD(fd, family, proto, net), nil + fd.rop.mode = 'r' + fd.wop.mode = 'w' + fd.rop.fd = fd + fd.wop.fd = fd + fd.rop.runtimeCtx = fd.pd.runtimeCtx + fd.wop.runtimeCtx = fd.pd.runtimeCtx + if !canCancelIO { + fd.rop.errc = make(chan error) + fd.wop.errc = make(chan error) + } + return nil } func (fd *netFD) setAddr(laddr, raddr Addr) { fd.laddr = laddr fd.raddr = raddr - runtime.SetFinalizer(fd, (*netFD).closesocket) -} - -// Make new connection. - -type connectOp struct { - anOp - ra syscall.Sockaddr -} - -func (o *connectOp) Submit() error { - return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o) -} - -func (o *connectOp) Name() string { - return "ConnectEx" + runtime.SetFinalizer(fd, (*netFD).Close) } func (fd *netFD) connect(la, ra syscall.Sockaddr) error { + // Do not need to call fd.writeLock here, + // because fd is not yet accessible to user, + // so no concurrent operations are possible. if !canUseConnectEx(fd.net) { return syscall.Connect(fd.sysfd, ra) } @@ -383,10 +335,11 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr) error { } } // Call ConnectEx API. - var o connectOp - o.Init(fd, 'w') - o.ra = ra - _, err := iosrv.ExecIO(&o, fd.wdeadline.value()) + o := &fd.wop + o.sa = ra + _, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error { + return syscall.ConnectEx(o.fd.sysfd, o.sa, nil, 0, nil, &o.o) + }) if err != nil { return err } @@ -394,61 +347,80 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr) error { return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) } +func (fd *netFD) destroy() { + if fd.sysfd == syscall.InvalidHandle { + return + } + // Poller may want to unregister fd in readiness notification mechanism, + // so this must be executed before closesocket. + fd.pd.Close() + closesocket(fd.sysfd) + fd.sysfd = syscall.InvalidHandle + // no need for a finalizer anymore + runtime.SetFinalizer(fd, nil) +} + // Add a reference to this fd. -// If closing==true, mark the fd as closing. // Returns an error if the fd cannot be used. -func (fd *netFD) incref(closing bool) error { - if fd == nil { +func (fd *netFD) incref() error { + if !fd.fdmu.Incref() { return errClosing } - fd.sysmu.Lock() - if fd.closing { - fd.sysmu.Unlock() - return errClosing + return nil +} + +// Remove a reference to this FD and close if we've been asked to do so +// (and there are no references left). +func (fd *netFD) decref() { + if fd.fdmu.Decref() { + fd.destroy() } - fd.sysref++ - if closing { - fd.closing = true +} + +// Add a reference to this fd and lock for reading. +// Returns an error if the fd cannot be used. +func (fd *netFD) readLock() error { + if !fd.fdmu.RWLock(true) { + return errClosing } - closing = fd.closing - fd.sysmu.Unlock() return nil } -// Remove a reference to this FD and close if we've been asked to do so (and -// there are no references left. -func (fd *netFD) decref() { - if fd == nil { - return +// Unlock for reading and remove a reference to this FD. +func (fd *netFD) readUnlock() { + if fd.fdmu.RWUnlock(true) { + fd.destroy() + } +} + +// Add a reference to this fd and lock for writing. +// Returns an error if the fd cannot be used. +func (fd *netFD) writeLock() error { + if !fd.fdmu.RWLock(false) { + return errClosing } - fd.sysmu.Lock() - fd.sysref-- - if fd.closing && fd.sysref == 0 && fd.sysfd != syscall.InvalidHandle { - closesocket(fd.sysfd) - fd.sysfd = syscall.InvalidHandle - // no need for a finalizer anymore - runtime.SetFinalizer(fd, nil) + return nil +} + +// Unlock for writing and remove a reference to this FD. +func (fd *netFD) writeUnlock() { + if fd.fdmu.RWUnlock(false) { + fd.destroy() } - fd.sysmu.Unlock() } func (fd *netFD) Close() error { - if err := fd.incref(true); err != nil { - return err + if !fd.fdmu.IncrefAndClose() { + return errClosing } - defer fd.decref() // unblock pending reader and writer - close(fd.closec) - // wait for both reader and writer to exit - fd.rio.Lock() - defer fd.rio.Unlock() - fd.wio.Lock() - defer fd.wio.Unlock() + fd.pd.Evict() + fd.decref() return nil } func (fd *netFD) shutdown(how int) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() @@ -467,72 +439,42 @@ func (fd *netFD) CloseWrite() error { return fd.shutdown(syscall.SHUT_WR) } -func (fd *netFD) closesocket() error { - return closesocket(fd.sysfd) -} - -// Read from network. - -type readOp struct { - bufOp -} - -func (o *readOp) Submit() error { - var d, f uint32 - return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil) -} - -func (o *readOp) Name() string { - return "WSARecv" -} - func (fd *netFD) Read(buf []byte) (int, error) { - if err := fd.incref(false); err != nil { + if err := fd.readLock(); err != nil { return 0, err } - defer fd.decref() - fd.rio.Lock() - defer fd.rio.Unlock() - var o readOp - o.Init(fd, buf, 'r') - n, err := iosrv.ExecIO(&o, fd.rdeadline.value()) + defer fd.readUnlock() + o := &fd.rop + o.InitBuf(buf) + n, err := rsrv.ExecIO(o, "WSARecv", func(o *operation) error { + return syscall.WSARecv(o.fd.sysfd, &o.buf, 1, &o.qty, &o.flags, &o.o, nil) + }) if err == nil && n == 0 { err = io.EOF } + if raceenabled { + raceAcquire(unsafe.Pointer(&ioSync)) + } return n, err } -// ReadFrom from network. - -type readFromOp struct { - bufOp - rsa syscall.RawSockaddrAny - rsan int32 -} - -func (o *readFromOp) Submit() error { - var d, f uint32 - return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil) -} - -func (o *readFromOp) Name() string { - return "WSARecvFrom" -} - func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) { if len(buf) == 0 { return 0, nil, nil } - if err := fd.incref(false); err != nil { + if err := fd.readLock(); err != nil { return 0, nil, err } - defer fd.decref() - fd.rio.Lock() - defer fd.rio.Unlock() - var o readFromOp - o.Init(fd, buf, 'r') - o.rsan = int32(unsafe.Sizeof(o.rsa)) - n, err = iosrv.ExecIO(&o, fd.rdeadline.value()) + defer fd.readUnlock() + o := &fd.rop + o.InitBuf(buf) + n, err = rsrv.ExecIO(o, "WSARecvFrom", func(o *operation) error { + if o.rsa == nil { + o.rsa = new(syscall.RawSockaddrAny) + } + o.rsan = int32(unsafe.Sizeof(*o.rsa)) + return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &o.qty, &o.flags, o.rsa, &o.rsan, &o.o, nil) + }) if err != nil { return 0, nil, err } @@ -540,89 +482,42 @@ func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) { return } -// Write to network. - -type writeOp struct { - bufOp -} - -func (o *writeOp) Submit() error { - var d uint32 - return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil) -} - -func (o *writeOp) Name() string { - return "WSASend" -} - func (fd *netFD) Write(buf []byte) (int, error) { - if err := fd.incref(false); err != nil { + if err := fd.writeLock(); err != nil { return 0, err } - defer fd.decref() - fd.wio.Lock() - defer fd.wio.Unlock() - var o writeOp - o.Init(fd, buf, 'w') - return iosrv.ExecIO(&o, fd.wdeadline.value()) -} - -// WriteTo to network. - -type writeToOp struct { - bufOp - sa syscall.Sockaddr -} - -func (o *writeToOp) Submit() error { - var d uint32 - return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil) -} - -func (o *writeToOp) Name() string { - return "WSASendto" + defer fd.writeUnlock() + if raceenabled { + raceReleaseMerge(unsafe.Pointer(&ioSync)) + } + o := &fd.wop + o.InitBuf(buf) + return wsrv.ExecIO(o, "WSASend", func(o *operation) error { + return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &o.qty, 0, &o.o, nil) + }) } func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (int, error) { if len(buf) == 0 { return 0, nil } - if err := fd.incref(false); err != nil { + if err := fd.writeLock(); err != nil { return 0, err } - defer fd.decref() - fd.wio.Lock() - defer fd.wio.Unlock() - var o writeToOp - o.Init(fd, buf, 'w') + defer fd.writeUnlock() + o := &fd.wop + o.InitBuf(buf) o.sa = sa - return iosrv.ExecIO(&o, fd.wdeadline.value()) -} - -// Accept new network connections. - -type acceptOp struct { - anOp - newsock syscall.Handle - attrs [2]syscall.RawSockaddrAny // space for local and remote address only -} - -func (o *acceptOp) Submit() error { - var d uint32 - l := uint32(unsafe.Sizeof(o.attrs[0])) - return syscall.AcceptEx(o.fd.sysfd, o.newsock, - (*byte)(unsafe.Pointer(&o.attrs[0])), 0, l, l, &d, &o.o) -} - -func (o *acceptOp) Name() string { - return "AcceptEx" + return wsrv.ExecIO(o, "WSASendto", func(o *operation) error { + return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &o.qty, 0, o.sa, &o.o, nil) + }) } func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) { - if err := fd.incref(false); err != nil { + if err := fd.readLock(); err != nil { return nil, err } - defer fd.decref() + defer fd.readUnlock() // Get new socket. s, err := sysSocket(fd.family, fd.sotype, 0) @@ -631,43 +526,67 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) { } // Associate our new socket with IOCP. - onceStartServer.Do(startServer) - if _, err := syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); err != nil { + netfd, err := newFD(s, fd.family, fd.sotype, fd.net) + if err != nil { closesocket(s) - return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, err} + return nil, &OpError{"accept", fd.net, fd.laddr, err} + } + if err := netfd.init(); err != nil { + fd.Close() + return nil, err } // Submit accept request. - var o acceptOp - o.Init(fd, 'r') - o.newsock = s - _, err = iosrv.ExecIO(&o, fd.rdeadline.value()) + o := &fd.rop + o.handle = s + var rawsa [2]syscall.RawSockaddrAny + o.rsan = int32(unsafe.Sizeof(rawsa[0])) + _, err = rsrv.ExecIO(o, "AcceptEx", func(o *operation) error { + return syscall.AcceptEx(o.fd.sysfd, o.handle, (*byte)(unsafe.Pointer(&rawsa[0])), 0, uint32(o.rsan), uint32(o.rsan), &o.qty, &o.o) + }) if err != nil { - closesocket(s) + netfd.Close() return nil, err } // Inherit properties of the listening socket. err = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) if err != nil { - closesocket(s) + netfd.Close() return nil, &OpError{"Setsockopt", fd.net, fd.laddr, err} } // Get local and peer addr out of AcceptEx buffer. var lrsa, rrsa *syscall.RawSockaddrAny var llen, rlen int32 - l := uint32(unsafe.Sizeof(*lrsa)) - syscall.GetAcceptExSockaddrs((*byte)(unsafe.Pointer(&o.attrs[0])), - 0, l, l, &lrsa, &llen, &rrsa, &rlen) + syscall.GetAcceptExSockaddrs((*byte)(unsafe.Pointer(&rawsa[0])), + 0, uint32(o.rsan), uint32(o.rsan), &lrsa, &llen, &rrsa, &rlen) lsa, _ := lrsa.Sockaddr() rsa, _ := rrsa.Sockaddr() - netfd := allocFD(s, fd.family, fd.sotype, fd.net) netfd.setAddr(toAddr(lsa), toAddr(rsa)) return netfd, nil } +func skipRawSocketTests() (skip bool, skipmsg string, err error) { + // From http://msdn.microsoft.com/en-us/library/windows/desktop/ms740548.aspx: + // Note: To use a socket of type SOCK_RAW requires administrative privileges. + // Users running Winsock applications that use raw sockets must be a member of + // the Administrators group on the local computer, otherwise raw socket calls + // will fail with an error code of WSAEACCES. On Windows Vista and later, access + // for raw sockets is enforced at socket creation. In earlier versions of Windows, + // access for raw sockets is enforced during other socket operations. + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, 0) + if err == syscall.WSAEACCES { + return true, "skipping test; no access to raw socket allowed", nil + } + if err != nil { + return true, "", err + } + defer syscall.Closesocket(s) + return false, "", nil +} + // Unimplemented functions. func (fd *netFD) dup() (*os.File, error) { diff --git a/libgo/go/net/file_unix.go b/libgo/go/net/file_unix.go index 4c8403e4063..8fe1b0eb035 100644 --- a/libgo/go/net/file_unix.go +++ b/libgo/go/net/file_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd package net @@ -12,14 +12,11 @@ import ( ) func newFileFD(f *os.File) (*netFD, error) { - syscall.ForkLock.RLock() - fd, err := syscall.Dup(int(f.Fd())) + fd, err := dupCloseOnExec(int(f.Fd())) if err != nil { - syscall.ForkLock.RUnlock() return nil, os.NewSyscallError("dup", err) } - syscall.CloseOnExec(fd) - syscall.ForkLock.RUnlock() + if err = syscall.SetNonblock(fd, true); err != nil { closesocket(fd) return nil, err @@ -70,6 +67,10 @@ func newFileFD(f *os.File) (*netFD, error) { closesocket(fd) return nil, err } + if err := netfd.init(); err != nil { + netfd.Close() + return nil, err + } netfd.setAddr(laddr, raddr) return netfd, nil } diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go index 100b8b77760..45fc2e57cd7 100644 --- a/libgo/go/net/http/cgi/child.go +++ b/libgo/go/net/http/cgi/child.go @@ -100,10 +100,21 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { uriStr += "?" + s } } + + // There's apparently a de-facto standard for this. + // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { + r.TLS = &tls.ConnectionState{HandshakeComplete: true} + } + if r.Host != "" { - // Hostname is provided, so we can reasonably construct a URL, - // even if we have to assume 'http' for the scheme. - rawurl := "http://" + r.Host + uriStr + // Hostname is provided, so we can reasonably construct a URL. + rawurl := r.Host + uriStr + if r.TLS == nil { + rawurl = "http://" + rawurl + } else { + rawurl = "https://" + rawurl + } url, err := url.Parse(rawurl) if err != nil { return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl) @@ -120,12 +131,6 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { r.URL = url } - // There's apparently a de-facto standard for this. - // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 - if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { - r.TLS = &tls.ConnectionState{HandshakeComplete: true} - } - // Request.RemoteAddr has its port set by Go's standard http // server, so we do here too. We don't have one, though, so we // use a dummy one. diff --git a/libgo/go/net/http/cgi/child_test.go b/libgo/go/net/http/cgi/child_test.go index 74e068014bb..075d8411bcf 100644 --- a/libgo/go/net/http/cgi/child_test.go +++ b/libgo/go/net/http/cgi/child_test.go @@ -21,7 +21,6 @@ func TestRequest(t *testing.T) { "REQUEST_URI": "/path?a=b", "CONTENT_LENGTH": "123", "CONTENT_TYPE": "text/xml", - "HTTPS": "1", "REMOTE_ADDR": "5.6.7.8", } req, err := RequestFromMap(env) @@ -58,14 +57,37 @@ func TestRequest(t *testing.T) { if req.Trailer == nil { t.Errorf("unexpected nil Trailer") } - if req.TLS == nil { - t.Errorf("expected non-nil TLS") + if req.TLS != nil { + t.Errorf("expected nil TLS") } if e, g := "5.6.7.8:0", req.RemoteAddr; e != g { t.Errorf("RemoteAddr: got %q; want %q", g, e) } } +func TestRequestWithTLS(t *testing.T) { + env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", + "REQUEST_METHOD": "GET", + "HTTP_HOST": "example.com", + "HTTP_REFERER": "elsewhere", + "REQUEST_URI": "/path?a=b", + "CONTENT_TYPE": "text/xml", + "HTTPS": "1", + "REMOTE_ADDR": "5.6.7.8", + } + req, err := RequestFromMap(env) + if err != nil { + t.Fatalf("RequestFromMap: %v", err) + } + if g, e := req.URL.String(), "https://example.com/path?a=b"; e != g { + t.Errorf("expected URL %q; got %q", e, g) + } + if req.TLS == nil { + t.Errorf("expected non-nil TLS") + } +} + func TestRequestWithoutHost(t *testing.T) { env := map[string]string{ "SERVER_PROTOCOL": "HTTP/1.1", diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index a34d47be1fa..22f2e865cf7 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -74,8 +74,8 @@ type RoundTripper interface { // authentication, or cookies. // // RoundTrip should not modify the request, except for - // consuming the Body. The request's URL and Header fields - // are guaranteed to be initialized. + // consuming and closing the Body. The request's URL and + // Header fields are guaranteed to be initialized. RoundTrip(*Request) (*Response, error) } @@ -161,7 +161,9 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { } if u := req.URL.User; u != nil { - req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String()))) + username := u.Username() + password, _ := u.Password() + req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) } resp, err = t.RoundTrip(req) if err != nil { @@ -173,6 +175,16 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { return resp, nil } +// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + // True if the specified HTTP status code is one for which the Get utility should // automatically redirect. func shouldRedirectGet(statusCode int) bool { @@ -335,6 +347,9 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro // Post issues a POST to the specified URL. // // Caller should close resp.Body when done reading from it. +// +// If the provided body is also an io.Closer, it is closed after the +// body is successfully written to the server. func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 73f1fe3c10a..997d04151c2 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -10,6 +10,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "encoding/base64" "errors" "fmt" "io" @@ -665,6 +666,36 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { } } +// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName +// when not empty. +// +// tls.Config.ServerName (non-empty, set to "example.com") takes +// precedence over "some-other-host.tld" which previously incorrectly +// took precedence. We don't actually connect to (or even resolve) +// "some-other-host.tld", though, because of the Transport.Dial hook. +// +// The httptest.Server has a cert with "example.com" as its name. +func TestTransportUsesTLSConfigServerName(t *testing.T) { + defer afterTest(t) + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello")) + })) + defer ts.Close() + + tr := newTLSTransport(t, ts) + tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get("https://some-other-host.tld/") + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + // Verify Response.ContentLength is populated. http://golang.org/issue/4126 func TestClientHeadContentLength(t *testing.T) { defer afterTest(t) @@ -700,3 +731,71 @@ func TestClientHeadContentLength(t *testing.T) { } } } + +func TestEmptyPasswordAuth(t *testing.T) { + defer afterTest(t) + gopher := "gopher" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Basic ") { + encoded := auth[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatal(err) + } + expected := gopher + ":" + s := string(decoded) + if expected != s { + t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected) + } + } else { + t.Errorf("Invalid auth %q", auth) + } + })) + defer ts.Close() + c := &Client{} + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + req.URL.User = url.User(gopher) + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() +} + +func TestBasicAuth(t *testing.T) { + defer afterTest(t) + tr := &recordingTransport{} + client := &Client{Transport: tr} + + url := "http://My%20User:My%20Pass@dummy.faketld/" + expected := "My User:My Pass" + client.Get(url) + + if tr.req.Method != "GET" { + t.Errorf("got method %q, want %q", tr.req.Method, "GET") + } + if tr.req.URL.String() != url { + t.Errorf("got URL %q, want %q", tr.req.URL.String(), url) + } + if tr.req.Header == nil { + t.Fatalf("expected non-nil request Header") + } + auth := tr.req.Header.Get("Authorization") + if strings.HasPrefix(auth, "Basic ") { + encoded := auth[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatal(err) + } + s := string(decoded) + if expected != s { + t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected) + } + } else { + t.Errorf("Invalid auth %q", auth) + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 155b09223e4..8b01c508eb1 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -7,6 +7,8 @@ package http import ( "bytes" "fmt" + "log" + "net" "strconv" "strings" "time" @@ -139,12 +141,25 @@ func SetCookie(w ResponseWriter, cookie *Cookie) { // header (if other fields are set). func (c *Cookie) String() string { var b bytes.Buffer - fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) + fmt.Fprintf(&b, "%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) + fmt.Fprintf(&b, "; Path=%s", sanitizeCookiePath(c.Path)) } if len(c.Domain) > 0 { - fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) + if validCookieDomain(c.Domain) { + // A c.Domain containing illegal characters is not + // sanitized but simply dropped which turns the cookie + // into a host-only cookie. A leading dot is okay + // but won't be sent. + d := c.Domain + if d[0] == '.' { + d = d[1:] + } + fmt.Fprintf(&b, "; Domain=%s", d) + } else { + log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", + c.Domain) + } } if c.Expires.Unix() > 0 { fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(time.RFC1123)) @@ -207,16 +222,122 @@ func readCookies(h Header, filter string) []*Cookie { return cookies } +// validCookieDomain returns wheter v is a valid cookie domain-value. +func validCookieDomain(v string) bool { + if isCookieDomainName(v) { + return true + } + if net.ParseIP(v) != nil && !strings.Contains(v, ":") { + return true + } + return false +} + +// isCookieDomainName returns whether s is a valid domain name or a valid +// domain name with a leading dot '.'. It is almost a direct copy of +// package net's isDomainName. +func isCookieDomainName(s string) bool { + if len(s) == 0 { + return false + } + if len(s) > 255 { + return false + } + + if s[0] == '.' { + // A cookie a domain attribute may start with a leading dot. + s = s[1:] + } + last := byte('.') + ok := false // Ok once we've seen a letter. + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // No '_' allowed here (in contrast to package net). + ok = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // Byte before dash cannot be dot. + if last == '.' { + return false + } + partlen++ + case c == '.': + // Byte before dot cannot be dot, dash. + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + + return ok +} + var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") -func sanitizeName(n string) string { +func sanitizeCookieName(n string) string { return cookieNameSanitizer.Replace(n) } -var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") +// http://tools.ietf.org/html/rfc6265#section-4.1.1 +// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE ) +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +func sanitizeCookieValue(v string) string { + return sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) +} + +func validCookieValueByte(b byte) bool { + return 0x20 < b && b < 0x7f && b != '"' && b != ',' && b != ';' && b != '\\' +} + +// path-av = "Path=" path-value +// path-value = <any CHAR except CTLs or ";"> +func sanitizeCookiePath(v string) string { + return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v) +} -func sanitizeValue(v string) string { - return cookieValueSanitizer.Replace(v) +func validCookiePathByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != ';' +} + +func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { + ok := true + for i := 0; i < len(v); i++ { + if valid(v[i]) { + continue + } + log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName) + ok = false + break + } + if ok { + return v + } + buf := make([]byte, 0, len(v)) + for i := 0; i < len(v); i++ { + if b := v[i]; valid(b) { + buf = append(buf, b) + } + } + return string(buf) } func unquoteCookieValue(v string) string { diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index f84f73936c7..11b01cc5713 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -26,12 +26,28 @@ var writeSetCookiesTests = []struct { }, { &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, - "cookie-3=three; Domain=.example.com", + "cookie-3=three; Domain=example.com", }, { &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, "cookie-4=four; Path=/restricted/", }, + { + &Cookie{Name: "cookie-5", Value: "five", Domain: "wrong;bad.abc"}, + "cookie-5=five", + }, + { + &Cookie{Name: "cookie-6", Value: "six", Domain: "bad-.abc"}, + "cookie-6=six", + }, + { + &Cookie{Name: "cookie-7", Value: "seven", Domain: "127.0.0.1"}, + "cookie-7=seven; Domain=127.0.0.1", + }, + { + &Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"}, + "cookie-8=eight", + }, } func TestWriteSetCookies(t *testing.T) { @@ -226,3 +242,34 @@ func TestReadCookies(t *testing.T) { } } } + +func TestCookieSanitizeValue(t *testing.T) { + tests := []struct { + in, want string + }{ + {"foo", "foo"}, + {"foo bar", "foobar"}, + {"\x00\x7e\x7f\x80", "\x7e"}, + {`"withquotes"`, "withquotes"}, + } + for _, tt := range tests { + if got := sanitizeCookieValue(tt.in); got != tt.want { + t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want) + } + } +} + +func TestCookieSanitizePath(t *testing.T) { + tests := []struct { + in, want string + }{ + {"/path", "/path"}, + {"/path with space/", "/path with space/"}, + {"/just;no;semicolon\x00orstuff/", "/justnosemicolonorstuff/"}, + } + for _, tt := range tests { + if got := sanitizeCookiePath(tt.in); got != tt.want { + t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want) + } + } +} diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go index 5977d48b631..389ab58e418 100644 --- a/libgo/go/net/http/cookiejar/jar.go +++ b/libgo/go/net/http/cookiejar/jar.go @@ -142,7 +142,7 @@ func (e *entry) pathMatch(requestPath string) bool { return false } -// hasDotSuffix returns whether s ends in "."+suffix. +// hasDotSuffix reports whether s ends in "."+suffix. func hasDotSuffix(s, suffix string) bool { return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix } @@ -316,7 +316,7 @@ func canonicalHost(host string) (string, error) { return toASCII(host) } -// hasPort returns whether host contains a port number. host may be a host +// hasPort reports whether host contains a port number. host may be a host // name, an IPv4 or an IPv6 address. func hasPort(host string) bool { colons := strings.Count(host, ":") @@ -357,7 +357,7 @@ func jarKey(host string, psl PublicSuffixList) string { return host[prevDot+1:] } -// isIP returns whether host is an IP address. +// isIP reports whether host is an IP address. func isIP(host string) bool { return net.ParseIP(host) != nil } @@ -380,7 +380,7 @@ func defaultPath(path string) string { // is compared to c.Expires to determine deletion of c. defPath and host are the // default-path and the canonical host name of the URL c was received from. // -// remove is whether the jar should delete this cookie, as it has already +// remove records whether the jar should delete this cookie, as it has already // expired with respect to now. In this case, e may be incomplete, but it will // be valid to call e.id (which depends on e's Name, Domain and Path). // diff --git a/libgo/go/net/http/doc.go b/libgo/go/net/http/doc.go index b6ae8b87a2f..b1216e8dafa 100644 --- a/libgo/go/net/http/doc.go +++ b/libgo/go/net/http/doc.go @@ -5,7 +5,7 @@ /* Package http provides HTTP client and server implementations. -Get, Head, Post, and PostForm make HTTP requests: +Get, Head, Post, and PostForm make HTTP (or HTTPS) requests: resp, err := http.Get("http://example.com/") ... diff --git a/libgo/go/net/http/example_test.go b/libgo/go/net/http/example_test.go index bc60df7f2b5..88b97d9e3d7 100644 --- a/libgo/go/net/http/example_test.go +++ b/libgo/go/net/http/example_test.go @@ -68,3 +68,21 @@ func ExampleStripPrefix() { // URL's path before the FileServer sees it: http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) } + +type apiHandler struct{} + +func (apiHandler) ServeHTTP(http.ResponseWriter, *http.Request) {} + +func ExampleServeMux_Handle() { + mux := http.NewServeMux() + mux.Handle("/api/", apiHandler{}) + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + // The "/" pattern matches everything, so we need to check + // that we're at the root here. + if req.URL.Path != "/" { + http.NotFound(w, req) + return + } + fmt.Fprintf(w, "Welcome to the home page!") + }) +} diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 3fc24532676..22b7f279689 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -16,6 +16,8 @@ func NewLoggingConn(baseName string, c net.Conn) net.Conn { return newLoggingConn(baseName, c) } +var ExportAppendTime = appendTime + func (t *Transport) NumPendingRequestsForTesting() int { t.reqMu.Lock() defer t.reqMu.Unlock() @@ -48,6 +50,12 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int { return len(conns) } +func (t *Transport) IdleConnChMapSizeForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConnCh) +} + func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { f := func() <-chan time.Time { return ch diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index b6bea0dfaad..8b32ca1d0ea 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -105,23 +105,31 @@ func dirList(w ResponseWriter, f File) { // // Note that *os.File implements the io.ReadSeeker interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { - size, err := content.Seek(0, os.SEEK_END) - if err != nil { - Error(w, "seeker can't seek", StatusInternalServerError) - return - } - _, err = content.Seek(0, os.SEEK_SET) - if err != nil { - Error(w, "seeker can't seek", StatusInternalServerError) - return + sizeFunc := func() (int64, error) { + size, err := content.Seek(0, os.SEEK_END) + if err != nil { + return 0, errSeeker + } + _, err = content.Seek(0, os.SEEK_SET) + if err != nil { + return 0, errSeeker + } + return size, nil } - serveContent(w, req, name, modtime, size, content) + serveContent(w, req, name, modtime, sizeFunc, content) } +// errSeeker is returned by ServeContent's sizeFunc when the content +// doesn't seek properly. The underlying Seeker's error text isn't +// included in the sizeFunc reply so it's not sent over HTTP to end +// users. +var errSeeker = errors.New("seeker can't seek") + // if name is empty, filename is unknown. (used for mime type, before sniffing) // if modtime.IsZero(), modtime is unknown. // content must be seeked to the beginning of the file. -func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, size int64, content io.ReadSeeker) { +// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response. +func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) { if checkLastModified(w, r, modtime) { return } @@ -132,16 +140,17 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, code := StatusOK - // If Content-Type isn't set, use the file's extension to find it. - ctype := w.Header().Get("Content-Type") - if ctype == "" { + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + ctypes, haveType := w.Header()["Content-Type"] + var ctype string + if !haveType { ctype = mime.TypeByExtension(filepath.Ext(name)) if ctype == "" { // read a chunk to decide between utf-8 text and binary - var buf [1024]byte + var buf [sniffLen]byte n, _ := io.ReadFull(content, buf[:]) - b := buf[:n] - ctype = DetectContentType(b) + ctype = DetectContentType(buf[:n]) _, err := content.Seek(0, os.SEEK_SET) // rewind to output whole file if err != nil { Error(w, "seeker can't seek", StatusInternalServerError) @@ -149,6 +158,14 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } } w.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size, err := sizeFunc() + if err != nil { + Error(w, err.Error(), StatusInternalServerError) + return } // handle Content-Range header. @@ -160,7 +177,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } - if sumRangesSize(ranges) >= size { + if sumRangesSize(ranges) > size { // The total number of bytes in all the ranges // is larger than the size of the file by // itself, so this is probably an attack, or a @@ -378,7 +395,8 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec } // serverContent will check modification time - serveContent(w, r, d.Name(), d.ModTime(), d.Size(), f) + sizeFunc := func() (int64, error) { return d.Size(), nil } + serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f) } // localRedirect gives a Moved Permanently response. diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index d38966764b1..dd3e9fefeac 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -20,8 +20,10 @@ import ( "os/exec" "path" "path/filepath" + "reflect" "regexp" "runtime" + "strconv" "strings" "testing" "time" @@ -36,6 +38,8 @@ type wantRange struct { start, end int64 // range [start,end) } +var itoa = strconv.Itoa + var ServeFileRangeTests = []struct { r string code int @@ -50,7 +54,11 @@ var ServeFileRangeTests = []struct { {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}}, {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}}, {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}}, + {r: "bytes=5-1000", code: StatusPartialContent, ranges: []wantRange{{5, testFileLen}}}, {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request + {r: "bytes=0-" + itoa(testFileLen-2), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen - 1}}}, + {r: "bytes=0-" + itoa(testFileLen-1), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}}, + {r: "bytes=0-" + itoa(testFileLen), code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}}, } func TestServeFile(t *testing.T) { @@ -259,6 +267,9 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) { } func TestDirJoin(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping test on windows") + } wfi, err := os.Stat("/etc/hosts") if err != nil { t.Skip("skipping test; no /etc/hosts file") @@ -309,24 +320,29 @@ func TestServeFileContentType(t *testing.T) { defer afterTest(t) const ctype = "icecream/chocolate" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - if r.FormValue("override") == "1" { + switch r.FormValue("override") { + case "1": w.Header().Set("Content-Type", ctype) + case "2": + // Explicitly inhibit sniffing. + w.Header()["Content-Type"] = []string{} } ServeFile(w, r, "testdata/file") })) defer ts.Close() - get := func(override, want string) { + get := func(override string, want []string) { resp, err := Get(ts.URL + "?override=" + override) if err != nil { t.Fatal(err) } - if h := resp.Header.Get("Content-Type"); h != want { - t.Errorf("Content-Type mismatch: got %q, want %q", h, want) + if h := resp.Header["Content-Type"]; !reflect.DeepEqual(h, want) { + t.Errorf("Content-Type mismatch: got %v, want %v", h, want) } resp.Body.Close() } - get("0", "text/plain; charset=utf-8") - get("1", ctype) + get("0", []string{"text/plain; charset=utf-8"}) + get("1", []string{ctype}) + get("2", nil) } func TestServeFileMimeType(t *testing.T) { @@ -567,7 +583,10 @@ func TestServeContent(t *testing.T) { defer ts.Close() type testCase struct { - file string + // One of file or content must be set: + file string + content io.ReadSeeker + modtime time.Time serveETag string // optional serveContentType string // optional @@ -615,6 +634,14 @@ func TestServeContent(t *testing.T) { }, wantStatus: 304, }, + "not_modified_etag_no_seek": { + content: panicOnSeek{nil}, // should never be called + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"foo"`, + }, + wantStatus: 304, + }, "range_good": { file: "testdata/style.css", serveETag: `"A"`, @@ -638,15 +665,21 @@ func TestServeContent(t *testing.T) { }, } for testName, tt := range tests { - f, err := os.Open(tt.file) - if err != nil { - t.Fatalf("test %q: %v", testName, err) + var content io.ReadSeeker + if tt.file != "" { + f, err := os.Open(tt.file) + if err != nil { + t.Fatalf("test %q: %v", testName, err) + } + defer f.Close() + content = f + } else { + content = tt.content } - defer f.Close() servec <- serveParam{ name: filepath.Base(tt.file), - content: f, + content: content, modtime: tt.modtime, etag: tt.serveETag, contentType: tt.serveContentType, @@ -768,3 +801,5 @@ func TestLinuxSendfileChild(*testing.T) { panic(err) } } + +type panicOnSeek struct{ io.ReadSeeker } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 6374237fba1..ca1ae07c25d 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -173,7 +173,7 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { // canonical key for "accept-encoding" is "Accept-Encoding". func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } -// hasToken returns whether token appears with v, ASCII +// hasToken reports whether token appears with v, ASCII // case-insensitive, with space or comma boundaries. // token must be all lowercase. // v may contain mixed cased. diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go index 584f1005440..2c896c5ad23 100644 --- a/libgo/go/net/http/header_test.go +++ b/libgo/go/net/http/header_test.go @@ -193,6 +193,9 @@ func BenchmarkHeaderWriteSubset(b *testing.B) { } func TestHeaderWriteSubsetMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } t.Skip("Skipping alloc count test on gccgo") if runtime.GOMAXPROCS(0) > 1 { t.Skip("skipping; GOMAXPROCS>1") @@ -202,6 +205,6 @@ func TestHeaderWriteSubsetMallocs(t *testing.T) { testHeader.WriteSubset(&buf, nil) }) if n > 0 { - t.Errorf("mallocs = %d; want 0", n) + t.Errorf("mallocs = %g; want 0", n) } } diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 6d4569146fd..57b5d094847 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -10,7 +10,6 @@ import ( "bufio" "bytes" "crypto/tls" - "encoding/base64" "errors" "fmt" "io" @@ -106,7 +105,16 @@ type Request struct { // following a hyphen uppercase and the rest lowercase. Header Header - // The message body. + // Body is the request's body. + // + // For client requests, a nil body means the request has no + // body, such as a GET request. The HTTP Client's Transport + // is responsible for calling the Close method. + // + // For server requests, the Request Body is always non-nil + // but will return EOF immediately when no body is present. + // The Server will close the request body. The ServeHTTP + // Handler does not need to. Body io.ReadCloser // ContentLength records the length of the associated content. @@ -183,7 +191,7 @@ type Request struct { TLS *tls.ConnectionState } -// ProtoAtLeast returns whether the HTTP protocol used +// ProtoAtLeast reports whether the HTTP protocol used // in the request is at least major.minor. func (r *Request) ProtoAtLeast(major, minor int) bool { return r.ProtoMajor > major || @@ -216,7 +224,7 @@ func (r *Request) Cookie(name string) (*Cookie, error) { // means all cookies, if any, are written into the same line, // separated by semicolon. func (r *Request) AddCookie(c *Cookie) { - s := fmt.Sprintf("%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) + s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) if c := r.Header.Get("Cookie"); c != "" { r.Header.Set("Cookie", c+"; "+s) } else { @@ -283,6 +291,11 @@ func valueOrDefault(value, def string) string { return def } +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed from "Go http package" to "Go 1.1 package http" at the +// time of the Go 1.1 release because the former User-Agent had ended up +// on a blacklist for some intrusion detection systems. +// See https://codereview.appspot.com/7532043. const defaultUserAgent = "Go 1.1 package http" // Write writes an HTTP/1.1 request -- header and body -- in wire format. @@ -424,6 +437,10 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) { } // NewRequest returns a new Request given a method, URL, and optional body. +// +// If the provided body is also an io.Closer, the returned +// Request.Body is set to body and will be closed by the Client +// methods Do, Post, and PostForm, and Transport.RoundTrip. func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { u, err := url.Parse(urlStr) if err != nil { @@ -463,8 +480,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { // With HTTP Basic Authentication the provided username and password // are not encrypted. func (r *Request) SetBasicAuth(username, password string) { - s := username + ":" + password - r.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) + r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) } // parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index 692485c49d9..89303c33602 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -332,7 +332,7 @@ func TestRequestWriteBufferedWriter(t *testing.T) { func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil { - t.Errorf("FormFile file = %q, want nil", f) + t.Errorf("FormFile file = %v, want nil", f) } if fh != nil { t.Errorf("FormFile file header = %q, want nil", fh) diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 9a7e4e319b0..35d0ba3bb15 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -32,7 +32,7 @@ type Response struct { ProtoMinor int // e.g. 0 // Header maps header keys to values. If the response had multiple - // headers with the same key, they will be concatenated, with comma + // headers with the same key, they may be concatenated, with comma // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers // be semantically equivalent to a comma-delimited sequence.) Values // duplicated by other fields in this struct (e.g., ContentLength) are @@ -98,18 +98,17 @@ func (r *Response) Location() (*url.URL, error) { return url.Parse(lv) } -// ReadResponse reads and returns an HTTP response from r. The -// req parameter specifies the Request that corresponds to -// this Response. Clients must call resp.Body.Close when finished -// reading resp.Body. After that call, clients can inspect -// resp.Trailer to find key/value pairs included in the response -// trailer. -func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) { - +// ReadResponse reads and returns an HTTP response from r. +// The req parameter optionally specifies the Request that corresponds +// to this Response. If nil, a GET request is assumed. +// Clients must call resp.Body.Close when finished reading resp.Body. +// After that call, clients can inspect resp.Trailer to find key/value +// pairs included in the response trailer. +func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { tp := textproto.NewReader(r) - resp = new(Response) - - resp.Request = req + resp := &Response{ + Request: req, + } // Parse the first line of the response. line, err := tp.ReadLine() @@ -168,7 +167,7 @@ func fixPragmaCacheControl(header Header) { } } -// ProtoAtLeast returns whether the HTTP protocol used +// ProtoAtLeast reports whether the HTTP protocol used // in the response is at least major.minor. func (r *Response) ProtoAtLeast(major, minor int) bool { return r.ProtoMajor > major || diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 02796e88b4c..5044306a876 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -348,6 +348,29 @@ some body`, "some body", }, + + // Unchunked response without Content-Length, Request is nil + { + "HTTP/1.0 200 OK\r\n" + + "Connection: close\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Header: Header{ + "Connection": {"close"}, // TODO(rsc): Delete? + }, + Close: true, + ContentLength: -1, + }, + + "Body here\n", + }, } func TestReadResponse(t *testing.T) { @@ -565,3 +588,42 @@ func TestResponseStatusStutter(t *testing.T) { t.Errorf("stutter in status: %s", buf.String()) } } + +func TestResponseContentLengthShortBody(t *testing.T) { + const shortBody = "Short body, not 123 bytes." + br := bufio.NewReader(strings.NewReader("HTTP/1.1 200 OK\r\n" + + "Content-Length: 123\r\n" + + "\r\n" + + shortBody)) + res, err := ReadResponse(br, &Request{Method: "GET"}) + if err != nil { + t.Fatal(err) + } + if res.ContentLength != 123 { + t.Fatalf("Content-Length = %d; want 123", res.ContentLength) + } + var buf bytes.Buffer + n, err := io.Copy(&buf, res.Body) + if n != int64(len(shortBody)) { + t.Errorf("Copied %d bytes; want %d, len(%q)", n, len(shortBody), shortBody) + } + if buf.String() != shortBody { + t.Errorf("Read body %q; want %q", buf.String(), shortBody) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("io.Copy error = %#v; want io.ErrUnexpectedEOF", err) + } +} + +func TestNeedsSniff(t *testing.T) { + // needsSniff returns true with an empty response. + r := &response{} + if got, want := r.needsSniff(), true; got != want { + t.Errorf("needsSniff = %t; want %t", got, want) + } + // needsSniff returns false when Content-Type = nil. + r.handlerHeader = Header{"Content-Type": nil} + if got, want := r.needsSniff(), false; got != want { + t.Errorf("needsSniff empty Content-Type = %t; want %t", got, want) + } +} diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index d7b321597c4..8961cf491f8 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -116,6 +116,34 @@ func (c *testConn) Close() error { return nil } +// reqBytes treats req as a request (with \n delimiters) and returns it with \r\n delimiters, +// ending in \r\n\r\n +func reqBytes(req string) []byte { + return []byte(strings.Replace(strings.TrimSpace(req), "\n", "\r\n", -1) + "\r\n\r\n") +} + +type handlerTest struct { + handler Handler +} + +func newHandlerTest(h Handler) handlerTest { + return handlerTest{h} +} + +func (ht handlerTest) rawResponse(req string) string { + reqb := reqBytes(req) + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(reqb), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + go Serve(ln, ht.handler) + <-conn.closec + return output.String() +} + func TestConsumingBodyOnNextConn(t *testing.T) { conn := new(testConn) for i := 0; i < 2; i++ { @@ -241,6 +269,152 @@ func TestHostHandlers(t *testing.T) { } } +var serveMuxRegister = []struct { + pattern string + h Handler +}{ + {"/dir/", serve(200)}, + {"/search", serve(201)}, + {"codesearch.google.com/search", serve(202)}, + {"codesearch.google.com/", serve(203)}, + {"example.com/", HandlerFunc(checkQueryStringHandler)}, +} + +// serve returns a handler that sends a response with the given code. +func serve(code int) HandlerFunc { + return func(w ResponseWriter, r *Request) { + w.WriteHeader(code) + } +} + +// checkQueryStringHandler checks if r.URL.RawQuery has the same value +// as the URL excluding the scheme and the query string and sends 200 +// response code if it is, 500 otherwise. +func checkQueryStringHandler(w ResponseWriter, r *Request) { + u := *r.URL + u.Scheme = "http" + u.Host = r.Host + u.RawQuery = "" + if "http://"+r.URL.RawQuery == u.String() { + w.WriteHeader(200) + } else { + w.WriteHeader(500) + } +} + +var serveMuxTests = []struct { + method string + host string + path string + code int + pattern string +}{ + {"GET", "google.com", "/", 404, ""}, + {"GET", "google.com", "/dir", 301, "/dir/"}, + {"GET", "google.com", "/dir/", 200, "/dir/"}, + {"GET", "google.com", "/dir/file", 200, "/dir/"}, + {"GET", "google.com", "/search", 201, "/search"}, + {"GET", "google.com", "/search/", 404, ""}, + {"GET", "google.com", "/search/foo", 404, ""}, + {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"}, + {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"}, + {"GET", "images.google.com", "/search", 201, "/search"}, + {"GET", "images.google.com", "/search/", 404, ""}, + {"GET", "images.google.com", "/search/foo", 404, ""}, + {"GET", "google.com", "/../search", 301, "/search"}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/./file", 301, "/dir/"}, + + // The /foo -> /foo/ redirect applies to CONNECT requests + // but the path canonicalization does not. + {"CONNECT", "google.com", "/dir", 301, "/dir/"}, + {"CONNECT", "google.com", "/../search", 404, ""}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"}, +} + +func TestServeMuxHandler(t *testing.T) { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + + for _, tt := range serveMuxTests { + r := &Request{ + Method: tt.method, + Host: tt.host, + URL: &url.URL{ + Path: tt.path, + }, + } + h, pattern := mux.Handler(r) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, r) + if pattern != tt.pattern || rr.Code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern) + } + } +} + +var serveMuxTests2 = []struct { + method string + host string + url string + code int + redirOk bool +}{ + {"GET", "google.com", "/", 404, false}, + {"GET", "example.com", "/test/?example.com/test/", 200, false}, + {"GET", "example.com", "test/?example.com/test/", 200, true}, +} + +// TestServeMuxHandlerRedirects tests that automatic redirects generated by +// mux.Handler() shouldn't clear the request's query string. +func TestServeMuxHandlerRedirects(t *testing.T) { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + + for _, tt := range serveMuxTests2 { + tries := 1 + turl := tt.url + for tries > 0 { + u, e := url.Parse(turl) + if e != nil { + t.Fatal(e) + } + r := &Request{ + Method: tt.method, + Host: tt.host, + URL: u, + } + h, _ := mux.Handler(r) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, r) + if rr.Code != 301 { + if rr.Code != tt.code { + t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code) + } + break + } + if !tt.redirOk { + t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url) + break + } + turl = rr.HeaderMap.Get("Location") + tries-- + } + if tries < 0 { + t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url) + } + } +} + // Tests for http://code.google.com/p/go/issues/detail?id=900 func TestMuxRedirectLeadingSlashes(t *testing.T) { paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} @@ -626,22 +800,20 @@ func Test304Responses(t *testing.T) { } } -// TestHeadResponses verifies that responses to HEAD requests don't -// declare that they're chunking in their response headers, aren't -// allowed to produce output, and don't set a Content-Type since -// the real type of the body data cannot be inferred. +// TestHeadResponses verifies that all MIME type sniffing and Content-Length +// counting of GET requests also happens on HEAD requests. func TestHeadResponses(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - _, err := w.Write([]byte("Ignored body")) - if err != ErrBodyNotAllowed { - t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + _, err := w.Write([]byte("<html>")) + if err != nil { + t.Errorf("ResponseWriter.Write: %v", err) } // Also exercise the ReaderFrom path - _, err = io.Copy(w, strings.NewReader("Ignored body")) - if err != ErrBodyNotAllowed { - t.Errorf("on Copy, expected ErrBodyNotAllowed, got %v", err) + _, err = io.Copy(w, strings.NewReader("789a")) + if err != nil { + t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) defer ts.Close() @@ -652,9 +824,11 @@ func TestHeadResponses(t *testing.T) { if len(res.TransferEncoding) > 0 { t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) } - ct := res.Header.Get("Content-Type") - if ct != "" { - t.Errorf("expected no Content-Type; got %s", ct) + if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" { + t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct) + } + if v := res.ContentLength; v != 10 { + t.Errorf("Content-Length: %d; want 10", v) } body, err := ioutil.ReadAll(res.Body) if err != nil { @@ -975,6 +1149,23 @@ func TestRedirectMunging(t *testing.T) { } } +func TestRedirectBadPath(t *testing.T) { + // This used to crash. It's not valid input (bad path), but it + // shouldn't crash. + rr := httptest.NewRecorder() + req := &Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Path: "not-empty-but-no-leading-slash", // bogus + }, + } + Redirect(rr, req, "", 304) + if rr.Code != 304 { + t.Errorf("Code = %d; want 304", rr.Code) + } +} + // TestZeroLengthPostAndResponse exercises an optimization done by the Transport: // when there is no body (either because the method doesn't permit a body, or an // explicit Content-Length of zero is present), then the transport can re-use the @@ -1408,10 +1599,7 @@ For: func TestCloseNotifierChanLeak(t *testing.T) { defer afterTest(t) - req := []byte(strings.Replace(`GET / HTTP/1.0 -Host: golang.org - -`, "\n", "\r\n", -1)) + req := reqBytes("GET / HTTP/1.0\nHost: golang.org") for i := 0; i < 20; i++ { var output bytes.Buffer conn := &rwTestConn{ @@ -1493,11 +1681,6 @@ func TestOptions(t *testing.T) { // ones, even if the handler modifies them (~erroneously) after the // first Write. func TestHeaderToWire(t *testing.T) { - req := []byte(strings.Replace(`GET / HTTP/1.1 -Host: golang.org - -`, "\n", "\r\n", -1)) - tests := []struct { name string handler func(ResponseWriter, *Request) @@ -1660,17 +1843,10 @@ Host: golang.org }, } for _, tc := range tests { - var output bytes.Buffer - conn := &rwTestConn{ - Reader: bytes.NewReader(req), - Writer: &output, - closec: make(chan bool, 1), - } - ln := &oneConnListener{conn: conn} - go Serve(ln, HandlerFunc(tc.handler)) - <-conn.closec - if err := tc.check(output.String()); err != nil { - t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, output.Bytes()) + ht := newHandlerTest(HandlerFunc(tc.handler)) + got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org") + if err := tc.check(got); err != nil { + t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, got) } } } @@ -1726,7 +1902,200 @@ func TestAcceptMaxFds(t *testing.T) { } } +func TestWriteAfterHijack(t *testing.T) { + req := reqBytes("GET / HTTP/1.1\nHost: golang.org") + var buf bytes.Buffer + wrotec := make(chan bool, 1) + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &buf, + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + conn, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + go func() { + bufrw.Write([]byte("[hijack-to-bufw]")) + bufrw.Flush() + conn.Write([]byte("[hijack-to-conn]")) + conn.Close() + wrotec <- true + }() + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + <-wrotec + if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w { + t.Errorf("wrote %q; want %q", g, w) + } +} + +// http://code.google.com/p/go/issues/detail?id=5955 +// Note that this does not test the "request too large" +// exit path from the http server. This is intentional; +// not sending Connection: close is just a minor wire +// optimization and is pointless if dealing with a +// badly behaved client. +func TestHTTP10ConnectionHeader(t *testing.T) { + defer afterTest(t) + + mux := NewServeMux() + mux.Handle("/", HandlerFunc(func(resp ResponseWriter, req *Request) {})) + ts := httptest.NewServer(mux) + defer ts.Close() + + // net/http uses HTTP/1.1 for requests, so write requests manually + tests := []struct { + req string // raw http request + expect []string // expected Connection header(s) + }{ + { + req: "GET / HTTP/1.0\r\n\r\n", + expect: nil, + }, + { + req: "OPTIONS * HTTP/1.0\r\n\r\n", + expect: nil, + }, + { + req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", + expect: []string{"keep-alive"}, + }, + } + + for _, tt := range tests { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal("dial err:", err) + } + + _, err = fmt.Fprint(conn, tt.req) + if err != nil { + t.Fatal("conn write err:", err) + } + + resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"}) + if err != nil { + t.Fatal("ReadResponse err:", err) + } + conn.Close() + resp.Body.Close() + + got := resp.Header["Connection"] + if !reflect.DeepEqual(got, tt.expect) { + t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect) + } + } +} + +// See golang.org/issue/5660 +func TestServerReaderFromOrder(t *testing.T) { + defer afterTest(t) + pr, pw := io.Pipe() + const size = 3 << 20 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path + done := make(chan bool) + go func() { + io.Copy(rw, pr) + close(done) + }() + time.Sleep(25 * time.Millisecond) // give Copy a chance to break things + n, err := io.Copy(ioutil.Discard, req.Body) + if err != nil { + t.Errorf("handler Copy: %v", err) + return + } + if n != size { + t.Errorf("handler Copy = %d; want %d", n, size) + } + pw.Write([]byte("hi")) + pw.Close() + <-done + })) + defer ts.Close() + + req, err := NewRequest("POST", ts.URL, io.LimitReader(neverEnding('a'), size)) + if err != nil { + t.Fatal(err) + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if string(all) != "hi" { + t.Errorf("Body = %q; want hi", all) + } +} + +// Issue 6157 +func TestNoContentTypeOnNotModified(t *testing.T) { + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/header" { + w.Header().Set("Content-Length", "123") + } + w.WriteHeader(StatusNotModified) + if r.URL.Path == "/more" { + w.Write([]byte("stuff")) + } + })) + for _, req := range []string{ + "GET / HTTP/1.0", + "GET /header HTTP/1.0", + "GET /more HTTP/1.0", + "GET / HTTP/1.1", + "GET /header HTTP/1.1", + "GET /more HTTP/1.1", + } { + got := ht.rawResponse(req) + if !strings.Contains(got, "304 Not Modified") { + t.Errorf("Non-304 Not Modified for %q: %s", req, got) + } else if strings.Contains(got, "Content-Length") { + t.Errorf("Got a Content-Length from %q: %s", req, got) + } + } +} + +func TestResponseWriterWriteStringAllocs(t *testing.T) { + t.Skip("allocs test unreliable with gccgo") + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/s" { + io.WriteString(w, "Hello world") + } else { + w.Write([]byte("Hello world")) + } + })) + before := testing.AllocsPerRun(25, func() { ht.rawResponse("GET / HTTP/1.0") }) + after := testing.AllocsPerRun(25, func() { ht.rawResponse("GET /s HTTP/1.0") }) + if int(after) >= int(before) { + t.Errorf("WriteString allocs of %v >= Write allocs of %v", after, before) + } +} + +func TestAppendTime(t *testing.T) { + var b [len(TimeFormat)]byte + t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60)) + res := ExportAppendTime(b[:0], t1) + t2, err := ParseTime(string(res)) + if err != nil { + t.Fatalf("Error parsing time: %s", err) + } + if !t1.Equal(t2) { + t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res)) + } +} + func BenchmarkClientServer(b *testing.B) { + b.ReportAllocs() b.StopTimer() ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") @@ -1761,6 +2130,7 @@ func BenchmarkClientServerParallel64(b *testing.B) { } func benchmarkClientServerParallel(b *testing.B, conc int) { + b.ReportAllocs() b.StopTimer() ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") @@ -1805,6 +2175,7 @@ func benchmarkClientServerParallel(b *testing.B, conc int) { // $ go tool pprof http.test http.prof // (pprof) web func BenchmarkServer(b *testing.B) { + b.ReportAllocs() // Child process mode; if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" { n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N")) @@ -1851,15 +2222,14 @@ func BenchmarkServer(b *testing.B) { func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) { b.ReportAllocs() - req := []byte(strings.Replace(`GET / HTTP/1.0 + req := reqBytes(`GET / HTTP/1.0 Host: golang.org Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 Accept-Encoding: gzip,deflate,sdch Accept-Language: en-US,en;q=0.8 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 - -`, "\n", "\r\n", -1)) +`) res := []byte("Hello world!\n") conn := &testConn{ @@ -1905,15 +2275,14 @@ func (r *repeatReader) Read(p []byte) (n int, err error) { func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) { b.ReportAllocs() - req := []byte(strings.Replace(`GET / HTTP/1.1 + req := reqBytes(`GET / HTTP/1.1 Host: golang.org Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 Accept-Encoding: gzip,deflate,sdch Accept-Language: en-US,en;q=0.8 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 - -`, "\n", "\r\n", -1)) +`) res := []byte("Hello world!\n") conn := &rwTestConn{ @@ -1940,10 +2309,9 @@ Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) { b.ReportAllocs() - req := []byte(strings.Replace(`GET / HTTP/1.1 + req := reqBytes(`GET / HTTP/1.1 Host: golang.org - -`, "\n", "\r\n", -1)) +`) res := []byte("Hello world!\n") conn := &rwTestConn{ @@ -2003,10 +2371,9 @@ func BenchmarkServerHandlerNoHeader(b *testing.B) { func benchmarkHandler(b *testing.B, h Handler) { b.ReportAllocs() - req := []byte(strings.Replace(`GET / HTTP/1.1 + req := reqBytes(`GET / HTTP/1.1 Host: golang.org - -`, "\n", "\r\n", -1)) +`) conn := &rwTestConn{ Reader: &repeatReader{content: req, count: b.N}, Writer: ioutil.Discard, diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index b2596070500..0e46863d5ae 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -16,6 +16,7 @@ import ( "log" "net" "net/url" + "os" "path" "runtime" "strconv" @@ -109,8 +110,6 @@ type conn struct { sr liveSwitchReader // where the LimitReader reads from; usually the rwc lr *io.LimitedReader // io.LimitReader(sr) buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc - bufswr *switchReader // the *switchReader io.Reader source of buf - bufsww *switchWriter // the *switchWriter io.Writer dest of buf tlsState *tls.ConnectionState // or nil when not using TLS mu sync.Mutex // guards the following @@ -246,6 +245,10 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { if !cw.wroteHeader { cw.writeHeader(p) } + if cw.res.req.Method == "HEAD" { + // Eat writes. + return len(p), nil + } if cw.chunking { _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) if err != nil { @@ -278,7 +281,7 @@ func (cw *chunkWriter) close() { // zero EOF chunk, trailer key/value pairs (currently // unsupported in Go's server), followed by a blank // line. - io.WriteString(cw.res.conn.buf, "0\r\n\r\n") + cw.res.conn.buf.WriteString("0\r\n\r\n") } } @@ -320,6 +323,10 @@ type response struct { requestBodyLimitHit bool handlerDone bool // set true when the handler exits + + // Buffers for Date and Content-Length + dateBuf [len(TimeFormat)]byte + clenBuf [10]byte } // requestTooLarge is called by maxBytesReader when too much input has @@ -332,16 +339,50 @@ func (w *response) requestTooLarge() { } } -// needsSniff returns whether a Content-Type still needs to be sniffed. +// needsSniff reports whether a Content-Type still needs to be sniffed. func (w *response) needsSniff() bool { - return !w.cw.wroteHeader && w.handlerHeader.Get("Content-Type") == "" && w.written < sniffLen + _, haveType := w.handlerHeader["Content-Type"] + return !w.cw.wroteHeader && !haveType && w.written < sniffLen } +// writerOnly hides an io.Writer value's optional ReadFrom method +// from io.Copy. type writerOnly struct { io.Writer } +func srcIsRegularFile(src io.Reader) (isRegular bool, err error) { + switch v := src.(type) { + case *os.File: + fi, err := v.Stat() + if err != nil { + return false, err + } + return fi.Mode().IsRegular(), nil + case *io.LimitedReader: + return srcIsRegularFile(v.R) + default: + return + } +} + +// ReadFrom is here to optimize copying from an *os.File regular file +// to a *net.TCPConn with sendfile. func (w *response) ReadFrom(src io.Reader) (n int64, err error) { + // Our underlying w.conn.rwc is usually a *TCPConn (with its + // own ReadFrom method). If not, or if our src isn't a regular + // file, just fall back to the normal copy method. + rf, ok := w.conn.rwc.(io.ReaderFrom) + regFile, err := srcIsRegularFile(src) + if err != nil { + return 0, err + } + if !ok || !regFile { + return io.Copy(writerOnly{w}, src) + } + + // sendfile path: + if !w.wroteHeader { w.WriteHeader(StatusOK) } @@ -359,16 +400,12 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) { // Now that cw has been flushed, its chunking field is guaranteed initialized. if !w.cw.chunking && w.bodyAllowed() { - if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { - n0, err := rf.ReadFrom(src) - n += n0 - w.written += n0 - return n, err - } + n0, err := rf.ReadFrom(src) + n += n0 + w.written += n0 + return n, err } - // Fall back to default io.Copy implementation. - // Use wrapper to hide w.ReadFrom from io.Copy. n0, err := io.Copy(writerOnly{w}, src) n += n0 return n, err @@ -392,34 +429,20 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { } c.sr = liveSwitchReader{r: c.rwc} c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) - br, sr := newBufioReader(c.lr) - bw, sw := newBufioWriterSize(c.rwc, 4<<10) + br := newBufioReader(c.lr) + bw := newBufioWriterSize(c.rwc, 4<<10) c.buf = bufio.NewReadWriter(br, bw) - c.bufswr = sr - c.bufsww = sw return c, nil } -// TODO: remove this, if issue 5100 is fixed -type bufioReaderPair struct { - br *bufio.Reader - sr *switchReader // from which the bufio.Reader is reading -} - -// TODO: remove this, if issue 5100 is fixed -type bufioWriterPair struct { - bw *bufio.Writer - sw *switchWriter // to which the bufio.Writer is writing -} - // TODO: use a sync.Cache instead var ( - bufioReaderCache = make(chan bufioReaderPair, 4) - bufioWriterCache2k = make(chan bufioWriterPair, 4) - bufioWriterCache4k = make(chan bufioWriterPair, 4) + bufioReaderCache = make(chan *bufio.Reader, 4) + bufioWriterCache2k = make(chan *bufio.Writer, 4) + bufioWriterCache4k = make(chan *bufio.Writer, 4) ) -func bufioWriterCache(size int) chan bufioWriterPair { +func bufioWriterCache(size int) chan *bufio.Writer { switch size { case 2 << 10: return bufioWriterCache2k @@ -429,55 +452,38 @@ func bufioWriterCache(size int) chan bufioWriterPair { return nil } -func newBufioReader(r io.Reader) (*bufio.Reader, *switchReader) { +func newBufioReader(r io.Reader) *bufio.Reader { select { case p := <-bufioReaderCache: - p.sr.Reader = r - return p.br, p.sr + p.Reset(r) + return p default: - sr := &switchReader{r} - return bufio.NewReader(sr), sr + return bufio.NewReader(r) } } -func putBufioReader(br *bufio.Reader, sr *switchReader) { - if n := br.Buffered(); n > 0 { - io.CopyN(ioutil.Discard, br, int64(n)) - } - br.Read(nil) // clears br.err - sr.Reader = nil +func putBufioReader(br *bufio.Reader) { + br.Reset(nil) select { - case bufioReaderCache <- bufioReaderPair{br, sr}: + case bufioReaderCache <- br: default: } } -func newBufioWriterSize(w io.Writer, size int) (*bufio.Writer, *switchWriter) { +func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { select { case p := <-bufioWriterCache(size): - p.sw.Writer = w - return p.bw, p.sw + p.Reset(w) + return p default: - sw := &switchWriter{w} - return bufio.NewWriterSize(sw, size), sw + return bufio.NewWriterSize(w, size) } } -func putBufioWriter(bw *bufio.Writer, sw *switchWriter) { - if bw.Buffered() > 0 { - // It must have failed to flush to its target - // earlier. We can't reuse this bufio.Writer. - return - } - if err := bw.Flush(); err != nil { - // Its sticky error field is set, which is returned by - // Flush even when there's no data buffered. This - // bufio Writer is dead to us. Don't reuse it. - return - } - sw.Writer = nil +func putBufioWriter(bw *bufio.Writer) { + bw.Reset(nil) select { - case bufioWriterCache(bw.Available()) <- bufioWriterPair{bw, sw}: + case bufioWriterCache(bw.Available()) <- bw: default: } } @@ -508,7 +514,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true - io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") + ecr.resp.conn.buf.WriteString("HTTP/1.1 100 Continue\r\n\r\n") ecr.resp.conn.buf.Flush() } return ecr.readCloser.Read(p) @@ -525,6 +531,28 @@ func (ecr *expectContinueReader) Close() error { // It is like time.RFC1123 but hard codes GMT as the time zone. const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" +// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat)) +func appendTime(b []byte, t time.Time) []byte { + const days = "SunMonTueWedThuFriSat" + const months = "JanFebMarAprMayJunJulAugSepOctNovDec" + + t = t.UTC() + yy, mm, dd := t.Date() + hh, mn, ss := t.Clock() + day := days[3*t.Weekday():] + mon := months[3*(mm-1):] + + return append(b, + day[0], day[1], day[2], ',', ' ', + byte('0'+dd/10), byte('0'+dd%10), ' ', + mon[0], mon[1], mon[2], ' ', + byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ', + byte('0'+hh/10), byte('0'+hh%10), ':', + byte('0'+mn/10), byte('0'+mn%10), ':', + byte('0'+ss/10), byte('0'+ss%10), ' ', + 'G', 'M', 'T') +} + var errTooLarge = errors.New("http: request too large") // Read next request from connection. @@ -562,7 +590,7 @@ func (c *conn) readRequest() (w *response, err error) { contentLength: -1, } w.cw.res = w - w.w, w.sw = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) + w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) return w, nil } @@ -620,27 +648,45 @@ func (w *response) WriteHeader(code int) { // the response Header map and all its 1-element slices. type extraHeader struct { contentType string - contentLength string connection string - date string transferEncoding string + date []byte // written if not nil + contentLength []byte // written if not nil } // Sorted the same as extraHeader.Write's loop. var extraHeaderKeys = [][]byte{ - []byte("Content-Type"), []byte("Content-Length"), - []byte("Connection"), []byte("Date"), []byte("Transfer-Encoding"), + []byte("Content-Type"), + []byte("Connection"), + []byte("Transfer-Encoding"), } -// The value receiver, despite copying 5 strings to the stack, -// prevents an extra allocation. The escape analysis isn't smart -// enough to realize this doesn't mutate h. -func (h extraHeader) Write(w io.Writer) { - for i, v := range []string{h.contentType, h.contentLength, h.connection, h.date, h.transferEncoding} { +var ( + headerContentLength = []byte("Content-Length: ") + headerDate = []byte("Date: ") +) + +// Write writes the headers described in h to w. +// +// This method has a value receiver, despite the somewhat large size +// of h, because it prevents an allocation. The escape analysis isn't +// smart enough to realize this function doesn't mutate h. +func (h extraHeader) Write(w *bufio.Writer) { + if h.date != nil { + w.Write(headerDate) + w.Write(h.date) + w.Write(crlf) + } + if h.contentLength != nil { + w.Write(headerContentLength) + w.Write(h.contentLength) + w.Write(crlf) + } + for i, v := range []string{h.contentType, h.connection, h.transferEncoding} { if v != "" { w.Write(extraHeaderKeys[i]) w.Write(colonSpace) - io.WriteString(w, v) + w.WriteString(v) w.Write(crlf) } } @@ -661,6 +707,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { cw.wroteHeader = true w := cw.res + isHEAD := w.req.Method == "HEAD" // header is written out to w.conn.buf below. Depending on the // state of the handler, we either own the map or not. If we @@ -692,9 +739,17 @@ func (cw *chunkWriter) writeHeader(p []byte) { // response header and this is our first (and last) write, set // it, even to zero. This helps HTTP/1.0 clients keep their // "keep-alive" connections alive. - if w.handlerDone && header.get("Content-Length") == "" && w.req.Method != "HEAD" { + // Exceptions: 304 responses never get Content-Length, and if + // it was a HEAD request, we don't know the difference between + // 0 actual bytes and 0 bytes because the handler noticed it + // was a HEAD request and chose not to write anything. So for + // HEAD, the handler should either write the Content-Length or + // write non-zero bytes. If it's actually 0 bytes and the + // handler never looked at the Request.Method, we just don't + // send a Content-Length header. + if w.handlerDone && w.status != StatusNotModified && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { w.contentLength = int64(len(p)) - setHeader.contentLength = strconv.Itoa(len(p)) + setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) } // If this was an HTTP/1.0 request with keep-alive and we sent a @@ -709,7 +764,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // Check for a explicit (and valid) Content-Length header. hasCL := w.contentLength != -1 - if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { + if w.req.wantsHttp10KeepAlive() && (isHEAD || hasCL) { _, connectionHeaderSet := header["Connection"] if !connectionHeaderSet { setHeader.connection = "keep-alive" @@ -749,13 +804,14 @@ func (cw *chunkWriter) writeHeader(p []byte) { } } else { // If no content type, apply sniffing algorithm to body. - if header.get("Content-Type") == "" && w.req.Method != "HEAD" { + _, haveType := header["Content-Type"] + if !haveType { setHeader.contentType = DetectContentType(p) } } if _, ok := header["Date"]; !ok { - setHeader.date = time.Now().UTC().Format(TimeFormat) + setHeader.date = appendTime(cw.res.dateBuf[:0], time.Now()) } te := header.get("Transfer-Encoding") @@ -801,12 +857,14 @@ func (cw *chunkWriter) writeHeader(p []byte) { if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { delHeader("Connection") - setHeader.connection = "close" + if w.req.ProtoAtLeast(1, 1) { + setHeader.connection = "close" + } } - io.WriteString(w.conn.buf, statusLine(w.req, code)) + w.conn.buf.WriteString(statusLine(w.req, code)) cw.header.WriteSubset(w.conn.buf, excludeHeader) - setHeader.Write(w.conn.buf) + setHeader.Write(w.conn.buf.Writer) w.conn.buf.Write(crlf) } @@ -861,7 +919,7 @@ func (w *response) bodyAllowed() bool { if !w.wroteHeader { panic("") } - return w.status != StatusNotModified && w.req.Method != "HEAD" + return w.status != StatusNotModified } // The Life Of A Write is like this: @@ -897,6 +955,15 @@ func (w *response) bodyAllowed() bool { // bufferBeforeChunkingSize smaller and having bufio's fast-paths deal // with this instead. func (w *response) Write(data []byte) (n int, err error) { + return w.write(len(data), data, "") +} + +func (w *response) WriteString(data string) (n int, err error) { + return w.write(len(data), nil, data) +} + +// either dataB or dataS is non-zero. +func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { if w.conn.hijacked() { log.Print("http: response.Write on hijacked connection") return 0, ErrHijacked @@ -904,18 +971,22 @@ func (w *response) Write(data []byte) (n int, err error) { if !w.wroteHeader { w.WriteHeader(StatusOK) } - if len(data) == 0 { + if lenData == 0 { return 0, nil } if !w.bodyAllowed() { return 0, ErrBodyNotAllowed } - w.written += int64(len(data)) // ignoring errors, for errorKludge + w.written += int64(lenData) // ignoring errors, for errorKludge if w.contentLength != -1 && w.written > w.contentLength { return 0, ErrContentLength } - return w.w.Write(data) + if dataB != nil { + return w.w.Write(dataB) + } else { + return w.w.WriteString(dataS) + } } func (w *response) finishRequest() { @@ -926,7 +997,7 @@ func (w *response) finishRequest() { } w.w.Flush() - putBufioWriter(w.w, w.sw) + putBufioWriter(w.w) w.cw.close() w.conn.buf.Flush() @@ -939,7 +1010,7 @@ func (w *response) finishRequest() { w.req.MultipartForm.RemoveAll() } - if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { + if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. w.closeAfterReply = true } @@ -959,11 +1030,11 @@ func (c *conn) finalFlush() { // Steal the bufio.Reader (~4KB worth of memory) and its associated // reader for a future connection. - putBufioReader(c.buf.Reader, c.bufswr) + putBufioReader(c.buf.Reader) // Steal the bufio.Writer (~4KB worth of memory) and its associated // writer for a future connection. - putBufioWriter(c.buf.Writer, c.bufsww) + putBufioWriter(c.buf.Writer) c.buf = nil } @@ -1001,7 +1072,7 @@ func (c *conn) closeWriteAndWait() { time.Sleep(rstAvoidanceDelay) } -// validNPN returns whether the proto is not a blacklisted Next +// validNPN reports whether the proto is not a blacklisted Next // Protocol Negotiation protocol. Empty and built-in protocol types // are blacklisted and can't be overridden with alternate // implementations. @@ -1152,6 +1223,7 @@ func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { // Helper handlers // Error replies to the request with the specified error message and HTTP code. +// The error message should be plain text. func Error(w ResponseWriter, error string, code int) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(code) @@ -1288,6 +1360,10 @@ func RedirectHandler(url string, code int) Handler { // former will receive requests for any other paths in the // "/images/" subtree. // +// Note that since a pattern ending in a slash names a rooted subtree, +// the pattern "/" matches all paths not matched by other registered +// patterns, not just the URL with Path == "/". +// // Patterns may optionally begin with a host name, restricting matches to // URLs on that host only. Host-specific patterns take precedence over // general patterns, so that a handler might register for the two patterns @@ -1378,7 +1454,9 @@ func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { if r.Method != "CONNECT" { if p := cleanPath(r.URL.Path); p != r.URL.Path { _, pattern = mux.handler(r.Host, p) - return RedirectHandler(p, StatusMovedPermanently), pattern + url := *r.URL + url.Path = p + return RedirectHandler(url.String(), StatusMovedPermanently), pattern } } @@ -1408,7 +1486,9 @@ func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { // pattern most closely matches the request URL. func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { if r.RequestURI == "*" { - w.Header().Set("Connection", "close") + if r.ProtoAtLeast(1, 1) { + w.Header().Set("Connection", "close") + } w.WriteHeader(StatusBadRequest) return } @@ -1771,7 +1851,15 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } // eofReader is a non-nil io.ReadCloser that always returns EOF. -var eofReader = ioutil.NopCloser(strings.NewReader("")) +// It embeds a *strings.Reader so it still has a WriteTo method +// and io.Copy won't need a buffer. +var eofReader = &struct { + *strings.Reader + io.Closer +}{ + strings.NewReader(""), + ioutil.NopCloser(nil), +} // initNPNRequest is an HTTP handler that initializes certain // uninitialized fields in its *Request. Such partially-initialized diff --git a/libgo/go/net/http/server_test.go b/libgo/go/net/http/server_test.go deleted file mode 100644 index e8b69f76cce..00000000000 --- a/libgo/go/net/http/server_test.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http_test - -import ( - . "net/http" - "net/http/httptest" - "net/url" - "testing" -) - -var serveMuxRegister = []struct { - pattern string - h Handler -}{ - {"/dir/", serve(200)}, - {"/search", serve(201)}, - {"codesearch.google.com/search", serve(202)}, - {"codesearch.google.com/", serve(203)}, -} - -// serve returns a handler that sends a response with the given code. -func serve(code int) HandlerFunc { - return func(w ResponseWriter, r *Request) { - w.WriteHeader(code) - } -} - -var serveMuxTests = []struct { - method string - host string - path string - code int - pattern string -}{ - {"GET", "google.com", "/", 404, ""}, - {"GET", "google.com", "/dir", 301, "/dir/"}, - {"GET", "google.com", "/dir/", 200, "/dir/"}, - {"GET", "google.com", "/dir/file", 200, "/dir/"}, - {"GET", "google.com", "/search", 201, "/search"}, - {"GET", "google.com", "/search/", 404, ""}, - {"GET", "google.com", "/search/foo", 404, ""}, - {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"}, - {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"}, - {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"}, - {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"}, - {"GET", "images.google.com", "/search", 201, "/search"}, - {"GET", "images.google.com", "/search/", 404, ""}, - {"GET", "images.google.com", "/search/foo", 404, ""}, - {"GET", "google.com", "/../search", 301, "/search"}, - {"GET", "google.com", "/dir/..", 301, ""}, - {"GET", "google.com", "/dir/..", 301, ""}, - {"GET", "google.com", "/dir/./file", 301, "/dir/"}, - - // The /foo -> /foo/ redirect applies to CONNECT requests - // but the path canonicalization does not. - {"CONNECT", "google.com", "/dir", 301, "/dir/"}, - {"CONNECT", "google.com", "/../search", 404, ""}, - {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, - {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, - {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"}, -} - -func TestServeMuxHandler(t *testing.T) { - mux := NewServeMux() - for _, e := range serveMuxRegister { - mux.Handle(e.pattern, e.h) - } - - for _, tt := range serveMuxTests { - r := &Request{ - Method: tt.method, - Host: tt.host, - URL: &url.URL{ - Path: tt.path, - }, - } - h, pattern := mux.Handler(r) - rr := httptest.NewRecorder() - h.ServeHTTP(rr, r) - if pattern != tt.pattern || rr.Code != tt.code { - t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern) - } - } -} - -func TestServerRedirect(t *testing.T) { - // This used to crash. It's not valid input (bad path), but it - // shouldn't crash. - rr := httptest.NewRecorder() - req := &Request{ - Method: "GET", - URL: &url.URL{ - Scheme: "http", - Path: "not-empty-but-no-leading-slash", // bogus - }, - } - Redirect(rr, req, "", 304) - if rr.Code != 304 { - t.Errorf("Code = %d; want 304", rr.Code) - } -} diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index 106d94ec1cb..24ca27afc16 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -12,6 +12,7 @@ import ( "log" . "net/http" "net/http/httptest" + "reflect" "strconv" "strings" "testing" @@ -84,6 +85,29 @@ func TestServerContentType(t *testing.T) { } } +// Issue 5953: shouldn't sniff if the handler set a Content-Type header, +// even if it's the empty string. +func TestServerIssue5953(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header()["Content-Type"] = []string{""} + fmt.Fprintf(w, "<html><head></head><body>hi</body></html>") + })) + defer ts.Close() + + resp, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + got := resp.Header["Content-Type"] + want := []string{""} + if !reflect.DeepEqual(got, want) { + t.Errorf("Content-Type = %q; want %q", got, want) + } + resp.Body.Close() +} + func TestContentTypeWithCopy(t *testing.T) { defer afterTest(t) diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 53569bcc2fc..bacd83732de 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -238,7 +238,7 @@ type transferReader struct { Trailer Header } -// bodyAllowedForStatus returns whether a given response status code +// bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC2616, section 4.4. func bodyAllowedForStatus(status int) bool { switch { @@ -254,7 +254,7 @@ func bodyAllowedForStatus(status int) bool { // msg is *Request or *Response. func readTransfer(msg interface{}, r *bufio.Reader) (err error) { - t := &transferReader{} + t := &transferReader{RequestMethod: "GET"} // Unify input isResponse := false @@ -262,11 +262,13 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { case *Response: t.Header = rr.Header t.StatusCode = rr.StatusCode - t.RequestMethod = rr.Request.Method t.ProtoMajor = rr.ProtoMajor t.ProtoMinor = rr.ProtoMinor t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header) isResponse = true + if rr.Request != nil { + t.RequestMethod = rr.Request.Method + } case *Request: t.Header = rr.Header t.ProtoMajor = rr.ProtoMajor @@ -274,7 +276,6 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // Transfer semantics for Requests are exactly like those for // Responses with status code 200, responding to a GET method t.StatusCode = 200 - t.RequestMethod = "GET" default: panic("unexpected type") } @@ -328,12 +329,12 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { switch { case chunked(t.TransferEncoding): if noBodyExpected(t.RequestMethod) { - t.Body = &body{Reader: eofReader, closing: t.Close} + t.Body = eofReader } else { t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} } case realLength == 0: - t.Body = &body{Reader: eofReader, closing: t.Close} + t.Body = eofReader case realLength > 0: t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} default: @@ -343,7 +344,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.Body = &body{Reader: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) - t.Body = &body{Reader: eofReader, closing: t.Close} + t.Body = eofReader } } @@ -518,8 +519,6 @@ type body struct { r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? closed bool - - res *response // response writer for server requests, else nil } // ErrBodyReadAfterClose is returned when reading a Request or Response @@ -534,13 +533,22 @@ func (b *body) Read(p []byte) (n int, err error) { } n, err = b.Reader.Read(p) - // Read the final trailer once we hit EOF. - if err == io.EOF && b.hdr != nil { - if e := b.readTrailer(); e != nil { - err = e + if err == io.EOF { + // Chunked case. Read the trailer. + if b.hdr != nil { + if e := b.readTrailer(); e != nil { + err = e + } + b.hdr = nil + } else { + // If the server declared the Content-Length, our body is a LimitedReader + // and we need to check whether this EOF arrived early. + if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 { + err = io.ErrUnexpectedEOF + } } - b.hdr = nil } + return n, err } @@ -618,14 +626,6 @@ func (b *body) Close() error { case b.hdr == nil && b.closing: // no trailer and closing the connection next. // no point in reading to EOF. - case b.res != nil && b.res.requestBodyLimitHit: - // In a server request, don't continue reading from the client - // if we've already hit the maximum body size set by the - // handler. If this is set, that also means the TCP connection - // is about to be closed, so getting to the next HTTP request - // in the stream is not necessary. - case b.Reader == eofReader: - // Nothing to read. No need to io.Copy from it. default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 4cd0533ffc2..f6871afacd7 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -13,7 +13,6 @@ import ( "bufio" "compress/gzip" "crypto/tls" - "encoding/base64" "errors" "fmt" "io" @@ -109,9 +108,11 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) { } proxyURL, err := url.Parse(proxy) if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") { - if u, err := url.Parse("http://" + proxy); err == nil { - proxyURL = u - err = nil + // proxy was bogus. Try prepending "http://" to it and + // see if that parses correctly. If not, we fall + // through and complain about the original one. + if proxyURL, err := url.Parse("http://" + proxy); err == nil { + return proxyURL, nil } } if err != nil { @@ -215,6 +216,7 @@ func (t *Transport) CloseIdleConnections() { t.idleMu.Lock() m := t.idleConn t.idleConn = nil + t.idleConnCh = nil t.idleMu.Unlock() if m == nil { return @@ -270,7 +272,9 @@ func (cm *connectMethod) proxyAuth() string { return "" } if u := cm.proxyURL.User; u != nil { - return "Basic " + base64.URLEncoding.EncodeToString([]byte(u.String())) + username := u.Username() + password, _ := u.Password() + return "Basic " + basicAuth(username, password) } return "" } @@ -293,8 +297,10 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { max = DefaultMaxIdleConnsPerHost } t.idleMu.Lock() + + waitingDialer := t.idleConnCh[key] select { - case t.idleConnCh[key] <- pconn: + case waitingDialer <- pconn: // We're done with this pconn and somebody else is // currently waiting for a conn of this type (they're // actively dialing, but this conn is ready @@ -303,6 +309,11 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { t.idleMu.Unlock() return true default: + if waitingDialer != nil { + // They had populated this, but their dial won + // first, so we can clean up this map entry. + delete(t.idleConnCh, key) + } } if t.idleConn == nil { t.idleConn = make(map[string][]*persistConn) @@ -322,7 +333,13 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { return true } +// getIdleConnCh returns a channel to receive and return idle +// persistent connection for the given connectMethod. +// It may return nil, if persistent connections are not being used. func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { + if t.DisableKeepAlives { + return nil + } key := cm.key() t.idleMu.Lock() defer t.idleMu.Unlock() @@ -498,8 +515,8 @@ func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { if err = conn.(*tls.Conn).Handshake(); err != nil { return nil, err } - if t.TLSClientConfig == nil || !t.TLSClientConfig.InsecureSkipVerify { - if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil { + if !cfg.InsecureSkipVerify { + if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil { return nil, err } } @@ -831,10 +848,15 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // uncompress the gzip stream if we were the layer that // requested it. requestedGzip := false - if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { + if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Method != "HEAD" { // Request gzip only, not deflate. Deflate is ambiguous and // not as universally supported anyway. // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // http://golang.org/issue/5522 requestedGzip = true req.extraHeaders().Set("Accept-Encoding", "gzip") } diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 9f64a6e4b5f..e4df30a98de 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -15,6 +15,7 @@ import ( "io" "io/ioutil" "net" + "net/http" . "net/http" "net/http/httptest" "net/url" @@ -469,6 +470,7 @@ func TestTransportHeadResponses(t *testing.T) { res, err := c.Head(ts.URL) if err != nil { t.Errorf("error on loop %d: %v", i, err) + continue } if e, g := "123", res.Header.Get("Content-Length"); e != g { t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) @@ -476,6 +478,11 @@ func TestTransportHeadResponses(t *testing.T) { if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } + if all, err := ioutil.ReadAll(res.Body); err != nil { + t.Errorf("loop %d: Body ReadAll: %v", i, err) + } else if len(all) != 0 { + t.Errorf("Bogus body %q", all) + } } } @@ -553,12 +560,13 @@ func TestRoundTripGzip(t *testing.T) { res, err := DefaultTransport.RoundTrip(req) var body []byte if test.compressed { - gzip, err := gzip.NewReader(res.Body) + var r *gzip.Reader + r, err = gzip.NewReader(res.Body) if err != nil { t.Errorf("%d. gzip NewReader: %v", i, err) continue } - body, err = ioutil.ReadAll(gzip) + body, err = ioutil.ReadAll(r) res.Body.Close() } else { body, err = ioutil.ReadAll(res.Body) @@ -585,13 +593,16 @@ func TestTransportGzip(t *testing.T) { const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if req.Method == "HEAD" { + if g := req.Header.Get("Accept-Encoding"); g != "" { + t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) + } + return + } if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { t.Errorf("Accept-Encoding = %q, want %q", g, e) } rw.Header().Set("Content-Encoding", "gzip") - if req.Method == "HEAD" { - return - } var w io.Writer = rw var buf bytes.Buffer @@ -819,7 +830,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { } nhigh := runtime.NumGoroutine() tr.CloseIdleConnections() - time.Sleep(50 * time.Millisecond) + time.Sleep(400 * time.Millisecond) runtime.GC() nfinal := runtime.NumGoroutine() @@ -1571,6 +1582,77 @@ func TestProxyFromEnvironment(t *testing.T) { } } +func TestIdleConnChannelLeak(t *testing.T) { + var mu sync.Mutex + var n int + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + n++ + mu.Unlock() + })) + defer ts.Close() + + tr := &Transport{ + Dial: func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + }, + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + + // First, without keep-alives. + for _, disableKeep := range []bool{true, false} { + tr.DisableKeepAlives = disableKeep + for i := 0; i < 5; i++ { + _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) + if err != nil { + t.Fatal(err) + } + } + if got := tr.IdleConnChMapSizeForTesting(); got != 0 { + t.Fatalf("ForDisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) + } + } +} + +// Verify the status quo: that the Client.Post function coerces its +// body into a ReadCloser if it's a Closer, and that the Transport +// then closes it. +func TestTransportClosesRequestBody(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(http.HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(ioutil.Discard, r.Body) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + cl := &Client{Transport: tr} + + closes := 0 + + res, err := cl.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if closes != 1 { + t.Errorf("closes = %d; want 1", closes) + } +} + +type countCloseReader struct { + n *int + io.Reader +} + +func (cr countCloseReader) Close() error { + (*cr.n)++ + return nil +} + // rgz is a gzip quine that uncompresses to itself. var rgz = []byte{ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, diff --git a/libgo/go/net/http/z_last_test.go b/libgo/go/net/http/z_last_test.go index 2161db7365d..5a0cc119849 100644 --- a/libgo/go/net/http/z_last_test.go +++ b/libgo/go/net/http/z_last_test.go @@ -23,7 +23,6 @@ func interestingGoroutines() (gs []string) { } stack := strings.TrimSpace(sl[1]) if stack == "" || - strings.Contains(stack, "created by net.newPollServer") || strings.Contains(stack, "created by net.startServer") || strings.Contains(stack, "created by testing.RunTests") || strings.Contains(stack, "closeWriteAndWait") || diff --git a/libgo/go/net/interface_bsd.go b/libgo/go/net/interface_bsd.go index 716b60a97f4..16775579d05 100644 --- a/libgo/go/net/interface_bsd.go +++ b/libgo/go/net/interface_bsd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd netbsd openbsd +// +build darwin dragonfly freebsd netbsd openbsd package net diff --git a/libgo/go/net/interface_dragonfly.go b/libgo/go/net/interface_dragonfly.go new file mode 100644 index 00000000000..c9ce5a7ac15 --- /dev/null +++ b/libgo/go/net/interface_dragonfly.go @@ -0,0 +1,12 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + // TODO(mikio): Implement this like other platforms. + return nil, nil +} diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go index e31894abf73..efabb5f3c25 100644 --- a/libgo/go/net/interface_test.go +++ b/libgo/go/net/interface_test.go @@ -108,12 +108,23 @@ func testInterfaceMulticastAddrs(t *testing.T, ifi *Interface) { func testAddrs(t *testing.T, ifat []Addr) { for _, ifa := range ifat { switch ifa := ifa.(type) { - case *IPAddr, *IPNet: - if ifa == nil { - t.Errorf("\tunexpected value: %v", ifa) + case *IPAddr: + if ifa == nil || ifa.IP == nil { + t.Errorf("\tunexpected value: %v, %v", ifa, ifa.IP) } else { t.Logf("\tinterface address %q", ifa.String()) } + case *IPNet: + if ifa == nil || ifa.IP == nil || ifa.Mask == nil { + t.Errorf("\tunexpected value: %v, %v, %v", ifa, ifa.IP, ifa.Mask) + } else { + _, prefixLen := ifa.Mask.Size() + if ifa.IP.To4() != nil && prefixLen != 8*IPv4len || ifa.IP.To16() != nil && ifa.IP.To4() == nil && prefixLen != 8*IPv6len { + t.Errorf("\tunexpected value: %v, %v, %v, %v", ifa, ifa.IP, ifa.Mask, prefixLen) + } else { + t.Logf("\tinterface address %q", ifa.String()) + } + } default: t.Errorf("\tunexpected type: %T", ifa) } diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go index 0e42da21683..fd6a7d4ee8b 100644 --- a/libgo/go/net/ip.go +++ b/libgo/go/net/ip.go @@ -12,6 +12,8 @@ package net +import "errors" + // IP address lengths (bytes). const ( IPv4len = 4 @@ -310,6 +312,43 @@ func (ip IP) String() string { return s } +// ipEmptyString is like ip.String except that it returns +// an empty string when ip is unset. +func ipEmptyString(ip IP) string { + if len(ip) == 0 { + return "" + } + return ip.String() +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The encoding is the same as returned by String. +func (ip IP) MarshalText() ([]byte, error) { + if len(ip) == 0 { + return []byte(""), nil + } + if len(ip) != IPv4len && len(ip) != IPv6len { + return nil, errors.New("invalid IP address") + } + return []byte(ip.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The IP address is expected in a form accepted by ParseIP. +func (ip *IP) UnmarshalText(text []byte) error { + if len(text) == 0 { + *ip = nil + return nil + } + s := string(text) + x := ParseIP(s) + if x == nil { + return &ParseError{"IP address", s} + } + *ip = x + return nil +} + // Equal returns true if ip and x are the same IP address. // An IPv4 address and that same address in IPv6 form are // considered to be equal. diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go index 16f30d446b5..26b53729b85 100644 --- a/libgo/go/net/ip_test.go +++ b/libgo/go/net/ip_test.go @@ -32,6 +32,32 @@ func TestParseIP(t *testing.T) { if out := ParseIP(tt.in); !reflect.DeepEqual(out, tt.out) { t.Errorf("ParseIP(%q) = %v, want %v", tt.in, out, tt.out) } + if tt.in == "" { + // Tested in TestMarshalEmptyIP below. + continue + } + var out IP + if err := out.UnmarshalText([]byte(tt.in)); !reflect.DeepEqual(out, tt.out) || (tt.out == nil) != (err != nil) { + t.Errorf("IP.UnmarshalText(%q) = %v, %v, want %v", tt.in, out, err, tt.out) + } + } +} + +// Issue 6339 +func TestMarshalEmptyIP(t *testing.T) { + for _, in := range [][]byte{nil, []byte("")} { + var out = IP{1, 2, 3, 4} + if err := out.UnmarshalText(in); err != nil || out != nil { + t.Errorf("UnmarshalText(%v) = %v, %v; want nil, nil", in, out, err) + } + } + var ip IP + got, err := ip.MarshalText() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, []byte("")) { + t.Errorf(`got %#v, want []byte("")`, got) } } @@ -47,13 +73,19 @@ var ipStringTests = []struct { {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0}, "2001:db8:0:0:1::"}, {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0x1}, "2001:db8::1:0:0:1"}, {IP{0x20, 0x1, 0xD, 0xB8, 0, 0, 0, 0, 0, 0xA, 0, 0xB, 0, 0xC, 0, 0xD}, "2001:db8::a:b:c:d"}, - {nil, "<nil>"}, + {IPv4(192, 168, 0, 1), "192.168.0.1"}, + {nil, ""}, } func TestIPString(t *testing.T) { for _, tt := range ipStringTests { - if out := tt.in.String(); out != tt.out { - t.Errorf("IP.String(%v) = %q, want %q", tt.in, out, tt.out) + if tt.in != nil { + if out := tt.in.String(); out != tt.out { + t.Errorf("IP.String(%v) = %q, want %q", tt.in, out, tt.out) + } + } + if out, err := tt.in.MarshalText(); string(out) != tt.out || err != nil { + t.Errorf("IP.MarshalText(%v) = %q, %v, want %q, nil", tt.in, out, err, tt.out) } } } diff --git a/libgo/go/net/ipraw_test.go b/libgo/go/net/ipraw_test.go index 12c199d1cf4..ea183f1d3eb 100644 --- a/libgo/go/net/ipraw_test.go +++ b/libgo/go/net/ipraw_test.go @@ -6,19 +6,19 @@ package net import ( "bytes" - "errors" "fmt" "os" "reflect" + "runtime" "testing" "time" ) type resolveIPAddrTest struct { - net string - litAddr string - addr *IPAddr - err error + net string + litAddrOrName string + addr *IPAddr + err error } var resolveIPAddrTests = []resolveIPAddrTest{ @@ -29,6 +29,7 @@ var resolveIPAddrTests = []resolveIPAddrTest{ {"ip", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, {"ip6", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, {"ip6:ipv6-icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + {"ip6:IPv6-ICMP", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, {"ip", "::1%en0", &IPAddr{IP: ParseIP("::1"), Zone: "en0"}, nil}, {"ip6", "::1%911", &IPAddr{IP: ParseIP("::1"), Zone: "911"}, nil}, @@ -49,13 +50,28 @@ func init() { {"ip6", "fe80::1%" + index, &IPAddr{IP: ParseIP("fe80::1"), Zone: index}, nil}, }...) } + if ips, err := LookupIP("localhost"); err == nil && len(ips) > 1 && supportsIPv4 && supportsIPv6 { + resolveIPAddrTests = append(resolveIPAddrTests, []resolveIPAddrTest{ + {"ip", "localhost", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, + {"ip4", "localhost", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, + {"ip6", "localhost", &IPAddr{IP: IPv6loopback}, nil}, + }...) + } +} + +func skipRawSocketTest(t *testing.T) (skip bool, skipmsg string) { + skip, skipmsg, err := skipRawSocketTests() + if err != nil { + t.Fatal(err) + } + return skip, skipmsg } func TestResolveIPAddr(t *testing.T) { for _, tt := range resolveIPAddrTests { - addr, err := ResolveIPAddr(tt.net, tt.litAddr) + addr, err := ResolveIPAddr(tt.net, tt.litAddrOrName) if err != tt.err { - condFatalf(t, "ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + t.Fatalf("ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddrOrName, err) } else if !reflect.DeepEqual(addr, tt.addr) { t.Fatalf("got %#v; expected %#v", addr, tt.addr) } @@ -72,8 +88,8 @@ var icmpEchoTests = []struct { } func TestConnICMPEcho(t *testing.T) { - if os.Getuid() != 0 { - t.Skip("skipping test; must be root") + if skip, skipmsg := skipRawSocketTest(t); skip { + t.Skip(skipmsg) } for i, tt := range icmpEchoTests { @@ -97,7 +113,7 @@ func TestConnICMPEcho(t *testing.T) { typ = icmpv6EchoRequest } xid, xseq := os.Getpid()&0xffff, i+1 - b, err := (&icmpMessage{ + wb, err := (&icmpMessage{ Type: typ, Code: 0, Body: &icmpEcho{ ID: xid, Seq: xseq, @@ -107,18 +123,19 @@ func TestConnICMPEcho(t *testing.T) { if err != nil { t.Fatalf("icmpMessage.Marshal failed: %v", err) } - if _, err := c.Write(b); err != nil { + if _, err := c.Write(wb); err != nil { t.Fatalf("Conn.Write failed: %v", err) } var m *icmpMessage + rb := make([]byte, 20+len(wb)) for { - if _, err := c.Read(b); err != nil { + if _, err := c.Read(rb); err != nil { t.Fatalf("Conn.Read failed: %v", err) } if net == "ip4" { - b = ipv4Payload(b) + rb = ipv4Payload(rb) } - if m, err = parseICMPMessage(b); err != nil { + if m, err = parseICMPMessage(rb); err != nil { t.Fatalf("parseICMPMessage failed: %v", err) } switch m.Type { @@ -139,8 +156,8 @@ func TestConnICMPEcho(t *testing.T) { } func TestPacketConnICMPEcho(t *testing.T) { - if os.Getuid() != 0 { - t.Skip("skipping test; must be root") + if skip, skipmsg := skipRawSocketTest(t); skip { + t.Skip(skipmsg) } for i, tt := range icmpEchoTests { @@ -168,7 +185,7 @@ func TestPacketConnICMPEcho(t *testing.T) { typ = icmpv6EchoRequest } xid, xseq := os.Getpid()&0xffff, i+1 - b, err := (&icmpMessage{ + wb, err := (&icmpMessage{ Type: typ, Code: 0, Body: &icmpEcho{ ID: xid, Seq: xseq, @@ -178,19 +195,20 @@ func TestPacketConnICMPEcho(t *testing.T) { if err != nil { t.Fatalf("icmpMessage.Marshal failed: %v", err) } - if _, err := c.WriteTo(b, ra); err != nil { + if _, err := c.WriteTo(wb, ra); err != nil { t.Fatalf("PacketConn.WriteTo failed: %v", err) } var m *icmpMessage + rb := make([]byte, 20+len(wb)) for { - if _, _, err := c.ReadFrom(b); err != nil { + if _, _, err := c.ReadFrom(rb); err != nil { t.Fatalf("PacketConn.ReadFrom failed: %v", err) } - // TODO: fix issue 3944 + // See BUG section. //if net == "ip4" { - // b = ipv4Payload(b) + // rb = ipv4Payload(rb) //} - if m, err = parseICMPMessage(b); err != nil { + if m, err = parseICMPMessage(rb); err != nil { t.Fatalf("parseICMPMessage failed: %v", err) } switch m.Type { @@ -218,115 +236,6 @@ func ipv4Payload(b []byte) []byte { return b[hdrlen:] } -const ( - icmpv4EchoRequest = 8 - icmpv4EchoReply = 0 - icmpv6EchoRequest = 128 - icmpv6EchoReply = 129 -) - -// icmpMessage represents an ICMP message. -type icmpMessage struct { - Type int // type - Code int // code - Checksum int // checksum - Body icmpMessageBody // body -} - -// icmpMessageBody represents an ICMP message body. -type icmpMessageBody interface { - Len() int - Marshal() ([]byte, error) -} - -// Marshal returns the binary enconding of the ICMP echo request or -// reply message m. -func (m *icmpMessage) Marshal() ([]byte, error) { - b := []byte{byte(m.Type), byte(m.Code), 0, 0} - if m.Body != nil && m.Body.Len() != 0 { - mb, err := m.Body.Marshal() - if err != nil { - return nil, err - } - b = append(b, mb...) - } - switch m.Type { - case icmpv6EchoRequest, icmpv6EchoReply: - return b, nil - } - csumcv := len(b) - 1 // checksum coverage - s := uint32(0) - for i := 0; i < csumcv; i += 2 { - s += uint32(b[i+1])<<8 | uint32(b[i]) - } - if csumcv&1 == 0 { - s += uint32(b[csumcv]) - } - s = s>>16 + s&0xffff - s = s + s>>16 - // Place checksum back in header; using ^= avoids the - // assumption the checksum bytes are zero. - b[2] ^= byte(^s & 0xff) - b[3] ^= byte(^s >> 8) - return b, nil -} - -// parseICMPMessage parses b as an ICMP message. -func parseICMPMessage(b []byte) (*icmpMessage, error) { - msglen := len(b) - if msglen < 4 { - return nil, errors.New("message too short") - } - m := &icmpMessage{Type: int(b[0]), Code: int(b[1]), Checksum: int(b[2])<<8 | int(b[3])} - if msglen > 4 { - var err error - switch m.Type { - case icmpv4EchoRequest, icmpv4EchoReply, icmpv6EchoRequest, icmpv6EchoReply: - m.Body, err = parseICMPEcho(b[4:]) - if err != nil { - return nil, err - } - } - } - return m, nil -} - -// imcpEcho represenets an ICMP echo request or reply message body. -type icmpEcho struct { - ID int // identifier - Seq int // sequence number - Data []byte // data -} - -func (p *icmpEcho) Len() int { - if p == nil { - return 0 - } - return 4 + len(p.Data) -} - -// Marshal returns the binary enconding of the ICMP echo request or -// reply message body p. -func (p *icmpEcho) Marshal() ([]byte, error) { - b := make([]byte, 4+len(p.Data)) - b[0], b[1] = byte(p.ID>>8), byte(p.ID&0xff) - b[2], b[3] = byte(p.Seq>>8), byte(p.Seq&0xff) - copy(b[4:], p.Data) - return b, nil -} - -// parseICMPEcho parses b as an ICMP echo request or reply message -// body. -func parseICMPEcho(b []byte) (*icmpEcho, error) { - bodylen := len(b) - p := &icmpEcho{ID: int(b[0])<<8 | int(b[1]), Seq: int(b[2])<<8 | int(b[3])} - if bodylen > 4 { - p.Data = make([]byte, bodylen-4) - copy(p.Data, b[4:]) - } - return p, nil -} - var ipConnLocalNameTests = []struct { net string laddr *IPAddr @@ -337,8 +246,13 @@ var ipConnLocalNameTests = []struct { } func TestIPConnLocalName(t *testing.T) { - if os.Getuid() != 0 { - t.Skip("skipping test; must be root") + switch runtime.GOOS { + case "plan9", "windows": + t.Skipf("skipping test on %q", runtime.GOOS) + default: + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } } for _, tt := range ipConnLocalNameTests { @@ -354,8 +268,13 @@ func TestIPConnLocalName(t *testing.T) { } func TestIPConnRemoteName(t *testing.T) { - if os.Getuid() != 0 { - t.Skip("skipping test; must be root") + switch runtime.GOOS { + case "plan9", "windows": + t.Skipf("skipping test on %q", runtime.GOOS) + default: + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } } raddr := &IPAddr{IP: IPv4(127, 0, 0, 10).To4()} diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go index 0be94eb70eb..5cc361390ff 100644 --- a/libgo/go/net/iprawsock.go +++ b/libgo/go/net/iprawsock.go @@ -23,6 +23,13 @@ func (a *IPAddr) String() string { return a.IP.String() } +func (a *IPAddr) toAddr() Addr { + if a == nil { + return nil + } + return a +} + // ResolveIPAddr parses addr as an IP address of the form "host" or // "ipv6-host%zone" and resolves the domain name on the network net, // which must be "ip", "ip4" or "ip6". @@ -43,5 +50,5 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) { if err != nil { return nil, err } - return a.(*IPAddr), nil + return a.toAddr().(*IPAddr), nil } diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go index caeeb465383..72285325761 100644 --- a/libgo/go/net/iprawsock_posix.go +++ b/libgo/go/net/iprawsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net @@ -11,6 +11,18 @@ import ( "time" ) +// BUG(mikio): On every POSIX platform, reads from the "ip4" network +// using the ReadFrom or ReadFromIP method might not return a complete +// IPv4 packet, including its header, even if there is space +// available. This can occur even in cases where Read or ReadMsgIP +// could return a complete packet. For this reason, it is recommended +// that you do not uses these methods if it is important to receive a +// full packet. +// +// The Go 1 compatibliity guidelines make it impossible for us to +// change the behavior of these methods; use Read or ReadMsgIP +// instead. + func sockaddrToIP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: @@ -39,14 +51,10 @@ func (a *IPAddr) isWildcard() bool { } func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) { - return ipToSockaddr(family, a.IP, 0, a.Zone) -} - -func (a *IPAddr) toAddr() sockaddr { - if a == nil { // nil *IPAddr - return nil // nil interface + if a == nil { + return nil, nil } - return a + return ipToSockaddr(family, a.IP, 0, a.Zone) } // IPConn is the implementation of the Conn and PacketConn interfaces @@ -125,6 +133,9 @@ func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL } + if addr == nil { + return 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} + } sa, err := addr.sockaddr(c.fd.family) if err != nil { return 0, &OpError{"write", c.fd.net, addr, err} @@ -151,6 +162,9 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error if !c.ok() { return 0, 0, syscall.EINVAL } + if addr == nil { + return 0, 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} + } sa, err := addr.sockaddr(c.fd.family) if err != nil { return 0, 0, &OpError{"write", c.fd.net, addr, err} @@ -168,19 +182,19 @@ func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { net, proto, err := parseNetwork(netProto) if err != nil { - return nil, err + return nil, &OpError{Op: "dial", Net: netProto, Addr: raddr, Err: err} } switch net { case "ip", "ip4", "ip6": default: - return nil, UnknownNetworkError(netProto) + return nil, &OpError{Op: "dial", Net: netProto, Addr: raddr, Err: UnknownNetworkError(netProto)} } if raddr == nil { - return nil, &OpError{"dial", netProto, nil, errMissingAddress} + return nil, &OpError{Op: "dial", Net: netProto, Addr: nil, Err: errMissingAddress} } - fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_RAW, proto, "dial", sockaddrToIP) + fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", sockaddrToIP) if err != nil { - return nil, err + return nil, &OpError{Op: "dial", Net: netProto, Addr: raddr, Err: err} } return newIPConn(fd), nil } @@ -192,16 +206,16 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { net, proto, err := parseNetwork(netProto) if err != nil { - return nil, err + return nil, &OpError{Op: "dial", Net: netProto, Addr: laddr, Err: err} } switch net { case "ip", "ip4", "ip6": default: - return nil, UnknownNetworkError(netProto) + return nil, &OpError{Op: "listen", Net: netProto, Addr: laddr, Err: UnknownNetworkError(netProto)} } - fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_RAW, proto, "listen", sockaddrToIP) + fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", sockaddrToIP) if err != nil { - return nil, err + return nil, &OpError{Op: "listen", Net: netProto, Addr: laddr, Err: err} } return newIPConn(fd), nil } diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go index d930595879c..8b586ef7c3e 100644 --- a/libgo/go/net/ipsock.go +++ b/libgo/go/net/ipsock.go @@ -6,68 +6,135 @@ package net -import "time" +import ( + "errors" + "time" +) -var supportsIPv6, supportsIPv4map bool +var ( + // supportsIPv4 reports whether the platform supports IPv4 + // networking functionality. + supportsIPv4 bool + + // supportsIPv6 reports whether the platfrom supports IPv6 + // networking functionality. + supportsIPv6 bool + + // supportsIPv4map reports whether the platform supports + // mapping an IPv4 address inside an IPv6 address at transport + // layer protocols. See RFC 4291, RFC 4038 and RFC 3493. + supportsIPv4map bool +) func init() { sysInit() + supportsIPv4 = probeIPv4Stack() supportsIPv6, supportsIPv4map = probeIPv6Stack() } -func firstFavoriteAddr(filter func(IP) IP, addrs []string) (addr IP) { - if filter == nil { - // We'll take any IP address, but since the dialing code - // does not yet try multiple addresses, prefer to use - // an IPv4 address if possible. This is especially relevant - // if localhost resolves to [ipv6-localhost, ipv4-localhost]. - // Too much code assumes localhost == ipv4-localhost. - addr = firstSupportedAddr(ipv4only, addrs) - if addr == nil { - addr = firstSupportedAddr(anyaddr, addrs) - } - } else { - addr = firstSupportedAddr(filter, addrs) +// A netaddr represents a network endpoint address or a list of +// network endpoint addresses. +type netaddr interface { + // toAddr returns the address represented in Addr interface. + // It returns a nil interface when the address is nil. + toAddr() Addr +} + +// An addrList represents a list of network endpoint addresses. +type addrList []netaddr + +func (al addrList) toAddr() Addr { + switch len(al) { + case 0: + return nil + case 1: + return al[0].toAddr() + default: + // For now, we'll roughly pick first one without + // considering dealing with any preferences such as + // DNS TTL, transport path quality, network routing + // information. + return al[0].toAddr() } - return } -func firstSupportedAddr(filter func(IP) IP, addrs []string) IP { - for _, s := range addrs { - if addr := filter(ParseIP(s)); addr != nil { - return addr +var errNoSuitableAddress = errors.New("no suitable address found") + +// firstFavoriteAddr returns an address or a list of addresses that +// implement the netaddr interface. Known filters are nil, ipv4only +// and ipv6only. It returns any address when filter is nil. The result +// contains at least one address when error is nil. +func firstFavoriteAddr(filter func(IP) IP, ips []IP, inetaddr func(IP) netaddr) (netaddr, error) { + if filter != nil { + return firstSupportedAddr(filter, ips, inetaddr) + } + var ( + ipv4, ipv6, swap bool + list addrList + ) + for _, ip := range ips { + // We'll take any IP address, but since the dialing + // code does not yet try multiple addresses + // effectively, prefer to use an IPv4 address if + // possible. This is especially relevant if localhost + // resolves to [ipv6-localhost, ipv4-localhost]. Too + // much code assumes localhost == ipv4-localhost. + if ip4 := ipv4only(ip); ip4 != nil && !ipv4 { + list = append(list, inetaddr(ip4)) + ipv4 = true + if ipv6 { + swap = true + } + } else if ip6 := ipv6only(ip); ip6 != nil && !ipv6 { + list = append(list, inetaddr(ip6)) + ipv6 = true + } + if ipv4 && ipv6 { + if swap { + list[0], list[1] = list[1], list[0] + } + break } } - return nil + switch len(list) { + case 0: + return nil, errNoSuitableAddress + case 1: + return list[0], nil + default: + return list, nil + } } -func anyaddr(x IP) IP { - if x4 := x.To4(); x4 != nil { - return x4 +func firstSupportedAddr(filter func(IP) IP, ips []IP, inetaddr func(IP) netaddr) (netaddr, error) { + for _, ip := range ips { + if ip := filter(ip); ip != nil { + return inetaddr(ip), nil + } } - if supportsIPv6 { - return x + return nil, errNoSuitableAddress +} + +// ipv4only returns IPv4 addresses that we can use with the kernel's +// IPv4 addressing modes. If ip is an IPv4 address, ipv4only returns ip. +// Otherwise it returns nil. +func ipv4only(ip IP) IP { + if supportsIPv4 && ip.To4() != nil { + return ip } return nil } -func ipv4only(x IP) IP { return x.To4() } - -func ipv6only(x IP) IP { - // Only return addresses that we can use - // with the kernel's IPv6 addressing modes. - if len(x) == IPv6len && x.To4() == nil && supportsIPv6 { - return x +// ipv6only returns IPv6 addresses that we can use with the kernel's +// IPv6 addressing modes. It returns IPv4-mapped IPv6 addresses as +// nils and returns other IPv6 address types as IPv6 addresses. +func ipv6only(ip IP) IP { + if supportsIPv6 && len(ip) == IPv6len && ip.To4() == nil { + return ip } return nil } -type InvalidAddrError string - -func (e InvalidAddrError) Error() string { return string(e) } -func (e InvalidAddrError) Timeout() bool { return false } -func (e InvalidAddrError) Temporary() bool { return false } - // SplitHostPort splits a network address of the form "host:port", // "[host]:port" or "[ipv6-host%zone]:port" into host or // ipv6-host%zone and port. A literal address or host name for IPv6 @@ -161,7 +228,13 @@ func JoinHostPort(host, port string) string { return host + ":" + port } -func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { +// resolveInternetAddr resolves addr that is either a literal IP +// address or a DNS name and returns an internet protocol family +// address. It returns a list that contains a pair of different +// address family addresses when addr is a DNS name and the name has +// mutiple address family records. The result contains at least one +// address when error is nil. +func resolveInternetAddr(net, addr string, deadline time.Time) (netaddr, error) { var ( err error host, port, zone string @@ -184,30 +257,32 @@ func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { default: return nil, UnknownNetworkError(net) } - inetaddr := func(net string, ip IP, port int, zone string) Addr { + inetaddr := func(ip IP) netaddr { switch net { case "tcp", "tcp4", "tcp6": - return &TCPAddr{IP: ip, Port: port, Zone: zone} + return &TCPAddr{IP: ip, Port: portnum, Zone: zone} case "udp", "udp4", "udp6": - return &UDPAddr{IP: ip, Port: port, Zone: zone} + return &UDPAddr{IP: ip, Port: portnum, Zone: zone} case "ip", "ip4", "ip6": return &IPAddr{IP: ip, Zone: zone} + default: + panic("unexpected network: " + net) } - return nil } if host == "" { - return inetaddr(net, nil, portnum, zone), nil + return inetaddr(nil), nil } - // Try as an IP address. - if ip := parseIPv4(host); ip != nil { - return inetaddr(net, ip, portnum, zone), nil + // Try as a literal IP address. + var ip IP + if ip = parseIPv4(host); ip != nil { + return inetaddr(ip), nil } - if ip, zone := parseIPv6(host, true); ip != nil { - return inetaddr(net, ip, portnum, zone), nil + if ip, zone = parseIPv6(host, true); ip != nil { + return inetaddr(ip), nil } - // Try as a domain name. + // Try as a DNS name. host, zone = splitHostZone(host) - addrs, err := lookupHostDeadline(host, deadline) + ips, err := lookupIPDeadline(host, deadline) if err != nil { return nil, err } @@ -218,12 +293,7 @@ func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { if net != "" && net[len(net)-1] == '6' || zone != "" { filter = ipv6only } - ip := firstFavoriteAddr(filter, addrs) - if ip == nil { - // should not happen - return nil, &AddrError{"LookupHost returned no suitable address", addrs[0]} - } - return inetaddr(net, ip, portnum, zone), nil + return firstFavoriteAddr(filter, ips, inetaddr) } func zoneToString(zone int) string { diff --git a/libgo/go/net/ipsock_plan9.go b/libgo/go/net/ipsock_plan9.go index c7d542dabc6..fcec4164f4c 100644 --- a/libgo/go/net/ipsock_plan9.go +++ b/libgo/go/net/ipsock_plan9.go @@ -12,13 +12,18 @@ import ( "syscall" ) -// /sys/include/ape/sys/socket.h:/SOMAXCONN -var listenerBacklog = 5 +func probeIPv4Stack() bool { + // TODO(mikio): implement this when Plan 9 supports IPv6-only + // kernel. + return true +} // probeIPv6Stack returns two boolean values. If the first boolean // value is true, kernel supports basic IPv6 functionality. If the // second boolean value is true, kernel supports IPv6 IPv4-mapping. func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { + // TODO(mikio): implement this once Plan 9 gets an IPv6 + // protocol stack implementation. return false, false } diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go index 4c37616ecf8..a83e5256174 100644 --- a/libgo/go/net/ipsock_posix.go +++ b/libgo/go/net/ipsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows // Internet protocol family sockets for POSIX @@ -13,6 +13,17 @@ import ( "time" ) +func probeIPv4Stack() bool { + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + switch err { + case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT: + return false + case nil: + closesocket(s) + } + return true +} + // Should we try to use the IPv4 socket interface if we're // only dealing with IPv4 sockets? As long as the host system // understands IPv6, it's okay to pass IPv4 addresses to the IPv6 @@ -28,8 +39,8 @@ import ( // boolean value is true, kernel supports IPv6 IPv4-mapping. func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { var probes = []struct { - la TCPAddr - ok bool + laddr TCPAddr + ok bool }{ // IPv6 communication capability {TCPAddr{IP: ParseIP("::1")}, false}, @@ -44,12 +55,11 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { } defer closesocket(s) syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) - sa, err := probes[i].la.toAddr().sockaddr(syscall.AF_INET6) + sa, err := probes[i].laddr.sockaddr(syscall.AF_INET6) if err != nil { continue } - err = syscall.Bind(s, sa) - if err != nil { + if err := syscall.Bind(s, sa); err != nil { continue } probes[i].ok = true @@ -121,40 +131,9 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family // Internet sockets (TCP, UDP, IP) -// A sockaddr represents a TCP, UDP or IP network address that can -// be converted into a syscall.Sockaddr. -type sockaddr interface { - Addr - family() int - isWildcard() bool - sockaddr(family int) (syscall.Sockaddr, error) -} - func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { - var la, ra syscall.Sockaddr family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) - if laddr != nil { - if la, err = laddr.sockaddr(family); err != nil { - goto Error - } - } - if raddr != nil { - if ra, err = raddr.sockaddr(family); err != nil { - goto Error - } - } - fd, err = socket(net, family, sotype, proto, ipv6only, la, ra, deadline, toAddr) - if err != nil { - goto Error - } - return fd, nil - -Error: - addr := raddr - if mode == "listen" { - addr = laddr - } - return nil, &OpError{mode, net, addr, err} + return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, toAddr) } func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) { diff --git a/libgo/go/net/ipsock_test.go b/libgo/go/net/ipsock_test.go new file mode 100644 index 00000000000..9ecaaec69f6 --- /dev/null +++ b/libgo/go/net/ipsock_test.go @@ -0,0 +1,193 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "reflect" + "testing" +) + +var testInetaddr = func(ip IP) netaddr { return &TCPAddr{IP: ip, Port: 5682} } + +var firstFavoriteAddrTests = []struct { + filter func(IP) IP + ips []IP + inetaddr func(IP) netaddr + addr netaddr + err error +}{ + { + nil, + []IP{ + IPv4(127, 0, 0, 1), + IPv6loopback, + }, + testInetaddr, + addrList{ + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + }, + nil, + }, + { + nil, + []IP{ + IPv6loopback, + IPv4(127, 0, 0, 1), + }, + testInetaddr, + addrList{ + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + }, + nil, + }, + { + nil, + []IP{ + IPv4(127, 0, 0, 1), + IPv4(192, 168, 0, 1), + }, + testInetaddr, + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + nil, + }, + { + nil, + []IP{ + IPv6loopback, + ParseIP("fe80::1"), + }, + testInetaddr, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + nil, + }, + { + nil, + []IP{ + IPv4(127, 0, 0, 1), + IPv4(192, 168, 0, 1), + IPv6loopback, + ParseIP("fe80::1"), + }, + testInetaddr, + addrList{ + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + }, + nil, + }, + { + nil, + []IP{ + IPv6loopback, + ParseIP("fe80::1"), + IPv4(127, 0, 0, 1), + IPv4(192, 168, 0, 1), + }, + testInetaddr, + addrList{ + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + }, + nil, + }, + { + nil, + []IP{ + IPv4(127, 0, 0, 1), + IPv6loopback, + IPv4(192, 168, 0, 1), + ParseIP("fe80::1"), + }, + testInetaddr, + addrList{ + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + }, + nil, + }, + { + nil, + []IP{ + IPv6loopback, + IPv4(127, 0, 0, 1), + ParseIP("fe80::1"), + IPv4(192, 168, 0, 1), + }, + testInetaddr, + addrList{ + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + }, + nil, + }, + + { + ipv4only, + []IP{ + IPv4(127, 0, 0, 1), + IPv6loopback, + }, + testInetaddr, + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + nil, + }, + { + ipv4only, + []IP{ + IPv6loopback, + IPv4(127, 0, 0, 1), + }, + testInetaddr, + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}, + nil, + }, + + { + ipv6only, + []IP{ + IPv4(127, 0, 0, 1), + IPv6loopback, + }, + testInetaddr, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + nil, + }, + { + ipv6only, + []IP{ + IPv6loopback, + IPv4(127, 0, 0, 1), + }, + testInetaddr, + &TCPAddr{IP: IPv6loopback, Port: 5682}, + nil, + }, + + {nil, nil, testInetaddr, nil, errNoSuitableAddress}, + + {ipv4only, nil, testInetaddr, nil, errNoSuitableAddress}, + {ipv4only, []IP{IPv6loopback}, testInetaddr, nil, errNoSuitableAddress}, + + {ipv6only, nil, testInetaddr, nil, errNoSuitableAddress}, + {ipv6only, []IP{IPv4(127, 0, 0, 1)}, testInetaddr, nil, errNoSuitableAddress}, +} + +func TestFirstFavoriteAddr(t *testing.T) { + if !supportsIPv4 || !supportsIPv6 { + t.Skip("ipv4 or ipv6 is not supported") + } + + for i, tt := range firstFavoriteAddrTests { + addr, err := firstFavoriteAddr(tt.filter, tt.ips, tt.inetaddr) + if err != tt.err { + t.Errorf("#%v: got %v; expected %v", i, err, tt.err) + } + if !reflect.DeepEqual(addr, tt.addr) { + t.Errorf("#%v: got %v; expected %v", i, addr, tt.addr) + } + } +} diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go index bec93ec08cd..20f20578cde 100644 --- a/libgo/go/net/lookup.go +++ b/libgo/go/net/lookup.go @@ -4,9 +4,20 @@ package net -import ( - "time" -) +import "time" + +// protocols contains minimal mappings between internet protocol +// names and numbers for platforms that don't have a complete list of +// protocol numbers. +// +// See http://www.iana.org/assignments/protocol-numbers +var protocols = map[string]int{ + "icmp": 1, "ICMP": 1, + "igmp": 2, "IGMP": 2, + "tcp": 6, "TCP": 6, + "udp": 17, "UDP": 17, + "ipv6-icmp": 58, "IPV6-ICMP": 58, "IPv6-ICMP": 58, +} // LookupHost looks up the given host using the local resolver. // It returns an array of that host's addresses. @@ -14,9 +25,36 @@ func LookupHost(host string) (addrs []string, err error) { return lookupHost(host) } -func lookupHostDeadline(host string, deadline time.Time) (addrs []string, err error) { +// LookupIP looks up host using the local resolver. +// It returns an array of that host's IPv4 and IPv6 addresses. +func LookupIP(host string) (addrs []IP, err error) { + return lookupIPMerge(host) +} + +var lookupGroup singleflight + +// lookupIPMerge wraps lookupIP, but makes sure that for any given +// host, only one lookup is in-flight at a time. The returned memory +// is always owned by the caller. +func lookupIPMerge(host string) (addrs []IP, err error) { + addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) { + return lookupIP(host) + }) + if err != nil { + return nil, err + } + addrs = addrsi.([]IP) + if shared { + clone := make([]IP, len(addrs)) + copy(clone, addrs) + addrs = clone + } + return addrs, nil +} + +func lookupIPDeadline(host string, deadline time.Time) (addrs []IP, err error) { if deadline.IsZero() { - return lookupHost(host) + return lookupIPMerge(host) } // TODO(bradfitz): consider pushing the deadline down into the @@ -34,12 +72,12 @@ func lookupHostDeadline(host string, deadline time.Time) (addrs []string, err er t := time.NewTimer(timeout) defer t.Stop() type res struct { - addrs []string + addrs []IP err error } resc := make(chan res, 1) go func() { - a, err := lookupHost(host) + a, err := lookupIPMerge(host) resc <- res{a, err} }() select { @@ -51,12 +89,6 @@ func lookupHostDeadline(host string, deadline time.Time) (addrs []string, err er return } -// LookupIP looks up host using the local resolver. -// It returns an array of that host's IPv4 and IPv6 addresses. -func LookupIP(host string) (addrs []IP, err error) { - return lookupIP(host) -} - // LookupPort looks up the port for the given network and service. func LookupPort(network, service string) (port int, err error) { return lookupPort(network, service) diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go index 94c55332869..f1204a99f7b 100644 --- a/libgo/go/net/lookup_plan9.go +++ b/libgo/go/net/lookup_plan9.go @@ -186,9 +186,9 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err if len(f) < 6 { continue } - port, _, portOk := dtoi(f[2], 0) + port, _, portOk := dtoi(f[4], 0) priority, _, priorityOk := dtoi(f[3], 0) - weight, _, weightOk := dtoi(f[4], 0) + weight, _, weightOk := dtoi(f[2], 0) if !(portOk && priorityOk && weightOk) { continue } @@ -224,10 +224,10 @@ func lookupNS(name string) (ns []*NS, err error) { } for _, line := range lines { f := getFields(line) - if len(f) < 4 { + if len(f) < 3 { continue } - ns = append(ns, &NS{f[3]}) + ns = append(ns, &NS{f[2]}) } return } diff --git a/libgo/go/net/lookup_unix.go b/libgo/go/net/lookup_unix.go index fa98eed5f26..59e9f63210c 100644 --- a/libgo/go/net/lookup_unix.go +++ b/libgo/go/net/lookup_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd package net @@ -11,15 +11,11 @@ import ( "sync" ) -var ( - protocols map[string]int - onceReadProtocols sync.Once -) +var onceReadProtocols sync.Once // readProtocols loads contents of /etc/protocols into protocols map // for quick access. func readProtocols() { - protocols = make(map[string]int) if file, err := open("/etc/protocols"); err == nil { for line, ok := file.readLine(); ok; line, ok = file.readLine() { // tcp 6 TCP # transmission control protocol @@ -31,9 +27,13 @@ func readProtocols() { continue } if proto, _, ok := dtoi(f[1], 0); ok { - protocols[f[0]] = proto + if _, ok := protocols[f[0]]; !ok { + protocols[f[0]] = proto + } for _, alias := range f[2:] { - protocols[alias] = proto + if _, ok := protocols[alias]; !ok { + protocols[alias] = proto + } } } } diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index 3b29724f27a..130364231d4 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -34,12 +34,19 @@ func lookupProtocol(name string) (proto int, err error) { } ch := make(chan result) go func() { + acquireThread() + defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() proto, err := getprotobyname(name) ch <- result{proto: proto, err: err} }() r := <-ch + if r.err != nil { + if proto, ok := protocols[name]; ok { + return proto, nil + } + } return r.proto, r.err } @@ -56,6 +63,7 @@ func lookupHost(name string) (addrs []string, err error) { } func gethostbyname(name string) (addrs []IP, err error) { + // caller already acquired thread h, err := syscall.GetHostByName(name) if err != nil { return nil, os.NewSyscallError("GetHostByName", err) @@ -83,6 +91,8 @@ func oldLookupIP(name string) (addrs []IP, err error) { } ch := make(chan result) go func() { + acquireThread() + defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() addrs, err := gethostbyname(name) @@ -93,6 +103,8 @@ func oldLookupIP(name string) (addrs []IP, err error) { } func newLookupIP(name string) (addrs []IP, err error) { + acquireThread() + defer releaseThread() hints := syscall.AddrinfoW{ Family: syscall.AF_UNSPEC, Socktype: syscall.SOCK_STREAM, @@ -122,6 +134,8 @@ func newLookupIP(name string) (addrs []IP, err error) { } func getservbyname(network, service string) (port int, err error) { + acquireThread() + defer releaseThread() switch network { case "tcp4", "tcp6": network = "tcp" @@ -144,6 +158,8 @@ func oldLookupPort(network, service string) (port int, err error) { } ch := make(chan result) go func() { + acquireThread() + defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() port, err := getservbyname(network, service) @@ -154,6 +170,8 @@ func oldLookupPort(network, service string) (port int, err error) { } func newLookupPort(network, service string) (port int, err error) { + acquireThread() + defer releaseThread() var stype int32 switch network { case "tcp4", "tcp6": @@ -188,6 +206,8 @@ func newLookupPort(network, service string) (port int, err error) { } func lookupCNAME(name string) (cname string, err error) { + acquireThread() + defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) if e != nil { @@ -202,6 +222,8 @@ func lookupCNAME(name string) (cname string, err error) { } func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { + acquireThread() + defer releaseThread() var target string if service == "" && proto == "" { target = name @@ -224,6 +246,8 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err } func lookupMX(name string) (mx []*MX, err error) { + acquireThread() + defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil) if e != nil { @@ -240,6 +264,8 @@ func lookupMX(name string) (mx []*MX, err error) { } func lookupNS(name string) (ns []*NS, err error) { + acquireThread() + defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil) if e != nil { @@ -255,6 +281,8 @@ func lookupNS(name string) (ns []*NS, err error) { } func lookupTXT(name string) (txt []string, err error) { + acquireThread() + defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil) if e != nil { @@ -273,6 +301,8 @@ func lookupTXT(name string) (txt []string, err error) { } func lookupAddr(addr string) (name []string, err error) { + acquireThread() + defer releaseThread() arpa, err := reverseaddr(addr) if err != nil { return nil, err diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index 96c796e7804..dc2ab44dab2 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -342,7 +342,9 @@ func (p *addrParser) consumePhrase() (phrase string, err error) { word, err = p.consumeQuotedString() } else { // atom - word, err = p.consumeAtom(false) + // We actually parse dot-atom here to be more permissive + // than what RFC 5322 specifies. + word, err = p.consumeAtom(true) } // RFC 2047 encoded-word starts with =?, ends with ?=, and has two other ?s. @@ -519,7 +521,7 @@ func isAtext(c byte, dot bool) bool { return bytes.IndexByte(atextChars, c) >= 0 } -// isQtext returns true if c is an RFC 5322 qtest character. +// isQtext returns true if c is an RFC 5322 qtext character. func isQtext(c byte) bool { // Printable US-ASCII, excluding backslash or quote. if c == '\\' || c == '"' { diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go index 2e746f4a722..3c037f38385 100644 --- a/libgo/go/net/mail/message_test.go +++ b/libgo/go/net/mail/message_test.go @@ -225,6 +225,16 @@ func TestAddressParsing(t *testing.T) { }, }, }, + // Custom example with "." in name. For issue 4938 + { + `Asem H. <noreply@example.com>`, + []*Address{ + { + Name: `Asem H.`, + Address: "noreply@example.com", + }, + }, + }, } for _, test := range tests { if len(test.exp) == 1 { diff --git a/libgo/go/net/mockicmp_test.go b/libgo/go/net/mockicmp_test.go new file mode 100644 index 00000000000..e742365ea03 --- /dev/null +++ b/libgo/go/net/mockicmp_test.go @@ -0,0 +1,116 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "errors" + +const ( + icmpv4EchoRequest = 8 + icmpv4EchoReply = 0 + icmpv6EchoRequest = 128 + icmpv6EchoReply = 129 +) + +// icmpMessage represents an ICMP message. +type icmpMessage struct { + Type int // type + Code int // code + Checksum int // checksum + Body icmpMessageBody // body +} + +// icmpMessageBody represents an ICMP message body. +type icmpMessageBody interface { + Len() int + Marshal() ([]byte, error) +} + +// Marshal returns the binary enconding of the ICMP echo request or +// reply message m. +func (m *icmpMessage) Marshal() ([]byte, error) { + b := []byte{byte(m.Type), byte(m.Code), 0, 0} + if m.Body != nil && m.Body.Len() != 0 { + mb, err := m.Body.Marshal() + if err != nil { + return nil, err + } + b = append(b, mb...) + } + switch m.Type { + case icmpv6EchoRequest, icmpv6EchoReply: + return b, nil + } + csumcv := len(b) - 1 // checksum coverage + s := uint32(0) + for i := 0; i < csumcv; i += 2 { + s += uint32(b[i+1])<<8 | uint32(b[i]) + } + if csumcv&1 == 0 { + s += uint32(b[csumcv]) + } + s = s>>16 + s&0xffff + s = s + s>>16 + // Place checksum back in header; using ^= avoids the + // assumption the checksum bytes are zero. + b[2] ^= byte(^s) + b[3] ^= byte(^s >> 8) + return b, nil +} + +// parseICMPMessage parses b as an ICMP message. +func parseICMPMessage(b []byte) (*icmpMessage, error) { + msglen := len(b) + if msglen < 4 { + return nil, errors.New("message too short") + } + m := &icmpMessage{Type: int(b[0]), Code: int(b[1]), Checksum: int(b[2])<<8 | int(b[3])} + if msglen > 4 { + var err error + switch m.Type { + case icmpv4EchoRequest, icmpv4EchoReply, icmpv6EchoRequest, icmpv6EchoReply: + m.Body, err = parseICMPEcho(b[4:]) + if err != nil { + return nil, err + } + } + } + return m, nil +} + +// imcpEcho represenets an ICMP echo request or reply message body. +type icmpEcho struct { + ID int // identifier + Seq int // sequence number + Data []byte // data +} + +func (p *icmpEcho) Len() int { + if p == nil { + return 0 + } + return 4 + len(p.Data) +} + +// Marshal returns the binary enconding of the ICMP echo request or +// reply message body p. +func (p *icmpEcho) Marshal() ([]byte, error) { + b := make([]byte, 4+len(p.Data)) + b[0], b[1] = byte(p.ID>>8), byte(p.ID) + b[2], b[3] = byte(p.Seq>>8), byte(p.Seq) + copy(b[4:], p.Data) + return b, nil +} + +// parseICMPEcho parses b as an ICMP echo request or reply message +// body. +func parseICMPEcho(b []byte) (*icmpEcho, error) { + bodylen := len(b) + p := &icmpEcho{ID: int(b[0])<<8 | int(b[1]), Seq: int(b[2])<<8 | int(b[3])} + if bodylen > 4 { + p.Data = make([]byte, bodylen-4) + copy(p.Data, b[4:]) + } + return p, nil +} diff --git a/libgo/go/net/mockserver_test.go b/libgo/go/net/mockserver_test.go new file mode 100644 index 00000000000..68ded5d7577 --- /dev/null +++ b/libgo/go/net/mockserver_test.go @@ -0,0 +1,82 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "sync" + +type streamListener struct { + net, addr string + ln Listener +} + +type dualStackServer struct { + lnmu sync.RWMutex + lns []streamListener + port string + + cmu sync.RWMutex + cs []Conn // established connections at the passive open side +} + +func (dss *dualStackServer) buildup(server func(*dualStackServer, Listener)) error { + for i := range dss.lns { + go server(dss, dss.lns[i].ln) + } + return nil +} + +func (dss *dualStackServer) putConn(c Conn) error { + dss.cmu.Lock() + dss.cs = append(dss.cs, c) + dss.cmu.Unlock() + return nil +} + +func (dss *dualStackServer) teardownNetwork(net string) error { + dss.lnmu.Lock() + for i := range dss.lns { + if net == dss.lns[i].net && dss.lns[i].ln != nil { + dss.lns[i].ln.Close() + dss.lns[i].ln = nil + } + } + dss.lnmu.Unlock() + return nil +} + +func (dss *dualStackServer) teardown() error { + dss.lnmu.Lock() + for i := range dss.lns { + if dss.lns[i].ln != nil { + dss.lns[i].ln.Close() + } + } + dss.lnmu.Unlock() + dss.cmu.Lock() + for _, c := range dss.cs { + c.Close() + } + dss.cmu.Unlock() + return nil +} + +func newDualStackServer(lns []streamListener) (*dualStackServer, error) { + dss := &dualStackServer{lns: lns, port: "0"} + for i := range dss.lns { + ln, err := Listen(dss.lns[i].net, dss.lns[i].addr+":"+dss.port) + if err != nil { + dss.teardown() + return nil, err + } + dss.lns[i].ln = ln + if dss.port == "0" { + if _, dss.port, err = SplitHostPort(ln.Addr().String()); err != nil { + dss.teardown() + return nil, err + } + } + } + return dss, nil +} diff --git a/libgo/go/net/multicast_test.go b/libgo/go/net/multicast_test.go index 8ff02a3c933..5660fd42f8c 100644 --- a/libgo/go/net/multicast_test.go +++ b/libgo/go/net/multicast_test.go @@ -158,7 +158,7 @@ func checkMulticastListener(c *UDPConn, ip IP) error { func multicastRIBContains(ip IP) (bool, error) { switch runtime.GOOS { - case "netbsd", "openbsd", "plan9", "solaris", "windows": + case "dragonfly", "netbsd", "openbsd", "plan9", "solaris", "windows": return true, nil // not implemented yet case "linux": if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index 72b2b646c48..2e6db555143 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -46,7 +46,6 @@ import ( "errors" "io" "os" - "sync" "syscall" "time" ) @@ -160,7 +159,7 @@ func (c *conn) SetDeadline(t time.Time) error { if !c.ok() { return syscall.EINVAL } - return setDeadline(c.fd, t) + return c.fd.setDeadline(t) } // SetReadDeadline implements the Conn SetReadDeadline method. @@ -168,7 +167,7 @@ func (c *conn) SetReadDeadline(t time.Time) error { if !c.ok() { return syscall.EINVAL } - return setReadDeadline(c.fd, t) + return c.fd.setReadDeadline(t) } // SetWriteDeadline implements the Conn SetWriteDeadline method. @@ -176,7 +175,7 @@ func (c *conn) SetWriteDeadline(t time.Time) error { if !c.ok() { return syscall.EINVAL } - return setWriteDeadline(c.fd, t) + return c.fd.setWriteDeadline(t) } // SetReadBuffer sets the size of the operating system's @@ -259,6 +258,8 @@ type PacketConn interface { SetWriteDeadline(t time.Time) error } +var listenerBacklog = maxListenerBacklog() + // A Listener is a generic network listener for stream-oriented protocols. // // Multiple goroutines may invoke methods on a Listener simultaneously. @@ -370,6 +371,12 @@ func (e UnknownNetworkError) Error() string { return "unknown network " + stri func (e UnknownNetworkError) Temporary() bool { return false } func (e UnknownNetworkError) Timeout() bool { return false } +type InvalidAddrError string + +func (e InvalidAddrError) Error() string { return string(e) } +func (e InvalidAddrError) Timeout() bool { return false } +func (e InvalidAddrError) Temporary() bool { return false } + // DNSConfigError represents an error reading the machine's DNS configuration. type DNSConfigError struct { Err error @@ -393,35 +400,22 @@ func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) { return io.Copy(writerOnly{w}, r) } -// deadline is an atomically-accessed number of nanoseconds since 1970 -// or 0, if no deadline is set. -type deadline struct { - sync.Mutex - val int64 -} +// Limit the number of concurrent cgo-using goroutines, because +// each will block an entire operating system thread. The usual culprit +// is resolving many DNS names in separate goroutines but the DNS +// server is not responding. Then the many lookups each use a different +// thread, and the system or the program runs out of threads. -func (d *deadline) expired() bool { - t := d.value() - return t > 0 && time.Now().UnixNano() >= t -} +var threadLimit = make(chan struct{}, 500) -func (d *deadline) value() (v int64) { - d.Lock() - v = d.val - d.Unlock() - return -} +// Using send for acquire is fine here because we are not using this +// to protect any memory. All we care about is the number of goroutines +// making calls at a time. -func (d *deadline) set(v int64) { - d.Lock() - d.val = v - d.Unlock() +func acquireThread() { + threadLimit <- struct{}{} } -func (d *deadline) setTime(t time.Time) { - if t.IsZero() { - d.set(0) - } else { - d.set(t.UnixNano()) - } +func releaseThread() { + <-threadLimit } diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go index 1a512a5b110..1320096df8f 100644 --- a/libgo/go/net/net_test.go +++ b/libgo/go/net/net_test.go @@ -25,6 +25,7 @@ func TestShutdown(t *testing.T) { } go func() { + defer ln.Close() c, err := ln.Accept() if err != nil { t.Fatalf("Accept: %v", err) @@ -75,7 +76,10 @@ func TestShutdownUnix(t *testing.T) { if err != nil { t.Fatalf("ListenUnix on %s: %s", tmpname, err) } - defer os.Remove(tmpname) + defer func() { + ln.Close() + os.Remove(tmpname) + }() go func() { c, err := ln.Accept() @@ -214,3 +218,41 @@ func TestTCPClose(t *testing.T) { t.Fatal(err) } } + +func TestErrorNil(t *testing.T) { + c, err := Dial("tcp", "127.0.0.1:65535") + if err == nil { + t.Fatal("Dial 127.0.0.1:65535 succeeded") + } + if c != nil { + t.Fatalf("Dial returned non-nil interface %T(%v) with err != nil", c, c) + } + + // Make Listen fail by relistening on the same address. + l, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("Listen 127.0.0.1:0: %v", err) + } + defer l.Close() + l1, err := Listen("tcp", l.Addr().String()) + if err == nil { + t.Fatal("second Listen %v: %v", l.Addr(), err) + } + if l1 != nil { + t.Fatalf("Listen returned non-nil interface %T(%v) with err != nil", l1, l1) + } + + // Make ListenPacket fail by relistening on the same address. + lp, err := ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal("Listen 127.0.0.1:0: %v", err) + } + defer lp.Close() + lp1, err := ListenPacket("udp", lp.LocalAddr().String()) + if err == nil { + t.Fatal("second Listen %v: %v", lp.LocalAddr(), err) + } + if lp1 != nil { + t.Fatalf("ListenPacket returned non-nil interface %T(%v) with err != nil", lp1, lp1) + } +} diff --git a/libgo/go/net/packetconn_test.go b/libgo/go/net/packetconn_test.go index ec5dd710f55..945003f67ad 100644 --- a/libgo/go/net/packetconn_test.go +++ b/libgo/go/net/packetconn_test.go @@ -21,6 +21,45 @@ func strfunc(s string) func() string { } } +func packetConnTestData(t *testing.T, net string, i int) ([]byte, func()) { + switch net { + case "udp": + return []byte("UDP PACKETCONN TEST"), nil + case "ip": + if skip, skipmsg := skipRawSocketTest(t); skip { + return nil, func() { + t.Logf(skipmsg) + } + } + b, err := (&icmpMessage{ + Type: icmpv4EchoRequest, Code: 0, + Body: &icmpEcho{ + ID: os.Getpid() & 0xffff, Seq: i + 1, + Data: []byte("IP PACKETCONN TEST"), + }, + }).Marshal() + if err != nil { + return nil, func() { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } + } + return b, nil + case "unixgram": + switch runtime.GOOS { + case "plan9", "windows": + return nil, func() { + t.Logf("skipping %q test on %q", net, runtime.GOOS) + } + default: + return []byte("UNIXGRAM PACKETCONN TEST"), nil + } + default: + return nil, func() { + t.Logf("skipping %q test", net) + } + } +} + var packetConnTests = []struct { net string addr1 func() string @@ -42,37 +81,10 @@ func TestPacketConn(t *testing.T) { } for i, tt := range packetConnTests { - var wb []byte netstr := strings.Split(tt.net, ":") - switch netstr[0] { - case "udp": - wb = []byte("UDP PACKETCONN TEST") - case "ip": - switch runtime.GOOS { - case "plan9": - continue - } - if os.Getuid() != 0 { - continue - } - var err error - wb, err = (&icmpMessage{ - Type: icmpv4EchoRequest, Code: 0, - Body: &icmpEcho{ - ID: os.Getpid() & 0xffff, Seq: i + 1, - Data: []byte("IP PACKETCONN TEST"), - }, - }).Marshal() - if err != nil { - t.Fatalf("icmpMessage.Marshal failed: %v", err) - } - case "unixgram": - switch runtime.GOOS { - case "plan9", "windows": - continue - } - wb = []byte("UNIXGRAM PACKETCONN TEST") - default: + wb, skipOrFatalFn := packetConnTestData(t, netstr[0], i) + if skipOrFatalFn != nil { + skipOrFatalFn() continue } @@ -127,35 +139,9 @@ func TestConnAndPacketConn(t *testing.T) { for i, tt := range packetConnTests { var wb []byte netstr := strings.Split(tt.net, ":") - switch netstr[0] { - case "udp": - wb = []byte("UDP PACKETCONN TEST") - case "ip": - switch runtime.GOOS { - case "plan9": - continue - } - if os.Getuid() != 0 { - continue - } - var err error - wb, err = (&icmpMessage{ - Type: icmpv4EchoRequest, Code: 0, - Body: &icmpEcho{ - ID: os.Getpid() & 0xffff, Seq: i + 1, - Data: []byte("IP PACKETCONN TEST"), - }, - }).Marshal() - if err != nil { - t.Fatalf("icmpMessage.Marshal failed: %v", err) - } - case "unixgram": - switch runtime.GOOS { - case "plan9", "windows": - continue - } - wb = []byte("UNIXGRAM PACKETCONN TEST") - default: + wb, skipOrFatalFn := packetConnTestData(t, netstr[0], i) + if skipOrFatalFn != nil { + skipOrFatalFn() continue } @@ -186,7 +172,7 @@ func TestConnAndPacketConn(t *testing.T) { } rb1 := make([]byte, 128) if _, _, err := c1.ReadFrom(rb1); err != nil { - t.Fatalf("PacetConn.ReadFrom failed: %v", err) + t.Fatalf("PacketConn.ReadFrom failed: %v", err) } var dst Addr switch netstr[0] { diff --git a/libgo/go/net/parse_test.go b/libgo/go/net/parse_test.go index 9df0c534b33..b86bc32884b 100644 --- a/libgo/go/net/parse_test.go +++ b/libgo/go/net/parse_test.go @@ -23,12 +23,14 @@ func TestReadLine(t *testing.T) { if err != nil { t.Fatalf("open %s: %v", filename, err) } + defer fd.Close() br := bufio.NewReader(fd) file, err := open(filename) if file == nil { t.Fatalf("net.open(%s) = nil", filename) } + defer file.close() lineno := 1 byteno := 0 diff --git a/libgo/go/net/port_unix.go b/libgo/go/net/port_unix.go index 16780da1160..3cd9ca2aa71 100644 --- a/libgo/go/net/port_unix.go +++ b/libgo/go/net/port_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd linux netbsd openbsd // Read system port mappings from /etc/services diff --git a/libgo/go/net/protoconn_test.go b/libgo/go/net/protoconn_test.go index b59925e01c1..5a8958b0866 100644 --- a/libgo/go/net/protoconn_test.go +++ b/libgo/go/net/protoconn_test.go @@ -103,6 +103,7 @@ func TestTCPConnSpecificMethods(t *testing.T) { } defer c.Close() c.SetKeepAlive(false) + c.SetKeepAlivePeriod(3 * time.Second) c.SetLinger(0) c.SetNoDelay(false) c.LocalAddr() @@ -160,15 +161,20 @@ func TestUDPConnSpecificMethods(t *testing.T) { } else { f.Close() } + + defer func() { + if p := recover(); p != nil { + t.Fatalf("UDPConn.WriteToUDP or WriteMsgUDP panicked: %v", p) + } + }() + + c.WriteToUDP(wb, nil) + c.WriteMsgUDP(wb, nil, nil) } func TestIPConnSpecificMethods(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("skipping test on %q", runtime.GOOS) - } - if os.Getuid() != 0 { - t.Skipf("skipping test; must be root") + if skip, skipmsg := skipRawSocketTest(t); skip { + t.Skip(skipmsg) } la, err := ResolveIPAddr("ip4", "127.0.0.1") @@ -198,7 +204,7 @@ func TestIPConnSpecificMethods(t *testing.T) { if err != nil { t.Fatalf("icmpMessage.Marshal failed: %v", err) } - rb := make([]byte, 20+128) + rb := make([]byte, 20+len(wb)) if _, err := c.WriteToIP(wb, c.LocalAddr().(*IPAddr)); err != nil { t.Fatalf("IPConn.WriteToIP failed: %v", err) } @@ -217,6 +223,15 @@ func TestIPConnSpecificMethods(t *testing.T) { } else { f.Close() } + + defer func() { + if p := recover(); p != nil { + t.Fatalf("IPConn.WriteToIP or WriteMsgIP panicked: %v", p) + } + }() + + c.WriteToIP(wb, nil) + c.WriteMsgIP(wb, nil, nil) } func TestUnixListenerSpecificMethods(t *testing.T) { @@ -357,4 +372,15 @@ func TestUnixConnSpecificMethods(t *testing.T) { } else { f.Close() } + + defer func() { + if p := recover(); p != nil { + t.Fatalf("UnixConn.WriteToUnix or WriteMsgUnix panicked: %v", p) + } + }() + + c1.WriteToUnix(wb, nil) + c1.WriteMsgUnix(wb, nil, nil) + c3.WriteToUnix(wb, nil) + c3.WriteMsgUnix(wb, nil, nil) } diff --git a/libgo/go/net/race.go b/libgo/go/net/race.go new file mode 100644 index 00000000000..2f02a6c226b --- /dev/null +++ b/libgo/go/net/race.go @@ -0,0 +1,31 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race +// +build windows + +package net + +import ( + "runtime" + "unsafe" +) + +const raceenabled = true + +func raceAcquire(addr unsafe.Pointer) { + runtime.RaceAcquire(addr) +} + +func raceReleaseMerge(addr unsafe.Pointer) { + runtime.RaceReleaseMerge(addr) +} + +func raceReadRange(addr unsafe.Pointer, len int) { + runtime.RaceReadRange(addr, len) +} + +func raceWriteRange(addr unsafe.Pointer, len int) { + runtime.RaceWriteRange(addr, len) +} diff --git a/libgo/go/net/race0.go b/libgo/go/net/race0.go new file mode 100644 index 00000000000..f5042977931 --- /dev/null +++ b/libgo/go/net/race0.go @@ -0,0 +1,26 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !race +// +build windows + +package net + +import ( + "unsafe" +) + +const raceenabled = false + +func raceAcquire(addr unsafe.Pointer) { +} + +func raceReleaseMerge(addr unsafe.Pointer) { +} + +func raceReadRange(addr unsafe.Pointer, len int) { +} + +func raceWriteRange(addr unsafe.Pointer, len int) { +} diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go index 4b0c9c3bba2..c524d0a0a2d 100644 --- a/libgo/go/net/rpc/client.go +++ b/libgo/go/net/rpc/client.go @@ -58,6 +58,7 @@ type Client struct { // argument to force the body of the response to be read and then // discarded. type ClientCodec interface { + // WriteRequest must be safe for concurrent use by multiple goroutines. WriteRequest(*Request, interface{}) error ReadResponseHeader(*Response) error ReadResponseBody(interface{}) error @@ -160,7 +161,7 @@ func (client *Client) input() { } client.mutex.Unlock() client.sending.Unlock() - if err != io.EOF && !closing { + if debugLog && err != io.EOF && !closing { log.Println("rpc: client protocol error:", err) } } @@ -172,7 +173,9 @@ func (call *Call) done() { default: // We don't want to block here. It is the caller's responsibility to make // sure the channel has enough buffer space. See comment in Go(). - log.Println("rpc: discarding Call reply due to insufficient Done chan capacity") + if debugLog { + log.Println("rpc: discarding Call reply due to insufficient Done chan capacity") + } } } diff --git a/libgo/go/net/rpc/debug.go b/libgo/go/net/rpc/debug.go index 663663fe941..926466d6255 100644 --- a/libgo/go/net/rpc/debug.go +++ b/libgo/go/net/rpc/debug.go @@ -38,6 +38,9 @@ const debugText = `<html> var debug = template.Must(template.New("RPC debug").Parse(debugText)) +// If set, print log statements for internal and I/O errors. +var debugLog = false + type debugMethod struct { Type *methodType Name string diff --git a/libgo/go/net/rpc/jsonrpc/server.go b/libgo/go/net/rpc/jsonrpc/server.go index 5bc05fd0a71..16ec0fe9ad5 100644 --- a/libgo/go/net/rpc/jsonrpc/server.go +++ b/libgo/go/net/rpc/jsonrpc/server.go @@ -20,8 +20,7 @@ type serverCodec struct { c io.Closer // temporary work space - req serverRequest - resp serverResponse + req serverRequest // JSON-RPC clients can use arbitrary json values as request IDs. // Package rpc expects uint64 request IDs. diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go index e71b6fb1a43..7eb2dcf5a9f 100644 --- a/libgo/go/net/rpc/server.go +++ b/libgo/go/net/rpc/server.go @@ -247,10 +247,12 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro sname = name } if sname == "" { - log.Fatal("rpc: no service name for type", s.typ.String()) + s := "rpc.Register: no service name for type " + s.typ.String() + log.Print(s) + return errors.New(s) } if !isExported(sname) && !useName { - s := "rpc Register: type " + sname + " is not exported" + s := "rpc.Register: type " + sname + " is not exported" log.Print(s) return errors.New(s) } @@ -258,13 +260,13 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro return errors.New("rpc: service already defined: " + sname) } s.name = sname - s.method = make(map[string]*methodType) // Install the methods s.method = suitableMethods(s.typ, true) if len(s.method) == 0 { str := "" + // To help the user, see if a pointer receiver would work. method := suitableMethods(reflect.PtrTo(s.typ), false) if len(method) != 0 { @@ -356,7 +358,7 @@ func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply inte resp.Seq = req.Seq sending.Lock() err := codec.WriteResponse(resp, reply) - if err != nil { + if debugLog && err != nil { log.Println("rpc: writing response:", err) } sending.Unlock() @@ -434,7 +436,7 @@ func (server *Server) ServeCodec(codec ServerCodec) { for { service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) if err != nil { - if err != io.EOF { + if debugLog && err != io.EOF { log.Println("rpc:", err) } if !keepReading { @@ -560,20 +562,23 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt // we can still recover and move on to the next request. keepReading = true - serviceMethod := strings.Split(req.ServiceMethod, ".") - if len(serviceMethod) != 2 { + dot := strings.LastIndex(req.ServiceMethod, ".") + if dot < 0 { err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) return } + serviceName := req.ServiceMethod[:dot] + methodName := req.ServiceMethod[dot+1:] + // Look up the request. server.mu.RLock() - service = server.serviceMap[serviceMethod[0]] + service = server.serviceMap[serviceName] server.mu.RUnlock() if service == nil { err = errors.New("rpc: can't find service " + req.ServiceMethod) return } - mtype = service.method[serviceMethod[1]] + mtype = service.method[methodName] if mtype == nil { err = errors.New("rpc: can't find method " + req.ServiceMethod) } @@ -612,6 +617,7 @@ func RegisterName(name string, rcvr interface{}) error { type ServerCodec interface { ReadRequestHeader(*Request) error ReadRequestBody(interface{}) error + // WriteResponse must be safe for concurrent use by multiple goroutines. WriteResponse(*Response, interface{}) error Close() error diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go index eb17210abc9..3b9a88380cf 100644 --- a/libgo/go/net/rpc/server_test.go +++ b/libgo/go/net/rpc/server_test.go @@ -84,6 +84,7 @@ func listenTCP() (net.Listener, string) { func startServer() { Register(new(Arith)) + RegisterName("net.rpc.Arith", new(Arith)) var l net.Listener l, serverAddr = listenTCP() @@ -97,11 +98,13 @@ func startServer() { func startNewServer() { newServer = NewServer() newServer.Register(new(Arith)) + newServer.RegisterName("net.rpc.Arith", new(Arith)) + newServer.RegisterName("newServer.Arith", new(Arith)) var l net.Listener l, newServerAddr = listenTCP() log.Println("NewServer test RPC server listening on", newServerAddr) - go Accept(l) + go newServer.Accept(l) newServer.HandleHTTP(newHttpPath, "/bar") httpOnce.Do(startHttpServer) @@ -118,6 +121,7 @@ func TestRPC(t *testing.T) { testRPC(t, serverAddr) newOnce.Do(startNewServer) testRPC(t, newServerAddr) + testNewServerRPC(t, newServerAddr) } func testRPC(t *testing.T, addr string) { @@ -125,6 +129,7 @@ func testRPC(t *testing.T, addr string) { if err != nil { t.Fatal("dialing", err) } + defer client.Close() // Synchronous calls args := &Args{7, 8} @@ -233,6 +238,36 @@ func testRPC(t *testing.T, addr string) { if reply.C != args.A*args.B { t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) } + + // ServiceName contain "." character + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("net.rpc.Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } +} + +func testNewServerRPC(t *testing.T, addr string) { + client, err := Dial("tcp", addr) + if err != nil { + t.Fatal("dialing", err) + } + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("newServer.Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } } func TestHTTP(t *testing.T) { @@ -253,6 +288,7 @@ func testHTTPRPC(t *testing.T, path string) { if err != nil { t.Fatal("dialing", err) } + defer client.Close() // Synchronous calls args := &Args{7, 8} @@ -329,6 +365,7 @@ func TestServeRequest(t *testing.T) { func testServeRequest(t *testing.T, server *Server) { client := CodecEmulator{server: server} + defer client.Close() args := &Args{7, 8} reply := new(Reply) @@ -411,6 +448,7 @@ func (WriteFailCodec) Close() error { func TestSendDeadlock(t *testing.T) { client := NewClientWithCodec(WriteFailCodec(0)) + defer client.Close() done := make(chan bool) go func() { @@ -449,6 +487,8 @@ func countMallocs(dial func() (*Client, error), t *testing.T) float64 { if err != nil { t.Fatal("error dialing", err) } + defer client.Close() + args := &Args{7, 8} reply := new(Reply) return testing.AllocsPerRun(100, func() { @@ -463,6 +503,9 @@ func countMallocs(dial func() (*Client, error), t *testing.T) float64 { } func TestCountMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } if runtime.GOMAXPROCS(0) > 1 { t.Skip("skipping; GOMAXPROCS>1") } @@ -470,6 +513,9 @@ func TestCountMallocs(t *testing.T) { } func TestCountMallocsOverHTTP(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } if runtime.GOMAXPROCS(0) > 1 { t.Skip("skipping; GOMAXPROCS>1") } @@ -496,6 +542,8 @@ func (writeCrasher) Write(p []byte) (int, error) { func TestClientWriteError(t *testing.T) { w := &writeCrasher{done: make(chan bool)} c := NewClient(w) + defer c.Close() + res := false err := c.Call("foo", 1, &res) if err == nil { @@ -552,6 +600,7 @@ func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { if err != nil { b.Fatal("error dialing:", err) } + defer client.Close() // Synchronous calls args := &Args{7, 8} @@ -587,6 +636,7 @@ func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { if err != nil { b.Fatal("error dialing:", err) } + defer client.Close() // Asynchronous calls args := &Args{7, 8} diff --git a/libgo/go/net/sendfile_dragonfly.go b/libgo/go/net/sendfile_dragonfly.go new file mode 100644 index 00000000000..a2219c16337 --- /dev/null +++ b/libgo/go/net/sendfile_dragonfly.go @@ -0,0 +1,103 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "io" + "os" + "syscall" +) + +// maxSendfileSize is the largest chunk size we ask the kernel to copy +// at a time. +const maxSendfileSize int = 4 << 20 + +// sendFile copies the contents of r to c using the sendfile +// system call to minimize copies. +// +// if handled == true, sendFile returns the number of bytes copied and any +// non-EOF error. +// +// if handled == false, sendFile performed no work. +func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { + // DragonFly uses 0 as the "until EOF" value. If you pass in more bytes than the + // file contains, it will loop back to the beginning ad nauseum until it's sent + // exactly the number of bytes told to. As such, we need to know exactly how many + // bytes to send. + var remain int64 = 0 + + lr, ok := r.(*io.LimitedReader) + if ok { + remain, r = lr.N, lr.R + if remain <= 0 { + return 0, nil, true + } + } + f, ok := r.(*os.File) + if !ok { + return 0, nil, false + } + + if remain == 0 { + fi, err := f.Stat() + if err != nil { + return 0, err, false + } + + remain = fi.Size() + } + + // The other quirk with DragonFly's sendfile implementation is that it doesn't + // use the current position of the file -- if you pass it offset 0, it starts + // from offset 0. There's no way to tell it "start from current position", so + // we have to manage that explicitly. + pos, err := f.Seek(0, os.SEEK_CUR) + if err != nil { + return 0, err, false + } + + if err := c.writeLock(); err != nil { + return 0, err, true + } + defer c.writeUnlock() + + dst := c.sysfd + src := int(f.Fd()) + for remain > 0 { + n := maxSendfileSize + if int64(n) > remain { + n = int(remain) + } + pos1 := pos + n, err1 := syscall.Sendfile(dst, src, &pos1, n) + if n > 0 { + pos += int64(n) + written += int64(n) + remain -= int64(n) + } + if n == 0 && err1 == nil { + break + } + if err1 == syscall.EAGAIN { + if err1 = c.pd.WaitWrite(); err1 == nil { + continue + } + } + if err1 == syscall.EINTR { + continue + } + if err1 != nil { + // This includes syscall.ENOSYS (no kernel + // support) and syscall.EINVAL (fd types which + // don't implement sendfile together) + err = &OpError{"sendfile", c.net, c.raddr, err1} + break + } + } + if lr != nil { + lr.N = remain + } + return written, err, written > 0 +} diff --git a/libgo/go/net/sendfile_freebsd.go b/libgo/go/net/sendfile_freebsd.go index dc5b767557b..42fe799efbd 100644 --- a/libgo/go/net/sendfile_freebsd.go +++ b/libgo/go/net/sendfile_freebsd.go @@ -58,12 +58,10 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, err, false } - c.wio.Lock() - defer c.wio.Unlock() - if err := c.incref(false); err != nil { + if err := c.writeLock(); err != nil { return 0, err, true } - defer c.decref() + defer c.writeUnlock() dst := c.sysfd src := int(f.Fd()) diff --git a/libgo/go/net/sendfile_linux.go b/libgo/go/net/sendfile_linux.go index 6f1323b3dcd..5e117636a80 100644 --- a/libgo/go/net/sendfile_linux.go +++ b/libgo/go/net/sendfile_linux.go @@ -36,12 +36,10 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, false } - c.wio.Lock() - defer c.wio.Unlock() - if err := c.incref(false); err != nil { + if err := c.writeLock(); err != nil { return 0, err, true } - defer c.decref() + defer c.writeUnlock() dst := c.sysfd src := int(f.Fd()) diff --git a/libgo/go/net/sendfile_windows.go b/libgo/go/net/sendfile_windows.go index 2d64f2f5bff..b128ba27b00 100644 --- a/libgo/go/net/sendfile_windows.go +++ b/libgo/go/net/sendfile_windows.go @@ -10,20 +10,6 @@ import ( "syscall" ) -type sendfileOp struct { - anOp - src syscall.Handle // source - n uint32 -} - -func (o *sendfileOp) Submit() (err error) { - return syscall.TransmitFile(o.fd.sysfd, o.src, o.n, 0, &o.o, nil, syscall.TF_WRITE_BEHIND) -} - -func (o *sendfileOp) Name() string { - return "TransmitFile" -} - // sendFile copies the contents of r to c using the TransmitFile // system call to minimize copies. // @@ -33,7 +19,7 @@ func (o *sendfileOp) Name() string { // if handled == false, sendFile performed no work. // // Note that sendfile for windows does not suppport >2GB file. -func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { +func sendFile(fd *netFD, r io.Reader) (written int64, err error, handled bool) { var n int64 = 0 // by default, copy until EOF lr, ok := r.(*io.LimitedReader) @@ -48,18 +34,17 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, false } - if err := c.incref(false); err != nil { + if err := fd.writeLock(); err != nil { return 0, err, true } - defer c.decref() - c.wio.Lock() - defer c.wio.Unlock() - - var o sendfileOp - o.Init(c, 'w') - o.n = uint32(n) - o.src = syscall.Handle(f.Fd()) - done, err := iosrv.ExecIO(&o, 0) + defer fd.writeUnlock() + + o := &fd.wop + o.qty = uint32(n) + o.handle = syscall.Handle(f.Fd()) + done, err := wsrv.ExecIO(o, "TransmitFile", func(o *operation) error { + return syscall.TransmitFile(o.fd.sysfd, o.handle, o.qty, 0, &o.o, nil, syscall.TF_WRITE_BEHIND) + }) if err != nil { return 0, err, false } diff --git a/libgo/go/net/singleflight.go b/libgo/go/net/singleflight.go new file mode 100644 index 00000000000..dc58affdaac --- /dev/null +++ b/libgo/go/net/singleflight.go @@ -0,0 +1,53 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "sync" + +// call is an in-flight or completed singleflight.Do call +type call struct { + wg sync.WaitGroup + val interface{} + err error + dups int +} + +// singleflight represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type singleflight struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *singleflight) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err, true + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + return c.val, c.err, c.dups > 0 +} diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go index 4b917787701..a0a478a8524 100644 --- a/libgo/go/net/smtp/smtp.go +++ b/libgo/go/net/smtp/smtp.go @@ -41,12 +41,13 @@ type Client struct { } // Dial returns a new Client connected to an SMTP server at addr. +// The addr must include a port number. func Dial(addr string) (*Client, error) { conn, err := net.Dial("tcp", addr) if err != nil { return nil, err } - host := addr[:strings.Index(addr, ":")] + host, _, _ := net.SplitHostPort(addr) return NewClient(conn, host) } @@ -63,6 +64,11 @@ func NewClient(conn net.Conn, host string) (*Client, error) { return c, nil } +// Close closes the connection. +func (c *Client) Close() error { + return c.Text.Close() +} + // hello runs a hello exchange if needed. func (c *Client) hello() error { if !c.didHello { @@ -190,7 +196,9 @@ func (c *Client) Auth(a Auth) error { default: err = &textproto.Error{Code: code, Msg: msg64} } - resp, err = a.Next(msg, code == 334) + if err == nil { + resp, err = a.Next(msg, code == 334) + } if err != nil { // abort the AUTH c.cmd(501, "*") @@ -256,15 +264,17 @@ func (c *Client) Data() (io.WriteCloser, error) { return &dataCloser{c, c.Text.DotWriter()}, nil } -// SendMail connects to the server at addr, switches to TLS if possible, -// authenticates with mechanism a if possible, and then sends an email from -// address from, to addresses to, with message msg. +// SendMail connects to the server at addr, switches to TLS if +// possible, authenticates with the optional mechanism a if possible, +// and then sends an email from address from, to addresses to, with +// message msg. func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { c, err := Dial(addr) if err != nil { return err } - if err := c.hello(); err != nil { + defer c.Close() + if err = c.hello(); err != nil { return err } if ok, _ := c.Extension("STARTTLS"); ok { diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go index c190b32c054..2133dc7c7ba 100644 --- a/libgo/go/net/smtp/smtp_test.go +++ b/libgo/go/net/smtp/smtp_test.go @@ -238,6 +238,7 @@ func TestNewClient(t *testing.T) { if err != nil { t.Fatalf("NewClient: %v\n(after %v)", err, out()) } + defer c.Close() if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { t.Fatalf("Expected AUTH supported") } @@ -278,6 +279,7 @@ func TestNewClient2(t *testing.T) { if err != nil { t.Fatalf("NewClient: %v", err) } + defer c.Close() if ok, _ := c.Extension("DSN"); ok { t.Fatalf("Shouldn't support DSN") } @@ -323,6 +325,7 @@ func TestHello(t *testing.T) { if err != nil { t.Fatalf("NewClient: %v", err) } + defer c.Close() c.localName = "customhost" err = nil @@ -501,3 +504,47 @@ SendMail is working for me. . QUIT ` + +func TestAuthFailed(t *testing.T) { + server := strings.Join(strings.Split(authFailedServer, "\n"), "\r\n") + client := strings.Join(strings.Split(authFailedClient, "\n"), "\r\n") + var cmdbuf bytes.Buffer + bcmdbuf := bufio.NewWriter(&cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + defer c.Close() + + c.tls = true + c.serverName = "smtp.google.com" + err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")) + + if err == nil { + t.Error("Auth: expected error; got none") + } else if err.Error() != "535 Invalid credentials\nplease see www.example.com" { + t.Errorf("Auth: got error: %v, want: %s", err, "535 Invalid credentials\nplease see www.example.com") + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } +} + +var authFailedServer = `220 hello world +250-mx.google.com at your service +250 AUTH LOGIN PLAIN +535-Invalid credentials +535 please see www.example.com +221 Goodbye +` + +var authFailedClient = `EHLO localhost +AUTH PLAIN AHVzZXIAcGFzcw== +* +QUIT +` diff --git a/libgo/go/net/sock_bsd.go b/libgo/go/net/sock_bsd.go index d99349265eb..6c37109f5e4 100644 --- a/libgo/go/net/sock_bsd.go +++ b/libgo/go/net/sock_bsd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd netbsd openbsd +// +build darwin dragonfly freebsd netbsd openbsd package net diff --git a/libgo/go/net/sock_plan9.go b/libgo/go/net/sock_plan9.go new file mode 100644 index 00000000000..88d9ed15cf1 --- /dev/null +++ b/libgo/go/net/sock_plan9.go @@ -0,0 +1,10 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +func maxListenerBacklog() int { + // /sys/include/ape/sys/socket.h:/SOMAXCONN + return 5 +} diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go index be89c26db2a..c2d343c5858 100644 --- a/libgo/go/net/sock_posix.go +++ b/libgo/go/net/sock_posix.go @@ -2,78 +2,197 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net import ( + "os" "syscall" "time" ) -var listenerBacklog = maxListenerBacklog() +// A sockaddr represents a TCP, UDP, IP or Unix network endpoint +// address that can be converted into a syscall.Sockaddr. +type sockaddr interface { + Addr -// Generic POSIX socket creation. -func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, deadline time.Time, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { - s, err := sysSocket(f, t, p) + netaddr + + // family returns the platform-dependent address family + // identifier. + family() int + + // isWildcard reports whether the address is a wildcard + // address. + isWildcard() bool + + // sockaddr returns the address converted into a syscall + // sockaddr type that implements syscall.Sockaddr + // interface. It returns a nil interface when the address is + // nil. + sockaddr(family int) (syscall.Sockaddr, error) +} + +// socket returns a network file descriptor that is ready for +// asynchronous I/O using the network poller. +func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { + s, err := sysSocket(family, sotype, proto) if err != nil { return nil, err } - - if err = setDefaultSockopts(s, f, t, ipv6only); err != nil { + if err = setDefaultSockopts(s, family, sotype, ipv6only); err != nil { closesocket(s) return nil, err } - - // This socket is used by a listener. - if ulsa != nil && ursa == nil { - // We provide a socket that listens to a wildcard - // address with reusable UDP port when the given ulsa - // is an appropriate UDP multicast address prefix. - // This makes it possible for a single UDP listener - // to join multiple different group addresses, for - // multiple UDP listeners that listen on the same UDP - // port to join the same group address. - if ulsa, err = listenerSockaddr(s, f, ulsa, toAddr); err != nil { - closesocket(s) - return nil, err - } + if fd, err = newFD(s, family, sotype, net); err != nil { + closesocket(s) + return nil, err } - if ulsa != nil { - if err = syscall.Bind(s, ulsa); err != nil { - closesocket(s) - return nil, err + // This function makes a network file descriptor for the + // following applications: + // + // - An endpoint holder that opens a passive stream + // connenction, known as a stream listener + // + // - An endpoint holder that opens a destination-unspecific + // datagram connection, known as a datagram listener + // + // - An endpoint holder that opens an active stream or a + // destination-specific datagram connection, known as a + // dialer + // + // - An endpoint holder that opens the other connection, such + // as talking to the protocol stack inside the kernel + // + // For stream and datagram listeners, they will only require + // named sockets, so we can assume that it's just a request + // from stream or datagram listeners when laddr is not nil but + // raddr is nil. Otherwise we assume it's just for dialers or + // the other connection holders. + + if laddr != nil && raddr == nil { + switch sotype { + case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET: + if err := fd.listenStream(laddr, listenerBacklog, toAddr); err != nil { + fd.Close() + return nil, err + } + return fd, nil + case syscall.SOCK_DGRAM: + if err := fd.listenDatagram(laddr, toAddr); err != nil { + fd.Close() + return nil, err + } + return fd, nil } } - - if fd, err = newFD(s, f, t, net); err != nil { - closesocket(s) + if err := fd.dial(laddr, raddr, deadline, toAddr); err != nil { + fd.Close() return nil, err } + return fd, nil +} - // This socket is used by a dialer. - if ursa != nil { - if !deadline.IsZero() { - setWriteDeadline(fd, deadline) +func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, toAddr func(syscall.Sockaddr) Addr) error { + var err error + var lsa syscall.Sockaddr + if laddr != nil { + if lsa, err = laddr.sockaddr(fd.family); err != nil { + return err + } else if lsa != nil { + if err := syscall.Bind(fd.sysfd, lsa); err != nil { + return os.NewSyscallError("bind", err) + } } - if err = fd.connect(ulsa, ursa); err != nil { - fd.Close() - return nil, err + } + if err := fd.init(); err != nil { + return err + } + var rsa syscall.Sockaddr + if raddr != nil { + if rsa, err = raddr.sockaddr(fd.family); err != nil { + return err + } else if rsa != nil { + if !deadline.IsZero() { + fd.setWriteDeadline(deadline) + } + if err := fd.connect(lsa, rsa); err != nil { + return err + } + fd.isConnected = true + if !deadline.IsZero() { + fd.setWriteDeadline(noDeadline) + } } - fd.isConnected = true - if !deadline.IsZero() { - setWriteDeadline(fd, time.Time{}) + } + lsa, _ = syscall.Getsockname(fd.sysfd) + if rsa, _ = syscall.Getpeername(fd.sysfd); rsa != nil { + fd.setAddr(toAddr(lsa), toAddr(rsa)) + } else { + fd.setAddr(toAddr(lsa), raddr) + } + return nil +} + +func (fd *netFD) listenStream(laddr sockaddr, backlog int, toAddr func(syscall.Sockaddr) Addr) error { + if err := setDefaultListenerSockopts(fd.sysfd); err != nil { + return err + } + if lsa, err := laddr.sockaddr(fd.family); err != nil { + return err + } else if lsa != nil { + if err := syscall.Bind(fd.sysfd, lsa); err != nil { + return os.NewSyscallError("bind", err) } } + if err := syscall.Listen(fd.sysfd, backlog); err != nil { + return os.NewSyscallError("listen", err) + } + if err := fd.init(); err != nil { + return err + } + lsa, _ := syscall.Getsockname(fd.sysfd) + fd.setAddr(toAddr(lsa), nil) + return nil +} - lsa, _ := syscall.Getsockname(s) - laddr := toAddr(lsa) - rsa, _ := syscall.Getpeername(s) - if rsa == nil { - rsa = ursa - } - raddr := toAddr(rsa) - fd.setAddr(laddr, raddr) - return fd, nil +func (fd *netFD) listenDatagram(laddr sockaddr, toAddr func(syscall.Sockaddr) Addr) error { + switch addr := laddr.(type) { + case *UDPAddr: + // We provide a socket that listens to a wildcard + // address with reusable UDP port when the given laddr + // is an appropriate UDP multicast address prefix. + // This makes it possible for a single UDP listener to + // join multiple different group addresses, for + // multiple UDP listeners that listen on the same UDP + // port to join the same group address. + if addr.IP != nil && addr.IP.IsMulticast() { + if err := setDefaultMulticastSockopts(fd.sysfd); err != nil { + return err + } + addr := *addr + switch fd.family { + case syscall.AF_INET: + addr.IP = IPv4zero + case syscall.AF_INET6: + addr.IP = IPv6unspecified + } + laddr = &addr + } + } + if lsa, err := laddr.sockaddr(fd.family); err != nil { + return err + } else if lsa != nil { + if err := syscall.Bind(fd.sysfd, lsa); err != nil { + return os.NewSyscallError("bind", err) + } + } + if err := fd.init(); err != nil { + return err + } + lsa, _ := syscall.Getsockname(fd.sysfd) + fd.setAddr(toAddr(lsa), nil) + return nil } diff --git a/libgo/go/net/sock_unix.go b/libgo/go/net/sock_unix.go deleted file mode 100644 index b0d6d4900f2..00000000000 --- a/libgo/go/net/sock_unix.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin freebsd linux netbsd openbsd - -package net - -import "syscall" - -func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { - a := toAddr(la) - if a == nil { - return la, nil - } - switch a := a.(type) { - case *TCPAddr, *UnixAddr: - if err := setDefaultListenerSockopts(s); err != nil { - return nil, err - } - case *UDPAddr: - if a.IP.IsMulticast() { - if err := setDefaultMulticastSockopts(s); err != nil { - return nil, err - } - switch f { - case syscall.AF_INET: - a.IP = IPv4zero - case syscall.AF_INET6: - a.IP = IPv6unspecified - } - return a.sockaddr(f) - } - } - return la, nil -} diff --git a/libgo/go/net/sock_windows.go b/libgo/go/net/sock_windows.go index 41368d39e81..6ccde3a24b9 100644 --- a/libgo/go/net/sock_windows.go +++ b/libgo/go/net/sock_windows.go @@ -12,33 +12,6 @@ func maxListenerBacklog() int { return syscall.SOMAXCONN } -func listenerSockaddr(s syscall.Handle, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { - a := toAddr(la) - if a == nil { - return la, nil - } - switch a := a.(type) { - case *TCPAddr, *UnixAddr: - if err := setDefaultListenerSockopts(s); err != nil { - return nil, err - } - case *UDPAddr: - if a.IP.IsMulticast() { - if err := setDefaultMulticastSockopts(s); err != nil { - return nil, err - } - switch f { - case syscall.AF_INET: - a.IP = IPv4zero - case syscall.AF_INET6: - a.IP = IPv6unspecified - } - return a.sockaddr(f) - } - } - return la, nil -} - func sysSocket(f, t, p int) (syscall.Handle, error) { // See ../syscall/exec_unix.go for description of ForkLock. syscall.ForkLock.RLock() diff --git a/libgo/go/net/sockopt_bsd.go b/libgo/go/net/sockopt_bsd.go index af88814b4b9..4b9c2f9afbe 100644 --- a/libgo/go/net/sockopt_bsd.go +++ b/libgo/go/net/sockopt_bsd.go @@ -2,9 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd netbsd openbsd - -// Socket options for BSD variants +// +build darwin dragonfly freebsd netbsd openbsd package net @@ -13,40 +11,26 @@ import ( "syscall" ) -func setDefaultSockopts(s, f, t int, ipv6only bool) error { - switch f { - case syscall.AF_INET6: - if ipv6only { - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1) - } else { - // Allow both IP versions even if the OS default - // is otherwise. Note that some operating systems - // never admit this option. - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) - } +func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { + if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW { + // Allow both IP versions even if the OS default + // is otherwise. Note that some operating systems + // never admit this option. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } // Allow broadcast. - err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) } func setDefaultListenerSockopts(s int) error { // Allow reuse of recently-used addresses. - err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)) } func setDefaultMulticastSockopts(s int) error { // Allow multicast UDP and raw IP datagram sockets to listen // concurrently across multiple listeners. - err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - if err != nil { + if err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { return os.NewSyscallError("setsockopt", err) } // Allow reuse of recently-used ports. @@ -54,10 +38,7 @@ func setDefaultMulticastSockopts(s int) error { // to make an effective multicast application that requires // quick draw possible. if syscall.SO_REUSEPORT != 0 { - err = syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1)) } return nil } diff --git a/libgo/go/net/sockopt_linux.go b/libgo/go/net/sockopt_linux.go index 0f47538c541..54c20b1409b 100644 --- a/libgo/go/net/sockopt_linux.go +++ b/libgo/go/net/sockopt_linux.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Socket options for Linux - package net import ( @@ -11,41 +9,24 @@ import ( "syscall" ) -func setDefaultSockopts(s, f, t int, ipv6only bool) error { - switch f { - case syscall.AF_INET6: - if ipv6only { - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1) - } else { - // Allow both IP versions even if the OS default - // is otherwise. Note that some operating systems - // never admit this option. - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) - } +func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { + if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW { + // Allow both IP versions even if the OS default + // is otherwise. Note that some operating systems + // never admit this option. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } // Allow broadcast. - err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) } func setDefaultListenerSockopts(s int) error { // Allow reuse of recently-used addresses. - err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)) } func setDefaultMulticastSockopts(s int) error { // Allow multicast UDP and raw IP datagram sockets to listen // concurrently across multiple listeners. - err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)) } diff --git a/libgo/go/net/sockopt_posix.go b/libgo/go/net/sockopt_posix.go index 1590f4e98de..ff3bc689940 100644 --- a/libgo/go/net/sockopt_posix.go +++ b/libgo/go/net/sockopt_posix.go @@ -2,9 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows - -// Socket options +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net @@ -103,7 +101,7 @@ done: } func setReadBuffer(fd *netFD, bytes int) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() @@ -111,7 +109,7 @@ func setReadBuffer(fd *netFD, bytes int) error { } func setWriteBuffer(fd *netFD, bytes int) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() @@ -119,21 +117,13 @@ func setWriteBuffer(fd *netFD, bytes int) error { } func setKeepAlive(fd *netFD, keepalive bool) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))) } -func setNoDelay(fd *netFD, noDelay bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))) -} - func setLinger(fd *netFD, sec int) error { var l syscall.Linger if sec >= 0 { @@ -143,7 +133,7 @@ func setLinger(fd *netFD, sec int) error { l.Onoff = 0 l.Linger = 0 } - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() diff --git a/libgo/go/net/sockopt_windows.go b/libgo/go/net/sockopt_windows.go index 0861fe8f4bf..cb64a40c695 100644 --- a/libgo/go/net/sockopt_windows.go +++ b/libgo/go/net/sockopt_windows.go @@ -2,27 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Socket options for Windows - package net import ( "os" "syscall" - "time" ) -func setDefaultSockopts(s syscall.Handle, f, t int, ipv6only bool) error { - switch f { - case syscall.AF_INET6: - if ipv6only { - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 1) - } else { - // Allow both IP versions even if the OS default - // is otherwise. Note that some operating systems - // never admit this option. - syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) - } +func setDefaultSockopts(s syscall.Handle, family, sotype int, ipv6only bool) error { + if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW { + // Allow both IP versions even if the OS default + // is otherwise. Note that some operating systems + // never admit this option. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } // Allow broadcast. syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) @@ -42,27 +34,5 @@ func setDefaultListenerSockopts(s syscall.Handle) error { func setDefaultMulticastSockopts(s syscall.Handle) error { // Allow multicast UDP and raw IP datagram sockets to listen // concurrently across multiple listeners. - err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil -} - -// TODO(dfc) these unused error returns could be removed - -func setReadDeadline(fd *netFD, t time.Time) error { - fd.rdeadline.setTime(t) - return nil -} - -func setWriteDeadline(fd *netFD, t time.Time) error { - fd.wdeadline.setTime(t) - return nil -} - -func setDeadline(fd *netFD, t time.Time) error { - setReadDeadline(fd, t) - setWriteDeadline(fd, t) - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)) } diff --git a/libgo/go/net/sockoptip_bsd.go b/libgo/go/net/sockoptip_bsd.go index 263f8552176..2199e480d42 100644 --- a/libgo/go/net/sockoptip_bsd.go +++ b/libgo/go/net/sockoptip_bsd.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd netbsd openbsd +// +build darwin dragonfly freebsd netbsd openbsd package net @@ -18,25 +18,17 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { } var a [4]byte copy(a[:], ip.To4()) - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err = syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a)) } func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v))) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v)))) } diff --git a/libgo/go/net/sockoptip_linux.go b/libgo/go/net/sockoptip_linux.go index 225fb0c4c6c..a69b778e4d1 100644 --- a/libgo/go/net/sockoptip_linux.go +++ b/libgo/go/net/sockoptip_linux.go @@ -15,25 +15,17 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { v = int32(ifi.Index) } mreq := &syscall.IPMreqn{Ifindex: v} - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)) } func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))) } diff --git a/libgo/go/net/sockoptip_posix.go b/libgo/go/net/sockoptip_posix.go index e4c56a0e4b2..c2579be9114 100644 --- a/libgo/go/net/sockoptip_posix.go +++ b/libgo/go/net/sockoptip_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net @@ -16,15 +16,11 @@ func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error { if err := setIPv4MreqToInterface(mreq, ifi); err != nil { return err } - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)) } func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { @@ -32,27 +28,19 @@ func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { if ifi != nil { v = ifi.Index } - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)) } func setIPv6MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v)) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))) } func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { @@ -61,13 +49,9 @@ func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { if ifi != nil { mreq.Interface = uint32(ifi.Index) } - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err := syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)) } diff --git a/libgo/go/net/sockoptip_windows.go b/libgo/go/net/sockoptip_windows.go index 3e248441ab3..7b11f207aaf 100644 --- a/libgo/go/net/sockoptip_windows.go +++ b/libgo/go/net/sockoptip_windows.go @@ -17,26 +17,17 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { } var a [4]byte copy(a[:], ip.To4()) - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - err = syscall.Setsockopt(fd.sysfd, int32(syscall.IPPROTO_IP), int32(syscall.IP_MULTICAST_IF), (*byte)(unsafe.Pointer(&a[0])), 4) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, (*byte)(unsafe.Pointer(&a[0])), 4)) } func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(false); err != nil { + if err := fd.incref(); err != nil { return err } defer fd.decref() - vv := int32(boolint(v)) - err := syscall.Setsockopt(fd.sysfd, int32(syscall.IPPROTO_IP), int32(syscall.IP_MULTICAST_LOOP), (*byte)(unsafe.Pointer(&vv)), 4) - if err != nil { - return os.NewSyscallError("setsockopt", err) - } - return nil + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))) } diff --git a/libgo/go/net/sys_cloexec.go b/libgo/go/net/sys_cloexec.go index 17e8749087d..bbfcc1a4fc4 100644 --- a/libgo/go/net/sys_cloexec.go +++ b/libgo/go/net/sys_cloexec.go @@ -5,7 +5,7 @@ // This file implements sysSocket and accept for platforms that do not // provide a fast path for setting SetNonblock and CloseOnExec. -// +build darwin freebsd netbsd openbsd +// +build darwin dragonfly freebsd netbsd openbsd package net diff --git a/libgo/go/net/tcp_test.go b/libgo/go/net/tcp_test.go index a71b02b4774..62fd99f5c0b 100644 --- a/libgo/go/net/tcp_test.go +++ b/libgo/go/net/tcp_test.go @@ -6,8 +6,10 @@ package net import ( "fmt" + "io" "reflect" "runtime" + "sync" "testing" "time" ) @@ -59,7 +61,7 @@ func BenchmarkTCP6PersistentTimeout(b *testing.B) { func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { const msgLen = 512 conns := b.N - numConcurrent := runtime.GOMAXPROCS(-1) * 16 + numConcurrent := runtime.GOMAXPROCS(-1) * 2 msgs := 1 if persistent { conns = numConcurrent @@ -147,11 +149,134 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { } } +func BenchmarkTCP4ConcurrentReadWrite(b *testing.B) { + benchmarkTCPConcurrentReadWrite(b, "127.0.0.1:0") +} + +func BenchmarkTCP6ConcurrentReadWrite(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCPConcurrentReadWrite(b, "[::1]:0") +} + +func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) { + // The benchmark creates GOMAXPROCS client/server pairs. + // Each pair creates 4 goroutines: client reader/writer and server reader/writer. + // The benchmark stresses concurrent reading and writing to the same connection. + // Such pattern is used in net/http and net/rpc. + + b.StopTimer() + + P := runtime.GOMAXPROCS(0) + N := b.N / P + W := 1000 + + // Setup P client/server connections. + clients := make([]Conn, P) + servers := make([]Conn, P) + ln, err := Listen("tcp", laddr) + if err != nil { + b.Fatalf("Listen failed: %v", err) + } + defer ln.Close() + done := make(chan bool) + go func() { + for p := 0; p < P; p++ { + s, err := ln.Accept() + if err != nil { + b.Fatalf("Accept failed: %v", err) + } + servers[p] = s + } + done <- true + }() + for p := 0; p < P; p++ { + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + b.Fatalf("Dial failed: %v", err) + } + clients[p] = c + } + <-done + + b.StartTimer() + + var wg sync.WaitGroup + wg.Add(4 * P) + for p := 0; p < P; p++ { + // Client writer. + go func(c Conn) { + defer wg.Done() + var buf [1]byte + for i := 0; i < N; i++ { + v := byte(i) + for w := 0; w < W; w++ { + v *= v + } + buf[0] = v + _, err := c.Write(buf[:]) + if err != nil { + b.Fatalf("Write failed: %v", err) + } + } + }(clients[p]) + + // Pipe between server reader and server writer. + pipe := make(chan byte, 128) + + // Server reader. + go func(s Conn) { + defer wg.Done() + var buf [1]byte + for i := 0; i < N; i++ { + _, err := s.Read(buf[:]) + if err != nil { + b.Fatalf("Read failed: %v", err) + } + pipe <- buf[0] + } + }(servers[p]) + + // Server writer. + go func(s Conn) { + defer wg.Done() + var buf [1]byte + for i := 0; i < N; i++ { + v := <-pipe + for w := 0; w < W; w++ { + v *= v + } + buf[0] = v + _, err := s.Write(buf[:]) + if err != nil { + b.Fatalf("Write failed: %v", err) + } + } + s.Close() + }(servers[p]) + + // Client reader. + go func(c Conn) { + defer wg.Done() + var buf [1]byte + for i := 0; i < N; i++ { + _, err := c.Read(buf[:]) + if err != nil { + b.Fatalf("Read failed: %v", err) + } + } + c.Close() + }(clients[p]) + } + wg.Wait() +} + type resolveTCPAddrTest struct { - net string - litAddr string - addr *TCPAddr - err error + net string + litAddrOrName string + addr *TCPAddr + err error } var resolveTCPAddrTests = []resolveTCPAddrTest{ @@ -167,6 +292,8 @@ var resolveTCPAddrTests = []resolveTCPAddrTest{ {"", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior {"", "[::1]:0", &TCPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior + {"tcp", ":12345", &TCPAddr{Port: 12345}, nil}, + {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")}, } @@ -178,16 +305,33 @@ func init() { {"tcp6", "[fe80::1%" + index + "]:4", &TCPAddr{IP: ParseIP("fe80::1"), Port: 4, Zone: index}, nil}, }...) } + if ips, err := LookupIP("localhost"); err == nil && len(ips) > 1 && supportsIPv4 && supportsIPv6 { + resolveTCPAddrTests = append(resolveTCPAddrTests, []resolveTCPAddrTest{ + {"tcp", "localhost:5", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5}, nil}, + {"tcp4", "localhost:6", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 6}, nil}, + {"tcp6", "localhost:7", &TCPAddr{IP: IPv6loopback, Port: 7}, nil}, + }...) + } } func TestResolveTCPAddr(t *testing.T) { for _, tt := range resolveTCPAddrTests { - addr, err := ResolveTCPAddr(tt.net, tt.litAddr) + addr, err := ResolveTCPAddr(tt.net, tt.litAddrOrName) if err != tt.err { - t.Fatalf("ResolveTCPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + t.Fatalf("ResolveTCPAddr(%q, %q) failed: %v", tt.net, tt.litAddrOrName, err) } if !reflect.DeepEqual(addr, tt.addr) { - t.Fatalf("got %#v; expected %#v", addr, tt.addr) + t.Fatalf("ResolveTCPAddr(%q, %q) = %#v, want %#v", tt.net, tt.litAddrOrName, addr, tt.addr) + } + if err == nil { + str := addr.String() + addr1, err := ResolveTCPAddr(tt.net, str) + if err != nil { + t.Fatalf("ResolveTCPAddr(%q, %q) [from %q]: %v", tt.net, str, tt.litAddrOrName, err) + } + if !reflect.DeepEqual(addr1, addr) { + t.Fatalf("ResolveTCPAddr(%q, %q) [from %q] = %#v, want %#v", tt.net, str, tt.litAddrOrName, addr1, addr) + } } } } @@ -294,3 +438,153 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) { <-done } } + +func TestTCPConcurrentAccept(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4)) + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + const N = 10 + var wg sync.WaitGroup + wg.Add(N) + for i := 0; i < N; i++ { + go func() { + for { + c, err := ln.Accept() + if err != nil { + break + } + c.Close() + } + wg.Done() + }() + } + for i := 0; i < 10*N; i++ { + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + c.Close() + } + ln.Close() + wg.Wait() +} + +func TestTCPReadWriteMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer ln.Close() + var server Conn + errc := make(chan error) + go func() { + var err error + server, err = ln.Accept() + errc <- err + }() + client, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + if err := <-errc; err != nil { + t.Fatalf("Accept failed: %v", err) + } + defer server.Close() + var buf [128]byte + mallocs := testing.AllocsPerRun(1000, func() { + _, err := server.Write(buf[:]) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + _, err = io.ReadFull(client, buf[:]) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + }) + if mallocs > 0 { + t.Fatalf("Got %v allocs, want 0", mallocs) + } +} + +func TestTCPStress(t *testing.T) { + const conns = 2 + const msgLen = 512 + msgs := int(1e4) + if testing.Short() { + msgs = 1e2 + } + + sendMsg := func(c Conn, buf []byte) bool { + n, err := c.Write(buf) + if n != len(buf) || err != nil { + t.Logf("Write failed: %v", err) + return false + } + return true + } + recvMsg := func(c Conn, buf []byte) bool { + for read := 0; read != len(buf); { + n, err := c.Read(buf) + read += n + if err != nil { + t.Logf("Read failed: %v", err) + return false + } + } + return true + } + + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer ln.Close() + // Acceptor. + go func() { + for { + c, err := ln.Accept() + if err != nil { + break + } + // Server connection. + go func(c Conn) { + defer c.Close() + var buf [msgLen]byte + for m := 0; m < msgs; m++ { + if !recvMsg(c, buf[:]) || !sendMsg(c, buf[:]) { + break + } + } + }(c) + } + }() + done := make(chan bool) + for i := 0; i < conns; i++ { + // Client connection. + go func() { + defer func() { + done <- true + }() + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Logf("Dial failed: %v", err) + return + } + defer c.Close() + var buf [msgLen]byte + for m := 0; m < msgs; m++ { + if !sendMsg(c, buf[:]) || !recvMsg(c, buf[:]) { + break + } + } + }() + } + for i := 0; i < conns; i++ { + <-done + } +} diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go index 4d9ebd214e0..f3dfbd23d34 100644 --- a/libgo/go/net/tcpsock.go +++ b/libgo/go/net/tcpsock.go @@ -18,10 +18,18 @@ func (a *TCPAddr) String() string { if a == nil { return "<nil>" } + ip := ipEmptyString(a.IP) if a.Zone != "" { - return JoinHostPort(a.IP.String()+"%"+a.Zone, itoa(a.Port)) + return JoinHostPort(ip+"%"+a.Zone, itoa(a.Port)) } - return JoinHostPort(a.IP.String(), itoa(a.Port)) + return JoinHostPort(ip, itoa(a.Port)) +} + +func (a *TCPAddr) toAddr() Addr { + if a == nil { + return nil + } + return a } // ResolveTCPAddr parses addr as a TCP address of the form "host:port" @@ -42,5 +50,5 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { if err != nil { return nil, err } - return a.(*TCPAddr), nil + return a.toAddr().(*TCPAddr), nil } diff --git a/libgo/go/net/tcpsock_plan9.go b/libgo/go/net/tcpsock_plan9.go index 48334fed7e4..cf9c0f89047 100644 --- a/libgo/go/net/tcpsock_plan9.go +++ b/libgo/go/net/tcpsock_plan9.go @@ -65,6 +65,11 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) error { return syscall.EPLAN9 } +// SetKeepAlivePeriod sets period between keep alives. +func (c *TCPConn) SetKeepAlivePeriod(d time.Duration) error { + return syscall.EPLAN9 +} + // SetNoDelay controls whether the operating system should delay // packet transmission in hopes of sending fewer packets (Nagle's // algorithm). The default is true (no delay), meaning that data is @@ -106,7 +111,7 @@ type TCPListener struct { } // AcceptTCP accepts the next incoming call and returns the new -// connection and the remote address. +// connection. func (l *TCPListener) AcceptTCP() (*TCPConn, error) { if l == nil || l.fd == nil || l.fd.ctl == nil { return nil, syscall.EINVAL @@ -153,7 +158,7 @@ func (l *TCPListener) SetDeadline(t time.Time) error { if l == nil || l.fd == nil || l.fd.ctl == nil { return syscall.EINVAL } - return setDeadline(l.fd, t) + return l.fd.setDeadline(t) } // File returns a copy of the underlying os.File, set to blocking diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go index 876edb101ca..00c692e4233 100644 --- a/libgo/go/net/tcpsock_posix.go +++ b/libgo/go/net/tcpsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net @@ -46,14 +46,10 @@ func (a *TCPAddr) isWildcard() bool { } func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) { - return ipToSockaddr(family, a.IP, a.Port, a.Zone) -} - -func (a *TCPAddr) toAddr() sockaddr { - if a == nil { // nil *TCPAddr - return nil // nil interface + if a == nil { + return nil, nil } - return a + return ipToSockaddr(family, a.IP, a.Port, a.Zone) } // TCPConn is an implementation of the Conn interface for TCP network @@ -121,6 +117,14 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) error { return setKeepAlive(c.fd, keepalive) } +// SetKeepAlivePeriod sets period between keep alives. +func (c *TCPConn) SetKeepAlivePeriod(d time.Duration) error { + if !c.ok() { + return syscall.EINVAL + } + return setKeepAlivePeriod(c.fd, d) +} + // SetNoDelay controls whether the operating system should delay // packet transmission in hopes of sending fewer packets (Nagle's // algorithm). The default is true (no delay), meaning that data is @@ -139,16 +143,16 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { switch net { case "tcp", "tcp4", "tcp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "dial", Net: net, Addr: raddr, Err: UnknownNetworkError(net)} } if raddr == nil { - return nil, &OpError{"dial", net, nil, errMissingAddress} + return nil, &OpError{Op: "dial", Net: net, Addr: nil, Err: errMissingAddress} } return dialTCP(net, laddr, raddr, noDeadline) } func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) { - fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) + fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) // TCP has a rarely used mechanism called a 'simultaneous connection' in // which Dial("tcp", addr1, addr2) run on the machine at addr1 can @@ -178,11 +182,11 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, e if err == nil { fd.Close() } - fd, err = internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) + fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) } if err != nil { - return nil, err + return nil, &OpError{Op: "dial", Net: net, Addr: raddr, Err: err} } return newTCPConn(fd), nil } @@ -221,7 +225,7 @@ type TCPListener struct { } // AcceptTCP accepts the next incoming call and returns the new -// connection and the remote address. +// connection. func (l *TCPListener) AcceptTCP() (*TCPConn, error) { if l == nil || l.fd == nil { return nil, syscall.EINVAL @@ -261,7 +265,7 @@ func (l *TCPListener) SetDeadline(t time.Time) error { if l == nil || l.fd == nil { return syscall.EINVAL } - return setDeadline(l.fd, t) + return l.fd.setDeadline(t) } // File returns a copy of the underlying os.File, set to blocking @@ -281,19 +285,14 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { switch net { case "tcp", "tcp4", "tcp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: UnknownNetworkError(net)} } if laddr == nil { laddr = &TCPAddr{} } - fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP) - if err != nil { - return nil, err - } - err = syscall.Listen(fd.sysfd, listenerBacklog) + fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP) if err != nil { - fd.Close() - return nil, &OpError{"listen", net, laddr, err} + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err} } return &TCPListener{fd}, nil } diff --git a/libgo/go/net/tcpsockopt_darwin.go b/libgo/go/net/tcpsockopt_darwin.go new file mode 100644 index 00000000000..33140849c95 --- /dev/null +++ b/libgo/go/net/tcpsockopt_darwin.go @@ -0,0 +1,27 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP socket options for darwin + +package net + +import ( + "os" + "syscall" + "time" +) + +// Set keep alive period. +func setKeepAlivePeriod(fd *netFD, d time.Duration) error { + if err := fd.incref(); err != nil { + return err + } + defer fd.decref() + + // The kernel expects seconds so round to next highest second. + d += (time.Second - time.Nanosecond) + secs := int(d.Seconds()) + + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, secs)) +} diff --git a/libgo/go/net/tcpsockopt_openbsd.go b/libgo/go/net/tcpsockopt_openbsd.go new file mode 100644 index 00000000000..3480f932c80 --- /dev/null +++ b/libgo/go/net/tcpsockopt_openbsd.go @@ -0,0 +1,27 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP socket options for openbsd + +package net + +import ( + "os" + "syscall" + "time" +) + +// Set keep alive period. +func setKeepAlivePeriod(fd *netFD, d time.Duration) error { + if err := fd.incref(); err != nil { + return err + } + defer fd.decref() + + // The kernel expects seconds so round to next highest second. + d += (time.Second - time.Nanosecond) + secs := int(d.Seconds()) + + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.SO_KEEPALIVE, secs)) +} diff --git a/libgo/go/net/tcpsockopt_posix.go b/libgo/go/net/tcpsockopt_posix.go new file mode 100644 index 00000000000..e03476ac634 --- /dev/null +++ b/libgo/go/net/tcpsockopt_posix.go @@ -0,0 +1,20 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd windows + +package net + +import ( + "os" + "syscall" +) + +func setNoDelay(fd *netFD, noDelay bool) error { + if err := fd.incref(); err != nil { + return err + } + defer fd.decref() + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))) +} diff --git a/libgo/go/net/tcpsockopt_unix.go b/libgo/go/net/tcpsockopt_unix.go new file mode 100644 index 00000000000..89d9143b52e --- /dev/null +++ b/libgo/go/net/tcpsockopt_unix.go @@ -0,0 +1,31 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build dragonfly freebsd linux netbsd + +package net + +import ( + "os" + "syscall" + "time" +) + +// Set keep alive period. +func setKeepAlivePeriod(fd *netFD, d time.Duration) error { + if err := fd.incref(); err != nil { + return err + } + defer fd.decref() + + // The kernel expects seconds so round to next highest second. + d += (time.Second - time.Nanosecond) + secs := int(d.Seconds()) + + err := os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs)) + if err != nil { + return err + } + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, secs)) +} diff --git a/libgo/go/net/tcpsockopt_windows.go b/libgo/go/net/tcpsockopt_windows.go new file mode 100644 index 00000000000..0bf4312f248 --- /dev/null +++ b/libgo/go/net/tcpsockopt_windows.go @@ -0,0 +1,21 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP socket options for windows + +package net + +import ( + "time" +) + +func setKeepAlivePeriod(fd *netFD, d time.Duration) error { + if err := fd.incref(); err != nil { + return err + } + defer fd.decref() + + // We can't actually set this per connection. Act as a noop rather than an error. + return nil +} diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go index 5bd26ac8d61..56ece5b087c 100644 --- a/libgo/go/net/textproto/reader.go +++ b/libgo/go/net/textproto/reader.go @@ -203,7 +203,7 @@ func parseCodeLine(line string, expectCode int) (code int, continued bool, messa // ReadCodeLine reads a response code line of the form // code message -// where code is a 3-digit status code and the message +// where code is a three-digit status code and the message // extends to the rest of the line. An example of such a line is: // 220 plan9.bell-labs.com ESMTP // @@ -231,7 +231,7 @@ func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err err // ... // code message line n // -// where code is a 3-digit status code. The first line starts with the +// where code is a three-digit status code. The first line starts with the // code and a hyphen. The response is terminated by a line that starts // with the same code followed by a space. Each line in message is // separated by a newline (\n). @@ -456,7 +456,16 @@ func (r *Reader) ReadDotLines() ([]string, error) { // } // func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { - m := make(MIMEHeader, 4) + // Avoid lots of small slice allocations later by allocating one + // large one ahead of time which we'll cut up into smaller + // slices. If this isn't big enough later, we allocate small ones. + var strs []string + hint := r.upcomingHeaderNewlines() + if hint > 0 { + strs = make([]string, hint) + } + + m := make(MIMEHeader, hint) for { kv, err := r.readContinuedLineSlice() if len(kv) == 0 { @@ -483,7 +492,18 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { } value := string(kv[i:]) - m[key] = append(m[key], value) + vv := m[key] + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = value + m[key] = vv + } else { + m[key] = append(vv, value) + } if err != nil { return m, err @@ -491,6 +511,29 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { } } +// upcomingHeaderNewlines returns an approximation of the number of newlines +// that will be in this header. If it gets confused, it returns 0. +func (r *Reader) upcomingHeaderNewlines() (n int) { + // Try to determine the 'hint' size. + r.R.Peek(1) // force a buffer load if empty + s := r.R.Buffered() + if s == 0 { + return + } + peek, _ := r.R.Peek(s) + for len(peek) > 0 { + i := bytes.IndexByte(peek, '\n') + if i < 3 { + // Not present (-1) or found within the next few bytes, + // implying we're at the end ("\r\n\r\n" or "\n\n") + return + } + n++ + peek = peek[i+1:] + } + return +} + // CanonicalMIMEHeaderKey returns the canonical format of the // MIME header key s. The canonicalization converts the first // letter and any letter following a hyphen to upper case; diff --git a/libgo/go/net/textproto/textproto.go b/libgo/go/net/textproto/textproto.go index eb6ced1c52e..026eb026b1d 100644 --- a/libgo/go/net/textproto/textproto.go +++ b/libgo/go/net/textproto/textproto.go @@ -105,7 +105,7 @@ func Dial(network, addr string) (*Conn, error) { // if _, _, err = c.ReadCodeLine(110); err != nil { // return nil, err // } -// text, err := c.ReadDotAll() +// text, err := c.ReadDotBytes() // if err != nil { // return nil, err // } diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go index 2e92147b8e3..35d427a69c0 100644 --- a/libgo/go/net/timeout_test.go +++ b/libgo/go/net/timeout_test.go @@ -325,9 +325,6 @@ func TestReadWriteDeadline(t *testing.T) { t.Skipf("skipping test on %q", runtime.GOOS) } - if !canCancelIO { - t.Skip("skipping test on this system") - } const ( readTimeout = 50 * time.Millisecond writeTimeout = 250 * time.Millisecond @@ -496,7 +493,10 @@ func testVariousDeadlines(t *testing.T, maxProcs int) { clientc <- copyRes{n, err, d} }() - const tooLong = 2000 * time.Millisecond + tooLong := 2 * time.Second + if runtime.GOOS == "windows" { + tooLong = 5 * time.Second + } select { case res := <-clientc: if isTimeout(res.err) { @@ -549,7 +549,7 @@ func TestReadDeadlineDataAvailable(t *testing.T) { } defer c.Close() if res := <-servec; res.err != nil || res.n != int64(len(msg)) { - t.Fatalf("unexpected server Write: n=%d, err=%d; want n=%d, err=nil", res.n, res.err, len(msg)) + t.Fatalf("unexpected server Write: n=%d, err=%v; want n=%d, err=nil", res.n, res.err, len(msg)) } c.SetReadDeadline(time.Now().Add(-5 * time.Second)) // in the psat. buf := make([]byte, len(msg)/2) @@ -703,3 +703,40 @@ func TestProlongTimeout(t *testing.T) { c.Write(buf[:]) } } + +func TestDeadlineRace(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + N := 1000 + if testing.Short() { + N = 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4)) + ln := newLocalListener(t) + defer ln.Close() + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + done := make(chan bool) + go func() { + t := time.NewTicker(2 * time.Microsecond).C + for i := 0; i < N; i++ { + if err := c.SetDeadline(time.Now().Add(2 * time.Microsecond)); err != nil { + break + } + <-t + } + done <- true + }() + var buf [1]byte + for i := 0; i < N; i++ { + c.Read(buf[:]) // ignore possible timeout errors + } + c.Close() + <-done +} diff --git a/libgo/go/net/udp_test.go b/libgo/go/net/udp_test.go index 4278f6dd4bc..6f4d2152c3c 100644 --- a/libgo/go/net/udp_test.go +++ b/libgo/go/net/udp_test.go @@ -5,53 +5,31 @@ package net import ( - "fmt" "reflect" "runtime" + "strings" "testing" ) -type resolveUDPAddrTest struct { - net string - litAddr string - addr *UDPAddr - err error -} - -var resolveUDPAddrTests = []resolveUDPAddrTest{ - {"udp", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, - {"udp4", "127.0.0.1:65535", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil}, - - {"udp", "[::1]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1}, nil}, - {"udp6", "[::1]:65534", &UDPAddr{IP: ParseIP("::1"), Port: 65534}, nil}, - - {"udp", "[::1%en0]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1, Zone: "en0"}, nil}, - {"udp6", "[::1%911]:2", &UDPAddr{IP: ParseIP("::1"), Port: 2, Zone: "911"}, nil}, - - {"", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior - {"", "[::1]:0", &UDPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior - - {"sip", "127.0.0.1:0", nil, UnknownNetworkError("sip")}, -} - -func init() { - if ifi := loopbackInterface(); ifi != nil { - index := fmt.Sprintf("%v", ifi.Index) - resolveUDPAddrTests = append(resolveUDPAddrTests, []resolveUDPAddrTest{ - {"udp6", "[fe80::1%" + ifi.Name + "]:3", &UDPAddr{IP: ParseIP("fe80::1"), Port: 3, Zone: zoneToString(ifi.Index)}, nil}, - {"udp6", "[fe80::1%" + index + "]:4", &UDPAddr{IP: ParseIP("fe80::1"), Port: 4, Zone: index}, nil}, - }...) - } -} - func TestResolveUDPAddr(t *testing.T) { - for _, tt := range resolveUDPAddrTests { - addr, err := ResolveUDPAddr(tt.net, tt.litAddr) + for _, tt := range resolveTCPAddrTests { + net := strings.Replace(tt.net, "tcp", "udp", -1) + addr, err := ResolveUDPAddr(net, tt.litAddrOrName) if err != tt.err { - t.Fatalf("ResolveUDPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + t.Fatalf("ResolveUDPAddr(%q, %q) failed: %v", net, tt.litAddrOrName, err) } - if !reflect.DeepEqual(addr, tt.addr) { - t.Fatalf("got %#v; expected %#v", addr, tt.addr) + if !reflect.DeepEqual(addr, (*UDPAddr)(tt.addr)) { + t.Fatalf("ResolveUDPAddr(%q, %q) = %#v, want %#v", net, tt.litAddrOrName, addr, tt.addr) + } + if err == nil { + str := addr.String() + addr1, err := ResolveUDPAddr(net, str) + if err != nil { + t.Fatalf("ResolveUDPAddr(%q, %q) [from %q]: %v", net, str, tt.litAddrOrName, err) + } + if !reflect.DeepEqual(addr1, addr) { + t.Fatalf("ResolveUDPAddr(%q, %q) [from %q] = %#v, want %#v", net, str, tt.litAddrOrName, addr1, addr) + } } } } @@ -224,7 +202,7 @@ func TestIPv6LinkLocalUnicastUDP(t *testing.T) { {"udp6", "[" + laddr + "%" + ifi.Name + "]:0", false}, } switch runtime.GOOS { - case "darwin", "freebsd", "openbsd", "netbsd": + case "darwin", "dragonfly", "freebsd", "openbsd", "netbsd": tests = append(tests, []test{ {"udp", "[localhost%" + ifi.Name + "]:0", true}, {"udp6", "[localhost%" + ifi.Name + "]:0", true}, diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index 5ce7d6bea0f..0dd0dbd7114 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -22,10 +22,18 @@ func (a *UDPAddr) String() string { if a == nil { return "<nil>" } + ip := ipEmptyString(a.IP) if a.Zone != "" { - return JoinHostPort(a.IP.String()+"%"+a.Zone, itoa(a.Port)) + return JoinHostPort(ip+"%"+a.Zone, itoa(a.Port)) } - return JoinHostPort(a.IP.String(), itoa(a.Port)) + return JoinHostPort(ip, itoa(a.Port)) +} + +func (a *UDPAddr) toAddr() Addr { + if a == nil { + return nil + } + return a } // ResolveUDPAddr parses addr as a UDP address of the form "host:port" @@ -46,5 +54,5 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { if err != nil { return nil, err } - return a.(*UDPAddr), nil + return a.toAddr().(*UDPAddr), nil } diff --git a/libgo/go/net/udpsock_plan9.go b/libgo/go/net/udpsock_plan9.go index 12a34839905..73621706d5c 100644 --- a/libgo/go/net/udpsock_plan9.go +++ b/libgo/go/net/udpsock_plan9.go @@ -73,6 +73,9 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if !c.ok() || c.fd.data == nil { return 0, syscall.EINVAL } + if addr == nil { + return 0, &OpError{Op: "write", Net: c.fd.dir, Addr: nil, Err: errMissingAddress} + } h := new(udpHeader) h.raddr = addr.IP.To16() h.laddr = c.fd.laddr.(*UDPAddr).IP.To16() diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go index b90cb030d81..142da8186f1 100644 --- a/libgo/go/net/udpsock_posix.go +++ b/libgo/go/net/udpsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net @@ -39,14 +39,10 @@ func (a *UDPAddr) isWildcard() bool { } func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) { - return ipToSockaddr(family, a.IP, a.Port, a.Zone) -} - -func (a *UDPAddr) toAddr() sockaddr { - if a == nil { // nil *UDPAddr - return nil // nil interface + if a == nil { + return nil, nil } - return a + return ipToSockaddr(family, a.IP, a.Port, a.Zone) } // UDPConn is the implementation of the Conn and PacketConn interfaces @@ -121,6 +117,9 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if c.fd.isConnected { return 0, &OpError{"write", c.fd.net, addr, ErrWriteToConnected} } + if addr == nil { + return 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} + } sa, err := addr.sockaddr(c.fd.family) if err != nil { return 0, &OpError{"write", c.fd.net, addr, err} @@ -150,6 +149,9 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er if c.fd.isConnected { return 0, 0, &OpError{"write", c.fd.net, addr, ErrWriteToConnected} } + if addr == nil { + return 0, 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} + } sa, err := addr.sockaddr(c.fd.family) if err != nil { return 0, 0, &OpError{"write", c.fd.net, addr, err} @@ -161,21 +163,21 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er // which must be "udp", "udp4", or "udp6". If laddr is not nil, it is // used as the local address for the connection. func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { - return dialUDP(net, laddr, raddr, noDeadline) -} - -func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "dial", Net: net, Addr: raddr, Err: UnknownNetworkError(net)} } if raddr == nil { - return nil, &OpError{"dial", net, nil, errMissingAddress} + return nil, &OpError{Op: "dial", Net: net, Addr: nil, Err: errMissingAddress} } - fd, err := internetSocket(net, laddr.toAddr(), raddr.toAddr(), deadline, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) + return dialUDP(net, laddr, raddr, noDeadline) +} + +func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) { + fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) if err != nil { - return nil, err + return nil, &OpError{Op: "dial", Net: net, Addr: raddr, Err: err} } return newUDPConn(fd), nil } @@ -191,14 +193,14 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: UnknownNetworkError(net)} } if laddr == nil { laddr = &UDPAddr{} } - fd, err := internetSocket(net, laddr.toAddr(), nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) + fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) if err != nil { - return nil, err + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err} } return newUDPConn(fd), nil } @@ -211,25 +213,25 @@ func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, e switch net { case "udp", "udp4", "udp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "listen", Net: net, Addr: gaddr, Err: UnknownNetworkError(net)} } if gaddr == nil || gaddr.IP == nil { - return nil, &OpError{"listen", net, nil, errMissingAddress} + return nil, &OpError{Op: "listen", Net: net, Addr: nil, Err: errMissingAddress} } - fd, err := internetSocket(net, gaddr.toAddr(), nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) + fd, err := internetSocket(net, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", sockaddrToUDP) if err != nil { - return nil, err + return nil, &OpError{Op: "listen", Net: net, Addr: gaddr, Err: err} } c := newUDPConn(fd) if ip4 := gaddr.IP.To4(); ip4 != nil { if err := listenIPv4MulticastUDP(c, ifi, ip4); err != nil { c.Close() - return nil, &OpError{"listen", net, &IPAddr{IP: ip4}, err} + return nil, &OpError{Op: "listen", Net: net, Addr: &IPAddr{IP: ip4}, Err: err} } } else { if err := listenIPv6MulticastUDP(c, ifi, gaddr.IP); err != nil { c.Close() - return nil, &OpError{"listen", net, &IPAddr{IP: gaddr.IP}, err} + return nil, &OpError{Op: "listen", Net: net, Addr: &IPAddr{IP: gaddr.IP}, Err: err} } } return c, nil diff --git a/libgo/go/net/unicast_posix_test.go b/libgo/go/net/unicast_posix_test.go index b0588f4e529..5deb8f47c6c 100644 --- a/libgo/go/net/unicast_posix_test.go +++ b/libgo/go/net/unicast_posix_test.go @@ -349,12 +349,16 @@ func checkDualStackSecondListener(t *testing.T, net, laddr string, xerr, err err if xerr == nil && err != nil || xerr != nil && err == nil { t.Fatalf("Second Listen(%q, %q) returns %v, expected %v", net, laddr, err, xerr) } - l.(*TCPListener).Close() + if err == nil { + l.(*TCPListener).Close() + } case "udp", "udp4", "udp6": if xerr == nil && err != nil || xerr != nil && err == nil { t.Fatalf("Second ListenPacket(%q, %q) returns %v, expected %v", net, laddr, err, xerr) } - l.(*UDPConn).Close() + if err == nil { + l.(*UDPConn).Close() + } default: t.Fatalf("Unexpected network: %q", net) } @@ -436,8 +440,8 @@ func TestWildWildcardListener(t *testing.T) { } defer func() { - if recover() != nil { - t.Fatalf("panicked") + if p := recover(); p != nil { + t.Fatalf("Listen, ListenPacket or protocol-specific Listen panicked: %v", p) } }() diff --git a/libgo/go/net/unix_test.go b/libgo/go/net/unix_test.go index 5e63e9d9dec..91df3ff8876 100644 --- a/libgo/go/net/unix_test.go +++ b/libgo/go/net/unix_test.go @@ -107,7 +107,7 @@ func TestReadUnixgramWithZeroBytesBuffer(t *testing.T) { } } -func TestUnixAutobind(t *testing.T) { +func TestUnixgramAutobind(t *testing.T) { if runtime.GOOS != "linux" { t.Skip("skipping: autobind is linux only") } @@ -139,8 +139,21 @@ func TestUnixAutobind(t *testing.T) { } } +func TestUnixAutobindClose(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("skipping: autobind is linux only") + } + laddr := &UnixAddr{Name: "", Net: "unix"} + ln, err := ListenUnix("unix", laddr) + if err != nil { + t.Fatalf("ListenUnix failed: %v", err) + } + ln.Close() +} + func TestUnixConnLocalAndRemoteNames(t *testing.T) { for _, laddr := range []string{"", testUnixAddr()} { + laddr := laddr taddr := testUnixAddr() ta, err := ResolveUnixAddr("unix", taddr) if err != nil { @@ -196,6 +209,7 @@ func TestUnixConnLocalAndRemoteNames(t *testing.T) { func TestUnixgramConnLocalAndRemoteNames(t *testing.T) { for _, laddr := range []string{"", testUnixAddr()} { + laddr := laddr taddr := testUnixAddr() ta, err := ResolveUnixAddr("unixgram", taddr) if err != nil { @@ -212,7 +226,6 @@ func TestUnixgramConnLocalAndRemoteNames(t *testing.T) { var la *UnixAddr if laddr != "" { - var err error if la, err = ResolveUnixAddr("unixgram", laddr); err != nil { t.Fatalf("ResolveUnixAddr failed: %v", err) } diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go index 21a19eca2c0..85955845b80 100644 --- a/libgo/go/net/unixsock.go +++ b/libgo/go/net/unixsock.go @@ -24,8 +24,8 @@ func (a *UnixAddr) String() string { } func (a *UnixAddr) toAddr() Addr { - if a == nil { // nil *UnixAddr - return nil // nil interface + if a == nil { + return nil } return a } diff --git a/libgo/go/net/unixsock_plan9.go b/libgo/go/net/unixsock_plan9.go index 8a1281fb1a4..c60c1d83bb3 100644 --- a/libgo/go/net/unixsock_plan9.go +++ b/libgo/go/net/unixsock_plan9.go @@ -97,7 +97,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { } // AcceptUnix accepts the next incoming call and returns the new -// connection and the remote address. +// connection. func (l *UnixListener) AcceptUnix() (*UnixConn, error) { return nil, syscall.EPLAN9 } diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go index 5db30df95fc..b82f3cee0b5 100644 --- a/libgo/go/net/unixsock_posix.go +++ b/libgo/go/net/unixsock_posix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin freebsd linux netbsd openbsd windows +// +build darwin dragonfly freebsd linux netbsd openbsd windows package net @@ -13,14 +13,7 @@ import ( "time" ) -func (a *UnixAddr) isUnnamed() bool { - if a == nil || a.Name == "" { - return true - } - return false -} - -func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.Time) (*netFD, error) { +func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Time) (*netFD, error) { var sotype int switch net { case "unix": @@ -33,19 +26,18 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.T return nil, UnknownNetworkError(net) } - var la, ra syscall.Sockaddr switch mode { case "dial": - if !laddr.isUnnamed() { - la = &syscall.SockaddrUnix{Name: laddr.Name} + if laddr != nil && laddr.isWildcard() { + laddr = nil } - if raddr != nil { - ra = &syscall.SockaddrUnix{Name: raddr.Name} - } else if sotype != syscall.SOCK_DGRAM || laddr.isUnnamed() { - return nil, &OpError{Op: mode, Net: net, Err: errMissingAddress} + if raddr != nil && raddr.isWildcard() { + raddr = nil + } + if raddr == nil && (sotype != syscall.SOCK_DGRAM || laddr == nil) { + return nil, errMissingAddress } case "listen": - la = &syscall.SockaddrUnix{Name: laddr.Name} default: return nil, errors.New("unknown mode: " + mode) } @@ -57,19 +49,11 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.T f = sockaddrToUnixpacket } - fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, la, ra, deadline, f) + fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, f) if err != nil { - goto error + return nil, err } return fd, nil - -error: - addr := raddr - switch mode { - case "listen": - addr = laddr - } - return nil, &OpError{Op: mode, Net: net, Addr: addr, Err: err} } func sockaddrToUnix(sa syscall.Sockaddr) Addr { @@ -106,6 +90,21 @@ func sotypeToNet(sotype int) string { } } +func (a *UnixAddr) family() int { + return syscall.AF_UNIX +} + +func (a *UnixAddr) isWildcard() bool { + return a == nil || a.Name == "" +} + +func (a *UnixAddr) sockaddr(family int) (syscall.Sockaddr, error) { + if a == nil { + return nil, nil + } + return &syscall.SockaddrUnix{Name: a.Name}, nil +} + // UnixConn is an implementation of the Conn interface for connections // to Unix domain sockets. type UnixConn struct { @@ -172,6 +171,9 @@ func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err error) { if !c.ok() { return 0, syscall.EINVAL } + if addr == nil { + return 0, &OpError{Op: "write", Net: c.fd.net, Addr: nil, Err: errMissingAddress} + } if addr.Net != sotypeToNet(c.fd.sotype) { return 0, syscall.EAFNOSUPPORT } @@ -230,18 +232,18 @@ func (c *UnixConn) CloseWrite() error { // which must be "unix", "unixgram" or "unixpacket". If laddr is not // nil, it is used as the local address for the connection. func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) { - return dialUnix(net, laddr, raddr, noDeadline) -} - -func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { switch net { case "unix", "unixgram", "unixpacket": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "dial", Net: net, Addr: raddr, Err: UnknownNetworkError(net)} } + return dialUnix(net, laddr, raddr, noDeadline) +} + +func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn, error) { fd, err := unixSocket(net, laddr, raddr, "dial", deadline) if err != nil { - return nil, err + return nil, &OpError{Op: "dial", Net: net, Addr: raddr, Err: err} } return newUnixConn(fd), nil } @@ -260,25 +262,20 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { switch net { case "unix", "unixpacket": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: UnknownNetworkError(net)} } if laddr == nil { - return nil, &OpError{"listen", net, nil, errMissingAddress} + return nil, &OpError{Op: "listen", Net: net, Addr: nil, Err: errMissingAddress} } fd, err := unixSocket(net, laddr, nil, "listen", noDeadline) if err != nil { - return nil, err - } - err = syscall.Listen(fd.sysfd, listenerBacklog) - if err != nil { - fd.Close() return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err} } - return &UnixListener{fd, laddr.Name}, nil + return &UnixListener{fd, fd.laddr.String()}, nil } // AcceptUnix accepts the next incoming call and returns the new -// connection and the remote address. +// connection. func (l *UnixListener) AcceptUnix() (*UnixConn, error) { if l == nil || l.fd == nil { return nil, syscall.EINVAL @@ -333,7 +330,7 @@ func (l *UnixListener) SetDeadline(t time.Time) (err error) { if l == nil || l.fd == nil { return syscall.EINVAL } - return setDeadline(l.fd, t) + return l.fd.setDeadline(t) } // File returns a copy of the underlying os.File, set to blocking @@ -353,14 +350,14 @@ func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { switch net { case "unixgram": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: UnknownNetworkError(net)} } if laddr == nil { - return nil, &OpError{"listen", net, nil, errMissingAddress} + return nil, &OpError{Op: "listen", Net: net, Addr: nil, Err: errMissingAddress} } fd, err := unixSocket(net, laddr, nil, "listen", noDeadline) if err != nil { - return nil, err + return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err} } return newUnixConn(fd), nil } diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index 459dc473ceb..597cb51c883 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -451,14 +451,17 @@ func (u *URL) String() string { } else { if u.Scheme != "" || u.Host != "" || u.User != nil { buf.WriteString("//") - if u := u.User; u != nil { - buf.WriteString(u.String()) + if ui := u.User; ui != nil { + buf.WriteString(ui.String()) buf.WriteByte('@') } if h := u.Host; h != "" { buf.WriteString(h) } } + if u.Path != "" && u.Path[0] != '/' && u.Host != "" { + buf.WriteByte('/') + } buf.WriteString(escape(u.Path, encodePath)) } if u.RawQuery != "" { diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go index 9d81289ceba..7578eb15b90 100644 --- a/libgo/go/net/url/url_test.go +++ b/libgo/go/net/url/url_test.go @@ -260,6 +260,14 @@ var urltests = []URLTest{ }, "mailto:webmaster@golang.org", }, + // Relative path + { + "a/b/c", + &URL{ + Path: "a/b/c", + }, + "a/b/c", + }, } // more useful string for debugging than fmt's struct printer @@ -372,6 +380,22 @@ func DoTestString(t *testing.T, parse func(string) (*URL, error), name string, t func TestURLString(t *testing.T) { DoTestString(t, Parse, "Parse", urltests) + + // no leading slash on path should prepend + // slash on String() call + noslash := URLTest{ + "http://www.google.com/search", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "search", + }, + "", + } + s := noslash.out.String() + if s != noslash.in { + t.Errorf("Expected %s; go %s", noslash.in, s) + } } type EscapeTest struct { |