diff options
Diffstat (limited to 'libgo/go/net/http')
37 files changed, 4460 insertions, 405 deletions
diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go index cb6f1df1f45..b514e10e963 100644 --- a/libgo/go/net/http/cgi/host_test.go +++ b/libgo/go/net/http/cgi/host_test.go @@ -19,7 +19,6 @@ import ( "runtime" "strconv" "strings" - "syscall" "testing" "time" ) @@ -340,11 +339,7 @@ func TestCopyError(t *testing.T) { } childRunning := func() bool { - p, err := os.FindProcess(pid) - if err != nil { - return false - } - return p.Signal(syscall.Signal(0)) == nil + return isProcessRunning(t, pid) } if !childRunning() { diff --git a/libgo/go/net/http/cgi/posix_test.go b/libgo/go/net/http/cgi/posix_test.go new file mode 100644 index 00000000000..5ff9e7d5eb7 --- /dev/null +++ b/libgo/go/net/http/cgi/posix_test.go @@ -0,0 +1,21 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package cgi + +import ( + "os" + "syscall" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + p, err := os.FindProcess(pid) + if err != nil { + return false + } + return p.Signal(syscall.Signal(0)) == nil +} diff --git a/libgo/go/net/http/cgi/testdata/test.cgi b/libgo/go/net/http/cgi/testdata/test.cgi index 1b25bc29996..3214df6f004 100644 --- a/libgo/go/net/http/cgi/testdata/test.cgi +++ b/libgo/go/net/http/cgi/testdata/test.cgi @@ -24,7 +24,8 @@ print "X-Test-Header: X-Test-Value\r\n"; print "\r\n"; if ($params->{"bigresponse"}) { - for (1..1024) { + # 17 MB, for OS X: golang.org/issue/4958 + for (1..(17 * 1024)) { print "A" x 1024, "\r\n"; } exit 0; diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 5ee0804c7da..a34d47be1fa 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -19,12 +19,16 @@ import ( "strings" ) -// A Client is an HTTP client. Its zero value (DefaultClient) is a usable client -// that uses DefaultTransport. +// A Client is an HTTP client. Its zero value (DefaultClient) is a +// usable client that uses DefaultTransport. // -// The Client's Transport typically has internal state (cached -// TCP connections), so Clients should be reused instead of created as +// The Client's Transport typically has internal state (cached TCP +// connections), so Clients should be reused instead of created as // needed. Clients are safe for concurrent use by multiple goroutines. +// +// A Client is higher-level than a RoundTripper (such as Transport) +// and additionally handles HTTP details such as cookies and +// redirects. type Client struct { // Transport specifies the mechanism by which individual // HTTP requests are made. diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 9514a4b9610..73f1fe3c10a 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -51,10 +51,10 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { return b, err } } - panic("unreachable") } func TestClient(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -72,6 +72,7 @@ func TestClient(t *testing.T) { } func TestClientHead(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -94,6 +95,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} url := "http://dummy.faketld/" @@ -110,6 +112,7 @@ func TestGetRequestFormat(t *testing.T) { } func TestPostRequestFormat(t *testing.T) { + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -136,6 +139,7 @@ func TestPostRequestFormat(t *testing.T) { } func TestPostFormRequestFormat(t *testing.T) { + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -176,7 +180,8 @@ func TestPostFormRequestFormat(t *testing.T) { } } -func TestRedirects(t *testing.T) { +func TestClientRedirects(t *testing.T) { + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) @@ -223,6 +228,7 @@ func TestRedirects(t *testing.T) { if err != nil { t.Fatalf("Get error: %v", err) } + res.Body.Close() finalUrl := res.Request.URL.String() if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { t.Errorf("with custom client, expected error %q, got %q", e, g) @@ -242,12 +248,14 @@ func TestRedirects(t *testing.T) { if res == nil { t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)") } + res.Body.Close() if res.Header.Get("Location") == "" { t.Errorf("no Location header in Response") } } func TestPostRedirects(t *testing.T) { + defer afterTest(t) var log struct { sync.Mutex bytes.Buffer @@ -265,6 +273,7 @@ func TestPostRedirects(t *testing.T) { w.WriteHeader(code) } })) + defer ts.Close() tests := []struct { suffix string want int // response code @@ -364,6 +373,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesOnRequest(t *testing.T) { + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -381,6 +391,7 @@ func TestRedirectCookiesOnRequest(t *testing.T) { } func TestRedirectCookiesJar(t *testing.T) { + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -393,6 +404,7 @@ func TestRedirectCookiesJar(t *testing.T) { if err != nil { t.Fatalf("Get: %v", err) } + resp.Body.Close() matchReturnedCookies(t, expectedCookies, resp.Cookies()) } @@ -416,6 +428,7 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } func TestJarCalls(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { pathSuffix := r.RequestURI[1:] if r.RequestURI == "/nosetcookie" { @@ -479,6 +492,7 @@ func (j *RecordingJar) logf(format string, args ...interface{}) { } func TestStreamingGet(t *testing.T) { + defer afterTest(t) say := make(chan string) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() @@ -529,6 +543,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. func TestClientWrites(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -562,6 +577,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) @@ -576,15 +592,20 @@ func TestClientInsecureTransport(t *testing.T) { InsecureSkipVerify: insecure, }, } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - _, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) if (err == nil) != insecure { t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) } + if res != nil { + res.Body.Close() + } } } func TestClientErrorWithRequestURI(t *testing.T) { + defer afterTest(t) req, _ := NewRequest("GET", "http://localhost:1234/", nil) req.RequestURI = "/this/field/is/illegal/and/should/error/" _, err := DefaultClient.Do(req) @@ -613,6 +634,7 @@ func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { } func TestClientWithCorrectTLSServerName(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS.ServerName != "127.0.0.1" { t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) @@ -627,6 +649,7 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { } func TestClientWithIncorrectTLSServerName(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() @@ -644,6 +667,7 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { // Verify Response.ContentLength is populated. http://golang.org/issue/4126 func TestClientHeadContentLength(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if v := r.FormValue("cl"); v != "" { w.Header().Set("Content-Length", v) diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go new file mode 100644 index 00000000000..5977d48b631 --- /dev/null +++ b/libgo/go/net/http/cookiejar/jar.go @@ -0,0 +1,497 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar. +package cookiejar + +import ( + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" +) + +// PublicSuffixList provides the public suffix of a domain. For example: +// - the public suffix of "example.com" is "com", +// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and +// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us". +// +// Implementations of PublicSuffixList must be safe for concurrent use by +// multiple goroutines. +// +// An implementation that always returns "" is valid and may be useful for +// testing but it is not secure: it means that the HTTP server for foo.com can +// set a cookie for bar.com. +// +// A public suffix list implementation is in the package +// code.google.com/p/go.net/publicsuffix. +type PublicSuffixList interface { + // PublicSuffix returns the public suffix of domain. + // + // TODO: specify which of the caller and callee is responsible for IP + // addresses, for leading and trailing dots, for case sensitivity, and + // for IDN/Punycode. + PublicSuffix(domain string) string + + // String returns a description of the source of this public suffix + // list. The description will typically contain something like a time + // stamp or version number. + String() string +} + +// Options are the options for creating a new Jar. +type Options struct { + // PublicSuffixList is the public suffix list that determines whether + // an HTTP server can set a cookie for a domain. + // + // A nil value is valid and may be useful for testing but it is not + // secure: it means that the HTTP server for foo.co.uk can set a cookie + // for bar.co.uk. + PublicSuffixList PublicSuffixList +} + +// Jar implements the http.CookieJar interface from the net/http package. +type Jar struct { + psList PublicSuffixList + + // mu locks the remaining fields. + mu sync.Mutex + + // entries is a set of entries, keyed by their eTLD+1 and subkeyed by + // their name/domain/path. + entries map[string]map[string]entry + + // nextSeqNum is the next sequence number assigned to a new cookie + // created SetCookies. + nextSeqNum uint64 +} + +// New returns a new cookie jar. A nil *Options is equivalent to a zero +// Options. +func New(o *Options) (*Jar, error) { + jar := &Jar{ + entries: make(map[string]map[string]entry), + } + if o != nil { + jar.psList = o.PublicSuffixList + } + return jar, nil +} + +// entry is the internal representation of a cookie. +// +// This struct type is not used outside of this package per se, but the exported +// fields are those of RFC 6265. +type entry struct { + Name string + Value string + Domain string + Path string + Secure bool + HttpOnly bool + Persistent bool + HostOnly bool + Expires time.Time + Creation time.Time + LastAccess time.Time + + // seqNum is a sequence number so that Cookies returns cookies in a + // deterministic order, even for cookies that have equal Path length and + // equal Creation time. This simplifies testing. + seqNum uint64 +} + +// Id returns the domain;path;name triple of e as an id. +func (e *entry) id() string { + return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) +} + +// shouldSend determines whether e's cookie qualifies to be included in a +// request to host/path. It is the caller's responsibility to check if the +// cookie is expired. +func (e *entry) shouldSend(https bool, host, path string) bool { + return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) +} + +// domainMatch implements "domain-match" of RFC 6265 section 5.1.3. +func (e *entry) domainMatch(host string) bool { + if e.Domain == host { + return true + } + return !e.HostOnly && hasDotSuffix(host, e.Domain) +} + +// pathMatch implements "path-match" according to RFC 6265 section 5.1.4. +func (e *entry) pathMatch(requestPath string) bool { + if requestPath == e.Path { + return true + } + if strings.HasPrefix(requestPath, e.Path) { + if e.Path[len(e.Path)-1] == '/' { + return true // The "/any/" matches "/any/path" case. + } else if requestPath[len(e.Path)] == '/' { + return true // The "/any" matches "/any/path" case. + } + } + return false +} + +// hasDotSuffix returns whether s ends in "."+suffix. +func hasDotSuffix(s, suffix string) bool { + return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix +} + +// byPathLength is a []entry sort.Interface that sorts according to RFC 6265 +// section 5.4 point 2: by longest path and then by earliest creation time. +type byPathLength []entry + +func (s byPathLength) Len() int { return len(s) } + +func (s byPathLength) Less(i, j int) bool { + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum +} + +func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Cookies implements the Cookies method of the http.CookieJar interface. +// +// It returns an empty slice if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { + return j.cookies(u, time.Now()) +} + +// cookies is like Cookies but takes the current time as a parameter. +func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { + if u.Scheme != "http" && u.Scheme != "https" { + return cookies + } + host, err := canonicalHost(u.Host) + if err != nil { + return cookies + } + key := jarKey(host, j.psList) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + if submap == nil { + return cookies + } + + https := u.Scheme == "https" + path := u.Path + if path == "" { + path = "/" + } + + modified := false + var selected []entry + for id, e := range submap { + if e.Persistent && !e.Expires.After(now) { + delete(submap, id) + modified = true + continue + } + if !e.shouldSend(https, host, path) { + continue + } + e.LastAccess = now + submap[id] = e + selected = append(selected, e) + modified = true + } + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } + + sort.Sort(byPathLength(selected)) + for _, e := range selected { + cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) + } + + return cookies +} + +// SetCookies implements the SetCookies method of the http.CookieJar interface. +// +// It does nothing if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.setCookies(u, cookies, time.Now()) +} + +// setCookies is like SetCookies but takes the current time as parameter. +func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) { + if len(cookies) == 0 { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + host, err := canonicalHost(u.Host) + if err != nil { + return + } + key := jarKey(host, j.psList) + defPath := defaultPath(u.Path) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + + modified := false + for _, cookie := range cookies { + e, remove, err := j.newEntry(cookie, now, defPath, host) + if err != nil { + continue + } + id := e.id() + if remove { + if submap != nil { + if _, ok := submap[id]; ok { + delete(submap, id) + modified = true + } + } + continue + } + if submap == nil { + submap = make(map[string]entry) + } + + if old, ok := submap[id]; ok { + e.Creation = old.Creation + e.seqNum = old.seqNum + } else { + e.Creation = now + e.seqNum = j.nextSeqNum + j.nextSeqNum++ + } + e.LastAccess = now + submap[id] = e + modified = true + } + + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } +} + +// canonicalHost strips port from host if present and returns the canonicalized +// host name. +func canonicalHost(host string) (string, error) { + var err error + host = strings.ToLower(host) + if hasPort(host) { + host, _, err = net.SplitHostPort(host) + if err != nil { + return "", err + } + } + if strings.HasSuffix(host, ".") { + // Strip trailing dot from fully qualified domain names. + host = host[:len(host)-1] + } + return toASCII(host) +} + +// hasPort returns whether host contains a port number. host may be a host +// name, an IPv4 or an IPv6 address. +func hasPort(host string) bool { + colons := strings.Count(host, ":") + if colons == 0 { + return false + } + if colons == 1 { + return true + } + return host[0] == '[' && strings.Contains(host, "]:") +} + +// jarKey returns the key to use for a jar. +func jarKey(host string, psl PublicSuffixList) string { + if isIP(host) { + return host + } + + var i int + if psl == nil { + i = strings.LastIndex(host, ".") + if i == -1 { + return host + } + } else { + suffix := psl.PublicSuffix(host) + if suffix == host { + return host + } + i = len(host) - len(suffix) + if i <= 0 || host[i-1] != '.' { + // The provided public suffix list psl is broken. + // Storing cookies under host is a safe stopgap. + return host + } + } + prevDot := strings.LastIndex(host[:i-1], ".") + return host[prevDot+1:] +} + +// isIP returns whether host is an IP address. +func isIP(host string) bool { + return net.ParseIP(host) != nil +} + +// defaultPath returns the directory part of an URL's path according to +// RFC 6265 section 5.1.4. +func defaultPath(path string) string { + if len(path) == 0 || path[0] != '/' { + return "/" // Path is empty or malformed. + } + + i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1. + if i == 0 { + return "/" // Path has the form "/abc". + } + return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". +} + +// newEntry creates an entry from a http.Cookie c. now is the current time and +// is compared to c.Expires to determine deletion of c. defPath and host are the +// default-path and the canonical host name of the URL c was received from. +// +// remove is whether the jar should delete this cookie, as it has already +// expired with respect to now. In this case, e may be incomplete, but it will +// be valid to call e.id (which depends on e's Name, Domain and Path). +// +// A malformed c.Domain will result in an error. +func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) { + e.Name = c.Name + + if c.Path == "" || c.Path[0] != '/' { + e.Path = defPath + } else { + e.Path = c.Path + } + + e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain) + if err != nil { + return e, false, err + } + + // MaxAge takes precedence over Expires. + if c.MaxAge < 0 { + return e, true, nil + } else if c.MaxAge > 0 { + e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) + e.Persistent = true + } else { + if c.Expires.IsZero() { + e.Expires = endOfTime + e.Persistent = false + } else { + if !c.Expires.After(now) { + return e, true, nil + } + e.Expires = c.Expires + e.Persistent = true + } + } + + e.Value = c.Value + e.Secure = c.Secure + e.HttpOnly = c.HttpOnly + + return e, false, nil +} + +var ( + errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute") + errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute") + errNoHostname = errors.New("cookiejar: no host name available (IP only)") +) + +// endOfTime is the time when session (non-persistent) cookies expire. +// This instant is representable in most date/time formats (not just +// Go's time.Time) and should be far enough in the future. +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +// domainAndType determines the cookie's domain and hostOnly attribute. +func (j *Jar) domainAndType(host, domain string) (string, bool, error) { + if domain == "" { + // No domain attribute in the SetCookie header indicates a + // host cookie. + return host, true, nil + } + + if isIP(host) { + // According to RFC 6265 domain-matching includes not being + // an IP address. + // TODO: This might be relaxed as in common browsers. + return "", false, errNoHostname + } + + // From here on: If the cookie is valid, it is a domain cookie (with + // the one exception of a public suffix below). + // See RFC 6265 section 5.2.3. + if domain[0] == '.' { + domain = domain[1:] + } + + if len(domain) == 0 || domain[0] == '.' { + // Received either "Domain=." or "Domain=..some.thing", + // both are illegal. + return "", false, errMalformedDomain + } + domain = strings.ToLower(domain) + + if domain[len(domain)-1] == '.' { + // We received stuff like "Domain=www.example.com.". + // Browsers do handle such stuff (actually differently) but + // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in + // requiring a reject. 4.1.2.3 is not normative, but + // "Domain Matching" (5.1.3) and "Canonicalized Host Names" + // (5.1.2) are. + return "", false, errMalformedDomain + } + + // See RFC 6265 section 5.3 #5. + if j.psList != nil { + if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) { + if host == domain { + // This is the one exception in which a cookie + // with a domain attribute is a host cookie. + return host, true, nil + } + return "", false, errIllegalDomain + } + } + + // The domain must domain-match host: www.mycompany.com cannot + // set cookies for .ourcompetitors.com. + if host != domain && !hasDotSuffix(host, domain) { + return "", false, errIllegalDomain + } + + return domain, false, nil +} diff --git a/libgo/go/net/http/cookiejar/jar_test.go b/libgo/go/net/http/cookiejar/jar_test.go new file mode 100644 index 00000000000..3aa601586e3 --- /dev/null +++ b/libgo/go/net/http/cookiejar/jar_test.go @@ -0,0 +1,1267 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "testing" + "time" +) + +// tNow is the synthetic current time used as now during testing. +var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) + +// testPSL implements PublicSuffixList with just two rules: "co.uk" +// and the default rule "*". +type testPSL struct{} + +func (testPSL) String() string { + return "testPSL" +} +func (testPSL) PublicSuffix(d string) string { + if d == "co.uk" || strings.HasSuffix(d, ".co.uk") { + return "co.uk" + } + return d[strings.LastIndex(d, ".")+1:] +} + +// newTestJar creates an empty Jar with testPSL as the public suffix list. +func newTestJar() *Jar { + jar, err := New(&Options{PublicSuffixList: testPSL{}}) + if err != nil { + panic(err) + } + return jar +} + +var hasDotSuffixTests = [...]struct { + s, suffix string +}{ + {"", ""}, + {"", "."}, + {"", "x"}, + {".", ""}, + {".", "."}, + {".", ".."}, + {".", "x"}, + {".", "x."}, + {".", ".x"}, + {".", ".x."}, + {"x", ""}, + {"x", "."}, + {"x", ".."}, + {"x", "x"}, + {"x", "x."}, + {"x", ".x"}, + {"x", ".x."}, + {".x", ""}, + {".x", "."}, + {".x", ".."}, + {".x", "x"}, + {".x", "x."}, + {".x", ".x"}, + {".x", ".x."}, + {"x.", ""}, + {"x.", "."}, + {"x.", ".."}, + {"x.", "x"}, + {"x.", "x."}, + {"x.", ".x"}, + {"x.", ".x."}, + {"com", ""}, + {"com", "m"}, + {"com", "om"}, + {"com", "com"}, + {"com", ".com"}, + {"com", "x.com"}, + {"com", "xcom"}, + {"com", "xorg"}, + {"com", "org"}, + {"com", "rg"}, + {"foo.com", ""}, + {"foo.com", "m"}, + {"foo.com", "om"}, + {"foo.com", "com"}, + {"foo.com", ".com"}, + {"foo.com", "o.com"}, + {"foo.com", "oo.com"}, + {"foo.com", "foo.com"}, + {"foo.com", ".foo.com"}, + {"foo.com", "x.foo.com"}, + {"foo.com", "xfoo.com"}, + {"foo.com", "xfoo.org"}, + {"foo.com", "foo.org"}, + {"foo.com", "oo.org"}, + {"foo.com", "o.org"}, + {"foo.com", ".org"}, + {"foo.com", "org"}, + {"foo.com", "rg"}, +} + +func TestHasDotSuffix(t *testing.T) { + for _, tc := range hasDotSuffixTests { + got := hasDotSuffix(tc.s, tc.suffix) + want := strings.HasSuffix(tc.s, "."+tc.suffix) + if got != want { + t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want) + } + } +} + +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bücher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + "[bad.unmatched.bracket:": "error", +} + +func TestCanonicalHost(t *testing.T) { + for h, want := range canonicalHostTests { + got, err := canonicalHost(h) + if want == "error" { + if err == nil { + t.Errorf("%q: got nil error, want non-nil", h) + } + continue + } + if err != nil { + t.Errorf("%q: %v", h, err) + continue + } + if got != want { + t.Errorf("%q: got %q, want %q", h, got, want) + continue + } + } +} + +var hasPortTests = map[string]bool{ + "www.example.com": false, + "www.example.com:80": true, + "127.0.0.1": false, + "127.0.0.1:8080": true, + "2001:4860:0:2001::68": false, + "[2001::0:::68]:80": true, +} + +func TestHasPort(t *testing.T) { + for host, want := range hasPortTests { + if got := hasPort(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var jarKeyTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "bbc.co.uk", + "www.bbc.co.uk": "bbc.co.uk", + "bbc.co.uk": "bbc.co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKey(t *testing.T) { + for host, want := range jarKeyTests { + if got := jarKey(host, testPSL{}); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var jarKeyNilPSLTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "co.uk", + "www.bbc.co.uk": "co.uk", + "bbc.co.uk": "co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKeyNilPSL(t *testing.T) { + for host, want := range jarKeyNilPSLTests { + if got := jarKey(host, nil); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var isIPTests = map[string]bool{ + "127.0.0.1": true, + "1.2.3.4": true, + "2001:4860:0:2001::68": true, + "example.com": false, + "1.1.1.300": false, + "www.foo.bar.net": false, + "123.foo.bar.net": false, +} + +func TestIsIP(t *testing.T) { + for host, want := range isIPTests { + if got := isIP(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var defaultPathTests = map[string]string{ + "/": "/", + "/abc": "/", + "/abc/": "/abc", + "/abc/xyz": "/abc", + "/abc/xyz/": "/abc/xyz", + "/a/b/c.html": "/a/b", + "": "/", + "strange": "/", + "//": "/", + "/a//b": "/a/", + "/a/./b": "/a/.", + "/a/../b": "/a/..", +} + +func TestDefaultPath(t *testing.T) { + for path, want := range defaultPathTests { + if got := defaultPath(path); got != want { + t.Errorf("%q: got %q, want %q", path, got, want) + } + } +} + +var domainAndTypeTests = [...]struct { + host string // host Set-Cookie header was received from + domain string // domain attribute in Set-Cookie header + wantDomain string // expected domain of cookie + wantHostOnly bool // expected host-cookie flag + wantErr error // expected error +}{ + {"www.example.com", "", "www.example.com", true, nil}, + {"127.0.0.1", "", "127.0.0.1", true, nil}, + {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil}, + {"www.example.com", "example.com", "example.com", false, nil}, + {"www.example.com", ".example.com", "example.com", false, nil}, + {"www.example.com", "www.example.com", "www.example.com", false, nil}, + {"www.example.com", ".www.example.com", "www.example.com", false, nil}, + {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil}, + {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil}, + {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil}, + {"127.0.0.1", "127.0.0.1", "", false, errNoHostname}, + {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname}, + {"www.example.com", ".", "", false, errMalformedDomain}, + {"www.example.com", "..", "", false, errMalformedDomain}, + {"www.example.com", "other.com", "", false, errIllegalDomain}, + {"www.example.com", "com", "", false, errIllegalDomain}, + {"www.example.com", ".com", "", false, errIllegalDomain}, + {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain}, + {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain}, + {"com", "", "com", true, nil}, + {"com", "com", "com", true, nil}, + {"com", ".com", "com", true, nil}, + {"co.uk", "", "co.uk", true, nil}, + {"co.uk", "co.uk", "co.uk", true, nil}, + {"co.uk", ".co.uk", "co.uk", true, nil}, +} + +func TestDomainAndType(t *testing.T) { + jar := newTestJar() + for _, tc := range domainAndTypeTests { + domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain) + if err != tc.wantErr { + t.Errorf("%q/%q: got %q error, want %q", + tc.host, tc.domain, err, tc.wantErr) + continue + } + if err != nil { + continue + } + if domain != tc.wantDomain || hostOnly != tc.wantHostOnly { + t.Errorf("%q/%q: got %q/%t want %q/%t", + tc.host, tc.domain, domain, hostOnly, + tc.wantDomain, tc.wantHostOnly) + } + } +} + +// expiresIn creates an expires attribute delta seconds from tNow. +func expiresIn(delta int) string { + t := tNow.Add(time.Duration(delta) * time.Second) + return "expires=" + t.Format(time.RFC1123) +} + +// mustParseURL parses s to an URL and panics on error. +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil || u.Scheme == "" || u.Host == "" { + panic(fmt.Sprintf("Unable to parse URL %s.", s)) + } + return u +} + +// jarTest encapsulates the following actions on a jar: +// 1. Perform SetCookies with fromURL and the cookies from setCookies. +// (Done at time tNow + 0 ms.) +// 2. Check that the entries in the jar matches content. +// (Done at time tNow + 1001 ms.) +// 3. For each query in tests: Check that Cookies with toURL yields the +// cookies in want. +// (Query n done at tNow + (n+2)*1001 ms.) +type jarTest struct { + description string // The description of what this test is supposed to test + fromURL string // The full URL of the request from which Set-Cookie headers where received + setCookies []string // All the cookies received from fromURL + content string // The whole (non-expired) content of the jar + queries []query // Queries to test the Jar.Cookies method +} + +// query contains one test of the cookies returned from Jar.Cookies. +type query struct { + toURL string // the URL in the Cookies call + want string // the expected list of cookies (order matters) +} + +// run runs the jarTest. +func (test jarTest) run(t *testing.T, jar *Jar) { + now := tNow + + // Populate jar with cookies. + setCookies := make([]*http.Cookie, len(test.setCookies)) + for i, cs := range test.setCookies { + cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() + if len(cookies) != 1 { + panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) + } + setCookies[i] = cookies[0] + } + jar.setCookies(mustParseURL(test.fromURL), setCookies, now) + now = now.Add(1001 * time.Millisecond) + + // Serialize non-expired entries in the form "name1=val1 name2=val2". + var cs []string + for _, submap := range jar.entries { + for _, cookie := range submap { + if !cookie.Expires.After(now) { + continue + } + cs = append(cs, cookie.Name+"="+cookie.Value) + } + } + sort.Strings(cs) + got := strings.Join(cs, " ") + + // Make sure jar content matches our expectations. + if got != test.content { + t.Errorf("Test %q Content\ngot %q\nwant %q", + test.description, got, test.content) + } + + // Test different calls to Cookies. + for i, query := range test.queries { + now = now.Add(1001 * time.Millisecond) + var s []string + for _, c := range jar.cookies(mustParseURL(query.toURL), now) { + s = append(s, c.Name+"="+c.Value) + } + if got := strings.Join(s, " "); got != query.want { + t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) + } + } +} + +// basicsTests contains fundamental tests. Each jarTest has to be performed on +// a fresh, empty Jar. +var basicsTests = [...]jarTest{ + { + "Retrieval of a plain host cookie.", + "http://www.host.test/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + {"ftp://www.host.test", ""}, + {"ftp://www.host.test/", ""}, + {"ftp://www.host.test/some/path", ""}, + {"http://www.other.org", ""}, + {"http://sibling.host.test", ""}, + {"http://deep.www.host.test", ""}, + }, + }, + { + "Secure cookies are not returned to http.", + "http://www.host.test/", + []string{"A=a; secure"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some/path", ""}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + }, + }, + { + "Explicit path.", + "http://www.host.test/", + []string{"A=a; path=/some/path"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #1: path is a directory.", + "http://www.host.test/some/path/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #2: path is not a directory.", + "http://www.host.test/some/path/index.html", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #3: no path in URL at all.", + "http://www.host.test", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + }, + }, + { + "Cookies are sorted by path length.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "D=d; path=/foo"}, + "A=a B=b C=c D=d", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"}, + {"http://www.host.test/foo/bar", "A=a D=d"}, + }, + }, + { + "Creation time determines sorting on same length paths.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "X=x; path=/foo/bar", + "Y=y; path=/foo/bar/baz/qux", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "W=w; path=/foo/bar/baz", + "Z=z; path=/foo", + "D=d; path=/foo"}, + "A=a B=b C=c D=d W=w X=x Y=y Z=z", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"}, + }, + }, + { + "Sorting of same-name cookies.", + "http://www.host.test/", + []string{ + "A=1; path=/", + "A=2; path=/path", + "A=3; path=/quux", + "A=4; path=/path/foo", + "A=5; domain=.host.test; path=/path", + "A=6; domain=.host.test; path=/quux", + "A=7; domain=.host.test; path=/path/foo", + }, + "A=1 A=2 A=3 A=4 A=5 A=6 A=7", + []query{ + {"http://www.host.test/path", "A=2 A=5 A=1"}, + {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"}, + }, + }, + { + "Disallow domain cookie on public suffix.", + "http://www.bbc.co.uk", + []string{ + "a=1", + "b=2; domain=co.uk", + }, + "a=1", + []query{{"http://www.bbc.co.uk", "a=1"}}, + }, + { + "Host cookie on IP.", + "http://192.168.0.10", + []string{"a=1"}, + "a=1", + []query{{"http://192.168.0.10", "a=1"}}, + }, + { + "Port is ignored #1.", + "http://www.host.test/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + }, + }, + { + "Port is ignored #2.", + "http://www.host.test:8080/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + {"http://www.host.test:1234/", "a=1"}, + }, + }, +} + +func TestBasics(t *testing.T) { + for _, test := range basicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// updateAndDeleteTests contains jarTests which must be performed on the same +// Jar. +var updateAndDeleteTests = [...]jarTest{ + { + "Set initial cookies.", + "http://www.host.test", + []string{ + "a=1", + "b=2; secure", + "c=3; httponly", + "d=4; secure; httponly"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://www.host.test", "a=1 c=3"}, + {"https://www.host.test", "a=1 b=2 c=3 d=4"}, + }, + }, + { + "Update value via http.", + "http://www.host.test", + []string{ + "a=w", + "b=x; secure", + "c=y; httponly", + "d=z; secure; httponly"}, + "a=w b=x c=y d=z", + []query{ + {"http://www.host.test", "a=w c=y"}, + {"https://www.host.test", "a=w b=x c=y d=z"}, + }, + }, + { + "Clear Secure flag from a http.", + "http://www.host.test/", + []string{ + "b=xx", + "d=zz; httponly"}, + "a=w b=xx c=y d=zz", + []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}}, + }, + { + "Delete all.", + "http://www.host.test/", + []string{ + "a=1; max-Age=-1", // delete via MaxAge + "b=2; " + expiresIn(-10), // delete via Expires + "c=2; max-age=-1; " + expiresIn(-10), // delete via both + "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence + "", + []query{{"http://www.host.test", ""}}, + }, + { + "Refill #1.", + "http://www.host.test", + []string{ + "A=1", + "A=2; path=/foo", + "A=3; domain=.host.test", + "A=4; path=/foo; domain=.host.test"}, + "A=1 A=2 A=3 A=4", + []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}}, + }, + { + "Refill #2.", + "http://www.google.com", + []string{ + "A=6", + "A=7; path=/foo", + "A=8; domain=.google.com", + "A=9; path=/foo; domain=.google.com"}, + "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"}, + }, + }, + { + "Delete A7.", + "http://www.google.com", + []string{"A=; path=/foo; max-age=-1"}, + "A=1 A=2 A=3 A=4 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A4.", + "http://www.host.test", + []string{"A=; path=/foo; domain=host.test; max-age=-1"}, + "A=1 A=2 A=3 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A6.", + "http://www.google.com", + []string{"A=; max-age=-1"}, + "A=1 A=2 A=3 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A3.", + "http://www.host.test", + []string{"A=; domain=host.test; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "No cross-domain delete.", + "http://www.host.test", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A8 and A9.", + "http://www.google.com", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", ""}, + }, + }, +} + +func TestUpdateAndDelete(t *testing.T) { + jar := newTestJar() + for _, test := range updateAndDeleteTests { + test.run(t, jar) + } +} + +func TestExpiration(t *testing.T) { + jar := newTestJar() + jarTest{ + "Expiration.", + "http://www.host.test", + []string{ + "a=1", + "b=2; max-age=3", + "c=3; " + expiresIn(3), + "d=4; max-age=5", + "e=5; " + expiresIn(5), + "f=6; max-age=100", + }, + "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms + []query{ + {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms + }, + }.run(t, jar) +} + +// +// Tests derived from Chromium's cookie_store_unittest.h. +// + +// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain +// Some of the original tests are in a bad condition (e.g. +// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g. +// TestNonDottedAndTLD #1 and #6) and have not been ported. + +// chromiumBasicsTests contains fundamental tests. Each jarTest has to be +// performed on a fresh, empty Jar. +var chromiumBasicsTests = [...]jarTest{ + { + "DomainWithTrailingDotTest.", + "http://www.google.com/", + []string{ + "a=1; domain=.www.google.com.", + "b=2; domain=.www.google.com.."}, + "", + []query{ + {"http://www.google.com", ""}, + }, + }, + { + "ValidSubdomainTest #1.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"}, + {"http://b.c.d.com", "b=2 c=3 d=4"}, + {"http://c.d.com", "c=3 d=4"}, + {"http://d.com", "d=4"}, + }, + }, + { + "ValidSubdomainTest #2.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com", + "X=bcd; domain=.b.c.d.com", + "X=cd; domain=.c.d.com"}, + "X=bcd X=cd a=1 b=2 c=3 d=4", + []query{ + {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"}, + {"http://c.d.com", "c=3 d=4 X=cd"}, + }, + }, + { + "InvalidDomainTest #1.", + "http://foo.bar.com", + []string{ + "a=1; domain=.yo.foo.bar.com", + "b=2; domain=.foo.com", + "c=3; domain=.bar.foo.com", + "d=4; domain=.foo.bar.com.net", + "e=5; domain=ar.com", + "f=6; domain=.", + "g=7; domain=/", + "h=8; domain=http://foo.bar.com", + "i=9; domain=..foo.bar.com", + "j=10; domain=..bar.com", + "k=11; domain=.foo.bar.com?blah", + "l=12; domain=.foo.bar.com/blah", + "m=12; domain=.foo.bar.com:80", + "n=14; domain=.foo.bar.com:", + "o=15; domain=.foo.bar.com#sup", + }, + "", // Jar is empty. + []query{{"http://foo.bar.com", ""}}, + }, + { + "InvalidDomainTest #2.", + "http://foo.com.com", + []string{"a=1; domain=.foo.com.com.com"}, + "", + []query{{"http://foo.bar.com", ""}}, + }, + { + "DomainWithoutLeadingDotTest #1.", + "http://manage.hosted.filefront.com", + []string{"a=1; domain=filefront.com"}, + "a=1", + []query{{"http://www.filefront.com", "a=1"}}, + }, + { + "DomainWithoutLeadingDotTest #2.", + "http://www.google.com", + []string{"a=1; domain=www.google.com"}, + "a=1", + []query{ + {"http://www.google.com", "a=1"}, + {"http://sub.www.google.com", "a=1"}, + {"http://something-else.com", ""}, + }, + }, + { + "CaseInsensitiveDomainTest.", + "http://www.google.com", + []string{ + "a=1; domain=.GOOGLE.COM", + "b=2; domain=.www.gOOgLE.coM"}, + "a=1 b=2", + []query{{"http://www.google.com", "a=1 b=2"}}, + }, + { + "TestIpAddress #1.", + "http://1.2.3.4/foo", + []string{"a=1; path=/"}, + "a=1", + []query{{"http://1.2.3.4/foo", "a=1"}}, + }, + { + "TestIpAddress #2.", + "http://1.2.3.4/foo", + []string{ + "a=1; domain=.1.2.3.4", + "b=2; domain=.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestIpAddress #3.", + "http://1.2.3.4/foo", + []string{"a=1; domain=1.2.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestNonDottedAndTLD #2.", + "http://com./index.html", + []string{"a=1"}, + "a=1", + []query{ + {"http://com./index.html", "a=1"}, + {"http://no-cookies.com./index.html", ""}, + }, + }, + { + "TestNonDottedAndTLD #3.", + "http://a.b", + []string{ + "a=1; domain=.b", + "b=2; domain=b"}, + "", + []query{{"http://bar.foo", ""}}, + }, + { + "TestNonDottedAndTLD #4.", + "http://google.com", + []string{ + "a=1; domain=.com", + "b=2; domain=com"}, + "", + []query{{"http://google.com", ""}}, + }, + { + "TestNonDottedAndTLD #5.", + "http://google.co.uk", + []string{ + "a=1; domain=.co.uk", + "b=2; domain=.uk"}, + "", + []query{ + {"http://google.co.uk", ""}, + {"http://else.co.com", ""}, + {"http://else.uk", ""}, + }, + }, + { + "TestHostEndsWithDot.", + "http://www.google.com", + []string{ + "a=1", + "b=2; domain=.www.google.com."}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "PathTest", + "http://www.google.izzle", + []string{"a=1; path=/wee"}, + "a=1", + []query{ + {"http://www.google.izzle/wee", "a=1"}, + {"http://www.google.izzle/wee/", "a=1"}, + {"http://www.google.izzle/wee/war", "a=1"}, + {"http://www.google.izzle/wee/war/more/more", "a=1"}, + {"http://www.google.izzle/weehee", ""}, + {"http://www.google.izzle/", ""}, + }, + }, +} + +func TestChromiumBasics(t *testing.T) { + for _, test := range chromiumBasicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// chromiumDomainTests contains jarTests which must be executed all on the +// same Jar. +var chromiumDomainTests = [...]jarTest{ + { + "Fill #1.", + "http://www.google.izzle", + []string{"A=B"}, + "A=B", + []query{{"http://www.google.izzle", "A=B"}}, + }, + { + "Fill #2.", + "http://www.google.izzle", + []string{"C=D; domain=.google.izzle"}, + "A=B C=D", + []query{{"http://www.google.izzle", "A=B C=D"}}, + }, + { + "Verify A is a host cookie and not accessible from subdomain.", + "http://unused.nil", + []string{}, + "A=B C=D", + []query{{"http://foo.www.google.izzle", "C=D"}}, + }, + { + "Verify domain cookies are found on proper domain.", + "http://www.google.izzle", + []string{"E=F; domain=.www.google.izzle"}, + "A=B C=D E=F", + []query{{"http://www.google.izzle", "A=B C=D E=F"}}, + }, + { + "Leading dots in domain attributes are optional.", + "http://www.google.izzle", + []string{"G=H; domain=www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #1.", + "http://www.google.izzle", + []string{"K=L; domain=.bar.www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #2.", + "http://unused.nil", + []string{}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, +} + +func TestChromiumDomain(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDomainTests { + test.run(t, jar) + } + +} + +// chromiumDeletionTests must be performed all on the same Jar. +var chromiumDeletionTests = [...]jarTest{ + { + "Create session cookie a1.", + "http://www.google.com", + []string{"a=1"}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "Delete sc a1 via MaxAge.", + "http://www.google.com", + []string{"a=1; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create session cookie b2.", + "http://www.google.com", + []string{"b=2"}, + "b=2", + []query{{"http://www.google.com", "b=2"}}, + }, + { + "Delete sc b2 via Expires.", + "http://www.google.com", + []string{"b=2; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie c3.", + "http://www.google.com", + []string{"c=3; max-age=3600"}, + "c=3", + []query{{"http://www.google.com", "c=3"}}, + }, + { + "Delete pc c3 via MaxAge.", + "http://www.google.com", + []string{"c=3; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie d4.", + "http://www.google.com", + []string{"d=4; max-age=3600"}, + "d=4", + []query{{"http://www.google.com", "d=4"}}, + }, + { + "Delete pc d4 via Expires.", + "http://www.google.com", + []string{"d=4; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, +} + +func TestChromiumDeletion(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDeletionTests { + test.run(t, jar) + } +} + +// domainHandlingTests tests and documents the rules for domain handling. +// Each test must be performed on an empty new Jar. +var domainHandlingTests = [...]jarTest{ + { + "Host cookie", + "http://www.host.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", ""}, + {"http://bar.host.test", ""}, + {"http://foo.www.host.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #1", + "http://www.host.test", + []string{"a=1; domain=host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #2", + "http://www.host.test", + []string{"a=1; domain=.host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on TLD.", + "http://com", + []string{"a=1"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Domain cookie on TLD becomes a host cookie.", + "http://com", + []string{"a=1; domain=com"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Host cookie on public suffix.", + "http://co.uk", + []string{"a=1"}, + "a=1", + []query{ + {"http://co.uk", "a=1"}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, + { + "Domain cookie on public suffix is ignored.", + "http://some.co.uk", + []string{"a=1; domain=co.uk"}, + "", + []query{ + {"http://co.uk", ""}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, +} + +func TestDomainHandling(t *testing.T) { + for _, test := range domainHandlingTests { + jar := newTestJar() + test.run(t, jar) + } +} diff --git a/libgo/go/net/http/cookiejar/punycode.go b/libgo/go/net/http/cookiejar/punycode.go new file mode 100644 index 00000000000..ea7ceb5ef3f --- /dev/null +++ b/libgo/go/net/http/cookiejar/punycode.go @@ -0,0 +1,159 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +// This file implements the Punycode algorithm from RFC 3492. + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +// These parameter values are specified in section 5. +// +// All computation is done with int32s, so that overflow behavior is identical +// regardless of whether int is 32-bit or 64-bit. +const ( + base int32 = 36 + damp int32 = 700 + initialBias int32 = 72 + initialN int32 = 128 + skew int32 = 38 + tmax int32 = 26 + tmin int32 = 1 +) + +// encode encodes a string as specified in section 6.3 and prepends prefix to +// the result. +// +// The "while h < length(input)" line in the specification becomes "for +// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes. +func encode(prefix, s string) (string, error) { + output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) + copy(output, prefix) + delta, n, bias := int32(0), initialN, initialBias + b, remaining := int32(0), int32(0) + for _, r := range s { + if r < 0x80 { + b++ + output = append(output, byte(r)) + } else { + remaining++ + } + } + h := b + if b > 0 { + output = append(output, '-') + } + for remaining != 0 { + m := int32(0x7fffffff) + for _, r := range s { + if m > r && r >= n { + m = r + } + } + delta += (m - n) * (h + 1) + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + n = m + for _, r := range s { + if r < n { + delta++ + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + continue + } + if r > n { + continue + } + q := delta + for k := base; ; k += base { + t := k - bias + if t < tmin { + t = tmin + } else if t > tmax { + t = tmax + } + if q < t { + break + } + output = append(output, encodeDigit(t+(q-t)%(base-t))) + q = (q - t) / (base - t) + } + output = append(output, encodeDigit(q)) + bias = adapt(delta, h+1, h == b) + delta = 0 + h++ + remaining-- + } + delta++ + n++ + } + return string(output), nil +} + +func encodeDigit(digit int32) byte { + switch { + case 0 <= digit && digit < 26: + return byte(digit + 'a') + case 26 <= digit && digit < 36: + return byte(digit + ('0' - 26)) + } + panic("cookiejar: internal error in punycode encoding") +} + +// adapt is the bias adaptation function specified in section 6.1. +func adapt(delta, numPoints int32, firstTime bool) int32 { + if firstTime { + delta /= damp + } else { + delta /= 2 + } + delta += delta / numPoints + k := int32(0) + for delta > ((base-tmin)*tmax)/2 { + delta /= base - tmin + k += base + } + return k + (base-tmin+1)*delta/(delta+skew) +} + +// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and +// friends) and not Punycode (RFC 3492) per se. + +// acePrefix is the ASCII Compatible Encoding prefix. +const acePrefix = "xn--" + +// toASCII converts a domain or domain label to its ASCII form. For example, +// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and +// toASCII("golang") is "golang". +func toASCII(s string) (string, error) { + if ascii(s) { + return s, nil + } + labels := strings.Split(s, ".") + for i, label := range labels { + if !ascii(label) { + a, err := encode(acePrefix, label) + if err != nil { + return "", err + } + labels[i] = a + } + } + return strings.Join(labels, "."), nil +} + +func ascii(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} diff --git a/libgo/go/net/http/cookiejar/punycode_test.go b/libgo/go/net/http/cookiejar/punycode_test.go new file mode 100644 index 00000000000..0301de14e46 --- /dev/null +++ b/libgo/go/net/http/cookiejar/punycode_test.go @@ -0,0 +1,161 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "testing" +) + +var punycodeTestCases = [...]struct { + s, encoded string +}{ + {"", ""}, + {"-", "--"}, + {"-a", "-a-"}, + {"-a-", "-a--"}, + {"a", "a-"}, + {"a-", "a--"}, + {"a-b", "a-b-"}, + {"books", "books-"}, + {"bücher", "bcher-kva"}, + {"Hello世界", "Hello-ck1hg65u"}, + {"ü", "tda"}, + {"üý", "tdac"}, + + // The test cases below come from RFC 3492 section 7.1 with Errata 3026. + { + // (A) Arabic (Egyptian). + "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" + + "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F", + "egbpdaj6bu4bxfgehfvwxn", + }, + { + // (B) Chinese (simplified). + "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587", + "ihqwcrb4cv8a8dqg056pqjye", + }, + { + // (C) Chinese (traditional). + "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587", + "ihqwctvzc91f659drss3x8bo0yb", + }, + { + // (D) Czech. + "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" + + "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" + + "\u0065\u0073\u006B\u0079", + "Proprostnemluvesky-uyb24dma41a", + }, + { + // (E) Hebrew. + "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" + + "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" + + "\u05D1\u05E8\u05D9\u05EA", + "4dbcagdahymbxekheh6e0a7fei0b", + }, + { + // (F) Hindi (Devanagari). + "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" + + "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" + + "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" + + "\u0939\u0948\u0902", + "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd", + }, + { + // (G) Japanese (kanji and hiragana). + "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" + + "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B", + "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa", + }, + { + // (H) Korean (Hangul syllables). + "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" + + "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" + + "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C", + "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" + + "psd879ccm6fea98c", + }, + { + // (I) Russian (Cyrillic). + "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" + + "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" + + "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" + + "\u0438", + "b1abfaaepdrnnbgefbadotcwatmq2g4l", + }, + { + // (J) Spanish. + "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" + + "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" + + "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" + + "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" + + "\u0061\u00F1\u006F\u006C", + "PorqunopuedensimplementehablarenEspaol-fmd56a", + }, + { + // (K) Vietnamese. + "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" + + "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" + + "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" + + "\u0056\u0069\u1EC7\u0074", + "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g", + }, + { + // (L) 3<nen>B<gumi><kinpachi><sensei>. + "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F", + "3B-ww4c5e180e575a65lsy2b", + }, + { + // (M) <amuro><namie>-with-SUPER-MONKEYS. + "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" + + "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" + + "\u004F\u004E\u004B\u0045\u0059\u0053", + "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n", + }, + { + // (N) Hello-Another-Way-<sorezore><no><basho>. + "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" + + "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" + + "\u305D\u308C\u305E\u308C\u306E\u5834\u6240", + "Hello-Another-Way--fc4qua05auwb3674vfr0b", + }, + { + // (O) <hitotsu><yane><no><shita>2. + "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032", + "2-u9tlzr9756bt3uc0v", + }, + { + // (P) Maji<de>Koi<suru>5<byou><mae> + "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" + + "\u308B\u0035\u79D2\u524D", + "MajiKoi5-783gue6qz075azm5e", + }, + { + // (Q) <pafii>de<runba> + "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0", + "de-jg4avhby1noc0d", + }, + { + // (R) <sono><supiido><de> + "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067", + "d9juau41awczczp", + }, + { + // (S) -> $1.00 <- + "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" + + "\u003C\u002D", + "-> $1.00 <--", + }, +} + +func TestPunycode(t *testing.T) { + for _, tc := range punycodeTestCases { + if got, err := encode("", tc.s); err != nil { + t.Errorf(`encode("", %q): %v`, tc.s, err) + } else if got != tc.encoded { + t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded) + } + } +} diff --git a/libgo/go/net/http/example_test.go b/libgo/go/net/http/example_test.go index 22073eaf7aa..bc60df7f2b5 100644 --- a/libgo/go/net/http/example_test.go +++ b/libgo/go/net/http/example_test.go @@ -51,6 +51,20 @@ func ExampleGet() { } func ExampleFileServer() { - // we use StripPrefix so that /tmpfiles/somefile will access /tmp/somefile + // Simple static webserver: + log.Fatal(http.ListenAndServe(":8080", http.FileServer(http.Dir("/usr/share/doc")))) +} + +func ExampleFileServer_stripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: + http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) +} + +func ExampleStripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) } diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index a7a07852d18..3fc24532676 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -16,10 +16,16 @@ func NewLoggingConn(baseName string, c net.Conn) net.Conn { return newLoggingConn(baseName, c) } +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqConn) +} + func (t *Transport) IdleConnKeysForTesting() (keys []string) { keys = make([]string, 0) - t.idleLk.Lock() - defer t.idleLk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return } @@ -30,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { } func (t *Transport) IdleConnCountForTesting(cacheKey string) int { - t.idleLk.Lock() - defer t.idleLk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return 0 } @@ -48,3 +54,5 @@ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { } return &timeoutHandler{handler, f, ""} } + +var DefaultUserAgent = defaultUserAgent diff --git a/libgo/go/net/http/fcgi/child.go b/libgo/go/net/http/fcgi/child.go index c8b9a33c87b..60b794e0775 100644 --- a/libgo/go/net/http/fcgi/child.go +++ b/libgo/go/net/http/fcgi/child.go @@ -10,10 +10,12 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "net/http" "net/http/cgi" "os" + "strings" "time" ) @@ -152,20 +154,23 @@ func (c *child) serve() { var errCloseConn = errors.New("fcgi: connection should be closed") +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + func (c *child) handleRecord(rec *record) error { req, ok := c.requests[rec.h.Id] if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { // The spec says to ignore unknown request IDs. return nil } - if ok && rec.h.Type == typeBeginRequest { - // The server is trying to begin a request with the same ID - // as an in-progress request. This is an error. - return errors.New("fcgi: received ID that is already in-flight") - } switch rec.h.Type { case typeBeginRequest: + if req != nil { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return errors.New("fcgi: received ID that is already in-flight") + } + var br beginRequest if err := br.read(rec.content()); err != nil { return err @@ -175,6 +180,7 @@ func (c *child) handleRecord(rec *record) error { return nil } c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + return nil case typeParams: // NOTE(eds): Technically a key-value pair can straddle the boundary // between two packets. We buffer until we've received all parameters. @@ -183,6 +189,7 @@ func (c *child) handleRecord(rec *record) error { return nil } req.parseParams() + return nil case typeStdin: content := rec.content() if req.pw == nil { @@ -191,6 +198,8 @@ func (c *child) handleRecord(rec *record) error { // body could be an io.LimitReader, but it shouldn't matter // as long as both sides are behaving. body, req.pw = io.Pipe() + } else { + body = emptyBody } go c.serveRequest(req, body) } @@ -201,24 +210,29 @@ func (c *child) handleRecord(rec *record) error { } else if req.pw != nil { req.pw.Close() } + return nil case typeGetValues: values := map[string]string{"FCGI_MPXS_CONNS": "1"} c.conn.writePairs(typeGetValuesResult, 0, values) + return nil case typeData: // If the filter role is implemented, read the data stream here. + return nil case typeAbortRequest: + println("abort") delete(c.requests, rec.h.Id) c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) if !req.keepConn { // connection will close upon return return errCloseConn } + return nil default: b := make([]byte, 8) b[0] = byte(rec.h.Type) c.conn.writeRecord(typeUnknownType, 0, b) + return nil } - return nil } func (c *child) serveRequest(req *request, body io.ReadCloser) { @@ -232,11 +246,19 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { httpReq.Body = body c.handler.ServeHTTP(r, httpReq) } - if body != nil { - body.Close() - } r.Close() c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + + // Consume the entire body, so the host isn't still writing to + // us when we close the socket below in the !keepConn case, + // otherwise we'd send a RST. (golang.org/issue/4183) + // TODO(bradfitz): also bound this copy in time. Or send + // some sort of abort request to the host, so the host + // can properly cut off the client sending all the data. + // For now just bound it a little and + io.CopyN(ioutil.Discard, body, 100<<20) + body.Close() + if !req.keepConn { c.conn.Close() } @@ -267,5 +289,4 @@ func Serve(l net.Listener, handler http.Handler) error { c := newChild(rw, handler) go c.serve() } - panic("unreachable") } diff --git a/libgo/go/net/http/filetransport_test.go b/libgo/go/net/http/filetransport_test.go index 039926b5382..6f1a537e2ed 100644 --- a/libgo/go/net/http/filetransport_test.go +++ b/libgo/go/net/http/filetransport_test.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http_test +package http import ( "io/ioutil" - "net/http" "os" "path/filepath" "testing" @@ -32,9 +31,9 @@ func TestFileTransport(t *testing.T) { defer os.Remove(dname) defer os.Remove(fname) - tr := &http.Transport{} - tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname))) - c := &http.Client{Transport: tr} + tr := &Transport{} + tr.RegisterProtocol("file", NewFileTransport(Dir(dname))) + c := &Client{Transport: tr} fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} for _, urlstr := range fooURLs { @@ -62,4 +61,5 @@ func TestFileTransport(t *testing.T) { if res.StatusCode != 404 { t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) } + res.Body.Close() } diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index d42014c265b..2c3737653b3 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -54,6 +54,7 @@ var ServeFileRangeTests = []struct { } func TestServeFile(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) @@ -169,6 +170,7 @@ var fsRedirectTestData = []struct { } func TestFSRedirect(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) defer ts.Close() @@ -193,6 +195,7 @@ func (fs *testFileSystem) Open(name string) (File, error) { } func TestFileServerCleans(t *testing.T) { + defer afterTest(t) ch := make(chan string, 1) fs := FileServer(&testFileSystem{func(name string) (File, error) { ch <- name @@ -224,6 +227,7 @@ func mustRemoveAll(dir string) { } func TestFileServerImplicitLeadingSlash(t *testing.T) { + defer afterTest(t) tempDir, err := ioutil.TempDir("", "") if err != nil { t.Fatalf("TempDir: %v", err) @@ -302,6 +306,7 @@ func TestEmptyDirOpenCWD(t *testing.T) { } func TestServeFileContentType(t *testing.T) { + defer afterTest(t) const ctype = "icecream/chocolate" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("override") == "1" { @@ -318,12 +323,14 @@ func TestServeFileContentType(t *testing.T) { if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } + resp.Body.Close() } get("0", "text/plain; charset=utf-8") get("1", ctype) } func TestServeFileMimeType(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") })) @@ -332,6 +339,7 @@ func TestServeFileMimeType(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() want := "text/css; charset=utf-8" if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) @@ -339,6 +347,7 @@ func TestServeFileMimeType(t *testing.T) { } func TestServeFileFromCWD(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") })) @@ -354,6 +363,7 @@ func TestServeFileFromCWD(t *testing.T) { } func TestServeFileWithContentEncoding(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -363,12 +373,14 @@ func TestServeFileWithContentEncoding(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() if g, e := resp.ContentLength, int64(-1); g != e { t.Errorf("Content-Length mismatch: got %d, want %d", g, e) } } func TestServeIndexHtml(t *testing.T) { + defer afterTest(t) const want = "index.html says hello\n" ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -390,6 +402,7 @@ func TestServeIndexHtml(t *testing.T) { } func TestFileServerZeroByte(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -458,6 +471,7 @@ func (fs fakeFS) Open(name string) (File, error) { } func TestDirectoryIfNotModified(t *testing.T) { + defer afterTest(t) const indexContents = "I am a fake index.html file" fileMod := time.Unix(1000000000, 0).UTC() fileModStr := fileMod.Format(TimeFormat) @@ -531,6 +545,7 @@ func mustStat(t *testing.T, fileName string) os.FileInfo { } func TestServeContent(t *testing.T) { + defer afterTest(t) type serveParam struct { name string modtime time.Time @@ -663,6 +678,7 @@ func TestServeContent(t *testing.T) { // verifies that sendfile is being used on Linux func TestLinuxSendfile(t *testing.T) { + defer afterTest(t) if runtime.GOOS != "linux" { t.Skip("skipping; linux-only test") } @@ -681,7 +697,7 @@ func TestLinuxSendfile(t *testing.T) { defer ln.Close() var buf bytes.Buffer - child := exec.Command("strace", "-f", "-e!sigaltstack", os.Args[0], "-test.run=TestLinuxSendfileChild") + child := exec.Command("strace", "-f", "-q", "-e", "trace=sendfile,sendfile64", os.Args[0], "-test.run=TestLinuxSendfileChild") child.ExtraFiles = append(child.ExtraFiles, lnf) child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) child.Stdout = &buf diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index f479b7b4eb9..6374237fba1 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -103,21 +103,41 @@ type keyValues struct { values []string } -type byKey []keyValues - -func (s byKey) Len() int { return len(s) } -func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key } - -func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues { - kvs := make([]keyValues, 0, len(h)) +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +// TODO: convert this to a sync.Cache (issue 4720) +var headerSorterCache = make(chan *headerSorter, 8) + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + select { + case hs = <-headerSorterCache: + default: + hs = new(headerSorter) + } + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] for k, vv := range h { if !exclude[k] { kvs = append(kvs, keyValues{k, vv}) } } - sort.Sort(byKey(kvs)) - return kvs + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs } // WriteSubset writes a header in wire format. @@ -127,7 +147,8 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { if !ok { ws = stringWriter{w} } - for _, kv := range h.sortedKeyValues(exclude) { + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) v = textproto.TrimString(v) @@ -138,6 +159,10 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { } } } + select { + case headerSorterCache <- sorter: + default: + } return nil } diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go index 01bb4dce000..584f1005440 100644 --- a/libgo/go/net/http/header_test.go +++ b/libgo/go/net/http/header_test.go @@ -175,38 +175,33 @@ func TestHasToken(t *testing.T) { } } -func BenchmarkHeaderWriteSubset(b *testing.B) { - doHeaderWriteSubset(b.N, b) +var testHeader = Header{ + "Content-Length": {"123"}, + "Content-Type": {"text/plain"}, + "Date": {"some date at some time Z"}, + "Server": {DefaultUserAgent}, } -func TestHeaderWriteSubsetMallocs(t *testing.T) { - doHeaderWriteSubset(100, t) -} +var buf bytes.Buffer -type errorfer interface { - Errorf(string, ...interface{}) +func BenchmarkHeaderWriteSubset(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + } } -func doHeaderWriteSubset(n int, t errorfer) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) - h := Header(map[string][]string{ - "Content-Length": {"123"}, - "Content-Type": {"text/plain"}, - "Date": {"some date at some time Z"}, - "Server": {"Go http package"}, - }) - var buf bytes.Buffer - var m0 runtime.MemStats - runtime.ReadMemStats(&m0) - for i := 0; i < n; i++ { - buf.Reset() - h.WriteSubset(&buf, nil) +func TestHeaderWriteSubsetMallocs(t *testing.T) { + t.Skip("Skipping alloc count test on gccgo") + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") } - var m1 runtime.MemStats - runtime.ReadMemStats(&m1) - if mallocs := m1.Mallocs - m0.Mallocs; n >= 100 && mallocs >= uint64(n) { - // TODO(bradfitz,rsc): once we can sort without allocating, - // make this an error. See http://golang.org/issue/3761 - // t.Errorf("did %d mallocs (>= %d iterations); should have avoided mallocs", mallocs, n) + n := testing.AllocsPerRun(100, func() { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + }) + if n > 0 { + t.Errorf("mallocs = %d; want 0", n) } } diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index fc52c9a2efc..7f265552f52 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -21,7 +21,11 @@ import ( type Server struct { URL string // base URL of form http://ipaddr:port with no trailing slash Listener net.Listener - TLS *tls.Config // nil if not using TLS + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config // Config may be changed after calling NewUnstartedServer and // before Start or StartTLS. @@ -119,9 +123,16 @@ func (s *Server) StartTLS() { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } - s.TLS = &tls.Config{ - NextProtos: []string{"http/1.1"}, - Certificates: []tls.Certificate{cert}, + existingConfig := s.TLS + s.TLS = new(tls.Config) + if existingConfig != nil { + *s.TLS = *existingConfig + } + if s.TLS.NextProtos == nil { + s.TLS.NextProtos = []string{"http/1.1"} + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} } tlsListener := tls.NewListener(s.Listener, s.TLS) @@ -189,28 +200,29 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.h.ServeHTTP(w, r) } -// localhostCert is a PEM-encoded TLS cert with SAN DNS names +// localhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). +// generated from src/pkg/crypto/tls: +// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX -DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 -qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL -8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB -/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC -CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB -nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw -Pg== +MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD +bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj +bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa +IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA +AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud +EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk +Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== -----END CERTIFICATE-----`) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v -tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi -SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0 -3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ -/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN -poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh -AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93 ------END RSA PRIVATE KEY----- -`) +MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 +0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV +NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d +AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW +MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD +EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA +1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +-----END RSA PRIVATE KEY-----`) diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 5afe9ba74e3..3e87c27bc36 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -68,7 +68,7 @@ var dumpTests = []dumpTest{ WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Accept-Encoding: gzip\r\n\r\n", }, @@ -80,7 +80,7 @@ var dumpTests = []dumpTest{ WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Accept-Encoding: gzip\r\n\r\n", }, } diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 134c452999d..1990f64dbd8 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -81,6 +81,19 @@ func copyHeader(dst, src http.Header) { } } +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { @@ -96,14 +109,21 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.ProtoMinor = 1 outreq.Close = false - // Remove the connection header to the backend. We want a - // persistent connection, regardless of what the client sent - // to us. This is modifying the same underlying map from req - // (shallow copied above) so we only copy it if necessary. - if outreq.Header.Get("Connection") != "" { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, req.Header) - outreq.Header.Del("Connection") + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(h) + } } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { @@ -182,7 +202,6 @@ func (m *maxLatencyWriter) flushLoop() { m.lk.Unlock() } } - panic("unreached") } func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 8639271626f..1c0444ec486 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -29,6 +29,9 @@ func TestReverseProxy(t *testing.T) { if c := r.Header.Get("Connection"); c != "" { t.Errorf("handler got Connection header value %q", c) } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } if g, e := r.Host, "some-name"; g != e { t.Errorf("backend got Host header %q, want %q", g, e) } @@ -49,6 +52,7 @@ func TestReverseProxy(t *testing.T) { getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Host = "some-name" getReq.Header.Set("Connection", "close") + getReq.Header.Set("Upgrade", "foo") getReq.Close = true res, err := http.DefaultClient.Do(getReq) if err != nil { diff --git a/libgo/go/net/http/jar.go b/libgo/go/net/http/jar.go index 35eee682f9f..5c3de0dad25 100644 --- a/libgo/go/net/http/jar.go +++ b/libgo/go/net/http/jar.go @@ -12,6 +12,8 @@ import ( // // Implementations of CookieJar must be safe for concurrent use by multiple // goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. type CookieJar interface { // SetCookies handles the receipt of the cookies in a reply for the // given URL. It may or may not choose to save the cookies, depending diff --git a/libgo/go/net/http/npn_test.go b/libgo/go/net/http/npn_test.go new file mode 100644 index 00000000000..98b8930d064 --- /dev/null +++ b/libgo/go/net/http/npn_test.go @@ -0,0 +1,118 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + . "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNextProtoUpgrade(t *testing.T) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) + if r.TLS != nil { + w.Write([]byte(r.TLS.NegotiatedProtocol)) + } + if r.RemoteAddr == "" { + t.Error("request with no RemoteAddr") + } + if r.Body == nil { + t.Errorf("request with nil Body") + } + })) + ts.TLS = &tls.Config{ + NextProtos: []string{"unhandled-proto", "tls-0.9"}, + } + ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){ + "tls-0.9": handleTLSProtocol09, + } + ts.StartTLS() + defer ts.Close() + + tr := newTLSTransport(t, ts) + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + // Normal request, without NPN. + { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "path=/,proto="; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } + + // Request to an advertised but unhandled NPN protocol. + // Server will hang up. + { + tr.CloseIdleConnections() + tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} + _, err := c.Get(ts.URL) + if err == nil { + t.Errorf("expected error on unhandled-proto request") + } + } + + // Request using the "tls-0.9" protocol, which we register here. + // It is HTTP/0.9 over TLS. + { + tlsConfig := newTLSTransport(t, ts).TLSClientConfig + tlsConfig.NextProtos = []string{"tls-0.9"} + conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET /foo\n")) + body, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if want := "path=/foo,proto=tls-0.9"; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } +} + +// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the +// TestNextProtoUpgrade test. +func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) { + br := bufio.NewReader(conn) + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimSpace(line) + path := strings.TrimPrefix(line, "GET ") + if path == line { + return + } + req, _ := NewRequest("GET", path, nil) + req.Proto = "HTTP/0.9" + req.ProtoMajor = 0 + req.ProtoMinor = 9 + rw := &http09Writer{conn, make(Header)} + h.ServeHTTP(rw, req) +} + +type http09Writer struct { + io.Writer + h Header +} + +func (w http09Writer) Header() Header { return w.h } +func (w http09Writer) WriteHeader(int) {} // no headers diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 0c03e5b2b75..0c7548e3ef3 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -172,7 +172,7 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // listing the available profiles. func Index(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, "/debug/pprof/") { - name := r.URL.Path[len("/debug/pprof/"):] + name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/") if name != "" { handler(name).ServeHTTP(w, r) return diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 217f35b4833..6d4569146fd 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -48,7 +48,7 @@ var ( ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} - ErrMissingBoundary = &ProtocolError{"no multipart boundary param Content-Type"} + ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} ) type badStringError struct { @@ -283,7 +283,7 @@ func valueOrDefault(value, def string) string { return def } -const defaultUserAgent = "Go http package" +const defaultUserAgent = "Go 1.1 package http" // Write writes an HTTP/1.1 request -- header and body -- in wire format. // This method consults the following fields of the request: @@ -467,10 +467,42 @@ func (r *Request) SetBasicAuth(username, password string) { r.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) } +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +// TODO(bradfitz): use a sync.Cache when available +var textprotoReaderCache = make(chan *textproto.Reader, 4) + +func newTextprotoReader(br *bufio.Reader) *textproto.Reader { + select { + case r := <-textprotoReaderCache: + r.R = br + return r + default: + return textproto.NewReader(br) + } +} + +func putTextprotoReader(r *textproto.Reader) { + r.R = nil + select { + case textprotoReaderCache <- r: + default: + } +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err error) { - tp := textproto.NewReader(b) + tp := newTextprotoReader(b) req = new(Request) // First line: GET /index.html HTTP/1.0 @@ -479,18 +511,18 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { return nil, err } defer func() { + putTextprotoReader(tp) if err == io.EOF { err = io.ErrUnexpectedEOF } }() - var f []string - if f = strings.SplitN(s, " ", 3); len(f) < 3 { + var ok bool + req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s) + if !ok { return nil, &badStringError{"malformed HTTP request", s} } - req.Method, req.RequestURI, req.Proto = f[0], f[1], f[2] rawurl := req.RequestURI - var ok bool if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { return nil, &badStringError{"malformed HTTP version", req.Proto} } diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index bd757920b78..692485c49d9 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -262,7 +262,39 @@ func TestNewRequestContentLength(t *testing.T) { t.Fatal(err) } if req.ContentLength != tt.want { - t.Errorf("ContentLength(%#T) = %d; want %d", tt.r, req.ContentLength, tt.want) + t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want) + } + } +} + +var parseHTTPVersionTests = []struct { + vers string + major, minor int + ok bool +}{ + {"HTTP/0.9", 0, 9, true}, + {"HTTP/1.0", 1, 0, true}, + {"HTTP/1.1", 1, 1, true}, + {"HTTP/3.14", 3, 14, true}, + + {"HTTP", 0, 0, false}, + {"HTTP/one.one", 0, 0, false}, + {"HTTP/1.1/", 0, 0, false}, + {"HTTP/-1,0", 0, 0, false}, + {"HTTP/0,-1", 0, 0, false}, + {"HTTP/", 0, 0, false}, + {"HTTP/1,1", 0, 0, false}, +} + +func TestParseHTTPVersion(t *testing.T) { + for _, tt := range parseHTTPVersionTests { + major, minor, ok := ParseHTTPVersion(tt.vers) + if ok != tt.ok || major != tt.major || minor != tt.minor { + type version struct { + major, minor int + ok bool + } + t.Errorf("failed to parse %q, expected: %#v, got %#v", tt.vers, version{tt.major, tt.minor, tt.ok}, version{major, minor, ok}) } } } @@ -289,7 +321,7 @@ func TestRequestWriteBufferedWriter(t *testing.T) { want := []string{ "GET / HTTP/1.1\r\n", "Host: foo.com\r\n", - "User-Agent: Go http package\r\n", + "User-Agent: " + DefaultUserAgent + "\r\n", "\r\n", } if !reflect.DeepEqual(got, want) { @@ -401,3 +433,81 @@ Content-Disposition: form-data; name="textb" ` + textbValue + ` --MyBoundary-- ` + +func benchmarkReadRequest(b *testing.B, request string) { + request = request + "\n" // final \n + request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n + b.SetBytes(int64(len(request))) + r := bufio.NewReader(&infiniteReader{buf: []byte(request)}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ReadRequest(r) + if err != nil { + b.Fatalf("failed to read request: %v", err) + } + } +} + +// infiniteReader satisfies Read requests as if the contents of buf +// loop indefinitely. +type infiniteReader struct { + buf []byte + offset int +} + +func (r *infiniteReader) Read(b []byte) (int, error) { + n := copy(b, r.buf[r.offset:]) + r.offset = (r.offset + n) % len(r.buf) + return n, nil +} + +func BenchmarkReadRequestChrome(b *testing.B) { + // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Connection: keep-alive +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false +`) +} + +func BenchmarkReadRequestCurl(b *testing.B) { + // curl http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +User-Agent: curl/7.27.0 +Host: localhost:8080 +Accept: */* +`) +} + +func BenchmarkReadRequestApachebench(b *testing.B) { + // ab -n 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.0 +Host: localhost:8080 +User-Agent: ApacheBench/2.3 +Accept: */* +`) +} + +func BenchmarkReadRequestSiege(b *testing.B) { + // siege -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Accept: */* +Accept-Encoding: gzip +User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70) +Connection: keep-alive +`) +} + +func BenchmarkReadRequestWrk(b *testing.B) { + // wrk -t 1 -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +`) +} diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index fc3186f0c0c..b27b1f7ce3b 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -93,13 +93,13 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), }, @@ -123,14 +123,14 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), @@ -156,7 +156,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "\r\n" + @@ -164,7 +164,7 @@ var reqWriteTests = []reqWriteTest{ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "\r\n" + @@ -187,14 +187,14 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 6\r\n" + "\r\n" + "abcdef", WantProxy: "POST http://example.com/ HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 6\r\n" + "\r\n" + "abcdef", @@ -210,7 +210,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "\r\n", }, @@ -232,13 +232,13 @@ var reqWriteTests = []reqWriteTest{ // Also, nginx expects it for POST and PUT. WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 0\r\n" + "\r\n", WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 0\r\n" + "\r\n", }, @@ -258,13 +258,13 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), }, @@ -325,9 +325,96 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /foo HTTP/1.1\r\n" + "Host: \r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "X-Foo: X-Bar\r\n\r\n", }, + + // If no Request.Host and no Request.URL.Host, we send + // an empty Host header, and don't use + // Request.Header["Host"]. This is just testing that + // we don't change Go 1.0 behavior. + { + Req: Request{ + Method: "GET", + Host: "", + URL: &url.URL{ + Scheme: "http", + Host: "", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Host": []string{"bad.example.com"}, + }, + }, + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Opaque test #1 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Opaque: "/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Opaque test #2 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "x.google.com", + Opaque: "//y.google.com/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" + + "Host: x.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Testing custom case in header keys. Issue 5022. + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "ALL-CAPS": {"x"}, + }, + }, + + WantWrite: "GET / HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "ALL-CAPS: x\r\n" + + "\r\n", + }, } func TestRequestWrite(t *testing.T) { @@ -411,7 +498,7 @@ func TestRequestWriteClosesBody(t *testing.T) { } expected := "POST / HTTP/1.1\r\n" + "Host: foo.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + // TODO: currently we don't buffer before chunking, so we get a // single "m" chunk before the other chunks, as this was the 1-byte diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 7901c49f5a5..9a7e4e319b0 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -46,6 +46,9 @@ type Response struct { // The http Client and Transport guarantee that Body is always // non-nil, even on responses without a body or responses with // a zero-lengthed body. + // + // The Body is automatically dechunked if the server replied + // with a "chunked" Transfer-Encoding. Body io.ReadCloser // ContentLength records the length of the associated content. The @@ -198,9 +201,7 @@ func (r *Response) Write(w io.Writer) error { } protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) statusCode := strconv.Itoa(r.StatusCode) + " " - if strings.HasPrefix(text, statusCode) { - text = text[len(statusCode):] - } + text = strings.TrimPrefix(text, statusCode) io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n") // Process Body,ContentLength,Close,Trailer diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index a00a4ae0a9b..02796e88b4c 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -112,8 +112,8 @@ var respTests = []respTest{ ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{ - "Connection": {"close"}, // TODO(rsc): Delete? - "Content-Length": {"10"}, // TODO(rsc): Delete? + "Connection": {"close"}, + "Content-Length": {"10"}, }, Close: true, ContentLength: 10, @@ -157,7 +157,7 @@ var respTests = []respTest{ "Content-Length: 10\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + "0\r\n" + "\r\n", @@ -170,7 +170,7 @@ var respTests = []respTest{ Request: dummyReq("GET"), Header: Header{}, Close: false, - ContentLength: -1, // TODO(rsc): Fix? + ContentLength: -1, TransferEncoding: []string{"chunked"}, }, @@ -324,16 +324,37 @@ var respTests = []respTest{ "", }, + + // golang.org/issue/4767: don't special-case multipart/byteranges responses + { + `HTTP/1.1 206 Partial Content +Connection: close +Content-Type: multipart/byteranges; boundary=18a75608c8f47cef + +some body`, + Response{ + Status: "206 Partial Content", + StatusCode: 206, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"}, + }, + Close: true, + ContentLength: -1, + }, + + "some body", + }, } func TestReadResponse(t *testing.T) { - for i := range respTests { - tt := &respTests[i] - var braw bytes.Buffer - braw.WriteString(tt.Raw) - resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request) + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) if err != nil { - t.Errorf("#%d: %s", i, err) + t.Errorf("#%d: %v", i, err) continue } rbody := resp.Body @@ -341,7 +362,11 @@ func TestReadResponse(t *testing.T) { diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp) var bout bytes.Buffer if rbody != nil { - io.Copy(&bout, rbody) + _, err = io.Copy(&bout, rbody) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } rbody.Close() } body := bout.String() @@ -351,6 +376,22 @@ func TestReadResponse(t *testing.T) { } } +func TestWriteResponse(t *testing.T) { + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + bout := bytes.NewBuffer(nil) + err = resp.Write(bout) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + } +} + var readResponseCloseInMiddleTests = []struct { chunked, compressed bool }{ @@ -425,7 +466,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { if test.compressed { gzReader, err := gzip.NewReader(resp.Body) checkErr(err, "gzip.NewReader") - resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } rbuf := make([]byte, 2500) diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 886ed4e8f74..d7b321597c4 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -64,10 +65,39 @@ func (a dummyAddr) String() string { return string(a) } +type noopConn struct{} + +func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") } +func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") } +func (noopConn) SetDeadline(t time.Time) error { return nil } +func (noopConn) SetReadDeadline(t time.Time) error { return nil } +func (noopConn) SetWriteDeadline(t time.Time) error { return nil } + +type rwTestConn struct { + io.Reader + io.Writer + noopConn + + closeFunc func() error // called if non-nil + closec chan bool // else, if non-nil, send value to it on close +} + +func (c *rwTestConn) Close() error { + if c.closeFunc != nil { + return c.closeFunc() + } + select { + case c.closec <- true: + default: + } + return nil +} + type testConn struct { readBuf bytes.Buffer writeBuf bytes.Buffer closec chan bool // if non-nil, send value to it on close + noopConn } func (c *testConn) Read(b []byte) (int, error) { @@ -86,26 +116,6 @@ func (c *testConn) Close() error { return nil } -func (c *testConn) LocalAddr() net.Addr { - return dummyAddr("local-addr") -} - -func (c *testConn) RemoteAddr() net.Addr { - return dummyAddr("remote-addr") -} - -func (c *testConn) SetDeadline(t time.Time) error { - return nil -} - -func (c *testConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *testConn) SetWriteDeadline(t time.Time) error { - return nil -} - func TestConsumingBodyOnNextConn(t *testing.T) { conn := new(testConn) for i := 0; i < 2; i++ { @@ -184,6 +194,7 @@ var vtests = []struct { } func TestHostHandlers(t *testing.T) { + defer afterTest(t) mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) @@ -256,28 +267,22 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - // TODO(bradfitz): convert this to use httptest.Server - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen error: %v", err) - } - addr, _ := l.Addr().(*net.TCPAddr) - + defer afterTest(t) reqNum := 0 - handler := HandlerFunc(func(res ResponseWriter, req *Request) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - }) - - server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond} - go server.Serve(l) - - url := fmt.Sprintf("http://%s/", addr) + })) + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - r, err := c.Get(url) + r, err := c.Get(ts.URL) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -290,13 +295,13 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Now() - conn, err := net.Dial("tcp", addr.String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } buf := make([]byte, 1) n, err := conn.Read(buf) - latency := time.Now().Sub(t1) + latency := time.Since(t1) if n != 0 || err != io.EOF { t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) } @@ -307,7 +312,7 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, err = Get(url) + r, err = Get(ts.URL) if err != nil { t.Fatalf("http Get #2: %v", err) } @@ -317,11 +322,87 @@ func TestServerTimeouts(t *testing.T) { t.Errorf("Get #2 got %q, want %q", string(got), expected) } - l.Close() + if !testing.Short() { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + go io.Copy(ioutil.Discard, conn) + for i := 0; i < 5; i++ { + _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("on write %d: %v", i, err) + } + time.Sleep(ts.Config.ReadTimeout / 2) + } + } +} + +// golang.org/issue/4741 -- setting only a write timeout that triggers +// shouldn't cause a handler to block forever on reads (next HTTP +// request) that will never happen. +func TestOnlyWriteTimeout(t *testing.T) { + defer afterTest(t) + var conn net.Conn + var afterTimeoutErrc = make(chan error, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + buf := make([]byte, 512<<10) + _, err := w.Write(buf) + if err != nil { + t.Errorf("handler Write error: %v", err) + return + } + conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) + _, err = w.Write(buf) + afterTimeoutErrc <- err + })) + ts.Listener = trackLastConnListener{ts.Listener, &conn} + ts.Start() + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + errc := make(chan error) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + errc <- err + return + } + _, err = io.Copy(ioutil.Discard, res.Body) + errc <- err + }() + select { + case err := <-errc: + if err == nil { + t.Errorf("expected an error from Get request") + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Get error") + } + if err := <-afterTimeoutErrc; err == nil { + t.Error("expected write error after timeout") + } +} + +// trackLastConnListener tracks the last net.Conn that was accepted. +type trackLastConnListener struct { + net.Listener + last *net.Conn // destination +} + +func (l trackLastConnListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + *l.last = c + return } // TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { + defer afterTest(t) handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -367,10 +448,12 @@ func TestIdentityResponse(t *testing.T) { // Verify that ErrContentLength is returned url := ts.URL + "/?overwrite=1" - _, err := Get(url) + res, err := Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } + res.Body.Close() + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -395,6 +478,7 @@ func TestIdentityResponse(t *testing.T) { } func testTCPConnectionCloses(t *testing.T, req string, h Handler) { + defer afterTest(t) s := httptest.NewServer(h) defer s.Close() @@ -465,6 +549,7 @@ func TestHandlersCanSetConnectionClose10(t *testing.T) { } func TestSetsRemoteAddr(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) @@ -485,6 +570,7 @@ func TestSetsRemoteAddr(t *testing.T) { } func TestChunkedResponseHeaders(t *testing.T) { + defer afterTest(t) log.SetOutput(ioutil.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) @@ -499,6 +585,7 @@ func TestChunkedResponseHeaders(t *testing.T) { if err != nil { t.Fatalf("Get error: %v", err) } + defer res.Body.Close() if g, e := res.ContentLength, int64(-1); g != e { t.Errorf("expected ContentLength of %d; got %d", e, g) } @@ -514,6 +601,7 @@ func TestChunkedResponseHeaders(t *testing.T) { // chunking in their response headers and aren't allowed to produce // output. func Test304Responses(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) @@ -543,6 +631,7 @@ func Test304Responses(t *testing.T) { // allowed to produce output, and don't set a Content-Type since // the real type of the body data cannot be inferred. func TestHeadResponses(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("Ignored body")) if err != ErrBodyNotAllowed { @@ -577,6 +666,7 @@ func TestHeadResponses(t *testing.T) { } func TestTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts.Config.ReadTimeout = 250 * time.Millisecond ts.StartTLS() @@ -596,6 +686,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { } func TestTLSServer(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") @@ -678,6 +769,7 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. func TestServerExpect(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only @@ -815,6 +907,7 @@ func TestServerUnreadRequestBodyLarge(t *testing.T) { } func TestTimeoutHandler(t *testing.T) { + defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -889,6 +982,7 @@ func TestRedirectMunging(t *testing.T) { // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. func TestZeroLengthPostAndResponse(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) if err != nil { @@ -939,6 +1033,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { } func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { + defer afterTest(t) // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -958,6 +1053,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { pr, pw := io.Pipe() log.SetOutput(pw) defer log.SetOutput(os.Stderr) + defer pw.Close() ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if withHijack { @@ -979,7 +1075,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { buf := make([]byte, 4<<10) _, err := pr.Read(buf) pr.Close() - if err != nil { + if err != nil && err != io.EOF { t.Error(err) } done <- true @@ -1003,6 +1099,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { } func TestNoDate(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Date"] = nil })) @@ -1018,6 +1115,7 @@ func TestNoDate(t *testing.T) { } func TestStripPrefix(t *testing.T) { + defer afterTest(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) }) @@ -1031,6 +1129,7 @@ func TestStripPrefix(t *testing.T) { if g, e := res.Header.Get("X-Path"), "/bar"; g != e { t.Errorf("test 1: got %s, want %s", g, e) } + res.Body.Close() res, err = Get(ts.URL + "/bar") if err != nil { @@ -1039,9 +1138,11 @@ func TestStripPrefix(t *testing.T) { if g, e := res.StatusCode, 404; g != e { t.Errorf("test 2: got status %v, want %v", g, e) } + res.Body.Close() } func TestRequestLimit(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") })) @@ -1058,6 +1159,7 @@ func TestRequestLimit(t *testing.T) { // we do support it (at least currently), so we expect a response below. t.Fatalf("Do: %v", err) } + defer res.Body.Close() if res.StatusCode != 413 { t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) } @@ -1084,6 +1186,7 @@ func (cr countReader) Read(p []byte) (n int, err error) { } func TestRequestBodyLimit(t *testing.T) { + defer afterTest(t) const limit = 1 << 20 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) @@ -1120,6 +1223,7 @@ func TestRequestBodyLimit(t *testing.T) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. func TestClientWriteShutdown(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1174,6 +1278,7 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See http://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) })) @@ -1216,6 +1321,7 @@ func TestServerGracefulClose(t *testing.T) { } func TestCaseSensitiveMethod(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) @@ -1264,6 +1370,7 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { + defer afterTest(t) gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -1299,6 +1406,31 @@ For: ts.Close() } +func TestCloseNotifierChanLeak(t *testing.T) { + defer afterTest(t) + req := []byte(strings.Replace(`GET / HTTP/1.0 +Host: golang.org + +`, "\n", "\r\n", -1)) + for i := 0; i < 20; i++ { + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + // Ignore the return value and never read from + // it, testing that we don't leak goroutines + // on the sending side: + _ = rw.(CloseNotifier).CloseNotify() + }) + go Serve(ln, handler) + <-conn.closec + } +} + func TestOptions(t *testing.T) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() @@ -1351,6 +1483,198 @@ func TestOptions(t *testing.T) { } } +// Tests regarding the ordering of Write, WriteHeader, Header, and +// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the +// (*response).header to the wire. In Go 1.1, the actual wire flush is +// delayed, so we could maybe tack on a Content-Length and better +// Content-Type after we see more (or all) of the output. To preserve +// compatibility with Go 1, we need to be careful to track which +// headers were live at the time of WriteHeader, so we write the same +// ones, even if the handler modifies them (~erroneously) after the +// first Write. +func TestHeaderToWire(t *testing.T) { + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + + tests := []struct { + name string + handler func(ResponseWriter, *Request) + check func(output string) error + }{ + { + name: "write without Header", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("no content-length") + } + return nil + }, + }, + { + name: "Header mutation before write", + handler: func(rw ResponseWriter, r *Request) { + h := rw.Header() + h.Set("Content-Type", "some/type") + rw.Write([]byte("hello world")) + h.Set("Too-Late", "bogus") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-type") + } + if strings.Contains(got, "Too-Late") { + return errors.New("don't want too-late header") + } + return nil + }, + }, + { + name: "write then useless Header mutation", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "flush then write", + handler: func(rw ResponseWriter, r *Request) { + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "header then flush", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length") + } + return nil + }, + }, + { + name: "sniff-on-first-write content-type", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + rw.Header().Set("Content-Type", "x/wrong") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/html") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "explicit content-type wins", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "empty handler", + handler: func(rw ResponseWriter, r *Request) { + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("wrong content-length; want text/plain") + } + if !strings.Contains(got, "Content-Length: 0") { + return errors.New("want 0 content-length") + } + return nil + }, + }, + { + name: "only Header, no write", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Some-Header", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "Some-Header") { + return errors.New("didn't get header") + } + return nil + }, + }, + { + name: "WriteHeader call", + handler: func(rw ResponseWriter, r *Request) { + rw.WriteHeader(404) + rw.Header().Set("Too-Late", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "404") { + return errors.New("wrong status") + } + if strings.Contains(got, "Some-Header") { + return errors.New("shouldn't have seen Too-Late") + } + return nil + }, + }, + } + for _, tc := range tests { + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + go Serve(ln, HandlerFunc(tc.handler)) + <-conn.closec + if err := tc.check(output.String()); err != nil { + t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, output.Bytes()) + } + } +} + // goTimeout runs f, failing t if f takes more than ns to complete. func goTimeout(t *testing.T, d time.Duration, f func()) { ch := make(chan bool, 2) @@ -1524,3 +1848,179 @@ func BenchmarkServer(b *testing.B) { b.Errorf("Test failure: %v, with output: %s", err, out) } } + +func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) { + b.ReportAllocs() + req := []byte(strings.Replace(`GET / HTTP/1.0 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &testConn{ + // testConn.Close will not push into the channel + // if it's full. + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := new(oneConnListener) + for i := 0; i < b.N; i++ { + conn.readBuf.Reset() + conn.writeBuf.Reset() + conn.readBuf.Write(req) + ln.conn = conn + Serve(ln, handler) + <-conn.closec + } +} + +// repeatReader reads content count times, then EOFs. +type repeatReader struct { + content []byte + count int + off int +} + +func (r *repeatReader) Read(p []byte) (n int, err error) { + if r.count <= 0 { + return 0, io.EOF + } + n = copy(p, r.content[r.off:]) + r.off += n + if r.off == len(r.content) { + r.count-- + r.off = 0 + } + return +} + +func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) { + b.ReportAllocs() + + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +// same as above, but representing the most simple possible request +// and handler. Notably: the handler does not call rw.Header(). +func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) { + b.ReportAllocs() + + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +const someResponse = "<html>some response</html>" + +// A Response that's just no bigger than 2KB, the buffer-before-chunking threshold. +var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse)) + +// Both Content-Type and Content-Length set. Should be no buffering. +func BenchmarkServerHandlerTypeLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// A Content-Type is set, but no length. No sniffing, but will count the Content-Length. +func BenchmarkServerHandlerNoLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Write(response) + })) +} + +// A Content-Length is set, but the Content-Type will be sniffed. +func BenchmarkServerHandlerNoType(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// Neither a Content-Type or Content-Length, so sniffed and counted. +func BenchmarkServerHandlerNoHeader(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write(response) + })) +} + +func benchmarkHandler(b *testing.B, h Handler) { + b.ReportAllocs() + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + h.ServeHTTP(rw, r) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 434943d49a3..b2596070500 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -4,9 +4,6 @@ // HTTP server. See RFC 2616. -// TODO(rsc): -// logging - package http import ( @@ -109,9 +106,11 @@ type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection - sr switchReader // where the LimitReader reads from; usually the rwc + sr liveSwitchReader // where the LimitReader reads from; usually the rwc lr *io.LimitedReader // io.LimitReader(sr) buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc + bufswr *switchReader // the *switchReader io.Reader source of buf + bufsww *switchWriter // the *switchWriter io.Writer dest of buf tlsState *tls.ConnectionState // or nil when not using TLS mu sync.Mutex // guards the following @@ -147,7 +146,7 @@ func (c *conn) closeNotify() <-chan bool { c.mu.Lock() defer c.mu.Unlock() if c.closeNotifyc == nil { - c.closeNotifyc = make(chan bool) + c.closeNotifyc = make(chan bool, 1) if c.hijackedv { // to obey the function signature, even though // it'll never receive a value. @@ -180,12 +179,26 @@ func (c *conn) noteClientGone() { c.clientGone = true } +// A switchReader can have its Reader changed at runtime. +// It's not safe for concurrent Reads and switches. type switchReader struct { + io.Reader +} + +// A switchWriter can have its Writer changed at runtime. +// It's not safe for concurrent Writes and switches. +type switchWriter struct { + io.Writer +} + +// A liveSwitchReader is a switchReader that's safe for concurrent +// reads and switches, if its mutex is held. +type liveSwitchReader struct { sync.Mutex r io.Reader } -func (sr *switchReader) Read(p []byte) (n int, err error) { +func (sr *liveSwitchReader) Read(p []byte) (n int, err error) { sr.Lock() r := sr.r sr.Unlock() @@ -206,15 +219,28 @@ const bufferBeforeChunkingSize = 2048 // // See the comment above (*response).Write for the entire write flow. type chunkWriter struct { - res *response - header Header // a deep copy of r.Header, once WriteHeader is called - wroteHeader bool // whether the header's been sent + res *response + + // header is either nil or a deep clone of res.handlerHeader + // at the time of res.WriteHeader, if res.WriteHeader is + // called and extra buffering is being done to calculate + // Content-Type and/or Content-Length. + header Header + + // wroteHeader tells whether the header's been written to "the + // wire" (or rather: w.conn.buf). this is unlike + // (*response).wroteHeader, which tells only whether it was + // logically written. + wroteHeader bool // set by the writeHeader method: chunking bool // using chunked transfer encoding for reply body } -var crlf = []byte("\r\n") +var ( + crlf = []byte("\r\n") + colonSpace = []byte(": ") +) func (cw *chunkWriter) Write(p []byte) (n int, err error) { if !cw.wroteHeader { @@ -223,6 +249,7 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { if cw.chunking { _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) if err != nil { + cw.res.conn.rwc.Close() return } } @@ -230,6 +257,9 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { if cw.chunking && err == nil { _, err = cw.res.conn.buf.Write(crlf) } + if err != nil { + cw.res.conn.rwc.Close() + } return } @@ -260,13 +290,15 @@ type response struct { wroteContinue bool // 100 Continue response was written w *bufio.Writer // buffers output in chunks to chunkWriter - cw *chunkWriter + cw chunkWriter + sw *switchWriter // of the bufio.Writer, for return to putBufioWriter // handlerHeader is the Header that Handlers get access to, // which may be retained and mutated even after WriteHeader. // handlerHeader is copied into cw.header at WriteHeader // time, and privately mutated thereafter. handlerHeader Header + calledHeader bool // handler accessed handlerHeader via Header written int64 // number of bytes written in body contentLength int64 // explicitly-declared Content-Length; or -1 @@ -358,14 +390,98 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } - c.sr = switchReader{r: c.rwc} + c.sr = liveSwitchReader{r: c.rwc} c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) - br := bufio.NewReader(c.lr) - bw := bufio.NewWriter(c.rwc) + br, sr := newBufioReader(c.lr) + bw, sw := newBufioWriterSize(c.rwc, 4<<10) c.buf = bufio.NewReadWriter(br, bw) + c.bufswr = sr + c.bufsww = sw return c, nil } +// TODO: remove this, if issue 5100 is fixed +type bufioReaderPair struct { + br *bufio.Reader + sr *switchReader // from which the bufio.Reader is reading +} + +// TODO: remove this, if issue 5100 is fixed +type bufioWriterPair struct { + bw *bufio.Writer + sw *switchWriter // to which the bufio.Writer is writing +} + +// TODO: use a sync.Cache instead +var ( + bufioReaderCache = make(chan bufioReaderPair, 4) + bufioWriterCache2k = make(chan bufioWriterPair, 4) + bufioWriterCache4k = make(chan bufioWriterPair, 4) +) + +func bufioWriterCache(size int) chan bufioWriterPair { + switch size { + case 2 << 10: + return bufioWriterCache2k + case 4 << 10: + return bufioWriterCache4k + } + return nil +} + +func newBufioReader(r io.Reader) (*bufio.Reader, *switchReader) { + select { + case p := <-bufioReaderCache: + p.sr.Reader = r + return p.br, p.sr + default: + sr := &switchReader{r} + return bufio.NewReader(sr), sr + } +} + +func putBufioReader(br *bufio.Reader, sr *switchReader) { + if n := br.Buffered(); n > 0 { + io.CopyN(ioutil.Discard, br, int64(n)) + } + br.Read(nil) // clears br.err + sr.Reader = nil + select { + case bufioReaderCache <- bufioReaderPair{br, sr}: + default: + } +} + +func newBufioWriterSize(w io.Writer, size int) (*bufio.Writer, *switchWriter) { + select { + case p := <-bufioWriterCache(size): + p.sw.Writer = w + return p.bw, p.sw + default: + sw := &switchWriter{w} + return bufio.NewWriterSize(sw, size), sw + } +} + +func putBufioWriter(bw *bufio.Writer, sw *switchWriter) { + if bw.Buffered() > 0 { + // It must have failed to flush to its target + // earlier. We can't reuse this bufio.Writer. + return + } + if err := bw.Flush(); err != nil { + // Its sticky error field is set, which is returned by + // Flush even when there's no data buffered. This + // bufio Writer is dead to us. Don't reuse it. + return + } + sw.Writer = nil + select { + case bufioWriterCache(bw.Available()) <- bufioWriterPair{bw, sw}: + default: + } +} + // DefaultMaxHeaderBytes is the maximum permitted size of the headers // in an HTTP request. // This can be overridden by setting Server.MaxHeaderBytes. @@ -416,6 +532,16 @@ func (c *conn) readRequest() (w *response, err error) { if c.hijacked() { return nil, ErrHijacked } + + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { @@ -434,14 +560,20 @@ func (c *conn) readRequest() (w *response, err error) { req: req, handlerHeader: make(Header), contentLength: -1, - cw: new(chunkWriter), } w.cw.res = w - w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize) + w.w, w.sw = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) return w, nil } func (w *response) Header() Header { + if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader { + // Accessing the header between logically writing it + // and physically writing it means we need to allocate + // a clone to snapshot the logically written state. + w.cw.header = w.handlerHeader.clone() + } + w.calledHeader = true return w.handlerHeader } @@ -468,15 +600,48 @@ func (w *response) WriteHeader(code int) { w.wroteHeader = true w.status = code - w.cw.header = w.handlerHeader.clone() + if w.calledHeader && w.cw.header == nil { + w.cw.header = w.handlerHeader.clone() + } - if cl := w.cw.header.get("Content-Length"); cl != "" { + if cl := w.handlerHeader.get("Content-Length"); cl != "" { v, err := strconv.ParseInt(cl, 10, 64) if err == nil && v >= 0 { w.contentLength = v } else { log.Printf("http: invalid Content-Length of %q", cl) - w.cw.header.Del("Content-Length") + w.handlerHeader.Del("Content-Length") + } + } +} + +// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader. +// This type is used to avoid extra allocations from cloning and/or populating +// the response Header map and all its 1-element slices. +type extraHeader struct { + contentType string + contentLength string + connection string + date string + transferEncoding string +} + +// Sorted the same as extraHeader.Write's loop. +var extraHeaderKeys = [][]byte{ + []byte("Content-Type"), []byte("Content-Length"), + []byte("Connection"), []byte("Date"), []byte("Transfer-Encoding"), +} + +// The value receiver, despite copying 5 strings to the stack, +// prevents an extra allocation. The escape analysis isn't smart +// enough to realize this doesn't mutate h. +func (h extraHeader) Write(w io.Writer) { + for i, v := range []string{h.contentType, h.contentLength, h.connection, h.date, h.transferEncoding} { + if v != "" { + w.Write(extraHeaderKeys[i]) + w.Write(colonSpace) + io.WriteString(w, v) + w.Write(crlf) } } } @@ -496,23 +661,47 @@ func (cw *chunkWriter) writeHeader(p []byte) { cw.wroteHeader = true w := cw.res - code := w.status - done := w.handlerDone + + // header is written out to w.conn.buf below. Depending on the + // state of the handler, we either own the map or not. If we + // don't own it, the exclude map is created lazily for + // WriteSubset to remove headers. The setHeader struct holds + // headers we need to add. + header := cw.header + owned := header != nil + if !owned { + header = w.handlerHeader + } + var excludeHeader map[string]bool + delHeader := func(key string) { + if owned { + header.Del(key) + return + } + if _, ok := header[key]; !ok { + return + } + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[key] = true + } + var setHeader extraHeader // If the handler is done but never sent a Content-Length // response header and this is our first (and last) write, set // it, even to zero. This helps HTTP/1.0 clients keep their // "keep-alive" connections alive. - if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" { + if w.handlerDone && header.get("Content-Length") == "" && w.req.Method != "HEAD" { w.contentLength = int64(len(p)) - cw.header.Set("Content-Length", strconv.Itoa(len(p))) + setHeader.contentLength = strconv.Itoa(len(p)) } // If this was an HTTP/1.0 request with keep-alive and we sent a // Content-Length back, we can make this a keep-alive response ... if w.req.wantsHttp10KeepAlive() { - sentLength := cw.header.get("Content-Length") != "" - if sentLength && cw.header.get("Connection") == "keep-alive" { + sentLength := header.get("Content-Length") != "" + if sentLength && header.get("Connection") == "keep-alive" { w.closeAfterReply = false } } @@ -521,15 +710,15 @@ func (cw *chunkWriter) writeHeader(p []byte) { hasCL := w.contentLength != -1 if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { - _, connectionHeaderSet := cw.header["Connection"] + _, connectionHeaderSet := header["Connection"] if !connectionHeaderSet { - cw.header.Set("Connection", "keep-alive") + setHeader.connection = "keep-alive" } } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { w.closeAfterReply = true } - if cw.header.get("Connection") == "close" { + if header.get("Connection") == "close" { w.closeAfterReply = true } @@ -543,49 +732,49 @@ func (cw *chunkWriter) writeHeader(p []byte) { n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) if n >= maxPostHandlerReadBytes { w.requestTooLarge() - cw.header.Set("Connection", "close") + delHeader("Connection") + setHeader.connection = "close" } else { w.req.Body.Close() } } } + code := w.status if code == StatusNotModified { // Must not have body. - for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" - if cw.header.get(header) != "" { - cw.header.Del(header) - } + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + for _, k := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { + delHeader(k) } } else { // If no content type, apply sniffing algorithm to body. - if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" { - cw.header.Set("Content-Type", DetectContentType(p)) + if header.get("Content-Type") == "" && w.req.Method != "HEAD" { + setHeader.contentType = DetectContentType(p) } } - if _, ok := cw.header["Date"]; !ok { - cw.header.Set("Date", time.Now().UTC().Format(TimeFormat)) + if _, ok := header["Date"]; !ok { + setHeader.date = time.Now().UTC().Format(TimeFormat) } - te := cw.header.get("Transfer-Encoding") + te := header.get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", te, w.contentLength) - cw.header.Del("Content-Length") + delHeader("Content-Length") hasCL = false } if w.req.Method == "HEAD" || code == StatusNotModified { // do nothing } else if code == StatusNoContent { - cw.header.Del("Transfer-Encoding") + delHeader("Transfer-Encoding") } else if hasCL { - cw.header.Del("Transfer-Encoding") + delHeader("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. @@ -593,29 +782,63 @@ func (cw *chunkWriter) writeHeader(p []byte) { // might have set. Deal with that as need arises once we have a valid // use case. cw.chunking = true - cw.header.Set("Transfer-Encoding", "chunked") + setHeader.transferEncoding = "chunked" } else { // HTTP version < 1.1: cannot do chunked transfer // encoding and we don't know the Content-Length so // signal EOF by closing connection. w.closeAfterReply = true - cw.header.Del("Transfer-Encoding") // in case already set + delHeader("Transfer-Encoding") // in case already set } // Cannot use Content-Length with non-identity Transfer-Encoding. if cw.chunking { - cw.header.Del("Content-Length") + delHeader("Content-Length") } if !w.req.ProtoAtLeast(1, 0) { return } if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { - cw.header.Set("Connection", "close") + delHeader("Connection") + setHeader.connection = "close" } + io.WriteString(w.conn.buf, statusLine(w.req, code)) + cw.header.WriteSubset(w.conn.buf, excludeHeader) + setHeader.Write(w.conn.buf) + w.conn.buf.Write(crlf) +} + +// statusLines is a cache of Status-Line strings, keyed by code (for +// HTTP/1.1) or negative code (for HTTP/1.0). This is faster than a +// map keyed by struct of two fields. This map's max size is bounded +// by 2*len(statusText), two protocol types for each known official +// status code in the statusText map. +var ( + statusMu sync.RWMutex + statusLines = make(map[int]string) +) + +// statusLine returns a response Status-Line (RFC 2616 Section 6.1) +// for the given request and response status code. +func statusLine(req *Request, code int) string { + // Fast path: + key := code + proto11 := req.ProtoAtLeast(1, 1) + if !proto11 { + key = -key + } + statusMu.RLock() + line, ok := statusLines[key] + statusMu.RUnlock() + if ok { + return line + } + + // Slow path: proto := "HTTP/1.0" - if w.req.ProtoAtLeast(1, 1) { + if proto11 { proto = "HTTP/1.1" } codestring := strconv.Itoa(code) @@ -623,9 +846,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { if !ok { text = "status code " + codestring } - io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - cw.header.Write(w.conn.buf) - w.conn.buf.Write(crlf) + line = proto + " " + codestring + " " + text + "\r\n" + if ok { + statusMu.Lock() + defer statusMu.Unlock() + statusLines[key] = line + } + return line } // bodyAllowed returns true if a Write is allowed for this response type. @@ -641,7 +868,7 @@ func (w *response) bodyAllowed() bool { // // Handler starts. No header has been sent. The handler can either // write a header, or just start writing. Writing before sending a header -// sends an implicity empty 200 OK header. +// sends an implicitly empty 200 OK header. // // If the handler didn't declare a Content-Length up front, we either // go into chunking mode or, if the handler finishes running before @@ -699,6 +926,7 @@ func (w *response) finishRequest() { } w.w.Flush() + putBufioWriter(w.w, w.sw) w.cw.close() w.conn.buf.Flush() @@ -728,6 +956,15 @@ func (w *response) Flush() { func (c *conn) finalFlush() { if c.buf != nil { c.buf.Flush() + + // Steal the bufio.Reader (~4KB worth of memory) and its associated + // reader for a future connection. + putBufioReader(c.buf.Reader, c.bufswr) + + // Steal the bufio.Writer (~4KB worth of memory) and its associated + // writer for a future connection. + putBufioWriter(c.buf.Writer, c.bufsww) + c.buf = nil } } @@ -764,6 +1001,18 @@ func (c *conn) closeWriteAndWait() { time.Sleep(rstAvoidanceDelay) } +// validNPN returns whether the proto is not a blacklisted Next +// Protocol Negotiation protocol. Empty and built-in protocol types +// are blacklisted and can't be overridden with alternate +// implementations. +func validNPN(proto string) bool { + switch proto { + case "", "http/1.1", "http/1.0": + return false + } + return true +} + // Serve a new connection. func (c *conn) serve() { defer func() { @@ -779,11 +1028,24 @@ func (c *conn) serve() { }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + } if err := tlsConn.Handshake(); err != nil { return } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() + if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) { + if fn := c.server.TLSNextProto[proto]; fn != nil { + h := initNPNRequest{tlsConn, serverHandler{c.server}} + fn(c.server, tlsConn, h) + } + return + } } for { @@ -826,20 +1088,12 @@ func (c *conn) serve() { break } - handler := c.server.Handler - if handler == nil { - handler = DefaultServeMux - } - if req.RequestURI == "*" && req.Method == "OPTIONS" { - handler = globalOptionsHandler{} - } - // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. - handler.ServeHTTP(w, w.req) + serverHandler{c.server}.ServeHTTP(w, w.req) if c.hijacked() { return } @@ -917,13 +1171,16 @@ func NotFoundHandler() Handler { return HandlerFunc(NotFound) } // request for a path that doesn't begin with prefix by // replying with an HTTP 404 not found error. func StripPrefix(prefix string, h Handler) Handler { + if prefix == "" { + return h + } return HandlerFunc(func(w ResponseWriter, r *Request) { - if !strings.HasPrefix(r.URL.Path, prefix) { + if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + r.URL.Path = p + h.ServeHTTP(w, r) + } else { NotFound(w, r) - return } - r.URL.Path = r.URL.Path[len(prefix):] - h.ServeHTTP(w, r) }) } @@ -965,9 +1222,9 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } // clean up but preserve trailing slash - trailing := urlStr[len(urlStr)-1] == '/' + trailing := strings.HasSuffix(urlStr, "/") urlStr = path.Clean(urlStr) - if trailing && urlStr[len(urlStr)-1] != '/' { + if trailing && !strings.HasSuffix(urlStr, "/") { urlStr += "/" } urlStr += query @@ -1232,6 +1489,32 @@ type Server struct { WriteTimeout time.Duration // maximum duration before timing out write of the response MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // TLSNextProto optionally specifies a function to take over + // ownership of the provided TLS connection when an NPN + // protocol upgrade has occurred. The map key is the protocol + // name negotiated. The Handler argument should be used to + // handle HTTP requests and will initialize the Request's TLS + // and RemoteAddr if not already set. The connection is + // automatically closed when the function returns. + TLSNextProto map[string]func(*Server, *tls.Conn, Handler) +} + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultServeMux + } + if req.RequestURI == "*" && req.Method == "OPTIONS" { + handler = globalOptionsHandler{} + } + handler.ServeHTTP(rw, req) } // ListenAndServe listens on the TCP network address srv.Addr and then @@ -1274,19 +1557,12 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - if srv.ReadTimeout != 0 { - rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } c, err := srv.newConn(rw) if err != nil { continue } go c.serve() } - panic("not reached") } // ListenAndServe listens on the TCP network address addr @@ -1425,7 +1701,7 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { - done := make(chan bool) + done := make(chan bool, 1) tw := &timeoutWriter{w: w} go func() { h.handler.ServeHTTP(tw, r) @@ -1494,6 +1770,31 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } } +// eofReader is a non-nil io.ReadCloser that always returns EOF. +var eofReader = ioutil.NopCloser(strings.NewReader("")) + +// initNPNRequest is an HTTP handler that initializes certain +// uninitialized fields in its *Request. Such partially-initialized +// Requests come from NPN protocol handlers. +type initNPNRequest struct { + c *tls.Conn + h serverHandler +} + +func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { + if req.TLS == nil { + req.TLS = &tls.ConnectionState{} + *req.TLS = h.c.ConnectionState() + } + if req.Body == nil { + req.Body = eofReader + } + if req.RemoteAddr == "" { + req.RemoteAddr = h.c.RemoteAddr().String() + } + h.h.ServeHTTP(rw, req) +} + // loggingConn is used for debugging. type loggingConn struct { name string diff --git a/libgo/go/net/http/server_test.go b/libgo/go/net/http/server_test.go index 8b4e8c6d6f6..e8b69f76cce 100644 --- a/libgo/go/net/http/server_test.go +++ b/libgo/go/net/http/server_test.go @@ -2,9 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http +package http_test import ( + . "net/http" + "net/http/httptest" "net/url" "testing" ) @@ -76,20 +78,27 @@ func TestServeMuxHandler(t *testing.T) { }, } h, pattern := mux.Handler(r) - cs := &codeSaver{h: Header{}} - h.ServeHTTP(cs, r) - if pattern != tt.pattern || cs.code != tt.code { - t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, r) + if pattern != tt.pattern || rr.Code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern) } } } -// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader. -type codeSaver struct { - h Header - code int +func TestServerRedirect(t *testing.T) { + // This used to crash. It's not valid input (bad path), but it + // shouldn't crash. + rr := httptest.NewRecorder() + req := &Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Path: "not-empty-but-no-leading-slash", // bogus + }, + } + Redirect(rr, req, "", 304) + if rr.Code != 304 { + t.Errorf("Code = %d; want 304", rr.Code) + } } - -func (cs *codeSaver) Header() Header { return cs.h } -func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil } -func (cs *codeSaver) WriteHeader(code int) { cs.code = code } diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index 8ab72ac23f9..106d94ec1cb 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -54,6 +54,7 @@ func TestDetectContentType(t *testing.T) { } func TestServerContentType(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] @@ -84,6 +85,8 @@ func TestServerContentType(t *testing.T) { } func TestContentTypeWithCopy(t *testing.T) { + defer afterTest(t) + const ( input = "\n<html>\n\t<head>\n" expected = "text/html; charset=utf-8" @@ -116,6 +119,7 @@ func TestContentTypeWithCopy(t *testing.T) { } func TestSniffWriteSize(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) @@ -133,6 +137,11 @@ func TestSniffWriteSize(t *testing.T) { if err != nil { t.Fatalf("size %d: %v", size, err) } - res.Body.Close() + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatalf("size %d: io.Copy of body = %v", size, err) + } + if err := res.Body.Close(); err != nil { + t.Fatalf("size %d: body Close = %v", size, err) + } } } diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go index 5af0b77c428..d253bd5cb54 100644 --- a/libgo/go/net/http/status.go +++ b/libgo/go/net/http/status.go @@ -51,6 +51,13 @@ const ( StatusServiceUnavailable = 503 StatusGatewayTimeout = 504 StatusHTTPVersionNotSupported = 505 + + // New HTTP status codes from RFC 6585. Not exported yet in Go 1.1. + // See discussion at https://codereview.appspot.com/7678043/ + statusPreconditionRequired = 428 + statusTooManyRequests = 429 + statusRequestHeaderFieldsTooLarge = 431 + statusNetworkAuthenticationRequired = 511 ) var statusText = map[int]string{ @@ -99,6 +106,11 @@ var statusText = map[int]string{ StatusServiceUnavailable: "Service Unavailable", StatusGatewayTimeout: "Gateway Timeout", StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + + statusPreconditionRequired: "Precondition Required", + statusTooManyRequests: "Too Many Requests", + statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", + statusNetworkAuthenticationRequired: "Network Authentication Required", } // StatusText returns a text for the HTTP status code. It returns the empty diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 25b34addec7..53569bcc2fc 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -194,10 +194,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { ncopy, err = io.Copy(w, t.Body) } else { ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) - nextra, err := io.Copy(ioutil.Discard, t.Body) if err != nil { return err } + var nextra int64 + nextra, err = io.Copy(ioutil.Discard, t.Body) ncopy += nextra } if err != nil { @@ -208,7 +209,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { } } - if t.ContentLength != -1 && t.ContentLength != ncopy { + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { return fmt.Errorf("http: Request.ContentLength=%d with Body length %d", t.ContentLength, ncopy) } @@ -326,9 +327,14 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} - case realLength >= 0: - // TODO: limit the Content-Length. This is an easy DoS vector. + if noBodyExpected(t.RequestMethod) { + t.Body = &body{Reader: eofReader, closing: t.Close} + } else { + t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = &body{Reader: eofReader, closing: t.Close} + case realLength > 0: t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header @@ -337,7 +343,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.Body = &body{Reader: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) - t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + t.Body = &body{Reader: eofReader, closing: t.Close} } } @@ -449,13 +455,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return 0, nil } - // Logic based on media type. The purpose of the following code is just - // to detect whether the unsupported "multipart/byteranges" is being - // used. A proper Content-Type parser is needed in the future. - if strings.Contains(strings.ToLower(header.get("Content-Type")), "multipart/byteranges") { - return -1, ErrNotSupported - } - // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } @@ -614,30 +613,26 @@ func (b *body) Close() error { if b.closed { return nil } - defer func() { - b.closed = true - }() - if b.hdr == nil && b.closing { + var err error + switch { + case b.hdr == nil && b.closing: // no trailer and closing the connection next. // no point in reading to EOF. - return nil - } - - // In a server request, don't continue reading from the client - // if we've already hit the maximum body size set by the - // handler. If this is set, that also means the TCP connection - // is about to be closed, so getting to the next HTTP request - // in the stream is not necessary. - if b.res != nil && b.res.requestBodyLimitHit { - return nil - } - - // Fully consume the body, which will also lead to us reading - // the trailer headers after the body, if present. - if _, err := io.Copy(ioutil.Discard, b); err != nil { - return err + case b.res != nil && b.res.requestBodyLimitHit: + // In a server request, don't continue reading from the client + // if we've already hit the maximum body size set by the + // handler. If this is set, that also means the TCP connection + // is about to be closed, so getting to the next HTTP request + // in the stream is not necessary. + case b.Reader == eofReader: + // Nothing to read. No need to io.Copy from it. + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(ioutil.Discard, b) } - return nil + b.closed = true + return err } // parseContentLength trims whitespace from s and returns -1 if no value diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 98e198e78af..4cd0533ffc2 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -17,7 +17,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/url" @@ -42,14 +41,13 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - idleLk sync.Mutex - idleConn map[string][]*persistConn - altLk sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper - - // TODO: tunable on global max cached connections - // TODO: tunable on timeout on cached connections - // TODO: optional pipelining + idleMu sync.Mutex + idleConn map[string][]*persistConn + idleConnCh map[string]chan *persistConn + reqMu sync.Mutex + reqConn map[*Request]*persistConn + altMu sync.RWMutex + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -60,19 +58,39 @@ type Transport struct { // Dial specifies the dial function for creating TCP // connections. // If Dial is nil, net.Dial is used. - Dial func(net, addr string) (c net.Conn, err error) + Dial func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config - DisableKeepAlives bool + // DisableKeepAlives, if true, prevents re-use of TCP connections + // between different HTTP requests. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. DisableCompression bool // MaxIdleConnsPerHost, if non-zero, controls the maximum idle // (keep-alive) to keep per-host. If zero, // DefaultMaxIdleConnsPerHost is used. MaxIdleConnsPerHost int + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + + // TODO: tunable on global max cached connections + // TODO: tunable on timeout on cached connections } // ProxyFromEnvironment returns the URL of the proxy to use for a @@ -125,6 +143,9 @@ func (tr *transportRequest) extraHeaders() Header { } // RoundTrip implements the RoundTripper interface. +// +// For higher-level HTTP client support (such as handling of cookies +// and redirects), see Get, Post, and the Client type. func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { if req.URL == nil { return nil, errors.New("http: nil Request.URL") @@ -133,12 +154,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - t.altLk.RLock() + t.altMu.RLock() var rt RoundTripper if t.altProto != nil { rt = t.altProto[req.URL.Scheme] } - t.altLk.RUnlock() + t.altMu.RUnlock() if rt == nil { return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } @@ -175,8 +196,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { if scheme == "http" || scheme == "https" { panic("protocol " + scheme + " already registered") } - t.altLk.Lock() - defer t.altLk.Unlock() + t.altMu.Lock() + defer t.altMu.Unlock() if t.altProto == nil { t.altProto = make(map[string]RoundTripper) } @@ -191,10 +212,10 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - t.idleLk.Lock() + t.idleMu.Lock() m := t.idleConn t.idleConn = nil - t.idleLk.Unlock() + t.idleMu.Unlock() if m == nil { return } @@ -205,6 +226,17 @@ func (t *Transport) CloseIdleConnections() { } } +// CancelRequest cancels an in-flight request by closing its +// connection. +func (t *Transport) CancelRequest(req *Request) { + t.reqMu.Lock() + pc := t.reqConn[req] + t.reqMu.Unlock() + if pc != nil { + pc.conn.Close() + } +} + // // Private implementation past this point. // @@ -260,12 +292,23 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { if max == 0 { max = DefaultMaxIdleConnsPerHost } - t.idleLk.Lock() + t.idleMu.Lock() + select { + case t.idleConnCh[key] <- pconn: + // We're done with this pconn and somebody else is + // currently waiting for a conn of this type (they're + // actively dialing, but this conn is ready + // first). Chrome calls this socket late binding. See + // https://insouciant.org/tech/connection-management-in-chromium/ + t.idleMu.Unlock() + return true + default: + } if t.idleConn == nil { t.idleConn = make(map[string][]*persistConn) } if len(t.idleConn[key]) >= max { - t.idleLk.Unlock() + t.idleMu.Unlock() pconn.close() return false } @@ -275,14 +318,29 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { } } t.idleConn[key] = append(t.idleConn[key], pconn) - t.idleLk.Unlock() + t.idleMu.Unlock() return true } +func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { + key := cm.key() + t.idleMu.Lock() + defer t.idleMu.Unlock() + if t.idleConnCh == nil { + t.idleConnCh = make(map[string]chan *persistConn) + } + ch, ok := t.idleConnCh[key] + if !ok { + ch = make(chan *persistConn) + t.idleConnCh[key] = ch + } + return ch +} + func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { - key := cm.String() - t.idleLk.Lock() - defer t.idleLk.Unlock() + key := cm.key() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return nil } @@ -304,7 +362,19 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } } - panic("unreachable") +} + +func (t *Transport) setReqConn(r *Request, pc *persistConn) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqConn == nil { + t.reqConn = make(map[*Request]*persistConn) + } + if pc != nil { + t.reqConn[r] = pc + } else { + delete(t.reqConn, r) + } } func (t *Transport) dial(network, addr string) (c net.Conn, err error) { @@ -323,6 +393,37 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { return pc, nil } + type dialRes struct { + pc *persistConn + err error + } + dialc := make(chan dialRes) + go func() { + pc, err := t.dialConn(cm) + dialc <- dialRes{pc, err} + }() + + idleConnCh := t.getIdleConnCh(cm) + select { + case v := <-dialc: + // Our dial finished. + return v.pc, v.err + case pc := <-idleConnCh: + // Another request finished first and its net.Conn + // became available before our dial. Or somebody + // else's dial that they didn't use. + // But our dial is still going, so give it away + // when it finishes: + go func() { + if v := <-dialc; v.err == nil { + t.putIdleConn(v.pc) + } + }() + return pc, nil + } +} + +func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { conn, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { @@ -335,7 +436,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { pconn := &persistConn{ t: t, - cacheKey: cm.String(), + cacheKey: cm.key(), conn: conn, reqch: make(chan requestAndChan, 50), writech: make(chan writeRequest, 50), @@ -485,6 +586,10 @@ type connectMethod struct { targetAddr string // Not used if proxy + http targetScheme (4th example in table) } +func (ck *connectMethod) key() string { + return ck.String() // TODO: use a struct type instead +} + func (ck *connectMethod) String() string { proxyStr := "" targetAddr := ck.targetAddr @@ -529,14 +634,13 @@ type persistConn struct { closech chan struct{} // broadcast close when readLoop (TCP connection) closes isProxy bool + lk sync.Mutex // guards following 3 fields + numExpectedResponses int + broken bool // an error has happened on this connection; marked broken so it's not reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the // original Request given to RoundTrip is not modified) mutateHeaderFunc func(Header) - - lk sync.Mutex // guards numExpectedResponses and broken - numExpectedResponses int - broken bool // an error has happened on this connection; marked broken so it's not reused. } func (pc *persistConn) isBroken() bool { @@ -561,7 +665,6 @@ func remoteSideClosed(err error) bool { func (pc *persistConn) readLoop() { defer close(pc.closech) alive := true - var lastbody io.ReadCloser // last response body, if any, read on this connection for alive { pb, err := pc.br.Peek(1) @@ -580,22 +683,23 @@ func (pc *persistConn) readLoop() { rc := <-pc.reqch - // Advance past the previous response's body, if the - // caller hasn't done so. - if lastbody != nil { - lastbody.Close() // assumed idempotent - lastbody = nil - } - var resp *Response if err == nil { resp, err = ReadResponse(pc.br, rc.req) + if err == nil && resp.StatusCode == 100 { + // Skip any 100-continue for now. + // TODO(bradfitz): if rc.req had "Expect: 100-continue", + // actually block the request body write and signal the + // writeLoop now to begin sending it. (Issue 2184) For now we + // eat it, since we're never expecting one. + resp, err = ReadResponse(pc.br, rc.req) + } } + hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 if err != nil { pc.close() } else { - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") @@ -605,21 +709,29 @@ func (pc *persistConn) readLoop() { pc.close() err = zerr } else { - resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } } resp.Body = &bodyEOFSignal{body: resp.Body} } - if err != nil || resp.Close || rc.req.Close { + if err != nil || resp.Close || rc.req.Close || resp.StatusCode <= 199 { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. alive = false } - hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 var waitForBodyRead chan bool if hasBody { - lastbody = resp.Body - waitForBodyRead = make(chan bool, 1) + waitForBodyRead = make(chan bool, 2) + resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { + // Sending false here sets alive to + // false and closes the connection + // below. + waitForBodyRead <- false + return nil + } resp.Body.(*bodyEOFSignal).fn = func(err error) { alive1 := alive if err != nil { @@ -636,15 +748,6 @@ func (pc *persistConn) readLoop() { } if alive && !hasBody { - // When there's no response body, we immediately - // reuse the TCP connection (putIdleConn), but - // we need to prevent ClientConn.Read from - // closing the Response.Body on the next - // loop, otherwise it might close the body - // before the client code has had a chance to - // read it (even though it'll just be 0, EOF). - lastbody = nil - if !pc.t.putIdleConn(pc) { alive = false } @@ -658,6 +761,8 @@ func (pc *persistConn) readLoop() { alive = <-waitForBodyRead } + pc.t.setReqConn(rc.req, nil) + if !alive { pc.close() } @@ -711,8 +816,14 @@ type writeRequest struct { } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - if pc.mutateHeaderFunc != nil { - pc.mutateHeaderFunc(req.extraHeaders()) + pc.t.setReqConn(req.Request, pc) + pc.lk.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.lk.Unlock() + + if headerFn != nil { + headerFn(req.extraHeaders()) } // Ask for a compressed version if the caller didn't set their @@ -728,10 +839,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err req.extraHeaders().Set("Accept-Encoding", "gzip") } - pc.lk.Lock() - pc.numExpectedResponses++ - pc.lk.Unlock() - // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. @@ -744,6 +851,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err var re responseAndError var pconnDeadCh = pc.closech var failTicker <-chan time.Time + var respHeaderTimer <-chan time.Time WaitResponse: for { select { @@ -753,6 +861,9 @@ WaitResponse: pc.close() break WaitResponse } + if d := pc.t.ResponseHeaderTimeout; d > 0 { + respHeaderTimer = time.After(d) + } case <-pconnDeadCh: // The persist connection is dead. This shouldn't // usually happen (only with Connection: close responses @@ -769,7 +880,11 @@ WaitResponse: pconnDeadCh = nil // avoid spinning failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc case <-failTicker: - re = responseAndError{nil, errors.New("net/http: transport closed before response was received")} + re = responseAndError{err: errors.New("net/http: transport closed before response was received")} + break WaitResponse + case <-respHeaderTimer: + pc.close() + re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")} break WaitResponse case re = <-resc: break WaitResponse @@ -780,6 +895,9 @@ WaitResponse: pc.numExpectedResponses-- pc.lk.Unlock() + if re.err != nil { + pc.t.setReqConn(req.Request, nil) + } return re.res, re.err } @@ -823,13 +941,16 @@ func canonicalAddr(url *url.URL) string { // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // once, right before its final (error-producing) Read or Close call -// returns. +// returns. If earlyCloseFn is non-nil and Close is called before +// io.EOF is seen, earlyCloseFn is called instead of fn, and its +// return value is the return value from Close. type bodyEOFSignal struct { - body io.ReadCloser - mu sync.Mutex // guards closed, rerr and fn - closed bool // whether Close has been called - rerr error // sticky Read error - fn func(error) // error will be nil on Read io.EOF + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) // error will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen } func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { @@ -862,6 +983,9 @@ func (es *bodyEOFSignal) Close() error { return nil } es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } err := es.body.Close() es.condfn(err) return err @@ -879,28 +1003,7 @@ func (es *bodyEOFSignal) condfn(err error) { es.fn = nil } -type readFirstCloseBoth struct { - io.ReadCloser +type readerAndCloser struct { + io.Reader io.Closer } - -func (r *readFirstCloseBoth) Close() error { - if err := r.ReadCloser.Close(); err != nil { - r.Closer.Close() - return err - } - if err := r.Closer.Close(); err != nil { - return err - } - return nil -} - -// discardOnCloseReadCloser consumes all its input on Close. -type discardOnCloseReadCloser struct { - io.ReadCloser -} - -func (d *discardOnCloseReadCloser) Close() error { - io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed - return d.ReadCloser.Close() -} diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index daaecae341b..9f64a6e4b5f 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -7,6 +7,7 @@ package http_test import ( + "bufio" "bytes" "compress/gzip" "crypto/rand" @@ -102,11 +103,13 @@ func (tcs *testConnSet) check(t *testing.T) { // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() for _, disableKeepAlive := range []bool{false, true} { tr := &Transport{DisableKeepAlives: disableKeepAlive} + defer tr.CloseIdleConnections() c := &Client{Transport: tr} fetch := func(n int) string { @@ -133,6 +136,7 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -183,6 +187,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { } func TestTransportConnectionCloseOnRequest(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -233,6 +238,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { } func TestTransportIdleCacheKeys(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -265,6 +271,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { + defer afterTest(t) resch := make(chan string) gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -333,6 +340,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportServerClosingUnexpectedly(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -389,6 +397,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for http://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { + defer afterTest(t) if testing.Short() { t.Skip("skipping test in short mode") } @@ -444,6 +453,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly func TestTransportHeadResponses(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -472,6 +482,7 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -513,6 +524,7 @@ var roundTripTests = []struct { // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { + defer afterTest(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") @@ -569,6 +581,7 @@ func TestRoundTripGzip(t *testing.T) { } func TestTransportGzip(t *testing.T) { + defer afterTest(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -661,6 +674,7 @@ func TestTransportGzip(t *testing.T) { } func TestTransportProxy(t *testing.T) { + defer afterTest(t) ch := make(chan string, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ch <- "real server" @@ -689,6 +703,7 @@ func TestTransportProxy(t *testing.T) { // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. func TestTransportGzipRecursive(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) @@ -715,6 +730,7 @@ func TestTransportGzipRecursive(t *testing.T) { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + defer afterTest(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -780,6 +796,7 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -818,6 +835,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { // This used to crash; http://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { + defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} @@ -847,6 +865,7 @@ func TestTransportIdleConnCrash(t *testing.T) { // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. func TestIssue3644(t *testing.T) { + defer afterTest(t) const numFoos = 5000 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") @@ -874,6 +893,7 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. func TestIssue3595(t *testing.T) { + defer afterTest(t) const deniedMsg = "sorry, denied." ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) @@ -898,6 +918,7 @@ func TestIssue3595(t *testing.T) { // From http://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" func TestChunkedNoContent(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })) @@ -920,20 +941,38 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { - const maxProcs = 16 - const numReqs = 500 + defer afterTest(t) + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) })) defer ts.Close() - tr := &Transport{} + + var wg sync.WaitGroup + wg.Add(numReqs) + + tr := &Transport{ + Dial: func(netw, addr string) (c net.Conn, err error) { + // Due to the Transport's "socket late + // binding" (see idleConnCh in transport.go), + // the numReqs HTTP requests below can finish + // with a dial still outstanding. So count + // our dials as work too so the leak checker + // doesn't complain at us. + wg.Add(1) + defer wg.Done() + return net.Dial(netw, addr) + }, + } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} reqs := make(chan string) defer close(reqs) - var wg sync.WaitGroup - wg.Add(numReqs) for i := 0; i < maxProcs*2; i++ { go func() { for req := range reqs { @@ -952,8 +991,8 @@ func TestTransportConcurrency(t *testing.T) { if string(all) != req { t.Errorf("body of req %s = %q; want %q", req, all, req) } - wg.Done() res.Body.Close() + wg.Done() } }() } @@ -964,12 +1003,14 @@ func TestTransportConcurrency(t *testing.T) { } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond client := &Client{ Transport: &Transport{ @@ -978,7 +1019,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { if err != nil { return nil, err } - conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + conn.SetDeadline(time.Now().Add(timeout)) if debug { conn = NewLoggingConn("client", conn) } @@ -988,6 +1029,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { }, } + getFailed := false nRuns := 5 if testing.Short() { nRuns = 1 @@ -998,6 +1040,14 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } sres, err := client.Get(ts.URL + "/get") if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } t.Errorf("Error issuing GET: %v", err) break } @@ -1014,6 +1064,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { @@ -1024,6 +1075,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { io.Copy(ioutil.Discard, r.Body) }) ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond client := &Client{ Transport: &Transport{ @@ -1032,7 +1084,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { if err != nil { return nil, err } - conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + conn.SetDeadline(time.Now().Add(timeout)) if debug { conn = NewLoggingConn("client", conn) } @@ -1042,6 +1094,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { }, } + getFailed := false nRuns := 5 if testing.Short() { nRuns = 1 @@ -1052,6 +1105,14 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { } sres, err := client.Get(ts.URL + "/get") if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } t.Errorf("Error issuing GET: %v", err) break } @@ -1070,6 +1131,171 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ts.Close() } +func TestTransportResponseHeaderTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping timeout test in -short mode") + } + mux := NewServeMux() + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) + mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + time.Sleep(2 * time.Second) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + tr := &Transport{ + ResponseHeaderTimeout: 500 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + tests := []struct { + path string + want int + wantErr string + }{ + {path: "/fast", want: 200}, + {path: "/slow", wantErr: "timeout awaiting response headers"}, + {path: "/fast", want: 200}, + } + for i, tt := range tests { + res, err := c.Get(ts.URL + tt.path) + if err != nil { + if strings.Contains(err.Error(), tt.wantErr) { + continue + } + t.Errorf("%d. unexpected error: %v", i, err) + continue + } + if tt.wantErr != "" { + t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) + continue + } + if res.StatusCode != tt.want { + t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) + } + } +} + +func TestTransportCancelRequest(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + tr.CancelRequest(req) + }() + t0 := time.Now() + body, err := ioutil.ReadAll(res.Body) + d := time.Since(t0) + + if err == nil { + t.Error("expected an error reading the body") + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 3; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + +// golang.org/issue/3672 -- Client can't close HTTP stream +// Calling Close on a Response.Body used to just read until EOF. +// Now it actually closes the TCP connection. +func TestTransportCloseResponseBody(t *testing.T) { + defer afterTest(t) + writeErr := make(chan error, 1) + msg := []byte("young\n") + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + for { + _, err := w.Write(msg) + if err != nil { + writeErr <- err + return + } + w.(Flusher).Flush() + } + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + defer tr.CancelRequest(req) + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + const repeats = 3 + buf := make([]byte, len(msg)*repeats) + want := bytes.Repeat(msg, repeats) + + _, err = io.ReadFull(res.Body, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want) { + t.Errorf("read %q; want %q", buf, want) + } + didClose := make(chan error, 1) + go func() { + didClose <- res.Body.Close() + }() + select { + case err := <-didClose: + if err != nil { + t.Errorf("Close = %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for close") + } + select { + case err := <-writeErr: + if err == nil { + t.Errorf("expected non-nil write error") + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for write error") + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { @@ -1083,6 +1309,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) { } func TestTransportAltProto(t *testing.T) { + defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} tr.RegisterProtocol("foo", fooProto{}) @@ -1101,6 +1328,7 @@ func TestTransportAltProto(t *testing.T) { } func TestTransportNoHost(t *testing.T) { + defer afterTest(t) tr := &Transport{} _, err := tr.RoundTrip(&Request{ Header: make(Header), @@ -1114,6 +1342,172 @@ func TestTransportNoHost(t *testing.T) { } } +func TestTransportSocketLateBinding(t *testing.T) { + defer afterTest(t) + + mux := NewServeMux() + fooGate := make(chan bool, 1) + mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { + w.Header().Set("foo-ipport", r.RemoteAddr) + w.(Flusher).Flush() + <-fooGate + }) + mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { + w.Header().Set("bar-ipport", r.RemoteAddr) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + dialGate := make(chan bool, 1) + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + <-dialGate + return net.Dial(n, addr) + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{ + Transport: tr, + } + + dialGate <- true // only allow one dial + fooRes, err := c.Get(ts.URL + "/foo") + if err != nil { + t.Fatal(err) + } + fooAddr := fooRes.Header.Get("foo-ipport") + if fooAddr == "" { + t.Fatal("No addr on /foo request") + } + time.AfterFunc(200*time.Millisecond, func() { + // let the foo response finish so we can use its + // connection for /bar + fooGate <- true + io.Copy(ioutil.Discard, fooRes.Body) + fooRes.Body.Close() + }) + + barRes, err := c.Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + barAddr := barRes.Header.Get("bar-ipport") + if barAddr != fooAddr { + t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) + } + barRes.Body.Close() + dialGate <- true +} + +// Issue 2184 +func TestTransportReading100Continue(t *testing.T) { + defer afterTest(t) + + const numReqs = 5 + reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } + reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } + + send100Response := func(w *io.PipeWriter, r *io.PipeReader) { + defer w.Close() + defer r.Close() + br := bufio.NewReader(r) + n := 0 + for { + n++ + req, err := ReadRequest(br) + if err == io.EOF { + return + } + if err != nil { + t.Error(err) + return + } + slurp, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("Server request body slurp: %v", err) + return + } + id := req.Header.Get("Request-Id") + resCode := req.Header.Get("X-Want-Response-Code") + if resCode == "" { + resCode = "100 Continue" + if string(slurp) != reqBody(n) { + t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) + } + } + body := fmt.Sprintf("Response number %d", n) + v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s +Date: Thu, 28 Feb 2013 17:55:41 GMT + +HTTP/1.1 200 OK +Content-Type: text/html +Echo-Request-Id: %s +Content-Length: %d + +%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) + w.Write(v) + if id == reqID(numReqs) { + return + } + } + + } + + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + sr, sw := io.Pipe() // server read/write + cr, cw := io.Pipe() // client read/write + conn := &rwTestConn{ + Reader: cr, + Writer: sw, + closeFunc: func() error { + sw.Close() + cw.Close() + return nil + }, + } + go send100Response(cw, sr) + return conn, nil + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + testResponse := func(req *Request, name string, wantCode int) { + res, err := c.Do(req) + if err != nil { + t.Fatalf("%s: Do: %v", name, err) + } + if res.StatusCode != wantCode { + t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) + } + if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { + t.Errorf("%s: response id %q != request id %q", name, idBack, id) + } + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: Slurp error: %v", name, err) + } + } + + // Few 100 responses, making sure we're not off-by-one. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("Request-Id", reqID(i)) + testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) + } + + // And some other informational 1xx but non-100 responses, to test + // we return them but don't re-use the connection. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("X-Want-Response-Code", "123 Sesame Street") + testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123) + } +} + type proxyFromEnvTest struct { req string // URL to fetch; blank means "http://example.com" env string diff --git a/libgo/go/net/http/z_last_test.go b/libgo/go/net/http/z_last_test.go new file mode 100644 index 00000000000..2161db7365d --- /dev/null +++ b/libgo/go/net/http/z_last_test.go @@ -0,0 +1,98 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "net/http" + "runtime" + "sort" + "strings" + "testing" + "time" +) + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if stack == "" || + strings.Contains(stack, "created by net.newPollServer") || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + sort.Strings(gs) + return +} + +// Verify the other tests didn't leave any goroutines running. +// This is in a file named z_last_test.go so it sorts at the end. +func TestGoroutinesRunning(t *testing.T) { + if testing.Short() { + t.Skip("not counting goroutines for leakage in -short mode") + } + gs := interestingGoroutines() + + n := 0 + stackCount := make(map[string]int) + for _, g := range gs { + stackCount[g]++ + n++ + } + + t.Logf("num goroutines = %d", n) + if n > 0 { + t.Error("Too many goroutines.") + for stack, count := range stackCount { + t.Logf("%d instances of:\n%s", count, stack) + } + } +} + +func afterTest(t *testing.T) { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() + if testing.Short() { + return + } + var bad string + badSubstring := map[string]string{ + ").readLoop(": "a Transport", + ").writeLoop(": "a Transport", + "created by net/http/httptest.(*Server).Start": "an httptest.Server", + "timeoutHandler": "a TimeoutHandler", + "net.(*netFD).connect(": "a timing out dial", + ").noteClientGone(": "a closenotifier sender", + } + var stacks string + for i := 0; i < 4; i++ { + bad = "" + stacks = strings.Join(interestingGoroutines(), "\n\n") + for substr, what := range badSubstring { + if strings.Contains(stacks, substr) { + bad = what + } + } + if bad == "" { + return + } + // Bad stuff found, but goroutines might just still be + // shutting down, so give it some time. + time.Sleep(250 * time.Millisecond) + } + t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) +} |