summaryrefslogtreecommitdiff
path: root/libgo/go/net/http/httputil/reverseproxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/http/httputil/reverseproxy.go')
-rw-r--r--libgo/go/net/http/httputil/reverseproxy.go58
1 files changed, 37 insertions, 21 deletions
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go
index 9c4bd6e09a5..134c452999d 100644
--- a/libgo/go/net/http/httputil/reverseproxy.go
+++ b/libgo/go/net/http/httputil/reverseproxy.go
@@ -17,6 +17,10 @@ import (
"time"
)
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
+
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
@@ -102,8 +106,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Del("Connection")
}
- if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
- outreq.Header.Set("X-Forwarded-For", clientIp)
+ if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ outreq.Header.Set("X-Forwarded-For", clientIP)
}
res, err := transport.RoundTrip(outreq)
@@ -112,20 +122,29 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
return
}
+ defer res.Body.Close()
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
+ p.copyResponse(rw, res.Body)
+}
- if res.Body != nil {
- var dst io.Writer = rw
- if p.FlushInterval != 0 {
- if wf, ok := rw.(writeFlusher); ok {
- dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+ if p.FlushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: p.FlushInterval,
+ done: make(chan bool),
}
+ go mlw.flushLoop()
+ defer mlw.stop()
+ dst = mlw
}
- io.Copy(dst, res.Body)
}
+
+ io.Copy(dst, src)
}
type writeFlusher interface {
@@ -137,22 +156,14 @@ type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
- lk sync.Mutex // protects init of done, as well Write + Flush
+ lk sync.Mutex // protects Write + Flush
done chan bool
}
-func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
- if m.done == nil {
- m.done = make(chan bool)
- go m.flushLoop()
- }
- n, err = m.dst.Write(p)
- if err != nil {
- m.done <- true
- }
- return
+ return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
@@ -160,13 +171,18 @@ func (m *maxLatencyWriter) flushLoop() {
defer t.Stop()
for {
select {
+ case <-m.done:
+ if onExitFlushLoop != nil {
+ onExitFlushLoop()
+ }
+ return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
- case <-m.done:
- return
}
}
panic("unreached")
}
+
+func (m *maxLatencyWriter) stop() { m.done <- true }