diff options
author | Ian Lance Taylor <ian@gcc.gnu.org> | 2016-07-22 18:15:38 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2016-07-22 18:15:38 +0000 |
commit | 22b955cca564a9a3a5b8c9d9dd1e295b7943c128 (patch) | |
tree | abdbd898676e1f853fca2d7e031d105d7ebcf676 /libgo/go/net/http/h2_bundle.go | |
parent | 9d04a3af4c6491536badf6bde9707c907e4d196b (diff) | |
download | gcc-22b955cca564a9a3a5b8c9d9dd1e295b7943c128.tar.gz |
libgo: update to go1.7rc3
Reviewed-on: https://go-review.googlesource.com/25150
From-SVN: r238662
Diffstat (limited to 'libgo/go/net/http/h2_bundle.go')
-rw-r--r-- | libgo/go/net/http/h2_bundle.go | 1729 |
1 files changed, 1060 insertions, 669 deletions
diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 4e19b3e71f7..db774554b2c 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -1,5 +1,5 @@ // Code generated by golang.org/x/tools/cmd/bundle. -//go:generate bundle -o h2_bundle.go -prefix http2 -import golang.org/x/net/http2/hpack=internal/golang.org/x/net/http2/hpack golang.org/x/net/http2 +//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2 // Package http2 implements the HTTP/2 protocol. // @@ -20,15 +20,16 @@ import ( "bufio" "bytes" "compress/gzip" + "context" "crypto/tls" "encoding/binary" "errors" "fmt" - "internal/golang.org/x/net/http2/hpack" "io" "io/ioutil" "log" "net" + "net/http/httptrace" "net/textproto" "net/url" "os" @@ -39,6 +40,9 @@ import ( "strings" "sync" "time" + + "golang_org/x/net/http2/hpack" + "golang_org/x/net/lex/httplex" ) // ClientConnPool manages a pool of HTTP/2 client connections. @@ -47,6 +51,18 @@ type http2ClientConnPool interface { MarkDead(*http2ClientConn) } +// clientConnPoolIdleCloser is the interface implemented by ClientConnPool +// implementations which can close their idle connections. +type http2clientConnPoolIdleCloser interface { + http2ClientConnPool + closeIdleConnections() +} + +var ( + _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil) + _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{} +) + // TODO: use singleflight for dialing and addConnCalls? type http2clientConnPool struct { t *http2Transport @@ -247,6 +263,15 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [ return out } +// noDialClientConnPool is an implementation of http2.ClientConnPool +// which never dials. We let the HTTP/1.1 client dial and use its TLS +// connection instead. +type http2noDialClientConnPool struct{ *http2clientConnPool } + +func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2noDialOnMiss) +} + func http2configureTransport(t1 *Transport) (*http2Transport, error) { connPool := new(http2clientConnPool) t2 := &http2Transport{ @@ -267,7 +292,7 @@ func http2configureTransport(t1 *Transport) (*http2Transport, error) { t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") } upgradeFn := func(authority string, c *tls.Conn) RoundTripper { - addr := http2authorityAddr(authority) + addr := http2authorityAddr("https", authority) if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { go c.Close() return http2erringRoundTripper{err} @@ -299,15 +324,6 @@ func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) { return nil } -// noDialClientConnPool is an implementation of http2.ClientConnPool -// which never dials. We let the HTTP/1.1 client dial and use its TLS -// connection instead. -type http2noDialClientConnPool struct{ *http2clientConnPool } - -func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { - return p.getClientConn(req, addr, http2noDialOnMiss) -} - // noDialH2RoundTripper is a RoundTripper which only tries to complete the request // if there's already has a cached connection to the host. type http2noDialH2RoundTripper struct{ t *http2Transport } @@ -403,6 +419,35 @@ func (e http2connError) Error() string { return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) } +type http2pseudoHeaderError string + +func (e http2pseudoHeaderError) Error() string { + return fmt.Sprintf("invalid pseudo-header %q", string(e)) +} + +type http2duplicatePseudoHeaderError string + +func (e http2duplicatePseudoHeaderError) Error() string { + return fmt.Sprintf("duplicate pseudo-header %q", string(e)) +} + +type http2headerFieldNameError string + +func (e http2headerFieldNameError) Error() string { + return fmt.Sprintf("invalid header field name %q", string(e)) +} + +type http2headerFieldValueError string + +func (e http2headerFieldValueError) Error() string { + return fmt.Sprintf("invalid header field value %q", string(e)) +} + +var ( + http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") + http2errPseudoAfterRegular = errors.New("pseudo header field after regular") +) + // fixedBuffer is an io.ReadWriter backed by a fixed size buffer. // It never allocates, but moves old data as new data is written. type http2fixedBuffer struct { @@ -743,7 +788,7 @@ type http2Frame interface { type http2Framer struct { r io.Reader lastFrame http2Frame - errReason string + errDetail error // lastHeaderStream is non-zero if the last frame was an // unfinished HEADERS/CONTINUATION. @@ -775,14 +820,33 @@ type http2Framer struct { // to return non-compliant frames or frame orders. // This is for testing and permits using the Framer to test // other HTTP/2 implementations' conformance to the spec. + // It is not compatible with ReadMetaHeaders. AllowIllegalReads bool + // ReadMetaHeaders if non-nil causes ReadFrame to merge + // HEADERS and CONTINUATION frames together and return + // MetaHeadersFrame instead. + ReadMetaHeaders *hpack.Decoder + + // MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE. + // It's used only if ReadMetaHeaders is set; 0 means a sane default + // (currently 16MB) + // If the limit is hit, MetaHeadersFrame.Truncated is set true. + MaxHeaderListSize uint32 + logReads bool debugFramer *http2Framer // only use for logging written writes debugFramerBuf *bytes.Buffer } +func (fr *http2Framer) maxHeaderListSize() uint32 { + if fr.MaxHeaderListSize == 0 { + return 16 << 20 + } + return fr.MaxHeaderListSize +} + func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) { f.wbuf = append(f.wbuf[:0], @@ -879,6 +943,17 @@ func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { fr.maxReadSize = v } +// ErrorDetail returns a more detailed error of the last error +// returned by Framer.ReadFrame. For instance, if ReadFrame +// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail +// will say exactly what was invalid. ErrorDetail is not guaranteed +// to return a non-nil value and like the rest of the http2 package, +// its return value is not protected by an API compatibility promise. +// ErrorDetail is reset after the next call to ReadFrame. +func (fr *http2Framer) ErrorDetail() error { + return fr.errDetail +} + // ErrFrameTooLarge is returned from Framer.ReadFrame when the peer // sends a frame that is larger than declared with SetMaxReadFrameSize. var http2ErrFrameTooLarge = errors.New("http2: frame too large") @@ -897,9 +972,10 @@ func http2terminalReadFrameError(err error) bool { // // If the frame is larger than previously set with SetMaxReadFrameSize, the // returned error is ErrFrameTooLarge. Other errors may be of type -// ConnectionError, StreamError, or anything else from from the underlying +// ConnectionError, StreamError, or anything else from the underlying // reader. func (fr *http2Framer) ReadFrame() (http2Frame, error) { + fr.errDetail = nil if fr.lastFrame != nil { fr.lastFrame.invalidate() } @@ -927,6 +1003,9 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { if fr.logReads { log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } + if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { + return fr.readMetaFrame(f.(*http2HeadersFrame)) + } return f, nil } @@ -935,7 +1014,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { // to the peer before hanging up on them. This might help others debug // their implementations. func (fr *http2Framer) connError(code http2ErrCode, reason string) error { - fr.errReason = reason + fr.errDetail = errors.New(reason) return http2ConnectionError(code) } @@ -1023,7 +1102,14 @@ func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error return f, nil } -var http2errStreamID = errors.New("invalid streamid") +var ( + http2errStreamID = errors.New("invalid stream ID") + http2errDepStreamID = errors.New("invalid dependent stream ID") +) + +func http2validStreamIDOrZero(streamID uint32) bool { + return streamID&(1<<31) == 0 +} func http2validStreamID(streamID uint32) bool { return streamID != 0 && streamID&(1<<31) == 0 @@ -1389,8 +1475,8 @@ func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { } if !p.Priority.IsZero() { v := p.Priority.StreamDep - if !http2validStreamID(v) && !f.AllowIllegalWrites { - return errors.New("invalid dependent stream id") + if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites { + return http2errDepStreamID } if p.Priority.Exclusive { v |= 1 << 31 @@ -1458,6 +1544,9 @@ func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error if !http2validStreamID(streamID) && !f.AllowIllegalWrites { return http2errStreamID } + if !http2validStreamIDOrZero(p.StreamDep) { + return http2errDepStreamID + } f.startWrite(http2FramePriority, 0, streamID) v := p.StreamDep if p.Exclusive { @@ -1669,6 +1758,193 @@ type http2headersEnder interface { HeadersEnded() bool } +type http2headersOrContinuation interface { + http2headersEnder + HeaderBlockFragment() []byte +} + +// A MetaHeadersFrame is the representation of one HEADERS frame and +// zero or more contiguous CONTINUATION frames and the decoding of +// their HPACK-encoded contents. +// +// This type of frame does not appear on the wire and is only returned +// by the Framer when Framer.ReadMetaHeaders is set. +type http2MetaHeadersFrame struct { + *http2HeadersFrame + + // Fields are the fields contained in the HEADERS and + // CONTINUATION frames. The underlying slice is owned by the + // Framer and must not be retained after the next call to + // ReadFrame. + // + // Fields are guaranteed to be in the correct http2 order and + // not have unknown pseudo header fields or invalid header + // field names or values. Required pseudo header fields may be + // missing, however. Use the MetaHeadersFrame.Pseudo accessor + // method access pseudo headers. + Fields []hpack.HeaderField + + // Truncated is whether the max header list size limit was hit + // and Fields is incomplete. The hpack decoder state is still + // valid, however. + Truncated bool +} + +// PseudoValue returns the given pseudo header field's value. +// The provided pseudo field should not contain the leading colon. +func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { + for _, hf := range mh.Fields { + if !hf.IsPseudo() { + return "" + } + if hf.Name[1:] == pseudo { + return hf.Value + } + } + return "" +} + +// RegularFields returns the regular (non-pseudo) header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[i:] + } + } + return nil +} + +// PseudoFields returns the pseudo header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[:i] + } + } + return mh.Fields +} + +func (mh *http2MetaHeadersFrame) checkPseudos() error { + var isRequest, isResponse bool + pf := mh.PseudoFields() + for i, hf := range pf { + switch hf.Name { + case ":method", ":path", ":scheme", ":authority": + isRequest = true + case ":status": + isResponse = true + default: + return http2pseudoHeaderError(hf.Name) + } + + for _, hf2 := range pf[:i] { + if hf.Name == hf2.Name { + return http2duplicatePseudoHeaderError(hf.Name) + } + } + } + if isRequest && isResponse { + return http2errMixPseudoHeaderTypes + } + return nil +} + +func (fr *http2Framer) maxHeaderStringLen() int { + v := fr.maxHeaderListSize() + if uint32(int(v)) == v { + return int(v) + } + + return 0 +} + +// readMetaFrame returns 0 or more CONTINUATION frames from fr and +// merge them into into the provided hf and returns a MetaHeadersFrame +// with the decoded hpack values. +func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) { + if fr.AllowIllegalReads { + return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") + } + mh := &http2MetaHeadersFrame{ + http2HeadersFrame: hf, + } + var remainSize = fr.maxHeaderListSize() + var sawRegular bool + + var invalid error // pseudo header field errors + hdec := fr.ReadMetaHeaders + hdec.SetEmitEnabled(true) + hdec.SetMaxStringLength(fr.maxHeaderStringLen()) + hdec.SetEmitFunc(func(hf hpack.HeaderField) { + if !httplex.ValidHeaderFieldValue(hf.Value) { + invalid = http2headerFieldValueError(hf.Value) + } + isPseudo := strings.HasPrefix(hf.Name, ":") + if isPseudo { + if sawRegular { + invalid = http2errPseudoAfterRegular + } + } else { + sawRegular = true + if !http2validWireHeaderFieldName(hf.Name) { + invalid = http2headerFieldNameError(hf.Name) + } + } + + if invalid != nil { + hdec.SetEmitEnabled(false) + return + } + + size := hf.Size() + if size > remainSize { + hdec.SetEmitEnabled(false) + mh.Truncated = true + return + } + remainSize -= size + + mh.Fields = append(mh.Fields, hf) + }) + + defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) + + var hc http2headersOrContinuation = hf + for { + frag := hc.HeaderBlockFragment() + if _, err := hdec.Write(frag); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + + if hc.HeadersEnded() { + break + } + if f, err := fr.ReadFrame(); err != nil { + return nil, err + } else { + hc = f.(*http2ContinuationFrame) + } + } + + mh.http2HeadersFrame.headerFragBuf = nil + mh.http2HeadersFrame.invalidate() + + if err := hdec.Close(); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + if invalid != nil { + fr.errDetail = invalid + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol} + } + if err := mh.checkPseudos(); err != nil { + fr.errDetail = err + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol} + } + return mh, nil +} + func http2summarizeFrame(f http2Frame) string { var buf bytes.Buffer f.Header().writeDebug(&buf) @@ -1712,7 +1988,111 @@ func http2summarizeFrame(f http2Frame) string { return buf.String() } -func http2requestCancel(req *Request) <-chan struct{} { return req.Cancel } +func http2transportExpectContinueTimeout(t1 *Transport) time.Duration { + return t1.ExpectContinueTimeout +} + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +func http2isBadCipher(cipher uint16) bool { + switch cipher { + case tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + + return true + default: + return false + } +} + +type http2contextContext interface { + context.Context +} + +func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) { + ctx, cancel = context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, ServerContextKey, hs) + } + return +} + +func http2contextWithCancel(ctx http2contextContext) (_ http2contextContext, cancel func()) { + return context.WithCancel(ctx) +} + +func http2requestWithContext(req *Request, ctx http2contextContext) *Request { + return req.WithContext(ctx) +} + +type http2clientTrace httptrace.ClientTrace + +func http2reqContext(r *Request) context.Context { return r.Context() } + +func http2setResponseUncompressed(res *Response) { res.Uncompressed = true } + +func http2traceGotConn(req *Request, cc *http2ClientConn) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GotConn == nil { + return + } + ci := httptrace.GotConnInfo{Conn: cc.tconn} + cc.mu.Lock() + ci.Reused = cc.nextStreamID > 1 + ci.WasIdle = len(cc.streams) == 0 && ci.Reused + if ci.WasIdle && !cc.lastActive.IsZero() { + ci.IdleTime = time.Now().Sub(cc.lastActive) + } + cc.mu.Unlock() + + trace.GotConn(ci) +} + +func http2traceWroteHeaders(trace *http2clientTrace) { + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } +} + +func http2traceGot100Continue(trace *http2clientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } +} + +func http2traceWait100Continue(trace *http2clientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } +} + +func http2traceWroteRequest(trace *http2clientTrace, err error) { + if trace != nil && trace.WroteRequest != nil { + trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + } +} + +func http2traceFirstResponseByte(trace *http2clientTrace) { + if trace != nil && trace.GotFirstResponseByte != nil { + trace.GotFirstResponseByte() + } +} + +func http2requestTrace(req *Request) *http2clientTrace { + trace := httptrace.ContextClientTrace(req.Context()) + return (*http2clientTrace)(trace) +} var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" @@ -2070,57 +2450,23 @@ var ( http2errInvalidHeaderFieldValue = errors.New("http2: invalid header field value") ) -// validHeaderFieldName reports whether v is a valid header field name (key). -// RFC 7230 says: -// header-field = field-name ":" OWS field-value OWS -// field-name = token -// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / -// "^" / "_" / " +// validWireHeaderFieldName reports whether v is a valid header field +// name (key). See httplex.ValidHeaderName for the base rules. +// // Further, http2 says: // "Just as in HTTP/1.x, header field names are strings of ASCII // characters that are compared in a case-insensitive // fashion. However, header field names MUST be converted to // lowercase prior to their encoding in HTTP/2. " -func http2validHeaderFieldName(v string) bool { +func http2validWireHeaderFieldName(v string) bool { if len(v) == 0 { return false } for _, r := range v { - if int(r) >= len(http2isTokenTable) || ('A' <= r && r <= 'Z') { - return false - } - if !http2isTokenTable[byte(r)] { + if !httplex.IsTokenRune(r) { return false } - } - return true -} - -// validHeaderFieldValue reports whether v is a valid header field value. -// -// RFC 7230 says: -// field-value = *( field-content / obs-fold ) -// obj-fold = N/A to http2, and deprecated -// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] -// field-vchar = VCHAR / obs-text -// obs-text = %x80-FF -// VCHAR = "any visible [USASCII] character" -// -// http2 further says: "Similarly, HTTP/2 allows header field values -// that are not valid. While most of the values that can be encoded -// will not alter header field parsing, carriage return (CR, ASCII -// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII -// 0x0) might be exploited by an attacker if they are translated -// verbatim. Any request or response that contains a character not -// permitted in a header field value MUST be treated as malformed -// (Section 8.1.2.6). Valid characters are defined by the -// field-content ABNF rule in Section 3.2 of [RFC7230]." -// -// This function does not (yet?) properly handle the rejection of -// strings that begin or end with SP or HTAB. -func http2validHeaderFieldValue(v string) bool { - for i := 0; i < len(v); i++ { - if b := v[i]; b < ' ' && b != '\t' || b == 0x7f { + if 'A' <= r && r <= 'Z' { return false } } @@ -2225,7 +2571,7 @@ func http2mustUint31(v int32) uint32 { } // bodyAllowedForStatus reports whether a given response status code -// permits a body. See RFC2616, section 4.4. +// permits a body. See RFC 2616, section 4.4. func http2bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: @@ -2251,90 +2597,44 @@ func (e *http2httpError) Temporary() bool { return true } var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} -var http2isTokenTable = [127]bool{ - '!': true, - '#': true, - '$': true, - '%': true, - '&': true, - '\'': true, - '*': true, - '+': true, - '-': true, - '.': true, - '0': true, - '1': true, - '2': true, - '3': true, - '4': true, - '5': true, - '6': true, - '7': true, - '8': true, - '9': true, - 'A': true, - 'B': true, - 'C': true, - 'D': true, - 'E': true, - 'F': true, - 'G': true, - 'H': true, - 'I': true, - 'J': true, - 'K': true, - 'L': true, - 'M': true, - 'N': true, - 'O': true, - 'P': true, - 'Q': true, - 'R': true, - 'S': true, - 'T': true, - 'U': true, - 'W': true, - 'V': true, - 'X': true, - 'Y': true, - 'Z': true, - '^': true, - '_': true, - '`': true, - 'a': true, - 'b': true, - 'c': true, - 'd': true, - 'e': true, - 'f': true, - 'g': true, - 'h': true, - 'i': true, - 'j': true, - 'k': true, - 'l': true, - 'm': true, - 'n': true, - 'o': true, - 'p': true, - 'q': true, - 'r': true, - 's': true, - 't': true, - 'u': true, - 'v': true, - 'w': true, - 'x': true, - 'y': true, - 'z': true, - '|': true, - '~': true, -} - type http2connectionStater interface { ConnectionState() tls.ConnectionState } +var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }} + +type http2sorter struct { + v []string // owned by sorter +} + +func (s *http2sorter) Len() int { return len(s.v) } + +func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } + +func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } + +// Keys returns the sorted keys of h. +// +// The returned slice is only valid until s used again or returned to +// its pool. +func (s *http2sorter) Keys(h Header) []string { + keys := s.v[:0] + for k := range h { + keys = append(keys, k) + } + s.v = keys + sort.Sort(s) + return keys +} + +func (s *http2sorter) SortStrings(ss []string) { + + save := s.v + s.v = ss + sort.Sort(s) + s.v = save +} + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -2354,6 +2654,12 @@ type http2pipeBuffer interface { io.Reader } +func (p *http2pipe) Len() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.b.Len() +} + // Read waits until data is available and copies bytes // from the buffer into p. func (p *http2pipe) Read(d []byte) (n int, err error) { @@ -2653,10 +2959,14 @@ func (o *http2ServeConnOpts) handler() Handler { // // The opts parameter is optional. If nil, default values are used. func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { + baseCtx, cancel := http2serverConnBaseContext(c, opts) + defer cancel() + sc := &http2serverConn{ srv: s, hs: opts.baseConfig(), conn: c, + baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), bw: http2newBufferedWriter(c), handler: opts.handler(), @@ -2675,13 +2985,14 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { serveG: http2newGoroutineLock(), pushEnabled: true, } + sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) - sc.hpackDecoder = hpack.NewDecoder(http2initialHeaderTableSize, nil) - sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen()) fr := http2NewFramer(sc.bw, c) + fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.SetMaxReadFrameSize(s.maxReadFrameSize()) sc.framer = fr @@ -2711,27 +3022,6 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { sc.serve() } -// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. -func http2isBadCipher(cipher uint16) bool { - switch cipher { - case tls.TLS_RSA_WITH_RC4_128_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: - - return true - default: - return false - } -} - func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) @@ -2747,8 +3037,8 @@ type http2serverConn struct { conn net.Conn bw *http2bufferedWriter // writing to conn handler Handler + baseCtx http2contextContext framer *http2Framer - hpackDecoder *hpack.Decoder doneServing chan struct{} // closed when serverConn.serve ends readFrameCh chan http2readFrameResult // written by serverConn.readFrames wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve @@ -2775,7 +3065,6 @@ type http2serverConn struct { headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - req http2requestParam // non-zero while reading request headers writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh needsFrameFlush bool // last frame write wasn't a flush writeSched http2writeScheduler @@ -2784,21 +3073,13 @@ type http2serverConn struct { goAwayCode http2ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used + freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer hpackEncoder *hpack.Encoder } -func (sc *http2serverConn) maxHeaderStringLen() int { - v := sc.maxHeaderListSize() - if uint32(int(v)) == v { - return int(v) - } - - return 0 -} - func (sc *http2serverConn) maxHeaderListSize() uint32 { n := sc.hs.MaxHeaderBytes if n <= 0 { @@ -2811,21 +3092,6 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { return uint32(n + typicalHeaders*perFieldOverhead) } -// requestParam is the state of the next request, initialized over -// potentially several frames HEADERS + zero or more CONTINUATION -// frames. -type http2requestParam struct { - // stream is non-nil if we're reading (HEADER or CONTINUATION) - // frames for a request (but not DATA). - stream *http2stream - header Header - method, path string - scheme, authority string - sawRegularHeader bool // saw a non-pseudo header already - invalidHeader bool // an invalid header was seen - headerListSize int64 // actually uint32, but easier math this way -} - // stream represents a stream. This is the minimal metadata needed by // the serve goroutine. Most of the actual stream state is owned by // the http.Handler's goroutine in the responseWriter. Because the @@ -2835,10 +3101,12 @@ type http2requestParam struct { // responseWriter's state field. type http2stream struct { // immutable: - sc *http2serverConn - id uint32 - body *http2pipe // non-nil if expecting DATA frames - cw http2closeWaiter // closed wait stream transitions to closed state + sc *http2serverConn + id uint32 + body *http2pipe // non-nil if expecting DATA frames + cw http2closeWaiter // closed wait stream transitions to closed state + ctx http2contextContext + cancelCtx func() // owned by serverConn's serve loop: bodyBytes int64 // body bytes seen so far @@ -2852,6 +3120,8 @@ type http2stream struct { sentReset bool // only true once detached from streams map gotReset bool // only true once detacted from streams map gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + reqBuf []byte trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -2952,83 +3222,6 @@ func (sc *http2serverConn) condlogf(err error, format string, args ...interface{ } } -func (sc *http2serverConn) onNewHeaderField(f hpack.HeaderField) { - sc.serveG.check() - if http2VerboseLogs { - sc.vlogf("http2: server decoded %v", f) - } - switch { - case !http2validHeaderFieldValue(f.Value): - sc.req.invalidHeader = true - case strings.HasPrefix(f.Name, ":"): - if sc.req.sawRegularHeader { - sc.logf("pseudo-header after regular header") - sc.req.invalidHeader = true - return - } - var dst *string - switch f.Name { - case ":method": - dst = &sc.req.method - case ":path": - dst = &sc.req.path - case ":scheme": - dst = &sc.req.scheme - case ":authority": - dst = &sc.req.authority - default: - - sc.logf("invalid pseudo-header %q", f.Name) - sc.req.invalidHeader = true - return - } - if *dst != "" { - sc.logf("duplicate pseudo-header %q sent", f.Name) - sc.req.invalidHeader = true - return - } - *dst = f.Value - case !http2validHeaderFieldName(f.Name): - sc.req.invalidHeader = true - default: - sc.req.sawRegularHeader = true - sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value) - const headerFieldOverhead = 32 // per spec - sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead - if sc.req.headerListSize > int64(sc.maxHeaderListSize()) { - sc.hpackDecoder.SetEmitEnabled(false) - } - } -} - -func (st *http2stream) onNewTrailerField(f hpack.HeaderField) { - sc := st.sc - sc.serveG.check() - if http2VerboseLogs { - sc.vlogf("http2: server decoded trailer %v", f) - } - switch { - case strings.HasPrefix(f.Name, ":"): - sc.req.invalidHeader = true - return - case !http2validHeaderFieldName(f.Name) || !http2validHeaderFieldValue(f.Value): - sc.req.invalidHeader = true - return - default: - key := sc.canonicalHeader(f.Name) - if st.trailer != nil { - vv := append(st.trailer[key], f.Value) - st.trailer[key] = vv - - // arbitrary; TODO: read spec about header list size limits wrt trailers - const tooBig = 1000 - if len(vv) >= tooBig { - sc.hpackDecoder.SetEmitEnabled(false) - } - } - } -} - func (sc *http2serverConn) canonicalHeader(v string) string { sc.serveG.check() cv, ok := http2commonCanonHeader[v] @@ -3063,10 +3256,11 @@ type http2readFrameResult struct { // It's run on its own goroutine. func (sc *http2serverConn) readFrames() { gate := make(http2gate) + gateDone := gate.Done for { f, err := sc.framer.ReadFrame() select { - case sc.readFrameCh <- http2readFrameResult{f, err, gate.Done}: + case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}: case <-sc.doneServing: return } @@ -3290,7 +3484,21 @@ func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { // If you're not on the serve goroutine, use writeFrameFromHandler instead. func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) { sc.serveG.check() - sc.writeSched.add(wm) + + var ignoreWrite bool + + switch wm.write.(type) { + case *http2writeResHeaders: + wm.stream.wroteHeaders = true + case http2write100ContinueHeadersFrame: + if wm.stream.wroteHeaders { + ignoreWrite = true + } + } + + if !ignoreWrite { + sc.writeSched.add(wm) + } sc.scheduleFrameWrite() } @@ -3511,10 +3719,8 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { switch f := f.(type) { case *http2SettingsFrame: return sc.processSettings(f) - case *http2HeadersFrame: + case *http2MetaHeadersFrame: return sc.processHeaders(f) - case *http2ContinuationFrame: - return sc.processContinuation(f) case *http2WindowUpdateFrame: return sc.processWindowUpdate(f) case *http2PingFrame: @@ -3579,6 +3785,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { } if st != nil { st.gotReset = true + st.cancelCtx() sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode}) } return nil @@ -3600,6 +3807,10 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { } st.cw.Close() sc.writeSched.forgetStream(st.id) + if st.reqBuf != nil { + + sc.freeRequestBodyBuf = st.reqBuf + } } func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { @@ -3732,7 +3943,7 @@ func (st *http2stream) copyTrailersToHandlerRequest() { } } -func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error { +func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() id := f.Header().StreamID if sc.inGoAway { @@ -3749,17 +3960,18 @@ func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error { return st.processTrailerHeaders(f) } - if id <= sc.maxStreamID || sc.req.stream != nil { + if id <= sc.maxStreamID { return http2ConnectionError(http2ErrCodeProtocol) } + sc.maxStreamID = id - if id > sc.maxStreamID { - sc.maxStreamID = id - } + ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) st = &http2stream{ - sc: sc, - id: id, - state: http2stateOpen, + sc: sc, + id: id, + state: http2stateOpen, + ctx: ctx, + cancelCtx: cancelCtx, } if f.StreamEnded() { st.state = http2stateHalfClosedRemote @@ -3779,50 +3991,6 @@ func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error { if sc.curOpenStreams == 1 { sc.setConnState(StateActive) } - sc.req = http2requestParam{ - stream: st, - header: make(Header), - } - sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField) - sc.hpackDecoder.SetEmitEnabled(true) - return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) -} - -func (st *http2stream) processTrailerHeaders(f *http2HeadersFrame) error { - sc := st.sc - sc.serveG.check() - if st.gotTrailerHeader { - return http2ConnectionError(http2ErrCodeProtocol) - } - st.gotTrailerHeader = true - if !f.StreamEnded() { - return http2StreamError{st.id, http2ErrCodeProtocol} - } - sc.resetPendingRequest() - return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) -} - -func (sc *http2serverConn) processContinuation(f *http2ContinuationFrame) error { - sc.serveG.check() - st := sc.streams[f.Header().StreamID] - if st.gotTrailerHeader { - return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) - } - return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) -} - -func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []byte, end bool) error { - sc.serveG.check() - if _, err := sc.hpackDecoder.Write(frag); err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - if !end { - return nil - } - if err := sc.hpackDecoder.Close(); err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - defer sc.resetPendingRequest() if sc.curOpenStreams > sc.advMaxStreams { if sc.unackedSettings == 0 { @@ -3833,7 +4001,7 @@ func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []by return http2StreamError{st.id, http2ErrCodeRefusedStream} } - rw, req, err := sc.newWriterAndRequest() + rw, req, err := sc.newWriterAndRequest(st, f) if err != nil { return err } @@ -3845,36 +4013,42 @@ func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []by st.declBodyBytes = req.ContentLength handler := sc.handler.ServeHTTP - if !sc.hpackDecoder.EmitEnabled() { + if f.Truncated { handler = http2handleHeaderListTooLong + } else if err := http2checkValidHTTP2Request(req); err != nil { + handler = http2new400Handler(err) } go sc.runHandler(rw, req, handler) return nil } -func (st *http2stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error { +func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { sc := st.sc sc.serveG.check() - sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField) - if _, err := sc.hpackDecoder.Write(frag); err != nil { - return http2ConnectionError(http2ErrCodeCompression) + if st.gotTrailerHeader { + return http2ConnectionError(http2ErrCodeProtocol) } - if !end { - return nil + st.gotTrailerHeader = true + if !f.StreamEnded() { + return http2StreamError{st.id, http2ErrCodeProtocol} } - rp := &sc.req - if rp.invalidHeader { - return http2StreamError{rp.stream.id, http2ErrCodeProtocol} + if len(f.PseudoFields()) > 0 { + return http2StreamError{st.id, http2ErrCodeProtocol} } + if st.trailer != nil { + for _, hf := range f.RegularFields() { + key := sc.canonicalHeader(hf.Name) + if !http2ValidTrailerHeader(key) { - err := sc.hpackDecoder.Close() - st.endStream() - if err != nil { - return http2ConnectionError(http2ErrCodeCompression) + return http2StreamError{st.id, http2ErrCodeProtocol} + } + st.trailer[key] = append(st.trailer[key], hf.Value) + } } + st.endStream() return nil } @@ -3912,59 +4086,56 @@ func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, } } -// resetPendingRequest zeros out all state related to a HEADERS frame -// and its zero or more CONTINUATION frames sent to start a new -// request. -func (sc *http2serverConn) resetPendingRequest() { +func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) { sc.serveG.check() - sc.req = http2requestParam{} -} -func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request, error) { - sc.serveG.check() - rp := &sc.req - - if rp.invalidHeader { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} - } + method := f.PseudoValue("method") + path := f.PseudoValue("path") + scheme := f.PseudoValue("scheme") + authority := f.PseudoValue("authority") - isConnect := rp.method == "CONNECT" + isConnect := method == "CONNECT" if isConnect { - if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + if path != "" || scheme != "" || authority == "" { + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } - } else if rp.method == "" || rp.path == "" || - (rp.scheme != "https" && rp.scheme != "http") { + } else if method == "" || path == "" || + (scheme != "https" && scheme != "http") { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } - bodyOpen := rp.stream.state == http2stateOpen - if rp.method == "HEAD" && bodyOpen { + bodyOpen := !f.StreamEnded() + if method == "HEAD" && bodyOpen { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } var tlsState *tls.ConnectionState // nil if not scheme https - if rp.scheme == "https" { + if scheme == "https" { tlsState = sc.tlsState } - authority := rp.authority + + header := make(Header) + for _, hf := range f.RegularFields() { + header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if authority == "" { - authority = rp.header.Get("Host") + authority = header.Get("Host") } - needsContinue := rp.header.Get("Expect") == "100-continue" + needsContinue := header.Get("Expect") == "100-continue" if needsContinue { - rp.header.Del("Expect") + header.Del("Expect") } - if cookies := rp.header["Cookie"]; len(cookies) > 1 { - rp.header.Set("Cookie", strings.Join(cookies, "; ")) + if cookies := header["Cookie"]; len(cookies) > 1 { + header.Set("Cookie", strings.Join(cookies, "; ")) } // Setup Trailers var trailer Header - for _, v := range rp.header["Trailer"] { + for _, v := range header["Trailer"] { for _, key := range strings.Split(v, ",") { key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { @@ -3978,31 +4149,31 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request } } } - delete(rp.header, "Trailer") + delete(header, "Trailer") body := &http2requestBody{ conn: sc, - stream: rp.stream, + stream: st, needsContinue: needsContinue, } var url_ *url.URL var requestURI string if isConnect { - url_ = &url.URL{Host: rp.authority} - requestURI = rp.authority + url_ = &url.URL{Host: authority} + requestURI = authority } else { var err error - url_, err = url.ParseRequestURI(rp.path) + url_, err = url.ParseRequestURI(path) if err != nil { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } - requestURI = rp.path + requestURI = path } req := &Request{ - Method: rp.method, + Method: method, URL: url_, RemoteAddr: sc.remoteAddrStr, - Header: rp.header, + Header: header, RequestURI: requestURI, Proto: "HTTP/2.0", ProtoMajor: 2, @@ -4012,12 +4183,16 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request Body: body, Trailer: trailer, } + req = http2requestWithContext(req, st.ctx) if bodyOpen { + + buf := make([]byte, http2initialWindowSize) + body.pipe = &http2pipe{ - b: &http2fixedBuffer{buf: make([]byte, http2initialWindowSize)}, + b: &http2fixedBuffer{buf: buf}, } - if vv, ok := rp.header["Content-Length"]; ok { + if vv, ok := header["Content-Length"]; ok { req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) } else { req.ContentLength = -1 @@ -4030,7 +4205,7 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request rws.conn = sc rws.bw = bwSave rws.bw.Reset(http2chunkWriter{rws}) - rws.stream = rp.stream + rws.stream = st rws.req = req rws.body = body @@ -4038,10 +4213,20 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request return rw, req, nil } +func (sc *http2serverConn) getRequestBodyBuf() []byte { + sc.serveG.check() + if buf := sc.freeRequestBodyBuf; buf != nil { + sc.freeRequestBodyBuf = nil + return buf + } + return make([]byte, http2initialWindowSize) +} + // Run on its own goroutine. func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { didPanic := true defer func() { + rw.rws.stream.cancelCtx() if didPanic { e := recover() // Same as net/http: @@ -4190,7 +4375,7 @@ type http2requestBody struct { func (b *http2requestBody) Close() error { if b.pipe != nil { - b.pipe.CloseWithError(http2errClosedBody) + b.pipe.BreakWithError(http2errClosedBody) } b.closed = true return nil @@ -4265,9 +4450,9 @@ func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailer // written in the trailers at the end of the response. func (rws *http2responseWriterState) declareTrailer(k string) { k = CanonicalHeaderKey(k) - switch k { - case "Transfer-Encoding", "Content-Length", "Trailer": + if !http2ValidTrailerHeader(k) { + rws.conn.logf("ignoring invalid trailer %q", k) return } if !http2strSliceContains(rws.trailers, k) { @@ -4408,7 +4593,12 @@ func (rws *http2responseWriterState) promoteUndeclaredTrailers() { rws.declareTrailer(trailerKey) rws.handlerHeader[CanonicalHeaderKey(trailerKey)] = vv } - sort.Strings(rws.trailers) + + if len(rws.trailers) > 1 { + sorter := http2sorterPool.Get().(*http2sorter) + sorter.SortStrings(rws.trailers) + http2sorterPool.Put(sorter) + } } func (w *http2responseWriter) Flush() { @@ -4552,6 +4742,72 @@ func http2foreachHeaderElement(v string, fn func(string)) { } } +// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 +var http2connHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Transfer-Encoding", + "Upgrade", +} + +// checkValidHTTP2Request checks whether req is a valid HTTP/2 request, +// per RFC 7540 Section 8.1.2.2. +// The returned error is reported to users. +func http2checkValidHTTP2Request(req *Request) error { + for _, h := range http2connHeaders { + if _, ok := req.Header[h]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", h) + } + } + te := req.Header["Te"] + if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { + return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + } + return nil +} + +func http2new400Handler(err error) HandlerFunc { + return func(w ResponseWriter, r *Request) { + Error(w, err.Error(), StatusBadRequest) + } +} + +// ValidTrailerHeader reports whether name is a valid header field name to appear +// in trailers. +// See: http://tools.ietf.org/html/rfc7230#section-4.1.2 +func http2ValidTrailerHeader(name string) bool { + name = CanonicalHeaderKey(name) + if strings.HasPrefix(name, "If-") || http2badTrailer[name] { + return false + } + return true +} + +var http2badTrailer = map[string]bool{ + "Authorization": true, + "Cache-Control": true, + "Connection": true, + "Content-Encoding": true, + "Content-Length": true, + "Content-Range": true, + "Content-Type": true, + "Expect": true, + "Host": true, + "Keep-Alive": true, + "Max-Forwards": true, + "Pragma": true, + "Proxy-Authenticate": true, + "Proxy-Authorization": true, + "Proxy-Connection": true, + "Range": true, + "Realm": true, + "Te": true, + "Trailer": true, + "Transfer-Encoding": true, + "Www-Authenticate": true, +} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -4601,6 +4857,10 @@ type http2Transport struct { // uncompressed. DisableCompression bool + // AllowHTTP, if true, permits HTTP/2 requests using the insecure, + // plain-text "http" scheme. Note that this does not enable h2c support. + AllowHTTP bool + // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to // send in the initial settings frame. It is how many bytes // of response headers are allow. Unlike the http2 spec, zero here @@ -4673,11 +4933,14 @@ type http2ClientConn struct { inflow http2flow // peer's conn-level flow control closed bool goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated nextStreamID uint32 bw *bufio.Writer br *bufio.Reader fr *http2Framer + lastActive time.Time + // Settings from peer: maxFrameSize uint32 maxConcurrentStreams uint32 @@ -4695,10 +4958,12 @@ type http2ClientConn struct { type http2clientStream struct { cc *http2ClientConn req *Request + trace *http2clientTrace // or nil ID uint32 resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload requestedGzip bool + on100 func() // optional code to run if get a 100 continue response flow http2flow // guarded by cc.mu inflow http2flow // guarded by cc.mu @@ -4712,36 +4977,43 @@ type http2clientStream struct { done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu // owned by clientConnReadLoop: - pastHeaders bool // got HEADERS w/ END_HEADERS - pastTrailers bool // got second HEADERS frame w/ END_HEADERS + firstByte bool // got the first response byte + pastHeaders bool // got first MetaHeadersFrame (actual headers) + pastTrailers bool // got optional second MetaHeadersFrame (trailers) trailer Header // accumulated trailers resTrailer *Header // client's Response.Trailer } // awaitRequestCancel runs in its own goroutine and waits for the user -// to either cancel a RoundTrip request (using the provided -// Request.Cancel channel), or for the request to be done (any way it -// might be removed from the cc.streams map: peer reset, successful -// completion, TCP connection breakage, etc) -func (cs *http2clientStream) awaitRequestCancel(cancel <-chan struct{}) { - if cancel == nil { +// to cancel a RoundTrip request, its context to expire, or for the +// request to be done (any way it might be removed from the cc.streams +// map: peer reset, successful completion, TCP connection breakage, +// etc) +func (cs *http2clientStream) awaitRequestCancel(req *Request) { + ctx := http2reqContext(req) + if req.Cancel == nil && ctx.Done() == nil { return } select { - case <-cancel: + case <-req.Cancel: cs.bufPipe.CloseWithError(http2errRequestCanceled) cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + case <-ctx.Done(): + cs.bufPipe.CloseWithError(ctx.Err()) + cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-cs.done: } } -// checkReset reports any error sent in a RST_STREAM frame by the -// server. -func (cs *http2clientStream) checkReset() error { +// checkResetOrDone reports any error sent in a RST_STREAM frame by the +// server, or errStreamClosed if the stream is complete. +func (cs *http2clientStream) checkResetOrDone() error { select { case <-cs.peerReset: return cs.resetErr + case <-cs.done: + return http2errStreamClosed default: return nil } @@ -4789,26 +5061,31 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. -func http2authorityAddr(authority string) (addr string) { +func http2authorityAddr(scheme string, authority string) (addr string) { if _, _, err := net.SplitHostPort(authority); err == nil { return authority } - return net.JoinHostPort(authority, "443") + port := "443" + if scheme == "http" { + port = "80" + } + return net.JoinHostPort(authority, port) } // RoundTripOpt is like RoundTrip, but takes options. func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) { - if req.URL.Scheme != "https" { + if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { return nil, errors.New("http2: unsupported scheme") } - addr := http2authorityAddr(req.URL.Host) + addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) for { cc, err := t.connPool().GetClientConn(req, addr) if err != nil { t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) return nil, err } + http2traceGotConn(req, cc) res, err := cc.RoundTrip(req) if http2shouldRetryRequest(req, err) { continue @@ -4825,7 +5102,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res // connected from previous requests but are now sitting idle. // It does not interrupt any connections currently in use. func (t *http2Transport) CloseIdleConnections() { - if cp, ok := t.connPool().(*http2clientConnPool); ok { + if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok { cp.closeIdleConnections() } } @@ -4857,8 +5134,12 @@ func (t *http2Transport) newTLSConfig(host string) *tls.Config { if t.TLSClientConfig != nil { *cfg = *t.TLSClientConfig } - cfg.NextProtos = []string{http2NextProtoTLS} - cfg.ServerName = host + if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { + cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) + } + if cfg.ServerName == "" { + cfg.ServerName = host + } return cfg } @@ -4898,6 +5179,13 @@ func (t *http2Transport) disableKeepAlives() bool { return t.t1 != nil && t.t1.DisableKeepAlives } +func (t *http2Transport) expectContinueTimeout() time.Duration { + if t.t1 == nil { + return 0 + } + return http2transportExpectContinueTimeout(t.t1) +} + func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { if http2VerboseLogs { t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) @@ -4923,6 +5211,8 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) cc.br = bufio.NewReader(c) cc.fr = http2NewFramer(cc.bw, cc.br) + cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + cc.fr.MaxHeaderListSize = t.maxHeaderListSize() cc.henc = hpack.NewEncoder(&cc.hbuf) @@ -4932,8 +5222,8 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { } initialSettings := []http2Setting{ - http2Setting{ID: http2SettingEnablePush, Val: 0}, - http2Setting{ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + {ID: http2SettingEnablePush, Val: 0}, + {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, } if max := t.maxHeaderListSize(); max != 0 { initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) @@ -4979,7 +5269,16 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() + + old := cc.goAway cc.goAway = f + + if cc.goAwayDebug == "" { + cc.goAwayDebug = string(f.DebugData()) + } + if old != nil && old.ErrCode != http2ErrCodeNo { + cc.goAway.ErrCode = old.ErrCode + } } func (cc *http2ClientConn) CanTakeNewRequest() bool { @@ -5093,6 +5392,30 @@ func http2checkConnHeaders(req *Request) error { return nil } +func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) { + body = req.Body + if body == nil { + return nil, 0 + } + if req.ContentLength != 0 { + return req.Body, req.ContentLength + } + + // We have a body but a zero content length. Test to see if + // it's actually zero or just unset. + var buf [1]byte + n, rerr := io.ReadFull(body, buf[:]) + if rerr != nil && rerr != io.EOF { + return http2errorReader{rerr}, -1 + } + if n == 1 { + + return io.MultiReader(bytes.NewReader(buf[:]), body), -1 + } + + return nil, 0 +} + func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { if err := http2checkConnHeaders(req); err != nil { return nil, err @@ -5104,67 +5427,62 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } hasTrailers := trailers != "" - var body io.Reader = req.Body - contentLen := req.ContentLength - if req.Body != nil && contentLen == 0 { - // Test to see if it's actually zero or just unset. - var buf [1]byte - n, rerr := io.ReadFull(body, buf[:]) - if rerr != nil && rerr != io.EOF { - contentLen = -1 - body = http2errorReader{rerr} - } else if n == 1 { - - contentLen = -1 - body = io.MultiReader(bytes.NewReader(buf[:]), body) - } else { - - body = nil - } - } + body, contentLen := http2bodyAndLength(req) + hasBody := body != nil cc.mu.Lock() + cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { cc.mu.Unlock() return nil, http2errClientConnUnusable } - cs := cc.newStream() - cs.req = req - hasBody := body != nil - + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + var requestedGzip bool if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { - cs.requestedGzip = true + requestedGzip = true } - hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) + hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) + if err != nil { + cc.mu.Unlock() + return nil, err + } + + cs := cc.newStream() + cs.req = req + cs.trace = http2requestTrace(req) + cs.requestedGzip = requestedGzip + bodyWriter := cc.t.getBodyWriterState(cs, body) + cs.on100 = bodyWriter.on100 + cc.wmu.Lock() endStream := !hasBody && !hasTrailers werr := cc.writeHeaders(cs.ID, endStream, hdrs) cc.wmu.Unlock() + http2traceWroteHeaders(cs.trace) cc.mu.Unlock() if werr != nil { if hasBody { req.Body.Close() + bodyWriter.cancel() } cc.forgetStreamID(cs.ID) + http2traceWroteRequest(cs.trace, werr) return nil, werr } var respHeaderTimer <-chan time.Time - var bodyCopyErrc chan error // result of body copy if hasBody { - bodyCopyErrc = make(chan error, 1) - go func() { - bodyCopyErrc <- cs.writeRequestBody(body, req.Body) - }() + bodyWriter.scheduleBodyWrite() } else { + http2traceWroteRequest(cs.trace, nil) if d := cc.responseHeaderTimeout(); d != 0 { timer := time.NewTimer(d) defer timer.Stop() @@ -5173,44 +5491,78 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } readLoopResCh := cs.resc - requestCanceledCh := http2requestCancel(req) bodyWritten := false + ctx := http2reqContext(req) + + reFunc := func(re http2resAndError) (*Response, error) { + res := re.res + if re.err != nil || res.StatusCode > 299 { + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWrite) + } + if re.err != nil { + cc.forgetStreamID(cs.ID) + return nil, re.err + } + res.Request = req + res.TLS = cc.tlsState + return res, nil + } for { select { case re := <-readLoopResCh: - res := re.res - if re.err != nil || res.StatusCode > 299 { - - cs.abortRequestBodyWrite(http2errStopReqBodyWrite) - } - if re.err != nil { - cc.forgetStreamID(cs.ID) - return nil, re.err - } - res.Request = req - res.TLS = cc.tlsState - return res, nil + return reFunc(re) case <-respHeaderTimer: cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { + bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } return nil, http2errTimeout - case <-requestCanceledCh: + case <-ctx.Done(): + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } + cc.forgetStreamID(cs.ID) + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + } + return nil, ctx.Err() + case <-req.Cancel: + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { + bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } return nil, http2errRequestCanceled case <-cs.peerReset: - + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } return nil, cs.resetErr - case err := <-bodyCopyErrc: + case err := <-bodyWriter.resc: + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } if err != nil { return nil, err } @@ -5268,6 +5620,7 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos defer cc.putFrameScratchBuffer(buf) defer func() { + http2traceWroteRequest(cs.trace, err) cerr := bodyCloser.Close() if err == nil { @@ -5355,7 +5708,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er if cs.stopReqBody != nil { return 0, cs.stopReqBody } - if err := cs.checkReset(); err != nil { + if err := cs.checkResetOrDone(); err != nil { return 0, err } if a := cs.flow.available(); a > 0 { @@ -5382,7 +5735,7 @@ type http2badStringError struct { func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } // requires cc.mu be held. -func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) []byte { +func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { cc.hbuf.Reset() host := req.Host @@ -5390,6 +5743,17 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail host = req.URL.Host } + for k, vv := range req.Header { + if !httplex.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httplex.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } + } + } + cc.writeHeader(":authority", host) cc.writeHeader(":method", req.Method) if req.Method != "CONNECT" { @@ -5407,7 +5771,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail case "host", "content-length": continue - case "connection", "proxy-connection", "transfer-encoding", "upgrade": + case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": continue case "user-agent": @@ -5434,7 +5798,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail if !didUA { cc.writeHeader("user-agent", http2defaultUserAgent) } - return cc.hbuf.Bytes() + return cc.hbuf.Bytes(), nil } // shouldSendReqContentLength reports whether the http2.Transport should send @@ -5510,8 +5874,10 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr defer cc.mu.Unlock() cs := cc.streams[id] if andRemove && cs != nil && !cc.closed { + cc.lastActive = time.Now() delete(cc.streams, id) close(cs.done) + cc.cond.Broadcast() } return cs } @@ -5521,15 +5887,6 @@ type http2clientConnReadLoop struct { cc *http2ClientConn activeRes map[uint32]*http2clientStream // keyed by streamID closeWhenIdle bool - - hdec *hpack.Decoder - - // Fields reset on each HEADERS: - nextRes *Response - sawRegHeader bool // saw non-pseudo header - reqMalformed error // non-nil once known to be malformed - lastHeaderEndsStream bool - headerListSize int64 // actually uint32, but easier math this way } // readLoop runs in its own goroutine and reads and dispatches frames. @@ -5538,7 +5895,6 @@ func (cc *http2ClientConn) readLoop() { cc: cc, activeRes: make(map[uint32]*http2clientStream), } - rl.hdec = hpack.NewDecoder(http2initialHeaderTableSize, rl.onNewHeaderField) defer rl.cleanup() cc.readerErr = rl.run() @@ -5549,6 +5905,19 @@ func (cc *http2ClientConn) readLoop() { } } +// GoAwayError is returned by the Transport when the server closes the +// TCP connection after sending a GOAWAY frame. +type http2GoAwayError struct { + LastStreamID uint32 + ErrCode http2ErrCode + DebugData string +} + +func (e http2GoAwayError) Error() string { + return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", + e.LastStreamID, e.ErrCode, e.DebugData) +} + func (rl *http2clientConnReadLoop) cleanup() { cc := rl.cc defer cc.tconn.Close() @@ -5556,10 +5925,18 @@ func (rl *http2clientConnReadLoop) cleanup() { defer close(cc.readerDone) err := cc.readerErr + cc.mu.Lock() if err == io.EOF { - err = io.ErrUnexpectedEOF + if cc.goAway != nil { + err = http2GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, + } + } else { + err = io.ErrUnexpectedEOF + } } - cc.mu.Lock() for _, cs := range rl.activeRes { cs.bufPipe.CloseWithError(err) } @@ -5585,8 +5962,10 @@ func (rl *http2clientConnReadLoop) run() error { cc.vlogf("Transport readFrame error: (%T) %v", err, err) } if se, ok := err.(http2StreamError); ok { - - return se + if cs := cc.streamByID(se.StreamID, true); cs != nil { + rl.endStreamError(cs, cc.fr.errDetail) + } + continue } else if err != nil { return err } @@ -5596,13 +5975,10 @@ func (rl *http2clientConnReadLoop) run() error { maybeIdle := false switch f := f.(type) { - case *http2HeadersFrame: + case *http2MetaHeadersFrame: err = rl.processHeaders(f) maybeIdle = true gotReply = true - case *http2ContinuationFrame: - err = rl.processContinuation(f) - maybeIdle = true case *http2DataFrame: err = rl.processData(f) maybeIdle = true @@ -5632,83 +6008,105 @@ func (rl *http2clientConnReadLoop) run() error { } } -func (rl *http2clientConnReadLoop) processHeaders(f *http2HeadersFrame) error { - rl.sawRegHeader = false - rl.reqMalformed = nil - rl.lastHeaderEndsStream = f.StreamEnded() - rl.headerListSize = 0 - rl.nextRes = &Response{ - Proto: "HTTP/2.0", - ProtoMajor: 2, - Header: make(Header), - } - rl.hdec.SetEmitEnabled(true) - return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded()) -} - -func (rl *http2clientConnReadLoop) processContinuation(f *http2ContinuationFrame) error { - return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded()) -} - -func (rl *http2clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error { +func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { cc := rl.cc - streamEnded := rl.lastHeaderEndsStream - cs := cc.streamByID(streamID, streamEnded && finalFrag) + cs := cc.streamByID(f.StreamID, f.StreamEnded()) if cs == nil { return nil } - if cs.pastHeaders { - rl.hdec.SetEmitFunc(func(f hpack.HeaderField) { rl.onNewTrailerField(cs, f) }) - } else { - rl.hdec.SetEmitFunc(rl.onNewHeaderField) - } - _, err := rl.hdec.Write(frag) - if err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - if finalFrag { - if err := rl.hdec.Close(); err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - } + if !cs.firstByte { + if cs.trace != nil { - if !finalFrag { - return nil + http2traceFirstResponseByte(cs.trace) + } + cs.firstByte = true } - if !cs.pastHeaders { cs.pastHeaders = true } else { + return rl.processTrailers(cs, f) + } - if cs.pastTrailers { - - return http2ConnectionError(http2ErrCodeProtocol) + res, err := rl.handleResponse(cs, f) + if err != nil { + if _, ok := err.(http2ConnectionError); ok { + return err } - cs.pastTrailers = true - if !streamEnded { - return http2ConnectionError(http2ErrCodeProtocol) - } - rl.endStream(cs) + cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err) + cs.resc <- http2resAndError{err: err} return nil } + if res == nil { - if rl.reqMalformed != nil { - cs.resc <- http2resAndError{err: rl.reqMalformed} - rl.cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, rl.reqMalformed) return nil } + if res.Body != http2noBody { + rl.activeRes[cs.ID] = cs + } + cs.resTrailer = &res.Trailer + cs.resc <- http2resAndError{res: res} + return nil +} - res := rl.nextRes +// may return error types nil, or ConnectionError. Any other error value +// is a StreamError of type ErrCodeProtocol. The returned error in that case +// is the detail. +// +// As a special case, handleResponse may return (nil, nil) to skip the +// frame (currently only used for 100 expect continue). This special +// case is going away after Issue 13851 is fixed. +func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*Response, error) { + if f.Truncated { + return nil, http2errResponseHeaderListSize + } - if res.StatusCode == 100 { + status := f.PseudoValue("status") + if status == "" { + return nil, errors.New("missing status pseudo header") + } + statusCode, err := strconv.Atoi(status) + if err != nil { + return nil, errors.New("malformed non-numeric status pseudo header") + } + if statusCode == 100 { + http2traceGot100Continue(cs.trace) + if cs.on100 != nil { + cs.on100() + } cs.pastHeaders = false - return nil + return nil, nil + } + + header := make(Header) + res := &Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: header, + StatusCode: statusCode, + Status: status + " " + StatusText(statusCode), + } + for _, hf := range f.RegularFields() { + key := CanonicalHeaderKey(hf.Name) + if key == "Trailer" { + t := res.Trailer + if t == nil { + t = make(Header) + res.Trailer = t + } + http2foreachHeaderElement(hf.Value, func(v string) { + t[CanonicalHeaderKey(v)] = nil + }) + } else { + header[key] = append(header[key], hf.Value) + } } - if !streamEnded || cs.req.Method == "HEAD" { + streamEnded := f.StreamEnded() + isHead := cs.req.Method == "HEAD" + if !streamEnded || isHead { res.ContentLength = -1 if clens := res.Header["Content-Length"]; len(clens) == 1 { if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { @@ -5721,27 +6119,50 @@ func (rl *http2clientConnReadLoop) processHeaderBlockFragment(frag []byte, strea } } - if streamEnded { + if streamEnded || isHead { res.Body = http2noBody - } else { - buf := new(bytes.Buffer) - cs.bufPipe = http2pipe{b: buf} - cs.bytesRemain = res.ContentLength - res.Body = http2transportResponseBody{cs} - go cs.awaitRequestCancel(http2requestCancel(cs.req)) - - if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = &http2gzipReader{body: res.Body} - } - rl.activeRes[cs.ID] = cs + return res, nil } - cs.resTrailer = &res.Trailer - cs.resc <- http2resAndError{res: res} - rl.nextRes = nil + buf := new(bytes.Buffer) + cs.bufPipe = http2pipe{b: buf} + cs.bytesRemain = res.ContentLength + res.Body = http2transportResponseBody{cs} + go cs.awaitRequestCancel(cs.req) + + if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &http2gzipReader{body: res.Body} + http2setResponseUncompressed(res) + } + return res, nil +} + +func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { + if cs.pastTrailers { + + return http2ConnectionError(http2ErrCodeProtocol) + } + cs.pastTrailers = true + if !f.StreamEnded() { + + return http2ConnectionError(http2ErrCodeProtocol) + } + if len(f.PseudoFields()) > 0 { + + return http2ConnectionError(http2ErrCodeProtocol) + } + + trailer := make(Header) + for _, hf := range f.RegularFields() { + key := CanonicalHeaderKey(hf.Name) + trailer[key] = append(trailer[key], hf.Value) + } + cs.trailer = trailer + + rl.endStream(cs) return nil } @@ -5792,8 +6213,10 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { cc.inflow.add(connAdd) } if err == nil { - if v := cs.inflow.available(); v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { - streamAdd = http2transportDefaultStreamFlow - v + + v := int(cs.inflow.available()) + cs.bufPipe.Len() + if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { + streamAdd = int32(http2transportDefaultStreamFlow - v) cs.inflow.add(streamAdd) } } @@ -5855,6 +6278,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc.mu.Unlock() if _, err := cs.bufPipe.Write(data); err != nil { + rl.endStreamError(cs, err) return err } } @@ -5869,11 +6293,14 @@ var http2errInvalidTrailers = errors.New("http2: invalid trailers") func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { - err := io.EOF - code := cs.copyTrailers - if rl.reqMalformed != nil { - err = rl.reqMalformed - code = nil + rl.endStreamError(cs, nil) +} + +func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { + var code func() + if err == nil { + err = io.EOF + code = cs.copyTrailers } cs.bufPipe.closeWithErrorAndCode(err, code) delete(rl.activeRes, cs.ID) @@ -5997,113 +6424,6 @@ var ( http2errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers") ) -func (rl *http2clientConnReadLoop) checkHeaderField(f hpack.HeaderField) bool { - if rl.reqMalformed != nil { - return false - } - - const headerFieldOverhead = 32 // per spec - rl.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead - if max := rl.cc.t.maxHeaderListSize(); max != 0 && rl.headerListSize > int64(max) { - rl.hdec.SetEmitEnabled(false) - rl.reqMalformed = http2errResponseHeaderListSize - return false - } - - if !http2validHeaderFieldValue(f.Value) { - rl.reqMalformed = http2errInvalidHeaderFieldValue - return false - } - - isPseudo := strings.HasPrefix(f.Name, ":") - if isPseudo { - if rl.sawRegHeader { - rl.reqMalformed = errors.New("http2: invalid pseudo header after regular header") - return false - } - } else { - if !http2validHeaderFieldName(f.Name) { - rl.reqMalformed = http2errInvalidHeaderFieldName - return false - } - rl.sawRegHeader = true - } - - return true -} - -// onNewHeaderField runs on the readLoop goroutine whenever a new -// hpack header field is decoded. -func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) { - cc := rl.cc - if http2VerboseLogs { - cc.logf("http2: Transport decoded %v", f) - } - - if !rl.checkHeaderField(f) { - return - } - - isPseudo := strings.HasPrefix(f.Name, ":") - if isPseudo { - switch f.Name { - case ":status": - code, err := strconv.Atoi(f.Value) - if err != nil { - rl.reqMalformed = errors.New("http2: invalid :status") - return - } - rl.nextRes.Status = f.Value + " " + StatusText(code) - rl.nextRes.StatusCode = code - default: - - rl.reqMalformed = fmt.Errorf("http2: unknown response pseudo header %q", f.Name) - } - return - } - - key := CanonicalHeaderKey(f.Name) - if key == "Trailer" { - t := rl.nextRes.Trailer - if t == nil { - t = make(Header) - rl.nextRes.Trailer = t - } - http2foreachHeaderElement(f.Value, func(v string) { - t[CanonicalHeaderKey(v)] = nil - }) - } else { - rl.nextRes.Header.Add(key, f.Value) - } -} - -func (rl *http2clientConnReadLoop) onNewTrailerField(cs *http2clientStream, f hpack.HeaderField) { - if http2VerboseLogs { - rl.cc.logf("http2: Transport decoded trailer %v", f) - } - if !rl.checkHeaderField(f) { - return - } - if strings.HasPrefix(f.Name, ":") { - - rl.reqMalformed = http2errPseudoTrailers - return - } - - key := CanonicalHeaderKey(f.Name) - - // The spec says one must predeclare their trailers but in practice - // popular users (which is to say the only user we found) do not so we - // violate the spec and accept all of them. - const acceptAllTrailers = true - if _, ok := (*cs.resTrailer)[key]; ok || acceptAllTrailers { - if cs.trailer == nil { - cs.trailer = make(Header) - } - cs.trailer[key] = append(cs.trailer[key], f.Value) - } -} - func (cc *http2ClientConn) logf(format string, args ...interface{}) { cc.t.logf(format, args...) } @@ -6167,6 +6487,79 @@ type http2errorReader struct{ err error } func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } +// bodyWriterState encapsulates various state around the Transport's writing +// of the request body, particularly regarding doing delayed writes of the body +// when the request contains "Expect: 100-continue". +type http2bodyWriterState struct { + cs *http2clientStream + timer *time.Timer // if non-nil, we're doing a delayed write + fnonce *sync.Once // to call fn with + fn func() // the code to run in the goroutine, writing the body + resc chan error // result of fn's execution + delay time.Duration // how long we should delay a delayed write for +} + +func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) { + s.cs = cs + if body == nil { + return + } + resc := make(chan error, 1) + s.resc = resc + s.fn = func() { + resc <- cs.writeRequestBody(body, cs.req.Body) + } + s.delay = t.expectContinueTimeout() + if s.delay == 0 || + !httplex.HeaderValuesContainsToken( + cs.req.Header["Expect"], + "100-continue") { + return + } + s.fnonce = new(sync.Once) + + // Arm the timer with a very large duration, which we'll + // intentionally lower later. It has to be large now because + // we need a handle to it before writing the headers, but the + // s.delay value is defined to not start until after the + // request headers were written. + const hugeDuration = 365 * 24 * time.Hour + s.timer = time.AfterFunc(hugeDuration, func() { + s.fnonce.Do(s.fn) + }) + return +} + +func (s http2bodyWriterState) cancel() { + if s.timer != nil { + s.timer.Stop() + } +} + +func (s http2bodyWriterState) on100() { + if s.timer == nil { + + return + } + s.timer.Stop() + go func() { s.fnonce.Do(s.fn) }() +} + +// scheduleBodyWrite starts writing the body, either immediately (in +// the common case) or after the delay timeout. It should not be +// called until after the headers have been written. +func (s http2bodyWriterState) scheduleBodyWrite() { + if s.timer == nil { + + go s.fn() + return + } + http2traceWait100Continue(s.cs.trace) + if s.timer.Stop() { + s.timer.Reset(s.delay) + } +} + // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error @@ -6380,24 +6773,22 @@ func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { } func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { - if keys == nil { - keys = make([]string, 0, len(h)) - for k := range h { - keys = append(keys, k) - } - sort.Strings(keys) + sorter := http2sorterPool.Get().(*http2sorter) + + defer http2sorterPool.Put(sorter) + keys = sorter.Keys(h) } for _, k := range keys { vv := h[k] k = http2lowerHeader(k) - if !http2validHeaderFieldName(k) { + if !http2validWireHeaderFieldName(k) { continue } isTE := k == "transfer-encoding" for _, v := range vv { - if !http2validHeaderFieldValue(v) { + if !httplex.ValidHeaderFieldValue(v) { continue } |