diff options
Diffstat (limited to 'libgo/go/net/http/server.go')
-rw-r--r-- | libgo/go/net/http/server.go | 719 |
1 files changed, 565 insertions, 154 deletions
diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 89574a8b36e..96236489bd9 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -40,7 +40,9 @@ var ( // ErrHijacked is returned by ResponseWriter.Write calls when // the underlying connection has been hijacked using the - // Hijacker interfaced. + // Hijacker interface. A zero-byte write on a hijacked + // connection will return ErrHijacked without any other side + // effects. ErrHijacked = errors.New("http: connection has been hijacked") // ErrContentLength is returned by ResponseWriter.Write calls @@ -73,7 +75,9 @@ var ( // If ServeHTTP panics, the server (the caller of ServeHTTP) assumes // that the effect of the panic was isolated to the active request. // It recovers the panic, logs a stack trace to the server error log, -// and hangs up the connection. +// and hangs up the connection. To abort a handler so the client sees +// an interrupted response but the server doesn't log an error, panic +// with the value ErrAbortHandler. type Handler interface { ServeHTTP(ResponseWriter, *Request) } @@ -85,11 +89,25 @@ type Handler interface { // has returned. type ResponseWriter interface { // Header returns the header map that will be sent by - // WriteHeader. Changing the header after a call to - // WriteHeader (or Write) has no effect unless the modified - // headers were declared as trailers by setting the - // "Trailer" header before the call to WriteHeader (see example). - // To suppress implicit response headers, set their value to nil. + // WriteHeader. The Header map also is the mechanism with which + // Handlers can set HTTP trailers. + // + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the modified headers are + // trailers. + // + // There are two ways to set Trailers. The preferred way is to + // predeclare in the headers which trailers you will later + // send by setting the "Trailer" header to the names of the + // trailer keys which will come later. In this case, those + // keys of the Header map are treated as if they were + // trailers. See the example. The second way, for trailer + // keys not known to the Handler until after the first Write, + // is to prefix the Header map keys with the TrailerPrefix + // constant value. See TrailerPrefix. + // + // To suppress implicit response headers (such as "Date"), set + // their value to nil. Header() Header // Write writes the data to the connection as part of an HTTP reply. @@ -206,6 +224,9 @@ type conn struct { // Immutable; never nil. server *Server + // cancelCtx cancels the connection-level context. + cancelCtx context.CancelFunc + // rwc is the underlying network connection. // This is never wrapped by other types and is the value given out // to CloseNotifier callers. It is usually of type *net.TCPConn or @@ -232,7 +253,6 @@ type conn struct { r *connReader // bufr reads from r. - // Users of bufr must hold mu. bufr *bufio.Reader // bufw writes to checkConnErrorWriter{c}, which populates werr on error. @@ -242,7 +262,11 @@ type conn struct { // on this connection, if any. lastMethod string - // mu guards hijackedv, use of bufr, (*response).closeNotifyCh. + curReq atomic.Value // of *response (which has a Request in it) + + curState atomic.Value // of ConnState + + // mu guards hijackedv mu sync.Mutex // hijackedv is whether this connection has been hijacked @@ -262,8 +286,12 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if c.hijackedv { return nil, nil, ErrHijacked } + c.r.abortPendingRead() + c.hijackedv = true rwc = c.rwc + rwc.SetDeadline(time.Time{}) + buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) c.setState(rwc, StateHijacked) return @@ -346,13 +374,7 @@ func (cw *chunkWriter) close() { bw := cw.res.conn.bufw // conn's bufio writer // zero chunk to mark EOF bw.WriteString("0\r\n") - if len(cw.res.trailers) > 0 { - trailers := make(Header) - for _, h := range cw.res.trailers { - if vv := cw.res.handlerHeader[h]; len(vv) > 0 { - trailers[h] = vv - } - } + if trailers := cw.res.finalTrailers(); trailers != nil { trailers.Write(bw) // the writer handles noting errors } // final blank line after the trailers (whether @@ -413,9 +435,48 @@ type response struct { dateBuf [len(TimeFormat)]byte clenBuf [10]byte - // closeNotifyCh is non-nil once CloseNotify is called. - // Guarded by conn.mu - closeNotifyCh <-chan bool + // closeNotifyCh is the channel returned by CloseNotify. + // TODO(bradfitz): this is currently (for Go 1.8) always + // non-nil. Make this lazily-created again as it used to be? + closeNotifyCh chan bool + didCloseNotify int32 // atomic (only 0->1 winner should send) +} + +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +const TrailerPrefix = "Trailer:" + +// finalTrailers is called after the Handler exits and returns a non-nil +// value if the Handler set any trailers. +func (w *response) finalTrailers() Header { + var t Header + for k, vv := range w.handlerHeader { + if strings.HasPrefix(k, TrailerPrefix) { + if t == nil { + t = make(Header) + } + t[strings.TrimPrefix(k, TrailerPrefix)] = vv + } + } + for _, k := range w.trailers { + if t == nil { + t = make(Header) + } + for _, v := range w.handlerHeader[k] { + t.Add(k, v) + } + } + return t } type atomicBool int32 @@ -548,60 +609,152 @@ type readResult struct { // call blocked in a background goroutine to wait for activity and // trigger a CloseNotifier channel. type connReader struct { - r io.Reader - remain int64 // bytes remaining + conn *conn - // ch is non-nil if a background read is in progress. - // It is guarded by conn.mu. - ch chan readResult + mu sync.Mutex // guards following + hasByte bool + byteBuf [1]byte + bgErr error // non-nil means error happened on background read + cond *sync.Cond + inRead bool + aborted bool // set true before conn.rwc deadline is set to past + remain int64 // bytes remaining +} + +func (cr *connReader) lock() { + cr.mu.Lock() + if cr.cond == nil { + cr.cond = sync.NewCond(&cr.mu) + } +} + +func (cr *connReader) unlock() { cr.mu.Unlock() } + +func (cr *connReader) startBackgroundRead() { + cr.lock() + defer cr.unlock() + if cr.inRead { + panic("invalid concurrent Body.Read call") + } + if cr.hasByte { + return + } + cr.inRead = true + cr.conn.rwc.SetReadDeadline(time.Time{}) + go cr.backgroundRead() +} + +func (cr *connReader) backgroundRead() { + n, err := cr.conn.rwc.Read(cr.byteBuf[:]) + cr.lock() + if n == 1 { + cr.hasByte = true + // We were at EOF already (since we wouldn't be in a + // background read otherwise), so this is a pipelined + // HTTP request. + cr.closeNotifyFromPipelinedRequest() + } + if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() { + // Ignore this error. It's the expected error from + // another goroutine calling abortPendingRead. + } else if err != nil { + cr.handleReadError(err) + } + cr.aborted = false + cr.inRead = false + cr.unlock() + cr.cond.Broadcast() +} + +func (cr *connReader) abortPendingRead() { + cr.lock() + defer cr.unlock() + if !cr.inRead { + return + } + cr.aborted = true + cr.conn.rwc.SetReadDeadline(aLongTimeAgo) + for cr.inRead { + cr.cond.Wait() + } + cr.conn.rwc.SetReadDeadline(time.Time{}) } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } +// may be called from multiple goroutines. +func (cr *connReader) handleReadError(err error) { + cr.conn.cancelCtx() + cr.closeNotify() +} + +// closeNotifyFromPipelinedRequest simply calls closeNotify. +// +// This method wrapper is here for documentation. The callers are the +// cases where we send on the closenotify channel because of a +// pipelined HTTP request, per the previous Go behavior and +// documentation (that this "MAY" happen). +// +// TODO: consider changing this behavior and making context +// cancelation and closenotify work the same. +func (cr *connReader) closeNotifyFromPipelinedRequest() { + cr.closeNotify() +} + +// may be called from multiple goroutines. +func (cr *connReader) closeNotify() { + res, _ := cr.conn.curReq.Load().(*response) + if res != nil { + if atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + res.closeNotifyCh <- true + } + } +} + func (cr *connReader) Read(p []byte) (n int, err error) { + cr.lock() + if cr.inRead { + cr.unlock() + panic("invalid concurrent Body.Read call") + } if cr.hitReadLimit() { + cr.unlock() return 0, io.EOF } + if cr.bgErr != nil { + err = cr.bgErr + cr.unlock() + return 0, err + } if len(p) == 0 { - return + cr.unlock() + return 0, nil } if int64(len(p)) > cr.remain { p = p[:cr.remain] } - - // Is a background read (started by CloseNotifier) already in - // flight? If so, wait for it and use its result. - ch := cr.ch - if ch != nil { - cr.ch = nil - res := <-ch - if res.n == 1 { - p[0] = res.b - cr.remain -= 1 - } - return res.n, res.err + if cr.hasByte { + p[0] = cr.byteBuf[0] + cr.hasByte = false + cr.unlock() + return 1, nil } - n, err = cr.r.Read(p) - cr.remain -= int64(n) - return -} + cr.inRead = true + cr.unlock() + n, err = cr.conn.rwc.Read(p) -func (cr *connReader) startBackgroundRead(onReadComplete func()) { - if cr.ch != nil { - // Background read already started. - return + cr.lock() + cr.inRead = false + if err != nil { + cr.handleReadError(err) } - cr.ch = make(chan readResult, 1) - go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete) -} + cr.remain -= int64(n) + cr.unlock() -func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) { - var buf [1]byte - n, err := cr.r.Read(buf[:1]) - onReadComplete() - ch <- readResult{n, err, buf[0]} + cr.cond.Broadcast() + return n, err } var ( @@ -633,7 +786,7 @@ func newBufioReader(r io.Reader) *bufio.Reader { br.Reset(r) return br } - // Note: if this reader size is every changed, update + // Note: if this reader size is ever changed, update // TestHandlerBodyClose's assumptions. return bufio.NewReader(r) } @@ -746,9 +899,18 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { return nil, ErrHijacked } + var ( + wholeReqDeadline time.Time // or zero if none + hdrDeadline time.Time // or zero if none + ) + t0 := time.Now() + if d := c.server.readHeaderTimeout(); d != 0 { + hdrDeadline = t0.Add(d) + } if d := c.server.ReadTimeout; d != 0 { - c.rwc.SetReadDeadline(time.Now().Add(d)) + wholeReqDeadline = t0.Add(d) } + c.rwc.SetReadDeadline(hdrDeadline) if d := c.server.WriteTimeout; d != 0 { defer func() { c.rwc.SetWriteDeadline(time.Now().Add(d)) @@ -756,14 +918,12 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { } c.r.setReadLimit(c.server.initialReadLimitSize()) - c.mu.Lock() // while using bufr if c.lastMethod == "POST" { // RFC 2616 section 4.1 tolerance for old buggy clients. peek, _ := c.bufr.Peek(4) // ReadRequest will get err below c.bufr.Discard(numLeadingCRorLF(peek)) } req, err := readRequest(c.bufr, keepHostHeader) - c.mu.Unlock() if err != nil { if c.r.hitReadLimit() { return nil, errTooLarge @@ -809,6 +969,11 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { body.doEarlyClose = true } + // Adjust the read deadline if necessary. + if !hdrDeadline.Equal(wholeReqDeadline) { + c.rwc.SetReadDeadline(wholeReqDeadline) + } + w = &response{ conn: c, cancelCtx: cancelCtx, @@ -816,6 +981,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { reqBody: req.Body, handlerHeader: make(Header), contentLength: -1, + closeNotifyCh: make(chan bool, 1), // We populate these ahead of time so we're not // reading from req.Header after their Handler starts @@ -990,7 +1156,17 @@ func (cw *chunkWriter) writeHeader(p []byte) { } var setHeader extraHeader + // Don't write out the fake "Trailer:foo" keys. See TrailerPrefix. trailers := false + for k := range cw.header { + if strings.HasPrefix(k, TrailerPrefix) { + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[k] = true + trailers = true + } + } for _, v := range cw.header["Trailer"] { trailers = true foreachHeaderElement(v, cw.res.declareTrailer) @@ -1318,7 +1494,9 @@ func (w *response) WriteString(data string) (n int, err error) { // either dataB or dataS is non-zero. func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { if w.conn.hijacked() { - w.conn.server.logf("http: response.Write on hijacked connection") + if lenData > 0 { + w.conn.server.logf("http: response.Write on hijacked connection") + } return 0, ErrHijacked } if !w.wroteHeader { @@ -1354,6 +1532,8 @@ func (w *response) finishRequest() { w.cw.close() w.conn.bufw.Flush() + w.conn.r.abortPendingRead() + // Close the body (regardless of w.closeAfterReply) so we can // re-use its bufio.Reader later safely. w.reqBody.Close() @@ -1469,11 +1649,30 @@ func validNPN(proto string) bool { } func (c *conn) setState(nc net.Conn, state ConnState) { - if hook := c.server.ConnState; hook != nil { + srv := c.server + switch state { + case StateNew: + srv.trackConn(c, true) + case StateHijacked, StateClosed: + srv.trackConn(c, false) + } + c.curState.Store(connStateInterface[state]) + if hook := srv.ConnState; hook != nil { hook(nc, state) } } +// connStateInterface is an array of the interface{} versions of +// ConnState values, so we can use them in atomic.Values later without +// paying the cost of shoving their integers in an interface{}. +var connStateInterface = [...]interface{}{ + StateNew: StateNew, + StateActive: StateActive, + StateIdle: StateIdle, + StateHijacked: StateHijacked, + StateClosed: StateClosed, +} + // badRequestError is a literal string (used by in the server in HTML, // unescaped) to tell the user why their request was bad. It should // be plain text without user info or other embedded errors. @@ -1481,11 +1680,34 @@ type badRequestError string func (e badRequestError) Error() string { return "Bad Request: " + string(e) } +// ErrAbortHandler is a sentinel panic value to abort a handler. +// While any panic from ServeHTTP aborts the response to the client, +// panicking with ErrAbortHandler also suppresses logging of a stack +// trace to the server's error log. +var ErrAbortHandler = errors.New("net/http: abort Handler") + +// isCommonNetReadError reports whether err is a common error +// encountered during reading a request off the network when the +// client has gone away or had its read fail somehow. This is used to +// determine which logs are interesting enough to log about. +func isCommonNetReadError(err error) bool { + if err == io.EOF { + return true + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return true + } + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + return true + } + return false +} + // Serve a new connection. func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() defer func() { - if err := recover(); err != nil { + if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] @@ -1521,13 +1743,14 @@ func (c *conn) serve(ctx context.Context) { // HTTP/1.x from here on. - c.r = &connReader{r: c.rwc} - c.bufr = newBufioReader(c.r) - c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) - ctx, cancelCtx := context.WithCancel(ctx) + c.cancelCtx = cancelCtx defer cancelCtx() + c.r = &connReader{conn: c} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + for { w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { @@ -1535,27 +1758,29 @@ func (c *conn) serve(ctx context.Context) { c.setState(c.rwc, StateActive) } if err != nil { + const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" + if err == errTooLarge { // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large") + const publicErr = "431 Request Header Fields Too Large" + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) c.closeWriteAndWait() return } - if err == io.EOF { - return // don't reply - } - if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + if isCommonNetReadError(err) { return // don't reply } - var publicErr string + + publicErr := "400 Bad Request" if v, ok := err.(badRequestError); ok { - publicErr = ": " + string(v) + publicErr = publicErr + ": " + string(v) } - io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr) + + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) return } @@ -1571,11 +1796,24 @@ func (c *conn) serve(ctx context.Context) { return } + c.curReq.Store(w) + + if requestBodyRemains(req.Body) { + registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead) + } else { + if w.conn.bufr.Buffered() > 0 { + w.conn.r.closeNotifyFromPipelinedRequest() + } + w.conn.r.startBackgroundRead() + } + // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. + // But we're not going to implement HTTP pipelining because it + // was never deployed in the wild and the answer is HTTP/2. serverHandler{c.server}.ServeHTTP(w, w.req) w.cancelCtx() if c.hijacked() { @@ -1589,6 +1827,23 @@ func (c *conn) serve(ctx context.Context) { return } c.setState(c.rwc, StateIdle) + c.curReq.Store((*response)(nil)) + + if !w.conn.server.doKeepAlives() { + // We're in shutdown mode. We might've replied + // to the user without "Connection: close" and + // they might think they can send another + // request, but such is life with HTTP/1.1. + return + } + + if d := c.server.idleTimeout(); d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + if _, err := c.bufr.Peek(4); err != nil { + return + } + } + c.rwc.SetReadDeadline(time.Time{}) } } @@ -1624,10 +1879,6 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { c.mu.Lock() defer c.mu.Unlock() - if w.closeNotifyCh != nil { - return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call") - } - // Release the bufioWriter that writes to the chunk writer, it is not // used after a connection has been hijacked. rwc, buf, err = c.hijackLocked() @@ -1642,50 +1893,7 @@ func (w *response) CloseNotify() <-chan bool { if w.handlerDone.isSet() { panic("net/http: CloseNotify called after ServeHTTP finished") } - c := w.conn - c.mu.Lock() - defer c.mu.Unlock() - - if w.closeNotifyCh != nil { - return w.closeNotifyCh - } - ch := make(chan bool, 1) - w.closeNotifyCh = ch - - if w.conn.hijackedv { - // CloseNotify is undefined after a hijack, but we have - // no place to return an error, so just return a channel, - // even though it'll never receive a value. - return ch - } - - var once sync.Once - notify := func() { once.Do(func() { ch <- true }) } - - if requestBodyRemains(w.reqBody) { - // They're still consuming the request body, so we - // shouldn't notify yet. - registerOnHitEOF(w.reqBody, func() { - c.mu.Lock() - defer c.mu.Unlock() - startCloseNotifyBackgroundRead(c, notify) - }) - } else { - startCloseNotifyBackgroundRead(c, notify) - } - return ch -} - -// c.mu must be held. -func startCloseNotifyBackgroundRead(c *conn, notify func()) { - if c.bufr.Buffered() > 0 { - // They've consumed the request body, so anything - // remaining is a pipelined request, which we - // document as firing on. - notify() - } else { - c.r.startBackgroundRead(notify) - } + return w.closeNotifyCh } func registerOnHitEOF(rc io.ReadCloser, fn func()) { @@ -1702,7 +1910,7 @@ func registerOnHitEOF(rc io.ReadCloser, fn func()) { // requestBodyRemains reports whether future calls to Read // on rc might yield more data. func requestBodyRemains(rc io.ReadCloser) bool { - if rc == eofReader { + if rc == NoBody { return false } switch v := rc.(type) { @@ -1816,7 +2024,7 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } } - w.Header().Set("Location", urlStr) + w.Header().Set("Location", hexEscapeNonASCII(urlStr)) w.WriteHeader(code) // RFC 2616 recommends that a short note "SHOULD" be included in the @@ -2094,11 +2302,36 @@ func Serve(l net.Listener, handler Handler) error { // A Server defines parameters for running an HTTP server. // The zero value for Server is a valid configuration. type Server struct { - Addr string // TCP address to listen on, ":http" if empty - Handler Handler // handler to invoke, http.DefaultServeMux if nil - ReadTimeout time.Duration // maximum duration before timing out read of the request - WriteTimeout time.Duration // maximum duration before timing out write of the response - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + Addr string // TCP address to listen on, ":http" if empty + Handler Handler // handler to invoke, http.DefaultServeMux if nil + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. The connection's read deadline is reset + // after reading the headers and the Handler can decide what + // is considered too slow for the body. + ReadHeaderTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + IdleTimeout time.Duration // MaxHeaderBytes controls the maximum number of bytes the // server will read parsing the request header's keys and @@ -2114,7 +2347,8 @@ type Server struct { // handle HTTP requests and will initialize the Request's TLS // and RemoteAddr if not already set. The connection is // automatically closed when the function returns. - // If TLSNextProto is nil, HTTP/2 support is enabled automatically. + // If TLSNextProto is not nil, HTTP/2 support is not enabled + // automatically. TLSNextProto map[string]func(*Server, *tls.Conn, Handler) // ConnState specifies an optional callback function that is @@ -2129,8 +2363,132 @@ type Server struct { ErrorLog *log.Logger disableKeepAlives int32 // accessed atomically. + inShutdown int32 // accessed atomically (non-zero means we're in Shutdown) nextProtoOnce sync.Once // guards setupHTTP2_* init nextProtoErr error // result of http2.ConfigureServer if used + + mu sync.Mutex + listeners map[net.Listener]struct{} + activeConn map[*conn]struct{} + doneChan chan struct{} +} + +func (s *Server) getDoneChan() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.getDoneChanLocked() +} + +func (s *Server) getDoneChanLocked() chan struct{} { + if s.doneChan == nil { + s.doneChan = make(chan struct{}) + } + return s.doneChan +} + +func (s *Server) closeDoneChanLocked() { + ch := s.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) + } +} + +// Close immediately closes all active net.Listeners and any +// connections in state StateNew, StateActive, or StateIdle. For a +// graceful shutdown, use Shutdown. +// +// Close does not attempt to close (and does not even know about) +// any hijacked connections, such as WebSockets. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + for c := range srv.activeConn { + c.rwc.Close() + delete(srv.activeConn, c) + } + return err +} + +// shutdownPollInterval is how often we poll for quiescence +// during Server.Shutdown. This is lower during tests, to +// speed up tests. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +var shutdownPollInterval = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +// +// Shutdown does not attempt to close nor wait for hijacked +// connections such as WebSockets. The caller of Shutdown should +// separately notify such long-lived connections of shutdown and wait +// for them to close, if desired. +func (srv *Server) Shutdown(ctx context.Context) error { + atomic.AddInt32(&srv.inShutdown, 1) + defer atomic.AddInt32(&srv.inShutdown, -1) + + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + srv.mu.Unlock() + + ticker := time.NewTicker(shutdownPollInterval) + defer ticker.Stop() + for { + if srv.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// closeIdleConns closes all idle connections and reports whether the +// server is quiescent. +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.activeConn { + st, ok := c.curState.Load().(ConnState) + if !ok || st != StateIdle { + quiescent = false + continue + } + c.rwc.Close() + delete(s.activeConn, c) + } + return quiescent +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(s.listeners, ln) + } + return err } // A ConnState represents the state of a client connection to a server. @@ -2243,6 +2601,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) } +var ErrServerClosed = errors.New("http: Server closed") + // Serve accepts incoming connections on the Listener l, creating a // new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. @@ -2252,7 +2612,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { // srv.TLSConfig is non-nil and doesn't include the string "h2" in // Config.NextProtos, HTTP/2 support is not enabled. // -// Serve always returns a non-nil error. +// Serve always returns a non-nil error. After Shutdown or Close, the +// returned error is ErrServerClosed. func (srv *Server) Serve(l net.Listener) error { defer l.Close() if fn := testHookServerServe; fn != nil { @@ -2264,14 +2625,20 @@ func (srv *Server) Serve(l net.Listener) error { return err } - // TODO: allow changing base context? can't imagine concrete - // use cases yet. - baseCtx := context.Background() + srv.trackListener(l, true) + defer srv.trackListener(l, false) + + baseCtx := context.Background() // base is always background, per Issue 16220 ctx := context.WithValue(baseCtx, ServerContextKey, srv) ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) for { rw, e := l.Accept() if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond @@ -2294,8 +2661,57 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (s *Server) trackListener(ln net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(s.listeners) == 0 && len(s.activeConn) == 0 { + s.doneChan = nil + } + s.listeners[ln] = struct{}{} + } else { + delete(s.listeners, ln) + } +} + +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + +func (s *Server) idleTimeout() time.Duration { + if s.IdleTimeout != 0 { + return s.IdleTimeout + } + return s.ReadTimeout +} + +func (s *Server) readHeaderTimeout() time.Duration { + if s.ReadHeaderTimeout != 0 { + return s.ReadHeaderTimeout + } + return s.ReadTimeout +} + func (s *Server) doKeepAlives() bool { - return atomic.LoadInt32(&s.disableKeepAlives) == 0 + return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() +} + +func (s *Server) shuttingDown() bool { + return atomic.LoadInt32(&s.inShutdown) != 0 } // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. @@ -2305,9 +2721,21 @@ func (s *Server) doKeepAlives() bool { func (srv *Server) SetKeepAlivesEnabled(v bool) { if v { atomic.StoreInt32(&srv.disableKeepAlives, 0) - } else { - atomic.StoreInt32(&srv.disableKeepAlives, 1) + return } + atomic.StoreInt32(&srv.disableKeepAlives, 1) + + // Close idle HTTP/1 conns: + srv.closeIdleConns() + + // Close HTTP/2 conns, as soon as they become idle, but reset + // the chan so future conns (if the listener is still active) + // still work and don't get a GOAWAY immediately, before their + // first request: + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() // closes http2 conns + srv.doneChan = nil } func (s *Server) logf(format string, args ...interface{}) { @@ -2630,24 +3058,6 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } } -type eofReaderWithWriteTo struct{} - -func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil } -func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF } - -// eofReader is a non-nil io.ReadCloser that always returns EOF. -// It has a WriteTo method so io.Copy won't need a buffer. -var eofReader = &struct { - eofReaderWithWriteTo - io.Closer -}{ - eofReaderWithWriteTo{}, - ioutil.NopCloser(nil), -} - -// Verify that an io.Copy from an eofReader won't require a buffer. -var _ io.WriterTo = eofReader - // initNPNRequest is an HTTP handler that initializes certain // uninitialized fields in its *Request. Such partially-initialized // Requests come from NPN protocol handlers. @@ -2662,7 +3072,7 @@ func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { *req.TLS = h.c.ConnectionState() } if req.Body == nil { - req.Body = eofReader + req.Body = NoBody } if req.RemoteAddr == "" { req.RemoteAddr = h.c.RemoteAddr().String() @@ -2723,6 +3133,7 @@ func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { n, err = w.c.rwc.Write(p) if err != nil && w.c.werr == nil { w.c.werr = err + w.c.cancelCtx() } return } |