diff options
Diffstat (limited to 'libgo/go/net/http/serve_test.go')
-rw-r--r-- | libgo/go/net/http/serve_test.go | 219 |
1 files changed, 170 insertions, 49 deletions
diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 6394da3bb7c..fb18cb2c6f5 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -23,6 +23,7 @@ import ( "net" . "net/http" "net/http/httptest" + "net/http/httptrace" "net/http/httputil" "net/http/internal" "net/http/internal/testcert" @@ -2146,7 +2147,7 @@ func TestInvalidTrailerClosesConnection(t *testing.T) { // Read and Write. type slowTestConn struct { // over multiple calls to Read, time.Durations are slept, strings are read. - script []interface{} + script []any closec chan bool mu sync.Mutex // guards rd/wd @@ -2238,7 +2239,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { defer afterTest(t) for _, handler := range testHandlerBodyConsumers { conn := &slowTestConn{ - script: []interface{}{ + script: []any{ "POST /public HTTP/1.1\r\n" + "Host: test\r\n" + "Content-Length: 10000\r\n" + @@ -2273,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { } } +// cancelableTimeoutContext overwrites the error message to DeadlineExceeded +type cancelableTimeoutContext struct { + context.Context +} + +func (c cancelableTimeoutContext) Err() error { + if c.Context.Err() != nil { + return context.DeadlineExceeded + } + return nil +} + func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } func testTimeoutHandler(t *testing.T, h2 bool) { @@ -2285,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h2, h) defer cst.close() // Succeed without timing out: @@ -2307,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2428,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h1Mode, h) defer cst.close() // Succeed without timing out: @@ -2450,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2500,6 +2517,47 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } +func TestTimeoutHandlerContextCanceled(t *testing.T) { + setParallel(t) + defer afterTest(t) + writeErrors := make(chan error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/plain") + var err error + // The request context has already been canceled, but + // retry the write for a while to give the timeout handler + // a chance to notice. + for i := 0; i < 100; i++ { + _, err = w.Write([]byte("a")) + if err != nil { + break + } + time.Sleep(1 * time.Millisecond) + } + writeErrors <- err + }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + h := NewTestTimeoutHandler(sayHi, ctx) + cst := newClientServerTest(t, h1Mode, h) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := io.ReadAll(res.Body) + if g, e := string(body), ""; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := <-writeErrors, context.Canceled; g != e { + t.Errorf("got unexpected Write in handler: %v, want %g", g, e) + } +} + // https://golang.org/issue/15948 func TestTimeoutHandlerEmptyResponse(t *testing.T) { setParallel(t) @@ -2708,7 +2766,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) { +func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { defer afterTest(t) // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three @@ -3017,22 +3075,14 @@ func TestClientWriteShutdown(t *testing.T) { if err != nil { t.Fatalf("CloseWrite: %v", err) } - donec := make(chan bool) - go func() { - defer close(donec) - bs, err := io.ReadAll(conn) - if err != nil { - t.Errorf("ReadAll: %v", err) - } - got := string(bs) - if got != "" { - t.Errorf("read %q from server; want nothing", got) - } - }() - select { - case <-donec: - case <-time.After(10 * time.Second): - t.Fatalf("timeout") + + bs, err := io.ReadAll(conn) + if err != nil { + t.Errorf("ReadAll: %v", err) + } + got := string(bs) + if got != "" { + t.Errorf("read %q from server; want nothing", got) } } @@ -3884,7 +3934,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // this test fails, it hangs. This helps debugging and I've // added this enough times "temporarily". It now gets added // full time. - errorf := func(format string, args ...interface{}) { + errorf := func(format string, args ...any) { v := fmt.Sprintf(format, args...) println(v) t.Error(v) @@ -3893,10 +3943,10 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { unblockBackend := make(chan bool) backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() - didCopy := make(chan interface{}) + didCopy := make(chan any) go func() { n, err := io.CopyN(rw, req.Body, bodySize) - didCopy <- []interface{}{n, err} + didCopy <- []any{n, err} }() isGone := false Loop: @@ -4888,7 +4938,7 @@ func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) - ch := make(chan interface{}, 1) + ch := make(chan any, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) @@ -5689,22 +5739,37 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { } // Not parallel: messes with global variable. (http2goAwayTimeout) defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "%v", r.RemoteAddr) - })) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) - a := cst.getURL(cst.ts.URL) - if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { - t.Fatalf("test server has active conns") - } - b := cst.getURL(cst.ts.URL) - if a == b { - t.Errorf("got same connection between first and second requests") - } - if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { - t.Fatalf("test server has active conns") + for try := 0; try < 2; try++ { + if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { + t.Fatalf("request %v: test server has active conns", try) + } + conns := 0 + var info httptrace.GotConnInfo + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + GotConn: func(v httptrace.GotConnInfo) { + conns++ + info = v + }, + }) + req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if conns != 1 { + t.Fatalf("request %v: got %v conns, want 1", try, conns) + } + if info.Reused || info.WasIdle { + t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle) + } } } @@ -5933,11 +5998,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { t.Fatal(err) } - select { - case <-done: - case <-time.After(2 * time.Second): - t.Error("timeout") - } + <-done } // Issue 18319: test that the Server validates the request method. @@ -6232,7 +6293,7 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // setting contentEncoding as an interface instead of a string // directly, so as to differentiate between 3 states: // unset, empty string "" and set string "foo/bar". - contentEncoding interface{} + contentEncoding any wantContentType string } @@ -6490,7 +6551,7 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { - t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body) + t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body) } _, err = rwc.Write([]byte("hello")) @@ -6609,3 +6670,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon } } } + +func TestMaxBytesHandler(t *testing.T) { + setParallel(t) + defer afterTest(t) + + for _, maxSize := range []int64{100, 1_000, 1_000_000} { + for _, requestSize := range []int64{100, 1_000, 1_000_000} { + t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), + func(t *testing.T) { + testMaxBytesHandler(t, maxSize, requestSize) + }) + } + } +} + +func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { + var ( + handlerN int64 + handlerErr error + ) + echo := HandlerFunc(func(w ResponseWriter, r *Request) { + var buf bytes.Buffer + handlerN, handlerErr = io.Copy(&buf, r.Body) + io.Copy(w, &buf) + }) + + ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + defer ts.Close() + + c := ts.Client() + var buf strings.Builder + body := strings.NewReader(strings.Repeat("a", int(requestSize))) + res, err := c.Post(ts.URL, "text/plain", body) + if err != nil { + t.Errorf("unexpected connection error: %v", err) + } else { + _, err = io.Copy(&buf, res.Body) + res.Body.Close() + if err != nil { + t.Errorf("unexpected read error: %v", err) + } + } + if handlerN > maxSize { + t.Errorf("expected max request body %d; got %d", maxSize, handlerN) + } + if requestSize > maxSize && handlerErr == nil { + t.Error("expected error on handler side; got nil") + } + if requestSize <= maxSize { + if handlerErr != nil { + t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr) + } + if handlerN != requestSize { + t.Errorf("expected request of size %d; got %d", requestSize, handlerN) + } + } + if buf.Len() != int(handlerN) { + t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) + } +} |