diff options
Diffstat (limited to 'workhorse/internal')
19 files changed, 337 insertions, 151 deletions
diff --git a/workhorse/internal/api/api.go b/workhorse/internal/api/api.go index 7f696f70c7a..896f59a322a 100644 --- a/workhorse/internal/api/api.go +++ b/workhorse/internal/api/api.go @@ -64,7 +64,13 @@ func NewAPI(myURL *url.URL, version string, roundTripper http.RoundTripper) *API } type GeoProxyEndpointResponse struct { - GeoProxyURL string `json:"geo_proxy_url"` + GeoProxyURL string `json:"geo_proxy_url"` + GeoProxyExtraData string `json:"geo_proxy_extra_data"` +} + +type GeoProxyData struct { + GeoProxyURL *url.URL + GeoProxyExtraData string } type HandleFunc func(http.ResponseWriter, *http.Request, *Response) @@ -394,7 +400,7 @@ func validResponseContentType(resp *http.Response) bool { return helper.IsContentType(ResponseContentType, resp.Header.Get("Content-Type")) } -func (api *API) GetGeoProxyURL() (*url.URL, error) { +func (api *API) GetGeoProxyData() (*GeoProxyData, error) { geoProxyApiUrl := *api.URL geoProxyApiUrl.Path, geoProxyApiUrl.RawPath = joinURLPath(api.URL, geoProxyEndpointPath) geoProxyApiReq := &http.Request{ @@ -405,23 +411,26 @@ func (api *API) GetGeoProxyURL() (*url.URL, error) { httpResponse, err := api.doRequestWithoutRedirects(geoProxyApiReq) if err != nil { - return nil, fmt.Errorf("GetGeoProxyURL: do request: %v", err) + return nil, fmt.Errorf("GetGeoProxyData: do request: %v", err) } defer httpResponse.Body.Close() if httpResponse.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GetGeoProxyURL: Received HTTP status code: %v", httpResponse.StatusCode) + return nil, fmt.Errorf("GetGeoProxyData: Received HTTP status code: %v", httpResponse.StatusCode) } response := &GeoProxyEndpointResponse{} if err := json.NewDecoder(httpResponse.Body).Decode(response); err != nil { - return nil, fmt.Errorf("GetGeoProxyURL: decode response: %v", err) + return nil, fmt.Errorf("GetGeoProxyData: decode response: %v", err) } geoProxyURL, err := url.Parse(response.GeoProxyURL) if err != nil { - return nil, fmt.Errorf("GetGeoProxyURL: Could not parse Geo proxy URL: %v, err: %v", response.GeoProxyURL, err) + return nil, fmt.Errorf("GetGeoProxyData: Could not parse Geo proxy URL: %v, err: %v", response.GeoProxyURL, err) } - return geoProxyURL, nil + return &GeoProxyData{ + GeoProxyURL: geoProxyURL, + GeoProxyExtraData: response.GeoProxyExtraData, + }, nil } diff --git a/workhorse/internal/api/api_test.go b/workhorse/internal/api/api_test.go index b82bb55fb85..346f32b4a36 100644 --- a/workhorse/internal/api/api_test.go +++ b/workhorse/internal/api/api_test.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "regexp" "testing" @@ -18,21 +17,37 @@ import ( "gitlab.com/gitlab-org/gitlab/workhorse/internal/upstream/roundtripper" ) -func TestGetGeoProxyURLWhenGeoSecondary(t *testing.T) { - geoProxyURL, err := getGeoProxyURLGivenResponse(t, `{"geo_proxy_url":"http://primary"}`) - - require.NoError(t, err) - require.Equal(t, "http://primary", geoProxyURL.String()) -} - -func TestGetGeoProxyURLWhenGeoPrimaryOrNonGeo(t *testing.T) { - geoProxyURL, err := getGeoProxyURLGivenResponse(t, "{}") - - require.NoError(t, err) - require.Equal(t, "", geoProxyURL.String()) +func TestGetGeoProxyDataForResponses(t *testing.T) { + testCases := []struct { + desc string + json string + expectedError bool + expectedURL string + expectedExtraData string + }{ + {"when Geo secondary", `{"geo_proxy_url":"http://primary","geo_proxy_extra_data":"geo-data"}`, false, "http://primary", "geo-data"}, + {"when Geo secondary with explicit null data", `{"geo_proxy_url":"http://primary","geo_proxy_extra_data":null}`, false, "http://primary", ""}, + {"when Geo secondary without extra data", `{"geo_proxy_url":"http://primary"}`, false, "http://primary", ""}, + {"when Geo primary or no node", `{}`, false, "", ""}, + {"for malformed request", `non-json`, true, "", ""}, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + geoProxyData, err := getGeoProxyDataGivenResponse(t, tc.json) + + if tc.expectedError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedURL, geoProxyData.GeoProxyURL.String()) + require.Equal(t, tc.expectedExtraData, geoProxyData.GeoProxyExtraData) + } + }) + } } -func getGeoProxyURLGivenResponse(t *testing.T, givenInternalApiResponse string) (*url.URL, error) { +func getGeoProxyDataGivenResponse(t *testing.T, givenInternalApiResponse string) (*GeoProxyData, error) { t.Helper() ts := testRailsServer(regexp.MustCompile(`/api/v4/geo/proxy`), 200, givenInternalApiResponse) defer ts.Close() @@ -43,9 +58,9 @@ func getGeoProxyURLGivenResponse(t *testing.T, givenInternalApiResponse string) apiClient := NewAPI(backend, version, rt) - geoProxyURL, err := apiClient.GetGeoProxyURL() + geoProxyData, err := apiClient.GetGeoProxyData() - return geoProxyURL, err + return geoProxyData, err } func testRailsServer(url *regexp.Regexp, code int, body string) *httptest.Server { diff --git a/workhorse/internal/git/info-refs.go b/workhorse/internal/git/info-refs.go index 8390143b99b..b7f825839f8 100644 --- a/workhorse/internal/git/info-refs.go +++ b/workhorse/internal/git/info-refs.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "sync" "github.com/golang/gddo/httputil" grpccodes "google.golang.org/grpc/codes" @@ -64,21 +65,43 @@ func handleGetInfoRefsWithGitaly(ctx context.Context, responseWriter *HttpRespon return err } - var w io.Writer - + var w io.WriteCloser = nopCloser{responseWriter} if encoding == "gzip" { - gzWriter := gzip.NewWriter(responseWriter) - w = gzWriter - defer gzWriter.Close() + gzWriter := getGzWriter(responseWriter) + defer putGzWriter(gzWriter) + w = gzWriter responseWriter.Header().Set("Content-Encoding", "gzip") - } else { - w = responseWriter } if _, err = io.Copy(w, infoRefsResponseReader); err != nil { return err } + if err := w.Close(); err != nil { + return err + } + return nil } + +var gzipPool = &sync.Pool{New: func() interface{} { + // Invariant: the inner writer is io.Discard. We do not want to retain + // response writers of past requests in the pool. + return gzip.NewWriter(io.Discard) +}} + +func getGzWriter(w io.Writer) *gzip.Writer { + gzWriter := gzipPool.Get().(*gzip.Writer) + gzWriter.Reset(w) + return gzWriter +} + +func putGzWriter(w *gzip.Writer) { + w.Reset(io.Discard) // Maintain pool invariant + gzipPool.Put(w) +} + +type nopCloser struct{ io.Writer } + +func (nc nopCloser) Close() error { return nil } diff --git a/workhorse/internal/headers/content_headers.go b/workhorse/internal/headers/content_headers.go index 9c33ddb8c8a..8cca3d97e82 100644 --- a/workhorse/internal/headers/content_headers.go +++ b/workhorse/internal/headers/content_headers.go @@ -8,28 +8,37 @@ import ( ) var ( - ImageTypeRegex = regexp.MustCompile(`^image/*`) - SvgMimeTypeRegex = regexp.MustCompile(`^image/svg\+xml$`) + javaScriptTypeRegex = regexp.MustCompile(`^(text|application)\/javascript$`) - TextTypeRegex = regexp.MustCompile(`^text/*`) + imageTypeRegex = regexp.MustCompile(`^image/*`) + svgMimeTypeRegex = regexp.MustCompile(`^image/svg\+xml$`) - VideoTypeRegex = regexp.MustCompile(`^video/*`) + textTypeRegex = regexp.MustCompile(`^text/*`) - PdfTypeRegex = regexp.MustCompile(`application\/pdf`) + videoTypeRegex = regexp.MustCompile(`^video/*`) - AttachmentRegex = regexp.MustCompile(`^attachment`) - InlineRegex = regexp.MustCompile(`^inline`) + pdfTypeRegex = regexp.MustCompile(`application\/pdf`) + + attachmentRegex = regexp.MustCompile(`^attachment`) + inlineRegex = regexp.MustCompile(`^inline`) ) // Mime types that can't be inlined. Usually subtypes of main types -var forbiddenInlineTypes = []*regexp.Regexp{SvgMimeTypeRegex} +var forbiddenInlineTypes = []*regexp.Regexp{svgMimeTypeRegex} // Mime types that can be inlined. We can add global types like "image/" or // specific types like "text/plain". If there is a specific type inside a global // allowed type that can't be inlined we must add it to the forbiddenInlineTypes var. // One example of this is the mime type "image". We allow all images to be // inlined except for SVGs. -var allowedInlineTypes = []*regexp.Regexp{ImageTypeRegex, TextTypeRegex, VideoTypeRegex, PdfTypeRegex} +var allowedInlineTypes = []*regexp.Regexp{imageTypeRegex, textTypeRegex, videoTypeRegex, pdfTypeRegex} + +const ( + svgContentType = "image/svg+xml" + textPlainContentType = "text/plain; charset=utf-8" + attachmentDispositionText = "attachment" + inlineDispositionText = "inline" +) func SafeContentHeaders(data []byte, contentDisposition string) (string, string) { contentType := safeContentType(data) @@ -40,16 +49,24 @@ func SafeContentHeaders(data []byte, contentDisposition string) (string, string) func safeContentType(data []byte) string { // Special case for svg because DetectContentType detects it as text if svg.Is(data) { - return "image/svg+xml" + return svgContentType } // Override any existing Content-Type header from other ResponseWriters contentType := http.DetectContentType(data) + // http.DetectContentType does not support JavaScript and would only + // return text/plain. But for cautionary measures, just in case they start supporting + // it down the road and start returning application/javascript, we want to handle it now + // to avoid regressions. + if isType(contentType, javaScriptTypeRegex) { + return textPlainContentType + } + // If the content is text type, we set to plain, because we don't // want to render it inline if they're html or javascript - if isType(contentType, TextTypeRegex) { - return "text/plain; charset=utf-8" + if isType(contentType, textTypeRegex) { + return textPlainContentType } return contentType @@ -58,7 +75,7 @@ func safeContentType(data []byte) string { func safeContentDisposition(contentType string, contentDisposition string) string { // If the existing disposition is attachment we return that. This allow us // to force a download from GitLab (ie: RawController) - if AttachmentRegex.MatchString(contentDisposition) { + if attachmentRegex.MatchString(contentDisposition) { return contentDisposition } @@ -82,11 +99,11 @@ func safeContentDisposition(contentType string, contentDisposition string) strin func attachmentDisposition(contentDisposition string) string { if contentDisposition == "" { - return "attachment" + return attachmentDispositionText } - if InlineRegex.MatchString(contentDisposition) { - return InlineRegex.ReplaceAllString(contentDisposition, "attachment") + if inlineRegex.MatchString(contentDisposition) { + return inlineRegex.ReplaceAllString(contentDisposition, attachmentDispositionText) } return contentDisposition @@ -94,11 +111,11 @@ func attachmentDisposition(contentDisposition string) string { func inlineDisposition(contentDisposition string) string { if contentDisposition == "" { - return "inline" + return inlineDispositionText } - if AttachmentRegex.MatchString(contentDisposition) { - return AttachmentRegex.ReplaceAllString(contentDisposition, "inline") + if attachmentRegex.MatchString(contentDisposition) { + return attachmentRegex.ReplaceAllString(contentDisposition, inlineDispositionText) } return contentDisposition diff --git a/workhorse/internal/lsif_transformer/parser/errors.go b/workhorse/internal/lsif_transformer/parser/errors.go deleted file mode 100644 index 1040a789413..00000000000 --- a/workhorse/internal/lsif_transformer/parser/errors.go +++ /dev/null @@ -1,30 +0,0 @@ -package parser - -import ( - "errors" - "strings" -) - -func combineErrors(errsOrNil ...error) error { - var errs []error - for _, err := range errsOrNil { - if err != nil { - errs = append(errs, err) - } - } - - if len(errs) == 0 { - return nil - } - - if len(errs) == 1 { - return errs[0] - } - - var msgs []string - for _, err := range errs { - msgs = append(msgs, err.Error()) - } - - return errors.New(strings.Join(msgs, "\n")) -} diff --git a/workhorse/internal/lsif_transformer/parser/errors_test.go b/workhorse/internal/lsif_transformer/parser/errors_test.go deleted file mode 100644 index 31a7130d05e..00000000000 --- a/workhorse/internal/lsif_transformer/parser/errors_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package parser - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/require" -) - -type customErr struct { - err string -} - -func (e customErr) Error() string { - return e.err -} - -func TestCombineErrors(t *testing.T) { - err := combineErrors(nil, errors.New("first"), nil, customErr{"second"}) - require.EqualError(t, err, "first\nsecond") - - err = customErr{"custom error"} - require.Equal(t, err, combineErrors(nil, err, nil)) - - require.Nil(t, combineErrors(nil, nil, nil)) -} diff --git a/workhorse/internal/lsif_transformer/parser/hovers.go b/workhorse/internal/lsif_transformer/parser/hovers.go index 5889d595ade..a13c7e4c5c2 100644 --- a/workhorse/internal/lsif_transformer/parser/hovers.go +++ b/workhorse/internal/lsif_transformer/parser/hovers.go @@ -95,10 +95,15 @@ func (h *Hovers) For(refId Id) json.RawMessage { } func (h *Hovers) Close() error { - return combineErrors( + for _, err := range []error{ h.File.Close(), h.Offsets.Close(), - ) + } { + if err != nil { + return err + } + } + return nil } func (h *Hovers) addData(line []byte) error { diff --git a/workhorse/internal/lsif_transformer/parser/ranges.go b/workhorse/internal/lsif_transformer/parser/ranges.go index a11a66d70ca..3786e15186e 100644 --- a/workhorse/internal/lsif_transformer/parser/ranges.go +++ b/workhorse/internal/lsif_transformer/parser/ranges.go @@ -130,11 +130,16 @@ func (r *Ranges) Serialize(f io.Writer, rangeIds []Id, docs map[Id]string) error } func (r *Ranges) Close() error { - return combineErrors( + for _, err := range []error{ r.Cache.Close(), r.References.Close(), r.Hovers.Close(), - ) + } { + if err != nil { + return err + } + } + return nil } func (r *Ranges) definitionPathFor(docs map[Id]string, refId Id) string { diff --git a/workhorse/internal/lsif_transformer/parser/references.go b/workhorse/internal/lsif_transformer/parser/references.go index 58ff9a61c02..39c34105fd1 100644 --- a/workhorse/internal/lsif_transformer/parser/references.go +++ b/workhorse/internal/lsif_transformer/parser/references.go @@ -86,10 +86,15 @@ func (r *References) For(docs map[Id]string, refId Id) []SerializedReference { } func (r *References) Close() error { - return combineErrors( + for _, err := range []error{ r.Items.Close(), r.Offsets.Close(), - ) + } { + if err != nil { + return err + } + } + return nil } func (r *References) getItems(refId Id) []Item { diff --git a/workhorse/internal/proxy/proxy.go b/workhorse/internal/proxy/proxy.go index be161c833a9..06e2c65a6a8 100644 --- a/workhorse/internal/proxy/proxy.go +++ b/workhorse/internal/proxy/proxy.go @@ -19,6 +19,7 @@ type Proxy struct { reverseProxy *httputil.ReverseProxy AllowResponseBuffering bool customHeaders map[string]string + forceTargetHostHeader bool } func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) { @@ -27,6 +28,12 @@ func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) { } } +func WithForcedTargetHostHeader() func(*Proxy) { + return func(proxy *Proxy) { + proxy.forceTargetHostHeader = true + } +} + func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, options ...func(*Proxy)) *Proxy { p := Proxy{Version: version, AllowResponseBuffering: true, customHeaders: make(map[string]string)} @@ -43,6 +50,25 @@ func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, op option(&p) } + if p.forceTargetHostHeader { + // because of https://github.com/golang/go/issues/28168, the + // upstream won't receive the expected Host header unless this + // is forced in the Director func here + previousDirector := p.reverseProxy.Director + p.reverseProxy.Director = func(request *http.Request) { + previousDirector(request) + + // send original host along for the upstream + // to know it's being proxied under a different Host + // (for redirects and other stuff that depends on this) + request.Header.Set("X-Forwarded-Host", request.Host) + request.Header.Set("Forwarded", fmt.Sprintf("host=%s", request.Host)) + + // override the Host with the target + request.Host = request.URL.Host + } + } + return &p } diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go index 2396bb0f952..b009cda1a24 100644 --- a/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go +++ b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go @@ -56,12 +56,18 @@ func TestSetProperContentTypeAndDisposition(t *testing.T) { body: "<html><body>Hello world!</body></html>", }, { - desc: "Javascript type", + desc: "Javascript within HTML type", contentType: "text/plain; charset=utf-8", contentDisposition: "inline", body: "<script>alert(\"foo\")</script>", }, { + desc: "Javascript type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "alert(\"foo\")", + }, + { desc: "Image type", contentType: "image/png", contentDisposition: "inline", @@ -170,25 +176,41 @@ func TestSetProperContentTypeAndDisposition(t *testing.T) { } func TestFailOverrideContentType(t *testing.T) { - testCase := struct { - contentType string - body string + testCases := []struct { + desc string + overrideFromUpstream string + responseContentType string + body string }{ - contentType: "text/plain; charset=utf-8", - body: "<html><body>Hello world!</body></html>", + { + desc: "Force text/html into text/plain", + responseContentType: "text/plain; charset=utf-8", + overrideFromUpstream: "text/html; charset=utf-8", + body: "<html><body>Hello world!</body></html>", + }, + { + desc: "Force application/javascript into text/plain", + responseContentType: "text/plain; charset=utf-8", + overrideFromUpstream: "application/javascript; charset=utf-8", + body: "alert(1);", + }, } - h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - // We are pretending to be upstream or an inner layer of the ResponseWriter chain - w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") - w.Header().Set(headers.ContentTypeHeader, "text/html; charset=utf-8") - _, err := io.WriteString(w, testCase.body) - require.NoError(t, err) - }) + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + w.Header().Set(headers.ContentTypeHeader, tc.overrideFromUpstream) + _, err := io.WriteString(w, tc.body) + require.NoError(t, err) + }) - resp := makeRequest(t, h, testCase.body, "") + resp := makeRequest(t, h, tc.body, "") - require.Equal(t, testCase.contentType, resp.Header.Get(headers.ContentTypeHeader)) + require.Equal(t, tc.responseContentType, resp.Header.Get(headers.ContentTypeHeader)) + }) + } } func TestSuccessOverrideContentDispositionFromInlineToAttachment(t *testing.T) { diff --git a/workhorse/internal/upload/artifacts_upload_test.go b/workhorse/internal/upload/artifacts_upload_test.go index 0a9e4ef3869..96eb3810673 100644 --- a/workhorse/internal/upload/artifacts_upload_test.go +++ b/workhorse/internal/upload/artifacts_upload_test.go @@ -66,7 +66,7 @@ func testArtifactsUploadServer(t *testing.T, authResponse *api.Response, bodyPro if r.Method != "POST" { t.Fatal("Expected POST request") } - if opts.IsLocal() { + if opts.IsLocalTempFile() { if r.FormValue("file.path") == "" { t.Fatal("Expected file to be present") return diff --git a/workhorse/internal/upload/destination/destination.go b/workhorse/internal/upload/destination/destination.go index 7a030e59a64..b18b6e22a99 100644 --- a/workhorse/internal/upload/destination/destination.go +++ b/workhorse/internal/upload/destination/destination.go @@ -128,9 +128,14 @@ func Upload(ctx context.Context, reader io.Reader, size int64, opts *UploadOpts) var uploadDestination consumer var err error switch { - case opts.IsLocal(): - clientMode = "local" + // This case means Workhorse is acting as an upload proxy for Rails and buffers files + // to disk in a temporary location, see: + // https://docs.gitlab.com/ee/development/uploads/background.html#moving-disk-buffering-to-workhorse + case opts.IsLocalTempFile(): + clientMode = "local_tempfile" uploadDestination, err = fh.newLocalFile(ctx, opts) + // All cases below mean we are doing a direct upload to remote i.e. object storage, see: + // https://docs.gitlab.com/ee/development/uploads/background.html#moving-to-object-storage-and-direct-uploads case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsGoCloud(): clientMode = fmt.Sprintf("go_cloud:%s", opts.ObjectStorageConfig.Provider) p := &objectstore.GoCloudObjectParams{ @@ -141,14 +146,14 @@ func Upload(ctx context.Context, reader io.Reader, size int64, opts *UploadOpts) } uploadDestination, err = objectstore.NewGoCloudObject(p) case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsAWS() && opts.ObjectStorageConfig.IsValid(): - clientMode = "s3" + clientMode = "s3_client" uploadDestination, err = objectstore.NewS3Object( opts.RemoteTempObjectID, opts.ObjectStorageConfig.S3Credentials, opts.ObjectStorageConfig.S3Config, ) case opts.IsMultipart(): - clientMode = "multipart" + clientMode = "s3_multipart" uploadDestination, err = objectstore.NewMultipart( opts.PresignedParts, opts.PresignedCompleteMultipart, @@ -158,7 +163,7 @@ func Upload(ctx context.Context, reader io.Reader, size int64, opts *UploadOpts) opts.PartSize, ) default: - clientMode = "http" + clientMode = "presigned_put" uploadDestination, err = objectstore.NewObject( opts.PresignedPut, opts.PresignedDelete, @@ -195,15 +200,15 @@ func Upload(ctx context.Context, reader io.Reader, size int64, opts *UploadOpts) logger := log.WithContextFields(ctx, log.Fields{ "copied_bytes": fh.Size, - "is_local": opts.IsLocal(), + "is_local": opts.IsLocalTempFile(), "is_multipart": opts.IsMultipart(), - "is_remote": !opts.IsLocal(), + "is_remote": !opts.IsLocalTempFile(), "remote_id": opts.RemoteID, "temp_file_prefix": opts.TempFilePrefix, "client_mode": clientMode, }) - if opts.IsLocal() { + if opts.IsLocalTempFile() { logger = logger.WithField("local_temp_path", opts.LocalTempPath) } else { logger = logger.WithField("remote_temp_object", opts.RemoteTempObjectID) diff --git a/workhorse/internal/upload/destination/objectstore/s3_session.go b/workhorse/internal/upload/destination/objectstore/s3_session.go index a0c1f099145..aa38f18ed7a 100644 --- a/workhorse/internal/upload/destination/objectstore/s3_session.go +++ b/workhorse/internal/upload/destination/objectstore/s3_session.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" @@ -70,7 +71,23 @@ func setupS3Session(s3Credentials config.S3Credentials, s3Config config.S3Config } if s3Config.Endpoint != "" { - cfg.Endpoint = aws.String(s3Config.Endpoint) + // The administrator has configured an S3 endpoint override, + // e.g. to make use of S3 IPv6 support or S3 FIPS mode. We + // need to configure a custom resolver to make sure that + // the custom endpoint is only used for S3 API calls, and not + // for STS API calls. + s3CustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + if service == endpoints.S3ServiceID { + return endpoints.ResolvedEndpoint{ + URL: s3Config.Endpoint, + SigningRegion: region, + }, nil + } + + return endpoints.DefaultResolver().EndpointFor(service, region, optFns...) + } + + cfg.EndpointResolver = endpoints.ResolverFunc(s3CustomResolver) } sess, err := session.NewSession(cfg) diff --git a/workhorse/internal/upload/destination/objectstore/s3_session_test.go b/workhorse/internal/upload/destination/objectstore/s3_session_test.go index 5d57b4f9af8..4bbe38f90ec 100644 --- a/workhorse/internal/upload/destination/objectstore/s3_session_test.go +++ b/workhorse/internal/upload/destination/objectstore/s3_session_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" @@ -17,7 +18,9 @@ func TestS3SessionSetup(t *testing.T) { sess, err := setupS3Session(credentials, cfg) require.NoError(t, err) - require.Equal(t, aws.StringValue(sess.Config.Region), "us-west-1") + s3Config := sess.ClientConfig(endpoints.S3ServiceID) + require.Equal(t, "https://s3.us-west-1.amazonaws.com", s3Config.Endpoint) + require.Equal(t, "us-west-1", s3Config.SigningRegion) require.True(t, aws.BoolValue(sess.Config.S3ForcePathStyle)) require.Equal(t, len(sessionCache.sessions), 1) @@ -29,6 +32,26 @@ func TestS3SessionSetup(t *testing.T) { ResetS3Session(cfg) } +func TestS3SessionEndpointSetup(t *testing.T) { + credentials := config.S3Credentials{} + const customS3Endpoint = "https://example.com" + const region = "us-west-2" + cfg := config.S3Config{Region: region, PathStyle: true, Endpoint: customS3Endpoint} + + sess, err := setupS3Session(credentials, cfg) + require.NoError(t, err) + + // ClientConfig is what is ultimately used by an S3 client + s3Config := sess.ClientConfig(endpoints.S3ServiceID) + require.Equal(t, customS3Endpoint, s3Config.Endpoint) + require.Equal(t, region, s3Config.SigningRegion) + + stsConfig := sess.ClientConfig(endpoints.StsServiceID) + require.Equal(t, "https://sts.amazonaws.com", stsConfig.Endpoint, "STS should use default endpoint") + + ResetS3Session(cfg) +} + func TestS3SessionExpiry(t *testing.T) { credentials := config.S3Credentials{} cfg := config.S3Config{Region: "us-west-1", PathStyle: true} diff --git a/workhorse/internal/upload/destination/upload_opts.go b/workhorse/internal/upload/destination/upload_opts.go index 750a79d7bc2..77a8927d34f 100644 --- a/workhorse/internal/upload/destination/upload_opts.go +++ b/workhorse/internal/upload/destination/upload_opts.go @@ -70,8 +70,8 @@ func (s *UploadOpts) UseWorkhorseClientEnabled() bool { return s.UseWorkhorseClient && s.ObjectStorageConfig.IsValid() && s.RemoteTempObjectID != "" } -// IsLocal checks if the options require the writing of the file on disk -func (s *UploadOpts) IsLocal() bool { +// IsLocalTempFile checks if the options require the writing of a temporary file on disk +func (s *UploadOpts) IsLocalTempFile() bool { return s.LocalTempPath != "" } diff --git a/workhorse/internal/upload/destination/upload_opts_test.go b/workhorse/internal/upload/destination/upload_opts_test.go index fde726c985d..24a372495c6 100644 --- a/workhorse/internal/upload/destination/upload_opts_test.go +++ b/workhorse/internal/upload/destination/upload_opts_test.go @@ -49,7 +49,7 @@ func TestUploadOptsLocalAndRemote(t *testing.T) { PartSize: test.partSize, } - require.Equal(t, test.isLocal, opts.IsLocal(), "IsLocal() mismatch") + require.Equal(t, test.isLocal, opts.IsLocalTempFile(), "IsLocalTempFile() mismatch") require.Equal(t, test.isMultipart, opts.IsMultipart(), "IsMultipart() mismatch") }) } @@ -336,7 +336,7 @@ func TestGoCloudConfig(t *testing.T) { require.Equal(t, apiResponse.RemoteObject.ObjectStorage.GoCloudConfig, opts.ObjectStorageConfig.GoCloudConfig) require.True(t, opts.UseWorkhorseClientEnabled()) require.Equal(t, test.valid, opts.ObjectStorageConfig.IsValid()) - require.False(t, opts.IsLocal()) + require.False(t, opts.IsLocalTempFile()) }) } } diff --git a/workhorse/internal/upstream/upstream.go b/workhorse/internal/upstream/upstream.go index c0678b1cb3e..6d107fc28cd 100644 --- a/workhorse/internal/upstream/upstream.go +++ b/workhorse/internal/upstream/upstream.go @@ -37,7 +37,6 @@ var ( upload.RewrittenFieldsHeader, } geoProxyApiPollingInterval = 10 * time.Second - geoProxyWorkhorseHeaders = map[string]string{"Gitlab-Workhorse-Geo-Proxy": "1"} ) type upstream struct { @@ -48,6 +47,7 @@ type upstream struct { CableRoundTripper http.RoundTripper APIClient *apipkg.API geoProxyBackend *url.URL + geoProxyExtraData string geoLocalRoutes []routeEntry geoProxyCableRoute routeEntry geoProxyRoute routeEntry @@ -215,34 +215,51 @@ func (u *upstream) pollGeoProxyAPI() { // Calls /api/v4/geo/proxy and sets up routes func (u *upstream) callGeoProxyAPI() { - geoProxyURL, err := u.APIClient.GetGeoProxyURL() + geoProxyData, err := u.APIClient.GetGeoProxyData() if err != nil { log.WithError(err).WithFields(log.Fields{"geoProxyBackend": u.geoProxyBackend}).Error("Geo Proxy: Unable to determine Geo Proxy URL. Fallback on cached value.") return } - if u.geoProxyBackend.String() != geoProxyURL.String() { - log.WithFields(log.Fields{"oldGeoProxyURL": u.geoProxyBackend, "newGeoProxyURL": geoProxyURL}).Info("Geo Proxy: URL changed") - u.updateGeoProxyFields(geoProxyURL) + hasProxyDataChanged := false + if u.geoProxyBackend.String() != geoProxyData.GeoProxyURL.String() { + log.WithFields(log.Fields{"oldGeoProxyURL": u.geoProxyBackend, "newGeoProxyURL": geoProxyData.GeoProxyURL}).Info("Geo Proxy: URL changed") + hasProxyDataChanged = true + } + + if u.geoProxyExtraData != geoProxyData.GeoProxyExtraData { + // extra data is usually a JWT, thus not explicitly logging it + log.Info("Geo Proxy: signed data changed") + hasProxyDataChanged = true + } + + if hasProxyDataChanged { + u.updateGeoProxyFieldsFromData(geoProxyData) } } -func (u *upstream) updateGeoProxyFields(geoProxyURL *url.URL) { +func (u *upstream) updateGeoProxyFieldsFromData(geoProxyData *apipkg.GeoProxyData) { u.mu.Lock() defer u.mu.Unlock() - u.geoProxyBackend = geoProxyURL + u.geoProxyBackend = geoProxyData.GeoProxyURL + u.geoProxyExtraData = geoProxyData.GeoProxyExtraData if u.geoProxyBackend.String() == "" { return } + geoProxyWorkhorseHeaders := map[string]string{ + "Gitlab-Workhorse-Geo-Proxy": "1", + "Gitlab-Workhorse-Geo-Proxy-Extra-Data": u.geoProxyExtraData, + } geoProxyRoundTripper := roundtripper.NewBackendRoundTripper(u.geoProxyBackend, "", u.ProxyHeadersTimeout, u.DevelopmentMode) geoProxyUpstream := proxypkg.NewProxy( u.geoProxyBackend, u.Version, geoProxyRoundTripper, proxypkg.WithCustomHeaders(geoProxyWorkhorseHeaders), + proxypkg.WithForcedTargetHostHeader(), ) u.geoProxyCableRoute = u.wsRoute(`^/-/cable\z`, geoProxyUpstream) u.geoProxyRoute = u.route("", "", geoProxyUpstream, withGeoProxy()) diff --git a/workhorse/internal/upstream/upstream_test.go b/workhorse/internal/upstream/upstream_test.go index 80e59202b69..8f054f5ccef 100644 --- a/workhorse/internal/upstream/upstream_test.go +++ b/workhorse/internal/upstream/upstream_test.go @@ -209,21 +209,74 @@ func TestGeoProxyFeatureEnablingAndDisabling(t *testing.T) { runTestCases(t, ws, testCasesProxied) } -func TestGeoProxySetsCustomHeader(t *testing.T) { +func TestGeoProxyUpdatesExtraDataWhenChanged(t *testing.T) { + var expectedGeoProxyExtraData string + remoteServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "1", r.Header.Get("Gitlab-Workhorse-Geo-Proxy"), "custom proxy header") + require.Equal(t, expectedGeoProxyExtraData, r.Header.Get("Gitlab-Workhorse-Geo-Proxy-Extra-Data"), "custom extra data header") w.WriteHeader(http.StatusOK) })) defer remoteServer.Close() - geoProxyEndpointResponseBody := fmt.Sprintf(`{"geo_proxy_url":"%v"}`, remoteServer.URL) + geoProxyEndpointExtraData1 := fmt.Sprintf(`{"geo_proxy_url":"%v","geo_proxy_extra_data":"data1"}`, remoteServer.URL) + geoProxyEndpointExtraData2 := fmt.Sprintf(`{"geo_proxy_url":"%v","geo_proxy_extra_data":"data2"}`, remoteServer.URL) + geoProxyEndpointExtraData3 := fmt.Sprintf(`{"geo_proxy_url":"%v"}`, remoteServer.URL) + geoProxyEndpointResponseBody := geoProxyEndpointExtraData1 + expectedGeoProxyExtraData = "data1" + railsServer, deferredClose := startRailsServer("Local Rails server", &geoProxyEndpointResponseBody) defer deferredClose() - ws, wsDeferredClose, _ := startWorkhorseServer(railsServer.URL, true) + ws, wsDeferredClose, waitForNextApiPoll := startWorkhorseServer(railsServer.URL, true) defer wsDeferredClose() http.Get(ws.URL) + + // Verify that the expected header changes after next updated poll. + geoProxyEndpointResponseBody = geoProxyEndpointExtraData2 + expectedGeoProxyExtraData = "data2" + waitForNextApiPoll() + + http.Get(ws.URL) + + // Validate that non-existing extra data results in empty header + geoProxyEndpointResponseBody = geoProxyEndpointExtraData3 + expectedGeoProxyExtraData = "" + waitForNextApiPoll() + + http.Get(ws.URL) +} + +func TestGeoProxySetsCustomHeader(t *testing.T) { + testCases := []struct { + desc string + json string + extraData string + }{ + {"no extra data", `{"geo_proxy_url":"%v"}`, ""}, + {"with extra data", `{"geo_proxy_url":"%v","geo_proxy_extra_data":"extra-geo-data"}`, "extra-geo-data"}, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + remoteServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "1", r.Header.Get("Gitlab-Workhorse-Geo-Proxy"), "custom proxy header") + require.Equal(t, tc.extraData, r.Header.Get("Gitlab-Workhorse-Geo-Proxy-Extra-Data"), "custom proxy extra data header") + w.WriteHeader(http.StatusOK) + })) + defer remoteServer.Close() + + geoProxyEndpointResponseBody := fmt.Sprintf(tc.json, remoteServer.URL) + railsServer, deferredClose := startRailsServer("Local Rails server", &geoProxyEndpointResponseBody) + defer deferredClose() + + ws, wsDeferredClose, _ := startWorkhorseServer(railsServer.URL, true) + defer wsDeferredClose() + + http.Get(ws.URL) + }) + } } func runTestCases(t *testing.T, ws *httptest.Server, testCases []testCase) { |