diff options
author | Brad Fitzpatrick <bradfitz@golang.org> | 2014-10-15 17:51:12 +0200 |
---|---|---|
committer | Brad Fitzpatrick <bradfitz@golang.org> | 2014-10-15 17:51:12 +0200 |
commit | 713940ccabd51693d5a5bbdfa3acfae9e7235e35 (patch) | |
tree | 5d2a21297281b469dc42a16fac2553ee595f0957 /src/net | |
parent | c5799fafe77a6561653f8faccfa7ff5128c968ca (diff) | |
download | go-713940ccabd51693d5a5bbdfa3acfae9e7235e35.tar.gz |
net/http: don't reuse a server connection after any Write errors
Fixes Issue 8534
LGTM=adg
R=adg
CC=golang-codereviews
https://codereview.appspot.com/149340044
Diffstat (limited to 'src/net')
-rw-r--r-- | src/net/http/serve_test.go | 97 | ||||
-rw-r--r-- | src/net/http/server.go | 32 |
2 files changed, 126 insertions, 3 deletions
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 702bffdc1..bb44ac853 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -2659,6 +2659,103 @@ func TestCloseWrite(t *testing.T) { } } +// This verifies that a handler can Flush and then Hijack. +// +// An similar test crashed once during development, but it was only +// testing this tangentially and temporarily until another TODO was +// fixed. +// +// So add an explicit test for this. +func TestServerFlushAndHijack(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, "Hello, ") + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + buf.WriteString("6\r\nworld!\r\n0\r\n\r\n") + if err := buf.Flush(); err != nil { + t.Error(err) + } + if err := conn.Close(); err != nil { + t.Error(err) + } + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "Hello, world!"; string(all) != want { + t.Errorf("Got %q; want %q", all, want) + } +} + +// golang.org/issue/8534 -- the Server shouldn't reuse a connection +// for keep-alive after it's seen any Write error (e.g. a timeout) on +// that net.Conn. +// +// To test, verify we don't timeout or see fewer unique client +// addresses (== unique connections) than requests. +func TestServerKeepAliveAfterWriteError(t *testing.T) { + if testing.Short() { + t.Skip("skipping in -short mode") + } + defer afterTest(t) + const numReq = 3 + addrc := make(chan string, numReq) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + addrc <- r.RemoteAddr + time.Sleep(500 * time.Millisecond) + w.(Flusher).Flush() + })) + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() + + errc := make(chan error, numReq) + go func() { + defer close(errc) + for i := 0; i < numReq; i++ { + res, err := Get(ts.URL) + if res != nil { + res.Body.Close() + } + errc <- err + } + }() + + timeout := time.NewTimer(numReq * 2 * time.Second) // 4x overkill + defer timeout.Stop() + addrSeen := map[string]bool{} + numOkay := 0 + for { + select { + case v := <-addrc: + addrSeen[v] = true + case err, ok := <-errc: + if !ok { + if len(addrSeen) != numReq { + t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq) + } + if numOkay != 0 { + t.Errorf("got %d successful client requests; want 0", numOkay) + } + return + } + if err == nil { + numOkay++ + } + case <-timeout.C: + t.Fatal("timeout waiting for requests to complete") + } + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() diff --git a/src/net/http/server.go b/src/net/http/server.go index b5959f732..008d5aa7a 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -114,6 +114,8 @@ type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection + w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack + werr error // any errors writing to w sr liveSwitchReader // where the LimitReader reads from; usually the rwc lr *io.LimitedReader // io.LimitReader(sr) buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc @@ -432,13 +434,14 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { c.remoteAddr = rwc.RemoteAddr().String() c.server = srv c.rwc = rwc + c.w = rwc if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } c.sr = liveSwitchReader{r: c.rwc} c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) br := newBufioReader(c.lr) - bw := newBufioWriterSize(c.rwc, 4<<10) + bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) c.buf = bufio.NewReadWriter(br, bw) return c, nil } @@ -956,8 +959,10 @@ func (w *response) bodyAllowed() bool { // 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes // 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type) // and which writes the chunk headers, if needed. -// 4. conn.buf, a bufio.Writer of default (4kB) bytes -// 5. the rwc, the net.Conn. +// 4. conn.buf, a bufio.Writer of default (4kB) bytes, writing to -> +// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write +// and populates c.werr with it if so. but otherwise writes to: +// 6. the rwc, the net.Conn. // // TODO(bradfitz): short-circuit some of the buffering when the // initial header contains both a Content-Type and Content-Length. @@ -1027,6 +1032,12 @@ func (w *response) finishRequest() { // Did not write enough. Avoid getting out of sync. w.closeAfterReply = true } + + // There was some error writing to the underlying connection + // during the request, so don't re-use this conn. + if w.conn.werr != nil { + w.closeAfterReply = true + } } func (w *response) Flush() { @@ -2068,3 +2079,18 @@ func (c *loggingConn) Close() (err error) { log.Printf("%s.Close() = %v", c.name, err) return } + +// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr. +// It only contains one field (and a pointer field at that), so it +// fits in an interface value without an extra allocation. +type checkConnErrorWriter struct { + c *conn +} + +func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { + n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil. + if err != nil && w.c.werr == nil { + w.c.werr = err + } + return +} |