summaryrefslogtreecommitdiff
path: root/workhorse/internal/senddata/senddata.go
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/senddata/senddata.go')
-rw-r--r--workhorse/internal/senddata/senddata.go105
1 files changed, 105 insertions, 0 deletions
diff --git a/workhorse/internal/senddata/senddata.go b/workhorse/internal/senddata/senddata.go
new file mode 100644
index 00000000000..c287d2574fa
--- /dev/null
+++ b/workhorse/internal/senddata/senddata.go
@@ -0,0 +1,105 @@
+package senddata
+
+import (
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata/contentprocessor"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+var (
+ sendDataResponses = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_senddata_responses",
+ Help: "How many HTTP responses have been hijacked by a workhorse senddata injecter",
+ },
+ []string{"injecter"},
+ )
+ sendDataResponseBytes = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_senddata_response_bytes",
+ Help: "How many bytes have been written by workhorse senddata response injecters",
+ },
+ []string{"injecter"},
+ )
+)
+
+type sendDataResponseWriter struct {
+ rw http.ResponseWriter
+ status int
+ hijacked bool
+ req *http.Request
+ injecters []Injecter
+}
+
+func SendData(h http.Handler, injecters ...Injecter) http.Handler {
+ return contentprocessor.SetContentHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ s := sendDataResponseWriter{
+ rw: w,
+ req: r,
+ injecters: injecters,
+ }
+ defer s.flush()
+ h.ServeHTTP(&s, r)
+ }))
+}
+
+func (s *sendDataResponseWriter) Header() http.Header {
+ return s.rw.Header()
+}
+
+func (s *sendDataResponseWriter) Write(data []byte) (int, error) {
+ if s.status == 0 {
+ s.WriteHeader(http.StatusOK)
+ }
+ if s.hijacked {
+ return len(data), nil
+ }
+ return s.rw.Write(data)
+}
+
+func (s *sendDataResponseWriter) WriteHeader(status int) {
+ if s.status != 0 {
+ return
+ }
+ s.status = status
+
+ if s.status == http.StatusOK && s.tryInject() {
+ return
+ }
+
+ s.rw.WriteHeader(s.status)
+}
+
+func (s *sendDataResponseWriter) tryInject() bool {
+ if s.hijacked {
+ return false
+ }
+
+ header := s.Header().Get(headers.GitlabWorkhorseSendDataHeader)
+ if header == "" {
+ return false
+ }
+
+ for _, injecter := range s.injecters {
+ if injecter.Match(header) {
+ s.hijacked = true
+ helper.DisableResponseBuffering(s.rw)
+ crw := helper.NewCountingResponseWriter(s.rw)
+ injecter.Inject(crw, s.req, header)
+ sendDataResponses.WithLabelValues(injecter.Name()).Inc()
+ sendDataResponseBytes.WithLabelValues(injecter.Name()).Add(float64(crw.Count()))
+ return true
+ }
+ }
+
+ return false
+}
+
+func (s *sendDataResponseWriter) flush() {
+ s.WriteHeader(http.StatusOK)
+}