diff options
Diffstat (limited to 'libgo/go/net/http/httptest/server.go')
-rw-r--r-- | libgo/go/net/http/httptest/server.go | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index 5e9ace591f3..e543672b1e8 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -9,6 +9,7 @@ package httptest import ( "bytes" "crypto/tls" + "crypto/x509" "flag" "fmt" "log" @@ -35,6 +36,9 @@ type Server struct { // before Start or StartTLS. Config *http.Server + // certificate is a parsed version of the TLS config certificate, if present. + certificate *x509.Certificate + // wg counts the number of outstanding HTTP requests on this server. // Close blocks until all requests are finished. wg sync.WaitGroup @@ -42,6 +46,10 @@ type Server struct { mu sync.Mutex // guards closed and conns closed bool conns map[net.Conn]http.ConnState // except terminal states + + // client is configured for use with the server. + // Its transport is automatically closed when Close is called. + client *http.Client } func newLocalListener() net.Listener { @@ -93,6 +101,9 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } s.URL = "http://" + s.Listener.Addr().String() s.wrap() s.goServe() @@ -107,6 +118,9 @@ func (s *Server) StartTLS() { if s.URL != "" { panic("Server already started") } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) @@ -124,6 +138,17 @@ func (s *Server) StartTLS() { if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } + s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + certpool := x509.NewCertPool() + certpool.AddCert(s.certificate) + s.client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + } s.Listener = tls.NewListener(s.Listener, s.TLS) s.URL = "https://" + s.Listener.Addr().String() s.wrap() @@ -186,6 +211,13 @@ func (s *Server) Close() { t.CloseIdleConnections() } + // Also close the client idle connections. + if s.client != nil { + if t, ok := s.client.Transport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + } + s.wg.Wait() } @@ -206,7 +238,7 @@ func (s *Server) CloseClientConnections() { nconn := len(s.conns) ch := make(chan struct{}, nconn) for c := range s.conns { - s.closeConnChan(c, ch) + go s.closeConnChan(c, ch) } s.mu.Unlock() @@ -228,6 +260,19 @@ func (s *Server) CloseClientConnections() { } } +// Certificate returns the certificate used by the server, or nil if +// the server doesn't use TLS. +func (s *Server) Certificate() *x509.Certificate { + return s.certificate +} + +// Client returns an HTTP client configured for making requests to the server. +// It is configured to trust the server's TLS test certificate and will +// close its idle connections on Server.Close. +func (s *Server) Client() *http.Client { + return s.client +} + func (s *Server) goServe() { s.wg.Add(1) go func() { |