diff options
Diffstat (limited to 'workhorse/internal/dependencyproxy')
-rw-r--r-- | workhorse/internal/dependencyproxy/dependencyproxy.go | 123 | ||||
-rw-r--r-- | workhorse/internal/dependencyproxy/dependencyproxy_test.go | 183 |
2 files changed, 306 insertions, 0 deletions
diff --git a/workhorse/internal/dependencyproxy/dependencyproxy.go b/workhorse/internal/dependencyproxy/dependencyproxy.go new file mode 100644 index 00000000000..cfb3045544f --- /dev/null +++ b/workhorse/internal/dependencyproxy/dependencyproxy.go @@ -0,0 +1,123 @@ +package dependencyproxy + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "time" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/tracing" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/senddata" +) + +// httpTransport defines a http.Transport with values +// that are more restrictive than for http.DefaultTransport, +// they define shorter TLS Handshake, and more aggressive connection closing +// to prevent the connection hanging and reduce FD usage +var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 10 * time.Second, + }).DialContext, + MaxIdleConns: 2, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, +})) + +var httpClient = &http.Client{ + Transport: httpTransport, +} + +type Injector struct { + senddata.Prefix + uploadHandler http.Handler +} + +type entryParams struct { + Url string + Header http.Header +} + +type nullResponseWriter struct { + header http.Header + status int +} + +func (nullResponseWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +func (w *nullResponseWriter) Header() http.Header { + return w.header +} + +func (w *nullResponseWriter) WriteHeader(status int) { + if w.status == 0 { + w.status = status + } +} + +func NewInjector() *Injector { + return &Injector{Prefix: "send-dependency:"} +} + +func (p *Injector) SetUploadHandler(uploadHandler http.Handler) { + p.uploadHandler = uploadHandler +} + +func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + dependencyResponse, err := p.fetchUrl(r.Context(), sendData) + if err != nil { + helper.Fail500(w, r, err) + return + } + defer dependencyResponse.Body.Close() + if dependencyResponse.StatusCode >= 400 { + w.WriteHeader(dependencyResponse.StatusCode) + io.Copy(w, dependencyResponse.Body) + return + } + + teeReader := io.TeeReader(dependencyResponse.Body, w) + saveFileRequest, err := http.NewRequestWithContext(r.Context(), "POST", r.URL.String()+"/upload", teeReader) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("dependency proxy: failed to create request: %w", err)) + } + saveFileRequest.Header = helper.HeaderClone(r.Header) + saveFileRequest.ContentLength = dependencyResponse.ContentLength + + w.Header().Del("Content-Length") + + nrw := &nullResponseWriter{header: make(http.Header)} + p.uploadHandler.ServeHTTP(nrw, saveFileRequest) + + if nrw.status != http.StatusOK { + fields := log.Fields{"code": nrw.status} + + helper.Fail500WithFields(nrw, r, fmt.Errorf("dependency proxy: failed to upload file"), fields) + } +} + +func (p *Injector) fetchUrl(ctx context.Context, sendData string) (*http.Response, error) { + var params entryParams + if err := p.Unpack(¶ms, sendData); err != nil { + return nil, fmt.Errorf("dependency proxy: unpack sendData: %v", err) + } + + r, err := http.NewRequestWithContext(ctx, "GET", params.Url, nil) + if err != nil { + return nil, fmt.Errorf("dependency proxy: failed to fetch dependency: %v", err) + } + r.Header = params.Header + + return httpClient.Do(r) +} diff --git a/workhorse/internal/dependencyproxy/dependencyproxy_test.go b/workhorse/internal/dependencyproxy/dependencyproxy_test.go new file mode 100644 index 00000000000..37e54c0b756 --- /dev/null +++ b/workhorse/internal/dependencyproxy/dependencyproxy_test.go @@ -0,0 +1,183 @@ +package dependencyproxy + +import ( + "encoding/base64" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/testhelper" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/upload" +) + +type fakeUploadHandler struct { + request *http.Request + body []byte + handler func(w http.ResponseWriter, r *http.Request) +} + +func (f *fakeUploadHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + f.request = r + + f.body, _ = io.ReadAll(r.Body) + + f.handler(w, r) +} + +type errWriter struct{ writes int } + +func (w *errWriter) Header() http.Header { return nil } +func (w *errWriter) WriteHeader(h int) {} + +// First call of Write function succeeds while all the subsequent ones fail +func (w *errWriter) Write(p []byte) (int, error) { + if w.writes > 0 { + return 0, fmt.Errorf("client error") + } + + w.writes++ + + return len(p), nil +} + +type fakePreAuthHandler struct{} + +func (f *fakePreAuthHandler) PreAuthorizeHandler(handler api.HandleFunc, _ string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler(w, r, &api.Response{TempPath: "../../testdata/scratch"}) + }) +} + +func TestInject(t *testing.T) { + contentLength := 32768 + 1 + content := strings.Repeat("p", contentLength) + + testCases := []struct { + desc string + responseWriter http.ResponseWriter + contentLength int + handlerMustBeCalled bool + }{ + { + desc: "the uploading successfully finalized", + responseWriter: httptest.NewRecorder(), + contentLength: contentLength, + handlerMustBeCalled: true, + }, { + desc: "a user failed to receive the response", + responseWriter: &errWriter{}, + contentLength: contentLength, + handlerMustBeCalled: false, + }, { + desc: "the origin resource server returns partial response", + responseWriter: httptest.NewRecorder(), + contentLength: contentLength + 1, + handlerMustBeCalled: false, + }, + } + testhelper.ConfigureSecret() + + for _, tc := range testCases { + originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", strconv.Itoa(tc.contentLength)) + w.Write([]byte(content)) + })) + defer originResourceServer.Close() + + // BodyUploader expects http.Handler as its second param, we can create a stub function and verify that + // it's only called for successful requests + handlerIsCalled := false + handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerIsCalled = true }) + + bodyUploader := upload.BodyUploader(&fakePreAuthHandler{}, handlerFunc, &upload.DefaultPreparer{}) + + injector := NewInjector() + injector.SetUploadHandler(bodyUploader) + + r := httptest.NewRequest("GET", "/target", nil) + sendData := base64.StdEncoding.EncodeToString([]byte(`{"Token": "token", "Url": "` + originResourceServer.URL + `/url"}`)) + + injector.Inject(tc.responseWriter, r, sendData) + + require.Equal(t, tc.handlerMustBeCalled, handlerIsCalled, "a partial file must not be saved") + } +} + +func TestSuccessfullRequest(t *testing.T) { + content := []byte("result") + originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", strconv.Itoa(len(content))) + w.Write(content) + })) + + uploadHandler := &fakeUploadHandler{ + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }, + } + + injector := NewInjector() + injector.SetUploadHandler(uploadHandler) + + response := makeRequest(injector, `{"Token": "token", "Url": "`+originResourceServer.URL+`/url"}`) + + require.Equal(t, "/target/upload", uploadHandler.request.URL.Path) + require.Equal(t, int64(6), uploadHandler.request.ContentLength) + + require.Equal(t, content, uploadHandler.body) + + require.Equal(t, 200, response.Code) + require.Equal(t, string(content), response.Body.String()) +} + +func TestIncorrectSendData(t *testing.T) { + response := makeRequest(NewInjector(), "") + + require.Equal(t, 500, response.Code) + require.Equal(t, "Internal server error\n", response.Body.String()) +} + +func TestIncorrectSendDataUrl(t *testing.T) { + response := makeRequest(NewInjector(), `{"Token": "token", "Url": "url"}`) + + require.Equal(t, 500, response.Code) + require.Equal(t, "Internal server error\n", response.Body.String()) +} + +func TestFailedOriginServer(t *testing.T) { + originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte("Not found")) + })) + + uploadHandler := &fakeUploadHandler{ + handler: func(w http.ResponseWriter, r *http.Request) { + require.FailNow(t, "the error response must not be uploaded") + }, + } + + injector := NewInjector() + injector.SetUploadHandler(uploadHandler) + + response := makeRequest(injector, `{"Token": "token", "Url": "`+originResourceServer.URL+`/url"}`) + + require.Equal(t, 404, response.Code) + require.Equal(t, "Not found", response.Body.String()) +} + +func makeRequest(injector *Injector, data string) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/target", nil) + + sendData := base64.StdEncoding.EncodeToString([]byte(data)) + injector.Inject(w, r, sendData) + + return w +} |