diff options
author | bstarynk <bstarynk@138bc75d-0d04-0410-961f-82ee72b054a4> | 2012-10-25 08:02:28 +0000 |
---|---|---|
committer | bstarynk <bstarynk@138bc75d-0d04-0410-961f-82ee72b054a4> | 2012-10-25 08:02:28 +0000 |
commit | f9a64dbd998f7761e6a06fc71052346d7f76c7f4 (patch) | |
tree | 3608e9a4fa99bbcc7d88dda34b1619a4ac4b122b /libgo/go/net/http | |
parent | 29a742dc2ec93b766a342fa6fb65da055c5417fc (diff) | |
download | gcc-f9a64dbd998f7761e6a06fc71052346d7f76c7f4.tar.gz |
2012-10-25 Basile Starynkevitch <basile@starynkevitch.net>
MELT branch merged with trunk rev 192797 using svnmerge.py
git-svn-id: svn+ssh://gcc.gnu.org/svn/gcc/branches/melt-branch@192798 138bc75d-0d04-0410-961f-82ee72b054a4
Diffstat (limited to 'libgo/go/net/http')
24 files changed, 1521 insertions, 445 deletions
diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 89441424e1d..02891db9adc 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -33,10 +33,11 @@ type Client struct { // CheckRedirect specifies the policy for handling redirects. // If CheckRedirect is not nil, the client calls it before - // following an HTTP redirect. The arguments req and via - // are the upcoming request and the requests made already, - // oldest first. If CheckRedirect returns an error, the client - // returns that error (wrapped in a url.Error) instead of + // following an HTTP redirect. The arguments req and via are + // the upcoming request and the requests made already, oldest + // first. If CheckRedirect returns an error, the Client's Get + // method returns both the previous Response and + // CheckRedirect's error (wrapped in a url.Error) instead of // issuing the Request req. // // If CheckRedirect is nil, the Client uses its default policy, @@ -95,7 +96,7 @@ type readClose struct { // // When err is nil, resp always contains a non-nil resp.Body. // -// Callers should close res.Body when done reading from it. If +// Callers should close resp.Body when done reading from it. If // resp.Body is not closed, the Client's underlying RoundTripper // (typically Transport) may not be able to re-use a persistent TCP // connection to the server for a subsequent "keep-alive" request. @@ -221,6 +222,7 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) req := ireq urlStr := "" // next relative or absolute URL to fetch (after first request) + redirectFailed := false for redirect := 0; ; redirect++ { if redirect != 0 { req = new(Request) @@ -239,6 +241,7 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) err = redirectChecker(req, via) if err != nil { + redirectFailed = true break } } @@ -268,16 +271,24 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) return } - if resp != nil { - resp.Body.Close() - } - method := ireq.Method - return nil, &url.Error{ + urlErr := &url.Error{ Op: method[0:1] + strings.ToLower(method[1:]), URL: urlStr, Err: err, } + + if redirectFailed { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See http://golang.org/issue/3795 + return resp, urlErr + } + + if resp != nil { + resp.Body.Close() + } + return nil, urlErr } func defaultCheckRedirect(req *Request, via []*Request) error { diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 09fcc1c0b40..c61b17d289b 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -235,6 +235,12 @@ func TestRedirects(t *testing.T) { if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err) } + if res == nil { + t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)") + } + if res.Header.Get("Location") == "" { + t.Errorf("no Location header in Response") + } } var expectedCookies = []*Cookie{ diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 2e30bbff177..43f519d1fba 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -258,10 +258,5 @@ func parseCookieValueUsing(raw string, validByte func(byte) bool) (string, bool) } func isCookieNameValid(raw string) bool { - for _, c := range raw { - if !isToken(byte(c)) { - return false - } - } - return true + return strings.IndexFunc(raw, isNotToken) < 0 } diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index 208d6cabb2c..b6bea0dfaad 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -100,6 +100,9 @@ func dirList(w ResponseWriter, f File) { // The content's Seek method must work: ServeContent uses // a seek to the end of the content to determine its size. // +// If the caller has set w's ETag header, ServeContent uses it to +// handle requests using If-Range and If-None-Match. +// // Note that *os.File implements the io.ReadSeeker interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { size, err := content.Seek(0, os.SEEK_END) @@ -122,6 +125,10 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, if checkLastModified(w, r, modtime) { return } + rangeReq, done := checkETag(w, r) + if done { + return + } code := StatusOK @@ -148,7 +155,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sendSize := size var sendContent io.Reader = content if size >= 0 { - ranges, err := parseRange(r.Header.Get("Range"), size) + ranges, err := parseRange(rangeReq, size) if err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return @@ -240,6 +247,9 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { // The Date-Modified header truncates sub-second precision, so // use mtime < t+1s instead of mtime <= t to check for unmodified. if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") w.WriteHeader(StatusNotModified) return true } @@ -247,6 +257,58 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { return false } +// checkETag implements If-None-Match and If-Range checks. +// The ETag must have been previously set in the ResponseWriter's headers. +// +// The return value is the effective request "Range" header to use and +// whether this request is now considered done. +func checkETag(w ResponseWriter, r *Request) (rangeReq string, done bool) { + etag := w.Header().get("Etag") + rangeReq = r.Header.get("Range") + + // Invalidate the range request if the entity doesn't match the one + // the client was expecting. + // "If-Range: version" means "ignore the Range: header unless version matches the + // current file." + // We only support ETag versions. + // The caller must have set the ETag on the response already. + if ir := r.Header.get("If-Range"); ir != "" && ir != etag { + // TODO(bradfitz): handle If-Range requests with Last-Modified + // times instead of ETags? I'd rather not, at least for + // now. That seems like a bug/compromise in the RFC 2616, and + // I've never heard of anybody caring about that (yet). + rangeReq = "" + } + + if inm := r.Header.get("If-None-Match"); inm != "" { + // Must know ETag. + if etag == "" { + return rangeReq, false + } + + // TODO(bradfitz): non-GET/HEAD requests require more work: + // sending a different status code on matches, and + // also can't use weak cache validators (those with a "W/ + // prefix). But most users of ServeContent will be using + // it on GET or HEAD, so only support those for now. + if r.Method != "GET" && r.Method != "HEAD" { + return rangeReq, false + } + + // TODO(bradfitz): deal with comma-separated or multiple-valued + // list of If-None-match values. For now just handle the common + // case of a single item. + if inm == etag || inm == "*" { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(StatusNotModified) + return "", true + } + } + return rangeReq, false +} + // name is '/'-separated, not filepath.Separator. func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { const indexPage = "/index.html" diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index 17329fbd59a..09d5cfaf2d4 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -81,6 +81,7 @@ func TestServeFile(t *testing.T) { } // Range tests +Cases: for _, rt := range ServeFileRangeTests { if rt.r != "" { req.Header.Set("Range", rt.r) @@ -109,7 +110,7 @@ func TestServeFile(t *testing.T) { t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) } if strings.HasPrefix(ct, "multipart/byteranges") { - t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r) + t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct) } } if len(rt.ranges) > 1 { @@ -119,37 +120,41 @@ func TestServeFile(t *testing.T) { continue } if typ != "multipart/byteranges" { - t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r) + t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ) continue } if params["boundary"] == "" { t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct) + continue } if g, w := resp.ContentLength, int64(len(body)); g != w { t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w) + continue } mr := multipart.NewReader(bytes.NewReader(body), params["boundary"]) for ri, rng := range rt.ranges { part, err := mr.NextPart() if err != nil { - t.Fatalf("range=%q, reading part index %d: %v", rt.r, ri, err) + t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err) + continue Cases + } + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { + t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) } body, err := ioutil.ReadAll(part) if err != nil { - t.Fatalf("range=%q, reading part index %d body: %v", rt.r, ri, err) + t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err) + continue Cases } - wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) wantBody := file[rng.start:rng.end] if !bytes.Equal(body, wantBody) { t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) } - if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { - t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) - } } _, err = mr.NextPart() if err != io.EOF { - t.Errorf("range=%q; expected final error io.EOF; got %v", err) + t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err) } } } @@ -335,11 +340,6 @@ func TestServeFileMimeType(t *testing.T) { } func TestServeFileFromCWD(t *testing.T) { - if runtime.GOOS == "windows" { - // TODO(brainman): find out why this test is broken - t.Logf("Temporarily skipping test on Windows; see http://golang.org/issue/3917") - return - } ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") })) @@ -348,6 +348,7 @@ func TestServeFileFromCWD(t *testing.T) { if err != nil { t.Fatal(err) } + r.Body.Close() if r.StatusCode != 200 { t.Fatalf("expected 200 OK, got %s", r.Status) } @@ -522,51 +523,140 @@ func TestDirectoryIfNotModified(t *testing.T) { res.Body.Close() } -func TestServeContent(t *testing.T) { - type req struct { - name string - modtime time.Time - content io.ReadSeeker +func mustStat(t *testing.T, fileName string) os.FileInfo { + fi, err := os.Stat(fileName) + if err != nil { + t.Fatal(err) } - ch := make(chan req, 1) + return fi +} + +func TestServeContent(t *testing.T) { + type serveParam struct { + name string + modtime time.Time + content io.ReadSeeker + contentType string + etag string + } + servec := make(chan serveParam, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - p := <-ch + p := <-servec + if p.etag != "" { + w.Header().Set("ETag", p.etag) + } + if p.contentType != "" { + w.Header().Set("Content-Type", p.contentType) + } ServeContent(w, r, p.name, p.modtime, p.content) })) defer ts.Close() - css, err := os.Open("testdata/style.css") - if err != nil { - t.Fatal(err) - } - defer css.Close() - - ch <- req{"style.css", time.Time{}, css} - res, err := Get(ts.URL) - if err != nil { - t.Fatal(err) - } - if g, e := res.Header.Get("Content-Type"), "text/css; charset=utf-8"; g != e { - t.Errorf("style.css: content type = %q, want %q", g, e) - } - if g := res.Header.Get("Last-Modified"); g != "" { - t.Errorf("want empty Last-Modified; got %q", g) + type testCase struct { + file string + modtime time.Time + serveETag string // optional + serveContentType string // optional + reqHeader map[string]string + wantLastMod string + wantContentType string + wantStatus int + } + htmlModTime := mustStat(t, "testdata/index.html").ModTime() + tests := map[string]testCase{ + "no_last_modified": { + file: "testdata/style.css", + wantContentType: "text/css; charset=utf-8", + wantStatus: 200, + }, + "with_last_modified": { + file: "testdata/index.html", + wantContentType: "text/html; charset=utf-8", + modtime: htmlModTime, + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + wantStatus: 200, + }, + "not_modified_modtime": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_modtime_with_contenttype": { + file: "testdata/style.css", + serveContentType: "text/css", // explicit content type + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 304, + }, + "not_modified_etag": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"foo"`, + }, + wantStatus: 304, + }, + "range_good": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + }, + // An If-Range resource for entity "A", but entity "B" is now current. + // The Range request should be ignored. + "range_no_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"B"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, } + for testName, tt := range tests { + f, err := os.Open(tt.file) + if err != nil { + t.Fatalf("test %q: %v", testName, err) + } + defer f.Close() - fi, err := css.Stat() - if err != nil { - t.Fatal(err) - } - ch <- req{"style.html", fi.ModTime(), css} - res, err = Get(ts.URL) - if err != nil { - t.Fatal(err) - } - if g, e := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != e { - t.Errorf("style.html: content type = %q, want %q", g, e) - } - if g := res.Header.Get("Last-Modified"); g == "" { - t.Errorf("want non-empty last-modified") + servec <- serveParam{ + name: filepath.Base(tt.file), + content: f, + modtime: tt.modtime, + etag: tt.serveETag, + contentType: tt.serveContentType, + } + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + for k, v := range tt.reqHeader { + req.Header.Set(k, v) + } + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tt.wantStatus { + t.Errorf("test %q: status = %d; want %d", testName, res.StatusCode, tt.wantStatus) + } + if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { + t.Errorf("test %q: content-type = %q, want %q", testName, g, e) + } + if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { + t.Errorf("test %q: last-modified = %q, want %q", testName, g, e) + } } } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 6be94f98e74..91417366ae8 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -5,11 +5,11 @@ package http import ( - "fmt" "io" "net/textproto" "sort" "strings" + "time" ) // A Header represents the key-value pairs in an HTTP header. @@ -36,6 +36,14 @@ func (h Header) Get(key string) string { return textproto.MIMEHeader(h).Get(key) } +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + // Del deletes the values associated with key. func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) @@ -46,24 +54,77 @@ func (h Header) Write(w io.Writer) error { return h.WriteSubset(w, nil) } +var timeFormats = []string{ + TimeFormat, + time.RFC850, + time.ANSIC, +} + +// ParseTime parses a time header (such as the Date: header), +// trying each of the three formats allowed by HTTP/1.1: +// TimeFormat, time.RFC850, and time.ANSIC. +func ParseTime(text string) (t time.Time, err error) { + for _, layout := range timeFormats { + t, err = time.Parse(layout, text) + if err == nil { + return + } + } + return +} + var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") +type writeStringer interface { + WriteString(string) (int, error) +} + +// stringWriter implements WriteString on a Writer. +type stringWriter struct { + w io.Writer +} + +func (w stringWriter) WriteString(s string) (n int, err error) { + return w.w.Write([]byte(s)) +} + +type keyValues struct { + key string + values []string +} + +type byKey []keyValues + +func (s byKey) Len() int { return len(s) } +func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key } + +func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues { + kvs := make([]keyValues, 0, len(h)) + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + sort.Sort(byKey(kvs)) + return kvs +} + // WriteSubset writes a header in wire format. // If exclude is not nil, keys where exclude[key] == true are not written. func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { - keys := make([]string, 0, len(h)) - for k := range h { - if exclude == nil || !exclude[k] { - keys = append(keys, k) - } + ws, ok := w.(writeStringer) + if !ok { + ws = stringWriter{w} } - sort.Strings(keys) - for _, k := range keys { - for _, v := range h[k] { + for _, kv := range h.sortedKeyValues(exclude) { + for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) - v = strings.TrimSpace(v) - if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { - return err + v = textproto.TrimString(v) + for _, s := range []string{kv.key, ": ", v, "\r\n"} { + if _, err := ws.WriteString(s); err != nil { + return err + } } } } diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go index ccdee8a97bd..fd971a61d05 100644 --- a/libgo/go/net/http/header_test.go +++ b/libgo/go/net/http/header_test.go @@ -6,7 +6,9 @@ package http import ( "bytes" + "runtime" "testing" + "time" ) var headerWriteTests = []struct { @@ -67,6 +69,24 @@ var headerWriteTests = []struct { nil, "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n", }, + // Tests header sorting when over the insertion sort threshold side: + { + Header{ + "k1": {"1a", "1b"}, + "k2": {"2a", "2b"}, + "k3": {"3a", "3b"}, + "k4": {"4a", "4b"}, + "k5": {"5a", "5b"}, + "k6": {"6a", "6b"}, + "k7": {"7a", "7b"}, + "k8": {"8a", "8b"}, + "k9": {"9a", "9b"}, + }, + map[string]bool{"k5": true}, + "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" + + "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" + + "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n", + }, } func TestHeaderWrite(t *testing.T) { @@ -79,3 +99,113 @@ func TestHeaderWrite(t *testing.T) { buf.Reset() } } + +var parseTimeTests = []struct { + h Header + err bool +}{ + {Header{"Date": {""}}, true}, + {Header{"Date": {"invalid"}}, true}, + {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true}, + {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false}, + {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false}, +} + +func TestParseTime(t *testing.T) { + expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC) + for i, test := range parseTimeTests { + d, err := ParseTime(test.h.Get("Date")) + if err != nil { + if !test.err { + t.Errorf("#%d:\n got err: %v", i, err) + } + continue + } + if test.err { + t.Errorf("#%d:\n should err", i) + continue + } + if !expect.Equal(d) { + t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect) + } + } +} + +type hasTokenTest struct { + header string + token string + want bool +} + +var hasTokenTests = []hasTokenTest{ + {"", "", false}, + {"", "foo", false}, + {"foo", "foo", true}, + {"foo ", "foo", true}, + {" foo", "foo", true}, + {" foo ", "foo", true}, + {"foo,bar", "foo", true}, + {"bar,foo", "foo", true}, + {"bar, foo", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo,baz", "foo", true}, + {"bar,foo, baz", "foo", true}, + {"bar, foo, baz", "foo", true}, + {"FOO", "foo", true}, + {"FOO ", "foo", true}, + {" FOO", "foo", true}, + {" FOO ", "foo", true}, + {"FOO,BAR", "foo", true}, + {"BAR,FOO", "foo", true}, + {"BAR, FOO", "foo", true}, + {"BAR,FOO, baz", "foo", true}, + {"BAR, FOO,BAZ", "foo", true}, + {"BAR,FOO, BAZ", "foo", true}, + {"BAR, FOO, BAZ", "foo", true}, + {"foobar", "foo", false}, + {"barfoo ", "foo", false}, +} + +func TestHasToken(t *testing.T) { + for _, tt := range hasTokenTests { + if hasToken(tt.header, tt.token) != tt.want { + t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want) + } + } +} + +func BenchmarkHeaderWriteSubset(b *testing.B) { + doHeaderWriteSubset(b.N, b) +} + +func TestHeaderWriteSubsetMallocs(t *testing.T) { + doHeaderWriteSubset(100, t) +} + +type errorfer interface { + Errorf(string, ...interface{}) +} + +func doHeaderWriteSubset(n int, t errorfer) { + h := Header(map[string][]string{ + "Content-Length": {"123"}, + "Content-Type": {"text/plain"}, + "Date": {"some date at some time Z"}, + "Server": {"Go http package"}, + }) + var buf bytes.Buffer + var m0 runtime.MemStats + runtime.ReadMemStats(&m0) + for i := 0; i < n; i++ { + buf.Reset() + h.WriteSubset(&buf, nil) + } + var m1 runtime.MemStats + runtime.ReadMemStats(&m1) + if mallocs := m1.Mallocs - m0.Mallocs; n >= 100 && mallocs >= uint64(n) { + // TODO(bradfitz,rsc): once we can sort with allocating, + // make this an error. See http://golang.org/issue/3761 + // t.Errorf("did %d mallocs (>= %d iterations); should have avoided mallocs", mallocs, n) + } +} diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index 9aa0d510bd4..5451f54234c 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -17,6 +17,8 @@ type ResponseRecorder struct { HeaderMap http.Header // the HTTP response headers Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool + + wroteHeader bool } // NewRecorder returns an initialized ResponseRecorder. @@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), + Code: 200, } } @@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4" // Header returns the response headers. func (rw *ResponseRecorder) Header() http.Header { - return rw.HeaderMap + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m } // Write always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(200) + } if rw.Body != nil { rw.Body.Write(buf) } - if rw.Code == 0 { - rw.Code = http.StatusOK - } return len(buf), nil } // WriteHeader sets rw.Code. func (rw *ResponseRecorder) WriteHeader(code int) { - rw.Code = code + if !rw.wroteHeader { + rw.Code = code + } + rw.wroteHeader = true } // Flush sets rw.Flushed to true. func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } rw.Flushed = true } diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go new file mode 100644 index 00000000000..2b563260c76 --- /dev/null +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -0,0 +1,90 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + + tests := []struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true)), + }, + } + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + for _, tt := range tests { + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } + } +} diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index 0fb2eeb8c00..0b003566165 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -75,7 +75,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { // Use the actual Transport code to record what we would send // on the wire, but not using TCP. Use a Transport with a - // customer dialer that returns a fake net.Conn that waits + // custom dialer that returns a fake net.Conn that waits // for the full input (and recording it), and then responds // with a dummy response. var buf bytes.Buffer // records the output diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 9c4bd6e09a5..134c452999d 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -17,6 +17,10 @@ import ( "time" ) +// onExitFlushLoop is a callback set by tests to detect the state of the +// flushLoop() goroutine. +var onExitFlushLoop func() + // ReverseProxy is an HTTP Handler that takes an incoming request and // sends it to another server, proxying the response back to the // client. @@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Del("Connection") } - if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - outreq.Header.Set("X-Forwarded-For", clientIp) + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) } res, err := transport.RoundTrip(outreq) @@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusInternalServerError) return } + defer res.Body.Close() copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) +} - if res.Body != nil { - var dst io.Writer = rw - if p.FlushInterval != 0 { - if wf, ok := rw.(writeFlusher); ok { - dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval} +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + if p.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: p.FlushInterval, + done: make(chan bool), } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw } - io.Copy(dst, res.Body) } + + io.Copy(dst, src) } type writeFlusher interface { @@ -137,22 +156,14 @@ type maxLatencyWriter struct { dst writeFlusher latency time.Duration - lk sync.Mutex // protects init of done, as well Write + Flush + lk sync.Mutex // protects Write + Flush done chan bool } -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { +func (m *maxLatencyWriter) Write(p []byte) (int, error) { m.lk.Lock() defer m.lk.Unlock() - if m.done == nil { - m.done = make(chan bool) - go m.flushLoop() - } - n, err = m.dst.Write(p) - if err != nil { - m.done <- true - } - return + return m.dst.Write(p) } func (m *maxLatencyWriter) flushLoop() { @@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() { defer t.Stop() for { select { + case <-m.done: + if onExitFlushLoop != nil { + onExitFlushLoop() + } + return case <-t.C: m.lk.Lock() m.dst.Flush() m.lk.Unlock() - case <-m.done: - return } } panic("unreached") } + +func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 28e9c90ad36..8639271626f 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -11,7 +11,9 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" ) func TestReverseProxy(t *testing.T) { @@ -70,6 +72,47 @@ func TestReverseProxy(t *testing.T) { } } +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + var proxyQueryTests = []struct { baseSuffix string // suffix to add to backend URL reqSuffix string // suffix to add to frontend's request URL @@ -107,3 +150,44 @@ func TestReverseProxyQuery(t *testing.T) { frontend.Close() } } + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + done := make(chan bool) + onExitFlushLoop = func() { done <- true } + defer func() { onExitFlushLoop = nil }() + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + + select { + case <-done: + // OK + case <-time.After(5 * time.Second): + t.Error("maxLatencyWriter flushLoop() never exited") + } +} diff --git a/libgo/go/net/http/lex.go b/libgo/go/net/http/lex.go index ffb393ccf6a..cb33318f49b 100644 --- a/libgo/go/net/http/lex.go +++ b/libgo/go/net/http/lex.go @@ -6,131 +6,91 @@ package http // This file deals with lexical matters of HTTP -func isSeparator(c byte) bool { - switch c { - case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': - return true - } - return false +var isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, } -func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 } - -func isChar(c byte) bool { return 0 <= c && c <= 127 } - -func isAnyText(c byte) bool { return !isCtl(c) } - -func isQdText(c byte) bool { return isAnyText(c) && c != '"' } - -func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) } - -// Valid escaped sequences are not specified in RFC 2616, so for now, we assume -// that they coincide with the common sense ones used by GO. Malformed -// characters should probably not be treated as errors by a robust (forgiving) -// parser, so we replace them with the '?' character. -func httpUnquotePair(b byte) byte { - // skip the first byte, which should always be '\' - switch b { - case 'a': - return '\a' - case 'b': - return '\b' - case 'f': - return '\f' - case 'n': - return '\n' - case 'r': - return '\r' - case 't': - return '\t' - case 'v': - return '\v' - case '\\': - return '\\' - case '\'': - return '\'' - case '"': - return '"' - } - return '?' -} - -// raw must begin with a valid quoted string. Only the first quoted string is -// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1 -// upon failure. -func httpUnquote(raw []byte) (eaten int, result string) { - buf := make([]byte, len(raw)) - if raw[0] != '"' { - return -1, "" - } - eaten = 1 - j := 0 // # of bytes written in buf - for i := 1; i < len(raw); i++ { - switch b := raw[i]; b { - case '"': - eaten++ - buf = buf[0:j] - return i + 1, string(buf) - case '\\': - if len(raw) < i+2 { - return -1, "" - } - buf[j] = httpUnquotePair(raw[i+1]) - eaten += 2 - j++ - i++ - default: - if isQdText(b) { - buf[j] = b - } else { - buf[j] = '?' - } - eaten++ - j++ - } - } - return -1, "" +func isToken(r rune) bool { + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] } -// This is a best effort parse, so errors are not returned, instead not all of -// the input string might be parsed. result is always non-nil. -func httpSplitFieldValue(fv string) (eaten int, result []string) { - result = make([]string, 0, len(fv)) - raw := []byte(fv) - i := 0 - chunk := "" - for i < len(raw) { - b := raw[i] - switch { - case b == '"': - eaten, unq := httpUnquote(raw[i:len(raw)]) - if eaten < 0 { - return i, result - } else { - i += eaten - chunk += unq - } - case isSeparator(b): - if chunk != "" { - result = result[0 : len(result)+1] - result[len(result)-1] = chunk - chunk = "" - } - i++ - case isToken(b): - chunk += string(b) - i++ - case b == '\n' || b == '\r': - i++ - default: - chunk += "?" - i++ - } - } - if chunk != "" { - result = result[0 : len(result)+1] - result[len(result)-1] = chunk - chunk = "" - } - return i, result +func isNotToken(r rune) bool { + return !isToken(r) } diff --git a/libgo/go/net/http/lex_test.go b/libgo/go/net/http/lex_test.go index 5386f7534db..6d9d294f703 100644 --- a/libgo/go/net/http/lex_test.go +++ b/libgo/go/net/http/lex_test.go @@ -8,63 +8,24 @@ import ( "testing" ) -type lexTest struct { - Raw string - Parsed int // # of parsed characters - Result []string -} +func isChar(c rune) bool { return c <= 127 } -var lexTests = []lexTest{ - { - Raw: `"abc"def,:ghi`, - Parsed: 13, - Result: []string{"abcdef", "ghi"}, - }, - // My understanding of the RFC is that escape sequences outside of - // quotes are not interpreted? - { - Raw: `"\t"\t"\t"`, - Parsed: 10, - Result: []string{"\t", "t\t"}, - }, - { - Raw: `"\yab"\r\n`, - Parsed: 10, - Result: []string{"?ab", "r", "n"}, - }, - { - Raw: "ab\f", - Parsed: 3, - Result: []string{"ab?"}, - }, - { - Raw: "\"ab \" c,de f, gh, ij\n\t\r", - Parsed: 23, - Result: []string{"ab ", "c", "de", "f", "gh", "ij"}, - }, -} +func isCtl(c rune) bool { return c <= 31 || c == 127 } -func min(x, y int) int { - if x <= y { - return x +func isSeparator(c rune) bool { + switch c { + case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': + return true } - return y + return false } -func TestSplitFieldValue(t *testing.T) { - for k, l := range lexTests { - parsed, result := httpSplitFieldValue(l.Raw) - if parsed != l.Parsed { - t.Errorf("#%d: Parsed %d, expected %d", k, parsed, l.Parsed) - } - if len(result) != len(l.Result) { - t.Errorf("#%d: Result len %d, expected %d", k, len(result), len(l.Result)) - } - for i := 0; i < min(len(result), len(l.Result)); i++ { - if result[i] != l.Result[i] { - t.Errorf("#%d: %d-th entry mismatch. Have {%s}, expect {%s}", - k, i, result[i], l.Result[i]) - } +func TestIsToken(t *testing.T) { + for i := 0; i <= 130; i++ { + r := rune(i) + expected := isChar(r) && !isCtl(r) && !isSeparator(r) + if isToken(r) != expected { + t.Errorf("isToken(0x%x) = %v", r, !expected) } } } diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 7a9f465c477..d70bf4ed9d3 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -30,6 +30,10 @@ // // go tool pprof http://localhost:6060/debug/pprof/profile // +// Or to look at the goroutine blocking profile: +// +// go tool pprof http://localhost:6060/debug/pprof/block +// // Or to view all available profiles: // // go tool pprof http://localhost:6060/debug/pprof/ diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index f5bc6eb9100..61557ff8302 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -19,6 +19,7 @@ import ( "mime/multipart" "net/textproto" "net/url" + "strconv" "strings" ) @@ -131,6 +132,12 @@ type Request struct { // The HTTP client ignores Form and uses Body instead. Form url.Values + // PostForm contains the parsed form data from POST or PUT + // body parameters. + // This field is only available after ParseForm is called. + // The HTTP client ignores PostForm and uses Body instead. + PostForm url.Values + // MultipartForm is the parsed multipart form, including file uploads. // This field is only available after ParseMultipartForm is called. // The HTTP client ignores MultipartForm and uses Body instead. @@ -369,36 +376,29 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err return bw.Flush() } -// Convert decimal at s[i:len(s)] to integer, -// returning value, string position where the digits stopped, -// and whether there was a valid number (digits, not too big). -func atoi(s string, i int) (n, i1 int, ok bool) { - const Big = 1000000 - if i >= len(s) || s[i] < '0' || s[i] > '9' { - return 0, 0, false - } - n = 0 - for ; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { - n = n*10 + int(s[i]-'0') - if n > Big { - return 0, 0, false - } - } - return n, i, true -} - // ParseHTTPVersion parses a HTTP version string. // "HTTP/1.0" returns (1, 0, true). func ParseHTTPVersion(vers string) (major, minor int, ok bool) { - if len(vers) < 5 || vers[0:5] != "HTTP/" { + const Big = 1000000 // arbitrary upper bound + switch vers { + case "HTTP/1.1": + return 1, 1, true + case "HTTP/1.0": + return 1, 0, true + } + if !strings.HasPrefix(vers, "HTTP/") { + return 0, 0, false + } + dot := strings.Index(vers, ".") + if dot < 0 { return 0, 0, false } - major, i, ok := atoi(vers, 5) - if !ok || i >= len(vers) || vers[i] != '.' { + major, err := strconv.Atoi(vers[5:dot]) + if err != nil || major < 0 || major > Big { return 0, 0, false } - minor, i, ok = atoi(vers, i+1) - if !ok || i != len(vers) { + minor, err = strconv.Atoi(vers[dot+1:]) + if err != nil || minor < 0 || minor > Big { return 0, 0, false } return major, minor, true @@ -513,9 +513,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { // the same. In the second case, any Host line is ignored. req.Host = req.URL.Host if req.Host == "" { - req.Host = req.Header.Get("Host") + req.Host = req.Header.get("Host") } - req.Header.Del("Host") + delete(req.Header, "Host") fixPragmaCacheControl(req.Header) @@ -594,66 +594,93 @@ func (l *maxBytesReader) Close() error { return l.r.Close() } +func copyValues(dst, src url.Values) { + for k, vs := range src { + for _, value := range vs { + dst.Add(k, value) + } + } +} + +func parsePostForm(r *Request) (vs url.Values, err error) { + if r.Body == nil { + err = errors.New("missing form body") + return + } + ct := r.Header.Get("Content-Type") + ct, _, err = mime.ParseMediaType(ct) + switch { + case ct == "application/x-www-form-urlencoded": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := ioutil.ReadAll(reader) + if e != nil { + if err == nil { + err = e + } + break + } + if int64(len(b)) > maxFormSize { + err = errors.New("http: POST too large") + return + } + vs, e = url.ParseQuery(string(b)) + if err == nil { + err = e + } + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestRequestMultipartCallOrder. + } + return +} + // ParseForm parses the raw query from the URL. // // For POST or PUT requests, it also parses the request body as a form. +// POST and PUT body parameters take precedence over URL query string values. // If the request Body's size has not already been limited by MaxBytesReader, // the size is capped at 10MB. // // ParseMultipartForm calls ParseForm automatically. // It is idempotent. func (r *Request) ParseForm() (err error) { - if r.Form != nil { - return - } - if r.URL != nil { - r.Form, err = url.ParseQuery(r.URL.RawQuery) + if r.PostForm == nil { + if r.Method == "POST" || r.Method == "PUT" { + r.PostForm, err = parsePostForm(r) + } + if r.PostForm == nil { + r.PostForm = make(url.Values) + } } - if r.Method == "POST" || r.Method == "PUT" { - if r.Body == nil { - return errors.New("missing form body") + if r.Form == nil { + if len(r.PostForm) > 0 { + r.Form = make(url.Values) + copyValues(r.Form, r.PostForm) } - ct := r.Header.Get("Content-Type") - ct, _, err = mime.ParseMediaType(ct) - switch { - case ct == "application/x-www-form-urlencoded": - var reader io.Reader = r.Body - maxFormSize := int64(1<<63 - 1) - if _, ok := r.Body.(*maxBytesReader); !ok { - maxFormSize = int64(10 << 20) // 10 MB is a lot of text. - reader = io.LimitReader(r.Body, maxFormSize+1) - } - b, e := ioutil.ReadAll(reader) - if e != nil { - if err == nil { - err = e - } - break - } - if int64(len(b)) > maxFormSize { - return errors.New("http: POST too large") - } - var newValues url.Values - newValues, e = url.ParseQuery(string(b)) + var newValues url.Values + if r.URL != nil { + var e error + newValues, e = url.ParseQuery(r.URL.RawQuery) if err == nil { err = e } - if r.Form == nil { - r.Form = make(url.Values) - } - // Copy values into r.Form. TODO: make this smoother. - for k, vs := range newValues { - for _, value := range vs { - r.Form.Add(k, value) - } - } - case ct == "multipart/form-data": - // handled by ParseMultipartForm (which is calling us, or should be) - // TODO(bradfitz): there are too many possible - // orders to call too many functions here. - // Clean this up and write more tests. - // request_test.go contains the start of this, - // in TestRequestMultipartCallOrder. + } + if newValues == nil { + newValues = make(url.Values) + } + if r.Form == nil { + r.Form = newValues + } else { + copyValues(r.Form, newValues) } } return err @@ -699,6 +726,7 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error { } // FormValue returns the first value for the named component of the query. +// POST and PUT body parameters take precedence over URL query string values. // FormValue calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormValue(key string) string { if r.Form == nil { @@ -710,6 +738,19 @@ func (r *Request) FormValue(key string) string { return "" } +// PostFormValue returns the first value for the named component of the POST +// or PUT request body. URL query parameters are ignored. +// PostFormValue calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) PostFormValue(key string) string { + if r.PostForm == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.PostForm[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + // FormFile returns the first file for the provided form key. // FormFile calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { @@ -732,12 +773,16 @@ func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, e } func (r *Request) expectsContinue() bool { - return strings.ToLower(r.Header.Get("Expect")) == "100-continue" + return hasToken(r.Header.get("Expect"), "100-continue") } func (r *Request) wantsHttp10KeepAlive() bool { if r.ProtoMajor != 1 || r.ProtoMinor != 0 { return false } - return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive") + return hasToken(r.Header.get("Connection"), "keep-alive") +} + +func (r *Request) wantsClose() bool { + return hasToken(r.Header.get("Connection"), "close") } diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index 6e00b9bfd39..db7419b26fe 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -30,8 +30,8 @@ func TestQuery(t *testing.T) { } func TestPostQuery(t *testing.T) { - req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x", - strings.NewReader("z=post&both=y")) + req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", + strings.NewReader("z=post&both=y&prio=2&empty=")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") if q := req.FormValue("q"); q != "foo" { @@ -40,8 +40,23 @@ func TestPostQuery(t *testing.T) { if z := req.FormValue("z"); z != "post" { t.Errorf(`req.FormValue("z") = %q, want "post"`, z) } - if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"x", "y"}) { - t.Errorf(`req.FormValue("both") = %q, want ["x", "y"]`, both) + if bq, found := req.PostForm["q"]; found { + t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) + } + if bz := req.PostFormValue("z"); bz != "post" { + t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) + } + if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { + t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) + } + if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { + t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) + } + if prio := req.FormValue("prio"); prio != "2" { + t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) + } + if empty := req.FormValue("empty"); empty != "" { + t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) } } @@ -76,6 +91,23 @@ func TestParseFormUnknownContentType(t *testing.T) { } } +func TestParseFormInitializeOnError(t *testing.T) { + nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil) + tests := []*Request{ + nilBody, + {Method: "GET", URL: nil}, + } + for i, req := range tests { + err := req.ParseForm() + if req.Form == nil { + t.Errorf("%d. Form not initialized, error %v", i, err) + } + if req.PostForm == nil { + t.Errorf("%d. PostForm not initialized, error %v", i, err) + } + } +} + func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 945ecd8a4b0..92d2f499839 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -107,7 +107,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (resp *Response, err error) { resp = new(Response) resp.Request = req - resp.Request.Method = strings.ToUpper(resp.Request.Method) // Parse the first line of the response. line, err := tp.ReadLine() @@ -188,11 +187,6 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { // func (r *Response) Write(w io.Writer) error { - // RequestMethod should be upper-case - if r.Request != nil { - r.Request.Method = strings.ToUpper(r.Request.Method) - } - // Status line text := r.Status if text == "" { diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index c9d73932bb9..71b7b3fb6b7 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -20,8 +20,13 @@ import ( "net/http/httputil" "net/url" "os" + "os/exec" "reflect" + "runtime" + "strconv" "strings" + "sync" + "sync/atomic" "syscall" "testing" "time" @@ -168,6 +173,9 @@ var vtests = []struct { {"http://someHost.com/someDir/apage", "someHost.com/someDir"}, {"http://otherHost.com/someDir/apage", "someDir"}, {"http://otherHost.com/aDir/apage", "Default"}, + // redirections for trees + {"http://localhost/someDir", "/someDir/"}, + {"http://someHost.com/someDir", "/someDir/"}, } func TestHostHandlers(t *testing.T) { @@ -199,9 +207,19 @@ func TestHostHandlers(t *testing.T) { t.Errorf("reading response: %v", err) continue } - s := r.Header.Get("Result") - if s != vt.expected { - t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + switch r.StatusCode { + case StatusOK: + s := r.Header.Get("Result") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + case StatusMovedPermanently: + s := r.Header.Get("Location") + if s != vt.expected { + t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) + } + default: + t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode) } } } @@ -370,7 +388,7 @@ func TestIdentityResponse(t *testing.T) { }) } -func testTcpConnectionCloses(t *testing.T, req string, h Handler) { +func testTCPConnectionCloses(t *testing.T, req string, h Handler) { s := httptest.NewServer(h) defer s.Close() @@ -414,21 +432,28 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) { // TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. func TestServeHTTP10Close(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) } +// TestClientCanClose verifies that clients can also force a connection to close. +func TestClientCanClose(t *testing.T) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + // Nothing. + })) +} + // TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, // even for HTTP/1.1 requests. func TestHandlersCanSetConnectionClose11(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } func TestHandlersCanSetConnectionClose10(t *testing.T) { - testTcpConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } @@ -665,30 +690,51 @@ func TestServerExpect(t *testing.T) { t.Fatalf("Dial: %v", err) } defer conn.Close() - sendf := func(format string, args ...interface{}) { - _, err := fmt.Fprintf(conn, format, args...) - if err != nil { - t.Fatalf("On test %#v, error writing %q: %v", test, format, err) - } - } + + // Only send the body immediately if we're acting like an HTTP client + // that doesn't send 100-continue expectations. + writeBody := test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" + go func() { - sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+ "Connection: close\r\n"+ "Content-Length: %d\r\n"+ "Expect: %s\r\nHost: foo\r\n\r\n", test.readBody, test.contentLength, test.expectation) - if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + if err != nil { + t.Errorf("On test %#v, error writing request headers: %v", test, err) + return + } + if writeBody { body := strings.Repeat("A", test.contentLength) - sendf(body) + _, err = fmt.Fprint(conn, body) + if err != nil { + if !test.readBody { + // Server likely already hung up on us. + // See larger comment below. + t.Logf("On test %#v, acceptable error writing request body: %v", test, err) + return + } + t.Errorf("On test %#v, error writing request body: %v", test, err) + } } }() bufr := bufio.NewReader(conn) line, err := bufr.ReadString('\n') if err != nil { - t.Fatalf("ReadString: %v", err) + if writeBody && !test.readBody { + // This is an acceptable failure due to a possible TCP race: + // We were still writing data and the server hung up on us. A TCP + // implementation may send a RST if our request body data was known + // to be lost, which may trigger our reads to fail. + // See RFC 1122 page 88. + t.Logf("On test %#v, acceptable error from ReadString: %v", test, err) + return + } + t.Fatalf("On test %#v, ReadString: %v", test, err) } if !strings.Contains(line, test.expectedResponse) { - t.Errorf("for test %#v got first line=%q", test, line) + t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse) } } @@ -1112,6 +1158,68 @@ func TestServerBufferedChunking(t *testing.T) { } } +// Tests that the server flushes its response headers out when it's +// ignoring the response body and waits a bit before forcefully +// closing the TCP connection, causing the client to get a RST. +// See http://golang.org/issue/3595 +func TestServerGracefulClose(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, "bye", StatusUnauthorized) + })) + defer ts.Close() + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + const bodySize = 5 << 20 + req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize)) + for i := 0; i < bodySize; i++ { + req = append(req, 'x') + } + writeErr := make(chan error) + go func() { + _, err := conn.Write(req) + writeErr <- err + }() + br := bufio.NewReader(conn) + lineNum := 0 + for { + line, err := br.ReadString('\n') + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("ReadLine: %v", err) + } + lineNum++ + if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") { + t.Errorf("Response line = %q; want a 401", line) + } + } + // Wait for write to finish. This is a broken pipe on both + // Darwin and Linux, but checking this isn't the point of + // the test. + <-writeErr +} + +func TestCaseSensitiveMethod(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "get" { + t.Errorf(`Got method %q; want "get"`, r.Method) + } + })) + defer ts.Close() + req, _ := NewRequest("get", ts.URL, nil) + res, err := DefaultClient.Do(req) + if err != nil { + t.Error(err) + return + } + res.Body.Close() +} + // TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1 // request (both keep-alive), when a Handler never writes any // response, the net/http package adds a "Content-Length: 0" response @@ -1220,3 +1328,100 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } + +func BenchmarkClientServerParallel4(b *testing.B) { + benchmarkClientServerParallel(b, 4) +} + +func BenchmarkClientServerParallel64(b *testing.B) { + benchmarkClientServerParallel(b, 64) +} + +func benchmarkClientServerParallel(b *testing.B, conc int) { + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })) + defer ts.Close() + b.StartTimer() + + numProcs := runtime.GOMAXPROCS(-1) * conc + var wg sync.WaitGroup + wg.Add(numProcs) + n := int32(b.N) + for p := 0; p < numProcs; p++ { + go func() { + for atomic.AddInt32(&n, -1) >= 0 { + res, err := Get(ts.URL) + if err != nil { + b.Logf("Get: %v", err) + continue + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + b.Logf("ReadAll: %v", err) + continue + } + body := string(all) + if body != "Hello world.\n" { + panic("Got body: " + body) + } + } + wg.Done() + }() + } + wg.Wait() +} + +// A benchmark for profiling the server without the HTTP client code. +// The client code runs in a subprocess. +// +// For use like: +// $ go test -c +// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof +// $ go tool pprof http.test http.prof +// (pprof) web +func BenchmarkServer(b *testing.B) { + // Child process mode; + if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" { + n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N")) + if err != nil { + panic(err) + } + for i := 0; i < n; i++ { + res, err := Get(url) + if err != nil { + log.Panicf("Get: %v", err) + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + log.Panicf("ReadAll: %v", err) + } + body := string(all) + if body != "Hello world.\n" { + log.Panicf("Got body: %q", body) + } + } + os.Exit(0) + return + } + + var res = []byte("Hello world.\n") + b.StopTimer() + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + })) + defer ts.Close() + b.StartTimer() + + cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer") + cmd.Env = append([]string{ + fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N), + fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL), + }, os.Environ()...) + out, err := cmd.CombinedOutput() + if err != nil { + b.Errorf("Test failure: %v, with output: %s", err, out) + } +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index b74b7629809..ee57e01276d 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -129,7 +129,7 @@ type response struct { // maxBytesReader hits its max size. It is checked in // WriteHeader, to make sure we don't consume the the // remaining request body to try to advance to the next HTTP - // request. Instead, when this is set, we stop doing + // request. Instead, when this is set, we stop reading // subsequent requests on this connection and stop reading // input from it. requestBodyLimitHit bool @@ -287,7 +287,7 @@ func (w *response) WriteHeader(code int) { // Check for a explicit (and valid) Content-Length header. var hasCL bool var contentLength int64 - if clenStr := w.header.Get("Content-Length"); clenStr != "" { + if clenStr := w.header.get("Content-Length"); clenStr != "" { var err error contentLength, err = strconv.ParseInt(clenStr, 10, 64) if err == nil { @@ -303,12 +303,11 @@ func (w *response) WriteHeader(code int) { if !connectionHeaderSet { w.header.Set("Connection", "keep-alive") } - } else if !w.req.ProtoAtLeast(1, 1) { - // Client did not ask to keep connection alive. + } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { w.closeAfterReply = true } - if w.header.Get("Connection") == "close" { + if w.header.get("Connection") == "close" { w.closeAfterReply = true } @@ -332,7 +331,7 @@ func (w *response) WriteHeader(code int) { if code == StatusNotModified { // Must not have body. for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - if w.header.Get(header) != "" { + if w.header.get(header) != "" { // TODO: return an error if WriteHeader gets a return parameter // or set a flag on w to make future Writes() write an error page? // for now just log and drop the header. @@ -342,7 +341,7 @@ func (w *response) WriteHeader(code int) { } } else { // If no content type, apply sniffing algorithm to body. - if w.header.Get("Content-Type") == "" && w.req.Method != "HEAD" { + if w.header.get("Content-Type") == "" && w.req.Method != "HEAD" { w.needSniff = true } } @@ -351,7 +350,7 @@ func (w *response) WriteHeader(code int) { w.Header().Set("Date", time.Now().UTC().Format(TimeFormat)) } - te := w.header.Get("Transfer-Encoding") + te := w.header.get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter @@ -391,7 +390,7 @@ func (w *response) WriteHeader(code int) { return } - if w.closeAfterReply && !hasToken(w.header.Get("Connection"), "close") { + if w.closeAfterReply && !hasToken(w.header.get("Connection"), "close") { w.header.Set("Connection", "close") } @@ -518,14 +517,14 @@ func (w *response) finishRequest() { // HTTP/1.0 clients keep their "keep-alive" connections alive, and for // HTTP/1.1 clients is just as good as the alternative: sending a // chunked response and immediately sending the zero-length EOF chunk. - if w.written == 0 && w.header.Get("Content-Length") == "" { + if w.written == 0 && w.header.get("Content-Length") == "" { w.header.Set("Content-Length", "0") } // If this was an HTTP/1.0 request with keep-alive and we sent a // Content-Length back, we can make this a keep-alive response ... if w.req.wantsHttp10KeepAlive() { - sentLength := w.header.Get("Content-Length") != "" - if sentLength && w.header.Get("Connection") == "keep-alive" { + sentLength := w.header.get("Content-Length") != "" + if sentLength && w.header.get("Connection") == "keep-alive" { w.closeAfterReply = false } } @@ -564,18 +563,31 @@ func (w *response) Flush() { w.conn.buf.Flush() } -// Close the connection. -func (c *conn) close() { +func (c *conn) finalFlush() { if c.buf != nil { c.buf.Flush() c.buf = nil } +} + +// Close the connection. +func (c *conn) close() { + c.finalFlush() if c.rwc != nil { c.rwc.Close() c.rwc = nil } } +// closeWrite flushes any outstanding data and sends a FIN packet (if client +// is connected via TCP), signalling that we're done. +func (c *conn) closeWrite() { + c.finalFlush() + if tcp, ok := c.rwc.(*net.TCPConn); ok { + tcp.CloseWrite() + } +} + // Serve a new connection. func (c *conn) serve() { defer func() { @@ -637,7 +649,7 @@ func (c *conn) serve() { break } req.Header.Del("Expect") - } else if req.Header.Get("Expect") != "" { + } else if req.Header.get("Expect") != "" { // TODO(bradfitz): let ServeHTTP handlers handle // requests with non-standard expectation[s]? Seems // theoretical at best, and doesn't fit into the @@ -672,6 +684,20 @@ func (c *conn) serve() { } w.finishRequest() if w.closeAfterReply { + if w.requestBodyLimitHit { + // Flush our response and send a FIN packet and wait a bit + // before closing the connection, so the client has a chance + // to read our response before they possibly get a RST from + // our TCP stack from ignoring their unread body. + // See http://golang.org/issue/3595 + c.closeWrite() + // Now wait a bit for our machine to send the FIN and the client's + // machine's HTTP client to read the request before we close + // the connection, which might send a RST (on BSDs, at least). + // 250ms is somewhat arbitrary (~latency around half the planet), + // but this doesn't need to be a full second probably. + time.Sleep(250 * time.Millisecond) + } break } } @@ -849,13 +875,15 @@ func RedirectHandler(url string, code int) Handler { // redirecting any request containing . or .. elements to an // equivalent .- and ..-free URL. type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry + mu sync.RWMutex + m map[string]muxEntry + hosts bool // whether any patterns contain hostnames } type muxEntry struct { explicit bool h Handler + pattern string } // NewServeMux allocates and returns a new ServeMux. @@ -896,8 +924,7 @@ func cleanPath(p string) string { // Find a handler on a handler map given a path string // Most-specific (longest) pattern wins -func (mux *ServeMux) match(path string) Handler { - var h Handler +func (mux *ServeMux) match(path string) (h Handler, pattern string) { var n = 0 for k, v := range mux.m { if !pathMatch(k, path) { @@ -906,39 +933,59 @@ func (mux *ServeMux) match(path string) Handler { if h == nil || len(k) > n { n = len(k) h = v.h + pattern = v.pattern + } + } + return +} + +// Handler returns the handler to use for the given request, +// consulting r.Method, r.Host, and r.URL.Path. It always returns +// a non-nil handler. If the path is not in its canonical form, the +// handler will be an internally-generated handler that redirects +// to the canonical path. +// +// Handler also returns the registered pattern that matches the +// request or, in the case of internally-generated redirects, +// the pattern that will match after following the redirect. +// +// If there is no registered handler that applies to the request, +// Handler returns a ``page not found'' handler and an empty pattern. +func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + if r.Method != "CONNECT" { + if p := cleanPath(r.URL.Path); p != r.URL.Path { + _, pattern = mux.handler(r.Host, p) + return RedirectHandler(p, StatusMovedPermanently), pattern } } - return h + + return mux.handler(r.Host, r.URL.Path) } -// handler returns the handler to use for the request r. -func (mux *ServeMux) handler(r *Request) Handler { +// handler is the main implementation of Handler. +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { mux.mu.RLock() defer mux.mu.RUnlock() // Host-specific pattern takes precedence over generic ones - h := mux.match(r.Host + r.URL.Path) + if mux.hosts { + h, pattern = mux.match(host + path) + } if h == nil { - h = mux.match(r.URL.Path) + h, pattern = mux.match(path) } if h == nil { - h = NotFoundHandler() + h, pattern = NotFoundHandler(), "" } - return h + return } // ServeHTTP dispatches the request to the handler whose // pattern most closely matches the request URL. func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { - if r.Method != "CONNECT" { - // Clean path to canonical form and redirect. - if p := cleanPath(r.URL.Path); p != r.URL.Path { - w.Header().Set("Location", p) - w.WriteHeader(StatusMovedPermanently) - return - } - } - mux.handler(r).ServeHTTP(w, r) + h, _ := mux.Handler(r) + h.ServeHTTP(w, r) } // Handle registers the handler for the given pattern. @@ -957,14 +1004,26 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { panic("http: multiple registrations for " + pattern) } - mux.m[pattern] = muxEntry{explicit: true, h: handler} + mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern} + + if pattern[0] != '/' { + mux.hosts = true + } // Helpful behavior: // If pattern is /tree/, insert an implicit permanent redirect for /tree. // It can be overridden by an explicit registration. n := len(pattern) if n > 0 && pattern[n-1] == '/' && !mux.m[pattern[0:n-1]].explicit { - mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(pattern, StatusMovedPermanently)} + // If pattern contains a host name, strip it and use remaining + // path for redirect. + path := pattern + if pattern[0] != '/' { + // In pattern, at least the last character is a '/', so + // strings.Index can't be -1. + path = pattern[strings.Index(pattern, "/"):] + } + mux.m[pattern[0:n-1]] = muxEntry{h: RedirectHandler(path, StatusMovedPermanently), pattern: pattern} } } diff --git a/libgo/go/net/http/server_test.go b/libgo/go/net/http/server_test.go new file mode 100644 index 00000000000..8b4e8c6d6f6 --- /dev/null +++ b/libgo/go/net/http/server_test.go @@ -0,0 +1,95 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" + "testing" +) + +var serveMuxRegister = []struct { + pattern string + h Handler +}{ + {"/dir/", serve(200)}, + {"/search", serve(201)}, + {"codesearch.google.com/search", serve(202)}, + {"codesearch.google.com/", serve(203)}, +} + +// serve returns a handler that sends a response with the given code. +func serve(code int) HandlerFunc { + return func(w ResponseWriter, r *Request) { + w.WriteHeader(code) + } +} + +var serveMuxTests = []struct { + method string + host string + path string + code int + pattern string +}{ + {"GET", "google.com", "/", 404, ""}, + {"GET", "google.com", "/dir", 301, "/dir/"}, + {"GET", "google.com", "/dir/", 200, "/dir/"}, + {"GET", "google.com", "/dir/file", 200, "/dir/"}, + {"GET", "google.com", "/search", 201, "/search"}, + {"GET", "google.com", "/search/", 404, ""}, + {"GET", "google.com", "/search/foo", 404, ""}, + {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"}, + {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"}, + {"GET", "images.google.com", "/search", 201, "/search"}, + {"GET", "images.google.com", "/search/", 404, ""}, + {"GET", "images.google.com", "/search/foo", 404, ""}, + {"GET", "google.com", "/../search", 301, "/search"}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/..", 301, ""}, + {"GET", "google.com", "/dir/./file", 301, "/dir/"}, + + // The /foo -> /foo/ redirect applies to CONNECT requests + // but the path canonicalization does not. + {"CONNECT", "google.com", "/dir", 301, "/dir/"}, + {"CONNECT", "google.com", "/../search", 404, ""}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/..", 200, "/dir/"}, + {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"}, +} + +func TestServeMuxHandler(t *testing.T) { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + + for _, tt := range serveMuxTests { + r := &Request{ + Method: tt.method, + Host: tt.host, + URL: &url.URL{ + Path: tt.path, + }, + } + h, pattern := mux.Handler(r) + cs := &codeSaver{h: Header{}} + h.ServeHTTP(cs, r) + if pattern != tt.pattern || cs.code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern) + } + } +} + +// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader. +type codeSaver struct { + h Header + code int +} + +func (cs *codeSaver) Header() Header { return cs.h } +func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil } +func (cs *codeSaver) WriteHeader(code int) { cs.code = code } diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 9e9d84172d0..1fc1e63a960 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -432,7 +432,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, } // Logic based on Content-Length - cl := strings.TrimSpace(header.Get("Content-Length")) + cl := strings.TrimSpace(header.get("Content-Length")) if cl != "" { n, err := strconv.ParseInt(cl, 10, 64) if err != nil || n < 0 { @@ -454,7 +454,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, // Logic based on media type. The purpose of the following code is just // to detect whether the unsupported "multipart/byteranges" is being // used. A proper Content-Type parser is needed in the future. - if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") { + if strings.Contains(strings.ToLower(header.get("Content-Type")), "multipart/byteranges") { return -1, ErrNotSupported } @@ -469,14 +469,14 @@ func shouldClose(major, minor int, header Header) bool { if major < 1 { return true } else if major == 1 && minor == 0 { - if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") { + if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") { return true } return false } else { // TODO: Should split on commas, toss surrounding white space, // and check each field. - if strings.ToLower(header.Get("Connection")) == "close" { + if strings.ToLower(header.get("Connection")) == "close" { header.Del("Connection") return true } @@ -486,7 +486,7 @@ func shouldClose(major, minor int, header Header) bool { // Parse the trailer header func fixTrailer(header Header, te []string) (Header, error) { - raw := header.Get("Trailer") + raw := header.get("Trailer") if raw == "" { return nil, nil } diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 6131d0d1ee1..651f3ce0081 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -24,6 +24,7 @@ import ( "os" "strings" "sync" + "time" ) // DefaultTransport is the default implementation of Transport and is @@ -265,6 +266,11 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { pconn.close() return false } + for _, exist := range t.idleConn[key] { + if exist == pconn { + log.Fatalf("dup idle pconn %p in freelist", pconn) + } + } t.idleConn[key] = append(t.idleConn[key], pconn) t.idleLk.Unlock() return true @@ -295,7 +301,7 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } } - return + panic("unreachable") } func (t *Transport) dial(network, addr string) (c net.Conn, err error) { @@ -329,6 +335,8 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { cacheKey: cm.String(), conn: conn, reqch: make(chan requestAndChan, 50), + writech: make(chan writeRequest, 50), + closech: make(chan struct{}), } switch { @@ -373,7 +381,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { // Initiate TLS and check remote host name against certificate. cfg := t.TLSClientConfig if cfg == nil || cfg.ServerName == "" { - host, _, _ := net.SplitHostPort(cm.addr()) + host := cm.tlsHost() if cfg == nil { cfg = &tls.Config{ServerName: host} } else { @@ -397,6 +405,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { pconn.br = bufio.NewReader(pconn.conn) pconn.bw = bufio.NewWriter(pconn.conn) go pconn.readLoop() + go pconn.writeLoop() return pconn, nil } @@ -504,7 +513,9 @@ type persistConn struct { closed bool // whether conn has been closed br *bufio.Reader // from conn bw *bufio.Writer // to conn - reqch chan requestAndChan // written by roundTrip(); read by readLoop() + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // broadcast close when readLoop (TCP connection) closes isProxy bool // mutateHeaderFunc is an optional func to modify extra @@ -537,6 +548,7 @@ func remoteSideClosed(err error) bool { } func (pc *persistConn) readLoop() { + defer close(pc.closech) alive := true var lastbody io.ReadCloser // last response body, if any, read on this connection @@ -563,7 +575,11 @@ func (pc *persistConn) readLoop() { lastbody.Close() // assumed idempotent lastbody = nil } - resp, err := ReadResponse(pc.br, rc.req) + + var resp *Response + if err == nil { + resp, err = ReadResponse(pc.br, rc.req) + } if err != nil { pc.close() @@ -592,12 +608,12 @@ func (pc *persistConn) readLoop() { var waitForBodyRead chan bool if hasBody { lastbody = resp.Body - waitForBodyRead = make(chan bool) + waitForBodyRead = make(chan bool, 1) resp.Body.(*bodyEOFSignal).fn = func() { if alive && !pc.t.putIdleConn(pc) { alive = false } - if !alive { + if !alive || pc.isBroken() { pc.close() } waitForBodyRead <- true @@ -633,6 +649,28 @@ func (pc *persistConn) readLoop() { } } +func (pc *persistConn) writeLoop() { + for { + select { + case wr := <-pc.writech: + if pc.isBroken() { + wr.ch <- errors.New("http: can't write HTTP request on broken connection") + continue + } + err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra) + if err == nil { + err = pc.bw.Flush() + } + if err != nil { + pc.markBroken() + } + wr.ch <- err + case <-pc.closech: + return + } + } +} + type responseAndError struct { res *Response err error @@ -648,6 +686,15 @@ type requestAndChan struct { addedGzip bool } +// A writeRequest is sent by the readLoop's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error +} + func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { if pc.mutateHeaderFunc != nil { pc.mutateHeaderFunc(req.extraHeaders()) @@ -670,16 +717,49 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err pc.numExpectedResponses++ pc.lk.Unlock() - err = req.Request.write(pc.bw, pc.isProxy, req.extra) - if err != nil { - pc.close() - return + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh} + + resc := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{req.Request, resc, requestedGzip} + + var re responseAndError + var pconnDeadCh = pc.closech + var failTicker <-chan time.Time +WaitResponse: + for { + select { + case err := <-writeErrCh: + if err != nil { + re = responseAndError{nil, err} + break WaitResponse + } + case <-pconnDeadCh: + // The persist connection is dead. This shouldn't + // usually happen (only with Connection: close responses + // with no response bodies), but if it does happen it + // means either a) the remote server hung up on us + // prematurely, or b) the readLoop sent us a response & + // closed its closech at roughly the same time, and we + // selected this case first, in which case a response + // might still be coming soon. + // + // We can't avoid the select race in b) by using a unbuffered + // resc channel instead, because then goroutines can + // leak if we exit due to other errors. + pconnDeadCh = nil // avoid spinning + failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc + case <-failTicker: + re = responseAndError{nil, errors.New("net/http: transport closed before response was received")} + break WaitResponse + case re = <-resc: + break WaitResponse + } } - pc.bw.Flush() - ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req.Request, ch, requestedGzip} - re := <-ch pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() @@ -687,6 +767,15 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return re.res, re.err } +// markBroken marks a connection as broken (so it's not reused). +// It differs from close in that it doesn't close the underlying +// connection for use when it's still being read. +func (pc *persistConn) markBroken() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.broken = true +} + func (pc *persistConn) close() { pc.lk.Lock() defer pc.lk.Unlock() @@ -728,6 +817,7 @@ type bodyEOFSignal struct { body io.ReadCloser fn func() isClosed bool + once sync.Once } func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { @@ -735,9 +825,8 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { if es.isClosed && n > 0 { panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") } - if err == io.EOF && es.fn != nil { - es.fn() - es.fn = nil + if err == io.EOF { + es.condfn() } return } @@ -748,13 +837,18 @@ func (es *bodyEOFSignal) Close() (err error) { } es.isClosed = true err = es.body.Close() - if err == nil && es.fn != nil { - es.fn() - es.fn = nil + if err == nil { + es.condfn() } return } +func (es *bodyEOFSignal) condfn() { + if es.fn != nil { + es.once.Do(es.fn) + } +} + type readFirstCloseBoth struct { io.ReadCloser io.Closer diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index e676bf6db39..e4072e88fed 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -833,6 +833,74 @@ func TestIssue3644(t *testing.T) { } } +// Test that a client receives a server's reply, even if the server doesn't read +// the entire request body. +func TestIssue3595(t *testing.T) { + const deniedMsg = "sorry, denied." + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, deniedMsg, StatusUnauthorized) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + t.Errorf("Post: %v", err) + return + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body ReadAll: %v", err) + } + if !strings.Contains(string(got), deniedMsg) { + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } +} + +func TestTransportConcurrency(t *testing.T) { + const maxProcs = 16 + const numReqs = 500 + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%v", r.FormValue("echo")) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + reqs := make(chan string) + defer close(reqs) + + var wg sync.WaitGroup + wg.Add(numReqs) + for i := 0; i < maxProcs*2; i++ { + go func() { + for req := range reqs { + res, err := c.Get(ts.URL + "/?echo=" + req) + if err != nil { + t.Errorf("error on req %s: %v", req, err) + wg.Done() + continue + } + all, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("read error on req %s: %v", req, err) + wg.Done() + continue + } + if string(all) != req { + t.Errorf("body of req %s = %q; want %q", req, all, req) + } + wg.Done() + res.Body.Close() + } + }() + } + for i := 0; i < numReqs; i++ { + reqs <- fmt.Sprintf("request-%d", i) + } + wg.Wait() +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { |