diff options
Diffstat (limited to 'workhorse/internal')
188 files changed, 19985 insertions, 0 deletions
diff --git a/workhorse/internal/api/api.go b/workhorse/internal/api/api.go new file mode 100644 index 00000000000..17fea398029 --- /dev/null +++ b/workhorse/internal/api/api.go @@ -0,0 +1,345 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" +) + +const ( + // Custom content type for API responses, to catch routing / programming mistakes + ResponseContentType = "application/vnd.gitlab-workhorse+json" + + failureResponseLimit = 32768 +) + +type API struct { + Client *http.Client + URL *url.URL + Version string +} + +var ( + requestsCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_internal_api_requests", + Help: "How many internal API requests have been completed by gitlab-workhorse, partitioned by status code and HTTP method.", + }, + []string{"code", "method"}, + ) + bytesTotal = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_internal_api_failure_response_bytes", + Help: "How many bytes have been returned by upstream GitLab in API failure/rejection response bodies.", + }, + ) +) + +func NewAPI(myURL *url.URL, version string, roundTripper http.RoundTripper) *API { + return &API{ + Client: &http.Client{Transport: roundTripper}, + URL: myURL, + Version: version, + } +} + +type HandleFunc func(http.ResponseWriter, *http.Request, *Response) + +type MultipartUploadParams struct { + // PartSize is the exact size of each uploaded part. Only the last one can be smaller + PartSize int64 + // PartURLs contains the presigned URLs for each part + PartURLs []string + // CompleteURL is a presigned URL for CompleteMulipartUpload + CompleteURL string + // AbortURL is a presigned URL for AbortMultipartUpload + AbortURL string +} + +type ObjectStorageParams struct { + Provider string + S3Config config.S3Config + GoCloudConfig config.GoCloudConfig +} + +type RemoteObject struct { + // GetURL is an S3 GetObject URL + GetURL string + // DeleteURL is a presigned S3 RemoveObject URL + DeleteURL string + // StoreURL is the temporary presigned S3 PutObject URL to which upload the first found file + StoreURL string + // Boolean to indicate whether to use headers included in PutHeaders + CustomPutHeaders bool + // PutHeaders are HTTP headers (e.g. Content-Type) to be sent with StoreURL + PutHeaders map[string]string + // Whether to ignore Rails pre-signed URLs and have Workhorse directly access object storage provider + UseWorkhorseClient bool + // Remote, temporary object name where Rails will move to the final destination + RemoteTempObjectID string + // ID is a unique identifier of object storage upload + ID string + // Timeout is a number that represents timeout in seconds for sending data to StoreURL + Timeout int + // MultipartUpload contains presigned URLs for S3 MultipartUpload + MultipartUpload *MultipartUploadParams + // Object storage config for Workhorse client + ObjectStorage *ObjectStorageParams +} + +type Response struct { + // GL_ID is an environment variable used by gitlab-shell hooks during 'git + // push' and 'git pull' + GL_ID string + + // GL_USERNAME holds gitlab username of the user who is taking the action causing hooks to be invoked + GL_USERNAME string + + // GL_REPOSITORY is an environment variable used by gitlab-shell hooks during + // 'git push' and 'git pull' + GL_REPOSITORY string + // GitConfigOptions holds the custom options that we want to pass to the git command + GitConfigOptions []string + // StoreLFSPath is provided by the GitLab Rails application to mark where the tmp file should be placed. + // This field is deprecated. GitLab will use TempPath instead + StoreLFSPath string + // LFS object id + LfsOid string + // LFS object size + LfsSize int64 + // TmpPath is the path where we should store temporary files + // This is set by authorization middleware + TempPath string + // RemoteObject is provided by the GitLab Rails application + // and defines a way to store object on remote storage + RemoteObject RemoteObject + // Archive is the path where the artifacts archive is stored + Archive string `json:"archive"` + // Entry is a filename inside the archive point to file that needs to be extracted + Entry string `json:"entry"` + // Used to communicate channel session details + Channel *ChannelSettings + // GitalyServer specifies an address and authentication token for a gitaly server we should connect to. + GitalyServer gitaly.Server + // Repository object for making gRPC requests to Gitaly. + Repository gitalypb.Repository + // For git-http, does the requestor have the right to view all refs? + ShowAllRefs bool + // Detects whether an artifact is used for code intelligence + ProcessLsif bool + // Detects whether LSIF artifact will be parsed with references + ProcessLsifReferences bool + // The maximum accepted size in bytes of the upload + MaximumSize int64 +} + +// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// rebaseUrl is taken from reverseproxy.go:NewSingleHostReverseProxy +func rebaseUrl(url *url.URL, onto *url.URL, suffix string) *url.URL { + newUrl := *url + newUrl.Scheme = onto.Scheme + newUrl.Host = onto.Host + if suffix != "" { + newUrl.Path = singleJoiningSlash(url.Path, suffix) + } + if onto.RawQuery == "" || newUrl.RawQuery == "" { + newUrl.RawQuery = onto.RawQuery + newUrl.RawQuery + } else { + newUrl.RawQuery = onto.RawQuery + "&" + newUrl.RawQuery + } + return &newUrl +} + +func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error) { + authReq := &http.Request{ + Method: r.Method, + URL: rebaseUrl(r.URL, api.URL, suffix), + Header: helper.HeaderClone(r.Header), + } + + authReq = authReq.WithContext(r.Context()) + + // Clean some headers when issuing a new request without body + authReq.Header.Del("Content-Type") + authReq.Header.Del("Content-Encoding") + authReq.Header.Del("Content-Length") + authReq.Header.Del("Content-Disposition") + authReq.Header.Del("Accept-Encoding") + + // Hop-by-hop headers. These are removed when sent to the backend. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html + authReq.Header.Del("Transfer-Encoding") + authReq.Header.Del("Connection") + authReq.Header.Del("Keep-Alive") + authReq.Header.Del("Proxy-Authenticate") + authReq.Header.Del("Proxy-Authorization") + authReq.Header.Del("Te") + authReq.Header.Del("Trailers") + authReq.Header.Del("Upgrade") + + // Also forward the Host header, which is excluded from the Header map by the http library. + // This allows the Host header received by the backend to be consistent with other + // requests not going through gitlab-workhorse. + authReq.Host = r.Host + + return authReq, nil +} + +// PreAuthorize performs a pre-authorization check against the API for the given HTTP request +// +// If `outErr` is set, the other fields will be nil and it should be treated as +// a 500 error. +// +// If httpResponse is present, the caller is responsible for closing its body +// +// authResponse will only be present if the authorization check was successful +func (api *API) PreAuthorize(suffix string, r *http.Request) (httpResponse *http.Response, authResponse *Response, outErr error) { + authReq, err := api.newRequest(r, suffix) + if err != nil { + return nil, nil, fmt.Errorf("preAuthorizeHandler newUpstreamRequest: %v", err) + } + + httpResponse, err = api.doRequestWithoutRedirects(authReq) + if err != nil { + return nil, nil, fmt.Errorf("preAuthorizeHandler: do request: %v", err) + } + defer func() { + if outErr != nil { + httpResponse.Body.Close() + httpResponse = nil + } + }() + requestsCounter.WithLabelValues(strconv.Itoa(httpResponse.StatusCode), authReq.Method).Inc() + + // This may be a false positive, e.g. for .../info/refs, rather than a + // failure, so pass the response back + if httpResponse.StatusCode != http.StatusOK || !validResponseContentType(httpResponse) { + return httpResponse, nil, nil + } + + authResponse = &Response{} + // The auth backend validated the client request and told us additional + // request metadata. We must extract this information from the auth + // response body. + if err := json.NewDecoder(httpResponse.Body).Decode(authResponse); err != nil { + return httpResponse, nil, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err) + } + + return httpResponse, authResponse, nil +} + +func (api *API) PreAuthorizeHandler(next HandleFunc, suffix string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpResponse, authResponse, err := api.PreAuthorize(suffix, r) + if httpResponse != nil { + defer httpResponse.Body.Close() + } + + if err != nil { + helper.Fail500(w, r, err) + return + } + + // The response couldn't be interpreted as a valid auth response, so + // pass it back (mostly) unmodified + if httpResponse != nil && authResponse == nil { + passResponseBack(httpResponse, w, r) + return + } + + httpResponse.Body.Close() // Free up the Unicorn worker + + copyAuthHeader(httpResponse, w) + + next(w, r, authResponse) + }) +} + +func (api *API) doRequestWithoutRedirects(authReq *http.Request) (*http.Response, error) { + signingTripper := secret.NewRoundTripper(api.Client.Transport, api.Version) + + return signingTripper.RoundTrip(authReq) +} + +func copyAuthHeader(httpResponse *http.Response, w http.ResponseWriter) { + // Negotiate authentication (Kerberos) may need to return a WWW-Authenticate + // header to the client even in case of success as per RFC4559. + for k, v := range httpResponse.Header { + // Case-insensitive comparison as per RFC7230 + if strings.EqualFold(k, "WWW-Authenticate") { + w.Header()[k] = v + } + } +} + +func passResponseBack(httpResponse *http.Response, w http.ResponseWriter, r *http.Request) { + // NGINX response buffering is disabled on this path (with + // X-Accel-Buffering: no) but we still want to free up the Unicorn worker + // that generated httpResponse as fast as possible. To do this we buffer + // the entire response body in memory before sending it on. + responseBody, err := bufferResponse(httpResponse.Body) + if err != nil { + helper.Fail500(w, r, err) + return + } + httpResponse.Body.Close() // Free up the Unicorn worker + bytesTotal.Add(float64(responseBody.Len())) + + for k, v := range httpResponse.Header { + // Accommodate broken clients that do case-sensitive header lookup + if k == "Www-Authenticate" { + w.Header()["WWW-Authenticate"] = v + } else { + w.Header()[k] = v + } + } + w.WriteHeader(httpResponse.StatusCode) + if _, err := io.Copy(w, responseBody); err != nil { + helper.LogError(r, err) + } +} + +func bufferResponse(r io.Reader) (*bytes.Buffer, error) { + responseBody := &bytes.Buffer{} + n, err := io.Copy(responseBody, io.LimitReader(r, failureResponseLimit)) + if err != nil { + return nil, err + } + + if n == failureResponseLimit { + return nil, fmt.Errorf("response body exceeded maximum buffer size (%d bytes)", failureResponseLimit) + } + + return responseBody, nil +} + +func validResponseContentType(resp *http.Response) bool { + return helper.IsContentType(ResponseContentType, resp.Header.Get("Content-Type")) +} diff --git a/workhorse/internal/api/block.go b/workhorse/internal/api/block.go new file mode 100644 index 00000000000..92322906c03 --- /dev/null +++ b/workhorse/internal/api/block.go @@ -0,0 +1,61 @@ +package api + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +// Prevent internal API responses intended for gitlab-workhorse from +// leaking to the end user +func Block(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rw := &blocker{rw: w, r: r} + defer rw.flush() + h.ServeHTTP(rw, r) + }) +} + +type blocker struct { + rw http.ResponseWriter + r *http.Request + hijacked bool + status int +} + +func (b *blocker) Header() http.Header { + return b.rw.Header() +} + +func (b *blocker) Write(data []byte) (int, error) { + if b.status == 0 { + b.WriteHeader(http.StatusOK) + } + if b.hijacked { + return len(data), nil + } + + return b.rw.Write(data) +} + +func (b *blocker) WriteHeader(status int) { + if b.status != 0 { + return + } + + if helper.IsContentType(ResponseContentType, b.Header().Get("Content-Type")) { + b.status = 500 + b.Header().Del("Content-Length") + b.hijacked = true + helper.Fail500(b.rw, b.r, fmt.Errorf("api.blocker: forbidden content-type: %q", ResponseContentType)) + return + } + + b.status = status + b.rw.WriteHeader(b.status) +} + +func (b *blocker) flush() { + b.WriteHeader(http.StatusOK) +} diff --git a/workhorse/internal/api/block_test.go b/workhorse/internal/api/block_test.go new file mode 100644 index 00000000000..85ad54f3cfd --- /dev/null +++ b/workhorse/internal/api/block_test.go @@ -0,0 +1,56 @@ +package api + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBlocker(t *testing.T) { + upstreamResponse := "hello world" + + testCases := []struct { + desc string + contentType string + out string + }{ + { + desc: "blocked", + contentType: ResponseContentType, + out: "Internal server error\n", + }, + { + desc: "pass", + contentType: "text/plain", + out: upstreamResponse, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + r, err := http.NewRequest("GET", "/foo", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + bl := &blocker{rw: rw, r: r} + bl.Header().Set("Content-Type", tc.contentType) + + upstreamBody := []byte(upstreamResponse) + n, err := bl.Write(upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + + rw.Flush() + + body := rw.Result().Body + data, err := ioutil.ReadAll(body) + require.NoError(t, err) + require.NoError(t, body.Close()) + + require.Equal(t, tc.out, string(data)) + }) + } +} diff --git a/workhorse/internal/api/channel_settings.go b/workhorse/internal/api/channel_settings.go new file mode 100644 index 00000000000..bf3094c9c91 --- /dev/null +++ b/workhorse/internal/api/channel_settings.go @@ -0,0 +1,122 @@ +package api + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "net/url" + + "github.com/gorilla/websocket" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +type ChannelSettings struct { + // The channel provider may require use of a particular subprotocol. If so, + // it must be specified here, and Workhorse must have a matching codec. + Subprotocols []string + + // The websocket URL to connect to. + Url string + + // Any headers (e.g., Authorization) to send with the websocket request + Header http.Header + + // The CA roots to validate the remote endpoint with, for wss:// URLs. The + // system-provided CA pool will be used if this is blank. PEM-encoded data. + CAPem string + + // The value is specified in seconds. It is converted to time.Duration + // later. + MaxSessionTime int +} + +func (t *ChannelSettings) URL() (*url.URL, error) { + return url.Parse(t.Url) +} + +func (t *ChannelSettings) Dialer() *websocket.Dialer { + dialer := &websocket.Dialer{ + Subprotocols: t.Subprotocols, + } + + if len(t.CAPem) > 0 { + pool := x509.NewCertPool() + pool.AppendCertsFromPEM([]byte(t.CAPem)) + dialer.TLSClientConfig = &tls.Config{RootCAs: pool} + } + + return dialer +} + +func (t *ChannelSettings) Clone() *ChannelSettings { + // Doesn't clone the strings, but that's OK as strings are immutable in go + cloned := *t + cloned.Header = helper.HeaderClone(t.Header) + return &cloned +} + +func (t *ChannelSettings) Dial() (*websocket.Conn, *http.Response, error) { + return t.Dialer().Dial(t.Url, t.Header) +} + +func (t *ChannelSettings) Validate() error { + if t == nil { + return fmt.Errorf("channel details not specified") + } + + if len(t.Subprotocols) == 0 { + return fmt.Errorf("no subprotocol specified") + } + + parsedURL, err := t.URL() + if err != nil { + return fmt.Errorf("invalid URL") + } + + if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" { + return fmt.Errorf("invalid websocket scheme: %q", parsedURL.Scheme) + } + + return nil +} + +func (t *ChannelSettings) IsEqual(other *ChannelSettings) bool { + if t == nil && other == nil { + return true + } + + if t == nil || other == nil { + return false + } + + if len(t.Subprotocols) != len(other.Subprotocols) { + return false + } + + for i, subprotocol := range t.Subprotocols { + if other.Subprotocols[i] != subprotocol { + return false + } + } + + if len(t.Header) != len(other.Header) { + return false + } + + for header, values := range t.Header { + if len(values) != len(other.Header[header]) { + return false + } + for i, value := range values { + if other.Header[header][i] != value { + return false + } + } + } + + return t.Url == other.Url && + t.CAPem == other.CAPem && + t.MaxSessionTime == other.MaxSessionTime +} diff --git a/workhorse/internal/api/channel_settings_test.go b/workhorse/internal/api/channel_settings_test.go new file mode 100644 index 00000000000..4aa2c835579 --- /dev/null +++ b/workhorse/internal/api/channel_settings_test.go @@ -0,0 +1,154 @@ +package api + +import ( + "net/http" + "testing" +) + +func channel(url string, subprotocols ...string) *ChannelSettings { + return &ChannelSettings{ + Url: url, + Subprotocols: subprotocols, + MaxSessionTime: 0, + } +} + +func ca(channel *ChannelSettings) *ChannelSettings { + channel = channel.Clone() + channel.CAPem = "Valid CA data" + + return channel +} + +func timeout(channel *ChannelSettings) *ChannelSettings { + channel = channel.Clone() + channel.MaxSessionTime = 600 + + return channel +} + +func header(channel *ChannelSettings, values ...string) *ChannelSettings { + if len(values) == 0 { + values = []string{"Dummy Value"} + } + + channel = channel.Clone() + channel.Header = http.Header{ + "Header": values, + } + + return channel +} + +func TestClone(t *testing.T) { + a := ca(header(channel("ws:", "", ""))) + b := a.Clone() + + if a == b { + t.Fatalf("Address of cloned channel didn't change") + } + + if &a.Subprotocols == &b.Subprotocols { + t.Fatalf("Address of cloned subprotocols didn't change") + } + + if &a.Header == &b.Header { + t.Fatalf("Address of cloned header didn't change") + } +} + +func TestValidate(t *testing.T) { + for i, tc := range []struct { + channel *ChannelSettings + valid bool + msg string + }{ + {nil, false, "nil channel"}, + {channel("", ""), false, "empty URL"}, + {channel("ws:"), false, "empty subprotocols"}, + {channel("ws:", "foo"), true, "any subprotocol"}, + {channel("ws:", "foo", "bar"), true, "multiple subprotocols"}, + {channel("ws:", ""), true, "websocket URL"}, + {channel("wss:", ""), true, "secure websocket URL"}, + {channel("http:", ""), false, "HTTP URL"}, + {channel("https:", ""), false, " HTTPS URL"}, + {ca(channel("ws:", "")), true, "any CA pem"}, + {header(channel("ws:", "")), true, "any headers"}, + {ca(header(channel("ws:", ""))), true, "PEM and headers"}, + } { + if err := tc.channel.Validate(); (err != nil) == tc.valid { + t.Fatalf("test case %d: "+tc.msg+": valid=%v: %s: %+v", i, tc.valid, err, tc.channel) + } + } +} + +func TestDialer(t *testing.T) { + channel := channel("ws:", "foo") + dialer := channel.Dialer() + + if len(dialer.Subprotocols) != len(channel.Subprotocols) { + t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols) + } + + for i, subprotocol := range channel.Subprotocols { + if dialer.Subprotocols[i] != subprotocol { + t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols) + } + } + + if dialer.TLSClientConfig != nil { + t.Fatalf("Unexpected TLSClientConfig: %+v", dialer) + } + + channel = ca(channel) + dialer = channel.Dialer() + + if dialer.TLSClientConfig == nil || dialer.TLSClientConfig.RootCAs == nil { + t.Fatalf("Custom CA certificates not recognised!") + } +} + +func TestIsEqual(t *testing.T) { + chann := channel("ws:", "foo") + + chann_header2 := header(chann, "extra") + chann_header3 := header(chann) + chann_header3.Header.Add("Extra", "extra") + + chann_ca2 := ca(chann) + chann_ca2.CAPem = "other value" + + for i, tc := range []struct { + channelA *ChannelSettings + channelB *ChannelSettings + expected bool + }{ + {nil, nil, true}, + {chann, nil, false}, + {nil, chann, false}, + {chann, chann, true}, + {chann.Clone(), chann.Clone(), true}, + {chann, channel("foo:"), false}, + {chann, channel(chann.Url), false}, + {header(chann), header(chann), true}, + {chann_header2, chann_header2, true}, + {chann_header3, chann_header3, true}, + {header(chann), chann_header2, false}, + {header(chann), chann_header3, false}, + {header(chann), chann, false}, + {chann, header(chann), false}, + {ca(chann), ca(chann), true}, + {ca(chann), chann, false}, + {chann, ca(chann), false}, + {ca(header(chann)), ca(header(chann)), true}, + {chann_ca2, ca(chann), false}, + {chann, timeout(chann), false}, + } { + if actual := tc.channelA.IsEqual(tc.channelB); tc.expected != actual { + t.Fatalf( + "test case %d: Comparison:\n-%+v\n+%+v\nexpected=%v: actual=%v", + i, tc.channelA, tc.channelB, tc.expected, actual, + ) + } + } +} diff --git a/workhorse/internal/artifacts/artifacts_store_test.go b/workhorse/internal/artifacts/artifacts_store_test.go new file mode 100644 index 00000000000..bd56d9ea725 --- /dev/null +++ b/workhorse/internal/artifacts/artifacts_store_test.go @@ -0,0 +1,338 @@ +package artifacts + +import ( + "archive/zip" + "bytes" + "crypto/md5" + "encoding/hex" + "fmt" + "io/ioutil" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func createTestZipArchive(t *testing.T) (data []byte, md5Hash string) { + var buffer bytes.Buffer + archive := zip.NewWriter(&buffer) + fileInArchive, err := archive.Create("test.file") + require.NoError(t, err) + fmt.Fprint(fileInArchive, "test") + archive.Close() + data = buffer.Bytes() + + hasher := md5.New() + hasher.Write(data) + hexHash := hasher.Sum(nil) + md5Hash = hex.EncodeToString(hexHash) + + return data, md5Hash +} + +func createTestMultipartForm(t *testing.T, data []byte) (bytes.Buffer, string) { + var buffer bytes.Buffer + writer := multipart.NewWriter(&buffer) + file, err := writer.CreateFormFile("file", "my.file") + require.NoError(t, err) + file.Write(data) + writer.Close() + return buffer, writer.FormDataContentType() +} + +func testUploadArtifactsFromTestZip(t *testing.T, ts *httptest.Server) *httptest.ResponseRecorder { + archiveData, _ := createTestZipArchive(t) + contentBuffer, contentType := createTestMultipartForm(t, archiveData) + + return testUploadArtifacts(t, contentType, ts.URL+Path, &contentBuffer) +} + +func TestUploadHandlerSendingToExternalStorage(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempPath) + + archiveData, md5 := createTestZipArchive(t) + archiveFile, err := ioutil.TempFile("", "artifact.zip") + require.NoError(t, err) + defer os.Remove(archiveFile.Name()) + _, err = archiveFile.Write(archiveData) + require.NoError(t, err) + archiveFile.Close() + + storeServerCalled := 0 + storeServerMux := http.NewServeMux() + storeServerMux.HandleFunc("/url/put", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "PUT", r.Method) + + receivedData, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, archiveData, receivedData) + + storeServerCalled++ + w.Header().Set("ETag", md5) + w.WriteHeader(200) + }) + storeServerMux.HandleFunc("/store-id", func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, archiveFile.Name()) + }) + + responseProcessorCalled := 0 + responseProcessor := func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "store-id", r.FormValue("file.remote_id")) + require.NotEmpty(t, r.FormValue("file.remote_url")) + w.WriteHeader(200) + responseProcessorCalled++ + } + + storeServer := httptest.NewServer(storeServerMux) + defer storeServer.Close() + + qs := fmt.Sprintf("?%s=%s", ArtifactFormatKey, ArtifactFormatZip) + + tests := []struct { + name string + preauth api.Response + }{ + { + name: "ObjectStore Upload", + preauth: api.Response{ + RemoteObject: api.RemoteObject{ + StoreURL: storeServer.URL + "/url/put" + qs, + ID: "store-id", + GetURL: storeServer.URL + "/store-id", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + storeServerCalled = 0 + responseProcessorCalled = 0 + + ts := testArtifactsUploadServer(t, test.preauth, responseProcessor) + defer ts.Close() + + contentBuffer, contentType := createTestMultipartForm(t, archiveData) + response := testUploadArtifacts(t, contentType, ts.URL+Path+qs, &contentBuffer) + require.Equal(t, http.StatusOK, response.Code) + testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderPresent) + require.Equal(t, 1, storeServerCalled, "store should be called only once") + require.Equal(t, 1, responseProcessorCalled, "response processor should be called only once") + }) + } +} + +func TestUploadHandlerSendingToExternalStorageAndStorageServerUnreachable(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempPath) + + responseProcessor := func(w http.ResponseWriter, r *http.Request) { + t.Fatal("it should not be called") + } + + authResponse := api.Response{ + TempPath: tempPath, + RemoteObject: api.RemoteObject{ + StoreURL: "http://localhost:12323/invalid/url", + ID: "store-id", + }, + } + + ts := testArtifactsUploadServer(t, authResponse, responseProcessor) + defer ts.Close() + + response := testUploadArtifactsFromTestZip(t, ts) + require.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestUploadHandlerSendingToExternalStorageAndInvalidURLIsUsed(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempPath) + + responseProcessor := func(w http.ResponseWriter, r *http.Request) { + t.Fatal("it should not be called") + } + + authResponse := api.Response{ + TempPath: tempPath, + RemoteObject: api.RemoteObject{ + StoreURL: "htt:////invalid-url", + ID: "store-id", + }, + } + + ts := testArtifactsUploadServer(t, authResponse, responseProcessor) + defer ts.Close() + + response := testUploadArtifactsFromTestZip(t, ts) + require.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestUploadHandlerSendingToExternalStorageAndItReturnsAnError(t *testing.T) { + putCalledTimes := 0 + + storeServerMux := http.NewServeMux() + storeServerMux.HandleFunc("/url/put", func(w http.ResponseWriter, r *http.Request) { + putCalledTimes++ + require.Equal(t, "PUT", r.Method) + w.WriteHeader(510) + }) + + responseProcessor := func(w http.ResponseWriter, r *http.Request) { + t.Fatal("it should not be called") + } + + storeServer := httptest.NewServer(storeServerMux) + defer storeServer.Close() + + authResponse := api.Response{ + RemoteObject: api.RemoteObject{ + StoreURL: storeServer.URL + "/url/put", + ID: "store-id", + }, + } + + ts := testArtifactsUploadServer(t, authResponse, responseProcessor) + defer ts.Close() + + response := testUploadArtifactsFromTestZip(t, ts) + require.Equal(t, http.StatusInternalServerError, response.Code) + require.Equal(t, 1, putCalledTimes, "upload should be called only once") +} + +func TestUploadHandlerSendingToExternalStorageAndSupportRequestTimeout(t *testing.T) { + putCalledTimes := 0 + + storeServerMux := http.NewServeMux() + storeServerMux.HandleFunc("/url/put", func(w http.ResponseWriter, r *http.Request) { + putCalledTimes++ + require.Equal(t, "PUT", r.Method) + time.Sleep(10 * time.Second) + w.WriteHeader(510) + }) + + responseProcessor := func(w http.ResponseWriter, r *http.Request) { + t.Fatal("it should not be called") + } + + storeServer := httptest.NewServer(storeServerMux) + defer storeServer.Close() + + authResponse := api.Response{ + RemoteObject: api.RemoteObject{ + StoreURL: storeServer.URL + "/url/put", + ID: "store-id", + Timeout: 1, + }, + } + + ts := testArtifactsUploadServer(t, authResponse, responseProcessor) + defer ts.Close() + + response := testUploadArtifactsFromTestZip(t, ts) + require.Equal(t, http.StatusInternalServerError, response.Code) + require.Equal(t, 1, putCalledTimes, "upload should be called only once") +} + +func TestUploadHandlerMultipartUploadSizeLimit(t *testing.T) { + os, server := test.StartObjectStore() + defer server.Close() + + err := os.InitiateMultipartUpload(test.ObjectPath) + require.NoError(t, err) + + objectURL := server.URL + test.ObjectPath + + uploadSize := 10 + preauth := api.Response{ + RemoteObject: api.RemoteObject{ + ID: "store-id", + MultipartUpload: &api.MultipartUploadParams{ + PartSize: 1, + PartURLs: []string{objectURL + "?partNumber=1"}, + AbortURL: objectURL, // DELETE + CompleteURL: objectURL, // POST + }, + }, + } + + responseProcessor := func(w http.ResponseWriter, r *http.Request) { + t.Fatal("it should not be called") + } + + ts := testArtifactsUploadServer(t, preauth, responseProcessor) + defer ts.Close() + + contentBuffer, contentType := createTestMultipartForm(t, make([]byte, uploadSize)) + response := testUploadArtifacts(t, contentType, ts.URL+Path, &contentBuffer) + require.Equal(t, http.StatusRequestEntityTooLarge, response.Code) + + // Poll because AbortMultipartUpload is async + for i := 0; os.IsMultipartUpload(test.ObjectPath) && i < 100; i++ { + time.Sleep(10 * time.Millisecond) + } + require.False(t, os.IsMultipartUpload(test.ObjectPath), "MultipartUpload should not be in progress anymore") + require.Empty(t, os.GetObjectMD5(test.ObjectPath), "upload should have failed, so the object should not exists") +} + +func TestUploadHandlerMultipartUploadMaximumSizeFromApi(t *testing.T) { + os, server := test.StartObjectStore() + defer server.Close() + + err := os.InitiateMultipartUpload(test.ObjectPath) + require.NoError(t, err) + + objectURL := server.URL + test.ObjectPath + + uploadSize := int64(10) + maxSize := uploadSize - 1 + preauth := api.Response{ + MaximumSize: maxSize, + RemoteObject: api.RemoteObject{ + ID: "store-id", + MultipartUpload: &api.MultipartUploadParams{ + PartSize: uploadSize, + PartURLs: []string{objectURL + "?partNumber=1"}, + AbortURL: objectURL, // DELETE + CompleteURL: objectURL, // POST + }, + }, + } + + responseProcessor := func(w http.ResponseWriter, r *http.Request) { + t.Fatal("it should not be called") + } + + ts := testArtifactsUploadServer(t, preauth, responseProcessor) + defer ts.Close() + + contentBuffer, contentType := createTestMultipartForm(t, make([]byte, uploadSize)) + response := testUploadArtifacts(t, contentType, ts.URL+Path, &contentBuffer) + require.Equal(t, http.StatusRequestEntityTooLarge, response.Code) + + testhelper.Retry(t, 5*time.Second, func() error { + if os.GetObjectMD5(test.ObjectPath) == "" { + return nil + } + + return fmt.Errorf("file is still present") + }) +} diff --git a/workhorse/internal/artifacts/artifacts_test.go b/workhorse/internal/artifacts/artifacts_test.go new file mode 100644 index 00000000000..b9a42cc60c1 --- /dev/null +++ b/workhorse/internal/artifacts/artifacts_test.go @@ -0,0 +1,19 @@ +package artifacts + +import ( + "os" + "testing" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func TestMain(m *testing.M) { + if err := testhelper.BuildExecutables(); err != nil { + log.WithError(err).Fatal() + } + + os.Exit(m.Run()) + +} diff --git a/workhorse/internal/artifacts/artifacts_upload.go b/workhorse/internal/artifacts/artifacts_upload.go new file mode 100644 index 00000000000..3d4b8bf0931 --- /dev/null +++ b/workhorse/internal/artifacts/artifacts_upload.go @@ -0,0 +1,167 @@ +package artifacts + +import ( + "context" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "os/exec" + "strings" + "syscall" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts" +) + +// Sent by the runner: https://gitlab.com/gitlab-org/gitlab-runner/blob/c24da19ecce8808d9d2950896f70c94f5ea1cc2e/network/gitlab.go#L580 +const ( + ArtifactFormatKey = "artifact_format" + ArtifactFormatZip = "zip" + ArtifactFormatDefault = "" +) + +var zipSubcommandsErrorsCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_zip_subcommand_errors_total", + Help: "Errors comming from subcommands used for processing ZIP archives", + }, []string{"error"}) + +type artifactsUploadProcessor struct { + opts *filestore.SaveFileOpts + format string + + upload.SavedFileTracker +} + +func (a *artifactsUploadProcessor) generateMetadataFromZip(ctx context.Context, file *filestore.FileHandler) (*filestore.FileHandler, error) { + metaReader, metaWriter := io.Pipe() + defer metaWriter.Close() + + metaOpts := &filestore.SaveFileOpts{ + LocalTempPath: a.opts.LocalTempPath, + TempFilePrefix: "metadata.gz", + } + if metaOpts.LocalTempPath == "" { + metaOpts.LocalTempPath = os.TempDir() + } + + fileName := file.LocalPath + if fileName == "" { + fileName = file.RemoteURL + } + + zipMd := exec.CommandContext(ctx, "gitlab-zip-metadata", fileName) + zipMd.Stderr = log.ContextLogger(ctx).Writer() + zipMd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + zipMd.Stdout = metaWriter + + if err := zipMd.Start(); err != nil { + return nil, err + } + defer helper.CleanUpProcessGroup(zipMd) + + type saveResult struct { + error + *filestore.FileHandler + } + done := make(chan saveResult) + go func() { + var result saveResult + result.FileHandler, result.error = filestore.SaveFileFromReader(ctx, metaReader, -1, metaOpts) + + done <- result + }() + + if err := zipMd.Wait(); err != nil { + st, ok := helper.ExitStatus(err) + + if !ok { + return nil, err + } + + zipSubcommandsErrorsCounter.WithLabelValues(zipartifacts.ErrorLabelByCode(st)).Inc() + + if st == zipartifacts.CodeNotZip { + return nil, nil + } + + if st == zipartifacts.CodeLimitsReached { + return nil, zipartifacts.ErrBadMetadata + } + } + + metaWriter.Close() + result := <-done + return result.FileHandler, result.error +} + +func (a *artifactsUploadProcessor) ProcessFile(ctx context.Context, formName string, file *filestore.FileHandler, writer *multipart.Writer) error { + // ProcessFile for artifacts requires file form-data field name to eq `file` + + if formName != "file" { + return fmt.Errorf("invalid form field: %q", formName) + } + if a.Count() > 0 { + return fmt.Errorf("artifacts request contains more than one file") + } + a.Track(formName, file.LocalPath) + + select { + case <-ctx.Done(): + return fmt.Errorf("ProcessFile: context done") + default: + } + + if !strings.EqualFold(a.format, ArtifactFormatZip) && a.format != ArtifactFormatDefault { + return nil + } + + // TODO: can we rely on disk for shipping metadata? Not if we split workhorse and rails in 2 different PODs + metadata, err := a.generateMetadataFromZip(ctx, file) + if err != nil { + return err + } + + if metadata != nil { + fields, err := metadata.GitLabFinalizeFields("metadata") + if err != nil { + return fmt.Errorf("finalize metadata field error: %v", err) + } + + for k, v := range fields { + writer.WriteField(k, v) + } + + a.Track("metadata", metadata.LocalPath) + } + + return nil +} + +func (a *artifactsUploadProcessor) Name() string { + return "artifacts" +} + +func UploadArtifacts(myAPI *api.API, h http.Handler, p upload.Preparer) http.Handler { + return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) { + opts, _, err := p.Prepare(a) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("UploadArtifacts: error preparing file storage options")) + return + } + + format := r.URL.Query().Get(ArtifactFormatKey) + + mg := &artifactsUploadProcessor{opts: opts, format: format, SavedFileTracker: upload.SavedFileTracker{Request: r}} + upload.HandleFileUploads(w, r, h, a, mg, opts) + }, "/authorize") +} diff --git a/workhorse/internal/artifacts/artifacts_upload_test.go b/workhorse/internal/artifacts/artifacts_upload_test.go new file mode 100644 index 00000000000..c82ae791239 --- /dev/null +++ b/workhorse/internal/artifacts/artifacts_upload_test.go @@ -0,0 +1,322 @@ +package artifacts + +import ( + "archive/zip" + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/dgrijalva/jwt-go" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream/roundtripper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts" + + "github.com/stretchr/testify/require" +) + +const ( + MetadataHeaderKey = "Metadata-Status" + MetadataHeaderPresent = "present" + MetadataHeaderMissing = "missing" + Path = "/url/path" +) + +func testArtifactsUploadServer(t *testing.T, authResponse api.Response, bodyProcessor func(w http.ResponseWriter, r *http.Request)) *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc(Path+"/authorize", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatal("Expected POST request") + } + + w.Header().Set("Content-Type", api.ResponseContentType) + + data, err := json.Marshal(&authResponse) + if err != nil { + t.Fatal("Expected to marshal") + } + w.Write(data) + }) + mux.HandleFunc(Path, func(w http.ResponseWriter, r *http.Request) { + opts, err := filestore.GetOpts(&authResponse) + require.NoError(t, err) + + if r.Method != "POST" { + t.Fatal("Expected POST request") + } + if opts.IsLocal() { + if r.FormValue("file.path") == "" { + t.Fatal("Expected file to be present") + return + } + + _, err := ioutil.ReadFile(r.FormValue("file.path")) + if err != nil { + t.Fatal("Expected file to be readable") + return + } + } else { + if r.FormValue("file.remote_url") == "" { + t.Fatal("Expected file to be remote accessible") + return + } + } + + if r.FormValue("metadata.path") != "" { + metadata, err := ioutil.ReadFile(r.FormValue("metadata.path")) + if err != nil { + t.Fatal("Expected metadata to be readable") + return + } + gz, err := gzip.NewReader(bytes.NewReader(metadata)) + if err != nil { + t.Fatal("Expected metadata to be valid gzip") + return + } + defer gz.Close() + metadata, err = ioutil.ReadAll(gz) + if err != nil { + t.Fatal("Expected metadata to be valid") + return + } + if !bytes.HasPrefix(metadata, []byte(zipartifacts.MetadataHeaderPrefix+zipartifacts.MetadataHeader)) { + t.Fatal("Expected metadata to be of valid format") + return + } + + w.Header().Set(MetadataHeaderKey, MetadataHeaderPresent) + + } else { + w.Header().Set(MetadataHeaderKey, MetadataHeaderMissing) + } + + w.WriteHeader(http.StatusOK) + + if bodyProcessor != nil { + bodyProcessor(w, r) + } + }) + return testhelper.TestServerWithHandler(nil, mux.ServeHTTP) +} + +type testServer struct { + url string + writer *multipart.Writer + buffer *bytes.Buffer + fileWriter io.Writer + cleanup func() +} + +func setupWithTmpPath(t *testing.T, filename string, includeFormat bool, format string, authResponse *api.Response, bodyProcessor func(w http.ResponseWriter, r *http.Request)) *testServer { + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + + if authResponse == nil { + authResponse = &api.Response{TempPath: tempPath} + } + + ts := testArtifactsUploadServer(t, *authResponse, bodyProcessor) + + var buffer bytes.Buffer + writer := multipart.NewWriter(&buffer) + fileWriter, err := writer.CreateFormFile(filename, "my.file") + require.NotNil(t, fileWriter) + require.NoError(t, err) + + cleanup := func() { + ts.Close() + require.NoError(t, os.RemoveAll(tempPath)) + require.NoError(t, writer.Close()) + } + + qs := "" + + if includeFormat { + qs = fmt.Sprintf("?%s=%s", ArtifactFormatKey, format) + } + + return &testServer{url: ts.URL + Path + qs, writer: writer, buffer: &buffer, fileWriter: fileWriter, cleanup: cleanup} +} + +func testUploadArtifacts(t *testing.T, contentType, url string, body io.Reader) *httptest.ResponseRecorder { + httpRequest, err := http.NewRequest("POST", url, body) + require.NoError(t, err) + + httpRequest.Header.Set("Content-Type", contentType) + response := httptest.NewRecorder() + parsedURL := helper.URLMustParse(url) + roundTripper := roundtripper.NewTestBackendRoundTripper(parsedURL) + testhelper.ConfigureSecret() + apiClient := api.NewAPI(parsedURL, "123", roundTripper) + proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper) + UploadArtifacts(apiClient, proxyClient, &upload.DefaultPreparer{}).ServeHTTP(response, httpRequest) + return response +} + +func TestUploadHandlerAddingMetadata(t *testing.T) { + testCases := []struct { + desc string + format string + includeFormat bool + }{ + { + desc: "ZIP format", + format: ArtifactFormatZip, + includeFormat: true, + }, + { + desc: "default format", + format: ArtifactFormatDefault, + includeFormat: true, + }, + { + desc: "default format without artifact_format", + format: ArtifactFormatDefault, + includeFormat: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + s := setupWithTmpPath(t, "file", tc.includeFormat, tc.format, nil, + func(w http.ResponseWriter, r *http.Request) { + token, err := jwt.ParseWithClaims(r.Header.Get(upload.RewrittenFieldsHeader), &upload.MultipartClaims{}, testhelper.ParseJWT) + require.NoError(t, err) + + rewrittenFields := token.Claims.(*upload.MultipartClaims).RewrittenFields + require.Equal(t, 2, len(rewrittenFields)) + + require.Contains(t, rewrittenFields, "file") + require.Contains(t, rewrittenFields, "metadata") + require.Contains(t, r.PostForm, "file.gitlab-workhorse-upload") + require.Contains(t, r.PostForm, "metadata.gitlab-workhorse-upload") + }, + ) + defer s.cleanup() + + archive := zip.NewWriter(s.fileWriter) + file, err := archive.Create("test.file") + require.NotNil(t, file) + require.NoError(t, err) + + require.NoError(t, archive.Close()) + require.NoError(t, s.writer.Close()) + + response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer) + require.Equal(t, http.StatusOK, response.Code) + testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderPresent) + }) + } +} + +func TestUploadHandlerTarArtifact(t *testing.T) { + s := setupWithTmpPath(t, "file", true, "tar", nil, + func(w http.ResponseWriter, r *http.Request) { + token, err := jwt.ParseWithClaims(r.Header.Get(upload.RewrittenFieldsHeader), &upload.MultipartClaims{}, testhelper.ParseJWT) + require.NoError(t, err) + + rewrittenFields := token.Claims.(*upload.MultipartClaims).RewrittenFields + require.Equal(t, 1, len(rewrittenFields)) + + require.Contains(t, rewrittenFields, "file") + require.Contains(t, r.PostForm, "file.gitlab-workhorse-upload") + }, + ) + defer s.cleanup() + + file, err := os.Open("../../testdata/tarfile.tar") + require.NoError(t, err) + + _, err = io.Copy(s.fileWriter, file) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, s.writer.Close()) + + response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer) + require.Equal(t, http.StatusOK, response.Code) + testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderMissing) +} + +func TestUploadHandlerForUnsupportedArchive(t *testing.T) { + s := setupWithTmpPath(t, "file", true, "other", nil, nil) + defer s.cleanup() + require.NoError(t, s.writer.Close()) + + response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer) + require.Equal(t, http.StatusOK, response.Code) + testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderMissing) +} + +func TestUploadHandlerForMultipleFiles(t *testing.T) { + s := setupWithTmpPath(t, "file", true, "", nil, nil) + defer s.cleanup() + + file, err := s.writer.CreateFormFile("file", "my.file") + require.NotNil(t, file) + require.NoError(t, err) + require.NoError(t, s.writer.Close()) + + response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer) + require.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestUploadFormProcessing(t *testing.T) { + s := setupWithTmpPath(t, "metadata", true, "", nil, nil) + defer s.cleanup() + require.NoError(t, s.writer.Close()) + + response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer) + require.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestLsifFileProcessing(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + + s := setupWithTmpPath(t, "file", true, "zip", &api.Response{TempPath: tempPath, ProcessLsif: true}, nil) + defer s.cleanup() + + file, err := os.Open("../../testdata/lsif/valid.lsif.zip") + require.NoError(t, err) + + _, err = io.Copy(s.fileWriter, file) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, s.writer.Close()) + + response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer) + require.Equal(t, http.StatusOK, response.Code) + testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderPresent) +} + +func TestInvalidLsifFileProcessing(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + + s := setupWithTmpPath(t, "file", true, "zip", &api.Response{TempPath: tempPath, ProcessLsif: true}, nil) + defer s.cleanup() + + file, err := os.Open("../../testdata/lsif/invalid.lsif.zip") + require.NoError(t, err) + + _, err = io.Copy(s.fileWriter, file) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, s.writer.Close()) + + response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer) + require.Equal(t, http.StatusInternalServerError, response.Code) +} diff --git a/workhorse/internal/artifacts/entry.go b/workhorse/internal/artifacts/entry.go new file mode 100644 index 00000000000..0c697d40020 --- /dev/null +++ b/workhorse/internal/artifacts/entry.go @@ -0,0 +1,123 @@ +package artifacts + +import ( + "bufio" + "context" + "fmt" + "io" + "mime" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts" +) + +type entry struct{ senddata.Prefix } +type entryParams struct{ Archive, Entry string } + +var SendEntry = &entry{"artifacts-entry:"} + +// Artifacts downloader doesn't support ranges when downloading a single file +func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + var params entryParams + if err := e.Unpack(¶ms, sendData); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendEntry: unpack sendData: %v", err)) + return + } + + log.WithContextFields(r.Context(), log.Fields{ + "entry": params.Entry, + "archive": params.Archive, + "path": r.URL.Path, + }).Print("SendEntry: sending") + + if params.Archive == "" || params.Entry == "" { + helper.Fail500(w, r, fmt.Errorf("SendEntry: Archive or Entry is empty")) + return + } + + err := unpackFileFromZip(r.Context(), params.Archive, params.Entry, w.Header(), w) + + if os.IsNotExist(err) { + http.NotFound(w, r) + } else if err != nil { + helper.Fail500(w, r, fmt.Errorf("SendEntry: %v", err)) + } +} + +func detectFileContentType(fileName string) string { + contentType := mime.TypeByExtension(filepath.Ext(fileName)) + if contentType == "" { + contentType = "application/octet-stream" + } + return contentType +} + +func unpackFileFromZip(ctx context.Context, archivePath, encodedFilename string, headers http.Header, output io.Writer) error { + fileName, err := zipartifacts.DecodeFileEntry(encodedFilename) + if err != nil { + return err + } + + catFile := exec.Command("gitlab-zip-cat") + catFile.Env = append(os.Environ(), + "ARCHIVE_PATH="+archivePath, + "ENCODED_FILE_NAME="+encodedFilename, + ) + catFile.Stderr = log.ContextLogger(ctx).Writer() + catFile.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + stdout, err := catFile.StdoutPipe() + if err != nil { + return fmt.Errorf("create gitlab-zip-cat stdout pipe: %v", err) + } + + if err := catFile.Start(); err != nil { + return fmt.Errorf("start %v: %v", catFile.Args, err) + } + defer helper.CleanUpProcessGroup(catFile) + + basename := filepath.Base(fileName) + reader := bufio.NewReader(stdout) + contentLength, err := reader.ReadString('\n') + if err != nil { + if catFileErr := waitCatFile(catFile); catFileErr != nil { + return catFileErr + } + return fmt.Errorf("read content-length: %v", err) + } + contentLength = strings.TrimSuffix(contentLength, "\n") + + // Write http headers about the file + headers.Set("Content-Length", contentLength) + headers.Set("Content-Type", detectFileContentType(fileName)) + headers.Set("Content-Disposition", "attachment; filename=\""+escapeQuotes(basename)+"\"") + // Copy file body to client + if _, err := io.Copy(output, reader); err != nil { + return fmt.Errorf("copy stdout of %v: %v", catFile.Args, err) + } + + return waitCatFile(catFile) +} + +func waitCatFile(cmd *exec.Cmd) error { + err := cmd.Wait() + if err == nil { + return nil + } + + st, ok := helper.ExitStatus(err) + + if ok && (st == zipartifacts.CodeArchiveNotFound || st == zipartifacts.CodeEntryNotFound) { + return os.ErrNotExist + } + return fmt.Errorf("wait for %v to finish: %v", cmd.Args, err) + +} diff --git a/workhorse/internal/artifacts/entry_test.go b/workhorse/internal/artifacts/entry_test.go new file mode 100644 index 00000000000..6f1e9d360aa --- /dev/null +++ b/workhorse/internal/artifacts/entry_test.go @@ -0,0 +1,134 @@ +package artifacts + +import ( + "archive/zip" + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func testEntryServer(t *testing.T, archive string, entry string) *httptest.ResponseRecorder { + mux := http.NewServeMux() + mux.HandleFunc("/url/path", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "GET", r.Method) + + encodedEntry := base64.StdEncoding.EncodeToString([]byte(entry)) + jsonParams := fmt.Sprintf(`{"Archive":"%s","Entry":"%s"}`, archive, encodedEntry) + data := base64.URLEncoding.EncodeToString([]byte(jsonParams)) + + SendEntry.Inject(w, r, data) + }) + + httpRequest, err := http.NewRequest("GET", "/url/path", nil) + require.NoError(t, err) + response := httptest.NewRecorder() + mux.ServeHTTP(response, httpRequest) + return response +} + +func TestDownloadingFromValidArchive(t *testing.T) { + tempFile, err := ioutil.TempFile("", "uploads") + require.NoError(t, err) + defer tempFile.Close() + defer os.Remove(tempFile.Name()) + + archive := zip.NewWriter(tempFile) + defer archive.Close() + fileInArchive, err := archive.Create("test.txt") + require.NoError(t, err) + fmt.Fprint(fileInArchive, "testtest") + archive.Close() + + response := testEntryServer(t, tempFile.Name(), "test.txt") + + require.Equal(t, 200, response.Code) + + testhelper.RequireResponseHeader(t, response, + "Content-Type", + "text/plain; charset=utf-8") + testhelper.RequireResponseHeader(t, response, + "Content-Disposition", + "attachment; filename=\"test.txt\"") + + testhelper.RequireResponseBody(t, response, "testtest") +} + +func TestDownloadingFromValidHTTPArchive(t *testing.T) { + tempDir, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + f, err := os.Create(filepath.Join(tempDir, "archive.zip")) + require.NoError(t, err) + defer f.Close() + + archive := zip.NewWriter(f) + defer archive.Close() + fileInArchive, err := archive.Create("test.txt") + require.NoError(t, err) + fmt.Fprint(fileInArchive, "testtest") + archive.Close() + f.Close() + + fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir))) + defer fileServer.Close() + + response := testEntryServer(t, fileServer.URL+"/archive.zip", "test.txt") + + require.Equal(t, 200, response.Code) + + testhelper.RequireResponseHeader(t, response, + "Content-Type", + "text/plain; charset=utf-8") + testhelper.RequireResponseHeader(t, response, + "Content-Disposition", + "attachment; filename=\"test.txt\"") + + testhelper.RequireResponseBody(t, response, "testtest") +} + +func TestDownloadingNonExistingFile(t *testing.T) { + tempFile, err := ioutil.TempFile("", "uploads") + require.NoError(t, err) + defer tempFile.Close() + defer os.Remove(tempFile.Name()) + + archive := zip.NewWriter(tempFile) + defer archive.Close() + archive.Close() + + response := testEntryServer(t, tempFile.Name(), "test") + require.Equal(t, 404, response.Code) +} + +func TestDownloadingFromInvalidArchive(t *testing.T) { + response := testEntryServer(t, "path/to/non/existing/file", "test") + require.Equal(t, 404, response.Code) +} + +func TestIncompleteApiResponse(t *testing.T) { + response := testEntryServer(t, "", "") + require.Equal(t, 500, response.Code) +} + +func TestDownloadingFromNonExistingHTTPArchive(t *testing.T) { + tempDir, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir))) + defer fileServer.Close() + + response := testEntryServer(t, fileServer.URL+"/not-existing-archive-file.zip", "test.txt") + + require.Equal(t, 404, response.Code) +} diff --git a/workhorse/internal/artifacts/escape_quotes.go b/workhorse/internal/artifacts/escape_quotes.go new file mode 100644 index 00000000000..94db2be39b7 --- /dev/null +++ b/workhorse/internal/artifacts/escape_quotes.go @@ -0,0 +1,10 @@ +package artifacts + +import "strings" + +// taken from mime/multipart/writer.go +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} diff --git a/workhorse/internal/badgateway/roundtripper.go b/workhorse/internal/badgateway/roundtripper.go new file mode 100644 index 00000000000..a36cc9f4a9a --- /dev/null +++ b/workhorse/internal/badgateway/roundtripper.go @@ -0,0 +1,115 @@ +package badgateway + +import ( + "bytes" + "fmt" + "html/template" + "io/ioutil" + "net/http" + "strings" + "time" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +// Error is a custom error for pretty Sentry 'issues' +type sentryError struct{ error } + +type roundTripper struct { + next http.RoundTripper + developmentMode bool +} + +// NewRoundTripper creates a RoundTripper with a provided underlying next transport +func NewRoundTripper(developmentMode bool, next http.RoundTripper) http.RoundTripper { + return &roundTripper{next: next, developmentMode: developmentMode} +} + +func (t *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + start := time.Now() + + res, err := t.next.RoundTrip(r) + if err == nil { + return res, err + } + + // httputil.ReverseProxy translates all errors from this + // RoundTrip function into 500 errors. But the most likely error + // is that the Rails app is not responding, in which case users + // and administrators expect to see a 502 error. To show 502s + // instead of 500s we catch the RoundTrip error here and inject a + // 502 response. + fields := log.Fields{"duration_ms": int64(time.Since(start).Seconds() * 1000)} + helper.LogErrorWithFields( + r, + &sentryError{fmt.Errorf("badgateway: failed to receive response: %v", err)}, + fields, + ) + + injectedResponse := &http.Response{ + StatusCode: http.StatusBadGateway, + Status: http.StatusText(http.StatusBadGateway), + + Request: r, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + Proto: r.Proto, + Header: make(http.Header), + Trailer: make(http.Header), + } + + message := "GitLab is not responding" + contentType := "text/plain" + if t.developmentMode { + message, contentType = developmentModeResponse(err) + } + + injectedResponse.Body = ioutil.NopCloser(strings.NewReader(message)) + injectedResponse.Header.Set("Content-Type", contentType) + + return injectedResponse, nil +} + +func developmentModeResponse(err error) (body string, contentType string) { + data := TemplateData{ + Time: time.Now().Format("15:04:05"), + Error: err.Error(), + ReloadSeconds: 5, + } + + buf := &bytes.Buffer{} + if err := developmentErrorTemplate.Execute(buf, data); err != nil { + return data.Error, "text/plain" + } + + return buf.String(), "text/html" +} + +type TemplateData struct { + Time string + Error string + ReloadSeconds int +} + +var developmentErrorTemplate = template.Must(template.New("error502").Parse(` +<html> +<head> +<title>502: GitLab is not responding</title> +<script> +window.setTimeout(function() { location.reload() }, {{.ReloadSeconds}} * 1000) +</script> +</head> + +<body> +<h1>502</h1> +<p>GitLab is not responding. The error was:</p> + +<pre>{{.Error}}</pre> + +<p>If you just started GDK it can take 60-300 seconds before GitLab has finished booting. This page will automatically reload every {{.ReloadSeconds}} seconds.</p> +<footer>Generated by gitlab-workhorse running in development mode at {{.Time}}.</footer> +</body> +</html> +`)) diff --git a/workhorse/internal/badgateway/roundtripper_test.go b/workhorse/internal/badgateway/roundtripper_test.go new file mode 100644 index 00000000000..fc7132f9bd7 --- /dev/null +++ b/workhorse/internal/badgateway/roundtripper_test.go @@ -0,0 +1,56 @@ +package badgateway + +import ( + "errors" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +type roundtrip502 struct{} + +func (roundtrip502) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errors.New("something went wrong") +} + +func TestErrorPage502(t *testing.T) { + tests := []struct { + name string + devMode bool + contentType string + responseSnippet string + }{ + { + name: "production mode", + contentType: "text/plain", + responseSnippet: "GitLab is not responding", + }, + { + name: "development mode", + devMode: true, + contentType: "text/html", + responseSnippet: "This page will automatically reload", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err, "build request") + + rt := NewRoundTripper(tc.devMode, roundtrip502{}) + response, err := rt.RoundTrip(req) + require.NoError(t, err, "perform roundtrip") + defer response.Body.Close() + + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + require.Equal(t, tc.contentType, response.Header.Get("content-type"), "content type") + require.Equal(t, 502, response.StatusCode, "response status") + require.Contains(t, string(body), tc.responseSnippet) + }) + } +} diff --git a/workhorse/internal/builds/register.go b/workhorse/internal/builds/register.go new file mode 100644 index 00000000000..77685889cfd --- /dev/null +++ b/workhorse/internal/builds/register.go @@ -0,0 +1,163 @@ +package builds + +import ( + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/redis" +) + +const ( + maxRegisterBodySize = 32 * 1024 + runnerBuildQueue = "runner:build_queue:" + runnerBuildQueueHeaderKey = "Gitlab-Ci-Builds-Polling" + runnerBuildQueueHeaderValue = "yes" +) + +var ( + registerHandlerRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_builds_register_handler_requests", + Help: "Describes how many requests in different states hit a register handler", + }, + []string{"status"}, + ) + registerHandlerOpen = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_workhorse_builds_register_handler_open", + Help: "Describes how many requests is currently open in given state", + }, + []string{"state"}, + ) + + registerHandlerOpenAtReading = registerHandlerOpen.WithLabelValues("reading") + registerHandlerOpenAtProxying = registerHandlerOpen.WithLabelValues("proxying") + registerHandlerOpenAtWatching = registerHandlerOpen.WithLabelValues("watching") + + registerHandlerBodyReadErrors = registerHandlerRequests.WithLabelValues("body-read-error") + registerHandlerBodyParseErrors = registerHandlerRequests.WithLabelValues("body-parse-error") + registerHandlerMissingValues = registerHandlerRequests.WithLabelValues("missing-values") + registerHandlerWatchErrors = registerHandlerRequests.WithLabelValues("watch-error") + registerHandlerAlreadyChangedRequests = registerHandlerRequests.WithLabelValues("already-changed") + registerHandlerSeenChangeRequests = registerHandlerRequests.WithLabelValues("seen-change") + registerHandlerTimeoutRequests = registerHandlerRequests.WithLabelValues("timeout") + registerHandlerNoChangeRequests = registerHandlerRequests.WithLabelValues("no-change") +) + +type largeBodyError struct{ error } + +type WatchKeyHandler func(key, value string, timeout time.Duration) (redis.WatchKeyStatus, error) + +type runnerRequest struct { + Token string `json:"token,omitempty"` + LastUpdate string `json:"last_update,omitempty"` +} + +func readRunnerBody(w http.ResponseWriter, r *http.Request) ([]byte, error) { + registerHandlerOpenAtReading.Inc() + defer registerHandlerOpenAtReading.Dec() + + return helper.ReadRequestBody(w, r, maxRegisterBodySize) +} + +func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) { + if !helper.IsApplicationJson(r) { + return nil, errors.New("invalid content-type received") + } + + var runnerRequest runnerRequest + err := json.Unmarshal(body, &runnerRequest) + if err != nil { + return nil, err + } + + return &runnerRequest, nil +} + +func proxyRegisterRequest(h http.Handler, w http.ResponseWriter, r *http.Request) { + registerHandlerOpenAtProxying.Inc() + defer registerHandlerOpenAtProxying.Dec() + + h.ServeHTTP(w, r) +} + +func watchForRunnerChange(watchHandler WatchKeyHandler, token, lastUpdate string, duration time.Duration) (redis.WatchKeyStatus, error) { + registerHandlerOpenAtWatching.Inc() + defer registerHandlerOpenAtWatching.Dec() + + return watchHandler(runnerBuildQueue+token, lastUpdate, duration) +} + +func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDuration time.Duration) http.Handler { + if pollingDuration == 0 { + return h + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(runnerBuildQueueHeaderKey, runnerBuildQueueHeaderValue) + + requestBody, err := readRunnerBody(w, r) + if err != nil { + registerHandlerBodyReadErrors.Inc() + helper.RequestEntityTooLarge(w, r, &largeBodyError{err}) + return + } + + newRequest := helper.CloneRequestWithNewBody(r, requestBody) + + runnerRequest, err := readRunnerRequest(r, requestBody) + if err != nil { + registerHandlerBodyParseErrors.Inc() + proxyRegisterRequest(h, w, newRequest) + return + } + + if runnerRequest.Token == "" || runnerRequest.LastUpdate == "" { + registerHandlerMissingValues.Inc() + proxyRegisterRequest(h, w, newRequest) + return + } + + result, err := watchForRunnerChange(watchHandler, runnerRequest.Token, + runnerRequest.LastUpdate, pollingDuration) + if err != nil { + registerHandlerWatchErrors.Inc() + proxyRegisterRequest(h, w, newRequest) + return + } + + switch result { + // It means that we detected a change before starting watching on change, + // We proxy request to Rails, to see whether we have a build to receive + case redis.WatchKeyStatusAlreadyChanged: + registerHandlerAlreadyChangedRequests.Inc() + proxyRegisterRequest(h, w, newRequest) + + // It means that we detected a change after watching. + // We could potentially proxy request to Rails, but... + // We can end-up with unreliable responses, + // as don't really know whether ResponseWriter is still in a sane state, + // for example the connection is dead + case redis.WatchKeyStatusSeenChange: + registerHandlerSeenChangeRequests.Inc() + w.WriteHeader(http.StatusNoContent) + + // When we receive one of these statuses, it means that we detected no change, + // so we return to runner 204, which means nothing got changed, + // and there's no new builds to process + case redis.WatchKeyStatusTimeout: + registerHandlerTimeoutRequests.Inc() + w.WriteHeader(http.StatusNoContent) + + case redis.WatchKeyStatusNoChange: + registerHandlerNoChangeRequests.Inc() + w.WriteHeader(http.StatusNoContent) + } + }) +} diff --git a/workhorse/internal/builds/register_test.go b/workhorse/internal/builds/register_test.go new file mode 100644 index 00000000000..a72d82dc2b7 --- /dev/null +++ b/workhorse/internal/builds/register_test.go @@ -0,0 +1,108 @@ +package builds + +import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/redis" +) + +const upstreamResponseCode = 999 + +func echoRequest(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(upstreamResponseCode) + io.Copy(rw, req.Body) +} + +var echoRequestFunc = http.HandlerFunc(echoRequest) + +func expectHandlerWithWatcher(t *testing.T, watchHandler WatchKeyHandler, data string, contentType string, expectedHttpStatus int, msgAndArgs ...interface{}) { + h := RegisterHandler(echoRequestFunc, watchHandler, time.Second) + + rw := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(data)) + req.Header.Set("Content-Type", contentType) + + h.ServeHTTP(rw, req) + + require.Equal(t, expectedHttpStatus, rw.Code, msgAndArgs...) +} + +func expectHandler(t *testing.T, data string, contentType string, expectedHttpStatus int, msgAndArgs ...interface{}) { + expectHandlerWithWatcher(t, nil, data, contentType, expectedHttpStatus, msgAndArgs...) +} + +func TestRegisterHandlerLargeBody(t *testing.T) { + data := strings.Repeat(".", maxRegisterBodySize+5) + expectHandler(t, data, "application/json", http.StatusRequestEntityTooLarge, + "rejects body with entity too large") +} + +func TestRegisterHandlerInvalidRunnerRequest(t *testing.T) { + expectHandler(t, "invalid content", "text/plain", upstreamResponseCode, + "proxies request to upstream") +} + +func TestRegisterHandlerInvalidJsonPayload(t *testing.T) { + expectHandler(t, `{[`, "application/json", upstreamResponseCode, + "fails on parsing body and proxies request to upstream") +} + +func TestRegisterHandlerMissingData(t *testing.T) { + testCases := []string{ + `{"token":"token"}`, + `{"last_update":"data"}`, + } + + for _, testCase := range testCases { + expectHandler(t, testCase, "application/json", upstreamResponseCode, + "fails on argument validation and proxies request to upstream") + } +} + +func expectWatcherToBeExecuted(t *testing.T, watchKeyStatus redis.WatchKeyStatus, watchKeyError error, + httpStatus int, msgAndArgs ...interface{}) { + executed := false + watchKeyHandler := func(key, value string, timeout time.Duration) (redis.WatchKeyStatus, error) { + executed = true + return watchKeyStatus, watchKeyError + } + + parsableData := `{"token":"token","last_update":"last_update"}` + + expectHandlerWithWatcher(t, watchKeyHandler, parsableData, "application/json", httpStatus, msgAndArgs...) + require.True(t, executed, msgAndArgs...) +} + +func TestRegisterHandlerWatcherError(t *testing.T) { + expectWatcherToBeExecuted(t, redis.WatchKeyStatusNoChange, errors.New("redis connection"), + upstreamResponseCode, "proxies data to upstream") +} + +func TestRegisterHandlerWatcherAlreadyChanged(t *testing.T) { + expectWatcherToBeExecuted(t, redis.WatchKeyStatusAlreadyChanged, nil, + upstreamResponseCode, "proxies data to upstream") +} + +func TestRegisterHandlerWatcherSeenChange(t *testing.T) { + expectWatcherToBeExecuted(t, redis.WatchKeyStatusSeenChange, nil, + http.StatusNoContent) +} + +func TestRegisterHandlerWatcherTimeout(t *testing.T) { + expectWatcherToBeExecuted(t, redis.WatchKeyStatusTimeout, nil, + http.StatusNoContent) +} + +func TestRegisterHandlerWatcherNoChange(t *testing.T) { + expectWatcherToBeExecuted(t, redis.WatchKeyStatusNoChange, nil, + http.StatusNoContent) +} diff --git a/workhorse/internal/channel/auth_checker.go b/workhorse/internal/channel/auth_checker.go new file mode 100644 index 00000000000..f44850e0861 --- /dev/null +++ b/workhorse/internal/channel/auth_checker.go @@ -0,0 +1,69 @@ +package channel + +import ( + "errors" + "net/http" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" +) + +type AuthCheckerFunc func() *api.ChannelSettings + +// Regularly checks that authorization is still valid for a channel, outputting +// to the stopper when it isn't +type AuthChecker struct { + Checker AuthCheckerFunc + Template *api.ChannelSettings + StopCh chan error + Done chan struct{} + Count int64 +} + +var ErrAuthChanged = errors.New("connection closed: authentication changed or endpoint unavailable") + +func NewAuthChecker(f AuthCheckerFunc, template *api.ChannelSettings, stopCh chan error) *AuthChecker { + return &AuthChecker{ + Checker: f, + Template: template, + StopCh: stopCh, + Done: make(chan struct{}), + } +} +func (c *AuthChecker) Loop(interval time.Duration) { + for { + select { + case <-time.After(interval): + settings := c.Checker() + if !c.Template.IsEqual(settings) { + c.StopCh <- ErrAuthChanged + return + } + c.Count = c.Count + 1 + case <-c.Done: + return + } + } +} + +func (c *AuthChecker) Close() error { + close(c.Done) + return nil +} + +// Generates a CheckerFunc from an *api.API + request needing authorization +func authCheckFunc(myAPI *api.API, r *http.Request, suffix string) AuthCheckerFunc { + return func() *api.ChannelSettings { + httpResponse, authResponse, err := myAPI.PreAuthorize(suffix, r) + if err != nil { + return nil + } + defer httpResponse.Body.Close() + + if httpResponse.StatusCode != http.StatusOK || authResponse == nil { + return nil + } + + return authResponse.Channel + } +} diff --git a/workhorse/internal/channel/auth_checker_test.go b/workhorse/internal/channel/auth_checker_test.go new file mode 100644 index 00000000000..18beb45cf3a --- /dev/null +++ b/workhorse/internal/channel/auth_checker_test.go @@ -0,0 +1,53 @@ +package channel + +import ( + "testing" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" +) + +func checkerSeries(values ...*api.ChannelSettings) AuthCheckerFunc { + return func() *api.ChannelSettings { + if len(values) == 0 { + return nil + } + out := values[0] + values = values[1:] + return out + } +} + +func TestAuthCheckerStopsWhenAuthFails(t *testing.T) { + template := &api.ChannelSettings{Url: "ws://example.com"} + stopCh := make(chan error) + series := checkerSeries(template, template, template) + ac := NewAuthChecker(series, template, stopCh) + + go ac.Loop(1 * time.Millisecond) + if err := <-stopCh; err != ErrAuthChanged { + t.Fatalf("Expected ErrAuthChanged, got %v", err) + } + + if ac.Count != 3 { + t.Fatalf("Expected 3 successful checks, got %v", ac.Count) + } +} + +func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) { + template := &api.ChannelSettings{Url: "ws://example.com"} + changed := template.Clone() + changed.Url = "wss://example.com" + stopCh := make(chan error) + series := checkerSeries(template, changed, template) + ac := NewAuthChecker(series, template, stopCh) + + go ac.Loop(1 * time.Millisecond) + if err := <-stopCh; err != ErrAuthChanged { + t.Fatalf("Expected ErrAuthChanged, got %v", err) + } + + if ac.Count != 1 { + t.Fatalf("Expected 1 successful check, got %v", ac.Count) + } +} diff --git a/workhorse/internal/channel/channel.go b/workhorse/internal/channel/channel.go new file mode 100644 index 00000000000..381ce95df82 --- /dev/null +++ b/workhorse/internal/channel/channel.go @@ -0,0 +1,132 @@ +package channel + +import ( + "fmt" + "net/http" + "time" + + "github.com/gorilla/websocket" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + // See doc/channel.md for documentation of this subprotocol + subprotocols = []string{"terminal.gitlab.com", "base64.terminal.gitlab.com"} + upgrader = &websocket.Upgrader{Subprotocols: subprotocols} + ReauthenticationInterval = 5 * time.Minute + BrowserPingInterval = 30 * time.Second +) + +func Handler(myAPI *api.API) http.Handler { + return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) { + if err := a.Channel.Validate(); err != nil { + helper.Fail500(w, r, err) + return + } + + proxy := NewProxy(2) // two stoppers: auth checker, max time + checker := NewAuthChecker( + authCheckFunc(myAPI, r, "authorize"), + a.Channel, + proxy.StopCh, + ) + defer checker.Close() + go checker.Loop(ReauthenticationInterval) + go closeAfterMaxTime(proxy, a.Channel.MaxSessionTime) + + ProxyChannel(w, r, a.Channel, proxy) + }, "authorize") +} + +func ProxyChannel(w http.ResponseWriter, r *http.Request, settings *api.ChannelSettings, proxy *Proxy) { + server, err := connectToServer(settings, r) + if err != nil { + helper.Fail500(w, r, err) + log.ContextLogger(r.Context()).WithError(err).Print("Channel: connecting to server failed") + return + } + defer server.UnderlyingConn().Close() + serverAddr := server.UnderlyingConn().RemoteAddr().String() + + client, err := upgradeClient(w, r) + if err != nil { + log.ContextLogger(r.Context()).WithError(err).Print("Channel: upgrading client to websocket failed") + return + } + + // Regularly send ping messages to the browser to keep the websocket from + // being timed out by intervening proxies. + go pingLoop(client) + + defer client.UnderlyingConn().Close() + clientAddr := getClientAddr(r) // We can't know the port with confidence + + logEntry := log.WithContextFields(r.Context(), log.Fields{ + "clientAddr": clientAddr, + "serverAddr": serverAddr, + }) + + logEntry.Print("Channel: started proxying") + + defer logEntry.Print("Channel: finished proxying") + + if err := proxy.Serve(server, client, serverAddr, clientAddr); err != nil { + logEntry.WithError(err).Print("Channel: error proxying") + } +} + +// In the future, we might want to look at X-Client-Ip or X-Forwarded-For +func getClientAddr(r *http.Request) string { + return r.RemoteAddr +} + +func upgradeClient(w http.ResponseWriter, r *http.Request) (Connection, error) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, err + } + + return Wrap(conn, conn.Subprotocol()), nil +} + +func pingLoop(conn Connection) { + for { + time.Sleep(BrowserPingInterval) + deadline := time.Now().Add(5 * time.Second) + if err := conn.WriteControl(websocket.PingMessage, nil, deadline); err != nil { + // Either the connection was already closed so no further pings are + // needed, or this connection is now dead and no further pings can + // be sent. + break + } + } +} + +func connectToServer(settings *api.ChannelSettings, r *http.Request) (Connection, error) { + settings = settings.Clone() + + helper.SetForwardedFor(&settings.Header, r) + + conn, _, err := settings.Dial() + if err != nil { + return nil, err + } + + return Wrap(conn, conn.Subprotocol()), nil +} + +func closeAfterMaxTime(proxy *Proxy, maxSessionTime int) { + if maxSessionTime == 0 { + return + } + + <-time.After(time.Duration(maxSessionTime) * time.Second) + proxy.StopCh <- fmt.Errorf( + "connection closed: session time greater than maximum time allowed - %v seconds", + maxSessionTime, + ) +} diff --git a/workhorse/internal/channel/proxy.go b/workhorse/internal/channel/proxy.go new file mode 100644 index 00000000000..71f58092276 --- /dev/null +++ b/workhorse/internal/channel/proxy.go @@ -0,0 +1,56 @@ +package channel + +import ( + "fmt" + "net" + "time" + + "github.com/gorilla/websocket" +) + +// ANSI "end of channel" code +var eot = []byte{0x04} + +// An abstraction of gorilla's *websocket.Conn +type Connection interface { + UnderlyingConn() net.Conn + ReadMessage() (int, []byte, error) + WriteMessage(int, []byte) error + WriteControl(int, []byte, time.Time) error +} + +type Proxy struct { + StopCh chan error +} + +// stoppers is the number of goroutines that may attempt to call Stop() +func NewProxy(stoppers int) *Proxy { + return &Proxy{ + StopCh: make(chan error, stoppers+2), // each proxy() call is a stopper + } +} + +func (p *Proxy) Serve(upstream, downstream Connection, upstreamAddr, downstreamAddr string) error { + // This signals the upstream channel to kill the exec'd process + defer upstream.WriteMessage(websocket.BinaryMessage, eot) + + go p.proxy(upstream, downstream, upstreamAddr, downstreamAddr) + go p.proxy(downstream, upstream, downstreamAddr, upstreamAddr) + + return <-p.StopCh +} + +func (p *Proxy) proxy(to, from Connection, toAddr, fromAddr string) { + for { + messageType, data, err := from.ReadMessage() + if err != nil { + p.StopCh <- fmt.Errorf("reading from %s: %s", fromAddr, err) + break + } + + if err := to.WriteMessage(messageType, data); err != nil { + p.StopCh <- fmt.Errorf("writing to %s: %s", toAddr, err) + break + } + } +} diff --git a/workhorse/internal/channel/wrappers.go b/workhorse/internal/channel/wrappers.go new file mode 100644 index 00000000000..6fd955bedc7 --- /dev/null +++ b/workhorse/internal/channel/wrappers.go @@ -0,0 +1,134 @@ +package channel + +import ( + "encoding/base64" + "net" + "time" + + "github.com/gorilla/websocket" +) + +func Wrap(conn Connection, subprotocol string) Connection { + switch subprotocol { + case "channel.k8s.io": + return &kubeWrapper{base64: false, conn: conn} + case "base64.channel.k8s.io": + return &kubeWrapper{base64: true, conn: conn} + case "terminal.gitlab.com": + return &gitlabWrapper{base64: false, conn: conn} + case "base64.terminal.gitlab.com": + return &gitlabWrapper{base64: true, conn: conn} + } + + return conn +} + +type kubeWrapper struct { + base64 bool + conn Connection +} + +type gitlabWrapper struct { + base64 bool + conn Connection +} + +func (w *gitlabWrapper) ReadMessage() (int, []byte, error) { + mt, data, err := w.conn.ReadMessage() + if err != nil { + return mt, data, err + } + + if isData(mt) { + mt = websocket.BinaryMessage + if w.base64 { + data, err = decodeBase64(data) + } + } + + return mt, data, err +} + +func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error { + if isData(mt) { + if w.base64 { + mt = websocket.TextMessage + data = encodeBase64(data) + } else { + mt = websocket.BinaryMessage + } + } + + return w.conn.WriteMessage(mt, data) +} + +func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error { + return w.conn.WriteControl(mt, data, deadline) +} + +func (w *gitlabWrapper) UnderlyingConn() net.Conn { + return w.conn.UnderlyingConn() +} + +// Coalesces all wsstreams into a single stream. In practice, we should only +// receive data on stream 1. +func (w *kubeWrapper) ReadMessage() (int, []byte, error) { + mt, data, err := w.conn.ReadMessage() + if err != nil { + return mt, data, err + } + + if isData(mt) { + mt = websocket.BinaryMessage + + // Remove the WSStream channel number, decode to raw + if len(data) > 0 { + data = data[1:] + if w.base64 { + data, err = decodeBase64(data) + } + } + } + + return mt, data, err +} + +// Always sends to wsstream 0 +func (w *kubeWrapper) WriteMessage(mt int, data []byte) error { + if isData(mt) { + if w.base64 { + mt = websocket.TextMessage + data = append([]byte{'0'}, encodeBase64(data)...) + } else { + mt = websocket.BinaryMessage + data = append([]byte{0}, data...) + } + } + + return w.conn.WriteMessage(mt, data) +} + +func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error { + return w.conn.WriteControl(mt, data, deadline) +} + +func (w *kubeWrapper) UnderlyingConn() net.Conn { + return w.conn.UnderlyingConn() +} + +func isData(mt int) bool { + return mt == websocket.BinaryMessage || mt == websocket.TextMessage +} + +func encodeBase64(data []byte) []byte { + buf := make([]byte, base64.StdEncoding.EncodedLen(len(data))) + base64.StdEncoding.Encode(buf, data) + + return buf +} + +func decodeBase64(data []byte) ([]byte, error) { + buf := make([]byte, base64.StdEncoding.DecodedLen(len(data))) + n, err := base64.StdEncoding.Decode(buf, data) + return buf[:n], err +} diff --git a/workhorse/internal/channel/wrappers_test.go b/workhorse/internal/channel/wrappers_test.go new file mode 100644 index 00000000000..1e0226f85d8 --- /dev/null +++ b/workhorse/internal/channel/wrappers_test.go @@ -0,0 +1,155 @@ +package channel + +import ( + "bytes" + "errors" + "net" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +type testcase struct { + input *fakeConn + expected *fakeConn +} + +type fakeConn struct { + // WebSocket message type + mt int + data []byte + err error +} + +func (f *fakeConn) ReadMessage() (int, []byte, error) { + return f.mt, f.data, f.err +} + +func (f *fakeConn) WriteMessage(mt int, data []byte) error { + f.mt = mt + f.data = data + return f.err +} + +func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error { + f.mt = mt + f.data = data + return f.err +} + +func (f *fakeConn) UnderlyingConn() net.Conn { + return nil +} + +func fake(mt int, data []byte, err error) *fakeConn { + return &fakeConn{mt: mt, data: []byte(data), err: err} +} + +var ( + msg = []byte("foo bar") + msgBase64 = []byte("Zm9vIGJhcg==") + kubeMsg = append([]byte{0}, msg...) + kubeMsgBase64 = append([]byte{'0'}, msgBase64...) + + errFake = errors.New("fake error") + + text = websocket.TextMessage + binary = websocket.BinaryMessage + other = 999 + + fakeOther = fake(other, []byte("foo"), nil) +) + +func requireEqualConn(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) { + if expected.mt != actual.mt { + t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt) + t.Fatalf(msg, args...) + } + + if !bytes.Equal(expected.data, actual.data) { + t.Logf("data expected to be %q but was %q: ", expected.data, actual.data) + t.Fatalf(msg, args...) + } + + if expected.err != actual.err { + t.Logf("error expected to be %v but was %v", expected.err, actual.err) + t.Fatalf(msg, args...) + } +} + +func TestReadMessage(t *testing.T) { + testCases := map[string][]testcase{ + "channel.k8s.io": { + {fake(binary, kubeMsg, errFake), fake(binary, kubeMsg, errFake)}, + {fake(binary, kubeMsg, nil), fake(binary, msg, nil)}, + {fake(text, kubeMsg, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.channel.k8s.io": { + {fake(text, kubeMsgBase64, errFake), fake(text, kubeMsgBase64, errFake)}, + {fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)}, + {fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "terminal.gitlab.com": { + {fake(binary, msg, errFake), fake(binary, msg, errFake)}, + {fake(binary, msg, nil), fake(binary, msg, nil)}, + {fake(text, msg, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.terminal.gitlab.com": { + {fake(text, msgBase64, errFake), fake(text, msgBase64, errFake)}, + {fake(text, msgBase64, nil), fake(binary, msg, nil)}, + {fake(binary, msgBase64, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + } + + for subprotocol, cases := range testCases { + for i, tc := range cases { + conn := Wrap(tc.input, subprotocol) + mt, data, err := conn.ReadMessage() + actual := fake(mt, data, err) + requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i) + } + } +} + +func TestWriteMessage(t *testing.T) { + testCases := map[string][]testcase{ + "channel.k8s.io": { + {fake(binary, msg, errFake), fake(binary, kubeMsg, errFake)}, + {fake(binary, msg, nil), fake(binary, kubeMsg, nil)}, + {fake(text, msg, nil), fake(binary, kubeMsg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.channel.k8s.io": { + {fake(binary, msg, errFake), fake(text, kubeMsgBase64, errFake)}, + {fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)}, + {fake(text, msg, nil), fake(text, kubeMsgBase64, nil)}, + {fakeOther, fakeOther}, + }, + "terminal.gitlab.com": { + {fake(binary, msg, errFake), fake(binary, msg, errFake)}, + {fake(binary, msg, nil), fake(binary, msg, nil)}, + {fake(text, msg, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.terminal.gitlab.com": { + {fake(binary, msg, errFake), fake(text, msgBase64, errFake)}, + {fake(binary, msg, nil), fake(text, msgBase64, nil)}, + {fake(text, msg, nil), fake(text, msgBase64, nil)}, + {fakeOther, fakeOther}, + }, + } + + for subprotocol, cases := range testCases { + for i, tc := range cases { + actual := fake(0, nil, tc.input.err) + conn := Wrap(actual, subprotocol) + actual.err = conn.WriteMessage(tc.input.mt, tc.input.data) + requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i) + } + } +} diff --git a/workhorse/internal/config/config.go b/workhorse/internal/config/config.go new file mode 100644 index 00000000000..84849c72744 --- /dev/null +++ b/workhorse/internal/config/config.go @@ -0,0 +1,154 @@ +package config + +import ( + "math" + "net/url" + "runtime" + "strings" + "time" + + "github.com/Azure/azure-storage-blob-go/azblob" + "github.com/BurntSushi/toml" + "gitlab.com/gitlab-org/labkit/log" + "gocloud.dev/blob" + "gocloud.dev/blob/azureblob" +) + +type TomlURL struct { + url.URL +} + +func (u *TomlURL) UnmarshalText(text []byte) error { + temp, err := url.Parse(string(text)) + u.URL = *temp + return err +} + +type TomlDuration struct { + time.Duration +} + +func (d *TomlDuration) UnmarshalTest(text []byte) error { + temp, err := time.ParseDuration(string(text)) + d.Duration = temp + return err +} + +type ObjectStorageCredentials struct { + Provider string + + S3Credentials S3Credentials `toml:"s3"` + AzureCredentials AzureCredentials `toml:"azurerm"` +} + +type ObjectStorageConfig struct { + URLMux *blob.URLMux `toml:"-"` +} + +type S3Credentials struct { + AwsAccessKeyID string `toml:"aws_access_key_id"` + AwsSecretAccessKey string `toml:"aws_secret_access_key"` +} + +type S3Config struct { + Region string `toml:"-"` + Bucket string `toml:"-"` + PathStyle bool `toml:"-"` + Endpoint string `toml:"-"` + UseIamProfile bool `toml:"-"` + ServerSideEncryption string `toml:"-"` // Server-side encryption mode (e.g. AES256, aws:kms) + SSEKMSKeyID string `toml:"-"` // Server-side encryption key-management service key ID (e.g. arn:aws:xxx) +} + +type GoCloudConfig struct { + URL string `toml:"-"` +} + +type AzureCredentials struct { + AccountName string `toml:"azure_storage_account_name"` + AccountKey string `toml:"azure_storage_access_key"` +} + +type RedisConfig struct { + URL TomlURL + Sentinel []TomlURL + SentinelMaster string + Password string + DB *int + ReadTimeout *TomlDuration + WriteTimeout *TomlDuration + KeepAlivePeriod *TomlDuration + MaxIdle *int + MaxActive *int +} + +type ImageResizerConfig struct { + MaxScalerProcs uint32 `toml:"max_scaler_procs"` + MaxFilesize uint64 `toml:"max_filesize"` +} + +type Config struct { + Redis *RedisConfig `toml:"redis"` + Backend *url.URL `toml:"-"` + CableBackend *url.URL `toml:"-"` + Version string `toml:"-"` + DocumentRoot string `toml:"-"` + DevelopmentMode bool `toml:"-"` + Socket string `toml:"-"` + CableSocket string `toml:"-"` + ProxyHeadersTimeout time.Duration `toml:"-"` + APILimit uint `toml:"-"` + APIQueueLimit uint `toml:"-"` + APIQueueTimeout time.Duration `toml:"-"` + APICILongPollingDuration time.Duration `toml:"-"` + ObjectStorageConfig ObjectStorageConfig `toml:"-"` + ObjectStorageCredentials ObjectStorageCredentials `toml:"object_storage"` + PropagateCorrelationID bool `toml:"-"` + ImageResizerConfig ImageResizerConfig `toml:"image_resizer"` + AltDocumentRoot string `toml:"alt_document_root"` +} + +var DefaultImageResizerConfig = ImageResizerConfig{ + MaxScalerProcs: uint32(math.Max(2, float64(runtime.NumCPU())/2)), + MaxFilesize: 250 * 1000, // 250kB, +} + +func LoadConfig(data string) (*Config, error) { + cfg := &Config{ImageResizerConfig: DefaultImageResizerConfig} + + if _, err := toml.Decode(data, cfg); err != nil { + return nil, err + } + + return cfg, nil +} + +func (c *Config) RegisterGoCloudURLOpeners() error { + c.ObjectStorageConfig.URLMux = new(blob.URLMux) + + creds := c.ObjectStorageCredentials + if strings.EqualFold(creds.Provider, "AzureRM") && creds.AzureCredentials.AccountName != "" && creds.AzureCredentials.AccountKey != "" { + accountName := azureblob.AccountName(creds.AzureCredentials.AccountName) + accountKey := azureblob.AccountKey(creds.AzureCredentials.AccountKey) + + credential, err := azureblob.NewCredential(accountName, accountKey) + if err != nil { + log.WithError(err).Error("error creating Azure credentials") + return err + } + + pipeline := azureblob.NewPipeline(credential, azblob.PipelineOptions{}) + + azureURLOpener := &azureURLOpener{ + &azureblob.URLOpener{ + AccountName: accountName, + Pipeline: pipeline, + Options: azureblob.Options{Credential: credential}, + }, + } + + c.ObjectStorageConfig.URLMux.RegisterBucket(azureblob.Scheme, azureURLOpener) + } + + return nil +} diff --git a/workhorse/internal/config/config_test.go b/workhorse/internal/config/config_test.go new file mode 100644 index 00000000000..102b29a0813 --- /dev/null +++ b/workhorse/internal/config/config_test.go @@ -0,0 +1,111 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +const azureConfig = ` +[object_storage] +provider = "AzureRM" + +[object_storage.azurerm] +azure_storage_account_name = "azuretester" +azure_storage_access_key = "deadbeef" +` + +func TestLoadEmptyConfig(t *testing.T) { + config := `` + + cfg, err := LoadConfig(config) + require.NoError(t, err) + + require.Empty(t, cfg.AltDocumentRoot) + require.Equal(t, cfg.ImageResizerConfig.MaxFilesize, uint64(250000)) + require.GreaterOrEqual(t, cfg.ImageResizerConfig.MaxScalerProcs, uint32(2)) + + require.Equal(t, ObjectStorageCredentials{}, cfg.ObjectStorageCredentials) + require.NoError(t, cfg.RegisterGoCloudURLOpeners()) +} + +func TestLoadObjectStorageConfig(t *testing.T) { + config := ` +[object_storage] +provider = "AWS" + +[object_storage.s3] +aws_access_key_id = "minio" +aws_secret_access_key = "gdk-minio" +` + + cfg, err := LoadConfig(config) + require.NoError(t, err) + + require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") + + expected := ObjectStorageCredentials{ + Provider: "AWS", + S3Credentials: S3Credentials{ + AwsAccessKeyID: "minio", + AwsSecretAccessKey: "gdk-minio", + }, + } + + require.Equal(t, expected, cfg.ObjectStorageCredentials) +} + +func TestRegisterGoCloudURLOpeners(t *testing.T) { + cfg, err := LoadConfig(azureConfig) + require.NoError(t, err) + + require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") + + expected := ObjectStorageCredentials{ + Provider: "AzureRM", + AzureCredentials: AzureCredentials{ + AccountName: "azuretester", + AccountKey: "deadbeef", + }, + } + + require.Equal(t, expected, cfg.ObjectStorageCredentials) + require.Nil(t, cfg.ObjectStorageConfig.URLMux) + + require.NoError(t, cfg.RegisterGoCloudURLOpeners()) + require.NotNil(t, cfg.ObjectStorageConfig.URLMux) + + require.True(t, cfg.ObjectStorageConfig.URLMux.ValidBucketScheme("azblob")) + require.Equal(t, []string{"azblob"}, cfg.ObjectStorageConfig.URLMux.BucketSchemes()) +} + +func TestLoadImageResizerConfig(t *testing.T) { + config := ` +[image_resizer] +max_scaler_procs = 200 +max_filesize = 350000 +` + + cfg, err := LoadConfig(config) + require.NoError(t, err) + + require.NotNil(t, cfg.ImageResizerConfig, "Expected image resizer config") + + expected := ImageResizerConfig{ + MaxScalerProcs: 200, + MaxFilesize: 350000, + } + + require.Equal(t, expected, cfg.ImageResizerConfig) +} + +func TestAltDocumentConfig(t *testing.T) { + config := ` +alt_document_root = "/path/to/documents" +` + + cfg, err := LoadConfig(config) + require.NoError(t, err) + + require.Equal(t, "/path/to/documents", cfg.AltDocumentRoot) +} diff --git a/workhorse/internal/config/url_openers.go b/workhorse/internal/config/url_openers.go new file mode 100644 index 00000000000..d3c96ee9eef --- /dev/null +++ b/workhorse/internal/config/url_openers.go @@ -0,0 +1,51 @@ +package config + +import ( + "context" + "fmt" + "net/url" + + "gocloud.dev/blob" + "gocloud.dev/blob/azureblob" +) + +// This code can be removed once https://github.com/google/go-cloud/pull/2851 is merged. + +// URLOpener opens Azure URLs like "azblob://mybucket". +// +// The URL host is used as the bucket name. +// +// The following query options are supported: +// - domain: The domain name used to access the Azure Blob storage (e.g. blob.core.windows.net) +type azureURLOpener struct { + *azureblob.URLOpener +} + +func (o *azureURLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) { + opts := new(azureblob.Options) + *opts = o.Options + + err := setOptionsFromURLParams(u.Query(), opts) + if err != nil { + return nil, err + } + return azureblob.OpenBucket(ctx, o.Pipeline, o.AccountName, u.Host, opts) +} + +func setOptionsFromURLParams(q url.Values, opts *azureblob.Options) error { + for param, values := range q { + if len(values) > 1 { + return fmt.Errorf("multiple values of %v not allowed", param) + } + + value := values[0] + switch param { + case "domain": + opts.StorageDomain = azureblob.StorageDomain(value) + default: + return fmt.Errorf("unknown query parameter %q", param) + } + } + + return nil +} diff --git a/workhorse/internal/config/url_openers_test.go b/workhorse/internal/config/url_openers_test.go new file mode 100644 index 00000000000..6a851cacbb8 --- /dev/null +++ b/workhorse/internal/config/url_openers_test.go @@ -0,0 +1,117 @@ +package config + +import ( + "context" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + "gocloud.dev/blob/azureblob" +) + +func TestURLOpeners(t *testing.T) { + cfg, err := LoadConfig(azureConfig) + require.NoError(t, err) + + require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") + + require.NoError(t, cfg.RegisterGoCloudURLOpeners()) + require.NotNil(t, cfg.ObjectStorageConfig.URLMux) + + tests := []struct { + url string + valid bool + }{ + + { + url: "azblob://container/object", + valid: true, + }, + { + url: "azblob://container/object?domain=core.windows.net", + valid: true, + }, + { + url: "azblob://container/object?domain=core.windows.net&domain=test", + valid: false, + }, + { + url: "azblob://container/object?param=value", + valid: false, + }, + { + url: "s3://bucket/object", + valid: false, + }, + } + + for _, test := range tests { + t.Run(test.url, func(t *testing.T) { + ctx := context.Background() + url, err := url.Parse(test.url) + require.NoError(t, err) + + bucket, err := cfg.ObjectStorageConfig.URLMux.OpenBucketURL(ctx, url) + if bucket != nil { + defer bucket.Close() + } + + if test.valid { + require.NotNil(t, bucket) + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestTestURLOpenersForParams(t *testing.T) { + tests := []struct { + name string + currOpts azureblob.Options + query url.Values + wantOpts azureblob.Options + wantErr bool + }{ + { + name: "InvalidParam", + query: url.Values{ + "foo": {"bar"}, + }, + wantErr: true, + }, + { + name: "StorageDomain", + query: url.Values{ + "domain": {"blob.core.usgovcloudapi.net"}, + }, + wantOpts: azureblob.Options{StorageDomain: "blob.core.usgovcloudapi.net"}, + }, + { + name: "duplicate StorageDomain", + query: url.Values{ + "domain": {"blob.core.usgovcloudapi.net", "blob.core.windows.net"}, + }, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + o := &azureURLOpener{ + URLOpener: &azureblob.URLOpener{ + Options: test.currOpts, + }, + } + err := setOptionsFromURLParams(test.query, &o.Options) + + if test.wantErr { + require.NotNil(t, err) + } else { + require.Nil(t, err) + require.Equal(t, test.wantOpts, o.Options) + } + }) + } +} diff --git a/workhorse/internal/filestore/file_handler.go b/workhorse/internal/filestore/file_handler.go new file mode 100644 index 00000000000..19764e9a5cf --- /dev/null +++ b/workhorse/internal/filestore/file_handler.go @@ -0,0 +1,257 @@ +package filestore + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "strconv" + "time" + + "github.com/dgrijalva/jwt-go" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" +) + +type SizeError error + +// ErrEntityTooLarge means that the uploaded content is bigger then maximum allowed size +var ErrEntityTooLarge = errors.New("entity is too large") + +// FileHandler represent a file that has been processed for upload +// it may be either uploaded to an ObjectStore and/or saved on local path. +type FileHandler struct { + // LocalPath is the path on the disk where file has been stored + LocalPath string + + // RemoteID is the objectID provided by GitLab Rails + RemoteID string + // RemoteURL is ObjectStore URL provided by GitLab Rails + RemoteURL string + + // Size is the persisted file size + Size int64 + + // Name is the resource name to send back to GitLab rails. + // It differ from the real file name in order to avoid file collisions + Name string + + // a map containing different hashes + hashes map[string]string +} + +type uploadClaims struct { + Upload map[string]string `json:"upload"` + jwt.StandardClaims +} + +// SHA256 hash of the handled file +func (fh *FileHandler) SHA256() string { + return fh.hashes["sha256"] +} + +// MD5 hash of the handled file +func (fh *FileHandler) MD5() string { + return fh.hashes["md5"] +} + +// GitLabFinalizeFields returns a map with all the fields GitLab Rails needs in order to finalize the upload. +func (fh *FileHandler) GitLabFinalizeFields(prefix string) (map[string]string, error) { + // TODO: remove `data` these once rails fully and exclusively support `signedData` (https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/263) + data := make(map[string]string) + signedData := make(map[string]string) + key := func(field string) string { + if prefix == "" { + return field + } + + return fmt.Sprintf("%s.%s", prefix, field) + } + + for k, v := range map[string]string{ + "name": fh.Name, + "path": fh.LocalPath, + "remote_url": fh.RemoteURL, + "remote_id": fh.RemoteID, + "size": strconv.FormatInt(fh.Size, 10), + } { + data[key(k)] = v + signedData[k] = v + } + + for hashName, hash := range fh.hashes { + data[key(hashName)] = hash + signedData[hashName] = hash + } + + claims := uploadClaims{Upload: signedData, StandardClaims: secret.DefaultClaims} + jwtData, err := secret.JWTTokenString(claims) + if err != nil { + return nil, err + } + data[key("gitlab-workhorse-upload")] = jwtData + + return data, nil +} + +type consumer interface { + Consume(context.Context, io.Reader, time.Time) (int64, error) +} + +// SaveFileFromReader persists the provided reader content to all the location specified in opts. A cleanup will be performed once ctx is Done +// Make sure the provided context will not expire before finalizing upload with GitLab Rails. +func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts *SaveFileOpts) (fh *FileHandler, err error) { + var uploadDestination consumer + fh = &FileHandler{ + Name: opts.TempFilePrefix, + RemoteID: opts.RemoteID, + RemoteURL: opts.RemoteURL, + } + hashes := newMultiHash() + reader = io.TeeReader(reader, hashes.Writer) + + var clientMode string + + switch { + case opts.IsLocal(): + clientMode = "local" + uploadDestination, err = fh.uploadLocalFile(ctx, opts) + case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsGoCloud(): + clientMode = fmt.Sprintf("go_cloud:%s", opts.ObjectStorageConfig.Provider) + p := &objectstore.GoCloudObjectParams{ + Ctx: ctx, + Mux: opts.ObjectStorageConfig.URLMux, + BucketURL: opts.ObjectStorageConfig.GoCloudConfig.URL, + ObjectName: opts.RemoteTempObjectID, + } + uploadDestination, err = objectstore.NewGoCloudObject(p) + case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsAWS() && opts.ObjectStorageConfig.IsValid(): + clientMode = "s3" + uploadDestination, err = objectstore.NewS3Object( + opts.RemoteTempObjectID, + opts.ObjectStorageConfig.S3Credentials, + opts.ObjectStorageConfig.S3Config, + ) + case opts.IsMultipart(): + clientMode = "multipart" + uploadDestination, err = objectstore.NewMultipart( + opts.PresignedParts, + opts.PresignedCompleteMultipart, + opts.PresignedAbortMultipart, + opts.PresignedDelete, + opts.PutHeaders, + opts.PartSize, + ) + default: + clientMode = "http" + uploadDestination, err = objectstore.NewObject( + opts.PresignedPut, + opts.PresignedDelete, + opts.PutHeaders, + size, + ) + } + + if err != nil { + return nil, err + } + + if opts.MaximumSize > 0 { + if size > opts.MaximumSize { + return nil, SizeError(fmt.Errorf("the upload size %d is over maximum of %d bytes", size, opts.MaximumSize)) + } + + hlr := &hardLimitReader{r: reader, n: opts.MaximumSize} + reader = hlr + defer func() { + if hlr.n < 0 { + err = ErrEntityTooLarge + } + }() + } + + fh.Size, err = uploadDestination.Consume(ctx, reader, opts.Deadline) + if err != nil { + if err == objectstore.ErrNotEnoughParts { + err = ErrEntityTooLarge + } + return nil, err + } + + if size != -1 && size != fh.Size { + return nil, SizeError(fmt.Errorf("expected %d bytes but got only %d", size, fh.Size)) + } + + logger := log.WithContextFields(ctx, log.Fields{ + "copied_bytes": fh.Size, + "is_local": opts.IsLocal(), + "is_multipart": opts.IsMultipart(), + "is_remote": !opts.IsLocal(), + "remote_id": opts.RemoteID, + "temp_file_prefix": opts.TempFilePrefix, + "client_mode": clientMode, + }) + + if opts.IsLocal() { + logger = logger.WithField("local_temp_path", opts.LocalTempPath) + } else { + logger = logger.WithField("remote_temp_object", opts.RemoteTempObjectID) + } + + logger.Info("saved file") + fh.hashes = hashes.finish() + return fh, nil +} + +func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) (consumer, error) { + // make sure TempFolder exists + err := os.MkdirAll(opts.LocalTempPath, 0700) + if err != nil { + return nil, fmt.Errorf("uploadLocalFile: mkdir %q: %v", opts.LocalTempPath, err) + } + + file, err := ioutil.TempFile(opts.LocalTempPath, opts.TempFilePrefix) + if err != nil { + return nil, fmt.Errorf("uploadLocalFile: create file: %v", err) + } + + go func() { + <-ctx.Done() + os.Remove(file.Name()) + }() + + fh.LocalPath = file.Name() + return &localUpload{file}, nil +} + +type localUpload struct{ io.WriteCloser } + +func (loc *localUpload) Consume(_ context.Context, r io.Reader, _ time.Time) (int64, error) { + n, err := io.Copy(loc.WriteCloser, r) + errClose := loc.Close() + if err == nil { + err = errClose + } + return n, err +} + +// SaveFileFromDisk open the local file fileName and calls SaveFileFromReader +func SaveFileFromDisk(ctx context.Context, fileName string, opts *SaveFileOpts) (fh *FileHandler, err error) { + file, err := os.Open(fileName) + if err != nil { + return nil, err + } + defer file.Close() + + fi, err := file.Stat() + if err != nil { + return nil, err + } + + return SaveFileFromReader(ctx, file, fi.Size(), opts) +} diff --git a/workhorse/internal/filestore/file_handler_test.go b/workhorse/internal/filestore/file_handler_test.go new file mode 100644 index 00000000000..e79e9d0f292 --- /dev/null +++ b/workhorse/internal/filestore/file_handler_test.go @@ -0,0 +1,551 @@ +package filestore_test + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "os" + "path" + "strconv" + "strings" + "testing" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/require" + "gocloud.dev/blob" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func testDeadline() time.Time { + return time.Now().Add(filestore.DefaultObjectStoreTimeout) +} + +func requireFileGetsRemovedAsync(t *testing.T, filePath string) { + var err error + + // Poll because the file removal is async + for i := 0; i < 100; i++ { + _, err = os.Stat(filePath) + if err != nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + require.True(t, os.IsNotExist(err), "File hasn't been deleted during cleanup") +} + +func requireObjectStoreDeletedAsync(t *testing.T, expectedDeletes int, osStub *test.ObjectstoreStub) { + // Poll because the object removal is async + for i := 0; i < 100; i++ { + if osStub.DeletesCnt() == expectedDeletes { + break + } + time.Sleep(10 * time.Millisecond) + } + + require.Equal(t, expectedDeletes, osStub.DeletesCnt(), "Object not deleted") +} + +func TestSaveFileWrongSize(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpFolder) + + opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file"} + fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize+1, opts) + require.Error(t, err) + _, isSizeError := err.(filestore.SizeError) + require.True(t, isSizeError, "Should fail with SizeError") + require.Nil(t, fh) +} + +func TestSaveFileWithKnownSizeExceedLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpFolder) + + opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file", MaximumSize: test.ObjectSize - 1} + fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, opts) + require.Error(t, err) + _, isSizeError := err.(filestore.SizeError) + require.True(t, isSizeError, "Should fail with SizeError") + require.Nil(t, fh) +} + +func TestSaveFileWithUnknownSizeExceedLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpFolder) + + opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file", MaximumSize: test.ObjectSize - 1} + fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), -1, opts) + require.Equal(t, err, filestore.ErrEntityTooLarge) + require.Nil(t, fh) +} + +func TestSaveFromDiskNotExistingFile(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fh, err := filestore.SaveFileFromDisk(ctx, "/I/do/not/exist", &filestore.SaveFileOpts{}) + require.Error(t, err, "SaveFileFromDisk should fail") + require.True(t, os.IsNotExist(err), "Provided file should not exists") + require.Nil(t, fh, "On error FileHandler should be nil") +} + +func TestSaveFileWrongETag(t *testing.T) { + tests := []struct { + name string + multipart bool + }{ + {name: "single part"}, + {name: "multi part", multipart: true}, + } + + for _, spec := range tests { + t.Run(spec.name, func(t *testing.T) { + osStub, ts := test.StartObjectStoreWithCustomMD5(map[string]string{test.ObjectPath: "brokenMD5"}) + defer ts.Close() + + objectURL := ts.URL + test.ObjectPath + + opts := &filestore.SaveFileOpts{ + RemoteID: "test-file", + RemoteURL: objectURL, + PresignedPut: objectURL + "?Signature=ASignature", + PresignedDelete: objectURL + "?Signature=AnotherSignature", + Deadline: testDeadline(), + } + if spec.multipart { + opts.PresignedParts = []string{objectURL + "?partNumber=1"} + opts.PresignedCompleteMultipart = objectURL + "?Signature=CompleteSig" + opts.PresignedAbortMultipart = objectURL + "?Signature=AbortSig" + opts.PartSize = test.ObjectSize + + osStub.InitiateMultipartUpload(test.ObjectPath) + } + ctx, cancel := context.WithCancel(context.Background()) + fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, opts) + require.Nil(t, fh) + require.Error(t, err) + require.Equal(t, 1, osStub.PutsCnt(), "File not uploaded") + + cancel() // this will trigger an async cleanup + requireObjectStoreDeletedAsync(t, 1, osStub) + require.False(t, spec.multipart && osStub.IsMultipartUpload(test.ObjectPath), "there must be no multipart upload in progress now") + }) + } +} + +func TestSaveFileFromDiskToLocalPath(t *testing.T) { + f, err := ioutil.TempFile("", "workhorse-test") + require.NoError(t, err) + defer os.Remove(f.Name()) + + _, err = fmt.Fprint(f, test.ObjectContent) + require.NoError(t, err) + + tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpFolder) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder} + fh, err := filestore.SaveFileFromDisk(ctx, f.Name(), opts) + require.NoError(t, err) + require.NotNil(t, fh) + + require.NotEmpty(t, fh.LocalPath, "File not persisted on disk") + _, err = os.Stat(fh.LocalPath) + require.NoError(t, err) +} + +func TestSaveFile(t *testing.T) { + testhelper.ConfigureSecret() + + type remote int + const ( + notRemote remote = iota + remoteSingle + remoteMultipart + ) + + tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpFolder) + + tests := []struct { + name string + local bool + remote remote + }{ + {name: "Local only", local: true}, + {name: "Remote Single only", remote: remoteSingle}, + {name: "Remote Multipart only", remote: remoteMultipart}, + } + + for _, spec := range tests { + t.Run(spec.name, func(t *testing.T) { + var opts filestore.SaveFileOpts + var expectedDeletes, expectedPuts int + + osStub, ts := test.StartObjectStore() + defer ts.Close() + + switch spec.remote { + case remoteSingle: + objectURL := ts.URL + test.ObjectPath + + opts.RemoteID = "test-file" + opts.RemoteURL = objectURL + opts.PresignedPut = objectURL + "?Signature=ASignature" + opts.PresignedDelete = objectURL + "?Signature=AnotherSignature" + opts.Deadline = testDeadline() + + expectedDeletes = 1 + expectedPuts = 1 + case remoteMultipart: + objectURL := ts.URL + test.ObjectPath + + opts.RemoteID = "test-file" + opts.RemoteURL = objectURL + opts.PresignedDelete = objectURL + "?Signature=AnotherSignature" + opts.PartSize = int64(len(test.ObjectContent)/2) + 1 + opts.PresignedParts = []string{objectURL + "?partNumber=1", objectURL + "?partNumber=2"} + opts.PresignedCompleteMultipart = objectURL + "?Signature=CompleteSignature" + opts.Deadline = testDeadline() + + osStub.InitiateMultipartUpload(test.ObjectPath) + expectedDeletes = 1 + expectedPuts = 2 + } + + if spec.local { + opts.LocalTempPath = tmpFolder + opts.TempFilePrefix = "test-file" + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts) + require.NoError(t, err) + require.NotNil(t, fh) + + require.Equal(t, opts.RemoteID, fh.RemoteID) + require.Equal(t, opts.RemoteURL, fh.RemoteURL) + + if spec.local { + require.NotEmpty(t, fh.LocalPath, "File not persisted on disk") + _, err := os.Stat(fh.LocalPath) + require.NoError(t, err) + + dir := path.Dir(fh.LocalPath) + require.Equal(t, opts.LocalTempPath, dir) + filename := path.Base(fh.LocalPath) + beginsWithPrefix := strings.HasPrefix(filename, opts.TempFilePrefix) + require.True(t, beginsWithPrefix, fmt.Sprintf("LocalPath filename %q do not begin with TempFilePrefix %q", filename, opts.TempFilePrefix)) + } else { + require.Empty(t, fh.LocalPath, "LocalPath must be empty for non local uploads") + } + + require.Equal(t, test.ObjectSize, fh.Size) + require.Equal(t, test.ObjectMD5, fh.MD5()) + require.Equal(t, test.ObjectSHA256, fh.SHA256()) + + require.Equal(t, expectedPuts, osStub.PutsCnt(), "ObjectStore PutObject count mismatch") + require.Equal(t, 0, osStub.DeletesCnt(), "File deleted too early") + + cancel() // this will trigger an async cleanup + requireObjectStoreDeletedAsync(t, expectedDeletes, osStub) + requireFileGetsRemovedAsync(t, fh.LocalPath) + + // checking generated fields + fields, err := fh.GitLabFinalizeFields("file") + require.NoError(t, err) + + checkFileHandlerWithFields(t, fh, fields, "file") + + token, jwtErr := jwt.ParseWithClaims(fields["file.gitlab-workhorse-upload"], &testhelper.UploadClaims{}, testhelper.ParseJWT) + require.NoError(t, jwtErr) + + uploadFields := token.Claims.(*testhelper.UploadClaims).Upload + + checkFileHandlerWithFields(t, fh, uploadFields, "") + }) + } +} + +func TestSaveFileWithS3WorkhorseClient(t *testing.T) { + tests := []struct { + name string + objectSize int64 + maxSize int64 + expectedErr error + }{ + { + name: "known size with no limit", + objectSize: test.ObjectSize, + }, + { + name: "unknown size with no limit", + objectSize: -1, + }, + { + name: "unknown object size with limit", + objectSize: -1, + maxSize: test.ObjectSize - 1, + expectedErr: filestore.ErrEntityTooLarge, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + + s3Creds, s3Config, sess, ts := test.SetupS3(t, "") + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + remoteObject := "tmp/test-file/1" + opts := filestore.SaveFileOpts{ + RemoteID: "test-file", + Deadline: testDeadline(), + UseWorkhorseClient: true, + RemoteTempObjectID: remoteObject, + ObjectStorageConfig: filestore.ObjectStorageConfig{ + Provider: "AWS", + S3Credentials: s3Creds, + S3Config: s3Config, + }, + MaximumSize: tc.maxSize, + } + + _, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), tc.objectSize, &opts) + + if tc.expectedErr == nil { + require.NoError(t, err) + test.S3ObjectExists(t, sess, s3Config, remoteObject, test.ObjectContent) + } else { + require.Equal(t, tc.expectedErr, err) + test.S3ObjectDoesNotExist(t, sess, s3Config, remoteObject) + } + }) + } +} + +func TestSaveFileWithAzureWorkhorseClient(t *testing.T) { + mux, bucketDir, cleanup := test.SetupGoCloudFileBucket(t, "azblob") + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + remoteObject := "tmp/test-file/1" + opts := filestore.SaveFileOpts{ + RemoteID: "test-file", + Deadline: testDeadline(), + UseWorkhorseClient: true, + RemoteTempObjectID: remoteObject, + ObjectStorageConfig: filestore.ObjectStorageConfig{ + Provider: "AzureRM", + URLMux: mux, + GoCloudConfig: config.GoCloudConfig{URL: "azblob://test-container"}, + }, + } + + _, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts) + require.NoError(t, err) + + test.GoCloudObjectExists(t, bucketDir, remoteObject) +} + +func TestSaveFileWithUnknownGoCloudScheme(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mux := new(blob.URLMux) + + remoteObject := "tmp/test-file/1" + opts := filestore.SaveFileOpts{ + RemoteID: "test-file", + Deadline: testDeadline(), + UseWorkhorseClient: true, + RemoteTempObjectID: remoteObject, + ObjectStorageConfig: filestore.ObjectStorageConfig{ + Provider: "SomeCloud", + URLMux: mux, + GoCloudConfig: config.GoCloudConfig{URL: "foo://test-container"}, + }, + } + + _, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts) + require.Error(t, err) +} + +func TestSaveMultipartInBodyFailure(t *testing.T) { + osStub, ts := test.StartObjectStore() + defer ts.Close() + + // this is a broken path because it contains bucket name but no key + // this is the only way to get an in-body failure from our ObjectStoreStub + objectPath := "/bucket-but-no-object-key" + objectURL := ts.URL + objectPath + opts := filestore.SaveFileOpts{ + RemoteID: "test-file", + RemoteURL: objectURL, + PartSize: test.ObjectSize, + PresignedParts: []string{objectURL + "?partNumber=1", objectURL + "?partNumber=2"}, + PresignedCompleteMultipart: objectURL + "?Signature=CompleteSignature", + Deadline: testDeadline(), + } + + osStub.InitiateMultipartUpload(objectPath) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts) + require.Nil(t, fh) + require.Error(t, err) + require.EqualError(t, err, test.MultipartUploadInternalError().Error()) +} + +func TestSaveRemoteFileWithLimit(t *testing.T) { + testhelper.ConfigureSecret() + + type remote int + const ( + notRemote remote = iota + remoteSingle + remoteMultipart + ) + + remoteTypes := []remote{remoteSingle, remoteMultipart} + + tests := []struct { + name string + objectSize int64 + maxSize int64 + expectedErr error + testData string + }{ + { + name: "known size with no limit", + testData: test.ObjectContent, + objectSize: test.ObjectSize, + }, + { + name: "unknown size with no limit", + testData: test.ObjectContent, + objectSize: -1, + }, + { + name: "unknown object size with limit", + testData: test.ObjectContent, + objectSize: -1, + maxSize: test.ObjectSize - 1, + expectedErr: filestore.ErrEntityTooLarge, + }, + { + name: "large object with unknown size with limit", + testData: string(make([]byte, 20000)), + objectSize: -1, + maxSize: 19000, + expectedErr: filestore.ErrEntityTooLarge, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var opts filestore.SaveFileOpts + + for _, remoteType := range remoteTypes { + tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp") + require.NoError(t, err) + defer os.RemoveAll(tmpFolder) + + osStub, ts := test.StartObjectStore() + defer ts.Close() + + switch remoteType { + case remoteSingle: + objectURL := ts.URL + test.ObjectPath + + opts.RemoteID = "test-file" + opts.RemoteURL = objectURL + opts.PresignedPut = objectURL + "?Signature=ASignature" + opts.PresignedDelete = objectURL + "?Signature=AnotherSignature" + opts.Deadline = testDeadline() + opts.MaximumSize = tc.maxSize + case remoteMultipart: + objectURL := ts.URL + test.ObjectPath + + opts.RemoteID = "test-file" + opts.RemoteURL = objectURL + opts.PresignedDelete = objectURL + "?Signature=AnotherSignature" + opts.PartSize = int64(len(tc.testData)/2) + 1 + opts.PresignedParts = []string{objectURL + "?partNumber=1", objectURL + "?partNumber=2"} + opts.PresignedCompleteMultipart = objectURL + "?Signature=CompleteSignature" + opts.Deadline = testDeadline() + opts.MaximumSize = tc.maxSize + + require.Less(t, int64(len(tc.testData)), int64(len(opts.PresignedParts))*opts.PartSize, "check part size calculation") + + osStub.InitiateMultipartUpload(test.ObjectPath) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(tc.testData), tc.objectSize, &opts) + + if tc.expectedErr == nil { + require.NoError(t, err) + require.NotNil(t, fh) + } else { + require.True(t, errors.Is(err, tc.expectedErr)) + require.Nil(t, fh) + } + } + }) + } +} + +func checkFileHandlerWithFields(t *testing.T, fh *filestore.FileHandler, fields map[string]string, prefix string) { + key := func(field string) string { + if prefix == "" { + return field + } + + return fmt.Sprintf("%s.%s", prefix, field) + } + + require.Equal(t, fh.Name, fields[key("name")]) + require.Equal(t, fh.LocalPath, fields[key("path")]) + require.Equal(t, fh.RemoteURL, fields[key("remote_url")]) + require.Equal(t, fh.RemoteID, fields[key("remote_id")]) + require.Equal(t, strconv.FormatInt(test.ObjectSize, 10), fields[key("size")]) + require.Equal(t, test.ObjectMD5, fields[key("md5")]) + require.Equal(t, test.ObjectSHA1, fields[key("sha1")]) + require.Equal(t, test.ObjectSHA256, fields[key("sha256")]) + require.Equal(t, test.ObjectSHA512, fields[key("sha512")]) +} diff --git a/workhorse/internal/filestore/multi_hash.go b/workhorse/internal/filestore/multi_hash.go new file mode 100644 index 00000000000..40efd3a5c1f --- /dev/null +++ b/workhorse/internal/filestore/multi_hash.go @@ -0,0 +1,48 @@ +package filestore + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "hash" + "io" +) + +var hashFactories = map[string](func() hash.Hash){ + "md5": md5.New, + "sha1": sha1.New, + "sha256": sha256.New, + "sha512": sha512.New, +} + +type multiHash struct { + io.Writer + hashes map[string]hash.Hash +} + +func newMultiHash() (m *multiHash) { + m = &multiHash{} + m.hashes = make(map[string]hash.Hash) + + var writers []io.Writer + for hash, hashFactory := range hashFactories { + writer := hashFactory() + + m.hashes[hash] = writer + writers = append(writers, writer) + } + + m.Writer = io.MultiWriter(writers...) + return m +} + +func (m *multiHash) finish() map[string]string { + h := make(map[string]string) + for hashName, hash := range m.hashes { + checksum := hash.Sum(nil) + h[hashName] = hex.EncodeToString(checksum) + } + return h +} diff --git a/workhorse/internal/filestore/reader.go b/workhorse/internal/filestore/reader.go new file mode 100644 index 00000000000..b1045b991fc --- /dev/null +++ b/workhorse/internal/filestore/reader.go @@ -0,0 +1,17 @@ +package filestore + +import "io" + +type hardLimitReader struct { + r io.Reader + n int64 +} + +func (h *hardLimitReader) Read(p []byte) (int, error) { + nRead, err := h.r.Read(p) + h.n -= int64(nRead) + if h.n < 0 { + err = ErrEntityTooLarge + } + return nRead, err +} diff --git a/workhorse/internal/filestore/reader_test.go b/workhorse/internal/filestore/reader_test.go new file mode 100644 index 00000000000..424d921ecaf --- /dev/null +++ b/workhorse/internal/filestore/reader_test.go @@ -0,0 +1,46 @@ +package filestore + +import ( + "fmt" + "io/ioutil" + "strings" + "testing" + "testing/iotest" + + "github.com/stretchr/testify/require" +) + +func TestHardLimitReader(t *testing.T) { + const text = "hello world" + r := iotest.OneByteReader( + &hardLimitReader{ + r: strings.NewReader(text), + n: int64(len(text)), + }, + ) + + out, err := ioutil.ReadAll(r) + require.NoError(t, err) + require.Equal(t, text, string(out)) +} + +func TestHardLimitReaderFail(t *testing.T) { + const text = "hello world" + + for bufSize := len(text) / 2; bufSize < len(text)*2; bufSize++ { + t.Run(fmt.Sprintf("bufsize:%d", bufSize), func(t *testing.T) { + r := &hardLimitReader{ + r: iotest.DataErrReader(strings.NewReader(text)), + n: int64(len(text)) - 1, + } + buf := make([]byte, bufSize) + + var err error + for i := 0; err == nil && i < 1000; i++ { + _, err = r.Read(buf) + } + + require.Equal(t, ErrEntityTooLarge, err) + }) + } +} diff --git a/workhorse/internal/filestore/save_file_opts.go b/workhorse/internal/filestore/save_file_opts.go new file mode 100644 index 00000000000..1eb708c3f55 --- /dev/null +++ b/workhorse/internal/filestore/save_file_opts.go @@ -0,0 +1,171 @@ +package filestore + +import ( + "errors" + "strings" + "time" + + "gocloud.dev/blob" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" +) + +// DefaultObjectStoreTimeout is the timeout for ObjectStore upload operation +const DefaultObjectStoreTimeout = 4 * time.Hour + +type ObjectStorageConfig struct { + Provider string + + S3Credentials config.S3Credentials + S3Config config.S3Config + + // GoCloud mux that maps azureblob:// and future URLs (e.g. s3://, gcs://, etc.) to a handler + URLMux *blob.URLMux + + // Azure credentials are registered at startup in the GoCloud URLMux, so only the container name is needed + GoCloudConfig config.GoCloudConfig +} + +// SaveFileOpts represents all the options available for saving a file to object store +type SaveFileOpts struct { + // TempFilePrefix is the prefix used to create temporary local file + TempFilePrefix string + // LocalTempPath is the directory where to write a local copy of the file + LocalTempPath string + // RemoteID is the remote ObjectID provided by GitLab + RemoteID string + // RemoteURL is the final URL of the file + RemoteURL string + // PresignedPut is a presigned S3 PutObject compatible URL + PresignedPut string + // PresignedDelete is a presigned S3 DeleteObject compatible URL. + PresignedDelete string + // HTTP headers to be sent along with PUT request + PutHeaders map[string]string + // Whether to ignore Rails pre-signed URLs and have Workhorse directly access object storage provider + UseWorkhorseClient bool + // If UseWorkhorseClient is true, this is the temporary object name to store the file + RemoteTempObjectID string + // Workhorse object storage client (e.g. S3) parameters + ObjectStorageConfig ObjectStorageConfig + // Deadline it the S3 operation deadline, the upload will be aborted if not completed in time + Deadline time.Time + // The maximum accepted size in bytes of the upload + MaximumSize int64 + + //MultipartUpload parameters + // PartSize is the exact size of each uploaded part. Only the last one can be smaller + PartSize int64 + // PresignedParts contains the presigned URLs for each part + PresignedParts []string + // PresignedCompleteMultipart is a presigned URL for CompleteMulipartUpload + PresignedCompleteMultipart string + // PresignedAbortMultipart is a presigned URL for AbortMultipartUpload + PresignedAbortMultipart string +} + +// UseWorkhorseClientEnabled checks if the options require direct access to object storage +func (s *SaveFileOpts) UseWorkhorseClientEnabled() bool { + return s.UseWorkhorseClient && s.ObjectStorageConfig.IsValid() && s.RemoteTempObjectID != "" +} + +// IsLocal checks if the options require the writing of the file on disk +func (s *SaveFileOpts) IsLocal() bool { + return s.LocalTempPath != "" +} + +// IsMultipart checks if the options requires a Multipart upload +func (s *SaveFileOpts) IsMultipart() bool { + return s.PartSize > 0 +} + +// GetOpts converts GitLab api.Response to a proper SaveFileOpts +func GetOpts(apiResponse *api.Response) (*SaveFileOpts, error) { + timeout := time.Duration(apiResponse.RemoteObject.Timeout) * time.Second + if timeout == 0 { + timeout = DefaultObjectStoreTimeout + } + + opts := SaveFileOpts{ + LocalTempPath: apiResponse.TempPath, + RemoteID: apiResponse.RemoteObject.ID, + RemoteURL: apiResponse.RemoteObject.GetURL, + PresignedPut: apiResponse.RemoteObject.StoreURL, + PresignedDelete: apiResponse.RemoteObject.DeleteURL, + PutHeaders: apiResponse.RemoteObject.PutHeaders, + UseWorkhorseClient: apiResponse.RemoteObject.UseWorkhorseClient, + RemoteTempObjectID: apiResponse.RemoteObject.RemoteTempObjectID, + Deadline: time.Now().Add(timeout), + MaximumSize: apiResponse.MaximumSize, + } + + if opts.LocalTempPath != "" && opts.RemoteID != "" { + return nil, errors.New("API response has both TempPath and RemoteObject") + } + + if opts.LocalTempPath == "" && opts.RemoteID == "" { + return nil, errors.New("API response has neither TempPath nor RemoteObject") + } + + objectStorageParams := apiResponse.RemoteObject.ObjectStorage + if opts.UseWorkhorseClient && objectStorageParams != nil { + opts.ObjectStorageConfig.Provider = objectStorageParams.Provider + opts.ObjectStorageConfig.S3Config = objectStorageParams.S3Config + opts.ObjectStorageConfig.GoCloudConfig = objectStorageParams.GoCloudConfig + } + + // Backwards compatibility to ensure API servers that do not include the + // CustomPutHeaders flag will default to the original content type. + if !apiResponse.RemoteObject.CustomPutHeaders { + opts.PutHeaders = make(map[string]string) + opts.PutHeaders["Content-Type"] = "application/octet-stream" + } + + if multiParams := apiResponse.RemoteObject.MultipartUpload; multiParams != nil { + opts.PartSize = multiParams.PartSize + opts.PresignedCompleteMultipart = multiParams.CompleteURL + opts.PresignedAbortMultipart = multiParams.AbortURL + opts.PresignedParts = append([]string(nil), multiParams.PartURLs...) + } + + return &opts, nil +} + +func (c *ObjectStorageConfig) IsAWS() bool { + return strings.EqualFold(c.Provider, "AWS") || strings.EqualFold(c.Provider, "S3") +} + +func (c *ObjectStorageConfig) IsAzure() bool { + return strings.EqualFold(c.Provider, "AzureRM") +} + +func (c *ObjectStorageConfig) IsGoCloud() bool { + return c.GoCloudConfig.URL != "" +} + +func (c *ObjectStorageConfig) IsValid() bool { + if c.IsAWS() { + return c.S3Config.Bucket != "" && c.S3Config.Region != "" && c.s3CredentialsValid() + } else if c.IsGoCloud() { + // We could parse and validate the URL, but GoCloud providers + // such as AzureRM don't have a fallback to normal HTTP, so we + // always want to try the GoCloud path if there is a URL. + return true + } + + return false +} + +func (c *ObjectStorageConfig) s3CredentialsValid() bool { + // We need to be able to distinguish between two cases of AWS access: + // 1. AWS access via key and secret, but credentials not configured in Workhorse + // 2. IAM instance profiles used + if c.S3Config.UseIamProfile { + return true + } else if c.S3Credentials.AwsAccessKeyID != "" && c.S3Credentials.AwsSecretAccessKey != "" { + return true + } + + return false +} diff --git a/workhorse/internal/filestore/save_file_opts_test.go b/workhorse/internal/filestore/save_file_opts_test.go new file mode 100644 index 00000000000..2d6cd683b51 --- /dev/null +++ b/workhorse/internal/filestore/save_file_opts_test.go @@ -0,0 +1,331 @@ +package filestore_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" +) + +func TestSaveFileOptsLocalAndRemote(t *testing.T) { + tests := []struct { + name string + localTempPath string + presignedPut string + partSize int64 + isLocal bool + isRemote bool + isMultipart bool + }{ + { + name: "Only LocalTempPath", + localTempPath: "/tmp", + isLocal: true, + }, + { + name: "No paths", + }, + { + name: "Only remoteUrl", + presignedPut: "http://example.com", + }, + { + name: "Multipart", + partSize: 10, + isMultipart: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := filestore.SaveFileOpts{ + LocalTempPath: test.localTempPath, + PresignedPut: test.presignedPut, + PartSize: test.partSize, + } + + require.Equal(t, test.isLocal, opts.IsLocal(), "IsLocal() mismatch") + require.Equal(t, test.isMultipart, opts.IsMultipart(), "IsMultipart() mismatch") + }) + } +} + +func TestGetOpts(t *testing.T) { + tests := []struct { + name string + multipart *api.MultipartUploadParams + customPutHeaders bool + putHeaders map[string]string + }{ + { + name: "Single upload", + }, { + name: "Multipart upload", + multipart: &api.MultipartUploadParams{ + PartSize: 10, + CompleteURL: "http://complete", + AbortURL: "http://abort", + PartURLs: []string{"http://part1", "http://part2"}, + }, + }, + { + name: "Single upload with custom content type", + customPutHeaders: true, + putHeaders: map[string]string{"Content-Type": "image/jpeg"}, + }, { + name: "Multipart upload with custom content type", + multipart: &api.MultipartUploadParams{ + PartSize: 10, + CompleteURL: "http://complete", + AbortURL: "http://abort", + PartURLs: []string{"http://part1", "http://part2"}, + }, + customPutHeaders: true, + putHeaders: map[string]string{"Content-Type": "image/jpeg"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + apiResponse := &api.Response{ + RemoteObject: api.RemoteObject{ + Timeout: 10, + ID: "id", + GetURL: "http://get", + StoreURL: "http://store", + DeleteURL: "http://delete", + MultipartUpload: test.multipart, + CustomPutHeaders: test.customPutHeaders, + PutHeaders: test.putHeaders, + }, + } + deadline := time.Now().Add(time.Duration(apiResponse.RemoteObject.Timeout) * time.Second) + opts, err := filestore.GetOpts(apiResponse) + require.NoError(t, err) + + require.Equal(t, apiResponse.TempPath, opts.LocalTempPath) + require.WithinDuration(t, deadline, opts.Deadline, time.Second) + require.Equal(t, apiResponse.RemoteObject.ID, opts.RemoteID) + require.Equal(t, apiResponse.RemoteObject.GetURL, opts.RemoteURL) + require.Equal(t, apiResponse.RemoteObject.StoreURL, opts.PresignedPut) + require.Equal(t, apiResponse.RemoteObject.DeleteURL, opts.PresignedDelete) + if test.customPutHeaders { + require.Equal(t, opts.PutHeaders, apiResponse.RemoteObject.PutHeaders) + } else { + require.Equal(t, opts.PutHeaders, map[string]string{"Content-Type": "application/octet-stream"}) + } + + if test.multipart == nil { + require.False(t, opts.IsMultipart()) + require.Empty(t, opts.PresignedCompleteMultipart) + require.Empty(t, opts.PresignedAbortMultipart) + require.Zero(t, opts.PartSize) + require.Empty(t, opts.PresignedParts) + } else { + require.True(t, opts.IsMultipart()) + require.Equal(t, test.multipart.CompleteURL, opts.PresignedCompleteMultipart) + require.Equal(t, test.multipart.AbortURL, opts.PresignedAbortMultipart) + require.Equal(t, test.multipart.PartSize, opts.PartSize) + require.Equal(t, test.multipart.PartURLs, opts.PresignedParts) + } + }) + } +} + +func TestGetOptsFail(t *testing.T) { + testCases := []struct { + desc string + in api.Response + }{ + { + desc: "neither local nor remote", + in: api.Response{}, + }, + { + desc: "both local and remote", + in: api.Response{TempPath: "/foobar", RemoteObject: api.RemoteObject{ID: "id"}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + _, err := filestore.GetOpts(&tc.in) + require.Error(t, err, "expect input to be rejected") + }) + } +} + +func TestGetOptsDefaultTimeout(t *testing.T) { + deadline := time.Now().Add(filestore.DefaultObjectStoreTimeout) + opts, err := filestore.GetOpts(&api.Response{TempPath: "/foo/bar"}) + require.NoError(t, err) + + require.WithinDuration(t, deadline, opts.Deadline, time.Minute) +} + +func TestUseWorkhorseClientEnabled(t *testing.T) { + cfg := filestore.ObjectStorageConfig{ + Provider: "AWS", + S3Config: config.S3Config{ + Bucket: "test-bucket", + Region: "test-region", + }, + S3Credentials: config.S3Credentials{ + AwsAccessKeyID: "test-key", + AwsSecretAccessKey: "test-secret", + }, + } + + missingCfg := cfg + missingCfg.S3Credentials = config.S3Credentials{} + + iamConfig := missingCfg + iamConfig.S3Config.UseIamProfile = true + + tests := []struct { + name string + UseWorkhorseClient bool + remoteTempObjectID string + objectStorageConfig filestore.ObjectStorageConfig + expected bool + }{ + { + name: "all direct access settings used", + UseWorkhorseClient: true, + remoteTempObjectID: "test-object", + objectStorageConfig: cfg, + expected: true, + }, + { + name: "missing AWS credentials", + UseWorkhorseClient: true, + remoteTempObjectID: "test-object", + objectStorageConfig: missingCfg, + expected: false, + }, + { + name: "direct access disabled", + UseWorkhorseClient: false, + remoteTempObjectID: "test-object", + objectStorageConfig: cfg, + expected: false, + }, + { + name: "with IAM instance profile", + UseWorkhorseClient: true, + remoteTempObjectID: "test-object", + objectStorageConfig: iamConfig, + expected: true, + }, + { + name: "missing remote temp object ID", + UseWorkhorseClient: true, + remoteTempObjectID: "", + objectStorageConfig: cfg, + expected: false, + }, + { + name: "missing S3 config", + UseWorkhorseClient: true, + remoteTempObjectID: "test-object", + expected: false, + }, + { + name: "missing S3 bucket", + UseWorkhorseClient: true, + remoteTempObjectID: "test-object", + objectStorageConfig: filestore.ObjectStorageConfig{ + Provider: "AWS", + S3Config: config.S3Config{}, + }, + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + apiResponse := &api.Response{ + RemoteObject: api.RemoteObject{ + Timeout: 10, + ID: "id", + UseWorkhorseClient: test.UseWorkhorseClient, + RemoteTempObjectID: test.remoteTempObjectID, + }, + } + deadline := time.Now().Add(time.Duration(apiResponse.RemoteObject.Timeout) * time.Second) + opts, err := filestore.GetOpts(apiResponse) + require.NoError(t, err) + opts.ObjectStorageConfig = test.objectStorageConfig + + require.Equal(t, apiResponse.TempPath, opts.LocalTempPath) + require.WithinDuration(t, deadline, opts.Deadline, time.Second) + require.Equal(t, apiResponse.RemoteObject.ID, opts.RemoteID) + require.Equal(t, apiResponse.RemoteObject.UseWorkhorseClient, opts.UseWorkhorseClient) + require.Equal(t, test.expected, opts.UseWorkhorseClientEnabled()) + }) + } +} + +func TestGoCloudConfig(t *testing.T) { + mux, _, cleanup := test.SetupGoCloudFileBucket(t, "azblob") + defer cleanup() + + tests := []struct { + name string + provider string + url string + valid bool + }{ + { + name: "valid AzureRM config", + provider: "AzureRM", + url: "azblob:://test-container", + valid: true, + }, + { + name: "invalid GoCloud scheme", + provider: "AzureRM", + url: "unknown:://test-container", + valid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + apiResponse := &api.Response{ + RemoteObject: api.RemoteObject{ + Timeout: 10, + ID: "id", + UseWorkhorseClient: true, + RemoteTempObjectID: "test-object", + ObjectStorage: &api.ObjectStorageParams{ + Provider: test.provider, + GoCloudConfig: config.GoCloudConfig{ + URL: test.url, + }, + }, + }, + } + deadline := time.Now().Add(time.Duration(apiResponse.RemoteObject.Timeout) * time.Second) + opts, err := filestore.GetOpts(apiResponse) + require.NoError(t, err) + opts.ObjectStorageConfig.URLMux = mux + + require.Equal(t, apiResponse.TempPath, opts.LocalTempPath) + require.Equal(t, apiResponse.RemoteObject.RemoteTempObjectID, opts.RemoteTempObjectID) + require.WithinDuration(t, deadline, opts.Deadline, time.Second) + require.Equal(t, apiResponse.RemoteObject.ID, opts.RemoteID) + require.Equal(t, apiResponse.RemoteObject.UseWorkhorseClient, opts.UseWorkhorseClient) + require.Equal(t, test.provider, opts.ObjectStorageConfig.Provider) + require.Equal(t, apiResponse.RemoteObject.ObjectStorage.GoCloudConfig, opts.ObjectStorageConfig.GoCloudConfig) + require.True(t, opts.UseWorkhorseClientEnabled()) + require.Equal(t, test.valid, opts.ObjectStorageConfig.IsValid()) + require.False(t, opts.IsLocal()) + }) + } +} diff --git a/workhorse/internal/git/archive.go b/workhorse/internal/git/archive.go new file mode 100644 index 00000000000..b7575be2c02 --- /dev/null +++ b/workhorse/internal/git/archive.go @@ -0,0 +1,216 @@ +/* +In this file we handle 'git archive' downloads +*/ + +package git + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "path" + "path/filepath" + "regexp" + "time" + + "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274 + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" +) + +type archive struct{ senddata.Prefix } +type archiveParams struct { + ArchivePath string + ArchivePrefix string + CommitId string + GitalyServer gitaly.Server + GitalyRepository gitalypb.Repository + DisableCache bool + GetArchiveRequest []byte +} + +var ( + SendArchive = &archive{"git-archive:"} + gitArchiveCache = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_git_archive_cache", + Help: "Cache hits and misses for 'git archive' streaming", + }, + []string{"result"}, + ) +) + +func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + var params archiveParams + if err := a.Unpack(¶ms, sendData); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendArchive: unpack sendData: %v", err)) + return + } + + urlPath := r.URL.Path + format, ok := parseBasename(filepath.Base(urlPath)) + if !ok { + helper.Fail500(w, r, fmt.Errorf("SendArchive: invalid format: %s", urlPath)) + return + } + + cacheEnabled := !params.DisableCache + archiveFilename := path.Base(params.ArchivePath) + + if cacheEnabled { + cachedArchive, err := os.Open(params.ArchivePath) + if err == nil { + defer cachedArchive.Close() + gitArchiveCache.WithLabelValues("hit").Inc() + setArchiveHeaders(w, format, archiveFilename) + // Even if somebody deleted the cachedArchive from disk since we opened + // the file, Unix file semantics guarantee we can still read from the + // open file in this process. + http.ServeContent(w, r, "", time.Unix(0, 0), cachedArchive) + return + } + } + + gitArchiveCache.WithLabelValues("miss").Inc() + + var tempFile *os.File + var err error + + if cacheEnabled { + // We assume the tempFile has a unique name so that concurrent requests are + // safe. We create the tempfile in the same directory as the final cached + // archive we want to create so that we can use an atomic link(2) operation + // to finalize the cached archive. + tempFile, err = prepareArchiveTempfile(path.Dir(params.ArchivePath), archiveFilename) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("SendArchive: create tempfile: %v", err)) + return + } + defer tempFile.Close() + defer os.Remove(tempFile.Name()) + } + + var archiveReader io.Reader + + archiveReader, err = handleArchiveWithGitaly(r, params, format) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("operations.GetArchive: %v", err)) + return + } + + reader := archiveReader + if cacheEnabled { + reader = io.TeeReader(archiveReader, tempFile) + } + + // Start writing the response + setArchiveHeaders(w, format, archiveFilename) + w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return + if _, err := io.Copy(w, reader); err != nil { + helper.LogError(r, ©Error{fmt.Errorf("SendArchive: copy 'git archive' output: %v", err)}) + return + } + + if cacheEnabled { + err := finalizeCachedArchive(tempFile, params.ArchivePath) + if err != nil { + helper.LogError(r, fmt.Errorf("SendArchive: finalize cached archive: %v", err)) + return + } + } +} + +func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gitalypb.GetArchiveRequest_Format) (io.Reader, error) { + var request *gitalypb.GetArchiveRequest + ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer) + if err != nil { + return nil, err + } + + if params.GetArchiveRequest != nil { + request = &gitalypb.GetArchiveRequest{} + + if err := proto.Unmarshal(params.GetArchiveRequest, request); err != nil { + return nil, fmt.Errorf("unmarshal GetArchiveRequest: %v", err) + } + } else { + request = &gitalypb.GetArchiveRequest{ + Repository: ¶ms.GitalyRepository, + CommitId: params.CommitId, + Prefix: params.ArchivePrefix, + Format: format, + } + } + + return c.ArchiveReader(ctx, request) +} + +func setArchiveHeaders(w http.ResponseWriter, format gitalypb.GetArchiveRequest_Format, archiveFilename string) { + w.Header().Del("Content-Length") + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename)) + // Caching proxies usually don't cache responses with Set-Cookie header + // present because it implies user-specific data, which is not the case + // for repository archives. + w.Header().Del("Set-Cookie") + if format == gitalypb.GetArchiveRequest_ZIP { + w.Header().Set("Content-Type", "application/zip") + } else { + w.Header().Set("Content-Type", "application/octet-stream") + } + w.Header().Set("Content-Transfer-Encoding", "binary") +} + +func prepareArchiveTempfile(dir string, prefix string) (*os.File, error) { + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, err + } + return ioutil.TempFile(dir, prefix) +} + +func finalizeCachedArchive(tempFile *os.File, archivePath string) error { + if err := tempFile.Close(); err != nil { + return err + } + if err := os.Link(tempFile.Name(), archivePath); err != nil && !os.IsExist(err) { + return err + } + + return nil +} + +var ( + patternZip = regexp.MustCompile(`\.zip$`) + patternTar = regexp.MustCompile(`\.tar$`) + patternTarGz = regexp.MustCompile(`\.(tar\.gz|tgz|gz)$`) + patternTarBz2 = regexp.MustCompile(`\.(tar\.bz2|tbz|tbz2|tb2|bz2)$`) +) + +func parseBasename(basename string) (gitalypb.GetArchiveRequest_Format, bool) { + var format gitalypb.GetArchiveRequest_Format + + switch { + case (basename == "archive"): + format = gitalypb.GetArchiveRequest_TAR_GZ + case patternZip.MatchString(basename): + format = gitalypb.GetArchiveRequest_ZIP + case patternTar.MatchString(basename): + format = gitalypb.GetArchiveRequest_TAR + case patternTarGz.MatchString(basename): + format = gitalypb.GetArchiveRequest_TAR_GZ + case patternTarBz2.MatchString(basename): + format = gitalypb.GetArchiveRequest_TAR_BZ2 + default: + return format, false + } + + return format, true +} diff --git a/workhorse/internal/git/archive_test.go b/workhorse/internal/git/archive_test.go new file mode 100644 index 00000000000..4b0753499e5 --- /dev/null +++ b/workhorse/internal/git/archive_test.go @@ -0,0 +1,87 @@ +package git + +import ( + "io/ioutil" + "net/http/httptest" + "testing" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestParseBasename(t *testing.T) { + for _, testCase := range []struct { + in string + out gitalypb.GetArchiveRequest_Format + }{ + {"archive", gitalypb.GetArchiveRequest_TAR_GZ}, + {"master.tar.gz", gitalypb.GetArchiveRequest_TAR_GZ}, + {"foo-master.tgz", gitalypb.GetArchiveRequest_TAR_GZ}, + {"foo-v1.2.1.gz", gitalypb.GetArchiveRequest_TAR_GZ}, + {"foo.tar.bz2", gitalypb.GetArchiveRequest_TAR_BZ2}, + {"archive.tbz", gitalypb.GetArchiveRequest_TAR_BZ2}, + {"archive.tbz2", gitalypb.GetArchiveRequest_TAR_BZ2}, + {"archive.tb2", gitalypb.GetArchiveRequest_TAR_BZ2}, + {"archive.bz2", gitalypb.GetArchiveRequest_TAR_BZ2}, + } { + basename := testCase.in + out, ok := parseBasename(basename) + if !ok { + t.Fatalf("parseBasename did not recognize %q", basename) + } + + if out != testCase.out { + t.Fatalf("expected %q, got %q", testCase.out, out) + } + } +} + +func TestFinalizeArchive(t *testing.T) { + tempFile, err := ioutil.TempFile("", "gitlab-workhorse-test") + if err != nil { + t.Fatal(err) + } + defer tempFile.Close() + + // Deliberately cause an EEXIST error: we know tempFile.Name() already exists + err = finalizeCachedArchive(tempFile, tempFile.Name()) + if err != nil { + t.Fatalf("expected nil from finalizeCachedArchive, received %v", err) + } +} + +func TestSetArchiveHeaders(t *testing.T) { + for _, testCase := range []struct { + in gitalypb.GetArchiveRequest_Format + out string + }{ + {gitalypb.GetArchiveRequest_ZIP, "application/zip"}, + {gitalypb.GetArchiveRequest_TAR, "application/octet-stream"}, + {gitalypb.GetArchiveRequest_TAR_GZ, "application/octet-stream"}, + {gitalypb.GetArchiveRequest_TAR_BZ2, "application/octet-stream"}, + } { + w := httptest.NewRecorder() + + // These should be replaced, not appended to + w.Header().Set("Content-Type", "test") + w.Header().Set("Content-Length", "test") + w.Header().Set("Content-Disposition", "test") + + // This should be deleted + w.Header().Set("Set-Cookie", "test") + + // This should be preserved + w.Header().Set("Cache-Control", "public, max-age=3600") + + setArchiveHeaders(w, testCase.in, "filename") + + testhelper.RequireResponseHeader(t, w, "Content-Type", testCase.out) + testhelper.RequireResponseHeader(t, w, "Content-Length") + testhelper.RequireResponseHeader(t, w, "Content-Disposition", `attachment; filename="filename"`) + testhelper.RequireResponseHeader(t, w, "Cache-Control", "public, max-age=3600") + require.Empty(t, w.Header().Get("Set-Cookie"), "remove Set-Cookie") + } +} diff --git a/workhorse/internal/git/blob.go b/workhorse/internal/git/blob.go new file mode 100644 index 00000000000..472f5d0bc96 --- /dev/null +++ b/workhorse/internal/git/blob.go @@ -0,0 +1,47 @@ +package git + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" +) + +type blob struct{ senddata.Prefix } +type blobParams struct { + GitalyServer gitaly.Server + GetBlobRequest gitalypb.GetBlobRequest +} + +var SendBlob = &blob{"git-blob:"} + +func (b *blob) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + var params blobParams + if err := b.Unpack(¶ms, sendData); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendBlob: unpack sendData: %v", err)) + return + } + + ctx, blobClient, err := gitaly.NewBlobClient(r.Context(), params.GitalyServer) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err)) + return + } + + setBlobHeaders(w) + if err := blobClient.SendBlob(ctx, w, ¶ms.GetBlobRequest); err != nil { + helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err)) + return + } +} + +func setBlobHeaders(w http.ResponseWriter) { + // Caching proxies usually don't cache responses with Set-Cookie header + // present because it implies user-specific data, which is not the case + // for blobs. + w.Header().Del("Set-Cookie") +} diff --git a/workhorse/internal/git/blob_test.go b/workhorse/internal/git/blob_test.go new file mode 100644 index 00000000000..ec28c2adb2f --- /dev/null +++ b/workhorse/internal/git/blob_test.go @@ -0,0 +1,17 @@ +package git + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSetBlobHeaders(t *testing.T) { + w := httptest.NewRecorder() + w.Header().Set("Set-Cookie", "gitlab_cookie=123456") + + setBlobHeaders(w) + + require.Empty(t, w.Header().Get("Set-Cookie"), "remove Set-Cookie") +} diff --git a/workhorse/internal/git/diff.go b/workhorse/internal/git/diff.go new file mode 100644 index 00000000000..b1a1c17a650 --- /dev/null +++ b/workhorse/internal/git/diff.go @@ -0,0 +1,48 @@ +package git + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" +) + +type diff struct{ senddata.Prefix } +type diffParams struct { + GitalyServer gitaly.Server + RawDiffRequest string +} + +var SendDiff = &diff{"git-diff:"} + +func (d *diff) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + var params diffParams + if err := d.Unpack(¶ms, sendData); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendDiff: unpack sendData: %v", err)) + return + } + + request := &gitalypb.RawDiffRequest{} + if err := gitaly.UnmarshalJSON(params.RawDiffRequest, request); err != nil { + helper.Fail500(w, r, fmt.Errorf("diff.RawDiff: %v", err)) + return + } + + ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("diff.RawDiff: %v", err)) + return + } + + if err := diffClient.SendRawDiff(ctx, w, request); err != nil { + helper.LogError( + r, + ©Error{fmt.Errorf("diff.RawDiff: request=%v, err=%v", request, err)}, + ) + return + } +} diff --git a/workhorse/internal/git/error.go b/workhorse/internal/git/error.go new file mode 100644 index 00000000000..2b7cad6bb64 --- /dev/null +++ b/workhorse/internal/git/error.go @@ -0,0 +1,4 @@ +package git + +// For cosmetic purposes in Sentry +type copyError struct{ error } diff --git a/workhorse/internal/git/format-patch.go b/workhorse/internal/git/format-patch.go new file mode 100644 index 00000000000..db96029b07e --- /dev/null +++ b/workhorse/internal/git/format-patch.go @@ -0,0 +1,48 @@ +package git + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" +) + +type patch struct{ senddata.Prefix } +type patchParams struct { + GitalyServer gitaly.Server + RawPatchRequest string +} + +var SendPatch = &patch{"git-format-patch:"} + +func (p *patch) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + var params patchParams + if err := p.Unpack(¶ms, sendData); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendPatch: unpack sendData: %v", err)) + return + } + + request := &gitalypb.RawPatchRequest{} + if err := gitaly.UnmarshalJSON(params.RawPatchRequest, request); err != nil { + helper.Fail500(w, r, fmt.Errorf("diff.RawPatch: %v", err)) + return + } + + ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("diff.RawPatch: %v", err)) + return + } + + if err := diffClient.SendRawPatch(ctx, w, request); err != nil { + helper.LogError( + r, + ©Error{fmt.Errorf("diff.RawPatch: request=%v, err=%v", request, err)}, + ) + return + } +} diff --git a/workhorse/internal/git/git-http.go b/workhorse/internal/git/git-http.go new file mode 100644 index 00000000000..5df20a68bb7 --- /dev/null +++ b/workhorse/internal/git/git-http.go @@ -0,0 +1,100 @@ +/* +In this file we handle the Git 'smart HTTP' protocol +*/ + +package git + +import ( + "fmt" + "io" + "net/http" + "path/filepath" + "sync" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +const ( + // We have to use a negative transfer.hideRefs since this is the only way + // to undo an already set parameter: https://www.spinics.net/lists/git/msg256772.html + GitConfigShowAllRefs = "transfer.hideRefs=!refs" +) + +func ReceivePack(a *api.API) http.Handler { + return postRPCHandler(a, "handleReceivePack", handleReceivePack) +} + +func UploadPack(a *api.API) http.Handler { + return postRPCHandler(a, "handleUploadPack", handleUploadPack) +} + +func gitConfigOptions(a *api.Response) []string { + var out []string + + if a.ShowAllRefs { + out = append(out, GitConfigShowAllRefs) + } + + return out +} + +func postRPCHandler(a *api.API, name string, handler func(*HttpResponseWriter, *http.Request, *api.Response) error) http.Handler { + return repoPreAuthorizeHandler(a, func(rw http.ResponseWriter, r *http.Request, ar *api.Response) { + cr := &countReadCloser{ReadCloser: r.Body} + r.Body = cr + + w := NewHttpResponseWriter(rw) + defer func() { + w.Log(r, cr.Count()) + }() + + if err := handler(w, r, ar); err != nil { + // If the handler already wrote a response this WriteHeader call is a + // no-op. It never reaches net/http because GitHttpResponseWriter calls + // WriteHeader on its underlying ResponseWriter at most once. + w.WriteHeader(500) + helper.LogError(r, fmt.Errorf("%s: %v", name, err)) + } + }) +} + +func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler { + return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) { + handleFunc(w, r, a) + }, "") +} + +func writePostRPCHeader(w http.ResponseWriter, action string) { + w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", action)) + w.Header().Set("Cache-Control", "no-cache") +} + +func getService(r *http.Request) string { + if r.Method == "GET" { + return r.URL.Query().Get("service") + } + return filepath.Base(r.URL.Path) +} + +type countReadCloser struct { + n int64 + io.ReadCloser + sync.Mutex +} + +func (c *countReadCloser) Read(p []byte) (n int, err error) { + n, err = c.ReadCloser.Read(p) + + c.Lock() + defer c.Unlock() + c.n += int64(n) + + return n, err +} + +func (c *countReadCloser) Count() int64 { + c.Lock() + defer c.Unlock() + return c.n +} diff --git a/workhorse/internal/git/info-refs.go b/workhorse/internal/git/info-refs.go new file mode 100644 index 00000000000..e5491a7b733 --- /dev/null +++ b/workhorse/internal/git/info-refs.go @@ -0,0 +1,76 @@ +package git + +import ( + "compress/gzip" + "context" + "fmt" + "io" + "net/http" + + "github.com/golang/gddo/httputil" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +func GetInfoRefsHandler(a *api.API) http.Handler { + return repoPreAuthorizeHandler(a, handleGetInfoRefs) +} + +func handleGetInfoRefs(rw http.ResponseWriter, r *http.Request, a *api.Response) { + responseWriter := NewHttpResponseWriter(rw) + // Log 0 bytes in because we ignore the request body (and there usually is none anyway). + defer responseWriter.Log(r, 0) + + rpc := getService(r) + if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { + // The 'dumb' Git HTTP protocol is not supported + http.Error(responseWriter, "Not Found", 404) + return + } + + responseWriter.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc)) + responseWriter.Header().Set("Cache-Control", "no-cache") + + gitProtocol := r.Header.Get("Git-Protocol") + + offers := []string{"gzip", "identity"} + encoding := httputil.NegotiateContentEncoding(r, offers) + + if err := handleGetInfoRefsWithGitaly(r.Context(), responseWriter, a, rpc, gitProtocol, encoding); err != nil { + helper.Fail500(responseWriter, r, fmt.Errorf("handleGetInfoRefs: %v", err)) + } +} + +func handleGetInfoRefsWithGitaly(ctx context.Context, responseWriter *HttpResponseWriter, a *api.Response, rpc, gitProtocol, encoding string) error { + ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer) + if err != nil { + return fmt.Errorf("GetInfoRefsHandler: %v", err) + } + + infoRefsResponseReader, err := smarthttp.InfoRefsResponseReader(ctx, &a.Repository, rpc, gitConfigOptions(a), gitProtocol) + if err != nil { + return fmt.Errorf("GetInfoRefsHandler: %v", err) + } + + var w io.Writer + + if encoding == "gzip" { + gzWriter := gzip.NewWriter(responseWriter) + w = gzWriter + defer gzWriter.Close() + + responseWriter.Header().Set("Content-Encoding", "gzip") + } else { + w = responseWriter + } + + if _, err = io.Copy(w, infoRefsResponseReader); err != nil { + log.WithError(err).Error("GetInfoRefsHandler: error copying gitaly response") + } + + return nil +} diff --git a/workhorse/internal/git/pktline.go b/workhorse/internal/git/pktline.go new file mode 100644 index 00000000000..e970f60182d --- /dev/null +++ b/workhorse/internal/git/pktline.go @@ -0,0 +1,59 @@ +package git + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strconv" +) + +func scanDeepen(body io.Reader) bool { + scanner := bufio.NewScanner(body) + scanner.Split(pktLineSplitter) + for scanner.Scan() { + if bytes.HasPrefix(scanner.Bytes(), []byte("deepen")) && scanner.Err() == nil { + return true + } + } + + return false +} + +func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err error) { + if len(data) < 4 { + if atEOF && len(data) > 0 { + return 0, nil, fmt.Errorf("pktLineSplitter: incomplete length prefix on %q", data) + } + return 0, nil, nil // want more data + } + + if bytes.HasPrefix(data, []byte("0000")) { + // special case: "0000" terminator packet: return empty token + return 4, data[:0], nil + } + + // We have at least 4 bytes available so we can decode the 4-hex digit + // length prefix of the packet line. + pktLength64, err := strconv.ParseInt(string(data[:4]), 16, 0) + if err != nil { + return 0, nil, fmt.Errorf("pktLineSplitter: decode length: %v", err) + } + + // Cast is safe because we requested an int-size number from strconv.ParseInt + pktLength := int(pktLength64) + + if pktLength < 0 { + return 0, nil, fmt.Errorf("pktLineSplitter: invalid length: %d", pktLength) + } + + if len(data) < pktLength { + if atEOF { + return 0, nil, fmt.Errorf("pktLineSplitter: less than %d bytes in input %q", pktLength, data) + } + return 0, nil, nil // want more data + } + + // return "pkt" token without length prefix + return pktLength, data[4:pktLength], nil +} diff --git a/workhorse/internal/git/pktline_test.go b/workhorse/internal/git/pktline_test.go new file mode 100644 index 00000000000..d4be8634538 --- /dev/null +++ b/workhorse/internal/git/pktline_test.go @@ -0,0 +1,39 @@ +package git + +import ( + "bytes" + "testing" +) + +func TestSuccessfulScanDeepen(t *testing.T) { + examples := []struct { + input string + output bool + }{ + {"000dsomething000cdeepen 10000", true}, + {"000dsomething0000000cdeepen 1", true}, + {"000dsomething0000", false}, + } + + for _, example := range examples { + hasDeepen := scanDeepen(bytes.NewReader([]byte(example.input))) + + if hasDeepen != example.output { + t.Fatalf("scanDeepen %q: expected %v, got %v", example.input, example.output, hasDeepen) + } + } +} + +func TestFailedScanDeepen(t *testing.T) { + examples := []string{ + "invalid data", + "deepen", + "000cdeepen", + } + + for _, example := range examples { + if scanDeepen(bytes.NewReader([]byte(example))) { + t.Fatalf("scanDeepen %q: expected result to be false, got true", example) + } + } +} diff --git a/workhorse/internal/git/receive-pack.go b/workhorse/internal/git/receive-pack.go new file mode 100644 index 00000000000..e72d8be5174 --- /dev/null +++ b/workhorse/internal/git/receive-pack.go @@ -0,0 +1,33 @@ +package git + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +// Will not return a non-nil error after the response body has been +// written to. +func handleReceivePack(w *HttpResponseWriter, r *http.Request, a *api.Response) error { + action := getService(r) + writePostRPCHeader(w, action) + + cr, cw := helper.NewWriteAfterReader(r.Body, w) + defer cw.Flush() + + gitProtocol := r.Header.Get("Git-Protocol") + + ctx, smarthttp, err := gitaly.NewSmartHTTPClient(r.Context(), a.GitalyServer) + if err != nil { + return fmt.Errorf("smarthttp.ReceivePack: %v", err) + } + + if err := smarthttp.ReceivePack(ctx, &a.Repository, a.GL_ID, a.GL_USERNAME, a.GL_REPOSITORY, a.GitConfigOptions, cr, cw, gitProtocol); err != nil { + return fmt.Errorf("smarthttp.ReceivePack: %v", err) + } + + return nil +} diff --git a/workhorse/internal/git/responsewriter.go b/workhorse/internal/git/responsewriter.go new file mode 100644 index 00000000000..c4d4ac252d4 --- /dev/null +++ b/workhorse/internal/git/responsewriter.go @@ -0,0 +1,75 @@ +package git + +import ( + "net/http" + "strconv" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +const ( + directionIn = "in" + directionOut = "out" +) + +var ( + gitHTTPSessionsActive = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "gitlab_workhorse_git_http_sessions_active", + Help: "Number of Git HTTP request-response cycles currently being handled by gitlab-workhorse.", + }) + + gitHTTPRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_git_http_requests", + Help: "How many Git HTTP requests have been processed by gitlab-workhorse, partitioned by request type and agent.", + }, + []string{"method", "code", "service", "agent"}, + ) + + gitHTTPBytes = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_git_http_bytes", + Help: "How many Git HTTP bytes have been sent by gitlab-workhorse, partitioned by request type, agent and direction.", + }, + []string{"method", "code", "service", "agent", "direction"}, + ) +) + +type HttpResponseWriter struct { + helper.CountingResponseWriter +} + +func NewHttpResponseWriter(rw http.ResponseWriter) *HttpResponseWriter { + gitHTTPSessionsActive.Inc() + return &HttpResponseWriter{ + CountingResponseWriter: helper.NewCountingResponseWriter(rw), + } +} + +func (w *HttpResponseWriter) Log(r *http.Request, writtenIn int64) { + service := getService(r) + agent := getRequestAgent(r) + + gitHTTPSessionsActive.Dec() + gitHTTPRequests.WithLabelValues(r.Method, strconv.Itoa(w.Status()), service, agent).Inc() + gitHTTPBytes.WithLabelValues(r.Method, strconv.Itoa(w.Status()), service, agent, directionIn). + Add(float64(writtenIn)) + gitHTTPBytes.WithLabelValues(r.Method, strconv.Itoa(w.Status()), service, agent, directionOut). + Add(float64(w.Count())) +} + +func getRequestAgent(r *http.Request) string { + u, _, ok := r.BasicAuth() + if !ok { + return "anonymous" + } + + if u == "gitlab-ci-token" { + return "gitlab-ci" + } + + return "logged" +} diff --git a/workhorse/internal/git/snapshot.go b/workhorse/internal/git/snapshot.go new file mode 100644 index 00000000000..eb38becbd06 --- /dev/null +++ b/workhorse/internal/git/snapshot.go @@ -0,0 +1,64 @@ +package git + +import ( + "fmt" + "io" + "net/http" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" +) + +type snapshot struct { + senddata.Prefix +} + +type snapshotParams struct { + GitalyServer gitaly.Server + GetSnapshotRequest string +} + +var ( + SendSnapshot = &snapshot{"git-snapshot:"} +) + +func (s *snapshot) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + var params snapshotParams + + if err := s.Unpack(¶ms, sendData); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendSnapshot: unpack sendData: %v", err)) + return + } + + request := &gitalypb.GetSnapshotRequest{} + if err := gitaly.UnmarshalJSON(params.GetSnapshotRequest, request); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendSnapshot: unmarshal GetSnapshotRequest: %v", err)) + return + } + + ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("SendSnapshot: gitaly.NewRepositoryClient: %v", err)) + return + } + + reader, err := c.SnapshotReader(ctx, request) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("SendSnapshot: client.SnapshotReader: %v", err)) + return + } + + w.Header().Del("Content-Length") + w.Header().Set("Content-Disposition", `attachment; filename="snapshot.tar"`) + w.Header().Set("Content-Type", "application/x-tar") + w.Header().Set("Content-Transfer-Encoding", "binary") + w.Header().Set("Cache-Control", "private") + w.WriteHeader(http.StatusOK) // Errors aren't detectable beyond this point + + if _, err := io.Copy(w, reader); err != nil { + helper.LogError(r, fmt.Errorf("SendSnapshot: copy gitaly output: %v", err)) + } +} diff --git a/workhorse/internal/git/upload-pack.go b/workhorse/internal/git/upload-pack.go new file mode 100644 index 00000000000..a3dbf2f2e02 --- /dev/null +++ b/workhorse/internal/git/upload-pack.go @@ -0,0 +1,57 @@ +package git + +import ( + "context" + "fmt" + "io" + "net/http" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + uploadPackTimeout = 10 * time.Minute +) + +// Will not return a non-nil error after the response body has been +// written to. +func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) error { + ctx := r.Context() + + // Prevent the client from holding the connection open indefinitely. A + // transfer rate of 17KiB/sec is sufficient to send 10MiB of data in + // ten minutes, which seems adequate. Most requests will be much smaller. + // This mitigates a use-after-check issue. + // + // We can't reliably interrupt the read from a http handler, but we can + // ensure the request will (eventually) fail: https://github.com/golang/go/issues/16100 + readerCtx, cancel := context.WithTimeout(ctx, uploadPackTimeout) + defer cancel() + + limited := helper.NewContextReader(readerCtx, r.Body) + cr, cw := helper.NewWriteAfterReader(limited, w) + defer cw.Flush() + + action := getService(r) + writePostRPCHeader(w, action) + + gitProtocol := r.Header.Get("Git-Protocol") + + return handleUploadPackWithGitaly(ctx, a, cr, cw, gitProtocol) +} + +func handleUploadPackWithGitaly(ctx context.Context, a *api.Response, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error { + ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer) + if err != nil { + return fmt.Errorf("smarthttp.UploadPack: %v", err) + } + + if err := smarthttp.UploadPack(ctx, &a.Repository, clientRequest, clientResponse, gitConfigOptions(a), gitProtocol); err != nil { + return fmt.Errorf("smarthttp.UploadPack: %v", err) + } + + return nil +} diff --git a/workhorse/internal/git/upload-pack_test.go b/workhorse/internal/git/upload-pack_test.go new file mode 100644 index 00000000000..c198939d5df --- /dev/null +++ b/workhorse/internal/git/upload-pack_test.go @@ -0,0 +1,85 @@ +package git + +import ( + "fmt" + "io/ioutil" + "net" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" +) + +var ( + originalUploadPackTimeout = uploadPackTimeout +) + +type fakeReader struct { + n int + err error +} + +func (f *fakeReader) Read(b []byte) (int, error) { + return f.n, f.err +} + +type smartHTTPServiceServer struct { + gitalypb.UnimplementedSmartHTTPServiceServer + PostUploadPackFunc func(gitalypb.SmartHTTPService_PostUploadPackServer) error +} + +func (srv *smartHTTPServiceServer) PostUploadPack(s gitalypb.SmartHTTPService_PostUploadPackServer) error { + return srv.PostUploadPackFunc(s) +} + +func TestUploadPackTimesOut(t *testing.T) { + uploadPackTimeout = time.Millisecond + defer func() { uploadPackTimeout = originalUploadPackTimeout }() + + addr, cleanUp := startSmartHTTPServer(t, &smartHTTPServiceServer{ + PostUploadPackFunc: func(stream gitalypb.SmartHTTPService_PostUploadPackServer) error { + _, err := stream.Recv() // trigger a read on the client request body + require.NoError(t, err) + return nil + }, + }) + defer cleanUp() + + body := &fakeReader{n: 0, err: nil} + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", body) + a := &api.Response{GitalyServer: gitaly.Server{Address: addr}} + + err := handleUploadPack(NewHttpResponseWriter(w), r, a) + require.EqualError(t, err, "smarthttp.UploadPack: busyReader: context deadline exceeded") +} + +func startSmartHTTPServer(t testing.TB, s gitalypb.SmartHTTPServiceServer) (string, func()) { + tmp, err := ioutil.TempDir("", "") + require.NoError(t, err) + + socket := filepath.Join(tmp, "gitaly.sock") + ln, err := net.Listen("unix", socket) + require.NoError(t, err) + + srv := grpc.NewServer() + gitalypb.RegisterSmartHTTPServiceServer(srv, s) + go func() { + require.NoError(t, srv.Serve(ln)) + }() + + return fmt.Sprintf("%s://%s", ln.Addr().Network(), ln.Addr().String()), func() { + srv.GracefulStop() + require.NoError(t, os.RemoveAll(tmp), "error removing temp dir %q", tmp) + } +} diff --git a/workhorse/internal/gitaly/blob.go b/workhorse/internal/gitaly/blob.go new file mode 100644 index 00000000000..c6f5d6676f3 --- /dev/null +++ b/workhorse/internal/gitaly/blob.go @@ -0,0 +1,41 @@ +package gitaly + +import ( + "context" + "fmt" + "io" + "net/http" + "strconv" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitaly/streamio" +) + +type BlobClient struct { + gitalypb.BlobServiceClient +} + +func (client *BlobClient) SendBlob(ctx context.Context, w http.ResponseWriter, request *gitalypb.GetBlobRequest) error { + c, err := client.GetBlob(ctx, request) + if err != nil { + return fmt.Errorf("rpc failed: %v", err) + } + + firstResponseReceived := false + rr := streamio.NewReader(func() ([]byte, error) { + resp, err := c.Recv() + + if !firstResponseReceived && err == nil { + firstResponseReceived = true + w.Header().Set("Content-Length", strconv.FormatInt(resp.GetSize(), 10)) + } + + return resp.GetData(), err + }) + + if _, err := io.Copy(w, rr); err != nil { + return fmt.Errorf("copy rpc data: %v", err) + } + + return nil +} diff --git a/workhorse/internal/gitaly/diff.go b/workhorse/internal/gitaly/diff.go new file mode 100644 index 00000000000..035a58ec6fd --- /dev/null +++ b/workhorse/internal/gitaly/diff.go @@ -0,0 +1,55 @@ +package gitaly + +import ( + "context" + "fmt" + "io" + "net/http" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitaly/streamio" +) + +type DiffClient struct { + gitalypb.DiffServiceClient +} + +func (client *DiffClient) SendRawDiff(ctx context.Context, w http.ResponseWriter, request *gitalypb.RawDiffRequest) error { + c, err := client.RawDiff(ctx, request) + if err != nil { + return fmt.Errorf("rpc failed: %v", err) + } + + w.Header().Del("Content-Length") + + rr := streamio.NewReader(func() ([]byte, error) { + resp, err := c.Recv() + return resp.GetData(), err + }) + + if _, err := io.Copy(w, rr); err != nil { + return fmt.Errorf("copy rpc data: %v", err) + } + + return nil +} + +func (client *DiffClient) SendRawPatch(ctx context.Context, w http.ResponseWriter, request *gitalypb.RawPatchRequest) error { + c, err := client.RawPatch(ctx, request) + if err != nil { + return fmt.Errorf("rpc failed: %v", err) + } + + w.Header().Del("Content-Length") + + rr := streamio.NewReader(func() ([]byte, error) { + resp, err := c.Recv() + return resp.GetData(), err + }) + + if _, err := io.Copy(w, rr); err != nil { + return fmt.Errorf("copy rpc data: %v", err) + } + + return nil +} diff --git a/workhorse/internal/gitaly/gitaly.go b/workhorse/internal/gitaly/gitaly.go new file mode 100644 index 00000000000..c739ac8d9b2 --- /dev/null +++ b/workhorse/internal/gitaly/gitaly.go @@ -0,0 +1,188 @@ +package gitaly + +import ( + "context" + "strings" + "sync" + + "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274 + "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274 + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + gitalyauth "gitlab.com/gitlab-org/gitaly/auth" + gitalyclient "gitlab.com/gitlab-org/gitaly/client" + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" + grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" +) + +type Server struct { + Address string `json:"address"` + Token string `json:"token"` + Features map[string]string `json:"features"` +} + +type cacheKey struct{ address, token string } + +func (server Server) cacheKey() cacheKey { + return cacheKey{address: server.Address, token: server.Token} +} + +type connectionsCache struct { + sync.RWMutex + connections map[cacheKey]*grpc.ClientConn +} + +var ( + jsonUnMarshaler = jsonpb.Unmarshaler{AllowUnknownFields: true} + cache = connectionsCache{ + connections: make(map[cacheKey]*grpc.ClientConn), + } + + connectionsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_gitaly_connections_total", + Help: "Number of Gitaly connections that have been established", + }, + []string{"status"}, + ) +) + +func withOutgoingMetadata(ctx context.Context, features map[string]string) context.Context { + md := metadata.New(nil) + for k, v := range features { + if !strings.HasPrefix(k, "gitaly-feature-") { + continue + } + md.Append(k, v) + } + + return metadata.NewOutgoingContext(ctx, md) +} + +func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *SmartHTTPClient, error) { + conn, err := getOrCreateConnection(server) + if err != nil { + return nil, nil, err + } + grpcClient := gitalypb.NewSmartHTTPServiceClient(conn) + return withOutgoingMetadata(ctx, server.Features), &SmartHTTPClient{grpcClient}, nil +} + +func NewBlobClient(ctx context.Context, server Server) (context.Context, *BlobClient, error) { + conn, err := getOrCreateConnection(server) + if err != nil { + return nil, nil, err + } + grpcClient := gitalypb.NewBlobServiceClient(conn) + return withOutgoingMetadata(ctx, server.Features), &BlobClient{grpcClient}, nil +} + +func NewRepositoryClient(ctx context.Context, server Server) (context.Context, *RepositoryClient, error) { + conn, err := getOrCreateConnection(server) + if err != nil { + return nil, nil, err + } + grpcClient := gitalypb.NewRepositoryServiceClient(conn) + return withOutgoingMetadata(ctx, server.Features), &RepositoryClient{grpcClient}, nil +} + +// NewNamespaceClient is only used by the Gitaly integration tests at present +func NewNamespaceClient(ctx context.Context, server Server) (context.Context, *NamespaceClient, error) { + conn, err := getOrCreateConnection(server) + if err != nil { + return nil, nil, err + } + grpcClient := gitalypb.NewNamespaceServiceClient(conn) + return withOutgoingMetadata(ctx, server.Features), &NamespaceClient{grpcClient}, nil +} + +func NewDiffClient(ctx context.Context, server Server) (context.Context, *DiffClient, error) { + conn, err := getOrCreateConnection(server) + if err != nil { + return nil, nil, err + } + grpcClient := gitalypb.NewDiffServiceClient(conn) + return withOutgoingMetadata(ctx, server.Features), &DiffClient{grpcClient}, nil +} + +func getOrCreateConnection(server Server) (*grpc.ClientConn, error) { + key := server.cacheKey() + + cache.RLock() + conn := cache.connections[key] + cache.RUnlock() + + if conn != nil { + return conn, nil + } + + cache.Lock() + defer cache.Unlock() + + if conn := cache.connections[key]; conn != nil { + return conn, nil + } + + conn, err := newConnection(server) + if err != nil { + return nil, err + } + + cache.connections[key] = conn + + return conn, nil +} + +func CloseConnections() { + cache.Lock() + defer cache.Unlock() + + for _, conn := range cache.connections { + conn.Close() + } +} + +func newConnection(server Server) (*grpc.ClientConn, error) { + connOpts := append(gitalyclient.DefaultDialOpts, + grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(server.Token)), + grpc.WithStreamInterceptor( + grpc_middleware.ChainStreamClient( + grpctracing.StreamClientTracingInterceptor(), + grpc_prometheus.StreamClientInterceptor, + grpccorrelation.StreamClientCorrelationInterceptor( + grpccorrelation.WithClientName("gitlab-workhorse"), + ), + ), + ), + + grpc.WithUnaryInterceptor( + grpc_middleware.ChainUnaryClient( + grpctracing.UnaryClientTracingInterceptor(), + grpc_prometheus.UnaryClientInterceptor, + grpccorrelation.UnaryClientCorrelationInterceptor( + grpccorrelation.WithClientName("gitlab-workhorse"), + ), + ), + ), + ) + + conn, connErr := gitalyclient.Dial(server.Address, connOpts) + + label := "ok" + if connErr != nil { + label = "fail" + } + connectionsTotal.WithLabelValues(label).Inc() + + return conn, connErr +} + +func UnmarshalJSON(s string, msg proto.Message) error { + return jsonUnMarshaler.Unmarshal(strings.NewReader(s), msg) +} diff --git a/workhorse/internal/gitaly/gitaly_test.go b/workhorse/internal/gitaly/gitaly_test.go new file mode 100644 index 00000000000..b17fb5c1d7b --- /dev/null +++ b/workhorse/internal/gitaly/gitaly_test.go @@ -0,0 +1,80 @@ +package gitaly + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" +) + +func TestNewSmartHTTPClient(t *testing.T) { + ctx, _, err := NewSmartHTTPClient(context.Background(), serverFixture()) + require.NoError(t, err) + testOutgoingMetadata(t, ctx) +} + +func TestNewBlobClient(t *testing.T) { + ctx, _, err := NewBlobClient(context.Background(), serverFixture()) + require.NoError(t, err) + testOutgoingMetadata(t, ctx) +} + +func TestNewRepositoryClient(t *testing.T) { + ctx, _, err := NewRepositoryClient(context.Background(), serverFixture()) + require.NoError(t, err) + testOutgoingMetadata(t, ctx) +} + +func TestNewNamespaceClient(t *testing.T) { + ctx, _, err := NewNamespaceClient(context.Background(), serverFixture()) + require.NoError(t, err) + testOutgoingMetadata(t, ctx) +} + +func TestNewDiffClient(t *testing.T) { + ctx, _, err := NewDiffClient(context.Background(), serverFixture()) + require.NoError(t, err) + testOutgoingMetadata(t, ctx) +} + +func testOutgoingMetadata(t *testing.T, ctx context.Context) { + md, ok := metadata.FromOutgoingContext(ctx) + require.True(t, ok, "get metadata from context") + + for k, v := range allowedFeatures() { + actual := md[k] + require.Len(t, actual, 1, "expect one value for %v", k) + require.Equal(t, v, actual[0], "value for %v", k) + } + + for k := range badFeatureMetadata() { + require.Empty(t, md[k], "value for bad key %v", k) + } +} + +func serverFixture() Server { + features := make(map[string]string) + for k, v := range allowedFeatures() { + features[k] = v + } + for k, v := range badFeatureMetadata() { + features[k] = v + } + + return Server{Address: "tcp://localhost:123", Features: features} +} + +func allowedFeatures() map[string]string { + return map[string]string{ + "gitaly-feature-foo": "bar", + "gitaly-feature-qux": "baz", + } +} + +func badFeatureMetadata() map[string]string { + return map[string]string{ + "bad-metadata-1": "bad-value-1", + "bad-metadata-2": "bad-value-2", + } +} diff --git a/workhorse/internal/gitaly/namespace.go b/workhorse/internal/gitaly/namespace.go new file mode 100644 index 00000000000..6db6ed4fc32 --- /dev/null +++ b/workhorse/internal/gitaly/namespace.go @@ -0,0 +1,8 @@ +package gitaly + +import "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + +// NamespaceClient encapsulates NamespaceService calls +type NamespaceClient struct { + gitalypb.NamespaceServiceClient +} diff --git a/workhorse/internal/gitaly/repository.go b/workhorse/internal/gitaly/repository.go new file mode 100644 index 00000000000..e3ec3257a85 --- /dev/null +++ b/workhorse/internal/gitaly/repository.go @@ -0,0 +1,45 @@ +package gitaly + +import ( + "context" + "fmt" + "io" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitaly/streamio" +) + +// RepositoryClient encapsulates RepositoryService calls +type RepositoryClient struct { + gitalypb.RepositoryServiceClient +} + +// ArchiveReader performs a GetArchive Gitaly request and returns an io.Reader +// for the response +func (client *RepositoryClient) ArchiveReader(ctx context.Context, request *gitalypb.GetArchiveRequest) (io.Reader, error) { + c, err := client.GetArchive(ctx, request) + if err != nil { + return nil, fmt.Errorf("RepositoryService::GetArchive: %v", err) + } + + return streamio.NewReader(func() ([]byte, error) { + resp, err := c.Recv() + + return resp.GetData(), err + }), nil +} + +// SnapshotReader performs a GetSnapshot Gitaly request and returns an io.Reader +// for the response +func (client *RepositoryClient) SnapshotReader(ctx context.Context, request *gitalypb.GetSnapshotRequest) (io.Reader, error) { + c, err := client.GetSnapshot(ctx, request) + if err != nil { + return nil, fmt.Errorf("RepositoryService::GetSnapshot: %v", err) + } + + return streamio.NewReader(func() ([]byte, error) { + resp, err := c.Recv() + + return resp.GetData(), err + }), nil +} diff --git a/workhorse/internal/gitaly/smarthttp.go b/workhorse/internal/gitaly/smarthttp.go new file mode 100644 index 00000000000..d1fe6fae5ba --- /dev/null +++ b/workhorse/internal/gitaly/smarthttp.go @@ -0,0 +1,139 @@ +package gitaly + +import ( + "context" + "fmt" + "io" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitaly/streamio" +) + +type SmartHTTPClient struct { + gitalypb.SmartHTTPServiceClient +} + +func (client *SmartHTTPClient) InfoRefsResponseReader(ctx context.Context, repo *gitalypb.Repository, rpc string, gitConfigOptions []string, gitProtocol string) (io.Reader, error) { + rpcRequest := &gitalypb.InfoRefsRequest{ + Repository: repo, + GitConfigOptions: gitConfigOptions, + GitProtocol: gitProtocol, + } + + switch rpc { + case "git-upload-pack": + stream, err := client.InfoRefsUploadPack(ctx, rpcRequest) + return infoRefsReader(stream), err + case "git-receive-pack": + stream, err := client.InfoRefsReceivePack(ctx, rpcRequest) + return infoRefsReader(stream), err + default: + return nil, fmt.Errorf("InfoRefsResponseWriterTo: Unsupported RPC: %q", rpc) + } +} + +type infoRefsClient interface { + Recv() (*gitalypb.InfoRefsResponse, error) +} + +func infoRefsReader(stream infoRefsClient) io.Reader { + return streamio.NewReader(func() ([]byte, error) { + resp, err := stream.Recv() + return resp.GetData(), err + }) +} + +func (client *SmartHTTPClient) ReceivePack(ctx context.Context, repo *gitalypb.Repository, glId string, glUsername string, glRepository string, gitConfigOptions []string, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error { + stream, err := client.PostReceivePack(ctx) + if err != nil { + return err + } + + rpcRequest := &gitalypb.PostReceivePackRequest{ + Repository: repo, + GlId: glId, + GlUsername: glUsername, + GlRepository: glRepository, + GitConfigOptions: gitConfigOptions, + GitProtocol: gitProtocol, + } + + if err := stream.Send(rpcRequest); err != nil { + return fmt.Errorf("initial request: %v", err) + } + + numStreams := 2 + errC := make(chan error, numStreams) + + go func() { + rr := streamio.NewReader(func() ([]byte, error) { + response, err := stream.Recv() + return response.GetData(), err + }) + _, err := io.Copy(clientResponse, rr) + errC <- err + }() + + go func() { + sw := streamio.NewWriter(func(data []byte) error { + return stream.Send(&gitalypb.PostReceivePackRequest{Data: data}) + }) + _, err := io.Copy(sw, clientRequest) + stream.CloseSend() + errC <- err + }() + + for i := 0; i < numStreams; i++ { + if err := <-errC; err != nil { + return err + } + } + + return nil +} + +func (client *SmartHTTPClient) UploadPack(ctx context.Context, repo *gitalypb.Repository, clientRequest io.Reader, clientResponse io.Writer, gitConfigOptions []string, gitProtocol string) error { + stream, err := client.PostUploadPack(ctx) + if err != nil { + return err + } + + rpcRequest := &gitalypb.PostUploadPackRequest{ + Repository: repo, + GitConfigOptions: gitConfigOptions, + GitProtocol: gitProtocol, + } + + if err := stream.Send(rpcRequest); err != nil { + return fmt.Errorf("initial request: %v", err) + } + + numStreams := 2 + errC := make(chan error, numStreams) + + go func() { + rr := streamio.NewReader(func() ([]byte, error) { + response, err := stream.Recv() + return response.GetData(), err + }) + _, err := io.Copy(clientResponse, rr) + errC <- err + }() + + go func() { + sw := streamio.NewWriter(func(data []byte) error { + return stream.Send(&gitalypb.PostUploadPackRequest{Data: data}) + }) + _, err := io.Copy(sw, clientRequest) + stream.CloseSend() + errC <- err + }() + + for i := 0; i < numStreams; i++ { + if err := <-errC; err != nil { + return err + } + } + + return nil +} diff --git a/workhorse/internal/gitaly/unmarshal_test.go b/workhorse/internal/gitaly/unmarshal_test.go new file mode 100644 index 00000000000..e2256903339 --- /dev/null +++ b/workhorse/internal/gitaly/unmarshal_test.go @@ -0,0 +1,35 @@ +package gitaly + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" +) + +func TestUnmarshalJSON(t *testing.T) { + testCases := []struct { + desc string + in string + out gitalypb.Repository + }{ + { + desc: "basic example", + in: `{"relative_path":"foo/bar.git"}`, + out: gitalypb.Repository{RelativePath: "foo/bar.git"}, + }, + { + desc: "unknown field", + in: `{"relative_path":"foo/bar.git","unknown_field":12345}`, + out: gitalypb.Repository{RelativePath: "foo/bar.git"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := gitalypb.Repository{} + require.NoError(t, UnmarshalJSON(tc.in, &result)) + require.Equal(t, tc.out, result) + }) + } +} diff --git a/workhorse/internal/headers/content_headers.go b/workhorse/internal/headers/content_headers.go new file mode 100644 index 00000000000..e43f10745d4 --- /dev/null +++ b/workhorse/internal/headers/content_headers.go @@ -0,0 +1,109 @@ +package headers + +import ( + "net/http" + "regexp" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/utils/svg" +) + +var ( + ImageTypeRegex = regexp.MustCompile(`^image/*`) + SvgMimeTypeRegex = regexp.MustCompile(`^image/svg\+xml$`) + + TextTypeRegex = regexp.MustCompile(`^text/*`) + + VideoTypeRegex = regexp.MustCompile(`^video/*`) + + PdfTypeRegex = regexp.MustCompile(`application\/pdf`) + + AttachmentRegex = regexp.MustCompile(`^attachment`) + InlineRegex = regexp.MustCompile(`^inline`) +) + +// Mime types that can't be inlined. Usually subtypes of main types +var forbiddenInlineTypes = []*regexp.Regexp{SvgMimeTypeRegex} + +// Mime types that can be inlined. We can add global types like "image/" or +// specific types like "text/plain". If there is a specific type inside a global +// allowed type that can't be inlined we must add it to the forbiddenInlineTypes var. +// One example of this is the mime type "image". We allow all images to be +// inlined except for SVGs. +var allowedInlineTypes = []*regexp.Regexp{ImageTypeRegex, TextTypeRegex, VideoTypeRegex, PdfTypeRegex} + +func SafeContentHeaders(data []byte, contentDisposition string) (string, string) { + contentType := safeContentType(data) + contentDisposition = safeContentDisposition(contentType, contentDisposition) + return contentType, contentDisposition +} + +func safeContentType(data []byte) string { + // Special case for svg because DetectContentType detects it as text + if svg.Is(data) { + return "image/svg+xml" + } + + // Override any existing Content-Type header from other ResponseWriters + contentType := http.DetectContentType(data) + + // If the content is text type, we set to plain, because we don't + // want to render it inline if they're html or javascript + if isType(contentType, TextTypeRegex) { + return "text/plain; charset=utf-8" + } + + return contentType +} + +func safeContentDisposition(contentType string, contentDisposition string) string { + // If the existing disposition is attachment we return that. This allow us + // to force a download from GitLab (ie: RawController) + if AttachmentRegex.MatchString(contentDisposition) { + return contentDisposition + } + + // Checks for mime types that are forbidden to be inline + for _, element := range forbiddenInlineTypes { + if isType(contentType, element) { + return attachmentDisposition(contentDisposition) + } + } + + // Checks for mime types allowed to be inline + for _, element := range allowedInlineTypes { + if isType(contentType, element) { + return inlineDisposition(contentDisposition) + } + } + + // Anything else is set to attachment + return attachmentDisposition(contentDisposition) +} + +func attachmentDisposition(contentDisposition string) string { + if contentDisposition == "" { + return "attachment" + } + + if InlineRegex.MatchString(contentDisposition) { + return InlineRegex.ReplaceAllString(contentDisposition, "attachment") + } + + return contentDisposition +} + +func inlineDisposition(contentDisposition string) string { + if contentDisposition == "" { + return "inline" + } + + if AttachmentRegex.MatchString(contentDisposition) { + return AttachmentRegex.ReplaceAllString(contentDisposition, "inline") + } + + return contentDisposition +} + +func isType(contentType string, mimeType *regexp.Regexp) bool { + return mimeType.MatchString(contentType) +} diff --git a/workhorse/internal/headers/headers.go b/workhorse/internal/headers/headers.go new file mode 100644 index 00000000000..63b39a6aa41 --- /dev/null +++ b/workhorse/internal/headers/headers.go @@ -0,0 +1,62 @@ +package headers + +import ( + "net/http" + "strconv" +) + +// Max number of bytes that http.DetectContentType needs to get the content type +// Fixme: Go back to 512 bytes once https://gitlab.com/gitlab-org/gitlab-workhorse/issues/208 +// has been merged +const MaxDetectSize = 4096 + +// HTTP Headers +const ( + ContentDispositionHeader = "Content-Disposition" + ContentTypeHeader = "Content-Type" + + // Workhorse related headers + GitlabWorkhorseSendDataHeader = "Gitlab-Workhorse-Send-Data" + XSendFileHeader = "X-Sendfile" + XSendFileTypeHeader = "X-Sendfile-Type" + + // Signal header that indicates Workhorse should detect and set the content headers + GitlabWorkhorseDetectContentTypeHeader = "Gitlab-Workhorse-Detect-Content-Type" +) + +var ResponseHeaders = []string{ + XSendFileHeader, + GitlabWorkhorseSendDataHeader, + GitlabWorkhorseDetectContentTypeHeader, +} + +func IsDetectContentTypeHeaderPresent(rw http.ResponseWriter) bool { + header, err := strconv.ParseBool(rw.Header().Get(GitlabWorkhorseDetectContentTypeHeader)) + if err != nil || !header { + return false + } + + return true +} + +// AnyResponseHeaderPresent checks in the ResponseWriter if there is any Response Header +func AnyResponseHeaderPresent(rw http.ResponseWriter) bool { + // If this header is not present means that we want the old behavior + if !IsDetectContentTypeHeaderPresent(rw) { + return false + } + + for _, header := range ResponseHeaders { + if rw.Header().Get(header) != "" { + return true + } + } + return false +} + +// RemoveResponseHeaders removes any ResponseHeader from the ResponseWriter +func RemoveResponseHeaders(rw http.ResponseWriter) { + for _, header := range ResponseHeaders { + rw.Header().Del(header) + } +} diff --git a/workhorse/internal/headers/headers_test.go b/workhorse/internal/headers/headers_test.go new file mode 100644 index 00000000000..555406ff165 --- /dev/null +++ b/workhorse/internal/headers/headers_test.go @@ -0,0 +1,24 @@ +package headers + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsDetectContentTypeHeaderPresent(t *testing.T) { + rw := httptest.NewRecorder() + + rw.Header().Del(GitlabWorkhorseDetectContentTypeHeader) + require.Equal(t, false, IsDetectContentTypeHeaderPresent(rw)) + + rw.Header().Set(GitlabWorkhorseDetectContentTypeHeader, "true") + require.Equal(t, true, IsDetectContentTypeHeaderPresent(rw)) + + rw.Header().Set(GitlabWorkhorseDetectContentTypeHeader, "false") + require.Equal(t, false, IsDetectContentTypeHeaderPresent(rw)) + + rw.Header().Set(GitlabWorkhorseDetectContentTypeHeader, "foobar") + require.Equal(t, false, IsDetectContentTypeHeaderPresent(rw)) +} diff --git a/workhorse/internal/helper/context_reader.go b/workhorse/internal/helper/context_reader.go new file mode 100644 index 00000000000..a4764043147 --- /dev/null +++ b/workhorse/internal/helper/context_reader.go @@ -0,0 +1,40 @@ +package helper + +import ( + "context" + "io" +) + +type ContextReader struct { + ctx context.Context + underlyingReader io.Reader +} + +func NewContextReader(ctx context.Context, underlyingReader io.Reader) *ContextReader { + return &ContextReader{ + ctx: ctx, + underlyingReader: underlyingReader, + } +} + +func (r *ContextReader) Read(b []byte) (int, error) { + if r.canceled() { + return 0, r.err() + } + + n, err := r.underlyingReader.Read(b) + + if r.canceled() { + err = r.err() + } + + return n, err +} + +func (r *ContextReader) canceled() bool { + return r.err() != nil +} + +func (r *ContextReader) err() error { + return r.ctx.Err() +} diff --git a/workhorse/internal/helper/context_reader_test.go b/workhorse/internal/helper/context_reader_test.go new file mode 100644 index 00000000000..257ec4e35f2 --- /dev/null +++ b/workhorse/internal/helper/context_reader_test.go @@ -0,0 +1,83 @@ +package helper + +import ( + "context" + "io" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type fakeReader struct { + n int + err error +} + +func (f *fakeReader) Read(b []byte) (int, error) { + return f.n, f.err +} + +type fakeContextWithTimeout struct { + n int + threshold int +} + +func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*fakeContextWithTimeout) Done() <-chan struct{} { + return nil +} + +func (*fakeContextWithTimeout) Value(key interface{}) interface{} { + return nil +} + +func (f *fakeContextWithTimeout) Err() error { + f.n++ + if f.n > f.threshold { + return context.DeadlineExceeded + } + + return nil +} + +func TestContextReaderRead(t *testing.T) { + underlyingReader := &fakeReader{n: 1, err: io.EOF} + + for _, tc := range []struct { + desc string + ctx *fakeContextWithTimeout + expectedN int + expectedErr error + }{ + { + desc: "Before and after read deadline checks are fine", + ctx: &fakeContextWithTimeout{n: 0, threshold: 2}, + expectedN: underlyingReader.n, + expectedErr: underlyingReader.err, + }, + { + desc: "Before read deadline check fails", + ctx: &fakeContextWithTimeout{n: 0, threshold: 0}, + expectedN: 0, + expectedErr: context.DeadlineExceeded, + }, + { + desc: "After read deadline check fails", + ctx: &fakeContextWithTimeout{n: 0, threshold: 1}, + expectedN: underlyingReader.n, + expectedErr: context.DeadlineExceeded, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + cr := NewContextReader(tc.ctx, underlyingReader) + + n, err := cr.Read(nil) + require.Equal(t, tc.expectedN, n) + require.Equal(t, tc.expectedErr, err) + }) + } +} diff --git a/workhorse/internal/helper/countingresponsewriter.go b/workhorse/internal/helper/countingresponsewriter.go new file mode 100644 index 00000000000..a79d51d4c6a --- /dev/null +++ b/workhorse/internal/helper/countingresponsewriter.go @@ -0,0 +1,56 @@ +package helper + +import ( + "net/http" +) + +type CountingResponseWriter interface { + http.ResponseWriter + Count() int64 + Status() int +} + +type countingResponseWriter struct { + rw http.ResponseWriter + status int + count int64 +} + +func NewCountingResponseWriter(rw http.ResponseWriter) CountingResponseWriter { + return &countingResponseWriter{rw: rw} +} + +func (c *countingResponseWriter) Header() http.Header { + return c.rw.Header() +} + +func (c *countingResponseWriter) Write(data []byte) (int, error) { + if c.status == 0 { + c.WriteHeader(http.StatusOK) + } + + n, err := c.rw.Write(data) + c.count += int64(n) + return n, err +} + +func (c *countingResponseWriter) WriteHeader(status int) { + if c.status != 0 { + return + } + + c.status = status + c.rw.WriteHeader(status) +} + +// Count returns the number of bytes written to the ResponseWriter. This +// function is not thread-safe. +func (c *countingResponseWriter) Count() int64 { + return c.count +} + +// Status returns the first HTTP status value that was written to the +// ResponseWriter. This function is not thread-safe. +func (c *countingResponseWriter) Status() int { + return c.status +} diff --git a/workhorse/internal/helper/countingresponsewriter_test.go b/workhorse/internal/helper/countingresponsewriter_test.go new file mode 100644 index 00000000000..f9f2f4ced5b --- /dev/null +++ b/workhorse/internal/helper/countingresponsewriter_test.go @@ -0,0 +1,50 @@ +package helper + +import ( + "bytes" + "io" + "net/http" + "testing" + "testing/iotest" + + "github.com/stretchr/testify/require" +) + +type testResponseWriter struct { + data []byte +} + +func (*testResponseWriter) WriteHeader(int) {} +func (*testResponseWriter) Header() http.Header { return nil } + +func (trw *testResponseWriter) Write(p []byte) (int, error) { + trw.data = append(trw.data, p...) + return len(p), nil +} + +func TestCountingResponseWriterStatus(t *testing.T) { + crw := NewCountingResponseWriter(&testResponseWriter{}) + crw.WriteHeader(123) + crw.WriteHeader(456) + require.Equal(t, 123, crw.Status()) +} + +func TestCountingResponseWriterCount(t *testing.T) { + crw := NewCountingResponseWriter(&testResponseWriter{}) + for _, n := range []int{1, 2, 4, 8, 16, 32} { + _, err := crw.Write(bytes.Repeat([]byte{'.'}, n)) + require.NoError(t, err) + } + require.Equal(t, int64(63), crw.Count()) +} + +func TestCountingResponseWriterWrite(t *testing.T) { + trw := &testResponseWriter{} + crw := NewCountingResponseWriter(trw) + + testData := []byte("test data") + _, err := io.Copy(crw, iotest.OneByteReader(bytes.NewReader(testData))) + require.NoError(t, err) + + require.Equal(t, string(testData), string(trw.data)) +} diff --git a/workhorse/internal/helper/helpers.go b/workhorse/internal/helper/helpers.go new file mode 100644 index 00000000000..5f1e5fc51b3 --- /dev/null +++ b/workhorse/internal/helper/helpers.go @@ -0,0 +1,217 @@ +package helper + +import ( + "bytes" + "errors" + "io/ioutil" + "mime" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "strings" + "syscall" + + "github.com/sebest/xff" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" +) + +const NginxResponseBufferHeader = "X-Accel-Buffering" + +func LogError(r *http.Request, err error) { + LogErrorWithFields(r, err, nil) +} + +func LogErrorWithFields(r *http.Request, err error, fields log.Fields) { + if err != nil { + captureRavenError(r, err, fields) + } + + printError(r, err, fields) +} + +func CaptureAndFail(w http.ResponseWriter, r *http.Request, err error, msg string, code int) { + http.Error(w, msg, code) + LogError(r, err) +} + +func CaptureAndFailWithFields(w http.ResponseWriter, r *http.Request, err error, msg string, code int, fields log.Fields) { + http.Error(w, msg, code) + LogErrorWithFields(r, err, fields) +} + +func Fail500(w http.ResponseWriter, r *http.Request, err error) { + CaptureAndFail(w, r, err, "Internal server error", http.StatusInternalServerError) +} + +func Fail500WithFields(w http.ResponseWriter, r *http.Request, err error, fields log.Fields) { + CaptureAndFailWithFields(w, r, err, "Internal server error", http.StatusInternalServerError, fields) +} + +func RequestEntityTooLarge(w http.ResponseWriter, r *http.Request, err error) { + CaptureAndFail(w, r, err, "Request Entity Too Large", http.StatusRequestEntityTooLarge) +} + +func printError(r *http.Request, err error, fields log.Fields) { + if r != nil { + entry := log.WithContextFields(r.Context(), log.Fields{ + "method": r.Method, + "uri": mask.URL(r.RequestURI), + }) + entry.WithFields(fields).WithError(err).Error("error") + } else { + log.WithFields(fields).WithError(err).Error("unknown error") + } +} + +func SetNoCacheHeaders(header http.Header) { + header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + header.Set("Pragma", "no-cache") + header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") +} + +func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) { + file, err = os.Open(path) + if err != nil { + return + } + + defer func() { + if err != nil { + file.Close() + } + }() + + fi, err = file.Stat() + if err != nil { + return + } + + // The os.Open can also open directories + if fi.IsDir() { + err = &os.PathError{ + Op: "open", + Path: path, + Err: errors.New("path is directory"), + } + return + } + + return +} + +func URLMustParse(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + log.WithError(err).WithField("url", s).Fatal("urlMustParse") + } + return u +} + +func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) { + if r.ProtoAtLeast(1, 1) { + // Force client to disconnect if we render request error + w.Header().Set("Connection", "close") + } + + http.Error(w, error, code) +} + +func HeaderClone(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +func CleanUpProcessGroup(cmd *exec.Cmd) { + if cmd == nil { + return + } + + process := cmd.Process + if process != nil && process.Pid > 0 { + // Send SIGTERM to the process group of cmd + syscall.Kill(-process.Pid, syscall.SIGTERM) + } + + // reap our child process + cmd.Wait() +} + +func ExitStatus(err error) (int, bool) { + exitError, ok := err.(*exec.ExitError) + if !ok { + return 0, false + } + + waitStatus, ok := exitError.Sys().(syscall.WaitStatus) + if !ok { + return 0, false + } + + return waitStatus.ExitStatus(), true +} + +func DisableResponseBuffering(w http.ResponseWriter) { + w.Header().Set(NginxResponseBufferHeader, "no") +} + +func AllowResponseBuffering(w http.ResponseWriter) { + w.Header().Del(NginxResponseBufferHeader) +} + +func FixRemoteAddr(r *http.Request) { + // Unix domain sockets have a remote addr of @. This will make the + // xff package lookup the X-Forwarded-For address if available. + if r.RemoteAddr == "@" { + r.RemoteAddr = "127.0.0.1:0" + } + r.RemoteAddr = xff.GetRemoteAddr(r) +} + +func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) { + if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil { + var header string + + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok { + header = strings.Join(prior, ", ") + ", " + clientIP + } else { + header = clientIP + } + newHeaders.Set("X-Forwarded-For", header) + } +} + +func IsContentType(expected, actual string) bool { + parsed, _, err := mime.ParseMediaType(actual) + return err == nil && parsed == expected +} + +func IsApplicationJson(r *http.Request) bool { + contentType := r.Header.Get("Content-Type") + return IsContentType("application/json", contentType) +} + +func ReadRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) { + limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize) + defer limitedBody.Close() + + return ioutil.ReadAll(limitedBody) +} + +func CloneRequestWithNewBody(r *http.Request, body []byte) *http.Request { + newReq := *r + newReq.Body = ioutil.NopCloser(bytes.NewReader(body)) + newReq.Header = HeaderClone(r.Header) + newReq.ContentLength = int64(len(body)) + return &newReq +} diff --git a/workhorse/internal/helper/helpers_test.go b/workhorse/internal/helper/helpers_test.go new file mode 100644 index 00000000000..6a895aded03 --- /dev/null +++ b/workhorse/internal/helper/helpers_test.go @@ -0,0 +1,258 @@ +package helper + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestFixRemoteAddr(t *testing.T) { + testCases := []struct { + initial string + forwarded string + expected string + }{ + {initial: "@", forwarded: "", expected: "127.0.0.1:0"}, + {initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, + {initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"}, + {initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"}, + {initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"}, + {initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, + } + + for _, tc := range testCases { + req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil) + require.NoError(t, err) + + req.RemoteAddr = tc.initial + + if tc.forwarded != "" { + req.Header.Add("X-Forwarded-For", tc.forwarded) + } + + FixRemoteAddr(req) + + require.Equal(t, tc.expected, req.RemoteAddr) + } +} + +func TestSetForwardedForGeneratesHeader(t *testing.T) { + testCases := []struct { + remoteAddr string + previousForwardedFor []string + expected string + }{ + { + "8.8.8.8:3000", + nil, + "8.8.8.8", + }, + { + "8.8.8.8:3000", + []string{"138.124.33.63, 151.146.211.237"}, + "138.124.33.63, 151.146.211.237, 8.8.8.8", + }, + { + "8.8.8.8:3000", + []string{"8.154.76.107", "115.206.118.179"}, + "8.154.76.107, 115.206.118.179, 8.8.8.8", + }, + } + for _, tc := range testCases { + headers := http.Header{} + originalRequest := http.Request{ + RemoteAddr: tc.remoteAddr, + } + + if tc.previousForwardedFor != nil { + originalRequest.Header = http.Header{ + "X-Forwarded-For": tc.previousForwardedFor, + } + } + + SetForwardedFor(&headers, &originalRequest) + + result := headers.Get("X-Forwarded-For") + if result != tc.expected { + t.Fatalf("Expected %v, got %v", tc.expected, result) + } + } +} + +func TestReadRequestBody(t *testing.T) { + data := []byte("123456") + rw := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) + + result, err := ReadRequestBody(rw, req, 1000) + require.NoError(t, err) + require.Equal(t, data, result) +} + +func TestReadRequestBodyLimit(t *testing.T) { + data := []byte("123456") + rw := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) + + _, err := ReadRequestBody(rw, req, 2) + require.Error(t, err) +} + +func TestCloneRequestWithBody(t *testing.T) { + input := []byte("test") + newInput := []byte("new body") + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input)) + newReq := CloneRequestWithNewBody(req, newInput) + + require.NotEqual(t, req, newReq) + require.NotEqual(t, req.Body, newReq.Body) + require.NotEqual(t, len(newInput), newReq.ContentLength) + + var buffer bytes.Buffer + io.Copy(&buffer, newReq.Body) + require.Equal(t, newInput, buffer.Bytes()) +} + +func TestApplicationJson(t *testing.T) { + req, _ := http.NewRequest("POST", "/test", nil) + req.Header.Set("Content-Type", "application/json") + + require.True(t, IsApplicationJson(req), "expected to match 'application/json' as 'application/json'") + + req.Header.Set("Content-Type", "application/json; charset=utf-8") + require.True(t, IsApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'") + + req.Header.Set("Content-Type", "text/plain") + require.False(t, IsApplicationJson(req), "expected not to match 'text/plain' as 'application/json'") +} + +func TestFail500WorksWithNils(t *testing.T) { + body := bytes.NewBuffer(nil) + w := httptest.NewRecorder() + w.Body = body + + Fail500(w, nil, nil) + + require.Equal(t, http.StatusInternalServerError, w.Code) + require.Equal(t, "Internal server error\n", body.String()) +} + +func TestLogError(t *testing.T) { + tests := []struct { + name string + method string + uri string + err error + logMatchers []string + }{ + { + name: "nil_request", + err: fmt.Errorf("crash"), + logMatchers: []string{ + `level=error msg="unknown error" error=crash`, + }, + }, + { + name: "nil_request_nil_error", + err: nil, + logMatchers: []string{ + `level=error msg="unknown error" error="<nil>"`, + }, + }, + { + name: "basic_url", + method: "GET", + uri: "http://localhost:3000/", + err: fmt.Errorf("error"), + logMatchers: []string{ + `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/"`, + }, + }, + { + name: "secret_url", + method: "GET", + uri: "http://localhost:3000/path?certificate=123&sharedSecret=123&import_url=the_url&my_password_string=password", + err: fmt.Errorf("error"), + logMatchers: []string{ + `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/path\?certificate=\[FILTERED\]&sharedSecret=\[FILTERED\]&import_url=\[FILTERED\]&my_password_string=\[FILTERED\]"`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + + oldOut := logrus.StandardLogger().Out + logrus.StandardLogger().Out = buf + defer func() { + logrus.StandardLogger().Out = oldOut + }() + + var r *http.Request + if tt.uri != "" { + r = httptest.NewRequest(tt.method, tt.uri, nil) + } + + LogError(r, tt.err) + + logString := buf.String() + for _, v := range tt.logMatchers { + require.Regexp(t, v, logString) + } + }) + } +} + +func TestLogErrorWithFields(t *testing.T) { + tests := []struct { + name string + request *http.Request + err error + fields map[string]interface{} + logMatcher string + }{ + { + name: "nil_request", + err: fmt.Errorf("crash"), + fields: map[string]interface{}{"extra_one": 123}, + logMatcher: `level=error msg="unknown error" error=crash extra_one=123`, + }, + { + name: "nil_request_nil_error", + err: nil, + fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"}, + logMatcher: `level=error msg="unknown error" error="<nil>" extra_one=123 extra_two=test`, + }, + { + name: "basic_url", + request: httptest.NewRequest("GET", "http://localhost:3000/", nil), + err: fmt.Errorf("error"), + fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"}, + logMatcher: `level=error msg=error correlation_id= error=error extra_one=123 extra_two=test method=GET uri="http://localhost:3000/`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + + oldOut := logrus.StandardLogger().Out + logrus.StandardLogger().Out = buf + defer func() { + logrus.StandardLogger().Out = oldOut + }() + + LogErrorWithFields(tt.request, tt.err, tt.fields) + + logString := buf.String() + require.Contains(t, logString, tt.logMatcher) + }) + } +} diff --git a/workhorse/internal/helper/raven.go b/workhorse/internal/helper/raven.go new file mode 100644 index 00000000000..ea1d0e1f6cc --- /dev/null +++ b/workhorse/internal/helper/raven.go @@ -0,0 +1,58 @@ +package helper + +import ( + "net/http" + "reflect" + + raven "github.com/getsentry/raven-go" + + //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package. + correlation "gitlab.com/gitlab-org/labkit/correlation/raven" + + "gitlab.com/gitlab-org/labkit/log" +) + +var ravenHeaderBlacklist = []string{ + "Authorization", + "Private-Token", +} + +func captureRavenError(r *http.Request, err error, fields log.Fields) { + client := raven.DefaultClient + extra := raven.Extra{} + + for k, v := range fields { + extra[k] = v + } + + interfaces := []raven.Interface{} + if r != nil { + CleanHeadersForRaven(r) + interfaces = append(interfaces, raven.NewHttp(r)) + + //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package. + extra = correlation.SetExtra(r.Context(), extra) + } + + exception := &raven.Exception{ + Stacktrace: raven.NewStacktrace(2, 3, nil), + Value: err.Error(), + Type: reflect.TypeOf(err).String(), + } + interfaces = append(interfaces, exception) + + packet := raven.NewPacketWithExtra(err.Error(), extra, interfaces...) + client.Capture(packet, nil) +} + +func CleanHeadersForRaven(r *http.Request) { + if r == nil { + return + } + + for _, key := range ravenHeaderBlacklist { + if r.Header.Get(key) != "" { + r.Header.Set(key, "[redacted]") + } + } +} diff --git a/workhorse/internal/helper/tempfile.go b/workhorse/internal/helper/tempfile.go new file mode 100644 index 00000000000..d8fc0d44698 --- /dev/null +++ b/workhorse/internal/helper/tempfile.go @@ -0,0 +1,35 @@ +package helper + +import ( + "io" + "io/ioutil" + "os" +) + +func ReadAllTempfile(r io.Reader) (tempfile *os.File, err error) { + tempfile, err = ioutil.TempFile("", "gitlab-workhorse-read-all-tempfile") + if err != nil { + return nil, err + } + + defer func() { + // Avoid leaking an open file if the function returns with an error + if err != nil { + tempfile.Close() + } + }() + + if err := os.Remove(tempfile.Name()); err != nil { + return nil, err + } + + if _, err := io.Copy(tempfile, r); err != nil { + return nil, err + } + + if _, err := tempfile.Seek(0, 0); err != nil { + return nil, err + } + + return tempfile, nil +} diff --git a/workhorse/internal/helper/writeafterreader.go b/workhorse/internal/helper/writeafterreader.go new file mode 100644 index 00000000000..d583ae4a9b8 --- /dev/null +++ b/workhorse/internal/helper/writeafterreader.go @@ -0,0 +1,144 @@ +package helper + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "sync" +) + +type WriteFlusher interface { + io.Writer + Flush() error +} + +// Couple r and w so that until r has been drained (before r.Read() has +// returned some error), all writes to w are sent to a tempfile first. +// The caller must call Flush() on the returned WriteFlusher to ensure +// all data is propagated to w. +func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) { + br := &busyReader{Reader: r} + return br, &coupledWriter{Writer: w, busyReader: br} +} + +type busyReader struct { + io.Reader + + error + errorMutex sync.RWMutex +} + +func (r *busyReader) Read(p []byte) (int, error) { + if err := r.getError(); err != nil { + return 0, err + } + + n, err := r.Reader.Read(p) + if err != nil { + if err != io.EOF { + err = fmt.Errorf("busyReader: %v", err) + } + r.setError(err) + } + return n, err +} + +func (r *busyReader) IsBusy() bool { + return r.getError() == nil +} + +func (r *busyReader) getError() error { + r.errorMutex.RLock() + defer r.errorMutex.RUnlock() + return r.error +} + +func (r *busyReader) setError(err error) { + if err == nil { + panic("busyReader: attempt to reset error to nil") + } + r.errorMutex.Lock() + defer r.errorMutex.Unlock() + r.error = err +} + +type coupledWriter struct { + io.Writer + *busyReader + + tempfile *os.File + tempfileMutex sync.Mutex + + writeError error +} + +func (w *coupledWriter) Write(data []byte) (int, error) { + if w.writeError != nil { + return 0, w.writeError + } + + if w.busyReader.IsBusy() { + n, err := w.tempfileWrite(data) + if err != nil { + w.writeError = fmt.Errorf("coupledWriter: %v", err) + } + return n, w.writeError + } + + if err := w.Flush(); err != nil { + w.writeError = fmt.Errorf("coupledWriter: %v", err) + return 0, w.writeError + } + + return w.Writer.Write(data) +} + +func (w *coupledWriter) Flush() error { + w.tempfileMutex.Lock() + defer w.tempfileMutex.Unlock() + + tempfile := w.tempfile + if tempfile == nil { + return nil + } + + w.tempfile = nil + defer tempfile.Close() + + if _, err := tempfile.Seek(0, 0); err != nil { + return err + } + if _, err := io.Copy(w.Writer, tempfile); err != nil { + return err + } + return nil +} + +func (w *coupledWriter) tempfileWrite(data []byte) (int, error) { + w.tempfileMutex.Lock() + defer w.tempfileMutex.Unlock() + + if w.tempfile == nil { + tempfile, err := w.newTempfile() + if err != nil { + return 0, err + } + w.tempfile = tempfile + } + + return w.tempfile.Write(data) +} + +func (*coupledWriter) newTempfile() (tempfile *os.File, err error) { + tempfile, err = ioutil.TempFile("", "gitlab-workhorse-coupledWriter") + if err != nil { + return nil, err + } + if err := os.Remove(tempfile.Name()); err != nil { + tempfile.Close() + return nil, err + } + + return tempfile, nil +} diff --git a/workhorse/internal/helper/writeafterreader_test.go b/workhorse/internal/helper/writeafterreader_test.go new file mode 100644 index 00000000000..67cb3e6e542 --- /dev/null +++ b/workhorse/internal/helper/writeafterreader_test.go @@ -0,0 +1,115 @@ +package helper + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "testing" + "testing/iotest" +) + +func TestBusyReader(t *testing.T) { + testData := "test data" + r := testReader(testData) + br, _ := NewWriteAfterReader(r, &bytes.Buffer{}) + + result, err := ioutil.ReadAll(br) + if err != nil { + t.Fatal(err) + } + + if string(result) != testData { + t.Fatalf("expected %q, got %q", testData, result) + } +} + +func TestFirstWriteAfterReadDone(t *testing.T) { + writeRecorder := &bytes.Buffer{} + br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder) + if _, err := io.Copy(ioutil.Discard, br); err != nil { + t.Fatalf("copy from busyreader: %v", err) + } + testData := "test data" + if _, err := io.Copy(cw, testReader(testData)); err != nil { + t.Fatalf("copy test data: %v", err) + } + if err := cw.Flush(); err != nil { + t.Fatalf("flush error: %v", err) + } + if result := writeRecorder.String(); result != testData { + t.Fatalf("expected %q, got %q", testData, result) + } +} + +func TestWriteDelay(t *testing.T) { + writeRecorder := &bytes.Buffer{} + w := &complainingWriter{Writer: writeRecorder} + br, cw := NewWriteAfterReader(&bytes.Buffer{}, w) + + testData1 := "1 test" + if _, err := io.Copy(cw, testReader(testData1)); err != nil { + t.Fatalf("error on first copy: %v", err) + } + + // Unblock the coupled writer by draining the reader + if _, err := io.Copy(ioutil.Discard, br); err != nil { + t.Fatalf("copy from busyreader: %v", err) + } + // Now it is no longer an error if 'w' receives a Write() + w.CheerUp() + + testData2 := "2 experiment" + if _, err := io.Copy(cw, testReader(testData2)); err != nil { + t.Fatalf("error on second copy: %v", err) + } + + if err := cw.Flush(); err != nil { + t.Fatalf("flush error: %v", err) + } + + expected := testData1 + testData2 + if result := writeRecorder.String(); result != expected { + t.Fatalf("total write: expected %q, got %q", expected, result) + } +} + +func TestComplainingWriterSanity(t *testing.T) { + recorder := &bytes.Buffer{} + w := &complainingWriter{Writer: recorder} + + testData := "test data" + if _, err := io.Copy(w, testReader(testData)); err == nil { + t.Error("error expected, none received") + } + + w.CheerUp() + if _, err := io.Copy(w, testReader(testData)); err != nil { + t.Errorf("copy after CheerUp: %v", err) + } + + if result := recorder.String(); result != testData { + t.Errorf("expected %q, got %q", testData, result) + } +} + +func testReader(data string) io.Reader { + return iotest.OneByteReader(bytes.NewBuffer([]byte(data))) +} + +type complainingWriter struct { + happy bool + io.Writer +} + +func (comp *complainingWriter) Write(data []byte) (int, error) { + if comp.happy { + return comp.Writer.Write(data) + } + + return 0, fmt.Errorf("I am unhappy about you wanting to write %q", data) +} + +func (comp *complainingWriter) CheerUp() { + comp.happy = true +} diff --git a/workhorse/internal/httprs/LICENSE b/workhorse/internal/httprs/LICENSE new file mode 100644 index 00000000000..58b9dd5ced1 --- /dev/null +++ b/workhorse/internal/httprs/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2015 Jean-François Bustarret + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE.
\ No newline at end of file diff --git a/workhorse/internal/httprs/README.md b/workhorse/internal/httprs/README.md new file mode 100644 index 00000000000..4f42489ab73 --- /dev/null +++ b/workhorse/internal/httprs/README.md @@ -0,0 +1,2 @@ +This directory contains a vendored copy of https://github.com/jfbus/httprs at commit SHA +b0af8319bb15446bbf29715477f841a49330a1e7. diff --git a/workhorse/internal/httprs/httprs.go b/workhorse/internal/httprs/httprs.go new file mode 100644 index 00000000000..a38230c1968 --- /dev/null +++ b/workhorse/internal/httprs/httprs.go @@ -0,0 +1,217 @@ +/* +Package httprs provides a ReadSeeker for http.Response.Body. + +Usage : + + resp, err := http.Get(url) + rs := httprs.NewHttpReadSeeker(resp) + defer rs.Close() + io.ReadFull(rs, buf) // reads the first bytes from the response body + rs.Seek(1024, 0) // moves the position, but does no range request + io.ReadFull(rs, buf) // does a range request and reads from the response body + +If you want use a specific http.Client for additional range requests : + rs := httprs.NewHttpReadSeeker(resp, client) +*/ +package httprs + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/mitchellh/copystructure" +) + +const shortSeekBytes = 1024 + +// A HttpReadSeeker reads from a http.Response.Body. It can Seek +// by doing range requests. +type HttpReadSeeker struct { + c *http.Client + req *http.Request + res *http.Response + ctx context.Context + r io.ReadCloser + pos int64 + + Requests int +} + +var _ io.ReadCloser = (*HttpReadSeeker)(nil) +var _ io.Seeker = (*HttpReadSeeker)(nil) + +var ( + // ErrNoContentLength is returned by Seek when the initial http response did not include a Content-Length header + ErrNoContentLength = errors.New("header Content-Length was not set") + // ErrRangeRequestsNotSupported is returned by Seek and Read + // when the remote server does not allow range requests (Accept-Ranges was not set) + ErrRangeRequestsNotSupported = errors.New("range requests are not supported by the remote server") + // ErrInvalidRange is returned by Read when trying to read past the end of the file + ErrInvalidRange = errors.New("invalid range") + // ErrContentHasChanged is returned by Read when the content has changed since the first request + ErrContentHasChanged = errors.New("content has changed since first request") +) + +// NewHttpReadSeeker returns a HttpReadSeeker, using the http.Response and, optionaly, the http.Client +// that needs to be used for future range requests. If no http.Client is given, http.DefaultClient will +// be used. +// +// res.Request will be reused for range requests, headers may be added/removed +func NewHttpReadSeeker(res *http.Response, client ...*http.Client) *HttpReadSeeker { + r := &HttpReadSeeker{ + req: res.Request, + ctx: res.Request.Context(), + res: res, + r: res.Body, + } + if len(client) > 0 { + r.c = client[0] + } else { + r.c = http.DefaultClient + } + return r +} + +// Clone clones the reader to enable parallel downloads of ranges +func (r *HttpReadSeeker) Clone() (*HttpReadSeeker, error) { + req, err := copystructure.Copy(r.req) + if err != nil { + return nil, err + } + return &HttpReadSeeker{ + req: req.(*http.Request), + res: r.res, + r: nil, + c: r.c, + }, nil +} + +// Read reads from the response body. It does a range request if Seek was called before. +// +// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged +func (r *HttpReadSeeker) Read(p []byte) (n int, err error) { + if r.r == nil { + err = r.rangeRequest() + } + if r.r != nil { + n, err = r.r.Read(p) + r.pos += int64(n) + } + return +} + +// ReadAt reads from the response body starting at offset off. +// +// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged +func (r *HttpReadSeeker) ReadAt(p []byte, off int64) (n int, err error) { + var nn int + + r.Seek(off, 0) + + for n < len(p) && err == nil { + nn, err = r.Read(p[n:]) + n += nn + } + return +} + +// Close closes the response body +func (r *HttpReadSeeker) Close() error { + if r.r != nil { + return r.r.Close() + } + return nil +} + +// Seek moves the reader position to a new offset. +// +// It does not send http requests, allowing for multiple seeks without overhead. +// The http request will be sent by the next Read call. +// +// May return ErrNoContentLength or ErrRangeRequestsNotSupported +func (r *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) { + var err error + switch whence { + case 0: + case 1: + offset += r.pos + case 2: + if r.res.ContentLength <= 0 { + return 0, ErrNoContentLength + } + offset = r.res.ContentLength - offset + } + if r.r != nil { + // Try to read, which is cheaper than doing a request + if r.pos < offset && offset-r.pos <= shortSeekBytes { + _, err := io.CopyN(ioutil.Discard, r, offset-r.pos) + if err != nil { + return 0, err + } + } + + if r.pos != offset { + err = r.r.Close() + r.r = nil + } + } + r.pos = offset + return r.pos, err +} + +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +func (r *HttpReadSeeker) newRequest() *http.Request { + newreq := r.req.WithContext(r.ctx) // includes shallow copies of maps, but okay + if r.req.ContentLength == 0 { + newreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + newreq.Header = cloneHeader(r.req.Header) + return newreq +} + +func (r *HttpReadSeeker) rangeRequest() error { + r.req = r.newRequest() + r.req.Header.Set("Range", fmt.Sprintf("bytes=%d-", r.pos)) + etag, last := r.res.Header.Get("ETag"), r.res.Header.Get("Last-Modified") + switch { + case last != "": + r.req.Header.Set("If-Range", last) + case etag != "": + r.req.Header.Set("If-Range", etag) + } + + r.Requests++ + + res, err := r.c.Do(r.req) + if err != nil { + return err + } + switch res.StatusCode { + case http.StatusRequestedRangeNotSatisfiable: + return ErrInvalidRange + case http.StatusOK: + // some servers return 200 OK for bytes=0- + if r.pos > 0 || + (etag != "" && etag != res.Header.Get("ETag")) { + return ErrContentHasChanged + } + fallthrough + case http.StatusPartialContent: + r.r = res.Body + return nil + } + return ErrRangeRequestsNotSupported +} diff --git a/workhorse/internal/httprs/httprs_test.go b/workhorse/internal/httprs/httprs_test.go new file mode 100644 index 00000000000..62279d895c9 --- /dev/null +++ b/workhorse/internal/httprs/httprs_test.go @@ -0,0 +1,257 @@ +package httprs + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" +) + +type fakeResponseWriter struct { + code int + h http.Header + tmp *os.File +} + +func (f *fakeResponseWriter) Header() http.Header { + return f.h +} + +func (f *fakeResponseWriter) Write(b []byte) (int, error) { + return f.tmp.Write(b) +} + +func (f *fakeResponseWriter) Close(b []byte) error { + return f.tmp.Close() +} + +func (f *fakeResponseWriter) WriteHeader(code int) { + f.code = code +} + +func (f *fakeResponseWriter) Response() *http.Response { + f.tmp.Seek(0, io.SeekStart) + return &http.Response{Body: f.tmp, StatusCode: f.code, Header: f.h} +} + +type fakeRoundTripper struct { + src *os.File + downgradeZeroToNoRange bool +} + +func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + fw := &fakeResponseWriter{h: http.Header{}} + var err error + fw.tmp, err = ioutil.TempFile(os.TempDir(), "httprs") + if err != nil { + return nil, err + } + if f.downgradeZeroToNoRange { + // There are implementations that downgrades bytes=0- to a normal un-ranged GET + if r.Header.Get("Range") == "bytes=0-" { + r.Header.Del("Range") + } + } + http.ServeContent(fw, r, "temp.txt", time.Now(), f.src) + + return fw.Response(), nil +} + +const SZ = 4096 + +const ( + downgradeZeroToNoRange = 1 << iota + sendAcceptRanges +) + +type RSFactory func() *HttpReadSeeker + +func newRSFactory(flags int) RSFactory { + return func() *HttpReadSeeker { + tmp, err := ioutil.TempFile(os.TempDir(), "httprs") + if err != nil { + return nil + } + for i := 0; i < SZ; i++ { + tmp.WriteString(fmt.Sprintf("%04d", i)) + } + + req, err := http.NewRequest("GET", "http://www.example.com", nil) + if err != nil { + return nil + } + res := &http.Response{ + Request: req, + ContentLength: SZ * 4, + } + + if flags&sendAcceptRanges > 0 { + res.Header = http.Header{"Accept-Ranges": []string{"bytes"}} + } + + downgradeZeroToNoRange := (flags & downgradeZeroToNoRange) > 0 + return NewHttpReadSeeker(res, &http.Client{Transport: &fakeRoundTripper{src: tmp, downgradeZeroToNoRange: downgradeZeroToNoRange}}) + } +} + +func TestHttpWebServer(t *testing.T) { + Convey("Scenario: testing WebServer", t, func() { + dir, err := ioutil.TempDir("", "webserver") + So(err, ShouldBeNil) + defer os.RemoveAll(dir) + + err = ioutil.WriteFile(filepath.Join(dir, "file"), make([]byte, 10000), 0755) + So(err, ShouldBeNil) + + server := httptest.NewServer(http.FileServer(http.Dir(dir))) + + Convey("When requesting /file", func() { + res, err := http.Get(server.URL + "/file") + So(err, ShouldBeNil) + + stream := NewHttpReadSeeker(res) + So(stream, ShouldNotBeNil) + + Convey("Can read 100 bytes from start of file", func() { + n, err := stream.Read(make([]byte, 100)) + So(err, ShouldBeNil) + So(n, ShouldEqual, 100) + + Convey("When seeking 4KiB forward", func() { + pos, err := stream.Seek(4096, io.SeekCurrent) + So(err, ShouldBeNil) + So(pos, ShouldEqual, 4096+100) + + Convey("Can read 100 bytes", func() { + n, err := stream.Read(make([]byte, 100)) + So(err, ShouldBeNil) + So(n, ShouldEqual, 100) + }) + }) + }) + }) + }) +} + +func TestHttpReaderSeeker(t *testing.T) { + tests := []struct { + name string + newRS func() *HttpReadSeeker + }{ + {name: "with no flags", newRS: newRSFactory(0)}, + {name: "with only Accept-Ranges", newRS: newRSFactory(sendAcceptRanges)}, + {name: "downgrade 0-range to no range", newRS: newRSFactory(downgradeZeroToNoRange)}, + {name: "downgrade 0-range with Accept-Ranges", newRS: newRSFactory(downgradeZeroToNoRange | sendAcceptRanges)}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testHttpReaderSeeker(t, test.newRS) + }) + } +} + +func testHttpReaderSeeker(t *testing.T, newRS RSFactory) { + Convey("Scenario: testing HttpReaderSeeker", t, func() { + + Convey("Read should start at the beginning", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0000") + }) + + Convey("Seek w SEEK_SET should seek to right offset", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + s, err := r.Seek(4*64, io.SeekStart) + So(s, ShouldEqual, 4*64) + So(err, ShouldBeNil) + buf := make([]byte, 4) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0064") + }) + + Convey("Read + Seek w SEEK_CUR should seek to right offset", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + io.ReadFull(r, buf) + s, err := r.Seek(4*64, os.SEEK_CUR) + So(s, ShouldEqual, 4*64+4) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0065") + }) + + Convey("Seek w SEEK_END should seek to right offset", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + io.ReadFull(r, buf) + s, err := r.Seek(4, os.SEEK_END) + So(s, ShouldEqual, SZ*4-4) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, fmt.Sprintf("%04d", SZ-1)) + }) + + Convey("Short seek should consume existing request", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + So(r.Requests, ShouldEqual, 0) + io.ReadFull(r, buf) + So(r.Requests, ShouldEqual, 1) + s, err := r.Seek(shortSeekBytes, os.SEEK_CUR) + So(r.Requests, ShouldEqual, 1) + So(s, ShouldEqual, shortSeekBytes+4) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "0257") + So(r.Requests, ShouldEqual, 1) + }) + + Convey("Long seek should do a new request", func() { + r := newRS() + So(r, ShouldNotBeNil) + defer r.Close() + buf := make([]byte, 4) + So(r.Requests, ShouldEqual, 0) + io.ReadFull(r, buf) + So(r.Requests, ShouldEqual, 1) + s, err := r.Seek(shortSeekBytes+1, os.SEEK_CUR) + So(r.Requests, ShouldEqual, 1) + So(s, ShouldEqual, shortSeekBytes+4+1) + So(err, ShouldBeNil) + n, err := io.ReadFull(r, buf) + So(n, ShouldEqual, 4) + So(err, ShouldBeNil) + So(string(buf), ShouldEqual, "2570") + So(r.Requests, ShouldEqual, 2) + }) + }) +} diff --git a/workhorse/internal/imageresizer/image_resizer.go b/workhorse/internal/imageresizer/image_resizer.go new file mode 100644 index 00000000000..77318ed1c46 --- /dev/null +++ b/workhorse/internal/imageresizer/image_resizer.go @@ -0,0 +1,449 @@ +package imageresizer + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "strconv" + "strings" + "sync/atomic" + "syscall" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" + "gitlab.com/gitlab-org/labkit/tracing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" +) + +type Resizer struct { + config.Config + senddata.Prefix + numScalerProcs processCounter +} + +type resizeParams struct { + Location string + ContentType string + Width uint +} + +type processCounter struct { + n int32 +} + +type resizeStatus = string + +type imageFile struct { + reader io.ReadCloser + contentLength int64 + lastModified time.Time +} + +// Carries information about how the scaler succeeded or failed. +type resizeOutcome struct { + bytesWritten int64 + originalFileSize int64 + status resizeStatus + err error +} + +const ( + statusSuccess = "success" // a rescaled image was served + statusClientCache = "success-client-cache" // scaling was skipped because client cache was fresh + statusServedOriginal = "served-original" // scaling failed but the original image was served + statusRequestFailure = "request-failed" // no image was served + statusUnknown = "unknown" // indicates an unhandled status case +) + +var envInjector = tracing.NewEnvInjector() + +// Images might be located remotely in object storage, in which case we need to stream +// it via http(s) +var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 10 * time.Second, + }).DialContext, + MaxIdleConns: 2, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, +})) + +var httpClient = &http.Client{ + Transport: httpTransport, +} + +const ( + namespace = "gitlab_workhorse" + subsystem = "image_resize" + logSystem = "imageresizer" +) + +var ( + imageResizeConcurrencyLimitExceeds = promauto.NewCounter( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "concurrency_limit_exceeds_total", + Help: "Amount of image resizing requests that exceeded the maximum allowed scaler processes", + }, + ) + imageResizeProcesses = promauto.NewGauge( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "processes", + Help: "Amount of image scaler processes working now", + }, + ) + imageResizeMaxProcesses = promauto.NewGauge( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "max_processes", + Help: "The maximum amount of image scaler processes allowed to run concurrently", + }, + ) + imageResizeRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "requests_total", + Help: "Image resizing operations requested", + }, + []string{"status"}, + ) + imageResizeDurations = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "duration_seconds", + Help: "Breakdown of total time spent serving successful image resizing requests (incl. data transfer)", + Buckets: []float64{ + 0.025, /* 25ms */ + 0.050, /* 50ms */ + 0.1, /* 100ms */ + 0.2, /* 200ms */ + 0.4, /* 400ms */ + 0.8, /* 800ms */ + }, + }, + []string{"content_type", "width"}, + ) +) + +const ( + jpegMagic = "\xff\xd8" // 2 bytes + pngMagic = "\x89PNG\r\n\x1a\n" // 8 bytes + maxMagicLen = 8 // 8 first bytes is enough to detect PNG or JPEG +) + +func NewResizer(cfg config.Config) *Resizer { + imageResizeMaxProcesses.Set(float64(cfg.ImageResizerConfig.MaxScalerProcs)) + + return &Resizer{Config: cfg, Prefix: "send-scaled-img:"} +} + +// Inject forks into a dedicated scaler process to resize an image identified by path or URL +// and streams the resized image back to the client +func (r *Resizer) Inject(w http.ResponseWriter, req *http.Request, paramsData string) { + var outcome = resizeOutcome{status: statusUnknown, originalFileSize: 0, bytesWritten: 0} + start := time.Now() + params, err := r.unpackParameters(paramsData) + + defer func() { + imageResizeRequests.WithLabelValues(outcome.status).Inc() + handleOutcome(w, req, start, params, &outcome) + }() + + if err != nil { + // This means the response header coming from Rails was malformed; there is no way + // to sensibly recover from this other than failing fast + outcome.error(fmt.Errorf("read image resize params: %v", err)) + return + } + + imageFile, err := openSourceImage(params.Location) + if err != nil { + // This means we cannot even read the input image; fail fast. + outcome.error(fmt.Errorf("open image data stream: %v", err)) + return + } + defer imageFile.reader.Close() + + outcome.originalFileSize = imageFile.contentLength + + setLastModified(w, imageFile.lastModified) + // If the original file has not changed, then any cached resized versions have not changed either. + if checkNotModified(req, imageFile.lastModified) { + writeNotModified(w) + outcome.ok(statusClientCache) + return + } + + // We first attempt to rescale the image; if this should fail for any reason, imageReader + // will point to the original image, i.e. we render it unchanged. + imageReader, resizeCmd, err := r.tryResizeImage(req, imageFile, params, r.Config.ImageResizerConfig) + if err != nil { + // Something failed, but we can still write out the original image, so don't return early. + // We need to log this separately since the subsequent steps might add other failures. + helper.LogErrorWithFields(req, err, *logFields(start, params, &outcome)) + } + defer helper.CleanUpProcessGroup(resizeCmd) + + w.Header().Del("Content-Length") + outcome.bytesWritten, err = serveImage(imageReader, w, resizeCmd) + + // We failed serving image data; this is a hard failure. + if err != nil { + outcome.error(err) + return + } + + // This means we served the original image because rescaling failed; this is a soft failure + if resizeCmd == nil { + outcome.ok(statusServedOriginal) + return + } + + widthLabelVal := strconv.Itoa(int(params.Width)) + imageResizeDurations.WithLabelValues(params.ContentType, widthLabelVal).Observe(time.Since(start).Seconds()) + + outcome.ok(statusSuccess) +} + +// Streams image data from the given reader to the given writer and returns the number of bytes written. +func serveImage(r io.Reader, w io.Writer, resizeCmd *exec.Cmd) (int64, error) { + bytesWritten, err := io.Copy(w, r) + if err != nil { + return bytesWritten, err + } + + if resizeCmd != nil { + // If a scaler process had been forked, wait for the command to finish. + if err = resizeCmd.Wait(); err != nil { + // err will be an ExitError; this is not useful beyond knowing the exit code since anything + // interesting has been written to stderr, so we turn that into an error we can return. + stdErr := resizeCmd.Stderr.(*strings.Builder) + return bytesWritten, fmt.Errorf(stdErr.String()) + } + } + + return bytesWritten, nil +} + +func (r *Resizer) unpackParameters(paramsData string) (*resizeParams, error) { + var params resizeParams + if err := r.Unpack(¶ms, paramsData); err != nil { + return nil, err + } + + if params.Location == "" { + return nil, fmt.Errorf("'Location' not set") + } + + if params.ContentType == "" { + return nil, fmt.Errorf("'ContentType' must be set") + } + + return ¶ms, nil +} + +// Attempts to rescale the given image data, or in case of errors, falls back to the original image. +func (r *Resizer) tryResizeImage(req *http.Request, f *imageFile, params *resizeParams, cfg config.ImageResizerConfig) (io.Reader, *exec.Cmd, error) { + if f.contentLength > int64(cfg.MaxFilesize) { + return f.reader, nil, fmt.Errorf("%d bytes exceeds maximum file size of %d bytes", f.contentLength, cfg.MaxFilesize) + } + + if f.contentLength < maxMagicLen { + return f.reader, nil, fmt.Errorf("file is too small to resize: %d bytes", f.contentLength) + } + + if !r.numScalerProcs.tryIncrement(int32(cfg.MaxScalerProcs)) { + return f.reader, nil, fmt.Errorf("too many running scaler processes (%d / %d)", r.numScalerProcs.n, cfg.MaxScalerProcs) + } + + ctx := req.Context() + go func() { + <-ctx.Done() + r.numScalerProcs.decrement() + }() + + // Creating buffered Reader is required for us to Peek into first bytes of the image file to detect the format + // without advancing the reader (we need to read from the file start in the Scaler binary). + // We set `8` as the minimal buffer size by the length of PNG magic bytes sequence (JPEG needs only 2). + // In fact, `NewReaderSize` will immediately override it with `16` using its `minReadBufferSize` - + // here we are just being explicit about the buffer size required for our code to operate correctly. + // Having a reader with such tiny buffer will not hurt the performance during further operations, + // because Golang `bufio.Read` avoids double copy: https://golang.org/src/bufio/bufio.go?s=1768:1804#L212 + buffered := bufio.NewReaderSize(f.reader, maxMagicLen) + + headerBytes, err := buffered.Peek(maxMagicLen) + if err != nil { + return buffered, nil, fmt.Errorf("peek stream: %v", err) + } + + // Check magic bytes to identify file type. + if string(headerBytes) != pngMagic && string(headerBytes[0:2]) != jpegMagic { + return buffered, nil, fmt.Errorf("unrecognized file signature: %v", headerBytes) + } + + resizeCmd, resizedImageReader, err := startResizeImageCommand(ctx, buffered, params) + if err != nil { + return buffered, nil, fmt.Errorf("fork into scaler process: %w", err) + } + return resizedImageReader, resizeCmd, nil +} + +func startResizeImageCommand(ctx context.Context, imageReader io.Reader, params *resizeParams) (*exec.Cmd, io.ReadCloser, error) { + cmd := exec.CommandContext(ctx, "gitlab-resize-image") + cmd.Stdin = imageReader + cmd.Stderr = &strings.Builder{} + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Env = []string{ + "GL_RESIZE_IMAGE_WIDTH=" + strconv.Itoa(int(params.Width)), + } + cmd.Env = envInjector(ctx, cmd.Env) + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, nil, err + } + + if err := cmd.Start(); err != nil { + return nil, nil, err + } + + return cmd, stdout, nil +} + +func isURL(location string) bool { + return strings.HasPrefix(location, "http://") || strings.HasPrefix(location, "https://") +} + +func openSourceImage(location string) (*imageFile, error) { + if isURL(location) { + return openFromURL(location) + } + + return openFromFile(location) +} + +func openFromURL(location string) (*imageFile, error) { + res, err := httpClient.Get(location) + if err != nil { + return nil, err + } + + switch res.StatusCode { + case http.StatusOK, http.StatusNotModified: + // Extract headers for conditional GETs from response. + lastModified, err := http.ParseTime(res.Header.Get("Last-Modified")) + if err != nil { + // This is unlikely to happen, coming from an object storage provider. + lastModified = time.Now().UTC() + } + return &imageFile{res.Body, res.ContentLength, lastModified}, nil + default: + res.Body.Close() + return nil, fmt.Errorf("stream data from %q: %d %s", location, res.StatusCode, res.Status) + } +} + +func openFromFile(location string) (*imageFile, error) { + file, err := os.Open(location) + if err != nil { + return nil, err + } + + fi, err := file.Stat() + if err != nil { + file.Close() + return nil, err + } + + return &imageFile{file, fi.Size(), fi.ModTime()}, nil +} + +// Only allow more scaling requests if we haven't yet reached the maximum +// allowed number of concurrent scaler processes +func (c *processCounter) tryIncrement(maxScalerProcs int32) bool { + if p := atomic.AddInt32(&c.n, 1); p > maxScalerProcs { + c.decrement() + imageResizeConcurrencyLimitExceeds.Inc() + + return false + } + + imageResizeProcesses.Set(float64(c.n)) + return true +} + +func (c *processCounter) decrement() { + atomic.AddInt32(&c.n, -1) + imageResizeProcesses.Set(float64(c.n)) +} + +func (o *resizeOutcome) ok(status resizeStatus) { + o.status = status + o.err = nil +} + +func (o *resizeOutcome) error(err error) { + o.status = statusRequestFailure + o.err = err +} + +func logFields(startTime time.Time, params *resizeParams, outcome *resizeOutcome) *log.Fields { + var targetWidth, contentType string + if params != nil { + targetWidth = fmt.Sprint(params.Width) + contentType = fmt.Sprint(params.ContentType) + } + return &log.Fields{ + "subsystem": logSystem, + "written_bytes": outcome.bytesWritten, + "duration_s": time.Since(startTime).Seconds(), + logSystem + ".status": outcome.status, + logSystem + ".target_width": targetWidth, + logSystem + ".content_type": contentType, + logSystem + ".original_filesize": outcome.originalFileSize, + } +} + +func handleOutcome(w http.ResponseWriter, req *http.Request, startTime time.Time, params *resizeParams, outcome *resizeOutcome) { + logger := log.ContextLogger(req.Context()) + fields := *logFields(startTime, params, outcome) + + switch outcome.status { + case statusRequestFailure: + if outcome.bytesWritten <= 0 { + helper.Fail500WithFields(w, req, outcome.err, fields) + } else { + helper.LogErrorWithFields(req, outcome.err, fields) + } + default: + logger.WithFields(fields).WithFields( + log.Fields{ + "method": req.Method, + "uri": mask.URL(req.RequestURI), + }, + ).Printf(outcome.status) + } +} diff --git a/workhorse/internal/imageresizer/image_resizer_caching.go b/workhorse/internal/imageresizer/image_resizer_caching.go new file mode 100644 index 00000000000..bbe0e3260d7 --- /dev/null +++ b/workhorse/internal/imageresizer/image_resizer_caching.go @@ -0,0 +1,44 @@ +// This file contains code derived from https://github.com/golang/go/blob/master/src/net/http/fs.go +// +// Copyright 2020 GitLab Inc. All rights reserved. +// Copyright 2009 The Go Authors. All rights reserved. + +package imageresizer + +import ( + "net/http" + "time" +) + +func checkNotModified(r *http.Request, modtime time.Time) bool { + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + // Treat bogus times as if there was no such header at all + return false + } + t, err := http.ParseTime(ims) + if err != nil { + return false + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + return !modtime.Truncate(time.Second).After(t) +} + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix epoch time). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(time.Unix(0, 0)) +} + +func setLastModified(w http.ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) + } +} + +func writeNotModified(w http.ResponseWriter) { + h := w.Header() + h.Del("Content-Type") + h.Del("Content-Length") + w.WriteHeader(http.StatusNotModified) +} diff --git a/workhorse/internal/imageresizer/image_resizer_test.go b/workhorse/internal/imageresizer/image_resizer_test.go new file mode 100644 index 00000000000..bacc97738b8 --- /dev/null +++ b/workhorse/internal/imageresizer/image_resizer_test.go @@ -0,0 +1,259 @@ +package imageresizer + +import ( + "encoding/base64" + "encoding/json" + "image" + "image/png" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + _ "image/jpeg" // need this for image.Decode with JPEG +) + +const imagePath = "../../testdata/image.png" + +func TestMain(m *testing.M) { + if err := testhelper.BuildExecutables(); err != nil { + log.WithError(err).Fatal() + } + + os.Exit(m.Run()) +} + +func requestScaledImage(t *testing.T, httpHeaders http.Header, params resizeParams, cfg config.ImageResizerConfig) *http.Response { + httpRequest := httptest.NewRequest("GET", "/image", nil) + if httpHeaders != nil { + httpRequest.Header = httpHeaders + } + responseWriter := httptest.NewRecorder() + paramsJSON := encodeParams(t, ¶ms) + + NewResizer(config.Config{ImageResizerConfig: cfg}).Inject(responseWriter, httpRequest, paramsJSON) + + return responseWriter.Result() +} + +func TestRequestScaledImageFromPath(t *testing.T) { + cfg := config.DefaultImageResizerConfig + + testCases := []struct { + desc string + imagePath string + contentType string + }{ + { + desc: "PNG", + imagePath: imagePath, + contentType: "image/png", + }, + { + desc: "JPEG", + imagePath: "../../testdata/image.jpg", + contentType: "image/jpeg", + }, + { + desc: "JPEG < 1kb", + imagePath: "../../testdata/image_single_pixel.jpg", + contentType: "image/jpeg", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + params := resizeParams{Location: tc.imagePath, ContentType: tc.contentType, Width: 64} + + resp := requestScaledImage(t, nil, params, cfg) + require.Equal(t, http.StatusOK, resp.StatusCode) + + bounds := imageFromResponse(t, resp).Bounds() + require.Equal(t, int(params.Width), bounds.Size().X, "wrong width after resizing") + }) + } +} + +func TestRequestScaledImageWithConditionalGetAndImageNotChanged(t *testing.T) { + cfg := config.DefaultImageResizerConfig + params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64} + + clientTime := testImageLastModified(t) + header := http.Header{} + header.Set("If-Modified-Since", httpTimeStr(clientTime)) + + resp := requestScaledImage(t, header, params, cfg) + require.Equal(t, http.StatusNotModified, resp.StatusCode) + require.Equal(t, httpTimeStr(testImageLastModified(t)), resp.Header.Get("Last-Modified")) + require.Empty(t, resp.Header.Get("Content-Type")) + require.Empty(t, resp.Header.Get("Content-Length")) +} + +func TestRequestScaledImageWithConditionalGetAndImageChanged(t *testing.T) { + cfg := config.DefaultImageResizerConfig + params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64} + + clientTime := testImageLastModified(t).Add(-1 * time.Second) + header := http.Header{} + header.Set("If-Modified-Since", httpTimeStr(clientTime)) + + resp := requestScaledImage(t, header, params, cfg) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, httpTimeStr(testImageLastModified(t)), resp.Header.Get("Last-Modified")) +} + +func TestRequestScaledImageWithConditionalGetAndInvalidClientTime(t *testing.T) { + cfg := config.DefaultImageResizerConfig + params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64} + + header := http.Header{} + header.Set("If-Modified-Since", "0") + + resp := requestScaledImage(t, header, params, cfg) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, httpTimeStr(testImageLastModified(t)), resp.Header.Get("Last-Modified")) +} + +func TestServeOriginalImageWhenSourceImageTooLarge(t *testing.T) { + originalImage := testImage(t) + cfg := config.ImageResizerConfig{MaxScalerProcs: 1, MaxFilesize: 1} + params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64} + + resp := requestScaledImage(t, nil, params, cfg) + require.Equal(t, http.StatusOK, resp.StatusCode) + + img := imageFromResponse(t, resp) + require.Equal(t, originalImage.Bounds(), img.Bounds(), "expected original image size") +} + +func TestFailFastOnOpenStreamFailure(t *testing.T) { + cfg := config.DefaultImageResizerConfig + params := resizeParams{Location: "does_not_exist.png", ContentType: "image/png", Width: 64} + resp := requestScaledImage(t, nil, params, cfg) + + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) +} + +func TestIgnoreContentTypeMismatchIfImageFormatIsAllowed(t *testing.T) { + cfg := config.DefaultImageResizerConfig + params := resizeParams{Location: imagePath, ContentType: "image/jpeg", Width: 64} + resp := requestScaledImage(t, nil, params, cfg) + require.Equal(t, http.StatusOK, resp.StatusCode) + + bounds := imageFromResponse(t, resp).Bounds() + require.Equal(t, int(params.Width), bounds.Size().X, "wrong width after resizing") +} + +func TestUnpackParametersReturnsParamsInstanceForValidInput(t *testing.T) { + r := Resizer{} + inParams := resizeParams{Location: imagePath, Width: 64, ContentType: "image/png"} + + outParams, err := r.unpackParameters(encodeParams(t, &inParams)) + + require.NoError(t, err, "unexpected error when unpacking params") + require.Equal(t, inParams, *outParams) +} + +func TestUnpackParametersReturnsErrorWhenLocationBlank(t *testing.T) { + r := Resizer{} + inParams := resizeParams{Location: "", Width: 64, ContentType: "image/jpg"} + + _, err := r.unpackParameters(encodeParams(t, &inParams)) + + require.Error(t, err, "expected error when Location is blank") +} + +func TestUnpackParametersReturnsErrorWhenContentTypeBlank(t *testing.T) { + r := Resizer{} + inParams := resizeParams{Location: imagePath, Width: 64, ContentType: ""} + + _, err := r.unpackParameters(encodeParams(t, &inParams)) + + require.Error(t, err, "expected error when ContentType is blank") +} + +func TestServeOriginalImageWhenSourceImageFormatIsNotAllowed(t *testing.T) { + cfg := config.DefaultImageResizerConfig + // SVG images are not allowed to be resized + svgImagePath := "../../testdata/image.svg" + svgImage, err := ioutil.ReadFile(svgImagePath) + require.NoError(t, err) + // ContentType is no longer used to perform the format validation. + // To make the test more strict, we'll use allowed, but incorrect ContentType. + params := resizeParams{Location: svgImagePath, ContentType: "image/png", Width: 64} + + resp := requestScaledImage(t, nil, params, cfg) + require.Equal(t, http.StatusOK, resp.StatusCode) + + responseData, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, svgImage, responseData, "expected original image") +} + +func TestServeOriginalImageWhenSourceImageIsTooSmall(t *testing.T) { + content := []byte("PNG") // 3 bytes only, invalid as PNG/JPEG image + + img, err := ioutil.TempFile("", "*.png") + require.NoError(t, err) + + defer img.Close() + defer os.Remove(img.Name()) + + _, err = img.Write(content) + require.NoError(t, err) + + cfg := config.DefaultImageResizerConfig + params := resizeParams{Location: img.Name(), ContentType: "image/png", Width: 64} + + resp := requestScaledImage(t, nil, params, cfg) + require.Equal(t, http.StatusOK, resp.StatusCode) + + responseData, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, content, responseData, "expected original image") +} + +// The Rails applications sends a Base64 encoded JSON string carrying +// these parameters in an HTTP response header +func encodeParams(t *testing.T, p *resizeParams) string { + json, err := json.Marshal(*p) + if err != nil { + require.NoError(t, err, "JSON encoder encountered unexpected error") + } + return base64.StdEncoding.EncodeToString(json) +} + +func testImage(t *testing.T) image.Image { + f, err := os.Open(imagePath) + require.NoError(t, err) + + image, err := png.Decode(f) + require.NoError(t, err, "decode original image") + + return image +} + +func testImageLastModified(t *testing.T) time.Time { + fi, err := os.Stat(imagePath) + require.NoError(t, err) + + return fi.ModTime() +} + +func imageFromResponse(t *testing.T, resp *http.Response) image.Image { + img, _, err := image.Decode(resp.Body) + require.NoError(t, err, "decode resized image") + return img +} + +func httpTimeStr(time time.Time) string { + return time.UTC().Format(http.TimeFormat) +} diff --git a/workhorse/internal/lfs/lfs.go b/workhorse/internal/lfs/lfs.go new file mode 100644 index 00000000000..ec48dc05ef9 --- /dev/null +++ b/workhorse/internal/lfs/lfs.go @@ -0,0 +1,55 @@ +/* +In this file we handle git lfs objects downloads and uploads +*/ + +package lfs + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" +) + +type object struct { + size int64 + oid string +} + +func (l *object) Verify(fh *filestore.FileHandler) error { + if fh.Size != l.size { + return fmt.Errorf("LFSObject: expected size %d, wrote %d", l.size, fh.Size) + } + + if fh.SHA256() != l.oid { + return fmt.Errorf("LFSObject: expected sha256 %s, got %s", l.oid, fh.SHA256()) + } + + return nil +} + +type uploadPreparer struct { + objectPreparer upload.Preparer +} + +func NewLfsUploadPreparer(c config.Config, objectPreparer upload.Preparer) upload.Preparer { + return &uploadPreparer{objectPreparer: objectPreparer} +} + +func (l *uploadPreparer) Prepare(a *api.Response) (*filestore.SaveFileOpts, upload.Verifier, error) { + opts, _, err := l.objectPreparer.Prepare(a) + if err != nil { + return nil, nil, err + } + + opts.TempFilePrefix = a.LfsOid + + return opts, &object{oid: a.LfsOid, size: a.LfsSize}, nil +} + +func PutStore(a *api.API, h http.Handler, p upload.Preparer) http.Handler { + return upload.BodyUploader(a, h, p) +} diff --git a/workhorse/internal/lfs/lfs_test.go b/workhorse/internal/lfs/lfs_test.go new file mode 100644 index 00000000000..828ed1bfe90 --- /dev/null +++ b/workhorse/internal/lfs/lfs_test.go @@ -0,0 +1,61 @@ +package lfs_test + +import ( + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/lfs" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" + + "github.com/stretchr/testify/require" +) + +func TestLfsUploadPreparerWithConfig(t *testing.T) { + lfsOid := "abcd1234" + creds := config.S3Credentials{ + AwsAccessKeyID: "test-key", + AwsSecretAccessKey: "test-secret", + } + + c := config.Config{ + ObjectStorageCredentials: config.ObjectStorageCredentials{ + Provider: "AWS", + S3Credentials: creds, + }, + } + + r := &api.Response{ + LfsOid: lfsOid, + RemoteObject: api.RemoteObject{ + ID: "the upload ID", + UseWorkhorseClient: true, + ObjectStorage: &api.ObjectStorageParams{ + Provider: "AWS", + }, + }, + } + + uploadPreparer := upload.NewObjectStoragePreparer(c) + lfsPreparer := lfs.NewLfsUploadPreparer(c, uploadPreparer) + opts, verifier, err := lfsPreparer.Prepare(r) + + require.NoError(t, err) + require.Equal(t, lfsOid, opts.TempFilePrefix) + require.True(t, opts.ObjectStorageConfig.IsAWS()) + require.True(t, opts.UseWorkhorseClient) + require.Equal(t, creds, opts.ObjectStorageConfig.S3Credentials) + require.NotNil(t, verifier) +} + +func TestLfsUploadPreparerWithNoConfig(t *testing.T) { + c := config.Config{} + r := &api.Response{RemoteObject: api.RemoteObject{ID: "the upload ID"}} + uploadPreparer := upload.NewObjectStoragePreparer(c) + lfsPreparer := lfs.NewLfsUploadPreparer(c, uploadPreparer) + opts, verifier, err := lfsPreparer.Prepare(r) + + require.NoError(t, err) + require.False(t, opts.UseWorkhorseClient) + require.NotNil(t, verifier) +} diff --git a/workhorse/internal/lsif_transformer/parser/cache.go b/workhorse/internal/lsif_transformer/parser/cache.go new file mode 100644 index 00000000000..395069cd217 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/cache.go @@ -0,0 +1,56 @@ +package parser + +import ( + "encoding/binary" + "io" + "io/ioutil" + "os" +) + +// This cache implementation is using a temp file to provide key-value data storage +// It allows to avoid storing intermediate calculations in RAM +// The stored data must be a fixed-size value or a slice of fixed-size values, or a pointer to such data +type cache struct { + file *os.File + chunkSize int64 +} + +func newCache(tempDir, filename string, data interface{}) (*cache, error) { + f, err := ioutil.TempFile(tempDir, filename) + if err != nil { + return nil, err + } + + if err := os.Remove(f.Name()); err != nil { + return nil, err + } + + return &cache{file: f, chunkSize: int64(binary.Size(data))}, nil +} + +func (c *cache) SetEntry(id Id, data interface{}) error { + if err := c.setOffset(id); err != nil { + return err + } + + return binary.Write(c.file, binary.LittleEndian, data) +} + +func (c *cache) Entry(id Id, data interface{}) error { + if err := c.setOffset(id); err != nil { + return err + } + + return binary.Read(c.file, binary.LittleEndian, data) +} + +func (c *cache) Close() error { + return c.file.Close() +} + +func (c *cache) setOffset(id Id) error { + offset := int64(id) * c.chunkSize + _, err := c.file.Seek(offset, io.SeekStart) + + return err +} diff --git a/workhorse/internal/lsif_transformer/parser/cache_test.go b/workhorse/internal/lsif_transformer/parser/cache_test.go new file mode 100644 index 00000000000..23a2ac6e9a9 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/cache_test.go @@ -0,0 +1,33 @@ +package parser + +import ( + "io/ioutil" + "testing" + + "github.com/stretchr/testify/require" +) + +type chunk struct { + A int16 + B int16 +} + +func TestCache(t *testing.T) { + cache, err := newCache("", "test-chunks", chunk{}) + require.NoError(t, err) + defer cache.Close() + + c := chunk{A: 1, B: 2} + require.NoError(t, cache.SetEntry(1, &c)) + require.NoError(t, cache.setOffset(0)) + + content, err := ioutil.ReadAll(cache.file) + require.NoError(t, err) + + expected := []byte{0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x2, 0x0} + require.Equal(t, expected, content) + + var nc chunk + require.NoError(t, cache.Entry(1, &nc)) + require.Equal(t, c, nc) +} diff --git a/workhorse/internal/lsif_transformer/parser/code_hover.go b/workhorse/internal/lsif_transformer/parser/code_hover.go new file mode 100644 index 00000000000..dbdaba643d1 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/code_hover.go @@ -0,0 +1,124 @@ +package parser + +import ( + "encoding/json" + "strings" + "unicode/utf8" + + "github.com/alecthomas/chroma" + "github.com/alecthomas/chroma/lexers" +) + +const maxValueSize = 250 + +type token struct { + Class string `json:"class,omitempty"` + Value string `json:"value"` +} + +type codeHover struct { + TruncatedValue *truncatableString `json:"value,omitempty"` + Tokens [][]token `json:"tokens,omitempty"` + Language string `json:"language,omitempty"` + Truncated bool `json:"truncated,omitempty"` +} + +type truncatableString struct { + Value string + Truncated bool +} + +func (ts *truncatableString) UnmarshalText(b []byte) error { + s := 0 + for i := 0; s < len(b); i++ { + if i >= maxValueSize { + ts.Truncated = true + break + } + + _, size := utf8.DecodeRune(b[s:]) + + s += size + } + + ts.Value = string(b[0:s]) + + return nil +} + +func (ts *truncatableString) MarshalJSON() ([]byte, error) { + return json.Marshal(ts.Value) +} + +func newCodeHover(content json.RawMessage) (*codeHover, error) { + // Hover value can be either an object: { "value": "func main()", "language": "go" } + // Or a string with documentation + // We try to unmarshal the content into a string and if we fail, we unmarshal it into an object + var c codeHover + if err := json.Unmarshal(content, &c.TruncatedValue); err != nil { + if err := json.Unmarshal(content, &c); err != nil { + return nil, err + } + + c.setTokens() + } + + c.Truncated = c.TruncatedValue.Truncated + + if len(c.Tokens) > 0 { + c.TruncatedValue = nil // remove value for hovers which have tokens + } + + return &c, nil +} + +func (c *codeHover) setTokens() { + lexer := lexers.Get(c.Language) + if lexer == nil { + return + } + + iterator, err := lexer.Tokenise(nil, c.TruncatedValue.Value) + if err != nil { + return + } + + var tokenLines [][]token + for _, tokenLine := range chroma.SplitTokensIntoLines(iterator.Tokens()) { + var tokens []token + var rawToken string + for _, t := range tokenLine { + class := c.classFor(t.Type) + + // accumulate consequent raw values in a single string to store them as + // [{ Class: "kd", Value: "func" }, { Value: " main() {" }] instead of + // [{ Class: "kd", Value: "func" }, { Value: " " }, { Value: "main" }, { Value: "(" }...] + if class == "" { + rawToken = rawToken + t.Value + } else { + if rawToken != "" { + tokens = append(tokens, token{Value: rawToken}) + rawToken = "" + } + + tokens = append(tokens, token{Class: class, Value: t.Value}) + } + } + + if rawToken != "" { + tokens = append(tokens, token{Value: rawToken}) + } + + tokenLines = append(tokenLines, tokens) + } + + c.Tokens = tokenLines +} + +func (c *codeHover) classFor(tokenType chroma.TokenType) string { + if strings.HasPrefix(tokenType.String(), "Keyword") || tokenType == chroma.String || tokenType == chroma.Comment { + return chroma.StandardTypes[tokenType] + } + + return "" +} diff --git a/workhorse/internal/lsif_transformer/parser/code_hover_test.go b/workhorse/internal/lsif_transformer/parser/code_hover_test.go new file mode 100644 index 00000000000..2030e530155 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/code_hover_test.go @@ -0,0 +1,106 @@ +package parser + +import ( + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHighlight(t *testing.T) { + tests := []struct { + name string + language string + value string + want [][]token + }{ + { + name: "go function definition", + language: "go", + value: "func main()", + want: [][]token{{{Class: "kd", Value: "func"}, {Value: " main()"}}}, + }, + { + name: "go struct definition", + language: "go", + value: "type Command struct", + want: [][]token{{{Class: "kd", Value: "type"}, {Value: " Command "}, {Class: "kd", Value: "struct"}}}, + }, + { + name: "go struct multiline definition", + language: "go", + value: `struct {\nConfig *Config\nReadWriter *ReadWriter\nEOFSent bool\n}`, + want: [][]token{ + {{Class: "kd", Value: "struct"}, {Value: " {\n"}}, + {{Value: "Config *Config\n"}}, + {{Value: "ReadWriter *ReadWriter\n"}}, + {{Value: "EOFSent "}, {Class: "kt", Value: "bool"}, {Value: "\n"}}, + {{Value: "}"}}, + }, + }, + { + name: "ruby method definition", + language: "ruby", + value: "def read(line)", + want: [][]token{{{Class: "k", Value: "def"}, {Value: " read(line)"}}}, + }, + { + name: "ruby multiline method definition", + language: "ruby", + value: `def read(line)\nend`, + want: [][]token{ + {{Class: "k", Value: "def"}, {Value: " read(line)\n"}}, + {{Class: "k", Value: "end"}}, + }, + }, + { + name: "unknown/malicious language is passed", + language: "<lang> alert(1); </lang>", + value: `def a;\nend`, + want: [][]token(nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + raw := []byte(fmt.Sprintf(`{"language":"%s","value":"%s"}`, tt.language, tt.value)) + c, err := newCodeHover(json.RawMessage(raw)) + + require.NoError(t, err) + require.Equal(t, tt.want, c.Tokens) + }) + } +} + +func TestMarkdown(t *testing.T) { + value := `"This method reverses a string \n\n"` + c, err := newCodeHover(json.RawMessage(value)) + + require.NoError(t, err) + require.Equal(t, "This method reverses a string \n\n", c.TruncatedValue.Value) +} + +func TestTruncatedValue(t *testing.T) { + value := strings.Repeat("a", 500) + rawValue, err := json.Marshal(value) + require.NoError(t, err) + + c, err := newCodeHover(rawValue) + require.NoError(t, err) + + require.Equal(t, value[0:maxValueSize], c.TruncatedValue.Value) + require.True(t, c.TruncatedValue.Truncated) +} + +func TestTruncatingMultiByteChars(t *testing.T) { + value := strings.Repeat("ಅ", 500) + rawValue, err := json.Marshal(value) + require.NoError(t, err) + + c, err := newCodeHover(rawValue) + require.NoError(t, err) + + symbolSize := 3 + require.Equal(t, value[0:maxValueSize*symbolSize], c.TruncatedValue.Value) +} diff --git a/workhorse/internal/lsif_transformer/parser/docs.go b/workhorse/internal/lsif_transformer/parser/docs.go new file mode 100644 index 00000000000..c626e07d3fe --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/docs.go @@ -0,0 +1,144 @@ +package parser + +import ( + "archive/zip" + "bufio" + "encoding/json" + "io" + "strings" +) + +const maxScanTokenSize = 1024 * 1024 + +type Line struct { + Type string `json:"label"` +} + +type Docs struct { + Root string + Entries map[Id]string + DocRanges map[Id][]Id + Ranges *Ranges +} + +type Document struct { + Id Id `json:"id"` + Uri string `json:"uri"` +} + +type DocumentRange struct { + OutV Id `json:"outV"` + RangeIds []Id `json:"inVs"` +} + +type Metadata struct { + Root string `json:"projectRoot"` +} + +func NewDocs(config Config) (*Docs, error) { + ranges, err := NewRanges(config) + if err != nil { + return nil, err + } + + return &Docs{ + Root: "file:///", + Entries: make(map[Id]string), + DocRanges: make(map[Id][]Id), + Ranges: ranges, + }, nil +} + +func (d *Docs) Parse(r io.Reader) error { + scanner := bufio.NewScanner(r) + buf := make([]byte, 0, bufio.MaxScanTokenSize) + scanner.Buffer(buf, maxScanTokenSize) + + for scanner.Scan() { + if err := d.process(scanner.Bytes()); err != nil { + return err + } + } + + return scanner.Err() +} + +func (d *Docs) process(line []byte) error { + l := Line{} + if err := json.Unmarshal(line, &l); err != nil { + return err + } + + switch l.Type { + case "metaData": + if err := d.addMetadata(line); err != nil { + return err + } + case "document": + if err := d.addDocument(line); err != nil { + return err + } + case "contains": + if err := d.addDocRanges(line); err != nil { + return err + } + default: + return d.Ranges.Read(l.Type, line) + } + + return nil +} + +func (d *Docs) Close() error { + return d.Ranges.Close() +} + +func (d *Docs) SerializeEntries(w *zip.Writer) error { + for id, path := range d.Entries { + filePath := Lsif + "/" + path + ".json" + + f, err := w.Create(filePath) + if err != nil { + return err + } + + if err := d.Ranges.Serialize(f, d.DocRanges[id], d.Entries); err != nil { + return err + } + } + + return nil +} + +func (d *Docs) addMetadata(line []byte) error { + var metadata Metadata + if err := json.Unmarshal(line, &metadata); err != nil { + return err + } + + d.Root = strings.TrimSpace(metadata.Root) + "/" + + return nil +} + +func (d *Docs) addDocument(line []byte) error { + var doc Document + if err := json.Unmarshal(line, &doc); err != nil { + return err + } + + d.Entries[doc.Id] = strings.TrimPrefix(doc.Uri, d.Root) + + return nil +} + +func (d *Docs) addDocRanges(line []byte) error { + var docRange DocumentRange + if err := json.Unmarshal(line, &docRange); err != nil { + return err + } + + d.DocRanges[docRange.OutV] = append(d.DocRanges[docRange.OutV], docRange.RangeIds...) + + return nil +} diff --git a/workhorse/internal/lsif_transformer/parser/docs_test.go b/workhorse/internal/lsif_transformer/parser/docs_test.go new file mode 100644 index 00000000000..57dca8e773d --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/docs_test.go @@ -0,0 +1,54 @@ +package parser + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func createLine(id, label, uri string) []byte { + return []byte(fmt.Sprintf(`{"id":"%s","label":"%s","uri":"%s"}`+"\n", id, label, uri)) +} + +func TestParse(t *testing.T) { + d, err := NewDocs(Config{}) + require.NoError(t, err) + defer d.Close() + + data := []byte(`{"id":"1","label":"metaData","projectRoot":"file:///Users/nested"}` + "\n") + data = append(data, createLine("2", "document", "file:///Users/nested/file.rb")...) + data = append(data, createLine("3", "document", "file:///Users/nested/folder/file.rb")...) + data = append(data, createLine("4", "document", "file:///Users/wrong/file.rb")...) + + require.NoError(t, d.Parse(bytes.NewReader(data))) + + require.Equal(t, d.Entries[2], "file.rb") + require.Equal(t, d.Entries[3], "folder/file.rb") + require.Equal(t, d.Entries[4], "file:///Users/wrong/file.rb") +} + +func TestParseContainsLine(t *testing.T) { + d, err := NewDocs(Config{}) + require.NoError(t, err) + defer d.Close() + + data := []byte(`{"id":"5","label":"contains","outV":"1", "inVs": ["2", "3"]}` + "\n") + data = append(data, []byte(`{"id":"6","label":"contains","outV":"1", "inVs": [4]}`+"\n")...) + + require.NoError(t, d.Parse(bytes.NewReader(data))) + + require.Equal(t, []Id{2, 3, 4}, d.DocRanges[1]) +} + +func TestParsingVeryLongLine(t *testing.T) { + d, err := NewDocs(Config{}) + require.NoError(t, err) + defer d.Close() + + line := []byte(`{"id": "` + strings.Repeat("a", 64*1024) + `"}`) + + require.NoError(t, d.Parse(bytes.NewReader(line))) +} diff --git a/workhorse/internal/lsif_transformer/parser/errors.go b/workhorse/internal/lsif_transformer/parser/errors.go new file mode 100644 index 00000000000..1040a789413 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/errors.go @@ -0,0 +1,30 @@ +package parser + +import ( + "errors" + "strings" +) + +func combineErrors(errsOrNil ...error) error { + var errs []error + for _, err := range errsOrNil { + if err != nil { + errs = append(errs, err) + } + } + + if len(errs) == 0 { + return nil + } + + if len(errs) == 1 { + return errs[0] + } + + var msgs []string + for _, err := range errs { + msgs = append(msgs, err.Error()) + } + + return errors.New(strings.Join(msgs, "\n")) +} diff --git a/workhorse/internal/lsif_transformer/parser/errors_test.go b/workhorse/internal/lsif_transformer/parser/errors_test.go new file mode 100644 index 00000000000..31a7130d05e --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/errors_test.go @@ -0,0 +1,26 @@ +package parser + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +type customErr struct { + err string +} + +func (e customErr) Error() string { + return e.err +} + +func TestCombineErrors(t *testing.T) { + err := combineErrors(nil, errors.New("first"), nil, customErr{"second"}) + require.EqualError(t, err, "first\nsecond") + + err = customErr{"custom error"} + require.Equal(t, err, combineErrors(nil, err, nil)) + + require.Nil(t, combineErrors(nil, nil, nil)) +} diff --git a/workhorse/internal/lsif_transformer/parser/hovers.go b/workhorse/internal/lsif_transformer/parser/hovers.go new file mode 100644 index 00000000000..e96d7e4fca3 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/hovers.go @@ -0,0 +1,162 @@ +package parser + +import ( + "encoding/json" + "io/ioutil" + "os" +) + +type Offset struct { + At int32 + Len int32 +} + +type Hovers struct { + File *os.File + Offsets *cache + CurrentOffset int +} + +type RawResult struct { + Contents []json.RawMessage `json:"contents"` +} + +type RawData struct { + Id Id `json:"id"` + Result RawResult `json:"result"` +} + +type HoverRef struct { + ResultSetId Id `json:"outV"` + HoverId Id `json:"inV"` +} + +type ResultSetRef struct { + ResultSetId Id `json:"outV"` + RefId Id `json:"inV"` +} + +func NewHovers(config Config) (*Hovers, error) { + tempPath := config.TempPath + + file, err := ioutil.TempFile(tempPath, "hovers") + if err != nil { + return nil, err + } + + if err := os.Remove(file.Name()); err != nil { + return nil, err + } + + offsets, err := newCache(tempPath, "hovers-indexes", Offset{}) + if err != nil { + return nil, err + } + + return &Hovers{ + File: file, + Offsets: offsets, + CurrentOffset: 0, + }, nil +} + +func (h *Hovers) Read(label string, line []byte) error { + switch label { + case "hoverResult": + if err := h.addData(line); err != nil { + return err + } + case "textDocument/hover": + if err := h.addHoverRef(line); err != nil { + return err + } + case "textDocument/references": + if err := h.addResultSetRef(line); err != nil { + return err + } + } + + return nil +} + +func (h *Hovers) For(refId Id) json.RawMessage { + var offset Offset + if err := h.Offsets.Entry(refId, &offset); err != nil || offset.Len == 0 { + return nil + } + + hover := make([]byte, offset.Len) + _, err := h.File.ReadAt(hover, int64(offset.At)) + if err != nil { + return nil + } + + return json.RawMessage(hover) +} + +func (h *Hovers) Close() error { + return combineErrors( + h.File.Close(), + h.Offsets.Close(), + ) +} + +func (h *Hovers) addData(line []byte) error { + var rawData RawData + if err := json.Unmarshal(line, &rawData); err != nil { + return err + } + + codeHovers := []*codeHover{} + for _, rawContent := range rawData.Result.Contents { + c, err := newCodeHover(rawContent) + if err != nil { + return err + } + + codeHovers = append(codeHovers, c) + } + + codeHoversData, err := json.Marshal(codeHovers) + if err != nil { + return err + } + + n, err := h.File.Write(codeHoversData) + if err != nil { + return err + } + + offset := Offset{At: int32(h.CurrentOffset), Len: int32(n)} + h.CurrentOffset += n + + return h.Offsets.SetEntry(rawData.Id, &offset) +} + +func (h *Hovers) addHoverRef(line []byte) error { + var hoverRef HoverRef + if err := json.Unmarshal(line, &hoverRef); err != nil { + return err + } + + var offset Offset + if err := h.Offsets.Entry(hoverRef.HoverId, &offset); err != nil { + return err + } + + return h.Offsets.SetEntry(hoverRef.ResultSetId, &offset) +} + +func (h *Hovers) addResultSetRef(line []byte) error { + var ref ResultSetRef + if err := json.Unmarshal(line, &ref); err != nil { + return err + } + + var offset Offset + if err := h.Offsets.Entry(ref.ResultSetId, &offset); err != nil { + return nil + } + + return h.Offsets.SetEntry(ref.RefId, &offset) +} diff --git a/workhorse/internal/lsif_transformer/parser/hovers_test.go b/workhorse/internal/lsif_transformer/parser/hovers_test.go new file mode 100644 index 00000000000..3037be103af --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/hovers_test.go @@ -0,0 +1,30 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHoversRead(t *testing.T) { + h := setupHovers(t) + + var offset Offset + require.NoError(t, h.Offsets.Entry(2, &offset)) + require.Equal(t, Offset{At: 0, Len: 19}, offset) + + require.Equal(t, `[{"value":"hello"}]`, string(h.For(1))) + + require.NoError(t, h.Close()) +} + +func setupHovers(t *testing.T) *Hovers { + h, err := NewHovers(Config{}) + require.NoError(t, err) + + require.NoError(t, h.Read("hoverResult", []byte(`{"id":"2","label":"hoverResult","result":{"contents": ["hello"]}}`))) + require.NoError(t, h.Read("textDocument/hover", []byte(`{"id":4,"label":"textDocument/hover","outV":"3","inV":2}`))) + require.NoError(t, h.Read("textDocument/references", []byte(`{"id":"3","label":"textDocument/references","outV":3,"inV":"1"}`))) + + return h +} diff --git a/workhorse/internal/lsif_transformer/parser/id.go b/workhorse/internal/lsif_transformer/parser/id.go new file mode 100644 index 00000000000..2adc4e092f5 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/id.go @@ -0,0 +1,52 @@ +package parser + +import ( + "encoding/json" + "errors" + "strconv" +) + +const ( + minId = 1 + maxId = 20 * 1000 * 1000 +) + +type Id int32 + +func (id *Id) UnmarshalJSON(b []byte) error { + if len(b) > 0 && b[0] != '"' { + if err := id.unmarshalInt(b); err != nil { + return err + } + } else { + if err := id.unmarshalString(b); err != nil { + return err + } + } + + if *id < minId || *id > maxId { + return errors.New("json: id is invalid") + } + + return nil +} + +func (id *Id) unmarshalInt(b []byte) error { + return json.Unmarshal(b, (*int32)(id)) +} + +func (id *Id) unmarshalString(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + + i, err := strconv.Atoi(s) + if err != nil { + return err + } + + *id = Id(i) + + return nil +} diff --git a/workhorse/internal/lsif_transformer/parser/id_test.go b/workhorse/internal/lsif_transformer/parser/id_test.go new file mode 100644 index 00000000000..c1c53928378 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/id_test.go @@ -0,0 +1,28 @@ +package parser + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +type jsonWithId struct { + Value Id `json:"value"` +} + +func TestId(t *testing.T) { + var v jsonWithId + require.NoError(t, json.Unmarshal([]byte(`{ "value": 1230 }`), &v)) + require.Equal(t, Id(1230), v.Value) + + require.NoError(t, json.Unmarshal([]byte(`{ "value": "1230" }`), &v)) + require.Equal(t, Id(1230), v.Value) + + require.Error(t, json.Unmarshal([]byte(`{ "value": "1.5" }`), &v)) + require.Error(t, json.Unmarshal([]byte(`{ "value": 1.5 }`), &v)) + require.Error(t, json.Unmarshal([]byte(`{ "value": "-1" }`), &v)) + require.Error(t, json.Unmarshal([]byte(`{ "value": -1 }`), &v)) + require.Error(t, json.Unmarshal([]byte(`{ "value": 21000000 }`), &v)) + require.Error(t, json.Unmarshal([]byte(`{ "value": "21000000" }`), &v)) +} diff --git a/workhorse/internal/lsif_transformer/parser/parser.go b/workhorse/internal/lsif_transformer/parser/parser.go new file mode 100644 index 00000000000..085e7a856aa --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/parser.go @@ -0,0 +1,109 @@ +package parser + +import ( + "archive/zip" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + + "gitlab.com/gitlab-org/labkit/log" +) + +var ( + Lsif = "lsif" +) + +type Parser struct { + Docs *Docs + + pr *io.PipeReader +} + +type Config struct { + TempPath string +} + +func NewParser(ctx context.Context, r io.Reader, config Config) (io.ReadCloser, error) { + docs, err := NewDocs(config) + if err != nil { + return nil, err + } + + // ZIP files need to be seekable. Don't hold it all in RAM, use a tempfile + tempFile, err := ioutil.TempFile(config.TempPath, Lsif) + if err != nil { + return nil, err + } + + defer tempFile.Close() + + if err := os.Remove(tempFile.Name()); err != nil { + return nil, err + } + + size, err := io.Copy(tempFile, r) + if err != nil { + return nil, err + } + log.WithContextFields(ctx, log.Fields{"lsif_zip_cache_bytes": size}).Print("cached incoming LSIF zip on disk") + + zr, err := zip.NewReader(tempFile, size) + if err != nil { + return nil, err + } + + if len(zr.File) == 0 { + return nil, errors.New("empty zip file") + } + + file, err := zr.File[0].Open() + if err != nil { + return nil, err + } + + defer file.Close() + + if err := docs.Parse(file); err != nil { + return nil, err + } + + pr, pw := io.Pipe() + parser := &Parser{ + Docs: docs, + pr: pr, + } + + go parser.transform(pw) + + return parser, nil +} + +func (p *Parser) Read(b []byte) (int, error) { + return p.pr.Read(b) +} + +func (p *Parser) Close() error { + p.pr.Close() + + return p.Docs.Close() +} + +func (p *Parser) transform(pw *io.PipeWriter) { + zw := zip.NewWriter(pw) + + if err := p.Docs.SerializeEntries(zw); err != nil { + zw.Close() // Free underlying resources only + pw.CloseWithError(fmt.Errorf("lsif parser: Docs.SerializeEntries: %v", err)) + return + } + + if err := zw.Close(); err != nil { + pw.CloseWithError(fmt.Errorf("lsif parser: ZipWriter.Close: %v", err)) + return + } + + pw.Close() +} diff --git a/workhorse/internal/lsif_transformer/parser/parser_test.go b/workhorse/internal/lsif_transformer/parser/parser_test.go new file mode 100644 index 00000000000..3a4d72360e2 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/parser_test.go @@ -0,0 +1,80 @@ +package parser + +import ( + "archive/zip" + "bytes" + "context" + "encoding/json" + "io" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGenerate(t *testing.T) { + filePath := "testdata/dump.lsif.zip" + tmpDir := filePath + ".tmp" + defer os.RemoveAll(tmpDir) + + createFiles(t, filePath, tmpDir) + + verifyCorrectnessOf(t, tmpDir, "lsif/main.go.json") + verifyCorrectnessOf(t, tmpDir, "lsif/morestrings/reverse.go.json") +} + +func verifyCorrectnessOf(t *testing.T, tmpDir, fileName string) { + file, err := ioutil.ReadFile(filepath.Join(tmpDir, fileName)) + require.NoError(t, err) + + var buf bytes.Buffer + require.NoError(t, json.Indent(&buf, file, "", " ")) + + expected, err := ioutil.ReadFile(filepath.Join("testdata/expected/", fileName)) + require.NoError(t, err) + + require.Equal(t, string(expected), buf.String()) +} + +func createFiles(t *testing.T, filePath, tmpDir string) { + t.Helper() + file, err := os.Open(filePath) + require.NoError(t, err) + + parser, err := NewParser(context.Background(), file, Config{}) + require.NoError(t, err) + + zipFileName := tmpDir + ".zip" + w, err := os.Create(zipFileName) + require.NoError(t, err) + defer os.RemoveAll(zipFileName) + + _, err = io.Copy(w, parser) + require.NoError(t, err) + require.NoError(t, parser.Close()) + + extractZipFiles(t, tmpDir, zipFileName) +} + +func extractZipFiles(t *testing.T, tmpDir, zipFileName string) { + zipReader, err := zip.OpenReader(zipFileName) + require.NoError(t, err) + + for _, file := range zipReader.Reader.File { + zippedFile, err := file.Open() + require.NoError(t, err) + defer zippedFile.Close() + + fileDir, fileName := filepath.Split(file.Name) + require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, fileDir), os.ModePerm)) + + outputFile, err := os.Create(filepath.Join(tmpDir, fileDir, fileName)) + require.NoError(t, err) + defer outputFile.Close() + + _, err = io.Copy(outputFile, zippedFile) + require.NoError(t, err) + } +} diff --git a/workhorse/internal/lsif_transformer/parser/performance_test.go b/workhorse/internal/lsif_transformer/parser/performance_test.go new file mode 100644 index 00000000000..5a12d90072f --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/performance_test.go @@ -0,0 +1,47 @@ +package parser + +import ( + "context" + "io" + "io/ioutil" + "os" + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkGenerate(b *testing.B) { + filePath := "testdata/workhorse.lsif.zip" + tmpDir := filePath + ".tmp" + defer os.RemoveAll(tmpDir) + + var memoryUsage float64 + for i := 0; i < b.N; i++ { + memoryUsage += measureMemory(func() { + file, err := os.Open(filePath) + require.NoError(b, err) + + parser, err := NewParser(context.Background(), file, Config{}) + require.NoError(b, err) + + _, err = io.Copy(ioutil.Discard, parser) + require.NoError(b, err) + require.NoError(b, parser.Close()) + }) + } + + b.ReportMetric(memoryUsage/float64(b.N), "MiB/op") +} + +func measureMemory(f func()) float64 { + var m, m1 runtime.MemStats + runtime.ReadMemStats(&m) + + f() + + runtime.ReadMemStats(&m1) + runtime.GC() + + return float64(m1.Alloc-m.Alloc) / 1024 / 1024 +} diff --git a/workhorse/internal/lsif_transformer/parser/ranges.go b/workhorse/internal/lsif_transformer/parser/ranges.go new file mode 100644 index 00000000000..a11a66d70ca --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/ranges.go @@ -0,0 +1,214 @@ +package parser + +import ( + "encoding/json" + "errors" + "io" + "strconv" +) + +const ( + definitions = "definitions" + references = "references" +) + +type Ranges struct { + DefRefs map[Id]Item + References *References + Hovers *Hovers + Cache *cache +} + +type RawRange struct { + Id Id `json:"id"` + Data Range `json:"start"` +} + +type Range struct { + Line int32 `json:"line"` + Character int32 `json:"character"` + RefId Id +} + +type RawItem struct { + Property string `json:"property"` + RefId Id `json:"outV"` + RangeIds []Id `json:"inVs"` + DocId Id `json:"document"` +} + +type Item struct { + Line int32 + DocId Id +} + +type SerializedRange struct { + StartLine int32 `json:"start_line"` + StartChar int32 `json:"start_char"` + DefinitionPath string `json:"definition_path,omitempty"` + Hover json.RawMessage `json:"hover"` + References []SerializedReference `json:"references,omitempty"` +} + +func NewRanges(config Config) (*Ranges, error) { + hovers, err := NewHovers(config) + if err != nil { + return nil, err + } + + references, err := NewReferences(config) + if err != nil { + return nil, err + } + + cache, err := newCache(config.TempPath, "ranges", Range{}) + if err != nil { + return nil, err + } + + return &Ranges{ + DefRefs: make(map[Id]Item), + References: references, + Hovers: hovers, + Cache: cache, + }, nil +} + +func (r *Ranges) Read(label string, line []byte) error { + switch label { + case "range": + if err := r.addRange(line); err != nil { + return err + } + case "item": + if err := r.addItem(line); err != nil { + return err + } + default: + return r.Hovers.Read(label, line) + } + + return nil +} + +func (r *Ranges) Serialize(f io.Writer, rangeIds []Id, docs map[Id]string) error { + encoder := json.NewEncoder(f) + n := len(rangeIds) + + if _, err := f.Write([]byte("[")); err != nil { + return err + } + + for i, rangeId := range rangeIds { + entry, err := r.getRange(rangeId) + if err != nil { + continue + } + + serializedRange := SerializedRange{ + StartLine: entry.Line, + StartChar: entry.Character, + DefinitionPath: r.definitionPathFor(docs, entry.RefId), + Hover: r.Hovers.For(entry.RefId), + References: r.References.For(docs, entry.RefId), + } + if err := encoder.Encode(serializedRange); err != nil { + return err + } + if i+1 < n { + if _, err := f.Write([]byte(",")); err != nil { + return err + } + } + } + + if _, err := f.Write([]byte("]")); err != nil { + return err + } + + return nil +} + +func (r *Ranges) Close() error { + return combineErrors( + r.Cache.Close(), + r.References.Close(), + r.Hovers.Close(), + ) +} + +func (r *Ranges) definitionPathFor(docs map[Id]string, refId Id) string { + defRef, ok := r.DefRefs[refId] + if !ok { + return "" + } + + defPath := docs[defRef.DocId] + "#L" + strconv.Itoa(int(defRef.Line)) + + return defPath +} + +func (r *Ranges) addRange(line []byte) error { + var rg RawRange + if err := json.Unmarshal(line, &rg); err != nil { + return err + } + + return r.Cache.SetEntry(rg.Id, &rg.Data) +} + +func (r *Ranges) addItem(line []byte) error { + var rawItem RawItem + if err := json.Unmarshal(line, &rawItem); err != nil { + return err + } + + if rawItem.Property != definitions && rawItem.Property != references { + return nil + } + + if len(rawItem.RangeIds) == 0 { + return errors.New("no range IDs") + } + + var references []Item + + for _, rangeId := range rawItem.RangeIds { + rg, err := r.getRange(rangeId) + if err != nil { + return err + } + + rg.RefId = rawItem.RefId + + if err := r.Cache.SetEntry(rangeId, rg); err != nil { + return err + } + + item := Item{ + Line: rg.Line + 1, + DocId: rawItem.DocId, + } + + if rawItem.Property == definitions { + r.DefRefs[rawItem.RefId] = item + } else { + references = append(references, item) + } + } + + if err := r.References.Store(rawItem.RefId, references); err != nil { + return err + } + + return nil +} + +func (r *Ranges) getRange(rangeId Id) (*Range, error) { + var rg Range + if err := r.Cache.Entry(rangeId, &rg); err != nil { + return nil, err + } + + return &rg, nil +} diff --git a/workhorse/internal/lsif_transformer/parser/ranges_test.go b/workhorse/internal/lsif_transformer/parser/ranges_test.go new file mode 100644 index 00000000000..c1400ba61da --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/ranges_test.go @@ -0,0 +1,61 @@ +package parser + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRangesRead(t *testing.T) { + r, cleanup := setup(t) + defer cleanup() + + firstRange := Range{Line: 1, Character: 2, RefId: 4} + rg, err := r.getRange(1) + require.NoError(t, err) + require.Equal(t, &firstRange, rg) + + secondRange := Range{Line: 5, Character: 4, RefId: 4} + rg, err = r.getRange(2) + require.NoError(t, err) + require.Equal(t, &secondRange, rg) + + thirdRange := Range{Line: 7, Character: 4, RefId: 4} + rg, err = r.getRange(3) + require.NoError(t, err) + require.Equal(t, &thirdRange, rg) +} + +func TestSerialize(t *testing.T) { + r, cleanup := setup(t) + defer cleanup() + + docs := map[Id]string{6: "def-path", 7: "ref-path"} + + var buf bytes.Buffer + err := r.Serialize(&buf, []Id{1}, docs) + want := `[{"start_line":1,"start_char":2,"definition_path":"def-path#L2","hover":null,"references":[{"path":"ref-path#L6"},{"path":"ref-path#L8"}]}` + "\n]" + + require.NoError(t, err) + require.Equal(t, want, buf.String()) +} + +func setup(t *testing.T) (*Ranges, func()) { + r, err := NewRanges(Config{}) + require.NoError(t, err) + + require.NoError(t, r.Read("range", []byte(`{"id":1,"label":"range","start":{"line":1,"character":2}}`))) + require.NoError(t, r.Read("range", []byte(`{"id":"2","label":"range","start":{"line":5,"character":4}}`))) + require.NoError(t, r.Read("range", []byte(`{"id":"3","label":"range","start":{"line":7,"character":4}}`))) + + require.NoError(t, r.Read("item", []byte(`{"id":5,"label":"item","property":"definitions","outV":"4","inVs":[1],"document":"6"}`))) + require.NoError(t, r.Read("item", []byte(`{"id":"6","label":"item","property":"references","outV":4,"inVs":["2"],"document":"7"}`))) + require.NoError(t, r.Read("item", []byte(`{"id":"7","label":"item","property":"references","outV":4,"inVs":["3"],"document":"7"}`))) + + cleanup := func() { + require.NoError(t, r.Close()) + } + + return r, cleanup +} diff --git a/workhorse/internal/lsif_transformer/parser/references.go b/workhorse/internal/lsif_transformer/parser/references.go new file mode 100644 index 00000000000..58ff9a61c02 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/references.go @@ -0,0 +1,107 @@ +package parser + +import ( + "strconv" +) + +type ReferencesOffset struct { + Id Id + Len int32 +} + +type References struct { + Items *cache + Offsets *cache + CurrentOffsetId Id +} + +type SerializedReference struct { + Path string `json:"path"` +} + +func NewReferences(config Config) (*References, error) { + tempPath := config.TempPath + + items, err := newCache(tempPath, "references", Item{}) + if err != nil { + return nil, err + } + + offsets, err := newCache(tempPath, "references-offsets", ReferencesOffset{}) + if err != nil { + return nil, err + } + + return &References{ + Items: items, + Offsets: offsets, + CurrentOffsetId: 0, + }, nil +} + +// Store is responsible for keeping track of references that will be used when +// serializing in `For`. +// +// The references are stored in a file to cache them. It is like +// `map[Id][]Item` (where `Id` is `refId`) but relies on caching the array and +// its offset in files for storage to reduce RAM usage. The items can be +// fetched by calling `getItems`. +func (r *References) Store(refId Id, references []Item) error { + size := len(references) + + if size == 0 { + return nil + } + + items := append(r.getItems(refId), references...) + err := r.Items.SetEntry(r.CurrentOffsetId, items) + if err != nil { + return err + } + + size = len(items) + r.Offsets.SetEntry(refId, ReferencesOffset{Id: r.CurrentOffsetId, Len: int32(size)}) + r.CurrentOffsetId += Id(size) + + return nil +} + +func (r *References) For(docs map[Id]string, refId Id) []SerializedReference { + references := r.getItems(refId) + if references == nil { + return nil + } + + var serializedReferences []SerializedReference + + for _, reference := range references { + serializedReference := SerializedReference{ + Path: docs[reference.DocId] + "#L" + strconv.Itoa(int(reference.Line)), + } + + serializedReferences = append(serializedReferences, serializedReference) + } + + return serializedReferences +} + +func (r *References) Close() error { + return combineErrors( + r.Items.Close(), + r.Offsets.Close(), + ) +} + +func (r *References) getItems(refId Id) []Item { + var offset ReferencesOffset + if err := r.Offsets.Entry(refId, &offset); err != nil || offset.Len == 0 { + return nil + } + + items := make([]Item, offset.Len) + if err := r.Items.Entry(offset.Id, &items); err != nil { + return nil + } + + return items +} diff --git a/workhorse/internal/lsif_transformer/parser/references_test.go b/workhorse/internal/lsif_transformer/parser/references_test.go new file mode 100644 index 00000000000..7b47513bc53 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/references_test.go @@ -0,0 +1,44 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReferencesStore(t *testing.T) { + const ( + docId = 1 + refId = 3 + ) + + r, err := NewReferences(Config{}) + require.NoError(t, err) + + err = r.Store(refId, []Item{{Line: 2, DocId: docId}, {Line: 3, DocId: docId}}) + require.NoError(t, err) + + docs := map[Id]string{docId: "doc.go"} + serializedReferences := r.For(docs, refId) + + require.Contains(t, serializedReferences, SerializedReference{Path: "doc.go#L2"}) + require.Contains(t, serializedReferences, SerializedReference{Path: "doc.go#L3"}) + + require.NoError(t, r.Close()) +} + +func TestReferencesStoreEmpty(t *testing.T) { + const refId = 3 + + r, err := NewReferences(Config{}) + require.NoError(t, err) + + err = r.Store(refId, []Item{}) + require.NoError(t, err) + + docs := map[Id]string{1: "doc.go"} + serializedReferences := r.For(docs, refId) + + require.Nil(t, serializedReferences) + require.NoError(t, r.Close()) +} diff --git a/workhorse/internal/lsif_transformer/parser/testdata/dump.lsif.zip b/workhorse/internal/lsif_transformer/parser/testdata/dump.lsif.zip Binary files differnew file mode 100644 index 00000000000..e7c9ef2da66 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/testdata/dump.lsif.zip diff --git a/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/main.go.json b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/main.go.json new file mode 100644 index 00000000000..781cb78fc1a --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/main.go.json @@ -0,0 +1,208 @@ +[ + { + "start_line": 7, + "start_char": 1, + "definition_path": "main.go#L4", + "hover": [ + { + "tokens": [ + [ + { + "class": "kn", + "value": "package" + }, + { + "value": " " + }, + { + "class": "s", + "value": "\"github.com/user/hello/morestrings\"" + } + ] + ], + "language": "go" + }, + { + "value": "Package morestrings implements additional functions to manipulate UTF-8 encoded strings, beyond what is provided in the standard \"strings\" package. \n\n" + } + ], + "references": [ + { + "path": "main.go#L8" + }, + { + "path": "main.go#L9" + } + ] + }, + { + "start_line": 7, + "start_char": 13, + "definition_path": "morestrings/reverse.go#L12", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "func" + }, + { + "value": " Reverse(s " + }, + { + "class": "kt", + "value": "string" + }, + { + "value": ") " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + }, + { + "value": "This method reverses a string \n\n" + } + ], + "references": [ + { + "path": "main.go#L8" + } + ] + }, + { + "start_line": 8, + "start_char": 1, + "definition_path": "main.go#L4", + "hover": [ + { + "tokens": [ + [ + { + "class": "kn", + "value": "package" + }, + { + "value": " " + }, + { + "class": "s", + "value": "\"github.com/user/hello/morestrings\"" + } + ] + ], + "language": "go" + }, + { + "value": "Package morestrings implements additional functions to manipulate UTF-8 encoded strings, beyond what is provided in the standard \"strings\" package. \n\n" + } + ], + "references": [ + { + "path": "main.go#L8" + }, + { + "path": "main.go#L9" + } + ] + }, + { + "start_line": 8, + "start_char": 13, + "definition_path": "morestrings/reverse.go#L5", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "func" + }, + { + "value": " Func2(i " + }, + { + "class": "kt", + "value": "int" + }, + { + "value": ") " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + } + ], + "references": [ + { + "path": "main.go#L9" + } + ] + }, + { + "start_line": 6, + "start_char": 5, + "definition_path": "main.go#L7", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "func" + }, + { + "value": " main()" + } + ] + ], + "language": "go" + } + ] + }, + { + "start_line": 3, + "start_char": 2, + "definition_path": "main.go#L4", + "hover": [ + { + "tokens": [ + [ + { + "class": "kn", + "value": "package" + }, + { + "value": " " + }, + { + "class": "s", + "value": "\"github.com/user/hello/morestrings\"" + } + ] + ], + "language": "go" + }, + { + "value": "Package morestrings implements additional functions to manipulate UTF-8 encoded strings, beyond what is provided in the standard \"strings\" package. \n\n" + } + ], + "references": [ + { + "path": "main.go#L8" + }, + { + "path": "main.go#L9" + } + ] + } +]
\ No newline at end of file diff --git a/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/morestrings/reverse.go.json b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/morestrings/reverse.go.json new file mode 100644 index 00000000000..1d238413d53 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/morestrings/reverse.go.json @@ -0,0 +1,249 @@ +[ + { + "start_line": 11, + "start_char": 5, + "definition_path": "morestrings/reverse.go#L12", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "func" + }, + { + "value": " Reverse(s " + }, + { + "class": "kt", + "value": "string" + }, + { + "value": ") " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + }, + { + "value": "This method reverses a string \n\n" + } + ], + "references": [ + { + "path": "main.go#L8" + } + ] + }, + { + "start_line": 4, + "start_char": 11, + "definition_path": "morestrings/reverse.go#L5", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "var" + }, + { + "value": " i " + }, + { + "class": "kt", + "value": "int" + } + ] + ], + "language": "go" + } + ] + }, + { + "start_line": 11, + "start_char": 13, + "definition_path": "morestrings/reverse.go#L12", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "var" + }, + { + "value": " s " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + } + ] + }, + { + "start_line": 12, + "start_char": 1, + "definition_path": "morestrings/reverse.go#L13", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "var" + }, + { + "value": " a " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + } + ], + "references": [ + { + "path": "morestrings/reverse.go#L15" + } + ] + }, + { + "start_line": 5, + "start_char": 1, + "definition_path": "morestrings/reverse.go#L6", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "var" + }, + { + "value": " b " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + } + ], + "references": [ + { + "path": "morestrings/reverse.go#L8" + } + ] + }, + { + "start_line": 14, + "start_char": 8, + "definition_path": "morestrings/reverse.go#L13", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "var" + }, + { + "value": " a " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + } + ], + "references": [ + { + "path": "morestrings/reverse.go#L15" + } + ] + }, + { + "start_line": 7, + "start_char": 8, + "definition_path": "morestrings/reverse.go#L6", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "var" + }, + { + "value": " b " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + } + ], + "references": [ + { + "path": "morestrings/reverse.go#L8" + } + ] + }, + { + "start_line": 4, + "start_char": 5, + "definition_path": "morestrings/reverse.go#L5", + "hover": [ + { + "tokens": [ + [ + { + "class": "kd", + "value": "func" + }, + { + "value": " Func2(i " + }, + { + "class": "kt", + "value": "int" + }, + { + "value": ") " + }, + { + "class": "kt", + "value": "string" + } + ] + ], + "language": "go" + } + ], + "references": [ + { + "path": "main.go#L9" + } + ] + } +]
\ No newline at end of file diff --git a/workhorse/internal/lsif_transformer/parser/testdata/workhorse.lsif.zip b/workhorse/internal/lsif_transformer/parser/testdata/workhorse.lsif.zip Binary files differnew file mode 100644 index 00000000000..76491ed8a93 --- /dev/null +++ b/workhorse/internal/lsif_transformer/parser/testdata/workhorse.lsif.zip diff --git a/workhorse/internal/objectstore/gocloud_object.go b/workhorse/internal/objectstore/gocloud_object.go new file mode 100644 index 00000000000..38545086994 --- /dev/null +++ b/workhorse/internal/objectstore/gocloud_object.go @@ -0,0 +1,100 @@ +package objectstore + +import ( + "context" + "io" + "time" + + "gitlab.com/gitlab-org/labkit/log" + "gocloud.dev/blob" + "gocloud.dev/gcerrors" +) + +type GoCloudObject struct { + bucket *blob.Bucket + mux *blob.URLMux + bucketURL string + objectName string + *uploader +} + +type GoCloudObjectParams struct { + Ctx context.Context + Mux *blob.URLMux + BucketURL string + ObjectName string +} + +func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) { + bucket, err := p.Mux.OpenBucket(p.Ctx, p.BucketURL) + if err != nil { + return nil, err + } + + o := &GoCloudObject{ + bucket: bucket, + mux: p.Mux, + bucketURL: p.BucketURL, + objectName: p.ObjectName, + } + + o.uploader = newUploader(o) + return o, nil +} + +func (o *GoCloudObject) Upload(ctx context.Context, r io.Reader) error { + defer o.bucket.Close() + + writer, err := o.bucket.NewWriter(ctx, o.objectName, nil) + if err != nil { + log.ContextLogger(ctx).WithError(err).Error("error creating GoCloud bucket") + return err + } + + if _, err = io.Copy(writer, r); err != nil { + log.ContextLogger(ctx).WithError(err).Error("error writing to GoCloud bucket") + writer.Close() + return err + } + + if err := writer.Close(); err != nil { + log.ContextLogger(ctx).WithError(err).Error("error closing GoCloud bucket") + return err + } + + return nil +} + +func (o *GoCloudObject) ETag() string { + return "" +} + +func (o *GoCloudObject) Abort() { + o.Delete() +} + +// Delete will always attempt to delete the temporary file. +// According to https://github.com/google/go-cloud/blob/7818961b5c9a112f7e092d3a2d8479cbca80d187/blob/azureblob/azureblob.go#L881-L883, +// if the writer is closed before any Write is called, Close will create an empty file. +func (o *GoCloudObject) Delete() { + if o.bucketURL == "" || o.objectName == "" { + return + } + + // Note we can't use the request context because in a successful + // case, the original request has already completed. + deleteCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // lint:allow context.Background + defer cancel() + + bucket, err := o.mux.OpenBucket(deleteCtx, o.bucketURL) + if err != nil { + log.WithError(err).Error("error opening bucket for delete") + return + } + + if err := bucket.Delete(deleteCtx, o.objectName); err != nil { + if gcerrors.Code(err) != gcerrors.NotFound { + log.WithError(err).Error("error deleting object") + } + } +} diff --git a/workhorse/internal/objectstore/gocloud_object_test.go b/workhorse/internal/objectstore/gocloud_object_test.go new file mode 100644 index 00000000000..4dc9d2d75cc --- /dev/null +++ b/workhorse/internal/objectstore/gocloud_object_test.go @@ -0,0 +1,56 @@ +package objectstore_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func TestGoCloudObjectUpload(t *testing.T) { + mux, _, cleanup := test.SetupGoCloudFileBucket(t, "azuretest") + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + deadline := time.Now().Add(testTimeout) + + objectName := "test.png" + testURL := "azuretest://azure.example.com/test-container" + p := &objectstore.GoCloudObjectParams{Ctx: ctx, Mux: mux, BucketURL: testURL, ObjectName: objectName} + object, err := objectstore.NewGoCloudObject(p) + require.NotNil(t, object) + require.NoError(t, err) + + // copy data + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) + require.NoError(t, err) + require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") + + bucket, err := mux.OpenBucket(ctx, testURL) + require.NoError(t, err) + + // Verify the data was copied correctly. + received, err := bucket.ReadAll(ctx, objectName) + require.NoError(t, err) + require.Equal(t, []byte(test.ObjectContent), received) + + cancel() + + testhelper.Retry(t, 5*time.Second, func() error { + exists, err := bucket.Exists(ctx, objectName) + require.NoError(t, err) + + if exists { + return fmt.Errorf("file %s is still present", objectName) + } else { + return nil + } + }) +} diff --git a/workhorse/internal/objectstore/multipart.go b/workhorse/internal/objectstore/multipart.go new file mode 100644 index 00000000000..fd1c0ed487d --- /dev/null +++ b/workhorse/internal/objectstore/multipart.go @@ -0,0 +1,188 @@ +package objectstore + +import ( + "bytes" + "context" + "encoding/xml" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" +) + +// ErrNotEnoughParts will be used when writing more than size * len(partURLs) +var ErrNotEnoughParts = errors.New("not enough Parts") + +// Multipart represents a MultipartUpload on a S3 compatible Object Store service. +// It can be used as io.WriteCloser for uploading an object +type Multipart struct { + PartURLs []string + // CompleteURL is a presigned URL for CompleteMultipartUpload + CompleteURL string + // AbortURL is a presigned URL for AbortMultipartUpload + AbortURL string + // DeleteURL is a presigned URL for RemoveObject + DeleteURL string + PutHeaders map[string]string + partSize int64 + etag string + + *uploader +} + +// NewMultipart provides Multipart pointer that can be used for uploading. Data written will be split buffered on disk up to size bytes +// then uploaded with S3 Upload Part. Once Multipart is Closed a final call to CompleteMultipartUpload will be sent. +// In case of any error a call to AbortMultipartUpload will be made to cleanup all the resources +func NewMultipart(partURLs []string, completeURL, abortURL, deleteURL string, putHeaders map[string]string, partSize int64) (*Multipart, error) { + m := &Multipart{ + PartURLs: partURLs, + CompleteURL: completeURL, + AbortURL: abortURL, + DeleteURL: deleteURL, + PutHeaders: putHeaders, + partSize: partSize, + } + + m.uploader = newUploader(m) + return m, nil +} + +func (m *Multipart) Upload(ctx context.Context, r io.Reader) error { + cmu := &CompleteMultipartUpload{} + for i, partURL := range m.PartURLs { + src := io.LimitReader(r, m.partSize) + part, err := m.readAndUploadOnePart(ctx, partURL, m.PutHeaders, src, i+1) + if err != nil { + return err + } + if part == nil { + break + } else { + cmu.Part = append(cmu.Part, part) + } + } + + n, err := io.Copy(ioutil.Discard, r) + if err != nil { + return fmt.Errorf("drain pipe: %v", err) + } + if n > 0 { + return ErrNotEnoughParts + } + + if err := m.complete(ctx, cmu); err != nil { + return err + } + + return nil +} + +func (m *Multipart) ETag() string { + return m.etag +} +func (m *Multipart) Abort() { + deleteURL(m.AbortURL) +} + +func (m *Multipart) Delete() { + deleteURL(m.DeleteURL) +} + +func (m *Multipart) readAndUploadOnePart(ctx context.Context, partURL string, putHeaders map[string]string, src io.Reader, partNumber int) (*completeMultipartUploadPart, error) { + file, err := ioutil.TempFile("", "part-buffer") + if err != nil { + return nil, fmt.Errorf("create temporary buffer file: %v", err) + } + defer func(path string) { + if err := os.Remove(path); err != nil { + log.WithError(err).WithField("file", path).Warning("Unable to delete temporary file") + } + }(file.Name()) + + n, err := io.Copy(file, src) + if err != nil { + return nil, err + } + if n == 0 { + return nil, nil + } + + if _, err = file.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("rewind part %d temporary dump : %v", partNumber, err) + } + + etag, err := m.uploadPart(ctx, partURL, putHeaders, file, n) + if err != nil { + return nil, fmt.Errorf("upload part %d: %v", partNumber, err) + } + return &completeMultipartUploadPart{PartNumber: partNumber, ETag: etag}, nil +} + +func (m *Multipart) uploadPart(ctx context.Context, url string, headers map[string]string, body io.Reader, size int64) (string, error) { + deadline, ok := ctx.Deadline() + if !ok { + return "", fmt.Errorf("missing deadline") + } + + part, err := newObject(url, "", headers, size, false) + if err != nil { + return "", err + } + + if n, err := part.Consume(ctx, io.LimitReader(body, size), deadline); err != nil || n < size { + if err == nil { + err = io.ErrUnexpectedEOF + } + return "", err + } + + return part.ETag(), nil +} + +func (m *Multipart) complete(ctx context.Context, cmu *CompleteMultipartUpload) error { + body, err := xml.Marshal(cmu) + if err != nil { + return fmt.Errorf("marshal CompleteMultipartUpload request: %v", err) + } + + req, err := http.NewRequest("POST", m.CompleteURL, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("create CompleteMultipartUpload request: %v", err) + } + req.ContentLength = int64(len(body)) + req.Header.Set("Content-Type", "application/xml") + req = req.WithContext(ctx) + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("CompleteMultipartUpload request %q: %v", mask.URL(m.CompleteURL), err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("CompleteMultipartUpload request %v returned: %s", mask.URL(m.CompleteURL), resp.Status) + } + + result := &compoundCompleteMultipartUploadResult{} + decoder := xml.NewDecoder(resp.Body) + if err := decoder.Decode(&result); err != nil { + return fmt.Errorf("decode CompleteMultipartUpload answer: %v", err) + } + + if result.isError() { + return result + } + + if result.CompleteMultipartUploadResult == nil { + return fmt.Errorf("empty CompleteMultipartUploadResult") + } + + m.etag = extractETag(result.ETag) + + return nil +} diff --git a/workhorse/internal/objectstore/multipart_test.go b/workhorse/internal/objectstore/multipart_test.go new file mode 100644 index 00000000000..00d6efc0982 --- /dev/null +++ b/workhorse/internal/objectstore/multipart_test.go @@ -0,0 +1,64 @@ +package objectstore_test + +import ( + "context" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" +) + +func TestMultipartUploadWithUpcaseETags(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var putCnt, postCnt int + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + defer r.Body.Close() + + // Part upload request + if r.Method == "PUT" { + putCnt++ + + w.Header().Set("ETag", strings.ToUpper(test.ObjectMD5)) + } + + // POST with CompleteMultipartUpload request + if r.Method == "POST" { + completeBody := `<CompleteMultipartUploadResult> + <Bucket>test-bucket</Bucket> + <ETag>No Longer Checked</ETag> + </CompleteMultipartUploadResult>` + postCnt++ + + w.Write([]byte(completeBody)) + } + })) + defer ts.Close() + + deadline := time.Now().Add(testTimeout) + + m, err := objectstore.NewMultipart( + []string{ts.URL}, // a single presigned part URL + ts.URL, // the complete multipart upload URL + "", // no abort + "", // no delete + map[string]string{}, // no custom headers + test.ObjectSize) // parts size equal to the whole content. Only 1 part + require.NoError(t, err) + + _, err = m.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) + require.NoError(t, err) + require.Equal(t, 1, putCnt, "1 part expected") + require.Equal(t, 1, postCnt, "1 complete multipart upload expected") +} diff --git a/workhorse/internal/objectstore/object.go b/workhorse/internal/objectstore/object.go new file mode 100644 index 00000000000..eaf3bfb2e36 --- /dev/null +++ b/workhorse/internal/objectstore/object.go @@ -0,0 +1,114 @@ +package objectstore + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "time" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/mask" + "gitlab.com/gitlab-org/labkit/tracing" +) + +// httpTransport defines a http.Transport with values +// that are more restrictive than for http.DefaultTransport, +// they define shorter TLS Handshake, and more aggressive connection closing +// to prevent the connection hanging and reduce FD usage +var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 10 * time.Second, + }).DialContext, + MaxIdleConns: 2, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, +})) + +var httpClient = &http.Client{ + Transport: httpTransport, +} + +// Object represents an object on a S3 compatible Object Store service. +// It can be used as io.WriteCloser for uploading an object +type Object struct { + // putURL is a presigned URL for PutObject + putURL string + // deleteURL is a presigned URL for RemoveObject + deleteURL string + putHeaders map[string]string + size int64 + etag string + metrics bool + + *uploader +} + +type StatusCodeError error + +// NewObject opens an HTTP connection to Object Store and returns an Object pointer that can be used for uploading. +func NewObject(putURL, deleteURL string, putHeaders map[string]string, size int64) (*Object, error) { + return newObject(putURL, deleteURL, putHeaders, size, true) +} + +func newObject(putURL, deleteURL string, putHeaders map[string]string, size int64, metrics bool) (*Object, error) { + o := &Object{ + putURL: putURL, + deleteURL: deleteURL, + putHeaders: putHeaders, + size: size, + metrics: metrics, + } + + o.uploader = newETagCheckUploader(o, metrics) + return o, nil +} + +func (o *Object) Upload(ctx context.Context, r io.Reader) error { + // we should prevent pr.Close() otherwise it may shadow error set with pr.CloseWithError(err) + req, err := http.NewRequest(http.MethodPut, o.putURL, ioutil.NopCloser(r)) + + if err != nil { + return fmt.Errorf("PUT %q: %v", mask.URL(o.putURL), err) + } + req.ContentLength = o.size + + for k, v := range o.putHeaders { + req.Header.Set(k, v) + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("PUT request %q: %v", mask.URL(o.putURL), err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if o.metrics { + objectStorageUploadRequestsInvalidStatus.Inc() + } + return StatusCodeError(fmt.Errorf("PUT request %v returned: %s", mask.URL(o.putURL), resp.Status)) + } + + o.etag = extractETag(resp.Header.Get("ETag")) + + return nil +} + +func (o *Object) ETag() string { + return o.etag +} + +func (o *Object) Abort() { + o.Delete() +} + +func (o *Object) Delete() { + deleteURL(o.deleteURL) +} diff --git a/workhorse/internal/objectstore/object_test.go b/workhorse/internal/objectstore/object_test.go new file mode 100644 index 00000000000..2ec45520e97 --- /dev/null +++ b/workhorse/internal/objectstore/object_test.go @@ -0,0 +1,155 @@ +package objectstore_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" +) + +const testTimeout = 10 * time.Second + +type osFactory func() (*test.ObjectstoreStub, *httptest.Server) + +func testObjectUploadNoErrors(t *testing.T, startObjectStore osFactory, useDeleteURL bool, contentType string) { + osStub, ts := startObjectStore() + defer ts.Close() + + objectURL := ts.URL + test.ObjectPath + var deleteURL string + if useDeleteURL { + deleteURL = objectURL + } + + putHeaders := map[string]string{"Content-Type": contentType} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + deadline := time.Now().Add(testTimeout) + object, err := objectstore.NewObject(objectURL, deleteURL, putHeaders, test.ObjectSize) + require.NoError(t, err) + + // copy data + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) + require.NoError(t, err) + require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") + + require.Equal(t, contentType, osStub.GetHeader(test.ObjectPath, "Content-Type")) + + // Checking MD5 extraction + require.Equal(t, osStub.GetObjectMD5(test.ObjectPath), object.ETag()) + + // Checking cleanup + cancel() + require.Equal(t, 1, osStub.PutsCnt(), "Object hasn't been uploaded") + + var expectedDeleteCnt int + if useDeleteURL { + expectedDeleteCnt = 1 + } + // Poll because the object removal is async + for i := 0; i < 100; i++ { + if osStub.DeletesCnt() == expectedDeleteCnt { + break + } + time.Sleep(10 * time.Millisecond) + } + + if useDeleteURL { + require.Equal(t, 1, osStub.DeletesCnt(), "Object hasn't been deleted") + } else { + require.Equal(t, 0, osStub.DeletesCnt(), "Object has been deleted") + } +} + +func TestObjectUpload(t *testing.T) { + t.Run("with delete URL", func(t *testing.T) { + testObjectUploadNoErrors(t, test.StartObjectStore, true, "application/octet-stream") + }) + t.Run("without delete URL", func(t *testing.T) { + testObjectUploadNoErrors(t, test.StartObjectStore, false, "application/octet-stream") + }) + t.Run("with custom content type", func(t *testing.T) { + testObjectUploadNoErrors(t, test.StartObjectStore, false, "image/jpeg") + }) + t.Run("with upcase ETAG", func(t *testing.T) { + factory := func() (*test.ObjectstoreStub, *httptest.Server) { + md5s := map[string]string{ + test.ObjectPath: strings.ToUpper(test.ObjectMD5), + } + + return test.StartObjectStoreWithCustomMD5(md5s) + } + + testObjectUploadNoErrors(t, factory, false, "application/octet-stream") + }) +} + +func TestObjectUpload404(t *testing.T) { + ts := httptest.NewServer(http.NotFoundHandler()) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + deadline := time.Now().Add(testTimeout) + objectURL := ts.URL + test.ObjectPath + object, err := objectstore.NewObject(objectURL, "", map[string]string{}, test.ObjectSize) + require.NoError(t, err) + _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) + + require.Error(t, err) + _, isStatusCodeError := err.(objectstore.StatusCodeError) + require.True(t, isStatusCodeError, "Should fail with StatusCodeError") + require.Contains(t, err.Error(), "404") +} + +type endlessReader struct{} + +func (e *endlessReader) Read(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + p[i] = '*' + } + + return len(p), nil +} + +// TestObjectUploadBrokenConnection purpose is to ensure that errors caused by the upload destination get propagated back correctly. +// This is important for troubleshooting in production. +func TestObjectUploadBrokenConnection(t *testing.T) { + // This test server closes connection immediately + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + require.FailNow(t, "webserver doesn't support hijacking") + } + conn, _, err := hj.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + conn.Close() + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + deadline := time.Now().Add(testTimeout) + objectURL := ts.URL + test.ObjectPath + object, err := objectstore.NewObject(objectURL, "", map[string]string{}, -1) + require.NoError(t, err) + + _, copyErr := object.Consume(ctx, &endlessReader{}, deadline) + require.Error(t, copyErr) + require.NotEqual(t, io.ErrClosedPipe, copyErr, "We are shadowing the real error") +} diff --git a/workhorse/internal/objectstore/prometheus.go b/workhorse/internal/objectstore/prometheus.go new file mode 100644 index 00000000000..20762fb52bc --- /dev/null +++ b/workhorse/internal/objectstore/prometheus.go @@ -0,0 +1,39 @@ +package objectstore + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + objectStorageUploadRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_object_storage_upload_requests", + Help: "How many object storage requests have been processed", + }, + []string{"status"}, + ) + objectStorageUploadsOpen = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "gitlab_workhorse_object_storage_upload_open", + Help: "Describes many object storage requests are open now", + }, + ) + objectStorageUploadBytes = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_object_storage_upload_bytes", + Help: "How many bytes were sent to object storage", + }, + ) + objectStorageUploadTime = promauto.NewHistogram( + prometheus.HistogramOpts{ + Name: "gitlab_workhorse_object_storage_upload_time", + Help: "How long it took to upload objects", + Buckets: objectStorageUploadTimeBuckets, + }) + + objectStorageUploadRequestsRequestFailed = objectStorageUploadRequests.WithLabelValues("request-failed") + objectStorageUploadRequestsInvalidStatus = objectStorageUploadRequests.WithLabelValues("invalid-status") + + objectStorageUploadTimeBuckets = []float64{.1, .25, .5, 1, 2.5, 5, 10, 25, 50, 100} +) diff --git a/workhorse/internal/objectstore/s3_complete_multipart_api.go b/workhorse/internal/objectstore/s3_complete_multipart_api.go new file mode 100644 index 00000000000..b84f5757f49 --- /dev/null +++ b/workhorse/internal/objectstore/s3_complete_multipart_api.go @@ -0,0 +1,51 @@ +package objectstore + +import ( + "encoding/xml" + "fmt" +) + +// CompleteMultipartUpload is the S3 CompleteMultipartUpload body +type CompleteMultipartUpload struct { + Part []*completeMultipartUploadPart +} + +type completeMultipartUploadPart struct { + PartNumber int + ETag string +} + +// CompleteMultipartUploadResult is the S3 answer to CompleteMultipartUpload request +type CompleteMultipartUploadResult struct { + Location string + Bucket string + Key string + ETag string +} + +// CompleteMultipartUploadError is the in-body error structure +// https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadComplete.html#mpUploadComplete-examples +// the answer contains other fields we are not using +type CompleteMultipartUploadError struct { + XMLName xml.Name `xml:"Error"` + Code string + Message string +} + +func (c *CompleteMultipartUploadError) Error() string { + return fmt.Sprintf("CompleteMultipartUpload remote error %q: %s", c.Code, c.Message) +} + +// compoundCompleteMultipartUploadResult holds both CompleteMultipartUploadResult and CompleteMultipartUploadError +// this allow us to deserialize the response body where the root element can either be Error orCompleteMultipartUploadResult +type compoundCompleteMultipartUploadResult struct { + *CompleteMultipartUploadResult + *CompleteMultipartUploadError + + // XMLName this overrides CompleteMultipartUploadError.XMLName tags + XMLName xml.Name +} + +func (c *compoundCompleteMultipartUploadResult) isError() bool { + return c.CompleteMultipartUploadError != nil +} diff --git a/workhorse/internal/objectstore/s3_object.go b/workhorse/internal/objectstore/s3_object.go new file mode 100644 index 00000000000..1f79f88224f --- /dev/null +++ b/workhorse/internal/objectstore/s3_object.go @@ -0,0 +1,119 @@ +package objectstore + +import ( + "context" + "io" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" +) + +type S3Object struct { + credentials config.S3Credentials + config config.S3Config + objectName string + uploaded bool + + *uploader +} + +func NewS3Object(objectName string, s3Credentials config.S3Credentials, s3Config config.S3Config) (*S3Object, error) { + o := &S3Object{ + credentials: s3Credentials, + config: s3Config, + objectName: objectName, + } + + o.uploader = newUploader(o) + return o, nil +} + +func setEncryptionOptions(input *s3manager.UploadInput, s3Config config.S3Config) { + if s3Config.ServerSideEncryption != "" { + input.ServerSideEncryption = aws.String(s3Config.ServerSideEncryption) + + if s3Config.ServerSideEncryption == s3.ServerSideEncryptionAwsKms && s3Config.SSEKMSKeyID != "" { + input.SSEKMSKeyId = aws.String(s3Config.SSEKMSKeyID) + } + } +} + +func (s *S3Object) Upload(ctx context.Context, r io.Reader) error { + sess, err := setupS3Session(s.credentials, s.config) + if err != nil { + log.WithError(err).Error("error creating S3 session") + return err + } + + uploader := s3manager.NewUploader(sess) + + input := &s3manager.UploadInput{ + Bucket: aws.String(s.config.Bucket), + Key: aws.String(s.objectName), + Body: r, + } + + setEncryptionOptions(input, s.config) + + _, err = uploader.UploadWithContext(ctx, input) + if err != nil { + log.WithError(err).Error("error uploading S3 session") + // Get the root cause, such as ErrEntityTooLarge, so we can return the proper HTTP status code + return unwrapAWSError(err) + } + + s.uploaded = true + + return nil +} + +func (s *S3Object) ETag() string { + return "" +} + +func (s *S3Object) Abort() { + s.Delete() +} + +func (s *S3Object) Delete() { + if !s.uploaded { + return + } + + session, err := setupS3Session(s.credentials, s.config) + if err != nil { + log.WithError(err).Error("error setting up S3 session in delete") + return + } + + svc := s3.New(session) + input := &s3.DeleteObjectInput{ + Bucket: aws.String(s.config.Bucket), + Key: aws.String(s.objectName), + } + + // Note we can't use the request context because in a successful + // case, the original request has already completed. + deleteCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // lint:allow context.Background + defer cancel() + + _, err = svc.DeleteObjectWithContext(deleteCtx, input) + if err != nil { + log.WithError(err).Error("error deleting S3 object", err) + } +} + +// This is needed until https://github.com/aws/aws-sdk-go/issues/2820 is closed. +func unwrapAWSError(e error) error { + if awsErr, ok := e.(awserr.Error); ok { + return unwrapAWSError(awsErr.OrigErr()) + } + + return e +} diff --git a/workhorse/internal/objectstore/s3_object_test.go b/workhorse/internal/objectstore/s3_object_test.go new file mode 100644 index 00000000000..d9ebbd7f979 --- /dev/null +++ b/workhorse/internal/objectstore/s3_object_test.go @@ -0,0 +1,174 @@ +package objectstore_test + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +type failedReader struct { + io.Reader +} + +func (r *failedReader) Read(p []byte) (int, error) { + origErr := fmt.Errorf("entity is too large") + return 0, awserr.New("Read", "read failed", origErr) +} + +func TestS3ObjectUpload(t *testing.T) { + testCases := []struct { + encryption string + }{ + {encryption: ""}, + {encryption: s3.ServerSideEncryptionAes256}, + {encryption: s3.ServerSideEncryptionAwsKms}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("encryption=%v", tc.encryption), func(t *testing.T) { + creds, config, sess, ts := test.SetupS3(t, tc.encryption) + defer ts.Close() + + deadline := time.Now().Add(testTimeout) + tmpDir, err := ioutil.TempDir("", "workhorse-test-") + require.NoError(t, err) + defer os.Remove(tmpDir) + + objectName := filepath.Join(tmpDir, "s3-test-data") + ctx, cancel := context.WithCancel(context.Background()) + + object, err := objectstore.NewS3Object(objectName, creds, config) + require.NoError(t, err) + + // copy data + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) + require.NoError(t, err) + require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") + + test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent) + test.CheckS3Metadata(t, sess, config, objectName) + + cancel() + + testhelper.Retry(t, 5*time.Second, func() error { + if test.S3ObjectDoesNotExist(t, sess, config, objectName) { + return nil + } + + return fmt.Errorf("file is still present") + }) + }) + } +} + +func TestConcurrentS3ObjectUpload(t *testing.T) { + creds, uploadsConfig, uploadsSession, uploadServer := test.SetupS3WithBucket(t, "uploads", "") + defer uploadServer.Close() + + // This will return a separate S3 endpoint + _, artifactsConfig, artifactsSession, artifactsServer := test.SetupS3WithBucket(t, "artifacts", "") + defer artifactsServer.Close() + + deadline := time.Now().Add(testTimeout) + tmpDir, err := ioutil.TempDir("", "workhorse-test-") + require.NoError(t, err) + defer os.Remove(tmpDir) + + var wg sync.WaitGroup + + for i := 0; i < 4; i++ { + wg.Add(1) + + go func(index int) { + var sess *session.Session + var config config.S3Config + + if index%2 == 0 { + sess = uploadsSession + config = uploadsConfig + } else { + sess = artifactsSession + config = artifactsConfig + } + + name := fmt.Sprintf("s3-test-data-%d", index) + objectName := filepath.Join(tmpDir, name) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + object, err := objectstore.NewS3Object(objectName, creds, config) + require.NoError(t, err) + + // copy data + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) + require.NoError(t, err) + require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") + + test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent) + wg.Done() + }(i) + } + + wg.Wait() +} + +func TestS3ObjectUploadCancel(t *testing.T) { + creds, config, _, ts := test.SetupS3(t, "") + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + deadline := time.Now().Add(testTimeout) + tmpDir, err := ioutil.TempDir("", "workhorse-test-") + require.NoError(t, err) + defer os.Remove(tmpDir) + + objectName := filepath.Join(tmpDir, "s3-test-data") + + object, err := objectstore.NewS3Object(objectName, creds, config) + + require.NoError(t, err) + + // Cancel the transfer before the data has been copied to ensure + // we handle this gracefully. + cancel() + + _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) + require.Error(t, err) + require.Equal(t, "context canceled", err.Error()) +} + +func TestS3ObjectUploadLimitReached(t *testing.T) { + creds, config, _, ts := test.SetupS3(t, "") + defer ts.Close() + + deadline := time.Now().Add(testTimeout) + tmpDir, err := ioutil.TempDir("", "workhorse-test-") + require.NoError(t, err) + defer os.Remove(tmpDir) + + objectName := filepath.Join(tmpDir, "s3-test-data") + object, err := objectstore.NewS3Object(objectName, creds, config) + require.NoError(t, err) + + _, err = object.Consume(context.Background(), &failedReader{}, deadline) + require.Error(t, err) + require.Equal(t, "entity is too large", err.Error()) +} diff --git a/workhorse/internal/objectstore/s3_session.go b/workhorse/internal/objectstore/s3_session.go new file mode 100644 index 00000000000..ebc8daf534c --- /dev/null +++ b/workhorse/internal/objectstore/s3_session.go @@ -0,0 +1,94 @@ +package objectstore + +import ( + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" +) + +type s3Session struct { + session *session.Session + expiry time.Time +} + +type s3SessionCache struct { + // An S3 session is cached by its input configuration (e.g. region, + // endpoint, path style, etc.), but the bucket is actually + // determined by the type of object to be uploaded (e.g. CI + // artifact, LFS, etc.) during runtime. In practice, we should only + // need one session per Workhorse process if we only allow one + // configuration for many different buckets. However, using a map + // indexed by the config avoids potential pitfalls in case the + // bucket configuration is supplied at startup or we need to support + // multiple S3 endpoints. + sessions map[config.S3Config]*s3Session + sync.Mutex +} + +func (s *s3Session) isExpired() bool { + return time.Now().After(s.expiry) +} + +func newS3SessionCache() *s3SessionCache { + return &s3SessionCache{sessions: make(map[config.S3Config]*s3Session)} +} + +var ( + // By default, it looks like IAM instance profiles may last 6 hours + // (via curl http://169.254.169.254/latest/meta-data/iam/security-credentials/<role_name>), + // but this may be configurable from anywhere for 15 minutes to 12 + // hours. To be safe, refresh AWS sessions every 10 minutes. + sessionExpiration = time.Duration(10 * time.Minute) + sessionCache = newS3SessionCache() +) + +// SetupS3Session initializes a new AWS S3 session and refreshes one if +// necessary. As recommended in https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/sessions.html, +// sessions should be cached when possible. Sessions are safe to use +// concurrently as long as the session isn't modified. +func setupS3Session(s3Credentials config.S3Credentials, s3Config config.S3Config) (*session.Session, error) { + sessionCache.Lock() + defer sessionCache.Unlock() + + if s, ok := sessionCache.sessions[s3Config]; ok && !s.isExpired() { + return s.session, nil + } + + cfg := &aws.Config{ + Region: aws.String(s3Config.Region), + S3ForcePathStyle: aws.Bool(s3Config.PathStyle), + } + + // In case IAM profiles aren't being used, use the static credentials + if s3Credentials.AwsAccessKeyID != "" && s3Credentials.AwsSecretAccessKey != "" { + cfg.Credentials = credentials.NewStaticCredentials(s3Credentials.AwsAccessKeyID, s3Credentials.AwsSecretAccessKey, "") + } + + if s3Config.Endpoint != "" { + cfg.Endpoint = aws.String(s3Config.Endpoint) + } + + sess, err := session.NewSession(cfg) + if err != nil { + return nil, err + } + + sessionCache.sessions[s3Config] = &s3Session{ + expiry: time.Now().Add(sessionExpiration), + session: sess, + } + + return sess, nil +} + +func ResetS3Session(s3Config config.S3Config) { + sessionCache.Lock() + defer sessionCache.Unlock() + + delete(sessionCache.sessions, s3Config) +} diff --git a/workhorse/internal/objectstore/s3_session_test.go b/workhorse/internal/objectstore/s3_session_test.go new file mode 100644 index 00000000000..8601f305917 --- /dev/null +++ b/workhorse/internal/objectstore/s3_session_test.go @@ -0,0 +1,57 @@ +package objectstore + +import ( + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" +) + +func TestS3SessionSetup(t *testing.T) { + credentials := config.S3Credentials{} + cfg := config.S3Config{Region: "us-west-1", PathStyle: true} + + sess, err := setupS3Session(credentials, cfg) + require.NoError(t, err) + + require.Equal(t, aws.StringValue(sess.Config.Region), "us-west-1") + require.True(t, aws.BoolValue(sess.Config.S3ForcePathStyle)) + + require.Equal(t, len(sessionCache.sessions), 1) + anotherConfig := cfg + _, err = setupS3Session(credentials, anotherConfig) + require.NoError(t, err) + require.Equal(t, len(sessionCache.sessions), 1) + + ResetS3Session(cfg) +} + +func TestS3SessionExpiry(t *testing.T) { + credentials := config.S3Credentials{} + cfg := config.S3Config{Region: "us-west-1", PathStyle: true} + + sess, err := setupS3Session(credentials, cfg) + require.NoError(t, err) + + require.Equal(t, aws.StringValue(sess.Config.Region), "us-west-1") + require.True(t, aws.BoolValue(sess.Config.S3ForcePathStyle)) + + firstSession, ok := sessionCache.sessions[cfg] + require.True(t, ok) + require.False(t, firstSession.isExpired()) + + firstSession.expiry = time.Now().Add(-1 * time.Second) + require.True(t, firstSession.isExpired()) + + _, err = setupS3Session(credentials, cfg) + require.NoError(t, err) + + nextSession, ok := sessionCache.sessions[cfg] + require.True(t, ok) + require.False(t, nextSession.isExpired()) + + ResetS3Session(cfg) +} diff --git a/workhorse/internal/objectstore/test/consts.go b/workhorse/internal/objectstore/test/consts.go new file mode 100644 index 00000000000..7a1bcc28d45 --- /dev/null +++ b/workhorse/internal/objectstore/test/consts.go @@ -0,0 +1,19 @@ +package test + +// Some useful const for testing purpose +const ( + // ObjectContent an example textual content + ObjectContent = "TEST OBJECT CONTENT" + // ObjectSize is the ObjectContent size + ObjectSize = int64(len(ObjectContent)) + // Objectpath is an example remote object path (including bucket name) + ObjectPath = "/bucket/object" + // ObjectMD5 is ObjectContent MD5 hash + ObjectMD5 = "42d000eea026ee0760677e506189cb33" + // ObjectSHA1 is ObjectContent SHA1 hash + ObjectSHA1 = "173cfd58c6b60cb910f68a26cbb77e3fc5017a6d" + // ObjectSHA256 is ObjectContent SHA256 hash + ObjectSHA256 = "b0257e9e657ef19b15eed4fbba975bd5238d651977564035ef91cb45693647aa" + // ObjectSHA512 is ObjectContent SHA512 hash + ObjectSHA512 = "51af8197db2047f7894652daa7437927bf831d5aa63f1b0b7277c4800b06f5e3057251f0e4c2d344ca8c2daf1ffc08a28dd3b2f5fe0e316d3fd6c3af58c34b97" +) diff --git a/workhorse/internal/objectstore/test/gocloud_stub.go b/workhorse/internal/objectstore/test/gocloud_stub.go new file mode 100644 index 00000000000..cf22075e407 --- /dev/null +++ b/workhorse/internal/objectstore/test/gocloud_stub.go @@ -0,0 +1,47 @@ +package test + +import ( + "context" + "io/ioutil" + "net/url" + "os" + "testing" + + "github.com/stretchr/testify/require" + "gocloud.dev/blob" + "gocloud.dev/blob/fileblob" +) + +type dirOpener struct { + tmpDir string +} + +func (o *dirOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) { + return fileblob.OpenBucket(o.tmpDir, nil) +} + +func SetupGoCloudFileBucket(t *testing.T, scheme string) (m *blob.URLMux, bucketDir string, cleanup func()) { + tmpDir, err := ioutil.TempDir("", "") + require.NoError(t, err) + + mux := new(blob.URLMux) + fake := &dirOpener{tmpDir: tmpDir} + mux.RegisterBucket(scheme, fake) + cleanup = func() { + os.RemoveAll(tmpDir) + } + + return mux, tmpDir, cleanup +} + +func GoCloudObjectExists(t *testing.T, bucketDir string, objectName string) { + bucket, err := fileblob.OpenBucket(bucketDir, nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) // lint:allow context.Background + defer cancel() + + exists, err := bucket.Exists(ctx, objectName) + require.NoError(t, err) + require.True(t, exists) +} diff --git a/workhorse/internal/objectstore/test/objectstore_stub.go b/workhorse/internal/objectstore/test/objectstore_stub.go new file mode 100644 index 00000000000..31ef4913305 --- /dev/null +++ b/workhorse/internal/objectstore/test/objectstore_stub.go @@ -0,0 +1,278 @@ +package test + +import ( + "crypto/md5" + "encoding/hex" + "encoding/xml" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore" +) + +type partsEtagMap map[int]string + +// ObjectstoreStub is a testing implementation of ObjectStore. +// Instead of storing objects it will just save md5sum. +type ObjectstoreStub struct { + // bucket contains md5sum of uploaded objects + bucket map[string]string + // overwriteMD5 contains overwrites for md5sum that should be return instead of the regular hash + overwriteMD5 map[string]string + // multipart is a map of MultipartUploads + multipart map[string]partsEtagMap + // HTTP header sent along request + headers map[string]*http.Header + + puts int + deletes int + + m sync.Mutex +} + +// StartObjectStore will start an ObjectStore stub +func StartObjectStore() (*ObjectstoreStub, *httptest.Server) { + return StartObjectStoreWithCustomMD5(make(map[string]string)) +} + +// StartObjectStoreWithCustomMD5 will start an ObjectStore stub: md5Hashes contains overwrites for md5sum that should be return on PutObject +func StartObjectStoreWithCustomMD5(md5Hashes map[string]string) (*ObjectstoreStub, *httptest.Server) { + os := &ObjectstoreStub{ + bucket: make(map[string]string), + multipart: make(map[string]partsEtagMap), + overwriteMD5: make(map[string]string), + headers: make(map[string]*http.Header), + } + + for k, v := range md5Hashes { + os.overwriteMD5[k] = v + } + + return os, httptest.NewServer(os) +} + +// PutsCnt counts PutObject invocations +func (o *ObjectstoreStub) PutsCnt() int { + o.m.Lock() + defer o.m.Unlock() + + return o.puts +} + +// DeletesCnt counts DeleteObject invocation of a valid object +func (o *ObjectstoreStub) DeletesCnt() int { + o.m.Lock() + defer o.m.Unlock() + + return o.deletes +} + +// GetObjectMD5 return the calculated MD5 of the object uploaded to path +// it will return an empty string if no object has been uploaded on such path +func (o *ObjectstoreStub) GetObjectMD5(path string) string { + o.m.Lock() + defer o.m.Unlock() + + return o.bucket[path] +} + +// GetHeader returns a given HTTP header of the object uploaded to the path +func (o *ObjectstoreStub) GetHeader(path, key string) string { + o.m.Lock() + defer o.m.Unlock() + + if val, ok := o.headers[path]; ok { + return val.Get(key) + } + + return "" +} + +// InitiateMultipartUpload prepare the ObjectstoreStob to receive a MultipartUpload on path +// It will return an error if a MultipartUpload is already in progress on that path +// InitiateMultipartUpload is only used during test setup. +// Workhorse's production code does not know how to initiate a multipart upload. +// +// Real S3 multipart uploads are more complicated than what we do here, +// but this is enough to verify that workhorse's production code behaves as intended. +func (o *ObjectstoreStub) InitiateMultipartUpload(path string) error { + o.m.Lock() + defer o.m.Unlock() + + if o.multipart[path] != nil { + return fmt.Errorf("MultipartUpload for %q already in progress", path) + } + + o.multipart[path] = make(partsEtagMap) + return nil +} + +// IsMultipartUpload check if the given path has a MultipartUpload in progress +func (o *ObjectstoreStub) IsMultipartUpload(path string) bool { + o.m.Lock() + defer o.m.Unlock() + + return o.isMultipartUpload(path) +} + +// isMultipartUpload is the lock free version of IsMultipartUpload +func (o *ObjectstoreStub) isMultipartUpload(path string) bool { + return o.multipart[path] != nil +} + +func (o *ObjectstoreStub) removeObject(w http.ResponseWriter, r *http.Request) { + o.m.Lock() + defer o.m.Unlock() + + objectPath := r.URL.Path + if o.isMultipartUpload(objectPath) { + o.deletes++ + delete(o.multipart, objectPath) + + w.WriteHeader(200) + } else if _, ok := o.bucket[objectPath]; ok { + o.deletes++ + delete(o.bucket, objectPath) + + w.WriteHeader(200) + } else { + w.WriteHeader(404) + } +} + +func (o *ObjectstoreStub) putObject(w http.ResponseWriter, r *http.Request) { + o.m.Lock() + defer o.m.Unlock() + + objectPath := r.URL.Path + + etag, overwritten := o.overwriteMD5[objectPath] + if !overwritten { + hasher := md5.New() + io.Copy(hasher, r.Body) + + checksum := hasher.Sum(nil) + etag = hex.EncodeToString(checksum) + } + + o.headers[objectPath] = &r.Header + o.puts++ + if o.isMultipartUpload(objectPath) { + pNumber := r.URL.Query().Get("partNumber") + idx, err := strconv.Atoi(pNumber) + if err != nil { + http.Error(w, fmt.Sprintf("malformed partNumber: %v", err), 400) + return + } + + o.multipart[objectPath][idx] = etag + } else { + o.bucket[objectPath] = etag + } + + w.Header().Set("ETag", etag) + w.WriteHeader(200) +} + +func MultipartUploadInternalError() *objectstore.CompleteMultipartUploadError { + return &objectstore.CompleteMultipartUploadError{Code: "InternalError", Message: "malformed object path"} +} + +func (o *ObjectstoreStub) completeMultipartUpload(w http.ResponseWriter, r *http.Request) { + o.m.Lock() + defer o.m.Unlock() + + objectPath := r.URL.Path + + multipart := o.multipart[objectPath] + if multipart == nil { + http.Error(w, "Unknown MultipartUpload", 404) + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + + var msg objectstore.CompleteMultipartUpload + err = xml.Unmarshal(buf, &msg) + if err != nil { + http.Error(w, err.Error(), 400) + return + } + + for _, part := range msg.Part { + etag := multipart[part.PartNumber] + if etag != part.ETag { + msg := fmt.Sprintf("ETag mismatch on part %d. Expected %q got %q", part.PartNumber, etag, part.ETag) + http.Error(w, msg, 400) + return + } + } + + etag, overwritten := o.overwriteMD5[objectPath] + if !overwritten { + etag = "CompleteMultipartUploadETag" + } + + o.bucket[objectPath] = etag + delete(o.multipart, objectPath) + + w.Header().Set("ETag", etag) + split := strings.SplitN(objectPath[1:], "/", 2) + if len(split) < 2 { + encodeXMLAnswer(w, MultipartUploadInternalError()) + return + } + + bucket := split[0] + key := split[1] + answer := objectstore.CompleteMultipartUploadResult{ + Location: r.URL.String(), + Bucket: bucket, + Key: key, + ETag: etag, + } + encodeXMLAnswer(w, answer) +} + +func encodeXMLAnswer(w http.ResponseWriter, answer interface{}) { + w.Header().Set("Content-Type", "text/xml") + + enc := xml.NewEncoder(w) + if err := enc.Encode(answer); err != nil { + http.Error(w, err.Error(), 500) + } +} + +func (o *ObjectstoreStub) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Body != nil { + defer r.Body.Close() + } + + fmt.Println("ObjectStore Stub:", r.Method, r.URL.String()) + + if r.URL.Path == "" { + http.Error(w, "No path provided", 404) + return + } + + switch r.Method { + case "DELETE": + o.removeObject(w, r) + case "PUT": + o.putObject(w, r) + case "POST": + o.completeMultipartUpload(w, r) + default: + w.WriteHeader(404) + } +} diff --git a/workhorse/internal/objectstore/test/objectstore_stub_test.go b/workhorse/internal/objectstore/test/objectstore_stub_test.go new file mode 100644 index 00000000000..8c0d52a2d79 --- /dev/null +++ b/workhorse/internal/objectstore/test/objectstore_stub_test.go @@ -0,0 +1,167 @@ +package test + +import ( + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func doRequest(method, url string, body io.Reader) error { + req, err := http.NewRequest(method, url, body) + if err != nil { + return err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + return resp.Body.Close() +} + +func TestObjectStoreStub(t *testing.T) { + stub, ts := StartObjectStore() + defer ts.Close() + + require.Equal(t, 0, stub.PutsCnt()) + require.Equal(t, 0, stub.DeletesCnt()) + + objectURL := ts.URL + ObjectPath + + require.NoError(t, doRequest(http.MethodPut, objectURL, strings.NewReader(ObjectContent))) + + require.Equal(t, 1, stub.PutsCnt()) + require.Equal(t, 0, stub.DeletesCnt()) + require.Equal(t, ObjectMD5, stub.GetObjectMD5(ObjectPath)) + + require.NoError(t, doRequest(http.MethodDelete, objectURL, nil)) + + require.Equal(t, 1, stub.PutsCnt()) + require.Equal(t, 1, stub.DeletesCnt()) +} + +func TestObjectStoreStubDelete404(t *testing.T) { + stub, ts := StartObjectStore() + defer ts.Close() + + objectURL := ts.URL + ObjectPath + + req, err := http.NewRequest(http.MethodDelete, objectURL, nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 404, resp.StatusCode) + + require.Equal(t, 0, stub.DeletesCnt()) +} + +func TestObjectStoreInitiateMultipartUpload(t *testing.T) { + stub, ts := StartObjectStore() + defer ts.Close() + + path := "/my-multipart" + err := stub.InitiateMultipartUpload(path) + require.NoError(t, err) + + err = stub.InitiateMultipartUpload(path) + require.Error(t, err, "second attempt to open the same MultipartUpload") +} + +func TestObjectStoreCompleteMultipartUpload(t *testing.T) { + stub, ts := StartObjectStore() + defer ts.Close() + + objectURL := ts.URL + ObjectPath + parts := []struct { + number int + content string + contentMD5 string + }{ + { + number: 1, + content: "first part", + contentMD5: "550cf6b6e60f65a0e3104a26e70fea42", + }, { + number: 2, + content: "second part", + contentMD5: "920b914bca0a70780b40881b8f376135", + }, + } + + stub.InitiateMultipartUpload(ObjectPath) + + require.True(t, stub.IsMultipartUpload(ObjectPath)) + require.Equal(t, 0, stub.PutsCnt()) + require.Equal(t, 0, stub.DeletesCnt()) + + // Workhorse knows nothing about S3 MultipartUpload, it receives some URLs + // from GitLab-rails and PUTs chunk of data to each of them. + // Then it completes the upload with a final POST + partPutURLs := []string{ + fmt.Sprintf("%s?partNumber=%d", objectURL, 1), + fmt.Sprintf("%s?partNumber=%d", objectURL, 2), + } + completePostURL := objectURL + + for i, partPutURL := range partPutURLs { + part := parts[i] + + require.NoError(t, doRequest(http.MethodPut, partPutURL, strings.NewReader(part.content))) + + require.Equal(t, i+1, stub.PutsCnt()) + require.Equal(t, 0, stub.DeletesCnt()) + require.Equal(t, part.contentMD5, stub.multipart[ObjectPath][part.number], "Part %d was not uploaded into ObjectStorage", part.number) + require.Empty(t, stub.GetObjectMD5(ObjectPath), "Part %d was mistakenly uploaded as a single object", part.number) + require.True(t, stub.IsMultipartUpload(ObjectPath), "MultipartUpload completed or aborted") + } + + completeBody := fmt.Sprintf(`<CompleteMultipartUpload> + <Part> + <PartNumber>1</PartNumber> + <ETag>%s</ETag> + </Part> + <Part> + <PartNumber>2</PartNumber> + <ETag>%s</ETag> + </Part> + </CompleteMultipartUpload>`, parts[0].contentMD5, parts[1].contentMD5) + require.NoError(t, doRequest(http.MethodPost, completePostURL, strings.NewReader(completeBody))) + + require.Equal(t, len(parts), stub.PutsCnt()) + require.Equal(t, 0, stub.DeletesCnt()) + require.False(t, stub.IsMultipartUpload(ObjectPath), "MultipartUpload is still in progress") +} + +func TestObjectStoreAbortMultipartUpload(t *testing.T) { + stub, ts := StartObjectStore() + defer ts.Close() + + stub.InitiateMultipartUpload(ObjectPath) + + require.True(t, stub.IsMultipartUpload(ObjectPath)) + require.Equal(t, 0, stub.PutsCnt()) + require.Equal(t, 0, stub.DeletesCnt()) + + objectURL := ts.URL + ObjectPath + require.NoError(t, doRequest(http.MethodPut, fmt.Sprintf("%s?partNumber=%d", objectURL, 1), strings.NewReader(ObjectContent))) + + require.Equal(t, 1, stub.PutsCnt()) + require.Equal(t, 0, stub.DeletesCnt()) + require.Equal(t, ObjectMD5, stub.multipart[ObjectPath][1], "Part was not uploaded into ObjectStorage") + require.Empty(t, stub.GetObjectMD5(ObjectPath), "Part was mistakenly uploaded as a single object") + require.True(t, stub.IsMultipartUpload(ObjectPath), "MultipartUpload completed or aborted") + + require.NoError(t, doRequest(http.MethodDelete, objectURL, nil)) + + require.Equal(t, 1, stub.PutsCnt()) + require.Equal(t, 1, stub.DeletesCnt()) + require.Empty(t, stub.GetObjectMD5(ObjectPath), "MultiUpload has been completed") + require.False(t, stub.IsMultipartUpload(ObjectPath), "MultiUpload is still in progress") +} diff --git a/workhorse/internal/objectstore/test/s3_stub.go b/workhorse/internal/objectstore/test/s3_stub.go new file mode 100644 index 00000000000..36514b3b887 --- /dev/null +++ b/workhorse/internal/objectstore/test/s3_stub.go @@ -0,0 +1,142 @@ +package test + +import ( + "io/ioutil" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + + "github.com/johannesboyne/gofakes3" + "github.com/johannesboyne/gofakes3/backend/s3mem" +) + +func SetupS3(t *testing.T, encryption string) (config.S3Credentials, config.S3Config, *session.Session, *httptest.Server) { + return SetupS3WithBucket(t, "test-bucket", encryption) +} + +func SetupS3WithBucket(t *testing.T, bucket string, encryption string) (config.S3Credentials, config.S3Config, *session.Session, *httptest.Server) { + backend := s3mem.New() + faker := gofakes3.New(backend) + ts := httptest.NewServer(faker.Server()) + + creds := config.S3Credentials{ + AwsAccessKeyID: "YOUR-ACCESSKEYID", + AwsSecretAccessKey: "YOUR-SECRETACCESSKEY", + } + + config := config.S3Config{ + Bucket: bucket, + Endpoint: ts.URL, + Region: "eu-central-1", + PathStyle: true, + } + + if encryption != "" { + config.ServerSideEncryption = encryption + + if encryption == s3.ServerSideEncryptionAwsKms { + config.SSEKMSKeyID = "arn:aws:1234" + } + } + + sess, err := session.NewSession(&aws.Config{ + Credentials: credentials.NewStaticCredentials(creds.AwsAccessKeyID, creds.AwsSecretAccessKey, ""), + Endpoint: aws.String(ts.URL), + Region: aws.String(config.Region), + DisableSSL: aws.Bool(true), + S3ForcePathStyle: aws.Bool(true), + }) + require.NoError(t, err) + + // Create S3 service client + svc := s3.New(sess) + + _, err = svc.CreateBucket(&s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + require.NoError(t, err) + + return creds, config, sess, ts +} + +// S3ObjectExists will fail the test if the file does not exist. +func S3ObjectExists(t *testing.T, sess *session.Session, config config.S3Config, objectName string, expectedBytes string) { + downloadObject(t, sess, config, objectName, func(tmpfile *os.File, numBytes int64, err error) { + require.NoError(t, err) + require.Equal(t, int64(len(expectedBytes)), numBytes) + + output, err := ioutil.ReadFile(tmpfile.Name()) + require.NoError(t, err) + + require.Equal(t, []byte(expectedBytes), output) + }) +} + +func CheckS3Metadata(t *testing.T, sess *session.Session, config config.S3Config, objectName string) { + // In a real S3 provider, s3crypto.NewDecryptionClient should probably be used + svc := s3.New(sess) + result, err := svc.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(config.Bucket), + Key: aws.String(objectName), + }) + require.NoError(t, err) + + if config.ServerSideEncryption != "" { + require.Equal(t, aws.String(config.ServerSideEncryption), result.ServerSideEncryption) + + if config.ServerSideEncryption == s3.ServerSideEncryptionAwsKms { + require.Equal(t, aws.String(config.SSEKMSKeyID), result.SSEKMSKeyId) + } else { + require.Nil(t, result.SSEKMSKeyId) + } + } else { + require.Nil(t, result.ServerSideEncryption) + require.Nil(t, result.SSEKMSKeyId) + } +} + +// S3ObjectDoesNotExist returns true if the object has been deleted, +// false otherwise. The return signature is different from +// S3ObjectExists because deletion may need to be retried since deferred +// clean up callsinternal/objectstore/test/s3_stub.go may cause the actual deletion to happen after the +// initial check. +func S3ObjectDoesNotExist(t *testing.T, sess *session.Session, config config.S3Config, objectName string) bool { + deleted := false + + downloadObject(t, sess, config, objectName, func(tmpfile *os.File, numBytes int64, err error) { + if err != nil && strings.Contains(err.Error(), "NoSuchKey") { + deleted = true + } + }) + + return deleted +} + +func downloadObject(t *testing.T, sess *session.Session, config config.S3Config, objectName string, handler func(tmpfile *os.File, numBytes int64, err error)) { + tmpDir, err := ioutil.TempDir("", "workhorse-test-") + require.NoError(t, err) + defer os.Remove(tmpDir) + + tmpfile, err := ioutil.TempFile(tmpDir, "s3-output") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + downloadSvc := s3manager.NewDownloader(sess) + numBytes, err := downloadSvc.Download(tmpfile, &s3.GetObjectInput{ + Bucket: aws.String(config.Bucket), + Key: aws.String(objectName), + }) + + handler(tmpfile, numBytes, err) +} diff --git a/workhorse/internal/objectstore/upload_strategy.go b/workhorse/internal/objectstore/upload_strategy.go new file mode 100644 index 00000000000..5707ba5f24e --- /dev/null +++ b/workhorse/internal/objectstore/upload_strategy.go @@ -0,0 +1,46 @@ +package objectstore + +import ( + "context" + "io" + "net/http" + + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" +) + +type uploadStrategy interface { + Upload(ctx context.Context, r io.Reader) error + ETag() string + Abort() + Delete() +} + +func deleteURL(url string) { + if url == "" { + return + } + + req, err := http.NewRequest("DELETE", url, nil) + if err != nil { + log.WithError(err).WithField("object", mask.URL(url)).Warning("Delete failed") + return + } + // TODO: consider adding the context to the outgoing request for better instrumentation + + // here we are not using u.ctx because we must perform cleanup regardless of parent context + resp, err := httpClient.Do(req) + if err != nil { + log.WithError(err).WithField("object", mask.URL(url)).Warning("Delete failed") + return + } + resp.Body.Close() +} + +func extractETag(rawETag string) string { + if rawETag != "" && rawETag[0] == '"' { + rawETag = rawETag[1 : len(rawETag)-1] + } + + return rawETag +} diff --git a/workhorse/internal/objectstore/uploader.go b/workhorse/internal/objectstore/uploader.go new file mode 100644 index 00000000000..aedfbe55ead --- /dev/null +++ b/workhorse/internal/objectstore/uploader.go @@ -0,0 +1,115 @@ +package objectstore + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "hash" + "io" + "strings" + "time" + + "gitlab.com/gitlab-org/labkit/log" +) + +// uploader consumes an io.Reader and uploads it using a pluggable uploadStrategy. +type uploader struct { + strategy uploadStrategy + + // In the case of S3 uploads, we have a multipart upload which + // instantiates uploads for the individual parts. We don't want to + // increment metrics for the individual parts, so that is why we have + // this boolean flag. + metrics bool + + // With S3 we compare the MD5 of the data we sent with the ETag returned + // by the object storage server. + checkETag bool +} + +func newUploader(strategy uploadStrategy) *uploader { + return &uploader{strategy: strategy, metrics: true} +} + +func newETagCheckUploader(strategy uploadStrategy, metrics bool) *uploader { + return &uploader{strategy: strategy, metrics: metrics, checkETag: true} +} + +func hexString(h hash.Hash) string { return hex.EncodeToString(h.Sum(nil)) } + +// Consume reads the reader until it reaches EOF or an error. It spawns a +// goroutine that waits for outerCtx to be done, after which the remote +// file is deleted. The deadline applies to the upload performed inside +// Consume, not to outerCtx. +func (u *uploader) Consume(outerCtx context.Context, reader io.Reader, deadline time.Time) (_ int64, err error) { + if u.metrics { + objectStorageUploadsOpen.Inc() + defer func(started time.Time) { + objectStorageUploadsOpen.Dec() + objectStorageUploadTime.Observe(time.Since(started).Seconds()) + if err != nil { + objectStorageUploadRequestsRequestFailed.Inc() + } + }(time.Now()) + } + + defer func() { + // We do this mainly to abort S3 multipart uploads: it is not enough to + // "delete" them. + if err != nil { + u.strategy.Abort() + } + }() + + go func() { + // Once gitlab-rails is done handling the request, we are supposed to + // delete the upload from its temporary location. + <-outerCtx.Done() + u.strategy.Delete() + }() + + uploadCtx, cancelFn := context.WithDeadline(outerCtx, deadline) + defer cancelFn() + + var hasher hash.Hash + if u.checkETag { + hasher = md5.New() + reader = io.TeeReader(reader, hasher) + } + + cr := &countReader{r: reader} + if err := u.strategy.Upload(uploadCtx, cr); err != nil { + return cr.n, err + } + + if u.checkETag { + if err := compareMD5(hexString(hasher), u.strategy.ETag()); err != nil { + log.ContextLogger(uploadCtx).WithError(err).Error("error comparing MD5 checksum") + return cr.n, err + } + } + + objectStorageUploadBytes.Add(float64(cr.n)) + + return cr.n, nil +} + +func compareMD5(local, remote string) error { + if !strings.EqualFold(local, remote) { + return fmt.Errorf("ETag mismatch. expected %q got %q", local, remote) + } + + return nil +} + +type countReader struct { + r io.Reader + n int64 +} + +func (cr *countReader) Read(p []byte) (int, error) { + nRead, err := cr.r.Read(p) + cr.n += int64(nRead) + return nRead, err +} diff --git a/workhorse/internal/proxy/proxy.go b/workhorse/internal/proxy/proxy.go new file mode 100644 index 00000000000..1bc417a841f --- /dev/null +++ b/workhorse/internal/proxy/proxy.go @@ -0,0 +1,62 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + defaultTarget = helper.URLMustParse("http://localhost") +) + +type Proxy struct { + Version string + reverseProxy *httputil.ReverseProxy + AllowResponseBuffering bool +} + +func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper) *Proxy { + p := Proxy{Version: version, AllowResponseBuffering: true} + + if myURL == nil { + myURL = defaultTarget + } + + u := *myURL // Make a copy of p.URL + u.Path = "" + p.reverseProxy = httputil.NewSingleHostReverseProxy(&u) + p.reverseProxy.Transport = roundTripper + return &p +} + +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Clone request + req := *r + req.Header = helper.HeaderClone(r.Header) + + // Set Workhorse version + req.Header.Set("Gitlab-Workhorse", p.Version) + req.Header.Set("Gitlab-Workhorse-Proxy-Start", fmt.Sprintf("%d", time.Now().UnixNano())) + + if p.AllowResponseBuffering { + helper.AllowResponseBuffering(w) + } + + // If the ultimate client disconnects when the response isn't fully written + // to them yet, httputil.ReverseProxy panics with a net/http.ErrAbortHandler + // error. We can catch and discard this to keep the error log clean + defer func() { + if err := recover(); err != nil { + if err != http.ErrAbortHandler { + panic(err) + } + } + }() + + p.reverseProxy.ServeHTTP(w, &req) +} diff --git a/workhorse/internal/queueing/queue.go b/workhorse/internal/queueing/queue.go new file mode 100644 index 00000000000..db082cf19c6 --- /dev/null +++ b/workhorse/internal/queueing/queue.go @@ -0,0 +1,201 @@ +package queueing + +import ( + "errors" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +type errTooManyRequests struct{ error } +type errQueueingTimedout struct{ error } + +var ErrTooManyRequests = &errTooManyRequests{errors.New("too many requests queued")} +var ErrQueueingTimedout = &errQueueingTimedout{errors.New("queueing timedout")} + +type queueMetrics struct { + queueingLimit prometheus.Gauge + queueingQueueLimit prometheus.Gauge + queueingQueueTimeout prometheus.Gauge + queueingBusy prometheus.Gauge + queueingWaiting prometheus.Gauge + queueingWaitingTime prometheus.Histogram + queueingErrors *prometheus.CounterVec +} + +// newQueueMetrics prepares Prometheus metrics for queueing mechanism +// name specifies name of the queue, used to label metrics with ConstLabel `queue_name` +// Don't call newQueueMetrics twice with the same name argument! +// timeout specifies the timeout of storing a request in queue - queueMetrics +// uses it to calculate histogram buckets for gitlab_workhorse_queueing_waiting_time +// metric +func newQueueMetrics(name string, timeout time.Duration) *queueMetrics { + waitingTimeBuckets := []float64{ + timeout.Seconds() * 0.01, + timeout.Seconds() * 0.05, + timeout.Seconds() * 0.10, + timeout.Seconds() * 0.25, + timeout.Seconds() * 0.50, + timeout.Seconds() * 0.75, + timeout.Seconds() * 0.90, + timeout.Seconds() * 0.95, + timeout.Seconds() * 0.99, + timeout.Seconds(), + } + + metrics := &queueMetrics{ + queueingLimit: promauto.NewGauge(prometheus.GaugeOpts{ + Name: "gitlab_workhorse_queueing_limit", + Help: "Current limit set for the queueing mechanism", + ConstLabels: prometheus.Labels{ + "queue_name": name, + }, + }), + + queueingQueueLimit: promauto.NewGauge(prometheus.GaugeOpts{ + Name: "gitlab_workhorse_queueing_queue_limit", + Help: "Current queueLimit set for the queueing mechanism", + ConstLabels: prometheus.Labels{ + "queue_name": name, + }, + }), + + queueingQueueTimeout: promauto.NewGauge(prometheus.GaugeOpts{ + Name: "gitlab_workhorse_queueing_queue_timeout", + Help: "Current queueTimeout set for the queueing mechanism", + ConstLabels: prometheus.Labels{ + "queue_name": name, + }, + }), + + queueingBusy: promauto.NewGauge(prometheus.GaugeOpts{ + Name: "gitlab_workhorse_queueing_busy", + Help: "How many queued requests are now processed", + ConstLabels: prometheus.Labels{ + "queue_name": name, + }, + }), + + queueingWaiting: promauto.NewGauge(prometheus.GaugeOpts{ + Name: "gitlab_workhorse_queueing_waiting", + Help: "How many requests are now queued", + ConstLabels: prometheus.Labels{ + "queue_name": name, + }, + }), + + queueingWaitingTime: promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "gitlab_workhorse_queueing_waiting_time", + Help: "How many time a request spent in queue", + ConstLabels: prometheus.Labels{ + "queue_name": name, + }, + Buckets: waitingTimeBuckets, + }), + + queueingErrors: promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_queueing_errors", + Help: "How many times the TooManyRequests or QueueintTimedout errors were returned while queueing, partitioned by error type", + ConstLabels: prometheus.Labels{ + "queue_name": name, + }, + }, + []string{"type"}, + ), + } + + return metrics +} + +type Queue struct { + *queueMetrics + + name string + busyCh chan struct{} + waitingCh chan time.Time + timeout time.Duration +} + +// newQueue creates a new queue +// name specifies name used to label queue metrics. +// Don't call newQueue twice with the same name argument! +// limit specifies number of requests run concurrently +// queueLimit specifies maximum number of requests that can be queued +// timeout specifies the time limit of storing the request in the queue +// if the number of requests is above the limit +func newQueue(name string, limit, queueLimit uint, timeout time.Duration) *Queue { + queue := &Queue{ + name: name, + busyCh: make(chan struct{}, limit), + waitingCh: make(chan time.Time, limit+queueLimit), + timeout: timeout, + } + + queue.queueMetrics = newQueueMetrics(name, timeout) + queue.queueingLimit.Set(float64(limit)) + queue.queueingQueueLimit.Set(float64(queueLimit)) + queue.queueingQueueTimeout.Set(timeout.Seconds()) + + return queue +} + +// Acquire takes one slot from the Queue +// and returns when a request should be processed +// it allows up to (limit) of requests running at a time +// it allows to queue up to (queue-limit) requests +func (s *Queue) Acquire() (err error) { + // push item to a queue to claim your own slot (non-blocking) + select { + case s.waitingCh <- time.Now(): + s.queueingWaiting.Inc() + break + default: + s.queueingErrors.WithLabelValues("too_many_requests").Inc() + return ErrTooManyRequests + } + + defer func() { + if err != nil { + waitStarted := <-s.waitingCh + s.queueingWaiting.Dec() + s.queueingWaitingTime.Observe(float64(time.Since(waitStarted).Seconds())) + } + }() + + // fast path: push item to current processed items (non-blocking) + select { + case s.busyCh <- struct{}{}: + s.queueingBusy.Inc() + return nil + default: + break + } + + timer := time.NewTimer(s.timeout) + defer timer.Stop() + + // push item to current processed items (blocking) + select { + case s.busyCh <- struct{}{}: + s.queueingBusy.Inc() + return nil + + case <-timer.C: + s.queueingErrors.WithLabelValues("queueing_timedout").Inc() + return ErrQueueingTimedout + } +} + +// Release marks the finish of processing of requests +// It triggers next request to be processed if it's in queue +func (s *Queue) Release() { + // dequeue from queue to allow next request to be processed + waitStarted := <-s.waitingCh + s.queueingWaiting.Dec() + s.queueingWaitingTime.Observe(float64(time.Since(waitStarted).Seconds())) + + <-s.busyCh + s.queueingBusy.Dec() +} diff --git a/workhorse/internal/queueing/queue_test.go b/workhorse/internal/queueing/queue_test.go new file mode 100644 index 00000000000..7f5ed9154f4 --- /dev/null +++ b/workhorse/internal/queueing/queue_test.go @@ -0,0 +1,62 @@ +package queueing + +import ( + "testing" + "time" +) + +func TestNormalQueueing(t *testing.T) { + q := newQueue("queue 1", 2, 1, time.Microsecond) + err1 := q.Acquire() + if err1 != nil { + t.Fatal("we should acquire a new slot") + } + + err2 := q.Acquire() + if err2 != nil { + t.Fatal("we should acquire a new slot") + } + + err3 := q.Acquire() + if err3 != ErrQueueingTimedout { + t.Fatal("we should timeout") + } + + q.Release() + + err4 := q.Acquire() + if err4 != nil { + t.Fatal("we should acquire a new slot") + } +} + +func TestQueueLimit(t *testing.T) { + q := newQueue("queue 2", 1, 0, time.Microsecond) + err1 := q.Acquire() + if err1 != nil { + t.Fatal("we should acquire a new slot") + } + + err2 := q.Acquire() + if err2 != ErrTooManyRequests { + t.Fatal("we should fail because of not enough slots in queue") + } +} + +func TestQueueProcessing(t *testing.T) { + q := newQueue("queue 3", 1, 1, time.Second) + err1 := q.Acquire() + if err1 != nil { + t.Fatal("we should acquire a new slot") + } + + go func() { + time.Sleep(50 * time.Microsecond) + q.Release() + }() + + err2 := q.Acquire() + if err2 != nil { + t.Fatal("we should acquire slot after the previous one finished") + } +} diff --git a/workhorse/internal/queueing/requests.go b/workhorse/internal/queueing/requests.go new file mode 100644 index 00000000000..409a7656fa4 --- /dev/null +++ b/workhorse/internal/queueing/requests.go @@ -0,0 +1,51 @@ +package queueing + +import ( + "net/http" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +const ( + DefaultTimeout = 30 * time.Second + httpStatusTooManyRequests = 429 +) + +// QueueRequests creates a new request queue +// name specifies the name of queue, used to label Prometheus metrics +// Don't call QueueRequests twice with the same name argument! +// h specifies a http.Handler which will handle the queue requests +// limit specifies number of requests run concurrently +// queueLimit specifies maximum number of requests that can be queued +// queueTimeout specifies the time limit of storing the request in the queue +func QueueRequests(name string, h http.Handler, limit, queueLimit uint, queueTimeout time.Duration) http.Handler { + if limit == 0 { + return h + } + if queueTimeout == 0 { + queueTimeout = DefaultTimeout + } + + queue := newQueue(name, limit, queueLimit, queueTimeout) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := queue.Acquire() + + switch err { + case nil: + defer queue.Release() + h.ServeHTTP(w, r) + + case ErrTooManyRequests: + http.Error(w, "Too Many Requests", httpStatusTooManyRequests) + + case ErrQueueingTimedout: + http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) + + default: + helper.Fail500(w, r, err) + } + + }) +} diff --git a/workhorse/internal/queueing/requests_test.go b/workhorse/internal/queueing/requests_test.go new file mode 100644 index 00000000000..f1c52e5c6f5 --- /dev/null +++ b/workhorse/internal/queueing/requests_test.go @@ -0,0 +1,76 @@ +package queueing + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +var httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "OK") +}) + +func pausedHttpHandler(pauseCh chan struct{}) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-pauseCh + fmt.Fprintln(w, "OK") + }) +} + +func TestNormalRequestProcessing(t *testing.T) { + w := httptest.NewRecorder() + h := QueueRequests("Normal request processing", httpHandler, 1, 1, time.Second) + h.ServeHTTP(w, nil) + if w.Code != 200 { + t.Fatal("QueueRequests should process request") + } +} + +// testSlowRequestProcessing creates a new queue, +// then it runs a number of requests that are going through queue, +// we return the response of first finished request, +// where status of request can be 200, 429 or 503 +func testSlowRequestProcessing(name string, count int, limit, queueLimit uint, queueTimeout time.Duration) *httptest.ResponseRecorder { + pauseCh := make(chan struct{}) + defer close(pauseCh) + + handler := QueueRequests("Slow request processing: "+name, pausedHttpHandler(pauseCh), limit, queueLimit, queueTimeout) + + respCh := make(chan *httptest.ResponseRecorder, count) + + // queue requests to use up the queue + for i := 0; i < count; i++ { + go func() { + w := httptest.NewRecorder() + handler.ServeHTTP(w, nil) + respCh <- w + }() + } + + // dequeue first request + return <-respCh +} + +// TestQueueingTimeout performs 2 requests +// the queue limit and length is 1, +// the second request gets timed-out +func TestQueueingTimeout(t *testing.T) { + w := testSlowRequestProcessing("timeout", 2, 1, 1, time.Microsecond) + + if w.Code != 503 { + t.Fatal("QueueRequests should timeout queued request") + } +} + +// TestQueueingTooManyRequests performs 3 requests +// the queue limit and length is 1, +// so the third request has to be rejected with 429 +func TestQueueingTooManyRequests(t *testing.T) { + w := testSlowRequestProcessing("too many requests", 3, 1, 1, time.Minute) + + if w.Code != 429 { + t.Fatal("QueueRequests should return immediately and return too many requests") + } +} diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go new file mode 100644 index 00000000000..96e33a64b85 --- /dev/null +++ b/workhorse/internal/redis/keywatcher.go @@ -0,0 +1,198 @@ +package redis + +import ( + "fmt" + "strings" + "sync" + "time" + + "github.com/gomodule/redigo/redis" + "github.com/jpillora/backoff" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + keyWatcher = make(map[string][]chan string) + keyWatcherMutex sync.Mutex + redisReconnectTimeout = backoff.Backoff{ + //These are the defaults + Min: 100 * time.Millisecond, + Max: 60 * time.Second, + Factor: 2, + Jitter: true, + } + keyWatchers = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "gitlab_workhorse_keywatcher_keywatchers", + Help: "The number of keys that is being watched by gitlab-workhorse", + }, + ) + totalMessages = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_keywatcher_total_messages", + Help: "How many messages gitlab-workhorse has received in total on pubsub.", + }, + ) +) + +const ( + keySubChannel = "workhorse:notifications" +) + +// KeyChan holds a key and a channel +type KeyChan struct { + Key string + Chan chan string +} + +func processInner(conn redis.Conn) error { + defer conn.Close() + psc := redis.PubSubConn{Conn: conn} + if err := psc.Subscribe(keySubChannel); err != nil { + return err + } + defer psc.Unsubscribe(keySubChannel) + + for { + switch v := psc.Receive().(type) { + case redis.Message: + totalMessages.Inc() + dataStr := string(v.Data) + msg := strings.SplitN(dataStr, "=", 2) + if len(msg) != 2 { + helper.LogError(nil, fmt.Errorf("keywatcher: invalid notification: %q", dataStr)) + continue + } + key, value := msg[0], msg[1] + notifyChanWatchers(key, value) + case error: + helper.LogError(nil, fmt.Errorf("keywatcher: pubsub receive: %v", v)) + // Intermittent error, return nil so that it doesn't wait before reconnect + return nil + } + } +} + +func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) { + conn, err := dialer() + if err != nil { + return nil, err + } + + // Make sure Redis is actually connected + conn.Do("PING") + if err := conn.Err(); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +// Process redis subscriptions +// +// NOTE: There Can Only Be One! +func Process() { + log.Info("keywatcher: starting process loop") + for { + conn, err := dialPubSub(workerDialFunc) + if err != nil { + helper.LogError(nil, fmt.Errorf("keywatcher: %v", err)) + time.Sleep(redisReconnectTimeout.Duration()) + continue + } + redisReconnectTimeout.Reset() + + if err = processInner(conn); err != nil { + helper.LogError(nil, fmt.Errorf("keywatcher: process loop: %v", err)) + } + } +} + +func notifyChanWatchers(key, value string) { + keyWatcherMutex.Lock() + defer keyWatcherMutex.Unlock() + if chanList, ok := keyWatcher[key]; ok { + for _, c := range chanList { + c <- value + keyWatchers.Dec() + } + delete(keyWatcher, key) + } +} + +func addKeyChan(kc *KeyChan) { + keyWatcherMutex.Lock() + defer keyWatcherMutex.Unlock() + keyWatcher[kc.Key] = append(keyWatcher[kc.Key], kc.Chan) + keyWatchers.Inc() +} + +func delKeyChan(kc *KeyChan) { + keyWatcherMutex.Lock() + defer keyWatcherMutex.Unlock() + if chans, ok := keyWatcher[kc.Key]; ok { + for i, c := range chans { + if kc.Chan == c { + keyWatcher[kc.Key] = append(chans[:i], chans[i+1:]...) + keyWatchers.Dec() + break + } + } + if len(keyWatcher[kc.Key]) == 0 { + delete(keyWatcher, kc.Key) + } + } +} + +// WatchKeyStatus is used to tell how WatchKey returned +type WatchKeyStatus int + +const ( + // WatchKeyStatusTimeout is returned when the watch timeout provided by the caller was exceeded + WatchKeyStatusTimeout WatchKeyStatus = iota + // WatchKeyStatusAlreadyChanged is returned when the value passed by the caller was never observed + WatchKeyStatusAlreadyChanged + // WatchKeyStatusSeenChange is returned when we have seen the value passed by the caller get changed + WatchKeyStatusSeenChange + // WatchKeyStatusNoChange is returned when the function had to return before observing a change. + // Also returned on errors. + WatchKeyStatusNoChange +) + +// WatchKey waits for a key to be updated or expired +func WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error) { + kw := &KeyChan{ + Key: key, + Chan: make(chan string, 1), + } + + addKeyChan(kw) + defer delKeyChan(kw) + + currentValue, err := GetString(key) + if err != nil { + return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) + } + if currentValue != value { + return WatchKeyStatusAlreadyChanged, nil + } + + select { + case currentValue := <-kw.Chan: + if currentValue == "" { + return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") + } + if currentValue == value { + return WatchKeyStatusNoChange, nil + } + return WatchKeyStatusSeenChange, nil + + case <-time.After(timeout): + return WatchKeyStatusTimeout, nil + } +} diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go new file mode 100644 index 00000000000..f1ee77e2194 --- /dev/null +++ b/workhorse/internal/redis/keywatcher_test.go @@ -0,0 +1,162 @@ +package redis + +import ( + "sync" + "testing" + "time" + + "github.com/rafaeljusto/redigomock" + "github.com/stretchr/testify/require" +) + +const ( + runnerKey = "runner:build_queue:10" +) + +func createSubscriptionMessage(key, data string) []interface{} { + return []interface{}{ + []byte("message"), + []byte(key), + []byte(data), + } +} + +func createSubscribeMessage(key string) []interface{} { + return []interface{}{ + []byte("subscribe"), + []byte(key), + []byte("1"), + } +} +func createUnsubscribeMessage(key string) []interface{} { + return []interface{}{ + []byte("unsubscribe"), + []byte(key), + []byte("1"), + } +} + +func countWatchers(key string) int { + keyWatcherMutex.Lock() + defer keyWatcherMutex.Unlock() + return len(keyWatcher[key]) +} + +func deleteWatchers(key string) { + keyWatcherMutex.Lock() + defer keyWatcherMutex.Unlock() + delete(keyWatcher, key) +} + +// Forces a run of the `Process` loop against a mock PubSubConn. +func processMessages(numWatchers int, value string) { + psc := redigomock.NewConn() + + // Setup the initial subscription message + psc.Command("SUBSCRIBE", keySubChannel).Expect(createSubscribeMessage(keySubChannel)) + psc.Command("UNSUBSCRIBE", keySubChannel).Expect(createUnsubscribeMessage(keySubChannel)) + psc.AddSubscriptionMessage(createSubscriptionMessage(keySubChannel, runnerKey+"="+value)) + + // Wait for all the `WatchKey` calls to be registered + for countWatchers(runnerKey) != numWatchers { + time.Sleep(time.Millisecond) + } + + processInner(psc) +} + +func TestWatchKeySeenChange(t *testing.T) { + conn, td := setupMockPool() + defer td() + + conn.Command("GET", runnerKey).Expect("something") + + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + val, err := WatchKey(runnerKey, "something", time.Second) + require.NoError(t, err, "Expected no error") + require.Equal(t, WatchKeyStatusSeenChange, val, "Expected value to change") + wg.Done() + }() + + processMessages(1, "somethingelse") + wg.Wait() +} + +func TestWatchKeyNoChange(t *testing.T) { + conn, td := setupMockPool() + defer td() + + conn.Command("GET", runnerKey).Expect("something") + + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + val, err := WatchKey(runnerKey, "something", time.Second) + require.NoError(t, err, "Expected no error") + require.Equal(t, WatchKeyStatusNoChange, val, "Expected notification without change to value") + wg.Done() + }() + + processMessages(1, "something") + wg.Wait() +} + +func TestWatchKeyTimeout(t *testing.T) { + conn, td := setupMockPool() + defer td() + + conn.Command("GET", runnerKey).Expect("something") + + val, err := WatchKey(runnerKey, "something", time.Millisecond) + require.NoError(t, err, "Expected no error") + require.Equal(t, WatchKeyStatusTimeout, val, "Expected value to not change") + + // Clean up watchers since Process isn't doing that for us (not running) + deleteWatchers(runnerKey) +} + +func TestWatchKeyAlreadyChanged(t *testing.T) { + conn, td := setupMockPool() + defer td() + + conn.Command("GET", runnerKey).Expect("somethingelse") + + val, err := WatchKey(runnerKey, "something", time.Second) + require.NoError(t, err, "Expected no error") + require.Equal(t, WatchKeyStatusAlreadyChanged, val, "Expected value to have already changed") + + // Clean up watchers since Process isn't doing that for us (not running) + deleteWatchers(runnerKey) +} + +func TestWatchKeyMassivelyParallel(t *testing.T) { + runTimes := 100 // 100 parallel watchers + + conn, td := setupMockPool() + defer td() + + wg := &sync.WaitGroup{} + wg.Add(runTimes) + + getCmd := conn.Command("GET", runnerKey) + + for i := 0; i < runTimes; i++ { + getCmd = getCmd.Expect("something") + } + + for i := 0; i < runTimes; i++ { + go func() { + val, err := WatchKey(runnerKey, "something", time.Second) + require.NoError(t, err, "Expected no error") + require.Equal(t, WatchKeyStatusSeenChange, val, "Expected value to change") + wg.Done() + }() + } + + processMessages(runTimes, "somethingelse") + wg.Wait() +} diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go new file mode 100644 index 00000000000..0029a2a9e2b --- /dev/null +++ b/workhorse/internal/redis/redis.go @@ -0,0 +1,295 @@ +package redis + +import ( + "errors" + "fmt" + "net" + "net/url" + "time" + + "github.com/FZambia/sentinel" + "github.com/gomodule/redigo/redis" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + pool *redis.Pool + sntnl *sentinel.Sentinel +) + +const ( + // Max Idle Connections in the pool. + defaultMaxIdle = 1 + // Max Active Connections in the pool. + defaultMaxActive = 1 + // Timeout for Read operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultReadTimeout = 1 * time.Second + // Timeout for Write operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultWriteTimeout = 1 * time.Second + // Timeout before killing Idle connections in the pool. 3 minutes seemed good. + // If you _actually_ hit this timeout often, you should consider turning of + // redis-support since it's not necessary at that point... + defaultIdleTimeout = 3 * time.Minute + // KeepAlivePeriod is to keep a TCP connection open for an extended period of + // time without being killed. This is used both in the pool, and in the + // worker-connection. + // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more + // information. + defaultKeepAlivePeriod = 5 * time.Minute +) + +var ( + totalConnections = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_total_connections", + Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", + }, + ) + + errorCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_errors", + Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", + }, + []string{"type", "dst"}, + ) +) + +func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { + if len(urls) == 0 { + return nil + } + var addrs []string + for _, url := range urls { + h := url.URL.String() + log.WithFields(log.Fields{ + "scheme": url.URL.Scheme, + "host": url.URL.Host, + }).Printf("redis: using sentinel") + addrs = append(addrs, h) + } + return &sentinel.Sentinel{ + Addrs: addrs, + MasterName: master, + Dial: func(addr string) (redis.Conn, error) { + // This timeout is recommended for Sentinel-support according to the guidelines. + // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel + // For every address it should try to connect to the Sentinel, + // using a short timeout (in the order of a few hundreds of milliseconds). + timeout := 500 * time.Millisecond + url := helper.URLMustParse(addr) + + var c redis.Conn + var err error + options := []redis.DialOption{ + redis.DialConnectTimeout(timeout), + redis.DialReadTimeout(timeout), + redis.DialWriteTimeout(timeout), + } + + if url.Scheme == "redis" || url.Scheme == "rediss" { + c, err = redis.DialURL(addr, options...) + } else { + c, err = redis.Dial("tcp", url.Host, options...) + } + + if err != nil { + errorCounter.WithLabelValues("dial", "sentinel").Inc() + return nil, err + } + return c, nil + }, + } +} + +var poolDialFunc func() (redis.Conn, error) +var workerDialFunc func() (redis.Conn, error) + +func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { + readTimeout := defaultReadTimeout + writeTimeout := defaultWriteTimeout + + if cfg != nil { + if cfg.ReadTimeout != nil { + readTimeout = cfg.ReadTimeout.Duration + } + + if cfg.WriteTimeout != nil { + writeTimeout = cfg.WriteTimeout.Duration + } + } + return []redis.DialOption{ + redis.DialReadTimeout(readTimeout), + redis.DialWriteTimeout(writeTimeout), + } +} + +func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption { + var dopts []redis.DialOption + if setTimeouts { + dopts = timeoutDialOptions(cfg) + } + if cfg == nil { + return dopts + } + if cfg.Password != "" { + dopts = append(dopts, redis.DialPassword(cfg.Password)) + } + if cfg.DB != nil { + dopts = append(dopts, redis.DialDatabase(*cfg.DB)) + } + return dopts +} + +func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) { + return func(network, address string) (net.Conn, error) { + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + tc, err := net.DialTCP(network, nil, addr) + if err != nil { + return nil, err + } + if err := tc.SetKeepAlive(true); err != nil { + return nil, err + } + if err := tc.SetKeepAlivePeriod(timeout); err != nil { + return nil, err + } + return tc, nil + } +} + +type redisDialerFunc func() (redis.Conn, error) + +func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc { + return func() (redis.Conn, error) { + address, err := sntnl.MasterAddr() + if err != nil { + errorCounter.WithLabelValues("master", "sentinel").Inc() + return nil, err + } + dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod))) + return redisDial("tcp", address, dopts...) + } +} + +func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc { + return func() (redis.Conn, error) { + if url.Scheme == "unix" { + return redisDial(url.Scheme, url.Path, dopts...) + } + + dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod))) + + // redis.DialURL only works with redis[s]:// URLs + if url.Scheme == "redis" || url.Scheme == "rediss" { + return redisURLDial(url, dopts...) + } + + return redisDial(url.Scheme, url.Host, dopts...) + } +} + +func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) { + log.WithFields(log.Fields{ + "scheme": url.Scheme, + "address": url.Host, + }).Printf("redis: dialing") + + return redis.DialURL(url.String(), options...) +} + +func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) { + log.WithFields(log.Fields{ + "network": network, + "address": address, + }).Printf("redis: dialing") + + return redis.Dial(network, address, options...) +} + +func countDialer(dialer redisDialerFunc) redisDialerFunc { + return func() (redis.Conn, error) { + c, err := dialer() + if err != nil { + errorCounter.WithLabelValues("dial", "redis").Inc() + } else { + totalConnections.Inc() + } + return c, err + } +} + +// DefaultDialFunc should always used. Only exception is for unit-tests. +func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { + keepAlivePeriod := defaultKeepAlivePeriod + if cfg.KeepAlivePeriod != nil { + keepAlivePeriod = cfg.KeepAlivePeriod.Duration + } + dopts := dialOptionsBuilder(cfg, setReadTimeout) + if sntnl != nil { + return countDialer(sentinelDialer(dopts, keepAlivePeriod)) + } + return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL)) +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) { + if cfg == nil { + return + } + maxIdle := defaultMaxIdle + if cfg.MaxIdle != nil { + maxIdle = *cfg.MaxIdle + } + maxActive := defaultMaxActive + if cfg.MaxActive != nil { + maxActive = *cfg.MaxActive + } + sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) + workerDialFunc = dialFunc(cfg, false) + poolDialFunc = dialFunc(cfg, true) + pool = &redis.Pool{ + MaxIdle: maxIdle, // Keep at most X hot connections + MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited + IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed + Dial: poolDialFunc, + Wait: true, + } + if sntnl != nil { + pool.TestOnBorrow = func(c redis.Conn, t time.Time) error { + if !sentinel.TestRole(c, "master") { + return errors.New("role check failed") + } + return nil + } + } +} + +// Get a connection for the Redis-pool +func Get() redis.Conn { + if pool != nil { + return pool.Get() + } + return nil +} + +// GetString fetches the value of a key in Redis as a string +func GetString(key string) (string, error) { + conn := Get() + if conn == nil { + return "", fmt.Errorf("redis: could not get connection from pool") + } + defer conn.Close() + + return redis.String(conn.Do("GET", key)) +} diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go new file mode 100644 index 00000000000..f4b4120517d --- /dev/null +++ b/workhorse/internal/redis/redis_test.go @@ -0,0 +1,234 @@ +package redis + +import ( + "net" + "testing" + "time" + + "github.com/gomodule/redigo/redis" + "github.com/rafaeljusto/redigomock" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +func mockRedisServer(t *testing.T, connectReceived *bool) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + + require.Nil(t, err) + + go func() { + defer ln.Close() + conn, err := ln.Accept() + require.Nil(t, err) + *connectReceived = true + conn.Write([]byte("OK\n")) + }() + + return ln.Addr().String() +} + +// Setup a MockPool for Redis +// +// Returns a teardown-function and the mock-connection +func setupMockPool() (*redigomock.Conn, func()) { + conn := redigomock.NewConn() + cfg := &config.RedisConfig{URL: config.TomlURL{}} + Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) { + return func() (redis.Conn, error) { + return conn, nil + } + }) + return conn, func() { + pool = nil + } +} + +func TestDefaultDialFunc(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "tcp", + }, + { + scheme: "redis", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + parsedURL := helper.URLMustParse(tc.scheme + "://" + a) + cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} + + dialer := DefaultDialFunc(cfg, true) + conn, err := dialer() + + require.Nil(t, err) + conn.Receive() + + require.True(t, connectReceived) + }) + } +} + +func TestConfigureNoConfig(t *testing.T) { + pool = nil + Configure(nil, nil) + require.Nil(t, pool, "Pool should be nil") +} + +func TestConfigureMinimalConfig(t *testing.T) { + cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""} + Configure(cfg, DefaultDialFunc) + + require.NotNil(t, pool, "Pool should not be nil") + require.Equal(t, 1, pool.MaxIdle) + require.Equal(t, 1, pool.MaxActive) + require.Equal(t, 3*time.Minute, pool.IdleTimeout) + + pool = nil +} + +func TestConfigureFullConfig(t *testing.T) { + i, a := 4, 10 + r := config.TomlDuration{Duration: 3} + cfg := &config.RedisConfig{ + URL: config.TomlURL{}, + Password: "", + MaxIdle: &i, + MaxActive: &a, + ReadTimeout: &r, + } + Configure(cfg, DefaultDialFunc) + + require.NotNil(t, pool, "Pool should not be nil") + require.Equal(t, i, pool.MaxIdle) + require.Equal(t, a, pool.MaxActive) + require.Equal(t, 3*time.Minute, pool.IdleTimeout) + + pool = nil +} + +func TestGetConnFail(t *testing.T) { + conn := Get() + require.Nil(t, conn, "Expected `conn` to be nil") +} + +func TestGetConnPass(t *testing.T) { + _, teardown := setupMockPool() + defer teardown() + conn := Get() + require.NotNil(t, conn, "Expected `conn` to be non-nil") +} + +func TestGetStringPass(t *testing.T) { + conn, teardown := setupMockPool() + defer teardown() + conn.Command("GET", "foobar").Expect("baz") + str, err := GetString("foobar") + + require.NoError(t, err, "Expected `err` to be nil") + var value string + require.IsType(t, value, str, "Expected value to be a string") + require.Equal(t, "baz", str, "Expected it to be equal") +} + +func TestGetStringFail(t *testing.T) { + _, err := GetString("foobar") + require.Error(t, err, "Expected error when not connected to redis") +} + +func TestSentinelConnNoSentinel(t *testing.T) { + s := sentinelConn("", []config.TomlURL{}) + + require.Nil(t, s, "Sentinel without urls should return nil") +} + +func TestSentinelConnDialURL(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "tcp", + }, + { + scheme: "redis", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + addrs := []string{tc.scheme + "://" + a} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + s := sentinelConn("foobar", sentinelUrls) + require.Equal(t, len(addrs), len(s.Addrs)) + + for i := range addrs { + require.Equal(t, addrs[i], s.Addrs[i]) + } + + conn, err := s.Dial(s.Addrs[0]) + + require.Nil(t, err) + conn.Receive() + + require.True(t, connectReceived) + }) + } +} + +func TestSentinelConnTwoURLs(t *testing.T) { + addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + s := sentinelConn("foobar", sentinelUrls) + require.Equal(t, len(addrs), len(s.Addrs)) + + for i := range addrs { + require.Equal(t, addrs[i], s.Addrs[i]) + } +} + +func TestDialOptionsBuildersPassword(t *testing.T) { + dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false) + require.Equal(t, 1, len(dopts)) +} + +func TestDialOptionsBuildersSetTimeouts(t *testing.T) { + dopts := dialOptionsBuilder(nil, true) + require.Equal(t, 2, len(dopts)) +} + +func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) { + cfg := &config.RedisConfig{ + ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)}, + WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)}, + } + dopts := dialOptionsBuilder(cfg, true) + require.Equal(t, 2, len(dopts)) +} + +func TestDialOptionsBuildersSelectDB(t *testing.T) { + db := 3 + dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false) + require.Equal(t, 1, len(dopts)) +} diff --git a/workhorse/internal/secret/jwt.go b/workhorse/internal/secret/jwt.go new file mode 100644 index 00000000000..04335e58f76 --- /dev/null +++ b/workhorse/internal/secret/jwt.go @@ -0,0 +1,25 @@ +package secret + +import ( + "fmt" + + "github.com/dgrijalva/jwt-go" +) + +var ( + DefaultClaims = jwt.StandardClaims{Issuer: "gitlab-workhorse"} +) + +func JWTTokenString(claims jwt.Claims) (string, error) { + secretBytes, err := Bytes() + if err != nil { + return "", fmt.Errorf("secret.JWTTokenString: %v", err) + } + + tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBytes) + if err != nil { + return "", fmt.Errorf("secret.JWTTokenString: sign JWT: %v", err) + } + + return tokenString, nil +} diff --git a/workhorse/internal/secret/roundtripper.go b/workhorse/internal/secret/roundtripper.go new file mode 100644 index 00000000000..50bf7fff5b8 --- /dev/null +++ b/workhorse/internal/secret/roundtripper.go @@ -0,0 +1,35 @@ +package secret + +import ( + "net/http" +) + +const ( + // This header carries the JWT token for gitlab-rails + RequestHeader = "Gitlab-Workhorse-Api-Request" +) + +type roundTripper struct { + next http.RoundTripper + version string +} + +// NewRoundTripper creates a RoundTripper that adds the JWT token header to a +// request. This is used to verify that a request came from workhorse +func NewRoundTripper(next http.RoundTripper, version string) http.RoundTripper { + return &roundTripper{next: next, version: version} +} + +func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + tokenString, err := JWTTokenString(DefaultClaims) + if err != nil { + return nil, err + } + + // Set a custom header for the request. This can be used in some + // configurations (Passenger) to solve auth request routing problems. + req.Header.Set("Gitlab-Workhorse", r.version) + req.Header.Set(RequestHeader, tokenString) + + return r.next.RoundTrip(req) +} diff --git a/workhorse/internal/secret/secret.go b/workhorse/internal/secret/secret.go new file mode 100644 index 00000000000..e8c7c25393c --- /dev/null +++ b/workhorse/internal/secret/secret.go @@ -0,0 +1,77 @@ +package secret + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "sync" +) + +const numSecretBytes = 32 + +type sec struct { + path string + bytes []byte + sync.RWMutex +} + +var ( + theSecret = &sec{} +) + +func SetPath(path string) { + theSecret.Lock() + defer theSecret.Unlock() + theSecret.path = path + theSecret.bytes = nil +} + +// Lazy access to the HMAC secret key. We must be lazy because if the key +// is not already there, it will be generated by gitlab-rails, and +// gitlab-rails is slow. +func Bytes() ([]byte, error) { + if bytes := getBytes(); bytes != nil { + return copyBytes(bytes), nil + } + + return setBytes() +} + +func getBytes() []byte { + theSecret.RLock() + defer theSecret.RUnlock() + return theSecret.bytes +} + +func copyBytes(bytes []byte) []byte { + out := make([]byte, len(bytes)) + copy(out, bytes) + return out +} + +func setBytes() ([]byte, error) { + theSecret.Lock() + defer theSecret.Unlock() + + if theSecret.bytes != nil { + return theSecret.bytes, nil + } + + base64Bytes, err := ioutil.ReadFile(theSecret.path) + if err != nil { + return nil, fmt.Errorf("secret.setBytes: read %q: %v", theSecret.path, err) + } + + secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes))) + n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes) + if err != nil { + return nil, fmt.Errorf("secret.setBytes: decode secret: %v", err) + } + + if n != numSecretBytes { + return nil, fmt.Errorf("secret.setBytes: expected %d secretBytes in %s, found %d", numSecretBytes, theSecret.path, n) + } + + theSecret.bytes = secretBytes + return copyBytes(theSecret.bytes), nil +} diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor.go b/workhorse/internal/senddata/contentprocessor/contentprocessor.go new file mode 100644 index 00000000000..a5cc0fee013 --- /dev/null +++ b/workhorse/internal/senddata/contentprocessor/contentprocessor.go @@ -0,0 +1,126 @@ +package contentprocessor + +import ( + "bytes" + "io" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" +) + +type contentDisposition struct { + rw http.ResponseWriter + buf *bytes.Buffer + wroteHeader bool + flushed bool + active bool + removedResponseHeaders bool + status int + sentStatus bool +} + +// SetContentHeaders buffers the response if Gitlab-Workhorse-Detect-Content-Type +// header is found and set the proper content headers based on the current +// value of content type and disposition +func SetContentHeaders(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cd := &contentDisposition{ + rw: w, + buf: &bytes.Buffer{}, + status: http.StatusOK, + } + + defer cd.flush() + + h.ServeHTTP(cd, r) + }) +} + +func (cd *contentDisposition) Header() http.Header { + return cd.rw.Header() +} + +func (cd *contentDisposition) Write(data []byte) (int, error) { + // Normal write if we don't need to buffer + if cd.isUnbuffered() { + cd.WriteHeader(cd.status) + return cd.rw.Write(data) + } + + // Write the new data into the buffer + n, _ := cd.buf.Write(data) + + // If we have enough data to calculate the content headers then flush the Buffer + var err error + if cd.buf.Len() >= headers.MaxDetectSize { + err = cd.flushBuffer() + } + + return n, err +} + +func (cd *contentDisposition) flushBuffer() error { + if cd.isUnbuffered() { + return nil + } + + cd.flushed = true + + // If the buffer has any content then we calculate the content headers and + // write in the response + if cd.buf.Len() > 0 { + cd.writeContentHeaders() + cd.WriteHeader(cd.status) + _, err := io.Copy(cd.rw, cd.buf) + return err + } + + // If no content is present in the buffer we still need to send the headers + cd.WriteHeader(cd.status) + return nil +} + +func (cd *contentDisposition) writeContentHeaders() { + if cd.wroteHeader { + return + } + + cd.wroteHeader = true + contentType, contentDisposition := headers.SafeContentHeaders(cd.buf.Bytes(), cd.Header().Get(headers.ContentDispositionHeader)) + cd.Header().Set(headers.ContentTypeHeader, contentType) + cd.Header().Set(headers.ContentDispositionHeader, contentDisposition) +} + +func (cd *contentDisposition) WriteHeader(status int) { + if cd.sentStatus { + return + } + + cd.status = status + + if cd.isUnbuffered() { + cd.rw.WriteHeader(cd.status) + cd.sentStatus = true + } +} + +// If we find any response header, then we must calculate the content headers +// If we don't find any, the data is not buffered and it works as +// a usual ResponseWriter +func (cd *contentDisposition) isUnbuffered() bool { + if !cd.removedResponseHeaders { + if headers.IsDetectContentTypeHeaderPresent(cd.rw) { + cd.active = true + } + + cd.removedResponseHeaders = true + // We ensure to clear any response header from the response + headers.RemoveResponseHeaders(cd.rw) + } + + return cd.flushed || !cd.active +} + +func (cd *contentDisposition) flush() { + cd.flushBuffer() +} diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go new file mode 100644 index 00000000000..5e3a74f04f9 --- /dev/null +++ b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go @@ -0,0 +1,293 @@ +package contentprocessor + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestFailSetContentTypeAndDisposition(t *testing.T) { + testCaseBody := "Hello world!" + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "", resp.Header.Get(headers.ContentDispositionHeader)) + require.Equal(t, "", resp.Header.Get(headers.ContentTypeHeader)) +} + +func TestSuccessSetContentTypeAndDispositionFeatureEnabled(t *testing.T) { + testCaseBody := "Hello world!" + + resp := makeRequest(t, nil, testCaseBody, "") + + require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader)) + require.Equal(t, "text/plain; charset=utf-8", resp.Header.Get(headers.ContentTypeHeader)) +} + +func TestSetProperContentTypeAndDisposition(t *testing.T) { + testCases := []struct { + desc string + contentType string + contentDisposition string + body string + }{ + { + desc: "text type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "Hello world!", + }, + { + desc: "HTML type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "<html><body>Hello world!</body></html>", + }, + { + desc: "Javascript type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "<script>alert(\"foo\")</script>", + }, + { + desc: "Image type", + contentType: "image/png", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/image.png"), + }, + { + desc: "SVG type", + contentType: "image/svg+xml", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/image.svg"), + }, + { + desc: "Partial SVG type", + contentType: "image/svg+xml", + contentDisposition: "attachment", + body: "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" viewBox=\"0 0 330 82\"><title>SVG logo combined with the W3C logo, set horizontally</title><desc>The logo combines three entities displayed horizontall</desc><metadata>", + }, + { + desc: "Application type", + contentType: "application/pdf", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.pdf"), + }, + { + desc: "Application type pdf with inline disposition", + contentType: "application/pdf", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/file.pdf"), + }, + { + desc: "Application executable type", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.swf"), + }, + { + desc: "Video type", + contentType: "video/mp4", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/video.mp4"), + }, + { + desc: "Audio type", + contentType: "audio/mpeg", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/audio.mp3"), + }, + { + desc: "JSON type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "{ \"glossary\": { \"title\": \"example glossary\", \"GlossDiv\": { \"title\": \"S\" } } }", + }, + { + desc: "Forged file with png extension but SWF content", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/forgedfile.png"), + }, + { + desc: "BMPR file", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.bmpr"), + }, + { + desc: "STL file", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.stl"), + }, + { + desc: "RDoc file", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/file.rdoc"), + }, + { + desc: "IPYNB file", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/file.ipynb"), + }, + { + desc: "Sketch file", + contentType: "application/zip", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.sketch"), + }, + { + desc: "PDF file with non-ASCII characters in filename", + contentType: "application/pdf", + contentDisposition: `attachment; filename="file-ä.pdf"; filename*=UTF-8''file-%c3.pdf`, + body: testhelper.LoadFile(t, "testdata/file-ä.pdf"), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + resp := makeRequest(t, nil, tc.body, tc.contentDisposition) + + require.Equal(t, tc.contentType, resp.Header.Get(headers.ContentTypeHeader)) + require.Equal(t, tc.contentDisposition, resp.Header.Get(headers.ContentDispositionHeader)) + }) + } +} + +func TestFailOverrideContentType(t *testing.T) { + testCase := struct { + contentType string + body string + }{ + contentType: "text/plain; charset=utf-8", + body: "<html><body>Hello world!</body></html>", + } + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + w.Header().Set(headers.ContentTypeHeader, "text/html; charset=utf-8") + _, err := io.WriteString(w, testCase.body) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCase.body, "") + + require.Equal(t, testCase.contentType, resp.Header.Get(headers.ContentTypeHeader)) +} + +func TestSuccessOverrideContentDispositionFromInlineToAttachment(t *testing.T) { + testCaseBody := "Hello world!" + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.ContentDispositionHeader, "attachment") + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestInlineContentDispositionForPdfFiles(t *testing.T) { + testCaseBody := testhelper.LoadFile(t, "testdata/file.pdf") + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.ContentDispositionHeader, "inline") + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestFailOverrideContentDispositionFromAttachmentToInline(t *testing.T) { + testCaseBody := testhelper.LoadFile(t, "testdata/image.svg") + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.ContentDispositionHeader, "inline") + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestHeadersDelete(t *testing.T) { + for _, code := range []int{200, 400} { + recorder := httptest.NewRecorder() + rw := &contentDisposition{rw: recorder} + for _, name := range headers.ResponseHeaders { + rw.Header().Set(name, "foobar") + } + + rw.WriteHeader(code) + + for _, name := range headers.ResponseHeaders { + if header := recorder.Header().Get(name); header != "" { + t.Fatalf("HTTP %d response: expected header to be empty, found %q", code, name) + } + } + } +} + +func TestWriteHeadersCalledOnce(t *testing.T) { + recorder := httptest.NewRecorder() + rw := &contentDisposition{rw: recorder} + rw.WriteHeader(400) + require.Equal(t, 400, rw.status) + require.Equal(t, true, rw.sentStatus) + + rw.WriteHeader(200) + require.Equal(t, 400, rw.status) +} + +func makeRequest(t *testing.T, handler http.HandlerFunc, body string, disposition string) *http.Response { + if handler == nil { + handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + w.Header().Set(headers.ContentDispositionHeader, disposition) + _, err := io.WriteString(w, body) + require.NoError(t, err) + }) + } + req, _ := http.NewRequest("GET", "/", nil) + + rw := httptest.NewRecorder() + SetContentHeaders(handler).ServeHTTP(rw, req) + + resp := rw.Result() + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, body, string(respBody)) + + return resp +} diff --git a/workhorse/internal/senddata/injecter.go b/workhorse/internal/senddata/injecter.go new file mode 100644 index 00000000000..d5739d2a053 --- /dev/null +++ b/workhorse/internal/senddata/injecter.go @@ -0,0 +1,35 @@ +package senddata + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "strings" +) + +type Injecter interface { + Match(string) bool + Inject(http.ResponseWriter, *http.Request, string) + Name() string +} + +type Prefix string + +func (p Prefix) Match(s string) bool { + return strings.HasPrefix(s, string(p)) +} + +func (p Prefix) Unpack(result interface{}, sendData string) error { + jsonBytes, err := base64.URLEncoding.DecodeString(strings.TrimPrefix(sendData, string(p))) + if err != nil { + return err + } + if err := json.Unmarshal([]byte(jsonBytes), result); err != nil { + return err + } + return nil +} + +func (p Prefix) Name() string { + return strings.TrimSuffix(string(p), ":") +} diff --git a/workhorse/internal/senddata/senddata.go b/workhorse/internal/senddata/senddata.go new file mode 100644 index 00000000000..c287d2574fa --- /dev/null +++ b/workhorse/internal/senddata/senddata.go @@ -0,0 +1,105 @@ +package senddata + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata/contentprocessor" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + sendDataResponses = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_senddata_responses", + Help: "How many HTTP responses have been hijacked by a workhorse senddata injecter", + }, + []string{"injecter"}, + ) + sendDataResponseBytes = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_senddata_response_bytes", + Help: "How many bytes have been written by workhorse senddata response injecters", + }, + []string{"injecter"}, + ) +) + +type sendDataResponseWriter struct { + rw http.ResponseWriter + status int + hijacked bool + req *http.Request + injecters []Injecter +} + +func SendData(h http.Handler, injecters ...Injecter) http.Handler { + return contentprocessor.SetContentHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := sendDataResponseWriter{ + rw: w, + req: r, + injecters: injecters, + } + defer s.flush() + h.ServeHTTP(&s, r) + })) +} + +func (s *sendDataResponseWriter) Header() http.Header { + return s.rw.Header() +} + +func (s *sendDataResponseWriter) Write(data []byte) (int, error) { + if s.status == 0 { + s.WriteHeader(http.StatusOK) + } + if s.hijacked { + return len(data), nil + } + return s.rw.Write(data) +} + +func (s *sendDataResponseWriter) WriteHeader(status int) { + if s.status != 0 { + return + } + s.status = status + + if s.status == http.StatusOK && s.tryInject() { + return + } + + s.rw.WriteHeader(s.status) +} + +func (s *sendDataResponseWriter) tryInject() bool { + if s.hijacked { + return false + } + + header := s.Header().Get(headers.GitlabWorkhorseSendDataHeader) + if header == "" { + return false + } + + for _, injecter := range s.injecters { + if injecter.Match(header) { + s.hijacked = true + helper.DisableResponseBuffering(s.rw) + crw := helper.NewCountingResponseWriter(s.rw) + injecter.Inject(crw, s.req, header) + sendDataResponses.WithLabelValues(injecter.Name()).Inc() + sendDataResponseBytes.WithLabelValues(injecter.Name()).Add(float64(crw.Count())) + return true + } + } + + return false +} + +func (s *sendDataResponseWriter) flush() { + s.WriteHeader(http.StatusOK) +} diff --git a/workhorse/internal/senddata/writer_test.go b/workhorse/internal/senddata/writer_test.go new file mode 100644 index 00000000000..1262acd5472 --- /dev/null +++ b/workhorse/internal/senddata/writer_test.go @@ -0,0 +1,71 @@ +package senddata + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" +) + +func TestWriter(t *testing.T) { + upstreamResponse := "hello world" + + testCases := []struct { + desc string + headerValue string + out string + }{ + { + desc: "inject", + headerValue: testInjecterName + ":" + testInjecterName, + out: testInjecterData, + }, + { + desc: "pass", + headerValue: "", + out: upstreamResponse, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + recorder := httptest.NewRecorder() + rw := &sendDataResponseWriter{rw: recorder, injecters: []Injecter{&testInjecter{}}} + + rw.Header().Set(headers.GitlabWorkhorseSendDataHeader, tc.headerValue) + + n, err := rw.Write([]byte(upstreamResponse)) + require.NoError(t, err) + require.Equal(t, len(upstreamResponse), n, "bytes written") + + recorder.Flush() + + body := recorder.Result().Body + data, err := ioutil.ReadAll(body) + require.NoError(t, err) + require.NoError(t, body.Close()) + + require.Equal(t, tc.out, string(data)) + }) + } +} + +const ( + testInjecterName = "test-injecter" + testInjecterData = "hello this is injected data" +) + +type testInjecter struct{} + +func (ti *testInjecter) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + io.WriteString(w, testInjecterData) +} + +func (ti *testInjecter) Match(s string) bool { return strings.HasPrefix(s, testInjecterName+":") } +func (ti *testInjecter) Name() string { return testInjecterName } diff --git a/workhorse/internal/sendfile/sendfile.go b/workhorse/internal/sendfile/sendfile.go new file mode 100644 index 00000000000..d009f216eb9 --- /dev/null +++ b/workhorse/internal/sendfile/sendfile.go @@ -0,0 +1,162 @@ +/* +The xSendFile middleware transparently sends static files in HTTP responses +via the X-Sendfile mechanism. All that is needed in the Rails code is the +'send_file' method. +*/ + +package sendfile + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "regexp" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + sendFileRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_sendfile_requests", + Help: "How many X-Sendfile requests have been processed by gitlab-workhorse, partitioned by sendfile type.", + }, + []string{"type"}, + ) + + sendFileBytes = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_sendfile_bytes", + Help: "How many X-Sendfile bytes have been sent by gitlab-workhorse, partitioned by sendfile type.", + }, + []string{"type"}, + ) + + artifactsSendFile = regexp.MustCompile("builds/[0-9]+/artifacts") +) + +type sendFileResponseWriter struct { + rw http.ResponseWriter + status int + hijacked bool + req *http.Request +} + +func SendFile(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + s := &sendFileResponseWriter{ + rw: rw, + req: req, + } + // Advertise to upstream (Rails) that we support X-Sendfile + req.Header.Set(headers.XSendFileTypeHeader, headers.XSendFileHeader) + defer s.flush() + h.ServeHTTP(s, req) + }) +} + +func (s *sendFileResponseWriter) Header() http.Header { + return s.rw.Header() +} + +func (s *sendFileResponseWriter) Write(data []byte) (int, error) { + if s.status == 0 { + s.WriteHeader(http.StatusOK) + } + if s.hijacked { + return len(data), nil + } + return s.rw.Write(data) +} + +func (s *sendFileResponseWriter) WriteHeader(status int) { + if s.status != 0 { + return + } + + s.status = status + if s.status != http.StatusOK { + s.rw.WriteHeader(s.status) + return + } + + file := s.Header().Get(headers.XSendFileHeader) + if file != "" && !s.hijacked { + // Mark this connection as hijacked + s.hijacked = true + + // Serve the file + helper.DisableResponseBuffering(s.rw) + sendFileFromDisk(s.rw, s.req, file) + return + } + + s.rw.WriteHeader(s.status) +} + +func sendFileFromDisk(w http.ResponseWriter, r *http.Request, file string) { + log.WithContextFields(r.Context(), log.Fields{ + "file": file, + "method": r.Method, + "uri": mask.URL(r.RequestURI), + }).Print("Send file") + + contentTypeHeaderPresent := false + + if headers.IsDetectContentTypeHeaderPresent(w) { + // Removing the GitlabWorkhorseDetectContentTypeHeader header to + // avoid handling the response by the senddata handler + w.Header().Del(headers.GitlabWorkhorseDetectContentTypeHeader) + contentTypeHeaderPresent = true + } + + content, fi, err := helper.OpenFile(file) + if err != nil { + http.NotFound(w, r) + return + } + defer content.Close() + + countSendFileMetrics(fi.Size(), r) + + if contentTypeHeaderPresent { + data, err := ioutil.ReadAll(io.LimitReader(content, headers.MaxDetectSize)) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("content type detection: %v", err)) + return + } + + content.Seek(0, io.SeekStart) + + contentType, contentDisposition := headers.SafeContentHeaders(data, w.Header().Get(headers.ContentDispositionHeader)) + w.Header().Set(headers.ContentTypeHeader, contentType) + w.Header().Set(headers.ContentDispositionHeader, contentDisposition) + } + + http.ServeContent(w, r, "", fi.ModTime(), content) +} + +func countSendFileMetrics(size int64, r *http.Request) { + var requestType string + switch { + case artifactsSendFile.MatchString(r.RequestURI): + requestType = "artifacts" + default: + requestType = "other" + } + + sendFileRequests.WithLabelValues(requestType).Inc() + sendFileBytes.WithLabelValues(requestType).Add(float64(size)) +} + +func (s *sendFileResponseWriter) flush() { + s.WriteHeader(http.StatusOK) +} diff --git a/workhorse/internal/sendfile/sendfile_test.go b/workhorse/internal/sendfile/sendfile_test.go new file mode 100644 index 00000000000..d424814b5e5 --- /dev/null +++ b/workhorse/internal/sendfile/sendfile_test.go @@ -0,0 +1,171 @@ +package sendfile + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" +) + +func TestResponseWriter(t *testing.T) { + upstreamResponse := "hello world" + + fixturePath := "testdata/sent-file.txt" + fixtureContent, err := ioutil.ReadFile(fixturePath) + require.NoError(t, err) + + testCases := []struct { + desc string + sendfileHeader string + out string + }{ + { + desc: "send a file", + sendfileHeader: fixturePath, + out: string(fixtureContent), + }, + { + desc: "pass through unaltered", + sendfileHeader: "", + out: upstreamResponse, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + r, err := http.NewRequest("GET", "/foo", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + sf := &sendFileResponseWriter{rw: rw, req: r} + sf.Header().Set(headers.XSendFileHeader, tc.sendfileHeader) + + upstreamBody := []byte(upstreamResponse) + n, err := sf.Write(upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + + rw.Flush() + + body := rw.Result().Body + data, err := ioutil.ReadAll(body) + require.NoError(t, err) + require.NoError(t, body.Close()) + + require.Equal(t, tc.out, string(data)) + }) + } +} + +func TestAllowExistentContentHeaders(t *testing.T) { + fixturePath := "../../testdata/forgedfile.png" + + httpHeaders := map[string]string{ + headers.ContentTypeHeader: "image/png", + headers.ContentDispositionHeader: "inline", + } + + resp := makeRequest(t, fixturePath, httpHeaders) + require.Equal(t, "image/png", resp.Header.Get(headers.ContentTypeHeader)) + require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestSuccessOverrideContentHeadersFeatureEnabled(t *testing.T) { + fixturePath := "../../testdata/forgedfile.png" + + httpHeaders := make(map[string]string) + httpHeaders[headers.ContentTypeHeader] = "image/png" + httpHeaders[headers.ContentDispositionHeader] = "inline" + httpHeaders["Range"] = "bytes=1-2" + + resp := makeRequest(t, fixturePath, httpHeaders) + require.Equal(t, "image/png", resp.Header.Get(headers.ContentTypeHeader)) + require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestSuccessOverrideContentHeadersRangeRequestFeatureEnabled(t *testing.T) { + fixturePath := "../../testdata/forgedfile.png" + + fixtureContent, err := ioutil.ReadFile(fixturePath) + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/foo", nil) + r.Header.Set("Range", "bytes=1-2") + require.NoError(t, err) + + rw := httptest.NewRecorder() + sf := &sendFileResponseWriter{rw: rw, req: r} + + sf.Header().Set(headers.XSendFileHeader, fixturePath) + sf.Header().Set(headers.ContentTypeHeader, "image/png") + sf.Header().Set(headers.ContentDispositionHeader, "inline") + sf.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + + upstreamBody := []byte(fixtureContent) + _, err = sf.Write(upstreamBody) + require.NoError(t, err) + + rw.Flush() + + resp := rw.Result() + body := resp.Body + data, err := ioutil.ReadAll(body) + require.NoError(t, err) + require.NoError(t, body.Close()) + + require.Len(t, data, 2) + + require.Equal(t, "application/octet-stream", resp.Header.Get(headers.ContentTypeHeader)) + require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestSuccessInlineWhitelistedTypesFeatureEnabled(t *testing.T) { + fixturePath := "../../testdata/image.png" + + httpHeaders := map[string]string{ + headers.ContentDispositionHeader: "inline", + headers.GitlabWorkhorseDetectContentTypeHeader: "true", + } + + resp := makeRequest(t, fixturePath, httpHeaders) + + require.Equal(t, "image/png", resp.Header.Get(headers.ContentTypeHeader)) + require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func makeRequest(t *testing.T, fixturePath string, httpHeaders map[string]string) *http.Response { + fixtureContent, err := ioutil.ReadFile(fixturePath) + require.NoError(t, err) + + r, err := http.NewRequest("GET", "/foo", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + sf := &sendFileResponseWriter{rw: rw, req: r} + + sf.Header().Set(headers.XSendFileHeader, fixturePath) + for name, value := range httpHeaders { + sf.Header().Set(name, value) + } + + upstreamBody := []byte("hello") + n, err := sf.Write(upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + + rw.Flush() + + resp := rw.Result() + body := resp.Body + data, err := ioutil.ReadAll(body) + require.NoError(t, err) + require.NoError(t, body.Close()) + + require.Equal(t, fixtureContent, data) + + return resp +} diff --git a/workhorse/internal/sendfile/testdata/sent-file.txt b/workhorse/internal/sendfile/testdata/sent-file.txt new file mode 100644 index 00000000000..40e33f8a628 --- /dev/null +++ b/workhorse/internal/sendfile/testdata/sent-file.txt @@ -0,0 +1 @@ +This file is sent with X-SendFile diff --git a/workhorse/internal/sendurl/sendurl.go b/workhorse/internal/sendurl/sendurl.go new file mode 100644 index 00000000000..cf3d14a2bf0 --- /dev/null +++ b/workhorse/internal/sendurl/sendurl.go @@ -0,0 +1,167 @@ +package sendurl + +import ( + "fmt" + "io" + "net" + "net/http" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" + "gitlab.com/gitlab-org/labkit/tracing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" +) + +type entry struct{ senddata.Prefix } + +type entryParams struct { + URL string + AllowRedirects bool +} + +var SendURL = &entry{"send-url:"} + +var rangeHeaderKeys = []string{ + "If-Match", + "If-Unmodified-Since", + "If-None-Match", + "If-Modified-Since", + "If-Range", + "Range", +} + +// Keep cache headers from the original response, not the proxied response. The +// original response comes from the Rails application, which should be the +// source of truth for caching. +var preserveHeaderKeys = map[string]bool{ + "Cache-Control": true, + "Expires": true, + "Date": true, // Support for HTTP 1.0 proxies + "Pragma": true, // Support for HTTP 1.0 proxies +} + +// httpTransport defines a http.Transport with values +// that are more restrictive than for http.DefaultTransport, +// they define shorter TLS Handshake, and more aggressive connection closing +// to prevent the connection hanging and reduce FD usage +var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 10 * time.Second, + }).DialContext, + MaxIdleConns: 2, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, +})) + +var httpClient = &http.Client{ + Transport: httpTransport, +} + +var ( + sendURLRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_send_url_requests", + Help: "How many send URL requests have been processed", + }, + []string{"status"}, + ) + sendURLOpenRequests = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "gitlab_workhorse_send_url_open_requests", + Help: "Describes how many send URL requests are open now", + }, + ) + sendURLBytes = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_send_url_bytes", + Help: "How many bytes were passed with send URL", + }, + ) + + sendURLRequestsInvalidData = sendURLRequests.WithLabelValues("invalid-data") + sendURLRequestsRequestFailed = sendURLRequests.WithLabelValues("request-failed") + sendURLRequestsSucceeded = sendURLRequests.WithLabelValues("succeeded") +) + +func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + var params entryParams + + sendURLOpenRequests.Inc() + defer sendURLOpenRequests.Dec() + + if err := e.Unpack(¶ms, sendData); err != nil { + helper.Fail500(w, r, fmt.Errorf("SendURL: unpack sendData: %v", err)) + return + } + + log.WithContextFields(r.Context(), log.Fields{ + "url": mask.URL(params.URL), + "path": r.URL.Path, + }).Info("SendURL: sending") + + if params.URL == "" { + sendURLRequestsInvalidData.Inc() + helper.Fail500(w, r, fmt.Errorf("SendURL: URL is empty")) + return + } + + // create new request and copy range headers + newReq, err := http.NewRequest("GET", params.URL, nil) + if err != nil { + sendURLRequestsInvalidData.Inc() + helper.Fail500(w, r, fmt.Errorf("SendURL: NewRequest: %v", err)) + return + } + newReq = newReq.WithContext(r.Context()) + + for _, header := range rangeHeaderKeys { + newReq.Header[header] = r.Header[header] + } + + // execute new request + var resp *http.Response + if params.AllowRedirects { + resp, err = httpClient.Do(newReq) + } else { + resp, err = httpTransport.RoundTrip(newReq) + } + if err != nil { + sendURLRequestsRequestFailed.Inc() + helper.Fail500(w, r, fmt.Errorf("SendURL: Do request: %v", err)) + return + } + + // Prevent Go from adding a Content-Length header automatically + w.Header().Del("Content-Length") + + // copy response headers and body, except the headers from preserveHeaderKeys + for key, value := range resp.Header { + if !preserveHeaderKeys[key] { + w.Header()[key] = value + } + } + w.WriteHeader(resp.StatusCode) + + defer resp.Body.Close() + n, err := io.Copy(w, resp.Body) + sendURLBytes.Add(float64(n)) + + if err != nil { + sendURLRequestsRequestFailed.Inc() + helper.LogError(r, fmt.Errorf("SendURL: Copy response: %v", err)) + return + } + + sendURLRequestsSucceeded.Inc() +} diff --git a/workhorse/internal/sendurl/sendurl_test.go b/workhorse/internal/sendurl/sendurl_test.go new file mode 100644 index 00000000000..41e1dbb8e0f --- /dev/null +++ b/workhorse/internal/sendurl/sendurl_test.go @@ -0,0 +1,197 @@ +package sendurl + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +const testData = `123456789012345678901234567890` +const testDataEtag = `W/"myetag"` + +func testEntryServer(t *testing.T, requestURL string, httpHeaders http.Header, allowRedirects bool) *httptest.ResponseRecorder { + requestHandler := func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "GET", r.Method) + + url := r.URL.String() + "/file" + jsonParams := fmt.Sprintf(`{"URL":%q,"AllowRedirects":%s}`, + url, strconv.FormatBool(allowRedirects)) + data := base64.URLEncoding.EncodeToString([]byte(jsonParams)) + + // The server returns a Content-Disposition + w.Header().Set("Content-Disposition", "attachment; filename=\"archive.txt\"") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Expires", "") + w.Header().Set("Date", "Wed, 21 Oct 2015 05:28:00 GMT") + w.Header().Set("Pragma", "no-cache") + + SendURL.Inject(w, r, data) + } + serveFile := func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "GET", r.Method) + + tempFile, err := ioutil.TempFile("", "download_file") + require.NoError(t, err) + require.NoError(t, os.Remove(tempFile.Name())) + defer tempFile.Close() + _, err = tempFile.Write([]byte(testData)) + require.NoError(t, err) + + w.Header().Set("Etag", testDataEtag) + w.Header().Set("Cache-Control", "public") + w.Header().Set("Expires", "Wed, 21 Oct 2015 07:28:00 GMT") + w.Header().Set("Date", "Wed, 21 Oct 2015 06:28:00 GMT") + w.Header().Set("Pragma", "") + + http.ServeContent(w, r, "archive.txt", time.Now(), tempFile) + } + redirectFile := func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "GET", r.Method) + http.Redirect(w, r, r.URL.String()+"/download", http.StatusTemporaryRedirect) + } + + mux := http.NewServeMux() + mux.HandleFunc("/get/request", requestHandler) + mux.HandleFunc("/get/request/file", serveFile) + mux.HandleFunc("/get/redirect", requestHandler) + mux.HandleFunc("/get/redirect/file", redirectFile) + mux.HandleFunc("/get/redirect/file/download", serveFile) + mux.HandleFunc("/get/file-not-existing", requestHandler) + + server := httptest.NewServer(mux) + defer server.Close() + + httpRequest, err := http.NewRequest("GET", server.URL+requestURL, nil) + require.NoError(t, err) + if httpHeaders != nil { + httpRequest.Header = httpHeaders + } + + response := httptest.NewRecorder() + mux.ServeHTTP(response, httpRequest) + return response +} + +func TestDownloadingUsingSendURL(t *testing.T) { + response := testEntryServer(t, "/get/request", nil, false) + require.Equal(t, http.StatusOK, response.Code) + + testhelper.RequireResponseHeader(t, response, + "Content-Type", + "text/plain; charset=utf-8") + testhelper.RequireResponseHeader(t, response, + "Content-Disposition", + "attachment; filename=\"archive.txt\"") + + testhelper.RequireResponseBody(t, response, testData) +} + +func TestDownloadingAChunkOfDataWithSendURL(t *testing.T) { + httpHeaders := http.Header{ + "Range": []string{ + "bytes=1-2", + }, + } + + response := testEntryServer(t, "/get/request", httpHeaders, false) + require.Equal(t, http.StatusPartialContent, response.Code) + + testhelper.RequireResponseHeader(t, response, + "Content-Type", + "text/plain; charset=utf-8") + testhelper.RequireResponseHeader(t, response, + "Content-Disposition", + "attachment; filename=\"archive.txt\"") + testhelper.RequireResponseHeader(t, response, + "Content-Range", + "bytes 1-2/30") + + testhelper.RequireResponseBody(t, response, "23") +} + +func TestAccessingAlreadyDownloadedFileWithSendURL(t *testing.T) { + httpHeaders := http.Header{ + "If-None-Match": []string{testDataEtag}, + } + + response := testEntryServer(t, "/get/request", httpHeaders, false) + require.Equal(t, http.StatusNotModified, response.Code) +} + +func TestAccessingRedirectWithSendURL(t *testing.T) { + response := testEntryServer(t, "/get/redirect", nil, false) + require.Equal(t, http.StatusTemporaryRedirect, response.Code) +} + +func TestAccessingAllowedRedirectWithSendURL(t *testing.T) { + response := testEntryServer(t, "/get/redirect", nil, true) + require.Equal(t, http.StatusOK, response.Code) + + testhelper.RequireResponseHeader(t, response, + "Content-Type", + "text/plain; charset=utf-8") + testhelper.RequireResponseHeader(t, response, + "Content-Disposition", + "attachment; filename=\"archive.txt\"") +} + +func TestAccessingAllowedRedirectWithChunkOfDataWithSendURL(t *testing.T) { + httpHeaders := http.Header{ + "Range": []string{ + "bytes=1-2", + }, + } + + response := testEntryServer(t, "/get/redirect", httpHeaders, true) + require.Equal(t, http.StatusPartialContent, response.Code) + + testhelper.RequireResponseHeader(t, response, + "Content-Type", + "text/plain; charset=utf-8") + testhelper.RequireResponseHeader(t, response, + "Content-Disposition", + "attachment; filename=\"archive.txt\"") + testhelper.RequireResponseHeader(t, response, + "Content-Range", + "bytes 1-2/30") + + testhelper.RequireResponseBody(t, response, "23") +} + +func TestOriginalCacheHeadersPreservedWithSendURL(t *testing.T) { + response := testEntryServer(t, "/get/redirect", nil, true) + require.Equal(t, http.StatusOK, response.Code) + + testhelper.RequireResponseHeader(t, response, + "Cache-Control", + "no-cache") + testhelper.RequireResponseHeader(t, response, + "Expires", + "") + testhelper.RequireResponseHeader(t, response, + "Date", + "Wed, 21 Oct 2015 05:28:00 GMT") + testhelper.RequireResponseHeader(t, response, + "Pragma", + "no-cache") +} + +func TestDownloadingNonExistingFileUsingSendURL(t *testing.T) { + response := testEntryServer(t, "/invalid/path", nil, false) + require.Equal(t, http.StatusNotFound, response.Code) +} + +func TestDownloadingNonExistingRemoteFileWithSendURL(t *testing.T) { + response := testEntryServer(t, "/get/file-not-existing", nil, false) + require.Equal(t, http.StatusNotFound, response.Code) +} diff --git a/workhorse/internal/staticpages/deploy_page.go b/workhorse/internal/staticpages/deploy_page.go new file mode 100644 index 00000000000..d08ed449ae6 --- /dev/null +++ b/workhorse/internal/staticpages/deploy_page.go @@ -0,0 +1,26 @@ +package staticpages + +import ( + "io/ioutil" + "net/http" + "path/filepath" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +func (s *Static) DeployPage(handler http.Handler) http.Handler { + deployPage := filepath.Join(s.DocumentRoot, "index.html") + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, err := ioutil.ReadFile(deployPage) + if err != nil { + handler.ServeHTTP(w, r) + return + } + + helper.SetNoCacheHeaders(w.Header()) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(data) + }) +} diff --git a/workhorse/internal/staticpages/deploy_page_test.go b/workhorse/internal/staticpages/deploy_page_test.go new file mode 100644 index 00000000000..4b081e73a97 --- /dev/null +++ b/workhorse/internal/staticpages/deploy_page_test.go @@ -0,0 +1,59 @@ +package staticpages + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestIfNoDeployPageExist(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + w := httptest.NewRecorder() + + executed := false + st := &Static{dir} + st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + executed = true + })).ServeHTTP(w, nil) + if !executed { + t.Error("The handler should get executed") + } +} + +func TestIfDeployPageExist(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + deployPage := "DEPLOY" + ioutil.WriteFile(filepath.Join(dir, "index.html"), []byte(deployPage), 0600) + + w := httptest.NewRecorder() + + executed := false + st := &Static{dir} + st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + executed = true + })).ServeHTTP(w, nil) + if executed { + t.Error("The handler should not get executed") + } + w.Flush() + + require.Equal(t, 200, w.Code) + testhelper.RequireResponseBody(t, w, deployPage) +} diff --git a/workhorse/internal/staticpages/error_pages.go b/workhorse/internal/staticpages/error_pages.go new file mode 100644 index 00000000000..3cc89d9f811 --- /dev/null +++ b/workhorse/internal/staticpages/error_pages.go @@ -0,0 +1,138 @@ +package staticpages + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "path/filepath" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + staticErrorResponses = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_static_error_responses", + Help: "How many HTTP responses have been changed to a static error page, by HTTP status code.", + }, + []string{"code"}, + ) +) + +type ErrorFormat int + +const ( + ErrorFormatHTML ErrorFormat = iota + ErrorFormatJSON + ErrorFormatText +) + +type errorPageResponseWriter struct { + rw http.ResponseWriter + status int + hijacked bool + path string + format ErrorFormat +} + +func (s *errorPageResponseWriter) Header() http.Header { + return s.rw.Header() +} + +func (s *errorPageResponseWriter) Write(data []byte) (int, error) { + if s.status == 0 { + s.WriteHeader(http.StatusOK) + } + if s.hijacked { + return len(data), nil + } + return s.rw.Write(data) +} + +func (s *errorPageResponseWriter) WriteHeader(status int) { + if s.status != 0 { + return + } + + s.status = status + + if s.status < 400 || s.status > 599 || s.rw.Header().Get("X-GitLab-Custom-Error") != "" { + s.rw.WriteHeader(status) + return + } + + var contentType string + var data []byte + switch s.format { + case ErrorFormatText: + contentType, data = s.writeText() + case ErrorFormatJSON: + contentType, data = s.writeJSON() + default: + contentType, data = s.writeHTML() + } + + if contentType == "" { + s.rw.WriteHeader(status) + return + } + + s.hijacked = true + staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc() + + helper.SetNoCacheHeaders(s.rw.Header()) + s.rw.Header().Set("Content-Type", contentType) + s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + s.rw.Header().Del("Transfer-Encoding") + s.rw.WriteHeader(s.status) + s.rw.Write(data) +} + +func (s *errorPageResponseWriter) writeHTML() (string, []byte) { + if s.rw.Header().Get("Content-Type") != "application/json" { + errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status)) + + // check if custom error page exists, serve this page instead + if data, err := ioutil.ReadFile(errorPageFile); err == nil { + return "text/html; charset=utf-8", data + } + } + + return "", nil +} + +func (s *errorPageResponseWriter) writeJSON() (string, []byte) { + message, err := json.Marshal(map[string]interface{}{"error": http.StatusText(s.status), "status": s.status}) + if err != nil { + return "", nil + } + + return "application/json; charset=utf-8", append(message, "\n"...) +} + +func (s *errorPageResponseWriter) writeText() (string, []byte) { + return "text/plain; charset=utf-8", []byte(http.StatusText(s.status) + "\n") +} + +func (s *errorPageResponseWriter) flush() { + s.WriteHeader(http.StatusOK) +} + +func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler { + if disabled { + return handler + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rw := errorPageResponseWriter{ + rw: w, + path: st.DocumentRoot, + format: format, + } + defer rw.flush() + handler.ServeHTTP(&rw, r) + }) +} diff --git a/workhorse/internal/staticpages/error_pages_test.go b/workhorse/internal/staticpages/error_pages_test.go new file mode 100644 index 00000000000..05ec06cd429 --- /dev/null +++ b/workhorse/internal/staticpages/error_pages_test.go @@ -0,0 +1,191 @@ +package staticpages + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func TestIfErrorPageIsPresented(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + upstreamBody := "Not Found" + n, err := fmt.Fprint(w, upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorPage) + testhelper.RequireResponseHeader(t, w, "Content-Type", "text/html; charset=utf-8") +} + +func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + w := httptest.NewRecorder() + errorResponse := "ERROR" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + fmt.Fprint(w, errorResponse) + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorResponse) +} + +func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + serverError := "Interesting Server Error" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(500) + fmt.Fprint(w, serverError) + }) + st := &Static{dir} + st.ErrorPagesUnless(true, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + require.Equal(t, 500, w.Code) + testhelper.RequireResponseBody(t, w, serverError) +} + +func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + serverError := "Interesting Server Error" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("X-GitLab-Custom-Error", "1") + w.WriteHeader(500) + fmt.Fprint(w, serverError) + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + require.Equal(t, 500, w.Code) + testhelper.RequireResponseBody(t, w, serverError) +} + +func TestErrorPageInterceptedByContentType(t *testing.T) { + testCases := []struct { + contentType string + intercepted bool + }{ + {contentType: "application/json", intercepted: false}, + {contentType: "text/plain", intercepted: true}, + {contentType: "text/html", intercepted: true}, + {contentType: "", intercepted: true}, + } + + for _, tc := range testCases { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + serverError := "Interesting Server Error" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("Content-Type", tc.contentType) + w.WriteHeader(500) + fmt.Fprint(w, serverError) + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + require.Equal(t, 500, w.Code) + + if tc.intercepted { + testhelper.RequireResponseBody(t, w, errorPage) + } else { + testhelper.RequireResponseBody(t, w, serverError) + } + } +} + +func TestIfErrorPageIsPresentedJSON(t *testing.T) { + errorPage := "{\"error\":\"Not Found\",\"status\":404}\n" + + w := httptest.NewRecorder() + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + upstreamBody := "This string is ignored" + n, err := fmt.Fprint(w, upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + }) + st := &Static{""} + st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorPage) + testhelper.RequireResponseHeader(t, w, "Content-Type", "application/json; charset=utf-8") +} + +func TestIfErrorPageIsPresentedText(t *testing.T) { + errorPage := "Not Found\n" + + w := httptest.NewRecorder() + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + upstreamBody := "This string is ignored" + n, err := fmt.Fprint(w, upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + }) + st := &Static{""} + st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorPage) + testhelper.RequireResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8") +} diff --git a/workhorse/internal/staticpages/servefile.go b/workhorse/internal/staticpages/servefile.go new file mode 100644 index 00000000000..c98bc030bc2 --- /dev/null +++ b/workhorse/internal/staticpages/servefile.go @@ -0,0 +1,84 @@ +package staticpages + +import ( + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix" +) + +type CacheMode int + +const ( + CacheDisabled CacheMode = iota + CacheExpireMax +) + +// BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists, +// handleServeFile will serve foo/bar instead of passing the request +// upstream. +func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + file := filepath.Join(s.DocumentRoot, prefix.Strip(r.URL.Path)) + + // The filepath.Join does Clean traversing directories up + if !strings.HasPrefix(file, s.DocumentRoot) { + helper.Fail500(w, r, &os.PathError{ + Op: "open", + Path: file, + Err: os.ErrInvalid, + }) + return + } + + var content *os.File + var fi os.FileInfo + var err error + + // Serve pre-gzipped assets + if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") { + content, fi, err = helper.OpenFile(file + ".gz") + if err == nil { + w.Header().Set("Content-Encoding", "gzip") + } + } + + // If not found, open the original file + if content == nil || err != nil { + content, fi, err = helper.OpenFile(file) + } + if err != nil { + if notFoundHandler != nil { + notFoundHandler.ServeHTTP(w, r) + } else { + http.NotFound(w, r) + } + return + } + defer content.Close() + + switch cache { + case CacheExpireMax: + // Cache statically served files for 1 year + cacheUntil := time.Now().AddDate(1, 0, 0).Format(http.TimeFormat) + w.Header().Set("Cache-Control", "public") + w.Header().Set("Expires", cacheUntil) + } + + log.WithContextFields(r.Context(), log.Fields{ + "file": file, + "encoding": w.Header().Get("Content-Encoding"), + "method": r.Method, + "uri": mask.URL(r.RequestURI), + }).Info("Send static file") + + http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content) + }) +} diff --git a/workhorse/internal/staticpages/servefile_test.go b/workhorse/internal/staticpages/servefile_test.go new file mode 100644 index 00000000000..e136b876298 --- /dev/null +++ b/workhorse/internal/staticpages/servefile_test.go @@ -0,0 +1,134 @@ +package staticpages + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestServingNonExistingFile(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 404, w.Code) +} + +func TestServingDirectory(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 404, w.Code) +} + +func TestServingMalformedUri(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 404, w.Code) +} + +func TestExecutingHandlerWhenNoFileFound(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + executed := false + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + executed = (r == httpRequest) + })).ServeHTTP(nil, httpRequest) + if !executed { + t.Error("The handler should get executed") + } +} + +func TestServingTheActualFile(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + fileContent := "STATIC" + ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 200, w.Code) + if w.Body.String() != fileContent { + t.Error("We should serve the file: ", w.Body.String()) + } +} + +func testServingThePregzippedFile(t *testing.T, enableGzip bool) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + if enableGzip { + httpRequest.Header.Set("Accept-Encoding", "gzip, deflate") + } + + fileContent := "STATIC" + + var fileGzipContent bytes.Buffer + fileGzip := gzip.NewWriter(&fileGzipContent) + fileGzip.Write([]byte(fileContent)) + fileGzip.Close() + + ioutil.WriteFile(filepath.Join(dir, "file.gz"), fileGzipContent.Bytes(), 0600) + ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 200, w.Code) + if enableGzip { + testhelper.RequireResponseHeader(t, w, "Content-Encoding", "gzip") + if !bytes.Equal(w.Body.Bytes(), fileGzipContent.Bytes()) { + t.Error("We should serve the pregzipped file") + } + } else { + require.Equal(t, 200, w.Code) + testhelper.RequireResponseHeader(t, w, "Content-Encoding") + if w.Body.String() != fileContent { + t.Error("We should serve the file: ", w.Body.String()) + } + } +} + +func TestServingThePregzippedFile(t *testing.T) { + testServingThePregzippedFile(t, true) +} + +func TestServingThePregzippedFileWithoutEncoding(t *testing.T) { + testServingThePregzippedFile(t, false) +} diff --git a/workhorse/internal/staticpages/static.go b/workhorse/internal/staticpages/static.go new file mode 100644 index 00000000000..b42351f15f5 --- /dev/null +++ b/workhorse/internal/staticpages/static.go @@ -0,0 +1,5 @@ +package staticpages + +type Static struct { + DocumentRoot string +} diff --git a/workhorse/internal/testhelper/gitaly.go b/workhorse/internal/testhelper/gitaly.go new file mode 100644 index 00000000000..24884505440 --- /dev/null +++ b/workhorse/internal/testhelper/gitaly.go @@ -0,0 +1,384 @@ +package testhelper + +import ( + "fmt" + "io" + "io/ioutil" + "path" + "strings" + "sync" + + "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274 + "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274 + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/labkit/log" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +type GitalyTestServer struct { + finalMessageCode codes.Code + sync.WaitGroup + LastIncomingMetadata metadata.MD + gitalypb.UnimplementedRepositoryServiceServer + gitalypb.UnimplementedBlobServiceServer + gitalypb.UnimplementedDiffServiceServer +} + +var ( + GitalyInfoRefsResponseMock = strings.Repeat("Mock Gitaly InfoRefsResponse data", 100000) + GitalyGetBlobResponseMock = strings.Repeat("Mock Gitaly GetBlobResponse data", 100000) + GitalyGetArchiveResponseMock = strings.Repeat("Mock Gitaly GetArchiveResponse data", 100000) + GitalyGetDiffResponseMock = strings.Repeat("Mock Gitaly GetDiffResponse data", 100000) + GitalyGetPatchResponseMock = strings.Repeat("Mock Gitaly GetPatchResponse data", 100000) + + GitalyGetSnapshotResponseMock = strings.Repeat("Mock Gitaly GetSnapshotResponse data", 100000) + + GitalyReceivePackResponseMock []byte + GitalyUploadPackResponseMock []byte +) + +func init() { + var err error + if GitalyReceivePackResponseMock, err = ioutil.ReadFile(path.Join(RootDir(), "testdata/receive-pack-fixture.txt")); err != nil { + log.WithError(err).Fatal("Unable to read pack response") + } + if GitalyUploadPackResponseMock, err = ioutil.ReadFile(path.Join(RootDir(), "testdata/upload-pack-fixture.txt")); err != nil { + log.WithError(err).Fatal("Unable to read pack response") + } +} + +func NewGitalyServer(finalMessageCode codes.Code) *GitalyTestServer { + return &GitalyTestServer{finalMessageCode: finalMessageCode} +} + +func (s *GitalyTestServer) InfoRefsUploadPack(in *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsUploadPackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + fmt.Printf("Result: %+v\n", in) + + marshaler := &jsonpb.Marshaler{} + jsonString, err := marshaler.MarshalToString(in) + if err != nil { + return err + } + + data := []byte(strings.Join([]string{ + jsonString, + "git-upload-pack", + GitalyInfoRefsResponseMock, + }, "\000")) + + s.LastIncomingMetadata = nil + if md, ok := metadata.FromIncomingContext(stream.Context()); ok { + s.LastIncomingMetadata = md + } + + return s.sendInfoRefs(stream, data) +} + +func (s *GitalyTestServer) InfoRefsReceivePack(in *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsReceivePackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + fmt.Printf("Result: %+v\n", in) + + jsonString, err := marshalJSON(in) + if err != nil { + return err + } + + data := []byte(strings.Join([]string{ + jsonString, + "git-receive-pack", + GitalyInfoRefsResponseMock, + }, "\000")) + + return s.sendInfoRefs(stream, data) +} + +func marshalJSON(msg proto.Message) (string, error) { + marshaler := &jsonpb.Marshaler{} + return marshaler.MarshalToString(msg) +} + +type infoRefsSender interface { + Send(*gitalypb.InfoRefsResponse) error +} + +func (s *GitalyTestServer) sendInfoRefs(stream infoRefsSender, data []byte) error { + nSends, err := sendBytes(data, 100, func(p []byte) error { + return stream.Send(&gitalypb.InfoRefsResponse{Data: p}) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) PostReceivePack(stream gitalypb.SmartHTTPService_PostReceivePackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + req, err := stream.Recv() + if err != nil { + return err + } + + repo := req.GetRepository() + if err := validateRepository(repo); err != nil { + return err + } + + jsonString, err := marshalJSON(req) + if err != nil { + return err + } + + data := []byte(jsonString + "\000") + + // The body of the request starts in the second message + for { + req, err := stream.Recv() + if err != nil { + if err != io.EOF { + return err + } + break + } + + // We want to echo the request data back + data = append(data, req.GetData()...) + } + + nSends, _ := sendBytes(data, 100, func(p []byte) error { + return stream.Send(&gitalypb.PostReceivePackResponse{Data: p}) + }) + + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) PostUploadPack(stream gitalypb.SmartHTTPService_PostUploadPackServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + req, err := stream.Recv() + if err != nil { + return err + } + + if err := validateRepository(req.GetRepository()); err != nil { + return err + } + + jsonString, err := marshalJSON(req) + if err != nil { + return err + } + + if err := stream.Send(&gitalypb.PostUploadPackResponse{ + Data: []byte(strings.Join([]string{jsonString}, "\000") + "\000"), + }); err != nil { + return err + } + + nSends := 0 + // The body of the request starts in the second message. Gitaly streams PostUploadPack responses + // as soon as possible without reading the request completely first. We stream messages here + // directly back to the client to simulate the streaming of the actual implementation. + for { + req, err := stream.Recv() + if err != nil { + if err != io.EOF { + return err + } + break + } + + if err := stream.Send(&gitalypb.PostUploadPackResponse{Data: req.GetData()}); err != nil { + return err + } + + nSends++ + } + + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) CommitIsAncestor(ctx context.Context, in *gitalypb.CommitIsAncestorRequest) (*gitalypb.CommitIsAncestorResponse, error) { + return nil, nil +} + +func (s *GitalyTestServer) GetBlob(in *gitalypb.GetBlobRequest, stream gitalypb.BlobService_GetBlobServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + response := &gitalypb.GetBlobResponse{ + Oid: in.GetOid(), + Size: int64(len(GitalyGetBlobResponseMock)), + } + nSends, err := sendBytes([]byte(GitalyGetBlobResponseMock), 100, func(p []byte) error { + response.Data = p + + if err := stream.Send(response); err != nil { + return err + } + + // Use a new response so we don't send other fields (Size, ...) over and over + response = &gitalypb.GetBlobResponse{} + + return nil + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) GetArchive(in *gitalypb.GetArchiveRequest, stream gitalypb.RepositoryService_GetArchiveServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + nSends, err := sendBytes([]byte(GitalyGetArchiveResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.GetArchiveResponse{Data: p}) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) RawDiff(in *gitalypb.RawDiffRequest, stream gitalypb.DiffService_RawDiffServer) error { + nSends, err := sendBytes([]byte(GitalyGetDiffResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.RawDiffResponse{ + Data: p, + }) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) RawPatch(in *gitalypb.RawPatchRequest, stream gitalypb.DiffService_RawPatchServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + nSends, err := sendBytes([]byte(GitalyGetPatchResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.RawPatchResponse{ + Data: p, + }) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +func (s *GitalyTestServer) GetSnapshot(in *gitalypb.GetSnapshotRequest, stream gitalypb.RepositoryService_GetSnapshotServer) error { + s.WaitGroup.Add(1) + defer s.WaitGroup.Done() + + if err := validateRepository(in.GetRepository()); err != nil { + return err + } + + nSends, err := sendBytes([]byte(GitalyGetSnapshotResponseMock), 100, func(p []byte) error { + return stream.Send(&gitalypb.GetSnapshotResponse{Data: p}) + }) + if err != nil { + return err + } + if nSends <= 1 { + panic("should have sent more than one message") + } + + return s.finalError() +} + +// sendBytes returns the number of times the 'sender' function was called and an error. +func sendBytes(data []byte, chunkSize int, sender func([]byte) error) (int, error) { + i := 0 + for ; len(data) > 0; i++ { + n := chunkSize + if n > len(data) { + n = len(data) + } + + if err := sender(data[:n]); err != nil { + return i, err + } + data = data[n:] + } + + return i, nil +} + +func (s *GitalyTestServer) finalError() error { + if code := s.finalMessageCode; code != codes.OK { + return status.Errorf(code, "error as specified by test") + } + + return nil +} + +func validateRepository(repo *gitalypb.Repository) error { + if len(repo.GetStorageName()) == 0 { + return fmt.Errorf("missing storage_name: %v", repo) + } + if len(repo.GetRelativePath()) == 0 { + return fmt.Errorf("missing relative_path: %v", repo) + } + return nil +} diff --git a/workhorse/internal/testhelper/testhelper.go b/workhorse/internal/testhelper/testhelper.go new file mode 100644 index 00000000000..40097bd453a --- /dev/null +++ b/workhorse/internal/testhelper/testhelper.go @@ -0,0 +1,152 @@ +package testhelper + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path" + "regexp" + "runtime" + "testing" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" +) + +func ConfigureSecret() { + secret.SetPath(path.Join(RootDir(), "testdata/test-secret")) +} + +func RequireResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) { + t.Helper() + require.Equal(t, expectedBody, response.Body.String(), "response body") +} + +func RequireResponseHeader(t *testing.T, w interface{}, header string, expected ...string) { + t.Helper() + var actual []string + + header = http.CanonicalHeaderKey(header) + type headerer interface{ Header() http.Header } + + switch resp := w.(type) { + case *http.Response: + actual = resp.Header[header] + case headerer: + actual = resp.Header()[header] + default: + t.Fatal("invalid type of w passed RequireResponseHeader") + } + + require.Equal(t, expected, actual, "values for HTTP header %s", header) +} + +func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logEntry := log.WithFields(log.Fields{ + "method": r.Method, + "url": r.URL, + "action": "DENY", + }) + + if url != nil && !url.MatchString(r.URL.Path) { + logEntry.Info("UPSTREAM") + w.WriteHeader(404) + return + } + + if version := r.Header.Get("Gitlab-Workhorse"); version == "" { + logEntry.Info("UPSTREAM") + w.WriteHeader(403) + return + } + + handler(w, r) + })) +} + +var workhorseExecutables = []string{"gitlab-workhorse", "gitlab-zip-cat", "gitlab-zip-metadata", "gitlab-resize-image"} + +func BuildExecutables() error { + rootDir := RootDir() + + for _, exe := range workhorseExecutables { + if _, err := os.Stat(path.Join(rootDir, exe)); os.IsNotExist(err) { + return fmt.Errorf("cannot find executable %s. Please run 'make prepare-tests'", exe) + } + } + + oldPath := os.Getenv("PATH") + testPath := fmt.Sprintf("%s:%s", rootDir, oldPath) + if err := os.Setenv("PATH", testPath); err != nil { + return fmt.Errorf("failed to set PATH to %v", testPath) + } + + return nil +} + +func RootDir() string { + _, currentFile, _, ok := runtime.Caller(0) + if !ok { + panic(errors.New("RootDir: calling runtime.Caller failed")) + } + return path.Join(path.Dir(currentFile), "../..") +} + +func LoadFile(t *testing.T, filePath string) string { + t.Helper() + content, err := ioutil.ReadFile(path.Join(RootDir(), filePath)) + require.NoError(t, err) + return string(content) +} + +func ReadAll(t *testing.T, r io.Reader) []byte { + t.Helper() + + b, err := ioutil.ReadAll(r) + require.NoError(t, err) + return b +} + +func ParseJWT(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + ConfigureSecret() + secretBytes, err := secret.Bytes() + if err != nil { + return nil, fmt.Errorf("read secret from file: %v", err) + } + + return secretBytes, nil +} + +// UploadClaims represents the JWT claim for upload parameters +type UploadClaims struct { + Upload map[string]string `json:"upload"` + jwt.StandardClaims +} + +func Retry(t testing.TB, timeout time.Duration, fn func() error) { + t.Helper() + start := time.Now() + var err error + for ; time.Since(start) < timeout; time.Sleep(time.Millisecond) { + err = fn() + if err == nil { + return + } + } + t.Fatalf("test timeout after %v; last error: %v", timeout, err) +} diff --git a/workhorse/internal/upload/accelerate.go b/workhorse/internal/upload/accelerate.go new file mode 100644 index 00000000000..7d8ea51b14d --- /dev/null +++ b/workhorse/internal/upload/accelerate.go @@ -0,0 +1,32 @@ +package upload + +import ( + "fmt" + "net/http" + + "github.com/dgrijalva/jwt-go" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +const RewrittenFieldsHeader = "Gitlab-Workhorse-Multipart-Fields" + +type MultipartClaims struct { + RewrittenFields map[string]string `json:"rewritten_fields"` + jwt.StandardClaims +} + +func Accelerate(rails PreAuthorizer, h http.Handler, p Preparer) http.Handler { + return rails.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) { + s := &SavedFileTracker{Request: r} + + opts, _, err := p.Prepare(a) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("Accelerate: error preparing file storage options")) + return + } + + HandleFileUploads(w, r, h, a, s, opts) + }, "/authorize") +} diff --git a/workhorse/internal/upload/body_uploader.go b/workhorse/internal/upload/body_uploader.go new file mode 100644 index 00000000000..2cee90195fb --- /dev/null +++ b/workhorse/internal/upload/body_uploader.go @@ -0,0 +1,90 @@ +package upload + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +type PreAuthorizer interface { + PreAuthorizeHandler(next api.HandleFunc, suffix string) http.Handler +} + +// Verifier allows to check an upload before sending it to rails +type Verifier interface { + // Verify can abort the upload returning an error + Verify(handler *filestore.FileHandler) error +} + +// Preparer allows to customize BodyUploader configuration +type Preparer interface { + // Prepare converts api.Response into a *SaveFileOpts, it can optionally return an Verifier that will be + // invoked after the real upload, before the finalization with rails + Prepare(a *api.Response) (*filestore.SaveFileOpts, Verifier, error) +} + +type DefaultPreparer struct{} + +func (s *DefaultPreparer) Prepare(a *api.Response) (*filestore.SaveFileOpts, Verifier, error) { + opts, err := filestore.GetOpts(a) + return opts, nil, err +} + +// BodyUploader is an http.Handler that perform a pre authorization call to rails before hijacking the request body and +// uploading it. +// Providing an Preparer allows to customize the upload process +func BodyUploader(rails PreAuthorizer, h http.Handler, p Preparer) http.Handler { + return rails.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) { + opts, verifier, err := p.Prepare(a) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("BodyUploader: preparation failed: %v", err)) + return + } + + fh, err := filestore.SaveFileFromReader(r.Context(), r.Body, r.ContentLength, opts) + if err != nil { + helper.Fail500(w, r, fmt.Errorf("BodyUploader: upload failed: %v", err)) + return + } + + if verifier != nil { + if err := verifier.Verify(fh); err != nil { + helper.Fail500(w, r, fmt.Errorf("BodyUploader: verification failed: %v", err)) + return + } + } + + data := url.Values{} + fields, err := fh.GitLabFinalizeFields("file") + if err != nil { + helper.Fail500(w, r, fmt.Errorf("BodyUploader: finalize fields failed: %v", err)) + return + } + + for k, v := range fields { + data.Set(k, v) + } + + // Hijack body + body := data.Encode() + r.Body = ioutil.NopCloser(strings.NewReader(body)) + r.ContentLength = int64(len(body)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + sft := SavedFileTracker{Request: r} + sft.Track("file", fh.LocalPath) + if err := sft.Finalize(r.Context()); err != nil { + helper.Fail500(w, r, fmt.Errorf("BodyUploader: finalize failed: %v", err)) + return + } + + // And proxy the request + h.ServeHTTP(w, r) + }, "/authorize") +} diff --git a/workhorse/internal/upload/body_uploader_test.go b/workhorse/internal/upload/body_uploader_test.go new file mode 100644 index 00000000000..451d7c97fab --- /dev/null +++ b/workhorse/internal/upload/body_uploader_test.go @@ -0,0 +1,195 @@ +package upload + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strconv" + "strings" + "testing" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +const ( + fileContent = "A test file content" + fileLen = len(fileContent) +) + +func TestBodyUploader(t *testing.T) { + testhelper.ConfigureSecret() + + body := strings.NewReader(fileContent) + + resp := testUpload(&rails{}, &alwaysLocalPreparer{}, echoProxy(t, fileLen), body) + require.Equal(t, http.StatusOK, resp.StatusCode) + + uploadEcho, err := ioutil.ReadAll(resp.Body) + + require.NoError(t, err, "Can't read response body") + require.Equal(t, fileContent, string(uploadEcho)) +} + +func TestBodyUploaderCustomPreparer(t *testing.T) { + body := strings.NewReader(fileContent) + + resp := testUpload(&rails{}, &alwaysLocalPreparer{}, echoProxy(t, fileLen), body) + require.Equal(t, http.StatusOK, resp.StatusCode) + + uploadEcho, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err, "Can't read response body") + require.Equal(t, fileContent, string(uploadEcho)) +} + +func TestBodyUploaderCustomVerifier(t *testing.T) { + body := strings.NewReader(fileContent) + verifier := &mockVerifier{} + + resp := testUpload(&rails{}, &alwaysLocalPreparer{verifier: verifier}, echoProxy(t, fileLen), body) + require.Equal(t, http.StatusOK, resp.StatusCode) + + uploadEcho, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err, "Can't read response body") + require.Equal(t, fileContent, string(uploadEcho)) + require.True(t, verifier.invoked, "Verifier.Verify not invoked") +} + +func TestBodyUploaderAuthorizationFailure(t *testing.T) { + testNoProxyInvocation(t, http.StatusUnauthorized, &rails{unauthorized: true}, &alwaysLocalPreparer{}) +} + +func TestBodyUploaderErrors(t *testing.T) { + tests := []struct { + name string + preparer *alwaysLocalPreparer + }{ + {name: "Prepare failure", preparer: &alwaysLocalPreparer{prepareError: fmt.Errorf("")}}, + {name: "Verify failure", preparer: &alwaysLocalPreparer{verifier: &alwaysFailsVerifier{}}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testNoProxyInvocation(t, http.StatusInternalServerError, &rails{}, test.preparer) + }) + } +} + +func testNoProxyInvocation(t *testing.T, expectedStatus int, auth PreAuthorizer, preparer Preparer) { + proxy := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Fail(t, "request proxied upstream") + }) + + resp := testUpload(auth, preparer, proxy, nil) + require.Equal(t, expectedStatus, resp.StatusCode) +} + +func testUpload(auth PreAuthorizer, preparer Preparer, proxy http.Handler, body io.Reader) *http.Response { + req := httptest.NewRequest("POST", "http://example.com/upload", body) + w := httptest.NewRecorder() + + BodyUploader(auth, proxy, preparer).ServeHTTP(w, req) + + return w.Result() +} + +func echoProxy(t *testing.T, expectedBodyLength int) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + require.NoError(t, err) + + require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"), "Wrong Content-Type header") + + require.Contains(t, r.PostForm, "file.md5") + require.Contains(t, r.PostForm, "file.sha1") + require.Contains(t, r.PostForm, "file.sha256") + require.Contains(t, r.PostForm, "file.sha512") + + require.Contains(t, r.PostForm, "file.path") + require.Contains(t, r.PostForm, "file.size") + require.Contains(t, r.PostForm, "file.gitlab-workhorse-upload") + require.Equal(t, strconv.Itoa(expectedBodyLength), r.PostFormValue("file.size")) + + token, err := jwt.ParseWithClaims(r.Header.Get(RewrittenFieldsHeader), &MultipartClaims{}, testhelper.ParseJWT) + require.NoError(t, err, "Wrong JWT header") + + rewrittenFields := token.Claims.(*MultipartClaims).RewrittenFields + if len(rewrittenFields) != 1 || len(rewrittenFields["file"]) == 0 { + t.Fatalf("Unexpected rewritten_fields value: %v", rewrittenFields) + } + + token, jwtErr := jwt.ParseWithClaims(r.PostFormValue("file.gitlab-workhorse-upload"), &testhelper.UploadClaims{}, testhelper.ParseJWT) + require.NoError(t, jwtErr, "Wrong signed upload fields") + + uploadFields := token.Claims.(*testhelper.UploadClaims).Upload + require.Contains(t, uploadFields, "name") + require.Contains(t, uploadFields, "path") + require.Contains(t, uploadFields, "remote_url") + require.Contains(t, uploadFields, "remote_id") + require.Contains(t, uploadFields, "size") + require.Contains(t, uploadFields, "md5") + require.Contains(t, uploadFields, "sha1") + require.Contains(t, uploadFields, "sha256") + require.Contains(t, uploadFields, "sha512") + + path := r.PostFormValue("file.path") + uploaded, err := os.Open(path) + require.NoError(t, err, "File not uploaded") + + //sending back the file for testing purpose + io.Copy(w, uploaded) + }) +} + +type rails struct { + unauthorized bool +} + +func (r *rails) PreAuthorizeHandler(next api.HandleFunc, _ string) http.Handler { + if r.unauthorized { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next(w, r, &api.Response{TempPath: os.TempDir()}) + }) +} + +type alwaysLocalPreparer struct { + verifier Verifier + prepareError error +} + +func (a *alwaysLocalPreparer) Prepare(_ *api.Response) (*filestore.SaveFileOpts, Verifier, error) { + opts, err := filestore.GetOpts(&api.Response{TempPath: os.TempDir()}) + if err != nil { + return nil, nil, err + } + + return opts, a.verifier, a.prepareError +} + +type alwaysFailsVerifier struct{} + +func (alwaysFailsVerifier) Verify(handler *filestore.FileHandler) error { + return fmt.Errorf("Verification failed") +} + +type mockVerifier struct { + invoked bool +} + +func (m *mockVerifier) Verify(handler *filestore.FileHandler) error { + m.invoked = true + + return nil +} diff --git a/workhorse/internal/upload/exif/exif.go b/workhorse/internal/upload/exif/exif.go new file mode 100644 index 00000000000..a9307b1ca90 --- /dev/null +++ b/workhorse/internal/upload/exif/exif.go @@ -0,0 +1,107 @@ +package exif + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "os/exec" + "regexp" + + "gitlab.com/gitlab-org/labkit/log" +) + +var ErrRemovingExif = errors.New("error while removing EXIF") + +type cleaner struct { + ctx context.Context + cmd *exec.Cmd + stdout io.Reader + stderr bytes.Buffer + eof bool +} + +func NewCleaner(ctx context.Context, stdin io.Reader) (io.ReadCloser, error) { + c := &cleaner{ctx: ctx} + + if err := c.startProcessing(stdin); err != nil { + return nil, err + } + + return c, nil +} + +func (c *cleaner) Close() error { + if c.cmd == nil { + return nil + } + + return c.cmd.Wait() +} + +func (c *cleaner) Read(p []byte) (int, error) { + if c.eof { + return 0, io.EOF + } + + n, err := c.stdout.Read(p) + if err == io.EOF { + if waitErr := c.cmd.Wait(); waitErr != nil { + log.WithContextFields(c.ctx, log.Fields{ + "command": c.cmd.Args, + "stderr": c.stderr.String(), + "error": waitErr.Error(), + }).Print("exiftool command failed") + + return n, ErrRemovingExif + } + + c.eof = true + } + + return n, err +} + +func (c *cleaner) startProcessing(stdin io.Reader) error { + var err error + + whitelisted_tags := []string{ + "-ResolutionUnit", + "-XResolution", + "-YResolution", + "-YCbCrSubSampling", + "-YCbCrPositioning", + "-BitsPerSample", + "-ImageHeight", + "-ImageWidth", + "-ImageSize", + "-Copyright", + "-CopyrightNotice", + "-Orientation", + } + + args := append([]string{"-all=", "--IPTC:all", "--XMP-iptcExt:all", "-tagsFromFile", "@"}, whitelisted_tags...) + args = append(args, "-") + c.cmd = exec.CommandContext(c.ctx, "exiftool", args...) + + c.cmd.Stderr = &c.stderr + c.cmd.Stdin = stdin + + c.stdout, err = c.cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %v", err) + } + + if err = c.cmd.Start(); err != nil { + return fmt.Errorf("start %v: %v", c.cmd.Args, err) + } + + return nil +} + +func IsExifFile(filename string) bool { + filenameMatch := regexp.MustCompile(`(?i)\.(jpg|jpeg|tiff)$`) + + return filenameMatch.MatchString(filename) +} diff --git a/workhorse/internal/upload/exif/exif_test.go b/workhorse/internal/upload/exif/exif_test.go new file mode 100644 index 00000000000..373d97f7fce --- /dev/null +++ b/workhorse/internal/upload/exif/exif_test.go @@ -0,0 +1,95 @@ +package exif + +import ( + "context" + "io" + "io/ioutil" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsExifFile(t *testing.T) { + tests := []struct { + name string + expected bool + }{ + { + name: "/full/path.jpg", + expected: true, + }, + { + name: "path.jpeg", + expected: true, + }, + { + name: "path.tiff", + expected: true, + }, + { + name: "path.JPG", + expected: true, + }, + { + name: "path.tar", + expected: false, + }, + { + name: "path", + expected: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.expected, IsExifFile(test.name)) + }) + } +} + +func TestNewCleanerWithValidFile(t *testing.T) { + input, err := os.Open("testdata/sample_exif.jpg") + require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cleaner, err := NewCleaner(ctx, input) + require.NoError(t, err, "Expected no error when creating cleaner command") + + size, err := io.Copy(ioutil.Discard, cleaner) + require.NoError(t, err, "Expected no error when reading output") + + sizeAfterStrip := int64(25399) + require.Equal(t, sizeAfterStrip, size, "Different size of converted image") +} + +func TestNewCleanerWithInvalidFile(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cleaner, err := NewCleaner(ctx, strings.NewReader("invalid image")) + require.NoError(t, err, "Expected no error when creating cleaner command") + + size, err := io.Copy(ioutil.Discard, cleaner) + require.Error(t, err, "Expected error when reading output") + require.Equal(t, int64(0), size, "Size of invalid image should be 0") +} + +func TestNewCleanerReadingAfterEOF(t *testing.T) { + input, err := os.Open("testdata/sample_exif.jpg") + require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cleaner, err := NewCleaner(ctx, input) + require.NoError(t, err, "Expected no error when creating cleaner command") + + _, err = io.Copy(ioutil.Discard, cleaner) + require.NoError(t, err, "Expected no error when reading output") + + buf := make([]byte, 1) + size, err := cleaner.Read(buf) + require.Equal(t, 0, size, "The output was already consumed by previous reads") + require.Equal(t, io.EOF, err, "We return EOF") +} diff --git a/workhorse/internal/upload/exif/testdata/sample_exif.jpg b/workhorse/internal/upload/exif/testdata/sample_exif.jpg Binary files differnew file mode 100644 index 00000000000..05eda3f7f95 --- /dev/null +++ b/workhorse/internal/upload/exif/testdata/sample_exif.jpg diff --git a/workhorse/internal/upload/object_storage_preparer.go b/workhorse/internal/upload/object_storage_preparer.go new file mode 100644 index 00000000000..7a113fae80a --- /dev/null +++ b/workhorse/internal/upload/object_storage_preparer.go @@ -0,0 +1,28 @@ +package upload + +import ( + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" +) + +type ObjectStoragePreparer struct { + config config.ObjectStorageConfig + credentials config.ObjectStorageCredentials +} + +func NewObjectStoragePreparer(c config.Config) Preparer { + return &ObjectStoragePreparer{credentials: c.ObjectStorageCredentials, config: c.ObjectStorageConfig} +} + +func (p *ObjectStoragePreparer) Prepare(a *api.Response) (*filestore.SaveFileOpts, Verifier, error) { + opts, err := filestore.GetOpts(a) + if err != nil { + return nil, nil, err + } + + opts.ObjectStorageConfig.URLMux = p.config.URLMux + opts.ObjectStorageConfig.S3Credentials = p.credentials.S3Credentials + + return opts, nil, nil +} diff --git a/workhorse/internal/upload/object_storage_preparer_test.go b/workhorse/internal/upload/object_storage_preparer_test.go new file mode 100644 index 00000000000..613b6071275 --- /dev/null +++ b/workhorse/internal/upload/object_storage_preparer_test.go @@ -0,0 +1,62 @@ +package upload_test + +import ( + "testing" + + "gocloud.dev/blob" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" + + "github.com/stretchr/testify/require" +) + +func TestPrepareWithS3Config(t *testing.T) { + creds := config.S3Credentials{ + AwsAccessKeyID: "test-key", + AwsSecretAccessKey: "test-secret", + } + + c := config.Config{ + ObjectStorageCredentials: config.ObjectStorageCredentials{ + Provider: "AWS", + S3Credentials: creds, + }, + ObjectStorageConfig: config.ObjectStorageConfig{ + URLMux: new(blob.URLMux), + }, + } + + r := &api.Response{ + RemoteObject: api.RemoteObject{ + ID: "the ID", + UseWorkhorseClient: true, + ObjectStorage: &api.ObjectStorageParams{ + Provider: "AWS", + }, + }, + } + + p := upload.NewObjectStoragePreparer(c) + opts, v, err := p.Prepare(r) + + require.NoError(t, err) + require.True(t, opts.ObjectStorageConfig.IsAWS()) + require.True(t, opts.UseWorkhorseClient) + require.Equal(t, creds, opts.ObjectStorageConfig.S3Credentials) + require.NotNil(t, opts.ObjectStorageConfig.URLMux) + require.Equal(t, nil, v) +} + +func TestPrepareWithNoConfig(t *testing.T) { + c := config.Config{} + r := &api.Response{RemoteObject: api.RemoteObject{ID: "id"}} + p := upload.NewObjectStoragePreparer(c) + opts, v, err := p.Prepare(r) + + require.NoError(t, err) + require.False(t, opts.UseWorkhorseClient) + require.Nil(t, v) + require.Nil(t, opts.ObjectStorageConfig.URLMux) +} diff --git a/workhorse/internal/upload/rewrite.go b/workhorse/internal/upload/rewrite.go new file mode 100644 index 00000000000..e51604c6ed9 --- /dev/null +++ b/workhorse/internal/upload/rewrite.go @@ -0,0 +1,203 @@ +package upload + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "strings" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/lsif_transformer/parser" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload/exif" +) + +// ErrInjectedClientParam means that the client sent a parameter that overrides one of our own fields +var ErrInjectedClientParam = errors.New("injected client parameter") + +var ( + multipartUploadRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + + Name: "gitlab_workhorse_multipart_upload_requests", + Help: "How many multipart upload requests have been processed by gitlab-workhorse. Partitioned by type.", + }, + []string{"type"}, + ) + + multipartFileUploadBytes = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_multipart_upload_bytes", + Help: "How many disk bytes of multipart file parts have been successfully written by gitlab-workhorse. Partitioned by type.", + }, + []string{"type"}, + ) + + multipartFiles = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_multipart_upload_files", + Help: "How many multipart file parts have been processed by gitlab-workhorse. Partitioned by type.", + }, + []string{"type"}, + ) +) + +type rewriter struct { + writer *multipart.Writer + preauth *api.Response + filter MultipartFormProcessor + finalizedFields map[string]bool +} + +func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, preauth *api.Response, filter MultipartFormProcessor, opts *filestore.SaveFileOpts) error { + // Create multipart reader + reader, err := r.MultipartReader() + if err != nil { + if err == http.ErrNotMultipart { + // We want to be able to recognize http.ErrNotMultipart elsewhere so no fmt.Errorf + return http.ErrNotMultipart + } + return fmt.Errorf("get multipart reader: %v", err) + } + + multipartUploadRequests.WithLabelValues(filter.Name()).Inc() + + rew := &rewriter{ + writer: writer, + preauth: preauth, + filter: filter, + finalizedFields: make(map[string]bool), + } + + for { + p, err := reader.NextPart() + if err != nil { + if err == io.EOF { + break + } + return err + } + + name := p.FormName() + if name == "" { + continue + } + + if rew.finalizedFields[name] { + return ErrInjectedClientParam + } + + if p.FileName() != "" { + err = rew.handleFilePart(r.Context(), name, p, opts) + } else { + err = rew.copyPart(r.Context(), name, p) + } + + if err != nil { + return err + } + } + + return nil +} + +func (rew *rewriter) handleFilePart(ctx context.Context, name string, p *multipart.Part, opts *filestore.SaveFileOpts) error { + multipartFiles.WithLabelValues(rew.filter.Name()).Inc() + + filename := p.FileName() + + if strings.Contains(filename, "/") || filename == "." || filename == ".." { + return fmt.Errorf("illegal filename: %q", filename) + } + + opts.TempFilePrefix = filename + + var inputReader io.ReadCloser + var err error + switch { + case exif.IsExifFile(filename): + inputReader, err = handleExifUpload(ctx, p, filename) + if err != nil { + return err + } + case rew.preauth.ProcessLsif: + inputReader, err = handleLsifUpload(ctx, p, opts.LocalTempPath, filename, rew.preauth) + if err != nil { + return err + } + default: + inputReader = ioutil.NopCloser(p) + } + + defer inputReader.Close() + + fh, err := filestore.SaveFileFromReader(ctx, inputReader, -1, opts) + if err != nil { + switch err { + case filestore.ErrEntityTooLarge, exif.ErrRemovingExif: + return err + default: + return fmt.Errorf("persisting multipart file: %v", err) + } + } + + fields, err := fh.GitLabFinalizeFields(name) + if err != nil { + return fmt.Errorf("failed to finalize fields: %v", err) + } + + for key, value := range fields { + rew.writer.WriteField(key, value) + rew.finalizedFields[key] = true + } + + multipartFileUploadBytes.WithLabelValues(rew.filter.Name()).Add(float64(fh.Size)) + + return rew.filter.ProcessFile(ctx, name, fh, rew.writer) +} + +func handleExifUpload(ctx context.Context, r io.Reader, filename string) (io.ReadCloser, error) { + log.WithContextFields(ctx, log.Fields{ + "filename": filename, + }).Print("running exiftool to remove any metadata") + + cleaner, err := exif.NewCleaner(ctx, r) + if err != nil { + return nil, err + } + + return cleaner, nil +} + +func handleLsifUpload(ctx context.Context, reader io.Reader, tempPath, filename string, preauth *api.Response) (io.ReadCloser, error) { + parserConfig := parser.Config{ + TempPath: tempPath, + } + + return parser.NewParser(ctx, reader, parserConfig) +} + +func (rew *rewriter) copyPart(ctx context.Context, name string, p *multipart.Part) error { + np, err := rew.writer.CreatePart(p.Header) + if err != nil { + return fmt.Errorf("create multipart field: %v", err) + } + + if _, err := io.Copy(np, p); err != nil { + return fmt.Errorf("duplicate multipart field: %v", err) + } + + if err := rew.filter.ProcessField(ctx, name, rew.writer); err != nil { + return fmt.Errorf("process multipart field: %v", err) + } + + return nil +} diff --git a/workhorse/internal/upload/saved_file_tracker.go b/workhorse/internal/upload/saved_file_tracker.go new file mode 100644 index 00000000000..7b6cade4faa --- /dev/null +++ b/workhorse/internal/upload/saved_file_tracker.go @@ -0,0 +1,55 @@ +package upload + +import ( + "context" + "fmt" + "mime/multipart" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" +) + +type SavedFileTracker struct { + Request *http.Request + rewrittenFields map[string]string +} + +func (s *SavedFileTracker) Track(fieldName string, localPath string) { + if s.rewrittenFields == nil { + s.rewrittenFields = make(map[string]string) + } + s.rewrittenFields[fieldName] = localPath +} + +func (s *SavedFileTracker) Count() int { + return len(s.rewrittenFields) +} + +func (s *SavedFileTracker) ProcessFile(_ context.Context, fieldName string, file *filestore.FileHandler, _ *multipart.Writer) error { + s.Track(fieldName, file.LocalPath) + return nil +} + +func (s *SavedFileTracker) ProcessField(_ context.Context, _ string, _ *multipart.Writer) error { + return nil +} + +func (s *SavedFileTracker) Finalize(_ context.Context) error { + if s.rewrittenFields == nil { + return nil + } + + claims := MultipartClaims{RewrittenFields: s.rewrittenFields, StandardClaims: secret.DefaultClaims} + tokenString, err := secret.JWTTokenString(claims) + if err != nil { + return fmt.Errorf("savedFileTracker.Finalize: %v", err) + } + + s.Request.Header.Set(RewrittenFieldsHeader, tokenString) + return nil +} + +func (s *SavedFileTracker) Name() string { + return "accelerate" +} diff --git a/workhorse/internal/upload/saved_file_tracker_test.go b/workhorse/internal/upload/saved_file_tracker_test.go new file mode 100644 index 00000000000..e5a5e8f23a7 --- /dev/null +++ b/workhorse/internal/upload/saved_file_tracker_test.go @@ -0,0 +1,39 @@ +package upload + +import ( + "context" + + "github.com/dgrijalva/jwt-go" + + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func TestSavedFileTracking(t *testing.T) { + testhelper.ConfigureSecret() + + r, err := http.NewRequest("PUT", "/url/path", nil) + require.NoError(t, err) + + tracker := SavedFileTracker{Request: r} + require.Equal(t, "accelerate", tracker.Name()) + + file := &filestore.FileHandler{} + ctx := context.Background() + tracker.ProcessFile(ctx, "test", file, nil) + require.Equal(t, 1, tracker.Count()) + + tracker.Finalize(ctx) + token, err := jwt.ParseWithClaims(r.Header.Get(RewrittenFieldsHeader), &MultipartClaims{}, testhelper.ParseJWT) + require.NoError(t, err) + + rewrittenFields := token.Claims.(*MultipartClaims).RewrittenFields + require.Equal(t, 1, len(rewrittenFields)) + + require.Contains(t, rewrittenFields, "test") +} diff --git a/workhorse/internal/upload/skip_rails_authorizer.go b/workhorse/internal/upload/skip_rails_authorizer.go new file mode 100644 index 00000000000..716467b8841 --- /dev/null +++ b/workhorse/internal/upload/skip_rails_authorizer.go @@ -0,0 +1,22 @@ +package upload + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" +) + +// SkipRailsAuthorizer implements a fake PreAuthorizer that do not calls rails API and +// authorize each call as a local only upload to TempPath +type SkipRailsAuthorizer struct { + // TempPath is the temporary path for a local only upload + TempPath string +} + +// PreAuthorizeHandler implements PreAuthorizer. It always grant the upload. +// The fake API response contains only TempPath +func (l *SkipRailsAuthorizer) PreAuthorizeHandler(next api.HandleFunc, _ string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next(w, r, &api.Response{TempPath: l.TempPath}) + }) +} diff --git a/workhorse/internal/upload/uploads.go b/workhorse/internal/upload/uploads.go new file mode 100644 index 00000000000..3be39f9518f --- /dev/null +++ b/workhorse/internal/upload/uploads.go @@ -0,0 +1,66 @@ +package upload + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "mime/multipart" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload/exif" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts" +) + +// These methods are allowed to have thread-unsafe implementations. +type MultipartFormProcessor interface { + ProcessFile(ctx context.Context, formName string, file *filestore.FileHandler, writer *multipart.Writer) error + ProcessField(ctx context.Context, formName string, writer *multipart.Writer) error + Finalize(ctx context.Context) error + Name() string +} + +func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, preauth *api.Response, filter MultipartFormProcessor, opts *filestore.SaveFileOpts) { + var body bytes.Buffer + writer := multipart.NewWriter(&body) + defer writer.Close() + + // Rewrite multipart form data + err := rewriteFormFilesFromMultipart(r, writer, preauth, filter, opts) + if err != nil { + switch err { + case ErrInjectedClientParam: + helper.CaptureAndFail(w, r, err, "Bad Request", http.StatusBadRequest) + case http.ErrNotMultipart: + h.ServeHTTP(w, r) + case filestore.ErrEntityTooLarge: + helper.RequestEntityTooLarge(w, r, err) + case zipartifacts.ErrBadMetadata: + helper.RequestEntityTooLarge(w, r, err) + case exif.ErrRemovingExif: + helper.CaptureAndFail(w, r, err, "Failed to process image", http.StatusUnprocessableEntity) + default: + helper.Fail500(w, r, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err)) + } + return + } + + // Close writer + writer.Close() + + // Hijack the request + r.Body = ioutil.NopCloser(&body) + r.ContentLength = int64(body.Len()) + r.Header.Set("Content-Type", writer.FormDataContentType()) + + if err := filter.Finalize(r.Context()); err != nil { + helper.Fail500(w, r, fmt.Errorf("handleFileUploads: Finalize: %v", err)) + return + } + + // Proxy the request + h.ServeHTTP(w, r) +} diff --git a/workhorse/internal/upload/uploads_test.go b/workhorse/internal/upload/uploads_test.go new file mode 100644 index 00000000000..fc1a1ac57ef --- /dev/null +++ b/workhorse/internal/upload/uploads_test.go @@ -0,0 +1,475 @@ +package upload + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream/roundtripper" +) + +var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) + +type testFormProcessor struct{} + +func (a *testFormProcessor) ProcessFile(ctx context.Context, formName string, file *filestore.FileHandler, writer *multipart.Writer) error { + return nil +} + +func (a *testFormProcessor) ProcessField(ctx context.Context, formName string, writer *multipart.Writer) error { + if formName != "token" && !strings.HasPrefix(formName, "file.") && !strings.HasPrefix(formName, "other.") { + return fmt.Errorf("illegal field: %v", formName) + } + return nil +} + +func (a *testFormProcessor) Finalize(ctx context.Context) error { + return nil +} + +func (a *testFormProcessor) Name() string { + return "" +} + +func TestUploadTempPathRequirement(t *testing.T) { + apiResponse := &api.Response{} + preparer := &DefaultPreparer{} + _, _, err := preparer.Prepare(apiResponse) + require.Error(t, err) +} + +func TestUploadHandlerForwardingRawData(t *testing.T) { + ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "PATCH", r.Method, "method") + + body, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, "REQUEST", string(body), "request body") + + w.WriteHeader(202) + fmt.Fprint(w, "RESPONSE") + }) + defer ts.Close() + + httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) + require.NoError(t, err) + + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + response := httptest.NewRecorder() + + handler := newProxy(ts.URL) + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, handler, apiResponse, nil, opts) + + require.Equal(t, 202, response.Code) + require.Equal(t, "RESPONSE", response.Body.String(), "response body") +} + +func TestUploadHandlerRewritingMultiPartData(t *testing.T) { + var filePath string + + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "PUT", r.Method, "method") + require.NoError(t, r.ParseMultipartForm(100000)) + + require.Empty(t, r.MultipartForm.File, "Expected to not receive any files") + require.Equal(t, "test", r.FormValue("token"), "Expected to receive token") + require.Equal(t, "my.file", r.FormValue("file.name"), "Expected to receive a filename") + + filePath = r.FormValue("file.path") + require.True(t, strings.HasPrefix(filePath, tempPath), "Expected to the file to be in tempPath") + + require.Empty(t, r.FormValue("file.remote_url"), "Expected to receive empty remote_url") + require.Empty(t, r.FormValue("file.remote_id"), "Expected to receive empty remote_id") + require.Equal(t, "4", r.FormValue("file.size"), "Expected to receive the file size") + + hashes := map[string]string{ + "md5": "098f6bcd4621d373cade4e832627b4f6", + "sha1": "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3", + "sha256": "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08", + "sha512": "ee26b0dd4af7e749aa1a8ee3c10ae9923f618980772e473f8819a5d4940e0db27ac185f8a0e1d5f84f88bc887fd67b143732c304cc5fa9ad8e6f57f50028a8ff", + } + + for algo, hash := range hashes { + require.Equal(t, hash, r.FormValue("file."+algo), "file hash %s", algo) + } + + require.Len(t, r.MultipartForm.Value, 11, "multipart form values") + + w.WriteHeader(202) + fmt.Fprint(w, "RESPONSE") + }) + + var buffer bytes.Buffer + + writer := multipart.NewWriter(&buffer) + writer.WriteField("token", "test") + file, err := writer.CreateFormFile("file", "my.file") + require.NoError(t, err) + fmt.Fprint(file, "test") + writer.Close() + + httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + httpRequest = httpRequest.WithContext(ctx) + httpRequest.Body = ioutil.NopCloser(&buffer) + httpRequest.ContentLength = int64(buffer.Len()) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + response := httptest.NewRecorder() + + handler := newProxy(ts.URL) + + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts) + require.Equal(t, 202, response.Code) + + cancel() // this will trigger an async cleanup + waitUntilDeleted(t, filePath) +} + +func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) { + var filePath string + + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + tests := []struct { + name string + field string + response int + }{ + { + name: "injected file.path", + field: "file.path", + response: 400, + }, + { + name: "injected file.remote_id", + field: "file.remote_id", + response: 400, + }, + { + name: "field with other prefix", + field: "other.path", + response: 202, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "PUT", r.Method, "method") + + w.WriteHeader(202) + fmt.Fprint(w, "RESPONSE") + }) + + var buffer bytes.Buffer + + writer := multipart.NewWriter(&buffer) + file, err := writer.CreateFormFile("file", "my.file") + require.NoError(t, err) + fmt.Fprint(file, "test") + + writer.WriteField(test.field, "value") + writer.Close() + + httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", &buffer) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + httpRequest = httpRequest.WithContext(ctx) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + response := httptest.NewRecorder() + + handler := newProxy(ts.URL) + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts) + require.Equal(t, test.response, response.Code) + + cancel() // this will trigger an async cleanup + waitUntilDeleted(t, filePath) + }) + } +} + +func TestUploadProcessingField(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + var buffer bytes.Buffer + + writer := multipart.NewWriter(&buffer) + writer.WriteField("token2", "test") + writer.Close() + + httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer) + require.NoError(t, err) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + + response := httptest.NewRecorder() + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, nilHandler, apiResponse, &testFormProcessor{}, opts) + + require.Equal(t, 500, response.Code) +} + +func TestUploadProcessingFile(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + _, testServer := test.StartObjectStore() + defer testServer.Close() + + storeUrl := testServer.URL + test.ObjectPath + + tests := []struct { + name string + preauth api.Response + }{ + { + name: "FileStore Upload", + preauth: api.Response{TempPath: tempPath}, + }, + { + name: "ObjectStore Upload", + preauth: api.Response{RemoteObject: api.RemoteObject{StoreURL: storeUrl}}, + }, + { + name: "ObjectStore and FileStore Upload", + preauth: api.Response{ + TempPath: tempPath, + RemoteObject: api.RemoteObject{StoreURL: storeUrl}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var buffer bytes.Buffer + writer := multipart.NewWriter(&buffer) + file, err := writer.CreateFormFile("file", "my.file") + require.NoError(t, err) + fmt.Fprint(file, "test") + writer.Close() + + httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer) + require.NoError(t, err) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + + response := httptest.NewRecorder() + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, nilHandler, apiResponse, &testFormProcessor{}, opts) + + require.Equal(t, 200, response.Code) + }) + } + +} + +func TestInvalidFileNames(t *testing.T) { + testhelper.ConfigureSecret() + + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + for _, testCase := range []struct { + filename string + code int + }{ + {"foobar", 200}, // sanity check for test setup below + {"foo/bar", 500}, + {"/../../foobar", 500}, + {".", 500}, + {"..", 500}, + } { + buffer := &bytes.Buffer{} + + writer := multipart.NewWriter(buffer) + file, err := writer.CreateFormFile("file", testCase.filename) + require.NoError(t, err) + fmt.Fprint(file, "test") + writer.Close() + + httpRequest, err := http.NewRequest("POST", "/example", buffer) + require.NoError(t, err) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + + response := httptest.NewRecorder() + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, nilHandler, apiResponse, &SavedFileTracker{Request: httpRequest}, opts) + require.Equal(t, testCase.code, response.Code) + } +} + +func TestUploadHandlerRemovingExif(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + var buffer bytes.Buffer + + content, err := ioutil.ReadFile("exif/testdata/sample_exif.jpg") + require.NoError(t, err) + + writer := multipart.NewWriter(&buffer) + file, err := writer.CreateFormFile("file", "test.jpg") + require.NoError(t, err) + + _, err = file.Write(content) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(100000) + require.NoError(t, err) + + size, err := strconv.Atoi(r.FormValue("file.size")) + require.NoError(t, err) + require.True(t, size < len(content), "Expected the file to be smaller after removal of exif") + require.True(t, size > 0, "Expected to receive not empty file") + + w.WriteHeader(200) + fmt.Fprint(w, "RESPONSE") + }) + defer ts.Close() + + httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", &buffer) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpRequest = httpRequest.WithContext(ctx) + httpRequest.ContentLength = int64(buffer.Len()) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + response := httptest.NewRecorder() + + handler := newProxy(ts.URL) + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts) + require.Equal(t, 200, response.Code) +} + +func TestUploadHandlerRemovingInvalidExif(t *testing.T) { + tempPath, err := ioutil.TempDir("", "uploads") + require.NoError(t, err) + defer os.RemoveAll(tempPath) + + var buffer bytes.Buffer + + writer := multipart.NewWriter(&buffer) + file, err := writer.CreateFormFile("file", "test.jpg") + require.NoError(t, err) + + fmt.Fprint(file, "this is not valid image data") + err = writer.Close() + require.NoError(t, err) + + ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(100000) + require.Error(t, err) + }) + defer ts.Close() + + httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", &buffer) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpRequest = httpRequest.WithContext(ctx) + httpRequest.ContentLength = int64(buffer.Len()) + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + response := httptest.NewRecorder() + + handler := newProxy(ts.URL) + apiResponse := &api.Response{TempPath: tempPath} + preparer := &DefaultPreparer{} + opts, _, err := preparer.Prepare(apiResponse) + require.NoError(t, err) + + HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts) + require.Equal(t, 422, response.Code) +} + +func newProxy(url string) *proxy.Proxy { + parsedURL := helper.URLMustParse(url) + return proxy.NewProxy(parsedURL, "123", roundtripper.NewTestBackendRoundTripper(parsedURL)) +} + +func waitUntilDeleted(t *testing.T, path string) { + var err error + + // Poll because the file removal is async + for i := 0; i < 100; i++ { + _, err = os.Stat(path) + if err != nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + require.True(t, os.IsNotExist(err), "expected the file to be deleted") +} diff --git a/workhorse/internal/upstream/development_test.go b/workhorse/internal/upstream/development_test.go new file mode 100644 index 00000000000..d2957abb18b --- /dev/null +++ b/workhorse/internal/upstream/development_test.go @@ -0,0 +1,39 @@ +package upstream + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDevelopmentModeEnabled(t *testing.T) { + developmentMode := true + + r, _ := http.NewRequest("GET", "/something", nil) + w := httptest.NewRecorder() + + executed := false + NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + executed = true + })).ServeHTTP(w, r) + + require.True(t, executed, "The handler should get executed") +} + +func TestDevelopmentModeDisabled(t *testing.T) { + developmentMode := false + + r, _ := http.NewRequest("GET", "/something", nil) + w := httptest.NewRecorder() + + executed := false + NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + executed = true + })).ServeHTTP(w, r) + + require.False(t, executed, "The handler should not get executed") + + require.Equal(t, 404, w.Code) +} diff --git a/workhorse/internal/upstream/handlers.go b/workhorse/internal/upstream/handlers.go new file mode 100644 index 00000000000..a6aa148a4ae --- /dev/null +++ b/workhorse/internal/upstream/handlers.go @@ -0,0 +1,39 @@ +package upstream + +import ( + "compress/gzip" + "fmt" + "io" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +func contentEncodingHandler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body io.ReadCloser + var err error + + // The client request body may have been gzipped. + contentEncoding := r.Header.Get("Content-Encoding") + switch contentEncoding { + case "": + body = r.Body + case "gzip": + body, err = gzip.NewReader(r.Body) + default: + err = fmt.Errorf("unsupported content encoding: %s", contentEncoding) + } + + if err != nil { + helper.Fail500(w, r, fmt.Errorf("contentEncodingHandler: %v", err)) + return + } + defer body.Close() + + r.Body = body + r.Header.Del("Content-Encoding") + + h.ServeHTTP(w, r) + }) +} diff --git a/workhorse/internal/upstream/handlers_test.go b/workhorse/internal/upstream/handlers_test.go new file mode 100644 index 00000000000..10c7479f5c5 --- /dev/null +++ b/workhorse/internal/upstream/handlers_test.go @@ -0,0 +1,67 @@ +package upstream + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGzipEncoding(t *testing.T) { + resp := httptest.NewRecorder() + + var b bytes.Buffer + w := gzip.NewWriter(&b) + fmt.Fprint(w, "test") + w.Close() + + body := ioutil.NopCloser(&b) + + req, err := http.NewRequest("POST", "http://address/test", body) + require.NoError(t, err) + req.Header.Set("Content-Encoding", "gzip") + + contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + require.IsType(t, &gzip.Reader{}, r.Body, "body type") + require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted") + })).ServeHTTP(resp, req) + + require.Equal(t, 200, resp.Code) +} + +func TestNoEncoding(t *testing.T) { + resp := httptest.NewRecorder() + + var b bytes.Buffer + body := ioutil.NopCloser(&b) + + req, err := http.NewRequest("POST", "http://address/test", body) + require.NoError(t, err) + req.Header.Set("Content-Encoding", "") + + contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + require.Equal(t, body, r.Body, "Expected the same body") + require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted") + })).ServeHTTP(resp, req) + + require.Equal(t, 200, resp.Code) +} + +func TestInvalidEncoding(t *testing.T) { + resp := httptest.NewRecorder() + + req, err := http.NewRequest("POST", "http://address/test", nil) + require.NoError(t, err) + req.Header.Set("Content-Encoding", "application/unknown") + + contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Fatal("it shouldn't be executed") + })).ServeHTTP(resp, req) + + require.Equal(t, 500, resp.Code) +} diff --git a/workhorse/internal/upstream/metrics.go b/workhorse/internal/upstream/metrics.go new file mode 100644 index 00000000000..38528056d43 --- /dev/null +++ b/workhorse/internal/upstream/metrics.go @@ -0,0 +1,117 @@ +package upstream + +import ( + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +const ( + namespace = "gitlab_workhorse" + httpSubsystem = "http" +) + +func secondsDurationBuckets() []float64 { + return []float64{ + 0.005, /* 5ms */ + 0.025, /* 25ms */ + 0.1, /* 100ms */ + 0.5, /* 500ms */ + 1.0, /* 1s */ + 10.0, /* 10s */ + 30.0, /* 30s */ + 60.0, /* 1m */ + 300.0, /* 10m */ + } +} + +func byteSizeBuckets() []float64 { + return []float64{ + 10, + 64, + 256, + 1024, /* 1kB */ + 64 * 1024, /* 64kB */ + 256 * 1024, /* 256kB */ + 1024 * 1024, /* 1mB */ + 64 * 1024 * 1024, /* 64mB */ + } +} + +var ( + httpInFlightRequests = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: httpSubsystem, + Name: "in_flight_requests", + Help: "A gauge of requests currently being served by workhorse.", + }) + + httpRequestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: httpSubsystem, + Name: "requests_total", + Help: "A counter for requests to workhorse.", + }, + []string{"code", "method", "route"}, + ) + + httpRequestDurationSeconds = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: httpSubsystem, + Name: "request_duration_seconds", + Help: "A histogram of latencies for requests to workhorse.", + Buckets: secondsDurationBuckets(), + }, + []string{"code", "method", "route"}, + ) + + httpRequestSizeBytes = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: httpSubsystem, + Name: "request_size_bytes", + Help: "A histogram of sizes of requests to workhorse.", + Buckets: byteSizeBuckets(), + }, + []string{"code", "method", "route"}, + ) + + httpResponseSizeBytes = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: httpSubsystem, + Name: "response_size_bytes", + Help: "A histogram of response sizes for requests to workhorse.", + Buckets: byteSizeBuckets(), + }, + []string{"code", "method", "route"}, + ) + + httpTimeToWriteHeaderSeconds = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: httpSubsystem, + Name: "time_to_write_header_seconds", + Help: "A histogram of request durations until the response headers are written.", + Buckets: secondsDurationBuckets(), + }, + []string{"code", "method", "route"}, + ) +) + +func instrumentRoute(next http.Handler, method string, regexpStr string) http.Handler { + handler := next + + handler = promhttp.InstrumentHandlerCounter(httpRequestsTotal.MustCurryWith(map[string]string{"route": regexpStr}), handler) + handler = promhttp.InstrumentHandlerDuration(httpRequestDurationSeconds.MustCurryWith(map[string]string{"route": regexpStr}), handler) + handler = promhttp.InstrumentHandlerInFlight(httpInFlightRequests, handler) + handler = promhttp.InstrumentHandlerRequestSize(httpRequestSizeBytes.MustCurryWith(map[string]string{"route": regexpStr}), handler) + handler = promhttp.InstrumentHandlerResponseSize(httpResponseSizeBytes.MustCurryWith(map[string]string{"route": regexpStr}), handler) + handler = promhttp.InstrumentHandlerTimeToWriteHeader(httpTimeToWriteHeaderSeconds.MustCurryWith(map[string]string{"route": regexpStr}), handler) + + return handler +} diff --git a/workhorse/internal/upstream/notfoundunless.go b/workhorse/internal/upstream/notfoundunless.go new file mode 100644 index 00000000000..3bbe3e873a4 --- /dev/null +++ b/workhorse/internal/upstream/notfoundunless.go @@ -0,0 +1,11 @@ +package upstream + +import "net/http" + +func NotFoundUnless(pass bool, handler http.Handler) http.Handler { + if pass { + return handler + } + + return http.HandlerFunc(http.NotFound) +} diff --git a/workhorse/internal/upstream/roundtripper/roundtripper.go b/workhorse/internal/upstream/roundtripper/roundtripper.go new file mode 100644 index 00000000000..84f1983b471 --- /dev/null +++ b/workhorse/internal/upstream/roundtripper/roundtripper.go @@ -0,0 +1,61 @@ +package roundtripper + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "time" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/tracing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" +) + +func mustParseAddress(address, scheme string) string { + if scheme == "https" { + panic("TLS is not supported for backend connections") + } + + for _, suffix := range []string{"", ":" + scheme} { + address += suffix + if host, port, err := net.SplitHostPort(address); err == nil && host != "" && port != "" { + return host + ":" + port + } + } + + panic(fmt.Errorf("could not parse host:port from address %q and scheme %q", address, scheme)) +} + +// NewBackendRoundTripper returns a new RoundTripper instance using the provided values +func NewBackendRoundTripper(backend *url.URL, socket string, proxyHeadersTimeout time.Duration, developmentMode bool) http.RoundTripper { + // Copied from the definition of http.DefaultTransport. We can't literally copy http.DefaultTransport because of its hidden internal state. + transport, dialer := newBackendTransport() + transport.ResponseHeaderTimeout = proxyHeadersTimeout + + if backend != nil && socket == "" { + address := mustParseAddress(backend.Host, backend.Scheme) + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", address) + } + } else if socket != "" { + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", socket) + } + } else { + panic("backend is nil and socket is empty") + } + + return tracing.NewRoundTripper( + correlation.NewInstrumentedRoundTripper( + badgateway.NewRoundTripper(developmentMode, transport), + ), + ) +} + +// NewTestBackendRoundTripper sets up a RoundTripper for testing purposes +func NewTestBackendRoundTripper(backend *url.URL) http.RoundTripper { + return NewBackendRoundTripper(backend, "", 0, true) +} diff --git a/workhorse/internal/upstream/roundtripper/roundtripper_test.go b/workhorse/internal/upstream/roundtripper/roundtripper_test.go new file mode 100644 index 00000000000..79ffa244918 --- /dev/null +++ b/workhorse/internal/upstream/roundtripper/roundtripper_test.go @@ -0,0 +1,39 @@ +package roundtripper + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMustParseAddress(t *testing.T) { + successExamples := []struct{ address, scheme, expected string }{ + {"1.2.3.4:56", "http", "1.2.3.4:56"}, + {"[::1]:23", "http", "::1:23"}, + {"4.5.6.7", "http", "4.5.6.7:http"}, + } + for i, example := range successExamples { + t.Run(strconv.Itoa(i), func(t *testing.T) { + require.Equal(t, example.expected, mustParseAddress(example.address, example.scheme)) + }) + } +} + +func TestMustParseAddressPanic(t *testing.T) { + panicExamples := []struct{ address, scheme string }{ + {"1.2.3.4", ""}, + {"1.2.3.4", "https"}, + } + + for i, panicExample := range panicExamples { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic") + } + }() + mustParseAddress(panicExample.address, panicExample.scheme) + }) + } +} diff --git a/workhorse/internal/upstream/roundtripper/transport.go b/workhorse/internal/upstream/roundtripper/transport.go new file mode 100644 index 00000000000..84d9623b129 --- /dev/null +++ b/workhorse/internal/upstream/roundtripper/transport.go @@ -0,0 +1,27 @@ +package roundtripper + +import ( + "net" + "net/http" + "time" +) + +// newBackendTransport setups the default HTTP transport which Workhorse uses +// to communicate with the upstream +func newBackendTransport() (*http.Transport, *net.Dialer) { + dialler := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialler.DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + return transport, dialler +} diff --git a/workhorse/internal/upstream/routes.go b/workhorse/internal/upstream/routes.go new file mode 100644 index 00000000000..5bbd245719b --- /dev/null +++ b/workhorse/internal/upstream/routes.go @@ -0,0 +1,345 @@ +package upstream + +import ( + "net/http" + "net/url" + "path" + "regexp" + + "github.com/gorilla/websocket" + + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/tracing" + + apipkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/artifacts" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/builds" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/channel" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/git" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/imageresizer" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/lfs" + proxypkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/redis" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/sendurl" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" +) + +type matcherFunc func(*http.Request) bool + +type routeEntry struct { + method string + regex *regexp.Regexp + handler http.Handler + matchers []matcherFunc +} + +type routeOptions struct { + tracing bool + matchers []matcherFunc +} + +type uploadPreparers struct { + artifacts upload.Preparer + lfs upload.Preparer + packages upload.Preparer + uploads upload.Preparer +} + +const ( + apiPattern = `^/api/` + ciAPIPattern = `^/ci/api/` + gitProjectPattern = `^/([^/]+/){1,}[^/]+\.git/` + projectPattern = `^/([^/]+/){1,}[^/]+/` + snippetUploadPattern = `^/uploads/personal_snippet` + userUploadPattern = `^/uploads/user` + importPattern = `^/import/` +) + +func compileRegexp(regexpStr string) *regexp.Regexp { + if len(regexpStr) == 0 { + return nil + } + + return regexp.MustCompile(regexpStr) +} + +func withMatcher(f matcherFunc) func(*routeOptions) { + return func(options *routeOptions) { + options.matchers = append(options.matchers, f) + } +} + +func withoutTracing() func(*routeOptions) { + return func(options *routeOptions) { + options.tracing = false + } +} + +func (u *upstream) observabilityMiddlewares(handler http.Handler, method string, regexpStr string) http.Handler { + handler = log.AccessLogger( + handler, + log.WithAccessLogger(u.accessLogger), + log.WithExtraFields(func(r *http.Request) log.Fields { + return log.Fields{ + "route": regexpStr, // This field matches the `route` label in Prometheus metrics + } + }), + ) + + handler = instrumentRoute(handler, method, regexpStr) // Add prometheus metrics + return handler +} + +func (u *upstream) route(method, regexpStr string, handler http.Handler, opts ...func(*routeOptions)) routeEntry { + // Instantiate a route with the defaults + options := routeOptions{ + tracing: true, + } + + for _, f := range opts { + f(&options) + } + + handler = u.observabilityMiddlewares(handler, method, regexpStr) + handler = denyWebsocket(handler) // Disallow websockets + if options.tracing { + // Add distributed tracing + handler = tracing.Handler(handler, tracing.WithRouteIdentifier(regexpStr)) + } + + return routeEntry{ + method: method, + regex: compileRegexp(regexpStr), + handler: handler, + matchers: options.matchers, + } +} + +func (u *upstream) wsRoute(regexpStr string, handler http.Handler, matchers ...matcherFunc) routeEntry { + method := "GET" + handler = u.observabilityMiddlewares(handler, method, regexpStr) + + return routeEntry{ + method: method, + regex: compileRegexp(regexpStr), + handler: handler, + matchers: append(matchers, websocket.IsWebSocketUpgrade), + } +} + +// Creates matcherFuncs for a particular content type. +func isContentType(contentType string) func(*http.Request) bool { + return func(r *http.Request) bool { + return helper.IsContentType(contentType, r.Header.Get("Content-Type")) + } +} + +func (ro *routeEntry) isMatch(cleanedPath string, req *http.Request) bool { + if ro.method != "" && req.Method != ro.method { + return false + } + + if ro.regex != nil && !ro.regex.MatchString(cleanedPath) { + return false + } + + ok := true + for _, matcher := range ro.matchers { + ok = matcher(req) + if !ok { + break + } + } + + return ok +} + +func buildProxy(backend *url.URL, version string, rt http.RoundTripper, cfg config.Config) http.Handler { + proxier := proxypkg.NewProxy(backend, version, rt) + + return senddata.SendData( + sendfile.SendFile(apipkg.Block(proxier)), + git.SendArchive, + git.SendBlob, + git.SendDiff, + git.SendPatch, + git.SendSnapshot, + artifacts.SendEntry, + sendurl.SendURL, + imageresizer.NewResizer(cfg), + ) +} + +// Routing table +// We match against URI not containing the relativeUrlRoot: +// see upstream.ServeHTTP + +func (u *upstream) configureRoutes() { + api := apipkg.NewAPI( + u.Backend, + u.Version, + u.RoundTripper, + ) + + static := &staticpages.Static{DocumentRoot: u.DocumentRoot} + proxy := buildProxy(u.Backend, u.Version, u.RoundTripper, u.Config) + cableProxy := proxypkg.NewProxy(u.CableBackend, u.Version, u.CableRoundTripper) + + assetsNotFoundHandler := NotFoundUnless(u.DevelopmentMode, proxy) + if u.AltDocumentRoot != "" { + altStatic := &staticpages.Static{DocumentRoot: u.AltDocumentRoot} + assetsNotFoundHandler = altStatic.ServeExisting( + u.URLPrefix, + staticpages.CacheExpireMax, + NotFoundUnless(u.DevelopmentMode, proxy), + ) + } + + signingTripper := secret.NewRoundTripper(u.RoundTripper, u.Version) + signingProxy := buildProxy(u.Backend, u.Version, signingTripper, u.Config) + + preparers := createUploadPreparers(u.Config) + uploadPath := path.Join(u.DocumentRoot, "uploads/tmp") + uploadAccelerateProxy := upload.Accelerate(&upload.SkipRailsAuthorizer{TempPath: uploadPath}, proxy, preparers.uploads) + ciAPIProxyQueue := queueing.QueueRequests("ci_api_job_requests", uploadAccelerateProxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout) + ciAPILongPolling := builds.RegisterHandler(ciAPIProxyQueue, redis.WatchKey, u.APICILongPollingDuration) + + // Serve static files or forward the requests + defaultUpstream := static.ServeExisting( + u.URLPrefix, + staticpages.CacheDisabled, + static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, uploadAccelerateProxy)), + ) + probeUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatJSON, proxy) + healthUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatText, proxy) + + u.Routes = []routeEntry{ + // Git Clone + u.route("GET", gitProjectPattern+`info/refs\z`, git.GetInfoRefsHandler(api)), + u.route("POST", gitProjectPattern+`git-upload-pack\z`, contentEncodingHandler(git.UploadPack(api)), withMatcher(isContentType("application/x-git-upload-pack-request"))), + u.route("POST", gitProjectPattern+`git-receive-pack\z`, contentEncodingHandler(git.ReceivePack(api)), withMatcher(isContentType("application/x-git-receive-pack-request"))), + u.route("PUT", gitProjectPattern+`gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`, lfs.PutStore(api, signingProxy, preparers.lfs), withMatcher(isContentType("application/octet-stream"))), + + // CI Artifacts + u.route("POST", apiPattern+`v4/jobs/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy, preparers.artifacts))), + u.route("POST", ciAPIPattern+`v1/builds/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy, preparers.artifacts))), + + // ActionCable websocket + u.wsRoute(`^/-/cable\z`, cableProxy), + + // Terminal websocket + u.wsRoute(projectPattern+`-/environments/[0-9]+/terminal.ws\z`, channel.Handler(api)), + u.wsRoute(projectPattern+`-/jobs/[0-9]+/terminal.ws\z`, channel.Handler(api)), + + // Proxy Job Services + u.wsRoute(projectPattern+`-/jobs/[0-9]+/proxy.ws\z`, channel.Handler(api)), + + // Long poll and limit capacity given to jobs/request and builds/register.json + u.route("", apiPattern+`v4/jobs/request\z`, ciAPILongPolling), + u.route("", ciAPIPattern+`v1/builds/register.json\z`, ciAPILongPolling), + + // Maven Artifact Repository + u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/maven/`, upload.BodyUploader(api, signingProxy, preparers.packages)), + + // Conan Artifact Repository + u.route("PUT", apiPattern+`v4/packages/conan/`, upload.BodyUploader(api, signingProxy, preparers.packages)), + u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/conan/`, upload.BodyUploader(api, signingProxy, preparers.packages)), + + // Generic Packages Repository + u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/generic/`, upload.BodyUploader(api, signingProxy, preparers.packages)), + + // NuGet Artifact Repository + u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/nuget/`, upload.Accelerate(api, signingProxy, preparers.packages)), + + // PyPI Artifact Repository + u.route("POST", apiPattern+`v4/projects/[0-9]+/packages/pypi`, upload.Accelerate(api, signingProxy, preparers.packages)), + + // Debian Artifact Repository + u.route("PUT", apiPattern+`v4/projects/[0-9]+/-/packages/debian/incoming/`, upload.BodyUploader(api, signingProxy, preparers.packages)), + + // We are porting API to disk acceleration + // we need to declare each routes until we have fixed all the routes on the rails codebase. + // Overall status can be seen at https://gitlab.com/groups/gitlab-org/-/epics/1802#current-status + u.route("POST", apiPattern+`v4/projects/[0-9]+/wikis/attachments\z`, uploadAccelerateProxy), + u.route("POST", apiPattern+`graphql\z`, uploadAccelerateProxy), + u.route("POST", apiPattern+`v4/groups/import`, upload.Accelerate(api, signingProxy, preparers.uploads)), + u.route("POST", apiPattern+`v4/projects/import`, upload.Accelerate(api, signingProxy, preparers.uploads)), + + // Project Import via UI upload acceleration + u.route("POST", importPattern+`gitlab_project`, upload.Accelerate(api, signingProxy, preparers.uploads)), + // Group Import via UI upload acceleration + u.route("POST", importPattern+`gitlab_group`, upload.Accelerate(api, signingProxy, preparers.uploads)), + + // Metric image upload + u.route("POST", apiPattern+`v4/projects/[0-9]+/issues/[0-9]+/metric_images\z`, upload.Accelerate(api, signingProxy, preparers.uploads)), + + // Requirements Import via UI upload acceleration + u.route("POST", projectPattern+`requirements_management/requirements/import_csv`, upload.Accelerate(api, signingProxy, preparers.uploads)), + + // Explicitly proxy API requests + u.route("", apiPattern, proxy), + u.route("", ciAPIPattern, proxy), + + // Serve assets + u.route( + "", `^/assets/`, + static.ServeExisting( + u.URLPrefix, + staticpages.CacheExpireMax, + assetsNotFoundHandler, + ), + withoutTracing(), // Tracing on assets is very noisy + ), + + // Uploads + u.route("POST", projectPattern+`uploads\z`, upload.Accelerate(api, signingProxy, preparers.uploads)), + u.route("POST", snippetUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)), + u.route("POST", userUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)), + + // For legacy reasons, user uploads are stored under the document root. + // To prevent anybody who knows/guesses the URL of a user-uploaded file + // from downloading it we make sure requests to /uploads/ do _not_ pass + // through static.ServeExisting. + u.route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, proxy)), + + // health checks don't intercept errors and go straight to rails + // TODO: We should probably not return a HTML deploy page? + // https://gitlab.com/gitlab-org/gitlab-workhorse/issues/230 + u.route("", "^/-/(readiness|liveness)$", static.DeployPage(probeUpstream)), + u.route("", "^/-/health$", static.DeployPage(healthUpstream)), + + // This route lets us filter out health checks from our metrics. + u.route("", "^/-/", defaultUpstream), + + u.route("", "", defaultUpstream), + } +} + +func createUploadPreparers(cfg config.Config) uploadPreparers { + defaultPreparer := upload.NewObjectStoragePreparer(cfg) + + return uploadPreparers{ + artifacts: defaultPreparer, + lfs: lfs.NewLfsUploadPreparer(cfg, defaultPreparer), + packages: defaultPreparer, + uploads: defaultPreparer, + } +} + +func denyWebsocket(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if websocket.IsWebSocketUpgrade(r) { + helper.HTTPError(w, r, "websocket upgrade not allowed", http.StatusBadRequest) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/workhorse/internal/upstream/upstream.go b/workhorse/internal/upstream/upstream.go new file mode 100644 index 00000000000..fd3f6191a5a --- /dev/null +++ b/workhorse/internal/upstream/upstream.go @@ -0,0 +1,123 @@ +/* +The upstream type implements http.Handler. + +In this file we handle request routing and interaction with the authBackend. +*/ + +package upstream + +import ( + "fmt" + + "net/http" + "strings" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/labkit/correlation" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream/roundtripper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix" +) + +var ( + DefaultBackend = helper.URLMustParse("http://localhost:8080") + requestHeaderBlacklist = []string{ + upload.RewrittenFieldsHeader, + } +) + +type upstream struct { + config.Config + URLPrefix urlprefix.Prefix + Routes []routeEntry + RoundTripper http.RoundTripper + CableRoundTripper http.RoundTripper + accessLogger *logrus.Logger +} + +func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler { + up := upstream{ + Config: cfg, + accessLogger: accessLogger, + } + if up.Backend == nil { + up.Backend = DefaultBackend + } + if up.CableBackend == nil { + up.CableBackend = up.Backend + } + if up.CableSocket == "" { + up.CableSocket = up.Socket + } + up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode) + up.CableRoundTripper = roundtripper.NewBackendRoundTripper(up.CableBackend, up.CableSocket, up.ProxyHeadersTimeout, cfg.DevelopmentMode) + up.configureURLPrefix() + up.configureRoutes() + + var correlationOpts []correlation.InboundHandlerOption + if cfg.PropagateCorrelationID { + correlationOpts = append(correlationOpts, correlation.WithPropagation()) + } + + handler := correlation.InjectCorrelationID(&up, correlationOpts...) + return handler +} + +func (u *upstream) configureURLPrefix() { + relativeURLRoot := u.Backend.Path + if !strings.HasSuffix(relativeURLRoot, "/") { + relativeURLRoot += "/" + } + u.URLPrefix = urlprefix.Prefix(relativeURLRoot) +} + +func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { + helper.FixRemoteAddr(r) + + helper.DisableResponseBuffering(w) + + // Drop RequestURI == "*" (FIXME: why?) + if r.RequestURI == "*" { + helper.HTTPError(w, r, "Connection upgrade not allowed", http.StatusBadRequest) + return + } + + // Disallow connect + if r.Method == "CONNECT" { + helper.HTTPError(w, r, "CONNECT not allowed", http.StatusBadRequest) + return + } + + // Check URL Root + URIPath := urlprefix.CleanURIPath(r.URL.Path) + prefix := u.URLPrefix + if !prefix.Match(URIPath) { + helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound) + return + } + + // Look for a matching route + var route *routeEntry + for _, ro := range u.Routes { + if ro.isMatch(prefix.Strip(URIPath), r) { + route = &ro + break + } + } + + if route == nil { + // The protocol spec in git/Documentation/technical/http-protocol.txt + // says we must return 403 if no matching service is found. + helper.HTTPError(w, r, "Forbidden", http.StatusForbidden) + return + } + + for _, h := range requestHeaderBlacklist { + r.Header.Del(h) + } + + route.handler.ServeHTTP(w, r) +} diff --git a/workhorse/internal/urlprefix/urlprefix.go b/workhorse/internal/urlprefix/urlprefix.go new file mode 100644 index 00000000000..23eefe70c67 --- /dev/null +++ b/workhorse/internal/urlprefix/urlprefix.go @@ -0,0 +1,35 @@ +package urlprefix + +import ( + "path" + "strings" +) + +type Prefix string + +func (p Prefix) Strip(path string) string { + return CleanURIPath(strings.TrimPrefix(path, string(p))) +} + +func (p Prefix) Match(path string) bool { + pre := string(p) + return strings.HasPrefix(path, pre) || path+"/" == pre +} + +// Borrowed from: net/http/server.go +// Return the canonical path for p, eliminating . and .. elements. +func CleanURIPath(p string) string { + if p == "" { + return "/" + } + if p[0] != '/' { + p = "/" + p + } + np := path.Clean(p) + // path.Clean removes trailing slash except for root; + // put the trailing slash back if necessary. + if p[len(p)-1] == '/' && np != "/" { + np += "/" + } + return np +} diff --git a/workhorse/internal/utils/svg/LICENSE b/workhorse/internal/utils/svg/LICENSE new file mode 100644 index 00000000000..f67807d0070 --- /dev/null +++ b/workhorse/internal/utils/svg/LICENSE @@ -0,0 +1,24 @@ +The MIT License + +Copyright (c) 2016 Tomas Aparicio + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. diff --git a/workhorse/internal/utils/svg/README.md b/workhorse/internal/utils/svg/README.md new file mode 100644 index 00000000000..e5531f47473 --- /dev/null +++ b/workhorse/internal/utils/svg/README.md @@ -0,0 +1,45 @@ +# go-is-svg + +Tiny package to verify if a given file buffer is an SVG image in Go (golang). + +## Installation + +```bash +go get -u github.com/h2non/go-is-svg +``` + +## Example + +```go +package main + +import ( + "fmt" + "io/ioutil" + + svg "github.com/h2non/go-is-svg" +) + +func main() { + buf, err := ioutil.ReadFile("_example/example.svg") + if err != nil { + fmt.Printf("Error: %s\n", err) + return + } + + if svg.Is(buf) { + fmt.Println("File is an SVG") + } else { + fmt.Println("File is NOT an SVG") + } +} +``` + +Run example: +```bash +go run _example/example.go +``` + +## License + +MIT - Tomas Aparicio diff --git a/workhorse/internal/utils/svg/svg.go b/workhorse/internal/utils/svg/svg.go new file mode 100644 index 00000000000..b209cb5bf33 --- /dev/null +++ b/workhorse/internal/utils/svg/svg.go @@ -0,0 +1,42 @@ +// Copyright (c) 2016 Tomas Aparicio. All rights reserved. +// +// Use of this source code is governed by a MIT License +// license that can be found in the LICENSE file or at +// https://github.com/h2non/go-is-svg/blob/master/LICENSE. + +package svg + +import ( + "regexp" + "unicode/utf8" +) + +var ( + htmlCommentRegex = regexp.MustCompile(`(?i)<!--([\s\S]*?)-->`) + svgRegex = regexp.MustCompile(`(?i)^\s*(?:<\?xml[^>]*>\s*)?(?:<!doctype svg[^>]*>\s*)?<svg[^>]*>`) +) + +// isBinary checks if the given buffer is a binary file. +func isBinary(buf []byte) bool { + if len(buf) < 24 { + return false + } + for i := 0; i < 24; i++ { + charCode, _ := utf8.DecodeRuneInString(string(buf[i])) + if charCode == 65533 || charCode <= 8 { + return true + } + } + return false +} + +// Is returns true if the given buffer is a valid SVG image. +func Is(buf []byte) bool { + return !isBinary(buf) && svgRegex.Match(htmlCommentRegex.ReplaceAll(buf, []byte{})) +} + +// IsSVG returns true if the given buffer is a valid SVG image. +// Alias to: Is() +func IsSVG(buf []byte) bool { + return Is(buf) +} diff --git a/workhorse/internal/zipartifacts/.gitignore b/workhorse/internal/zipartifacts/.gitignore new file mode 100644 index 00000000000..ace1063ab02 --- /dev/null +++ b/workhorse/internal/zipartifacts/.gitignore @@ -0,0 +1 @@ +/testdata diff --git a/workhorse/internal/zipartifacts/entry.go b/workhorse/internal/zipartifacts/entry.go new file mode 100644 index 00000000000..527387ceaa1 --- /dev/null +++ b/workhorse/internal/zipartifacts/entry.go @@ -0,0 +1,13 @@ +package zipartifacts + +import ( + "encoding/base64" +) + +func DecodeFileEntry(entry string) (string, error) { + decoded, err := base64.StdEncoding.DecodeString(entry) + if err != nil { + return "", err + } + return string(decoded), nil +} diff --git a/workhorse/internal/zipartifacts/errors.go b/workhorse/internal/zipartifacts/errors.go new file mode 100644 index 00000000000..162816618f8 --- /dev/null +++ b/workhorse/internal/zipartifacts/errors.go @@ -0,0 +1,57 @@ +package zipartifacts + +import ( + "errors" +) + +// These are exit codes used by subprocesses in cmd/gitlab-zip-xxx. We also use +// them to map errors and error messages that we use as label in Prometheus. +const ( + CodeNotZip = 10 + iota + CodeEntryNotFound + CodeArchiveNotFound + CodeLimitsReached + CodeUnknownError +) + +var ( + ErrorCode = map[int]error{ + CodeNotZip: errors.New("zip archive format invalid"), + CodeEntryNotFound: errors.New("zip entry not found"), + CodeArchiveNotFound: errors.New("zip archive not found"), + CodeLimitsReached: errors.New("zip processing limits reached"), + CodeUnknownError: errors.New("zip processing unknown error"), + } + + ErrorLabel = map[int]string{ + CodeNotZip: "archive_invalid", + CodeEntryNotFound: "entry_not_found", + CodeArchiveNotFound: "archive_not_found", + CodeLimitsReached: "limits_reached", + CodeUnknownError: "unknown_error", + } + + ErrBadMetadata = errors.New("zip artifacts metadata invalid") +) + +// ExitCodeByError find an os.Exit code for a corresponding error. +// CodeUnkownError in case it can not be found. +func ExitCodeByError(err error) int { + for c, e := range ErrorCode { + if err == e { + return c + } + } + + return CodeUnknownError +} + +// ErrorLabelByCode returns a Prometheus counter label associated with an exit code. +func ErrorLabelByCode(code int) string { + label, ok := ErrorLabel[code] + if ok { + return label + } + + return ErrorLabel[CodeUnknownError] +} diff --git a/workhorse/internal/zipartifacts/errors_test.go b/workhorse/internal/zipartifacts/errors_test.go new file mode 100644 index 00000000000..6fce160b3bc --- /dev/null +++ b/workhorse/internal/zipartifacts/errors_test.go @@ -0,0 +1,32 @@ +package zipartifacts + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExitCodeByError(t *testing.T) { + t.Run("when error has been recognized", func(t *testing.T) { + code := ExitCodeByError(ErrorCode[CodeLimitsReached]) + + require.Equal(t, code, CodeLimitsReached) + require.Greater(t, code, 10) + }) + + t.Run("when error is an unknown one", func(t *testing.T) { + code := ExitCodeByError(errors.New("unknown error")) + + require.Equal(t, code, CodeUnknownError) + require.Greater(t, code, 10) + }) +} + +func TestErrorLabels(t *testing.T) { + for code := range ErrorCode { + _, ok := ErrorLabel[code] + + require.True(t, ok) + } +} diff --git a/workhorse/internal/zipartifacts/metadata.go b/workhorse/internal/zipartifacts/metadata.go new file mode 100644 index 00000000000..1ecf52deafb --- /dev/null +++ b/workhorse/internal/zipartifacts/metadata.go @@ -0,0 +1,117 @@ +package zipartifacts + +import ( + "archive/zip" + "compress/gzip" + "encoding/binary" + "encoding/json" + "io" + "path" + "sort" + "strconv" +) + +type metadata struct { + Modified int64 `json:"modified,omitempty"` + Mode string `json:"mode,omitempty"` + CRC uint32 `json:"crc,omitempty"` + Size uint64 `json:"size,omitempty"` + Zipped uint64 `json:"zipped,omitempty"` + Comment string `json:"comment,omitempty"` +} + +const MetadataHeaderPrefix = "\x00\x00\x00&" // length of string below, encoded properly +const MetadataHeader = "GitLab Build Artifacts Metadata 0.0.2\n" + +func newMetadata(file *zip.File) metadata { + if file == nil { + return metadata{} + } + + return metadata{ + //lint:ignore SA1019 Remove this once the minimum supported version is go 1.10 (go 1.9 and down do not support an alternative) + Modified: file.ModTime().Unix(), + Mode: strconv.FormatUint(uint64(file.Mode().Perm()), 8), + CRC: file.CRC32, + Size: file.UncompressedSize64, + Zipped: file.CompressedSize64, + Comment: file.Comment, + } +} + +func (m metadata) writeEncoded(output io.Writer) error { + j, err := json.Marshal(m) + if err != nil { + return err + } + j = append(j, byte('\n')) + return writeBytes(output, j) +} + +func writeZipEntryMetadata(output io.Writer, path string, entry *zip.File) error { + if err := writeString(output, path); err != nil { + return err + } + + if err := newMetadata(entry).writeEncoded(output); err != nil { + return err + } + + return nil +} + +func GenerateZipMetadata(w io.Writer, archive *zip.Reader) error { + output := gzip.NewWriter(w) + defer output.Close() + + if err := writeString(output, MetadataHeader); err != nil { + return err + } + + // Write empty error header that we may need in the future + if err := writeString(output, "{}"); err != nil { + return err + } + + // Create map of files in zip archive + zipMap := make(map[string]*zip.File, len(archive.File)) + + // Add missing entries + for _, entry := range archive.File { + zipMap[entry.Name] = entry + + for d := path.Dir(entry.Name); d != "." && d != "/"; d = path.Dir(d) { + entryDir := d + "/" + if _, ok := zipMap[entryDir]; !ok { + zipMap[entryDir] = nil + } + } + } + + // Sort paths + sortedPaths := make([]string, 0, len(zipMap)) + for path := range zipMap { + sortedPaths = append(sortedPaths, path) + } + sort.Strings(sortedPaths) + + // Write all files + for _, path := range sortedPaths { + if err := writeZipEntryMetadata(output, path, zipMap[path]); err != nil { + return err + } + } + return nil +} + +func writeBytes(output io.Writer, data []byte) error { + err := binary.Write(output, binary.BigEndian, uint32(len(data))) + if err == nil { + _, err = output.Write(data) + } + return err +} + +func writeString(output io.Writer, str string) error { + return writeBytes(output, []byte(str)) +} diff --git a/workhorse/internal/zipartifacts/metadata_test.go b/workhorse/internal/zipartifacts/metadata_test.go new file mode 100644 index 00000000000..0f130ab4c15 --- /dev/null +++ b/workhorse/internal/zipartifacts/metadata_test.go @@ -0,0 +1,102 @@ +package zipartifacts_test + +import ( + "archive/zip" + "bytes" + "compress/gzip" + "context" + "fmt" + "io" + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts" +) + +func generateTestArchive(w io.Writer) error { + archive := zip.NewWriter(w) + + // non-POSIX paths are here just to test if we never enter infinite loop + files := []string{"file1", "some/file/dir/", "some/file/dir/file2", "../../test12/test", + "/usr/bin/test", `c:\windows\win32.exe`, `c:/windows/win.dll`, "./f/asd", "/"} + + for _, file := range files { + archiveFile, err := archive.Create(file) + if err != nil { + return err + } + + fmt.Fprint(archiveFile, file) + } + + return archive.Close() +} + +func validateMetadata(r io.Reader) error { + gz, err := gzip.NewReader(r) + if err != nil { + return err + } + + meta, err := ioutil.ReadAll(gz) + if err != nil { + return err + } + + paths := []string{"file1", "some/", "some/file/", "some/file/dir/", "some/file/dir/file2"} + for _, path := range paths { + if !bytes.Contains(meta, []byte(path+"\x00")) { + return fmt.Errorf(fmt.Sprintf("zipartifacts: metadata for path %q not found", path)) + } + } + + return nil +} + +func TestGenerateZipMetadataFromFile(t *testing.T) { + var metaBuffer bytes.Buffer + + f, err := ioutil.TempFile("", "workhorse-metadata.zip-") + if f != nil { + defer os.Remove(f.Name()) + } + require.NoError(t, err) + defer f.Close() + + err = generateTestArchive(f) + require.NoError(t, err) + f.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + archive, err := zipartifacts.OpenArchive(ctx, f.Name()) + require.NoError(t, err, "zipartifacts: OpenArchive failed") + + err = zipartifacts.GenerateZipMetadata(&metaBuffer, archive) + require.NoError(t, err, "zipartifacts: GenerateZipMetadata failed") + + err = validateMetadata(&metaBuffer) + require.NoError(t, err) +} + +func TestErrNotAZip(t *testing.T) { + f, err := ioutil.TempFile("", "workhorse-metadata.zip-") + if f != nil { + defer os.Remove(f.Name()) + } + require.NoError(t, err) + defer f.Close() + + _, err = fmt.Fprint(f, "Not a zip file") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = zipartifacts.OpenArchive(ctx, f.Name()) + require.Equal(t, zipartifacts.ErrorCode[zipartifacts.CodeNotZip], err, "OpenArchive requires a zip file") +} diff --git a/workhorse/internal/zipartifacts/open_archive.go b/workhorse/internal/zipartifacts/open_archive.go new file mode 100644 index 00000000000..30b86b66c49 --- /dev/null +++ b/workhorse/internal/zipartifacts/open_archive.go @@ -0,0 +1,138 @@ +package zipartifacts + +import ( + "archive/zip" + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "strings" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/httprs" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/mask" + "gitlab.com/gitlab-org/labkit/tracing" +) + +var httpClient = &http.Client{ + Transport: tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 10 * time.Second, + }).DialContext, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + DisableCompression: true, + })), +} + +type archive struct { + reader io.ReaderAt + size int64 +} + +// OpenArchive will open a zip.Reader from a local path or a remote object store URL +// in case of remote url it will make use of ranged requestes to support seeking. +// If the path do not exists error will be ErrArchiveNotFound, +// if the file isn't a zip archive error will be ErrNotAZip +func OpenArchive(ctx context.Context, archivePath string) (*zip.Reader, error) { + archive, err := openArchiveLocation(ctx, archivePath) + if err != nil { + return nil, err + } + + return openZipReader(archive.reader, archive.size) +} + +// OpenArchiveWithReaderFunc opens a zip.Reader from either local path or a +// remote object, similarly to OpenArchive function. The difference is that it +// allows passing a readerFunc that takes a io.ReaderAt that is either going to +// be os.File or a custom reader we use to read from object storage. The +// readerFunc can augment the archive reader and return a type that satisfies +// io.ReaderAt. +func OpenArchiveWithReaderFunc(ctx context.Context, location string, readerFunc func(io.ReaderAt, int64) io.ReaderAt) (*zip.Reader, error) { + archive, err := openArchiveLocation(ctx, location) + if err != nil { + return nil, err + } + + return openZipReader(readerFunc(archive.reader, archive.size), archive.size) +} + +func openArchiveLocation(ctx context.Context, location string) (*archive, error) { + if isURL(location) { + return openHTTPArchive(ctx, location) + } + + return openFileArchive(ctx, location) +} + +func isURL(path string) bool { + return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") +} + +func openHTTPArchive(ctx context.Context, archivePath string) (*archive, error) { + scrubbedArchivePath := mask.URL(archivePath) + req, err := http.NewRequest(http.MethodGet, archivePath, nil) + if err != nil { + return nil, fmt.Errorf("can't create HTTP GET %q: %v", scrubbedArchivePath, err) + } + req = req.WithContext(ctx) + + resp, err := httpClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, fmt.Errorf("HTTP GET %q: %v", scrubbedArchivePath, err) + } else if resp.StatusCode == http.StatusNotFound { + return nil, ErrorCode[CodeArchiveNotFound] + } else if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP GET %q: %d: %v", scrubbedArchivePath, resp.StatusCode, resp.Status) + } + + rs := httprs.NewHttpReadSeeker(resp, httpClient) + + go func() { + <-ctx.Done() + resp.Body.Close() + rs.Close() + }() + + return &archive{reader: rs, size: resp.ContentLength}, nil +} + +func openFileArchive(ctx context.Context, archivePath string) (*archive, error) { + file, err := os.Open(archivePath) + if err != nil { + if os.IsNotExist(err) { + return nil, ErrorCode[CodeArchiveNotFound] + } + } + + go func() { + <-ctx.Done() + // We close the archive from this goroutine so that we can safely return a *zip.Reader instead of a *zip.ReadCloser + file.Close() + }() + + stat, err := file.Stat() + if err != nil { + return nil, err + } + + return &archive{reader: file, size: stat.Size()}, nil +} + +func openZipReader(archive io.ReaderAt, size int64) (*zip.Reader, error) { + reader, err := zip.NewReader(archive, size) + if err != nil { + return nil, ErrorCode[CodeNotZip] + } + + return reader, nil +} diff --git a/workhorse/internal/zipartifacts/open_archive_test.go b/workhorse/internal/zipartifacts/open_archive_test.go new file mode 100644 index 00000000000..f7624d053d9 --- /dev/null +++ b/workhorse/internal/zipartifacts/open_archive_test.go @@ -0,0 +1,68 @@ +package zipartifacts + +import ( + "archive/zip" + "context" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenHTTPArchive(t *testing.T) { + const ( + zipFile = "test.zip" + entryName = "hello.txt" + contents = "world" + testRoot = "testdata/public" + ) + + require.NoError(t, os.MkdirAll(testRoot, 0755)) + f, err := os.Create(filepath.Join(testRoot, zipFile)) + require.NoError(t, err, "create file") + defer f.Close() + + zw := zip.NewWriter(f) + w, err := zw.Create(entryName) + require.NoError(t, err, "create zip entry") + _, err = fmt.Fprint(w, contents) + require.NoError(t, err, "write zip entry contents") + require.NoError(t, zw.Close(), "close zip writer") + require.NoError(t, f.Close(), "close file") + + srv := httptest.NewServer(http.FileServer(http.Dir(testRoot))) + defer srv.Close() + + zr, err := OpenArchive(context.Background(), srv.URL+"/"+zipFile) + require.NoError(t, err, "call OpenArchive") + require.Len(t, zr.File, 1) + + zf := zr.File[0] + require.Equal(t, entryName, zf.Name, "zip entry name") + + entry, err := zf.Open() + require.NoError(t, err, "get zip entry reader") + defer entry.Close() + + actualContents, err := ioutil.ReadAll(entry) + require.NoError(t, err, "read zip entry contents") + require.Equal(t, contents, string(actualContents), "compare zip entry contents") +} + +func TestOpenHTTPArchiveNotSendingAcceptEncodingHeader(t *testing.T) { + requestHandler := func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "GET", r.Method) + require.Nil(t, r.Header["Accept-Encoding"]) + w.WriteHeader(http.StatusOK) + } + + srv := httptest.NewServer(http.HandlerFunc(requestHandler)) + defer srv.Close() + + OpenArchive(context.Background(), srv.URL) +} |