summaryrefslogtreecommitdiff
path: root/workhorse/internal/dependencyproxy
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/dependencyproxy')
-rw-r--r--workhorse/internal/dependencyproxy/dependencyproxy.go123
-rw-r--r--workhorse/internal/dependencyproxy/dependencyproxy_test.go183
2 files changed, 306 insertions, 0 deletions
diff --git a/workhorse/internal/dependencyproxy/dependencyproxy.go b/workhorse/internal/dependencyproxy/dependencyproxy.go
new file mode 100644
index 00000000000..cfb3045544f
--- /dev/null
+++ b/workhorse/internal/dependencyproxy/dependencyproxy.go
@@ -0,0 +1,123 @@
+package dependencyproxy
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "time"
+
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/tracing"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/senddata"
+)
+
+// httpTransport defines a http.Transport with values
+// that are more restrictive than for http.DefaultTransport,
+// they define shorter TLS Handshake, and more aggressive connection closing
+// to prevent the connection hanging and reduce FD usage
+var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 10 * time.Second,
+ }).DialContext,
+ MaxIdleConns: 2,
+ IdleConnTimeout: 30 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 10 * time.Second,
+ ResponseHeaderTimeout: 30 * time.Second,
+}))
+
+var httpClient = &http.Client{
+ Transport: httpTransport,
+}
+
+type Injector struct {
+ senddata.Prefix
+ uploadHandler http.Handler
+}
+
+type entryParams struct {
+ Url string
+ Header http.Header
+}
+
+type nullResponseWriter struct {
+ header http.Header
+ status int
+}
+
+func (nullResponseWriter) Write(p []byte) (int, error) {
+ return len(p), nil
+}
+
+func (w *nullResponseWriter) Header() http.Header {
+ return w.header
+}
+
+func (w *nullResponseWriter) WriteHeader(status int) {
+ if w.status == 0 {
+ w.status = status
+ }
+}
+
+func NewInjector() *Injector {
+ return &Injector{Prefix: "send-dependency:"}
+}
+
+func (p *Injector) SetUploadHandler(uploadHandler http.Handler) {
+ p.uploadHandler = uploadHandler
+}
+
+func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ dependencyResponse, err := p.fetchUrl(r.Context(), sendData)
+ if err != nil {
+ helper.Fail500(w, r, err)
+ return
+ }
+ defer dependencyResponse.Body.Close()
+ if dependencyResponse.StatusCode >= 400 {
+ w.WriteHeader(dependencyResponse.StatusCode)
+ io.Copy(w, dependencyResponse.Body)
+ return
+ }
+
+ teeReader := io.TeeReader(dependencyResponse.Body, w)
+ saveFileRequest, err := http.NewRequestWithContext(r.Context(), "POST", r.URL.String()+"/upload", teeReader)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("dependency proxy: failed to create request: %w", err))
+ }
+ saveFileRequest.Header = helper.HeaderClone(r.Header)
+ saveFileRequest.ContentLength = dependencyResponse.ContentLength
+
+ w.Header().Del("Content-Length")
+
+ nrw := &nullResponseWriter{header: make(http.Header)}
+ p.uploadHandler.ServeHTTP(nrw, saveFileRequest)
+
+ if nrw.status != http.StatusOK {
+ fields := log.Fields{"code": nrw.status}
+
+ helper.Fail500WithFields(nrw, r, fmt.Errorf("dependency proxy: failed to upload file"), fields)
+ }
+}
+
+func (p *Injector) fetchUrl(ctx context.Context, sendData string) (*http.Response, error) {
+ var params entryParams
+ if err := p.Unpack(&params, sendData); err != nil {
+ return nil, fmt.Errorf("dependency proxy: unpack sendData: %v", err)
+ }
+
+ r, err := http.NewRequestWithContext(ctx, "GET", params.Url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("dependency proxy: failed to fetch dependency: %v", err)
+ }
+ r.Header = params.Header
+
+ return httpClient.Do(r)
+}
diff --git a/workhorse/internal/dependencyproxy/dependencyproxy_test.go b/workhorse/internal/dependencyproxy/dependencyproxy_test.go
new file mode 100644
index 00000000000..37e54c0b756
--- /dev/null
+++ b/workhorse/internal/dependencyproxy/dependencyproxy_test.go
@@ -0,0 +1,183 @@
+package dependencyproxy
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/testhelper"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/upload"
+)
+
+type fakeUploadHandler struct {
+ request *http.Request
+ body []byte
+ handler func(w http.ResponseWriter, r *http.Request)
+}
+
+func (f *fakeUploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ f.request = r
+
+ f.body, _ = io.ReadAll(r.Body)
+
+ f.handler(w, r)
+}
+
+type errWriter struct{ writes int }
+
+func (w *errWriter) Header() http.Header { return nil }
+func (w *errWriter) WriteHeader(h int) {}
+
+// First call of Write function succeeds while all the subsequent ones fail
+func (w *errWriter) Write(p []byte) (int, error) {
+ if w.writes > 0 {
+ return 0, fmt.Errorf("client error")
+ }
+
+ w.writes++
+
+ return len(p), nil
+}
+
+type fakePreAuthHandler struct{}
+
+func (f *fakePreAuthHandler) PreAuthorizeHandler(handler api.HandleFunc, _ string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ handler(w, r, &api.Response{TempPath: "../../testdata/scratch"})
+ })
+}
+
+func TestInject(t *testing.T) {
+ contentLength := 32768 + 1
+ content := strings.Repeat("p", contentLength)
+
+ testCases := []struct {
+ desc string
+ responseWriter http.ResponseWriter
+ contentLength int
+ handlerMustBeCalled bool
+ }{
+ {
+ desc: "the uploading successfully finalized",
+ responseWriter: httptest.NewRecorder(),
+ contentLength: contentLength,
+ handlerMustBeCalled: true,
+ }, {
+ desc: "a user failed to receive the response",
+ responseWriter: &errWriter{},
+ contentLength: contentLength,
+ handlerMustBeCalled: false,
+ }, {
+ desc: "the origin resource server returns partial response",
+ responseWriter: httptest.NewRecorder(),
+ contentLength: contentLength + 1,
+ handlerMustBeCalled: false,
+ },
+ }
+ testhelper.ConfigureSecret()
+
+ for _, tc := range testCases {
+ originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Length", strconv.Itoa(tc.contentLength))
+ w.Write([]byte(content))
+ }))
+ defer originResourceServer.Close()
+
+ // BodyUploader expects http.Handler as its second param, we can create a stub function and verify that
+ // it's only called for successful requests
+ handlerIsCalled := false
+ handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerIsCalled = true })
+
+ bodyUploader := upload.BodyUploader(&fakePreAuthHandler{}, handlerFunc, &upload.DefaultPreparer{})
+
+ injector := NewInjector()
+ injector.SetUploadHandler(bodyUploader)
+
+ r := httptest.NewRequest("GET", "/target", nil)
+ sendData := base64.StdEncoding.EncodeToString([]byte(`{"Token": "token", "Url": "` + originResourceServer.URL + `/url"}`))
+
+ injector.Inject(tc.responseWriter, r, sendData)
+
+ require.Equal(t, tc.handlerMustBeCalled, handlerIsCalled, "a partial file must not be saved")
+ }
+}
+
+func TestSuccessfullRequest(t *testing.T) {
+ content := []byte("result")
+ originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Length", strconv.Itoa(len(content)))
+ w.Write(content)
+ }))
+
+ uploadHandler := &fakeUploadHandler{
+ handler: func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ },
+ }
+
+ injector := NewInjector()
+ injector.SetUploadHandler(uploadHandler)
+
+ response := makeRequest(injector, `{"Token": "token", "Url": "`+originResourceServer.URL+`/url"}`)
+
+ require.Equal(t, "/target/upload", uploadHandler.request.URL.Path)
+ require.Equal(t, int64(6), uploadHandler.request.ContentLength)
+
+ require.Equal(t, content, uploadHandler.body)
+
+ require.Equal(t, 200, response.Code)
+ require.Equal(t, string(content), response.Body.String())
+}
+
+func TestIncorrectSendData(t *testing.T) {
+ response := makeRequest(NewInjector(), "")
+
+ require.Equal(t, 500, response.Code)
+ require.Equal(t, "Internal server error\n", response.Body.String())
+}
+
+func TestIncorrectSendDataUrl(t *testing.T) {
+ response := makeRequest(NewInjector(), `{"Token": "token", "Url": "url"}`)
+
+ require.Equal(t, 500, response.Code)
+ require.Equal(t, "Internal server error\n", response.Body.String())
+}
+
+func TestFailedOriginServer(t *testing.T) {
+ originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(404)
+ w.Write([]byte("Not found"))
+ }))
+
+ uploadHandler := &fakeUploadHandler{
+ handler: func(w http.ResponseWriter, r *http.Request) {
+ require.FailNow(t, "the error response must not be uploaded")
+ },
+ }
+
+ injector := NewInjector()
+ injector.SetUploadHandler(uploadHandler)
+
+ response := makeRequest(injector, `{"Token": "token", "Url": "`+originResourceServer.URL+`/url"}`)
+
+ require.Equal(t, 404, response.Code)
+ require.Equal(t, "Not found", response.Body.String())
+}
+
+func makeRequest(injector *Injector, data string) *httptest.ResponseRecorder {
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest("GET", "/target", nil)
+
+ sendData := base64.StdEncoding.EncodeToString([]byte(data))
+ injector.Inject(w, r, sendData)
+
+ return w
+}