diff options
Diffstat (limited to 'libgo/go/net/http')
42 files changed, 11652 insertions, 1172 deletions
diff --git a/libgo/go/net/http/cgi/host.go b/libgo/go/net/http/cgi/host.go index 4efbe7abeec..9b4d8754183 100644 --- a/libgo/go/net/http/cgi/host.go +++ b/libgo/go/net/http/cgi/host.go @@ -77,15 +77,15 @@ type Handler struct { // Env: []string{"SCRIPT_FILENAME=foo.php"}, // } func removeLeadingDuplicates(env []string) (ret []string) { - n := len(env) - for i := 0; i < n; i++ { - e := env[i] - s := strings.SplitN(e, "=", 2)[0] + for i, e := range env { found := false - for j := i + 1; j < n; j++ { - if s == strings.SplitN(env[j], "=", 2)[0] { - found = true - break + if eq := strings.IndexByte(e, '='); eq != -1 { + keq := e[:eq+1] // "key=" + for _, e2 := range env[i+1:] { + if strings.HasPrefix(e2, keq) { + found = true + break + } } } if !found { @@ -159,10 +159,6 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env = append(env, "CONTENT_TYPE="+ctype) } - if h.Env != nil { - env = append(env, h.Env...) - } - envPath := os.Getenv("PATH") if envPath == "" { envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin" @@ -181,6 +177,10 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } + if h.Env != nil { + env = append(env, h.Env...) + } + env = removeLeadingDuplicates(env) var cwd, path string diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go index f3411105ca9..fb7d66adb9f 100644 --- a/libgo/go/net/http/cgi/host_test.go +++ b/libgo/go/net/http/cgi/host_test.go @@ -16,6 +16,7 @@ import ( "os" "os/exec" "path/filepath" + "reflect" "runtime" "strconv" "strings" @@ -488,12 +489,36 @@ func TestEnvOverride(t *testing.T) { Args: []string{cgifile}, Env: []string{ "SCRIPT_FILENAME=" + cgifile, - "REQUEST_URI=/foo/bar"}, + "REQUEST_URI=/foo/bar", + "PATH=/wibble"}, } expectedMap := map[string]string{ "cwd": cwd, "env-SCRIPT_FILENAME": cgifile, "env-REQUEST_URI": "/foo/bar", + "env-PATH": "/wibble", } runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) } + +func TestRemoveLeadingDuplicates(t *testing.T) { + tests := []struct { + env []string + want []string + }{ + { + env: []string{"a=b", "b=c", "a=b2"}, + want: []string{"b=c", "a=b2"}, + }, + { + env: []string{"a=b", "b=c", "d", "e=f"}, + want: []string{"a=b", "b=c", "d", "e=f"}, + }, + } + for _, tt := range tests { + got := removeLeadingDuplicates(tt.env) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("removeLeadingDuplicates(%q) = %q; want %q", tt.env, got, tt.want) + } + } +} diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 7f2fbb4678e..3106d229da6 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -10,6 +10,7 @@ package http import ( + "crypto/tls" "encoding/base64" "errors" "fmt" @@ -19,7 +20,6 @@ import ( "net/url" "strings" "sync" - "sync/atomic" "time" ) @@ -65,10 +65,15 @@ type Client struct { // // A Timeout of zero means no timeout. // - // The Client's Transport must support the CancelRequest - // method or Client will return errors when attempting to make - // a request with Get, Head, Post, or Do. Client's default - // Transport (DefaultTransport) supports CancelRequest. + // The Client cancels requests to the underlying Transport + // using the Request.Cancel mechanism. Requests passed + // to Client.Do may still set Request.Cancel; both will + // cancel the request. + // + // For compatibility, the Client will also use the deprecated + // CancelRequest method on Transport if found. New + // RoundTripper implementations should use Request.Cancel + // instead of implementing CancelRequest. Timeout time.Duration } @@ -82,19 +87,26 @@ var DefaultClient = &Client{} // goroutines. type RoundTripper interface { // RoundTrip executes a single HTTP transaction, returning - // the Response for the request req. RoundTrip should not - // attempt to interpret the response. In particular, - // RoundTrip must return err == nil if it obtained a response, - // regardless of the response's HTTP status code. A non-nil - // err should be reserved for failure to obtain a response. - // Similarly, RoundTrip should not attempt to handle - // higher-level protocol details such as redirects, + // a Response for the provided Request. + // + // RoundTrip should not attempt to interpret the response. In + // particular, RoundTrip must return err == nil if it obtained + // a response, regardless of the response's HTTP status code. + // A non-nil err should be reserved for failure to obtain a + // response. Similarly, RoundTrip should not attempt to + // handle higher-level protocol details such as redirects, // authentication, or cookies. // // RoundTrip should not modify the request, except for - // consuming and closing the Body, including on errors. The - // request's URL and Header fields are guaranteed to be - // initialized. + // consuming and closing the Request's Body. + // + // RoundTrip must always close the body, including on errors, + // but depending on the implementation may do so in a separate + // goroutine even after RoundTrip returns. This means that + // callers wanting to reuse the body for subsequent requests + // must arrange to wait for the Close call before doing so. + // + // The Request's URL and Header fields must be initialized. RoundTrip(*Request) (*Response, error) } @@ -134,13 +146,13 @@ type readClose struct { io.Closer } -func (c *Client) send(req *Request) (*Response, error) { +func (c *Client) send(req *Request, deadline time.Time) (*Response, error) { if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) } } - resp, err := send(req, c.transport()) + resp, err := send(req, c.transport(), deadline) if err != nil { return nil, err } @@ -171,13 +183,21 @@ func (c *Client) send(req *Request) (*Response, error) { // // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err error) { - if req.Method == "GET" || req.Method == "HEAD" { + method := valueOrDefault(req.Method, "GET") + if method == "GET" || method == "HEAD" { return c.doFollowingRedirects(req, shouldRedirectGet) } - if req.Method == "POST" || req.Method == "PUT" { + if method == "POST" || method == "PUT" { return c.doFollowingRedirects(req, shouldRedirectPost) } - return c.send(req) + return c.send(req, c.deadline()) +} + +func (c *Client) deadline() time.Time { + if c.Timeout > 0 { + return time.Now().Add(c.Timeout) + } + return time.Time{} } func (c *Client) transport() RoundTripper { @@ -189,8 +209,10 @@ func (c *Client) transport() RoundTripper { // send issues an HTTP request. // Caller should close resp.Body when done reading from it. -func send(req *Request, t RoundTripper) (resp *Response, err error) { - if t == nil { +func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) { + req := ireq // req is either the original request, or a modified fork + + if rt == nil { req.closeBody() return nil, errors.New("http: no Client.Transport or DefaultTransport") } @@ -205,28 +227,122 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { return nil, errors.New("http: Request.RequestURI can't be set in client requests.") } + // forkReq forks req into a shallow clone of ireq the first + // time it's called. + forkReq := func() { + if ireq == req { + req = new(Request) + *req = *ireq // shallow clone + } + } + // Most the callers of send (Get, Post, et al) don't need // Headers, leaving it uninitialized. We guarantee to the // Transport that this has been initialized, though. if req.Header == nil { + forkReq() req.Header = make(Header) } if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { username := u.Username() password, _ := u.Password() + forkReq() + req.Header = cloneHeader(ireq.Header) req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) } - resp, err = t.RoundTrip(req) + + if !deadline.IsZero() { + forkReq() + } + stopTimer, wasCanceled := setRequestCancel(req, rt, deadline) + + resp, err := rt.RoundTrip(req) if err != nil { + stopTimer() if resp != nil { log.Printf("RoundTripper returned a response & error; ignoring response") } + if tlsErr, ok := err.(tls.RecordHeaderError); ok { + // If we get a bad TLS record header, check to see if the + // response looks like HTTP and give a more helpful error. + // See golang.org/issue/11111. + if string(tlsErr.RecordHeader[:]) == "HTTP/" { + err = errors.New("http: server gave HTTP response to HTTPS client") + } + } return nil, err } + if !deadline.IsZero() { + resp.Body = &cancelTimerBody{ + stop: stopTimer, + rc: resp.Body, + reqWasCanceled: wasCanceled, + } + } return resp, nil } +// setRequestCancel sets the Cancel field of req, if deadline is +// non-zero. The RoundTripper's type is used to determine whether the legacy +// CancelRequest behavior should be used. +func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) { + if deadline.IsZero() { + return nop, alwaysFalse + } + + initialReqCancel := req.Cancel // the user's original Request.Cancel, if any + + cancel := make(chan struct{}) + req.Cancel = cancel + + wasCanceled = func() bool { + select { + case <-cancel: + return true + default: + return false + } + } + + doCancel := func() { + // The new way: + close(cancel) + + // The legacy compatibility way, used only + // for RoundTripper implementations written + // before Go 1.5 or Go 1.6. + type canceler interface { + CancelRequest(*Request) + } + switch v := rt.(type) { + case *Transport, *http2Transport: + // Do nothing. The net/http package's transports + // support the new Request.Cancel channel + case canceler: + v.CancelRequest(req) + } + } + + stopTimerCh := make(chan struct{}) + var once sync.Once + stopTimer = func() { once.Do(func() { close(stopTimerCh) }) } + + timer := time.NewTimer(deadline.Sub(time.Now())) + go func() { + select { + case <-initialReqCancel: + doCancel() + case <-timer.C: + doCancel() + case <-stopTimerCh: + timer.Stop() + } + }() + + return stopTimer, wasCanceled +} + // See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt // "To receive authorization, the client sends the userid and password, // separated by a single colon (":") character, within a base64 @@ -321,34 +437,15 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo return nil, errors.New("http: nil Request.URL") } - var reqmu sync.Mutex // guards req req := ireq - - var timer *time.Timer - var atomicWasCanceled int32 // atomic bool (1 or 0) - var wasCanceled = alwaysFalse - if c.Timeout > 0 { - wasCanceled = func() bool { return atomic.LoadInt32(&atomicWasCanceled) != 0 } - type canceler interface { - CancelRequest(*Request) - } - tr, ok := c.transport().(canceler) - if !ok { - return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport()) - } - timer = time.AfterFunc(c.Timeout, func() { - atomic.StoreInt32(&atomicWasCanceled, 1) - reqmu.Lock() - defer reqmu.Unlock() - tr.CancelRequest(req) - }) - } + deadline := c.deadline() urlStr := "" // next relative or absolute URL to fetch (after first request) redirectFailed := false for redirect := 0; ; redirect++ { if redirect != 0 { nreq := new(Request) + nreq.Cancel = ireq.Cancel nreq.Method = ireq.Method if ireq.Method == "POST" || ireq.Method == "PUT" { nreq.Method = "GET" @@ -371,14 +468,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo break } } - reqmu.Lock() req = nreq - reqmu.Unlock() } urlStr = req.URL.String() - if resp, err = c.send(req); err != nil { - if wasCanceled() { + if resp, err = c.send(req, deadline); err != nil { + if !deadline.IsZero() && !time.Now().Before(deadline) { err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", timeout: true, @@ -403,19 +498,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo via = append(via, req) continue } - if timer != nil { - resp.Body = &cancelTimerBody{ - t: timer, - rc: resp.Body, - reqWasCanceled: wasCanceled, - } - } return resp, nil } - method := ireq.Method + method := valueOrDefault(ireq.Method, "GET") urlErr := &url.Error{ - Op: method[0:1] + strings.ToLower(method[1:]), + Op: method[:1] + strings.ToLower(method[1:]), URL: urlStr, Err: err, } @@ -528,30 +616,35 @@ func (c *Client) Head(url string) (resp *Response, err error) { } // cancelTimerBody is an io.ReadCloser that wraps rc with two features: -// 1) on Read EOF or Close, the timer t is Stopped, +// 1) on Read error or close, the stop func is called. // 2) On Read failure, if reqWasCanceled is true, the error is wrapped and // marked as net.Error that hit its timeout. type cancelTimerBody struct { - t *time.Timer + stop func() // stops the time.Timer waiting to cancel the request rc io.ReadCloser reqWasCanceled func() bool } func (b *cancelTimerBody) Read(p []byte) (n int, err error) { n, err = b.rc.Read(p) + if err == nil { + return n, nil + } + b.stop() if err == io.EOF { - b.t.Stop() - } else if err != nil && b.reqWasCanceled() { - return n, &httpError{ + return n, err + } + if b.reqWasCanceled() { + err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while reading body)", timeout: true, } } - return + return n, err } func (b *cancelTimerBody) Close() error { err := b.rc.Close() - b.t.Stop() + b.stop() return err } diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 7b524d381bc..8939dc8baf9 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -20,8 +20,6 @@ import ( . "net/http" "net/http/httptest" "net/url" - "reflect" - "sort" "strconv" "strings" "sync" @@ -83,12 +81,15 @@ func TestClient(t *testing.T) { } } -func TestClientHead(t *testing.T) { +func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) } +func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) } + +func testClientHead(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(robotsTxtHandler) - defer ts.Close() + cst := newClientServerTest(t, h2, robotsTxtHandler) + defer cst.close() - r, err := Head(ts.URL) + r, err := cst.c.Head(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -230,9 +231,18 @@ func TestClientRedirects(t *testing.T) { t.Errorf("with default client Do, expected error %q, got %q", e, g) } + // Requests with an empty Method should also redirect (Issue 12705) + greq.Method = "" + _, err = c.Do(greq) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g) + } + var checkErr error var lastVia []*Request - c = &Client{CheckRedirect: func(_ *Request, via []*Request) error { + var lastReq *Request + c = &Client{CheckRedirect: func(req *Request, via []*Request) error { + lastReq = req lastVia = via return checkErr }} @@ -252,6 +262,20 @@ func TestClientRedirects(t *testing.T) { t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) } + // Test that Request.Cancel is propagated between requests (Issue 14053) + creq, _ := NewRequest("HEAD", ts.URL, nil) + cancel := make(chan struct{}) + creq.Cancel = cancel + if _, err := c.Do(creq); err != nil { + t.Fatal(err) + } + if lastReq == nil { + t.Fatal("didn't see redirect") + } + if lastReq.Cancel != cancel { + t.Errorf("expected lastReq to have the cancel channel set on the inital req") + } + checkErr = errors.New("no redirects allowed") res, err = c.Get(ts.URL) if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { @@ -486,20 +510,23 @@ func (j *RecordingJar) logf(format string, args ...interface{}) { fmt.Fprintf(&j.log, format, args...) } -func TestStreamingGet(t *testing.T) { +func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) } +func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) } + +func testStreamingGet(t *testing.T, h2 bool) { defer afterTest(t) say := make(chan string) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() for str := range say { w.Write([]byte(str)) w.(Flusher).Flush() } })) - defer ts.Close() + defer cst.close() - c := &Client{} - res, err := c.Get(ts.URL) + c := cst.c + res, err := c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -642,14 +669,18 @@ func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { func TestClientWithCorrectTLSServerName(t *testing.T) { defer afterTest(t) + + const serverName = "example.com" ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { - if r.TLS.ServerName != "127.0.0.1" { - t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) + if r.TLS.ServerName != serverName { + t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName) } })) defer ts.Close() - c := &Client{Transport: newTLSTransport(t, ts)} + trans := newTLSTransport(t, ts) + trans.TLSClientConfig.ServerName = serverName + c := &Client{Transport: trans} if _, err := c.Get(ts.URL); err != nil { t.Fatalf("expected successful TLS connection, got error: %v", err) } @@ -739,15 +770,37 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { } } +// Check that an HTTPS client can interpret a particular TLS error +// to determine that the server is speaking HTTP. +// See golang.org/issue/11111. +func TestHTTPSClientDetectsHTTPServer(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + + _, err := Get(strings.Replace(ts.URL, "http", "https", 1)) + if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") { + t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got) + } +} + // Verify Response.ContentLength is populated. https://golang.org/issue/4126 -func TestClientHeadContentLength(t *testing.T) { +func TestClientHeadContentLength_h1(t *testing.T) { + testClientHeadContentLength(t, h1Mode) +} + +func TestClientHeadContentLength_h2(t *testing.T) { + testClientHeadContentLength(t, h2Mode) +} + +func testClientHeadContentLength(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { if v := r.FormValue("cl"); v != "" { w.Header().Set("Content-Length", v) } })) - defer ts.Close() + defer cst.close() tests := []struct { suffix string want int64 @@ -757,8 +810,8 @@ func TestClientHeadContentLength(t *testing.T) { {"", -1}, } for _, tt := range tests { - req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil) - res, err := DefaultClient.Do(req) + req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil) + res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } @@ -884,14 +937,17 @@ func TestBasicAuthHeadersPreserved(t *testing.T) { } -func TestClientTimeout(t *testing.T) { +func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } +func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } + +func testClientTimeout(t *testing.T, h2 bool) { if testing.Short() { t.Skip("skipping in short mode") } defer afterTest(t) sawRoot := make(chan bool, 1) sawSlow := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/" { sawRoot <- true Redirect(w, r, "/slow", StatusFound) @@ -905,13 +961,11 @@ func TestClientTimeout(t *testing.T) { return } })) - defer ts.Close() + defer cst.close() const timeout = 500 * time.Millisecond - c := &Client{ - Timeout: timeout, - } + cst.c.Timeout = timeout - res, err := c.Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -957,17 +1011,20 @@ func TestClientTimeout(t *testing.T) { } } +func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) } +func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) } + // Client.Timeout firing before getting to the body -func TestClientTimeout_Headers(t *testing.T) { +func testClientTimeout_Headers(t *testing.T, h2 bool) { if testing.Short() { t.Skip("skipping in short mode") } defer afterTest(t) donec := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { <-donec })) - defer ts.Close() + defer cst.close() // Note that we use a channel send here and not a close. // The race detector doesn't know that we're waiting for a timeout // and thinks that the waitgroup inside httptest.Server is added to concurrently @@ -977,19 +1034,17 @@ func TestClientTimeout_Headers(t *testing.T) { // doesn't know this, so synchronize explicitly. defer func() { donec <- true }() - c := &Client{Timeout: 500 * time.Millisecond} - - _, err := c.Get(ts.URL) + cst.c.Timeout = 500 * time.Millisecond + _, err := cst.c.Get(cst.ts.URL) if err == nil { t.Fatal("got response from Get; expected error") } - ue, ok := err.(*url.Error) - if !ok { + if _, ok := err.(*url.Error); !ok { t.Fatalf("Got error of type %T; want *url.Error", err) } - ne, ok := ue.Err.(net.Error) + ne, ok := err.(net.Error) if !ok { - t.Fatalf("Got url.Error.Err of type %T; want some net.Error", err) + t.Fatalf("Got error of type %T; want some net.Error", err) } if !ne.Timeout() { t.Error("net.Error.Timeout = false; want true") @@ -999,18 +1054,20 @@ func TestClientTimeout_Headers(t *testing.T) { } } -func TestClientRedirectEatsBody(t *testing.T) { +func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } +func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } +func testClientRedirectEatsBody(t *testing.T, h2 bool) { defer afterTest(t) saw := make(chan string, 2) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { saw <- r.RemoteAddr if r.URL.Path == "/" { Redirect(w, r, "/foo", StatusFound) // which includes a body } })) - defer ts.Close() + defer cst.close() - res, err := Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -1047,76 +1104,6 @@ func (f eofReaderFunc) Read(p []byte) (n int, err error) { return 0, io.EOF } -func TestClientTrailers(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - w.Header().Set("Connection", "close") - w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") - w.Header().Add("Trailer", "Server-Trailer-C") - - var decl []string - for k := range r.Trailer { - decl = append(decl, k) - } - sort.Strings(decl) - - slurp, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Errorf("Server reading request body: %v", err) - } - if string(slurp) != "foo" { - t.Errorf("Server read request body %q; want foo", slurp) - } - if r.Trailer == nil { - io.WriteString(w, "nil Trailer") - } else { - fmt.Fprintf(w, "decl: %v, vals: %s, %s", - decl, - r.Trailer.Get("Client-Trailer-A"), - r.Trailer.Get("Client-Trailer-B")) - } - - // How handlers set Trailers: declare it ahead of time - // with the Trailer header, and then mutate the - // Header() of those values later, after the response - // has been written (we wrote to w above). - w.Header().Set("Server-Trailer-A", "valuea") - w.Header().Set("Server-Trailer-C", "valuec") // skipping B - })) - defer ts.Close() - - var req *Request - req, _ = NewRequest("POST", ts.URL, io.MultiReader( - eofReaderFunc(func() { - req.Trailer["Client-Trailer-A"] = []string{"valuea"} - }), - strings.NewReader("foo"), - eofReaderFunc(func() { - req.Trailer["Client-Trailer-B"] = []string{"valueb"} - }), - )) - req.Trailer = Header{ - "Client-Trailer-A": nil, // to be set later - "Client-Trailer-B": nil, // to be set later - } - req.ContentLength = -1 - res, err := DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { - t.Error(err) - } - want := Header{ - "Server-Trailer-A": []string{"valuea"}, - "Server-Trailer-B": nil, - "Server-Trailer-C": []string{"valuec"}, - } - if !reflect.DeepEqual(res.Trailer, want) { - t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want) - } -} - func TestReferer(t *testing.T) { tests := []struct { lastReq, newReq string // from -> to URLs diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go new file mode 100644 index 00000000000..3c87fd0cf83 --- /dev/null +++ b/libgo/go/net/http/clientserver_test.go @@ -0,0 +1,1056 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode. + +package http_test + +import ( + "bytes" + "compress/gzip" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + "log" + "net" + . "net/http" + "net/http/httptest" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type clientServerTest struct { + t *testing.T + h2 bool + h Handler + ts *httptest.Server + tr *Transport + c *Client +} + +func (t *clientServerTest) close() { + t.tr.CloseIdleConnections() + t.ts.Close() +} + +const ( + h1Mode = false + h2Mode = true +) + +func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest { + cst := &clientServerTest{ + t: t, + h2: h2, + h: h, + tr: &Transport{}, + } + cst.c = &Client{Transport: cst.tr} + + for _, opt := range opts { + switch opt := opt.(type) { + case func(*Transport): + opt(cst.tr) + default: + t.Fatalf("unhandled option type %T", opt) + } + } + + if !h2 { + cst.ts = httptest.NewServer(h) + return cst + } + cst.ts = httptest.NewUnstartedServer(h) + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + + cst.tr.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { + t.Fatal(err) + } + return cst +} + +// Testing the newClientServerTest helper itself. +func TestNewClientServerTest(t *testing.T) { + var got struct { + sync.Mutex + log []string + } + h := HandlerFunc(func(w ResponseWriter, r *Request) { + got.Lock() + defer got.Unlock() + got.log = append(got.log, r.Proto) + }) + for _, v := range [2]bool{false, true} { + cst := newClientServerTest(t, v, h) + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) + } + cst.close() + } + got.Lock() // no need to unlock + if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) { + t.Errorf("got %q; want %q", got.log, want) + } +} + +func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) } +func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) } + +func testChunkedResponseHeaders(t *testing.T, h2 bool) { + defer afterTest(t) + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted + w.(Flusher).Flush() + fmt.Fprintf(w, "I am a chunked response.") + })) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + defer res.Body.Close() + if g, e := res.ContentLength, int64(-1); g != e { + t.Errorf("expected ContentLength of %d; got %d", e, g) + } + wantTE := []string{"chunked"} + if h2 { + wantTE = nil + } + if !reflect.DeepEqual(res.TransferEncoding, wantTE) { + t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE) + } + if got, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length: %q", got) + } +} + +type reqFunc func(c *Client, url string) (*Response, error) + +// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior +// against each other. +type h12Compare struct { + Handler func(ResponseWriter, *Request) // required + ReqFunc reqFunc // optional + CheckResponse func(proto string, res *Response) // optional + Opts []interface{} +} + +func (tt h12Compare) reqFunc() reqFunc { + if tt.ReqFunc == nil { + return (*Client).Get + } + return tt.ReqFunc +} + +func (tt h12Compare) run(t *testing.T) { + cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) + defer cst1.close() + cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) + defer cst2.close() + + res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) + if err != nil { + t.Errorf("HTTP/1 request: %v", err) + return + } + res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL) + if err != nil { + t.Errorf("HTTP/2 request: %v", err) + return + } + tt.normalizeRes(t, res1, "HTTP/1.1") + tt.normalizeRes(t, res2, "HTTP/2.0") + res1body, res2body := res1.Body, res2.Body + + eres1 := mostlyCopy(res1) + eres2 := mostlyCopy(res2) + if !reflect.DeepEqual(eres1, eres2) { + t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v", + cst1.ts.URL, eres1, cst2.ts.URL, eres2) + } + if !reflect.DeepEqual(res1body, res2body) { + t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body) + } + if fn := tt.CheckResponse; fn != nil { + res1.Body, res2.Body = res1body, res2body + fn("HTTP/1.1", res1) + fn("HTTP/2.0", res2) + } +} + +func mostlyCopy(r *Response) *Response { + c := *r + c.Body = nil + c.TransferEncoding = nil + c.TLS = nil + c.Request = nil + return &c +} + +type slurpResult struct { + io.ReadCloser + body []byte + err error +} + +func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } + +func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { + if res.Proto == wantProto { + res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 + } else { + t.Errorf("got %q response; want %q", res.Proto, wantProto) + } + slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() + res.Body = slurpResult{ + ReadCloser: ioutil.NopCloser(bytes.NewReader(slurp)), + body: slurp, + err: err, + } + for i, v := range res.Header["Date"] { + res.Header["Date"][i] = strings.Repeat("x", len(v)) + } + if res.Request == nil { + t.Errorf("for %s, no request", wantProto) + } + if (res.TLS != nil) != (wantProto == "HTTP/2.0") { + t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil) + } +} + +// Issue 13532 +func TestH12_HeadContentLengthNoBody(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + }, + }.run(t) +} + +func TestH12_HeadContentLengthSmallBody(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + io.WriteString(w, "small") + }, + }.run(t) +} + +func TestH12_HeadContentLengthLargeBody(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + chunk := strings.Repeat("x", 512<<10) + for i := 0; i < 10; i++ { + io.WriteString(w, chunk) + } + }, + }.run(t) +} + +func TestH12_200NoBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t) +} + +func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) } +func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) } +func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) } + +func testH12_noBody(t *testing.T, status int) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + w.WriteHeader(status) + }}.run(t) +} + +func TestH12_SmallBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + io.WriteString(w, "small body") + }}.run(t) +} + +func TestH12_ExplicitContentLength(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "3") + io.WriteString(w, "foo") + }}.run(t) +} + +func TestH12_FlushBeforeBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + io.WriteString(w, "foo") + }}.run(t) +} + +func TestH12_FlushMidBody(t *testing.T) { + h12Compare{Handler: func(w ResponseWriter, r *Request) { + io.WriteString(w, "foo") + w.(Flusher).Flush() + io.WriteString(w, "bar") + }}.run(t) +} + +func TestH12_Head_ExplicitLen(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + t.Errorf("unexpected method %q", r.Method) + } + w.Header().Set("Content-Length", "1235") + }, + }.run(t) +} + +func TestH12_Head_ImplicitLen(t *testing.T) { + h12Compare{ + ReqFunc: (*Client).Head, + Handler: func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + t.Errorf("unexpected method %q", r.Method) + } + io.WriteString(w, "foo") + }, + }.run(t) +} + +func TestH12_HandlerWritesTooLittle(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "3") + io.WriteString(w, "12") // one byte short + }, + CheckResponse: func(proto string, res *Response) { + sr, ok := res.Body.(slurpResult) + if !ok { + t.Errorf("%s body is %T; want slurpResult", proto, res.Body) + return + } + if sr.err != io.ErrUnexpectedEOF { + t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err) + } + if string(sr.body) != "12" { + t.Errorf("%s body = %q; want %q", proto, sr.body, "12") + } + }, + }.run(t) +} + +// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from +// writing more than they declared. This test does not test whether +// the transport deals with too much data, though, since the server +// doesn't make it possible to send bogus data. For those tests, see +// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go +// (for HTTP/2). +func TestH12_HandlerWritesTooMuch(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "3") + w.(Flusher).Flush() + io.WriteString(w, "123") + w.(Flusher).Flush() + n, err := io.WriteString(w, "x") // too many + if n > 0 || err == nil { + t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err) + } + }, + }.run(t) +} + +// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip. +// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298 +func TestH12_AutoGzip(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" { + t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae) + } + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.") + gz.Close() + }, + }.run(t) +} + +func TestH12_AutoGzip_Disabled(t *testing.T) { + h12Compare{ + Opts: []interface{}{ + func(tr *Transport) { tr.DisableCompression = true }, + }, + Handler: func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) + if ae := r.Header.Get("Accept-Encoding"); ae != "" { + t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) + } + }, + }.run(t) +} + +// Test304Responses verifies that 304s don't declare that they're +// chunking in their response headers and aren't allowed to produce +// output. +func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) } +func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) } + +func test304Responses(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNotModified) + _, err := w.Write([]byte("illegal body")) + if err != ErrBodyNotAllowed { + t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + } + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if len(res.TransferEncoding) > 0 { + t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if len(body) > 0 { + t.Errorf("got unexpected body %q", string(body)) + } +} + +func TestH12_ServerEmptyContentLength(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header()["Content-Type"] = []string{""} + io.WriteString(w, "<html><body>hi</body></html>") + }, + }.run(t) +} + +func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) +} + +func TestH12_RequestContentLength_Known_Zero(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0) +} + +func TestH12_RequestContentLength_Unknown(t *testing.T) { + h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) +} + +func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) + fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) + }, + ReqFunc: func(c *Client, url string) (*Response, error) { + return c.Post(url, "text/plain", bodyfn()) + }, + CheckResponse: func(proto string, res *Response) { + if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { + t.Errorf("Proto %q got length %q; want %q", proto, got, want) + } + }, + }.run(t) +} + +// Tests that closing the Request.Cancel channel also while still +// reading the response body. Issue 13159. +func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } +func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) } +func testCancelRequestMidBody(t *testing.T, h2 bool) { + defer afterTest(t) + unblock := make(chan bool) + didFlush := make(chan bool, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, "Hello") + w.(Flusher).Flush() + didFlush <- true + <-unblock + io.WriteString(w, ", world.") + })) + defer cst.close() + defer close(unblock) + + req, _ := NewRequest("GET", cst.ts.URL, nil) + cancel := make(chan struct{}) + req.Cancel = cancel + + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + <-didFlush + + // Read a bit before we cancel. (Issue 13626) + // We should have "Hello" at least sitting there. + firstRead := make([]byte, 10) + n, err := res.Body.Read(firstRead) + if err != nil { + t.Fatal(err) + } + firstRead = firstRead[:n] + + close(cancel) + + rest, err := ioutil.ReadAll(res.Body) + all := string(firstRead) + string(rest) + if all != "Hello" { + t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) + } + if !reflect.DeepEqual(err, ExportErrRequestCanceled) { + t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) + } +} + +// Tests that clients can send trailers to a server and that the server can read them. +func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } +func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } + +func testTrailersClientToServer(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + var decl []string + for k := range r.Trailer { + decl = append(decl, k) + } + sort.Strings(decl) + + slurp, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("Server reading request body: %v", err) + } + if string(slurp) != "foo" { + t.Errorf("Server read request body %q; want foo", slurp) + } + if r.Trailer == nil { + io.WriteString(w, "nil Trailer") + } else { + fmt.Fprintf(w, "decl: %v, vals: %s, %s", + decl, + r.Trailer.Get("Client-Trailer-A"), + r.Trailer.Get("Client-Trailer-B")) + } + })) + defer cst.close() + + var req *Request + req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( + eofReaderFunc(func() { + req.Trailer["Client-Trailer-A"] = []string{"valuea"} + }), + strings.NewReader("foo"), + eofReaderFunc(func() { + req.Trailer["Client-Trailer-B"] = []string{"valueb"} + }), + )) + req.Trailer = Header{ + "Client-Trailer-A": nil, // to be set later + "Client-Trailer-B": nil, // to be set later + } + req.ContentLength = -1 + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { + t.Error(err) + } +} + +// Tests that servers send trailers to a client and that the client can read them. +func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) } +func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) } +func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) } +func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) } + +func testTrailersServerToClient(t *testing.T, h2, flush bool) { + defer afterTest(t) + const body = "Some body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") + w.Header().Add("Trailer", "Server-Trailer-C") + + io.WriteString(w, body) + if flush { + w.(Flusher).Flush() + } + + // How handlers set Trailers: declare it ahead of time + // with the Trailer header, and then mutate the + // Header() of those values later, after the response + // has been written (we wrote to w above). + w.Header().Set("Server-Trailer-A", "valuea") + w.Header().Set("Server-Trailer-C", "valuec") // skipping B + w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") + })) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + + wantHeader := Header{ + "Content-Type": {"text/plain; charset=utf-8"}, + } + wantLen := -1 + if h2 && !flush { + // In HTTP/1.1, any use of trailers forces HTTP/1.1 + // chunking and a flush at the first write. That's + // unnecessary with HTTP/2's framing, so the server + // is able to calculate the length while still sending + // trailers afterwards. + wantLen = len(body) + wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)} + } + if res.ContentLength != int64(wantLen) { + t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen) + } + + delete(res.Header, "Date") // irrelevant for test + if !reflect.DeepEqual(res.Header, wantHeader) { + t.Errorf("Header = %v; want %v", res.Header, wantHeader) + } + + if got, want := res.Trailer, (Header{ + "Server-Trailer-A": nil, + "Server-Trailer-B": nil, + "Server-Trailer-C": nil, + }); !reflect.DeepEqual(got, want) { + t.Errorf("Trailer before body read = %v; want %v", got, want) + } + + if err := wantBody(res, nil, body); err != nil { + t.Fatal(err) + } + + if got, want := res.Trailer, (Header{ + "Server-Trailer-A": {"valuea"}, + "Server-Trailer-B": nil, + "Server-Trailer-C": {"valuec"}, + }); !reflect.DeepEqual(got, want) { + t.Errorf("Trailer after body read = %v; want %v", got, want) + } +} + +// Don't allow a Body.Read after Body.Close. Issue 13648. +func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) } +func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) } + +func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { + defer afterTest(t) + const body = "Some body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, body) + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + data, err := ioutil.ReadAll(res.Body) + if len(data) != 0 || err == nil { + t.Fatalf("ReadAll returned %q, %v; want error", data, err) + } +} + +func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) } +func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) } +func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { + defer afterTest(t) + const reqBody = "some request body" + const resBody = "some response body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + var wg sync.WaitGroup + wg.Add(2) + didRead := make(chan bool, 1) + // Read in one goroutine. + go func() { + defer wg.Done() + data, err := ioutil.ReadAll(r.Body) + if string(data) != reqBody { + t.Errorf("Handler read %q; want %q", data, reqBody) + } + if err != nil { + t.Errorf("Handler Read: %v", err) + } + didRead <- true + }() + // Write in another goroutine. + go func() { + defer wg.Done() + if !h2 { + // our HTTP/1 implementation intentionally + // doesn't permit writes during read (mostly + // due to it being undefined); if that is ever + // relaxed, change this. + <-didRead + } + io.WriteString(w, resBody) + }() + wg.Wait() + })) + defer cst.close() + req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) + req.Header.Add("Expect", "100-continue") // just to complicate things + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + data, err := ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(data) != resBody { + t.Errorf("read %q; want %q", data, resBody) + } +} + +func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) } +func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) } +func testConnectRequest(t *testing.T, h2 bool) { + defer afterTest(t) + gotc := make(chan *Request, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + gotc <- r + })) + defer cst.close() + + u, err := url.Parse(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + req *Request + want string + }{ + { + req: &Request{ + Method: "CONNECT", + Header: Header{}, + URL: u, + }, + want: u.Host, + }, + { + req: &Request{ + Method: "CONNECT", + Header: Header{}, + URL: u, + Host: "example.com:123", + }, + want: "example.com:123", + }, + } + + for i, tt := range tests { + res, err := cst.c.Do(tt.req) + if err != nil { + t.Errorf("%d. RoundTrip = %v", i, err) + continue + } + res.Body.Close() + req := <-gotc + if req.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", req.Method) + } + if req.Host != tt.want { + t.Errorf("Host = %q; want %q", req.Host, tt.want) + } + if req.URL.Host != tt.want { + t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) + } + } +} + +func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) } +func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) } +func testTransportUserAgent(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%q", r.Header["User-Agent"]) + })) + defer cst.close() + + either := func(a, b string) string { + if h2 { + return b + } + return a + } + + tests := []struct { + setup func(*Request) + want string + }{ + { + func(r *Request) {}, + either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`), + }, + { + func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") }, + `["foo/1.2.3"]`, + }, + { + func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} }, + `["single"]`, + }, + { + func(r *Request) { r.Header.Set("User-Agent", "") }, + `[]`, + }, + { + func(r *Request) { r.Header["User-Agent"] = nil }, + `[]`, + }, + } + for i, tt := range tests { + req, _ := NewRequest("GET", cst.ts.URL, nil) + tt.setup(req) + res, err := cst.c.Do(req) + if err != nil { + t.Errorf("%d. RoundTrip = %v", i, err) + continue + } + slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("%d. read body = %v", i, err) + continue + } + if string(slurp) != tt.want { + t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want) + } + } +} + +func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) } +func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) } +func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) } +func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) } +func testStarRequest(t *testing.T, method string, h2 bool) { + defer afterTest(t) + gotc := make(chan *Request, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("foo", "bar") + gotc <- r + w.(Flusher).Flush() + })) + defer cst.close() + + u, err := url.Parse(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + u.Path = "*" + + req := &Request{ + Method: method, + Header: Header{}, + URL: u, + } + + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("RoundTrip = %v", err) + } + res.Body.Close() + + wantFoo := "bar" + wantLen := int64(-1) + if method == "OPTIONS" { + wantFoo = "" + wantLen = 0 + } + if res.StatusCode != 200 { + t.Errorf("status code = %v; want %d", res.Status, 200) + } + if res.ContentLength != wantLen { + t.Errorf("content length = %v; want %d", res.ContentLength, wantLen) + } + if got := res.Header.Get("foo"); got != wantFoo { + t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo) + } + select { + case req = <-gotc: + default: + req = nil + } + if req == nil { + if method != "OPTIONS" { + t.Fatalf("handler never got request") + } + return + } + if req.Method != method { + t.Errorf("method = %q; want %q", req.Method, method) + } + if req.URL.Path != "*" { + t.Errorf("URL.Path = %q; want *", req.URL.Path) + } + if req.RequestURI != "*" { + t.Errorf("RequestURI = %q; want *", req.RequestURI) + } +} + +// Issue 13957 +func TestTransportDiscardsUnneededConns(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) + })) + defer cst.close() + + var numOpen, numClose int32 // atomic + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + tr := &Transport{ + TLSClientConfig: tlsConfig, + DialTLS: func(_, addr string) (net.Conn, error) { + time.Sleep(10 * time.Millisecond) + rc, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + atomic.AddInt32(&numOpen, 1) + c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} + return tls.Client(c, tlsConfig), nil + }, + } + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatal(err) + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + + const N = 10 + gotBody := make(chan string, N) + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := c.Get(cst.ts.URL) + if err != nil { + t.Errorf("Get: %v", err) + return + } + defer resp.Body.Close() + slurp, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + gotBody <- string(slurp) + }() + } + wg.Wait() + close(gotBody) + + var last string + for got := range gotBody { + if last == "" { + last = got + continue + } + if got != last { + t.Errorf("Response body changed: %q -> %q", last, got) + } + } + + var open, close int32 + for i := 0; i < 150; i++ { + open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) + if open < 1 { + t.Fatalf("open = %d; want at least", open) + } + if close == open-1 { + // Success + return + } + time.Sleep(10 * time.Millisecond) + } + t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) +} + +// tests that Transport doesn't retain a pointer to the provided request. +func TestTransportGCRequest_h1(t *testing.T) { testTransportGCRequest(t, h1Mode) } +func TestTransportGCRequest_h2(t *testing.T) { testTransportGCRequest(t, h2Mode) } +func testTransportGCRequest(t *testing.T, h2 bool) { + if runtime.Compiler == "gccgo" { + t.Skip("skipping on gccgo because conservative GC means that finalizer may never run") + } + + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + ioutil.ReadAll(r.Body) + io.WriteString(w, "Hello.") + })) + defer cst.close() + + didGC := make(chan struct{}) + (func() { + body := strings.NewReader("some body") + req, _ := NewRequest("POST", cst.ts.URL, body) + runtime.SetFinalizer(req, func(*Request) { close(didGC) }) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if _, err := ioutil.ReadAll(res.Body); err != nil { + t.Fatal(err) + } + if err := res.Body.Close(); err != nil { + t.Fatal(err) + } + })() + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + for { + select { + case <-didGC: + return + case <-time.After(100 * time.Millisecond): + runtime.GC() + case <-timeout.C: + t.Fatal("never saw GC of request") + } + } +} + +type noteCloseConn struct { + net.Conn + closeFunc func() +} + +func (x noteCloseConn) Close() error { + x.closeFunc() + return x.Conn.Close() +} diff --git a/libgo/go/net/http/doc.go b/libgo/go/net/http/doc.go index b1216e8dafa..4ec8272f628 100644 --- a/libgo/go/net/http/doc.go +++ b/libgo/go/net/http/doc.go @@ -76,5 +76,20 @@ custom Server: MaxHeaderBytes: 1 << 20, } log.Fatal(s.ListenAndServe()) + +The http package has transparent support for the HTTP/2 protocol when +using HTTPS. Programs that must disable HTTP/2 can do so by setting +Transport.TLSNextProto (for clients) or Server.TLSNextProto (for +servers) to a non-nil, empty map. Alternatively, the following GODEBUG +environment variables are currently supported: + + GODEBUG=http2client=0 # disable HTTP/2 client support + GODEBUG=http2server=0 # disable HTTP/2 server support + GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs + GODEBUG=http2debug=2 # ... even more verbose, with frame dumps + +The GODEBUG variables are not covered by Go's API compatibility promise. +HTTP/2 support was added in Go 1.6. Please report any issues instead of +disabling HTTP/2 support: https://golang.org/s/http2bug */ package http diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 0457be50da6..52bccbdce31 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -9,11 +9,24 @@ package http import ( "net" - "net/url" "sync" "time" ) +var ( + DefaultUserAgent = defaultUserAgent + NewLoggingConn = newLoggingConn + ExportAppendTime = appendTime + ExportRefererForURL = refererForURL + ExportServerNewConn = (*Server).newConn + ExportCloseWriteAndWait = (*conn).closeWriteAndWait + ExportErrRequestCanceled = errRequestCanceled + ExportErrRequestCanceledConn = errRequestCanceledConn + ExportServeFile = serveFile + ExportHttp2ConfigureTransport = http2ConfigureTransport + ExportHttp2ConfigureServer = http2ConfigureServer +) + func init() { // We only want to pay for this cost during testing. // When not under test, these values are always nil @@ -21,11 +34,42 @@ func init() { testHookMu = new(sync.Mutex) } -func NewLoggingConn(baseName string, c net.Conn) net.Conn { - return newLoggingConn(baseName, c) +var ( + SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip) + SetTestHookWaitResLoop = hookSetter(&testHookWaitResLoop) + SetRoundTripRetried = hookSetter(&testHookRoundTripRetried) +) + +func SetReadLoopBeforeNextReadHook(f func()) { + testHookMu.Lock() + defer testHookMu.Unlock() + unnilTestHook(&f) + testHookReadLoopBeforeNextRead = f +} + +// SetPendingDialHooks sets the hooks that run before and after handling +// pending dials. +func SetPendingDialHooks(before, after func()) { + unnilTestHook(&before) + unnilTestHook(&after) + testHookPrePendingDial, testHookPostPendingDial = before, after } -var ExportAppendTime = appendTime +func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } + +func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { + return &timeoutHandler{ + handler: handler, + timeout: func() <-chan time.Time { return ch }, + // (no body and nil cancelTimer) + } +} + +func ResetCachedEnvironment() { + httpProxyEnv.reset() + httpsProxyEnv.reset() + noProxyEnv.reset() +} func (t *Transport) NumPendingRequestsForTesting() int { t.reqMu.Lock() @@ -78,55 +122,25 @@ func (t *Transport) RequestIdleConnChForTesting() { func (t *Transport) PutIdleTestConn() bool { c, _ := net.Pipe() - return t.putIdleConn(&persistConn{ + return t.tryPutIdleConn(&persistConn{ t: t, conn: c, // dummy closech: make(chan struct{}), // so it can be closed cacheKey: connectMethodKey{"", "http", "example.com"}, - }) -} - -func SetInstallConnClosedHook(f func()) { - testHookPersistConnClosedGotRes = f + }) == nil } -func SetEnterRoundTripHook(f func()) { - testHookEnterRoundTrip = f -} - -func SetReadLoopBeforeNextReadHook(f func()) { - testHookMu.Lock() - defer testHookMu.Unlock() - testHookReadLoopBeforeNextRead = f -} - -func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { - f := func() <-chan time.Time { - return ch +// All test hooks must be non-nil so they can be called directly, +// but the tests use nil to mean hook disabled. +func unnilTestHook(f *func()) { + if *f == nil { + *f = nop } - return &timeoutHandler{handler, f, ""} } -func ResetCachedEnvironment() { - httpProxyEnv.reset() - httpsProxyEnv.reset() - noProxyEnv.reset() -} - -var DefaultUserAgent = defaultUserAgent - -func ExportRefererForURL(lastReq, newReq *url.URL) string { - return refererForURL(lastReq, newReq) -} - -// SetPendingDialHooks sets the hooks that run before and after handling -// pending dials. -func SetPendingDialHooks(before, after func()) { - prePendingDial, postPendingDial = before, after +func hookSetter(dst *func()) func(func()) { + return func(fn func()) { + unnilTestHook(&fn) + *dst = fn + } } - -var ExportServerNewConn = (*Server).newConn - -var ExportCloseWriteAndWait = (*conn).closeWriteAndWait - -var ExportErrRequestCanceled = errRequestCanceled diff --git a/libgo/go/net/http/fcgi/child.go b/libgo/go/net/http/fcgi/child.go index da824ed717e..88704245db8 100644 --- a/libgo/go/net/http/fcgi/child.go +++ b/libgo/go/net/http/fcgi/child.go @@ -56,6 +56,9 @@ func (r *request) parseParams() { return } text = text[n:] + if int(keyLen)+int(valLen) > len(text) { + return + } key := readString(text, keyLen) text = text[keyLen:] val := readString(text, valLen) diff --git a/libgo/go/net/http/fcgi/fcgi_test.go b/libgo/go/net/http/fcgi/fcgi_test.go index de0f7f831f6..b6013bfdd51 100644 --- a/libgo/go/net/http/fcgi/fcgi_test.go +++ b/libgo/go/net/http/fcgi/fcgi_test.go @@ -254,3 +254,27 @@ func TestChildServeCleansUp(t *testing.T) { <-done } } + +type rwNopCloser struct { + io.Reader + io.Writer +} + +func (rwNopCloser) Close() error { + return nil +} + +// Verifies it doesn't crash. Issue 11824. +func TestMalformedParams(t *testing.T) { + input := []byte{ + // beginRequest, requestId=1, contentLength=8, role=1, keepConn=1 + 1, 1, 0, 1, 0, 8, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + // params, requestId=1, contentLength=10, k1Len=50, v1Len=50 (malformed, wrong length) + 1, 4, 0, 1, 0, 10, 0, 0, 50, 50, 3, 4, 5, 6, 7, 8, 9, 10, + // end of params + 1, 4, 0, 1, 0, 0, 0, 0, + } + rw := rwNopCloser{bytes.NewReader(input), ioutil.Discard} + c := newChild(rw, http.DefaultServeMux) + c.serve() +} diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index 75720234c25..f61c138c1d9 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -17,6 +17,7 @@ import ( "os" "path" "path/filepath" + "sort" "strconv" "strings" "time" @@ -62,30 +63,34 @@ type FileSystem interface { type File interface { io.Closer io.Reader + io.Seeker Readdir(count int) ([]os.FileInfo, error) - Seek(offset int64, whence int) (int64, error) Stat() (os.FileInfo, error) } func dirList(w ResponseWriter, f File) { + dirs, err := f.Readdir(-1) + if err != nil { + // TODO: log err.Error() to the Server.ErrorLog, once it's possible + // for a handler to get at its Server via the ResponseWriter. See + // Issue 12438. + Error(w, "Error reading directory", StatusInternalServerError) + return + } + sort.Sort(byName(dirs)) + w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "<pre>\n") - for { - dirs, err := f.Readdir(100) - if err != nil || len(dirs) == 0 { - break - } - for _, d := range dirs { - name := d.Name() - if d.IsDir() { - name += "/" - } - // name may contain '?' or '#', which must be escaped to remain - // part of the URL path, and not indicate the start of a query - // string or fragment. - url := url.URL{Path: name} - fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name)) + for _, d := range dirs { + name := d.Name() + if d.IsDir() { + name += "/" } + // name may contain '?' or '#', which must be escaped to remain + // part of the URL path, and not indicate the start of a query + // string or fragment. + url := url.URL{Path: name} + fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name)) } fmt.Fprintf(w, "</pre>\n") } @@ -364,8 +369,8 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec } defer f.Close() - d, err1 := f.Stat() - if err1 != nil { + d, err := f.Stat() + if err != nil { msg, code := toHTTPError(err) Error(w, msg, code) return @@ -446,15 +451,44 @@ func localRedirect(w ResponseWriter, r *Request, newPath string) { // ServeFile replies to the request with the contents of the named // file or directory. // +// If the provided file or direcory name is a relative path, it is +// interpreted relative to the current directory and may ascend to parent +// directories. If the provided name is constructed from user input, it +// should be sanitized before calling ServeFile. As a precaution, ServeFile +// will reject requests where r.URL.Path contains a ".." path element. +// // As a special case, ServeFile redirects any request where r.URL.Path // ends in "/index.html" to the same path, without the final // "index.html". To avoid such redirects either modify the path or // use ServeContent. func ServeFile(w ResponseWriter, r *Request, name string) { + if containsDotDot(r.URL.Path) { + // Too many programs use r.URL.Path to construct the argument to + // serveFile. Reject the request under the assumption that happened + // here and ".." may not be wanted. + // Note that name might not contain "..", for example if code (still + // incorrectly) used filepath.Join(myDir, r.URL.Path). + Error(w, "invalid URL path", StatusBadRequest) + return + } dir, file := filepath.Split(name) serveFile(w, r, Dir(dir), file, false) } +func containsDotDot(v string) bool { + if !strings.Contains(v, "..") { + return false + } + for _, ent := range strings.FieldsFunc(v, isSlashRune) { + if ent == ".." { + return true + } + } + return false +} + +func isSlashRune(r rune) bool { return r == '/' || r == '\\' } + type fileHandler struct { root FileSystem } @@ -585,3 +619,9 @@ func sumRangesSize(ranges []httpRange) (size int64) { } return } + +type byName []os.FileInfo + +func (s byName) Len() int { return len(s) } +func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() } +func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index 538f34d7201..cf5b63c9f75 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -5,6 +5,7 @@ package http_test import ( + "bufio" "bytes" "errors" "fmt" @@ -177,6 +178,36 @@ Cases: } } +func TestServeFile_DotDot(t *testing.T) { + tests := []struct { + req string + wantStatus int + }{ + {"/testdata/file", 200}, + {"/../file", 400}, + {"/..", 400}, + {"/../", 400}, + {"/../foo", 400}, + {"/..\\foo", 400}, + {"/file/a", 200}, + {"/file/a..", 200}, + {"/file/a/..", 400}, + {"/file/a\\..", 400}, + } + for _, tt := range tests { + req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + tt.req + " HTTP/1.1\r\nHost: foo\r\n\r\n"))) + if err != nil { + t.Errorf("bad request %q: %v", tt.req, err) + continue + } + rec := httptest.NewRecorder() + ServeFile(rec, req, "testdata/file") + if rec.Code != tt.wantStatus { + t.Errorf("for request %q, status = %d; want %d", tt.req, rec.Code, tt.wantStatus) + } + } +} + var fsRedirectTestData = []struct { original, redirect string }{ @@ -283,6 +314,49 @@ func TestFileServerEscapesNames(t *testing.T) { } } +func TestFileServerSortsNames(t *testing.T) { + defer afterTest(t) + const contents = "I am a fake file" + dirMod := time.Unix(123, 0).UTC() + fileMod := time.Unix(1000000000, 0).UTC() + fs := fakeFS{ + "/": &fakeFileInfo{ + dir: true, + modtime: dirMod, + ents: []*fakeFileInfo{ + { + basename: "b", + modtime: fileMod, + contents: contents, + }, + { + basename: "a", + modtime: fileMod, + contents: contents, + }, + }, + }, + } + + ts := httptest.NewServer(FileServer(&fs)) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("read Body: %v", err) + } + s := string(b) + if !strings.Contains(s, "<a href=\"a\">a</a>\n<a href=\"b\">b</a>") { + t.Errorf("output appears to be unsorted:\n%s", s) + } +} + func mustRemoveAll(dir string) { err := os.RemoveAll(dir) if err != nil { @@ -434,14 +508,27 @@ func TestServeFileFromCWD(t *testing.T) { } } -func TestServeFileWithContentEncoding(t *testing.T) { +// Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is +// specified. +func TestServeFileWithContentEncoding_h1(t *testing.T) { testServeFileWithContentEncoding(t, h1Mode) } +func TestServeFileWithContentEncoding_h2(t *testing.T) { testServeFileWithContentEncoding(t, h2Mode) } +func testServeFileWithContentEncoding(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") + + // Because the testdata is so small, it would fit in + // both the h1 and h2 Server's write buffers. For h1, + // sendfile is used, though, forcing a header flush at + // the io.Copy. http2 doesn't do a header flush so + // buffers all 11 bytes and then adds its own + // Content-Length. To prevent the Server's + // Content-Length and test ServeFile only, flush here. + w.(Flusher).Flush() })) - defer ts.Close() - resp, err := Get(ts.URL) + defer cst.close() + resp, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -807,6 +894,28 @@ func TestServeContent(t *testing.T) { } } +// Issue 12991 +func TestServerFileStatError(t *testing.T) { + rec := httptest.NewRecorder() + r, _ := NewRequest("GET", "http://foo/", nil) + redirect := false + name := "file.txt" + fs := issue12991FS{} + ExportServeFile(rec, r, fs, name, redirect) + if body := rec.Body.String(); !strings.Contains(body, "403") || !strings.Contains(body, "Forbidden") { + t.Errorf("wanted 403 forbidden message; got: %s", body) + } +} + +type issue12991FS struct{} + +func (issue12991FS) Open(string) (File, error) { return issue12991File{}, nil } + +type issue12991File struct{ File } + +func (issue12991File) Stat() (os.FileInfo, error) { return nil, os.ErrPermission } +func (issue12991File) Close() error { return nil } + func TestServeContentErrorMessages(t *testing.T) { defer afterTest(t) fs := fakeFS{ @@ -852,13 +961,16 @@ func TestLinuxSendfile(t *testing.T) { } defer ln.Close() - trace := "trace=sendfile" - if runtime.GOARCH != "alpha" { - trace = trace + ",sendfile64" + syscalls := "sendfile,sendfile64" + switch runtime.GOARCH { + case "mips64", "mips64le", "alpha": + // mips64 strace doesn't support sendfile64 and will error out + // if we specify that with `-e trace='. + syscalls = "sendfile" } var buf bytes.Buffer - child := exec.Command("strace", "-f", "-q", "-e", trace, os.Args[0], "-test.run=TestLinuxSendfileChild") + child := exec.Command("strace", "-f", "-q", "-e", "trace="+syscalls, os.Args[0], "-test.run=TestLinuxSendfileChild") child.ExtraFiles = append(child.ExtraFiles, lnf) child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) child.Stdout = &buf @@ -878,7 +990,7 @@ func TestLinuxSendfile(t *testing.T) { res.Body.Close() // Force child to exit cleanly. - Get(fmt.Sprintf("http://%s/quit", ln.Addr())) + Post(fmt.Sprintf("http://%s/quit", ln.Addr()), "", nil) child.Wait() rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+\)\s*=\s*\d+\s*\n`) diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go new file mode 100644 index 00000000000..e7236299e22 --- /dev/null +++ b/libgo/go/net/http/h2_bundle.go @@ -0,0 +1,6530 @@ +// Code generated by golang.org/x/tools/cmd/bundle command: +// $ bundle golang.org/x/net/http2 net/http http2 + +// Package http2 implements the HTTP/2 protocol. +// +// This package is low-level and intended to be used directly by very +// few people. Most users will use it indirectly through the automatic +// use by the net/http package (from Go 1.6 and later). +// For use in earlier Go versions see ConfigureServer. (Transport support +// requires Go 1.6 or later) +// +// See https://http2.github.io/ for more information on HTTP/2. +// +// See https://http2.golang.org/ for a test server running this code. +// + +package http + +import ( + "bufio" + "bytes" + "compress/gzip" + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "internal/golang.org/x/net/http2/hpack" + "io" + "io/ioutil" + "log" + "net" + "net/textproto" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +// ClientConnPool manages a pool of HTTP/2 client connections. +type http2ClientConnPool interface { + GetClientConn(req *Request, addr string) (*http2ClientConn, error) + MarkDead(*http2ClientConn) +} + +// TODO: use singleflight for dialing and addConnCalls? +type http2clientConnPool struct { + t *http2Transport + + mu sync.Mutex // TODO: maybe switch to RWMutex + // TODO: add support for sharing conns based on cert names + // (e.g. share conn for googleapis.com and appspot.com) + conns map[string][]*http2ClientConn // key is host:port + dialing map[string]*http2dialCall // currently in-flight dials + keys map[*http2ClientConn][]string + addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls +} + +func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2dialOnMiss) +} + +const ( + http2dialOnMiss = true + http2noDialOnMiss = false +) + +func (p *http2clientConnPool) getClientConn(_ *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + p.mu.Lock() + for _, cc := range p.conns[addr] { + if cc.CanTakeNewRequest() { + p.mu.Unlock() + return cc, nil + } + } + if !dialOnMiss { + p.mu.Unlock() + return nil, http2ErrNoCachedConn + } + call := p.getStartDialLocked(addr) + p.mu.Unlock() + <-call.done + return call.res, call.err +} + +// dialCall is an in-flight Transport dial call to a host. +type http2dialCall struct { + p *http2clientConnPool + done chan struct{} // closed when done + res *http2ClientConn // valid after done is closed + err error // valid after done is closed +} + +// requires p.mu is held. +func (p *http2clientConnPool) getStartDialLocked(addr string) *http2dialCall { + if call, ok := p.dialing[addr]; ok { + + return call + } + call := &http2dialCall{p: p, done: make(chan struct{})} + if p.dialing == nil { + p.dialing = make(map[string]*http2dialCall) + } + p.dialing[addr] = call + go call.dial(addr) + return call +} + +// run in its own goroutine. +func (c *http2dialCall) dial(addr string) { + c.res, c.err = c.p.t.dialClientConn(addr) + close(c.done) + + c.p.mu.Lock() + delete(c.p.dialing, addr) + if c.err == nil { + c.p.addConnLocked(addr, c.res) + } + c.p.mu.Unlock() +} + +// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't +// already exist. It coalesces concurrent calls with the same key. +// This is used by the http1 Transport code when it creates a new connection. Because +// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know +// the protocol), it can get into a situation where it has multiple TLS connections. +// This code decides which ones live or die. +// The return value used is whether c was used. +// c is never closed. +func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) { + p.mu.Lock() + for _, cc := range p.conns[key] { + if cc.CanTakeNewRequest() { + p.mu.Unlock() + return false, nil + } + } + call, dup := p.addConnCalls[key] + if !dup { + if p.addConnCalls == nil { + p.addConnCalls = make(map[string]*http2addConnCall) + } + call = &http2addConnCall{ + p: p, + done: make(chan struct{}), + } + p.addConnCalls[key] = call + go call.run(t, key, c) + } + p.mu.Unlock() + + <-call.done + if call.err != nil { + return false, call.err + } + return !dup, nil +} + +type http2addConnCall struct { + p *http2clientConnPool + done chan struct{} // closed when done + err error +} + +func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { + cc, err := t.NewClientConn(tc) + + p := c.p + p.mu.Lock() + if err != nil { + c.err = err + } else { + p.addConnLocked(key, cc) + } + delete(p.addConnCalls, key) + p.mu.Unlock() + close(c.done) +} + +func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) { + p.mu.Lock() + p.addConnLocked(key, cc) + p.mu.Unlock() +} + +// p.mu must be held +func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) { + for _, v := range p.conns[key] { + if v == cc { + return + } + } + if p.conns == nil { + p.conns = make(map[string][]*http2ClientConn) + } + if p.keys == nil { + p.keys = make(map[*http2ClientConn][]string) + } + p.conns[key] = append(p.conns[key], cc) + p.keys[cc] = append(p.keys[cc], key) +} + +func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + for _, key := range p.keys[cc] { + vv, ok := p.conns[key] + if !ok { + continue + } + newList := http2filterOutClientConn(vv, cc) + if len(newList) > 0 { + p.conns[key] = newList + } else { + delete(p.conns, key) + } + } + delete(p.keys, cc) +} + +func (p *http2clientConnPool) closeIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + + for _, vv := range p.conns { + for _, cc := range vv { + cc.closeIfIdle() + } + } +} + +func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn { + out := in[:0] + for _, v := range in { + if v != exclude { + out = append(out, v) + } + } + + if len(in) != len(out) { + in[len(in)-1] = nil + } + return out +} + +func http2configureTransport(t1 *Transport) (*http2Transport, error) { + connPool := new(http2clientConnPool) + t2 := &http2Transport{ + ConnPool: http2noDialClientConnPool{connPool}, + t1: t1, + } + connPool.t = t2 + if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { + return nil, err + } + if t1.TLSClientConfig == nil { + t1.TLSClientConfig = new(tls.Config) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { + t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { + t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") + } + upgradeFn := func(authority string, c *tls.Conn) RoundTripper { + addr := http2authorityAddr(authority) + if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { + go c.Close() + return http2erringRoundTripper{err} + } else if !used { + + go c.Close() + } + return t2 + } + if m := t1.TLSNextProto; len(m) == 0 { + t1.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{ + "h2": upgradeFn, + } + } else { + m["h2"] = upgradeFn + } + return t2, nil +} + +// registerHTTPSProtocol calls Transport.RegisterProtocol but +// convering panics into errors. +func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) + } + }() + t.RegisterProtocol("https", rt) + return nil +} + +// noDialClientConnPool is an implementation of http2.ClientConnPool +// which never dials. We let the HTTP/1.1 client dial and use its TLS +// connection instead. +type http2noDialClientConnPool struct{ *http2clientConnPool } + +func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2noDialOnMiss) +} + +// noDialH2RoundTripper is a RoundTripper which only tries to complete the request +// if there's already has a cached connection to the host. +type http2noDialH2RoundTripper struct{ t *http2Transport } + +func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) { + res, err := rt.t.RoundTrip(req) + if err == http2ErrNoCachedConn { + return nil, ErrSkipAltProtocol + } + return res, err +} + +// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. +type http2ErrCode uint32 + +const ( + http2ErrCodeNo http2ErrCode = 0x0 + http2ErrCodeProtocol http2ErrCode = 0x1 + http2ErrCodeInternal http2ErrCode = 0x2 + http2ErrCodeFlowControl http2ErrCode = 0x3 + http2ErrCodeSettingsTimeout http2ErrCode = 0x4 + http2ErrCodeStreamClosed http2ErrCode = 0x5 + http2ErrCodeFrameSize http2ErrCode = 0x6 + http2ErrCodeRefusedStream http2ErrCode = 0x7 + http2ErrCodeCancel http2ErrCode = 0x8 + http2ErrCodeCompression http2ErrCode = 0x9 + http2ErrCodeConnect http2ErrCode = 0xa + http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb + http2ErrCodeInadequateSecurity http2ErrCode = 0xc + http2ErrCodeHTTP11Required http2ErrCode = 0xd +) + +var http2errCodeName = map[http2ErrCode]string{ + http2ErrCodeNo: "NO_ERROR", + http2ErrCodeProtocol: "PROTOCOL_ERROR", + http2ErrCodeInternal: "INTERNAL_ERROR", + http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR", + http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", + http2ErrCodeStreamClosed: "STREAM_CLOSED", + http2ErrCodeFrameSize: "FRAME_SIZE_ERROR", + http2ErrCodeRefusedStream: "REFUSED_STREAM", + http2ErrCodeCancel: "CANCEL", + http2ErrCodeCompression: "COMPRESSION_ERROR", + http2ErrCodeConnect: "CONNECT_ERROR", + http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", + http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", + http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", +} + +func (e http2ErrCode) String() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("unknown error code 0x%x", uint32(e)) +} + +// ConnectionError is an error that results in the termination of the +// entire connection. +type http2ConnectionError http2ErrCode + +func (e http2ConnectionError) Error() string { + return fmt.Sprintf("connection error: %s", http2ErrCode(e)) +} + +// StreamError is an error that only affects one stream within an +// HTTP/2 connection. +type http2StreamError struct { + StreamID uint32 + Code http2ErrCode +} + +func (e http2StreamError) Error() string { + return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) +} + +// 6.9.1 The Flow Control Window +// "If a sender receives a WINDOW_UPDATE that causes a flow control +// window to exceed this maximum it MUST terminate either the stream +// or the connection, as appropriate. For streams, [...]; for the +// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." +type http2goAwayFlowError struct{} + +func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } + +// Errors of this type are only returned by the frame parser functions +// and converted into ConnectionError(ErrCodeProtocol). +type http2connError struct { + Code http2ErrCode + Reason string +} + +func (e http2connError) Error() string { + return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) +} + +// fixedBuffer is an io.ReadWriter backed by a fixed size buffer. +// It never allocates, but moves old data as new data is written. +type http2fixedBuffer struct { + buf []byte + r, w int +} + +var ( + http2errReadEmpty = errors.New("read from empty fixedBuffer") + http2errWriteFull = errors.New("write on full fixedBuffer") +) + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *http2fixedBuffer) Read(p []byte) (n int, err error) { + if b.r == b.w { + return 0, http2errReadEmpty + } + n = copy(p, b.buf[b.r:b.w]) + b.r += n + if b.r == b.w { + b.r = 0 + b.w = 0 + } + return n, nil +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *http2fixedBuffer) Len() int { + return b.w - b.r +} + +// Write copies bytes from p into the buffer. +// It is an error to write more data than the buffer can hold. +func (b *http2fixedBuffer) Write(p []byte) (n int, err error) { + + if b.r > 0 && len(p) > len(b.buf)-b.w { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + n = copy(b.buf[b.w:], p) + b.w += n + if n < len(p) { + err = http2errWriteFull + } + return n, err +} + +// flow is the flow control window's size. +type http2flow struct { + // n is the number of DATA bytes we're allowed to send. + // A flow is kept both on a conn and a per-stream. + n int32 + + // conn points to the shared connection-level flow that is + // shared by all streams on that conn. It is nil for the flow + // that's on the conn directly. + conn *http2flow +} + +func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf } + +func (f *http2flow) available() int32 { + n := f.n + if f.conn != nil && f.conn.n < n { + n = f.conn.n + } + return n +} + +func (f *http2flow) take(n int32) { + if n > f.available() { + panic("internal error: took too much") + } + f.n -= n + if f.conn != nil { + f.conn.n -= n + } +} + +// add adds n bytes (positive or negative) to the flow control window. +// It returns false if the sum would exceed 2^31-1. +func (f *http2flow) add(n int32) bool { + remain := (1<<31 - 1) - f.n + if n > remain { + return false + } + f.n += n + return true +} + +const http2frameHeaderLen = 9 + +var http2padZeros = make([]byte, 255) // zeros for padding + +// A FrameType is a registered frame type as defined in +// http://http2.github.io/http2-spec/#rfc.section.11.2 +type http2FrameType uint8 + +const ( + http2FrameData http2FrameType = 0x0 + http2FrameHeaders http2FrameType = 0x1 + http2FramePriority http2FrameType = 0x2 + http2FrameRSTStream http2FrameType = 0x3 + http2FrameSettings http2FrameType = 0x4 + http2FramePushPromise http2FrameType = 0x5 + http2FramePing http2FrameType = 0x6 + http2FrameGoAway http2FrameType = 0x7 + http2FrameWindowUpdate http2FrameType = 0x8 + http2FrameContinuation http2FrameType = 0x9 +) + +var http2frameName = map[http2FrameType]string{ + http2FrameData: "DATA", + http2FrameHeaders: "HEADERS", + http2FramePriority: "PRIORITY", + http2FrameRSTStream: "RST_STREAM", + http2FrameSettings: "SETTINGS", + http2FramePushPromise: "PUSH_PROMISE", + http2FramePing: "PING", + http2FrameGoAway: "GOAWAY", + http2FrameWindowUpdate: "WINDOW_UPDATE", + http2FrameContinuation: "CONTINUATION", +} + +func (t http2FrameType) String() string { + if s, ok := http2frameName[t]; ok { + return s + } + return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) +} + +// Flags is a bitmask of HTTP/2 flags. +// The meaning of flags varies depending on the frame type. +type http2Flags uint8 + +// Has reports whether f contains all (0 or more) flags in v. +func (f http2Flags) Has(v http2Flags) bool { + return (f & v) == v +} + +// Frame-specific FrameHeader flag bits. +const ( + // Data Frame + http2FlagDataEndStream http2Flags = 0x1 + http2FlagDataPadded http2Flags = 0x8 + + // Headers Frame + http2FlagHeadersEndStream http2Flags = 0x1 + http2FlagHeadersEndHeaders http2Flags = 0x4 + http2FlagHeadersPadded http2Flags = 0x8 + http2FlagHeadersPriority http2Flags = 0x20 + + // Settings Frame + http2FlagSettingsAck http2Flags = 0x1 + + // Ping Frame + http2FlagPingAck http2Flags = 0x1 + + // Continuation Frame + http2FlagContinuationEndHeaders http2Flags = 0x4 + + http2FlagPushPromiseEndHeaders http2Flags = 0x4 + http2FlagPushPromisePadded http2Flags = 0x8 +) + +var http2flagName = map[http2FrameType]map[http2Flags]string{ + http2FrameData: { + http2FlagDataEndStream: "END_STREAM", + http2FlagDataPadded: "PADDED", + }, + http2FrameHeaders: { + http2FlagHeadersEndStream: "END_STREAM", + http2FlagHeadersEndHeaders: "END_HEADERS", + http2FlagHeadersPadded: "PADDED", + http2FlagHeadersPriority: "PRIORITY", + }, + http2FrameSettings: { + http2FlagSettingsAck: "ACK", + }, + http2FramePing: { + http2FlagPingAck: "ACK", + }, + http2FrameContinuation: { + http2FlagContinuationEndHeaders: "END_HEADERS", + }, + http2FramePushPromise: { + http2FlagPushPromiseEndHeaders: "END_HEADERS", + http2FlagPushPromisePadded: "PADDED", + }, +} + +// a frameParser parses a frame given its FrameHeader and payload +// bytes. The length of payload will always equal fh.Length (which +// might be 0). +type http2frameParser func(fh http2FrameHeader, payload []byte) (http2Frame, error) + +var http2frameParsers = map[http2FrameType]http2frameParser{ + http2FrameData: http2parseDataFrame, + http2FrameHeaders: http2parseHeadersFrame, + http2FramePriority: http2parsePriorityFrame, + http2FrameRSTStream: http2parseRSTStreamFrame, + http2FrameSettings: http2parseSettingsFrame, + http2FramePushPromise: http2parsePushPromise, + http2FramePing: http2parsePingFrame, + http2FrameGoAway: http2parseGoAwayFrame, + http2FrameWindowUpdate: http2parseWindowUpdateFrame, + http2FrameContinuation: http2parseContinuationFrame, +} + +func http2typeFrameParser(t http2FrameType) http2frameParser { + if f := http2frameParsers[t]; f != nil { + return f + } + return http2parseUnknownFrame +} + +// A FrameHeader is the 9 byte header of all HTTP/2 frames. +// +// See http://http2.github.io/http2-spec/#FrameHeader +type http2FrameHeader struct { + valid bool // caller can access []byte fields in the Frame + + // Type is the 1 byte frame type. There are ten standard frame + // types, but extension frame types may be written by WriteRawFrame + // and will be returned by ReadFrame (as UnknownFrame). + Type http2FrameType + + // Flags are the 1 byte of 8 potential bit flags per frame. + // They are specific to the frame type. + Flags http2Flags + + // Length is the length of the frame, not including the 9 byte header. + // The maximum size is one byte less than 16MB (uint24), but only + // frames up to 16KB are allowed without peer agreement. + Length uint32 + + // StreamID is which stream this frame is for. Certain frames + // are not stream-specific, in which case this field is 0. + StreamID uint32 +} + +// Header returns h. It exists so FrameHeaders can be embedded in other +// specific frame types and implement the Frame interface. +func (h http2FrameHeader) Header() http2FrameHeader { return h } + +func (h http2FrameHeader) String() string { + var buf bytes.Buffer + buf.WriteString("[FrameHeader ") + h.writeDebug(&buf) + buf.WriteByte(']') + return buf.String() +} + +func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) { + buf.WriteString(h.Type.String()) + if h.Flags != 0 { + buf.WriteString(" flags=") + set := 0 + for i := uint8(0); i < 8; i++ { + if h.Flags&(1<<i) == 0 { + continue + } + set++ + if set > 1 { + buf.WriteByte('|') + } + name := http2flagName[h.Type][http2Flags(1<<i)] + if name != "" { + buf.WriteString(name) + } else { + fmt.Fprintf(buf, "0x%x", 1<<i) + } + } + } + if h.StreamID != 0 { + fmt.Fprintf(buf, " stream=%d", h.StreamID) + } + fmt.Fprintf(buf, " len=%d", h.Length) +} + +func (h *http2FrameHeader) checkValid() { + if !h.valid { + panic("Frame accessor called on non-owned Frame") + } +} + +func (h *http2FrameHeader) invalidate() { h.valid = false } + +// frame header bytes. +// Used only by ReadFrameHeader. +var http2fhBytes = sync.Pool{ + New: func() interface{} { + buf := make([]byte, http2frameHeaderLen) + return &buf + }, +} + +// ReadFrameHeader reads 9 bytes from r and returns a FrameHeader. +// Most users should use Framer.ReadFrame instead. +func http2ReadFrameHeader(r io.Reader) (http2FrameHeader, error) { + bufp := http2fhBytes.Get().(*[]byte) + defer http2fhBytes.Put(bufp) + return http2readFrameHeader(*bufp, r) +} + +func http2readFrameHeader(buf []byte, r io.Reader) (http2FrameHeader, error) { + _, err := io.ReadFull(r, buf[:http2frameHeaderLen]) + if err != nil { + return http2FrameHeader{}, err + } + return http2FrameHeader{ + Length: (uint32(buf[0])<<16 | uint32(buf[1])<<8 | uint32(buf[2])), + Type: http2FrameType(buf[3]), + Flags: http2Flags(buf[4]), + StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1), + valid: true, + }, nil +} + +// A Frame is the base interface implemented by all frame types. +// Callers will generally type-assert the specific frame type: +// *HeadersFrame, *SettingsFrame, *WindowUpdateFrame, etc. +// +// Frames are only valid until the next call to Framer.ReadFrame. +type http2Frame interface { + Header() http2FrameHeader + + // invalidate is called by Framer.ReadFrame to make this + // frame's buffers as being invalid, since the subsequent + // frame will reuse them. + invalidate() +} + +// A Framer reads and writes Frames. +type http2Framer struct { + r io.Reader + lastFrame http2Frame + errReason string + + // lastHeaderStream is non-zero if the last frame was an + // unfinished HEADERS/CONTINUATION. + lastHeaderStream uint32 + + maxReadSize uint32 + headerBuf [http2frameHeaderLen]byte + + // TODO: let getReadBuf be configurable, and use a less memory-pinning + // allocator in server.go to minimize memory pinned for many idle conns. + // Will probably also need to make frame invalidation have a hook too. + getReadBuf func(size uint32) []byte + readBuf []byte // cache for default getReadBuf + + maxWriteSize uint32 // zero means unlimited; TODO: implement + + w io.Writer + wbuf []byte + + // AllowIllegalWrites permits the Framer's Write methods to + // write frames that do not conform to the HTTP/2 spec. This + // permits using the Framer to test other HTTP/2 + // implementations' conformance to the spec. + // If false, the Write methods will prefer to return an error + // rather than comply. + AllowIllegalWrites bool + + // AllowIllegalReads permits the Framer's ReadFrame method + // to return non-compliant frames or frame orders. + // This is for testing and permits using the Framer to test + // other HTTP/2 implementations' conformance to the spec. + AllowIllegalReads bool + + logReads bool + + debugFramer *http2Framer // only use for logging written writes + debugFramerBuf *bytes.Buffer +} + +func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) { + + f.wbuf = append(f.wbuf[:0], + 0, + 0, + 0, + byte(ftype), + byte(flags), + byte(streamID>>24), + byte(streamID>>16), + byte(streamID>>8), + byte(streamID)) +} + +func (f *http2Framer) endWrite() error { + + length := len(f.wbuf) - http2frameHeaderLen + if length >= (1 << 24) { + return http2ErrFrameTooLarge + } + _ = append(f.wbuf[:0], + byte(length>>16), + byte(length>>8), + byte(length)) + if http2logFrameWrites { + f.logWrite() + } + + n, err := f.w.Write(f.wbuf) + if err == nil && n != len(f.wbuf) { + err = io.ErrShortWrite + } + return err +} + +func (f *http2Framer) logWrite() { + if f.debugFramer == nil { + f.debugFramerBuf = new(bytes.Buffer) + f.debugFramer = http2NewFramer(nil, f.debugFramerBuf) + f.debugFramer.logReads = false + + f.debugFramer.AllowIllegalReads = true + } + f.debugFramerBuf.Write(f.wbuf) + fr, err := f.debugFramer.ReadFrame() + if err != nil { + log.Printf("http2: Framer %p: failed to decode just-written frame", f) + return + } + log.Printf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) +} + +func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } + +func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } + +func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } + +func (f *http2Framer) writeUint32(v uint32) { + f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +const ( + http2minMaxFrameSize = 1 << 14 + http2maxFrameSize = 1<<24 - 1 +) + +// NewFramer returns a Framer that writes frames to w and reads them from r. +func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { + fr := &http2Framer{ + w: w, + r: r, + logReads: http2logFrameReads, + } + fr.getReadBuf = func(size uint32) []byte { + if cap(fr.readBuf) >= int(size) { + return fr.readBuf[:size] + } + fr.readBuf = make([]byte, size) + return fr.readBuf + } + fr.SetMaxReadFrameSize(http2maxFrameSize) + return fr +} + +// SetMaxReadFrameSize sets the maximum size of a frame +// that will be read by a subsequent call to ReadFrame. +// It is the caller's responsibility to advertise this +// limit with a SETTINGS frame. +func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { + if v > http2maxFrameSize { + v = http2maxFrameSize + } + fr.maxReadSize = v +} + +// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer +// sends a frame that is larger than declared with SetMaxReadFrameSize. +var http2ErrFrameTooLarge = errors.New("http2: frame too large") + +// terminalReadFrameError reports whether err is an unrecoverable +// error from ReadFrame and no other frames should be read. +func http2terminalReadFrameError(err error) bool { + if _, ok := err.(http2StreamError); ok { + return false + } + return err != nil +} + +// ReadFrame reads a single frame. The returned Frame is only valid +// until the next call to ReadFrame. +// +// If the frame is larger than previously set with SetMaxReadFrameSize, the +// returned error is ErrFrameTooLarge. Other errors may be of type +// ConnectionError, StreamError, or anything else from from the underlying +// reader. +func (fr *http2Framer) ReadFrame() (http2Frame, error) { + if fr.lastFrame != nil { + fr.lastFrame.invalidate() + } + fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r) + if err != nil { + return nil, err + } + if fh.Length > fr.maxReadSize { + return nil, http2ErrFrameTooLarge + } + payload := fr.getReadBuf(fh.Length) + if _, err := io.ReadFull(fr.r, payload); err != nil { + return nil, err + } + f, err := http2typeFrameParser(fh.Type)(fh, payload) + if err != nil { + if ce, ok := err.(http2connError); ok { + return nil, fr.connError(ce.Code, ce.Reason) + } + return nil, err + } + if err := fr.checkFrameOrder(f); err != nil { + return nil, err + } + if fr.logReads { + log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + } + return f, nil +} + +// connError returns ConnectionError(code) but first +// stashes away a public reason to the caller can optionally relay it +// to the peer before hanging up on them. This might help others debug +// their implementations. +func (fr *http2Framer) connError(code http2ErrCode, reason string) error { + fr.errReason = reason + return http2ConnectionError(code) +} + +// checkFrameOrder reports an error if f is an invalid frame to return +// next from ReadFrame. Mostly it checks whether HEADERS and +// CONTINUATION frames are contiguous. +func (fr *http2Framer) checkFrameOrder(f http2Frame) error { + last := fr.lastFrame + fr.lastFrame = f + if fr.AllowIllegalReads { + return nil + } + + fh := f.Header() + if fr.lastHeaderStream != 0 { + if fh.Type != http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", + fh.Type, fh.StreamID, + last.Header().Type, fr.lastHeaderStream)) + } + if fh.StreamID != fr.lastHeaderStream { + return fr.connError(http2ErrCodeProtocol, + fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d", + fh.StreamID, fr.lastHeaderStream)) + } + } else if fh.Type == http2FrameContinuation { + return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID)) + } + + switch fh.Type { + case http2FrameHeaders, http2FrameContinuation: + if fh.Flags.Has(http2FlagHeadersEndHeaders) { + fr.lastHeaderStream = 0 + } else { + fr.lastHeaderStream = fh.StreamID + } + } + + return nil +} + +// A DataFrame conveys arbitrary, variable-length sequences of octets +// associated with a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.1 +type http2DataFrame struct { + http2FrameHeader + data []byte +} + +func (f *http2DataFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream) +} + +// Data returns the frame's data octets, not including any padding +// size byte or padding suffix bytes. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2DataFrame) Data() []byte { + f.checkValid() + return f.data +} + +func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + + return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} + } + f := &http2DataFrame{ + http2FrameHeader: fh, + } + var padSize byte + if fh.Flags.Has(http2FlagDataPadded) { + var err error + payload, padSize, err = http2readByte(payload) + if err != nil { + return nil, err + } + } + if int(padSize) > len(payload) { + + return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} + } + f.data = payload[:len(payload)-int(padSize)] + return f, nil +} + +var http2errStreamID = errors.New("invalid streamid") + +func http2validStreamID(streamID uint32) bool { + return streamID != 0 && streamID&(1<<31) == 0 +} + +// WriteData writes a DATA frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { + + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if endStream { + flags |= http2FlagDataEndStream + } + f.startWrite(http2FrameData, flags, streamID) + f.wbuf = append(f.wbuf, data...) + return f.endWrite() +} + +// A SettingsFrame conveys configuration parameters that affect how +// endpoints communicate, such as preferences and constraints on peer +// behavior. +// +// See http://http2.github.io/http2-spec/#SETTINGS +type http2SettingsFrame struct { + http2FrameHeader + p []byte +} + +func http2parseSettingsFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { + if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { + + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p)%6 != 0 { + + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + f := &http2SettingsFrame{http2FrameHeader: fh, p: p} + if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { + + return nil, http2ConnectionError(http2ErrCodeFlowControl) + } + return f, nil +} + +func (f *http2SettingsFrame) IsAck() bool { + return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck) +} + +func (f *http2SettingsFrame) Value(s http2SettingID) (v uint32, ok bool) { + f.checkValid() + buf := f.p + for len(buf) > 0 { + settingID := http2SettingID(binary.BigEndian.Uint16(buf[:2])) + if settingID == s { + return binary.BigEndian.Uint32(buf[2:6]), true + } + buf = buf[6:] + } + return 0, false +} + +// ForeachSetting runs fn for each setting. +// It stops and returns the first error. +func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error { + f.checkValid() + buf := f.p + for len(buf) > 0 { + if err := fn(http2Setting{ + http2SettingID(binary.BigEndian.Uint16(buf[:2])), + binary.BigEndian.Uint32(buf[2:6]), + }); err != nil { + return err + } + buf = buf[6:] + } + return nil +} + +// WriteSettings writes a SETTINGS frame with zero or more settings +// specified and the ACK bit not set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettings(settings ...http2Setting) error { + f.startWrite(http2FrameSettings, 0, 0) + for _, s := range settings { + f.writeUint16(uint16(s.ID)) + f.writeUint32(s.Val) + } + return f.endWrite() +} + +// WriteSettings writes an empty SETTINGS frame with the ACK bit set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteSettingsAck() error { + f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0) + return f.endWrite() +} + +// A PingFrame is a mechanism for measuring a minimal round trip time +// from the sender, as well as determining whether an idle connection +// is still functional. +// See http://http2.github.io/http2-spec/#rfc.section.6.7 +type http2PingFrame struct { + http2FrameHeader + Data [8]byte +} + +func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } + +func http2parsePingFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) { + if len(payload) != 8 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID != 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + f := &http2PingFrame{http2FrameHeader: fh} + copy(f.Data[:], payload) + return f, nil +} + +func (f *http2Framer) WritePing(ack bool, data [8]byte) error { + var flags http2Flags + if ack { + flags = http2FlagPingAck + } + f.startWrite(http2FramePing, flags, 0) + f.writeBytes(data[:]) + return f.endWrite() +} + +// A GoAwayFrame informs the remote peer to stop creating streams on this connection. +// See http://http2.github.io/http2-spec/#rfc.section.6.8 +type http2GoAwayFrame struct { + http2FrameHeader + LastStreamID uint32 + ErrCode http2ErrCode + debugData []byte +} + +// DebugData returns any debug data in the GOAWAY frame. Its contents +// are not defined. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *http2GoAwayFrame) DebugData() []byte { + f.checkValid() + return f.debugData +} + +func http2parseGoAwayFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { + if fh.StreamID != 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + if len(p) < 8 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + return &http2GoAwayFrame{ + http2FrameHeader: fh, + LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), + ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])), + debugData: p[8:], + }, nil +} + +func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error { + f.startWrite(http2FrameGoAway, 0, 0) + f.writeUint32(maxStreamID & (1<<31 - 1)) + f.writeUint32(uint32(code)) + f.writeBytes(debugData) + return f.endWrite() +} + +// An UnknownFrame is the frame type returned when the frame type is unknown +// or no specific frame type parser exists. +type http2UnknownFrame struct { + http2FrameHeader + p []byte +} + +// Payload returns the frame's payload (after the header). It is not +// valid to call this method after a subsequent call to +// Framer.ReadFrame, nor is it valid to retain the returned slice. +// The memory is owned by the Framer and is invalidated when the next +// frame is read. +func (f *http2UnknownFrame) Payload() []byte { + f.checkValid() + return f.p +} + +func http2parseUnknownFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { + return &http2UnknownFrame{fh, p}, nil +} + +// A WindowUpdateFrame is used to implement flow control. +// See http://http2.github.io/http2-spec/#rfc.section.6.9 +type http2WindowUpdateFrame struct { + http2FrameHeader + Increment uint32 // never read with high bit set +} + +func http2parseWindowUpdateFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { + if len(p) != 4 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff + if inc == 0 { + + if fh.StreamID == 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol} + } + return &http2WindowUpdateFrame{ + http2FrameHeader: fh, + Increment: inc, + }, nil +} + +// WriteWindowUpdate writes a WINDOW_UPDATE frame. +// The increment value must be between 1 and 2,147,483,647, inclusive. +// If the Stream ID is zero, the window update applies to the +// connection as a whole. +func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { + + if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { + return errors.New("illegal window increment value") + } + f.startWrite(http2FrameWindowUpdate, 0, streamID) + f.writeUint32(incr) + return f.endWrite() +} + +// A HeadersFrame is used to open a stream and additionally carries a +// header block fragment. +type http2HeadersFrame struct { + http2FrameHeader + + // Priority is set if FlagHeadersPriority is set in the FrameHeader. + Priority http2PriorityParam + + headerFragBuf []byte // not owned +} + +func (f *http2HeadersFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2HeadersFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders) +} + +func (f *http2HeadersFrame) StreamEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream) +} + +func (f *http2HeadersFrame) HasPriority() bool { + return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) +} + +func http2parseHeadersFrame(fh http2FrameHeader, p []byte) (_ http2Frame, err error) { + hf := &http2HeadersFrame{ + http2FrameHeader: fh, + } + if fh.StreamID == 0 { + + return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} + } + var padLength uint8 + if fh.Flags.Has(http2FlagHeadersPadded) { + if p, padLength, err = http2readByte(p); err != nil { + return + } + } + if fh.Flags.Has(http2FlagHeadersPriority) { + var v uint32 + p, v, err = http2readUint32(p) + if err != nil { + return nil, err + } + hf.Priority.StreamDep = v & 0x7fffffff + hf.Priority.Exclusive = (v != hf.Priority.StreamDep) + p, hf.Priority.Weight, err = http2readByte(p) + if err != nil { + return nil, err + } + } + if len(p)-int(padLength) <= 0 { + return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol} + } + hf.headerFragBuf = p[:len(p)-int(padLength)] + return hf, nil +} + +// HeadersFrameParam are the parameters for writing a HEADERS frame. +type http2HeadersFrameParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndStream indicates that the header block is the last that + // the endpoint will send for the identified stream. Setting + // this flag causes the stream to enter one of "half closed" + // states. + EndStream bool + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 + + // Priority, if non-zero, includes stream priority information + // in the HEADER frame. + Priority http2PriorityParam +} + +// WriteHeaders writes a single HEADERS frame. +// +// This is a low-level header writing method. Encoding headers and +// splitting them into any necessary CONTINUATION frames is handled +// elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagHeadersPadded + } + if p.EndStream { + flags |= http2FlagHeadersEndStream + } + if p.EndHeaders { + flags |= http2FlagHeadersEndHeaders + } + if !p.Priority.IsZero() { + flags |= http2FlagHeadersPriority + } + f.startWrite(http2FrameHeaders, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !p.Priority.IsZero() { + v := p.Priority.StreamDep + if !http2validStreamID(v) && !f.AllowIllegalWrites { + return errors.New("invalid dependent stream id") + } + if p.Priority.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Priority.Weight) + } + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// A PriorityFrame specifies the sender-advised priority of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.3 +type http2PriorityFrame struct { + http2FrameHeader + http2PriorityParam +} + +// PriorityParam are the stream prioritzation parameters. +type http2PriorityParam struct { + // StreamDep is a 31-bit stream identifier for the + // stream that this stream depends on. Zero means no + // dependency. + StreamDep uint32 + + // Exclusive is whether the dependency is exclusive. + Exclusive bool + + // Weight is the stream's zero-indexed weight. It should be + // set together with StreamDep, or neither should be set. Per + // the spec, "Add one to the value to obtain a weight between + // 1 and 256." + Weight uint8 +} + +func (p http2PriorityParam) IsZero() bool { + return p == http2PriorityParam{} +} + +func http2parsePriorityFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) { + if fh.StreamID == 0 { + return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} + } + if len(payload) != 5 { + return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} + } + v := binary.BigEndian.Uint32(payload[:4]) + streamID := v & 0x7fffffff + return &http2PriorityFrame{ + http2FrameHeader: fh, + http2PriorityParam: http2PriorityParam{ + Weight: payload[4], + StreamDep: streamID, + Exclusive: streamID != v, + }, + }, nil +} + +// WritePriority writes a PRIORITY frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.startWrite(http2FramePriority, 0, streamID) + v := p.StreamDep + if p.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Weight) + return f.endWrite() +} + +// A RSTStreamFrame allows for abnormal termination of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.4 +type http2RSTStreamFrame struct { + http2FrameHeader + ErrCode http2ErrCode +} + +func http2parseRSTStreamFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { + if len(p) != 4 { + return nil, http2ConnectionError(http2ErrCodeFrameSize) + } + if fh.StreamID == 0 { + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil +} + +// WriteRSTStream writes a RST_STREAM frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.startWrite(http2FrameRSTStream, 0, streamID) + f.writeUint32(uint32(code)) + return f.endWrite() +} + +// A ContinuationFrame is used to continue a sequence of header block fragments. +// See http://http2.github.io/http2-spec/#rfc.section.6.10 +type http2ContinuationFrame struct { + http2FrameHeader + headerFragBuf []byte +} + +func http2parseContinuationFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { + if fh.StreamID == 0 { + return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} + } + return &http2ContinuationFrame{fh, p}, nil +} + +func (f *http2ContinuationFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2ContinuationFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders) +} + +// WriteContinuation writes a CONTINUATION frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if !http2validStreamID(streamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if endHeaders { + flags |= http2FlagContinuationEndHeaders + } + f.startWrite(http2FrameContinuation, flags, streamID) + f.wbuf = append(f.wbuf, headerBlockFragment...) + return f.endWrite() +} + +// A PushPromiseFrame is used to initiate a server stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.6 +type http2PushPromiseFrame struct { + http2FrameHeader + PromiseID uint32 + headerFragBuf []byte // not owned +} + +func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *http2PushPromiseFrame) HeadersEnded() bool { + return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) +} + +func http2parsePushPromise(fh http2FrameHeader, p []byte) (_ http2Frame, err error) { + pp := &http2PushPromiseFrame{ + http2FrameHeader: fh, + } + if pp.StreamID == 0 { + + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + // The PUSH_PROMISE frame includes optional padding. + // Padding fields and flags are identical to those defined for DATA frames + var padLength uint8 + if fh.Flags.Has(http2FlagPushPromisePadded) { + if p, padLength, err = http2readByte(p); err != nil { + return + } + } + + p, pp.PromiseID, err = http2readUint32(p) + if err != nil { + return + } + pp.PromiseID = pp.PromiseID & (1<<31 - 1) + + if int(padLength) > len(p) { + + return nil, http2ConnectionError(http2ErrCodeProtocol) + } + pp.headerFragBuf = p[:len(p)-int(padLength)] + return pp, nil +} + +// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. +type http2PushPromiseParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + + // PromiseID is the required Stream ID which this + // Push Promises + PromiseID uint32 + + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 +} + +// WritePushPromise writes a single PushPromise Frame. +// +// As with Header Frames, This is the low level call for writing +// individual frames. Continuation frames are handled elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error { + if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return http2errStreamID + } + var flags http2Flags + if p.PadLength != 0 { + flags |= http2FlagPushPromisePadded + } + if p.EndHeaders { + flags |= http2FlagPushPromiseEndHeaders + } + f.startWrite(http2FramePushPromise, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites { + return http2errStreamID + } + f.writeUint32(p.PromiseID) + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...) + return f.endWrite() +} + +// WriteRawFrame writes a raw frame. This can be used to write +// extension frames unknown to this package. +func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error { + f.startWrite(t, flags, streamID) + f.writeBytes(payload) + return f.endWrite() +} + +func http2readByte(p []byte) (remain []byte, b byte, err error) { + if len(p) == 0 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[1:], p[0], nil +} + +func http2readUint32(p []byte) (remain []byte, v uint32, err error) { + if len(p) < 4 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[4:], binary.BigEndian.Uint32(p[:4]), nil +} + +type http2streamEnder interface { + StreamEnded() bool +} + +type http2headersEnder interface { + HeadersEnded() bool +} + +func http2summarizeFrame(f http2Frame) string { + var buf bytes.Buffer + f.Header().writeDebug(&buf) + switch f := f.(type) { + case *http2SettingsFrame: + n := 0 + f.ForeachSetting(func(s http2Setting) error { + n++ + if n == 1 { + buf.WriteString(", settings:") + } + fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val) + return nil + }) + if n > 0 { + buf.Truncate(buf.Len() - 1) + } + case *http2DataFrame: + data := f.Data() + const max = 256 + if len(data) > max { + data = data[:max] + } + fmt.Fprintf(&buf, " data=%q", data) + if len(f.Data()) > max { + fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max) + } + case *http2WindowUpdateFrame: + if f.StreamID == 0 { + buf.WriteString(" (conn)") + } + fmt.Fprintf(&buf, " incr=%v", f.Increment) + case *http2PingFrame: + fmt.Fprintf(&buf, " ping=%q", f.Data[:]) + case *http2GoAwayFrame: + fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q", + f.LastStreamID, f.ErrCode, f.debugData) + case *http2RSTStreamFrame: + fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode) + } + return buf.String() +} + +func http2requestCancel(req *Request) <-chan struct{} { return req.Cancel } + +var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" + +type http2goroutineLock uint64 + +func http2newGoroutineLock() http2goroutineLock { + if !http2DebugGoroutines { + return 0 + } + return http2goroutineLock(http2curGoroutineID()) +} + +func (g http2goroutineLock) check() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() != uint64(g) { + panic("running on the wrong goroutine") + } +} + +func (g http2goroutineLock) checkNotOn() { + if !http2DebugGoroutines { + return + } + if http2curGoroutineID() == uint64(g) { + panic("running on the wrong goroutine") + } +} + +var http2goroutineSpace = []byte("goroutine ") + +func http2curGoroutineID() uint64 { + bp := http2littleBuf.Get().(*[]byte) + defer http2littleBuf.Put(bp) + b := *bp + b = b[:runtime.Stack(b, false)] + + b = bytes.TrimPrefix(b, http2goroutineSpace) + i := bytes.IndexByte(b, ' ') + if i < 0 { + panic(fmt.Sprintf("No space found in %q", b)) + } + b = b[:i] + n, err := http2parseUintBytes(b, 10, 64) + if err != nil { + panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err)) + } + return n +} + +var http2littleBuf = sync.Pool{ + New: func() interface{} { + buf := make([]byte, 64) + return &buf + }, +} + +// parseUintBytes is like strconv.ParseUint, but using a []byte. +func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { + var cutoff, maxVal uint64 + + if bitSize == 0 { + bitSize = int(strconv.IntSize) + } + + s0 := s + switch { + case len(s) < 1: + err = strconv.ErrSyntax + goto Error + + case 2 <= base && base <= 36: + + case base == 0: + + switch { + case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'): + base = 16 + s = s[2:] + if len(s) < 1 { + err = strconv.ErrSyntax + goto Error + } + case s[0] == '0': + base = 8 + default: + base = 10 + } + + default: + err = errors.New("invalid base " + strconv.Itoa(base)) + goto Error + } + + n = 0 + cutoff = http2cutoff64(base) + maxVal = 1<<uint(bitSize) - 1 + + for i := 0; i < len(s); i++ { + var v byte + d := s[i] + switch { + case '0' <= d && d <= '9': + v = d - '0' + case 'a' <= d && d <= 'z': + v = d - 'a' + 10 + case 'A' <= d && d <= 'Z': + v = d - 'A' + 10 + default: + n = 0 + err = strconv.ErrSyntax + goto Error + } + if int(v) >= base { + n = 0 + err = strconv.ErrSyntax + goto Error + } + + if n >= cutoff { + + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n *= uint64(base) + + n1 := n + uint64(v) + if n1 < n || n1 > maxVal { + + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n = n1 + } + + return n, nil + +Error: + return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err} +} + +// Return the first number n such that n*base >= 1<<64. +func http2cutoff64(base int) uint64 { + if base < 2 { + return 0 + } + return (1<<64-1)/uint64(base) + 1 +} + +var ( + http2commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case + http2commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case +) + +func init() { + for _, v := range []string{ + "accept", + "accept-charset", + "accept-encoding", + "accept-language", + "accept-ranges", + "age", + "access-control-allow-origin", + "allow", + "authorization", + "cache-control", + "content-disposition", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-range", + "content-type", + "cookie", + "date", + "etag", + "expect", + "expires", + "from", + "host", + "if-match", + "if-modified-since", + "if-none-match", + "if-unmodified-since", + "last-modified", + "link", + "location", + "max-forwards", + "proxy-authenticate", + "proxy-authorization", + "range", + "referer", + "refresh", + "retry-after", + "server", + "set-cookie", + "strict-transport-security", + "trailer", + "transfer-encoding", + "user-agent", + "vary", + "via", + "www-authenticate", + } { + chk := CanonicalHeaderKey(v) + http2commonLowerHeader[chk] = v + http2commonCanonHeader[v] = chk + } +} + +func http2lowerHeader(v string) string { + if s, ok := http2commonLowerHeader[v]; ok { + return s + } + return strings.ToLower(v) +} + +var ( + http2VerboseLogs bool + http2logFrameWrites bool + http2logFrameReads bool +) + +func init() { + e := os.Getenv("GODEBUG") + if strings.Contains(e, "http2debug=1") { + http2VerboseLogs = true + } + if strings.Contains(e, "http2debug=2") { + http2VerboseLogs = true + http2logFrameWrites = true + http2logFrameReads = true + } +} + +const ( + // ClientPreface is the string that must be sent by new + // connections from clients. + http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + // SETTINGS_MAX_FRAME_SIZE default + // http://http2.github.io/http2-spec/#rfc.section.6.5.2 + http2initialMaxFrameSize = 16384 + + // NextProtoTLS is the NPN/ALPN protocol negotiated during + // HTTP/2's TLS setup. + http2NextProtoTLS = "h2" + + // http://http2.github.io/http2-spec/#SettingValues + http2initialHeaderTableSize = 4096 + + http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size + + http2defaultMaxReadFrameSize = 1 << 20 +) + +var ( + http2clientPreface = []byte(http2ClientPreface) +) + +type http2streamState int + +const ( + http2stateIdle http2streamState = iota + http2stateOpen + http2stateHalfClosedLocal + http2stateHalfClosedRemote + http2stateResvLocal + http2stateResvRemote + http2stateClosed +) + +var http2stateName = [...]string{ + http2stateIdle: "Idle", + http2stateOpen: "Open", + http2stateHalfClosedLocal: "HalfClosedLocal", + http2stateHalfClosedRemote: "HalfClosedRemote", + http2stateResvLocal: "ResvLocal", + http2stateResvRemote: "ResvRemote", + http2stateClosed: "Closed", +} + +func (st http2streamState) String() string { + return http2stateName[st] +} + +// Setting is a setting parameter: which setting it is, and its value. +type http2Setting struct { + // ID is which setting is being set. + // See http://http2.github.io/http2-spec/#SettingValues + ID http2SettingID + + // Val is the value. + Val uint32 +} + +func (s http2Setting) String() string { + return fmt.Sprintf("[%v = %d]", s.ID, s.Val) +} + +// Valid reports whether the setting is valid. +func (s http2Setting) Valid() error { + + switch s.ID { + case http2SettingEnablePush: + if s.Val != 1 && s.Val != 0 { + return http2ConnectionError(http2ErrCodeProtocol) + } + case http2SettingInitialWindowSize: + if s.Val > 1<<31-1 { + return http2ConnectionError(http2ErrCodeFlowControl) + } + case http2SettingMaxFrameSize: + if s.Val < 16384 || s.Val > 1<<24-1 { + return http2ConnectionError(http2ErrCodeProtocol) + } + } + return nil +} + +// A SettingID is an HTTP/2 setting as defined in +// http://http2.github.io/http2-spec/#iana-settings +type http2SettingID uint16 + +const ( + http2SettingHeaderTableSize http2SettingID = 0x1 + http2SettingEnablePush http2SettingID = 0x2 + http2SettingMaxConcurrentStreams http2SettingID = 0x3 + http2SettingInitialWindowSize http2SettingID = 0x4 + http2SettingMaxFrameSize http2SettingID = 0x5 + http2SettingMaxHeaderListSize http2SettingID = 0x6 +) + +var http2settingName = map[http2SettingID]string{ + http2SettingHeaderTableSize: "HEADER_TABLE_SIZE", + http2SettingEnablePush: "ENABLE_PUSH", + http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + http2SettingMaxFrameSize: "MAX_FRAME_SIZE", + http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", +} + +func (s http2SettingID) String() string { + if v, ok := http2settingName[s]; ok { + return v + } + return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) +} + +var ( + http2errInvalidHeaderFieldName = errors.New("http2: invalid header field name") + http2errInvalidHeaderFieldValue = errors.New("http2: invalid header field value") +) + +// validHeaderFieldName reports whether v is a valid header field name (key). +// RFC 7230 says: +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / " +// Further, http2 says: +// "Just as in HTTP/1.x, header field names are strings of ASCII +// characters that are compared in a case-insensitive +// fashion. However, header field names MUST be converted to +// lowercase prior to their encoding in HTTP/2. " +func http2validHeaderFieldName(v string) bool { + if len(v) == 0 { + return false + } + for _, r := range v { + if int(r) >= len(http2isTokenTable) || ('A' <= r && r <= 'Z') { + return false + } + if !http2isTokenTable[byte(r)] { + return false + } + } + return true +} + +// validHeaderFieldValue reports whether v is a valid header field value. +// +// RFC 7230 says: +// field-value = *( field-content / obs-fold ) +// obj-fold = N/A to http2, and deprecated +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// VCHAR = "any visible [USASCII] character" +// +// http2 further says: "Similarly, HTTP/2 allows header field values +// that are not valid. While most of the values that can be encoded +// will not alter header field parsing, carriage return (CR, ASCII +// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII +// 0x0) might be exploited by an attacker if they are translated +// verbatim. Any request or response that contains a character not +// permitted in a header field value MUST be treated as malformed +// (Section 8.1.2.6). Valid characters are defined by the +// field-content ABNF rule in Section 3.2 of [RFC7230]." +// +// This function does not (yet?) properly handle the rejection of +// strings that begin or end with SP or HTAB. +func http2validHeaderFieldValue(v string) bool { + for i := 0; i < len(v); i++ { + if b := v[i]; b < ' ' && b != '\t' || b == 0x7f { + return false + } + } + return true +} + +var http2httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n) + +func init() { + for i := 100; i <= 999; i++ { + if v := StatusText(i); v != "" { + http2httpCodeStringCommon[i] = strconv.Itoa(i) + } + } +} + +func http2httpCodeString(code int) string { + if s, ok := http2httpCodeStringCommon[code]; ok { + return s + } + return strconv.Itoa(code) +} + +// from pkg io +type http2stringWriter interface { + WriteString(s string) (n int, err error) +} + +// A gate lets two goroutines coordinate their activities. +type http2gate chan struct{} + +func (g http2gate) Done() { g <- struct{}{} } + +func (g http2gate) Wait() { <-g } + +// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). +type http2closeWaiter chan struct{} + +// Init makes a closeWaiter usable. +// It exists because so a closeWaiter value can be placed inside a +// larger struct and have the Mutex and Cond's memory in the same +// allocation. +func (cw *http2closeWaiter) Init() { + *cw = make(chan struct{}) +} + +// Close marks the closeWaiter as closed and unblocks any waiters. +func (cw http2closeWaiter) Close() { + close(cw) +} + +// Wait waits for the closeWaiter to become closed. +func (cw http2closeWaiter) Wait() { + <-cw +} + +// bufferedWriter is a buffered writer that writes to w. +// Its buffered writer is lazily allocated as needed, to minimize +// idle memory usage with many connections. +type http2bufferedWriter struct { + w io.Writer // immutable + bw *bufio.Writer // non-nil when data is buffered +} + +func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { + return &http2bufferedWriter{w: w} +} + +var http2bufWriterPool = sync.Pool{ + New: func() interface{} { + + return bufio.NewWriterSize(nil, 4<<10) + }, +} + +func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { + if w.bw == nil { + bw := http2bufWriterPool.Get().(*bufio.Writer) + bw.Reset(w.w) + w.bw = bw + } + return w.bw.Write(p) +} + +func (w *http2bufferedWriter) Flush() error { + bw := w.bw + if bw == nil { + return nil + } + err := bw.Flush() + bw.Reset(nil) + http2bufWriterPool.Put(bw) + w.bw = nil + return err +} + +func http2mustUint31(v int32) uint32 { + if v < 0 || v > 2147483647 { + panic("out of range") + } + return uint32(v) +} + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC2616, section 4.4. +func http2bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +type http2httpError struct { + msg string + timeout bool +} + +func (e *http2httpError) Error() string { return e.msg } + +func (e *http2httpError) Timeout() bool { return e.timeout } + +func (e *http2httpError) Temporary() bool { return true } + +var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} + +var http2isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like +// io.Pipe except there are no PipeReader/PipeWriter halves, and the +// underlying buffer is an interface. (io.Pipe is always unbuffered) +type http2pipe struct { + mu sync.Mutex + c sync.Cond // c.L lazily initialized to &p.mu + b http2pipeBuffer + err error // read error once empty. non-nil means closed. + breakErr error // immediate read error (caller doesn't see rest of b) + donec chan struct{} // closed on error + readFn func() // optional code to run in Read before error +} + +type http2pipeBuffer interface { + Len() int + io.Writer + io.Reader +} + +// Read waits until data is available and copies bytes +// from the buffer into p. +func (p *http2pipe) Read(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + for { + if p.breakErr != nil { + return 0, p.breakErr + } + if p.b.Len() > 0 { + return p.b.Read(d) + } + if p.err != nil { + if p.readFn != nil { + p.readFn() + p.readFn = nil + } + return 0, p.err + } + p.c.Wait() + } +} + +var http2errClosedPipeWrite = errors.New("write on closed buffer") + +// Write copies bytes from p into the buffer and wakes a reader. +// It is an error to write more data than the buffer can hold. +func (p *http2pipe) Write(d []byte) (n int, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if p.err != nil { + return 0, http2errClosedPipeWrite + } + return p.b.Write(d) +} + +// CloseWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err after all data has been +// read. +// +// The error must be non-nil. +func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) } + +// BreakWithError causes the next Read (waking up a current blocked +// Read if needed) to return the provided err immediately, without +// waiting for unread data. +func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) } + +// closeWithErrorAndCode is like CloseWithError but also sets some code to run +// in the caller's goroutine before returning the error. +func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) } + +func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { + if err == nil { + panic("err must be non-nil") + } + p.mu.Lock() + defer p.mu.Unlock() + if p.c.L == nil { + p.c.L = &p.mu + } + defer p.c.Signal() + if *dst != nil { + + return + } + p.readFn = fn + *dst = err + p.closeDoneLocked() +} + +// requires p.mu be held. +func (p *http2pipe) closeDoneLocked() { + if p.donec == nil { + return + } + + select { + case <-p.donec: + default: + close(p.donec) + } +} + +// Err returns the error (if any) first set by BreakWithError or CloseWithError. +func (p *http2pipe) Err() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.breakErr != nil { + return p.breakErr + } + return p.err +} + +// Done returns a channel which is closed if and when this pipe is closed +// with CloseWithError. +func (p *http2pipe) Done() <-chan struct{} { + p.mu.Lock() + defer p.mu.Unlock() + if p.donec == nil { + p.donec = make(chan struct{}) + if p.err != nil || p.breakErr != nil { + + p.closeDoneLocked() + } + } + return p.donec +} + +const ( + http2prefaceTimeout = 10 * time.Second + http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway + http2handlerChunkWriteSize = 4 << 10 + http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? +) + +var ( + http2errClientDisconnected = errors.New("client disconnected") + http2errClosedBody = errors.New("body closed by handler") + http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + http2errStreamClosed = errors.New("http2: stream closed") +) + +var http2responseWriterStatePool = sync.Pool{ + New: func() interface{} { + rws := &http2responseWriterState{} + rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize) + return rws + }, +} + +// Test hooks. +var ( + http2testHookOnConn func() + http2testHookGetServerConn func(*http2serverConn) + http2testHookOnPanicMu *sync.Mutex // nil except in tests + http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool) +) + +// Server is an HTTP/2 server. +type http2Server struct { + // MaxHandlers limits the number of http.Handler ServeHTTP goroutines + // which may run at a time over all connections. + // Negative or zero no limit. + // TODO: implement + MaxHandlers int + + // MaxConcurrentStreams optionally specifies the number of + // concurrent streams that each client may have open at a + // time. This is unrelated to the number of http.Handler goroutines + // which may be active globally, which is MaxHandlers. + // If zero, MaxConcurrentStreams defaults to at least 100, per + // the HTTP/2 spec's recommendations. + MaxConcurrentStreams uint32 + + // MaxReadFrameSize optionally specifies the largest frame + // this server is willing to read. A valid value is between + // 16k and 16M, inclusive. If zero or otherwise invalid, a + // default value is used. + MaxReadFrameSize uint32 + + // PermitProhibitedCipherSuites, if true, permits the use of + // cipher suites prohibited by the HTTP/2 spec. + PermitProhibitedCipherSuites bool +} + +func (s *http2Server) maxReadFrameSize() uint32 { + if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize { + return v + } + return http2defaultMaxReadFrameSize +} + +func (s *http2Server) maxConcurrentStreams() uint32 { + if v := s.MaxConcurrentStreams; v > 0 { + return v + } + return http2defaultMaxStreams +} + +// ConfigureServer adds HTTP/2 support to a net/http Server. +// +// The configuration conf may be nil. +// +// ConfigureServer must be called before s begins serving. +func http2ConfigureServer(s *Server, conf *http2Server) error { + if conf == nil { + conf = new(http2Server) + } + + if s.TLSConfig == nil { + s.TLSConfig = new(tls.Config) + } else if s.TLSConfig.CipherSuites != nil { + // If they already provided a CipherSuite list, return + // an error if it has a bad order or is missing + // ECDHE_RSA_WITH_AES_128_GCM_SHA256. + const requiredCipher = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + haveRequired := false + sawBad := false + for i, cs := range s.TLSConfig.CipherSuites { + if cs == requiredCipher { + haveRequired = true + } + if http2isBadCipher(cs) { + sawBad = true + } else if sawBad { + return fmt.Errorf("http2: TLSConfig.CipherSuites index %d contains an HTTP/2-approved cipher suite (%#04x), but it comes after unapproved cipher suites. With this configuration, clients that don't support previous, approved cipher suites may be given an unapproved one and reject the connection.", i, cs) + } + } + if !haveRequired { + return fmt.Errorf("http2: TLSConfig.CipherSuites is missing HTTP/2-required TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") + } + } + + s.TLSConfig.PreferServerCipherSuites = true + + haveNPN := false + for _, p := range s.TLSConfig.NextProtos { + if p == http2NextProtoTLS { + haveNPN = true + break + } + } + if !haveNPN { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) + } + + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14") + + if s.TLSNextProto == nil { + s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){} + } + protoHandler := func(hs *Server, c *tls.Conn, h Handler) { + if http2testHookOnConn != nil { + http2testHookOnConn() + } + conf.handleConn(hs, c, h) + } + s.TLSNextProto[http2NextProtoTLS] = protoHandler + s.TLSNextProto["h2-14"] = protoHandler + return nil +} + +func (srv *http2Server) handleConn(hs *Server, c net.Conn, h Handler) { + sc := &http2serverConn{ + srv: srv, + hs: hs, + conn: c, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: h, + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2frameWriteMsg, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), + bodyReadCh: make(chan http2bodyReadMsg), + doneServing: make(chan struct{}), + advMaxStreams: srv.maxConcurrentStreams(), + writeSched: http2writeScheduler{ + maxFrameSize: http2initialMaxFrameSize, + }, + initialWindowSize: http2initialWindowSize, + headerTableSize: http2initialHeaderTableSize, + serveG: http2newGoroutineLock(), + pushEnabled: true, + } + sc.flow.add(http2initialWindowSize) + sc.inflow.add(http2initialWindowSize) + sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) + sc.hpackDecoder = hpack.NewDecoder(http2initialHeaderTableSize, nil) + sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen()) + + fr := http2NewFramer(sc.bw, c) + fr.SetMaxReadFrameSize(srv.maxReadFrameSize()) + sc.framer = fr + + if tc, ok := c.(*tls.Conn); ok { + sc.tlsState = new(tls.ConnectionState) + *sc.tlsState = tc.ConnectionState() + + if sc.tlsState.Version < tls.VersionTLS12 { + sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low") + return + } + + if sc.tlsState.ServerName == "" { + + } + + if !srv.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) { + + sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) + return + } + } + + if hook := http2testHookGetServerConn; hook != nil { + hook(sc) + } + sc.serve() +} + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +func http2isBadCipher(cipher uint16) bool { + switch cipher { + case tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + + return true + default: + return false + } +} + +func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { + sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) + + sc.framer.WriteGoAway(0, err, []byte(debug)) + sc.bw.Flush() + sc.conn.Close() +} + +type http2serverConn struct { + // Immutable: + srv *http2Server + hs *Server + conn net.Conn + bw *http2bufferedWriter // writing to conn + handler Handler + framer *http2Framer + hpackDecoder *hpack.Decoder + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + testHookCh chan func(int) // code to run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http + remoteAddrStr string + + // Everything following is owned by the serve loop; use serveG.check(): + serveG http2goroutineLock // used to verify funcs are on serve() + pushEnabled bool + sawFirstSettings bool // got the initial SETTINGS frame after the preface + needToSendSettingsAck bool + unackedSettings int // how many SETTINGS have we sent without ACKs? + clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) + advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client + curOpenStreams uint32 // client's number of open streams + maxStreamID uint32 // max ever seen + streams map[uint32]*http2stream + initialWindowSize int32 + headerTableSize uint32 + peerMaxHeaderListSize uint32 // zero means unknown (default) + canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case + req http2requestParam // non-zero while reading request headers + writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh + needsFrameFlush bool // last frame write wasn't a flush + writeSched http2writeScheduler + inGoAway bool // we've started to or sent GOAWAY + needToSendGoAway bool // we need to schedule a GOAWAY frame write + goAwayCode http2ErrCode + shutdownTimerCh <-chan time.Time // nil until used + shutdownTimer *time.Timer // nil until used + + // Owned by the writeFrameAsync goroutine: + headerWriteBuf bytes.Buffer + hpackEncoder *hpack.Encoder +} + +func (sc *http2serverConn) maxHeaderStringLen() int { + v := sc.maxHeaderListSize() + if uint32(int(v)) == v { + return int(v) + } + + return 0 +} + +func (sc *http2serverConn) maxHeaderListSize() uint32 { + n := sc.hs.MaxHeaderBytes + if n <= 0 { + n = DefaultMaxHeaderBytes + } + // http2's count is in a slightly different unit and includes 32 bytes per pair. + // So, take the net/http.Server value and pad it up a bit, assuming 10 headers. + const perFieldOverhead = 32 // per http2 spec + const typicalHeaders = 10 // conservative + return uint32(n + typicalHeaders*perFieldOverhead) +} + +// requestParam is the state of the next request, initialized over +// potentially several frames HEADERS + zero or more CONTINUATION +// frames. +type http2requestParam struct { + // stream is non-nil if we're reading (HEADER or CONTINUATION) + // frames for a request (but not DATA). + stream *http2stream + header Header + method, path string + scheme, authority string + sawRegularHeader bool // saw a non-pseudo header already + invalidHeader bool // an invalid header was seen + headerListSize int64 // actually uint32, but easier math this way +} + +// stream represents a stream. This is the minimal metadata needed by +// the serve goroutine. Most of the actual stream state is owned by +// the http.Handler's goroutine in the responseWriter. Because the +// responseWriter's responseWriterState is recycled at the end of a +// handler, this struct intentionally has no pointer to the +// *responseWriter{,State} itself, as the Handler ending nils out the +// responseWriter's state field. +type http2stream struct { + // immutable: + sc *http2serverConn + id uint32 + body *http2pipe // non-nil if expecting DATA frames + cw http2closeWaiter // closed wait stream transitions to closed state + + // owned by serverConn's serve loop: + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow http2flow // limits writing from Handler to client + inflow http2flow // what the client is allowed to POST/etc to us + parent *http2stream // or nil + numTrailerValues int64 + weight uint8 + state http2streamState + sentReset bool // only true once detached from streams map + gotReset bool // only true once detacted from streams map + gotTrailerHeader bool // HEADER frame for trailers was seen + + trailer Header // accumulated trailers + reqTrailer Header // handler's Request.Trailer +} + +func (sc *http2serverConn) Framer() *http2Framer { return sc.framer } + +func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() } + +func (sc *http2serverConn) Flush() error { return sc.bw.Flush() } + +func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { + return sc.hpackEncoder, &sc.headerWriteBuf +} + +func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) { + sc.serveG.check() + + if st, ok := sc.streams[streamID]; ok { + return st.state, st + } + + if streamID <= sc.maxStreamID { + return http2stateClosed, nil + } + return http2stateIdle, nil +} + +// setConnState calls the net/http ConnState hook for this connection, if configured. +// Note that the net/http package does StateNew and StateClosed for us. +// There is currently no plan for StateHijacked or hijacking HTTP/2 connections. +func (sc *http2serverConn) setConnState(state ConnState) { + if sc.hs.ConnState != nil { + sc.hs.ConnState(sc.conn, state) + } +} + +func (sc *http2serverConn) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) logf(format string, args ...interface{}) { + if lg := sc.hs.ErrorLog; lg != nil { + lg.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +var http2uintptrType = reflect.TypeOf(uintptr(0)) + +// errno returns v's underlying uintptr, else 0. +// +// TODO: remove this helper function once http2 can use build +// tags. See comment in isClosedConnError. +func http2errno(v error) uintptr { + if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { + return uintptr(rv.Uint()) + } + return 0 +} + +// isClosedConnError reports whether err is an error from use of a closed +// network connection. +func http2isClosedConnError(err error) bool { + if err == nil { + return false + } + + str := err.Error() + if strings.Contains(str, "use of closed network connection") { + return true + } + + if runtime.GOOS == "windows" { + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { + const WSAECONNABORTED = 10053 + const WSAECONNRESET = 10054 + if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { + return true + } + } + } + } + return false +} + +func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) { + if err == nil { + return + } + if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) { + + sc.vlogf(format, args...) + } else { + sc.logf(format, args...) + } +} + +func (sc *http2serverConn) onNewHeaderField(f hpack.HeaderField) { + sc.serveG.check() + if http2VerboseLogs { + sc.vlogf("http2: server decoded %v", f) + } + switch { + case !http2validHeaderFieldValue(f.Value): + sc.req.invalidHeader = true + case strings.HasPrefix(f.Name, ":"): + if sc.req.sawRegularHeader { + sc.logf("pseudo-header after regular header") + sc.req.invalidHeader = true + return + } + var dst *string + switch f.Name { + case ":method": + dst = &sc.req.method + case ":path": + dst = &sc.req.path + case ":scheme": + dst = &sc.req.scheme + case ":authority": + dst = &sc.req.authority + default: + + sc.logf("invalid pseudo-header %q", f.Name) + sc.req.invalidHeader = true + return + } + if *dst != "" { + sc.logf("duplicate pseudo-header %q sent", f.Name) + sc.req.invalidHeader = true + return + } + *dst = f.Value + case !http2validHeaderFieldName(f.Name): + sc.req.invalidHeader = true + default: + sc.req.sawRegularHeader = true + sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value) + const headerFieldOverhead = 32 // per spec + sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead + if sc.req.headerListSize > int64(sc.maxHeaderListSize()) { + sc.hpackDecoder.SetEmitEnabled(false) + } + } +} + +func (st *http2stream) onNewTrailerField(f hpack.HeaderField) { + sc := st.sc + sc.serveG.check() + if http2VerboseLogs { + sc.vlogf("http2: server decoded trailer %v", f) + } + switch { + case strings.HasPrefix(f.Name, ":"): + sc.req.invalidHeader = true + return + case !http2validHeaderFieldName(f.Name) || !http2validHeaderFieldValue(f.Value): + sc.req.invalidHeader = true + return + default: + key := sc.canonicalHeader(f.Name) + if st.trailer != nil { + vv := append(st.trailer[key], f.Value) + st.trailer[key] = vv + + // arbitrary; TODO: read spec about header list size limits wrt trailers + const tooBig = 1000 + if len(vv) >= tooBig { + sc.hpackDecoder.SetEmitEnabled(false) + } + } + } +} + +func (sc *http2serverConn) canonicalHeader(v string) string { + sc.serveG.check() + cv, ok := http2commonCanonHeader[v] + if ok { + return cv + } + cv, ok = sc.canonHeader[v] + if ok { + return cv + } + if sc.canonHeader == nil { + sc.canonHeader = make(map[string]string) + } + cv = CanonicalHeaderKey(v) + sc.canonHeader[v] = cv + return cv +} + +type http2readFrameResult struct { + f http2Frame // valid until readMore is called + err error + + // readMore should be called once the consumer no longer needs or + // retains f. After readMore, f is invalid and more frames can be + // read. + readMore func() +} + +// readFrames is the loop that reads incoming frames. +// It takes care to only read one frame at a time, blocking until the +// consumer is done with the frame. +// It's run on its own goroutine. +func (sc *http2serverConn) readFrames() { + gate := make(http2gate) + for { + f, err := sc.framer.ReadFrame() + select { + case sc.readFrameCh <- http2readFrameResult{f, err, gate.Done}: + case <-sc.doneServing: + return + } + select { + case <-gate: + case <-sc.doneServing: + return + } + if http2terminalReadFrameError(err) { + return + } + } +} + +// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. +type http2frameWriteResult struct { + wm http2frameWriteMsg // what was written (or attempted) + err error // result of the writeFrame call +} + +// writeFrameAsync runs in its own goroutine and writes a single frame +// and then reports when it's done. +// At most one goroutine can be running writeFrameAsync at a time per +// serverConn. +func (sc *http2serverConn) writeFrameAsync(wm http2frameWriteMsg) { + err := wm.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wm, err} +} + +func (sc *http2serverConn) closeAllStreamsOnConnClose() { + sc.serveG.check() + for _, st := range sc.streams { + sc.closeStream(st, http2errClientDisconnected) + } +} + +func (sc *http2serverConn) stopShutdownTimer() { + sc.serveG.check() + if t := sc.shutdownTimer; t != nil { + t.Stop() + } +} + +func (sc *http2serverConn) notePanic() { + + if http2testHookOnPanicMu != nil { + http2testHookOnPanicMu.Lock() + defer http2testHookOnPanicMu.Unlock() + } + if http2testHookOnPanic != nil { + if e := recover(); e != nil { + if http2testHookOnPanic(sc, e) { + panic(e) + } + } + } +} + +func (sc *http2serverConn) serve() { + sc.serveG.check() + defer sc.notePanic() + defer sc.conn.Close() + defer sc.closeAllStreamsOnConnClose() + defer sc.stopShutdownTimer() + defer close(sc.doneServing) + + if http2VerboseLogs { + sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) + } + + sc.writeFrame(http2frameWriteMsg{ + write: http2writeSettings{ + {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, + {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, + {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, + }, + }) + sc.unackedSettings++ + + if err := sc.readPreface(); err != nil { + sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) + return + } + + sc.setConnState(StateActive) + sc.setConnState(StateIdle) + + go sc.readFrames() + + settingsTimer := time.NewTimer(http2firstSettingsTimeout) + loopNum := 0 + for { + loopNum++ + select { + case wm := <-sc.wantWriteFrameCh: + sc.writeFrame(wm) + case res := <-sc.wroteFrameCh: + sc.wroteFrame(res) + case res := <-sc.readFrameCh: + if !sc.processFrameFromReader(res) { + return + } + res.readMore() + if settingsTimer.C != nil { + settingsTimer.Stop() + settingsTimer.C = nil + } + case m := <-sc.bodyReadCh: + sc.noteBodyRead(m.st, m.n) + case <-settingsTimer.C: + sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) + return + case <-sc.shutdownTimerCh: + sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) + return + case fn := <-sc.testHookCh: + fn(loopNum) + } + } +} + +// readPreface reads the ClientPreface greeting from the peer +// or returns an error on timeout or an invalid greeting. +func (sc *http2serverConn) readPreface() error { + errc := make(chan error, 1) + go func() { + + buf := make([]byte, len(http2ClientPreface)) + if _, err := io.ReadFull(sc.conn, buf); err != nil { + errc <- err + } else if !bytes.Equal(buf, http2clientPreface) { + errc <- fmt.Errorf("bogus greeting %q", buf) + } else { + errc <- nil + } + }() + timer := time.NewTimer(http2prefaceTimeout) + defer timer.Stop() + select { + case <-timer.C: + return errors.New("timeout waiting for client preface") + case err := <-errc: + if err == nil { + if http2VerboseLogs { + sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr()) + } + } + return err + } +} + +var http2errChanPool = sync.Pool{ + New: func() interface{} { return make(chan error, 1) }, +} + +var http2writeDataPool = sync.Pool{ + New: func() interface{} { return new(http2writeData) }, +} + +// writeDataFromHandler writes DATA response frames from a handler on +// the given stream. +func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error { + ch := http2errChanPool.Get().(chan error) + writeArg := http2writeDataPool.Get().(*http2writeData) + *writeArg = http2writeData{stream.id, data, endStream} + err := sc.writeFrameFromHandler(http2frameWriteMsg{ + write: writeArg, + stream: stream, + done: ch, + }) + if err != nil { + return err + } + var frameWriteDone bool // the frame write is done (successfully or not) + select { + case err = <-ch: + frameWriteDone = true + case <-sc.doneServing: + return http2errClientDisconnected + case <-stream.cw: + + select { + case err = <-ch: + frameWriteDone = true + default: + return http2errStreamClosed + } + } + http2errChanPool.Put(ch) + if frameWriteDone { + http2writeDataPool.Put(writeArg) + } + return err +} + +// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts +// if the connection has gone away. +// +// This must not be run from the serve goroutine itself, else it might +// deadlock writing to sc.wantWriteFrameCh (which is only mildly +// buffered and is read by serve itself). If you're on the serve +// goroutine, call writeFrame instead. +func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { + sc.serveG.checkNotOn() + select { + case sc.wantWriteFrameCh <- wm: + return nil + case <-sc.doneServing: + + return http2errClientDisconnected + } +} + +// writeFrame schedules a frame to write and sends it if there's nothing +// already being written. +// +// There is no pushback here (the serve goroutine never blocks). It's +// the http.Handlers that block, waiting for their previous frames to +// make it onto the wire +// +// If you're not on the serve goroutine, use writeFrameFromHandler instead. +func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) { + sc.serveG.check() + sc.writeSched.add(wm) + sc.scheduleFrameWrite() +} + +// startFrameWrite starts a goroutine to write wm (in a separate +// goroutine since that might block on the network), and updates the +// serve goroutine's state about the world, updated from info in wm. +func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) { + sc.serveG.check() + if sc.writingFrame { + panic("internal error: can only be writing one frame at a time") + } + + st := wm.stream + if st != nil { + switch st.state { + case http2stateHalfClosedLocal: + panic("internal error: attempt to send frame on half-closed-local stream") + case http2stateClosed: + if st.sentReset || st.gotReset { + + sc.scheduleFrameWrite() + return + } + panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm)) + } + } + + sc.writingFrame = true + sc.needsFrameFlush = true + go sc.writeFrameAsync(wm) +} + +// errHandlerPanicked is the error given to any callers blocked in a read from +// Request.Body when the main goroutine panics. Since most handlers read in the +// the main ServeHTTP goroutine, this will show up rarely. +var http2errHandlerPanicked = errors.New("http2: handler panicked") + +// wroteFrame is called on the serve goroutine with the result of +// whatever happened on writeFrameAsync. +func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { + sc.serveG.check() + if !sc.writingFrame { + panic("internal error: expected to be already writing a frame") + } + sc.writingFrame = false + + wm := res.wm + st := wm.stream + + closeStream := http2endsStream(wm.write) + + if _, ok := wm.write.(http2handlerPanicRST); ok { + sc.closeStream(st, http2errHandlerPanicked) + } + + if ch := wm.done; ch != nil { + select { + case ch <- res.err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write)) + } + } + wm.write = nil + + if closeStream { + if st == nil { + panic("internal error: expecting non-nil stream") + } + switch st.state { + case http2stateOpen: + + st.state = http2stateHalfClosedLocal + errCancel := http2StreamError{st.id, http2ErrCodeCancel} + sc.resetStream(errCancel) + case http2stateHalfClosedRemote: + sc.closeStream(st, http2errHandlerComplete) + } + } + + sc.scheduleFrameWrite() +} + +// scheduleFrameWrite tickles the frame writing scheduler. +// +// If a frame is already being written, nothing happens. This will be called again +// when the frame is done being written. +// +// If a frame isn't being written we need to send one, the best frame +// to send is selected, preferring first things that aren't +// stream-specific (e.g. ACKing settings), and then finding the +// highest priority stream. +// +// If a frame isn't being written and there's nothing else to send, we +// flush the write buffer. +func (sc *http2serverConn) scheduleFrameWrite() { + sc.serveG.check() + if sc.writingFrame { + return + } + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2frameWriteMsg{ + write: &http2writeGoAway{ + maxStreamID: sc.maxStreamID, + code: sc.goAwayCode, + }, + }) + return + } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2frameWriteMsg{write: http2writeSettingsAck{}}) + return + } + if !sc.inGoAway { + if wm, ok := sc.writeSched.take(); ok { + sc.startFrameWrite(wm) + return + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2frameWriteMsg{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false + return + } +} + +func (sc *http2serverConn) goAway(code http2ErrCode) { + sc.serveG.check() + if sc.inGoAway { + return + } + if code != http2ErrCodeNo { + sc.shutDownIn(250 * time.Millisecond) + } else { + + sc.shutDownIn(1 * time.Second) + } + sc.inGoAway = true + sc.needToSendGoAway = true + sc.goAwayCode = code + sc.scheduleFrameWrite() +} + +func (sc *http2serverConn) shutDownIn(d time.Duration) { + sc.serveG.check() + sc.shutdownTimer = time.NewTimer(d) + sc.shutdownTimerCh = sc.shutdownTimer.C +} + +func (sc *http2serverConn) resetStream(se http2StreamError) { + sc.serveG.check() + sc.writeFrame(http2frameWriteMsg{write: se}) + if st, ok := sc.streams[se.StreamID]; ok { + st.sentReset = true + sc.closeStream(st, se) + } +} + +// processFrameFromReader processes the serve loop's read from readFrameCh from the +// frame-reading goroutine. +// processFrameFromReader returns whether the connection should be kept open. +func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool { + sc.serveG.check() + err := res.err + if err != nil { + if err == http2ErrFrameTooLarge { + sc.goAway(http2ErrCodeFrameSize) + return true + } + clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) + if clientGone { + + return false + } + } else { + f := res.f + if http2VerboseLogs { + sc.vlogf("http2: server read frame %v", http2summarizeFrame(f)) + } + err = sc.processFrame(f) + if err == nil { + return true + } + } + + switch ev := err.(type) { + case http2StreamError: + sc.resetStream(ev) + return true + case http2goAwayFlowError: + sc.goAway(http2ErrCodeFlowControl) + return true + case http2ConnectionError: + sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) + sc.goAway(http2ErrCode(ev)) + return true + default: + if res.err != nil { + sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) + } else { + sc.logf("http2: server closing client connection: %v", err) + } + return false + } +} + +func (sc *http2serverConn) processFrame(f http2Frame) error { + sc.serveG.check() + + if !sc.sawFirstSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + return http2ConnectionError(http2ErrCodeProtocol) + } + sc.sawFirstSettings = true + } + + switch f := f.(type) { + case *http2SettingsFrame: + return sc.processSettings(f) + case *http2HeadersFrame: + return sc.processHeaders(f) + case *http2ContinuationFrame: + return sc.processContinuation(f) + case *http2WindowUpdateFrame: + return sc.processWindowUpdate(f) + case *http2PingFrame: + return sc.processPing(f) + case *http2DataFrame: + return sc.processData(f) + case *http2RSTStreamFrame: + return sc.processResetStream(f) + case *http2PriorityFrame: + return sc.processPriority(f) + case *http2PushPromiseFrame: + + return http2ConnectionError(http2ErrCodeProtocol) + default: + sc.vlogf("http2: server ignoring frame: %v", f.Header()) + return nil + } +} + +func (sc *http2serverConn) processPing(f *http2PingFrame) error { + sc.serveG.check() + if f.IsAck() { + + return nil + } + if f.StreamID != 0 { + + return http2ConnectionError(http2ErrCodeProtocol) + } + sc.writeFrame(http2frameWriteMsg{write: http2writePingAck{f}}) + return nil +} + +func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { + sc.serveG.check() + switch { + case f.StreamID != 0: + st := sc.streams[f.StreamID] + if st == nil { + + return nil + } + if !st.flow.add(int32(f.Increment)) { + return http2StreamError{f.StreamID, http2ErrCodeFlowControl} + } + default: + if !sc.flow.add(int32(f.Increment)) { + return http2goAwayFlowError{} + } + } + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { + sc.serveG.check() + + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } + if st != nil { + st.gotReset = true + sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode}) + } + return nil +} + +func (sc *http2serverConn) closeStream(st *http2stream, err error) { + sc.serveG.check() + if st.state == http2stateIdle || st.state == http2stateClosed { + panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) + } + st.state = http2stateClosed + sc.curOpenStreams-- + if sc.curOpenStreams == 0 { + sc.setConnState(StateIdle) + } + delete(sc.streams, st.id) + if p := st.body; p != nil { + p.CloseWithError(err) + } + st.cw.Close() + sc.writeSched.forgetStream(st.id) +} + +func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { + sc.serveG.check() + if f.IsAck() { + sc.unackedSettings-- + if sc.unackedSettings < 0 { + + return http2ConnectionError(http2ErrCodeProtocol) + } + return nil + } + if err := f.ForeachSetting(sc.processSetting); err != nil { + return err + } + sc.needToSendSettingsAck = true + sc.scheduleFrameWrite() + return nil +} + +func (sc *http2serverConn) processSetting(s http2Setting) error { + sc.serveG.check() + if err := s.Valid(); err != nil { + return err + } + if http2VerboseLogs { + sc.vlogf("http2: server processing setting %v", s) + } + switch s.ID { + case http2SettingHeaderTableSize: + sc.headerTableSize = s.Val + sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) + case http2SettingEnablePush: + sc.pushEnabled = s.Val != 0 + case http2SettingMaxConcurrentStreams: + sc.clientMaxStreams = s.Val + case http2SettingInitialWindowSize: + return sc.processSettingInitialWindowSize(s.Val) + case http2SettingMaxFrameSize: + sc.writeSched.maxFrameSize = s.Val + case http2SettingMaxHeaderListSize: + sc.peerMaxHeaderListSize = s.Val + default: + + if http2VerboseLogs { + sc.vlogf("http2: server ignoring unknown setting %v", s) + } + } + return nil +} + +func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { + sc.serveG.check() + + old := sc.initialWindowSize + sc.initialWindowSize = int32(val) + growth := sc.initialWindowSize - old + for _, st := range sc.streams { + if !st.flow.add(growth) { + + return http2ConnectionError(http2ErrCodeFlowControl) + } + } + return nil +} + +func (sc *http2serverConn) processData(f *http2DataFrame) error { + sc.serveG.check() + + id := f.Header().StreamID + st, ok := sc.streams[id] + if !ok || st.state != http2stateOpen || st.gotTrailerHeader { + + return http2StreamError{id, http2ErrCodeStreamClosed} + } + if st.body == nil { + panic("internal error: should have a body in this state") + } + data := f.Data() + + if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { + st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) + return http2StreamError{id, http2ErrCodeStreamClosed} + } + if len(data) > 0 { + + if int(st.inflow.available()) < len(data) { + return http2StreamError{id, http2ErrCodeFlowControl} + } + st.inflow.take(int32(len(data))) + wrote, err := st.body.Write(data) + if err != nil { + return http2StreamError{id, http2ErrCodeStreamClosed} + } + if wrote != len(data) { + panic("internal error: bad Writer") + } + st.bodyBytes += int64(len(data)) + } + if f.StreamEnded() { + st.endStream() + } + return nil +} + +// endStream closes a Request.Body's pipe. It is called when a DATA +// frame says a request body is over (or after trailers). +func (st *http2stream) endStream() { + sc := st.sc + sc.serveG.check() + + if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { + st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", + st.declBodyBytes, st.bodyBytes)) + } else { + st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest) + st.body.CloseWithError(io.EOF) + } + st.state = http2stateHalfClosedRemote +} + +// copyTrailersToHandlerRequest is run in the Handler's goroutine in +// its Request.Body.Read just before it gets io.EOF. +func (st *http2stream) copyTrailersToHandlerRequest() { + for k, vv := range st.trailer { + if _, ok := st.reqTrailer[k]; ok { + + st.reqTrailer[k] = vv + } + } +} + +func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error { + sc.serveG.check() + id := f.Header().StreamID + if sc.inGoAway { + + return nil + } + + if id%2 != 1 { + return http2ConnectionError(http2ErrCodeProtocol) + } + + st := sc.streams[f.Header().StreamID] + if st != nil { + return st.processTrailerHeaders(f) + } + + if id <= sc.maxStreamID || sc.req.stream != nil { + return http2ConnectionError(http2ErrCodeProtocol) + } + + if id > sc.maxStreamID { + sc.maxStreamID = id + } + st = &http2stream{ + sc: sc, + id: id, + state: http2stateOpen, + } + if f.StreamEnded() { + st.state = http2stateHalfClosedRemote + } + st.cw.Init() + + st.flow.conn = &sc.flow + st.flow.add(sc.initialWindowSize) + st.inflow.conn = &sc.inflow + st.inflow.add(http2initialWindowSize) + + sc.streams[id] = st + if f.HasPriority() { + http2adjustStreamPriority(sc.streams, st.id, f.Priority) + } + sc.curOpenStreams++ + if sc.curOpenStreams == 1 { + sc.setConnState(StateActive) + } + sc.req = http2requestParam{ + stream: st, + header: make(Header), + } + sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField) + sc.hpackDecoder.SetEmitEnabled(true) + return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) +} + +func (st *http2stream) processTrailerHeaders(f *http2HeadersFrame) error { + sc := st.sc + sc.serveG.check() + if st.gotTrailerHeader { + return http2ConnectionError(http2ErrCodeProtocol) + } + st.gotTrailerHeader = true + if !f.StreamEnded() { + return http2StreamError{st.id, http2ErrCodeProtocol} + } + sc.resetPendingRequest() + return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) +} + +func (sc *http2serverConn) processContinuation(f *http2ContinuationFrame) error { + sc.serveG.check() + st := sc.streams[f.Header().StreamID] + if st.gotTrailerHeader { + return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) + } + return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) +} + +func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []byte, end bool) error { + sc.serveG.check() + if _, err := sc.hpackDecoder.Write(frag); err != nil { + return http2ConnectionError(http2ErrCodeCompression) + } + if !end { + return nil + } + if err := sc.hpackDecoder.Close(); err != nil { + return http2ConnectionError(http2ErrCodeCompression) + } + defer sc.resetPendingRequest() + if sc.curOpenStreams > sc.advMaxStreams { + + if sc.unackedSettings == 0 { + + return http2StreamError{st.id, http2ErrCodeProtocol} + } + + return http2StreamError{st.id, http2ErrCodeRefusedStream} + } + + rw, req, err := sc.newWriterAndRequest() + if err != nil { + return err + } + st.reqTrailer = req.Trailer + if st.reqTrailer != nil { + st.trailer = make(Header) + } + st.body = req.Body.(*http2requestBody).pipe + st.declBodyBytes = req.ContentLength + + handler := sc.handler.ServeHTTP + if !sc.hpackDecoder.EmitEnabled() { + + handler = http2handleHeaderListTooLong + } + + go sc.runHandler(rw, req, handler) + return nil +} + +func (st *http2stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error { + sc := st.sc + sc.serveG.check() + sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField) + if _, err := sc.hpackDecoder.Write(frag); err != nil { + return http2ConnectionError(http2ErrCodeCompression) + } + if !end { + return nil + } + + rp := &sc.req + if rp.invalidHeader { + return http2StreamError{rp.stream.id, http2ErrCodeProtocol} + } + + err := sc.hpackDecoder.Close() + st.endStream() + if err != nil { + return http2ConnectionError(http2ErrCodeCompression) + } + return nil +} + +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + http2adjustStreamPriority(sc.streams, f.StreamID, f.http2PriorityParam) + return nil +} + +func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, priority http2PriorityParam) { + st, ok := streams[streamID] + if !ok { + + return + } + st.weight = priority.Weight + parent := streams[priority.StreamDep] + if parent == st { + + return + } + + for piter := parent; piter != nil; piter = piter.parent { + if piter == st { + parent.parent = st.parent + break + } + } + st.parent = parent + if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) { + for _, openStream := range streams { + if openStream != st && openStream.parent == st.parent { + openStream.parent = st + } + } + } +} + +// resetPendingRequest zeros out all state related to a HEADERS frame +// and its zero or more CONTINUATION frames sent to start a new +// request. +func (sc *http2serverConn) resetPendingRequest() { + sc.serveG.check() + sc.req = http2requestParam{} +} + +func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request, error) { + sc.serveG.check() + rp := &sc.req + + if rp.invalidHeader { + return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + } + + isConnect := rp.method == "CONNECT" + if isConnect { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { + return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + } + } else if rp.method == "" || rp.path == "" || + (rp.scheme != "https" && rp.scheme != "http") { + + return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + } + + bodyOpen := rp.stream.state == http2stateOpen + if rp.method == "HEAD" && bodyOpen { + + return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + } + var tlsState *tls.ConnectionState // nil if not scheme https + + if rp.scheme == "https" { + tlsState = sc.tlsState + } + authority := rp.authority + if authority == "" { + authority = rp.header.Get("Host") + } + needsContinue := rp.header.Get("Expect") == "100-continue" + if needsContinue { + rp.header.Del("Expect") + } + + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) + } + + // Setup Trailers + var trailer Header + for _, v := range rp.header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = CanonicalHeaderKey(strings.TrimSpace(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + + default: + if trailer == nil { + trailer = make(Header) + } + trailer[key] = nil + } + } + } + delete(rp.header, "Trailer") + + body := &http2requestBody{ + conn: sc, + stream: rp.stream, + needsContinue: needsContinue, + } + var url_ *url.URL + var requestURI string + if isConnect { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority + } else { + var err error + url_, err = url.ParseRequestURI(rp.path) + if err != nil { + return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + } + requestURI = rp.path + } + req := &Request{ + Method: rp.method, + URL: url_, + RemoteAddr: sc.remoteAddrStr, + Header: rp.header, + RequestURI: requestURI, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + TLS: tlsState, + Host: authority, + Body: body, + Trailer: trailer, + } + if bodyOpen { + body.pipe = &http2pipe{ + b: &http2fixedBuffer{buf: make([]byte, http2initialWindowSize)}, + } + + if vv, ok := rp.header["Content-Length"]; ok { + req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + } else { + req.ContentLength = -1 + } + } + + rws := http2responseWriterStatePool.Get().(*http2responseWriterState) + bwSave := rws.bw + *rws = http2responseWriterState{} + rws.conn = sc + rws.bw = bwSave + rws.bw.Reset(http2chunkWriter{rws}) + rws.stream = rp.stream + rws.req = req + rws.body = body + + rw := &http2responseWriter{rws: rws} + return rw, req, nil +} + +// Run on its own goroutine. +func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { + didPanic := true + defer func() { + if didPanic { + e := recover() + // Same as net/http: + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.writeFrameFromHandler(http2frameWriteMsg{ + write: http2handlerPanicRST{rw.rws.stream.id}, + stream: rw.rws.stream, + }) + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + return + } + rw.handlerDone() + }() + handler(rw, req) + didPanic = false +} + +func http2handleHeaderListTooLong(w ResponseWriter, r *Request) { + // 10.5.1 Limits on Header Block Size: + // .. "A server that receives a larger header block than it is + // willing to handle can send an HTTP 431 (Request Header Fields Too + // Large) status code" + const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+ + w.WriteHeader(statusRequestHeaderFieldsTooLarge) + io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>") +} + +// called from handler goroutines. +// h may be nil. +func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error { + sc.serveG.checkNotOn() + var errc chan error + if headerData.h != nil { + + errc = http2errChanPool.Get().(chan error) + } + if err := sc.writeFrameFromHandler(http2frameWriteMsg{ + write: headerData, + stream: st, + done: errc, + }); err != nil { + return err + } + if errc != nil { + select { + case err := <-errc: + http2errChanPool.Put(errc) + return err + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + } + } + return nil +} + +// called from handler goroutines. +func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { + sc.writeFrameFromHandler(http2frameWriteMsg{ + write: http2write100ContinueHeadersFrame{st.id}, + stream: st, + }) +} + +// A bodyReadMsg tells the server loop that the http.Handler read n +// bytes of the DATA from the client on the given stream. +type http2bodyReadMsg struct { + st *http2stream + n int +} + +// called from handler goroutines. +// Notes that the handler for the given stream ID read n bytes of its body +// and schedules flow control tokens to be sent. +func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int) { + sc.serveG.checkNotOn() + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } +} + +func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { + sc.serveG.check() + sc.sendWindowUpdate(nil, n) + if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed { + + sc.sendWindowUpdate(st, n) + } +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { + sc.serveG.check() + // "The legal range for the increment to the flow control + // window is 1 to 2^31-1 (2,147,483,647) octets." + // A Go Read call on 64-bit machines could in theory read + // a larger Read than this. Very unlikely, but we handle it here + // rather than elsewhere for now. + const maxUint31 = 1<<31 - 1 + for n >= maxUint31 { + sc.sendWindowUpdate32(st, maxUint31) + n -= maxUint31 + } + sc.sendWindowUpdate32(st, int32(n)) +} + +// st may be nil for conn-level +func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { + sc.serveG.check() + if n == 0 { + return + } + if n < 0 { + panic("negative update") + } + var streamID uint32 + if st != nil { + streamID = st.id + } + sc.writeFrame(http2frameWriteMsg{ + write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, + stream: st, + }) + var ok bool + if st == nil { + ok = sc.inflow.add(n) + } else { + ok = st.inflow.add(n) + } + if !ok { + panic("internal error; sent too many window updates without decrements?") + } +} + +type http2requestBody struct { + stream *http2stream + conn *http2serverConn + closed bool + pipe *http2pipe // non-nil if we have a HTTP entity message body + needsContinue bool // need to send a 100-continue +} + +func (b *http2requestBody) Close() error { + if b.pipe != nil { + b.pipe.CloseWithError(http2errClosedBody) + } + b.closed = true + return nil +} + +func (b *http2requestBody) Read(p []byte) (n int, err error) { + if b.needsContinue { + b.needsContinue = false + b.conn.write100ContinueHeaders(b.stream) + } + if b.pipe == nil { + return 0, io.EOF + } + n, err = b.pipe.Read(p) + if n > 0 { + b.conn.noteBodyReadFromHandler(b.stream, n) + } + return +} + +// responseWriter is the http.ResponseWriter implementation. It's +// intentionally small (1 pointer wide) to minimize garbage. The +// responseWriterState pointer inside is zeroed at the end of a +// request (in handlerDone) and calls on the responseWriter thereafter +// simply crash (caller's mistake), but the much larger responseWriterState +// and buffers are reused between multiple requests. +type http2responseWriter struct { + rws *http2responseWriterState +} + +// Optional http.ResponseWriter interfaces implemented. +var ( + _ CloseNotifier = (*http2responseWriter)(nil) + _ Flusher = (*http2responseWriter)(nil) + _ http2stringWriter = (*http2responseWriter)(nil) +) + +type http2responseWriterState struct { + // immutable within a request: + stream *http2stream + req *Request + body *http2requestBody // to close at end of request, if DATA frames didn't + conn *http2serverConn + + // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc + bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} + + // mutated by http.Handler goroutine: + handlerHeader Header // nil until called + snapHeader Header // snapshot of handlerHeader at WriteHeader time + trailers []string // set in writeChunk + status int // status code passed to WriteHeader + wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. + sentHeader bool // have we sent the header frame? + handlerDone bool // handler has finished + + sentContentLen int64 // non-zero if handler set a Content-Length header + wroteBytes int64 + + closeNotifierMu sync.Mutex // guards closeNotifierCh + closeNotifierCh chan bool // nil until first used +} + +type http2chunkWriter struct{ rws *http2responseWriterState } + +func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } + +func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) != 0 } + +// declareTrailer is called for each Trailer header when the +// response header is written. It notes that a header will need to be +// written in the trailers at the end of the response. +func (rws *http2responseWriterState) declareTrailer(k string) { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Content-Length", "Trailer": + + return + } + rws.trailers = append(rws.trailers, k) +} + +// writeChunk writes chunks from the bufio.Writer. But because +// bufio.Writer may bypass its chunking, sometimes p may be +// arbitrarily large. +// +// writeChunk is also responsible (on the first chunk) for sending the +// HEADER response. +func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { + if !rws.wroteHeader { + rws.writeHeader(200) + } + + isHeadResp := rws.req.Method == "HEAD" + if !rws.sentHeader { + rws.sentHeader = true + var ctype, clen string + if clen = rws.snapHeader.Get("Content-Length"); clen != "" { + rws.snapHeader.Del("Content-Length") + clen64, err := strconv.ParseInt(clen, 10, 64) + if err == nil && clen64 >= 0 { + rws.sentContentLen = clen64 + } else { + clen = "" + } + } + if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { + clen = strconv.Itoa(len(p)) + } + _, hasContentType := rws.snapHeader["Content-Type"] + if !hasContentType && http2bodyAllowedForStatus(rws.status) { + ctype = DetectContentType(p) + } + var date string + if _, ok := rws.snapHeader["Date"]; !ok { + + date = time.Now().UTC().Format(TimeFormat) + } + + for _, v := range rws.snapHeader["Trailer"] { + http2foreachHeaderElement(v, rws.declareTrailer) + } + + endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + httpResCode: rws.status, + h: rws.snapHeader, + endStream: endStream, + contentType: ctype, + contentLength: clen, + date: date, + }) + if err != nil { + return 0, err + } + if endStream { + return 0, nil + } + } + if isHeadResp { + return len(p), nil + } + if len(p) == 0 && !rws.handlerDone { + return 0, nil + } + + endStream := rws.handlerDone && !rws.hasTrailers() + if len(p) > 0 || endStream { + + if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { + return 0, err + } + } + + if rws.handlerDone && rws.hasTrailers() { + err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ + streamID: rws.stream.id, + h: rws.handlerHeader, + trailers: rws.trailers, + endStream: true, + }) + return len(p), err + } + return len(p), nil +} + +func (w *http2responseWriter) Flush() { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.bw.Buffered() > 0 { + if err := rws.bw.Flush(); err != nil { + + return + } + } else { + + rws.writeChunk(nil) + } +} + +func (w *http2responseWriter) CloseNotify() <-chan bool { + rws := w.rws + if rws == nil { + panic("CloseNotify called after Handler finished") + } + rws.closeNotifierMu.Lock() + ch := rws.closeNotifierCh + if ch == nil { + ch = make(chan bool, 1) + rws.closeNotifierCh = ch + go func() { + rws.stream.cw.Wait() + ch <- true + }() + } + rws.closeNotifierMu.Unlock() + return ch +} + +func (w *http2responseWriter) Header() Header { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.handlerHeader == nil { + rws.handlerHeader = make(Header) + } + return rws.handlerHeader +} + +func (w *http2responseWriter) WriteHeader(code int) { + rws := w.rws + if rws == nil { + panic("WriteHeader called after Handler finished") + } + rws.writeHeader(code) +} + +func (rws *http2responseWriterState) writeHeader(code int) { + if !rws.wroteHeader { + rws.wroteHeader = true + rws.status = code + if len(rws.handlerHeader) > 0 { + rws.snapHeader = http2cloneHeader(rws.handlerHeader) + } + } +} + +func http2cloneHeader(h Header) Header { + h2 := make(Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +// The Life Of A Write is like this: +// +// * Handler calls w.Write or w.WriteString -> +// * -> rws.bw (*bufio.Writer) -> +// * (Handler migth call Flush) +// * -> chunkWriter{rws} +// * -> responseWriterState.writeChunk(p []byte) +// * -> responseWriterState.writeChunk (most of the magic; see comment there) +func (w *http2responseWriter) Write(p []byte) (n int, err error) { + return w.write(len(p), p, "") +} + +func (w *http2responseWriter) WriteString(s string) (n int, err error) { + return w.write(len(s), nil, s) +} + +// either dataB or dataS is non-zero. +func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { + rws := w.rws + if rws == nil { + panic("Write called after Handler finished") + } + if !rws.wroteHeader { + w.WriteHeader(200) + } + if !http2bodyAllowedForStatus(rws.status) { + return 0, ErrBodyNotAllowed + } + rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) + if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { + + return 0, errors.New("http2: handler wrote more than declared Content-Length") + } + + if dataB != nil { + return rws.bw.Write(dataB) + } else { + return rws.bw.WriteString(dataS) + } +} + +func (w *http2responseWriter) handlerDone() { + rws := w.rws + rws.handlerDone = true + w.Flush() + w.rws = nil + http2responseWriterStatePool.Put(rws) +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 2616 section 2.1 and calls fn for each non-empty element. +func http2foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +const ( + // transportDefaultConnFlow is how many connection-level flow control + // tokens we give the server at start-up, past the default 64k. + http2transportDefaultConnFlow = 1 << 30 + + // transportDefaultStreamFlow is how many stream-level flow + // control tokens we announce to the peer, and how many bytes + // we buffer per stream. + http2transportDefaultStreamFlow = 4 << 20 + + // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send + // a stream-level WINDOW_UPDATE for at a time. + http2transportDefaultStreamMinRefresh = 4 << 10 + + http2defaultUserAgent = "Go-http-client/2.0" +) + +// Transport is an HTTP/2 Transport. +// +// A Transport internally caches connections to servers. It is safe +// for concurrent use by multiple goroutines. +type http2Transport struct { + // DialTLS specifies an optional dial function for creating + // TLS connections for requests. + // + // If DialTLS is nil, tls.Dial is used. + // + // If the returned net.Conn has a ConnectionState method like tls.Conn, + // it will be used to set http.Response.TLS. + DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // ConnPool optionally specifies an alternate connection pool to use. + // If nil, the default is used. + ConnPool http2ClientConnPool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to + // send in the initial settings frame. It is how many bytes + // of response headers are allow. Unlike the http2 spec, zero here + // means to use a default limit (currently 10MB). If you actually + // want to advertise an ulimited value to the peer, Transport + // interprets the highest possible value here (0xffffffff or 1<<32-1) + // to mean no limit. + MaxHeaderListSize uint32 + + // t1, if non-nil, is the standard library Transport using + // this transport. Its settings are used (but not its + // RoundTrip method, etc). + t1 *Transport + + connPoolOnce sync.Once + connPoolOrDef http2ClientConnPool // non-nil version of ConnPool +} + +func (t *http2Transport) maxHeaderListSize() uint32 { + if t.MaxHeaderListSize == 0 { + return 10 << 20 + } + if t.MaxHeaderListSize == 0xffffffff { + return 0 + } + return t.MaxHeaderListSize +} + +func (t *http2Transport) disableCompression() bool { + return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) +} + +var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6") + +// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. +// It requires Go 1.6 or later and returns an error if the net/http package is too old +// or if t1 has already been HTTP/2-enabled. +func http2ConfigureTransport(t1 *Transport) error { + _, err := http2configureTransport(t1) + return err +} + +func (t *http2Transport) connPool() http2ClientConnPool { + t.connPoolOnce.Do(t.initConnPool) + return t.connPoolOrDef +} + +func (t *http2Transport) initConnPool() { + if t.ConnPool != nil { + t.connPoolOrDef = t.ConnPool + } else { + t.connPoolOrDef = &http2clientConnPool{t: t} + } +} + +// ClientConn is the state of a single HTTP/2 client connection to an +// HTTP/2 server. +type http2ClientConn struct { + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + + // readLoop goroutine fields: + readerDone chan struct{} // closed on error + readerErr error // set before readerDone is closed + + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow http2flow // our conn-level flow control quota (cs.flow is per stream) + inflow http2flow // peer's conn-level flow control + closed bool + goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + streams map[uint32]*http2clientStream // client-initiated + nextStreamID uint32 + bw *bufio.Writer + br *bufio.Reader + fr *http2Framer + // Settings from peer: + maxFrameSize uint32 + maxConcurrentStreams uint32 + initialWindowSize uint32 + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder + freeBuf [][]byte + + wmu sync.Mutex // held while writing; acquire AFTER mu if holding both + werr error // first write error that has occurred +} + +// clientStream is the state for a single HTTP/2 stream. One of these +// is created for each Transport.RoundTrip call. +type http2clientStream struct { + cc *http2ClientConn + req *Request + ID uint32 + resc chan http2resAndError + bufPipe http2pipe // buffered pipe with the flow-controlled response payload + requestedGzip bool + + flow http2flow // guarded by cc.mu + inflow http2flow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read + stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu + + peerReset chan struct{} // closed on peer reset + resetErr error // populated before peerReset is closed + + done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu + + // owned by clientConnReadLoop: + pastHeaders bool // got HEADERS w/ END_HEADERS + pastTrailers bool // got second HEADERS frame w/ END_HEADERS + + trailer Header // accumulated trailers + resTrailer *Header // client's Response.Trailer +} + +// awaitRequestCancel runs in its own goroutine and waits for the user +// to either cancel a RoundTrip request (using the provided +// Request.Cancel channel), or for the request to be done (any way it +// might be removed from the cc.streams map: peer reset, successful +// completion, TCP connection breakage, etc) +func (cs *http2clientStream) awaitRequestCancel(cancel <-chan struct{}) { + if cancel == nil { + return + } + select { + case <-cancel: + cs.bufPipe.CloseWithError(http2errRequestCanceled) + cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + case <-cs.done: + } +} + +// checkReset reports any error sent in a RST_STREAM frame by the +// server. +func (cs *http2clientStream) checkReset() error { + select { + case <-cs.peerReset: + return cs.resetErr + default: + return nil + } +} + +func (cs *http2clientStream) abortRequestBodyWrite(err error) { + if err == nil { + panic("nil error") + } + cc := cs.cc + cc.mu.Lock() + cs.stopReqBody = err + cc.cond.Broadcast() + cc.mu.Unlock() +} + +type http2stickyErrWriter struct { + w io.Writer + err *error +} + +func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { + if *sew.err != nil { + return 0, *sew.err + } + n, err = sew.w.Write(p) + *sew.err = err + return +} + +var http2ErrNoCachedConn = errors.New("http2: no cached connection was available") + +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type http2RoundTripOpt struct { + // OnlyCachedConn controls whether RoundTripOpt may + // create a new TCP connection. If set true and + // no cached connection is available, RoundTripOpt + // will return ErrNoCachedConn. + OnlyCachedConn bool +} + +func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { + return t.RoundTripOpt(req, http2RoundTripOpt{}) +} + +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func http2authorityAddr(authority string) (addr string) { + if _, _, err := net.SplitHostPort(authority); err == nil { + return authority + } + return net.JoinHostPort(authority, "443") +} + +// RoundTripOpt is like RoundTrip, but takes options. +func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) { + if req.URL.Scheme != "https" { + return nil, errors.New("http2: unsupported scheme") + } + + addr := http2authorityAddr(req.URL.Host) + for { + cc, err := t.connPool().GetClientConn(req, addr) + if err != nil { + t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) + return nil, err + } + res, err := cc.RoundTrip(req) + if http2shouldRetryRequest(req, err) { + continue + } + if err != nil { + t.vlogf("RoundTrip failure: %v", err) + return nil, err + } + return res, nil + } +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle. +// It does not interrupt any connections currently in use. +func (t *http2Transport) CloseIdleConnections() { + if cp, ok := t.connPool().(*http2clientConnPool); ok { + cp.closeIdleConnections() + } +} + +var ( + http2errClientConnClosed = errors.New("http2: client conn is closed") + http2errClientConnUnusable = errors.New("http2: client conn not usable") +) + +func http2shouldRetryRequest(req *Request, err error) bool { + + return err == http2errClientConnUnusable +} + +func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host)) + if err != nil { + return nil, err + } + return t.NewClientConn(tconn) +} + +func (t *http2Transport) newTLSConfig(host string) *tls.Config { + cfg := new(tls.Config) + if t.TLSClientConfig != nil { + *cfg = *t.TLSClientConfig + } + cfg.NextProtos = []string{http2NextProtoTLS} + cfg.ServerName = host + return cfg +} + +func (t *http2Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) { + if t.DialTLS != nil { + return t.DialTLS + } + return t.dialTLSDefault +} + +func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) { + cn, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + if err := cn.Handshake(); err != nil { + return nil, err + } + if !cfg.InsecureSkipVerify { + if err := cn.VerifyHostname(cfg.ServerName); err != nil { + return nil, err + } + } + state := cn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS) + } + if !state.NegotiatedProtocolIsMutual { + return nil, errors.New("http2: could not negotiate protocol mutually") + } + return cn, nil +} + +// disableKeepAlives reports whether connections should be closed as +// soon as possible after handling the first request. +func (t *http2Transport) disableKeepAlives() bool { + return t.t1 != nil && t.t1.DisableKeepAlives +} + +func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { + if http2VerboseLogs { + t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) + } + if _, err := c.Write(http2clientPreface); err != nil { + t.vlogf("client preface write error: %v", err) + return nil, err + } + + cc := &http2ClientConn{ + t: t, + tconn: c, + readerDone: make(chan struct{}), + nextStreamID: 1, + maxFrameSize: 16 << 10, + initialWindowSize: 65535, + maxConcurrentStreams: 1000, + streams: make(map[uint32]*http2clientStream), + } + cc.cond = sync.NewCond(&cc.mu) + cc.flow.add(int32(http2initialWindowSize)) + + cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) + cc.br = bufio.NewReader(c) + cc.fr = http2NewFramer(cc.bw, cc.br) + + cc.henc = hpack.NewEncoder(&cc.hbuf) + + type connectionStater interface { + ConnectionState() tls.ConnectionState + } + if cs, ok := c.(connectionStater); ok { + state := cs.ConnectionState() + cc.tlsState = &state + } + + initialSettings := []http2Setting{ + http2Setting{ID: http2SettingEnablePush, Val: 0}, + http2Setting{ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + } + if max := t.maxHeaderListSize(); max != 0 { + initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) + } + cc.fr.WriteSettings(initialSettings...) + cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) + cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) + cc.bw.Flush() + if cc.werr != nil { + return nil, cc.werr + } + + f, err := cc.fr.ReadFrame() + if err != nil { + return nil, err + } + sf, ok := f.(*http2SettingsFrame) + if !ok { + return nil, fmt.Errorf("expected settings frame, got: %T", f) + } + cc.fr.WriteSettingsAck() + cc.bw.Flush() + + sf.ForeachSetting(func(s http2Setting) error { + switch s.ID { + case http2SettingMaxFrameSize: + cc.maxFrameSize = s.Val + case http2SettingMaxConcurrentStreams: + cc.maxConcurrentStreams = s.Val + case http2SettingInitialWindowSize: + cc.initialWindowSize = s.Val + default: + + t.vlogf("Unhandled Setting: %v", s) + } + return nil + }) + + go cc.readLoop() + return cc, nil +} + +func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.goAway = f +} + +func (cc *http2ClientConn) CanTakeNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.canTakeNewRequestLocked() +} + +func (cc *http2ClientConn) canTakeNewRequestLocked() bool { + return cc.goAway == nil && !cc.closed && + int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && + cc.nextStreamID < 2147483647 +} + +func (cc *http2ClientConn) closeIfIdle() { + cc.mu.Lock() + if len(cc.streams) > 0 { + cc.mu.Unlock() + return + } + cc.closed = true + + cc.mu.Unlock() + + cc.tconn.Close() +} + +const http2maxAllocFrameSize = 512 << 10 + +// frameBuffer returns a scratch buffer suitable for writing DATA frames. +// They're capped at the min of the peer's max frame size or 512KB +// (kinda arbitrarily), but definitely capped so we don't allocate 4GB +// bufers. +func (cc *http2ClientConn) frameScratchBuffer() []byte { + cc.mu.Lock() + size := cc.maxFrameSize + if size > http2maxAllocFrameSize { + size = http2maxAllocFrameSize + } + for i, buf := range cc.freeBuf { + if len(buf) >= int(size) { + cc.freeBuf[i] = nil + cc.mu.Unlock() + return buf[:size] + } + } + cc.mu.Unlock() + return make([]byte, size) +} + +func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) { + cc.mu.Lock() + defer cc.mu.Unlock() + const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate. + if len(cc.freeBuf) < maxBufs { + cc.freeBuf = append(cc.freeBuf, buf) + return + } + for i, old := range cc.freeBuf { + if old == nil { + cc.freeBuf[i] = buf + return + } + } + +} + +// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not +// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests. +var http2errRequestCanceled = errors.New("net/http: request canceled") + +func http2commaSeparatedTrailers(req *Request) (string, error) { + keys := make([]string, 0, len(req.Trailer)) + for k := range req.Trailer { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return "", &http2badStringError{"invalid Trailer key", k} + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + + return strings.Join(keys, ","), nil + } + return "", nil +} + +func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { + if cc.t.t1 != nil { + return cc.t.t1.ResponseHeaderTimeout + } + + return 0 +} + +func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + return nil, err + } + hasTrailers := trailers != "" + + var body io.Reader = req.Body + contentLen := req.ContentLength + if req.Body != nil && contentLen == 0 { + // Test to see if it's actually zero or just unset. + var buf [1]byte + n, rerr := io.ReadFull(body, buf[:]) + if rerr != nil && rerr != io.EOF { + contentLen = -1 + body = http2errorReader{rerr} + } else if n == 1 { + + contentLen = -1 + body = io.MultiReader(bytes.NewReader(buf[:]), body) + } else { + + body = nil + } + } + + cc.mu.Lock() + if cc.closed || !cc.canTakeNewRequestLocked() { + cc.mu.Unlock() + return nil, http2errClientConnUnusable + } + + cs := cc.newStream() + cs.req = req + hasBody := body != nil + + if !cc.t.disableCompression() && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + req.Method != "HEAD" { + + cs.requestedGzip = true + } + + hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) + cc.wmu.Lock() + endStream := !hasBody && !hasTrailers + werr := cc.writeHeaders(cs.ID, endStream, hdrs) + cc.wmu.Unlock() + cc.mu.Unlock() + + if werr != nil { + if hasBody { + req.Body.Close() + } + cc.forgetStreamID(cs.ID) + + return nil, werr + } + + var respHeaderTimer <-chan time.Time + var bodyCopyErrc chan error // result of body copy + if hasBody { + bodyCopyErrc = make(chan error, 1) + go func() { + bodyCopyErrc <- cs.writeRequestBody(body, req.Body) + }() + } else { + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + } + } + + readLoopResCh := cs.resc + requestCanceledCh := http2requestCancel(req) + bodyWritten := false + + for { + select { + case re := <-readLoopResCh: + res := re.res + if re.err != nil || res.StatusCode > 299 { + + cs.abortRequestBodyWrite(http2errStopReqBodyWrite) + } + if re.err != nil { + cc.forgetStreamID(cs.ID) + return nil, re.err + } + res.Request = req + res.TLS = cc.tlsState + return res, nil + case <-respHeaderTimer: + cc.forgetStreamID(cs.ID) + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + } + return nil, http2errTimeout + case <-requestCanceledCh: + cc.forgetStreamID(cs.ID) + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + } + return nil, http2errRequestCanceled + case <-cs.peerReset: + + return nil, cs.resetErr + case err := <-bodyCopyErrc: + if err != nil { + return nil, err + } + bodyWritten = true + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + } + } + } +} + +// requires cc.wmu be held +func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error { + first := true + frameSize := int(cc.maxFrameSize) + for len(hdrs) > 0 && cc.werr == nil { + chunk := hdrs + if len(chunk) > frameSize { + chunk = chunk[:frameSize] + } + hdrs = hdrs[len(chunk):] + endHeaders := len(hdrs) == 0 + if first { + cc.fr.WriteHeaders(http2HeadersFrameParam{ + StreamID: streamID, + BlockFragment: chunk, + EndStream: endStream, + EndHeaders: endHeaders, + }) + first = false + } else { + cc.fr.WriteContinuation(streamID, endHeaders, chunk) + } + } + + cc.bw.Flush() + return cc.werr +} + +// internal error values; they don't escape to callers +var ( + // abort request body write; don't send cancel + http2errStopReqBodyWrite = errors.New("http2: aborting request body write") + + // abort request body write, but send stream reset of cancel. + http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request") +) + +func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { + cc := cs.cc + sentEnd := false + buf := cc.frameScratchBuffer() + defer cc.putFrameScratchBuffer(buf) + + defer func() { + + cerr := bodyCloser.Close() + if err == nil { + err = cerr + } + }() + + req := cs.req + hasTrailers := req.Trailer != nil + + var sawEOF bool + for !sawEOF { + n, err := body.Read(buf) + if err == io.EOF { + sawEOF = true + err = nil + } else if err != nil { + return err + } + + remain := buf[:n] + for len(remain) > 0 && err == nil { + var allowed int32 + allowed, err = cs.awaitFlowControl(len(remain)) + switch { + case err == http2errStopReqBodyWrite: + return err + case err == http2errStopReqBodyWriteAndCancel: + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + return err + case err != nil: + return err + } + cc.wmu.Lock() + data := remain[:allowed] + remain = remain[allowed:] + sentEnd = sawEOF && len(remain) == 0 && !hasTrailers + err = cc.fr.WriteData(cs.ID, sentEnd, data) + if err == nil { + + err = cc.bw.Flush() + } + cc.wmu.Unlock() + } + if err != nil { + return err + } + } + + cc.wmu.Lock() + if !sentEnd { + var trls []byte + if hasTrailers { + cc.mu.Lock() + trls = cc.encodeTrailers(req) + cc.mu.Unlock() + } + + if len(trls) > 0 { + err = cc.writeHeaders(cs.ID, true, trls) + } else { + err = cc.fr.WriteData(cs.ID, true, nil) + } + } + if ferr := cc.bw.Flush(); ferr != nil && err == nil { + err = ferr + } + cc.wmu.Unlock() + + return err +} + +// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow +// control tokens from the server. +// It returns either the non-zero number of tokens taken or an error +// if the stream is dead. +func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { + cc := cs.cc + cc.mu.Lock() + defer cc.mu.Unlock() + for { + if cc.closed { + return 0, http2errClientConnClosed + } + if cs.stopReqBody != nil { + return 0, cs.stopReqBody + } + if err := cs.checkReset(); err != nil { + return 0, err + } + if a := cs.flow.available(); a > 0 { + take := a + if int(take) > maxBytes { + + take = int32(maxBytes) + } + if take > int32(cc.maxFrameSize) { + take = int32(cc.maxFrameSize) + } + cs.flow.take(take) + return take, nil + } + cc.cond.Wait() + } +} + +type http2badStringError struct { + what string + str string +} + +func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } + +// requires cc.mu be held. +func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) []byte { + cc.hbuf.Reset() + + host := req.Host + if host == "" { + host = req.URL.Host + } + + cc.writeHeader(":authority", host) + cc.writeHeader(":method", req.Method) + if req.Method != "CONNECT" { + cc.writeHeader(":path", req.URL.RequestURI()) + cc.writeHeader(":scheme", "https") + } + if trailers != "" { + cc.writeHeader("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + lowKey := strings.ToLower(k) + if lowKey == "host" || lowKey == "content-length" { + continue + } + if lowKey == "user-agent" { + + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + } + for _, v := range vv { + cc.writeHeader(lowKey, v) + } + } + if http2shouldSendReqContentLength(req.Method, contentLength) { + cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + cc.writeHeader("accept-encoding", "gzip") + } + if !didUA { + cc.writeHeader("user-agent", http2defaultUserAgent) + } + return cc.hbuf.Bytes() +} + +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func http2shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} + +// requires cc.mu be held. +func (cc *http2ClientConn) encodeTrailers(req *Request) []byte { + cc.hbuf.Reset() + for k, vv := range req.Trailer { + + lowKey := strings.ToLower(k) + for _, v := range vv { + cc.writeHeader(lowKey, v) + } + } + return cc.hbuf.Bytes() +} + +func (cc *http2ClientConn) writeHeader(name, value string) { + if http2VerboseLogs { + log.Printf("http2: Transport encoding header %q = %q", name, value) + } + cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) +} + +type http2resAndError struct { + res *Response + err error +} + +// requires cc.mu be held. +func (cc *http2ClientConn) newStream() *http2clientStream { + cs := &http2clientStream{ + cc: cc, + ID: cc.nextStreamID, + resc: make(chan http2resAndError, 1), + peerReset: make(chan struct{}), + done: make(chan struct{}), + } + cs.flow.add(int32(cc.initialWindowSize)) + cs.flow.setConnFlow(&cc.flow) + cs.inflow.add(http2transportDefaultStreamFlow) + cs.inflow.setConnFlow(&cc.inflow) + cc.nextStreamID += 2 + cc.streams[cs.ID] = cs + return cs +} + +func (cc *http2ClientConn) forgetStreamID(id uint32) { + cc.streamByID(id, true) +} + +func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream { + cc.mu.Lock() + defer cc.mu.Unlock() + cs := cc.streams[id] + if andRemove && cs != nil && !cc.closed { + delete(cc.streams, id) + close(cs.done) + } + return cs +} + +// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. +type http2clientConnReadLoop struct { + cc *http2ClientConn + activeRes map[uint32]*http2clientStream // keyed by streamID + + hdec *hpack.Decoder + + // Fields reset on each HEADERS: + nextRes *Response + sawRegHeader bool // saw non-pseudo header + reqMalformed error // non-nil once known to be malformed + lastHeaderEndsStream bool + headerListSize int64 // actually uint32, but easier math this way +} + +// readLoop runs in its own goroutine and reads and dispatches frames. +func (cc *http2ClientConn) readLoop() { + rl := &http2clientConnReadLoop{ + cc: cc, + activeRes: make(map[uint32]*http2clientStream), + } + rl.hdec = hpack.NewDecoder(http2initialHeaderTableSize, rl.onNewHeaderField) + + defer rl.cleanup() + cc.readerErr = rl.run() + if ce, ok := cc.readerErr.(http2ConnectionError); ok { + cc.wmu.Lock() + cc.fr.WriteGoAway(0, http2ErrCode(ce), nil) + cc.wmu.Unlock() + } +} + +func (rl *http2clientConnReadLoop) cleanup() { + cc := rl.cc + defer cc.tconn.Close() + defer cc.t.connPool().MarkDead(cc) + defer close(cc.readerDone) + + err := cc.readerErr + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + cc.mu.Lock() + for _, cs := range rl.activeRes { + cs.bufPipe.CloseWithError(err) + } + for _, cs := range cc.streams { + select { + case cs.resc <- http2resAndError{err: err}: + default: + } + close(cs.done) + } + cc.closed = true + cc.cond.Broadcast() + cc.mu.Unlock() +} + +func (rl *http2clientConnReadLoop) run() error { + cc := rl.cc + closeWhenIdle := cc.t.disableKeepAlives() + gotReply := false + for { + f, err := cc.fr.ReadFrame() + if err != nil { + cc.vlogf("Transport readFrame error: (%T) %v", err, err) + } + if se, ok := err.(http2StreamError); ok { + + return se + } else if err != nil { + return err + } + if http2VerboseLogs { + cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) + } + maybeIdle := false + + switch f := f.(type) { + case *http2HeadersFrame: + err = rl.processHeaders(f) + maybeIdle = true + gotReply = true + case *http2ContinuationFrame: + err = rl.processContinuation(f) + maybeIdle = true + case *http2DataFrame: + err = rl.processData(f) + maybeIdle = true + case *http2GoAwayFrame: + err = rl.processGoAway(f) + maybeIdle = true + case *http2RSTStreamFrame: + err = rl.processResetStream(f) + maybeIdle = true + case *http2SettingsFrame: + err = rl.processSettings(f) + case *http2PushPromiseFrame: + err = rl.processPushPromise(f) + case *http2WindowUpdateFrame: + err = rl.processWindowUpdate(f) + case *http2PingFrame: + err = rl.processPing(f) + default: + cc.logf("Transport: unhandled response frame type %T", f) + } + if err != nil { + return err + } + if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 { + cc.closeIfIdle() + } + } +} + +func (rl *http2clientConnReadLoop) processHeaders(f *http2HeadersFrame) error { + rl.sawRegHeader = false + rl.reqMalformed = nil + rl.lastHeaderEndsStream = f.StreamEnded() + rl.headerListSize = 0 + rl.nextRes = &Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: make(Header), + } + rl.hdec.SetEmitEnabled(true) + return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded()) +} + +func (rl *http2clientConnReadLoop) processContinuation(f *http2ContinuationFrame) error { + return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded()) +} + +func (rl *http2clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error { + cc := rl.cc + streamEnded := rl.lastHeaderEndsStream + cs := cc.streamByID(streamID, streamEnded && finalFrag) + if cs == nil { + + return nil + } + if cs.pastHeaders { + rl.hdec.SetEmitFunc(func(f hpack.HeaderField) { rl.onNewTrailerField(cs, f) }) + } else { + rl.hdec.SetEmitFunc(rl.onNewHeaderField) + } + _, err := rl.hdec.Write(frag) + if err != nil { + return http2ConnectionError(http2ErrCodeCompression) + } + if finalFrag { + if err := rl.hdec.Close(); err != nil { + return http2ConnectionError(http2ErrCodeCompression) + } + } + + if !finalFrag { + return nil + } + + if !cs.pastHeaders { + cs.pastHeaders = true + } else { + + if cs.pastTrailers { + + return http2ConnectionError(http2ErrCodeProtocol) + } + cs.pastTrailers = true + if !streamEnded { + + return http2ConnectionError(http2ErrCodeProtocol) + } + rl.endStream(cs) + return nil + } + + if rl.reqMalformed != nil { + cs.resc <- http2resAndError{err: rl.reqMalformed} + rl.cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, rl.reqMalformed) + return nil + } + + res := rl.nextRes + + if res.StatusCode == 100 { + + cs.pastHeaders = false + return nil + } + + if !streamEnded || cs.req.Method == "HEAD" { + res.ContentLength = -1 + if clens := res.Header["Content-Length"]; len(clens) == 1 { + if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { + res.ContentLength = clen64 + } else { + + } + } else if len(clens) > 1 { + + } + } + + if streamEnded { + res.Body = http2noBody + } else { + buf := new(bytes.Buffer) + cs.bufPipe = http2pipe{b: buf} + cs.bytesRemain = res.ContentLength + res.Body = http2transportResponseBody{cs} + go cs.awaitRequestCancel(http2requestCancel(cs.req)) + + if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &http2gzipReader{body: res.Body} + } + } + + cs.resTrailer = &res.Trailer + rl.activeRes[cs.ID] = cs + cs.resc <- http2resAndError{res: res} + rl.nextRes = nil + return nil +} + +// transportResponseBody is the concrete type of Transport.RoundTrip's +// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body. +// On Close it sends RST_STREAM if EOF wasn't already seen. +type http2transportResponseBody struct { + cs *http2clientStream +} + +func (b http2transportResponseBody) Read(p []byte) (n int, err error) { + cs := b.cs + cc := cs.cc + + if cs.readErr != nil { + return 0, cs.readErr + } + n, err = b.cs.bufPipe.Read(p) + if cs.bytesRemain != -1 { + if int64(n) > cs.bytesRemain { + n = int(cs.bytesRemain) + if err == nil { + err = errors.New("net/http: server replied with more than declared Content-Length; truncated") + cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, err) + } + cs.readErr = err + return int(cs.bytesRemain), err + } + cs.bytesRemain -= int64(n) + if err == io.EOF && cs.bytesRemain > 0 { + err = io.ErrUnexpectedEOF + cs.readErr = err + return n, err + } + } + if n == 0 { + + return + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + var connAdd, streamAdd int32 + + if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { + connAdd = http2transportDefaultConnFlow - v + cc.inflow.add(connAdd) + } + if err == nil { + if v := cs.inflow.available(); v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { + streamAdd = http2transportDefaultStreamFlow - v + cs.inflow.add(streamAdd) + } + } + if connAdd != 0 || streamAdd != 0 { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if connAdd != 0 { + cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd)) + } + if streamAdd != 0 { + cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd)) + } + cc.bw.Flush() + } + return +} + +var http2errClosedResponseBody = errors.New("http2: response body closed") + +func (b http2transportResponseBody) Close() error { + cs := b.cs + if cs.bufPipe.Err() != io.EOF { + + cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } + cs.bufPipe.BreakWithError(http2errClosedResponseBody) + return nil +} + +func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { + cc := rl.cc + cs := cc.streamByID(f.StreamID, f.StreamEnded()) + if cs == nil { + cc.mu.Lock() + neverSent := cc.nextStreamID + cc.mu.Unlock() + if f.StreamID >= neverSent { + + cc.logf("http2: Transport received unsolicited DATA frame; closing connection") + return http2ConnectionError(http2ErrCodeProtocol) + } + + return nil + } + if data := f.Data(); len(data) > 0 { + if cs.bufPipe.b == nil { + + cc.logf("http2: Transport received DATA frame for closed stream; closing connection") + return http2ConnectionError(http2ErrCodeProtocol) + } + + cc.mu.Lock() + if cs.inflow.available() >= int32(len(data)) { + cs.inflow.take(int32(len(data))) + } else { + cc.mu.Unlock() + return http2ConnectionError(http2ErrCodeFlowControl) + } + cc.mu.Unlock() + + if _, err := cs.bufPipe.Write(data); err != nil { + return err + } + } + + if f.StreamEnded() { + rl.endStream(cs) + } + return nil +} + +var http2errInvalidTrailers = errors.New("http2: invalid trailers") + +func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { + + err := io.EOF + code := cs.copyTrailers + if rl.reqMalformed != nil { + err = rl.reqMalformed + code = nil + } + cs.bufPipe.closeWithErrorAndCode(err, code) + delete(rl.activeRes, cs.ID) +} + +func (cs *http2clientStream) copyTrailers() { + for k, vv := range cs.trailer { + t := cs.resTrailer + if *t == nil { + *t = make(Header) + } + (*t)[k] = vv + } +} + +func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { + cc := rl.cc + cc.t.connPool().MarkDead(cc) + if f.ErrCode != 0 { + + cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) + } + cc.setGoAway(f) + return nil +} + +func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + return f.ForeachSetting(func(s http2Setting) error { + switch s.ID { + case http2SettingMaxFrameSize: + cc.maxFrameSize = s.Val + case http2SettingMaxConcurrentStreams: + cc.maxConcurrentStreams = s.Val + case http2SettingInitialWindowSize: + + cc.initialWindowSize = s.Val + default: + + cc.vlogf("Unhandled Setting: %v", s) + } + return nil + }) +} + +func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { + cc := rl.cc + cs := cc.streamByID(f.StreamID, false) + if f.StreamID != 0 && cs == nil { + return nil + } + + cc.mu.Lock() + defer cc.mu.Unlock() + + fl := &cc.flow + if cs != nil { + fl = &cs.flow + } + if !fl.add(int32(f.Increment)) { + return http2ConnectionError(http2ErrCodeFlowControl) + } + cc.cond.Broadcast() + return nil +} + +func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { + cs := rl.cc.streamByID(f.StreamID, true) + if cs == nil { + + return nil + } + select { + case <-cs.peerReset: + + default: + err := http2StreamError{cs.ID, f.ErrCode} + cs.resetErr = err + close(cs.peerReset) + cs.bufPipe.CloseWithError(err) + cs.cc.cond.Broadcast() + } + delete(rl.activeRes, cs.ID) + return nil +} + +func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { + if f.IsAck() { + + return nil + } + cc := rl.cc + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(true, f.Data); err != nil { + return err + } + return cc.bw.Flush() +} + +func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { + + return http2ConnectionError(http2ErrCodeProtocol) +} + +func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { + + cc.wmu.Lock() + cc.fr.WriteRSTStream(streamID, code) + cc.bw.Flush() + cc.wmu.Unlock() +} + +var ( + http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") + http2errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers") +) + +func (rl *http2clientConnReadLoop) checkHeaderField(f hpack.HeaderField) bool { + if rl.reqMalformed != nil { + return false + } + + const headerFieldOverhead = 32 // per spec + rl.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead + if max := rl.cc.t.maxHeaderListSize(); max != 0 && rl.headerListSize > int64(max) { + rl.hdec.SetEmitEnabled(false) + rl.reqMalformed = http2errResponseHeaderListSize + return false + } + + if !http2validHeaderFieldValue(f.Value) { + rl.reqMalformed = http2errInvalidHeaderFieldValue + return false + } + + isPseudo := strings.HasPrefix(f.Name, ":") + if isPseudo { + if rl.sawRegHeader { + rl.reqMalformed = errors.New("http2: invalid pseudo header after regular header") + return false + } + } else { + if !http2validHeaderFieldName(f.Name) { + rl.reqMalformed = http2errInvalidHeaderFieldName + return false + } + rl.sawRegHeader = true + } + + return true +} + +// onNewHeaderField runs on the readLoop goroutine whenever a new +// hpack header field is decoded. +func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) { + cc := rl.cc + if http2VerboseLogs { + cc.logf("http2: Transport decoded %v", f) + } + + if !rl.checkHeaderField(f) { + return + } + + isPseudo := strings.HasPrefix(f.Name, ":") + if isPseudo { + switch f.Name { + case ":status": + code, err := strconv.Atoi(f.Value) + if err != nil { + rl.reqMalformed = errors.New("http2: invalid :status") + return + } + rl.nextRes.Status = f.Value + " " + StatusText(code) + rl.nextRes.StatusCode = code + default: + + rl.reqMalformed = fmt.Errorf("http2: unknown response pseudo header %q", f.Name) + } + return + } + + key := CanonicalHeaderKey(f.Name) + if key == "Trailer" { + t := rl.nextRes.Trailer + if t == nil { + t = make(Header) + rl.nextRes.Trailer = t + } + http2foreachHeaderElement(f.Value, func(v string) { + t[CanonicalHeaderKey(v)] = nil + }) + } else { + rl.nextRes.Header.Add(key, f.Value) + } +} + +func (rl *http2clientConnReadLoop) onNewTrailerField(cs *http2clientStream, f hpack.HeaderField) { + if http2VerboseLogs { + rl.cc.logf("http2: Transport decoded trailer %v", f) + } + if !rl.checkHeaderField(f) { + return + } + if strings.HasPrefix(f.Name, ":") { + + rl.reqMalformed = http2errPseudoTrailers + return + } + + key := CanonicalHeaderKey(f.Name) + + // The spec says one must predeclare their trailers but in practice + // popular users (which is to say the only user we found) do not so we + // violate the spec and accept all of them. + const acceptAllTrailers = true + if _, ok := (*cs.resTrailer)[key]; ok || acceptAllTrailers { + if cs.trailer == nil { + cs.trailer = make(Header) + } + cs.trailer[key] = append(cs.trailer[key], f.Value) + } +} + +func (cc *http2ClientConn) logf(format string, args ...interface{}) { + cc.t.logf(format, args...) +} + +func (cc *http2ClientConn) vlogf(format string, args ...interface{}) { + cc.t.vlogf(format, args...) +} + +func (t *http2Transport) vlogf(format string, args ...interface{}) { + if http2VerboseLogs { + t.logf(format, args...) + } +} + +func (t *http2Transport) logf(format string, args ...interface{}) { + log.Printf(format, args...) +} + +var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) + +func http2strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} + +type http2erringRoundTripper struct{ err error } + +func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err } + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type http2gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr io.Reader // lazily-initialized gzip reader +} + +func (gz *http2gzipReader) Read(p []byte) (n int, err error) { + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *http2gzipReader) Close() error { + return gz.body.Close() +} + +type http2errorReader struct{ err error } + +func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } + +// writeFramer is implemented by any type that is used to write frames. +type http2writeFramer interface { + writeFrame(http2writeContext) error +} + +// writeContext is the interface needed by the various frame writer +// types below. All the writeFrame methods below are scheduled via the +// frame writing scheduler (see writeScheduler in writesched.go). +// +// This interface is implemented by *serverConn. +// +// TODO: decide whether to a) use this in the client code (which didn't +// end up using this yet, because it has a simpler design, not +// currently implementing priorities), or b) delete this and +// make the server code a bit more concrete. +type http2writeContext interface { + Framer() *http2Framer + Flush() error + CloseConn() error + // HeaderEncoder returns an HPACK encoder that writes to the + // returned buffer. + HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) +} + +// endsStream reports whether the given frame writer w will locally +// close the stream. +func http2endsStream(w http2writeFramer) bool { + switch v := w.(type) { + case *http2writeData: + return v.endStream + case *http2writeResHeaders: + return v.endStream + case nil: + + panic("endsStream called on nil writeFramer") + } + return false +} + +type http2flushFrameWriter struct{} + +func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { + return ctx.Flush() +} + +type http2writeSettings []http2Setting + +func (s http2writeSettings) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettings([]http2Setting(s)...) +} + +type http2writeGoAway struct { + maxStreamID uint32 + code http2ErrCode +} + +func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { + err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) + if p.code != 0 { + ctx.Flush() + time.Sleep(50 * time.Millisecond) + ctx.CloseConn() + } + return err +} + +type http2writeData struct { + streamID uint32 + p []byte + endStream bool +} + +func (w *http2writeData) String() string { + return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) +} + +func (w *http2writeData) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) +} + +// handlerPanicRST is the message sent from handler goroutines when +// the handler panics. +type http2handlerPanicRST struct { + StreamID uint32 +} + +func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) +} + +func (se http2StreamError) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) +} + +type http2writePingAck struct{ pf *http2PingFrame } + +func (w http2writePingAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WritePing(true, w.pf.Data) +} + +type http2writeSettingsAck struct{} + +func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteSettingsAck() +} + +// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames +// for HTTP response headers or trailers from a server handler. +type http2writeResHeaders struct { + streamID uint32 + httpResCode int // 0 means no ":status" line + h Header // may be nil + trailers []string // if non-nil, which keys of h to write. nil means all. + endStream bool + + date string + contentType string + contentLength string +} + +func http2encKV(enc *hpack.Encoder, k, v string) { + if http2VerboseLogs { + log.Printf("http2: server encoding header %q = %q", k, v) + } + enc.WriteField(hpack.HeaderField{Name: k, Value: v}) +} + +func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + if w.httpResCode != 0 { + http2encKV(enc, ":status", http2httpCodeString(w.httpResCode)) + } + + http2encodeHeaders(enc, w.h, w.trailers) + + if w.contentType != "" { + http2encKV(enc, "content-type", w.contentType) + } + if w.contentLength != "" { + http2encKV(enc, "content-length", w.contentLength) + } + if w.date != "" { + http2encKV(enc, "date", w.date) + } + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 && w.trailers == nil { + panic("unexpected empty hpack") + } + + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + endHeaders := len(headerBlock) == 0 + var err error + if first { + first = false + err = ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: endHeaders, + }) + } else { + err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) + } + if err != nil { + return err + } + } + return nil +} + +type http2write100ContinueHeadersFrame struct { + streamID uint32 +} + +func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + http2encKV(enc, ":status", "100") + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: buf.Bytes(), + EndStream: false, + EndHeaders: true, + }) +} + +type http2writeWindowUpdate struct { + streamID uint32 // or 0 for conn-level + n uint32 +} + +func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { + return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) +} + +func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { + + if keys == nil { + keys = make([]string, 0, len(h)) + for k := range h { + keys = append(keys, k) + } + sort.Strings(keys) + } + for _, k := range keys { + vv := h[k] + k = http2lowerHeader(k) + isTE := k == "transfer-encoding" + for _, v := range vv { + + if isTE && v != "trailers" { + continue + } + http2encKV(enc, k, v) + } + } +} + +// frameWriteMsg is a request to write a frame. +type http2frameWriteMsg struct { + // write is the interface value that does the writing, once the + // writeScheduler (below) has decided to select this frame + // to write. The write functions are all defined in write.go. + write http2writeFramer + + stream *http2stream // used for prioritization. nil for non-stream frames. + + // done, if non-nil, must be a buffered channel with space for + // 1 message and is sent the return value from write (or an + // earlier error) when the frame has been written. + done chan error +} + +// for debugging only: +func (wm http2frameWriteMsg) String() string { + var streamID uint32 + if wm.stream != nil { + streamID = wm.stream.id + } + var des string + if s, ok := wm.write.(fmt.Stringer); ok { + des = s.String() + } else { + des = fmt.Sprintf("%T", wm.write) + } + return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des) +} + +// writeScheduler tracks pending frames to write, priorities, and decides +// the next one to use. It is not thread-safe. +type http2writeScheduler struct { + // zero are frames not associated with a specific stream. + // They're sent before any stream-specific freams. + zero http2writeQueue + + // maxFrameSize is the maximum size of a DATA frame + // we'll write. Must be non-zero and between 16K-16M. + maxFrameSize uint32 + + // sq contains the stream-specific queues, keyed by stream ID. + // when a stream is idle, it's deleted from the map. + sq map[uint32]*http2writeQueue + + // canSend is a slice of memory that's reused between frame + // scheduling decisions to hold the list of writeQueues (from sq) + // which have enough flow control data to send. After canSend is + // built, the best is selected. + canSend []*http2writeQueue + + // pool of empty queues for reuse. + queuePool []*http2writeQueue +} + +func (ws *http2writeScheduler) putEmptyQueue(q *http2writeQueue) { + if len(q.s) != 0 { + panic("queue must be empty") + } + ws.queuePool = append(ws.queuePool, q) +} + +func (ws *http2writeScheduler) getEmptyQueue() *http2writeQueue { + ln := len(ws.queuePool) + if ln == 0 { + return new(http2writeQueue) + } + q := ws.queuePool[ln-1] + ws.queuePool = ws.queuePool[:ln-1] + return q +} + +func (ws *http2writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 } + +func (ws *http2writeScheduler) add(wm http2frameWriteMsg) { + st := wm.stream + if st == nil { + ws.zero.push(wm) + } else { + ws.streamQueue(st.id).push(wm) + } +} + +func (ws *http2writeScheduler) streamQueue(streamID uint32) *http2writeQueue { + if q, ok := ws.sq[streamID]; ok { + return q + } + if ws.sq == nil { + ws.sq = make(map[uint32]*http2writeQueue) + } + q := ws.getEmptyQueue() + ws.sq[streamID] = q + return q +} + +// take returns the most important frame to write and removes it from the scheduler. +// It is illegal to call this if the scheduler is empty or if there are no connection-level +// flow control bytes available. +func (ws *http2writeScheduler) take() (wm http2frameWriteMsg, ok bool) { + if ws.maxFrameSize == 0 { + panic("internal error: ws.maxFrameSize not initialized or invalid") + } + + if !ws.zero.empty() { + return ws.zero.shift(), true + } + if len(ws.sq) == 0 { + return + } + + for id, q := range ws.sq { + if q.firstIsNoCost() { + return ws.takeFrom(id, q) + } + } + + if len(ws.canSend) != 0 { + panic("should be empty") + } + for _, q := range ws.sq { + if n := ws.streamWritableBytes(q); n > 0 { + ws.canSend = append(ws.canSend, q) + } + } + if len(ws.canSend) == 0 { + return + } + defer ws.zeroCanSend() + + q := ws.canSend[0] + + return ws.takeFrom(q.streamID(), q) +} + +// zeroCanSend is defered from take. +func (ws *http2writeScheduler) zeroCanSend() { + for i := range ws.canSend { + ws.canSend[i] = nil + } + ws.canSend = ws.canSend[:0] +} + +// streamWritableBytes returns the number of DATA bytes we could write +// from the given queue's stream, if this stream/queue were +// selected. It is an error to call this if q's head isn't a +// *writeData. +func (ws *http2writeScheduler) streamWritableBytes(q *http2writeQueue) int32 { + wm := q.head() + ret := wm.stream.flow.available() + if ret == 0 { + return 0 + } + if int32(ws.maxFrameSize) < ret { + ret = int32(ws.maxFrameSize) + } + if ret == 0 { + panic("internal error: ws.maxFrameSize not initialized or invalid") + } + wd := wm.write.(*http2writeData) + if len(wd.p) < int(ret) { + ret = int32(len(wd.p)) + } + return ret +} + +func (ws *http2writeScheduler) takeFrom(id uint32, q *http2writeQueue) (wm http2frameWriteMsg, ok bool) { + wm = q.head() + + if wd, ok := wm.write.(*http2writeData); ok && len(wd.p) > 0 { + allowed := wm.stream.flow.available() + if allowed == 0 { + + return http2frameWriteMsg{}, false + } + if int32(ws.maxFrameSize) < allowed { + allowed = int32(ws.maxFrameSize) + } + + if len(wd.p) > int(allowed) { + wm.stream.flow.take(allowed) + chunk := wd.p[:allowed] + wd.p = wd.p[allowed:] + + return http2frameWriteMsg{ + stream: wm.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: chunk, + + endStream: false, + }, + + done: nil, + }, true + } + wm.stream.flow.take(int32(len(wd.p))) + } + + q.shift() + if q.empty() { + ws.putEmptyQueue(q) + delete(ws.sq, id) + } + return wm, true +} + +func (ws *http2writeScheduler) forgetStream(id uint32) { + q, ok := ws.sq[id] + if !ok { + return + } + delete(ws.sq, id) + + for i := range q.s { + q.s[i] = http2frameWriteMsg{} + } + q.s = q.s[:0] + ws.putEmptyQueue(q) +} + +type http2writeQueue struct { + s []http2frameWriteMsg +} + +// streamID returns the stream ID for a non-empty stream-specific queue. +func (q *http2writeQueue) streamID() uint32 { return q.s[0].stream.id } + +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } + +func (q *http2writeQueue) push(wm http2frameWriteMsg) { + q.s = append(q.s, wm) +} + +// head returns the next item that would be removed by shift. +func (q *http2writeQueue) head() http2frameWriteMsg { + if len(q.s) == 0 { + panic("invalid use of queue") + } + return q.s[0] +} + +func (q *http2writeQueue) shift() http2frameWriteMsg { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wm := q.s[0] + + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2frameWriteMsg{} + q.s = q.s[:len(q.s)-1] + return wm +} + +func (q *http2writeQueue) firstIsNoCost() bool { + if df, ok := q.s[0].write.(*http2writeData); ok { + return len(df.p) == 0 + } + return true +} diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index d847b131184..049f32f27dc 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -211,3 +211,13 @@ func hasToken(v, token string) bool { func isTokenBoundary(b byte) bool { return b == ' ' || b == ',' || b == '\t' } + +func cloneHeader(h Header) Header { + h2 := make(Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index 5451f54234c..7c51af1867a 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -44,23 +44,60 @@ func (rw *ResponseRecorder) Header() http.Header { return m } +// writeHeader writes a header if it was not written yet and +// detects Content-Type if needed. +// +// bytes or str are the beginning of the response body. +// We pass both to avoid unnecessarily generate garbage +// in rw.WriteString which was created for performance reasons. +// Non-nil bytes win. +func (rw *ResponseRecorder) writeHeader(b []byte, str string) { + if rw.wroteHeader { + return + } + if len(str) > 512 { + str = str[:512] + } + + _, hasType := rw.HeaderMap["Content-Type"] + hasTE := rw.HeaderMap.Get("Transfer-Encoding") != "" + if !hasType && !hasTE { + if b == nil { + b = []byte(str) + } + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + rw.HeaderMap.Set("Content-Type", http.DetectContentType(b)) + } + + rw.WriteHeader(200) +} + // Write always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { - if !rw.wroteHeader { - rw.WriteHeader(200) - } + rw.writeHeader(buf, "") if rw.Body != nil { rw.Body.Write(buf) } return len(buf), nil } +// WriteString always succeeds and writes to rw.Body, if not nil. +func (rw *ResponseRecorder) WriteString(str string) (int, error) { + rw.writeHeader(nil, str) + if rw.Body != nil { + rw.Body.WriteString(str) + } + return len(str), nil +} + // WriteHeader sets rw.Code. func (rw *ResponseRecorder) WriteHeader(code int) { if !rw.wroteHeader { rw.Code = code + rw.wroteHeader = true } - rw.wroteHeader = true } // Flush sets rw.Flushed to true. diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go index 2b563260c76..c29b6d4cf91 100644 --- a/libgo/go/net/http/httptest/recorder_test.go +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -6,6 +6,7 @@ package httptest import ( "fmt" + "io" "net/http" "testing" ) @@ -38,6 +39,14 @@ func TestRecorder(t *testing.T) { return nil } } + hasHeader := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.HeaderMap.Get(key); got != want { + return fmt.Errorf("header %s = %q; want %q", key, got, want) + } + return nil + } + } tests := []struct { name string @@ -68,6 +77,18 @@ func TestRecorder(t *testing.T) { check(hasStatus(200), hasContents("hi first"), hasFlush(false)), }, { + "write string", + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "hi first") + }, + check( + hasStatus(200), + hasContents("hi first"), + hasFlush(false), + hasHeader("Content-Type", "text/plain; charset=utf-8"), + ), + }, + { "flush", func(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() // also sends a 200 @@ -75,6 +96,40 @@ func TestRecorder(t *testing.T) { }, check(hasStatus(200), hasFlush(true)), }, + { + "Content-Type detection", + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "text/html; charset=utf-8")), + }, + { + "no Content-Type detection with Transfer-Encoding", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Transfer-Encoding", "some encoding") + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "")), // no header + }, + { + "no Content-Type detection if set explicitly", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "some/type") + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "some/type")), + }, + { + "Content-Type detection doesn't crash if HeaderMap is nil", + func(w http.ResponseWriter, r *http.Request) { + // Act as if the user wrote new(httptest.ResponseRecorder) + // rather than using NewRecorder (which initializes + // HeaderMap) + w.(*ResponseRecorder).HeaderMap = nil + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "text/html; charset=utf-8")), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests { diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index 96eb0ef6d2f..5c19c0ca340 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -7,13 +7,18 @@ package httptest import ( + "bytes" "crypto/tls" "flag" "fmt" + "log" "net" "net/http" + "net/http/internal" "os" + "runtime" "sync" + "time" ) // A Server is an HTTP server listening on a system-chosen port on the @@ -34,24 +39,10 @@ type Server struct { // wg counts the number of outstanding HTTP requests on this server. // Close blocks until all requests are finished. wg sync.WaitGroup -} - -// historyListener keeps track of all connections that it's ever -// accepted. -type historyListener struct { - net.Listener - sync.Mutex // protects history - history []net.Conn -} -func (hs *historyListener) Accept() (c net.Conn, err error) { - c, err = hs.Listener.Accept() - if err == nil { - hs.Lock() - hs.history = append(hs.history, c) - hs.Unlock() - } - return + mu sync.Mutex // guards closed and conns + closed bool + conns map[net.Conn]http.ConnState // except terminal states } func newLocalListener() net.Listener { @@ -103,10 +94,9 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } - s.Listener = &historyListener{Listener: s.Listener} s.URL = "http://" + s.Listener.Addr().String() - s.wrapHandler() - go s.Config.Serve(s.Listener) + s.wrap() + s.goServe() if *serve != "" { fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) select {} @@ -118,7 +108,7 @@ func (s *Server) StartTLS() { if s.URL != "" { panic("Server already started") } - cert, err := tls.X509KeyPair(localhostCert, localhostKey) + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } @@ -134,23 +124,10 @@ func (s *Server) StartTLS() { if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } - tlsListener := tls.NewListener(s.Listener, s.TLS) - - s.Listener = &historyListener{Listener: tlsListener} + s.Listener = tls.NewListener(s.Listener, s.TLS) s.URL = "https://" + s.Listener.Addr().String() - s.wrapHandler() - go s.Config.Serve(s.Listener) -} - -func (s *Server) wrapHandler() { - h := s.Config.Handler - if h == nil { - h = http.DefaultServeMux - } - s.Config.Handler = &waitGroupHandler{ - s: s, - h: h, - } + s.wrap() + s.goServe() } // NewTLSServer starts and returns a new Server using TLS. @@ -161,78 +138,155 @@ func NewTLSServer(handler http.Handler) *Server { return ts } +type closeIdleTransport interface { + CloseIdleConnections() +} + // Close shuts down the server and blocks until all outstanding // requests on this server have completed. func (s *Server) Close() { - s.Listener.Close() - s.wg.Wait() - s.CloseClientConnections() - if t, ok := http.DefaultTransport.(*http.Transport); ok { + s.mu.Lock() + if !s.closed { + s.closed = true + s.Listener.Close() + s.Config.SetKeepAlivesEnabled(false) + for c, st := range s.conns { + // Force-close any idle connections (those between + // requests) and new connections (those which connected + // but never sent a request). StateNew connections are + // super rare and have only been seen (in + // previously-flaky tests) in the case of + // socket-late-binding races from the http Client + // dialing this server and then getting an idle + // connection before the dial completed. There is thus + // a connected connection in StateNew with no + // associated Request. We only close StateIdle and + // StateNew because they're not doing anything. It's + // possible StateNew is about to do something in a few + // milliseconds, but a previous CL to check again in a + // few milliseconds wasn't liked (early versions of + // https://golang.org/cl/15151) so now we just + // forcefully close StateNew. The docs for Server.Close say + // we wait for "oustanding requests", so we don't close things + // in StateActive. + if st == http.StateIdle || st == http.StateNew { + s.closeConn(c) + } + } + // If this server doesn't shut down in 20 seconds, tell the user why. + t := time.AfterFunc(20*time.Second, s.logCloseHangDebugInfo) + defer t.Stop() + } + s.mu.Unlock() + + // Not part of httptest.Server's correctness, but assume most + // users of httptest.Server will be using the standard + // transport, so help them out and close any idle connections for them. + if t, ok := http.DefaultTransport.(closeIdleTransport); ok { t.CloseIdleConnections() } + + s.wg.Wait() } -// CloseClientConnections closes any currently open HTTP connections -// to the test Server. -func (s *Server) CloseClientConnections() { - hl, ok := s.Listener.(*historyListener) - if !ok { - return +func (s *Server) logCloseHangDebugInfo() { + s.mu.Lock() + defer s.mu.Unlock() + var buf bytes.Buffer + buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n") + for c, st := range s.conns { + fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st) } - hl.Lock() - for _, conn := range hl.history { - conn.Close() + log.Print(buf.String()) +} + +// CloseClientConnections closes any open HTTP connections to the test Server. +func (s *Server) CloseClientConnections() { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.conns { + s.closeConn(c) } - hl.Unlock() } -// waitGroupHandler wraps a handler, incrementing and decrementing a -// sync.WaitGroup on each request, to enable Server.Close to block -// until outstanding requests are finished. -type waitGroupHandler struct { - s *Server - h http.Handler // non-nil +func (s *Server) goServe() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.Config.Serve(s.Listener) + }() +} + +// wrap installs the connection state-tracking hook to know which +// connections are idle. +func (s *Server) wrap() { + oldHook := s.Config.ConnState + s.Config.ConnState = func(c net.Conn, cs http.ConnState) { + s.mu.Lock() + defer s.mu.Unlock() + switch cs { + case http.StateNew: + s.wg.Add(1) + if _, exists := s.conns[c]; exists { + panic("invalid state transition") + } + if s.conns == nil { + s.conns = make(map[net.Conn]http.ConnState) + } + s.conns[c] = cs + if s.closed { + // Probably just a socket-late-binding dial from + // the default transport that lost the race (and + // thus this connection is now idle and will + // never be used). + s.closeConn(c) + } + case http.StateActive: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateNew && oldState != http.StateIdle { + panic("invalid state transition") + } + s.conns[c] = cs + } + case http.StateIdle: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateActive { + panic("invalid state transition") + } + s.conns[c] = cs + } + if s.closed { + s.closeConn(c) + } + case http.StateHijacked, http.StateClosed: + s.forgetConn(c) + } + if oldHook != nil { + oldHook(c, cs) + } + } } -func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.s.wg.Add(1) - defer h.s.wg.Done() // a defer, in case ServeHTTP below panics - h.h.ServeHTTP(w, r) +// closeConn closes c. Except on plan9, which is special. See comment below. +// s.mu must be held. +func (s *Server) closeConn(c net.Conn) { + if runtime.GOOS == "plan9" { + // Go's Plan 9 net package isn't great at unblocking reads when + // their underlying TCP connections are closed. Don't trust + // that that the ConnState state machine will get to + // StateClosed. Instead, just go there directly. Plan 9 may leak + // resources if the syscall doesn't end up returning. Oh well. + s.forgetConn(c) + } + go c.Close() } -// localhostCert is a PEM-encoded TLS cert with SAN IPs -// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end -// of ASN.1 time). -// generated from src/crypto/tls: -// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h -var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS -MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw -MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB -iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 -iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul -rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO -BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw -AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA -AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 -tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs -h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM -fblo6RBxUQ== ------END CERTIFICATE-----`) - -// localhostKey is the private key for localhostCert. -var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 -SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB -l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB -AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet -3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb -uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H -qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp -jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY -fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U -fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU -y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX -qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo -f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== ------END RSA PRIVATE KEY-----`) +// forgetConn removes c from the set of tracked conns and decrements it from the +// waitgroup, unless it was previously removed. +// s.mu must be held. +func (s *Server) forgetConn(c net.Conn) { + if _, ok := s.conns[c]; ok { + delete(s.conns, c) + s.wg.Done() + } +} diff --git a/libgo/go/net/http/httptest/server_test.go b/libgo/go/net/http/httptest/server_test.go index 500a9f0b800..6ffc671e575 100644 --- a/libgo/go/net/http/httptest/server_test.go +++ b/libgo/go/net/http/httptest/server_test.go @@ -5,7 +5,9 @@ package httptest import ( + "bufio" "io/ioutil" + "net" "net/http" "testing" ) @@ -27,3 +29,58 @@ func TestServer(t *testing.T) { t.Errorf("got %q, want hello", string(got)) } } + +// Issue 12781 +func TestGetAfterClose(t *testing.T) { + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Fatalf("got %q, want hello", string(got)) + } + + ts.Close() + + res, err = http.Get(ts.URL) + if err == nil { + body, _ := ioutil.ReadAll(res.Body) + t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body) + } +} + +func TestServerCloseBlocking(t *testing.T) { + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + dial := func() net.Conn { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + return c + } + + // Keep one connection in StateNew (connected, but not sending anything) + cnew := dial() + defer cnew.Close() + + // Keep one connection in StateIdle (idle after a request) + cidle := dial() + defer cidle.Close() + cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) + _, err := http.ReadResponse(bufio.NewReader(cidle), nil) + if err != nil { + t.Fatal(err) + } + + ts.Close() // test we don't hang here forever. +} diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index ca2d1cde924..e22cc66dbfc 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -25,10 +25,10 @@ import ( func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { var buf bytes.Buffer if _, err = buf.ReadFrom(b); err != nil { - return nil, nil, err + return nil, b, err } if err = b.Close(); err != nil { - return nil, nil, err + return nil, b, err } return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil } @@ -55,9 +55,9 @@ func (b neverEnding) Read(p []byte) (n int, err error) { return len(p), nil } -// DumpRequestOut is like DumpRequest but includes -// headers that the standard http.Transport adds, -// such as User-Agent. +// DumpRequestOut is like DumpRequest but for outgoing client requests. It +// includes any headers that the standard http.Transport adds, such as +// User-Agent. func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { save := req.Body dummyBody := false @@ -175,13 +175,22 @@ func dumpAsReceived(req *http.Request, w io.Writer) error { return nil } -// DumpRequest returns the as-received wire representation of req, -// optionally including the request body, for debugging. -// DumpRequest is semantically a no-op, but in order to -// dump the body, it reads the body data into memory and -// changes req.Body to refer to the in-memory copy. +// DumpRequest returns the given request in its HTTP/1.x wire +// representation. It should only be used by servers to debug client +// requests. The returned representation is an approximation only; +// some details of the initial request are lost while parsing it into +// an http.Request. In particular, the order and case of header field +// names are lost. The order of values in multi-valued headers is kept +// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their +// original binary representations. +// +// If body is true, DumpRequest also returns the body. To do so, it +// consumes req.Body and then replaces it with a new io.ReadCloser +// that yields the same bytes. If DumpRequest returns an error, +// the state of req is undefined. +// // The documentation for http.Request.Write details which fields -// of req are used. +// of req are included in the dump. func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { save := req.Body if !body || req.Body == nil { @@ -189,21 +198,35 @@ func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { } else { save, req.Body, err = drainBody(req.Body) if err != nil { - return + return nil, err } } var b bytes.Buffer + // By default, print out the unmodified req.RequestURI, which + // is always set for incoming server requests. But because we + // previously used req.URL.RequestURI and the docs weren't + // always so clear about when to use DumpRequest vs + // DumpRequestOut, fall back to the old way if the caller + // provides a non-server Request. + reqURI := req.RequestURI + if reqURI == "" { + reqURI = req.URL.RequestURI() + } + fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"), - req.URL.RequestURI(), req.ProtoMajor, req.ProtoMinor) + reqURI, req.ProtoMajor, req.ProtoMinor) - host := req.Host - if host == "" && req.URL != nil { - host = req.URL.Host - } - if host != "" { - fmt.Fprintf(&b, "Host: %s\r\n", host) + absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://") + if !absRequestURI { + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&b, "Host: %s\r\n", host) + } } chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" @@ -269,7 +292,7 @@ func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { } else { save, resp.Body, err = drainBody(resp.Body) if err != nil { - return + return nil, err } } err = resp.Write(&b) diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index ae67e983ae9..46bf521723a 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -5,6 +5,7 @@ package httputil import ( + "bufio" "bytes" "fmt" "io" @@ -135,6 +136,14 @@ var dumpTests = []dumpTest{ "Accept-Encoding: gzip\r\n\r\n" + strings.Repeat("a", 8193), }, + + { + Req: *mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" + + "User-Agent: blah\r\n\r\n"), + NoBody: true, + WantDump: "GET http://foo.com/ HTTP/1.1\r\n" + + "User-Agent: blah\r\n\r\n", + }, } func TestDumpRequest(t *testing.T) { @@ -211,6 +220,14 @@ func mustNewRequest(method, url string, body io.Reader) *http.Request { return req } +func mustReadRequest(s string) *http.Request { + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(s))) + if err != nil { + panic(err) + } + return req +} + var dumpResTests = []struct { res *http.Response body bool diff --git a/libgo/go/net/http/httputil/example_test.go b/libgo/go/net/http/httputil/example_test.go new file mode 100644 index 00000000000..8fb1a2d2792 --- /dev/null +++ b/libgo/go/net/http/httputil/example_test.go @@ -0,0 +1,125 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package httputil_test + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" +) + +func ExampleDumpRequest() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dump, err := httputil.DumpRequest(r, true) + if err != nil { + http.Error(w, fmt.Sprint(err), http.StatusInternalServerError) + return + } + + fmt.Fprintf(w, "%q", dump) + })) + defer ts.Close() + + const body = "Go is a general-purpose language designed with systems programming in mind." + req, err := http.NewRequest("POST", ts.URL, strings.NewReader(body)) + if err != nil { + log.Fatal(err) + } + req.Host = "www.example.org" + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind." +} + +func ExampleDumpRequestOut() { + const body = "Go is a general-purpose language designed with systems programming in mind." + req, err := http.NewRequest("PUT", "http://www.example.org", strings.NewReader(body)) + if err != nil { + log.Fatal(err) + } + + dump, err := httputil.DumpRequestOut(req, true) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%q", dump) + + // Output: + // "PUT / HTTP/1.1\r\nHost: www.example.org\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 75\r\nAccept-Encoding: gzip\r\n\r\nGo is a general-purpose language designed with systems programming in mind." +} + +func ExampleDumpResponse() { + const body = "Go is a general-purpose language designed with systems programming in mind." + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Date", "Wed, 19 Jul 1972 19:00:00 GMT") + fmt.Fprintln(w, body) + })) + defer ts.Close() + + resp, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%q", dump) + + // Output: + // "HTTP/1.1 200 OK\r\nContent-Length: 76\r\nContent-Type: text/plain; charset=utf-8\r\nDate: Wed, 19 Jul 1972 19:00:00 GMT\r\n\r\nGo is a general-purpose language designed with systems programming in mind.\n" +} + +func ExampleReverseProxy() { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "this call was relayed by the reverse proxy") + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + log.Fatal(err) + } + frontendProxy := httptest.NewServer(httputil.NewSingleHostReverseProxy(rpURL)) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL) + if err != nil { + log.Fatal(err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // this call was relayed by the reverse proxy +} diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index c8e113221c4..4dba352a4fa 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -46,6 +46,18 @@ type ReverseProxy struct { // If nil, logging goes to os.Stderr via the log package's // standard logger. ErrorLog *log.Logger + + // BufferPool optionally specifies a buffer pool to + // get byte slices for use by io.CopyBuffer when + // copying HTTP response bodies. + BufferPool BufferPool +} + +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by io.CopyBuffer. +type BufferPool interface { + Get() []byte + Put([]byte) } func singleJoiningSlash(a, b string) string { @@ -60,10 +72,13 @@ func singleJoiningSlash(a, b string) string { return a + b } -// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// NewSingleHostReverseProxy returns a new ReverseProxy that routes // URLs to the scheme, host, and base path provided in target. If the // target's path is "/base" and the incoming request was for "/dir", // the target request will be for /base/dir. +// NewSingleHostReverseProxy does not rewrite the Host header. +// To rewrite Host headers, use ReverseProxy directly with a custom +// Director policy. func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { targetQuery := target.RawQuery director := func(req *http.Request) { @@ -242,7 +257,14 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { } } - io.Copy(dst, src) + var buf []byte + if p.BufferPool != nil { + buf = p.BufferPool.Get() + } + io.CopyBuffer(dst, src, buf) + if p.BufferPool != nil { + p.BufferPool.Put(buf) + } } func (p *ReverseProxy) logf(format string, args ...interface{}) { diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 80a26abe414..7f203d878f5 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -8,14 +8,17 @@ package httputil import ( "bufio" + "bytes" + "io" "io/ioutil" "log" "net/http" "net/http/httptest" "net/url" "reflect" - "runtime" + "strconv" "strings" + "sync" "testing" "time" ) @@ -102,7 +105,6 @@ func TestReverseProxy(t *testing.T) { if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) } - } func TestXForwardedFor(t *testing.T) { @@ -225,10 +227,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { } } -func TestReverseProxyCancellation(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/9554") - } +func TestReverseProxyCancelation(t *testing.T) { const backendResponse = "I am the backend" reqInFlight := make(chan struct{}) @@ -320,3 +319,108 @@ func TestNilBody(t *testing.T) { t.Errorf("Got %q; want %q", slurp, "hi") } } + +type bufferPool struct { + get func() []byte + put func([]byte) +} + +func (bp bufferPool) Get() []byte { return bp.get() } +func (bp bufferPool) Put(v []byte) { bp.put(v) } + +func TestReverseProxyGetPutBuffer(t *testing.T) { + const msg = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + var ( + mu sync.Mutex + log []string + ) + addLog := func(event string) { + mu.Lock() + defer mu.Unlock() + log = append(log, event) + } + rp := NewSingleHostReverseProxy(backendURL) + const size = 1234 + rp.BufferPool = bufferPool{ + get: func() []byte { + addLog("getBuf") + return make([]byte, size) + }, + put: func(p []byte) { + addLog("putBuf-" + strconv.Itoa(len(p))) + }, + } + frontend := httptest.NewServer(rp) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if string(slurp) != msg { + t.Errorf("msg = %q; want %q", slurp, msg) + } + wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} + mu.Lock() + defer mu.Unlock() + if !reflect.DeepEqual(log, wantLog) { + t.Errorf("Log events = %q; want %q", log, wantLog) + } +} + +func TestReverseProxy_Post(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 200 + var requestBody = bytes.Repeat([]byte("a"), 1<<20) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slurp, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("Backend body read = %v", err) + } + if len(slurp) != len(requestBody) { + t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) + } + if !bytes.Equal(slurp, requestBody) { + t.Error("Backend read wrong request body.") // 1MB; omitting details + } + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) + res, err := http.DefaultClient.Do(postReq) + if err != nil { + t.Fatalf("Do: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go index 6d7c69874d9..2e62c00d5db 100644 --- a/libgo/go/net/http/internal/chunked.go +++ b/libgo/go/net/http/internal/chunked.go @@ -44,7 +44,7 @@ type chunkedReader struct { func (cr *chunkedReader) beginChunk() { // chunk-size CRLF var line []byte - line, cr.err = readLine(cr.r) + line, cr.err = readChunkLine(cr.r) if cr.err != nil { return } @@ -104,10 +104,11 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // Read a line of bytes (up to \n) from b. // Give up if the line exceeds maxLineLength. -// The returned bytes are a pointer into storage in -// the bufio, so they are only valid until the next bufio read. -func readLine(b *bufio.Reader) (p []byte, err error) { - if p, err = b.ReadSlice('\n'); err != nil { +// The returned bytes are owned by the bufio.Reader +// so they are only valid until the next bufio read. +func readChunkLine(b *bufio.Reader) ([]byte, error) { + p, err := b.ReadSlice('\n') + if err != nil { // We always know when EOF is coming. // If the caller asked for a line, there should be a line. if err == io.EOF { @@ -120,7 +121,12 @@ func readLine(b *bufio.Reader) (p []byte, err error) { if len(p) >= maxLineLength { return nil, ErrLineTooLong } - return trimTrailingWhitespace(p), nil + p = trimTrailingWhitespace(p) + p, err = removeChunkExtension(p) + if err != nil { + return nil, err + } + return p, nil } func trimTrailingWhitespace(b []byte) []byte { @@ -134,6 +140,23 @@ func isASCIISpace(b byte) bool { return b == ' ' || b == '\t' || b == '\n' || b == '\r' } +// removeChunkExtension removes any chunk-extension from p. +// For example, +// "0" => "0" +// "0;token" => "0" +// "0;token=val" => "0" +// `0;token="quoted string"` => "0" +func removeChunkExtension(p []byte) ([]byte, error) { + semi := bytes.IndexByte(p, ';') + if semi == -1 { + return p, nil + } + // TODO: care about exact syntax of chunk extensions? We're + // ignoring and stripping them anyway. For now just never + // return an error. + return p[:semi], nil +} + // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP // "chunked" format before writing them to w. Closing the returned chunkedWriter // sends the final 0-length chunk that marks the end of the stream. @@ -197,8 +220,7 @@ type FlushAfterChunkWriter struct { } func parseHexUint(v []byte) (n uint64, err error) { - for _, b := range v { - n <<= 4 + for i, b := range v { switch { case '0' <= b && b <= '9': b = b - '0' @@ -209,6 +231,10 @@ func parseHexUint(v []byte) (n uint64, err error) { default: return 0, errors.New("invalid byte in chunk length") } + if i == 16 { + return 0, errors.New("http chunk length too large") + } + n <<= 4 n |= uint64(b) } return diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go index ebc626ea9d0..a136dc99a65 100644 --- a/libgo/go/net/http/internal/chunked_test.go +++ b/libgo/go/net/http/internal/chunked_test.go @@ -139,18 +139,49 @@ func TestChunkReaderAllocs(t *testing.T) { } func TestParseHexUint(t *testing.T) { + type testCase struct { + in string + want uint64 + wantErr string + } + tests := []testCase{ + {"x", 0, "invalid byte in chunk length"}, + {"0000000000000000", 0, ""}, + {"0000000000000001", 1, ""}, + {"ffffffffffffffff", 1<<64 - 1, ""}, + {"000000000000bogus", 0, "invalid byte in chunk length"}, + {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted + {"10000000000000000", 0, "http chunk length too large"}, + {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted + } for i := uint64(0); i <= 1234; i++ { - line := []byte(fmt.Sprintf("%x", i)) - got, err := parseHexUint(line) - if err != nil { - t.Fatalf("on %d: %v", i, err) - } - if got != i { - t.Errorf("for input %q = %d; want %d", line, got, i) + tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i}) + } + for _, tt := range tests { + got, err := parseHexUint([]byte(tt.in)) + if tt.wantErr != "" { + if !strings.Contains(fmt.Sprint(err), tt.wantErr) { + t.Errorf("parseHexUint(%q) = %v, %v; want error %q", tt.in, got, err, tt.wantErr) + } + } else { + if err != nil || got != tt.want { + t.Errorf("parseHexUint(%q) = %v, %v; want %v", tt.in, got, err, tt.want) + } } } - _, err := parseHexUint([]byte("bogus")) - if err == nil { - t.Error("expected error on bogus input") +} + +func TestChunkReadingIgnoresExtensions(t *testing.T) { + in := "7;ext=\"some quoted string\"\r\n" + // token=quoted string + "hello, \r\n" + + "17;someext\r\n" + // token without value + "world! 0123456789abcdef\r\n" + + "0;someextension=sometoken\r\n" // token=token + data, err := ioutil.ReadAll(NewChunkedReader(strings.NewReader(in))) + if err != nil { + t.Fatalf("ReadAll = %q, %v", data, err) + } + if g, e := string(data), "hello, world! 0123456789abcdef"; g != e { + t.Errorf("read %q; want %q", g, e) } } diff --git a/libgo/go/net/http/internal/testcert.go b/libgo/go/net/http/internal/testcert.go new file mode 100644 index 00000000000..407890920fa --- /dev/null +++ b/libgo/go/net/http/internal/testcert.go @@ -0,0 +1,41 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +// LocalhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. +// generated from src/crypto/tls: +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB +iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 +iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul +rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO +BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw +AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA +AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 +tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs +h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM +fblo6RBxUQ== +-----END CERTIFICATE-----`) + +// LocalhostKey is the private key for localhostCert. +var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 +SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB +l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB +AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet +3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb +uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H +qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp +jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY +fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U +fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU +y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX +qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo +f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== +-----END RSA PRIVATE KEY-----`) diff --git a/libgo/go/net/http/lex.go b/libgo/go/net/http/lex.go index 50b14f8b325..52b6481c14e 100644 --- a/libgo/go/net/http/lex.go +++ b/libgo/go/net/http/lex.go @@ -167,3 +167,17 @@ func tokenEqual(t1, t2 string) bool { } return true } + +// isLWS reports whether b is linear white space, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// LWS = [CRLF] 1*( SP | HT ) +func isLWS(b byte) bool { return b == ' ' || b == '\t' } + +// isCTL reports whether b is a control byte, according +// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 +// CTL = <any US-ASCII control character +// (octets 0 - 31) and DEL (127)> +func isCTL(b byte) bool { + const del = 0x7f // a CTL + return b < ' ' || b == del +} diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index 12eea6f0e11..299cd7b2d2f 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -5,6 +5,7 @@ package http_test import ( + "flag" "fmt" "net/http" "os" @@ -15,6 +16,8 @@ import ( "time" ) +var flaky = flag.Bool("flaky", false, "run known-flaky tests too") + func TestMain(m *testing.M) { v := m.Run() if v == 0 && goroutineLeaked() { @@ -79,6 +82,21 @@ func goroutineLeaked() bool { return true } +// setParallel marks t as a parallel test if we're in short mode +// (all.bash), but as a serial test otherwise. Using t.Parallel isn't +// compatible with the afterTest func in non-short mode. +func setParallel(t *testing.T) { + if testing.Short() { + t.Parallel() + } +} + +func setFlaky(t *testing.T, issue int) { + if !*flaky { + t.Skipf("skipping known flaky test; see golang.org/issue/%d", issue) + } +} + func afterTest(t testing.TB) { http.DefaultTransport.(*http.Transport).CloseIdleConnections() if testing.Short() { diff --git a/libgo/go/net/http/method.go b/libgo/go/net/http/method.go new file mode 100644 index 00000000000..b74f9604d34 --- /dev/null +++ b/libgo/go/net/http/method.go @@ -0,0 +1,20 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// Common HTTP methods. +// +// Unless otherwise noted, these are defined in RFC 7231 section 4.3. +const ( + MethodGet = "GET" + MethodHead = "HEAD" + MethodPost = "POST" + MethodPut = "PUT" + MethodPatch = "PATCH" // RFC 5741 + MethodDelete = "DELETE" + MethodConnect = "CONNECT" + MethodOptions = "OPTIONS" + MethodTrace = "TRACE" +) diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 8994392b1e4..7262c6c1016 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -79,6 +79,17 @@ func Cmdline(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, strings.Join(os.Args, "\x00")) } +func sleep(w http.ResponseWriter, d time.Duration) { + var clientGone <-chan bool + if cn, ok := w.(http.CloseNotifier); ok { + clientGone = cn.CloseNotify() + } + select { + case <-time.After(d): + case <-clientGone: + } +} + // Profile responds with the pprof-formatted cpu profile. // The package initialization registers it as /debug/pprof/profile. func Profile(w http.ResponseWriter, r *http.Request) { @@ -99,7 +110,7 @@ func Profile(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) return } - time.Sleep(time.Duration(sec) * time.Second) + sleep(w, time.Duration(sec)*time.Second) pprof.StopCPUProfile() } @@ -125,7 +136,7 @@ func Trace(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Could not enable tracing: %s\n", err) return } - time.Sleep(time.Duration(sec) * time.Second) + sleep(w, time.Duration(sec)*time.Second) trace.Stop() */ } diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 31fe45a4edb..16c5bb43ac4 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -90,8 +90,11 @@ type Request struct { // request. URL *url.URL - // The protocol version for incoming requests. - // Client requests always use HTTP/1.1. + // The protocol version for incoming server requests. + // + // For client requests these fields are ignored. The HTTP + // client code always uses either HTTP/1.1 or HTTP/2. + // See the docs on Transport for details. Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 @@ -354,7 +357,7 @@ const defaultUserAgent = "Go-http-client/1.1" // hasn't been set to "identity", Write adds "Transfer-Encoding: // chunked" to the header. Body is closed after it is sent. func (r *Request) Write(w io.Writer) error { - return r.write(w, false, nil) + return r.write(w, false, nil, nil) } // WriteProxy is like Write but writes the request in the form @@ -364,11 +367,16 @@ func (r *Request) Write(w io.Writer) error { // In either case, WriteProxy also writes a Host header, using // either r.Host or r.URL.Host. func (r *Request) WriteProxy(w io.Writer) error { - return r.write(w, true, nil) + return r.write(w, true, nil, nil) } +// errMissingHost is returned by Write when there is no Host or URL present in +// the Request. +var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") + // extraHeaders may be nil -func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error { +// waitForContinue may be nil +func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) error { // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. // @@ -376,7 +384,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err host := cleanHost(req.Host) if host == "" { if req.URL == nil { - return errors.New("http: Request.Write on Request with no Host or URL set") + return errMissingHost } host = cleanHost(req.URL.Host) } @@ -419,10 +427,8 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. userAgent := defaultUserAgent - if req.Header != nil { - if ua := req.Header["User-Agent"]; len(ua) > 0 { - userAgent = ua[0] - } + if _, ok := req.Header["User-Agent"]; ok { + userAgent = req.Header.Get("User-Agent") } if userAgent != "" { _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) @@ -458,6 +464,21 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err return err } + // Flush and wait for 100-continue if expected. + if waitForContinue != nil { + if bw, ok := w.(*bufio.Writer); ok { + err = bw.Flush() + if err != nil { + return err + } + } + + if !waitForContinue() { + req.closeBody() + return nil + } + } + // Write body and trailer err = tw.WriteBody(w) if err != nil { @@ -531,6 +552,23 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) { return major, minor, true } +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1*<any CHAR except CTLs or separators> + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + // NewRequest returns a new Request given a method, URL, and optional body. // // If the provided body is also an io.Closer, the returned @@ -544,6 +582,15 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) { // type's documentation for the difference between inbound and outbound // request fields. func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { + if method == "" { + // We document that "" means "GET" for Request.Method, and people have + // relied on that from NewRequest, so keep that working. + // We still enforce validMethod for non-empty methods. + method = "GET" + } + if !validMethod(method) { + return nil, fmt.Errorf("net/http: invalid method %q", method) + } u, err := url.Parse(urlStr) if err != nil { return nil, err @@ -643,8 +690,15 @@ func putTextprotoReader(r *textproto.Reader) { } // ReadRequest reads and parses an incoming request from b. -func ReadRequest(b *bufio.Reader) (req *Request, err error) { +func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, deleteHostHeader) } + +// Constants for readRequest's deleteHostHeader parameter. +const ( + deleteHostHeader = true + keepHostHeader = false +) +func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) { tp := newTextprotoReader(b) req = new(Request) @@ -711,7 +765,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { if req.Host == "" { req.Host = req.Header.get("Host") } - delete(req.Header, "Host") + if deleteHostHeader { + delete(req.Header, "Host") + } fixPragmaCacheControl(req.Header) @@ -1006,3 +1062,102 @@ func (r *Request) closeBody() { r.Body.Close() } } + +func (r *Request) isReplayable() bool { + if r.Body == nil { + switch valueOrDefault(r.Method, "GET") { + case "GET", "HEAD", "OPTIONS", "TRACE": + return true + } + } + return false +} + +func validHostHeader(h string) bool { + // The latests spec is actually this: + // + // http://tools.ietf.org/html/rfc7230#section-5.4 + // Host = uri-host [ ":" port ] + // + // Where uri-host is: + // http://tools.ietf.org/html/rfc3986#section-3.2.2 + // + // But we're going to be much more lenient for now and just + // search for any byte that's not a valid byte in any of those + // expressions. + for i := 0; i < len(h); i++ { + if !validHostByte[h[i]] { + return false + } + } + return true +} + +// See the validHostHeader comment. +var validHostByte = [256]bool{ + '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, + '8': true, '9': true, + + 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true, + 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true, + 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true, + 'y': true, 'z': true, + + 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, + 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true, + 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true, + 'Y': true, 'Z': true, + + '!': true, // sub-delims + '$': true, // sub-delims + '%': true, // pct-encoded (and used in IPv6 zones) + '&': true, // sub-delims + '(': true, // sub-delims + ')': true, // sub-delims + '*': true, // sub-delims + '+': true, // sub-delims + ',': true, // sub-delims + '-': true, // unreserved + '.': true, // unreserved + ':': true, // IPv6address + Host expression's optional port + ';': true, // sub-delims + '=': true, // sub-delims + '[': true, + '\'': true, // sub-delims + ']': true, + '_': true, // unreserved + '~': true, // unreserved +} + +func validHeaderName(v string) bool { + if len(v) == 0 { + return false + } + return strings.IndexFunc(v, isNotToken) == -1 +} + +// validHeaderValue reports whether v is a valid "field-value" according to +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 : +// +// message-header = field-name ":" [ field-value ] +// field-value = *( field-content | LWS ) +// field-content = <the OCTETs making up the field-value +// and consisting of either *TEXT or combinations +// of token, separators, and quoted-string> +// +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 : +// +// TEXT = <any OCTET except CTLs, +// but including LWS> +// LWS = [CRLF] 1*( SP | HT ) +// CTL = <any US-ASCII control character +// (octets 0 - 31) and DEL (127)> +func validHeaderValue(v string) bool { + for i := 0; i < len(v); i++ { + b := v[i] + if isCTL(b) && !isLWS(b) { + return false + } + } + return true +} diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index 627620c0c41..0ecdf85a563 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -13,7 +13,6 @@ import ( "io/ioutil" "mime/multipart" . "net/http" - "net/http/httptest" "net/url" "os" "reflect" @@ -177,9 +176,11 @@ func TestParseMultipartForm(t *testing.T) { } } -func TestRedirect(t *testing.T) { +func TestRedirect_h1(t *testing.T) { testRedirect(t, h1Mode) } +func TestRedirect_h2(t *testing.T) { testRedirect(t, h2Mode) } +func testRedirect(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.URL.Path { case "/": w.Header().Set("Location", "/foo/") @@ -190,10 +191,10 @@ func TestRedirect(t *testing.T) { w.WriteHeader(StatusBadRequest) } })) - defer ts.Close() + defer cst.close() var end = regexp.MustCompile("/foo/$") - r, err := Get(ts.URL) + r, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -355,6 +356,29 @@ func TestNewRequestHost(t *testing.T) { } } +func TestRequestInvalidMethod(t *testing.T) { + _, err := NewRequest("bad method", "http://foo.com/", nil) + if err == nil { + t.Error("expected error from NewRequest with invalid method") + } + req, err := NewRequest("GET", "http://foo.example/", nil) + if err != nil { + t.Fatal(err) + } + req.Method = "bad method" + _, err = DefaultClient.Do(req) + if err == nil || !strings.Contains(err.Error(), "invalid method") { + t.Errorf("Transport error = %v; want invalid method", err) + } + + req, err = NewRequest("", "http://foo.com/", nil) + if err != nil { + t.Errorf("NewRequest(empty method) = %v; want nil", err) + } else if req.Method != "GET" { + t.Errorf("NewRequest(empty method) has method %q; want GET", req.Method) + } +} + func TestNewRequestContentLength(t *testing.T) { readByte := func(r io.Reader) io.Reader { var b [1]byte @@ -515,10 +539,12 @@ func TestRequestWriteBufferedWriter(t *testing.T) { func TestRequestBadHost(t *testing.T) { got := []string{} - req, err := NewRequest("GET", "http://foo.com with spaces/after", nil) + req, err := NewRequest("GET", "http://foo/after", nil) if err != nil { t.Fatal(err) } + req.Host = "foo.com with spaces" + req.URL.Host = "foo.com with spaces" req.Write(logWrites{t, &got}) want := []string{ "GET /after HTTP/1.1\r\n", diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 76b85385244..c424f61cd00 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -72,8 +72,18 @@ type Response struct { // ReadResponse nor Response.Write ever closes a connection. Close bool - // Trailer maps trailer keys to values, in the same - // format as the header. + // Trailer maps trailer keys to values in the same + // format as Header. + // + // The Trailer initially contains only nil values, one for + // each key specified in the server's "Trailer" header + // value. Those values are not added to Header. + // + // Trailer must not be accessed concurrently with Read calls + // on the Body. + // + // After Body.Read has returned io.EOF, Trailer will contain + // any trailer values sent by the server. Trailer Header // The Request that was sent to obtain this Response. @@ -140,12 +150,14 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { if len(f) > 2 { reasonPhrase = f[2] } - resp.Status = f[1] + " " + reasonPhrase + if len(f[1]) != 3 { + return nil, &badStringError{"malformed HTTP status code", f[1]} + } resp.StatusCode, err = strconv.Atoi(f[1]) - if err != nil { + if err != nil || resp.StatusCode < 0 { return nil, &badStringError{"malformed HTTP status code", f[1]} } - + resp.Status = f[1] + " " + reasonPhrase resp.Proto = f[0] var ok bool if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 421cf55f491..d8a53400cf2 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -456,8 +456,59 @@ some body`, "", }, + + // Issue 12785: HTTP/1.0 response with bogus (to be ignored) Transfer-Encoding. + // Without a Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Transfer-Encoding: bogus\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Close: true, + ContentLength: -1, + }, + + "Body here\n", + }, + + // Issue 12785: HTTP/1.0 response with bogus (to be ignored) Transfer-Encoding. + // With a Content-Length. + { + "HTTP/1.0 200 OK\r\n" + + "Transfer-Encoding: bogus\r\n" + + "Content-Length: 10\r\n" + + "\r\n" + + "Body here\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{ + "Content-Length": {"10"}, + }, + Close: true, + ContentLength: 10, + }, + + "Body here\n", + }, } +// tests successful calls to ReadResponse, and inspects the returned Response. +// For error cases, see TestReadResponseErrors below. func TestReadResponse(t *testing.T) { for i, tt := range respTests { resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) @@ -624,6 +675,7 @@ var responseLocationTests = []responseLocationTest{ {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil}, {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil}, {"", "http://bar.com/baz", "", ErrNoLocation}, + {"/bar", "", "/bar", nil}, } func TestLocationResponse(t *testing.T) { @@ -702,13 +754,106 @@ func TestResponseContentLengthShortBody(t *testing.T) { } } -func TestReadResponseUnexpectedEOF(t *testing.T) { - br := bufio.NewReader(strings.NewReader("HTTP/1.1 301 Moved Permanently\r\n" + - "Location: http://example.com")) - _, err := ReadResponse(br, nil) - if err != io.ErrUnexpectedEOF { - t.Errorf("ReadResponse = %v; want io.ErrUnexpectedEOF", err) +// Test various ReadResponse error cases. (also tests success cases, but mostly +// it's about errors). This does not test anything involving the bodies. Only +// the return value from ReadResponse itself. +func TestReadResponseErrors(t *testing.T) { + type testCase struct { + name string // optional, defaults to in + in string + wantErr interface{} // nil, err value, or string substring + } + + status := func(s string, wantErr interface{}) testCase { + if wantErr == true { + wantErr = "malformed HTTP status code" + } + return testCase{ + name: fmt.Sprintf("status %q", s), + in: "HTTP/1.1 " + s + "\r\nFoo: bar\r\n\r\n", + wantErr: wantErr, + } + } + + version := func(s string, wantErr interface{}) testCase { + if wantErr == true { + wantErr = "malformed HTTP version" + } + return testCase{ + name: fmt.Sprintf("version %q", s), + in: s + " 200 OK\r\n\r\n", + wantErr: wantErr, + } + } + + tests := []testCase{ + {"", "", io.ErrUnexpectedEOF}, + {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", io.ErrUnexpectedEOF}, + {"", "HTTP/1.1", "malformed HTTP response"}, + {"", "HTTP/2.0", "malformed HTTP response"}, + status("20X Unknown", true), + status("abcd Unknown", true), + status("二百/两百 OK", true), + status(" Unknown", true), + status("c8 OK", true), + status("0x12d Moved Permanently", true), + status("200 OK", nil), + status("000 OK", nil), + status("001 OK", nil), + status("404 NOTFOUND", nil), + status("20 OK", true), + status("00 OK", true), + status("-10 OK", true), + status("1000 OK", true), + status("999 Done", nil), + status("-1 OK", true), + status("-200 OK", true), + version("HTTP/1.2", nil), + version("HTTP/2.0", nil), + version("HTTP/1.100000000002", true), + version("HTTP/1.-1", true), + version("HTTP/A.B", true), + version("HTTP/1", true), + version("http/1.1", true), + } + for i, tt := range tests { + br := bufio.NewReader(strings.NewReader(tt.in)) + _, rerr := ReadResponse(br, nil) + if err := matchErr(rerr, tt.wantErr); err != nil { + name := tt.name + if name == "" { + name = fmt.Sprintf("%d. input %q", i, tt.in) + } + t.Errorf("%s: %v", name, err) + } + } +} + +// wantErr can be nil, an error value to match exactly, or type string to +// match a substring. +func matchErr(err error, wantErr interface{}) error { + if err == nil { + if wantErr == nil { + return nil + } + if sub, ok := wantErr.(string); ok { + return fmt.Errorf("unexpected success; want error with substring %q", sub) + } + return fmt.Errorf("unexpected success; want error %v", wantErr) + } + if wantErr == nil { + return fmt.Errorf("%v; want success", err) + } + if sub, ok := wantErr.(string); ok { + if strings.Contains(err.Error(), sub) { + return nil + } + return fmt.Errorf("error = %v; want an error with substring %q", err, sub) + } + if err == wantErr { + return nil } + return fmt.Errorf("%v; want %v", err, wantErr) } func TestNeedsSniff(t *testing.T) { diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index d51417eb4a0..f8cad802d49 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -12,6 +12,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/testenv" "io" "io/ioutil" "log" @@ -26,6 +27,8 @@ import ( "os/exec" "reflect" "runtime" + "runtime/debug" + "sort" "strconv" "strings" "sync" @@ -96,6 +99,7 @@ func (c *rwTestConn) Close() error { } type testConn struct { + readMu sync.Mutex // for TestHandlerBodyClose readBuf bytes.Buffer writeBuf bytes.Buffer closec chan bool // if non-nil, send value to it on close @@ -103,6 +107,8 @@ type testConn struct { } func (c *testConn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() return c.readBuf.Read(b) } @@ -450,6 +456,7 @@ func TestServerTimeouts(t *testing.T) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/7237") } + setParallel(t) defer afterTest(t) reqNum := 0 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { @@ -734,14 +741,17 @@ func TestHandlersCanSetConnectionClose10(t *testing.T) { })) } -func TestSetsRemoteAddr(t *testing.T) { +func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } +func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } + +func testSetsRemoteAddr(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) - defer ts.Close() + defer cst.close() - res, err := Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatalf("Get error: %v", err) } @@ -755,34 +765,106 @@ func TestSetsRemoteAddr(t *testing.T) { } } -func TestChunkedResponseHeaders(t *testing.T) { - defer afterTest(t) - log.SetOutput(ioutil.Discard) // is noisy otherwise - defer log.SetOutput(os.Stderr) +type blockingRemoteAddrListener struct { + net.Listener + conns chan<- net.Conn +} - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted - w.(Flusher).Flush() - fmt.Fprintf(w, "I am a chunked response.") +func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + brac := &blockingRemoteAddrConn{ + Conn: c, + addrs: make(chan net.Addr, 1), + } + l.conns <- brac + return brac, nil +} + +type blockingRemoteAddrConn struct { + net.Conn + addrs chan net.Addr +} + +func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { + return <-c.addrs +} + +// Issue 12943 +func TestServerAllowsBlockingRemoteAddr(t *testing.T) { + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "RA:%s", r.RemoteAddr) })) + conns := make(chan net.Conn) + ts.Listener = &blockingRemoteAddrListener{ + Listener: ts.Listener, + conns: conns, + } + ts.Start() defer ts.Close() - res, err := Get(ts.URL) - if err != nil { - t.Fatalf("Get error: %v", err) + tr := &Transport{DisableKeepAlives: true} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr, Timeout: time.Second} + + fetch := func(response chan string) { + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + response <- "" + return + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + response <- "" + return + } + response <- string(body) } - defer res.Body.Close() - if g, e := res.ContentLength, int64(-1); g != e { - t.Errorf("expected ContentLength of %d; got %d", e, g) + + // Start a request. The server will block on getting conn.RemoteAddr. + response1c := make(chan string, 1) + go fetch(response1c) + + // Wait for the server to accept it; grab the connection. + conn1 := <-conns + + // Start another request and grab its connection + response2c := make(chan string, 1) + go fetch(response2c) + var conn2 net.Conn + + select { + case conn2 = <-conns: + case <-time.After(time.Second): + t.Fatal("Second Accept didn't happen") } - if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) { - t.Errorf("expected TransferEncoding of %v; got %v", e, g) + + // Send a response on connection 2. + conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{ + IP: net.ParseIP("12.12.12.12"), Port: 12} + + // ... and see it + response2 := <-response2c + if g, e := response2, "RA:12.12.12.12:12"; g != e { + t.Fatalf("response 2 addr = %q; want %q", g, e) } - if _, haveCL := res.Header["Content-Length"]; haveCL { - t.Errorf("Unexpected Content-Length") + + // Finish the first response. + conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{ + IP: net.ParseIP("21.21.21.21"), Port: 21} + + // ... and see it + response1 := <-response1c + if g, e := response1, "RA:21.21.21.21:21"; g != e { + t.Fatalf("response 1 addr = %q; want %q", g, e) } } - func TestIdentityResponseHeaders(t *testing.T) { defer afterTest(t) log.SetOutput(ioutil.Discard) // is noisy otherwise @@ -812,40 +894,14 @@ func TestIdentityResponseHeaders(t *testing.T) { } } -// Test304Responses verifies that 304s don't declare that they're -// chunking in their response headers and aren't allowed to produce -// output. -func Test304Responses(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - w.WriteHeader(StatusNotModified) - _, err := w.Write([]byte("illegal body")) - if err != ErrBodyNotAllowed { - t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) - } - })) - defer ts.Close() - res, err := Get(ts.URL) - if err != nil { - t.Error(err) - } - if len(res.TransferEncoding) > 0 { - t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) - } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Error(err) - } - if len(body) > 0 { - t.Errorf("got unexpected body %q", string(body)) - } -} - // TestHeadResponses verifies that all MIME type sniffing and Content-Length // counting of GET requests also happens on HEAD requests. -func TestHeadResponses(t *testing.T) { +func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } +func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } + +func testHeadResponses(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("<html>")) if err != nil { t.Errorf("ResponseWriter.Write: %v", err) @@ -857,8 +913,8 @@ func TestHeadResponses(t *testing.T) { t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) - defer ts.Close() - res, err := Head(ts.URL) + defer cst.close() + res, err := cst.c.Head(cst.ts.URL) if err != nil { t.Error(err) } @@ -884,6 +940,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/7237") } + setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) errc := make(chanWriter, 10) // but only expecting 1 @@ -967,6 +1024,79 @@ func TestTLSServer(t *testing.T) { }) } +func TestAutomaticHTTP2_Serve(t *testing.T) { + defer afterTest(t) + ln := newLocalListener(t) + ln.Close() // immediately (not a defer!) + var s Server + if err := s.Serve(ln); err == nil { + t.Fatal("expected an error") + } + on := s.TLSNextProto["h2"] != nil + if !on { + t.Errorf("http2 wasn't automatically enabled") + } +} + +func TestAutomaticHTTP2_ListenAndServe(t *testing.T) { + defer afterTest(t) + defer SetTestHookServerServe(nil) + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + t.Fatal(err) + } + var ok bool + var s *Server + const maxTries = 5 + var ln net.Listener +Try: + for try := 0; try < maxTries; try++ { + ln = newLocalListener(t) + addr := ln.Addr().String() + ln.Close() + t.Logf("Got %v", addr) + lnc := make(chan net.Listener, 1) + SetTestHookServerServe(func(s *Server, ln net.Listener) { + lnc <- ln + }) + s = &Server{ + Addr: addr, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + } + errc := make(chan error, 1) + go func() { errc <- s.ListenAndServeTLS("", "") }() + select { + case err := <-errc: + t.Logf("On try #%v: %v", try+1, err) + continue + case ln = <-lnc: + ok = true + t.Logf("Listening on %v", ln.Addr().String()) + break Try + } + } + if !ok { + t.Fatalf("Failed to start up after %d tries", maxTries) + } + defer ln.Close() + c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2", "http/1.1"}, + }) + if err != nil { + t.Fatal(err) + } + defer c.Close() + if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want { + t.Errorf("NegotiatedProtocol = %q; want %q", got, want) + } + if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want { + t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want) + } +} + type serverExpectTest struct { contentLength int // of request body chunked bool @@ -1016,6 +1146,7 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. +// http2 test: TestServer_Response_Automatic100Continue func TestServerExpect(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1122,15 +1253,21 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { done := make(chan bool) + readBufLen := func() int { + conn.readMu.Lock() + defer conn.readMu.Unlock() + return conn.readBuf.Len() + } + ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { defer close(done) - if conn.readBuf.Len() < len(body)/2 { - t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) + if bufLen := readBufLen(); bufLen < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen) } rw.WriteHeader(200) rw.(Flusher).Flush() - if g, e := conn.readBuf.Len(), 0; g != e { + if g, e := readBufLen(), 0; g != e { t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) } if c := rw.Header().Get("Connection"); c != "" { @@ -1144,6 +1281,9 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { // should ignore client request bodies that a handler didn't read // and close the connection. func TestServerUnreadRequestBodyLarge(t *testing.T) { + if testing.Short() && testenv.Builder() == "" { + t.Log("skipping in short mode") + } conn := new(testConn) body := strings.Repeat("x", 1<<20) conn.readBuf.Write([]byte(fmt.Sprintf( @@ -1274,6 +1414,9 @@ var handlerBodyCloseTests = [...]handlerBodyCloseTest{ } func TestHandlerBodyClose(t *testing.T) { + if testing.Short() && testenv.Builder() == "" { + t.Skip("skipping in -short mode") + } for i, tt := range handlerBodyCloseTests { testHandlerBodyClose(t, i, tt) } @@ -1306,15 +1449,21 @@ func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) { } conn.closec = make(chan bool, 1) + readBufLen := func() int { + conn.readMu.Lock() + defer conn.readMu.Unlock() + return conn.readBuf.Len() + } + ls := &oneConnListener{conn} var numReqs int var size0, size1 int go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { numReqs++ if numReqs == 1 { - size0 = conn.readBuf.Len() + size0 = readBufLen() req.Body.Close() - size1 = conn.readBuf.Len() + size1 = readBufLen() } })) <-conn.closec @@ -1414,7 +1563,9 @@ type slowTestConn struct { // over multiple calls to Read, time.Durations are slept, strings are read. script []interface{} closec chan bool - rd, wd time.Time // read, write deadline + + mu sync.Mutex // guards rd/wd + rd, wd time.Time // read, write deadline noopConn } @@ -1425,16 +1576,22 @@ func (c *slowTestConn) SetDeadline(t time.Time) error { } func (c *slowTestConn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() c.rd = t return nil } func (c *slowTestConn) SetWriteDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() c.wd = t return nil } func (c *slowTestConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() restart: if !c.rd.IsZero() && time.Now().After(c.rd) { return 0, syscall.ETIMEDOUT @@ -1531,7 +1688,9 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { } } -func TestTimeoutHandler(t *testing.T) { +func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } +func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } +func testTimeoutHandler(t *testing.T, h2 bool) { defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) @@ -1541,12 +1700,12 @@ func TestTimeoutHandler(t *testing.T) { writeErrors <- werr }) timeout := make(chan time.Time, 1) // write to this to force timeouts - ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout)) - defer ts.Close() + cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout)) + defer cst.close() // Succeed without timing out: sendHi <- true - res, err := Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) } @@ -1563,7 +1722,7 @@ func TestTimeoutHandler(t *testing.T) { // Times out: timeout <- time.Time{} - res, err = Get(ts.URL) + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) } @@ -1659,6 +1818,60 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { wg.Wait() } +// Issue 9162 +func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { + defer afterTest(t) + sendHi := make(chan bool, 1) + writeErrors := make(chan error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/plain") + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + timeout := make(chan time.Time, 1) // write to this to force timeouts + cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout)) + defer cst.close() + + // Succeed without timing out: + sendHi <- true + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusOK; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := ioutil.ReadAll(res.Body) + if g, e := string(body), "hi"; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g := <-writeErrors; g != nil { + t.Errorf("got unexpected Write error on first request: %v", g) + } + + // Times out: + timeout <- time.Time{} + res, err = cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), "<title>Timeout</title>") { + t.Errorf("expected timeout body; got %q", string(body)) + } + + // Now make the previously-timed out handler speak again, + // which verifies the panic is handled: + sendHi <- true + if g, e := <-writeErrors, ErrHandlerTimeout; g != e { + t.Errorf("expected Write error of %v; got %v", e, g) + } +} + // Verifies we don't path.Clean() on the wrong parts in redirects. func TestRedirectMunging(t *testing.T) { req, _ := NewRequest("GET", "http://example.com/", nil) @@ -1693,15 +1906,57 @@ func TestRedirectBadPath(t *testing.T) { } } +// Test different URL formats and schemes +func TestRedirectURLFormat(t *testing.T) { + req, _ := NewRequest("GET", "http://example.com/qux/", nil) + + var tests = []struct { + in string + want string + }{ + // normal http + {"http://foobar.com/baz", "http://foobar.com/baz"}, + // normal https + {"https://foobar.com/baz", "https://foobar.com/baz"}, + // custom scheme + {"test://foobar.com/baz", "test://foobar.com/baz"}, + // schemeless + {"//foobar.com/baz", "//foobar.com/baz"}, + // relative to the root + {"/foobar.com/baz", "/foobar.com/baz"}, + // relative to the current path + {"foobar.com/baz", "/qux/foobar.com/baz"}, + // relative to the current path (+ going upwards) + {"../quux/foobar.com/baz", "/quux/foobar.com/baz"}, + // incorrect number of slashes + {"///foobar.com/baz", "/foobar.com/baz"}, + } + + for _, tt := range tests { + rec := httptest.NewRecorder() + Redirect(rec, req, tt.in, 302) + if got := rec.Header().Get("Location"); got != tt.want { + t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want) + } + } +} + // TestZeroLengthPostAndResponse exercises an optimization done by the Transport: // when there is no body (either because the method doesn't permit a body, or an // explicit Content-Length of zero is present), then the transport can re-use the // connection immediately. But when it re-uses the connection, it typically closes // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. -func TestZeroLengthPostAndResponse(t *testing.T) { +func TestZeroLengthPostAndResponse_h1(t *testing.T) { + testZeroLengthPostAndResponse(t, h1Mode) +} +func TestZeroLengthPostAndResponse_h2(t *testing.T) { + testZeroLengthPostAndResponse(t, h2Mode) +} + +func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) @@ -1711,9 +1966,9 @@ func TestZeroLengthPostAndResponse(t *testing.T) { } rw.Header().Set("Content-Length", "0") })) - defer ts.Close() + defer cst.close() - req, err := NewRequest("POST", ts.URL, strings.NewReader("")) + req, err := NewRequest("POST", cst.ts.URL, strings.NewReader("")) if err != nil { t.Fatal(err) } @@ -1721,7 +1976,7 @@ func TestZeroLengthPostAndResponse(t *testing.T) { var resp [5]*Response for i := range resp { - resp[i], err = DefaultClient.Do(req) + resp[i], err = cst.c.Do(req) if err != nil { t.Fatalf("client post #%d: %v", i, err) } @@ -1738,19 +1993,22 @@ func TestZeroLengthPostAndResponse(t *testing.T) { } } -func TestHandlerPanicNil(t *testing.T) { - testHandlerPanic(t, false, nil) -} +func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil) } +func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil) } -func TestHandlerPanic(t *testing.T) { - testHandlerPanic(t, false, "intentional death for testing") +func TestHandlerPanic_h1(t *testing.T) { + testHandlerPanic(t, false, h1Mode, "intentional death for testing") +} +func TestHandlerPanic_h2(t *testing.T) { + testHandlerPanic(t, false, h2Mode, "intentional death for testing") } func TestHandlerPanicWithHijack(t *testing.T) { - testHandlerPanic(t, true, "intentional death for testing") + // Only testing HTTP/1, and our http2 server doesn't support hijacking. + testHandlerPanic(t, true, h1Mode, "intentional death for testing") } -func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { +func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) { defer afterTest(t) // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three @@ -1773,7 +2031,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { defer log.SetOutput(os.Stderr) defer pw.Close() - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { if withHijack { rwc, _, err := w.(Hijacker).Hijack() if err != nil { @@ -1783,7 +2041,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { } panic(panicValue) })) - defer ts.Close() + defer cst.close() // Do a blocking read on the log output pipe so its logging // doesn't bleed into the next test. But wait only 5 seconds @@ -1799,7 +2057,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { done <- true }() - _, err := Get(ts.URL) + _, err := cst.c.Get(cst.ts.URL) if err == nil { t.Logf("expected an error") } @@ -1816,17 +2074,19 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { } } -func TestServerNoDate(t *testing.T) { testServerNoHeader(t, "Date") } -func TestServerNoContentType(t *testing.T) { testServerNoHeader(t, "Content-Type") } +func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } +func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } +func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } +func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } -func testServerNoHeader(t *testing.T, header string) { +func testServerNoHeader(t *testing.T, h2 bool, header string) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil io.WriteString(w, "<html>foo</html>") // non-empty })) - defer ts.Close() - res, err := Get(ts.URL) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -1863,18 +2123,20 @@ func TestStripPrefix(t *testing.T) { res.Body.Close() } -func TestRequestLimit(t *testing.T) { +func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } +func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } +func testRequestLimit(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") })) - defer ts.Close() - req, _ := NewRequest("GET", ts.URL, nil) + defer cst.close() + req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i)) } - res, err := DefaultClient.Do(req) + res, err := cst.c.Do(req) if err != nil { // Some HTTP clients may fail on this undefined behavior (server replying and // closing the connection while the request is still being written), but @@ -1882,8 +2144,8 @@ func TestRequestLimit(t *testing.T) { t.Fatalf("Do: %v", err) } defer res.Body.Close() - if res.StatusCode != 413 { - t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) + if res.StatusCode != 431 { + t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status) } } @@ -1907,10 +2169,12 @@ func (cr countReader) Read(p []byte) (n int, err error) { return } -func TestRequestBodyLimit(t *testing.T) { +func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } +func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } +func testRequestBodyLimit(t *testing.T, h2 bool) { defer afterTest(t) const limit = 1 << 20 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) n, err := io.Copy(ioutil.Discard, r.Body) if err == nil { @@ -1920,10 +2184,10 @@ func TestRequestBodyLimit(t *testing.T) { t.Errorf("io.Copy = %d, want %d", n, limit) } })) - defer ts.Close() + defer cst.close() nWritten := new(int64) - req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) + req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) // Send the POST, but don't care it succeeds or not. The // remote side is going to reply and then close the TCP @@ -1934,7 +2198,7 @@ func TestRequestBodyLimit(t *testing.T) { // // But that's okay, since what we're really testing is that // the remote side hung up on us before we wrote too much. - _, _ = DefaultClient.Do(req) + _, _ = cst.c.Do(req) if atomic.LoadInt64(nWritten) > limit*100 { t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", @@ -1982,7 +2246,7 @@ func TestClientWriteShutdown(t *testing.T) { // buffered before chunk headers are added, not after chunk headers. func TestServerBufferedChunking(t *testing.T) { conn := new(testConn) - conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n")) + conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) conn.closec = make(chan bool, 1) ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -2045,20 +2309,23 @@ func TestServerGracefulClose(t *testing.T) { <-writeErr } -func TestCaseSensitiveMethod(t *testing.T) { +func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) } +func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) } +func testCaseSensitiveMethod(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) } })) - defer ts.Close() - req, _ := NewRequest("get", ts.URL, nil) - res, err := DefaultClient.Do(req) + defer cst.close() + req, _ := NewRequest("get", cst.ts.URL, nil) + res, err := cst.c.Do(req) if err != nil { t.Error(err) return } + res.Body.Close() } @@ -2131,6 +2398,49 @@ For: ts.Close() } +// Tests that a pipelined request causes the first request's Handler's CloseNotify +// channel to fire. Previously it deadlocked. +// +// Issue 13165 +func TestCloseNotifierPipelined(t *testing.T) { + defer afterTest(t) + gotReq := make(chan bool, 2) + sawClose := make(chan bool, 2) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + gotReq <- true + cc := rw.(CloseNotifier).CloseNotify() + <-cc + sawClose <- true + })) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + diec := make(chan bool, 2) + go func() { + const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n" + _, err = io.WriteString(conn, req+req) // two requests + if err != nil { + t.Fatal(err) + } + <-diec + conn.Close() + }() +For: + for { + select { + case <-gotReq: + diec <- true + case <-sawClose: + break For + case <-time.After(5 * time.Second): + ts.CloseClientConnections() + t.Fatal("timeout") + } + } + ts.Close() +} + func TestCloseNotifierChanLeak(t *testing.T) { defer afterTest(t) req := reqBytes("GET / HTTP/1.0\nHost: golang.org") @@ -2153,6 +2463,114 @@ func TestCloseNotifierChanLeak(t *testing.T) { } } +// Tests that we can use CloseNotifier in one request, and later call Hijack +// on a second request on the same connection. +// +// It also tests that the connReader stitches together its background +// 1-byte read for CloseNotifier when CloseNotifier doesn't fire with +// the rest of the second HTTP later. +// +// Issue 9763. +// HTTP/1-only test. (http2 doesn't have Hijack) +func TestHijackAfterCloseNotifier(t *testing.T) { + defer afterTest(t) + script := make(chan string, 2) + script <- "closenotify" + script <- "hijack" + close(script) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + plan := <-script + switch plan { + default: + panic("bogus plan; too many requests") + case "closenotify": + w.(CloseNotifier).CloseNotify() // discard result + w.Header().Set("X-Addr", r.RemoteAddr) + case "hijack": + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack in Handler: %v", err) + return + } + if _, ok := c.(*net.TCPConn); !ok { + // Verify it's not wrapped in some type. + // Not strictly a go1 compat issue, but in practice it probably is. + t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c) + } + fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr) + c.Close() + return + } + })) + defer ts.Close() + res1, err := Get(ts.URL) + if err != nil { + log.Fatal(err) + } + res2, err := Get(ts.URL) + if err != nil { + log.Fatal(err) + } + addr1 := res1.Header.Get("X-Addr") + addr2 := res2.Header.Get("X-Addr") + if addr1 == "" || addr1 != addr2 { + t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2) + } +} + +func TestHijackBeforeRequestBodyRead(t *testing.T) { + defer afterTest(t) + var requestBody = bytes.Repeat([]byte("a"), 1<<20) + bodyOkay := make(chan bool, 1) + gotCloseNotify := make(chan bool, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(bodyOkay) // caller will read false if nothing else + + reqBody := r.Body + r.Body = nil // to test that server.go doesn't use this value. + + gone := w.(CloseNotifier).CloseNotify() + slurp, err := ioutil.ReadAll(reqBody) + if err != nil { + t.Errorf("Body read: %v", err) + return + } + if len(slurp) != len(requestBody) { + t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) + return + } + if !bytes.Equal(slurp, requestBody) { + t.Error("Backend read wrong request body.") // 1MB; omitting details + return + } + bodyOkay <- true + select { + case <-gone: + gotCloseNotify <- true + case <-time.After(5 * time.Second): + gotCloseNotify <- false + } + })) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s", + len(requestBody), requestBody) + if !<-bodyOkay { + // already failed. + return + } + conn.Close() + if !<-gotCloseNotify { + t.Error("timeout waiting for CloseNotify") + } +} + func TestOptions(t *testing.T) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() @@ -2230,7 +2648,7 @@ func TestHeaderToWire(t *testing.T) { return errors.New("no content-length") } if !strings.Contains(got, "Content-Type: text/plain") { - return errors.New("no content-length") + return errors.New("no content-type") } return nil }, @@ -2302,7 +2720,7 @@ func TestHeaderToWire(t *testing.T) { return errors.New("header appeared from after WriteHeader") } if !strings.Contains(got, "Content-Type: some/type") { - return errors.New("wrong content-length") + return errors.New("wrong content-type") } return nil }, @@ -2315,7 +2733,7 @@ func TestHeaderToWire(t *testing.T) { }, check: func(got string) error { if !strings.Contains(got, "Content-Type: text/html") { - return errors.New("wrong content-length; want html") + return errors.New("wrong content-type; want html") } return nil }, @@ -2328,7 +2746,7 @@ func TestHeaderToWire(t *testing.T) { }, check: func(got string) error { if !strings.Contains(got, "Content-Type: some/type") { - return errors.New("wrong content-length; want html") + return errors.New("wrong content-type; want html") } return nil }, @@ -2339,7 +2757,7 @@ func TestHeaderToWire(t *testing.T) { }, check: func(got string) error { if !strings.Contains(got, "Content-Type: text/plain") { - return errors.New("wrong content-length; want text/plain") + return errors.New("wrong content-type; want text/plain") } if !strings.Contains(got, "Content-Length: 0") { return errors.New("want 0 content-length") @@ -2369,7 +2787,7 @@ func TestHeaderToWire(t *testing.T) { if !strings.Contains(got, "404") { return errors.New("wrong status") } - if strings.Contains(got, "Some-Header") { + if strings.Contains(got, "Too-Late") { return errors.New("shouldn't have seen Too-Late") } return nil @@ -2503,7 +2921,7 @@ func TestHTTP10ConnectionHeader(t *testing.T) { defer afterTest(t) mux := NewServeMux() - mux.Handle("/", HandlerFunc(func(resp ResponseWriter, req *Request) {})) + mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) ts := httptest.NewServer(mux) defer ts.Close() @@ -2552,11 +2970,13 @@ func TestHTTP10ConnectionHeader(t *testing.T) { } // See golang.org/issue/5660 -func TestServerReaderFromOrder(t *testing.T) { +func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } +func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } +func testServerReaderFromOrder(t *testing.T, h2 bool) { defer afterTest(t) pr, pw := io.Pipe() const size = 3 << 20 - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path done := make(chan bool) go func() { @@ -2576,13 +2996,13 @@ func TestServerReaderFromOrder(t *testing.T) { pw.Close() <-done })) - defer ts.Close() + defer cst.close() - req, err := NewRequest("POST", ts.URL, io.LimitReader(neverEnding('a'), size)) + req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size)) if err != nil { t.Fatal(err) } - res, err := DefaultClient.Do(req) + res, err := cst.c.Do(req) if err != nil { t.Fatal(err) } @@ -2612,9 +3032,9 @@ func TestCodesPreventingContentTypeAndBody(t *testing.T) { "GET / HTTP/1.0", "GET /header HTTP/1.0", "GET /more HTTP/1.0", - "GET / HTTP/1.1", - "GET /header HTTP/1.1", - "GET /more HTTP/1.1", + "GET / HTTP/1.1\nHost: foo", + "GET /header HTTP/1.1\nHost: foo", + "GET /more HTTP/1.1\nHost: foo", } { got := ht.rawResponse(req) wantStatus := fmt.Sprintf("%d %s", code, StatusText(code)) @@ -2635,7 +3055,7 @@ func TestContentTypeOkayOn204(t *testing.T) { w.Header().Set("Content-Type", "foo/bar") w.WriteHeader(204) })) - got := ht.rawResponse("GET / HTTP/1.1") + got := ht.rawResponse("GET / HTTP/1.1\nHost: foo") if !strings.Contains(got, "Content-Type: foo/bar") { t.Errorf("Response = %q; want Content-Type: foo/bar", got) } @@ -2650,45 +3070,101 @@ func TestContentTypeOkayOn204(t *testing.T) { // proxy). So then two people own that Request.Body (both the server // and the http client), and both think they can close it on failure. // Therefore, all incoming server requests Bodies need to be thread-safe. -func TestTransportAndServerSharedBodyRace(t *testing.T) { +func TestTransportAndServerSharedBodyRace_h1(t *testing.T) { + testTransportAndServerSharedBodyRace(t, h1Mode) +} +func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { + testTransportAndServerSharedBodyRace(t, h2Mode) +} +func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { defer afterTest(t) const bodySize = 1 << 20 + // errorf is like t.Errorf, but also writes to println. When + // this test fails, it hangs. This helps debugging and I've + // added this enough times "temporarily". It now gets added + // full time. + errorf := func(format string, args ...interface{}) { + v := fmt.Sprintf(format, args...) + println(v) + t.Error(v) + } + unblockBackend := make(chan bool) - backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { - io.CopyN(rw, req.Body, bodySize) + backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + gone := rw.(CloseNotifier).CloseNotify() + didCopy := make(chan interface{}) + go func() { + n, err := io.CopyN(rw, req.Body, bodySize) + didCopy <- []interface{}{n, err} + }() + isGone := false + Loop: + for { + select { + case <-didCopy: + break Loop + case <-gone: + isGone = true + case <-time.After(time.Second): + println("1 second passes in backend, proxygone=", isGone) + } + } <-unblockBackend })) - defer backend.Close() + var quitTimer *time.Timer + defer func() { quitTimer.Stop() }() + defer backend.close() backendRespc := make(chan *Response, 1) - proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { - req2, _ := NewRequest("POST", backend.URL, req.Body) + var proxy *clientServerTest + proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + req2, _ := NewRequest("POST", backend.ts.URL, req.Body) req2.ContentLength = bodySize + cancel := make(chan struct{}) + req2.Cancel = cancel - bresp, err := DefaultClient.Do(req2) + bresp, err := proxy.c.Do(req2) if err != nil { - t.Errorf("Proxy outbound request: %v", err) + errorf("Proxy outbound request: %v", err) return } _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/2) if err != nil { - t.Errorf("Proxy copy error: %v", err) + errorf("Proxy copy error: %v", err) return } backendRespc <- bresp // to close later - // Try to cause a race: Both the DefaultTransport and the proxy handler's Server + // Try to cause a race: Both the Transport and the proxy handler's Server // will try to read/close req.Body (aka req2.Body) - DefaultTransport.(*Transport).CancelRequest(req2) + if h2 { + close(cancel) + } else { + proxy.c.Transport.(*Transport).CancelRequest(req2) + } rw.Write([]byte("OK")) })) - defer proxy.Close() + defer proxy.close() + defer func() { + // Before we shut down our two httptest.Servers, start a timer. + // We choose 7 seconds because httptest.Server starts logging + // warnings to stderr at 5 seconds. If we don't disarm this bomb + // in 7 seconds (after the two httptest.Server.Close calls above), + // then we explode with stacks. + quitTimer = time.AfterFunc(7*time.Second, func() { + debug.SetTraceback("ALL") + stacks := make([]byte, 1<<20) + stacks = stacks[:runtime.Stack(stacks, true)] + fmt.Fprintf(os.Stderr, "%s", stacks) + log.Fatalf("Timeout.") + }) + }() defer close(unblockBackend) - req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize)) - res, err := DefaultClient.Do(req) + req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize)) + res, err := proxy.c.Do(req) if err != nil { t.Fatalf("Original request: %v", err) } @@ -2699,7 +3175,7 @@ func TestTransportAndServerSharedBodyRace(t *testing.T) { case res := <-backendRespc: res.Body.Close() default: - // We failed earlier. (e.g. on DefaultClient.Do(req2)) + // We failed earlier. (e.g. on proxy.c.Do(req2)) } } @@ -2863,6 +3339,7 @@ func TestServerConnState(t *testing.T) { if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil { t.Fatal(err) } + c.Read(make([]byte, 1)) // block until server hangs up on us c.Close() } @@ -2896,9 +3373,14 @@ func TestServerConnState(t *testing.T) { } logString := func(m map[int][]ConnState) string { var b bytes.Buffer - for id, l := range m { + var keys []int + for id := range m { + keys = append(keys, id) + } + sort.Ints(keys) + for _, id := range keys { fmt.Fprintf(&b, "Conn %d: ", id) - for _, s := range l { + for _, s := range m[id] { fmt.Fprintf(&b, "%s ", s) } b.WriteString("\n") @@ -2959,20 +3441,22 @@ func TestServerKeepAlivesEnabled(t *testing.T) { } // golang.org/issue/7856 -func TestServerEmptyBodyRace(t *testing.T) { +func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } +func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } +func testServerEmptyBodyRace(t *testing.T, h2 bool) { defer afterTest(t) var n int32 - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { atomic.AddInt32(&n, 1) })) - defer ts.Close() + defer cst.close() var wg sync.WaitGroup const reqs = 20 for i := 0; i < reqs; i++ { wg.Add(1) go func() { defer wg.Done() - res, err := Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) return @@ -3025,10 +3509,7 @@ func (c *closeWriteTestConn) CloseWrite() error { func TestCloseWrite(t *testing.T) { var srv Server var testConn closeWriteTestConn - c, err := ExportServerNewConn(&srv, &testConn) - if err != nil { - t.Fatal(err) - } + c := ExportServerNewConn(&srv, &testConn) ExportCloseWriteAndWait(c) if !testConn.didCloseWrite { t.Error("didn't see CloseWrite call") @@ -3193,6 +3674,33 @@ func TestTolerateCRLFBeforeRequestLine(t *testing.T) { } } +func TestIssue13893_Expect100(t *testing.T) { + // test that the Server doesn't filter out Expect headers. + req := reqBytes(`PUT /readbody HTTP/1.1 +User-Agent: PycURL/7.22.0 +Host: 127.0.0.1:9000 +Accept: */* +Expect: 100-continue +Content-Length: 10 + +HelloWorld + +`) + var buf bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &buf, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) { + if _, ok := r.Header["Expect"]; !ok { + t.Error("Expect header should not be filtered out") + } + })) + <-conn.closec +} + func TestIssue11549_Expect100(t *testing.T) { req := reqBytes(`PUT /readbody HTTP/1.1 User-Agent: PycURL/7.22.0 @@ -3260,6 +3768,122 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { } } +func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) } +func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) } +func testHandlerSetsBodyNil(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + r.Body = nil + fmt.Fprintf(w, "%v", r.RemoteAddr) + })) + defer cst.close() + get := func() string { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + return string(slurp) + } + a, b := get(), get() + if a != b { + t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b) + } +} + +// Test that we validate the Host header. +// Issue 11206 (invalid bytes in Host) and 13624 (Host present in HTTP/1.1) +func TestServerValidatesHostHeader(t *testing.T) { + tests := []struct { + proto string + host string + want int + }{ + {"HTTP/1.1", "", 400}, + {"HTTP/1.1", "Host: \r\n", 200}, + {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200}, + {"HTTP/1.1", "Host: foo.com\r\n", 200}, + {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200}, + {"HTTP/1.1", "Host: foo.com:80\r\n", 200}, + {"HTTP/1.1", "Host: ::1\r\n", 200}, + {"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it + {"HTTP/1.1", "Host: [::1]:80\r\n", 200}, + {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200}, + {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200}, + {"HTTP/1.1", "Host: \x06\r\n", 400}, + {"HTTP/1.1", "Host: \xff\r\n", 400}, + {"HTTP/1.1", "Host: {\r\n", 400}, + {"HTTP/1.1", "Host: }\r\n", 400}, + {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400}, + + // HTTP/1.0 can lack a host header, but if present + // must play by the rules too: + {"HTTP/1.0", "", 200}, + {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400}, + {"HTTP/1.0", "Host: \xff\r\n", 400}, + } + for _, tt := range tests { + conn := &testConn{closec: make(chan bool, 1)} + io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n") + + ln := &oneConnListener{conn} + go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + <-conn.closec + res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) + if err != nil { + t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res) + continue + } + if res.StatusCode != tt.want { + t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want) + } + } +} + +// Test that we validate the valid bytes in HTTP/1 headers. +// Issue 11207. +func TestServerValidatesHeaders(t *testing.T) { + tests := []struct { + header string + want int + }{ + {"", 200}, + {"Foo: bar\r\n", 200}, + {"X-Foo: bar\r\n", 200}, + {"Foo: a space\r\n", 200}, + + {"A space: foo\r\n", 400}, // space in header + {"foo\xffbar: foo\r\n", 400}, // binary in header + {"foo\x00bar: foo\r\n", 400}, // binary in header + + {"foo: foo foo\r\n", 200}, // LWS space is okay + {"foo: foo\tfoo\r\n", 200}, // LWS tab is okay + {"foo: foo\x00foo\r\n", 400}, // CTL 0x00 in value is bad + {"foo: foo\x7ffoo\r\n", 400}, // CTL 0x7f in value is bad + {"foo: foo\xfffoo\r\n", 200}, // non-ASCII high octets in value are fine + } + for _, tt := range tests { + conn := &testConn{closec: make(chan bool, 1)} + io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n") + + ln := &oneConnListener{conn} + go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + <-conn.closec + res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) + if err != nil { + t.Errorf("For %q, ReadResponse: %v", tt.header, res) + continue + } + if res.StatusCode != tt.want { + t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want) + } + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() @@ -3685,3 +4309,35 @@ Host: golang.org <-conn.closec } } + +func BenchmarkCloseNotifier(b *testing.B) { + b.ReportAllocs() + b.StopTimer() + sawClose := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + <-rw.(CloseNotifier).CloseNotify() + sawClose <- true + })) + defer ts.Close() + tot := time.NewTimer(5 * time.Second) + defer tot.Stop() + b.StartTimer() + for i := 0; i < b.N; i++ { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + b.Fatalf("error dialing: %v", err) + } + _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n") + if err != nil { + b.Fatal(err) + } + conn.Close() + tot.Reset(5 * time.Second) + select { + case <-sawClose: + case <-tot.C: + b.Fatal("timeout") + } + } + b.StopTimer() +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index a3e43555bb3..004a1f92fc4 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -8,6 +8,7 @@ package http import ( "bufio" + "bytes" "crypto/tls" "errors" "fmt" @@ -35,26 +36,33 @@ var ( ErrContentLength = errors.New("Conn.Write wrote more than the declared Content-Length") ) -// Objects implementing the Handler interface can be -// registered to serve a particular path or subtree -// in the HTTP server. +// A Handler responds to an HTTP request. // // ServeHTTP should write reply headers and data to the ResponseWriter -// and then return. Returning signals that the request is finished -// and that the HTTP server can move on to the next request on -// the connection. +// and then return. Returning signals that the request is finished; it +// is not valid to use the ResponseWriter or read from the +// Request.Body after or concurrently with the completion of the +// ServeHTTP call. +// +// Depending on the HTTP client software, HTTP protocol version, and +// any intermediaries between the client and the Go server, it may not +// be possible to read from the Request.Body after writing to the +// ResponseWriter. Cautious handlers should read the Request.Body +// first, and then reply. // // If ServeHTTP panics, the server (the caller of ServeHTTP) assumes // that the effect of the panic was isolated to the active request. // It recovers the panic, logs a stack trace to the server error log, // and hangs up the connection. -// type Handler interface { ServeHTTP(ResponseWriter, *Request) } // A ResponseWriter interface is used by an HTTP handler to // construct an HTTP response. +// +// A ResponseWriter may not be used after the Handler.ServeHTTP method +// has returned. type ResponseWriter interface { // Header returns the header map that will be sent by // WriteHeader. Changing the header after a call to @@ -114,28 +122,76 @@ type Hijacker interface { // This mechanism can be used to cancel long operations on the server // if the client has disconnected before the response is ready. type CloseNotifier interface { - // CloseNotify returns a channel that receives a single value - // when the client connection has gone away. + // CloseNotify returns a channel that receives at most a + // single value (true) when the client connection has gone + // away. + // + // CloseNotify may wait to notify until Request.Body has been + // fully read. + // + // After the Handler has returned, there is no guarantee + // that the channel receives a value. + // + // If the protocol is HTTP/1.1 and CloseNotify is called while + // processing an idempotent request (such a GET) while + // HTTP/1.1 pipelining is in use, the arrival of a subsequent + // pipelined request may cause a value to be sent on the + // returned channel. In practice HTTP/1.1 pipelining is not + // enabled in browsers and not seen often in the wild. If this + // is a problem, use HTTP/2 or only use CloseNotify on methods + // such as POST. CloseNotify() <-chan bool } // A conn represents the server side of an HTTP connection. type conn struct { - remoteAddr string // network address of remote side - server *Server // the Server on which the connection arrived - rwc net.Conn // i/o connection - w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack - werr error // any errors writing to w - sr liveSwitchReader // where the LimitReader reads from; usually the rwc - lr *io.LimitedReader // io.LimitReader(sr) - buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc - tlsState *tls.ConnectionState // or nil when not using TLS - lastMethod string // method of previous request, or "" - - mu sync.Mutex // guards the following - clientGone bool // if client has disconnected mid-request - closeNotifyc chan bool // made lazily - hijackedv bool // connection has been hijacked by handler + // server is the server on which the connection arrived. + // Immutable; never nil. + server *Server + + // rwc is the underlying network connection. + // This is never wrapped by other types and is the value given out + // to CloseNotifier callers. It is usually of type *net.TCPConn or + // *tls.Conn. + rwc net.Conn + + // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously + // inside the Listener's Accept goroutine, as some implementations block. + // It is populated immediately inside the (*conn).serve goroutine. + // This is the value of a Handler's (*Request).RemoteAddr. + remoteAddr string + + // tlsState is the TLS connection state when using TLS. + // nil means not TLS. + tlsState *tls.ConnectionState + + // werr is set to the first write error to rwc. + // It is set via checkConnErrorWriter{w}, where bufw writes. + werr error + + // r is bufr's read source. It's a wrapper around rwc that provides + // io.LimitedReader-style limiting (while reading request headers) + // and functionality to support CloseNotifier. See *connReader docs. + r *connReader + + // bufr reads from r. + // Users of bufr must hold mu. + bufr *bufio.Reader + + // bufw writes to checkConnErrorWriter{c}, which populates werr on error. + bufw *bufio.Writer + + // lastMethod is the method of the most recent request + // on this connection, if any. + lastMethod string + + // mu guards hijackedv, use of bufr, (*response).closeNotifyCh. + mu sync.Mutex + + // hijackedv is whether this connection has been hijacked + // by a Handler with the Hijacker interface. + // It is guarded by mu. + hijackedv bool } func (c *conn) hijacked() bool { @@ -144,81 +200,18 @@ func (c *conn) hijacked() bool { return c.hijackedv } -func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - c.mu.Lock() - defer c.mu.Unlock() +// c.mu must be held. +func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if c.hijackedv { return nil, nil, ErrHijacked } - if c.closeNotifyc != nil { - return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") - } c.hijackedv = true rwc = c.rwc - buf = c.buf - c.rwc = nil - c.buf = nil + buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) c.setState(rwc, StateHijacked) return } -func (c *conn) closeNotify() <-chan bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.closeNotifyc == nil { - c.closeNotifyc = make(chan bool, 1) - if c.hijackedv { - // to obey the function signature, even though - // it'll never receive a value. - return c.closeNotifyc - } - pr, pw := io.Pipe() - - readSource := c.sr.r - c.sr.Lock() - c.sr.r = pr - c.sr.Unlock() - go func() { - _, err := io.Copy(pw, readSource) - if err == nil { - err = io.EOF - } - pw.CloseWithError(err) - c.noteClientGone() - }() - } - return c.closeNotifyc -} - -func (c *conn) noteClientGone() { - c.mu.Lock() - defer c.mu.Unlock() - if c.closeNotifyc != nil && !c.clientGone { - c.closeNotifyc <- true - } - c.clientGone = true -} - -// A switchWriter can have its Writer changed at runtime. -// It's not safe for concurrent Writes and switches. -type switchWriter struct { - io.Writer -} - -// A liveSwitchReader can have its Reader changed at runtime. It's -// safe for concurrent reads and switches, if its mutex is held. -type liveSwitchReader struct { - sync.Mutex - r io.Reader -} - -func (sr *liveSwitchReader) Read(p []byte) (n int, err error) { - sr.Lock() - r := sr.r - sr.Unlock() - return r.Read(p) -} - // This should be >= 512 bytes for DetectContentType, // but otherwise it's somewhat arbitrary. const bufferBeforeChunkingSize = 2048 @@ -265,15 +258,15 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { return len(p), nil } if cw.chunking { - _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) + _, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p)) if err != nil { cw.res.conn.rwc.Close() return } } - n, err = cw.res.conn.buf.Write(p) + n, err = cw.res.conn.bufw.Write(p) if cw.chunking && err == nil { - _, err = cw.res.conn.buf.Write(crlf) + _, err = cw.res.conn.bufw.Write(crlf) } if err != nil { cw.res.conn.rwc.Close() @@ -285,7 +278,7 @@ func (cw *chunkWriter) flush() { if !cw.wroteHeader { cw.writeHeader(nil) } - cw.res.conn.buf.Flush() + cw.res.conn.bufw.Flush() } func (cw *chunkWriter) close() { @@ -293,7 +286,7 @@ func (cw *chunkWriter) close() { cw.writeHeader(nil) } if cw.chunking { - bw := cw.res.conn.buf // conn's bufio writer + bw := cw.res.conn.bufw // conn's bufio writer // zero chunk to mark EOF bw.WriteString("0\r\n") if len(cw.res.trailers) > 0 { @@ -315,12 +308,12 @@ func (cw *chunkWriter) close() { type response struct { conn *conn req *Request // request for this response - wroteHeader bool // reply header has been (logically) written - wroteContinue bool // 100 Continue response was written + reqBody io.ReadCloser + wroteHeader bool // reply header has been (logically) written + wroteContinue bool // 100 Continue response was written w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter - sw *switchWriter // of the bufio.Writer, for return to putBufioWriter // handlerHeader is the Header that Handlers get access to, // which may be retained and mutated even after WriteHeader. @@ -354,13 +347,22 @@ type response struct { // written. trailers []string - handlerDone bool // set true when the handler exits + handlerDone atomicBool // set true when the handler exits // Buffers for Date and Content-Length dateBuf [len(TimeFormat)]byte clenBuf [10]byte + + // closeNotifyCh is non-nil once CloseNotify is called. + // Guarded by conn.mu + closeNotifyCh <-chan bool } +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } + // declareTrailer is called for each Trailer header when the // response header is written. It notes that a header will need to be // written in the trailers at the end of the response. @@ -423,7 +425,9 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) { return 0, err } if !ok || !regFile { - return io.Copy(writerOnly{w}, src) + bufp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bufp) + return io.CopyBuffer(writerOnly{w}, src, *bufp) } // sendfile path: @@ -456,29 +460,88 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) { return n, err } -// noLimit is an effective infinite upper bound for io.LimitedReader -const noLimit int64 = (1 << 63) - 1 - // debugServerConnections controls whether all server connections are wrapped // with a verbose logging wrapper. const debugServerConnections = false // Create new connection from rwc. -func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { - c = new(conn) - c.remoteAddr = rwc.RemoteAddr().String() - c.server = srv - c.rwc = rwc - c.w = rwc +func (srv *Server) newConn(rwc net.Conn) *conn { + c := &conn{ + server: srv, + rwc: rwc, + } if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } - c.sr.r = c.rwc - c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) - br := newBufioReader(c.lr) - bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) - c.buf = bufio.NewReadWriter(br, bw) - return c, nil + return c +} + +type readResult struct { + n int + err error + b byte // byte read, if n == 1 +} + +// connReader is the io.Reader wrapper used by *conn. It combines a +// selectively-activated io.LimitedReader (to bound request header +// read sizes) with support for selectively keeping an io.Reader.Read +// call blocked in a background goroutine to wait for activity and +// trigger a CloseNotifier channel. +type connReader struct { + r io.Reader + remain int64 // bytes remaining + + // ch is non-nil if a background read is in progress. + // It is guarded by conn.mu. + ch chan readResult +} + +func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } +func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 } +func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } + +func (cr *connReader) Read(p []byte) (n int, err error) { + if cr.hitReadLimit() { + return 0, io.EOF + } + if len(p) == 0 { + return + } + if int64(len(p)) > cr.remain { + p = p[:cr.remain] + } + + // Is a background read (started by CloseNotifier) already in + // flight? If so, wait for it and use its result. + ch := cr.ch + if ch != nil { + cr.ch = nil + res := <-ch + if res.n == 1 { + p[0] = res.b + cr.remain -= 1 + } + return res.n, res.err + } + n, err = cr.r.Read(p) + cr.remain -= int64(n) + return +} + +func (cr *connReader) startBackgroundRead(onReadComplete func()) { + if cr.ch != nil { + // Background read already started. + return + } + cr.ch = make(chan readResult, 1) + go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete) +} + +func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) { + var buf [1]byte + n, err := cr.r.Read(buf[:1]) + onReadComplete() + ch <- readResult{n, err, buf[0]} } var ( @@ -487,6 +550,13 @@ var ( bufioWriter4kPool sync.Pool ) +var copyBufPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, 32*1024) + return &b + }, +} + func bufioWriterPool(size int) *sync.Pool { switch size { case 2 << 10: @@ -544,7 +614,7 @@ func (srv *Server) maxHeaderBytes() int { return DefaultMaxHeaderBytes } -func (srv *Server) initialLimitedReaderSize() int64 { +func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } @@ -563,8 +633,8 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true - ecr.resp.conn.buf.WriteString("HTTP/1.1 100 Continue\r\n\r\n") - ecr.resp.conn.buf.Flush() + ecr.resp.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + ecr.resp.conn.bufw.Flush() } n, err = ecr.readCloser.Read(p) if err == io.EOF { @@ -578,10 +648,12 @@ func (ecr *expectContinueReader) Close() error { return ecr.readCloser.Close() } -// TimeFormat is the time format to use with -// time.Parse and time.Time.Format when parsing -// or generating times in HTTP headers. -// It is like time.RFC1123 but hard codes GMT as the time zone. +// TimeFormat is the time format to use when generating times in HTTP +// headers. It is like time.RFC1123 but hard-codes GMT as the time +// zone. The time being formatted must be in UTC for Format to +// generate the correct format. +// +// For parsing this time format, see ParseTime. const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" // appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat)) @@ -623,21 +695,45 @@ func (c *conn) readRequest() (w *response, err error) { }() } - c.lr.N = c.server.initialLimitedReaderSize() + c.r.setReadLimit(c.server.initialReadLimitSize()) + c.mu.Lock() // while using bufr if c.lastMethod == "POST" { // RFC 2616 section 4.1 tolerance for old buggy clients. - peek, _ := c.buf.Reader.Peek(4) // ReadRequest will get err below - c.buf.Reader.Discard(numLeadingCRorLF(peek)) + peek, _ := c.bufr.Peek(4) // ReadRequest will get err below + c.bufr.Discard(numLeadingCRorLF(peek)) } - var req *Request - if req, err = ReadRequest(c.buf.Reader); err != nil { - if c.lr.N == 0 { + req, err := readRequest(c.bufr, keepHostHeader) + c.mu.Unlock() + if err != nil { + if c.r.hitReadLimit() { return nil, errTooLarge } return nil, err } - c.lr.N = noLimit c.lastMethod = req.Method + c.r.setInfiniteReadLimit() + + hosts, haveHost := req.Header["Host"] + if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) { + return nil, badRequestError("missing required Host header") + } + if len(hosts) > 1 { + return nil, badRequestError("too many Host headers") + } + if len(hosts) == 1 && !validHostHeader(hosts[0]) { + return nil, badRequestError("malformed Host header") + } + for k, vv := range req.Header { + if !validHeaderName(k) { + return nil, badRequestError("invalid header name") + } + for _, v := range vv { + if !validHeaderValue(v) { + return nil, badRequestError("invalid header value") + } + } + } + delete(req.Header, "Host") req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState @@ -648,6 +744,7 @@ func (c *conn) readRequest() (w *response, err error) { w = &response{ conn: c, req: req, + reqBody: req.Body, handlerHeader: make(Header), contentLength: -1, } @@ -755,7 +852,7 @@ func (h extraHeader) Write(w *bufio.Writer) { } // writeHeader finalizes the header sent to the client and writes it -// to cw.res.conn.buf. +// to cw.res.conn.bufw. // // p is not written by writeHeader, but is the first chunk of the body // that will be written. It is sniffed for a Content-Type if none is @@ -821,7 +918,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // send a Content-Length header. // Further, we don't send an automatic Content-Length if they // set a Transfer-Encoding, because they're generally incompatible. - if w.handlerDone && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + if w.handlerDone.isSet() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { w.contentLength = int64(len(p)) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) } @@ -898,7 +995,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { } if discard { - _, err := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) + _, err := io.CopyN(ioutil.Discard, w.reqBody, maxPostHandlerReadBytes+1) switch err { case nil: // There must be even more data left over. @@ -907,7 +1004,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // Body was already consumed and closed. case io.EOF: // The remaining body was just consumed, close it. - err = w.req.Body.Close() + err = w.reqBody.Close() if err != nil { w.closeAfterReply = true } @@ -996,10 +1093,10 @@ func (cw *chunkWriter) writeHeader(p []byte) { } } - w.conn.buf.WriteString(statusLine(w.req, code)) - cw.header.WriteSubset(w.conn.buf, excludeHeader) - setHeader.Write(w.conn.buf.Writer) - w.conn.buf.Write(crlf) + w.conn.bufw.WriteString(statusLine(w.req, code)) + cw.header.WriteSubset(w.conn.bufw, excludeHeader) + setHeader.Write(w.conn.bufw) + w.conn.bufw.Write(crlf) } // foreachHeaderElement splits v according to the "#rule" construction @@ -1144,7 +1241,7 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er } func (w *response) finishRequest() { - w.handlerDone = true + w.handlerDone.setTrue() if !w.wroteHeader { w.WriteHeader(StatusOK) @@ -1153,11 +1250,11 @@ func (w *response) finishRequest() { w.w.Flush() putBufioWriter(w.w) w.cw.close() - w.conn.buf.Flush() + w.conn.bufw.Flush() // Close the body (regardless of w.closeAfterReply) so we can // re-use its bufio.Reader later safely. - w.req.Body.Close() + w.reqBody.Close() if w.req.MultipartForm != nil { w.req.MultipartForm.RemoveAll() @@ -1206,28 +1303,26 @@ func (w *response) Flush() { } func (c *conn) finalFlush() { - if c.buf != nil { - c.buf.Flush() - + if c.bufr != nil { // Steal the bufio.Reader (~4KB worth of memory) and its associated // reader for a future connection. - putBufioReader(c.buf.Reader) + putBufioReader(c.bufr) + c.bufr = nil + } + if c.bufw != nil { + c.bufw.Flush() // Steal the bufio.Writer (~4KB worth of memory) and its associated // writer for a future connection. - putBufioWriter(c.buf.Writer) - - c.buf = nil + putBufioWriter(c.bufw) + c.bufw = nil } } // Close the connection. func (c *conn) close() { c.finalFlush() - if c.rwc != nil { - c.rwc.Close() - c.rwc = nil - } + c.rwc.Close() } // rstAvoidanceDelay is the amount of time we sleep after closing the @@ -1277,9 +1372,16 @@ func (c *conn) setState(nc net.Conn, state ConnState) { } } +// badRequestError is a literal string (used by in the server in HTML, +// unescaped) to tell the user why their request was bad. It should +// be plain text without user info or other embeddded errors. +type badRequestError string + +func (e badRequestError) Error() string { return "Bad Request: " + string(e) } + // Serve a new connection. func (c *conn) serve() { - origConn := c.rwc // copy it before it's set nil on Close or Hijack + c.remoteAddr = c.rwc.RemoteAddr().String() defer func() { if err := recover(); err != nil { const size = 64 << 10 @@ -1289,7 +1391,7 @@ func (c *conn) serve() { } if !c.hijacked() { c.close() - c.setState(origConn, StateClosed) + c.setState(c.rwc, StateClosed) } }() @@ -1315,9 +1417,13 @@ func (c *conn) serve() { } } + c.r = &connReader{r: c.rwc} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + for { w, err := c.readRequest() - if c.lr.N != c.server.initialLimitedReaderSize() { + if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. c.setState(c.rwc, StateActive) } @@ -1328,16 +1434,22 @@ func (c *conn) serve() { // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n") + io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large") c.closeWriteAndWait() - break - } else if err == io.EOF { - break // Don't reply - } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { - break // Don't reply + return } - io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n") - break + if err == io.EOF { + return // don't reply + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return // don't reply + } + var publicErr string + if v, ok := err.(badRequestError); ok { + publicErr = ": " + string(v) + } + io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr) + return } // Expect 100 Continue support @@ -1347,10 +1459,9 @@ func (c *conn) serve() { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} } - req.Header.Del("Expect") } else if req.Header.get("Expect") != "" { w.sendExpectationFailed() - break + return } // HTTP cannot have multiple simultaneous active requests.[*] @@ -1367,7 +1478,7 @@ func (c *conn) serve() { if w.requestBodyLimitHit || w.closedRequestBodyEarly() { c.closeWriteAndWait() } - break + return } c.setState(c.rwc, StateIdle) } @@ -1394,12 +1505,24 @@ func (w *response) sendExpectationFailed() { // Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter // and a Hijacker. func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + if w.handlerDone.isSet() { + panic("net/http: Hijack called after ServeHTTP finished") + } if w.wroteHeader { w.cw.flush() } + + c := w.conn + c.mu.Lock() + defer c.mu.Unlock() + + if w.closeNotifyCh != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call") + } + // Release the bufioWriter that writes to the chunk writer, it is not // used after a connection has been hijacked. - rwc, buf, err = w.conn.hijack() + rwc, buf, err = c.hijackLocked() if err == nil { putBufioWriter(w.w) w.w = nil @@ -1408,13 +1531,86 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { } func (w *response) CloseNotify() <-chan bool { - return w.conn.closeNotify() + if w.handlerDone.isSet() { + panic("net/http: CloseNotify called after ServeHTTP finished") + } + c := w.conn + c.mu.Lock() + defer c.mu.Unlock() + + if w.closeNotifyCh != nil { + return w.closeNotifyCh + } + ch := make(chan bool, 1) + w.closeNotifyCh = ch + + if w.conn.hijackedv { + // CloseNotify is undefined after a hijack, but we have + // no place to return an error, so just return a channel, + // even though it'll never receive a value. + return ch + } + + var once sync.Once + notify := func() { once.Do(func() { ch <- true }) } + + if requestBodyRemains(w.reqBody) { + // They're still consuming the request body, so we + // shouldn't notify yet. + registerOnHitEOF(w.reqBody, func() { + c.mu.Lock() + defer c.mu.Unlock() + startCloseNotifyBackgroundRead(c, notify) + }) + } else { + startCloseNotifyBackgroundRead(c, notify) + } + return ch +} + +// c.mu must be held. +func startCloseNotifyBackgroundRead(c *conn, notify func()) { + if c.bufr.Buffered() > 0 { + // They've consumed the request body, so anything + // remaining is a pipelined request, which we + // document as firing on. + notify() + } else { + c.r.startBackgroundRead(notify) + } +} + +func registerOnHitEOF(rc io.ReadCloser, fn func()) { + switch v := rc.(type) { + case *expectContinueReader: + registerOnHitEOF(v.readCloser, fn) + case *body: + v.registerOnHitEOF(fn) + default: + panic("unexpected type " + fmt.Sprintf("%T", rc)) + } +} + +// requestBodyRemains reports whether future calls to Read +// on rc might yield more data. +func requestBodyRemains(rc io.ReadCloser) bool { + if rc == eofReader { + return false + } + switch v := rc.(type) { + case *expectContinueReader: + return requestBodyRemains(v.readCloser) + case *body: + return v.bodyRemains() + default: + panic("unexpected type " + fmt.Sprintf("%T", rc)) + } } // The HandlerFunc type is an adapter to allow the use of // ordinary functions as HTTP handlers. If f is a function // with the appropriate signature, HandlerFunc(f) is a -// Handler object that calls f. +// Handler that calls f. type HandlerFunc func(ResponseWriter, *Request) // ServeHTTP calls f(w, r). @@ -1461,6 +1657,9 @@ func StripPrefix(prefix string, h Handler) Handler { // Redirect replies to the request with a redirect to url, // which may be a path relative to the request path. +// +// The provided code should be in the 3xx range and is usually +// StatusMovedPermanently, StatusFound or StatusSeeOther. func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { if u, err := url.Parse(urlStr); err == nil { // If url was relative, make absolute by @@ -1479,11 +1678,12 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { // Because of this problem, no one pays attention // to the RFC; they all send back just a new path. // So do we. - oldpath := r.URL.Path - if oldpath == "" { // should not happen, but avoid a crash if it does - oldpath = "/" - } - if u.Scheme == "" { + if u.Scheme == "" && u.Host == "" { + oldpath := r.URL.Path + if oldpath == "" { // should not happen, but avoid a crash if it does + oldpath = "/" + } + // no leading http://server if urlStr == "" || urlStr[0] != '/' { // make relative path absolute @@ -1545,6 +1745,9 @@ func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) { // RedirectHandler returns a request handler that redirects // each request it receives to the given url using the given // status code. +// +// The provided code should be in the 3xx range and is usually +// StatusMovedPermanently, StatusFound or StatusSeeOther. func RedirectHandler(url string, code int) Handler { return &redirectHandler{url, code} } @@ -1567,6 +1770,14 @@ func RedirectHandler(url string, code int) Handler { // the pattern "/" matches all paths not matched by other registered // patterns, not just the URL with Path == "/". // +// If a subtree has been registered and a request is received naming the +// subtree root without its trailing slash, ServeMux redirects that +// request to the subtree root (adding the trailing slash). This behavior can +// be overridden with a separate registration for the path without +// the trailing slash. For example, registering "/images/" causes ServeMux +// to redirect a request for "/images" to "/images/", unless "/images" has +// been registered separately. +// // Patterns may optionally begin with a host name, restricting matches to // URLs on that host only. Host-specific patterns take precedence over // general patterns, so that a handler might register for the two patterns @@ -1574,8 +1785,8 @@ func RedirectHandler(url string, code int) Handler { // requests for "http://www.google.com/". // // ServeMux also takes care of sanitizing the URL request path, -// redirecting any request containing . or .. elements to an -// equivalent .- and ..-free URL. +// redirecting any request containing . or .. elements or repeated slashes +// to an equivalent, cleaner URL. type ServeMux struct { mu sync.RWMutex m map[string]muxEntry @@ -1782,6 +1993,7 @@ type Server struct { // handle HTTP requests and will initialize the Request's TLS // and RemoteAddr if not already set. The connection is // automatically closed when the function returns. + // If TLSNextProto is nil, HTTP/2 support is enabled automatically. TLSNextProto map[string]func(*Server, *tls.Conn, Handler) // ConnState specifies an optional callback function that is @@ -1795,7 +2007,9 @@ type Server struct { // standard logger. ErrorLog *log.Logger - disableKeepAlives int32 // accessed atomically. + disableKeepAlives int32 // accessed atomically. + nextProtoOnce sync.Once // guards initialization of TLSNextProto in Serve + nextProtoErr error } // A ConnState represents the state of a client connection to a server. @@ -1815,6 +2029,11 @@ const ( // and doesn't fire again until the request has been // handled. After the request is handled, the state // transitions to StateClosed, StateHijacked, or StateIdle. + // For HTTP/2, StateActive fires on the transition from zero + // to one active request, and only transitions away once all + // active requests are complete. That means that ConnState + // can not be used to do per-request work; ConnState only notes + // the overall state of the connection. StateActive // StateIdle represents a connection that has finished @@ -1863,8 +2082,10 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { } // ListenAndServe listens on the TCP network address srv.Addr and then -// calls Serve to handle requests on incoming connections. If -// srv.Addr is blank, ":http" is used. +// calls Serve to handle requests on incoming connections. +// Accepted connections are configured to enable TCP keep-alives. +// If srv.Addr is blank, ":http" is used. +// ListenAndServe always returns a non-nil error. func (srv *Server) ListenAndServe() error { addr := srv.Addr if addr == "" { @@ -1877,12 +2098,21 @@ func (srv *Server) ListenAndServe() error { return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) } +var testHookServerServe func(*Server, net.Listener) // used if non-nil + // Serve accepts incoming connections on the Listener l, creating a -// new service goroutine for each. The service goroutines read requests and +// new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. +// Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { defer l.Close() + if fn := testHookServerServe; fn != nil { + fn(srv, l) + } var tempDelay time.Duration // how long to sleep on accept failure + if err := srv.setupHTTP2(); err != nil { + return err + } for { rw, e := l.Accept() if e != nil { @@ -1902,10 +2132,7 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - c, err := srv.newConn(rw) - if err != nil { - continue - } + c := srv.newConn(rw) c.setState(c.rwc, StateNew) // before Serve can return go c.serve() } @@ -1937,8 +2164,10 @@ func (s *Server) logf(format string, args ...interface{}) { // ListenAndServe listens on the TCP network address addr // and then calls Serve with handler to handle requests -// on incoming connections. Handler is typically nil, -// in which case the DefaultServeMux is used. +// on incoming connections. +// Accepted connections are configured to enable TCP keep-alives. +// Handler is typically nil, in which case the DefaultServeMux is +// used. // // A trivial example server is: // @@ -1957,11 +2186,10 @@ func (s *Server) logf(format string, args ...interface{}) { // // func main() { // http.HandleFunc("/hello", HelloServer) -// err := http.ListenAndServe(":12345", nil) -// if err != nil { -// log.Fatal("ListenAndServe: ", err) -// } +// log.Fatal(http.ListenAndServe(":12345", nil)) // } +// +// ListenAndServe always returns a non-nil error. func ListenAndServe(addr string, handler Handler) error { server := &Server{Addr: addr, Handler: handler} return server.ListenAndServe() @@ -1989,19 +2217,20 @@ func ListenAndServe(addr string, handler Handler) error { // http.HandleFunc("/", handler) // log.Printf("About to listen on 10443. Go to https://127.0.0.1:10443/") // err := http.ListenAndServeTLS(":10443", "cert.pem", "key.pem", nil) -// if err != nil { -// log.Fatal(err) -// } +// log.Fatal(err) // } // // One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem. -func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler) error { +// +// ListenAndServeTLS always returns a non-nil error. +func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { server := &Server{Addr: addr, Handler: handler} return server.ListenAndServeTLS(certFile, keyFile) } // ListenAndServeTLS listens on the TCP network address srv.Addr and // then calls Serve to handle requests on incoming TLS connections. +// Accepted connections are configured to enable TCP keep-alives. // // Filenames containing a certificate and matching private key for the // server must be provided if the Server's TLSConfig.Certificates is @@ -2010,14 +2239,23 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han // certificate, any intermediates, and the CA's certificate. // // If srv.Addr is blank, ":https" is used. +// +// ListenAndServeTLS always returns a non-nil error. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { addr := srv.Addr if addr == "" { addr = ":https" } + + // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig + // before we clone it and create the TLS Listener. + if err := srv.setupHTTP2(); err != nil { + return err + } + config := cloneTLSConfig(srv.TLSConfig) - if config.NextProtos == nil { - config.NextProtos = []string{"http/1.1"} + if !strSliceContains(config.NextProtos, "http/1.1") { + config.NextProtos = append(config.NextProtos, "http/1.1") } if len(config.Certificates) == 0 || certFile != "" || keyFile != "" { @@ -2038,6 +2276,25 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { return srv.Serve(tlsListener) } +func (srv *Server) setupHTTP2() error { + srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults) + return srv.nextProtoErr +} + +// onceSetNextProtoDefaults configures HTTP/2, if the user hasn't +// configured otherwise. (by setting srv.TLSNextProto non-nil) +// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2). +func (srv *Server) onceSetNextProtoDefaults() { + if strings.Contains(os.Getenv("GODEBUG"), "http2server=0") { + return + } + // Enable HTTP/2 by default if the user hasn't otherwise + // configured their TLSNextProto map. + if srv.TLSNextProto == nil { + srv.nextProtoErr = http2ConfigureServer(srv, nil) + } +} + // TimeoutHandler returns a Handler that runs h with the given time limit. // // The new Handler calls h.ServeHTTP to handle each request, but if a @@ -2046,11 +2303,20 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { // (If msg is empty, a suitable default message will be sent.) // After such a timeout, writes by h to its ResponseWriter will return // ErrHandlerTimeout. +// +// TimeoutHandler buffers all Handler writes to memory and does not +// support the Hijacker or Flusher interfaces. func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler { - f := func() <-chan time.Time { - return time.After(dt) + t := time.NewTimer(dt) + return &timeoutHandler{ + handler: h, + body: msg, + + // Effectively storing a *time.Timer, but decomposed + // for testing: + timeout: func() <-chan time.Time { return t.C }, + cancelTimer: t.Stop, } - return &timeoutHandler{h, f, msg} } // ErrHandlerTimeout is returned on ResponseWriter Write calls @@ -2059,8 +2325,13 @@ var ErrHandlerTimeout = errors.New("http: Handler timeout") type timeoutHandler struct { handler Handler - timeout func() <-chan time.Time // returns channel producing a timeout body string + + // timeout returns the channel of a *time.Timer and + // cancelTimer cancels it. They're stored separately for + // testing purposes. + timeout func() <-chan time.Time // returns channel producing a timeout + cancelTimer func() bool // optional } func (h *timeoutHandler) errorBody() string { @@ -2071,46 +2342,61 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { - done := make(chan bool, 1) - tw := &timeoutWriter{w: w} + done := make(chan struct{}) + tw := &timeoutWriter{ + w: w, + h: make(Header), + } go func() { h.handler.ServeHTTP(tw, r) - done <- true + close(done) }() select { case <-done: - return - case <-h.timeout(): tw.mu.Lock() defer tw.mu.Unlock() - if !tw.wroteHeader { - tw.w.WriteHeader(StatusServiceUnavailable) - tw.w.Write([]byte(h.errorBody())) + dst := w.Header() + for k, vv := range tw.h { + dst[k] = vv + } + w.WriteHeader(tw.code) + w.Write(tw.wbuf.Bytes()) + if h.cancelTimer != nil { + h.cancelTimer() } + case <-h.timeout(): + tw.mu.Lock() + defer tw.mu.Unlock() + w.WriteHeader(StatusServiceUnavailable) + io.WriteString(w, h.errorBody()) tw.timedOut = true + return } } type timeoutWriter struct { - w ResponseWriter + w ResponseWriter + h Header + wbuf bytes.Buffer mu sync.Mutex timedOut bool wroteHeader bool + code int } -func (tw *timeoutWriter) Header() Header { - return tw.w.Header() -} +func (tw *timeoutWriter) Header() Header { return tw.h } func (tw *timeoutWriter) Write(p []byte) (int, error) { tw.mu.Lock() defer tw.mu.Unlock() - tw.wroteHeader = true // implicitly at least if tw.timedOut { return 0, ErrHandlerTimeout } - return tw.w.Write(p) + if !tw.wroteHeader { + tw.writeHeader(StatusOK) + } + return tw.wbuf.Write(p) } func (tw *timeoutWriter) WriteHeader(code int) { @@ -2119,8 +2405,12 @@ func (tw *timeoutWriter) WriteHeader(code int) { if tw.timedOut || tw.wroteHeader { return } + tw.writeHeader(code) +} + +func (tw *timeoutWriter) writeHeader(code int) { tw.wroteHeader = true - tw.w.WriteHeader(code) + tw.code = code } // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted @@ -2247,7 +2537,7 @@ type checkConnErrorWriter struct { } func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { - n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil. + n, err = w.c.rwc.Write(p) if err != nil && w.c.werr == nil { w.c.werr = err } @@ -2265,3 +2555,12 @@ func numLeadingCRorLF(v []byte) (n int) { return } + +func strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go index 3be8c865d3b..18810bad068 100644 --- a/libgo/go/net/http/sniff.go +++ b/libgo/go/net/http/sniff.go @@ -102,10 +102,9 @@ var sniffSignatures = []sniffSig{ &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"}, &exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"}, - // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4. - //mp4Sig(0), + mp4Sig{}, - textSig(0), // should be last + textSig{}, // should be last } type exactSig struct { @@ -166,12 +165,14 @@ func (h htmlSig) match(data []byte, firstNonWS int) string { } var mp4ftype = []byte("ftyp") +var mp4 = []byte("mp4") -type mp4Sig int +type mp4Sig struct{} func (mp4Sig) match(data []byte, firstNonWS int) string { - // c.f. section 6.1. - if len(data) < 8 { + // https://mimesniff.spec.whatwg.org/#signature-for-mp4 + // c.f. section 6.2.1 + if len(data) < 12 { return "" } boxSize := int(binary.BigEndian.Uint32(data[:4])) @@ -186,30 +187,20 @@ func (mp4Sig) match(data []byte, firstNonWS int) string { // minor version number continue } - seg := string(data[st : st+3]) - switch seg { - case "mp4", "iso", "M4V", "M4P", "M4B": + if bytes.Equal(data[st:st+3], mp4) { return "video/mp4" - /* The remainder are not in the spec. - case "M4A": - return "audio/mp4" - case "3gp": - return "video/3gpp" - case "jp2": - return "image/jp2" // JPEG 2000 - */ } } return "" } -type textSig int +type textSig struct{} func (textSig) match(data []byte, firstNonWS int) string { // c.f. section 5, step 4. for _, b := range data[firstNonWS:] { switch { - case 0x00 <= b && b <= 0x08, + case b <= 0x08, b == 0x0B, 0x0E <= b && b <= 0x1A, 0x1C <= b && b <= 0x1F: diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index 24ca27afc16..e0085516da3 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -11,7 +11,6 @@ import ( "io/ioutil" "log" . "net/http" - "net/http/httptest" "reflect" "strconv" "strings" @@ -40,9 +39,7 @@ var sniffTests = []struct { {"GIF 87a", []byte(`GIF87a`), "image/gif"}, {"GIF 89a", []byte(`GIF89a...`), "image/gif"}, - // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4. - //{"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"}, - //{"MP4 audio", []byte("\x00\x00\x00\x20ftypM4A \x00\x00\x00\x00M4A mp42isom\x00\x00\x00\x00"), "audio/mp4"}, + {"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"}, } func TestDetectContentType(t *testing.T) { @@ -54,9 +51,12 @@ func TestDetectContentType(t *testing.T) { } } -func TestServerContentType(t *testing.T) { +func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) } +func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) } + +func testServerContentType(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] n, err := w.Write(tt.data) @@ -64,10 +64,10 @@ func TestServerContentType(t *testing.T) { log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data)) } })) - defer ts.Close() + defer cst.close() for i, tt := range sniffTests { - resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i)) + resp, err := cst.c.Get(cst.ts.URL + "/?i=" + strconv.Itoa(i)) if err != nil { t.Errorf("%v: %v", tt.desc, err) continue @@ -87,15 +87,17 @@ func TestServerContentType(t *testing.T) { // Issue 5953: shouldn't sniff if the handler set a Content-Type header, // even if it's the empty string. -func TestServerIssue5953(t *testing.T) { +func TestServerIssue5953_h1(t *testing.T) { testServerIssue5953(t, h1Mode) } +func TestServerIssue5953_h2(t *testing.T) { testServerIssue5953(t, h2Mode) } +func testServerIssue5953(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Content-Type"] = []string{""} fmt.Fprintf(w, "<html><head></head><body>hi</body></html>") })) - defer ts.Close() + defer cst.close() - resp, err := Get(ts.URL) + resp, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -108,7 +110,9 @@ func TestServerIssue5953(t *testing.T) { resp.Body.Close() } -func TestContentTypeWithCopy(t *testing.T) { +func TestContentTypeWithCopy_h1(t *testing.T) { testContentTypeWithCopy(t, h1Mode) } +func TestContentTypeWithCopy_h2(t *testing.T) { testContentTypeWithCopy(t, h2Mode) } +func testContentTypeWithCopy(t *testing.T, h2 bool) { defer afterTest(t) const ( @@ -116,7 +120,7 @@ func TestContentTypeWithCopy(t *testing.T) { expected = "text/html; charset=utf-8" ) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { // Use io.Copy from a bytes.Buffer to trigger ReadFrom. buf := bytes.NewBuffer([]byte(input)) n, err := io.Copy(w, buf) @@ -124,9 +128,9 @@ func TestContentTypeWithCopy(t *testing.T) { t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input)) } })) - defer ts.Close() + defer cst.close() - resp, err := Get(ts.URL) + resp, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatalf("Get: %v", err) } @@ -142,9 +146,11 @@ func TestContentTypeWithCopy(t *testing.T) { resp.Body.Close() } -func TestSniffWriteSize(t *testing.T) { +func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) } +func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) } +func testSniffWriteSize(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) if err != nil { @@ -155,9 +161,9 @@ func TestSniffWriteSize(t *testing.T) { t.Errorf("write of %d bytes wrote %d bytes", size, written) } })) - defer ts.Close() + defer cst.close() for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { - res, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size)) + res, err := cst.c.Get(fmt.Sprintf("%s/?size=%d", cst.ts.URL, size)) if err != nil { t.Fatalf("size %d: %v", size, err) } diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go index d253bd5cb54..f3dacab6a92 100644 --- a/libgo/go/net/http/status.go +++ b/libgo/go/net/http/status.go @@ -44,20 +44,18 @@ const ( StatusRequestedRangeNotSatisfiable = 416 StatusExpectationFailed = 417 StatusTeapot = 418 + StatusPreconditionRequired = 428 + StatusTooManyRequests = 429 + StatusRequestHeaderFieldsTooLarge = 431 + StatusUnavailableForLegalReasons = 451 - StatusInternalServerError = 500 - StatusNotImplemented = 501 - StatusBadGateway = 502 - StatusServiceUnavailable = 503 - StatusGatewayTimeout = 504 - StatusHTTPVersionNotSupported = 505 - - // New HTTP status codes from RFC 6585. Not exported yet in Go 1.1. - // See discussion at https://codereview.appspot.com/7678043/ - statusPreconditionRequired = 428 - statusTooManyRequests = 429 - statusRequestHeaderFieldsTooLarge = 431 - statusNetworkAuthenticationRequired = 511 + StatusInternalServerError = 500 + StatusNotImplemented = 501 + StatusBadGateway = 502 + StatusServiceUnavailable = 503 + StatusGatewayTimeout = 504 + StatusHTTPVersionNotSupported = 505 + StatusNetworkAuthenticationRequired = 511 ) var statusText = map[int]string{ @@ -99,18 +97,18 @@ var statusText = map[int]string{ StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", StatusExpectationFailed: "Expectation Failed", StatusTeapot: "I'm a teapot", + StatusPreconditionRequired: "Precondition Required", + StatusTooManyRequests: "Too Many Requests", + StatusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", + StatusUnavailableForLegalReasons: "Unavailable For Legal Reasons", - StatusInternalServerError: "Internal Server Error", - StatusNotImplemented: "Not Implemented", - StatusBadGateway: "Bad Gateway", - StatusServiceUnavailable: "Service Unavailable", - StatusGatewayTimeout: "Gateway Timeout", - StatusHTTPVersionNotSupported: "HTTP Version Not Supported", - - statusPreconditionRequired: "Precondition Required", - statusTooManyRequests: "Too Many Requests", - statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", - statusNetworkAuthenticationRequired: "Network Authentication Required", + StatusInternalServerError: "Internal Server Error", + StatusNotImplemented: "Not Implemented", + StatusBadGateway: "Bad Gateway", + StatusServiceUnavailable: "Service Unavailable", + StatusGatewayTimeout: "Gateway Timeout", + StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + StatusNetworkAuthenticationRequired: "Network Authentication Required", } // StatusText returns a text for the HTTP status code. It returns the empty diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index a8736b28e16..6e59af8f6f4 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -56,7 +56,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { if rr.ContentLength != 0 && rr.Body == nil { return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) } - t.Method = rr.Method + t.Method = valueOrDefault(rr.Method, "GET") t.Body = rr.Body t.BodyCloser = rr.Body t.ContentLength = rr.ContentLength @@ -271,6 +271,10 @@ type transferReader struct { Trailer Header } +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +} + // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC2616, section 4.4. func bodyAllowedForStatus(status int) bool { @@ -337,7 +341,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { } // Transfer encoding, content length - t.TransferEncoding, err = fixTransferEncoding(isResponse, t.RequestMethod, t.Header) + err = t.fixTransferEncoding() if err != nil { return err } @@ -424,13 +428,18 @@ func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } // Checks whether the encoding is explicitly "identity". func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" } -// Sanitize transfer encoding -func fixTransferEncoding(isResponse bool, requestMethod string, header Header) ([]string, error) { - raw, present := header["Transfer-Encoding"] +// fixTransferEncoding sanitizes t.TransferEncoding, if needed. +func (t *transferReader) fixTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] if !present { - return nil, nil + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil } - delete(header, "Transfer-Encoding") encodings := strings.Split(raw[0], ",") te := make([]string, 0, len(encodings)) @@ -445,13 +454,13 @@ func fixTransferEncoding(isResponse bool, requestMethod string, header Header) ( break } if encoding != "chunked" { - return nil, &badStringError{"unsupported transfer encoding", encoding} + return &badStringError{"unsupported transfer encoding", encoding} } te = te[0 : len(te)+1] te[len(te)-1] = encoding } if len(te) > 1 { - return nil, &badStringError{"too many transfer encodings", strings.Join(te, ",")} + return &badStringError{"too many transfer encodings", strings.Join(te, ",")} } if len(te) > 0 { // RFC 7230 3.3.2 says "A sender MUST NOT send a @@ -470,11 +479,12 @@ func fixTransferEncoding(isResponse bool, requestMethod string, header Header) ( // such a message downstream." // // Reportedly, these appear in the wild. - delete(header, "Content-Length") - return te, nil + delete(t.Header, "Content-Length") + t.TransferEncoding = te + return nil } - return nil, nil + return nil } // Determine the expected body length, using RFC 2616 Section 4.4. This @@ -567,21 +577,29 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { // Parse the trailer header func fixTrailer(header Header, te []string) (Header, error) { - raw := header.get("Trailer") - if raw == "" { + vv, ok := header["Trailer"] + if !ok { return nil, nil } - header.Del("Trailer") + trailer := make(Header) - keys := strings.Split(raw, ",") - for _, key := range keys { - key = CanonicalHeaderKey(strings.TrimSpace(key)) - switch key { - case "Transfer-Encoding", "Trailer", "Content-Length": - return nil, &badStringError{"bad trailer key", key} - } - trailer[key] = nil + var err error + for _, v := range vv { + foreachHeaderElement(v, func(key string) { + key = CanonicalHeaderKey(key) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + if err == nil { + err = &badStringError{"bad trailer key", key} + return + } + } + trailer[key] = nil + }) + } + if err != nil { + return nil, err } if len(trailer) == 0 { return nil, nil @@ -603,10 +621,11 @@ type body struct { closing bool // is the connection to be closed after reading body? doEarlyClose bool // whether Close should stop early - mu sync.Mutex // guards closed, and calls to Read and Close + mu sync.Mutex // guards following, and calls to Read and Close sawEOF bool closed bool - earlyClose bool // Close called and we didn't read to the end of src + earlyClose bool // Close called and we didn't read to the end of src + onHitEOF func() // if non-nil, func to call when EOF is Read } // ErrBodyReadAfterClose is returned when reading a Request or Response @@ -666,6 +685,10 @@ func (b *body) readLocked(p []byte) (n int, err error) { } } + if b.sawEOF && b.onHitEOF != nil { + b.onHitEOF() + } + return n, err } @@ -800,6 +823,20 @@ func (b *body) didEarlyClose() bool { return b.earlyClose } +// bodyRemains reports whether future Read calls might +// yield data. +func (b *body) bodyRemains() bool { + b.mu.Lock() + defer b.mu.Unlock() + return !b.sawEOF +} + +func (b *body) registerOnHitEOF(fn func()) { + b.mu.Lock() + defer b.mu.Unlock() + b.onHitEOF = fn +} + // bodyLocked is a io.Reader reading from a *body when its mutex is // already held. type bodyLocked struct { diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 70d18646059..41df906cf2d 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -36,7 +36,8 @@ var DefaultTransport RoundTripper = &Transport{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).Dial, - TLSHandshakeTimeout: 10 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, } // DefaultMaxIdleConnsPerHost is the default value of Transport's @@ -45,7 +46,21 @@ const DefaultMaxIdleConnsPerHost = 2 // Transport is an implementation of RoundTripper that supports HTTP, // HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). -// Transport can also cache connections for future re-use. +// +// By default, Transport caches connections for future re-use. +// This may leave many open connections when accessing many hosts. +// This behavior can be managed using Transport's CloseIdleConnections method +// and the MaxIdleConnsPerHost and DisableKeepAlives fields. +// +// Transports should be reused instead of created as needed. +// Transports are safe for concurrent use by multiple goroutines. +// +// A Transport is a low-level primitive for making HTTP and HTTPS requests. +// For high-level functionality, such as cookies and redirects, see Client. +// +// Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2 +// for HTTPS URLs, depending on whether the server supports HTTP/2. +// See the package docs for more about HTTP/2. type Transport struct { idleMu sync.Mutex wantIdle bool // user has requested to close all idle conns @@ -113,8 +128,49 @@ type Transport struct { // time does not include the time to read the response body. ResponseHeaderTimeout time.Duration + // ExpectContinueTimeout, if non-zero, specifies the amount of + // time to wait for a server's first response headers after fully + // writing the request headers if the request has an + // "Expect: 100-continue" header. Zero means no timeout. + // This time does not include the time to send the request header. + ExpectContinueTimeout time.Duration + + // TLSNextProto specifies how the Transport switches to an + // alternate protocol (such as HTTP/2) after a TLS NPN/ALPN + // protocol negotiation. If Transport dials an TLS connection + // with a non-empty protocol name and TLSNextProto contains a + // map entry for that key (such as "h2"), then the func is + // called with the request's authority (such as "example.com" + // or "example.com:1234") and the TLS connection. The function + // must return a RoundTripper that then handles the request. + // If TLSNextProto is nil, HTTP/2 support is enabled automatically. + TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper + + // nextProtoOnce guards initialization of TLSNextProto and + // h2transport (via onceSetNextProtoDefaults) + nextProtoOnce sync.Once + h2transport *http2Transport // non-nil if http2 wired up + // TODO: tunable on global max cached connections // TODO: tunable on timeout on cached connections + // TODO: tunable on max per-host TCP dials in flight (Issue 13957) +} + +// onceSetNextProtoDefaults initializes TLSNextProto. +// It must be called via t.nextProtoOnce.Do. +func (t *Transport) onceSetNextProtoDefaults() { + if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") { + return + } + if t.TLSNextProto != nil { + return + } + t2, err := http2configureTransport(t) + if err != nil { + log.Printf("Error enabling Transport HTTP/2 support: %v", err) + } else { + t.h2transport = t2 + } } // ProxyFromEnvironment returns the URL of the proxy to use for a @@ -188,7 +244,8 @@ func (tr *transportRequest) extraHeaders() Header { // // For higher-level HTTP client support (such as handling of cookies // and redirects), see Get, Post, and the Client type. -func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { +func (t *Transport) RoundTrip(req *Request) (*Response, error) { + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) if req.URL == nil { req.closeBody() return nil, errors.New("http: nil Request.URL") @@ -197,54 +254,114 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { req.closeBody() return nil, errors.New("http: nil Request.Header") } - if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - t.altMu.RLock() - var rt RoundTripper - if t.altProto != nil { - rt = t.altProto[req.URL.Scheme] - } - t.altMu.RUnlock() - if rt == nil { - req.closeBody() - return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} + // TODO(bradfitz): switch to atomic.Value for this map instead of RWMutex + t.altMu.RLock() + altRT := t.altProto[req.URL.Scheme] + t.altMu.RUnlock() + if altRT != nil { + if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { + return resp, err } - return rt.RoundTrip(req) } - if req.URL.Host == "" { + if s := req.URL.Scheme; s != "http" && s != "https" { req.closeBody() - return nil, errors.New("http: no Host in request URL") + return nil, &badStringError{"unsupported protocol scheme", s} } - treq := &transportRequest{Request: req} - cm, err := t.connectMethodForRequest(treq) - if err != nil { + if req.Method != "" && !validMethod(req.Method) { + return nil, fmt.Errorf("net/http: invalid method %q", req.Method) + } + if req.URL.Host == "" { req.closeBody() - return nil, err + return nil, errors.New("http: no Host in request URL") } - // Get the cached or newly-created connection to either the - // host (for http or https), the http proxy, or the http proxy - // pre-CONNECTed to https server. In any case, we'll be ready - // to send it requests. - pconn, err := t.getConn(req, cm) - if err != nil { - t.setReqCanceler(req, nil) - req.closeBody() - return nil, err + for { + // treq gets modified by roundTrip, so we need to recreate for each retry. + treq := &transportRequest{Request: req} + cm, err := t.connectMethodForRequest(treq) + if err != nil { + req.closeBody() + return nil, err + } + + // Get the cached or newly-created connection to either the + // host (for http or https), the http proxy, or the http proxy + // pre-CONNECTed to https server. In any case, we'll be ready + // to send it requests. + pconn, err := t.getConn(req, cm) + if err != nil { + t.setReqCanceler(req, nil) + req.closeBody() + return nil, err + } + + var resp *Response + if pconn.alt != nil { + // HTTP/2 path. + t.setReqCanceler(req, nil) // not cancelable with CancelRequest + resp, err = pconn.alt.RoundTrip(req) + } else { + resp, err = pconn.roundTrip(treq) + } + if err == nil { + return resp, nil + } + if err := checkTransportResend(err, req, pconn); err != nil { + return nil, err + } + testHookRoundTripRetried() } +} - return pconn.roundTrip(treq) +// checkTransportResend checks whether a failed HTTP request can be +// resent on a new connection. The non-nil input error is the error from +// roundTrip, which might be wrapped in a beforeRespHeaderError error. +// +// The return value is err or the unwrapped error inside a +// beforeRespHeaderError. +func checkTransportResend(err error, req *Request, pconn *persistConn) error { + brhErr, ok := err.(beforeRespHeaderError) + if !ok { + return err + } + err = brhErr.error // unwrap the custom error in case we return it + if err != errMissingHost && pconn.isReused() && req.isReplayable() { + // If we try to reuse a connection that the server is in the process of + // closing, we may end up successfully writing out our request (or a + // portion of our request) only to find a connection error when we try to + // read from (or finish writing to) the socket. + + // There can be a race between the socket pool checking whether a socket + // is still connected, receiving the FIN, and sending/reading data on a + // reused socket. If we receive the FIN between the connectedness check + // and writing/reading from the socket, we may first learn the socket is + // disconnected when we get a ERR_SOCKET_NOT_CONNECTED. This will most + // likely happen when trying to retrieve its IP address. See + // http://crbug.com/105824 for more details. + + // We resend a request only if we reused a keep-alive connection and did + // not yet receive any header data. This automatically prevents an + // infinite resend loop because we'll run out of the cached keep-alive + // connections eventually. + return nil + } + return err } +// ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol. +var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") + // RegisterProtocol registers a new protocol with scheme. // The Transport will pass requests using the given scheme to rt. // It is rt's responsibility to simulate HTTP request semantics. // // RegisterProtocol can be used by other packages to provide // implementations of protocol schemes like "ftp" or "file". +// +// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will +// handle the RoundTrip itself for that one request, as if the +// protocol were not registered. func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { - if scheme == "http" || scheme == "https" { - panic("protocol " + scheme + " already registered") - } t.altMu.Lock() defer t.altMu.Unlock() if t.altProto == nil { @@ -261,6 +378,7 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t.idleMu.Lock() m := t.idleConn t.idleConn = nil @@ -269,13 +387,19 @@ func (t *Transport) CloseIdleConnections() { t.idleMu.Unlock() for _, conns := range m { for _, pconn := range conns { - pconn.close() + pconn.close(errCloseIdleConns) } } + if t2 := t.h2transport; t2 != nil { + t2.CloseIdleConnections() + } } // CancelRequest cancels an in-flight request by closing its connection. // CancelRequest should only be called after RoundTrip has returned. +// +// Deprecated: Use Request.Cancel instead. CancelRequest can not cancel +// HTTP/2 requests. func (t *Transport) CancelRequest(req *Request) { t.reqMu.Lock() cancel := t.reqCanceler[req] @@ -354,23 +478,41 @@ func (cm *connectMethod) proxyAuth() string { return "" } -// putIdleConn adds pconn to the list of idle persistent connections awaiting +// error values for debugging and testing, not seen by users. +var ( + errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled") + errConnBroken = errors.New("http: putIdleConn: connection is in bad state") + errWantIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") + errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errCloseIdleConns = errors.New("http: CloseIdleConnections called") + errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") + errServerClosedIdle = errors.New("http: server closed idle conn") +) + +func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { + if err := t.tryPutIdleConn(pconn); err != nil { + pconn.close(err) + } +} + +// tryPutIdleConn adds pconn to the list of idle persistent connections awaiting // a new request. -// If pconn is no longer needed or not in a good state, putIdleConn -// returns false. -func (t *Transport) putIdleConn(pconn *persistConn) bool { +// If pconn is no longer needed or not in a good state, tryPutIdleConn returns +// an error explaining why it wasn't registered. +// tryPutIdleConn does not close pconn. Use putOrCloseIdleConn instead for that. +func (t *Transport) tryPutIdleConn(pconn *persistConn) error { if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { - pconn.close() - return false + return errKeepAlivesDisabled } if pconn.isBroken() { - return false + return errConnBroken } key := pconn.cacheKey max := t.MaxIdleConnsPerHost if max == 0 { max = DefaultMaxIdleConnsPerHost } + pconn.markReused() t.idleMu.Lock() waitingDialer := t.idleConnCh[key] @@ -382,7 +524,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { // first). Chrome calls this socket late binding. See // https://insouciant.org/tech/connection-management-in-chromium/ t.idleMu.Unlock() - return true + return nil default: if waitingDialer != nil { // They had populated this, but their dial won @@ -392,16 +534,14 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { } if t.wantIdle { t.idleMu.Unlock() - pconn.close() - return false + return errWantIdle } if t.idleConn == nil { t.idleConn = make(map[connectMethodKey][]*persistConn) } if len(t.idleConn[key]) >= max { t.idleMu.Unlock() - pconn.close() - return false + return errTooManyIdle } for _, exist := range t.idleConn[key] { if exist == pconn { @@ -410,7 +550,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { } t.idleConn[key] = append(t.idleConn[key], pconn) t.idleMu.Unlock() - return true + return nil } // getIdleConnCh returns a channel to receive and return idle @@ -494,16 +634,17 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { return true } -func (t *Transport) dial(network, addr string) (c net.Conn, err error) { +func (t *Transport) dial(network, addr string) (net.Conn, error) { if t.Dial != nil { - return t.Dial(network, addr) + c, err := t.Dial(network, addr) + if c == nil && err == nil { + err = errors.New("net/http: Transport.Dial hook returned (nil, nil)") + } + return c, err } return net.Dial(network, addr) } -// Testing hooks: -var prePendingDial, postPendingDial func() - // getConn dials and creates a new persistConn to the target as // specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn @@ -525,20 +666,16 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error // Copy these hooks so we don't race on the postPendingDial in // the goroutine we launch. Issue 11136. - prePendingDial := prePendingDial - postPendingDial := postPendingDial + testHookPrePendingDial := testHookPrePendingDial + testHookPostPendingDial := testHookPostPendingDial handlePendingDial := func() { - if prePendingDial != nil { - prePendingDial() - } + testHookPrePendingDial() go func() { if v := <-dialc; v.err == nil { - t.putIdleConn(v.pc) - } - if postPendingDial != nil { - postPendingDial() + t.putOrCloseIdleConn(v.pc) } + testHookPostPendingDial() }() } @@ -565,10 +702,10 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error return pc, nil case <-req.Cancel: handlePendingDial() - return nil, errors.New("net/http: request canceled while waiting for connection") + return nil, errRequestCanceledConn case <-cancelc: handlePendingDial() - return nil, errors.New("net/http: request canceled while waiting for connection") + return nil, errRequestCanceledConn } } @@ -588,7 +725,16 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { if err != nil { return nil, err } + if pconn.conn == nil { + return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)") + } if tc, ok := pconn.conn.(*tls.Conn); ok { + // Handshake here, in case DialTLS didn't. TLSNextProto below + // depends on it for knowing the connection state. + if err := tc.Handshake(); err != nil { + go pconn.conn.Close() + return nil, err + } cs := tc.ConnectionState() pconn.tlsState = &cs } @@ -680,6 +826,12 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { pconn.conn = tlsConn } + if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { + if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { + return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil + } + } + pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF}) pconn.bw = bufio.NewWriter(pconn.conn) go pconn.readLoop() @@ -809,6 +961,11 @@ func (k connectMethodKey) String() string { // persistConn wraps a connection, usually a persistent one // (but may be used for non-keep-alive requests as well) type persistConn struct { + // alt optionally specifies the TLS NextProto RoundTripper. + // This is used for HTTP/2 today and future protocol laters. + // If it's non-nil, the rest of the fields are unused. + alt RoundTripper + t *Transport cacheKey connectMethodKey conn net.Conn @@ -828,9 +985,10 @@ type persistConn struct { lk sync.Mutex // guards following fields numExpectedResponses int - closed bool // whether conn has been closed - broken bool // an error has happened on this connection; marked broken so it's not reused. - canceled bool // whether this conn was broken due a CancelRequest + closed error // set non-nil when conn is closed, before closech is closed + broken bool // an error has happened on this connection; marked broken so it's not reused. + canceled bool // whether this conn was broken due a CancelRequest + reused bool // whether conn has had successful request/response and is being reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the // original Request given to RoundTrip is not modified) @@ -852,15 +1010,34 @@ func (pc *persistConn) isCanceled() bool { return pc.canceled } +// isReused reports whether this connection is in a known broken state. +func (pc *persistConn) isReused() bool { + pc.lk.Lock() + r := pc.reused + pc.lk.Unlock() + return r +} + func (pc *persistConn) cancelRequest() { pc.lk.Lock() defer pc.lk.Unlock() pc.canceled = true - pc.closeLocked() + pc.closeLocked(errRequestCanceled) } func (pc *persistConn) readLoop() { - // eofc is used to block http.Handler goroutines reading from Response.Body + closeErr := errReadLoopExiting // default value, if not changed below + defer func() { pc.close(closeErr) }() + + tryPutIdleConn := func() bool { + if err := pc.t.tryPutIdleConn(pc); err != nil { + closeErr = err + return false + } + return true + } + + // eofc is used to block caller goroutines reading from Response.Body // at EOF until this goroutines has (potentially) added the connection // back to the idle pool. eofc := make(chan struct{}) @@ -873,17 +1050,14 @@ func (pc *persistConn) readLoop() { alive := true for alive { - pb, err := pc.br.Peek(1) + _, err := pc.br.Peek(1) + if err != nil { + err = beforeRespHeaderError{err} + } pc.lk.Lock() if pc.numExpectedResponses == 0 { - if !pc.closed { - pc.closeLocked() - if len(pb) > 0 { - log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", - string(pb), err) - } - } + pc.readLoopPeekFailLocked(err) pc.lk.Unlock() return } @@ -893,115 +1067,189 @@ func (pc *persistConn) readLoop() { var resp *Response if err == nil { - resp, err = ReadResponse(pc.br, rc.req) - if err == nil && resp.StatusCode == 100 { - // Skip any 100-continue for now. - // TODO(bradfitz): if rc.req had "Expect: 100-continue", - // actually block the request body write and signal the - // writeLoop now to begin sending it. (Issue 2184) For now we - // eat it, since we're never expecting one. - resp, err = ReadResponse(pc.br, rc.req) - } - } - - if resp != nil { - resp.TLS = pc.tlsState + resp, err = pc.readResponse(rc) } - hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 - if err != nil { - pc.close() - } else { - if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" { - resp.Header.Del("Content-Encoding") - resp.Header.Del("Content-Length") - resp.ContentLength = -1 - resp.Body = &gzipReader{body: resp.Body} + // If we won't be able to retry this request later (from the + // roundTrip goroutine), mark it as done now. + // BEFORE the send on rc.ch, as the client might re-use the + // same *Request pointer, and we don't want to set call + // t.setReqCanceler from this persistConn while the Transport + // potentially spins up a different persistConn for the + // caller's subsequent request. + if checkTransportResend(err, rc.req, pc) != nil { + pc.t.setReqCanceler(rc.req, nil) + } + select { + case rc.ch <- responseAndError{err: err}: + case <-rc.callerGone: + return } - resp.Body = &bodyEOFSignal{body: resp.Body} + return } - if err != nil || resp.Close || rc.req.Close || resp.StatusCode <= 199 { + pc.lk.Lock() + pc.numExpectedResponses-- + pc.lk.Unlock() + + hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + + if resp.Close || rc.req.Close || resp.StatusCode <= 199 { // Don't do keep-alive on error if either party requested a close // or we get an unexpected informational (1xx) response. // StatusCode 100 is already handled above. alive = false } - var waitForBodyRead chan bool // channel is nil when there's no body - if hasBody { - waitForBodyRead = make(chan bool, 2) - resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { - waitForBodyRead <- false - return nil - } - resp.Body.(*bodyEOFSignal).fn = func(err error) error { - isEOF := err == io.EOF - waitForBodyRead <- isEOF - if isEOF { - <-eofc // see comment at top - } else if err != nil && pc.isCanceled() { - return errRequestCanceled - } - return err - } - } else { - // Before send on rc.ch, as client might re-use the - // same *Request pointer, and we don't want to set this - // on t from this persistConn while the Transport - // potentially spins up a different persistConn for the - // caller's subsequent request. + if !hasBody { pc.t.setReqCanceler(rc.req, nil) - } - pc.lk.Lock() - pc.numExpectedResponses-- - pc.lk.Unlock() + // Put the idle conn back into the pool before we send the response + // so if they process it quickly and make another request, they'll + // get this same conn. But we use the unbuffered channel 'rc' + // to guarantee that persistConn.roundTrip got out of its select + // potentially waiting for this persistConn to close. + // but after + alive = alive && + !pc.sawEOF && + pc.wroteRequest() && + tryPutIdleConn() - // The connection might be going away when we put the - // idleConn below. When that happens, we close the response channel to signal - // to roundTrip that the connection is gone. roundTrip waits for - // both closing and a response in a select, so it might choose - // the close channel, rather than the response. - // We send the response first so that roundTrip can check - // if there is a pending one with a non-blocking select - // on the response channel before erroring out. - rc.ch <- responseAndError{resp, err} - - if hasBody { - // To avoid a race, wait for the just-returned - // response body to be fully consumed before peek on - // the underlying bufio reader. select { - case <-rc.req.Cancel: - alive = false - pc.t.CancelRequest(rc.req) - case bodyEOF := <-waitForBodyRead: - pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool - alive = alive && - bodyEOF && - !pc.sawEOF && - pc.wroteRequest() && - pc.t.putIdleConn(pc) - if bodyEOF { - eofc <- struct{}{} - } - case <-pc.closech: - alive = false + case rc.ch <- responseAndError{res: resp}: + case <-rc.callerGone: + return } - } else { + + // Now that they've read from the unbuffered channel, they're safely + // out of the select that also waits on this goroutine to die, so + // we're allowed to exit now if needed (if alive is false) + testHookReadLoopBeforeNextRead() + continue + } + + if rc.addedGzip { + maybeUngzipResponse(resp) + } + resp.Body = &bodyEOFSignal{body: resp.Body} + + waitForBodyRead := make(chan bool, 2) + resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { + waitForBodyRead <- false + return nil + } + resp.Body.(*bodyEOFSignal).fn = func(err error) error { + isEOF := err == io.EOF + waitForBodyRead <- isEOF + if isEOF { + <-eofc // see comment above eofc declaration + } else if err != nil && pc.isCanceled() { + return errRequestCanceled + } + return err + } + + select { + case rc.ch <- responseAndError{res: resp}: + case <-rc.callerGone: + return + } + + // Before looping back to the top of this function and peeking on + // the bufio.Reader, wait for the caller goroutine to finish + // reading the response body. (or for cancelation or death) + select { + case bodyEOF := <-waitForBodyRead: + pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool alive = alive && + bodyEOF && !pc.sawEOF && pc.wroteRequest() && - pc.t.putIdleConn(pc) + tryPutIdleConn() + if bodyEOF { + eofc <- struct{}{} + } + case <-rc.req.Cancel: + alive = false + pc.t.CancelRequest(rc.req) + case <-pc.closech: + alive = false + } + + testHookReadLoopBeforeNextRead() + } +} + +func maybeUngzipResponse(resp *Response) { + if resp.Header.Get("Content-Encoding") == "gzip" { + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Body = &gzipReader{body: resp.Body} + } +} + +func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { + if pc.closed != nil { + return + } + if n := pc.br.Buffered(); n > 0 { + buf, _ := pc.br.Peek(n) + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr) + } + if peekErr == io.EOF { + // common case. + pc.closeLocked(errServerClosedIdle) + } else { + pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %v", peekErr)) + } +} + +// readResponse reads an HTTP response (or two, in the case of "Expect: +// 100-continue") from the server. It returns the final non-100 one. +func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err error) { + resp, err = ReadResponse(pc.br, rc.req) + if err != nil { + return + } + if rc.continueCh != nil { + if resp.StatusCode == 100 { + rc.continueCh <- struct{}{} + } else { + close(rc.continueCh) + } + } + if resp.StatusCode == 100 { + resp, err = ReadResponse(pc.br, rc.req) + if err != nil { + return } + } + resp.TLS = pc.tlsState + return +} + +// waitForContinue returns the function to block until +// any response, timeout or connection close. After any of them, +// the function returns a bool which indicates if the body should be sent. +func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool { + if continueCh == nil { + return nil + } + return func() bool { + timer := time.NewTimer(pc.t.ExpectContinueTimeout) + defer timer.Stop() - if hook := testHookReadLoopBeforeNextRead; hook != nil { - hook() + select { + case _, ok := <-continueCh: + return ok + case <-timer.C: + return true + case <-pc.closech: + return false } } - pc.close() } func (pc *persistConn) writeLoop() { @@ -1012,7 +1260,7 @@ func (pc *persistConn) writeLoop() { wr.ch <- errors.New("http: can't write HTTP request on broken connection") continue } - err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra) + err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh)) if err == nil { err = pc.bw.Flush() } @@ -1056,19 +1304,29 @@ func (pc *persistConn) wroteRequest() bool { } } +// responseAndError is how the goroutine reading from an HTTP/1 server +// communicates with the goroutine doing the RoundTrip. type responseAndError struct { - res *Response + res *Response // else use this response (see res method) err error } type requestAndChan struct { req *Request - ch chan responseAndError + ch chan responseAndError // unbuffered; always send in select on callerGone // did the Transport (as opposed to the client code) add an // Accept-Encoding gzip header? only if it we set it do // we transparently decode the gzip. addedGzip bool + + // Optional blocking chan for Expect: 100-continue (for send). + // If the request has an "Expect: 100-continue" header and + // the server responds 100 Continue, readLoop send a value + // to writeLoop via this chan. + continueCh chan<- struct{} + + callerGone <-chan struct{} // closed when roundTrip caller has returned } // A writeRequest is sent by the readLoop's goroutine to the @@ -1078,6 +1336,11 @@ type requestAndChan struct { type writeRequest struct { req *transportRequest ch chan<- error + + // Optional blocking chan for Expect: 100-continue (for recieve). + // If not nil, writeLoop blocks sending request body until + // it receives from this chan. + continueCh <-chan struct{} } type httpError struct { @@ -1090,23 +1353,34 @@ func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} -var errClosed error = &httpError{err: "net/http: transport closed before response was received"} +var errClosed error = &httpError{err: "net/http: server closed connection before response was received"} var errRequestCanceled = errors.New("net/http: request canceled") +var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? + +func nop() {} -// nil except for tests +// testHooks. Always non-nil. var ( - testHookPersistConnClosedGotRes func() - testHookEnterRoundTrip func() - testHookMu sync.Locker = fakeLocker{} // guards following - testHookReadLoopBeforeNextRead func() + testHookEnterRoundTrip = nop + testHookWaitResLoop = nop + testHookRoundTripRetried = nop + testHookPrePendingDial = nop + testHookPostPendingDial = nop + + testHookMu sync.Locker = fakeLocker{} // guards following + testHookReadLoopBeforeNextRead = nop ) +// beforeRespHeaderError is used to indicate when an IO error has occurred before +// any header data was received. +type beforeRespHeaderError struct { + error +} + func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - if hook := testHookEnterRoundTrip; hook != nil { - hook() - } + testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) { - pc.t.putIdleConn(pc) + pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } pc.lk.Lock() @@ -1143,42 +1417,47 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err req.extraHeaders().Set("Accept-Encoding", "gzip") } + var continueCh chan struct{} + if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() { + continueCh = make(chan struct{}, 1) + } + if pc.t.DisableKeepAlives { req.extraHeaders().Set("Connection", "close") } + gone := make(chan struct{}) + defer close(gone) + // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. writeErrCh := make(chan error, 1) - pc.writech <- writeRequest{req, writeErrCh} + pc.writech <- writeRequest{req, writeErrCh, continueCh} - resc := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req.Request, resc, requestedGzip} + resc := make(chan responseAndError) + pc.reqch <- requestAndChan{ + req: req.Request, + ch: resc, + addedGzip: requestedGzip, + continueCh: continueCh, + callerGone: gone, + } var re responseAndError var respHeaderTimer <-chan time.Time cancelChan := req.Request.Cancel WaitResponse: for { + testHookWaitResLoop() select { case err := <-writeErrCh: - if isNetWriteError(err) { - // Issue 11745. If we failed to write the request - // body, it's possible the server just heard enough - // and already wrote to us. Prioritize the server's - // response over returning a body write error. - select { - case re = <-resc: - pc.close() - break WaitResponse - case <-time.After(50 * time.Millisecond): - // Fall through. - } - } if err != nil { - re = responseAndError{nil, err} - pc.close() + if pc.isCanceled() { + err = errRequestCanceled + } + re = responseAndError{err: beforeRespHeaderError{err}} + pc.close(fmt.Errorf("write error: %v", err)) break WaitResponse } if d := pc.t.ResponseHeaderTimeout; d > 0 { @@ -1187,33 +1466,22 @@ WaitResponse: respHeaderTimer = timer.C } case <-pc.closech: - // The persist connection is dead. This shouldn't - // usually happen (only with Connection: close responses - // with no response bodies), but if it does happen it - // means either a) the remote server hung up on us - // prematurely, or b) the readLoop sent us a response & - // closed its closech at roughly the same time, and we - // selected this case first. If we got a response, readLoop makes sure - // to send it before it puts the conn and closes the channel. - // That way, we can fetch the response, if there is one, - // with a non-blocking receive. - select { - case re = <-resc: - if fn := testHookPersistConnClosedGotRes; fn != nil { - fn() - } - default: - re = responseAndError{err: errClosed} - if pc.isCanceled() { - re = responseAndError{err: errRequestCanceled} - } + var err error + if pc.isCanceled() { + err = errRequestCanceled + } else { + err = beforeRespHeaderError{fmt.Errorf("net/http: HTTP/1 transport connection broken: %v", pc.closed)} } + re = responseAndError{err: err} break WaitResponse case <-respHeaderTimer: - pc.close() + pc.close(errTimeout) re = responseAndError{err: errTimeout} break WaitResponse case re = <-resc: + if re.err != nil && pc.isCanceled() { + re.err = errRequestCanceled + } break WaitResponse case <-cancelChan: pc.t.CancelRequest(req.Request) @@ -1224,6 +1492,9 @@ WaitResponse: if re.err != nil { pc.t.setReqCanceler(req.Request, nil) } + if (re.res == nil) == (re.err == nil) { + panic("internal error: exactly one of res or err should be set") + } return re.res, re.err } @@ -1236,18 +1507,44 @@ func (pc *persistConn) markBroken() { pc.broken = true } -func (pc *persistConn) close() { +// markReused marks this connection as having been successfully used for a +// request and response. +func (pc *persistConn) markReused() { + pc.lk.Lock() + pc.reused = true + pc.lk.Unlock() +} + +// close closes the underlying TCP connection and closes +// the pc.closech channel. +// +// The provided err is only for testing and debugging; in normal +// circumstances it should never be seen by users. +func (pc *persistConn) close(err error) { pc.lk.Lock() defer pc.lk.Unlock() - pc.closeLocked() + pc.closeLocked(err) } -func (pc *persistConn) closeLocked() { +func (pc *persistConn) closeLocked(err error) { + if err == nil { + panic("nil error") + } pc.broken = true - if !pc.closed { - pc.conn.Close() - pc.closed = true - close(pc.closech) + if pc.closed == nil { + pc.closed = err + if pc.alt != nil { + // Do nothing; can only get here via getConn's + // handlePendingDial's putOrCloseIdleConn when + // it turns out the abandoned connection in + // flight ended up negotiating an alternate + // protocol. We don't use the connection + // freelist for http2. That's done by the + // alternate protocol's RoundTripper. + } else { + pc.conn.Close() + close(pc.closech) + } } pc.mutateHeaderFunc = nil } diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index c21d4afa87f..3b2a5f978e2 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -2,7 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Tests for transport.go +// Tests for transport.go. +// +// More tests are in clientserver_test.go (for things testing both client & server for both +// HTTP/1 and HTTP/2). This package http_test @@ -20,6 +23,8 @@ import ( "net" . "net/http" "net/http/httptest" + "net/http/httputil" + "net/http/internal" "net/url" "os" "reflect" @@ -256,6 +261,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { // if the Transport's DisableKeepAlives is set, all requests should // send Connection: close. +// HTTP/1-only (Connection: close doesn't exist in h2) func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) @@ -431,6 +437,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportServerClosingUnexpectedly(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -597,6 +604,7 @@ func TestTransportHeadChunkedResponse(t *testing.T) { tr := &Transport{DisableKeepAlives: false} c := &Client{Transport: tr} + defer tr.CloseIdleConnections() // Ensure that we wait for the readLoop to complete before // calling Head again @@ -790,6 +798,94 @@ func TestTransportGzip(t *testing.T) { } } +// If a request has Expect:100-continue header, the request blocks sending body until the first response. +// Premature consumption of the request body should not be occurred. +func TestTransportExpect100Continue(t *testing.T) { + defer afterTest(t) + + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + switch req.URL.Path { + case "/100": + // This endpoint implicitly responds 100 Continue and reads body. + if _, err := io.Copy(ioutil.Discard, req.Body); err != nil { + t.Error("Failed to read Body", err) + } + rw.WriteHeader(StatusOK) + case "/200": + // Go 1.5 adds Connection: close header if the client expect + // continue but not entire request body is consumed. + rw.WriteHeader(StatusOK) + case "/500": + rw.WriteHeader(StatusInternalServerError) + case "/keepalive": + // This hijacked endpoint responds error without Connection:close. + _, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") + bufrw.WriteString("Content-Length: 0\r\n\r\n") + bufrw.Flush() + case "/timeout": + // This endpoint tries to read body without 100 (Continue) response. + // After ExpectContinueTimeout, the reading will be started. + conn, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + if _, err := io.CopyN(ioutil.Discard, bufrw, req.ContentLength); err != nil { + t.Error("Failed to read Body", err) + } + bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") + bufrw.Flush() + conn.Close() + } + + })) + defer ts.Close() + + tests := []struct { + path string + body []byte + sent int + status int + }{ + {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. + {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. + {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. + {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Althogh without Connection:close, body isn't sent. + {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. + } + + for i, v := range tests { + tr := &Transport{ExpectContinueTimeout: 2 * time.Second} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + body := bytes.NewReader(v.body) + req, err := NewRequest("PUT", ts.URL+v.path, body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Expect", "100-continue") + req.ContentLength = int64(len(v.body)) + + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + sent := len(v.body) - body.Len() + if v.status != resp.StatusCode { + t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) + } + if v.sent != sent { + t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) + } + } +} + func TestTransportProxy(t *testing.T) { defer afterTest(t) ch := make(chan string, 1) @@ -874,9 +970,7 @@ func TestTransportGzipShort(t *testing.T) { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) @@ -943,9 +1037,7 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) @@ -1286,6 +1378,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { } func TestTransportResponseHeaderTimeout(t *testing.T) { + setParallel(t) defer afterTest(t) if testing.Short() { t.Skip("skipping timeout test in -short mode") @@ -1357,6 +1450,7 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } func TestTransportCancelRequest(t *testing.T) { + setParallel(t) defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") @@ -1466,6 +1560,7 @@ Get = Get http://something.no-network.tld/: net/http: request canceled while wai } func TestCancelRequestWithChannel(t *testing.T) { + setParallel(t) defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") @@ -1523,6 +1618,7 @@ func TestCancelRequestWithChannel(t *testing.T) { } func TestCancelRequestWithChannelBeforeDo(t *testing.T) { + setParallel(t) defer afterTest(t) unblockc := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1554,7 +1650,6 @@ func TestCancelRequestWithChannelBeforeDo(t *testing.T) { // Issue 11020. The returned error message should be errRequestCanceled func TestTransportCancelBeforeResponseHeaders(t *testing.T) { - t.Skip("Skipping flaky test; see Issue 11894") defer afterTest(t) serverConnCh := make(chan net.Conn, 1) @@ -1704,6 +1799,19 @@ func TestTransportNoHost(t *testing.T) { } } +// Issue 13311 +func TestTransportEmptyMethod(t *testing.T) { + req, _ := NewRequest("GET", "http://foo.com/", nil) + req.Method = "" // docs say "For client requests an empty string means GET" + got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(got), "GET ") { + t.Fatalf("expected substring 'GET '; got: %s", got) + } +} + func TestTransportSocketLateBinding(t *testing.T) { defer afterTest(t) @@ -2291,15 +2399,103 @@ type errorReader struct { func (e errorReader) Read(p []byte) (int, error) { return 0, e.err } +type plan9SleepReader struct{} + +func (plan9SleepReader) Read(p []byte) (int, error) { + if runtime.GOOS == "plan9" { + // After the fix to unblock TCP Reads in + // https://golang.org/cl/15941, this sleep is required + // on plan9 to make sure TCP Writes before an + // immediate TCP close go out on the wire. On Plan 9, + // it seems that a hangup of a TCP connection with + // queued data doesn't send the queued data first. + // https://golang.org/issue/9554 + time.Sleep(50 * time.Millisecond) + } + return 0, io.EOF +} + type closerFunc func() error func (f closerFunc) Close() error { return f() } +// Issue 4677. If we try to reuse a connection that the server is in the +// process of closing, we may end up successfully writing out our request (or a +// portion of our request) only to find a connection error when we try to read +// from (or finish writing to) the socket. +// +// NOTE: we resend a request only if the request is idempotent, we reused a +// keep-alive connection, and we haven't yet received any header data. This +// automatically prevents an infinite resend loop because we'll run out of the +// cached keep-alive connections eventually. +func TestRetryIdempotentRequestsOnError(t *testing.T) { + defer afterTest(t) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })) + defer ts.Close() + + tr := &Transport{} + c := &Client{Transport: tr} + + const N = 2 + retryc := make(chan struct{}, N) + SetRoundTripRetried(func() { + retryc <- struct{}{} + }) + defer SetRoundTripRetried(nil) + + for n := 0; n < 100; n++ { + // open 2 conns + errc := make(chan error, N) + for i := 0; i < N; i++ { + // start goroutines, send on errc + go func() { + res, err := c.Get(ts.URL) + if err == nil { + res.Body.Close() + } + errc <- err + }() + } + for i := 0; i < N; i++ { + if err := <-errc; err != nil { + t.Fatal(err) + } + } + + ts.CloseClientConnections() + for i := 0; i < N; i++ { + go func() { + res, err := c.Get(ts.URL) + if err == nil { + res.Body.Close() + } + errc <- err + }() + } + + for i := 0; i < N; i++ { + if err := <-errc; err != nil { + t.Fatal(err) + } + } + for i := 0; i < N; i++ { + select { + case <-retryc: + // we triggered a retry, test was successful + t.Logf("finished after %d runs\n", n) + return + default: + } + } + } + t.Fatal("did not trigger any retries") +} + // Issue 6981 func TestTransportClosesBodyOnError(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7782") - } + setParallel(t) defer afterTest(t) readBody := make(chan error, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2313,7 +2509,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { io.Reader io.Closer }{ - io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}), + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), plan9SleepReader{}, errorReader{fakeErr}), closerFunc(func() error { select { case didClose <- true: @@ -2474,52 +2670,6 @@ func TestTransportRangeAndGzip(t *testing.T) { res.Body.Close() } -// Previously, we used to handle a logical race within RoundTrip by waiting for 100ms -// in the case of an error. Changing the order of the channel operations got rid of this -// race. -// -// In order to test that the channel op reordering works, we install a hook into the -// roundTrip function which gets called if we saw the connection go away and -// we subsequently received a response. -func TestTransportResponseCloseRace(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - })) - defer ts.Close() - sawRace := false - SetInstallConnClosedHook(func() { - sawRace = true - }) - defer SetInstallConnClosedHook(nil) - tr := &Transport{ - DisableKeepAlives: true, - } - req, err := NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - // selects are not deterministic, so do this a bunch - // and see if we handle the logical race at least once. - for i := 0; i < 10000; i++ { - resp, err := tr.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected error: %s", err) - continue - } - resp.Body.Close() - if sawRace { - break - } - } - if !sawRace { - t.Errorf("didn't see response/connection going away race") - } -} - // Test for issue 10474 func TestTransportResponseCancelRace(t *testing.T) { defer afterTest(t) @@ -2645,7 +2795,7 @@ func TestTransportFlushesBodyChunks(t *testing.T) { req.Header.Set("User-Agent", "x") // known value for test res, err := tr.RoundTrip(req) if err != nil { - t.Error("RoundTrip: %v", err) + t.Errorf("RoundTrip: %v", err) close(resc) return } @@ -2735,6 +2885,153 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { } } +func TestTransportAutomaticHTTP2(t *testing.T) { + tr := &Transport{} + _, err := tr.RoundTrip(new(Request)) + if err == nil { + t.Error("expected error from RoundTrip") + } + if tr.TLSNextProto["h2"] == nil { + t.Errorf("HTTP/2 not registered.") + } + + // Now with TLSNextProto set: + tr = &Transport{TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper)} + _, err = tr.RoundTrip(new(Request)) + if err == nil { + t.Error("expected error from RoundTrip") + } + if tr.TLSNextProto["h2"] != nil { + t.Errorf("HTTP/2 registered, despite non-nil TLSNextProto field") + } +} + +// Issue 13633: there was a race where we returned bodyless responses +// to callers before recycling the persistent connection, which meant +// a client doing two subsequent requests could end up on different +// connections. It's somewhat harmless but enough tests assume it's +// not true in order to test other things that it's worth fixing. +// Plus it's nice to be consistent and not have timing-dependent +// behavior. +func TestTransportReuseConnEmptyResponseBody(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Addr", r.RemoteAddr) + // Empty response body. + })) + defer cst.close() + n := 100 + if testing.Short() { + n = 10 + } + var firstAddr string + for i := 0; i < n; i++ { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + log.Fatal(err) + } + addr := res.Header.Get("X-Addr") + if i == 0 { + firstAddr = addr + } else if addr != firstAddr { + t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) + } + res.Body.Close() + } +} + +// Issue 13839 +func TestNoCrashReturningTransportAltConn(t *testing.T) { + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + t.Fatal(err) + } + ln := newLocalListener(t) + defer ln.Close() + + handledPendingDial := make(chan bool, 1) + SetPendingDialHooks(nil, func() { handledPendingDial <- true }) + defer SetPendingDialHooks(nil, nil) + + testDone := make(chan struct{}) + defer close(testDone) + go func() { + tln := tls.NewListener(ln, &tls.Config{ + NextProtos: []string{"foo"}, + Certificates: []tls.Certificate{cert}, + }) + sc, err := tln.Accept() + if err != nil { + t.Error(err) + return + } + if err := sc.(*tls.Conn).Handshake(); err != nil { + t.Error(err) + return + } + <-testDone + sc.Close() + }() + + addr := ln.Addr().String() + + req, _ := NewRequest("GET", "https://fake.tld/", nil) + cancel := make(chan struct{}) + req.Cancel = cancel + + doReturned := make(chan bool, 1) + madeRoundTripper := make(chan bool, 1) + + tr := &Transport{ + DisableKeepAlives: true, + TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { + madeRoundTripper <- true + return funcRoundTripper(func() { + t.Error("foo RoundTripper should not be called") + }) + }, + }, + Dial: func(_, _ string) (net.Conn, error) { + panic("shouldn't be called") + }, + DialTLS: func(_, _ string) (net.Conn, error) { + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"foo"}, + }) + if err != nil { + return nil, err + } + if err := tc.Handshake(); err != nil { + return nil, err + } + close(cancel) + <-doReturned + return tc, nil + }, + } + c := &Client{Transport: tr} + + _, err = c.Do(req) + if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn { + t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) + } + + doReturned <- true + <-madeRoundTripper + <-handledPendingDial +} + +var errFakeRoundTrip = errors.New("fake roundtrip") + +type funcRoundTripper func() + +func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) { + fn() + return nil, errFakeRoundTrip +} + func wantBody(res *Response, err error, want string) error { if err != nil { return err diff --git a/libgo/go/net/http/triv.go b/libgo/go/net/http/triv.go index 232d6508906..cfbc5778c1c 100644 --- a/libgo/go/net/http/triv.go +++ b/libgo/go/net/http/triv.go @@ -134,8 +134,5 @@ func main() { http.HandleFunc("/args", ArgServer) http.HandleFunc("/go/hello", HelloServer) http.HandleFunc("/date", DateServer) - err := http.ListenAndServe(":12345", nil) - if err != nil { - log.Panicln("ListenAndServe:", err) - } + log.Fatal(http.ListenAndServe(":12345", nil)) } |