summaryrefslogtreecommitdiff
path: root/workhorse/internal
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal')
-rw-r--r--workhorse/internal/api/api.go345
-rw-r--r--workhorse/internal/api/block.go61
-rw-r--r--workhorse/internal/api/block_test.go56
-rw-r--r--workhorse/internal/api/channel_settings.go122
-rw-r--r--workhorse/internal/api/channel_settings_test.go154
-rw-r--r--workhorse/internal/artifacts/artifacts_store_test.go338
-rw-r--r--workhorse/internal/artifacts/artifacts_test.go19
-rw-r--r--workhorse/internal/artifacts/artifacts_upload.go167
-rw-r--r--workhorse/internal/artifacts/artifacts_upload_test.go322
-rw-r--r--workhorse/internal/artifacts/entry.go123
-rw-r--r--workhorse/internal/artifacts/entry_test.go134
-rw-r--r--workhorse/internal/artifacts/escape_quotes.go10
-rw-r--r--workhorse/internal/badgateway/roundtripper.go115
-rw-r--r--workhorse/internal/badgateway/roundtripper_test.go56
-rw-r--r--workhorse/internal/builds/register.go163
-rw-r--r--workhorse/internal/builds/register_test.go108
-rw-r--r--workhorse/internal/channel/auth_checker.go69
-rw-r--r--workhorse/internal/channel/auth_checker_test.go53
-rw-r--r--workhorse/internal/channel/channel.go132
-rw-r--r--workhorse/internal/channel/proxy.go56
-rw-r--r--workhorse/internal/channel/wrappers.go134
-rw-r--r--workhorse/internal/channel/wrappers_test.go155
-rw-r--r--workhorse/internal/config/config.go154
-rw-r--r--workhorse/internal/config/config_test.go111
-rw-r--r--workhorse/internal/config/url_openers.go51
-rw-r--r--workhorse/internal/config/url_openers_test.go117
-rw-r--r--workhorse/internal/filestore/file_handler.go257
-rw-r--r--workhorse/internal/filestore/file_handler_test.go551
-rw-r--r--workhorse/internal/filestore/multi_hash.go48
-rw-r--r--workhorse/internal/filestore/reader.go17
-rw-r--r--workhorse/internal/filestore/reader_test.go46
-rw-r--r--workhorse/internal/filestore/save_file_opts.go171
-rw-r--r--workhorse/internal/filestore/save_file_opts_test.go331
-rw-r--r--workhorse/internal/git/archive.go216
-rw-r--r--workhorse/internal/git/archive_test.go87
-rw-r--r--workhorse/internal/git/blob.go47
-rw-r--r--workhorse/internal/git/blob_test.go17
-rw-r--r--workhorse/internal/git/diff.go48
-rw-r--r--workhorse/internal/git/error.go4
-rw-r--r--workhorse/internal/git/format-patch.go48
-rw-r--r--workhorse/internal/git/git-http.go100
-rw-r--r--workhorse/internal/git/info-refs.go76
-rw-r--r--workhorse/internal/git/pktline.go59
-rw-r--r--workhorse/internal/git/pktline_test.go39
-rw-r--r--workhorse/internal/git/receive-pack.go33
-rw-r--r--workhorse/internal/git/responsewriter.go75
-rw-r--r--workhorse/internal/git/snapshot.go64
-rw-r--r--workhorse/internal/git/upload-pack.go57
-rw-r--r--workhorse/internal/git/upload-pack_test.go85
-rw-r--r--workhorse/internal/gitaly/blob.go41
-rw-r--r--workhorse/internal/gitaly/diff.go55
-rw-r--r--workhorse/internal/gitaly/gitaly.go188
-rw-r--r--workhorse/internal/gitaly/gitaly_test.go80
-rw-r--r--workhorse/internal/gitaly/namespace.go8
-rw-r--r--workhorse/internal/gitaly/repository.go45
-rw-r--r--workhorse/internal/gitaly/smarthttp.go139
-rw-r--r--workhorse/internal/gitaly/unmarshal_test.go35
-rw-r--r--workhorse/internal/headers/content_headers.go109
-rw-r--r--workhorse/internal/headers/headers.go62
-rw-r--r--workhorse/internal/headers/headers_test.go24
-rw-r--r--workhorse/internal/helper/context_reader.go40
-rw-r--r--workhorse/internal/helper/context_reader_test.go83
-rw-r--r--workhorse/internal/helper/countingresponsewriter.go56
-rw-r--r--workhorse/internal/helper/countingresponsewriter_test.go50
-rw-r--r--workhorse/internal/helper/helpers.go217
-rw-r--r--workhorse/internal/helper/helpers_test.go258
-rw-r--r--workhorse/internal/helper/raven.go58
-rw-r--r--workhorse/internal/helper/tempfile.go35
-rw-r--r--workhorse/internal/helper/writeafterreader.go144
-rw-r--r--workhorse/internal/helper/writeafterreader_test.go115
-rw-r--r--workhorse/internal/httprs/LICENSE19
-rw-r--r--workhorse/internal/httprs/README.md2
-rw-r--r--workhorse/internal/httprs/httprs.go217
-rw-r--r--workhorse/internal/httprs/httprs_test.go257
-rw-r--r--workhorse/internal/imageresizer/image_resizer.go449
-rw-r--r--workhorse/internal/imageresizer/image_resizer_caching.go44
-rw-r--r--workhorse/internal/imageresizer/image_resizer_test.go259
-rw-r--r--workhorse/internal/lfs/lfs.go55
-rw-r--r--workhorse/internal/lfs/lfs_test.go61
-rw-r--r--workhorse/internal/lsif_transformer/parser/cache.go56
-rw-r--r--workhorse/internal/lsif_transformer/parser/cache_test.go33
-rw-r--r--workhorse/internal/lsif_transformer/parser/code_hover.go124
-rw-r--r--workhorse/internal/lsif_transformer/parser/code_hover_test.go106
-rw-r--r--workhorse/internal/lsif_transformer/parser/docs.go144
-rw-r--r--workhorse/internal/lsif_transformer/parser/docs_test.go54
-rw-r--r--workhorse/internal/lsif_transformer/parser/errors.go30
-rw-r--r--workhorse/internal/lsif_transformer/parser/errors_test.go26
-rw-r--r--workhorse/internal/lsif_transformer/parser/hovers.go162
-rw-r--r--workhorse/internal/lsif_transformer/parser/hovers_test.go30
-rw-r--r--workhorse/internal/lsif_transformer/parser/id.go52
-rw-r--r--workhorse/internal/lsif_transformer/parser/id_test.go28
-rw-r--r--workhorse/internal/lsif_transformer/parser/parser.go109
-rw-r--r--workhorse/internal/lsif_transformer/parser/parser_test.go80
-rw-r--r--workhorse/internal/lsif_transformer/parser/performance_test.go47
-rw-r--r--workhorse/internal/lsif_transformer/parser/ranges.go214
-rw-r--r--workhorse/internal/lsif_transformer/parser/ranges_test.go61
-rw-r--r--workhorse/internal/lsif_transformer/parser/references.go107
-rw-r--r--workhorse/internal/lsif_transformer/parser/references_test.go44
-rw-r--r--workhorse/internal/lsif_transformer/parser/testdata/dump.lsif.zipbin0 -> 2023 bytes
-rw-r--r--workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/main.go.json208
-rw-r--r--workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/morestrings/reverse.go.json249
-rw-r--r--workhorse/internal/lsif_transformer/parser/testdata/workhorse.lsif.zipbin0 -> 2120741 bytes
-rw-r--r--workhorse/internal/objectstore/gocloud_object.go100
-rw-r--r--workhorse/internal/objectstore/gocloud_object_test.go56
-rw-r--r--workhorse/internal/objectstore/multipart.go188
-rw-r--r--workhorse/internal/objectstore/multipart_test.go64
-rw-r--r--workhorse/internal/objectstore/object.go114
-rw-r--r--workhorse/internal/objectstore/object_test.go155
-rw-r--r--workhorse/internal/objectstore/prometheus.go39
-rw-r--r--workhorse/internal/objectstore/s3_complete_multipart_api.go51
-rw-r--r--workhorse/internal/objectstore/s3_object.go119
-rw-r--r--workhorse/internal/objectstore/s3_object_test.go174
-rw-r--r--workhorse/internal/objectstore/s3_session.go94
-rw-r--r--workhorse/internal/objectstore/s3_session_test.go57
-rw-r--r--workhorse/internal/objectstore/test/consts.go19
-rw-r--r--workhorse/internal/objectstore/test/gocloud_stub.go47
-rw-r--r--workhorse/internal/objectstore/test/objectstore_stub.go278
-rw-r--r--workhorse/internal/objectstore/test/objectstore_stub_test.go167
-rw-r--r--workhorse/internal/objectstore/test/s3_stub.go142
-rw-r--r--workhorse/internal/objectstore/upload_strategy.go46
-rw-r--r--workhorse/internal/objectstore/uploader.go115
-rw-r--r--workhorse/internal/proxy/proxy.go62
-rw-r--r--workhorse/internal/queueing/queue.go201
-rw-r--r--workhorse/internal/queueing/queue_test.go62
-rw-r--r--workhorse/internal/queueing/requests.go51
-rw-r--r--workhorse/internal/queueing/requests_test.go76
-rw-r--r--workhorse/internal/redis/keywatcher.go198
-rw-r--r--workhorse/internal/redis/keywatcher_test.go162
-rw-r--r--workhorse/internal/redis/redis.go295
-rw-r--r--workhorse/internal/redis/redis_test.go234
-rw-r--r--workhorse/internal/secret/jwt.go25
-rw-r--r--workhorse/internal/secret/roundtripper.go35
-rw-r--r--workhorse/internal/secret/secret.go77
-rw-r--r--workhorse/internal/senddata/contentprocessor/contentprocessor.go126
-rw-r--r--workhorse/internal/senddata/contentprocessor/contentprocessor_test.go293
-rw-r--r--workhorse/internal/senddata/injecter.go35
-rw-r--r--workhorse/internal/senddata/senddata.go105
-rw-r--r--workhorse/internal/senddata/writer_test.go71
-rw-r--r--workhorse/internal/sendfile/sendfile.go162
-rw-r--r--workhorse/internal/sendfile/sendfile_test.go171
-rw-r--r--workhorse/internal/sendfile/testdata/sent-file.txt1
-rw-r--r--workhorse/internal/sendurl/sendurl.go167
-rw-r--r--workhorse/internal/sendurl/sendurl_test.go197
-rw-r--r--workhorse/internal/staticpages/deploy_page.go26
-rw-r--r--workhorse/internal/staticpages/deploy_page_test.go59
-rw-r--r--workhorse/internal/staticpages/error_pages.go138
-rw-r--r--workhorse/internal/staticpages/error_pages_test.go191
-rw-r--r--workhorse/internal/staticpages/servefile.go84
-rw-r--r--workhorse/internal/staticpages/servefile_test.go134
-rw-r--r--workhorse/internal/staticpages/static.go5
-rw-r--r--workhorse/internal/testhelper/gitaly.go384
-rw-r--r--workhorse/internal/testhelper/testhelper.go152
-rw-r--r--workhorse/internal/upload/accelerate.go32
-rw-r--r--workhorse/internal/upload/body_uploader.go90
-rw-r--r--workhorse/internal/upload/body_uploader_test.go195
-rw-r--r--workhorse/internal/upload/exif/exif.go107
-rw-r--r--workhorse/internal/upload/exif/exif_test.go95
-rw-r--r--workhorse/internal/upload/exif/testdata/sample_exif.jpgbin0 -> 33881 bytes
-rw-r--r--workhorse/internal/upload/object_storage_preparer.go28
-rw-r--r--workhorse/internal/upload/object_storage_preparer_test.go62
-rw-r--r--workhorse/internal/upload/rewrite.go203
-rw-r--r--workhorse/internal/upload/saved_file_tracker.go55
-rw-r--r--workhorse/internal/upload/saved_file_tracker_test.go39
-rw-r--r--workhorse/internal/upload/skip_rails_authorizer.go22
-rw-r--r--workhorse/internal/upload/uploads.go66
-rw-r--r--workhorse/internal/upload/uploads_test.go475
-rw-r--r--workhorse/internal/upstream/development_test.go39
-rw-r--r--workhorse/internal/upstream/handlers.go39
-rw-r--r--workhorse/internal/upstream/handlers_test.go67
-rw-r--r--workhorse/internal/upstream/metrics.go117
-rw-r--r--workhorse/internal/upstream/notfoundunless.go11
-rw-r--r--workhorse/internal/upstream/roundtripper/roundtripper.go61
-rw-r--r--workhorse/internal/upstream/roundtripper/roundtripper_test.go39
-rw-r--r--workhorse/internal/upstream/roundtripper/transport.go27
-rw-r--r--workhorse/internal/upstream/routes.go345
-rw-r--r--workhorse/internal/upstream/upstream.go123
-rw-r--r--workhorse/internal/urlprefix/urlprefix.go35
-rw-r--r--workhorse/internal/utils/svg/LICENSE24
-rw-r--r--workhorse/internal/utils/svg/README.md45
-rw-r--r--workhorse/internal/utils/svg/svg.go42
-rw-r--r--workhorse/internal/zipartifacts/.gitignore1
-rw-r--r--workhorse/internal/zipartifacts/entry.go13
-rw-r--r--workhorse/internal/zipartifacts/errors.go57
-rw-r--r--workhorse/internal/zipartifacts/errors_test.go32
-rw-r--r--workhorse/internal/zipartifacts/metadata.go117
-rw-r--r--workhorse/internal/zipartifacts/metadata_test.go102
-rw-r--r--workhorse/internal/zipartifacts/open_archive.go138
-rw-r--r--workhorse/internal/zipartifacts/open_archive_test.go68
188 files changed, 19985 insertions, 0 deletions
diff --git a/workhorse/internal/api/api.go b/workhorse/internal/api/api.go
new file mode 100644
index 00000000000..17fea398029
--- /dev/null
+++ b/workhorse/internal/api/api.go
@@ -0,0 +1,345 @@
+package api
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
+)
+
+const (
+ // Custom content type for API responses, to catch routing / programming mistakes
+ ResponseContentType = "application/vnd.gitlab-workhorse+json"
+
+ failureResponseLimit = 32768
+)
+
+type API struct {
+ Client *http.Client
+ URL *url.URL
+ Version string
+}
+
+var (
+ requestsCounter = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_internal_api_requests",
+ Help: "How many internal API requests have been completed by gitlab-workhorse, partitioned by status code and HTTP method.",
+ },
+ []string{"code", "method"},
+ )
+ bytesTotal = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_internal_api_failure_response_bytes",
+ Help: "How many bytes have been returned by upstream GitLab in API failure/rejection response bodies.",
+ },
+ )
+)
+
+func NewAPI(myURL *url.URL, version string, roundTripper http.RoundTripper) *API {
+ return &API{
+ Client: &http.Client{Transport: roundTripper},
+ URL: myURL,
+ Version: version,
+ }
+}
+
+type HandleFunc func(http.ResponseWriter, *http.Request, *Response)
+
+type MultipartUploadParams struct {
+ // PartSize is the exact size of each uploaded part. Only the last one can be smaller
+ PartSize int64
+ // PartURLs contains the presigned URLs for each part
+ PartURLs []string
+ // CompleteURL is a presigned URL for CompleteMulipartUpload
+ CompleteURL string
+ // AbortURL is a presigned URL for AbortMultipartUpload
+ AbortURL string
+}
+
+type ObjectStorageParams struct {
+ Provider string
+ S3Config config.S3Config
+ GoCloudConfig config.GoCloudConfig
+}
+
+type RemoteObject struct {
+ // GetURL is an S3 GetObject URL
+ GetURL string
+ // DeleteURL is a presigned S3 RemoveObject URL
+ DeleteURL string
+ // StoreURL is the temporary presigned S3 PutObject URL to which upload the first found file
+ StoreURL string
+ // Boolean to indicate whether to use headers included in PutHeaders
+ CustomPutHeaders bool
+ // PutHeaders are HTTP headers (e.g. Content-Type) to be sent with StoreURL
+ PutHeaders map[string]string
+ // Whether to ignore Rails pre-signed URLs and have Workhorse directly access object storage provider
+ UseWorkhorseClient bool
+ // Remote, temporary object name where Rails will move to the final destination
+ RemoteTempObjectID string
+ // ID is a unique identifier of object storage upload
+ ID string
+ // Timeout is a number that represents timeout in seconds for sending data to StoreURL
+ Timeout int
+ // MultipartUpload contains presigned URLs for S3 MultipartUpload
+ MultipartUpload *MultipartUploadParams
+ // Object storage config for Workhorse client
+ ObjectStorage *ObjectStorageParams
+}
+
+type Response struct {
+ // GL_ID is an environment variable used by gitlab-shell hooks during 'git
+ // push' and 'git pull'
+ GL_ID string
+
+ // GL_USERNAME holds gitlab username of the user who is taking the action causing hooks to be invoked
+ GL_USERNAME string
+
+ // GL_REPOSITORY is an environment variable used by gitlab-shell hooks during
+ // 'git push' and 'git pull'
+ GL_REPOSITORY string
+ // GitConfigOptions holds the custom options that we want to pass to the git command
+ GitConfigOptions []string
+ // StoreLFSPath is provided by the GitLab Rails application to mark where the tmp file should be placed.
+ // This field is deprecated. GitLab will use TempPath instead
+ StoreLFSPath string
+ // LFS object id
+ LfsOid string
+ // LFS object size
+ LfsSize int64
+ // TmpPath is the path where we should store temporary files
+ // This is set by authorization middleware
+ TempPath string
+ // RemoteObject is provided by the GitLab Rails application
+ // and defines a way to store object on remote storage
+ RemoteObject RemoteObject
+ // Archive is the path where the artifacts archive is stored
+ Archive string `json:"archive"`
+ // Entry is a filename inside the archive point to file that needs to be extracted
+ Entry string `json:"entry"`
+ // Used to communicate channel session details
+ Channel *ChannelSettings
+ // GitalyServer specifies an address and authentication token for a gitaly server we should connect to.
+ GitalyServer gitaly.Server
+ // Repository object for making gRPC requests to Gitaly.
+ Repository gitalypb.Repository
+ // For git-http, does the requestor have the right to view all refs?
+ ShowAllRefs bool
+ // Detects whether an artifact is used for code intelligence
+ ProcessLsif bool
+ // Detects whether LSIF artifact will be parsed with references
+ ProcessLsifReferences bool
+ // The maximum accepted size in bytes of the upload
+ MaximumSize int64
+}
+
+// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
+ }
+ return a + b
+}
+
+// rebaseUrl is taken from reverseproxy.go:NewSingleHostReverseProxy
+func rebaseUrl(url *url.URL, onto *url.URL, suffix string) *url.URL {
+ newUrl := *url
+ newUrl.Scheme = onto.Scheme
+ newUrl.Host = onto.Host
+ if suffix != "" {
+ newUrl.Path = singleJoiningSlash(url.Path, suffix)
+ }
+ if onto.RawQuery == "" || newUrl.RawQuery == "" {
+ newUrl.RawQuery = onto.RawQuery + newUrl.RawQuery
+ } else {
+ newUrl.RawQuery = onto.RawQuery + "&" + newUrl.RawQuery
+ }
+ return &newUrl
+}
+
+func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error) {
+ authReq := &http.Request{
+ Method: r.Method,
+ URL: rebaseUrl(r.URL, api.URL, suffix),
+ Header: helper.HeaderClone(r.Header),
+ }
+
+ authReq = authReq.WithContext(r.Context())
+
+ // Clean some headers when issuing a new request without body
+ authReq.Header.Del("Content-Type")
+ authReq.Header.Del("Content-Encoding")
+ authReq.Header.Del("Content-Length")
+ authReq.Header.Del("Content-Disposition")
+ authReq.Header.Del("Accept-Encoding")
+
+ // Hop-by-hop headers. These are removed when sent to the backend.
+ // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
+ authReq.Header.Del("Transfer-Encoding")
+ authReq.Header.Del("Connection")
+ authReq.Header.Del("Keep-Alive")
+ authReq.Header.Del("Proxy-Authenticate")
+ authReq.Header.Del("Proxy-Authorization")
+ authReq.Header.Del("Te")
+ authReq.Header.Del("Trailers")
+ authReq.Header.Del("Upgrade")
+
+ // Also forward the Host header, which is excluded from the Header map by the http library.
+ // This allows the Host header received by the backend to be consistent with other
+ // requests not going through gitlab-workhorse.
+ authReq.Host = r.Host
+
+ return authReq, nil
+}
+
+// PreAuthorize performs a pre-authorization check against the API for the given HTTP request
+//
+// If `outErr` is set, the other fields will be nil and it should be treated as
+// a 500 error.
+//
+// If httpResponse is present, the caller is responsible for closing its body
+//
+// authResponse will only be present if the authorization check was successful
+func (api *API) PreAuthorize(suffix string, r *http.Request) (httpResponse *http.Response, authResponse *Response, outErr error) {
+ authReq, err := api.newRequest(r, suffix)
+ if err != nil {
+ return nil, nil, fmt.Errorf("preAuthorizeHandler newUpstreamRequest: %v", err)
+ }
+
+ httpResponse, err = api.doRequestWithoutRedirects(authReq)
+ if err != nil {
+ return nil, nil, fmt.Errorf("preAuthorizeHandler: do request: %v", err)
+ }
+ defer func() {
+ if outErr != nil {
+ httpResponse.Body.Close()
+ httpResponse = nil
+ }
+ }()
+ requestsCounter.WithLabelValues(strconv.Itoa(httpResponse.StatusCode), authReq.Method).Inc()
+
+ // This may be a false positive, e.g. for .../info/refs, rather than a
+ // failure, so pass the response back
+ if httpResponse.StatusCode != http.StatusOK || !validResponseContentType(httpResponse) {
+ return httpResponse, nil, nil
+ }
+
+ authResponse = &Response{}
+ // The auth backend validated the client request and told us additional
+ // request metadata. We must extract this information from the auth
+ // response body.
+ if err := json.NewDecoder(httpResponse.Body).Decode(authResponse); err != nil {
+ return httpResponse, nil, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err)
+ }
+
+ return httpResponse, authResponse, nil
+}
+
+func (api *API) PreAuthorizeHandler(next HandleFunc, suffix string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ httpResponse, authResponse, err := api.PreAuthorize(suffix, r)
+ if httpResponse != nil {
+ defer httpResponse.Body.Close()
+ }
+
+ if err != nil {
+ helper.Fail500(w, r, err)
+ return
+ }
+
+ // The response couldn't be interpreted as a valid auth response, so
+ // pass it back (mostly) unmodified
+ if httpResponse != nil && authResponse == nil {
+ passResponseBack(httpResponse, w, r)
+ return
+ }
+
+ httpResponse.Body.Close() // Free up the Unicorn worker
+
+ copyAuthHeader(httpResponse, w)
+
+ next(w, r, authResponse)
+ })
+}
+
+func (api *API) doRequestWithoutRedirects(authReq *http.Request) (*http.Response, error) {
+ signingTripper := secret.NewRoundTripper(api.Client.Transport, api.Version)
+
+ return signingTripper.RoundTrip(authReq)
+}
+
+func copyAuthHeader(httpResponse *http.Response, w http.ResponseWriter) {
+ // Negotiate authentication (Kerberos) may need to return a WWW-Authenticate
+ // header to the client even in case of success as per RFC4559.
+ for k, v := range httpResponse.Header {
+ // Case-insensitive comparison as per RFC7230
+ if strings.EqualFold(k, "WWW-Authenticate") {
+ w.Header()[k] = v
+ }
+ }
+}
+
+func passResponseBack(httpResponse *http.Response, w http.ResponseWriter, r *http.Request) {
+ // NGINX response buffering is disabled on this path (with
+ // X-Accel-Buffering: no) but we still want to free up the Unicorn worker
+ // that generated httpResponse as fast as possible. To do this we buffer
+ // the entire response body in memory before sending it on.
+ responseBody, err := bufferResponse(httpResponse.Body)
+ if err != nil {
+ helper.Fail500(w, r, err)
+ return
+ }
+ httpResponse.Body.Close() // Free up the Unicorn worker
+ bytesTotal.Add(float64(responseBody.Len()))
+
+ for k, v := range httpResponse.Header {
+ // Accommodate broken clients that do case-sensitive header lookup
+ if k == "Www-Authenticate" {
+ w.Header()["WWW-Authenticate"] = v
+ } else {
+ w.Header()[k] = v
+ }
+ }
+ w.WriteHeader(httpResponse.StatusCode)
+ if _, err := io.Copy(w, responseBody); err != nil {
+ helper.LogError(r, err)
+ }
+}
+
+func bufferResponse(r io.Reader) (*bytes.Buffer, error) {
+ responseBody := &bytes.Buffer{}
+ n, err := io.Copy(responseBody, io.LimitReader(r, failureResponseLimit))
+ if err != nil {
+ return nil, err
+ }
+
+ if n == failureResponseLimit {
+ return nil, fmt.Errorf("response body exceeded maximum buffer size (%d bytes)", failureResponseLimit)
+ }
+
+ return responseBody, nil
+}
+
+func validResponseContentType(resp *http.Response) bool {
+ return helper.IsContentType(ResponseContentType, resp.Header.Get("Content-Type"))
+}
diff --git a/workhorse/internal/api/block.go b/workhorse/internal/api/block.go
new file mode 100644
index 00000000000..92322906c03
--- /dev/null
+++ b/workhorse/internal/api/block.go
@@ -0,0 +1,61 @@
+package api
+
+import (
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+// Prevent internal API responses intended for gitlab-workhorse from
+// leaking to the end user
+func Block(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rw := &blocker{rw: w, r: r}
+ defer rw.flush()
+ h.ServeHTTP(rw, r)
+ })
+}
+
+type blocker struct {
+ rw http.ResponseWriter
+ r *http.Request
+ hijacked bool
+ status int
+}
+
+func (b *blocker) Header() http.Header {
+ return b.rw.Header()
+}
+
+func (b *blocker) Write(data []byte) (int, error) {
+ if b.status == 0 {
+ b.WriteHeader(http.StatusOK)
+ }
+ if b.hijacked {
+ return len(data), nil
+ }
+
+ return b.rw.Write(data)
+}
+
+func (b *blocker) WriteHeader(status int) {
+ if b.status != 0 {
+ return
+ }
+
+ if helper.IsContentType(ResponseContentType, b.Header().Get("Content-Type")) {
+ b.status = 500
+ b.Header().Del("Content-Length")
+ b.hijacked = true
+ helper.Fail500(b.rw, b.r, fmt.Errorf("api.blocker: forbidden content-type: %q", ResponseContentType))
+ return
+ }
+
+ b.status = status
+ b.rw.WriteHeader(b.status)
+}
+
+func (b *blocker) flush() {
+ b.WriteHeader(http.StatusOK)
+}
diff --git a/workhorse/internal/api/block_test.go b/workhorse/internal/api/block_test.go
new file mode 100644
index 00000000000..85ad54f3cfd
--- /dev/null
+++ b/workhorse/internal/api/block_test.go
@@ -0,0 +1,56 @@
+package api
+
+import (
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestBlocker(t *testing.T) {
+ upstreamResponse := "hello world"
+
+ testCases := []struct {
+ desc string
+ contentType string
+ out string
+ }{
+ {
+ desc: "blocked",
+ contentType: ResponseContentType,
+ out: "Internal server error\n",
+ },
+ {
+ desc: "pass",
+ contentType: "text/plain",
+ out: upstreamResponse,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ r, err := http.NewRequest("GET", "/foo", nil)
+ require.NoError(t, err)
+
+ rw := httptest.NewRecorder()
+ bl := &blocker{rw: rw, r: r}
+ bl.Header().Set("Content-Type", tc.contentType)
+
+ upstreamBody := []byte(upstreamResponse)
+ n, err := bl.Write(upstreamBody)
+ require.NoError(t, err)
+ require.Equal(t, len(upstreamBody), n, "bytes written")
+
+ rw.Flush()
+
+ body := rw.Result().Body
+ data, err := ioutil.ReadAll(body)
+ require.NoError(t, err)
+ require.NoError(t, body.Close())
+
+ require.Equal(t, tc.out, string(data))
+ })
+ }
+}
diff --git a/workhorse/internal/api/channel_settings.go b/workhorse/internal/api/channel_settings.go
new file mode 100644
index 00000000000..bf3094c9c91
--- /dev/null
+++ b/workhorse/internal/api/channel_settings.go
@@ -0,0 +1,122 @@
+package api
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "github.com/gorilla/websocket"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+type ChannelSettings struct {
+ // The channel provider may require use of a particular subprotocol. If so,
+ // it must be specified here, and Workhorse must have a matching codec.
+ Subprotocols []string
+
+ // The websocket URL to connect to.
+ Url string
+
+ // Any headers (e.g., Authorization) to send with the websocket request
+ Header http.Header
+
+ // The CA roots to validate the remote endpoint with, for wss:// URLs. The
+ // system-provided CA pool will be used if this is blank. PEM-encoded data.
+ CAPem string
+
+ // The value is specified in seconds. It is converted to time.Duration
+ // later.
+ MaxSessionTime int
+}
+
+func (t *ChannelSettings) URL() (*url.URL, error) {
+ return url.Parse(t.Url)
+}
+
+func (t *ChannelSettings) Dialer() *websocket.Dialer {
+ dialer := &websocket.Dialer{
+ Subprotocols: t.Subprotocols,
+ }
+
+ if len(t.CAPem) > 0 {
+ pool := x509.NewCertPool()
+ pool.AppendCertsFromPEM([]byte(t.CAPem))
+ dialer.TLSClientConfig = &tls.Config{RootCAs: pool}
+ }
+
+ return dialer
+}
+
+func (t *ChannelSettings) Clone() *ChannelSettings {
+ // Doesn't clone the strings, but that's OK as strings are immutable in go
+ cloned := *t
+ cloned.Header = helper.HeaderClone(t.Header)
+ return &cloned
+}
+
+func (t *ChannelSettings) Dial() (*websocket.Conn, *http.Response, error) {
+ return t.Dialer().Dial(t.Url, t.Header)
+}
+
+func (t *ChannelSettings) Validate() error {
+ if t == nil {
+ return fmt.Errorf("channel details not specified")
+ }
+
+ if len(t.Subprotocols) == 0 {
+ return fmt.Errorf("no subprotocol specified")
+ }
+
+ parsedURL, err := t.URL()
+ if err != nil {
+ return fmt.Errorf("invalid URL")
+ }
+
+ if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
+ return fmt.Errorf("invalid websocket scheme: %q", parsedURL.Scheme)
+ }
+
+ return nil
+}
+
+func (t *ChannelSettings) IsEqual(other *ChannelSettings) bool {
+ if t == nil && other == nil {
+ return true
+ }
+
+ if t == nil || other == nil {
+ return false
+ }
+
+ if len(t.Subprotocols) != len(other.Subprotocols) {
+ return false
+ }
+
+ for i, subprotocol := range t.Subprotocols {
+ if other.Subprotocols[i] != subprotocol {
+ return false
+ }
+ }
+
+ if len(t.Header) != len(other.Header) {
+ return false
+ }
+
+ for header, values := range t.Header {
+ if len(values) != len(other.Header[header]) {
+ return false
+ }
+ for i, value := range values {
+ if other.Header[header][i] != value {
+ return false
+ }
+ }
+ }
+
+ return t.Url == other.Url &&
+ t.CAPem == other.CAPem &&
+ t.MaxSessionTime == other.MaxSessionTime
+}
diff --git a/workhorse/internal/api/channel_settings_test.go b/workhorse/internal/api/channel_settings_test.go
new file mode 100644
index 00000000000..4aa2c835579
--- /dev/null
+++ b/workhorse/internal/api/channel_settings_test.go
@@ -0,0 +1,154 @@
+package api
+
+import (
+ "net/http"
+ "testing"
+)
+
+func channel(url string, subprotocols ...string) *ChannelSettings {
+ return &ChannelSettings{
+ Url: url,
+ Subprotocols: subprotocols,
+ MaxSessionTime: 0,
+ }
+}
+
+func ca(channel *ChannelSettings) *ChannelSettings {
+ channel = channel.Clone()
+ channel.CAPem = "Valid CA data"
+
+ return channel
+}
+
+func timeout(channel *ChannelSettings) *ChannelSettings {
+ channel = channel.Clone()
+ channel.MaxSessionTime = 600
+
+ return channel
+}
+
+func header(channel *ChannelSettings, values ...string) *ChannelSettings {
+ if len(values) == 0 {
+ values = []string{"Dummy Value"}
+ }
+
+ channel = channel.Clone()
+ channel.Header = http.Header{
+ "Header": values,
+ }
+
+ return channel
+}
+
+func TestClone(t *testing.T) {
+ a := ca(header(channel("ws:", "", "")))
+ b := a.Clone()
+
+ if a == b {
+ t.Fatalf("Address of cloned channel didn't change")
+ }
+
+ if &a.Subprotocols == &b.Subprotocols {
+ t.Fatalf("Address of cloned subprotocols didn't change")
+ }
+
+ if &a.Header == &b.Header {
+ t.Fatalf("Address of cloned header didn't change")
+ }
+}
+
+func TestValidate(t *testing.T) {
+ for i, tc := range []struct {
+ channel *ChannelSettings
+ valid bool
+ msg string
+ }{
+ {nil, false, "nil channel"},
+ {channel("", ""), false, "empty URL"},
+ {channel("ws:"), false, "empty subprotocols"},
+ {channel("ws:", "foo"), true, "any subprotocol"},
+ {channel("ws:", "foo", "bar"), true, "multiple subprotocols"},
+ {channel("ws:", ""), true, "websocket URL"},
+ {channel("wss:", ""), true, "secure websocket URL"},
+ {channel("http:", ""), false, "HTTP URL"},
+ {channel("https:", ""), false, " HTTPS URL"},
+ {ca(channel("ws:", "")), true, "any CA pem"},
+ {header(channel("ws:", "")), true, "any headers"},
+ {ca(header(channel("ws:", ""))), true, "PEM and headers"},
+ } {
+ if err := tc.channel.Validate(); (err != nil) == tc.valid {
+ t.Fatalf("test case %d: "+tc.msg+": valid=%v: %s: %+v", i, tc.valid, err, tc.channel)
+ }
+ }
+}
+
+func TestDialer(t *testing.T) {
+ channel := channel("ws:", "foo")
+ dialer := channel.Dialer()
+
+ if len(dialer.Subprotocols) != len(channel.Subprotocols) {
+ t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols)
+ }
+
+ for i, subprotocol := range channel.Subprotocols {
+ if dialer.Subprotocols[i] != subprotocol {
+ t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols)
+ }
+ }
+
+ if dialer.TLSClientConfig != nil {
+ t.Fatalf("Unexpected TLSClientConfig: %+v", dialer)
+ }
+
+ channel = ca(channel)
+ dialer = channel.Dialer()
+
+ if dialer.TLSClientConfig == nil || dialer.TLSClientConfig.RootCAs == nil {
+ t.Fatalf("Custom CA certificates not recognised!")
+ }
+}
+
+func TestIsEqual(t *testing.T) {
+ chann := channel("ws:", "foo")
+
+ chann_header2 := header(chann, "extra")
+ chann_header3 := header(chann)
+ chann_header3.Header.Add("Extra", "extra")
+
+ chann_ca2 := ca(chann)
+ chann_ca2.CAPem = "other value"
+
+ for i, tc := range []struct {
+ channelA *ChannelSettings
+ channelB *ChannelSettings
+ expected bool
+ }{
+ {nil, nil, true},
+ {chann, nil, false},
+ {nil, chann, false},
+ {chann, chann, true},
+ {chann.Clone(), chann.Clone(), true},
+ {chann, channel("foo:"), false},
+ {chann, channel(chann.Url), false},
+ {header(chann), header(chann), true},
+ {chann_header2, chann_header2, true},
+ {chann_header3, chann_header3, true},
+ {header(chann), chann_header2, false},
+ {header(chann), chann_header3, false},
+ {header(chann), chann, false},
+ {chann, header(chann), false},
+ {ca(chann), ca(chann), true},
+ {ca(chann), chann, false},
+ {chann, ca(chann), false},
+ {ca(header(chann)), ca(header(chann)), true},
+ {chann_ca2, ca(chann), false},
+ {chann, timeout(chann), false},
+ } {
+ if actual := tc.channelA.IsEqual(tc.channelB); tc.expected != actual {
+ t.Fatalf(
+ "test case %d: Comparison:\n-%+v\n+%+v\nexpected=%v: actual=%v",
+ i, tc.channelA, tc.channelB, tc.expected, actual,
+ )
+ }
+ }
+}
diff --git a/workhorse/internal/artifacts/artifacts_store_test.go b/workhorse/internal/artifacts/artifacts_store_test.go
new file mode 100644
index 00000000000..bd56d9ea725
--- /dev/null
+++ b/workhorse/internal/artifacts/artifacts_store_test.go
@@ -0,0 +1,338 @@
+package artifacts
+
+import (
+ "archive/zip"
+ "bytes"
+ "crypto/md5"
+ "encoding/hex"
+ "fmt"
+ "io/ioutil"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+func createTestZipArchive(t *testing.T) (data []byte, md5Hash string) {
+ var buffer bytes.Buffer
+ archive := zip.NewWriter(&buffer)
+ fileInArchive, err := archive.Create("test.file")
+ require.NoError(t, err)
+ fmt.Fprint(fileInArchive, "test")
+ archive.Close()
+ data = buffer.Bytes()
+
+ hasher := md5.New()
+ hasher.Write(data)
+ hexHash := hasher.Sum(nil)
+ md5Hash = hex.EncodeToString(hexHash)
+
+ return data, md5Hash
+}
+
+func createTestMultipartForm(t *testing.T, data []byte) (bytes.Buffer, string) {
+ var buffer bytes.Buffer
+ writer := multipart.NewWriter(&buffer)
+ file, err := writer.CreateFormFile("file", "my.file")
+ require.NoError(t, err)
+ file.Write(data)
+ writer.Close()
+ return buffer, writer.FormDataContentType()
+}
+
+func testUploadArtifactsFromTestZip(t *testing.T, ts *httptest.Server) *httptest.ResponseRecorder {
+ archiveData, _ := createTestZipArchive(t)
+ contentBuffer, contentType := createTestMultipartForm(t, archiveData)
+
+ return testUploadArtifacts(t, contentType, ts.URL+Path, &contentBuffer)
+}
+
+func TestUploadHandlerSendingToExternalStorage(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tempPath)
+
+ archiveData, md5 := createTestZipArchive(t)
+ archiveFile, err := ioutil.TempFile("", "artifact.zip")
+ require.NoError(t, err)
+ defer os.Remove(archiveFile.Name())
+ _, err = archiveFile.Write(archiveData)
+ require.NoError(t, err)
+ archiveFile.Close()
+
+ storeServerCalled := 0
+ storeServerMux := http.NewServeMux()
+ storeServerMux.HandleFunc("/url/put", func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "PUT", r.Method)
+
+ receivedData, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+ require.Equal(t, archiveData, receivedData)
+
+ storeServerCalled++
+ w.Header().Set("ETag", md5)
+ w.WriteHeader(200)
+ })
+ storeServerMux.HandleFunc("/store-id", func(w http.ResponseWriter, r *http.Request) {
+ http.ServeFile(w, r, archiveFile.Name())
+ })
+
+ responseProcessorCalled := 0
+ responseProcessor := func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "store-id", r.FormValue("file.remote_id"))
+ require.NotEmpty(t, r.FormValue("file.remote_url"))
+ w.WriteHeader(200)
+ responseProcessorCalled++
+ }
+
+ storeServer := httptest.NewServer(storeServerMux)
+ defer storeServer.Close()
+
+ qs := fmt.Sprintf("?%s=%s", ArtifactFormatKey, ArtifactFormatZip)
+
+ tests := []struct {
+ name string
+ preauth api.Response
+ }{
+ {
+ name: "ObjectStore Upload",
+ preauth: api.Response{
+ RemoteObject: api.RemoteObject{
+ StoreURL: storeServer.URL + "/url/put" + qs,
+ ID: "store-id",
+ GetURL: storeServer.URL + "/store-id",
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ storeServerCalled = 0
+ responseProcessorCalled = 0
+
+ ts := testArtifactsUploadServer(t, test.preauth, responseProcessor)
+ defer ts.Close()
+
+ contentBuffer, contentType := createTestMultipartForm(t, archiveData)
+ response := testUploadArtifacts(t, contentType, ts.URL+Path+qs, &contentBuffer)
+ require.Equal(t, http.StatusOK, response.Code)
+ testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderPresent)
+ require.Equal(t, 1, storeServerCalled, "store should be called only once")
+ require.Equal(t, 1, responseProcessorCalled, "response processor should be called only once")
+ })
+ }
+}
+
+func TestUploadHandlerSendingToExternalStorageAndStorageServerUnreachable(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tempPath)
+
+ responseProcessor := func(w http.ResponseWriter, r *http.Request) {
+ t.Fatal("it should not be called")
+ }
+
+ authResponse := api.Response{
+ TempPath: tempPath,
+ RemoteObject: api.RemoteObject{
+ StoreURL: "http://localhost:12323/invalid/url",
+ ID: "store-id",
+ },
+ }
+
+ ts := testArtifactsUploadServer(t, authResponse, responseProcessor)
+ defer ts.Close()
+
+ response := testUploadArtifactsFromTestZip(t, ts)
+ require.Equal(t, http.StatusInternalServerError, response.Code)
+}
+
+func TestUploadHandlerSendingToExternalStorageAndInvalidURLIsUsed(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tempPath)
+
+ responseProcessor := func(w http.ResponseWriter, r *http.Request) {
+ t.Fatal("it should not be called")
+ }
+
+ authResponse := api.Response{
+ TempPath: tempPath,
+ RemoteObject: api.RemoteObject{
+ StoreURL: "htt:////invalid-url",
+ ID: "store-id",
+ },
+ }
+
+ ts := testArtifactsUploadServer(t, authResponse, responseProcessor)
+ defer ts.Close()
+
+ response := testUploadArtifactsFromTestZip(t, ts)
+ require.Equal(t, http.StatusInternalServerError, response.Code)
+}
+
+func TestUploadHandlerSendingToExternalStorageAndItReturnsAnError(t *testing.T) {
+ putCalledTimes := 0
+
+ storeServerMux := http.NewServeMux()
+ storeServerMux.HandleFunc("/url/put", func(w http.ResponseWriter, r *http.Request) {
+ putCalledTimes++
+ require.Equal(t, "PUT", r.Method)
+ w.WriteHeader(510)
+ })
+
+ responseProcessor := func(w http.ResponseWriter, r *http.Request) {
+ t.Fatal("it should not be called")
+ }
+
+ storeServer := httptest.NewServer(storeServerMux)
+ defer storeServer.Close()
+
+ authResponse := api.Response{
+ RemoteObject: api.RemoteObject{
+ StoreURL: storeServer.URL + "/url/put",
+ ID: "store-id",
+ },
+ }
+
+ ts := testArtifactsUploadServer(t, authResponse, responseProcessor)
+ defer ts.Close()
+
+ response := testUploadArtifactsFromTestZip(t, ts)
+ require.Equal(t, http.StatusInternalServerError, response.Code)
+ require.Equal(t, 1, putCalledTimes, "upload should be called only once")
+}
+
+func TestUploadHandlerSendingToExternalStorageAndSupportRequestTimeout(t *testing.T) {
+ putCalledTimes := 0
+
+ storeServerMux := http.NewServeMux()
+ storeServerMux.HandleFunc("/url/put", func(w http.ResponseWriter, r *http.Request) {
+ putCalledTimes++
+ require.Equal(t, "PUT", r.Method)
+ time.Sleep(10 * time.Second)
+ w.WriteHeader(510)
+ })
+
+ responseProcessor := func(w http.ResponseWriter, r *http.Request) {
+ t.Fatal("it should not be called")
+ }
+
+ storeServer := httptest.NewServer(storeServerMux)
+ defer storeServer.Close()
+
+ authResponse := api.Response{
+ RemoteObject: api.RemoteObject{
+ StoreURL: storeServer.URL + "/url/put",
+ ID: "store-id",
+ Timeout: 1,
+ },
+ }
+
+ ts := testArtifactsUploadServer(t, authResponse, responseProcessor)
+ defer ts.Close()
+
+ response := testUploadArtifactsFromTestZip(t, ts)
+ require.Equal(t, http.StatusInternalServerError, response.Code)
+ require.Equal(t, 1, putCalledTimes, "upload should be called only once")
+}
+
+func TestUploadHandlerMultipartUploadSizeLimit(t *testing.T) {
+ os, server := test.StartObjectStore()
+ defer server.Close()
+
+ err := os.InitiateMultipartUpload(test.ObjectPath)
+ require.NoError(t, err)
+
+ objectURL := server.URL + test.ObjectPath
+
+ uploadSize := 10
+ preauth := api.Response{
+ RemoteObject: api.RemoteObject{
+ ID: "store-id",
+ MultipartUpload: &api.MultipartUploadParams{
+ PartSize: 1,
+ PartURLs: []string{objectURL + "?partNumber=1"},
+ AbortURL: objectURL, // DELETE
+ CompleteURL: objectURL, // POST
+ },
+ },
+ }
+
+ responseProcessor := func(w http.ResponseWriter, r *http.Request) {
+ t.Fatal("it should not be called")
+ }
+
+ ts := testArtifactsUploadServer(t, preauth, responseProcessor)
+ defer ts.Close()
+
+ contentBuffer, contentType := createTestMultipartForm(t, make([]byte, uploadSize))
+ response := testUploadArtifacts(t, contentType, ts.URL+Path, &contentBuffer)
+ require.Equal(t, http.StatusRequestEntityTooLarge, response.Code)
+
+ // Poll because AbortMultipartUpload is async
+ for i := 0; os.IsMultipartUpload(test.ObjectPath) && i < 100; i++ {
+ time.Sleep(10 * time.Millisecond)
+ }
+ require.False(t, os.IsMultipartUpload(test.ObjectPath), "MultipartUpload should not be in progress anymore")
+ require.Empty(t, os.GetObjectMD5(test.ObjectPath), "upload should have failed, so the object should not exists")
+}
+
+func TestUploadHandlerMultipartUploadMaximumSizeFromApi(t *testing.T) {
+ os, server := test.StartObjectStore()
+ defer server.Close()
+
+ err := os.InitiateMultipartUpload(test.ObjectPath)
+ require.NoError(t, err)
+
+ objectURL := server.URL + test.ObjectPath
+
+ uploadSize := int64(10)
+ maxSize := uploadSize - 1
+ preauth := api.Response{
+ MaximumSize: maxSize,
+ RemoteObject: api.RemoteObject{
+ ID: "store-id",
+ MultipartUpload: &api.MultipartUploadParams{
+ PartSize: uploadSize,
+ PartURLs: []string{objectURL + "?partNumber=1"},
+ AbortURL: objectURL, // DELETE
+ CompleteURL: objectURL, // POST
+ },
+ },
+ }
+
+ responseProcessor := func(w http.ResponseWriter, r *http.Request) {
+ t.Fatal("it should not be called")
+ }
+
+ ts := testArtifactsUploadServer(t, preauth, responseProcessor)
+ defer ts.Close()
+
+ contentBuffer, contentType := createTestMultipartForm(t, make([]byte, uploadSize))
+ response := testUploadArtifacts(t, contentType, ts.URL+Path, &contentBuffer)
+ require.Equal(t, http.StatusRequestEntityTooLarge, response.Code)
+
+ testhelper.Retry(t, 5*time.Second, func() error {
+ if os.GetObjectMD5(test.ObjectPath) == "" {
+ return nil
+ }
+
+ return fmt.Errorf("file is still present")
+ })
+}
diff --git a/workhorse/internal/artifacts/artifacts_test.go b/workhorse/internal/artifacts/artifacts_test.go
new file mode 100644
index 00000000000..b9a42cc60c1
--- /dev/null
+++ b/workhorse/internal/artifacts/artifacts_test.go
@@ -0,0 +1,19 @@
+package artifacts
+
+import (
+ "os"
+ "testing"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+func TestMain(m *testing.M) {
+ if err := testhelper.BuildExecutables(); err != nil {
+ log.WithError(err).Fatal()
+ }
+
+ os.Exit(m.Run())
+
+}
diff --git a/workhorse/internal/artifacts/artifacts_upload.go b/workhorse/internal/artifacts/artifacts_upload.go
new file mode 100644
index 00000000000..3d4b8bf0931
--- /dev/null
+++ b/workhorse/internal/artifacts/artifacts_upload.go
@@ -0,0 +1,167 @@
+package artifacts
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "os"
+ "os/exec"
+ "strings"
+ "syscall"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
+)
+
+// Sent by the runner: https://gitlab.com/gitlab-org/gitlab-runner/blob/c24da19ecce8808d9d2950896f70c94f5ea1cc2e/network/gitlab.go#L580
+const (
+ ArtifactFormatKey = "artifact_format"
+ ArtifactFormatZip = "zip"
+ ArtifactFormatDefault = ""
+)
+
+var zipSubcommandsErrorsCounter = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_zip_subcommand_errors_total",
+ Help: "Errors comming from subcommands used for processing ZIP archives",
+ }, []string{"error"})
+
+type artifactsUploadProcessor struct {
+ opts *filestore.SaveFileOpts
+ format string
+
+ upload.SavedFileTracker
+}
+
+func (a *artifactsUploadProcessor) generateMetadataFromZip(ctx context.Context, file *filestore.FileHandler) (*filestore.FileHandler, error) {
+ metaReader, metaWriter := io.Pipe()
+ defer metaWriter.Close()
+
+ metaOpts := &filestore.SaveFileOpts{
+ LocalTempPath: a.opts.LocalTempPath,
+ TempFilePrefix: "metadata.gz",
+ }
+ if metaOpts.LocalTempPath == "" {
+ metaOpts.LocalTempPath = os.TempDir()
+ }
+
+ fileName := file.LocalPath
+ if fileName == "" {
+ fileName = file.RemoteURL
+ }
+
+ zipMd := exec.CommandContext(ctx, "gitlab-zip-metadata", fileName)
+ zipMd.Stderr = log.ContextLogger(ctx).Writer()
+ zipMd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+ zipMd.Stdout = metaWriter
+
+ if err := zipMd.Start(); err != nil {
+ return nil, err
+ }
+ defer helper.CleanUpProcessGroup(zipMd)
+
+ type saveResult struct {
+ error
+ *filestore.FileHandler
+ }
+ done := make(chan saveResult)
+ go func() {
+ var result saveResult
+ result.FileHandler, result.error = filestore.SaveFileFromReader(ctx, metaReader, -1, metaOpts)
+
+ done <- result
+ }()
+
+ if err := zipMd.Wait(); err != nil {
+ st, ok := helper.ExitStatus(err)
+
+ if !ok {
+ return nil, err
+ }
+
+ zipSubcommandsErrorsCounter.WithLabelValues(zipartifacts.ErrorLabelByCode(st)).Inc()
+
+ if st == zipartifacts.CodeNotZip {
+ return nil, nil
+ }
+
+ if st == zipartifacts.CodeLimitsReached {
+ return nil, zipartifacts.ErrBadMetadata
+ }
+ }
+
+ metaWriter.Close()
+ result := <-done
+ return result.FileHandler, result.error
+}
+
+func (a *artifactsUploadProcessor) ProcessFile(ctx context.Context, formName string, file *filestore.FileHandler, writer *multipart.Writer) error {
+ // ProcessFile for artifacts requires file form-data field name to eq `file`
+
+ if formName != "file" {
+ return fmt.Errorf("invalid form field: %q", formName)
+ }
+ if a.Count() > 0 {
+ return fmt.Errorf("artifacts request contains more than one file")
+ }
+ a.Track(formName, file.LocalPath)
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("ProcessFile: context done")
+ default:
+ }
+
+ if !strings.EqualFold(a.format, ArtifactFormatZip) && a.format != ArtifactFormatDefault {
+ return nil
+ }
+
+ // TODO: can we rely on disk for shipping metadata? Not if we split workhorse and rails in 2 different PODs
+ metadata, err := a.generateMetadataFromZip(ctx, file)
+ if err != nil {
+ return err
+ }
+
+ if metadata != nil {
+ fields, err := metadata.GitLabFinalizeFields("metadata")
+ if err != nil {
+ return fmt.Errorf("finalize metadata field error: %v", err)
+ }
+
+ for k, v := range fields {
+ writer.WriteField(k, v)
+ }
+
+ a.Track("metadata", metadata.LocalPath)
+ }
+
+ return nil
+}
+
+func (a *artifactsUploadProcessor) Name() string {
+ return "artifacts"
+}
+
+func UploadArtifacts(myAPI *api.API, h http.Handler, p upload.Preparer) http.Handler {
+ return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
+ opts, _, err := p.Prepare(a)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("UploadArtifacts: error preparing file storage options"))
+ return
+ }
+
+ format := r.URL.Query().Get(ArtifactFormatKey)
+
+ mg := &artifactsUploadProcessor{opts: opts, format: format, SavedFileTracker: upload.SavedFileTracker{Request: r}}
+ upload.HandleFileUploads(w, r, h, a, mg, opts)
+ }, "/authorize")
+}
diff --git a/workhorse/internal/artifacts/artifacts_upload_test.go b/workhorse/internal/artifacts/artifacts_upload_test.go
new file mode 100644
index 00000000000..c82ae791239
--- /dev/null
+++ b/workhorse/internal/artifacts/artifacts_upload_test.go
@@ -0,0 +1,322 @@
+package artifacts
+
+import (
+ "archive/zip"
+ "bytes"
+ "compress/gzip"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+
+ "github.com/dgrijalva/jwt-go"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream/roundtripper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
+
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ MetadataHeaderKey = "Metadata-Status"
+ MetadataHeaderPresent = "present"
+ MetadataHeaderMissing = "missing"
+ Path = "/url/path"
+)
+
+func testArtifactsUploadServer(t *testing.T, authResponse api.Response, bodyProcessor func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
+ mux := http.NewServeMux()
+ mux.HandleFunc(Path+"/authorize", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "POST" {
+ t.Fatal("Expected POST request")
+ }
+
+ w.Header().Set("Content-Type", api.ResponseContentType)
+
+ data, err := json.Marshal(&authResponse)
+ if err != nil {
+ t.Fatal("Expected to marshal")
+ }
+ w.Write(data)
+ })
+ mux.HandleFunc(Path, func(w http.ResponseWriter, r *http.Request) {
+ opts, err := filestore.GetOpts(&authResponse)
+ require.NoError(t, err)
+
+ if r.Method != "POST" {
+ t.Fatal("Expected POST request")
+ }
+ if opts.IsLocal() {
+ if r.FormValue("file.path") == "" {
+ t.Fatal("Expected file to be present")
+ return
+ }
+
+ _, err := ioutil.ReadFile(r.FormValue("file.path"))
+ if err != nil {
+ t.Fatal("Expected file to be readable")
+ return
+ }
+ } else {
+ if r.FormValue("file.remote_url") == "" {
+ t.Fatal("Expected file to be remote accessible")
+ return
+ }
+ }
+
+ if r.FormValue("metadata.path") != "" {
+ metadata, err := ioutil.ReadFile(r.FormValue("metadata.path"))
+ if err != nil {
+ t.Fatal("Expected metadata to be readable")
+ return
+ }
+ gz, err := gzip.NewReader(bytes.NewReader(metadata))
+ if err != nil {
+ t.Fatal("Expected metadata to be valid gzip")
+ return
+ }
+ defer gz.Close()
+ metadata, err = ioutil.ReadAll(gz)
+ if err != nil {
+ t.Fatal("Expected metadata to be valid")
+ return
+ }
+ if !bytes.HasPrefix(metadata, []byte(zipartifacts.MetadataHeaderPrefix+zipartifacts.MetadataHeader)) {
+ t.Fatal("Expected metadata to be of valid format")
+ return
+ }
+
+ w.Header().Set(MetadataHeaderKey, MetadataHeaderPresent)
+
+ } else {
+ w.Header().Set(MetadataHeaderKey, MetadataHeaderMissing)
+ }
+
+ w.WriteHeader(http.StatusOK)
+
+ if bodyProcessor != nil {
+ bodyProcessor(w, r)
+ }
+ })
+ return testhelper.TestServerWithHandler(nil, mux.ServeHTTP)
+}
+
+type testServer struct {
+ url string
+ writer *multipart.Writer
+ buffer *bytes.Buffer
+ fileWriter io.Writer
+ cleanup func()
+}
+
+func setupWithTmpPath(t *testing.T, filename string, includeFormat bool, format string, authResponse *api.Response, bodyProcessor func(w http.ResponseWriter, r *http.Request)) *testServer {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+
+ if authResponse == nil {
+ authResponse = &api.Response{TempPath: tempPath}
+ }
+
+ ts := testArtifactsUploadServer(t, *authResponse, bodyProcessor)
+
+ var buffer bytes.Buffer
+ writer := multipart.NewWriter(&buffer)
+ fileWriter, err := writer.CreateFormFile(filename, "my.file")
+ require.NotNil(t, fileWriter)
+ require.NoError(t, err)
+
+ cleanup := func() {
+ ts.Close()
+ require.NoError(t, os.RemoveAll(tempPath))
+ require.NoError(t, writer.Close())
+ }
+
+ qs := ""
+
+ if includeFormat {
+ qs = fmt.Sprintf("?%s=%s", ArtifactFormatKey, format)
+ }
+
+ return &testServer{url: ts.URL + Path + qs, writer: writer, buffer: &buffer, fileWriter: fileWriter, cleanup: cleanup}
+}
+
+func testUploadArtifacts(t *testing.T, contentType, url string, body io.Reader) *httptest.ResponseRecorder {
+ httpRequest, err := http.NewRequest("POST", url, body)
+ require.NoError(t, err)
+
+ httpRequest.Header.Set("Content-Type", contentType)
+ response := httptest.NewRecorder()
+ parsedURL := helper.URLMustParse(url)
+ roundTripper := roundtripper.NewTestBackendRoundTripper(parsedURL)
+ testhelper.ConfigureSecret()
+ apiClient := api.NewAPI(parsedURL, "123", roundTripper)
+ proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper)
+ UploadArtifacts(apiClient, proxyClient, &upload.DefaultPreparer{}).ServeHTTP(response, httpRequest)
+ return response
+}
+
+func TestUploadHandlerAddingMetadata(t *testing.T) {
+ testCases := []struct {
+ desc string
+ format string
+ includeFormat bool
+ }{
+ {
+ desc: "ZIP format",
+ format: ArtifactFormatZip,
+ includeFormat: true,
+ },
+ {
+ desc: "default format",
+ format: ArtifactFormatDefault,
+ includeFormat: true,
+ },
+ {
+ desc: "default format without artifact_format",
+ format: ArtifactFormatDefault,
+ includeFormat: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ s := setupWithTmpPath(t, "file", tc.includeFormat, tc.format, nil,
+ func(w http.ResponseWriter, r *http.Request) {
+ token, err := jwt.ParseWithClaims(r.Header.Get(upload.RewrittenFieldsHeader), &upload.MultipartClaims{}, testhelper.ParseJWT)
+ require.NoError(t, err)
+
+ rewrittenFields := token.Claims.(*upload.MultipartClaims).RewrittenFields
+ require.Equal(t, 2, len(rewrittenFields))
+
+ require.Contains(t, rewrittenFields, "file")
+ require.Contains(t, rewrittenFields, "metadata")
+ require.Contains(t, r.PostForm, "file.gitlab-workhorse-upload")
+ require.Contains(t, r.PostForm, "metadata.gitlab-workhorse-upload")
+ },
+ )
+ defer s.cleanup()
+
+ archive := zip.NewWriter(s.fileWriter)
+ file, err := archive.Create("test.file")
+ require.NotNil(t, file)
+ require.NoError(t, err)
+
+ require.NoError(t, archive.Close())
+ require.NoError(t, s.writer.Close())
+
+ response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer)
+ require.Equal(t, http.StatusOK, response.Code)
+ testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderPresent)
+ })
+ }
+}
+
+func TestUploadHandlerTarArtifact(t *testing.T) {
+ s := setupWithTmpPath(t, "file", true, "tar", nil,
+ func(w http.ResponseWriter, r *http.Request) {
+ token, err := jwt.ParseWithClaims(r.Header.Get(upload.RewrittenFieldsHeader), &upload.MultipartClaims{}, testhelper.ParseJWT)
+ require.NoError(t, err)
+
+ rewrittenFields := token.Claims.(*upload.MultipartClaims).RewrittenFields
+ require.Equal(t, 1, len(rewrittenFields))
+
+ require.Contains(t, rewrittenFields, "file")
+ require.Contains(t, r.PostForm, "file.gitlab-workhorse-upload")
+ },
+ )
+ defer s.cleanup()
+
+ file, err := os.Open("../../testdata/tarfile.tar")
+ require.NoError(t, err)
+
+ _, err = io.Copy(s.fileWriter, file)
+ require.NoError(t, err)
+ require.NoError(t, file.Close())
+ require.NoError(t, s.writer.Close())
+
+ response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer)
+ require.Equal(t, http.StatusOK, response.Code)
+ testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderMissing)
+}
+
+func TestUploadHandlerForUnsupportedArchive(t *testing.T) {
+ s := setupWithTmpPath(t, "file", true, "other", nil, nil)
+ defer s.cleanup()
+ require.NoError(t, s.writer.Close())
+
+ response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer)
+ require.Equal(t, http.StatusOK, response.Code)
+ testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderMissing)
+}
+
+func TestUploadHandlerForMultipleFiles(t *testing.T) {
+ s := setupWithTmpPath(t, "file", true, "", nil, nil)
+ defer s.cleanup()
+
+ file, err := s.writer.CreateFormFile("file", "my.file")
+ require.NotNil(t, file)
+ require.NoError(t, err)
+ require.NoError(t, s.writer.Close())
+
+ response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer)
+ require.Equal(t, http.StatusInternalServerError, response.Code)
+}
+
+func TestUploadFormProcessing(t *testing.T) {
+ s := setupWithTmpPath(t, "metadata", true, "", nil, nil)
+ defer s.cleanup()
+ require.NoError(t, s.writer.Close())
+
+ response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer)
+ require.Equal(t, http.StatusInternalServerError, response.Code)
+}
+
+func TestLsifFileProcessing(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+
+ s := setupWithTmpPath(t, "file", true, "zip", &api.Response{TempPath: tempPath, ProcessLsif: true}, nil)
+ defer s.cleanup()
+
+ file, err := os.Open("../../testdata/lsif/valid.lsif.zip")
+ require.NoError(t, err)
+
+ _, err = io.Copy(s.fileWriter, file)
+ require.NoError(t, err)
+ require.NoError(t, file.Close())
+ require.NoError(t, s.writer.Close())
+
+ response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer)
+ require.Equal(t, http.StatusOK, response.Code)
+ testhelper.RequireResponseHeader(t, response, MetadataHeaderKey, MetadataHeaderPresent)
+}
+
+func TestInvalidLsifFileProcessing(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+
+ s := setupWithTmpPath(t, "file", true, "zip", &api.Response{TempPath: tempPath, ProcessLsif: true}, nil)
+ defer s.cleanup()
+
+ file, err := os.Open("../../testdata/lsif/invalid.lsif.zip")
+ require.NoError(t, err)
+
+ _, err = io.Copy(s.fileWriter, file)
+ require.NoError(t, err)
+ require.NoError(t, file.Close())
+ require.NoError(t, s.writer.Close())
+
+ response := testUploadArtifacts(t, s.writer.FormDataContentType(), s.url, s.buffer)
+ require.Equal(t, http.StatusInternalServerError, response.Code)
+}
diff --git a/workhorse/internal/artifacts/entry.go b/workhorse/internal/artifacts/entry.go
new file mode 100644
index 00000000000..0c697d40020
--- /dev/null
+++ b/workhorse/internal/artifacts/entry.go
@@ -0,0 +1,123 @@
+package artifacts
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "mime"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "syscall"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
+)
+
+type entry struct{ senddata.Prefix }
+type entryParams struct{ Archive, Entry string }
+
+var SendEntry = &entry{"artifacts-entry:"}
+
+// Artifacts downloader doesn't support ranges when downloading a single file
+func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ var params entryParams
+ if err := e.Unpack(&params, sendData); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendEntry: unpack sendData: %v", err))
+ return
+ }
+
+ log.WithContextFields(r.Context(), log.Fields{
+ "entry": params.Entry,
+ "archive": params.Archive,
+ "path": r.URL.Path,
+ }).Print("SendEntry: sending")
+
+ if params.Archive == "" || params.Entry == "" {
+ helper.Fail500(w, r, fmt.Errorf("SendEntry: Archive or Entry is empty"))
+ return
+ }
+
+ err := unpackFileFromZip(r.Context(), params.Archive, params.Entry, w.Header(), w)
+
+ if os.IsNotExist(err) {
+ http.NotFound(w, r)
+ } else if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendEntry: %v", err))
+ }
+}
+
+func detectFileContentType(fileName string) string {
+ contentType := mime.TypeByExtension(filepath.Ext(fileName))
+ if contentType == "" {
+ contentType = "application/octet-stream"
+ }
+ return contentType
+}
+
+func unpackFileFromZip(ctx context.Context, archivePath, encodedFilename string, headers http.Header, output io.Writer) error {
+ fileName, err := zipartifacts.DecodeFileEntry(encodedFilename)
+ if err != nil {
+ return err
+ }
+
+ catFile := exec.Command("gitlab-zip-cat")
+ catFile.Env = append(os.Environ(),
+ "ARCHIVE_PATH="+archivePath,
+ "ENCODED_FILE_NAME="+encodedFilename,
+ )
+ catFile.Stderr = log.ContextLogger(ctx).Writer()
+ catFile.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+ stdout, err := catFile.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("create gitlab-zip-cat stdout pipe: %v", err)
+ }
+
+ if err := catFile.Start(); err != nil {
+ return fmt.Errorf("start %v: %v", catFile.Args, err)
+ }
+ defer helper.CleanUpProcessGroup(catFile)
+
+ basename := filepath.Base(fileName)
+ reader := bufio.NewReader(stdout)
+ contentLength, err := reader.ReadString('\n')
+ if err != nil {
+ if catFileErr := waitCatFile(catFile); catFileErr != nil {
+ return catFileErr
+ }
+ return fmt.Errorf("read content-length: %v", err)
+ }
+ contentLength = strings.TrimSuffix(contentLength, "\n")
+
+ // Write http headers about the file
+ headers.Set("Content-Length", contentLength)
+ headers.Set("Content-Type", detectFileContentType(fileName))
+ headers.Set("Content-Disposition", "attachment; filename=\""+escapeQuotes(basename)+"\"")
+ // Copy file body to client
+ if _, err := io.Copy(output, reader); err != nil {
+ return fmt.Errorf("copy stdout of %v: %v", catFile.Args, err)
+ }
+
+ return waitCatFile(catFile)
+}
+
+func waitCatFile(cmd *exec.Cmd) error {
+ err := cmd.Wait()
+ if err == nil {
+ return nil
+ }
+
+ st, ok := helper.ExitStatus(err)
+
+ if ok && (st == zipartifacts.CodeArchiveNotFound || st == zipartifacts.CodeEntryNotFound) {
+ return os.ErrNotExist
+ }
+ return fmt.Errorf("wait for %v to finish: %v", cmd.Args, err)
+
+}
diff --git a/workhorse/internal/artifacts/entry_test.go b/workhorse/internal/artifacts/entry_test.go
new file mode 100644
index 00000000000..6f1e9d360aa
--- /dev/null
+++ b/workhorse/internal/artifacts/entry_test.go
@@ -0,0 +1,134 @@
+package artifacts
+
+import (
+ "archive/zip"
+ "encoding/base64"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+func testEntryServer(t *testing.T, archive string, entry string) *httptest.ResponseRecorder {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/url/path", func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "GET", r.Method)
+
+ encodedEntry := base64.StdEncoding.EncodeToString([]byte(entry))
+ jsonParams := fmt.Sprintf(`{"Archive":"%s","Entry":"%s"}`, archive, encodedEntry)
+ data := base64.URLEncoding.EncodeToString([]byte(jsonParams))
+
+ SendEntry.Inject(w, r, data)
+ })
+
+ httpRequest, err := http.NewRequest("GET", "/url/path", nil)
+ require.NoError(t, err)
+ response := httptest.NewRecorder()
+ mux.ServeHTTP(response, httpRequest)
+ return response
+}
+
+func TestDownloadingFromValidArchive(t *testing.T) {
+ tempFile, err := ioutil.TempFile("", "uploads")
+ require.NoError(t, err)
+ defer tempFile.Close()
+ defer os.Remove(tempFile.Name())
+
+ archive := zip.NewWriter(tempFile)
+ defer archive.Close()
+ fileInArchive, err := archive.Create("test.txt")
+ require.NoError(t, err)
+ fmt.Fprint(fileInArchive, "testtest")
+ archive.Close()
+
+ response := testEntryServer(t, tempFile.Name(), "test.txt")
+
+ require.Equal(t, 200, response.Code)
+
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Type",
+ "text/plain; charset=utf-8")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Disposition",
+ "attachment; filename=\"test.txt\"")
+
+ testhelper.RequireResponseBody(t, response, "testtest")
+}
+
+func TestDownloadingFromValidHTTPArchive(t *testing.T) {
+ tempDir, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempDir)
+
+ f, err := os.Create(filepath.Join(tempDir, "archive.zip"))
+ require.NoError(t, err)
+ defer f.Close()
+
+ archive := zip.NewWriter(f)
+ defer archive.Close()
+ fileInArchive, err := archive.Create("test.txt")
+ require.NoError(t, err)
+ fmt.Fprint(fileInArchive, "testtest")
+ archive.Close()
+ f.Close()
+
+ fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
+ defer fileServer.Close()
+
+ response := testEntryServer(t, fileServer.URL+"/archive.zip", "test.txt")
+
+ require.Equal(t, 200, response.Code)
+
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Type",
+ "text/plain; charset=utf-8")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Disposition",
+ "attachment; filename=\"test.txt\"")
+
+ testhelper.RequireResponseBody(t, response, "testtest")
+}
+
+func TestDownloadingNonExistingFile(t *testing.T) {
+ tempFile, err := ioutil.TempFile("", "uploads")
+ require.NoError(t, err)
+ defer tempFile.Close()
+ defer os.Remove(tempFile.Name())
+
+ archive := zip.NewWriter(tempFile)
+ defer archive.Close()
+ archive.Close()
+
+ response := testEntryServer(t, tempFile.Name(), "test")
+ require.Equal(t, 404, response.Code)
+}
+
+func TestDownloadingFromInvalidArchive(t *testing.T) {
+ response := testEntryServer(t, "path/to/non/existing/file", "test")
+ require.Equal(t, 404, response.Code)
+}
+
+func TestIncompleteApiResponse(t *testing.T) {
+ response := testEntryServer(t, "", "")
+ require.Equal(t, 500, response.Code)
+}
+
+func TestDownloadingFromNonExistingHTTPArchive(t *testing.T) {
+ tempDir, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempDir)
+
+ fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
+ defer fileServer.Close()
+
+ response := testEntryServer(t, fileServer.URL+"/not-existing-archive-file.zip", "test.txt")
+
+ require.Equal(t, 404, response.Code)
+}
diff --git a/workhorse/internal/artifacts/escape_quotes.go b/workhorse/internal/artifacts/escape_quotes.go
new file mode 100644
index 00000000000..94db2be39b7
--- /dev/null
+++ b/workhorse/internal/artifacts/escape_quotes.go
@@ -0,0 +1,10 @@
+package artifacts
+
+import "strings"
+
+// taken from mime/multipart/writer.go
+var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
+
+func escapeQuotes(s string) string {
+ return quoteEscaper.Replace(s)
+}
diff --git a/workhorse/internal/badgateway/roundtripper.go b/workhorse/internal/badgateway/roundtripper.go
new file mode 100644
index 00000000000..a36cc9f4a9a
--- /dev/null
+++ b/workhorse/internal/badgateway/roundtripper.go
@@ -0,0 +1,115 @@
+package badgateway
+
+import (
+ "bytes"
+ "fmt"
+ "html/template"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "time"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+// Error is a custom error for pretty Sentry 'issues'
+type sentryError struct{ error }
+
+type roundTripper struct {
+ next http.RoundTripper
+ developmentMode bool
+}
+
+// NewRoundTripper creates a RoundTripper with a provided underlying next transport
+func NewRoundTripper(developmentMode bool, next http.RoundTripper) http.RoundTripper {
+ return &roundTripper{next: next, developmentMode: developmentMode}
+}
+
+func (t *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
+ start := time.Now()
+
+ res, err := t.next.RoundTrip(r)
+ if err == nil {
+ return res, err
+ }
+
+ // httputil.ReverseProxy translates all errors from this
+ // RoundTrip function into 500 errors. But the most likely error
+ // is that the Rails app is not responding, in which case users
+ // and administrators expect to see a 502 error. To show 502s
+ // instead of 500s we catch the RoundTrip error here and inject a
+ // 502 response.
+ fields := log.Fields{"duration_ms": int64(time.Since(start).Seconds() * 1000)}
+ helper.LogErrorWithFields(
+ r,
+ &sentryError{fmt.Errorf("badgateway: failed to receive response: %v", err)},
+ fields,
+ )
+
+ injectedResponse := &http.Response{
+ StatusCode: http.StatusBadGateway,
+ Status: http.StatusText(http.StatusBadGateway),
+
+ Request: r,
+ ProtoMajor: r.ProtoMajor,
+ ProtoMinor: r.ProtoMinor,
+ Proto: r.Proto,
+ Header: make(http.Header),
+ Trailer: make(http.Header),
+ }
+
+ message := "GitLab is not responding"
+ contentType := "text/plain"
+ if t.developmentMode {
+ message, contentType = developmentModeResponse(err)
+ }
+
+ injectedResponse.Body = ioutil.NopCloser(strings.NewReader(message))
+ injectedResponse.Header.Set("Content-Type", contentType)
+
+ return injectedResponse, nil
+}
+
+func developmentModeResponse(err error) (body string, contentType string) {
+ data := TemplateData{
+ Time: time.Now().Format("15:04:05"),
+ Error: err.Error(),
+ ReloadSeconds: 5,
+ }
+
+ buf := &bytes.Buffer{}
+ if err := developmentErrorTemplate.Execute(buf, data); err != nil {
+ return data.Error, "text/plain"
+ }
+
+ return buf.String(), "text/html"
+}
+
+type TemplateData struct {
+ Time string
+ Error string
+ ReloadSeconds int
+}
+
+var developmentErrorTemplate = template.Must(template.New("error502").Parse(`
+<html>
+<head>
+<title>502: GitLab is not responding</title>
+<script>
+window.setTimeout(function() { location.reload() }, {{.ReloadSeconds}} * 1000)
+</script>
+</head>
+
+<body>
+<h1>502</h1>
+<p>GitLab is not responding. The error was:</p>
+
+<pre>{{.Error}}</pre>
+
+<p>If you just started GDK it can take 60-300 seconds before GitLab has finished booting. This page will automatically reload every {{.ReloadSeconds}} seconds.</p>
+<footer>Generated by gitlab-workhorse running in development mode at {{.Time}}.</footer>
+</body>
+</html>
+`))
diff --git a/workhorse/internal/badgateway/roundtripper_test.go b/workhorse/internal/badgateway/roundtripper_test.go
new file mode 100644
index 00000000000..fc7132f9bd7
--- /dev/null
+++ b/workhorse/internal/badgateway/roundtripper_test.go
@@ -0,0 +1,56 @@
+package badgateway
+
+import (
+ "errors"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type roundtrip502 struct{}
+
+func (roundtrip502) RoundTrip(*http.Request) (*http.Response, error) {
+ return nil, errors.New("something went wrong")
+}
+
+func TestErrorPage502(t *testing.T) {
+ tests := []struct {
+ name string
+ devMode bool
+ contentType string
+ responseSnippet string
+ }{
+ {
+ name: "production mode",
+ contentType: "text/plain",
+ responseSnippet: "GitLab is not responding",
+ },
+ {
+ name: "development mode",
+ devMode: true,
+ contentType: "text/html",
+ responseSnippet: "This page will automatically reload",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ req, err := http.NewRequest("GET", "/", nil)
+ require.NoError(t, err, "build request")
+
+ rt := NewRoundTripper(tc.devMode, roundtrip502{})
+ response, err := rt.RoundTrip(req)
+ require.NoError(t, err, "perform roundtrip")
+ defer response.Body.Close()
+
+ body, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ require.Equal(t, tc.contentType, response.Header.Get("content-type"), "content type")
+ require.Equal(t, 502, response.StatusCode, "response status")
+ require.Contains(t, string(body), tc.responseSnippet)
+ })
+ }
+}
diff --git a/workhorse/internal/builds/register.go b/workhorse/internal/builds/register.go
new file mode 100644
index 00000000000..77685889cfd
--- /dev/null
+++ b/workhorse/internal/builds/register.go
@@ -0,0 +1,163 @@
+package builds
+
+import (
+ "encoding/json"
+ "errors"
+ "net/http"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/redis"
+)
+
+const (
+ maxRegisterBodySize = 32 * 1024
+ runnerBuildQueue = "runner:build_queue:"
+ runnerBuildQueueHeaderKey = "Gitlab-Ci-Builds-Polling"
+ runnerBuildQueueHeaderValue = "yes"
+)
+
+var (
+ registerHandlerRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_builds_register_handler_requests",
+ Help: "Describes how many requests in different states hit a register handler",
+ },
+ []string{"status"},
+ )
+ registerHandlerOpen = promauto.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_builds_register_handler_open",
+ Help: "Describes how many requests is currently open in given state",
+ },
+ []string{"state"},
+ )
+
+ registerHandlerOpenAtReading = registerHandlerOpen.WithLabelValues("reading")
+ registerHandlerOpenAtProxying = registerHandlerOpen.WithLabelValues("proxying")
+ registerHandlerOpenAtWatching = registerHandlerOpen.WithLabelValues("watching")
+
+ registerHandlerBodyReadErrors = registerHandlerRequests.WithLabelValues("body-read-error")
+ registerHandlerBodyParseErrors = registerHandlerRequests.WithLabelValues("body-parse-error")
+ registerHandlerMissingValues = registerHandlerRequests.WithLabelValues("missing-values")
+ registerHandlerWatchErrors = registerHandlerRequests.WithLabelValues("watch-error")
+ registerHandlerAlreadyChangedRequests = registerHandlerRequests.WithLabelValues("already-changed")
+ registerHandlerSeenChangeRequests = registerHandlerRequests.WithLabelValues("seen-change")
+ registerHandlerTimeoutRequests = registerHandlerRequests.WithLabelValues("timeout")
+ registerHandlerNoChangeRequests = registerHandlerRequests.WithLabelValues("no-change")
+)
+
+type largeBodyError struct{ error }
+
+type WatchKeyHandler func(key, value string, timeout time.Duration) (redis.WatchKeyStatus, error)
+
+type runnerRequest struct {
+ Token string `json:"token,omitempty"`
+ LastUpdate string `json:"last_update,omitempty"`
+}
+
+func readRunnerBody(w http.ResponseWriter, r *http.Request) ([]byte, error) {
+ registerHandlerOpenAtReading.Inc()
+ defer registerHandlerOpenAtReading.Dec()
+
+ return helper.ReadRequestBody(w, r, maxRegisterBodySize)
+}
+
+func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) {
+ if !helper.IsApplicationJson(r) {
+ return nil, errors.New("invalid content-type received")
+ }
+
+ var runnerRequest runnerRequest
+ err := json.Unmarshal(body, &runnerRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ return &runnerRequest, nil
+}
+
+func proxyRegisterRequest(h http.Handler, w http.ResponseWriter, r *http.Request) {
+ registerHandlerOpenAtProxying.Inc()
+ defer registerHandlerOpenAtProxying.Dec()
+
+ h.ServeHTTP(w, r)
+}
+
+func watchForRunnerChange(watchHandler WatchKeyHandler, token, lastUpdate string, duration time.Duration) (redis.WatchKeyStatus, error) {
+ registerHandlerOpenAtWatching.Inc()
+ defer registerHandlerOpenAtWatching.Dec()
+
+ return watchHandler(runnerBuildQueue+token, lastUpdate, duration)
+}
+
+func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDuration time.Duration) http.Handler {
+ if pollingDuration == 0 {
+ return h
+ }
+
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(runnerBuildQueueHeaderKey, runnerBuildQueueHeaderValue)
+
+ requestBody, err := readRunnerBody(w, r)
+ if err != nil {
+ registerHandlerBodyReadErrors.Inc()
+ helper.RequestEntityTooLarge(w, r, &largeBodyError{err})
+ return
+ }
+
+ newRequest := helper.CloneRequestWithNewBody(r, requestBody)
+
+ runnerRequest, err := readRunnerRequest(r, requestBody)
+ if err != nil {
+ registerHandlerBodyParseErrors.Inc()
+ proxyRegisterRequest(h, w, newRequest)
+ return
+ }
+
+ if runnerRequest.Token == "" || runnerRequest.LastUpdate == "" {
+ registerHandlerMissingValues.Inc()
+ proxyRegisterRequest(h, w, newRequest)
+ return
+ }
+
+ result, err := watchForRunnerChange(watchHandler, runnerRequest.Token,
+ runnerRequest.LastUpdate, pollingDuration)
+ if err != nil {
+ registerHandlerWatchErrors.Inc()
+ proxyRegisterRequest(h, w, newRequest)
+ return
+ }
+
+ switch result {
+ // It means that we detected a change before starting watching on change,
+ // We proxy request to Rails, to see whether we have a build to receive
+ case redis.WatchKeyStatusAlreadyChanged:
+ registerHandlerAlreadyChangedRequests.Inc()
+ proxyRegisterRequest(h, w, newRequest)
+
+ // It means that we detected a change after watching.
+ // We could potentially proxy request to Rails, but...
+ // We can end-up with unreliable responses,
+ // as don't really know whether ResponseWriter is still in a sane state,
+ // for example the connection is dead
+ case redis.WatchKeyStatusSeenChange:
+ registerHandlerSeenChangeRequests.Inc()
+ w.WriteHeader(http.StatusNoContent)
+
+ // When we receive one of these statuses, it means that we detected no change,
+ // so we return to runner 204, which means nothing got changed,
+ // and there's no new builds to process
+ case redis.WatchKeyStatusTimeout:
+ registerHandlerTimeoutRequests.Inc()
+ w.WriteHeader(http.StatusNoContent)
+
+ case redis.WatchKeyStatusNoChange:
+ registerHandlerNoChangeRequests.Inc()
+ w.WriteHeader(http.StatusNoContent)
+ }
+ })
+}
diff --git a/workhorse/internal/builds/register_test.go b/workhorse/internal/builds/register_test.go
new file mode 100644
index 00000000000..a72d82dc2b7
--- /dev/null
+++ b/workhorse/internal/builds/register_test.go
@@ -0,0 +1,108 @@
+package builds
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/redis"
+)
+
+const upstreamResponseCode = 999
+
+func echoRequest(rw http.ResponseWriter, req *http.Request) {
+ rw.WriteHeader(upstreamResponseCode)
+ io.Copy(rw, req.Body)
+}
+
+var echoRequestFunc = http.HandlerFunc(echoRequest)
+
+func expectHandlerWithWatcher(t *testing.T, watchHandler WatchKeyHandler, data string, contentType string, expectedHttpStatus int, msgAndArgs ...interface{}) {
+ h := RegisterHandler(echoRequestFunc, watchHandler, time.Second)
+
+ rw := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(data))
+ req.Header.Set("Content-Type", contentType)
+
+ h.ServeHTTP(rw, req)
+
+ require.Equal(t, expectedHttpStatus, rw.Code, msgAndArgs...)
+}
+
+func expectHandler(t *testing.T, data string, contentType string, expectedHttpStatus int, msgAndArgs ...interface{}) {
+ expectHandlerWithWatcher(t, nil, data, contentType, expectedHttpStatus, msgAndArgs...)
+}
+
+func TestRegisterHandlerLargeBody(t *testing.T) {
+ data := strings.Repeat(".", maxRegisterBodySize+5)
+ expectHandler(t, data, "application/json", http.StatusRequestEntityTooLarge,
+ "rejects body with entity too large")
+}
+
+func TestRegisterHandlerInvalidRunnerRequest(t *testing.T) {
+ expectHandler(t, "invalid content", "text/plain", upstreamResponseCode,
+ "proxies request to upstream")
+}
+
+func TestRegisterHandlerInvalidJsonPayload(t *testing.T) {
+ expectHandler(t, `{[`, "application/json", upstreamResponseCode,
+ "fails on parsing body and proxies request to upstream")
+}
+
+func TestRegisterHandlerMissingData(t *testing.T) {
+ testCases := []string{
+ `{"token":"token"}`,
+ `{"last_update":"data"}`,
+ }
+
+ for _, testCase := range testCases {
+ expectHandler(t, testCase, "application/json", upstreamResponseCode,
+ "fails on argument validation and proxies request to upstream")
+ }
+}
+
+func expectWatcherToBeExecuted(t *testing.T, watchKeyStatus redis.WatchKeyStatus, watchKeyError error,
+ httpStatus int, msgAndArgs ...interface{}) {
+ executed := false
+ watchKeyHandler := func(key, value string, timeout time.Duration) (redis.WatchKeyStatus, error) {
+ executed = true
+ return watchKeyStatus, watchKeyError
+ }
+
+ parsableData := `{"token":"token","last_update":"last_update"}`
+
+ expectHandlerWithWatcher(t, watchKeyHandler, parsableData, "application/json", httpStatus, msgAndArgs...)
+ require.True(t, executed, msgAndArgs...)
+}
+
+func TestRegisterHandlerWatcherError(t *testing.T) {
+ expectWatcherToBeExecuted(t, redis.WatchKeyStatusNoChange, errors.New("redis connection"),
+ upstreamResponseCode, "proxies data to upstream")
+}
+
+func TestRegisterHandlerWatcherAlreadyChanged(t *testing.T) {
+ expectWatcherToBeExecuted(t, redis.WatchKeyStatusAlreadyChanged, nil,
+ upstreamResponseCode, "proxies data to upstream")
+}
+
+func TestRegisterHandlerWatcherSeenChange(t *testing.T) {
+ expectWatcherToBeExecuted(t, redis.WatchKeyStatusSeenChange, nil,
+ http.StatusNoContent)
+}
+
+func TestRegisterHandlerWatcherTimeout(t *testing.T) {
+ expectWatcherToBeExecuted(t, redis.WatchKeyStatusTimeout, nil,
+ http.StatusNoContent)
+}
+
+func TestRegisterHandlerWatcherNoChange(t *testing.T) {
+ expectWatcherToBeExecuted(t, redis.WatchKeyStatusNoChange, nil,
+ http.StatusNoContent)
+}
diff --git a/workhorse/internal/channel/auth_checker.go b/workhorse/internal/channel/auth_checker.go
new file mode 100644
index 00000000000..f44850e0861
--- /dev/null
+++ b/workhorse/internal/channel/auth_checker.go
@@ -0,0 +1,69 @@
+package channel
+
+import (
+ "errors"
+ "net/http"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+)
+
+type AuthCheckerFunc func() *api.ChannelSettings
+
+// Regularly checks that authorization is still valid for a channel, outputting
+// to the stopper when it isn't
+type AuthChecker struct {
+ Checker AuthCheckerFunc
+ Template *api.ChannelSettings
+ StopCh chan error
+ Done chan struct{}
+ Count int64
+}
+
+var ErrAuthChanged = errors.New("connection closed: authentication changed or endpoint unavailable")
+
+func NewAuthChecker(f AuthCheckerFunc, template *api.ChannelSettings, stopCh chan error) *AuthChecker {
+ return &AuthChecker{
+ Checker: f,
+ Template: template,
+ StopCh: stopCh,
+ Done: make(chan struct{}),
+ }
+}
+func (c *AuthChecker) Loop(interval time.Duration) {
+ for {
+ select {
+ case <-time.After(interval):
+ settings := c.Checker()
+ if !c.Template.IsEqual(settings) {
+ c.StopCh <- ErrAuthChanged
+ return
+ }
+ c.Count = c.Count + 1
+ case <-c.Done:
+ return
+ }
+ }
+}
+
+func (c *AuthChecker) Close() error {
+ close(c.Done)
+ return nil
+}
+
+// Generates a CheckerFunc from an *api.API + request needing authorization
+func authCheckFunc(myAPI *api.API, r *http.Request, suffix string) AuthCheckerFunc {
+ return func() *api.ChannelSettings {
+ httpResponse, authResponse, err := myAPI.PreAuthorize(suffix, r)
+ if err != nil {
+ return nil
+ }
+ defer httpResponse.Body.Close()
+
+ if httpResponse.StatusCode != http.StatusOK || authResponse == nil {
+ return nil
+ }
+
+ return authResponse.Channel
+ }
+}
diff --git a/workhorse/internal/channel/auth_checker_test.go b/workhorse/internal/channel/auth_checker_test.go
new file mode 100644
index 00000000000..18beb45cf3a
--- /dev/null
+++ b/workhorse/internal/channel/auth_checker_test.go
@@ -0,0 +1,53 @@
+package channel
+
+import (
+ "testing"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+)
+
+func checkerSeries(values ...*api.ChannelSettings) AuthCheckerFunc {
+ return func() *api.ChannelSettings {
+ if len(values) == 0 {
+ return nil
+ }
+ out := values[0]
+ values = values[1:]
+ return out
+ }
+}
+
+func TestAuthCheckerStopsWhenAuthFails(t *testing.T) {
+ template := &api.ChannelSettings{Url: "ws://example.com"}
+ stopCh := make(chan error)
+ series := checkerSeries(template, template, template)
+ ac := NewAuthChecker(series, template, stopCh)
+
+ go ac.Loop(1 * time.Millisecond)
+ if err := <-stopCh; err != ErrAuthChanged {
+ t.Fatalf("Expected ErrAuthChanged, got %v", err)
+ }
+
+ if ac.Count != 3 {
+ t.Fatalf("Expected 3 successful checks, got %v", ac.Count)
+ }
+}
+
+func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) {
+ template := &api.ChannelSettings{Url: "ws://example.com"}
+ changed := template.Clone()
+ changed.Url = "wss://example.com"
+ stopCh := make(chan error)
+ series := checkerSeries(template, changed, template)
+ ac := NewAuthChecker(series, template, stopCh)
+
+ go ac.Loop(1 * time.Millisecond)
+ if err := <-stopCh; err != ErrAuthChanged {
+ t.Fatalf("Expected ErrAuthChanged, got %v", err)
+ }
+
+ if ac.Count != 1 {
+ t.Fatalf("Expected 1 successful check, got %v", ac.Count)
+ }
+}
diff --git a/workhorse/internal/channel/channel.go b/workhorse/internal/channel/channel.go
new file mode 100644
index 00000000000..381ce95df82
--- /dev/null
+++ b/workhorse/internal/channel/channel.go
@@ -0,0 +1,132 @@
+package channel
+
+import (
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/gorilla/websocket"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ // See doc/channel.md for documentation of this subprotocol
+ subprotocols = []string{"terminal.gitlab.com", "base64.terminal.gitlab.com"}
+ upgrader = &websocket.Upgrader{Subprotocols: subprotocols}
+ ReauthenticationInterval = 5 * time.Minute
+ BrowserPingInterval = 30 * time.Second
+)
+
+func Handler(myAPI *api.API) http.Handler {
+ return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
+ if err := a.Channel.Validate(); err != nil {
+ helper.Fail500(w, r, err)
+ return
+ }
+
+ proxy := NewProxy(2) // two stoppers: auth checker, max time
+ checker := NewAuthChecker(
+ authCheckFunc(myAPI, r, "authorize"),
+ a.Channel,
+ proxy.StopCh,
+ )
+ defer checker.Close()
+ go checker.Loop(ReauthenticationInterval)
+ go closeAfterMaxTime(proxy, a.Channel.MaxSessionTime)
+
+ ProxyChannel(w, r, a.Channel, proxy)
+ }, "authorize")
+}
+
+func ProxyChannel(w http.ResponseWriter, r *http.Request, settings *api.ChannelSettings, proxy *Proxy) {
+ server, err := connectToServer(settings, r)
+ if err != nil {
+ helper.Fail500(w, r, err)
+ log.ContextLogger(r.Context()).WithError(err).Print("Channel: connecting to server failed")
+ return
+ }
+ defer server.UnderlyingConn().Close()
+ serverAddr := server.UnderlyingConn().RemoteAddr().String()
+
+ client, err := upgradeClient(w, r)
+ if err != nil {
+ log.ContextLogger(r.Context()).WithError(err).Print("Channel: upgrading client to websocket failed")
+ return
+ }
+
+ // Regularly send ping messages to the browser to keep the websocket from
+ // being timed out by intervening proxies.
+ go pingLoop(client)
+
+ defer client.UnderlyingConn().Close()
+ clientAddr := getClientAddr(r) // We can't know the port with confidence
+
+ logEntry := log.WithContextFields(r.Context(), log.Fields{
+ "clientAddr": clientAddr,
+ "serverAddr": serverAddr,
+ })
+
+ logEntry.Print("Channel: started proxying")
+
+ defer logEntry.Print("Channel: finished proxying")
+
+ if err := proxy.Serve(server, client, serverAddr, clientAddr); err != nil {
+ logEntry.WithError(err).Print("Channel: error proxying")
+ }
+}
+
+// In the future, we might want to look at X-Client-Ip or X-Forwarded-For
+func getClientAddr(r *http.Request) string {
+ return r.RemoteAddr
+}
+
+func upgradeClient(w http.ResponseWriter, r *http.Request) (Connection, error) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return Wrap(conn, conn.Subprotocol()), nil
+}
+
+func pingLoop(conn Connection) {
+ for {
+ time.Sleep(BrowserPingInterval)
+ deadline := time.Now().Add(5 * time.Second)
+ if err := conn.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
+ // Either the connection was already closed so no further pings are
+ // needed, or this connection is now dead and no further pings can
+ // be sent.
+ break
+ }
+ }
+}
+
+func connectToServer(settings *api.ChannelSettings, r *http.Request) (Connection, error) {
+ settings = settings.Clone()
+
+ helper.SetForwardedFor(&settings.Header, r)
+
+ conn, _, err := settings.Dial()
+ if err != nil {
+ return nil, err
+ }
+
+ return Wrap(conn, conn.Subprotocol()), nil
+}
+
+func closeAfterMaxTime(proxy *Proxy, maxSessionTime int) {
+ if maxSessionTime == 0 {
+ return
+ }
+
+ <-time.After(time.Duration(maxSessionTime) * time.Second)
+ proxy.StopCh <- fmt.Errorf(
+ "connection closed: session time greater than maximum time allowed - %v seconds",
+ maxSessionTime,
+ )
+}
diff --git a/workhorse/internal/channel/proxy.go b/workhorse/internal/channel/proxy.go
new file mode 100644
index 00000000000..71f58092276
--- /dev/null
+++ b/workhorse/internal/channel/proxy.go
@@ -0,0 +1,56 @@
+package channel
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+// ANSI "end of channel" code
+var eot = []byte{0x04}
+
+// An abstraction of gorilla's *websocket.Conn
+type Connection interface {
+ UnderlyingConn() net.Conn
+ ReadMessage() (int, []byte, error)
+ WriteMessage(int, []byte) error
+ WriteControl(int, []byte, time.Time) error
+}
+
+type Proxy struct {
+ StopCh chan error
+}
+
+// stoppers is the number of goroutines that may attempt to call Stop()
+func NewProxy(stoppers int) *Proxy {
+ return &Proxy{
+ StopCh: make(chan error, stoppers+2), // each proxy() call is a stopper
+ }
+}
+
+func (p *Proxy) Serve(upstream, downstream Connection, upstreamAddr, downstreamAddr string) error {
+ // This signals the upstream channel to kill the exec'd process
+ defer upstream.WriteMessage(websocket.BinaryMessage, eot)
+
+ go p.proxy(upstream, downstream, upstreamAddr, downstreamAddr)
+ go p.proxy(downstream, upstream, downstreamAddr, upstreamAddr)
+
+ return <-p.StopCh
+}
+
+func (p *Proxy) proxy(to, from Connection, toAddr, fromAddr string) {
+ for {
+ messageType, data, err := from.ReadMessage()
+ if err != nil {
+ p.StopCh <- fmt.Errorf("reading from %s: %s", fromAddr, err)
+ break
+ }
+
+ if err := to.WriteMessage(messageType, data); err != nil {
+ p.StopCh <- fmt.Errorf("writing to %s: %s", toAddr, err)
+ break
+ }
+ }
+}
diff --git a/workhorse/internal/channel/wrappers.go b/workhorse/internal/channel/wrappers.go
new file mode 100644
index 00000000000..6fd955bedc7
--- /dev/null
+++ b/workhorse/internal/channel/wrappers.go
@@ -0,0 +1,134 @@
+package channel
+
+import (
+ "encoding/base64"
+ "net"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+func Wrap(conn Connection, subprotocol string) Connection {
+ switch subprotocol {
+ case "channel.k8s.io":
+ return &kubeWrapper{base64: false, conn: conn}
+ case "base64.channel.k8s.io":
+ return &kubeWrapper{base64: true, conn: conn}
+ case "terminal.gitlab.com":
+ return &gitlabWrapper{base64: false, conn: conn}
+ case "base64.terminal.gitlab.com":
+ return &gitlabWrapper{base64: true, conn: conn}
+ }
+
+ return conn
+}
+
+type kubeWrapper struct {
+ base64 bool
+ conn Connection
+}
+
+type gitlabWrapper struct {
+ base64 bool
+ conn Connection
+}
+
+func (w *gitlabWrapper) ReadMessage() (int, []byte, error) {
+ mt, data, err := w.conn.ReadMessage()
+ if err != nil {
+ return mt, data, err
+ }
+
+ if isData(mt) {
+ mt = websocket.BinaryMessage
+ if w.base64 {
+ data, err = decodeBase64(data)
+ }
+ }
+
+ return mt, data, err
+}
+
+func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error {
+ if isData(mt) {
+ if w.base64 {
+ mt = websocket.TextMessage
+ data = encodeBase64(data)
+ } else {
+ mt = websocket.BinaryMessage
+ }
+ }
+
+ return w.conn.WriteMessage(mt, data)
+}
+
+func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
+ return w.conn.WriteControl(mt, data, deadline)
+}
+
+func (w *gitlabWrapper) UnderlyingConn() net.Conn {
+ return w.conn.UnderlyingConn()
+}
+
+// Coalesces all wsstreams into a single stream. In practice, we should only
+// receive data on stream 1.
+func (w *kubeWrapper) ReadMessage() (int, []byte, error) {
+ mt, data, err := w.conn.ReadMessage()
+ if err != nil {
+ return mt, data, err
+ }
+
+ if isData(mt) {
+ mt = websocket.BinaryMessage
+
+ // Remove the WSStream channel number, decode to raw
+ if len(data) > 0 {
+ data = data[1:]
+ if w.base64 {
+ data, err = decodeBase64(data)
+ }
+ }
+ }
+
+ return mt, data, err
+}
+
+// Always sends to wsstream 0
+func (w *kubeWrapper) WriteMessage(mt int, data []byte) error {
+ if isData(mt) {
+ if w.base64 {
+ mt = websocket.TextMessage
+ data = append([]byte{'0'}, encodeBase64(data)...)
+ } else {
+ mt = websocket.BinaryMessage
+ data = append([]byte{0}, data...)
+ }
+ }
+
+ return w.conn.WriteMessage(mt, data)
+}
+
+func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
+ return w.conn.WriteControl(mt, data, deadline)
+}
+
+func (w *kubeWrapper) UnderlyingConn() net.Conn {
+ return w.conn.UnderlyingConn()
+}
+
+func isData(mt int) bool {
+ return mt == websocket.BinaryMessage || mt == websocket.TextMessage
+}
+
+func encodeBase64(data []byte) []byte {
+ buf := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
+ base64.StdEncoding.Encode(buf, data)
+
+ return buf
+}
+
+func decodeBase64(data []byte) ([]byte, error) {
+ buf := make([]byte, base64.StdEncoding.DecodedLen(len(data)))
+ n, err := base64.StdEncoding.Decode(buf, data)
+ return buf[:n], err
+}
diff --git a/workhorse/internal/channel/wrappers_test.go b/workhorse/internal/channel/wrappers_test.go
new file mode 100644
index 00000000000..1e0226f85d8
--- /dev/null
+++ b/workhorse/internal/channel/wrappers_test.go
@@ -0,0 +1,155 @@
+package channel
+
+import (
+ "bytes"
+ "errors"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+type testcase struct {
+ input *fakeConn
+ expected *fakeConn
+}
+
+type fakeConn struct {
+ // WebSocket message type
+ mt int
+ data []byte
+ err error
+}
+
+func (f *fakeConn) ReadMessage() (int, []byte, error) {
+ return f.mt, f.data, f.err
+}
+
+func (f *fakeConn) WriteMessage(mt int, data []byte) error {
+ f.mt = mt
+ f.data = data
+ return f.err
+}
+
+func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error {
+ f.mt = mt
+ f.data = data
+ return f.err
+}
+
+func (f *fakeConn) UnderlyingConn() net.Conn {
+ return nil
+}
+
+func fake(mt int, data []byte, err error) *fakeConn {
+ return &fakeConn{mt: mt, data: []byte(data), err: err}
+}
+
+var (
+ msg = []byte("foo bar")
+ msgBase64 = []byte("Zm9vIGJhcg==")
+ kubeMsg = append([]byte{0}, msg...)
+ kubeMsgBase64 = append([]byte{'0'}, msgBase64...)
+
+ errFake = errors.New("fake error")
+
+ text = websocket.TextMessage
+ binary = websocket.BinaryMessage
+ other = 999
+
+ fakeOther = fake(other, []byte("foo"), nil)
+)
+
+func requireEqualConn(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) {
+ if expected.mt != actual.mt {
+ t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt)
+ t.Fatalf(msg, args...)
+ }
+
+ if !bytes.Equal(expected.data, actual.data) {
+ t.Logf("data expected to be %q but was %q: ", expected.data, actual.data)
+ t.Fatalf(msg, args...)
+ }
+
+ if expected.err != actual.err {
+ t.Logf("error expected to be %v but was %v", expected.err, actual.err)
+ t.Fatalf(msg, args...)
+ }
+}
+
+func TestReadMessage(t *testing.T) {
+ testCases := map[string][]testcase{
+ "channel.k8s.io": {
+ {fake(binary, kubeMsg, errFake), fake(binary, kubeMsg, errFake)},
+ {fake(binary, kubeMsg, nil), fake(binary, msg, nil)},
+ {fake(text, kubeMsg, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.channel.k8s.io": {
+ {fake(text, kubeMsgBase64, errFake), fake(text, kubeMsgBase64, errFake)},
+ {fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)},
+ {fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "terminal.gitlab.com": {
+ {fake(binary, msg, errFake), fake(binary, msg, errFake)},
+ {fake(binary, msg, nil), fake(binary, msg, nil)},
+ {fake(text, msg, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.terminal.gitlab.com": {
+ {fake(text, msgBase64, errFake), fake(text, msgBase64, errFake)},
+ {fake(text, msgBase64, nil), fake(binary, msg, nil)},
+ {fake(binary, msgBase64, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ }
+
+ for subprotocol, cases := range testCases {
+ for i, tc := range cases {
+ conn := Wrap(tc.input, subprotocol)
+ mt, data, err := conn.ReadMessage()
+ actual := fake(mt, data, err)
+ requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
+ }
+ }
+}
+
+func TestWriteMessage(t *testing.T) {
+ testCases := map[string][]testcase{
+ "channel.k8s.io": {
+ {fake(binary, msg, errFake), fake(binary, kubeMsg, errFake)},
+ {fake(binary, msg, nil), fake(binary, kubeMsg, nil)},
+ {fake(text, msg, nil), fake(binary, kubeMsg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.channel.k8s.io": {
+ {fake(binary, msg, errFake), fake(text, kubeMsgBase64, errFake)},
+ {fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)},
+ {fake(text, msg, nil), fake(text, kubeMsgBase64, nil)},
+ {fakeOther, fakeOther},
+ },
+ "terminal.gitlab.com": {
+ {fake(binary, msg, errFake), fake(binary, msg, errFake)},
+ {fake(binary, msg, nil), fake(binary, msg, nil)},
+ {fake(text, msg, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.terminal.gitlab.com": {
+ {fake(binary, msg, errFake), fake(text, msgBase64, errFake)},
+ {fake(binary, msg, nil), fake(text, msgBase64, nil)},
+ {fake(text, msg, nil), fake(text, msgBase64, nil)},
+ {fakeOther, fakeOther},
+ },
+ }
+
+ for subprotocol, cases := range testCases {
+ for i, tc := range cases {
+ actual := fake(0, nil, tc.input.err)
+ conn := Wrap(actual, subprotocol)
+ actual.err = conn.WriteMessage(tc.input.mt, tc.input.data)
+ requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
+ }
+ }
+}
diff --git a/workhorse/internal/config/config.go b/workhorse/internal/config/config.go
new file mode 100644
index 00000000000..84849c72744
--- /dev/null
+++ b/workhorse/internal/config/config.go
@@ -0,0 +1,154 @@
+package config
+
+import (
+ "math"
+ "net/url"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/Azure/azure-storage-blob-go/azblob"
+ "github.com/BurntSushi/toml"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gocloud.dev/blob"
+ "gocloud.dev/blob/azureblob"
+)
+
+type TomlURL struct {
+ url.URL
+}
+
+func (u *TomlURL) UnmarshalText(text []byte) error {
+ temp, err := url.Parse(string(text))
+ u.URL = *temp
+ return err
+}
+
+type TomlDuration struct {
+ time.Duration
+}
+
+func (d *TomlDuration) UnmarshalTest(text []byte) error {
+ temp, err := time.ParseDuration(string(text))
+ d.Duration = temp
+ return err
+}
+
+type ObjectStorageCredentials struct {
+ Provider string
+
+ S3Credentials S3Credentials `toml:"s3"`
+ AzureCredentials AzureCredentials `toml:"azurerm"`
+}
+
+type ObjectStorageConfig struct {
+ URLMux *blob.URLMux `toml:"-"`
+}
+
+type S3Credentials struct {
+ AwsAccessKeyID string `toml:"aws_access_key_id"`
+ AwsSecretAccessKey string `toml:"aws_secret_access_key"`
+}
+
+type S3Config struct {
+ Region string `toml:"-"`
+ Bucket string `toml:"-"`
+ PathStyle bool `toml:"-"`
+ Endpoint string `toml:"-"`
+ UseIamProfile bool `toml:"-"`
+ ServerSideEncryption string `toml:"-"` // Server-side encryption mode (e.g. AES256, aws:kms)
+ SSEKMSKeyID string `toml:"-"` // Server-side encryption key-management service key ID (e.g. arn:aws:xxx)
+}
+
+type GoCloudConfig struct {
+ URL string `toml:"-"`
+}
+
+type AzureCredentials struct {
+ AccountName string `toml:"azure_storage_account_name"`
+ AccountKey string `toml:"azure_storage_access_key"`
+}
+
+type RedisConfig struct {
+ URL TomlURL
+ Sentinel []TomlURL
+ SentinelMaster string
+ Password string
+ DB *int
+ ReadTimeout *TomlDuration
+ WriteTimeout *TomlDuration
+ KeepAlivePeriod *TomlDuration
+ MaxIdle *int
+ MaxActive *int
+}
+
+type ImageResizerConfig struct {
+ MaxScalerProcs uint32 `toml:"max_scaler_procs"`
+ MaxFilesize uint64 `toml:"max_filesize"`
+}
+
+type Config struct {
+ Redis *RedisConfig `toml:"redis"`
+ Backend *url.URL `toml:"-"`
+ CableBackend *url.URL `toml:"-"`
+ Version string `toml:"-"`
+ DocumentRoot string `toml:"-"`
+ DevelopmentMode bool `toml:"-"`
+ Socket string `toml:"-"`
+ CableSocket string `toml:"-"`
+ ProxyHeadersTimeout time.Duration `toml:"-"`
+ APILimit uint `toml:"-"`
+ APIQueueLimit uint `toml:"-"`
+ APIQueueTimeout time.Duration `toml:"-"`
+ APICILongPollingDuration time.Duration `toml:"-"`
+ ObjectStorageConfig ObjectStorageConfig `toml:"-"`
+ ObjectStorageCredentials ObjectStorageCredentials `toml:"object_storage"`
+ PropagateCorrelationID bool `toml:"-"`
+ ImageResizerConfig ImageResizerConfig `toml:"image_resizer"`
+ AltDocumentRoot string `toml:"alt_document_root"`
+}
+
+var DefaultImageResizerConfig = ImageResizerConfig{
+ MaxScalerProcs: uint32(math.Max(2, float64(runtime.NumCPU())/2)),
+ MaxFilesize: 250 * 1000, // 250kB,
+}
+
+func LoadConfig(data string) (*Config, error) {
+ cfg := &Config{ImageResizerConfig: DefaultImageResizerConfig}
+
+ if _, err := toml.Decode(data, cfg); err != nil {
+ return nil, err
+ }
+
+ return cfg, nil
+}
+
+func (c *Config) RegisterGoCloudURLOpeners() error {
+ c.ObjectStorageConfig.URLMux = new(blob.URLMux)
+
+ creds := c.ObjectStorageCredentials
+ if strings.EqualFold(creds.Provider, "AzureRM") && creds.AzureCredentials.AccountName != "" && creds.AzureCredentials.AccountKey != "" {
+ accountName := azureblob.AccountName(creds.AzureCredentials.AccountName)
+ accountKey := azureblob.AccountKey(creds.AzureCredentials.AccountKey)
+
+ credential, err := azureblob.NewCredential(accountName, accountKey)
+ if err != nil {
+ log.WithError(err).Error("error creating Azure credentials")
+ return err
+ }
+
+ pipeline := azureblob.NewPipeline(credential, azblob.PipelineOptions{})
+
+ azureURLOpener := &azureURLOpener{
+ &azureblob.URLOpener{
+ AccountName: accountName,
+ Pipeline: pipeline,
+ Options: azureblob.Options{Credential: credential},
+ },
+ }
+
+ c.ObjectStorageConfig.URLMux.RegisterBucket(azureblob.Scheme, azureURLOpener)
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/config/config_test.go b/workhorse/internal/config/config_test.go
new file mode 100644
index 00000000000..102b29a0813
--- /dev/null
+++ b/workhorse/internal/config/config_test.go
@@ -0,0 +1,111 @@
+package config
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+const azureConfig = `
+[object_storage]
+provider = "AzureRM"
+
+[object_storage.azurerm]
+azure_storage_account_name = "azuretester"
+azure_storage_access_key = "deadbeef"
+`
+
+func TestLoadEmptyConfig(t *testing.T) {
+ config := ``
+
+ cfg, err := LoadConfig(config)
+ require.NoError(t, err)
+
+ require.Empty(t, cfg.AltDocumentRoot)
+ require.Equal(t, cfg.ImageResizerConfig.MaxFilesize, uint64(250000))
+ require.GreaterOrEqual(t, cfg.ImageResizerConfig.MaxScalerProcs, uint32(2))
+
+ require.Equal(t, ObjectStorageCredentials{}, cfg.ObjectStorageCredentials)
+ require.NoError(t, cfg.RegisterGoCloudURLOpeners())
+}
+
+func TestLoadObjectStorageConfig(t *testing.T) {
+ config := `
+[object_storage]
+provider = "AWS"
+
+[object_storage.s3]
+aws_access_key_id = "minio"
+aws_secret_access_key = "gdk-minio"
+`
+
+ cfg, err := LoadConfig(config)
+ require.NoError(t, err)
+
+ require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
+
+ expected := ObjectStorageCredentials{
+ Provider: "AWS",
+ S3Credentials: S3Credentials{
+ AwsAccessKeyID: "minio",
+ AwsSecretAccessKey: "gdk-minio",
+ },
+ }
+
+ require.Equal(t, expected, cfg.ObjectStorageCredentials)
+}
+
+func TestRegisterGoCloudURLOpeners(t *testing.T) {
+ cfg, err := LoadConfig(azureConfig)
+ require.NoError(t, err)
+
+ require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
+
+ expected := ObjectStorageCredentials{
+ Provider: "AzureRM",
+ AzureCredentials: AzureCredentials{
+ AccountName: "azuretester",
+ AccountKey: "deadbeef",
+ },
+ }
+
+ require.Equal(t, expected, cfg.ObjectStorageCredentials)
+ require.Nil(t, cfg.ObjectStorageConfig.URLMux)
+
+ require.NoError(t, cfg.RegisterGoCloudURLOpeners())
+ require.NotNil(t, cfg.ObjectStorageConfig.URLMux)
+
+ require.True(t, cfg.ObjectStorageConfig.URLMux.ValidBucketScheme("azblob"))
+ require.Equal(t, []string{"azblob"}, cfg.ObjectStorageConfig.URLMux.BucketSchemes())
+}
+
+func TestLoadImageResizerConfig(t *testing.T) {
+ config := `
+[image_resizer]
+max_scaler_procs = 200
+max_filesize = 350000
+`
+
+ cfg, err := LoadConfig(config)
+ require.NoError(t, err)
+
+ require.NotNil(t, cfg.ImageResizerConfig, "Expected image resizer config")
+
+ expected := ImageResizerConfig{
+ MaxScalerProcs: 200,
+ MaxFilesize: 350000,
+ }
+
+ require.Equal(t, expected, cfg.ImageResizerConfig)
+}
+
+func TestAltDocumentConfig(t *testing.T) {
+ config := `
+alt_document_root = "/path/to/documents"
+`
+
+ cfg, err := LoadConfig(config)
+ require.NoError(t, err)
+
+ require.Equal(t, "/path/to/documents", cfg.AltDocumentRoot)
+}
diff --git a/workhorse/internal/config/url_openers.go b/workhorse/internal/config/url_openers.go
new file mode 100644
index 00000000000..d3c96ee9eef
--- /dev/null
+++ b/workhorse/internal/config/url_openers.go
@@ -0,0 +1,51 @@
+package config
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+
+ "gocloud.dev/blob"
+ "gocloud.dev/blob/azureblob"
+)
+
+// This code can be removed once https://github.com/google/go-cloud/pull/2851 is merged.
+
+// URLOpener opens Azure URLs like "azblob://mybucket".
+//
+// The URL host is used as the bucket name.
+//
+// The following query options are supported:
+// - domain: The domain name used to access the Azure Blob storage (e.g. blob.core.windows.net)
+type azureURLOpener struct {
+ *azureblob.URLOpener
+}
+
+func (o *azureURLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
+ opts := new(azureblob.Options)
+ *opts = o.Options
+
+ err := setOptionsFromURLParams(u.Query(), opts)
+ if err != nil {
+ return nil, err
+ }
+ return azureblob.OpenBucket(ctx, o.Pipeline, o.AccountName, u.Host, opts)
+}
+
+func setOptionsFromURLParams(q url.Values, opts *azureblob.Options) error {
+ for param, values := range q {
+ if len(values) > 1 {
+ return fmt.Errorf("multiple values of %v not allowed", param)
+ }
+
+ value := values[0]
+ switch param {
+ case "domain":
+ opts.StorageDomain = azureblob.StorageDomain(value)
+ default:
+ return fmt.Errorf("unknown query parameter %q", param)
+ }
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/config/url_openers_test.go b/workhorse/internal/config/url_openers_test.go
new file mode 100644
index 00000000000..6a851cacbb8
--- /dev/null
+++ b/workhorse/internal/config/url_openers_test.go
@@ -0,0 +1,117 @@
+package config
+
+import (
+ "context"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gocloud.dev/blob/azureblob"
+)
+
+func TestURLOpeners(t *testing.T) {
+ cfg, err := LoadConfig(azureConfig)
+ require.NoError(t, err)
+
+ require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
+
+ require.NoError(t, cfg.RegisterGoCloudURLOpeners())
+ require.NotNil(t, cfg.ObjectStorageConfig.URLMux)
+
+ tests := []struct {
+ url string
+ valid bool
+ }{
+
+ {
+ url: "azblob://container/object",
+ valid: true,
+ },
+ {
+ url: "azblob://container/object?domain=core.windows.net",
+ valid: true,
+ },
+ {
+ url: "azblob://container/object?domain=core.windows.net&domain=test",
+ valid: false,
+ },
+ {
+ url: "azblob://container/object?param=value",
+ valid: false,
+ },
+ {
+ url: "s3://bucket/object",
+ valid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.url, func(t *testing.T) {
+ ctx := context.Background()
+ url, err := url.Parse(test.url)
+ require.NoError(t, err)
+
+ bucket, err := cfg.ObjectStorageConfig.URLMux.OpenBucketURL(ctx, url)
+ if bucket != nil {
+ defer bucket.Close()
+ }
+
+ if test.valid {
+ require.NotNil(t, bucket)
+ require.NoError(t, err)
+ } else {
+ require.Error(t, err)
+ }
+ })
+ }
+}
+
+func TestTestURLOpenersForParams(t *testing.T) {
+ tests := []struct {
+ name string
+ currOpts azureblob.Options
+ query url.Values
+ wantOpts azureblob.Options
+ wantErr bool
+ }{
+ {
+ name: "InvalidParam",
+ query: url.Values{
+ "foo": {"bar"},
+ },
+ wantErr: true,
+ },
+ {
+ name: "StorageDomain",
+ query: url.Values{
+ "domain": {"blob.core.usgovcloudapi.net"},
+ },
+ wantOpts: azureblob.Options{StorageDomain: "blob.core.usgovcloudapi.net"},
+ },
+ {
+ name: "duplicate StorageDomain",
+ query: url.Values{
+ "domain": {"blob.core.usgovcloudapi.net", "blob.core.windows.net"},
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ o := &azureURLOpener{
+ URLOpener: &azureblob.URLOpener{
+ Options: test.currOpts,
+ },
+ }
+ err := setOptionsFromURLParams(test.query, &o.Options)
+
+ if test.wantErr {
+ require.NotNil(t, err)
+ } else {
+ require.Nil(t, err)
+ require.Equal(t, test.wantOpts, o.Options)
+ }
+ })
+ }
+}
diff --git a/workhorse/internal/filestore/file_handler.go b/workhorse/internal/filestore/file_handler.go
new file mode 100644
index 00000000000..19764e9a5cf
--- /dev/null
+++ b/workhorse/internal/filestore/file_handler.go
@@ -0,0 +1,257 @@
+package filestore
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "strconv"
+ "time"
+
+ "github.com/dgrijalva/jwt-go"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
+)
+
+type SizeError error
+
+// ErrEntityTooLarge means that the uploaded content is bigger then maximum allowed size
+var ErrEntityTooLarge = errors.New("entity is too large")
+
+// FileHandler represent a file that has been processed for upload
+// it may be either uploaded to an ObjectStore and/or saved on local path.
+type FileHandler struct {
+ // LocalPath is the path on the disk where file has been stored
+ LocalPath string
+
+ // RemoteID is the objectID provided by GitLab Rails
+ RemoteID string
+ // RemoteURL is ObjectStore URL provided by GitLab Rails
+ RemoteURL string
+
+ // Size is the persisted file size
+ Size int64
+
+ // Name is the resource name to send back to GitLab rails.
+ // It differ from the real file name in order to avoid file collisions
+ Name string
+
+ // a map containing different hashes
+ hashes map[string]string
+}
+
+type uploadClaims struct {
+ Upload map[string]string `json:"upload"`
+ jwt.StandardClaims
+}
+
+// SHA256 hash of the handled file
+func (fh *FileHandler) SHA256() string {
+ return fh.hashes["sha256"]
+}
+
+// MD5 hash of the handled file
+func (fh *FileHandler) MD5() string {
+ return fh.hashes["md5"]
+}
+
+// GitLabFinalizeFields returns a map with all the fields GitLab Rails needs in order to finalize the upload.
+func (fh *FileHandler) GitLabFinalizeFields(prefix string) (map[string]string, error) {
+ // TODO: remove `data` these once rails fully and exclusively support `signedData` (https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/263)
+ data := make(map[string]string)
+ signedData := make(map[string]string)
+ key := func(field string) string {
+ if prefix == "" {
+ return field
+ }
+
+ return fmt.Sprintf("%s.%s", prefix, field)
+ }
+
+ for k, v := range map[string]string{
+ "name": fh.Name,
+ "path": fh.LocalPath,
+ "remote_url": fh.RemoteURL,
+ "remote_id": fh.RemoteID,
+ "size": strconv.FormatInt(fh.Size, 10),
+ } {
+ data[key(k)] = v
+ signedData[k] = v
+ }
+
+ for hashName, hash := range fh.hashes {
+ data[key(hashName)] = hash
+ signedData[hashName] = hash
+ }
+
+ claims := uploadClaims{Upload: signedData, StandardClaims: secret.DefaultClaims}
+ jwtData, err := secret.JWTTokenString(claims)
+ if err != nil {
+ return nil, err
+ }
+ data[key("gitlab-workhorse-upload")] = jwtData
+
+ return data, nil
+}
+
+type consumer interface {
+ Consume(context.Context, io.Reader, time.Time) (int64, error)
+}
+
+// SaveFileFromReader persists the provided reader content to all the location specified in opts. A cleanup will be performed once ctx is Done
+// Make sure the provided context will not expire before finalizing upload with GitLab Rails.
+func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts *SaveFileOpts) (fh *FileHandler, err error) {
+ var uploadDestination consumer
+ fh = &FileHandler{
+ Name: opts.TempFilePrefix,
+ RemoteID: opts.RemoteID,
+ RemoteURL: opts.RemoteURL,
+ }
+ hashes := newMultiHash()
+ reader = io.TeeReader(reader, hashes.Writer)
+
+ var clientMode string
+
+ switch {
+ case opts.IsLocal():
+ clientMode = "local"
+ uploadDestination, err = fh.uploadLocalFile(ctx, opts)
+ case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsGoCloud():
+ clientMode = fmt.Sprintf("go_cloud:%s", opts.ObjectStorageConfig.Provider)
+ p := &objectstore.GoCloudObjectParams{
+ Ctx: ctx,
+ Mux: opts.ObjectStorageConfig.URLMux,
+ BucketURL: opts.ObjectStorageConfig.GoCloudConfig.URL,
+ ObjectName: opts.RemoteTempObjectID,
+ }
+ uploadDestination, err = objectstore.NewGoCloudObject(p)
+ case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsAWS() && opts.ObjectStorageConfig.IsValid():
+ clientMode = "s3"
+ uploadDestination, err = objectstore.NewS3Object(
+ opts.RemoteTempObjectID,
+ opts.ObjectStorageConfig.S3Credentials,
+ opts.ObjectStorageConfig.S3Config,
+ )
+ case opts.IsMultipart():
+ clientMode = "multipart"
+ uploadDestination, err = objectstore.NewMultipart(
+ opts.PresignedParts,
+ opts.PresignedCompleteMultipart,
+ opts.PresignedAbortMultipart,
+ opts.PresignedDelete,
+ opts.PutHeaders,
+ opts.PartSize,
+ )
+ default:
+ clientMode = "http"
+ uploadDestination, err = objectstore.NewObject(
+ opts.PresignedPut,
+ opts.PresignedDelete,
+ opts.PutHeaders,
+ size,
+ )
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.MaximumSize > 0 {
+ if size > opts.MaximumSize {
+ return nil, SizeError(fmt.Errorf("the upload size %d is over maximum of %d bytes", size, opts.MaximumSize))
+ }
+
+ hlr := &hardLimitReader{r: reader, n: opts.MaximumSize}
+ reader = hlr
+ defer func() {
+ if hlr.n < 0 {
+ err = ErrEntityTooLarge
+ }
+ }()
+ }
+
+ fh.Size, err = uploadDestination.Consume(ctx, reader, opts.Deadline)
+ if err != nil {
+ if err == objectstore.ErrNotEnoughParts {
+ err = ErrEntityTooLarge
+ }
+ return nil, err
+ }
+
+ if size != -1 && size != fh.Size {
+ return nil, SizeError(fmt.Errorf("expected %d bytes but got only %d", size, fh.Size))
+ }
+
+ logger := log.WithContextFields(ctx, log.Fields{
+ "copied_bytes": fh.Size,
+ "is_local": opts.IsLocal(),
+ "is_multipart": opts.IsMultipart(),
+ "is_remote": !opts.IsLocal(),
+ "remote_id": opts.RemoteID,
+ "temp_file_prefix": opts.TempFilePrefix,
+ "client_mode": clientMode,
+ })
+
+ if opts.IsLocal() {
+ logger = logger.WithField("local_temp_path", opts.LocalTempPath)
+ } else {
+ logger = logger.WithField("remote_temp_object", opts.RemoteTempObjectID)
+ }
+
+ logger.Info("saved file")
+ fh.hashes = hashes.finish()
+ return fh, nil
+}
+
+func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) (consumer, error) {
+ // make sure TempFolder exists
+ err := os.MkdirAll(opts.LocalTempPath, 0700)
+ if err != nil {
+ return nil, fmt.Errorf("uploadLocalFile: mkdir %q: %v", opts.LocalTempPath, err)
+ }
+
+ file, err := ioutil.TempFile(opts.LocalTempPath, opts.TempFilePrefix)
+ if err != nil {
+ return nil, fmt.Errorf("uploadLocalFile: create file: %v", err)
+ }
+
+ go func() {
+ <-ctx.Done()
+ os.Remove(file.Name())
+ }()
+
+ fh.LocalPath = file.Name()
+ return &localUpload{file}, nil
+}
+
+type localUpload struct{ io.WriteCloser }
+
+func (loc *localUpload) Consume(_ context.Context, r io.Reader, _ time.Time) (int64, error) {
+ n, err := io.Copy(loc.WriteCloser, r)
+ errClose := loc.Close()
+ if err == nil {
+ err = errClose
+ }
+ return n, err
+}
+
+// SaveFileFromDisk open the local file fileName and calls SaveFileFromReader
+func SaveFileFromDisk(ctx context.Context, fileName string, opts *SaveFileOpts) (fh *FileHandler, err error) {
+ file, err := os.Open(fileName)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
+
+ fi, err := file.Stat()
+ if err != nil {
+ return nil, err
+ }
+
+ return SaveFileFromReader(ctx, file, fi.Size(), opts)
+}
diff --git a/workhorse/internal/filestore/file_handler_test.go b/workhorse/internal/filestore/file_handler_test.go
new file mode 100644
index 00000000000..e79e9d0f292
--- /dev/null
+++ b/workhorse/internal/filestore/file_handler_test.go
@@ -0,0 +1,551 @@
+package filestore_test
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/dgrijalva/jwt-go"
+ "github.com/stretchr/testify/require"
+ "gocloud.dev/blob"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+func testDeadline() time.Time {
+ return time.Now().Add(filestore.DefaultObjectStoreTimeout)
+}
+
+func requireFileGetsRemovedAsync(t *testing.T, filePath string) {
+ var err error
+
+ // Poll because the file removal is async
+ for i := 0; i < 100; i++ {
+ _, err = os.Stat(filePath)
+ if err != nil {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ require.True(t, os.IsNotExist(err), "File hasn't been deleted during cleanup")
+}
+
+func requireObjectStoreDeletedAsync(t *testing.T, expectedDeletes int, osStub *test.ObjectstoreStub) {
+ // Poll because the object removal is async
+ for i := 0; i < 100; i++ {
+ if osStub.DeletesCnt() == expectedDeletes {
+ break
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ require.Equal(t, expectedDeletes, osStub.DeletesCnt(), "Object not deleted")
+}
+
+func TestSaveFileWrongSize(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpFolder)
+
+ opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file"}
+ fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize+1, opts)
+ require.Error(t, err)
+ _, isSizeError := err.(filestore.SizeError)
+ require.True(t, isSizeError, "Should fail with SizeError")
+ require.Nil(t, fh)
+}
+
+func TestSaveFileWithKnownSizeExceedLimit(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpFolder)
+
+ opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file", MaximumSize: test.ObjectSize - 1}
+ fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, opts)
+ require.Error(t, err)
+ _, isSizeError := err.(filestore.SizeError)
+ require.True(t, isSizeError, "Should fail with SizeError")
+ require.Nil(t, fh)
+}
+
+func TestSaveFileWithUnknownSizeExceedLimit(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpFolder)
+
+ opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file", MaximumSize: test.ObjectSize - 1}
+ fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), -1, opts)
+ require.Equal(t, err, filestore.ErrEntityTooLarge)
+ require.Nil(t, fh)
+}
+
+func TestSaveFromDiskNotExistingFile(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ fh, err := filestore.SaveFileFromDisk(ctx, "/I/do/not/exist", &filestore.SaveFileOpts{})
+ require.Error(t, err, "SaveFileFromDisk should fail")
+ require.True(t, os.IsNotExist(err), "Provided file should not exists")
+ require.Nil(t, fh, "On error FileHandler should be nil")
+}
+
+func TestSaveFileWrongETag(t *testing.T) {
+ tests := []struct {
+ name string
+ multipart bool
+ }{
+ {name: "single part"},
+ {name: "multi part", multipart: true},
+ }
+
+ for _, spec := range tests {
+ t.Run(spec.name, func(t *testing.T) {
+ osStub, ts := test.StartObjectStoreWithCustomMD5(map[string]string{test.ObjectPath: "brokenMD5"})
+ defer ts.Close()
+
+ objectURL := ts.URL + test.ObjectPath
+
+ opts := &filestore.SaveFileOpts{
+ RemoteID: "test-file",
+ RemoteURL: objectURL,
+ PresignedPut: objectURL + "?Signature=ASignature",
+ PresignedDelete: objectURL + "?Signature=AnotherSignature",
+ Deadline: testDeadline(),
+ }
+ if spec.multipart {
+ opts.PresignedParts = []string{objectURL + "?partNumber=1"}
+ opts.PresignedCompleteMultipart = objectURL + "?Signature=CompleteSig"
+ opts.PresignedAbortMultipart = objectURL + "?Signature=AbortSig"
+ opts.PartSize = test.ObjectSize
+
+ osStub.InitiateMultipartUpload(test.ObjectPath)
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, opts)
+ require.Nil(t, fh)
+ require.Error(t, err)
+ require.Equal(t, 1, osStub.PutsCnt(), "File not uploaded")
+
+ cancel() // this will trigger an async cleanup
+ requireObjectStoreDeletedAsync(t, 1, osStub)
+ require.False(t, spec.multipart && osStub.IsMultipartUpload(test.ObjectPath), "there must be no multipart upload in progress now")
+ })
+ }
+}
+
+func TestSaveFileFromDiskToLocalPath(t *testing.T) {
+ f, err := ioutil.TempFile("", "workhorse-test")
+ require.NoError(t, err)
+ defer os.Remove(f.Name())
+
+ _, err = fmt.Fprint(f, test.ObjectContent)
+ require.NoError(t, err)
+
+ tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpFolder)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder}
+ fh, err := filestore.SaveFileFromDisk(ctx, f.Name(), opts)
+ require.NoError(t, err)
+ require.NotNil(t, fh)
+
+ require.NotEmpty(t, fh.LocalPath, "File not persisted on disk")
+ _, err = os.Stat(fh.LocalPath)
+ require.NoError(t, err)
+}
+
+func TestSaveFile(t *testing.T) {
+ testhelper.ConfigureSecret()
+
+ type remote int
+ const (
+ notRemote remote = iota
+ remoteSingle
+ remoteMultipart
+ )
+
+ tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpFolder)
+
+ tests := []struct {
+ name string
+ local bool
+ remote remote
+ }{
+ {name: "Local only", local: true},
+ {name: "Remote Single only", remote: remoteSingle},
+ {name: "Remote Multipart only", remote: remoteMultipart},
+ }
+
+ for _, spec := range tests {
+ t.Run(spec.name, func(t *testing.T) {
+ var opts filestore.SaveFileOpts
+ var expectedDeletes, expectedPuts int
+
+ osStub, ts := test.StartObjectStore()
+ defer ts.Close()
+
+ switch spec.remote {
+ case remoteSingle:
+ objectURL := ts.URL + test.ObjectPath
+
+ opts.RemoteID = "test-file"
+ opts.RemoteURL = objectURL
+ opts.PresignedPut = objectURL + "?Signature=ASignature"
+ opts.PresignedDelete = objectURL + "?Signature=AnotherSignature"
+ opts.Deadline = testDeadline()
+
+ expectedDeletes = 1
+ expectedPuts = 1
+ case remoteMultipart:
+ objectURL := ts.URL + test.ObjectPath
+
+ opts.RemoteID = "test-file"
+ opts.RemoteURL = objectURL
+ opts.PresignedDelete = objectURL + "?Signature=AnotherSignature"
+ opts.PartSize = int64(len(test.ObjectContent)/2) + 1
+ opts.PresignedParts = []string{objectURL + "?partNumber=1", objectURL + "?partNumber=2"}
+ opts.PresignedCompleteMultipart = objectURL + "?Signature=CompleteSignature"
+ opts.Deadline = testDeadline()
+
+ osStub.InitiateMultipartUpload(test.ObjectPath)
+ expectedDeletes = 1
+ expectedPuts = 2
+ }
+
+ if spec.local {
+ opts.LocalTempPath = tmpFolder
+ opts.TempFilePrefix = "test-file"
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts)
+ require.NoError(t, err)
+ require.NotNil(t, fh)
+
+ require.Equal(t, opts.RemoteID, fh.RemoteID)
+ require.Equal(t, opts.RemoteURL, fh.RemoteURL)
+
+ if spec.local {
+ require.NotEmpty(t, fh.LocalPath, "File not persisted on disk")
+ _, err := os.Stat(fh.LocalPath)
+ require.NoError(t, err)
+
+ dir := path.Dir(fh.LocalPath)
+ require.Equal(t, opts.LocalTempPath, dir)
+ filename := path.Base(fh.LocalPath)
+ beginsWithPrefix := strings.HasPrefix(filename, opts.TempFilePrefix)
+ require.True(t, beginsWithPrefix, fmt.Sprintf("LocalPath filename %q do not begin with TempFilePrefix %q", filename, opts.TempFilePrefix))
+ } else {
+ require.Empty(t, fh.LocalPath, "LocalPath must be empty for non local uploads")
+ }
+
+ require.Equal(t, test.ObjectSize, fh.Size)
+ require.Equal(t, test.ObjectMD5, fh.MD5())
+ require.Equal(t, test.ObjectSHA256, fh.SHA256())
+
+ require.Equal(t, expectedPuts, osStub.PutsCnt(), "ObjectStore PutObject count mismatch")
+ require.Equal(t, 0, osStub.DeletesCnt(), "File deleted too early")
+
+ cancel() // this will trigger an async cleanup
+ requireObjectStoreDeletedAsync(t, expectedDeletes, osStub)
+ requireFileGetsRemovedAsync(t, fh.LocalPath)
+
+ // checking generated fields
+ fields, err := fh.GitLabFinalizeFields("file")
+ require.NoError(t, err)
+
+ checkFileHandlerWithFields(t, fh, fields, "file")
+
+ token, jwtErr := jwt.ParseWithClaims(fields["file.gitlab-workhorse-upload"], &testhelper.UploadClaims{}, testhelper.ParseJWT)
+ require.NoError(t, jwtErr)
+
+ uploadFields := token.Claims.(*testhelper.UploadClaims).Upload
+
+ checkFileHandlerWithFields(t, fh, uploadFields, "")
+ })
+ }
+}
+
+func TestSaveFileWithS3WorkhorseClient(t *testing.T) {
+ tests := []struct {
+ name string
+ objectSize int64
+ maxSize int64
+ expectedErr error
+ }{
+ {
+ name: "known size with no limit",
+ objectSize: test.ObjectSize,
+ },
+ {
+ name: "unknown size with no limit",
+ objectSize: -1,
+ },
+ {
+ name: "unknown object size with limit",
+ objectSize: -1,
+ maxSize: test.ObjectSize - 1,
+ expectedErr: filestore.ErrEntityTooLarge,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+
+ s3Creds, s3Config, sess, ts := test.SetupS3(t, "")
+ defer ts.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ remoteObject := "tmp/test-file/1"
+ opts := filestore.SaveFileOpts{
+ RemoteID: "test-file",
+ Deadline: testDeadline(),
+ UseWorkhorseClient: true,
+ RemoteTempObjectID: remoteObject,
+ ObjectStorageConfig: filestore.ObjectStorageConfig{
+ Provider: "AWS",
+ S3Credentials: s3Creds,
+ S3Config: s3Config,
+ },
+ MaximumSize: tc.maxSize,
+ }
+
+ _, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), tc.objectSize, &opts)
+
+ if tc.expectedErr == nil {
+ require.NoError(t, err)
+ test.S3ObjectExists(t, sess, s3Config, remoteObject, test.ObjectContent)
+ } else {
+ require.Equal(t, tc.expectedErr, err)
+ test.S3ObjectDoesNotExist(t, sess, s3Config, remoteObject)
+ }
+ })
+ }
+}
+
+func TestSaveFileWithAzureWorkhorseClient(t *testing.T) {
+ mux, bucketDir, cleanup := test.SetupGoCloudFileBucket(t, "azblob")
+ defer cleanup()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ remoteObject := "tmp/test-file/1"
+ opts := filestore.SaveFileOpts{
+ RemoteID: "test-file",
+ Deadline: testDeadline(),
+ UseWorkhorseClient: true,
+ RemoteTempObjectID: remoteObject,
+ ObjectStorageConfig: filestore.ObjectStorageConfig{
+ Provider: "AzureRM",
+ URLMux: mux,
+ GoCloudConfig: config.GoCloudConfig{URL: "azblob://test-container"},
+ },
+ }
+
+ _, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts)
+ require.NoError(t, err)
+
+ test.GoCloudObjectExists(t, bucketDir, remoteObject)
+}
+
+func TestSaveFileWithUnknownGoCloudScheme(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ mux := new(blob.URLMux)
+
+ remoteObject := "tmp/test-file/1"
+ opts := filestore.SaveFileOpts{
+ RemoteID: "test-file",
+ Deadline: testDeadline(),
+ UseWorkhorseClient: true,
+ RemoteTempObjectID: remoteObject,
+ ObjectStorageConfig: filestore.ObjectStorageConfig{
+ Provider: "SomeCloud",
+ URLMux: mux,
+ GoCloudConfig: config.GoCloudConfig{URL: "foo://test-container"},
+ },
+ }
+
+ _, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts)
+ require.Error(t, err)
+}
+
+func TestSaveMultipartInBodyFailure(t *testing.T) {
+ osStub, ts := test.StartObjectStore()
+ defer ts.Close()
+
+ // this is a broken path because it contains bucket name but no key
+ // this is the only way to get an in-body failure from our ObjectStoreStub
+ objectPath := "/bucket-but-no-object-key"
+ objectURL := ts.URL + objectPath
+ opts := filestore.SaveFileOpts{
+ RemoteID: "test-file",
+ RemoteURL: objectURL,
+ PartSize: test.ObjectSize,
+ PresignedParts: []string{objectURL + "?partNumber=1", objectURL + "?partNumber=2"},
+ PresignedCompleteMultipart: objectURL + "?Signature=CompleteSignature",
+ Deadline: testDeadline(),
+ }
+
+ osStub.InitiateMultipartUpload(objectPath)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(test.ObjectContent), test.ObjectSize, &opts)
+ require.Nil(t, fh)
+ require.Error(t, err)
+ require.EqualError(t, err, test.MultipartUploadInternalError().Error())
+}
+
+func TestSaveRemoteFileWithLimit(t *testing.T) {
+ testhelper.ConfigureSecret()
+
+ type remote int
+ const (
+ notRemote remote = iota
+ remoteSingle
+ remoteMultipart
+ )
+
+ remoteTypes := []remote{remoteSingle, remoteMultipart}
+
+ tests := []struct {
+ name string
+ objectSize int64
+ maxSize int64
+ expectedErr error
+ testData string
+ }{
+ {
+ name: "known size with no limit",
+ testData: test.ObjectContent,
+ objectSize: test.ObjectSize,
+ },
+ {
+ name: "unknown size with no limit",
+ testData: test.ObjectContent,
+ objectSize: -1,
+ },
+ {
+ name: "unknown object size with limit",
+ testData: test.ObjectContent,
+ objectSize: -1,
+ maxSize: test.ObjectSize - 1,
+ expectedErr: filestore.ErrEntityTooLarge,
+ },
+ {
+ name: "large object with unknown size with limit",
+ testData: string(make([]byte, 20000)),
+ objectSize: -1,
+ maxSize: 19000,
+ expectedErr: filestore.ErrEntityTooLarge,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ var opts filestore.SaveFileOpts
+
+ for _, remoteType := range remoteTypes {
+ tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpFolder)
+
+ osStub, ts := test.StartObjectStore()
+ defer ts.Close()
+
+ switch remoteType {
+ case remoteSingle:
+ objectURL := ts.URL + test.ObjectPath
+
+ opts.RemoteID = "test-file"
+ opts.RemoteURL = objectURL
+ opts.PresignedPut = objectURL + "?Signature=ASignature"
+ opts.PresignedDelete = objectURL + "?Signature=AnotherSignature"
+ opts.Deadline = testDeadline()
+ opts.MaximumSize = tc.maxSize
+ case remoteMultipart:
+ objectURL := ts.URL + test.ObjectPath
+
+ opts.RemoteID = "test-file"
+ opts.RemoteURL = objectURL
+ opts.PresignedDelete = objectURL + "?Signature=AnotherSignature"
+ opts.PartSize = int64(len(tc.testData)/2) + 1
+ opts.PresignedParts = []string{objectURL + "?partNumber=1", objectURL + "?partNumber=2"}
+ opts.PresignedCompleteMultipart = objectURL + "?Signature=CompleteSignature"
+ opts.Deadline = testDeadline()
+ opts.MaximumSize = tc.maxSize
+
+ require.Less(t, int64(len(tc.testData)), int64(len(opts.PresignedParts))*opts.PartSize, "check part size calculation")
+
+ osStub.InitiateMultipartUpload(test.ObjectPath)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(tc.testData), tc.objectSize, &opts)
+
+ if tc.expectedErr == nil {
+ require.NoError(t, err)
+ require.NotNil(t, fh)
+ } else {
+ require.True(t, errors.Is(err, tc.expectedErr))
+ require.Nil(t, fh)
+ }
+ }
+ })
+ }
+}
+
+func checkFileHandlerWithFields(t *testing.T, fh *filestore.FileHandler, fields map[string]string, prefix string) {
+ key := func(field string) string {
+ if prefix == "" {
+ return field
+ }
+
+ return fmt.Sprintf("%s.%s", prefix, field)
+ }
+
+ require.Equal(t, fh.Name, fields[key("name")])
+ require.Equal(t, fh.LocalPath, fields[key("path")])
+ require.Equal(t, fh.RemoteURL, fields[key("remote_url")])
+ require.Equal(t, fh.RemoteID, fields[key("remote_id")])
+ require.Equal(t, strconv.FormatInt(test.ObjectSize, 10), fields[key("size")])
+ require.Equal(t, test.ObjectMD5, fields[key("md5")])
+ require.Equal(t, test.ObjectSHA1, fields[key("sha1")])
+ require.Equal(t, test.ObjectSHA256, fields[key("sha256")])
+ require.Equal(t, test.ObjectSHA512, fields[key("sha512")])
+}
diff --git a/workhorse/internal/filestore/multi_hash.go b/workhorse/internal/filestore/multi_hash.go
new file mode 100644
index 00000000000..40efd3a5c1f
--- /dev/null
+++ b/workhorse/internal/filestore/multi_hash.go
@@ -0,0 +1,48 @@
+package filestore
+
+import (
+ "crypto/md5"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/hex"
+ "hash"
+ "io"
+)
+
+var hashFactories = map[string](func() hash.Hash){
+ "md5": md5.New,
+ "sha1": sha1.New,
+ "sha256": sha256.New,
+ "sha512": sha512.New,
+}
+
+type multiHash struct {
+ io.Writer
+ hashes map[string]hash.Hash
+}
+
+func newMultiHash() (m *multiHash) {
+ m = &multiHash{}
+ m.hashes = make(map[string]hash.Hash)
+
+ var writers []io.Writer
+ for hash, hashFactory := range hashFactories {
+ writer := hashFactory()
+
+ m.hashes[hash] = writer
+ writers = append(writers, writer)
+ }
+
+ m.Writer = io.MultiWriter(writers...)
+ return m
+}
+
+func (m *multiHash) finish() map[string]string {
+ h := make(map[string]string)
+ for hashName, hash := range m.hashes {
+ checksum := hash.Sum(nil)
+ h[hashName] = hex.EncodeToString(checksum)
+ }
+ return h
+}
diff --git a/workhorse/internal/filestore/reader.go b/workhorse/internal/filestore/reader.go
new file mode 100644
index 00000000000..b1045b991fc
--- /dev/null
+++ b/workhorse/internal/filestore/reader.go
@@ -0,0 +1,17 @@
+package filestore
+
+import "io"
+
+type hardLimitReader struct {
+ r io.Reader
+ n int64
+}
+
+func (h *hardLimitReader) Read(p []byte) (int, error) {
+ nRead, err := h.r.Read(p)
+ h.n -= int64(nRead)
+ if h.n < 0 {
+ err = ErrEntityTooLarge
+ }
+ return nRead, err
+}
diff --git a/workhorse/internal/filestore/reader_test.go b/workhorse/internal/filestore/reader_test.go
new file mode 100644
index 00000000000..424d921ecaf
--- /dev/null
+++ b/workhorse/internal/filestore/reader_test.go
@@ -0,0 +1,46 @@
+package filestore
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strings"
+ "testing"
+ "testing/iotest"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestHardLimitReader(t *testing.T) {
+ const text = "hello world"
+ r := iotest.OneByteReader(
+ &hardLimitReader{
+ r: strings.NewReader(text),
+ n: int64(len(text)),
+ },
+ )
+
+ out, err := ioutil.ReadAll(r)
+ require.NoError(t, err)
+ require.Equal(t, text, string(out))
+}
+
+func TestHardLimitReaderFail(t *testing.T) {
+ const text = "hello world"
+
+ for bufSize := len(text) / 2; bufSize < len(text)*2; bufSize++ {
+ t.Run(fmt.Sprintf("bufsize:%d", bufSize), func(t *testing.T) {
+ r := &hardLimitReader{
+ r: iotest.DataErrReader(strings.NewReader(text)),
+ n: int64(len(text)) - 1,
+ }
+ buf := make([]byte, bufSize)
+
+ var err error
+ for i := 0; err == nil && i < 1000; i++ {
+ _, err = r.Read(buf)
+ }
+
+ require.Equal(t, ErrEntityTooLarge, err)
+ })
+ }
+}
diff --git a/workhorse/internal/filestore/save_file_opts.go b/workhorse/internal/filestore/save_file_opts.go
new file mode 100644
index 00000000000..1eb708c3f55
--- /dev/null
+++ b/workhorse/internal/filestore/save_file_opts.go
@@ -0,0 +1,171 @@
+package filestore
+
+import (
+ "errors"
+ "strings"
+ "time"
+
+ "gocloud.dev/blob"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+)
+
+// DefaultObjectStoreTimeout is the timeout for ObjectStore upload operation
+const DefaultObjectStoreTimeout = 4 * time.Hour
+
+type ObjectStorageConfig struct {
+ Provider string
+
+ S3Credentials config.S3Credentials
+ S3Config config.S3Config
+
+ // GoCloud mux that maps azureblob:// and future URLs (e.g. s3://, gcs://, etc.) to a handler
+ URLMux *blob.URLMux
+
+ // Azure credentials are registered at startup in the GoCloud URLMux, so only the container name is needed
+ GoCloudConfig config.GoCloudConfig
+}
+
+// SaveFileOpts represents all the options available for saving a file to object store
+type SaveFileOpts struct {
+ // TempFilePrefix is the prefix used to create temporary local file
+ TempFilePrefix string
+ // LocalTempPath is the directory where to write a local copy of the file
+ LocalTempPath string
+ // RemoteID is the remote ObjectID provided by GitLab
+ RemoteID string
+ // RemoteURL is the final URL of the file
+ RemoteURL string
+ // PresignedPut is a presigned S3 PutObject compatible URL
+ PresignedPut string
+ // PresignedDelete is a presigned S3 DeleteObject compatible URL.
+ PresignedDelete string
+ // HTTP headers to be sent along with PUT request
+ PutHeaders map[string]string
+ // Whether to ignore Rails pre-signed URLs and have Workhorse directly access object storage provider
+ UseWorkhorseClient bool
+ // If UseWorkhorseClient is true, this is the temporary object name to store the file
+ RemoteTempObjectID string
+ // Workhorse object storage client (e.g. S3) parameters
+ ObjectStorageConfig ObjectStorageConfig
+ // Deadline it the S3 operation deadline, the upload will be aborted if not completed in time
+ Deadline time.Time
+ // The maximum accepted size in bytes of the upload
+ MaximumSize int64
+
+ //MultipartUpload parameters
+ // PartSize is the exact size of each uploaded part. Only the last one can be smaller
+ PartSize int64
+ // PresignedParts contains the presigned URLs for each part
+ PresignedParts []string
+ // PresignedCompleteMultipart is a presigned URL for CompleteMulipartUpload
+ PresignedCompleteMultipart string
+ // PresignedAbortMultipart is a presigned URL for AbortMultipartUpload
+ PresignedAbortMultipart string
+}
+
+// UseWorkhorseClientEnabled checks if the options require direct access to object storage
+func (s *SaveFileOpts) UseWorkhorseClientEnabled() bool {
+ return s.UseWorkhorseClient && s.ObjectStorageConfig.IsValid() && s.RemoteTempObjectID != ""
+}
+
+// IsLocal checks if the options require the writing of the file on disk
+func (s *SaveFileOpts) IsLocal() bool {
+ return s.LocalTempPath != ""
+}
+
+// IsMultipart checks if the options requires a Multipart upload
+func (s *SaveFileOpts) IsMultipart() bool {
+ return s.PartSize > 0
+}
+
+// GetOpts converts GitLab api.Response to a proper SaveFileOpts
+func GetOpts(apiResponse *api.Response) (*SaveFileOpts, error) {
+ timeout := time.Duration(apiResponse.RemoteObject.Timeout) * time.Second
+ if timeout == 0 {
+ timeout = DefaultObjectStoreTimeout
+ }
+
+ opts := SaveFileOpts{
+ LocalTempPath: apiResponse.TempPath,
+ RemoteID: apiResponse.RemoteObject.ID,
+ RemoteURL: apiResponse.RemoteObject.GetURL,
+ PresignedPut: apiResponse.RemoteObject.StoreURL,
+ PresignedDelete: apiResponse.RemoteObject.DeleteURL,
+ PutHeaders: apiResponse.RemoteObject.PutHeaders,
+ UseWorkhorseClient: apiResponse.RemoteObject.UseWorkhorseClient,
+ RemoteTempObjectID: apiResponse.RemoteObject.RemoteTempObjectID,
+ Deadline: time.Now().Add(timeout),
+ MaximumSize: apiResponse.MaximumSize,
+ }
+
+ if opts.LocalTempPath != "" && opts.RemoteID != "" {
+ return nil, errors.New("API response has both TempPath and RemoteObject")
+ }
+
+ if opts.LocalTempPath == "" && opts.RemoteID == "" {
+ return nil, errors.New("API response has neither TempPath nor RemoteObject")
+ }
+
+ objectStorageParams := apiResponse.RemoteObject.ObjectStorage
+ if opts.UseWorkhorseClient && objectStorageParams != nil {
+ opts.ObjectStorageConfig.Provider = objectStorageParams.Provider
+ opts.ObjectStorageConfig.S3Config = objectStorageParams.S3Config
+ opts.ObjectStorageConfig.GoCloudConfig = objectStorageParams.GoCloudConfig
+ }
+
+ // Backwards compatibility to ensure API servers that do not include the
+ // CustomPutHeaders flag will default to the original content type.
+ if !apiResponse.RemoteObject.CustomPutHeaders {
+ opts.PutHeaders = make(map[string]string)
+ opts.PutHeaders["Content-Type"] = "application/octet-stream"
+ }
+
+ if multiParams := apiResponse.RemoteObject.MultipartUpload; multiParams != nil {
+ opts.PartSize = multiParams.PartSize
+ opts.PresignedCompleteMultipart = multiParams.CompleteURL
+ opts.PresignedAbortMultipart = multiParams.AbortURL
+ opts.PresignedParts = append([]string(nil), multiParams.PartURLs...)
+ }
+
+ return &opts, nil
+}
+
+func (c *ObjectStorageConfig) IsAWS() bool {
+ return strings.EqualFold(c.Provider, "AWS") || strings.EqualFold(c.Provider, "S3")
+}
+
+func (c *ObjectStorageConfig) IsAzure() bool {
+ return strings.EqualFold(c.Provider, "AzureRM")
+}
+
+func (c *ObjectStorageConfig) IsGoCloud() bool {
+ return c.GoCloudConfig.URL != ""
+}
+
+func (c *ObjectStorageConfig) IsValid() bool {
+ if c.IsAWS() {
+ return c.S3Config.Bucket != "" && c.S3Config.Region != "" && c.s3CredentialsValid()
+ } else if c.IsGoCloud() {
+ // We could parse and validate the URL, but GoCloud providers
+ // such as AzureRM don't have a fallback to normal HTTP, so we
+ // always want to try the GoCloud path if there is a URL.
+ return true
+ }
+
+ return false
+}
+
+func (c *ObjectStorageConfig) s3CredentialsValid() bool {
+ // We need to be able to distinguish between two cases of AWS access:
+ // 1. AWS access via key and secret, but credentials not configured in Workhorse
+ // 2. IAM instance profiles used
+ if c.S3Config.UseIamProfile {
+ return true
+ } else if c.S3Credentials.AwsAccessKeyID != "" && c.S3Credentials.AwsSecretAccessKey != "" {
+ return true
+ }
+
+ return false
+}
diff --git a/workhorse/internal/filestore/save_file_opts_test.go b/workhorse/internal/filestore/save_file_opts_test.go
new file mode 100644
index 00000000000..2d6cd683b51
--- /dev/null
+++ b/workhorse/internal/filestore/save_file_opts_test.go
@@ -0,0 +1,331 @@
+package filestore_test
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+)
+
+func TestSaveFileOptsLocalAndRemote(t *testing.T) {
+ tests := []struct {
+ name string
+ localTempPath string
+ presignedPut string
+ partSize int64
+ isLocal bool
+ isRemote bool
+ isMultipart bool
+ }{
+ {
+ name: "Only LocalTempPath",
+ localTempPath: "/tmp",
+ isLocal: true,
+ },
+ {
+ name: "No paths",
+ },
+ {
+ name: "Only remoteUrl",
+ presignedPut: "http://example.com",
+ },
+ {
+ name: "Multipart",
+ partSize: 10,
+ isMultipart: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := filestore.SaveFileOpts{
+ LocalTempPath: test.localTempPath,
+ PresignedPut: test.presignedPut,
+ PartSize: test.partSize,
+ }
+
+ require.Equal(t, test.isLocal, opts.IsLocal(), "IsLocal() mismatch")
+ require.Equal(t, test.isMultipart, opts.IsMultipart(), "IsMultipart() mismatch")
+ })
+ }
+}
+
+func TestGetOpts(t *testing.T) {
+ tests := []struct {
+ name string
+ multipart *api.MultipartUploadParams
+ customPutHeaders bool
+ putHeaders map[string]string
+ }{
+ {
+ name: "Single upload",
+ }, {
+ name: "Multipart upload",
+ multipart: &api.MultipartUploadParams{
+ PartSize: 10,
+ CompleteURL: "http://complete",
+ AbortURL: "http://abort",
+ PartURLs: []string{"http://part1", "http://part2"},
+ },
+ },
+ {
+ name: "Single upload with custom content type",
+ customPutHeaders: true,
+ putHeaders: map[string]string{"Content-Type": "image/jpeg"},
+ }, {
+ name: "Multipart upload with custom content type",
+ multipart: &api.MultipartUploadParams{
+ PartSize: 10,
+ CompleteURL: "http://complete",
+ AbortURL: "http://abort",
+ PartURLs: []string{"http://part1", "http://part2"},
+ },
+ customPutHeaders: true,
+ putHeaders: map[string]string{"Content-Type": "image/jpeg"},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ apiResponse := &api.Response{
+ RemoteObject: api.RemoteObject{
+ Timeout: 10,
+ ID: "id",
+ GetURL: "http://get",
+ StoreURL: "http://store",
+ DeleteURL: "http://delete",
+ MultipartUpload: test.multipart,
+ CustomPutHeaders: test.customPutHeaders,
+ PutHeaders: test.putHeaders,
+ },
+ }
+ deadline := time.Now().Add(time.Duration(apiResponse.RemoteObject.Timeout) * time.Second)
+ opts, err := filestore.GetOpts(apiResponse)
+ require.NoError(t, err)
+
+ require.Equal(t, apiResponse.TempPath, opts.LocalTempPath)
+ require.WithinDuration(t, deadline, opts.Deadline, time.Second)
+ require.Equal(t, apiResponse.RemoteObject.ID, opts.RemoteID)
+ require.Equal(t, apiResponse.RemoteObject.GetURL, opts.RemoteURL)
+ require.Equal(t, apiResponse.RemoteObject.StoreURL, opts.PresignedPut)
+ require.Equal(t, apiResponse.RemoteObject.DeleteURL, opts.PresignedDelete)
+ if test.customPutHeaders {
+ require.Equal(t, opts.PutHeaders, apiResponse.RemoteObject.PutHeaders)
+ } else {
+ require.Equal(t, opts.PutHeaders, map[string]string{"Content-Type": "application/octet-stream"})
+ }
+
+ if test.multipart == nil {
+ require.False(t, opts.IsMultipart())
+ require.Empty(t, opts.PresignedCompleteMultipart)
+ require.Empty(t, opts.PresignedAbortMultipart)
+ require.Zero(t, opts.PartSize)
+ require.Empty(t, opts.PresignedParts)
+ } else {
+ require.True(t, opts.IsMultipart())
+ require.Equal(t, test.multipart.CompleteURL, opts.PresignedCompleteMultipart)
+ require.Equal(t, test.multipart.AbortURL, opts.PresignedAbortMultipart)
+ require.Equal(t, test.multipart.PartSize, opts.PartSize)
+ require.Equal(t, test.multipart.PartURLs, opts.PresignedParts)
+ }
+ })
+ }
+}
+
+func TestGetOptsFail(t *testing.T) {
+ testCases := []struct {
+ desc string
+ in api.Response
+ }{
+ {
+ desc: "neither local nor remote",
+ in: api.Response{},
+ },
+ {
+ desc: "both local and remote",
+ in: api.Response{TempPath: "/foobar", RemoteObject: api.RemoteObject{ID: "id"}},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err := filestore.GetOpts(&tc.in)
+ require.Error(t, err, "expect input to be rejected")
+ })
+ }
+}
+
+func TestGetOptsDefaultTimeout(t *testing.T) {
+ deadline := time.Now().Add(filestore.DefaultObjectStoreTimeout)
+ opts, err := filestore.GetOpts(&api.Response{TempPath: "/foo/bar"})
+ require.NoError(t, err)
+
+ require.WithinDuration(t, deadline, opts.Deadline, time.Minute)
+}
+
+func TestUseWorkhorseClientEnabled(t *testing.T) {
+ cfg := filestore.ObjectStorageConfig{
+ Provider: "AWS",
+ S3Config: config.S3Config{
+ Bucket: "test-bucket",
+ Region: "test-region",
+ },
+ S3Credentials: config.S3Credentials{
+ AwsAccessKeyID: "test-key",
+ AwsSecretAccessKey: "test-secret",
+ },
+ }
+
+ missingCfg := cfg
+ missingCfg.S3Credentials = config.S3Credentials{}
+
+ iamConfig := missingCfg
+ iamConfig.S3Config.UseIamProfile = true
+
+ tests := []struct {
+ name string
+ UseWorkhorseClient bool
+ remoteTempObjectID string
+ objectStorageConfig filestore.ObjectStorageConfig
+ expected bool
+ }{
+ {
+ name: "all direct access settings used",
+ UseWorkhorseClient: true,
+ remoteTempObjectID: "test-object",
+ objectStorageConfig: cfg,
+ expected: true,
+ },
+ {
+ name: "missing AWS credentials",
+ UseWorkhorseClient: true,
+ remoteTempObjectID: "test-object",
+ objectStorageConfig: missingCfg,
+ expected: false,
+ },
+ {
+ name: "direct access disabled",
+ UseWorkhorseClient: false,
+ remoteTempObjectID: "test-object",
+ objectStorageConfig: cfg,
+ expected: false,
+ },
+ {
+ name: "with IAM instance profile",
+ UseWorkhorseClient: true,
+ remoteTempObjectID: "test-object",
+ objectStorageConfig: iamConfig,
+ expected: true,
+ },
+ {
+ name: "missing remote temp object ID",
+ UseWorkhorseClient: true,
+ remoteTempObjectID: "",
+ objectStorageConfig: cfg,
+ expected: false,
+ },
+ {
+ name: "missing S3 config",
+ UseWorkhorseClient: true,
+ remoteTempObjectID: "test-object",
+ expected: false,
+ },
+ {
+ name: "missing S3 bucket",
+ UseWorkhorseClient: true,
+ remoteTempObjectID: "test-object",
+ objectStorageConfig: filestore.ObjectStorageConfig{
+ Provider: "AWS",
+ S3Config: config.S3Config{},
+ },
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ apiResponse := &api.Response{
+ RemoteObject: api.RemoteObject{
+ Timeout: 10,
+ ID: "id",
+ UseWorkhorseClient: test.UseWorkhorseClient,
+ RemoteTempObjectID: test.remoteTempObjectID,
+ },
+ }
+ deadline := time.Now().Add(time.Duration(apiResponse.RemoteObject.Timeout) * time.Second)
+ opts, err := filestore.GetOpts(apiResponse)
+ require.NoError(t, err)
+ opts.ObjectStorageConfig = test.objectStorageConfig
+
+ require.Equal(t, apiResponse.TempPath, opts.LocalTempPath)
+ require.WithinDuration(t, deadline, opts.Deadline, time.Second)
+ require.Equal(t, apiResponse.RemoteObject.ID, opts.RemoteID)
+ require.Equal(t, apiResponse.RemoteObject.UseWorkhorseClient, opts.UseWorkhorseClient)
+ require.Equal(t, test.expected, opts.UseWorkhorseClientEnabled())
+ })
+ }
+}
+
+func TestGoCloudConfig(t *testing.T) {
+ mux, _, cleanup := test.SetupGoCloudFileBucket(t, "azblob")
+ defer cleanup()
+
+ tests := []struct {
+ name string
+ provider string
+ url string
+ valid bool
+ }{
+ {
+ name: "valid AzureRM config",
+ provider: "AzureRM",
+ url: "azblob:://test-container",
+ valid: true,
+ },
+ {
+ name: "invalid GoCloud scheme",
+ provider: "AzureRM",
+ url: "unknown:://test-container",
+ valid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ apiResponse := &api.Response{
+ RemoteObject: api.RemoteObject{
+ Timeout: 10,
+ ID: "id",
+ UseWorkhorseClient: true,
+ RemoteTempObjectID: "test-object",
+ ObjectStorage: &api.ObjectStorageParams{
+ Provider: test.provider,
+ GoCloudConfig: config.GoCloudConfig{
+ URL: test.url,
+ },
+ },
+ },
+ }
+ deadline := time.Now().Add(time.Duration(apiResponse.RemoteObject.Timeout) * time.Second)
+ opts, err := filestore.GetOpts(apiResponse)
+ require.NoError(t, err)
+ opts.ObjectStorageConfig.URLMux = mux
+
+ require.Equal(t, apiResponse.TempPath, opts.LocalTempPath)
+ require.Equal(t, apiResponse.RemoteObject.RemoteTempObjectID, opts.RemoteTempObjectID)
+ require.WithinDuration(t, deadline, opts.Deadline, time.Second)
+ require.Equal(t, apiResponse.RemoteObject.ID, opts.RemoteID)
+ require.Equal(t, apiResponse.RemoteObject.UseWorkhorseClient, opts.UseWorkhorseClient)
+ require.Equal(t, test.provider, opts.ObjectStorageConfig.Provider)
+ require.Equal(t, apiResponse.RemoteObject.ObjectStorage.GoCloudConfig, opts.ObjectStorageConfig.GoCloudConfig)
+ require.True(t, opts.UseWorkhorseClientEnabled())
+ require.Equal(t, test.valid, opts.ObjectStorageConfig.IsValid())
+ require.False(t, opts.IsLocal())
+ })
+ }
+}
diff --git a/workhorse/internal/git/archive.go b/workhorse/internal/git/archive.go
new file mode 100644
index 00000000000..b7575be2c02
--- /dev/null
+++ b/workhorse/internal/git/archive.go
@@ -0,0 +1,216 @@
+/*
+In this file we handle 'git archive' downloads
+*/
+
+package git
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+ "regexp"
+ "time"
+
+ "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+)
+
+type archive struct{ senddata.Prefix }
+type archiveParams struct {
+ ArchivePath string
+ ArchivePrefix string
+ CommitId string
+ GitalyServer gitaly.Server
+ GitalyRepository gitalypb.Repository
+ DisableCache bool
+ GetArchiveRequest []byte
+}
+
+var (
+ SendArchive = &archive{"git-archive:"}
+ gitArchiveCache = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_git_archive_cache",
+ Help: "Cache hits and misses for 'git archive' streaming",
+ },
+ []string{"result"},
+ )
+)
+
+func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ var params archiveParams
+ if err := a.Unpack(&params, sendData); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendArchive: unpack sendData: %v", err))
+ return
+ }
+
+ urlPath := r.URL.Path
+ format, ok := parseBasename(filepath.Base(urlPath))
+ if !ok {
+ helper.Fail500(w, r, fmt.Errorf("SendArchive: invalid format: %s", urlPath))
+ return
+ }
+
+ cacheEnabled := !params.DisableCache
+ archiveFilename := path.Base(params.ArchivePath)
+
+ if cacheEnabled {
+ cachedArchive, err := os.Open(params.ArchivePath)
+ if err == nil {
+ defer cachedArchive.Close()
+ gitArchiveCache.WithLabelValues("hit").Inc()
+ setArchiveHeaders(w, format, archiveFilename)
+ // Even if somebody deleted the cachedArchive from disk since we opened
+ // the file, Unix file semantics guarantee we can still read from the
+ // open file in this process.
+ http.ServeContent(w, r, "", time.Unix(0, 0), cachedArchive)
+ return
+ }
+ }
+
+ gitArchiveCache.WithLabelValues("miss").Inc()
+
+ var tempFile *os.File
+ var err error
+
+ if cacheEnabled {
+ // We assume the tempFile has a unique name so that concurrent requests are
+ // safe. We create the tempfile in the same directory as the final cached
+ // archive we want to create so that we can use an atomic link(2) operation
+ // to finalize the cached archive.
+ tempFile, err = prepareArchiveTempfile(path.Dir(params.ArchivePath), archiveFilename)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendArchive: create tempfile: %v", err))
+ return
+ }
+ defer tempFile.Close()
+ defer os.Remove(tempFile.Name())
+ }
+
+ var archiveReader io.Reader
+
+ archiveReader, err = handleArchiveWithGitaly(r, params, format)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("operations.GetArchive: %v", err))
+ return
+ }
+
+ reader := archiveReader
+ if cacheEnabled {
+ reader = io.TeeReader(archiveReader, tempFile)
+ }
+
+ // Start writing the response
+ setArchiveHeaders(w, format, archiveFilename)
+ w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
+ if _, err := io.Copy(w, reader); err != nil {
+ helper.LogError(r, &copyError{fmt.Errorf("SendArchive: copy 'git archive' output: %v", err)})
+ return
+ }
+
+ if cacheEnabled {
+ err := finalizeCachedArchive(tempFile, params.ArchivePath)
+ if err != nil {
+ helper.LogError(r, fmt.Errorf("SendArchive: finalize cached archive: %v", err))
+ return
+ }
+ }
+}
+
+func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gitalypb.GetArchiveRequest_Format) (io.Reader, error) {
+ var request *gitalypb.GetArchiveRequest
+ ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer)
+ if err != nil {
+ return nil, err
+ }
+
+ if params.GetArchiveRequest != nil {
+ request = &gitalypb.GetArchiveRequest{}
+
+ if err := proto.Unmarshal(params.GetArchiveRequest, request); err != nil {
+ return nil, fmt.Errorf("unmarshal GetArchiveRequest: %v", err)
+ }
+ } else {
+ request = &gitalypb.GetArchiveRequest{
+ Repository: &params.GitalyRepository,
+ CommitId: params.CommitId,
+ Prefix: params.ArchivePrefix,
+ Format: format,
+ }
+ }
+
+ return c.ArchiveReader(ctx, request)
+}
+
+func setArchiveHeaders(w http.ResponseWriter, format gitalypb.GetArchiveRequest_Format, archiveFilename string) {
+ w.Header().Del("Content-Length")
+ w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename))
+ // Caching proxies usually don't cache responses with Set-Cookie header
+ // present because it implies user-specific data, which is not the case
+ // for repository archives.
+ w.Header().Del("Set-Cookie")
+ if format == gitalypb.GetArchiveRequest_ZIP {
+ w.Header().Set("Content-Type", "application/zip")
+ } else {
+ w.Header().Set("Content-Type", "application/octet-stream")
+ }
+ w.Header().Set("Content-Transfer-Encoding", "binary")
+}
+
+func prepareArchiveTempfile(dir string, prefix string) (*os.File, error) {
+ if err := os.MkdirAll(dir, 0700); err != nil {
+ return nil, err
+ }
+ return ioutil.TempFile(dir, prefix)
+}
+
+func finalizeCachedArchive(tempFile *os.File, archivePath string) error {
+ if err := tempFile.Close(); err != nil {
+ return err
+ }
+ if err := os.Link(tempFile.Name(), archivePath); err != nil && !os.IsExist(err) {
+ return err
+ }
+
+ return nil
+}
+
+var (
+ patternZip = regexp.MustCompile(`\.zip$`)
+ patternTar = regexp.MustCompile(`\.tar$`)
+ patternTarGz = regexp.MustCompile(`\.(tar\.gz|tgz|gz)$`)
+ patternTarBz2 = regexp.MustCompile(`\.(tar\.bz2|tbz|tbz2|tb2|bz2)$`)
+)
+
+func parseBasename(basename string) (gitalypb.GetArchiveRequest_Format, bool) {
+ var format gitalypb.GetArchiveRequest_Format
+
+ switch {
+ case (basename == "archive"):
+ format = gitalypb.GetArchiveRequest_TAR_GZ
+ case patternZip.MatchString(basename):
+ format = gitalypb.GetArchiveRequest_ZIP
+ case patternTar.MatchString(basename):
+ format = gitalypb.GetArchiveRequest_TAR
+ case patternTarGz.MatchString(basename):
+ format = gitalypb.GetArchiveRequest_TAR_GZ
+ case patternTarBz2.MatchString(basename):
+ format = gitalypb.GetArchiveRequest_TAR_BZ2
+ default:
+ return format, false
+ }
+
+ return format, true
+}
diff --git a/workhorse/internal/git/archive_test.go b/workhorse/internal/git/archive_test.go
new file mode 100644
index 00000000000..4b0753499e5
--- /dev/null
+++ b/workhorse/internal/git/archive_test.go
@@ -0,0 +1,87 @@
+package git
+
+import (
+ "io/ioutil"
+ "net/http/httptest"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseBasename(t *testing.T) {
+ for _, testCase := range []struct {
+ in string
+ out gitalypb.GetArchiveRequest_Format
+ }{
+ {"archive", gitalypb.GetArchiveRequest_TAR_GZ},
+ {"master.tar.gz", gitalypb.GetArchiveRequest_TAR_GZ},
+ {"foo-master.tgz", gitalypb.GetArchiveRequest_TAR_GZ},
+ {"foo-v1.2.1.gz", gitalypb.GetArchiveRequest_TAR_GZ},
+ {"foo.tar.bz2", gitalypb.GetArchiveRequest_TAR_BZ2},
+ {"archive.tbz", gitalypb.GetArchiveRequest_TAR_BZ2},
+ {"archive.tbz2", gitalypb.GetArchiveRequest_TAR_BZ2},
+ {"archive.tb2", gitalypb.GetArchiveRequest_TAR_BZ2},
+ {"archive.bz2", gitalypb.GetArchiveRequest_TAR_BZ2},
+ } {
+ basename := testCase.in
+ out, ok := parseBasename(basename)
+ if !ok {
+ t.Fatalf("parseBasename did not recognize %q", basename)
+ }
+
+ if out != testCase.out {
+ t.Fatalf("expected %q, got %q", testCase.out, out)
+ }
+ }
+}
+
+func TestFinalizeArchive(t *testing.T) {
+ tempFile, err := ioutil.TempFile("", "gitlab-workhorse-test")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tempFile.Close()
+
+ // Deliberately cause an EEXIST error: we know tempFile.Name() already exists
+ err = finalizeCachedArchive(tempFile, tempFile.Name())
+ if err != nil {
+ t.Fatalf("expected nil from finalizeCachedArchive, received %v", err)
+ }
+}
+
+func TestSetArchiveHeaders(t *testing.T) {
+ for _, testCase := range []struct {
+ in gitalypb.GetArchiveRequest_Format
+ out string
+ }{
+ {gitalypb.GetArchiveRequest_ZIP, "application/zip"},
+ {gitalypb.GetArchiveRequest_TAR, "application/octet-stream"},
+ {gitalypb.GetArchiveRequest_TAR_GZ, "application/octet-stream"},
+ {gitalypb.GetArchiveRequest_TAR_BZ2, "application/octet-stream"},
+ } {
+ w := httptest.NewRecorder()
+
+ // These should be replaced, not appended to
+ w.Header().Set("Content-Type", "test")
+ w.Header().Set("Content-Length", "test")
+ w.Header().Set("Content-Disposition", "test")
+
+ // This should be deleted
+ w.Header().Set("Set-Cookie", "test")
+
+ // This should be preserved
+ w.Header().Set("Cache-Control", "public, max-age=3600")
+
+ setArchiveHeaders(w, testCase.in, "filename")
+
+ testhelper.RequireResponseHeader(t, w, "Content-Type", testCase.out)
+ testhelper.RequireResponseHeader(t, w, "Content-Length")
+ testhelper.RequireResponseHeader(t, w, "Content-Disposition", `attachment; filename="filename"`)
+ testhelper.RequireResponseHeader(t, w, "Cache-Control", "public, max-age=3600")
+ require.Empty(t, w.Header().Get("Set-Cookie"), "remove Set-Cookie")
+ }
+}
diff --git a/workhorse/internal/git/blob.go b/workhorse/internal/git/blob.go
new file mode 100644
index 00000000000..472f5d0bc96
--- /dev/null
+++ b/workhorse/internal/git/blob.go
@@ -0,0 +1,47 @@
+package git
+
+import (
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+)
+
+type blob struct{ senddata.Prefix }
+type blobParams struct {
+ GitalyServer gitaly.Server
+ GetBlobRequest gitalypb.GetBlobRequest
+}
+
+var SendBlob = &blob{"git-blob:"}
+
+func (b *blob) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ var params blobParams
+ if err := b.Unpack(&params, sendData); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendBlob: unpack sendData: %v", err))
+ return
+ }
+
+ ctx, blobClient, err := gitaly.NewBlobClient(r.Context(), params.GitalyServer)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err))
+ return
+ }
+
+ setBlobHeaders(w)
+ if err := blobClient.SendBlob(ctx, w, &params.GetBlobRequest); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err))
+ return
+ }
+}
+
+func setBlobHeaders(w http.ResponseWriter) {
+ // Caching proxies usually don't cache responses with Set-Cookie header
+ // present because it implies user-specific data, which is not the case
+ // for blobs.
+ w.Header().Del("Set-Cookie")
+}
diff --git a/workhorse/internal/git/blob_test.go b/workhorse/internal/git/blob_test.go
new file mode 100644
index 00000000000..ec28c2adb2f
--- /dev/null
+++ b/workhorse/internal/git/blob_test.go
@@ -0,0 +1,17 @@
+package git
+
+import (
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSetBlobHeaders(t *testing.T) {
+ w := httptest.NewRecorder()
+ w.Header().Set("Set-Cookie", "gitlab_cookie=123456")
+
+ setBlobHeaders(w)
+
+ require.Empty(t, w.Header().Get("Set-Cookie"), "remove Set-Cookie")
+}
diff --git a/workhorse/internal/git/diff.go b/workhorse/internal/git/diff.go
new file mode 100644
index 00000000000..b1a1c17a650
--- /dev/null
+++ b/workhorse/internal/git/diff.go
@@ -0,0 +1,48 @@
+package git
+
+import (
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+)
+
+type diff struct{ senddata.Prefix }
+type diffParams struct {
+ GitalyServer gitaly.Server
+ RawDiffRequest string
+}
+
+var SendDiff = &diff{"git-diff:"}
+
+func (d *diff) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ var params diffParams
+ if err := d.Unpack(&params, sendData); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendDiff: unpack sendData: %v", err))
+ return
+ }
+
+ request := &gitalypb.RawDiffRequest{}
+ if err := gitaly.UnmarshalJSON(params.RawDiffRequest, request); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("diff.RawDiff: %v", err))
+ return
+ }
+
+ ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("diff.RawDiff: %v", err))
+ return
+ }
+
+ if err := diffClient.SendRawDiff(ctx, w, request); err != nil {
+ helper.LogError(
+ r,
+ &copyError{fmt.Errorf("diff.RawDiff: request=%v, err=%v", request, err)},
+ )
+ return
+ }
+}
diff --git a/workhorse/internal/git/error.go b/workhorse/internal/git/error.go
new file mode 100644
index 00000000000..2b7cad6bb64
--- /dev/null
+++ b/workhorse/internal/git/error.go
@@ -0,0 +1,4 @@
+package git
+
+// For cosmetic purposes in Sentry
+type copyError struct{ error }
diff --git a/workhorse/internal/git/format-patch.go b/workhorse/internal/git/format-patch.go
new file mode 100644
index 00000000000..db96029b07e
--- /dev/null
+++ b/workhorse/internal/git/format-patch.go
@@ -0,0 +1,48 @@
+package git
+
+import (
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+)
+
+type patch struct{ senddata.Prefix }
+type patchParams struct {
+ GitalyServer gitaly.Server
+ RawPatchRequest string
+}
+
+var SendPatch = &patch{"git-format-patch:"}
+
+func (p *patch) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ var params patchParams
+ if err := p.Unpack(&params, sendData); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendPatch: unpack sendData: %v", err))
+ return
+ }
+
+ request := &gitalypb.RawPatchRequest{}
+ if err := gitaly.UnmarshalJSON(params.RawPatchRequest, request); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("diff.RawPatch: %v", err))
+ return
+ }
+
+ ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("diff.RawPatch: %v", err))
+ return
+ }
+
+ if err := diffClient.SendRawPatch(ctx, w, request); err != nil {
+ helper.LogError(
+ r,
+ &copyError{fmt.Errorf("diff.RawPatch: request=%v, err=%v", request, err)},
+ )
+ return
+ }
+}
diff --git a/workhorse/internal/git/git-http.go b/workhorse/internal/git/git-http.go
new file mode 100644
index 00000000000..5df20a68bb7
--- /dev/null
+++ b/workhorse/internal/git/git-http.go
@@ -0,0 +1,100 @@
+/*
+In this file we handle the Git 'smart HTTP' protocol
+*/
+
+package git
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "path/filepath"
+ "sync"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+const (
+ // We have to use a negative transfer.hideRefs since this is the only way
+ // to undo an already set parameter: https://www.spinics.net/lists/git/msg256772.html
+ GitConfigShowAllRefs = "transfer.hideRefs=!refs"
+)
+
+func ReceivePack(a *api.API) http.Handler {
+ return postRPCHandler(a, "handleReceivePack", handleReceivePack)
+}
+
+func UploadPack(a *api.API) http.Handler {
+ return postRPCHandler(a, "handleUploadPack", handleUploadPack)
+}
+
+func gitConfigOptions(a *api.Response) []string {
+ var out []string
+
+ if a.ShowAllRefs {
+ out = append(out, GitConfigShowAllRefs)
+ }
+
+ return out
+}
+
+func postRPCHandler(a *api.API, name string, handler func(*HttpResponseWriter, *http.Request, *api.Response) error) http.Handler {
+ return repoPreAuthorizeHandler(a, func(rw http.ResponseWriter, r *http.Request, ar *api.Response) {
+ cr := &countReadCloser{ReadCloser: r.Body}
+ r.Body = cr
+
+ w := NewHttpResponseWriter(rw)
+ defer func() {
+ w.Log(r, cr.Count())
+ }()
+
+ if err := handler(w, r, ar); err != nil {
+ // If the handler already wrote a response this WriteHeader call is a
+ // no-op. It never reaches net/http because GitHttpResponseWriter calls
+ // WriteHeader on its underlying ResponseWriter at most once.
+ w.WriteHeader(500)
+ helper.LogError(r, fmt.Errorf("%s: %v", name, err))
+ }
+ })
+}
+
+func repoPreAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
+ return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
+ handleFunc(w, r, a)
+ }, "")
+}
+
+func writePostRPCHeader(w http.ResponseWriter, action string) {
+ w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", action))
+ w.Header().Set("Cache-Control", "no-cache")
+}
+
+func getService(r *http.Request) string {
+ if r.Method == "GET" {
+ return r.URL.Query().Get("service")
+ }
+ return filepath.Base(r.URL.Path)
+}
+
+type countReadCloser struct {
+ n int64
+ io.ReadCloser
+ sync.Mutex
+}
+
+func (c *countReadCloser) Read(p []byte) (n int, err error) {
+ n, err = c.ReadCloser.Read(p)
+
+ c.Lock()
+ defer c.Unlock()
+ c.n += int64(n)
+
+ return n, err
+}
+
+func (c *countReadCloser) Count() int64 {
+ c.Lock()
+ defer c.Unlock()
+ return c.n
+}
diff --git a/workhorse/internal/git/info-refs.go b/workhorse/internal/git/info-refs.go
new file mode 100644
index 00000000000..e5491a7b733
--- /dev/null
+++ b/workhorse/internal/git/info-refs.go
@@ -0,0 +1,76 @@
+package git
+
+import (
+ "compress/gzip"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/golang/gddo/httputil"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+func GetInfoRefsHandler(a *api.API) http.Handler {
+ return repoPreAuthorizeHandler(a, handleGetInfoRefs)
+}
+
+func handleGetInfoRefs(rw http.ResponseWriter, r *http.Request, a *api.Response) {
+ responseWriter := NewHttpResponseWriter(rw)
+ // Log 0 bytes in because we ignore the request body (and there usually is none anyway).
+ defer responseWriter.Log(r, 0)
+
+ rpc := getService(r)
+ if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
+ // The 'dumb' Git HTTP protocol is not supported
+ http.Error(responseWriter, "Not Found", 404)
+ return
+ }
+
+ responseWriter.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc))
+ responseWriter.Header().Set("Cache-Control", "no-cache")
+
+ gitProtocol := r.Header.Get("Git-Protocol")
+
+ offers := []string{"gzip", "identity"}
+ encoding := httputil.NegotiateContentEncoding(r, offers)
+
+ if err := handleGetInfoRefsWithGitaly(r.Context(), responseWriter, a, rpc, gitProtocol, encoding); err != nil {
+ helper.Fail500(responseWriter, r, fmt.Errorf("handleGetInfoRefs: %v", err))
+ }
+}
+
+func handleGetInfoRefsWithGitaly(ctx context.Context, responseWriter *HttpResponseWriter, a *api.Response, rpc, gitProtocol, encoding string) error {
+ ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer)
+ if err != nil {
+ return fmt.Errorf("GetInfoRefsHandler: %v", err)
+ }
+
+ infoRefsResponseReader, err := smarthttp.InfoRefsResponseReader(ctx, &a.Repository, rpc, gitConfigOptions(a), gitProtocol)
+ if err != nil {
+ return fmt.Errorf("GetInfoRefsHandler: %v", err)
+ }
+
+ var w io.Writer
+
+ if encoding == "gzip" {
+ gzWriter := gzip.NewWriter(responseWriter)
+ w = gzWriter
+ defer gzWriter.Close()
+
+ responseWriter.Header().Set("Content-Encoding", "gzip")
+ } else {
+ w = responseWriter
+ }
+
+ if _, err = io.Copy(w, infoRefsResponseReader); err != nil {
+ log.WithError(err).Error("GetInfoRefsHandler: error copying gitaly response")
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/git/pktline.go b/workhorse/internal/git/pktline.go
new file mode 100644
index 00000000000..e970f60182d
--- /dev/null
+++ b/workhorse/internal/git/pktline.go
@@ -0,0 +1,59 @@
+package git
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "strconv"
+)
+
+func scanDeepen(body io.Reader) bool {
+ scanner := bufio.NewScanner(body)
+ scanner.Split(pktLineSplitter)
+ for scanner.Scan() {
+ if bytes.HasPrefix(scanner.Bytes(), []byte("deepen")) && scanner.Err() == nil {
+ return true
+ }
+ }
+
+ return false
+}
+
+func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if len(data) < 4 {
+ if atEOF && len(data) > 0 {
+ return 0, nil, fmt.Errorf("pktLineSplitter: incomplete length prefix on %q", data)
+ }
+ return 0, nil, nil // want more data
+ }
+
+ if bytes.HasPrefix(data, []byte("0000")) {
+ // special case: "0000" terminator packet: return empty token
+ return 4, data[:0], nil
+ }
+
+ // We have at least 4 bytes available so we can decode the 4-hex digit
+ // length prefix of the packet line.
+ pktLength64, err := strconv.ParseInt(string(data[:4]), 16, 0)
+ if err != nil {
+ return 0, nil, fmt.Errorf("pktLineSplitter: decode length: %v", err)
+ }
+
+ // Cast is safe because we requested an int-size number from strconv.ParseInt
+ pktLength := int(pktLength64)
+
+ if pktLength < 0 {
+ return 0, nil, fmt.Errorf("pktLineSplitter: invalid length: %d", pktLength)
+ }
+
+ if len(data) < pktLength {
+ if atEOF {
+ return 0, nil, fmt.Errorf("pktLineSplitter: less than %d bytes in input %q", pktLength, data)
+ }
+ return 0, nil, nil // want more data
+ }
+
+ // return "pkt" token without length prefix
+ return pktLength, data[4:pktLength], nil
+}
diff --git a/workhorse/internal/git/pktline_test.go b/workhorse/internal/git/pktline_test.go
new file mode 100644
index 00000000000..d4be8634538
--- /dev/null
+++ b/workhorse/internal/git/pktline_test.go
@@ -0,0 +1,39 @@
+package git
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestSuccessfulScanDeepen(t *testing.T) {
+ examples := []struct {
+ input string
+ output bool
+ }{
+ {"000dsomething000cdeepen 10000", true},
+ {"000dsomething0000000cdeepen 1", true},
+ {"000dsomething0000", false},
+ }
+
+ for _, example := range examples {
+ hasDeepen := scanDeepen(bytes.NewReader([]byte(example.input)))
+
+ if hasDeepen != example.output {
+ t.Fatalf("scanDeepen %q: expected %v, got %v", example.input, example.output, hasDeepen)
+ }
+ }
+}
+
+func TestFailedScanDeepen(t *testing.T) {
+ examples := []string{
+ "invalid data",
+ "deepen",
+ "000cdeepen",
+ }
+
+ for _, example := range examples {
+ if scanDeepen(bytes.NewReader([]byte(example))) {
+ t.Fatalf("scanDeepen %q: expected result to be false, got true", example)
+ }
+ }
+}
diff --git a/workhorse/internal/git/receive-pack.go b/workhorse/internal/git/receive-pack.go
new file mode 100644
index 00000000000..e72d8be5174
--- /dev/null
+++ b/workhorse/internal/git/receive-pack.go
@@ -0,0 +1,33 @@
+package git
+
+import (
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+// Will not return a non-nil error after the response body has been
+// written to.
+func handleReceivePack(w *HttpResponseWriter, r *http.Request, a *api.Response) error {
+ action := getService(r)
+ writePostRPCHeader(w, action)
+
+ cr, cw := helper.NewWriteAfterReader(r.Body, w)
+ defer cw.Flush()
+
+ gitProtocol := r.Header.Get("Git-Protocol")
+
+ ctx, smarthttp, err := gitaly.NewSmartHTTPClient(r.Context(), a.GitalyServer)
+ if err != nil {
+ return fmt.Errorf("smarthttp.ReceivePack: %v", err)
+ }
+
+ if err := smarthttp.ReceivePack(ctx, &a.Repository, a.GL_ID, a.GL_USERNAME, a.GL_REPOSITORY, a.GitConfigOptions, cr, cw, gitProtocol); err != nil {
+ return fmt.Errorf("smarthttp.ReceivePack: %v", err)
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/git/responsewriter.go b/workhorse/internal/git/responsewriter.go
new file mode 100644
index 00000000000..c4d4ac252d4
--- /dev/null
+++ b/workhorse/internal/git/responsewriter.go
@@ -0,0 +1,75 @@
+package git
+
+import (
+ "net/http"
+ "strconv"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+const (
+ directionIn = "in"
+ directionOut = "out"
+)
+
+var (
+ gitHTTPSessionsActive = promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_git_http_sessions_active",
+ Help: "Number of Git HTTP request-response cycles currently being handled by gitlab-workhorse.",
+ })
+
+ gitHTTPRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_git_http_requests",
+ Help: "How many Git HTTP requests have been processed by gitlab-workhorse, partitioned by request type and agent.",
+ },
+ []string{"method", "code", "service", "agent"},
+ )
+
+ gitHTTPBytes = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_git_http_bytes",
+ Help: "How many Git HTTP bytes have been sent by gitlab-workhorse, partitioned by request type, agent and direction.",
+ },
+ []string{"method", "code", "service", "agent", "direction"},
+ )
+)
+
+type HttpResponseWriter struct {
+ helper.CountingResponseWriter
+}
+
+func NewHttpResponseWriter(rw http.ResponseWriter) *HttpResponseWriter {
+ gitHTTPSessionsActive.Inc()
+ return &HttpResponseWriter{
+ CountingResponseWriter: helper.NewCountingResponseWriter(rw),
+ }
+}
+
+func (w *HttpResponseWriter) Log(r *http.Request, writtenIn int64) {
+ service := getService(r)
+ agent := getRequestAgent(r)
+
+ gitHTTPSessionsActive.Dec()
+ gitHTTPRequests.WithLabelValues(r.Method, strconv.Itoa(w.Status()), service, agent).Inc()
+ gitHTTPBytes.WithLabelValues(r.Method, strconv.Itoa(w.Status()), service, agent, directionIn).
+ Add(float64(writtenIn))
+ gitHTTPBytes.WithLabelValues(r.Method, strconv.Itoa(w.Status()), service, agent, directionOut).
+ Add(float64(w.Count()))
+}
+
+func getRequestAgent(r *http.Request) string {
+ u, _, ok := r.BasicAuth()
+ if !ok {
+ return "anonymous"
+ }
+
+ if u == "gitlab-ci-token" {
+ return "gitlab-ci"
+ }
+
+ return "logged"
+}
diff --git a/workhorse/internal/git/snapshot.go b/workhorse/internal/git/snapshot.go
new file mode 100644
index 00000000000..eb38becbd06
--- /dev/null
+++ b/workhorse/internal/git/snapshot.go
@@ -0,0 +1,64 @@
+package git
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+)
+
+type snapshot struct {
+ senddata.Prefix
+}
+
+type snapshotParams struct {
+ GitalyServer gitaly.Server
+ GetSnapshotRequest string
+}
+
+var (
+ SendSnapshot = &snapshot{"git-snapshot:"}
+)
+
+func (s *snapshot) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ var params snapshotParams
+
+ if err := s.Unpack(&params, sendData); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendSnapshot: unpack sendData: %v", err))
+ return
+ }
+
+ request := &gitalypb.GetSnapshotRequest{}
+ if err := gitaly.UnmarshalJSON(params.GetSnapshotRequest, request); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendSnapshot: unmarshal GetSnapshotRequest: %v", err))
+ return
+ }
+
+ ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendSnapshot: gitaly.NewRepositoryClient: %v", err))
+ return
+ }
+
+ reader, err := c.SnapshotReader(ctx, request)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendSnapshot: client.SnapshotReader: %v", err))
+ return
+ }
+
+ w.Header().Del("Content-Length")
+ w.Header().Set("Content-Disposition", `attachment; filename="snapshot.tar"`)
+ w.Header().Set("Content-Type", "application/x-tar")
+ w.Header().Set("Content-Transfer-Encoding", "binary")
+ w.Header().Set("Cache-Control", "private")
+ w.WriteHeader(http.StatusOK) // Errors aren't detectable beyond this point
+
+ if _, err := io.Copy(w, reader); err != nil {
+ helper.LogError(r, fmt.Errorf("SendSnapshot: copy gitaly output: %v", err))
+ }
+}
diff --git a/workhorse/internal/git/upload-pack.go b/workhorse/internal/git/upload-pack.go
new file mode 100644
index 00000000000..a3dbf2f2e02
--- /dev/null
+++ b/workhorse/internal/git/upload-pack.go
@@ -0,0 +1,57 @@
+package git
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ uploadPackTimeout = 10 * time.Minute
+)
+
+// Will not return a non-nil error after the response body has been
+// written to.
+func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) error {
+ ctx := r.Context()
+
+ // Prevent the client from holding the connection open indefinitely. A
+ // transfer rate of 17KiB/sec is sufficient to send 10MiB of data in
+ // ten minutes, which seems adequate. Most requests will be much smaller.
+ // This mitigates a use-after-check issue.
+ //
+ // We can't reliably interrupt the read from a http handler, but we can
+ // ensure the request will (eventually) fail: https://github.com/golang/go/issues/16100
+ readerCtx, cancel := context.WithTimeout(ctx, uploadPackTimeout)
+ defer cancel()
+
+ limited := helper.NewContextReader(readerCtx, r.Body)
+ cr, cw := helper.NewWriteAfterReader(limited, w)
+ defer cw.Flush()
+
+ action := getService(r)
+ writePostRPCHeader(w, action)
+
+ gitProtocol := r.Header.Get("Git-Protocol")
+
+ return handleUploadPackWithGitaly(ctx, a, cr, cw, gitProtocol)
+}
+
+func handleUploadPackWithGitaly(ctx context.Context, a *api.Response, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error {
+ ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer)
+ if err != nil {
+ return fmt.Errorf("smarthttp.UploadPack: %v", err)
+ }
+
+ if err := smarthttp.UploadPack(ctx, &a.Repository, clientRequest, clientResponse, gitConfigOptions(a), gitProtocol); err != nil {
+ return fmt.Errorf("smarthttp.UploadPack: %v", err)
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/git/upload-pack_test.go b/workhorse/internal/git/upload-pack_test.go
new file mode 100644
index 00000000000..c198939d5df
--- /dev/null
+++ b/workhorse/internal/git/upload-pack_test.go
@@ -0,0 +1,85 @@
+package git
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+)
+
+var (
+ originalUploadPackTimeout = uploadPackTimeout
+)
+
+type fakeReader struct {
+ n int
+ err error
+}
+
+func (f *fakeReader) Read(b []byte) (int, error) {
+ return f.n, f.err
+}
+
+type smartHTTPServiceServer struct {
+ gitalypb.UnimplementedSmartHTTPServiceServer
+ PostUploadPackFunc func(gitalypb.SmartHTTPService_PostUploadPackServer) error
+}
+
+func (srv *smartHTTPServiceServer) PostUploadPack(s gitalypb.SmartHTTPService_PostUploadPackServer) error {
+ return srv.PostUploadPackFunc(s)
+}
+
+func TestUploadPackTimesOut(t *testing.T) {
+ uploadPackTimeout = time.Millisecond
+ defer func() { uploadPackTimeout = originalUploadPackTimeout }()
+
+ addr, cleanUp := startSmartHTTPServer(t, &smartHTTPServiceServer{
+ PostUploadPackFunc: func(stream gitalypb.SmartHTTPService_PostUploadPackServer) error {
+ _, err := stream.Recv() // trigger a read on the client request body
+ require.NoError(t, err)
+ return nil
+ },
+ })
+ defer cleanUp()
+
+ body := &fakeReader{n: 0, err: nil}
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest("GET", "/", body)
+ a := &api.Response{GitalyServer: gitaly.Server{Address: addr}}
+
+ err := handleUploadPack(NewHttpResponseWriter(w), r, a)
+ require.EqualError(t, err, "smarthttp.UploadPack: busyReader: context deadline exceeded")
+}
+
+func startSmartHTTPServer(t testing.TB, s gitalypb.SmartHTTPServiceServer) (string, func()) {
+ tmp, err := ioutil.TempDir("", "")
+ require.NoError(t, err)
+
+ socket := filepath.Join(tmp, "gitaly.sock")
+ ln, err := net.Listen("unix", socket)
+ require.NoError(t, err)
+
+ srv := grpc.NewServer()
+ gitalypb.RegisterSmartHTTPServiceServer(srv, s)
+ go func() {
+ require.NoError(t, srv.Serve(ln))
+ }()
+
+ return fmt.Sprintf("%s://%s", ln.Addr().Network(), ln.Addr().String()), func() {
+ srv.GracefulStop()
+ require.NoError(t, os.RemoveAll(tmp), "error removing temp dir %q", tmp)
+ }
+}
diff --git a/workhorse/internal/gitaly/blob.go b/workhorse/internal/gitaly/blob.go
new file mode 100644
index 00000000000..c6f5d6676f3
--- /dev/null
+++ b/workhorse/internal/gitaly/blob.go
@@ -0,0 +1,41 @@
+package gitaly
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitaly/streamio"
+)
+
+type BlobClient struct {
+ gitalypb.BlobServiceClient
+}
+
+func (client *BlobClient) SendBlob(ctx context.Context, w http.ResponseWriter, request *gitalypb.GetBlobRequest) error {
+ c, err := client.GetBlob(ctx, request)
+ if err != nil {
+ return fmt.Errorf("rpc failed: %v", err)
+ }
+
+ firstResponseReceived := false
+ rr := streamio.NewReader(func() ([]byte, error) {
+ resp, err := c.Recv()
+
+ if !firstResponseReceived && err == nil {
+ firstResponseReceived = true
+ w.Header().Set("Content-Length", strconv.FormatInt(resp.GetSize(), 10))
+ }
+
+ return resp.GetData(), err
+ })
+
+ if _, err := io.Copy(w, rr); err != nil {
+ return fmt.Errorf("copy rpc data: %v", err)
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/gitaly/diff.go b/workhorse/internal/gitaly/diff.go
new file mode 100644
index 00000000000..035a58ec6fd
--- /dev/null
+++ b/workhorse/internal/gitaly/diff.go
@@ -0,0 +1,55 @@
+package gitaly
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitaly/streamio"
+)
+
+type DiffClient struct {
+ gitalypb.DiffServiceClient
+}
+
+func (client *DiffClient) SendRawDiff(ctx context.Context, w http.ResponseWriter, request *gitalypb.RawDiffRequest) error {
+ c, err := client.RawDiff(ctx, request)
+ if err != nil {
+ return fmt.Errorf("rpc failed: %v", err)
+ }
+
+ w.Header().Del("Content-Length")
+
+ rr := streamio.NewReader(func() ([]byte, error) {
+ resp, err := c.Recv()
+ return resp.GetData(), err
+ })
+
+ if _, err := io.Copy(w, rr); err != nil {
+ return fmt.Errorf("copy rpc data: %v", err)
+ }
+
+ return nil
+}
+
+func (client *DiffClient) SendRawPatch(ctx context.Context, w http.ResponseWriter, request *gitalypb.RawPatchRequest) error {
+ c, err := client.RawPatch(ctx, request)
+ if err != nil {
+ return fmt.Errorf("rpc failed: %v", err)
+ }
+
+ w.Header().Del("Content-Length")
+
+ rr := streamio.NewReader(func() ([]byte, error) {
+ resp, err := c.Recv()
+ return resp.GetData(), err
+ })
+
+ if _, err := io.Copy(w, rr); err != nil {
+ return fmt.Errorf("copy rpc data: %v", err)
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/gitaly/gitaly.go b/workhorse/internal/gitaly/gitaly.go
new file mode 100644
index 00000000000..c739ac8d9b2
--- /dev/null
+++ b/workhorse/internal/gitaly/gitaly.go
@@ -0,0 +1,188 @@
+package gitaly
+
+import (
+ "context"
+ "strings"
+ "sync"
+
+ "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274
+ "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274
+ grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
+ grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ gitalyauth "gitlab.com/gitlab-org/gitaly/auth"
+ gitalyclient "gitlab.com/gitlab-org/gitaly/client"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+
+ grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
+ grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
+)
+
+type Server struct {
+ Address string `json:"address"`
+ Token string `json:"token"`
+ Features map[string]string `json:"features"`
+}
+
+type cacheKey struct{ address, token string }
+
+func (server Server) cacheKey() cacheKey {
+ return cacheKey{address: server.Address, token: server.Token}
+}
+
+type connectionsCache struct {
+ sync.RWMutex
+ connections map[cacheKey]*grpc.ClientConn
+}
+
+var (
+ jsonUnMarshaler = jsonpb.Unmarshaler{AllowUnknownFields: true}
+ cache = connectionsCache{
+ connections: make(map[cacheKey]*grpc.ClientConn),
+ }
+
+ connectionsTotal = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_gitaly_connections_total",
+ Help: "Number of Gitaly connections that have been established",
+ },
+ []string{"status"},
+ )
+)
+
+func withOutgoingMetadata(ctx context.Context, features map[string]string) context.Context {
+ md := metadata.New(nil)
+ for k, v := range features {
+ if !strings.HasPrefix(k, "gitaly-feature-") {
+ continue
+ }
+ md.Append(k, v)
+ }
+
+ return metadata.NewOutgoingContext(ctx, md)
+}
+
+func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *SmartHTTPClient, error) {
+ conn, err := getOrCreateConnection(server)
+ if err != nil {
+ return nil, nil, err
+ }
+ grpcClient := gitalypb.NewSmartHTTPServiceClient(conn)
+ return withOutgoingMetadata(ctx, server.Features), &SmartHTTPClient{grpcClient}, nil
+}
+
+func NewBlobClient(ctx context.Context, server Server) (context.Context, *BlobClient, error) {
+ conn, err := getOrCreateConnection(server)
+ if err != nil {
+ return nil, nil, err
+ }
+ grpcClient := gitalypb.NewBlobServiceClient(conn)
+ return withOutgoingMetadata(ctx, server.Features), &BlobClient{grpcClient}, nil
+}
+
+func NewRepositoryClient(ctx context.Context, server Server) (context.Context, *RepositoryClient, error) {
+ conn, err := getOrCreateConnection(server)
+ if err != nil {
+ return nil, nil, err
+ }
+ grpcClient := gitalypb.NewRepositoryServiceClient(conn)
+ return withOutgoingMetadata(ctx, server.Features), &RepositoryClient{grpcClient}, nil
+}
+
+// NewNamespaceClient is only used by the Gitaly integration tests at present
+func NewNamespaceClient(ctx context.Context, server Server) (context.Context, *NamespaceClient, error) {
+ conn, err := getOrCreateConnection(server)
+ if err != nil {
+ return nil, nil, err
+ }
+ grpcClient := gitalypb.NewNamespaceServiceClient(conn)
+ return withOutgoingMetadata(ctx, server.Features), &NamespaceClient{grpcClient}, nil
+}
+
+func NewDiffClient(ctx context.Context, server Server) (context.Context, *DiffClient, error) {
+ conn, err := getOrCreateConnection(server)
+ if err != nil {
+ return nil, nil, err
+ }
+ grpcClient := gitalypb.NewDiffServiceClient(conn)
+ return withOutgoingMetadata(ctx, server.Features), &DiffClient{grpcClient}, nil
+}
+
+func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
+ key := server.cacheKey()
+
+ cache.RLock()
+ conn := cache.connections[key]
+ cache.RUnlock()
+
+ if conn != nil {
+ return conn, nil
+ }
+
+ cache.Lock()
+ defer cache.Unlock()
+
+ if conn := cache.connections[key]; conn != nil {
+ return conn, nil
+ }
+
+ conn, err := newConnection(server)
+ if err != nil {
+ return nil, err
+ }
+
+ cache.connections[key] = conn
+
+ return conn, nil
+}
+
+func CloseConnections() {
+ cache.Lock()
+ defer cache.Unlock()
+
+ for _, conn := range cache.connections {
+ conn.Close()
+ }
+}
+
+func newConnection(server Server) (*grpc.ClientConn, error) {
+ connOpts := append(gitalyclient.DefaultDialOpts,
+ grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(server.Token)),
+ grpc.WithStreamInterceptor(
+ grpc_middleware.ChainStreamClient(
+ grpctracing.StreamClientTracingInterceptor(),
+ grpc_prometheus.StreamClientInterceptor,
+ grpccorrelation.StreamClientCorrelationInterceptor(
+ grpccorrelation.WithClientName("gitlab-workhorse"),
+ ),
+ ),
+ ),
+
+ grpc.WithUnaryInterceptor(
+ grpc_middleware.ChainUnaryClient(
+ grpctracing.UnaryClientTracingInterceptor(),
+ grpc_prometheus.UnaryClientInterceptor,
+ grpccorrelation.UnaryClientCorrelationInterceptor(
+ grpccorrelation.WithClientName("gitlab-workhorse"),
+ ),
+ ),
+ ),
+ )
+
+ conn, connErr := gitalyclient.Dial(server.Address, connOpts)
+
+ label := "ok"
+ if connErr != nil {
+ label = "fail"
+ }
+ connectionsTotal.WithLabelValues(label).Inc()
+
+ return conn, connErr
+}
+
+func UnmarshalJSON(s string, msg proto.Message) error {
+ return jsonUnMarshaler.Unmarshal(strings.NewReader(s), msg)
+}
diff --git a/workhorse/internal/gitaly/gitaly_test.go b/workhorse/internal/gitaly/gitaly_test.go
new file mode 100644
index 00000000000..b17fb5c1d7b
--- /dev/null
+++ b/workhorse/internal/gitaly/gitaly_test.go
@@ -0,0 +1,80 @@
+package gitaly
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc/metadata"
+)
+
+func TestNewSmartHTTPClient(t *testing.T) {
+ ctx, _, err := NewSmartHTTPClient(context.Background(), serverFixture())
+ require.NoError(t, err)
+ testOutgoingMetadata(t, ctx)
+}
+
+func TestNewBlobClient(t *testing.T) {
+ ctx, _, err := NewBlobClient(context.Background(), serverFixture())
+ require.NoError(t, err)
+ testOutgoingMetadata(t, ctx)
+}
+
+func TestNewRepositoryClient(t *testing.T) {
+ ctx, _, err := NewRepositoryClient(context.Background(), serverFixture())
+ require.NoError(t, err)
+ testOutgoingMetadata(t, ctx)
+}
+
+func TestNewNamespaceClient(t *testing.T) {
+ ctx, _, err := NewNamespaceClient(context.Background(), serverFixture())
+ require.NoError(t, err)
+ testOutgoingMetadata(t, ctx)
+}
+
+func TestNewDiffClient(t *testing.T) {
+ ctx, _, err := NewDiffClient(context.Background(), serverFixture())
+ require.NoError(t, err)
+ testOutgoingMetadata(t, ctx)
+}
+
+func testOutgoingMetadata(t *testing.T, ctx context.Context) {
+ md, ok := metadata.FromOutgoingContext(ctx)
+ require.True(t, ok, "get metadata from context")
+
+ for k, v := range allowedFeatures() {
+ actual := md[k]
+ require.Len(t, actual, 1, "expect one value for %v", k)
+ require.Equal(t, v, actual[0], "value for %v", k)
+ }
+
+ for k := range badFeatureMetadata() {
+ require.Empty(t, md[k], "value for bad key %v", k)
+ }
+}
+
+func serverFixture() Server {
+ features := make(map[string]string)
+ for k, v := range allowedFeatures() {
+ features[k] = v
+ }
+ for k, v := range badFeatureMetadata() {
+ features[k] = v
+ }
+
+ return Server{Address: "tcp://localhost:123", Features: features}
+}
+
+func allowedFeatures() map[string]string {
+ return map[string]string{
+ "gitaly-feature-foo": "bar",
+ "gitaly-feature-qux": "baz",
+ }
+}
+
+func badFeatureMetadata() map[string]string {
+ return map[string]string{
+ "bad-metadata-1": "bad-value-1",
+ "bad-metadata-2": "bad-value-2",
+ }
+}
diff --git a/workhorse/internal/gitaly/namespace.go b/workhorse/internal/gitaly/namespace.go
new file mode 100644
index 00000000000..6db6ed4fc32
--- /dev/null
+++ b/workhorse/internal/gitaly/namespace.go
@@ -0,0 +1,8 @@
+package gitaly
+
+import "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+
+// NamespaceClient encapsulates NamespaceService calls
+type NamespaceClient struct {
+ gitalypb.NamespaceServiceClient
+}
diff --git a/workhorse/internal/gitaly/repository.go b/workhorse/internal/gitaly/repository.go
new file mode 100644
index 00000000000..e3ec3257a85
--- /dev/null
+++ b/workhorse/internal/gitaly/repository.go
@@ -0,0 +1,45 @@
+package gitaly
+
+import (
+ "context"
+ "fmt"
+ "io"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitaly/streamio"
+)
+
+// RepositoryClient encapsulates RepositoryService calls
+type RepositoryClient struct {
+ gitalypb.RepositoryServiceClient
+}
+
+// ArchiveReader performs a GetArchive Gitaly request and returns an io.Reader
+// for the response
+func (client *RepositoryClient) ArchiveReader(ctx context.Context, request *gitalypb.GetArchiveRequest) (io.Reader, error) {
+ c, err := client.GetArchive(ctx, request)
+ if err != nil {
+ return nil, fmt.Errorf("RepositoryService::GetArchive: %v", err)
+ }
+
+ return streamio.NewReader(func() ([]byte, error) {
+ resp, err := c.Recv()
+
+ return resp.GetData(), err
+ }), nil
+}
+
+// SnapshotReader performs a GetSnapshot Gitaly request and returns an io.Reader
+// for the response
+func (client *RepositoryClient) SnapshotReader(ctx context.Context, request *gitalypb.GetSnapshotRequest) (io.Reader, error) {
+ c, err := client.GetSnapshot(ctx, request)
+ if err != nil {
+ return nil, fmt.Errorf("RepositoryService::GetSnapshot: %v", err)
+ }
+
+ return streamio.NewReader(func() ([]byte, error) {
+ resp, err := c.Recv()
+
+ return resp.GetData(), err
+ }), nil
+}
diff --git a/workhorse/internal/gitaly/smarthttp.go b/workhorse/internal/gitaly/smarthttp.go
new file mode 100644
index 00000000000..d1fe6fae5ba
--- /dev/null
+++ b/workhorse/internal/gitaly/smarthttp.go
@@ -0,0 +1,139 @@
+package gitaly
+
+import (
+ "context"
+ "fmt"
+ "io"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitaly/streamio"
+)
+
+type SmartHTTPClient struct {
+ gitalypb.SmartHTTPServiceClient
+}
+
+func (client *SmartHTTPClient) InfoRefsResponseReader(ctx context.Context, repo *gitalypb.Repository, rpc string, gitConfigOptions []string, gitProtocol string) (io.Reader, error) {
+ rpcRequest := &gitalypb.InfoRefsRequest{
+ Repository: repo,
+ GitConfigOptions: gitConfigOptions,
+ GitProtocol: gitProtocol,
+ }
+
+ switch rpc {
+ case "git-upload-pack":
+ stream, err := client.InfoRefsUploadPack(ctx, rpcRequest)
+ return infoRefsReader(stream), err
+ case "git-receive-pack":
+ stream, err := client.InfoRefsReceivePack(ctx, rpcRequest)
+ return infoRefsReader(stream), err
+ default:
+ return nil, fmt.Errorf("InfoRefsResponseWriterTo: Unsupported RPC: %q", rpc)
+ }
+}
+
+type infoRefsClient interface {
+ Recv() (*gitalypb.InfoRefsResponse, error)
+}
+
+func infoRefsReader(stream infoRefsClient) io.Reader {
+ return streamio.NewReader(func() ([]byte, error) {
+ resp, err := stream.Recv()
+ return resp.GetData(), err
+ })
+}
+
+func (client *SmartHTTPClient) ReceivePack(ctx context.Context, repo *gitalypb.Repository, glId string, glUsername string, glRepository string, gitConfigOptions []string, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error {
+ stream, err := client.PostReceivePack(ctx)
+ if err != nil {
+ return err
+ }
+
+ rpcRequest := &gitalypb.PostReceivePackRequest{
+ Repository: repo,
+ GlId: glId,
+ GlUsername: glUsername,
+ GlRepository: glRepository,
+ GitConfigOptions: gitConfigOptions,
+ GitProtocol: gitProtocol,
+ }
+
+ if err := stream.Send(rpcRequest); err != nil {
+ return fmt.Errorf("initial request: %v", err)
+ }
+
+ numStreams := 2
+ errC := make(chan error, numStreams)
+
+ go func() {
+ rr := streamio.NewReader(func() ([]byte, error) {
+ response, err := stream.Recv()
+ return response.GetData(), err
+ })
+ _, err := io.Copy(clientResponse, rr)
+ errC <- err
+ }()
+
+ go func() {
+ sw := streamio.NewWriter(func(data []byte) error {
+ return stream.Send(&gitalypb.PostReceivePackRequest{Data: data})
+ })
+ _, err := io.Copy(sw, clientRequest)
+ stream.CloseSend()
+ errC <- err
+ }()
+
+ for i := 0; i < numStreams; i++ {
+ if err := <-errC; err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (client *SmartHTTPClient) UploadPack(ctx context.Context, repo *gitalypb.Repository, clientRequest io.Reader, clientResponse io.Writer, gitConfigOptions []string, gitProtocol string) error {
+ stream, err := client.PostUploadPack(ctx)
+ if err != nil {
+ return err
+ }
+
+ rpcRequest := &gitalypb.PostUploadPackRequest{
+ Repository: repo,
+ GitConfigOptions: gitConfigOptions,
+ GitProtocol: gitProtocol,
+ }
+
+ if err := stream.Send(rpcRequest); err != nil {
+ return fmt.Errorf("initial request: %v", err)
+ }
+
+ numStreams := 2
+ errC := make(chan error, numStreams)
+
+ go func() {
+ rr := streamio.NewReader(func() ([]byte, error) {
+ response, err := stream.Recv()
+ return response.GetData(), err
+ })
+ _, err := io.Copy(clientResponse, rr)
+ errC <- err
+ }()
+
+ go func() {
+ sw := streamio.NewWriter(func(data []byte) error {
+ return stream.Send(&gitalypb.PostUploadPackRequest{Data: data})
+ })
+ _, err := io.Copy(sw, clientRequest)
+ stream.CloseSend()
+ errC <- err
+ }()
+
+ for i := 0; i < numStreams; i++ {
+ if err := <-errC; err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/gitaly/unmarshal_test.go b/workhorse/internal/gitaly/unmarshal_test.go
new file mode 100644
index 00000000000..e2256903339
--- /dev/null
+++ b/workhorse/internal/gitaly/unmarshal_test.go
@@ -0,0 +1,35 @@
+package gitaly
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+)
+
+func TestUnmarshalJSON(t *testing.T) {
+ testCases := []struct {
+ desc string
+ in string
+ out gitalypb.Repository
+ }{
+ {
+ desc: "basic example",
+ in: `{"relative_path":"foo/bar.git"}`,
+ out: gitalypb.Repository{RelativePath: "foo/bar.git"},
+ },
+ {
+ desc: "unknown field",
+ in: `{"relative_path":"foo/bar.git","unknown_field":12345}`,
+ out: gitalypb.Repository{RelativePath: "foo/bar.git"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ result := gitalypb.Repository{}
+ require.NoError(t, UnmarshalJSON(tc.in, &result))
+ require.Equal(t, tc.out, result)
+ })
+ }
+}
diff --git a/workhorse/internal/headers/content_headers.go b/workhorse/internal/headers/content_headers.go
new file mode 100644
index 00000000000..e43f10745d4
--- /dev/null
+++ b/workhorse/internal/headers/content_headers.go
@@ -0,0 +1,109 @@
+package headers
+
+import (
+ "net/http"
+ "regexp"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/utils/svg"
+)
+
+var (
+ ImageTypeRegex = regexp.MustCompile(`^image/*`)
+ SvgMimeTypeRegex = regexp.MustCompile(`^image/svg\+xml$`)
+
+ TextTypeRegex = regexp.MustCompile(`^text/*`)
+
+ VideoTypeRegex = regexp.MustCompile(`^video/*`)
+
+ PdfTypeRegex = regexp.MustCompile(`application\/pdf`)
+
+ AttachmentRegex = regexp.MustCompile(`^attachment`)
+ InlineRegex = regexp.MustCompile(`^inline`)
+)
+
+// Mime types that can't be inlined. Usually subtypes of main types
+var forbiddenInlineTypes = []*regexp.Regexp{SvgMimeTypeRegex}
+
+// Mime types that can be inlined. We can add global types like "image/" or
+// specific types like "text/plain". If there is a specific type inside a global
+// allowed type that can't be inlined we must add it to the forbiddenInlineTypes var.
+// One example of this is the mime type "image". We allow all images to be
+// inlined except for SVGs.
+var allowedInlineTypes = []*regexp.Regexp{ImageTypeRegex, TextTypeRegex, VideoTypeRegex, PdfTypeRegex}
+
+func SafeContentHeaders(data []byte, contentDisposition string) (string, string) {
+ contentType := safeContentType(data)
+ contentDisposition = safeContentDisposition(contentType, contentDisposition)
+ return contentType, contentDisposition
+}
+
+func safeContentType(data []byte) string {
+ // Special case for svg because DetectContentType detects it as text
+ if svg.Is(data) {
+ return "image/svg+xml"
+ }
+
+ // Override any existing Content-Type header from other ResponseWriters
+ contentType := http.DetectContentType(data)
+
+ // If the content is text type, we set to plain, because we don't
+ // want to render it inline if they're html or javascript
+ if isType(contentType, TextTypeRegex) {
+ return "text/plain; charset=utf-8"
+ }
+
+ return contentType
+}
+
+func safeContentDisposition(contentType string, contentDisposition string) string {
+ // If the existing disposition is attachment we return that. This allow us
+ // to force a download from GitLab (ie: RawController)
+ if AttachmentRegex.MatchString(contentDisposition) {
+ return contentDisposition
+ }
+
+ // Checks for mime types that are forbidden to be inline
+ for _, element := range forbiddenInlineTypes {
+ if isType(contentType, element) {
+ return attachmentDisposition(contentDisposition)
+ }
+ }
+
+ // Checks for mime types allowed to be inline
+ for _, element := range allowedInlineTypes {
+ if isType(contentType, element) {
+ return inlineDisposition(contentDisposition)
+ }
+ }
+
+ // Anything else is set to attachment
+ return attachmentDisposition(contentDisposition)
+}
+
+func attachmentDisposition(contentDisposition string) string {
+ if contentDisposition == "" {
+ return "attachment"
+ }
+
+ if InlineRegex.MatchString(contentDisposition) {
+ return InlineRegex.ReplaceAllString(contentDisposition, "attachment")
+ }
+
+ return contentDisposition
+}
+
+func inlineDisposition(contentDisposition string) string {
+ if contentDisposition == "" {
+ return "inline"
+ }
+
+ if AttachmentRegex.MatchString(contentDisposition) {
+ return AttachmentRegex.ReplaceAllString(contentDisposition, "inline")
+ }
+
+ return contentDisposition
+}
+
+func isType(contentType string, mimeType *regexp.Regexp) bool {
+ return mimeType.MatchString(contentType)
+}
diff --git a/workhorse/internal/headers/headers.go b/workhorse/internal/headers/headers.go
new file mode 100644
index 00000000000..63b39a6aa41
--- /dev/null
+++ b/workhorse/internal/headers/headers.go
@@ -0,0 +1,62 @@
+package headers
+
+import (
+ "net/http"
+ "strconv"
+)
+
+// Max number of bytes that http.DetectContentType needs to get the content type
+// Fixme: Go back to 512 bytes once https://gitlab.com/gitlab-org/gitlab-workhorse/issues/208
+// has been merged
+const MaxDetectSize = 4096
+
+// HTTP Headers
+const (
+ ContentDispositionHeader = "Content-Disposition"
+ ContentTypeHeader = "Content-Type"
+
+ // Workhorse related headers
+ GitlabWorkhorseSendDataHeader = "Gitlab-Workhorse-Send-Data"
+ XSendFileHeader = "X-Sendfile"
+ XSendFileTypeHeader = "X-Sendfile-Type"
+
+ // Signal header that indicates Workhorse should detect and set the content headers
+ GitlabWorkhorseDetectContentTypeHeader = "Gitlab-Workhorse-Detect-Content-Type"
+)
+
+var ResponseHeaders = []string{
+ XSendFileHeader,
+ GitlabWorkhorseSendDataHeader,
+ GitlabWorkhorseDetectContentTypeHeader,
+}
+
+func IsDetectContentTypeHeaderPresent(rw http.ResponseWriter) bool {
+ header, err := strconv.ParseBool(rw.Header().Get(GitlabWorkhorseDetectContentTypeHeader))
+ if err != nil || !header {
+ return false
+ }
+
+ return true
+}
+
+// AnyResponseHeaderPresent checks in the ResponseWriter if there is any Response Header
+func AnyResponseHeaderPresent(rw http.ResponseWriter) bool {
+ // If this header is not present means that we want the old behavior
+ if !IsDetectContentTypeHeaderPresent(rw) {
+ return false
+ }
+
+ for _, header := range ResponseHeaders {
+ if rw.Header().Get(header) != "" {
+ return true
+ }
+ }
+ return false
+}
+
+// RemoveResponseHeaders removes any ResponseHeader from the ResponseWriter
+func RemoveResponseHeaders(rw http.ResponseWriter) {
+ for _, header := range ResponseHeaders {
+ rw.Header().Del(header)
+ }
+}
diff --git a/workhorse/internal/headers/headers_test.go b/workhorse/internal/headers/headers_test.go
new file mode 100644
index 00000000000..555406ff165
--- /dev/null
+++ b/workhorse/internal/headers/headers_test.go
@@ -0,0 +1,24 @@
+package headers
+
+import (
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsDetectContentTypeHeaderPresent(t *testing.T) {
+ rw := httptest.NewRecorder()
+
+ rw.Header().Del(GitlabWorkhorseDetectContentTypeHeader)
+ require.Equal(t, false, IsDetectContentTypeHeaderPresent(rw))
+
+ rw.Header().Set(GitlabWorkhorseDetectContentTypeHeader, "true")
+ require.Equal(t, true, IsDetectContentTypeHeaderPresent(rw))
+
+ rw.Header().Set(GitlabWorkhorseDetectContentTypeHeader, "false")
+ require.Equal(t, false, IsDetectContentTypeHeaderPresent(rw))
+
+ rw.Header().Set(GitlabWorkhorseDetectContentTypeHeader, "foobar")
+ require.Equal(t, false, IsDetectContentTypeHeaderPresent(rw))
+}
diff --git a/workhorse/internal/helper/context_reader.go b/workhorse/internal/helper/context_reader.go
new file mode 100644
index 00000000000..a4764043147
--- /dev/null
+++ b/workhorse/internal/helper/context_reader.go
@@ -0,0 +1,40 @@
+package helper
+
+import (
+ "context"
+ "io"
+)
+
+type ContextReader struct {
+ ctx context.Context
+ underlyingReader io.Reader
+}
+
+func NewContextReader(ctx context.Context, underlyingReader io.Reader) *ContextReader {
+ return &ContextReader{
+ ctx: ctx,
+ underlyingReader: underlyingReader,
+ }
+}
+
+func (r *ContextReader) Read(b []byte) (int, error) {
+ if r.canceled() {
+ return 0, r.err()
+ }
+
+ n, err := r.underlyingReader.Read(b)
+
+ if r.canceled() {
+ err = r.err()
+ }
+
+ return n, err
+}
+
+func (r *ContextReader) canceled() bool {
+ return r.err() != nil
+}
+
+func (r *ContextReader) err() error {
+ return r.ctx.Err()
+}
diff --git a/workhorse/internal/helper/context_reader_test.go b/workhorse/internal/helper/context_reader_test.go
new file mode 100644
index 00000000000..257ec4e35f2
--- /dev/null
+++ b/workhorse/internal/helper/context_reader_test.go
@@ -0,0 +1,83 @@
+package helper
+
+import (
+ "context"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type fakeReader struct {
+ n int
+ err error
+}
+
+func (f *fakeReader) Read(b []byte) (int, error) {
+ return f.n, f.err
+}
+
+type fakeContextWithTimeout struct {
+ n int
+ threshold int
+}
+
+func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) {
+ return
+}
+
+func (*fakeContextWithTimeout) Done() <-chan struct{} {
+ return nil
+}
+
+func (*fakeContextWithTimeout) Value(key interface{}) interface{} {
+ return nil
+}
+
+func (f *fakeContextWithTimeout) Err() error {
+ f.n++
+ if f.n > f.threshold {
+ return context.DeadlineExceeded
+ }
+
+ return nil
+}
+
+func TestContextReaderRead(t *testing.T) {
+ underlyingReader := &fakeReader{n: 1, err: io.EOF}
+
+ for _, tc := range []struct {
+ desc string
+ ctx *fakeContextWithTimeout
+ expectedN int
+ expectedErr error
+ }{
+ {
+ desc: "Before and after read deadline checks are fine",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 2},
+ expectedN: underlyingReader.n,
+ expectedErr: underlyingReader.err,
+ },
+ {
+ desc: "Before read deadline check fails",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 0},
+ expectedN: 0,
+ expectedErr: context.DeadlineExceeded,
+ },
+ {
+ desc: "After read deadline check fails",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 1},
+ expectedN: underlyingReader.n,
+ expectedErr: context.DeadlineExceeded,
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ cr := NewContextReader(tc.ctx, underlyingReader)
+
+ n, err := cr.Read(nil)
+ require.Equal(t, tc.expectedN, n)
+ require.Equal(t, tc.expectedErr, err)
+ })
+ }
+}
diff --git a/workhorse/internal/helper/countingresponsewriter.go b/workhorse/internal/helper/countingresponsewriter.go
new file mode 100644
index 00000000000..a79d51d4c6a
--- /dev/null
+++ b/workhorse/internal/helper/countingresponsewriter.go
@@ -0,0 +1,56 @@
+package helper
+
+import (
+ "net/http"
+)
+
+type CountingResponseWriter interface {
+ http.ResponseWriter
+ Count() int64
+ Status() int
+}
+
+type countingResponseWriter struct {
+ rw http.ResponseWriter
+ status int
+ count int64
+}
+
+func NewCountingResponseWriter(rw http.ResponseWriter) CountingResponseWriter {
+ return &countingResponseWriter{rw: rw}
+}
+
+func (c *countingResponseWriter) Header() http.Header {
+ return c.rw.Header()
+}
+
+func (c *countingResponseWriter) Write(data []byte) (int, error) {
+ if c.status == 0 {
+ c.WriteHeader(http.StatusOK)
+ }
+
+ n, err := c.rw.Write(data)
+ c.count += int64(n)
+ return n, err
+}
+
+func (c *countingResponseWriter) WriteHeader(status int) {
+ if c.status != 0 {
+ return
+ }
+
+ c.status = status
+ c.rw.WriteHeader(status)
+}
+
+// Count returns the number of bytes written to the ResponseWriter. This
+// function is not thread-safe.
+func (c *countingResponseWriter) Count() int64 {
+ return c.count
+}
+
+// Status returns the first HTTP status value that was written to the
+// ResponseWriter. This function is not thread-safe.
+func (c *countingResponseWriter) Status() int {
+ return c.status
+}
diff --git a/workhorse/internal/helper/countingresponsewriter_test.go b/workhorse/internal/helper/countingresponsewriter_test.go
new file mode 100644
index 00000000000..f9f2f4ced5b
--- /dev/null
+++ b/workhorse/internal/helper/countingresponsewriter_test.go
@@ -0,0 +1,50 @@
+package helper
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "testing"
+ "testing/iotest"
+
+ "github.com/stretchr/testify/require"
+)
+
+type testResponseWriter struct {
+ data []byte
+}
+
+func (*testResponseWriter) WriteHeader(int) {}
+func (*testResponseWriter) Header() http.Header { return nil }
+
+func (trw *testResponseWriter) Write(p []byte) (int, error) {
+ trw.data = append(trw.data, p...)
+ return len(p), nil
+}
+
+func TestCountingResponseWriterStatus(t *testing.T) {
+ crw := NewCountingResponseWriter(&testResponseWriter{})
+ crw.WriteHeader(123)
+ crw.WriteHeader(456)
+ require.Equal(t, 123, crw.Status())
+}
+
+func TestCountingResponseWriterCount(t *testing.T) {
+ crw := NewCountingResponseWriter(&testResponseWriter{})
+ for _, n := range []int{1, 2, 4, 8, 16, 32} {
+ _, err := crw.Write(bytes.Repeat([]byte{'.'}, n))
+ require.NoError(t, err)
+ }
+ require.Equal(t, int64(63), crw.Count())
+}
+
+func TestCountingResponseWriterWrite(t *testing.T) {
+ trw := &testResponseWriter{}
+ crw := NewCountingResponseWriter(trw)
+
+ testData := []byte("test data")
+ _, err := io.Copy(crw, iotest.OneByteReader(bytes.NewReader(testData)))
+ require.NoError(t, err)
+
+ require.Equal(t, string(testData), string(trw.data))
+}
diff --git a/workhorse/internal/helper/helpers.go b/workhorse/internal/helper/helpers.go
new file mode 100644
index 00000000000..5f1e5fc51b3
--- /dev/null
+++ b/workhorse/internal/helper/helpers.go
@@ -0,0 +1,217 @@
+package helper
+
+import (
+ "bytes"
+ "errors"
+ "io/ioutil"
+ "mime"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "os/exec"
+ "strings"
+ "syscall"
+
+ "github.com/sebest/xff"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+)
+
+const NginxResponseBufferHeader = "X-Accel-Buffering"
+
+func LogError(r *http.Request, err error) {
+ LogErrorWithFields(r, err, nil)
+}
+
+func LogErrorWithFields(r *http.Request, err error, fields log.Fields) {
+ if err != nil {
+ captureRavenError(r, err, fields)
+ }
+
+ printError(r, err, fields)
+}
+
+func CaptureAndFail(w http.ResponseWriter, r *http.Request, err error, msg string, code int) {
+ http.Error(w, msg, code)
+ LogError(r, err)
+}
+
+func CaptureAndFailWithFields(w http.ResponseWriter, r *http.Request, err error, msg string, code int, fields log.Fields) {
+ http.Error(w, msg, code)
+ LogErrorWithFields(r, err, fields)
+}
+
+func Fail500(w http.ResponseWriter, r *http.Request, err error) {
+ CaptureAndFail(w, r, err, "Internal server error", http.StatusInternalServerError)
+}
+
+func Fail500WithFields(w http.ResponseWriter, r *http.Request, err error, fields log.Fields) {
+ CaptureAndFailWithFields(w, r, err, "Internal server error", http.StatusInternalServerError, fields)
+}
+
+func RequestEntityTooLarge(w http.ResponseWriter, r *http.Request, err error) {
+ CaptureAndFail(w, r, err, "Request Entity Too Large", http.StatusRequestEntityTooLarge)
+}
+
+func printError(r *http.Request, err error, fields log.Fields) {
+ if r != nil {
+ entry := log.WithContextFields(r.Context(), log.Fields{
+ "method": r.Method,
+ "uri": mask.URL(r.RequestURI),
+ })
+ entry.WithFields(fields).WithError(err).Error("error")
+ } else {
+ log.WithFields(fields).WithError(err).Error("unknown error")
+ }
+}
+
+func SetNoCacheHeaders(header http.Header) {
+ header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
+ header.Set("Pragma", "no-cache")
+ header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
+}
+
+func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
+ file, err = os.Open(path)
+ if err != nil {
+ return
+ }
+
+ defer func() {
+ if err != nil {
+ file.Close()
+ }
+ }()
+
+ fi, err = file.Stat()
+ if err != nil {
+ return
+ }
+
+ // The os.Open can also open directories
+ if fi.IsDir() {
+ err = &os.PathError{
+ Op: "open",
+ Path: path,
+ Err: errors.New("path is directory"),
+ }
+ return
+ }
+
+ return
+}
+
+func URLMustParse(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ log.WithError(err).WithField("url", s).Fatal("urlMustParse")
+ }
+ return u
+}
+
+func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) {
+ if r.ProtoAtLeast(1, 1) {
+ // Force client to disconnect if we render request error
+ w.Header().Set("Connection", "close")
+ }
+
+ http.Error(w, error, code)
+}
+
+func HeaderClone(h http.Header) http.Header {
+ h2 := make(http.Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+func CleanUpProcessGroup(cmd *exec.Cmd) {
+ if cmd == nil {
+ return
+ }
+
+ process := cmd.Process
+ if process != nil && process.Pid > 0 {
+ // Send SIGTERM to the process group of cmd
+ syscall.Kill(-process.Pid, syscall.SIGTERM)
+ }
+
+ // reap our child process
+ cmd.Wait()
+}
+
+func ExitStatus(err error) (int, bool) {
+ exitError, ok := err.(*exec.ExitError)
+ if !ok {
+ return 0, false
+ }
+
+ waitStatus, ok := exitError.Sys().(syscall.WaitStatus)
+ if !ok {
+ return 0, false
+ }
+
+ return waitStatus.ExitStatus(), true
+}
+
+func DisableResponseBuffering(w http.ResponseWriter) {
+ w.Header().Set(NginxResponseBufferHeader, "no")
+}
+
+func AllowResponseBuffering(w http.ResponseWriter) {
+ w.Header().Del(NginxResponseBufferHeader)
+}
+
+func FixRemoteAddr(r *http.Request) {
+ // Unix domain sockets have a remote addr of @. This will make the
+ // xff package lookup the X-Forwarded-For address if available.
+ if r.RemoteAddr == "@" {
+ r.RemoteAddr = "127.0.0.1:0"
+ }
+ r.RemoteAddr = xff.GetRemoteAddr(r)
+}
+
+func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) {
+ if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil {
+ var header string
+
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok {
+ header = strings.Join(prior, ", ") + ", " + clientIP
+ } else {
+ header = clientIP
+ }
+ newHeaders.Set("X-Forwarded-For", header)
+ }
+}
+
+func IsContentType(expected, actual string) bool {
+ parsed, _, err := mime.ParseMediaType(actual)
+ return err == nil && parsed == expected
+}
+
+func IsApplicationJson(r *http.Request) bool {
+ contentType := r.Header.Get("Content-Type")
+ return IsContentType("application/json", contentType)
+}
+
+func ReadRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) {
+ limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize)
+ defer limitedBody.Close()
+
+ return ioutil.ReadAll(limitedBody)
+}
+
+func CloneRequestWithNewBody(r *http.Request, body []byte) *http.Request {
+ newReq := *r
+ newReq.Body = ioutil.NopCloser(bytes.NewReader(body))
+ newReq.Header = HeaderClone(r.Header)
+ newReq.ContentLength = int64(len(body))
+ return &newReq
+}
diff --git a/workhorse/internal/helper/helpers_test.go b/workhorse/internal/helper/helpers_test.go
new file mode 100644
index 00000000000..6a895aded03
--- /dev/null
+++ b/workhorse/internal/helper/helpers_test.go
@@ -0,0 +1,258 @@
+package helper
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+)
+
+func TestFixRemoteAddr(t *testing.T) {
+ testCases := []struct {
+ initial string
+ forwarded string
+ expected string
+ }{
+ {initial: "@", forwarded: "", expected: "127.0.0.1:0"},
+ {initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
+ {initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"},
+ {initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"},
+ {initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"},
+ {initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
+ }
+
+ for _, tc := range testCases {
+ req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil)
+ require.NoError(t, err)
+
+ req.RemoteAddr = tc.initial
+
+ if tc.forwarded != "" {
+ req.Header.Add("X-Forwarded-For", tc.forwarded)
+ }
+
+ FixRemoteAddr(req)
+
+ require.Equal(t, tc.expected, req.RemoteAddr)
+ }
+}
+
+func TestSetForwardedForGeneratesHeader(t *testing.T) {
+ testCases := []struct {
+ remoteAddr string
+ previousForwardedFor []string
+ expected string
+ }{
+ {
+ "8.8.8.8:3000",
+ nil,
+ "8.8.8.8",
+ },
+ {
+ "8.8.8.8:3000",
+ []string{"138.124.33.63, 151.146.211.237"},
+ "138.124.33.63, 151.146.211.237, 8.8.8.8",
+ },
+ {
+ "8.8.8.8:3000",
+ []string{"8.154.76.107", "115.206.118.179"},
+ "8.154.76.107, 115.206.118.179, 8.8.8.8",
+ },
+ }
+ for _, tc := range testCases {
+ headers := http.Header{}
+ originalRequest := http.Request{
+ RemoteAddr: tc.remoteAddr,
+ }
+
+ if tc.previousForwardedFor != nil {
+ originalRequest.Header = http.Header{
+ "X-Forwarded-For": tc.previousForwardedFor,
+ }
+ }
+
+ SetForwardedFor(&headers, &originalRequest)
+
+ result := headers.Get("X-Forwarded-For")
+ if result != tc.expected {
+ t.Fatalf("Expected %v, got %v", tc.expected, result)
+ }
+ }
+}
+
+func TestReadRequestBody(t *testing.T) {
+ data := []byte("123456")
+ rw := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
+
+ result, err := ReadRequestBody(rw, req, 1000)
+ require.NoError(t, err)
+ require.Equal(t, data, result)
+}
+
+func TestReadRequestBodyLimit(t *testing.T) {
+ data := []byte("123456")
+ rw := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
+
+ _, err := ReadRequestBody(rw, req, 2)
+ require.Error(t, err)
+}
+
+func TestCloneRequestWithBody(t *testing.T) {
+ input := []byte("test")
+ newInput := []byte("new body")
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input))
+ newReq := CloneRequestWithNewBody(req, newInput)
+
+ require.NotEqual(t, req, newReq)
+ require.NotEqual(t, req.Body, newReq.Body)
+ require.NotEqual(t, len(newInput), newReq.ContentLength)
+
+ var buffer bytes.Buffer
+ io.Copy(&buffer, newReq.Body)
+ require.Equal(t, newInput, buffer.Bytes())
+}
+
+func TestApplicationJson(t *testing.T) {
+ req, _ := http.NewRequest("POST", "/test", nil)
+ req.Header.Set("Content-Type", "application/json")
+
+ require.True(t, IsApplicationJson(req), "expected to match 'application/json' as 'application/json'")
+
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+ require.True(t, IsApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'")
+
+ req.Header.Set("Content-Type", "text/plain")
+ require.False(t, IsApplicationJson(req), "expected not to match 'text/plain' as 'application/json'")
+}
+
+func TestFail500WorksWithNils(t *testing.T) {
+ body := bytes.NewBuffer(nil)
+ w := httptest.NewRecorder()
+ w.Body = body
+
+ Fail500(w, nil, nil)
+
+ require.Equal(t, http.StatusInternalServerError, w.Code)
+ require.Equal(t, "Internal server error\n", body.String())
+}
+
+func TestLogError(t *testing.T) {
+ tests := []struct {
+ name string
+ method string
+ uri string
+ err error
+ logMatchers []string
+ }{
+ {
+ name: "nil_request",
+ err: fmt.Errorf("crash"),
+ logMatchers: []string{
+ `level=error msg="unknown error" error=crash`,
+ },
+ },
+ {
+ name: "nil_request_nil_error",
+ err: nil,
+ logMatchers: []string{
+ `level=error msg="unknown error" error="<nil>"`,
+ },
+ },
+ {
+ name: "basic_url",
+ method: "GET",
+ uri: "http://localhost:3000/",
+ err: fmt.Errorf("error"),
+ logMatchers: []string{
+ `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/"`,
+ },
+ },
+ {
+ name: "secret_url",
+ method: "GET",
+ uri: "http://localhost:3000/path?certificate=123&sharedSecret=123&import_url=the_url&my_password_string=password",
+ err: fmt.Errorf("error"),
+ logMatchers: []string{
+ `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/path\?certificate=\[FILTERED\]&sharedSecret=\[FILTERED\]&import_url=\[FILTERED\]&my_password_string=\[FILTERED\]"`,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := &bytes.Buffer{}
+
+ oldOut := logrus.StandardLogger().Out
+ logrus.StandardLogger().Out = buf
+ defer func() {
+ logrus.StandardLogger().Out = oldOut
+ }()
+
+ var r *http.Request
+ if tt.uri != "" {
+ r = httptest.NewRequest(tt.method, tt.uri, nil)
+ }
+
+ LogError(r, tt.err)
+
+ logString := buf.String()
+ for _, v := range tt.logMatchers {
+ require.Regexp(t, v, logString)
+ }
+ })
+ }
+}
+
+func TestLogErrorWithFields(t *testing.T) {
+ tests := []struct {
+ name string
+ request *http.Request
+ err error
+ fields map[string]interface{}
+ logMatcher string
+ }{
+ {
+ name: "nil_request",
+ err: fmt.Errorf("crash"),
+ fields: map[string]interface{}{"extra_one": 123},
+ logMatcher: `level=error msg="unknown error" error=crash extra_one=123`,
+ },
+ {
+ name: "nil_request_nil_error",
+ err: nil,
+ fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"},
+ logMatcher: `level=error msg="unknown error" error="<nil>" extra_one=123 extra_two=test`,
+ },
+ {
+ name: "basic_url",
+ request: httptest.NewRequest("GET", "http://localhost:3000/", nil),
+ err: fmt.Errorf("error"),
+ fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"},
+ logMatcher: `level=error msg=error correlation_id= error=error extra_one=123 extra_two=test method=GET uri="http://localhost:3000/`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := &bytes.Buffer{}
+
+ oldOut := logrus.StandardLogger().Out
+ logrus.StandardLogger().Out = buf
+ defer func() {
+ logrus.StandardLogger().Out = oldOut
+ }()
+
+ LogErrorWithFields(tt.request, tt.err, tt.fields)
+
+ logString := buf.String()
+ require.Contains(t, logString, tt.logMatcher)
+ })
+ }
+}
diff --git a/workhorse/internal/helper/raven.go b/workhorse/internal/helper/raven.go
new file mode 100644
index 00000000000..ea1d0e1f6cc
--- /dev/null
+++ b/workhorse/internal/helper/raven.go
@@ -0,0 +1,58 @@
+package helper
+
+import (
+ "net/http"
+ "reflect"
+
+ raven "github.com/getsentry/raven-go"
+
+ //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package.
+ correlation "gitlab.com/gitlab-org/labkit/correlation/raven"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+var ravenHeaderBlacklist = []string{
+ "Authorization",
+ "Private-Token",
+}
+
+func captureRavenError(r *http.Request, err error, fields log.Fields) {
+ client := raven.DefaultClient
+ extra := raven.Extra{}
+
+ for k, v := range fields {
+ extra[k] = v
+ }
+
+ interfaces := []raven.Interface{}
+ if r != nil {
+ CleanHeadersForRaven(r)
+ interfaces = append(interfaces, raven.NewHttp(r))
+
+ //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package.
+ extra = correlation.SetExtra(r.Context(), extra)
+ }
+
+ exception := &raven.Exception{
+ Stacktrace: raven.NewStacktrace(2, 3, nil),
+ Value: err.Error(),
+ Type: reflect.TypeOf(err).String(),
+ }
+ interfaces = append(interfaces, exception)
+
+ packet := raven.NewPacketWithExtra(err.Error(), extra, interfaces...)
+ client.Capture(packet, nil)
+}
+
+func CleanHeadersForRaven(r *http.Request) {
+ if r == nil {
+ return
+ }
+
+ for _, key := range ravenHeaderBlacklist {
+ if r.Header.Get(key) != "" {
+ r.Header.Set(key, "[redacted]")
+ }
+ }
+}
diff --git a/workhorse/internal/helper/tempfile.go b/workhorse/internal/helper/tempfile.go
new file mode 100644
index 00000000000..d8fc0d44698
--- /dev/null
+++ b/workhorse/internal/helper/tempfile.go
@@ -0,0 +1,35 @@
+package helper
+
+import (
+ "io"
+ "io/ioutil"
+ "os"
+)
+
+func ReadAllTempfile(r io.Reader) (tempfile *os.File, err error) {
+ tempfile, err = ioutil.TempFile("", "gitlab-workhorse-read-all-tempfile")
+ if err != nil {
+ return nil, err
+ }
+
+ defer func() {
+ // Avoid leaking an open file if the function returns with an error
+ if err != nil {
+ tempfile.Close()
+ }
+ }()
+
+ if err := os.Remove(tempfile.Name()); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.Copy(tempfile, r); err != nil {
+ return nil, err
+ }
+
+ if _, err := tempfile.Seek(0, 0); err != nil {
+ return nil, err
+ }
+
+ return tempfile, nil
+}
diff --git a/workhorse/internal/helper/writeafterreader.go b/workhorse/internal/helper/writeafterreader.go
new file mode 100644
index 00000000000..d583ae4a9b8
--- /dev/null
+++ b/workhorse/internal/helper/writeafterreader.go
@@ -0,0 +1,144 @@
+package helper
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "sync"
+)
+
+type WriteFlusher interface {
+ io.Writer
+ Flush() error
+}
+
+// Couple r and w so that until r has been drained (before r.Read() has
+// returned some error), all writes to w are sent to a tempfile first.
+// The caller must call Flush() on the returned WriteFlusher to ensure
+// all data is propagated to w.
+func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) {
+ br := &busyReader{Reader: r}
+ return br, &coupledWriter{Writer: w, busyReader: br}
+}
+
+type busyReader struct {
+ io.Reader
+
+ error
+ errorMutex sync.RWMutex
+}
+
+func (r *busyReader) Read(p []byte) (int, error) {
+ if err := r.getError(); err != nil {
+ return 0, err
+ }
+
+ n, err := r.Reader.Read(p)
+ if err != nil {
+ if err != io.EOF {
+ err = fmt.Errorf("busyReader: %v", err)
+ }
+ r.setError(err)
+ }
+ return n, err
+}
+
+func (r *busyReader) IsBusy() bool {
+ return r.getError() == nil
+}
+
+func (r *busyReader) getError() error {
+ r.errorMutex.RLock()
+ defer r.errorMutex.RUnlock()
+ return r.error
+}
+
+func (r *busyReader) setError(err error) {
+ if err == nil {
+ panic("busyReader: attempt to reset error to nil")
+ }
+ r.errorMutex.Lock()
+ defer r.errorMutex.Unlock()
+ r.error = err
+}
+
+type coupledWriter struct {
+ io.Writer
+ *busyReader
+
+ tempfile *os.File
+ tempfileMutex sync.Mutex
+
+ writeError error
+}
+
+func (w *coupledWriter) Write(data []byte) (int, error) {
+ if w.writeError != nil {
+ return 0, w.writeError
+ }
+
+ if w.busyReader.IsBusy() {
+ n, err := w.tempfileWrite(data)
+ if err != nil {
+ w.writeError = fmt.Errorf("coupledWriter: %v", err)
+ }
+ return n, w.writeError
+ }
+
+ if err := w.Flush(); err != nil {
+ w.writeError = fmt.Errorf("coupledWriter: %v", err)
+ return 0, w.writeError
+ }
+
+ return w.Writer.Write(data)
+}
+
+func (w *coupledWriter) Flush() error {
+ w.tempfileMutex.Lock()
+ defer w.tempfileMutex.Unlock()
+
+ tempfile := w.tempfile
+ if tempfile == nil {
+ return nil
+ }
+
+ w.tempfile = nil
+ defer tempfile.Close()
+
+ if _, err := tempfile.Seek(0, 0); err != nil {
+ return err
+ }
+ if _, err := io.Copy(w.Writer, tempfile); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (w *coupledWriter) tempfileWrite(data []byte) (int, error) {
+ w.tempfileMutex.Lock()
+ defer w.tempfileMutex.Unlock()
+
+ if w.tempfile == nil {
+ tempfile, err := w.newTempfile()
+ if err != nil {
+ return 0, err
+ }
+ w.tempfile = tempfile
+ }
+
+ return w.tempfile.Write(data)
+}
+
+func (*coupledWriter) newTempfile() (tempfile *os.File, err error) {
+ tempfile, err = ioutil.TempFile("", "gitlab-workhorse-coupledWriter")
+ if err != nil {
+ return nil, err
+ }
+ if err := os.Remove(tempfile.Name()); err != nil {
+ tempfile.Close()
+ return nil, err
+ }
+
+ return tempfile, nil
+}
diff --git a/workhorse/internal/helper/writeafterreader_test.go b/workhorse/internal/helper/writeafterreader_test.go
new file mode 100644
index 00000000000..67cb3e6e542
--- /dev/null
+++ b/workhorse/internal/helper/writeafterreader_test.go
@@ -0,0 +1,115 @@
+package helper
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "testing"
+ "testing/iotest"
+)
+
+func TestBusyReader(t *testing.T) {
+ testData := "test data"
+ r := testReader(testData)
+ br, _ := NewWriteAfterReader(r, &bytes.Buffer{})
+
+ result, err := ioutil.ReadAll(br)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if string(result) != testData {
+ t.Fatalf("expected %q, got %q", testData, result)
+ }
+}
+
+func TestFirstWriteAfterReadDone(t *testing.T) {
+ writeRecorder := &bytes.Buffer{}
+ br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder)
+ if _, err := io.Copy(ioutil.Discard, br); err != nil {
+ t.Fatalf("copy from busyreader: %v", err)
+ }
+ testData := "test data"
+ if _, err := io.Copy(cw, testReader(testData)); err != nil {
+ t.Fatalf("copy test data: %v", err)
+ }
+ if err := cw.Flush(); err != nil {
+ t.Fatalf("flush error: %v", err)
+ }
+ if result := writeRecorder.String(); result != testData {
+ t.Fatalf("expected %q, got %q", testData, result)
+ }
+}
+
+func TestWriteDelay(t *testing.T) {
+ writeRecorder := &bytes.Buffer{}
+ w := &complainingWriter{Writer: writeRecorder}
+ br, cw := NewWriteAfterReader(&bytes.Buffer{}, w)
+
+ testData1 := "1 test"
+ if _, err := io.Copy(cw, testReader(testData1)); err != nil {
+ t.Fatalf("error on first copy: %v", err)
+ }
+
+ // Unblock the coupled writer by draining the reader
+ if _, err := io.Copy(ioutil.Discard, br); err != nil {
+ t.Fatalf("copy from busyreader: %v", err)
+ }
+ // Now it is no longer an error if 'w' receives a Write()
+ w.CheerUp()
+
+ testData2 := "2 experiment"
+ if _, err := io.Copy(cw, testReader(testData2)); err != nil {
+ t.Fatalf("error on second copy: %v", err)
+ }
+
+ if err := cw.Flush(); err != nil {
+ t.Fatalf("flush error: %v", err)
+ }
+
+ expected := testData1 + testData2
+ if result := writeRecorder.String(); result != expected {
+ t.Fatalf("total write: expected %q, got %q", expected, result)
+ }
+}
+
+func TestComplainingWriterSanity(t *testing.T) {
+ recorder := &bytes.Buffer{}
+ w := &complainingWriter{Writer: recorder}
+
+ testData := "test data"
+ if _, err := io.Copy(w, testReader(testData)); err == nil {
+ t.Error("error expected, none received")
+ }
+
+ w.CheerUp()
+ if _, err := io.Copy(w, testReader(testData)); err != nil {
+ t.Errorf("copy after CheerUp: %v", err)
+ }
+
+ if result := recorder.String(); result != testData {
+ t.Errorf("expected %q, got %q", testData, result)
+ }
+}
+
+func testReader(data string) io.Reader {
+ return iotest.OneByteReader(bytes.NewBuffer([]byte(data)))
+}
+
+type complainingWriter struct {
+ happy bool
+ io.Writer
+}
+
+func (comp *complainingWriter) Write(data []byte) (int, error) {
+ if comp.happy {
+ return comp.Writer.Write(data)
+ }
+
+ return 0, fmt.Errorf("I am unhappy about you wanting to write %q", data)
+}
+
+func (comp *complainingWriter) CheerUp() {
+ comp.happy = true
+}
diff --git a/workhorse/internal/httprs/LICENSE b/workhorse/internal/httprs/LICENSE
new file mode 100644
index 00000000000..58b9dd5ced1
--- /dev/null
+++ b/workhorse/internal/httprs/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2015 Jean-François Bustarret
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE. \ No newline at end of file
diff --git a/workhorse/internal/httprs/README.md b/workhorse/internal/httprs/README.md
new file mode 100644
index 00000000000..4f42489ab73
--- /dev/null
+++ b/workhorse/internal/httprs/README.md
@@ -0,0 +1,2 @@
+This directory contains a vendored copy of https://github.com/jfbus/httprs at commit SHA
+b0af8319bb15446bbf29715477f841a49330a1e7.
diff --git a/workhorse/internal/httprs/httprs.go b/workhorse/internal/httprs/httprs.go
new file mode 100644
index 00000000000..a38230c1968
--- /dev/null
+++ b/workhorse/internal/httprs/httprs.go
@@ -0,0 +1,217 @@
+/*
+Package httprs provides a ReadSeeker for http.Response.Body.
+
+Usage :
+
+ resp, err := http.Get(url)
+ rs := httprs.NewHttpReadSeeker(resp)
+ defer rs.Close()
+ io.ReadFull(rs, buf) // reads the first bytes from the response body
+ rs.Seek(1024, 0) // moves the position, but does no range request
+ io.ReadFull(rs, buf) // does a range request and reads from the response body
+
+If you want use a specific http.Client for additional range requests :
+ rs := httprs.NewHttpReadSeeker(resp, client)
+*/
+package httprs
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+
+ "github.com/mitchellh/copystructure"
+)
+
+const shortSeekBytes = 1024
+
+// A HttpReadSeeker reads from a http.Response.Body. It can Seek
+// by doing range requests.
+type HttpReadSeeker struct {
+ c *http.Client
+ req *http.Request
+ res *http.Response
+ ctx context.Context
+ r io.ReadCloser
+ pos int64
+
+ Requests int
+}
+
+var _ io.ReadCloser = (*HttpReadSeeker)(nil)
+var _ io.Seeker = (*HttpReadSeeker)(nil)
+
+var (
+ // ErrNoContentLength is returned by Seek when the initial http response did not include a Content-Length header
+ ErrNoContentLength = errors.New("header Content-Length was not set")
+ // ErrRangeRequestsNotSupported is returned by Seek and Read
+ // when the remote server does not allow range requests (Accept-Ranges was not set)
+ ErrRangeRequestsNotSupported = errors.New("range requests are not supported by the remote server")
+ // ErrInvalidRange is returned by Read when trying to read past the end of the file
+ ErrInvalidRange = errors.New("invalid range")
+ // ErrContentHasChanged is returned by Read when the content has changed since the first request
+ ErrContentHasChanged = errors.New("content has changed since first request")
+)
+
+// NewHttpReadSeeker returns a HttpReadSeeker, using the http.Response and, optionaly, the http.Client
+// that needs to be used for future range requests. If no http.Client is given, http.DefaultClient will
+// be used.
+//
+// res.Request will be reused for range requests, headers may be added/removed
+func NewHttpReadSeeker(res *http.Response, client ...*http.Client) *HttpReadSeeker {
+ r := &HttpReadSeeker{
+ req: res.Request,
+ ctx: res.Request.Context(),
+ res: res,
+ r: res.Body,
+ }
+ if len(client) > 0 {
+ r.c = client[0]
+ } else {
+ r.c = http.DefaultClient
+ }
+ return r
+}
+
+// Clone clones the reader to enable parallel downloads of ranges
+func (r *HttpReadSeeker) Clone() (*HttpReadSeeker, error) {
+ req, err := copystructure.Copy(r.req)
+ if err != nil {
+ return nil, err
+ }
+ return &HttpReadSeeker{
+ req: req.(*http.Request),
+ res: r.res,
+ r: nil,
+ c: r.c,
+ }, nil
+}
+
+// Read reads from the response body. It does a range request if Seek was called before.
+//
+// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged
+func (r *HttpReadSeeker) Read(p []byte) (n int, err error) {
+ if r.r == nil {
+ err = r.rangeRequest()
+ }
+ if r.r != nil {
+ n, err = r.r.Read(p)
+ r.pos += int64(n)
+ }
+ return
+}
+
+// ReadAt reads from the response body starting at offset off.
+//
+// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged
+func (r *HttpReadSeeker) ReadAt(p []byte, off int64) (n int, err error) {
+ var nn int
+
+ r.Seek(off, 0)
+
+ for n < len(p) && err == nil {
+ nn, err = r.Read(p[n:])
+ n += nn
+ }
+ return
+}
+
+// Close closes the response body
+func (r *HttpReadSeeker) Close() error {
+ if r.r != nil {
+ return r.r.Close()
+ }
+ return nil
+}
+
+// Seek moves the reader position to a new offset.
+//
+// It does not send http requests, allowing for multiple seeks without overhead.
+// The http request will be sent by the next Read call.
+//
+// May return ErrNoContentLength or ErrRangeRequestsNotSupported
+func (r *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
+ var err error
+ switch whence {
+ case 0:
+ case 1:
+ offset += r.pos
+ case 2:
+ if r.res.ContentLength <= 0 {
+ return 0, ErrNoContentLength
+ }
+ offset = r.res.ContentLength - offset
+ }
+ if r.r != nil {
+ // Try to read, which is cheaper than doing a request
+ if r.pos < offset && offset-r.pos <= shortSeekBytes {
+ _, err := io.CopyN(ioutil.Discard, r, offset-r.pos)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ if r.pos != offset {
+ err = r.r.Close()
+ r.r = nil
+ }
+ }
+ r.pos = offset
+ return r.pos, err
+}
+
+func cloneHeader(h http.Header) http.Header {
+ h2 := make(http.Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+func (r *HttpReadSeeker) newRequest() *http.Request {
+ newreq := r.req.WithContext(r.ctx) // includes shallow copies of maps, but okay
+ if r.req.ContentLength == 0 {
+ newreq.Body = nil // Issue 16036: nil Body for http.Transport retries
+ }
+ newreq.Header = cloneHeader(r.req.Header)
+ return newreq
+}
+
+func (r *HttpReadSeeker) rangeRequest() error {
+ r.req = r.newRequest()
+ r.req.Header.Set("Range", fmt.Sprintf("bytes=%d-", r.pos))
+ etag, last := r.res.Header.Get("ETag"), r.res.Header.Get("Last-Modified")
+ switch {
+ case last != "":
+ r.req.Header.Set("If-Range", last)
+ case etag != "":
+ r.req.Header.Set("If-Range", etag)
+ }
+
+ r.Requests++
+
+ res, err := r.c.Do(r.req)
+ if err != nil {
+ return err
+ }
+ switch res.StatusCode {
+ case http.StatusRequestedRangeNotSatisfiable:
+ return ErrInvalidRange
+ case http.StatusOK:
+ // some servers return 200 OK for bytes=0-
+ if r.pos > 0 ||
+ (etag != "" && etag != res.Header.Get("ETag")) {
+ return ErrContentHasChanged
+ }
+ fallthrough
+ case http.StatusPartialContent:
+ r.r = res.Body
+ return nil
+ }
+ return ErrRangeRequestsNotSupported
+}
diff --git a/workhorse/internal/httprs/httprs_test.go b/workhorse/internal/httprs/httprs_test.go
new file mode 100644
index 00000000000..62279d895c9
--- /dev/null
+++ b/workhorse/internal/httprs/httprs_test.go
@@ -0,0 +1,257 @@
+package httprs
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ . "github.com/smartystreets/goconvey/convey"
+)
+
+type fakeResponseWriter struct {
+ code int
+ h http.Header
+ tmp *os.File
+}
+
+func (f *fakeResponseWriter) Header() http.Header {
+ return f.h
+}
+
+func (f *fakeResponseWriter) Write(b []byte) (int, error) {
+ return f.tmp.Write(b)
+}
+
+func (f *fakeResponseWriter) Close(b []byte) error {
+ return f.tmp.Close()
+}
+
+func (f *fakeResponseWriter) WriteHeader(code int) {
+ f.code = code
+}
+
+func (f *fakeResponseWriter) Response() *http.Response {
+ f.tmp.Seek(0, io.SeekStart)
+ return &http.Response{Body: f.tmp, StatusCode: f.code, Header: f.h}
+}
+
+type fakeRoundTripper struct {
+ src *os.File
+ downgradeZeroToNoRange bool
+}
+
+func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
+ fw := &fakeResponseWriter{h: http.Header{}}
+ var err error
+ fw.tmp, err = ioutil.TempFile(os.TempDir(), "httprs")
+ if err != nil {
+ return nil, err
+ }
+ if f.downgradeZeroToNoRange {
+ // There are implementations that downgrades bytes=0- to a normal un-ranged GET
+ if r.Header.Get("Range") == "bytes=0-" {
+ r.Header.Del("Range")
+ }
+ }
+ http.ServeContent(fw, r, "temp.txt", time.Now(), f.src)
+
+ return fw.Response(), nil
+}
+
+const SZ = 4096
+
+const (
+ downgradeZeroToNoRange = 1 << iota
+ sendAcceptRanges
+)
+
+type RSFactory func() *HttpReadSeeker
+
+func newRSFactory(flags int) RSFactory {
+ return func() *HttpReadSeeker {
+ tmp, err := ioutil.TempFile(os.TempDir(), "httprs")
+ if err != nil {
+ return nil
+ }
+ for i := 0; i < SZ; i++ {
+ tmp.WriteString(fmt.Sprintf("%04d", i))
+ }
+
+ req, err := http.NewRequest("GET", "http://www.example.com", nil)
+ if err != nil {
+ return nil
+ }
+ res := &http.Response{
+ Request: req,
+ ContentLength: SZ * 4,
+ }
+
+ if flags&sendAcceptRanges > 0 {
+ res.Header = http.Header{"Accept-Ranges": []string{"bytes"}}
+ }
+
+ downgradeZeroToNoRange := (flags & downgradeZeroToNoRange) > 0
+ return NewHttpReadSeeker(res, &http.Client{Transport: &fakeRoundTripper{src: tmp, downgradeZeroToNoRange: downgradeZeroToNoRange}})
+ }
+}
+
+func TestHttpWebServer(t *testing.T) {
+ Convey("Scenario: testing WebServer", t, func() {
+ dir, err := ioutil.TempDir("", "webserver")
+ So(err, ShouldBeNil)
+ defer os.RemoveAll(dir)
+
+ err = ioutil.WriteFile(filepath.Join(dir, "file"), make([]byte, 10000), 0755)
+ So(err, ShouldBeNil)
+
+ server := httptest.NewServer(http.FileServer(http.Dir(dir)))
+
+ Convey("When requesting /file", func() {
+ res, err := http.Get(server.URL + "/file")
+ So(err, ShouldBeNil)
+
+ stream := NewHttpReadSeeker(res)
+ So(stream, ShouldNotBeNil)
+
+ Convey("Can read 100 bytes from start of file", func() {
+ n, err := stream.Read(make([]byte, 100))
+ So(err, ShouldBeNil)
+ So(n, ShouldEqual, 100)
+
+ Convey("When seeking 4KiB forward", func() {
+ pos, err := stream.Seek(4096, io.SeekCurrent)
+ So(err, ShouldBeNil)
+ So(pos, ShouldEqual, 4096+100)
+
+ Convey("Can read 100 bytes", func() {
+ n, err := stream.Read(make([]byte, 100))
+ So(err, ShouldBeNil)
+ So(n, ShouldEqual, 100)
+ })
+ })
+ })
+ })
+ })
+}
+
+func TestHttpReaderSeeker(t *testing.T) {
+ tests := []struct {
+ name string
+ newRS func() *HttpReadSeeker
+ }{
+ {name: "with no flags", newRS: newRSFactory(0)},
+ {name: "with only Accept-Ranges", newRS: newRSFactory(sendAcceptRanges)},
+ {name: "downgrade 0-range to no range", newRS: newRSFactory(downgradeZeroToNoRange)},
+ {name: "downgrade 0-range with Accept-Ranges", newRS: newRSFactory(downgradeZeroToNoRange | sendAcceptRanges)},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testHttpReaderSeeker(t, test.newRS)
+ })
+ }
+}
+
+func testHttpReaderSeeker(t *testing.T, newRS RSFactory) {
+ Convey("Scenario: testing HttpReaderSeeker", t, func() {
+
+ Convey("Read should start at the beginning", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0000")
+ })
+
+ Convey("Seek w SEEK_SET should seek to right offset", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ s, err := r.Seek(4*64, io.SeekStart)
+ So(s, ShouldEqual, 4*64)
+ So(err, ShouldBeNil)
+ buf := make([]byte, 4)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0064")
+ })
+
+ Convey("Read + Seek w SEEK_CUR should seek to right offset", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ io.ReadFull(r, buf)
+ s, err := r.Seek(4*64, os.SEEK_CUR)
+ So(s, ShouldEqual, 4*64+4)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0065")
+ })
+
+ Convey("Seek w SEEK_END should seek to right offset", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ io.ReadFull(r, buf)
+ s, err := r.Seek(4, os.SEEK_END)
+ So(s, ShouldEqual, SZ*4-4)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, fmt.Sprintf("%04d", SZ-1))
+ })
+
+ Convey("Short seek should consume existing request", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ So(r.Requests, ShouldEqual, 0)
+ io.ReadFull(r, buf)
+ So(r.Requests, ShouldEqual, 1)
+ s, err := r.Seek(shortSeekBytes, os.SEEK_CUR)
+ So(r.Requests, ShouldEqual, 1)
+ So(s, ShouldEqual, shortSeekBytes+4)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "0257")
+ So(r.Requests, ShouldEqual, 1)
+ })
+
+ Convey("Long seek should do a new request", func() {
+ r := newRS()
+ So(r, ShouldNotBeNil)
+ defer r.Close()
+ buf := make([]byte, 4)
+ So(r.Requests, ShouldEqual, 0)
+ io.ReadFull(r, buf)
+ So(r.Requests, ShouldEqual, 1)
+ s, err := r.Seek(shortSeekBytes+1, os.SEEK_CUR)
+ So(r.Requests, ShouldEqual, 1)
+ So(s, ShouldEqual, shortSeekBytes+4+1)
+ So(err, ShouldBeNil)
+ n, err := io.ReadFull(r, buf)
+ So(n, ShouldEqual, 4)
+ So(err, ShouldBeNil)
+ So(string(buf), ShouldEqual, "2570")
+ So(r.Requests, ShouldEqual, 2)
+ })
+ })
+}
diff --git a/workhorse/internal/imageresizer/image_resizer.go b/workhorse/internal/imageresizer/image_resizer.go
new file mode 100644
index 00000000000..77318ed1c46
--- /dev/null
+++ b/workhorse/internal/imageresizer/image_resizer.go
@@ -0,0 +1,449 @@
+package imageresizer
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+ "gitlab.com/gitlab-org/labkit/tracing"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+)
+
+type Resizer struct {
+ config.Config
+ senddata.Prefix
+ numScalerProcs processCounter
+}
+
+type resizeParams struct {
+ Location string
+ ContentType string
+ Width uint
+}
+
+type processCounter struct {
+ n int32
+}
+
+type resizeStatus = string
+
+type imageFile struct {
+ reader io.ReadCloser
+ contentLength int64
+ lastModified time.Time
+}
+
+// Carries information about how the scaler succeeded or failed.
+type resizeOutcome struct {
+ bytesWritten int64
+ originalFileSize int64
+ status resizeStatus
+ err error
+}
+
+const (
+ statusSuccess = "success" // a rescaled image was served
+ statusClientCache = "success-client-cache" // scaling was skipped because client cache was fresh
+ statusServedOriginal = "served-original" // scaling failed but the original image was served
+ statusRequestFailure = "request-failed" // no image was served
+ statusUnknown = "unknown" // indicates an unhandled status case
+)
+
+var envInjector = tracing.NewEnvInjector()
+
+// Images might be located remotely in object storage, in which case we need to stream
+// it via http(s)
+var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 10 * time.Second,
+ }).DialContext,
+ MaxIdleConns: 2,
+ IdleConnTimeout: 30 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 10 * time.Second,
+ ResponseHeaderTimeout: 30 * time.Second,
+}))
+
+var httpClient = &http.Client{
+ Transport: httpTransport,
+}
+
+const (
+ namespace = "gitlab_workhorse"
+ subsystem = "image_resize"
+ logSystem = "imageresizer"
+)
+
+var (
+ imageResizeConcurrencyLimitExceeds = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "concurrency_limit_exceeds_total",
+ Help: "Amount of image resizing requests that exceeded the maximum allowed scaler processes",
+ },
+ )
+ imageResizeProcesses = promauto.NewGauge(
+ prometheus.GaugeOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "processes",
+ Help: "Amount of image scaler processes working now",
+ },
+ )
+ imageResizeMaxProcesses = promauto.NewGauge(
+ prometheus.GaugeOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "max_processes",
+ Help: "The maximum amount of image scaler processes allowed to run concurrently",
+ },
+ )
+ imageResizeRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "requests_total",
+ Help: "Image resizing operations requested",
+ },
+ []string{"status"},
+ )
+ imageResizeDurations = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "duration_seconds",
+ Help: "Breakdown of total time spent serving successful image resizing requests (incl. data transfer)",
+ Buckets: []float64{
+ 0.025, /* 25ms */
+ 0.050, /* 50ms */
+ 0.1, /* 100ms */
+ 0.2, /* 200ms */
+ 0.4, /* 400ms */
+ 0.8, /* 800ms */
+ },
+ },
+ []string{"content_type", "width"},
+ )
+)
+
+const (
+ jpegMagic = "\xff\xd8" // 2 bytes
+ pngMagic = "\x89PNG\r\n\x1a\n" // 8 bytes
+ maxMagicLen = 8 // 8 first bytes is enough to detect PNG or JPEG
+)
+
+func NewResizer(cfg config.Config) *Resizer {
+ imageResizeMaxProcesses.Set(float64(cfg.ImageResizerConfig.MaxScalerProcs))
+
+ return &Resizer{Config: cfg, Prefix: "send-scaled-img:"}
+}
+
+// Inject forks into a dedicated scaler process to resize an image identified by path or URL
+// and streams the resized image back to the client
+func (r *Resizer) Inject(w http.ResponseWriter, req *http.Request, paramsData string) {
+ var outcome = resizeOutcome{status: statusUnknown, originalFileSize: 0, bytesWritten: 0}
+ start := time.Now()
+ params, err := r.unpackParameters(paramsData)
+
+ defer func() {
+ imageResizeRequests.WithLabelValues(outcome.status).Inc()
+ handleOutcome(w, req, start, params, &outcome)
+ }()
+
+ if err != nil {
+ // This means the response header coming from Rails was malformed; there is no way
+ // to sensibly recover from this other than failing fast
+ outcome.error(fmt.Errorf("read image resize params: %v", err))
+ return
+ }
+
+ imageFile, err := openSourceImage(params.Location)
+ if err != nil {
+ // This means we cannot even read the input image; fail fast.
+ outcome.error(fmt.Errorf("open image data stream: %v", err))
+ return
+ }
+ defer imageFile.reader.Close()
+
+ outcome.originalFileSize = imageFile.contentLength
+
+ setLastModified(w, imageFile.lastModified)
+ // If the original file has not changed, then any cached resized versions have not changed either.
+ if checkNotModified(req, imageFile.lastModified) {
+ writeNotModified(w)
+ outcome.ok(statusClientCache)
+ return
+ }
+
+ // We first attempt to rescale the image; if this should fail for any reason, imageReader
+ // will point to the original image, i.e. we render it unchanged.
+ imageReader, resizeCmd, err := r.tryResizeImage(req, imageFile, params, r.Config.ImageResizerConfig)
+ if err != nil {
+ // Something failed, but we can still write out the original image, so don't return early.
+ // We need to log this separately since the subsequent steps might add other failures.
+ helper.LogErrorWithFields(req, err, *logFields(start, params, &outcome))
+ }
+ defer helper.CleanUpProcessGroup(resizeCmd)
+
+ w.Header().Del("Content-Length")
+ outcome.bytesWritten, err = serveImage(imageReader, w, resizeCmd)
+
+ // We failed serving image data; this is a hard failure.
+ if err != nil {
+ outcome.error(err)
+ return
+ }
+
+ // This means we served the original image because rescaling failed; this is a soft failure
+ if resizeCmd == nil {
+ outcome.ok(statusServedOriginal)
+ return
+ }
+
+ widthLabelVal := strconv.Itoa(int(params.Width))
+ imageResizeDurations.WithLabelValues(params.ContentType, widthLabelVal).Observe(time.Since(start).Seconds())
+
+ outcome.ok(statusSuccess)
+}
+
+// Streams image data from the given reader to the given writer and returns the number of bytes written.
+func serveImage(r io.Reader, w io.Writer, resizeCmd *exec.Cmd) (int64, error) {
+ bytesWritten, err := io.Copy(w, r)
+ if err != nil {
+ return bytesWritten, err
+ }
+
+ if resizeCmd != nil {
+ // If a scaler process had been forked, wait for the command to finish.
+ if err = resizeCmd.Wait(); err != nil {
+ // err will be an ExitError; this is not useful beyond knowing the exit code since anything
+ // interesting has been written to stderr, so we turn that into an error we can return.
+ stdErr := resizeCmd.Stderr.(*strings.Builder)
+ return bytesWritten, fmt.Errorf(stdErr.String())
+ }
+ }
+
+ return bytesWritten, nil
+}
+
+func (r *Resizer) unpackParameters(paramsData string) (*resizeParams, error) {
+ var params resizeParams
+ if err := r.Unpack(&params, paramsData); err != nil {
+ return nil, err
+ }
+
+ if params.Location == "" {
+ return nil, fmt.Errorf("'Location' not set")
+ }
+
+ if params.ContentType == "" {
+ return nil, fmt.Errorf("'ContentType' must be set")
+ }
+
+ return &params, nil
+}
+
+// Attempts to rescale the given image data, or in case of errors, falls back to the original image.
+func (r *Resizer) tryResizeImage(req *http.Request, f *imageFile, params *resizeParams, cfg config.ImageResizerConfig) (io.Reader, *exec.Cmd, error) {
+ if f.contentLength > int64(cfg.MaxFilesize) {
+ return f.reader, nil, fmt.Errorf("%d bytes exceeds maximum file size of %d bytes", f.contentLength, cfg.MaxFilesize)
+ }
+
+ if f.contentLength < maxMagicLen {
+ return f.reader, nil, fmt.Errorf("file is too small to resize: %d bytes", f.contentLength)
+ }
+
+ if !r.numScalerProcs.tryIncrement(int32(cfg.MaxScalerProcs)) {
+ return f.reader, nil, fmt.Errorf("too many running scaler processes (%d / %d)", r.numScalerProcs.n, cfg.MaxScalerProcs)
+ }
+
+ ctx := req.Context()
+ go func() {
+ <-ctx.Done()
+ r.numScalerProcs.decrement()
+ }()
+
+ // Creating buffered Reader is required for us to Peek into first bytes of the image file to detect the format
+ // without advancing the reader (we need to read from the file start in the Scaler binary).
+ // We set `8` as the minimal buffer size by the length of PNG magic bytes sequence (JPEG needs only 2).
+ // In fact, `NewReaderSize` will immediately override it with `16` using its `minReadBufferSize` -
+ // here we are just being explicit about the buffer size required for our code to operate correctly.
+ // Having a reader with such tiny buffer will not hurt the performance during further operations,
+ // because Golang `bufio.Read` avoids double copy: https://golang.org/src/bufio/bufio.go?s=1768:1804#L212
+ buffered := bufio.NewReaderSize(f.reader, maxMagicLen)
+
+ headerBytes, err := buffered.Peek(maxMagicLen)
+ if err != nil {
+ return buffered, nil, fmt.Errorf("peek stream: %v", err)
+ }
+
+ // Check magic bytes to identify file type.
+ if string(headerBytes) != pngMagic && string(headerBytes[0:2]) != jpegMagic {
+ return buffered, nil, fmt.Errorf("unrecognized file signature: %v", headerBytes)
+ }
+
+ resizeCmd, resizedImageReader, err := startResizeImageCommand(ctx, buffered, params)
+ if err != nil {
+ return buffered, nil, fmt.Errorf("fork into scaler process: %w", err)
+ }
+ return resizedImageReader, resizeCmd, nil
+}
+
+func startResizeImageCommand(ctx context.Context, imageReader io.Reader, params *resizeParams) (*exec.Cmd, io.ReadCloser, error) {
+ cmd := exec.CommandContext(ctx, "gitlab-resize-image")
+ cmd.Stdin = imageReader
+ cmd.Stderr = &strings.Builder{}
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+ cmd.Env = []string{
+ "GL_RESIZE_IMAGE_WIDTH=" + strconv.Itoa(int(params.Width)),
+ }
+ cmd.Env = envInjector(ctx, cmd.Env)
+
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if err := cmd.Start(); err != nil {
+ return nil, nil, err
+ }
+
+ return cmd, stdout, nil
+}
+
+func isURL(location string) bool {
+ return strings.HasPrefix(location, "http://") || strings.HasPrefix(location, "https://")
+}
+
+func openSourceImage(location string) (*imageFile, error) {
+ if isURL(location) {
+ return openFromURL(location)
+ }
+
+ return openFromFile(location)
+}
+
+func openFromURL(location string) (*imageFile, error) {
+ res, err := httpClient.Get(location)
+ if err != nil {
+ return nil, err
+ }
+
+ switch res.StatusCode {
+ case http.StatusOK, http.StatusNotModified:
+ // Extract headers for conditional GETs from response.
+ lastModified, err := http.ParseTime(res.Header.Get("Last-Modified"))
+ if err != nil {
+ // This is unlikely to happen, coming from an object storage provider.
+ lastModified = time.Now().UTC()
+ }
+ return &imageFile{res.Body, res.ContentLength, lastModified}, nil
+ default:
+ res.Body.Close()
+ return nil, fmt.Errorf("stream data from %q: %d %s", location, res.StatusCode, res.Status)
+ }
+}
+
+func openFromFile(location string) (*imageFile, error) {
+ file, err := os.Open(location)
+ if err != nil {
+ return nil, err
+ }
+
+ fi, err := file.Stat()
+ if err != nil {
+ file.Close()
+ return nil, err
+ }
+
+ return &imageFile{file, fi.Size(), fi.ModTime()}, nil
+}
+
+// Only allow more scaling requests if we haven't yet reached the maximum
+// allowed number of concurrent scaler processes
+func (c *processCounter) tryIncrement(maxScalerProcs int32) bool {
+ if p := atomic.AddInt32(&c.n, 1); p > maxScalerProcs {
+ c.decrement()
+ imageResizeConcurrencyLimitExceeds.Inc()
+
+ return false
+ }
+
+ imageResizeProcesses.Set(float64(c.n))
+ return true
+}
+
+func (c *processCounter) decrement() {
+ atomic.AddInt32(&c.n, -1)
+ imageResizeProcesses.Set(float64(c.n))
+}
+
+func (o *resizeOutcome) ok(status resizeStatus) {
+ o.status = status
+ o.err = nil
+}
+
+func (o *resizeOutcome) error(err error) {
+ o.status = statusRequestFailure
+ o.err = err
+}
+
+func logFields(startTime time.Time, params *resizeParams, outcome *resizeOutcome) *log.Fields {
+ var targetWidth, contentType string
+ if params != nil {
+ targetWidth = fmt.Sprint(params.Width)
+ contentType = fmt.Sprint(params.ContentType)
+ }
+ return &log.Fields{
+ "subsystem": logSystem,
+ "written_bytes": outcome.bytesWritten,
+ "duration_s": time.Since(startTime).Seconds(),
+ logSystem + ".status": outcome.status,
+ logSystem + ".target_width": targetWidth,
+ logSystem + ".content_type": contentType,
+ logSystem + ".original_filesize": outcome.originalFileSize,
+ }
+}
+
+func handleOutcome(w http.ResponseWriter, req *http.Request, startTime time.Time, params *resizeParams, outcome *resizeOutcome) {
+ logger := log.ContextLogger(req.Context())
+ fields := *logFields(startTime, params, outcome)
+
+ switch outcome.status {
+ case statusRequestFailure:
+ if outcome.bytesWritten <= 0 {
+ helper.Fail500WithFields(w, req, outcome.err, fields)
+ } else {
+ helper.LogErrorWithFields(req, outcome.err, fields)
+ }
+ default:
+ logger.WithFields(fields).WithFields(
+ log.Fields{
+ "method": req.Method,
+ "uri": mask.URL(req.RequestURI),
+ },
+ ).Printf(outcome.status)
+ }
+}
diff --git a/workhorse/internal/imageresizer/image_resizer_caching.go b/workhorse/internal/imageresizer/image_resizer_caching.go
new file mode 100644
index 00000000000..bbe0e3260d7
--- /dev/null
+++ b/workhorse/internal/imageresizer/image_resizer_caching.go
@@ -0,0 +1,44 @@
+// This file contains code derived from https://github.com/golang/go/blob/master/src/net/http/fs.go
+//
+// Copyright 2020 GitLab Inc. All rights reserved.
+// Copyright 2009 The Go Authors. All rights reserved.
+
+package imageresizer
+
+import (
+ "net/http"
+ "time"
+)
+
+func checkNotModified(r *http.Request, modtime time.Time) bool {
+ ims := r.Header.Get("If-Modified-Since")
+ if ims == "" || isZeroTime(modtime) {
+ // Treat bogus times as if there was no such header at all
+ return false
+ }
+ t, err := http.ParseTime(ims)
+ if err != nil {
+ return false
+ }
+ // The Last-Modified header truncates sub-second precision so
+ // the modtime needs to be truncated too.
+ return !modtime.Truncate(time.Second).After(t)
+}
+
+// isZeroTime reports whether t is obviously unspecified (either zero or Unix epoch time).
+func isZeroTime(t time.Time) bool {
+ return t.IsZero() || t.Equal(time.Unix(0, 0))
+}
+
+func setLastModified(w http.ResponseWriter, modtime time.Time) {
+ if !isZeroTime(modtime) {
+ w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat))
+ }
+}
+
+func writeNotModified(w http.ResponseWriter) {
+ h := w.Header()
+ h.Del("Content-Type")
+ h.Del("Content-Length")
+ w.WriteHeader(http.StatusNotModified)
+}
diff --git a/workhorse/internal/imageresizer/image_resizer_test.go b/workhorse/internal/imageresizer/image_resizer_test.go
new file mode 100644
index 00000000000..bacc97738b8
--- /dev/null
+++ b/workhorse/internal/imageresizer/image_resizer_test.go
@@ -0,0 +1,259 @@
+package imageresizer
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "image"
+ "image/png"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+
+ _ "image/jpeg" // need this for image.Decode with JPEG
+)
+
+const imagePath = "../../testdata/image.png"
+
+func TestMain(m *testing.M) {
+ if err := testhelper.BuildExecutables(); err != nil {
+ log.WithError(err).Fatal()
+ }
+
+ os.Exit(m.Run())
+}
+
+func requestScaledImage(t *testing.T, httpHeaders http.Header, params resizeParams, cfg config.ImageResizerConfig) *http.Response {
+ httpRequest := httptest.NewRequest("GET", "/image", nil)
+ if httpHeaders != nil {
+ httpRequest.Header = httpHeaders
+ }
+ responseWriter := httptest.NewRecorder()
+ paramsJSON := encodeParams(t, &params)
+
+ NewResizer(config.Config{ImageResizerConfig: cfg}).Inject(responseWriter, httpRequest, paramsJSON)
+
+ return responseWriter.Result()
+}
+
+func TestRequestScaledImageFromPath(t *testing.T) {
+ cfg := config.DefaultImageResizerConfig
+
+ testCases := []struct {
+ desc string
+ imagePath string
+ contentType string
+ }{
+ {
+ desc: "PNG",
+ imagePath: imagePath,
+ contentType: "image/png",
+ },
+ {
+ desc: "JPEG",
+ imagePath: "../../testdata/image.jpg",
+ contentType: "image/jpeg",
+ },
+ {
+ desc: "JPEG < 1kb",
+ imagePath: "../../testdata/image_single_pixel.jpg",
+ contentType: "image/jpeg",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ params := resizeParams{Location: tc.imagePath, ContentType: tc.contentType, Width: 64}
+
+ resp := requestScaledImage(t, nil, params, cfg)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ bounds := imageFromResponse(t, resp).Bounds()
+ require.Equal(t, int(params.Width), bounds.Size().X, "wrong width after resizing")
+ })
+ }
+}
+
+func TestRequestScaledImageWithConditionalGetAndImageNotChanged(t *testing.T) {
+ cfg := config.DefaultImageResizerConfig
+ params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64}
+
+ clientTime := testImageLastModified(t)
+ header := http.Header{}
+ header.Set("If-Modified-Since", httpTimeStr(clientTime))
+
+ resp := requestScaledImage(t, header, params, cfg)
+ require.Equal(t, http.StatusNotModified, resp.StatusCode)
+ require.Equal(t, httpTimeStr(testImageLastModified(t)), resp.Header.Get("Last-Modified"))
+ require.Empty(t, resp.Header.Get("Content-Type"))
+ require.Empty(t, resp.Header.Get("Content-Length"))
+}
+
+func TestRequestScaledImageWithConditionalGetAndImageChanged(t *testing.T) {
+ cfg := config.DefaultImageResizerConfig
+ params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64}
+
+ clientTime := testImageLastModified(t).Add(-1 * time.Second)
+ header := http.Header{}
+ header.Set("If-Modified-Since", httpTimeStr(clientTime))
+
+ resp := requestScaledImage(t, header, params, cfg)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+ require.Equal(t, httpTimeStr(testImageLastModified(t)), resp.Header.Get("Last-Modified"))
+}
+
+func TestRequestScaledImageWithConditionalGetAndInvalidClientTime(t *testing.T) {
+ cfg := config.DefaultImageResizerConfig
+ params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64}
+
+ header := http.Header{}
+ header.Set("If-Modified-Since", "0")
+
+ resp := requestScaledImage(t, header, params, cfg)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+ require.Equal(t, httpTimeStr(testImageLastModified(t)), resp.Header.Get("Last-Modified"))
+}
+
+func TestServeOriginalImageWhenSourceImageTooLarge(t *testing.T) {
+ originalImage := testImage(t)
+ cfg := config.ImageResizerConfig{MaxScalerProcs: 1, MaxFilesize: 1}
+ params := resizeParams{Location: imagePath, ContentType: "image/png", Width: 64}
+
+ resp := requestScaledImage(t, nil, params, cfg)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ img := imageFromResponse(t, resp)
+ require.Equal(t, originalImage.Bounds(), img.Bounds(), "expected original image size")
+}
+
+func TestFailFastOnOpenStreamFailure(t *testing.T) {
+ cfg := config.DefaultImageResizerConfig
+ params := resizeParams{Location: "does_not_exist.png", ContentType: "image/png", Width: 64}
+ resp := requestScaledImage(t, nil, params, cfg)
+
+ require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+}
+
+func TestIgnoreContentTypeMismatchIfImageFormatIsAllowed(t *testing.T) {
+ cfg := config.DefaultImageResizerConfig
+ params := resizeParams{Location: imagePath, ContentType: "image/jpeg", Width: 64}
+ resp := requestScaledImage(t, nil, params, cfg)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ bounds := imageFromResponse(t, resp).Bounds()
+ require.Equal(t, int(params.Width), bounds.Size().X, "wrong width after resizing")
+}
+
+func TestUnpackParametersReturnsParamsInstanceForValidInput(t *testing.T) {
+ r := Resizer{}
+ inParams := resizeParams{Location: imagePath, Width: 64, ContentType: "image/png"}
+
+ outParams, err := r.unpackParameters(encodeParams(t, &inParams))
+
+ require.NoError(t, err, "unexpected error when unpacking params")
+ require.Equal(t, inParams, *outParams)
+}
+
+func TestUnpackParametersReturnsErrorWhenLocationBlank(t *testing.T) {
+ r := Resizer{}
+ inParams := resizeParams{Location: "", Width: 64, ContentType: "image/jpg"}
+
+ _, err := r.unpackParameters(encodeParams(t, &inParams))
+
+ require.Error(t, err, "expected error when Location is blank")
+}
+
+func TestUnpackParametersReturnsErrorWhenContentTypeBlank(t *testing.T) {
+ r := Resizer{}
+ inParams := resizeParams{Location: imagePath, Width: 64, ContentType: ""}
+
+ _, err := r.unpackParameters(encodeParams(t, &inParams))
+
+ require.Error(t, err, "expected error when ContentType is blank")
+}
+
+func TestServeOriginalImageWhenSourceImageFormatIsNotAllowed(t *testing.T) {
+ cfg := config.DefaultImageResizerConfig
+ // SVG images are not allowed to be resized
+ svgImagePath := "../../testdata/image.svg"
+ svgImage, err := ioutil.ReadFile(svgImagePath)
+ require.NoError(t, err)
+ // ContentType is no longer used to perform the format validation.
+ // To make the test more strict, we'll use allowed, but incorrect ContentType.
+ params := resizeParams{Location: svgImagePath, ContentType: "image/png", Width: 64}
+
+ resp := requestScaledImage(t, nil, params, cfg)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ responseData, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err)
+ require.Equal(t, svgImage, responseData, "expected original image")
+}
+
+func TestServeOriginalImageWhenSourceImageIsTooSmall(t *testing.T) {
+ content := []byte("PNG") // 3 bytes only, invalid as PNG/JPEG image
+
+ img, err := ioutil.TempFile("", "*.png")
+ require.NoError(t, err)
+
+ defer img.Close()
+ defer os.Remove(img.Name())
+
+ _, err = img.Write(content)
+ require.NoError(t, err)
+
+ cfg := config.DefaultImageResizerConfig
+ params := resizeParams{Location: img.Name(), ContentType: "image/png", Width: 64}
+
+ resp := requestScaledImage(t, nil, params, cfg)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ responseData, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err)
+ require.Equal(t, content, responseData, "expected original image")
+}
+
+// The Rails applications sends a Base64 encoded JSON string carrying
+// these parameters in an HTTP response header
+func encodeParams(t *testing.T, p *resizeParams) string {
+ json, err := json.Marshal(*p)
+ if err != nil {
+ require.NoError(t, err, "JSON encoder encountered unexpected error")
+ }
+ return base64.StdEncoding.EncodeToString(json)
+}
+
+func testImage(t *testing.T) image.Image {
+ f, err := os.Open(imagePath)
+ require.NoError(t, err)
+
+ image, err := png.Decode(f)
+ require.NoError(t, err, "decode original image")
+
+ return image
+}
+
+func testImageLastModified(t *testing.T) time.Time {
+ fi, err := os.Stat(imagePath)
+ require.NoError(t, err)
+
+ return fi.ModTime()
+}
+
+func imageFromResponse(t *testing.T, resp *http.Response) image.Image {
+ img, _, err := image.Decode(resp.Body)
+ require.NoError(t, err, "decode resized image")
+ return img
+}
+
+func httpTimeStr(time time.Time) string {
+ return time.UTC().Format(http.TimeFormat)
+}
diff --git a/workhorse/internal/lfs/lfs.go b/workhorse/internal/lfs/lfs.go
new file mode 100644
index 00000000000..ec48dc05ef9
--- /dev/null
+++ b/workhorse/internal/lfs/lfs.go
@@ -0,0 +1,55 @@
+/*
+In this file we handle git lfs objects downloads and uploads
+*/
+
+package lfs
+
+import (
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
+)
+
+type object struct {
+ size int64
+ oid string
+}
+
+func (l *object) Verify(fh *filestore.FileHandler) error {
+ if fh.Size != l.size {
+ return fmt.Errorf("LFSObject: expected size %d, wrote %d", l.size, fh.Size)
+ }
+
+ if fh.SHA256() != l.oid {
+ return fmt.Errorf("LFSObject: expected sha256 %s, got %s", l.oid, fh.SHA256())
+ }
+
+ return nil
+}
+
+type uploadPreparer struct {
+ objectPreparer upload.Preparer
+}
+
+func NewLfsUploadPreparer(c config.Config, objectPreparer upload.Preparer) upload.Preparer {
+ return &uploadPreparer{objectPreparer: objectPreparer}
+}
+
+func (l *uploadPreparer) Prepare(a *api.Response) (*filestore.SaveFileOpts, upload.Verifier, error) {
+ opts, _, err := l.objectPreparer.Prepare(a)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ opts.TempFilePrefix = a.LfsOid
+
+ return opts, &object{oid: a.LfsOid, size: a.LfsSize}, nil
+}
+
+func PutStore(a *api.API, h http.Handler, p upload.Preparer) http.Handler {
+ return upload.BodyUploader(a, h, p)
+}
diff --git a/workhorse/internal/lfs/lfs_test.go b/workhorse/internal/lfs/lfs_test.go
new file mode 100644
index 00000000000..828ed1bfe90
--- /dev/null
+++ b/workhorse/internal/lfs/lfs_test.go
@@ -0,0 +1,61 @@
+package lfs_test
+
+import (
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/lfs"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestLfsUploadPreparerWithConfig(t *testing.T) {
+ lfsOid := "abcd1234"
+ creds := config.S3Credentials{
+ AwsAccessKeyID: "test-key",
+ AwsSecretAccessKey: "test-secret",
+ }
+
+ c := config.Config{
+ ObjectStorageCredentials: config.ObjectStorageCredentials{
+ Provider: "AWS",
+ S3Credentials: creds,
+ },
+ }
+
+ r := &api.Response{
+ LfsOid: lfsOid,
+ RemoteObject: api.RemoteObject{
+ ID: "the upload ID",
+ UseWorkhorseClient: true,
+ ObjectStorage: &api.ObjectStorageParams{
+ Provider: "AWS",
+ },
+ },
+ }
+
+ uploadPreparer := upload.NewObjectStoragePreparer(c)
+ lfsPreparer := lfs.NewLfsUploadPreparer(c, uploadPreparer)
+ opts, verifier, err := lfsPreparer.Prepare(r)
+
+ require.NoError(t, err)
+ require.Equal(t, lfsOid, opts.TempFilePrefix)
+ require.True(t, opts.ObjectStorageConfig.IsAWS())
+ require.True(t, opts.UseWorkhorseClient)
+ require.Equal(t, creds, opts.ObjectStorageConfig.S3Credentials)
+ require.NotNil(t, verifier)
+}
+
+func TestLfsUploadPreparerWithNoConfig(t *testing.T) {
+ c := config.Config{}
+ r := &api.Response{RemoteObject: api.RemoteObject{ID: "the upload ID"}}
+ uploadPreparer := upload.NewObjectStoragePreparer(c)
+ lfsPreparer := lfs.NewLfsUploadPreparer(c, uploadPreparer)
+ opts, verifier, err := lfsPreparer.Prepare(r)
+
+ require.NoError(t, err)
+ require.False(t, opts.UseWorkhorseClient)
+ require.NotNil(t, verifier)
+}
diff --git a/workhorse/internal/lsif_transformer/parser/cache.go b/workhorse/internal/lsif_transformer/parser/cache.go
new file mode 100644
index 00000000000..395069cd217
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/cache.go
@@ -0,0 +1,56 @@
+package parser
+
+import (
+ "encoding/binary"
+ "io"
+ "io/ioutil"
+ "os"
+)
+
+// This cache implementation is using a temp file to provide key-value data storage
+// It allows to avoid storing intermediate calculations in RAM
+// The stored data must be a fixed-size value or a slice of fixed-size values, or a pointer to such data
+type cache struct {
+ file *os.File
+ chunkSize int64
+}
+
+func newCache(tempDir, filename string, data interface{}) (*cache, error) {
+ f, err := ioutil.TempFile(tempDir, filename)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := os.Remove(f.Name()); err != nil {
+ return nil, err
+ }
+
+ return &cache{file: f, chunkSize: int64(binary.Size(data))}, nil
+}
+
+func (c *cache) SetEntry(id Id, data interface{}) error {
+ if err := c.setOffset(id); err != nil {
+ return err
+ }
+
+ return binary.Write(c.file, binary.LittleEndian, data)
+}
+
+func (c *cache) Entry(id Id, data interface{}) error {
+ if err := c.setOffset(id); err != nil {
+ return err
+ }
+
+ return binary.Read(c.file, binary.LittleEndian, data)
+}
+
+func (c *cache) Close() error {
+ return c.file.Close()
+}
+
+func (c *cache) setOffset(id Id) error {
+ offset := int64(id) * c.chunkSize
+ _, err := c.file.Seek(offset, io.SeekStart)
+
+ return err
+}
diff --git a/workhorse/internal/lsif_transformer/parser/cache_test.go b/workhorse/internal/lsif_transformer/parser/cache_test.go
new file mode 100644
index 00000000000..23a2ac6e9a9
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/cache_test.go
@@ -0,0 +1,33 @@
+package parser
+
+import (
+ "io/ioutil"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type chunk struct {
+ A int16
+ B int16
+}
+
+func TestCache(t *testing.T) {
+ cache, err := newCache("", "test-chunks", chunk{})
+ require.NoError(t, err)
+ defer cache.Close()
+
+ c := chunk{A: 1, B: 2}
+ require.NoError(t, cache.SetEntry(1, &c))
+ require.NoError(t, cache.setOffset(0))
+
+ content, err := ioutil.ReadAll(cache.file)
+ require.NoError(t, err)
+
+ expected := []byte{0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x2, 0x0}
+ require.Equal(t, expected, content)
+
+ var nc chunk
+ require.NoError(t, cache.Entry(1, &nc))
+ require.Equal(t, c, nc)
+}
diff --git a/workhorse/internal/lsif_transformer/parser/code_hover.go b/workhorse/internal/lsif_transformer/parser/code_hover.go
new file mode 100644
index 00000000000..dbdaba643d1
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/code_hover.go
@@ -0,0 +1,124 @@
+package parser
+
+import (
+ "encoding/json"
+ "strings"
+ "unicode/utf8"
+
+ "github.com/alecthomas/chroma"
+ "github.com/alecthomas/chroma/lexers"
+)
+
+const maxValueSize = 250
+
+type token struct {
+ Class string `json:"class,omitempty"`
+ Value string `json:"value"`
+}
+
+type codeHover struct {
+ TruncatedValue *truncatableString `json:"value,omitempty"`
+ Tokens [][]token `json:"tokens,omitempty"`
+ Language string `json:"language,omitempty"`
+ Truncated bool `json:"truncated,omitempty"`
+}
+
+type truncatableString struct {
+ Value string
+ Truncated bool
+}
+
+func (ts *truncatableString) UnmarshalText(b []byte) error {
+ s := 0
+ for i := 0; s < len(b); i++ {
+ if i >= maxValueSize {
+ ts.Truncated = true
+ break
+ }
+
+ _, size := utf8.DecodeRune(b[s:])
+
+ s += size
+ }
+
+ ts.Value = string(b[0:s])
+
+ return nil
+}
+
+func (ts *truncatableString) MarshalJSON() ([]byte, error) {
+ return json.Marshal(ts.Value)
+}
+
+func newCodeHover(content json.RawMessage) (*codeHover, error) {
+ // Hover value can be either an object: { "value": "func main()", "language": "go" }
+ // Or a string with documentation
+ // We try to unmarshal the content into a string and if we fail, we unmarshal it into an object
+ var c codeHover
+ if err := json.Unmarshal(content, &c.TruncatedValue); err != nil {
+ if err := json.Unmarshal(content, &c); err != nil {
+ return nil, err
+ }
+
+ c.setTokens()
+ }
+
+ c.Truncated = c.TruncatedValue.Truncated
+
+ if len(c.Tokens) > 0 {
+ c.TruncatedValue = nil // remove value for hovers which have tokens
+ }
+
+ return &c, nil
+}
+
+func (c *codeHover) setTokens() {
+ lexer := lexers.Get(c.Language)
+ if lexer == nil {
+ return
+ }
+
+ iterator, err := lexer.Tokenise(nil, c.TruncatedValue.Value)
+ if err != nil {
+ return
+ }
+
+ var tokenLines [][]token
+ for _, tokenLine := range chroma.SplitTokensIntoLines(iterator.Tokens()) {
+ var tokens []token
+ var rawToken string
+ for _, t := range tokenLine {
+ class := c.classFor(t.Type)
+
+ // accumulate consequent raw values in a single string to store them as
+ // [{ Class: "kd", Value: "func" }, { Value: " main() {" }] instead of
+ // [{ Class: "kd", Value: "func" }, { Value: " " }, { Value: "main" }, { Value: "(" }...]
+ if class == "" {
+ rawToken = rawToken + t.Value
+ } else {
+ if rawToken != "" {
+ tokens = append(tokens, token{Value: rawToken})
+ rawToken = ""
+ }
+
+ tokens = append(tokens, token{Class: class, Value: t.Value})
+ }
+ }
+
+ if rawToken != "" {
+ tokens = append(tokens, token{Value: rawToken})
+ }
+
+ tokenLines = append(tokenLines, tokens)
+ }
+
+ c.Tokens = tokenLines
+}
+
+func (c *codeHover) classFor(tokenType chroma.TokenType) string {
+ if strings.HasPrefix(tokenType.String(), "Keyword") || tokenType == chroma.String || tokenType == chroma.Comment {
+ return chroma.StandardTypes[tokenType]
+ }
+
+ return ""
+}
diff --git a/workhorse/internal/lsif_transformer/parser/code_hover_test.go b/workhorse/internal/lsif_transformer/parser/code_hover_test.go
new file mode 100644
index 00000000000..2030e530155
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/code_hover_test.go
@@ -0,0 +1,106 @@
+package parser
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestHighlight(t *testing.T) {
+ tests := []struct {
+ name string
+ language string
+ value string
+ want [][]token
+ }{
+ {
+ name: "go function definition",
+ language: "go",
+ value: "func main()",
+ want: [][]token{{{Class: "kd", Value: "func"}, {Value: " main()"}}},
+ },
+ {
+ name: "go struct definition",
+ language: "go",
+ value: "type Command struct",
+ want: [][]token{{{Class: "kd", Value: "type"}, {Value: " Command "}, {Class: "kd", Value: "struct"}}},
+ },
+ {
+ name: "go struct multiline definition",
+ language: "go",
+ value: `struct {\nConfig *Config\nReadWriter *ReadWriter\nEOFSent bool\n}`,
+ want: [][]token{
+ {{Class: "kd", Value: "struct"}, {Value: " {\n"}},
+ {{Value: "Config *Config\n"}},
+ {{Value: "ReadWriter *ReadWriter\n"}},
+ {{Value: "EOFSent "}, {Class: "kt", Value: "bool"}, {Value: "\n"}},
+ {{Value: "}"}},
+ },
+ },
+ {
+ name: "ruby method definition",
+ language: "ruby",
+ value: "def read(line)",
+ want: [][]token{{{Class: "k", Value: "def"}, {Value: " read(line)"}}},
+ },
+ {
+ name: "ruby multiline method definition",
+ language: "ruby",
+ value: `def read(line)\nend`,
+ want: [][]token{
+ {{Class: "k", Value: "def"}, {Value: " read(line)\n"}},
+ {{Class: "k", Value: "end"}},
+ },
+ },
+ {
+ name: "unknown/malicious language is passed",
+ language: "<lang> alert(1); </lang>",
+ value: `def a;\nend`,
+ want: [][]token(nil),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ raw := []byte(fmt.Sprintf(`{"language":"%s","value":"%s"}`, tt.language, tt.value))
+ c, err := newCodeHover(json.RawMessage(raw))
+
+ require.NoError(t, err)
+ require.Equal(t, tt.want, c.Tokens)
+ })
+ }
+}
+
+func TestMarkdown(t *testing.T) {
+ value := `"This method reverses a string \n\n"`
+ c, err := newCodeHover(json.RawMessage(value))
+
+ require.NoError(t, err)
+ require.Equal(t, "This method reverses a string \n\n", c.TruncatedValue.Value)
+}
+
+func TestTruncatedValue(t *testing.T) {
+ value := strings.Repeat("a", 500)
+ rawValue, err := json.Marshal(value)
+ require.NoError(t, err)
+
+ c, err := newCodeHover(rawValue)
+ require.NoError(t, err)
+
+ require.Equal(t, value[0:maxValueSize], c.TruncatedValue.Value)
+ require.True(t, c.TruncatedValue.Truncated)
+}
+
+func TestTruncatingMultiByteChars(t *testing.T) {
+ value := strings.Repeat("ಅ", 500)
+ rawValue, err := json.Marshal(value)
+ require.NoError(t, err)
+
+ c, err := newCodeHover(rawValue)
+ require.NoError(t, err)
+
+ symbolSize := 3
+ require.Equal(t, value[0:maxValueSize*symbolSize], c.TruncatedValue.Value)
+}
diff --git a/workhorse/internal/lsif_transformer/parser/docs.go b/workhorse/internal/lsif_transformer/parser/docs.go
new file mode 100644
index 00000000000..c626e07d3fe
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/docs.go
@@ -0,0 +1,144 @@
+package parser
+
+import (
+ "archive/zip"
+ "bufio"
+ "encoding/json"
+ "io"
+ "strings"
+)
+
+const maxScanTokenSize = 1024 * 1024
+
+type Line struct {
+ Type string `json:"label"`
+}
+
+type Docs struct {
+ Root string
+ Entries map[Id]string
+ DocRanges map[Id][]Id
+ Ranges *Ranges
+}
+
+type Document struct {
+ Id Id `json:"id"`
+ Uri string `json:"uri"`
+}
+
+type DocumentRange struct {
+ OutV Id `json:"outV"`
+ RangeIds []Id `json:"inVs"`
+}
+
+type Metadata struct {
+ Root string `json:"projectRoot"`
+}
+
+func NewDocs(config Config) (*Docs, error) {
+ ranges, err := NewRanges(config)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Docs{
+ Root: "file:///",
+ Entries: make(map[Id]string),
+ DocRanges: make(map[Id][]Id),
+ Ranges: ranges,
+ }, nil
+}
+
+func (d *Docs) Parse(r io.Reader) error {
+ scanner := bufio.NewScanner(r)
+ buf := make([]byte, 0, bufio.MaxScanTokenSize)
+ scanner.Buffer(buf, maxScanTokenSize)
+
+ for scanner.Scan() {
+ if err := d.process(scanner.Bytes()); err != nil {
+ return err
+ }
+ }
+
+ return scanner.Err()
+}
+
+func (d *Docs) process(line []byte) error {
+ l := Line{}
+ if err := json.Unmarshal(line, &l); err != nil {
+ return err
+ }
+
+ switch l.Type {
+ case "metaData":
+ if err := d.addMetadata(line); err != nil {
+ return err
+ }
+ case "document":
+ if err := d.addDocument(line); err != nil {
+ return err
+ }
+ case "contains":
+ if err := d.addDocRanges(line); err != nil {
+ return err
+ }
+ default:
+ return d.Ranges.Read(l.Type, line)
+ }
+
+ return nil
+}
+
+func (d *Docs) Close() error {
+ return d.Ranges.Close()
+}
+
+func (d *Docs) SerializeEntries(w *zip.Writer) error {
+ for id, path := range d.Entries {
+ filePath := Lsif + "/" + path + ".json"
+
+ f, err := w.Create(filePath)
+ if err != nil {
+ return err
+ }
+
+ if err := d.Ranges.Serialize(f, d.DocRanges[id], d.Entries); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *Docs) addMetadata(line []byte) error {
+ var metadata Metadata
+ if err := json.Unmarshal(line, &metadata); err != nil {
+ return err
+ }
+
+ d.Root = strings.TrimSpace(metadata.Root) + "/"
+
+ return nil
+}
+
+func (d *Docs) addDocument(line []byte) error {
+ var doc Document
+ if err := json.Unmarshal(line, &doc); err != nil {
+ return err
+ }
+
+ d.Entries[doc.Id] = strings.TrimPrefix(doc.Uri, d.Root)
+
+ return nil
+}
+
+func (d *Docs) addDocRanges(line []byte) error {
+ var docRange DocumentRange
+ if err := json.Unmarshal(line, &docRange); err != nil {
+ return err
+ }
+
+ d.DocRanges[docRange.OutV] = append(d.DocRanges[docRange.OutV], docRange.RangeIds...)
+
+ return nil
+}
diff --git a/workhorse/internal/lsif_transformer/parser/docs_test.go b/workhorse/internal/lsif_transformer/parser/docs_test.go
new file mode 100644
index 00000000000..57dca8e773d
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/docs_test.go
@@ -0,0 +1,54 @@
+package parser
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func createLine(id, label, uri string) []byte {
+ return []byte(fmt.Sprintf(`{"id":"%s","label":"%s","uri":"%s"}`+"\n", id, label, uri))
+}
+
+func TestParse(t *testing.T) {
+ d, err := NewDocs(Config{})
+ require.NoError(t, err)
+ defer d.Close()
+
+ data := []byte(`{"id":"1","label":"metaData","projectRoot":"file:///Users/nested"}` + "\n")
+ data = append(data, createLine("2", "document", "file:///Users/nested/file.rb")...)
+ data = append(data, createLine("3", "document", "file:///Users/nested/folder/file.rb")...)
+ data = append(data, createLine("4", "document", "file:///Users/wrong/file.rb")...)
+
+ require.NoError(t, d.Parse(bytes.NewReader(data)))
+
+ require.Equal(t, d.Entries[2], "file.rb")
+ require.Equal(t, d.Entries[3], "folder/file.rb")
+ require.Equal(t, d.Entries[4], "file:///Users/wrong/file.rb")
+}
+
+func TestParseContainsLine(t *testing.T) {
+ d, err := NewDocs(Config{})
+ require.NoError(t, err)
+ defer d.Close()
+
+ data := []byte(`{"id":"5","label":"contains","outV":"1", "inVs": ["2", "3"]}` + "\n")
+ data = append(data, []byte(`{"id":"6","label":"contains","outV":"1", "inVs": [4]}`+"\n")...)
+
+ require.NoError(t, d.Parse(bytes.NewReader(data)))
+
+ require.Equal(t, []Id{2, 3, 4}, d.DocRanges[1])
+}
+
+func TestParsingVeryLongLine(t *testing.T) {
+ d, err := NewDocs(Config{})
+ require.NoError(t, err)
+ defer d.Close()
+
+ line := []byte(`{"id": "` + strings.Repeat("a", 64*1024) + `"}`)
+
+ require.NoError(t, d.Parse(bytes.NewReader(line)))
+}
diff --git a/workhorse/internal/lsif_transformer/parser/errors.go b/workhorse/internal/lsif_transformer/parser/errors.go
new file mode 100644
index 00000000000..1040a789413
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/errors.go
@@ -0,0 +1,30 @@
+package parser
+
+import (
+ "errors"
+ "strings"
+)
+
+func combineErrors(errsOrNil ...error) error {
+ var errs []error
+ for _, err := range errsOrNil {
+ if err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ if len(errs) == 0 {
+ return nil
+ }
+
+ if len(errs) == 1 {
+ return errs[0]
+ }
+
+ var msgs []string
+ for _, err := range errs {
+ msgs = append(msgs, err.Error())
+ }
+
+ return errors.New(strings.Join(msgs, "\n"))
+}
diff --git a/workhorse/internal/lsif_transformer/parser/errors_test.go b/workhorse/internal/lsif_transformer/parser/errors_test.go
new file mode 100644
index 00000000000..31a7130d05e
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/errors_test.go
@@ -0,0 +1,26 @@
+package parser
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type customErr struct {
+ err string
+}
+
+func (e customErr) Error() string {
+ return e.err
+}
+
+func TestCombineErrors(t *testing.T) {
+ err := combineErrors(nil, errors.New("first"), nil, customErr{"second"})
+ require.EqualError(t, err, "first\nsecond")
+
+ err = customErr{"custom error"}
+ require.Equal(t, err, combineErrors(nil, err, nil))
+
+ require.Nil(t, combineErrors(nil, nil, nil))
+}
diff --git a/workhorse/internal/lsif_transformer/parser/hovers.go b/workhorse/internal/lsif_transformer/parser/hovers.go
new file mode 100644
index 00000000000..e96d7e4fca3
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/hovers.go
@@ -0,0 +1,162 @@
+package parser
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "os"
+)
+
+type Offset struct {
+ At int32
+ Len int32
+}
+
+type Hovers struct {
+ File *os.File
+ Offsets *cache
+ CurrentOffset int
+}
+
+type RawResult struct {
+ Contents []json.RawMessage `json:"contents"`
+}
+
+type RawData struct {
+ Id Id `json:"id"`
+ Result RawResult `json:"result"`
+}
+
+type HoverRef struct {
+ ResultSetId Id `json:"outV"`
+ HoverId Id `json:"inV"`
+}
+
+type ResultSetRef struct {
+ ResultSetId Id `json:"outV"`
+ RefId Id `json:"inV"`
+}
+
+func NewHovers(config Config) (*Hovers, error) {
+ tempPath := config.TempPath
+
+ file, err := ioutil.TempFile(tempPath, "hovers")
+ if err != nil {
+ return nil, err
+ }
+
+ if err := os.Remove(file.Name()); err != nil {
+ return nil, err
+ }
+
+ offsets, err := newCache(tempPath, "hovers-indexes", Offset{})
+ if err != nil {
+ return nil, err
+ }
+
+ return &Hovers{
+ File: file,
+ Offsets: offsets,
+ CurrentOffset: 0,
+ }, nil
+}
+
+func (h *Hovers) Read(label string, line []byte) error {
+ switch label {
+ case "hoverResult":
+ if err := h.addData(line); err != nil {
+ return err
+ }
+ case "textDocument/hover":
+ if err := h.addHoverRef(line); err != nil {
+ return err
+ }
+ case "textDocument/references":
+ if err := h.addResultSetRef(line); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (h *Hovers) For(refId Id) json.RawMessage {
+ var offset Offset
+ if err := h.Offsets.Entry(refId, &offset); err != nil || offset.Len == 0 {
+ return nil
+ }
+
+ hover := make([]byte, offset.Len)
+ _, err := h.File.ReadAt(hover, int64(offset.At))
+ if err != nil {
+ return nil
+ }
+
+ return json.RawMessage(hover)
+}
+
+func (h *Hovers) Close() error {
+ return combineErrors(
+ h.File.Close(),
+ h.Offsets.Close(),
+ )
+}
+
+func (h *Hovers) addData(line []byte) error {
+ var rawData RawData
+ if err := json.Unmarshal(line, &rawData); err != nil {
+ return err
+ }
+
+ codeHovers := []*codeHover{}
+ for _, rawContent := range rawData.Result.Contents {
+ c, err := newCodeHover(rawContent)
+ if err != nil {
+ return err
+ }
+
+ codeHovers = append(codeHovers, c)
+ }
+
+ codeHoversData, err := json.Marshal(codeHovers)
+ if err != nil {
+ return err
+ }
+
+ n, err := h.File.Write(codeHoversData)
+ if err != nil {
+ return err
+ }
+
+ offset := Offset{At: int32(h.CurrentOffset), Len: int32(n)}
+ h.CurrentOffset += n
+
+ return h.Offsets.SetEntry(rawData.Id, &offset)
+}
+
+func (h *Hovers) addHoverRef(line []byte) error {
+ var hoverRef HoverRef
+ if err := json.Unmarshal(line, &hoverRef); err != nil {
+ return err
+ }
+
+ var offset Offset
+ if err := h.Offsets.Entry(hoverRef.HoverId, &offset); err != nil {
+ return err
+ }
+
+ return h.Offsets.SetEntry(hoverRef.ResultSetId, &offset)
+}
+
+func (h *Hovers) addResultSetRef(line []byte) error {
+ var ref ResultSetRef
+ if err := json.Unmarshal(line, &ref); err != nil {
+ return err
+ }
+
+ var offset Offset
+ if err := h.Offsets.Entry(ref.ResultSetId, &offset); err != nil {
+ return nil
+ }
+
+ return h.Offsets.SetEntry(ref.RefId, &offset)
+}
diff --git a/workhorse/internal/lsif_transformer/parser/hovers_test.go b/workhorse/internal/lsif_transformer/parser/hovers_test.go
new file mode 100644
index 00000000000..3037be103af
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/hovers_test.go
@@ -0,0 +1,30 @@
+package parser
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestHoversRead(t *testing.T) {
+ h := setupHovers(t)
+
+ var offset Offset
+ require.NoError(t, h.Offsets.Entry(2, &offset))
+ require.Equal(t, Offset{At: 0, Len: 19}, offset)
+
+ require.Equal(t, `[{"value":"hello"}]`, string(h.For(1)))
+
+ require.NoError(t, h.Close())
+}
+
+func setupHovers(t *testing.T) *Hovers {
+ h, err := NewHovers(Config{})
+ require.NoError(t, err)
+
+ require.NoError(t, h.Read("hoverResult", []byte(`{"id":"2","label":"hoverResult","result":{"contents": ["hello"]}}`)))
+ require.NoError(t, h.Read("textDocument/hover", []byte(`{"id":4,"label":"textDocument/hover","outV":"3","inV":2}`)))
+ require.NoError(t, h.Read("textDocument/references", []byte(`{"id":"3","label":"textDocument/references","outV":3,"inV":"1"}`)))
+
+ return h
+}
diff --git a/workhorse/internal/lsif_transformer/parser/id.go b/workhorse/internal/lsif_transformer/parser/id.go
new file mode 100644
index 00000000000..2adc4e092f5
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/id.go
@@ -0,0 +1,52 @@
+package parser
+
+import (
+ "encoding/json"
+ "errors"
+ "strconv"
+)
+
+const (
+ minId = 1
+ maxId = 20 * 1000 * 1000
+)
+
+type Id int32
+
+func (id *Id) UnmarshalJSON(b []byte) error {
+ if len(b) > 0 && b[0] != '"' {
+ if err := id.unmarshalInt(b); err != nil {
+ return err
+ }
+ } else {
+ if err := id.unmarshalString(b); err != nil {
+ return err
+ }
+ }
+
+ if *id < minId || *id > maxId {
+ return errors.New("json: id is invalid")
+ }
+
+ return nil
+}
+
+func (id *Id) unmarshalInt(b []byte) error {
+ return json.Unmarshal(b, (*int32)(id))
+}
+
+func (id *Id) unmarshalString(b []byte) error {
+ var s string
+ if err := json.Unmarshal(b, &s); err != nil {
+ return err
+ }
+
+ i, err := strconv.Atoi(s)
+ if err != nil {
+ return err
+ }
+
+ *id = Id(i)
+
+ return nil
+}
diff --git a/workhorse/internal/lsif_transformer/parser/id_test.go b/workhorse/internal/lsif_transformer/parser/id_test.go
new file mode 100644
index 00000000000..c1c53928378
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/id_test.go
@@ -0,0 +1,28 @@
+package parser
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type jsonWithId struct {
+ Value Id `json:"value"`
+}
+
+func TestId(t *testing.T) {
+ var v jsonWithId
+ require.NoError(t, json.Unmarshal([]byte(`{ "value": 1230 }`), &v))
+ require.Equal(t, Id(1230), v.Value)
+
+ require.NoError(t, json.Unmarshal([]byte(`{ "value": "1230" }`), &v))
+ require.Equal(t, Id(1230), v.Value)
+
+ require.Error(t, json.Unmarshal([]byte(`{ "value": "1.5" }`), &v))
+ require.Error(t, json.Unmarshal([]byte(`{ "value": 1.5 }`), &v))
+ require.Error(t, json.Unmarshal([]byte(`{ "value": "-1" }`), &v))
+ require.Error(t, json.Unmarshal([]byte(`{ "value": -1 }`), &v))
+ require.Error(t, json.Unmarshal([]byte(`{ "value": 21000000 }`), &v))
+ require.Error(t, json.Unmarshal([]byte(`{ "value": "21000000" }`), &v))
+}
diff --git a/workhorse/internal/lsif_transformer/parser/parser.go b/workhorse/internal/lsif_transformer/parser/parser.go
new file mode 100644
index 00000000000..085e7a856aa
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/parser.go
@@ -0,0 +1,109 @@
+package parser
+
+import (
+ "archive/zip"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+var (
+ Lsif = "lsif"
+)
+
+type Parser struct {
+ Docs *Docs
+
+ pr *io.PipeReader
+}
+
+type Config struct {
+ TempPath string
+}
+
+func NewParser(ctx context.Context, r io.Reader, config Config) (io.ReadCloser, error) {
+ docs, err := NewDocs(config)
+ if err != nil {
+ return nil, err
+ }
+
+ // ZIP files need to be seekable. Don't hold it all in RAM, use a tempfile
+ tempFile, err := ioutil.TempFile(config.TempPath, Lsif)
+ if err != nil {
+ return nil, err
+ }
+
+ defer tempFile.Close()
+
+ if err := os.Remove(tempFile.Name()); err != nil {
+ return nil, err
+ }
+
+ size, err := io.Copy(tempFile, r)
+ if err != nil {
+ return nil, err
+ }
+ log.WithContextFields(ctx, log.Fields{"lsif_zip_cache_bytes": size}).Print("cached incoming LSIF zip on disk")
+
+ zr, err := zip.NewReader(tempFile, size)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(zr.File) == 0 {
+ return nil, errors.New("empty zip file")
+ }
+
+ file, err := zr.File[0].Open()
+ if err != nil {
+ return nil, err
+ }
+
+ defer file.Close()
+
+ if err := docs.Parse(file); err != nil {
+ return nil, err
+ }
+
+ pr, pw := io.Pipe()
+ parser := &Parser{
+ Docs: docs,
+ pr: pr,
+ }
+
+ go parser.transform(pw)
+
+ return parser, nil
+}
+
+func (p *Parser) Read(b []byte) (int, error) {
+ return p.pr.Read(b)
+}
+
+func (p *Parser) Close() error {
+ p.pr.Close()
+
+ return p.Docs.Close()
+}
+
+func (p *Parser) transform(pw *io.PipeWriter) {
+ zw := zip.NewWriter(pw)
+
+ if err := p.Docs.SerializeEntries(zw); err != nil {
+ zw.Close() // Free underlying resources only
+ pw.CloseWithError(fmt.Errorf("lsif parser: Docs.SerializeEntries: %v", err))
+ return
+ }
+
+ if err := zw.Close(); err != nil {
+ pw.CloseWithError(fmt.Errorf("lsif parser: ZipWriter.Close: %v", err))
+ return
+ }
+
+ pw.Close()
+}
diff --git a/workhorse/internal/lsif_transformer/parser/parser_test.go b/workhorse/internal/lsif_transformer/parser/parser_test.go
new file mode 100644
index 00000000000..3a4d72360e2
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/parser_test.go
@@ -0,0 +1,80 @@
+package parser
+
+import (
+ "archive/zip"
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestGenerate(t *testing.T) {
+ filePath := "testdata/dump.lsif.zip"
+ tmpDir := filePath + ".tmp"
+ defer os.RemoveAll(tmpDir)
+
+ createFiles(t, filePath, tmpDir)
+
+ verifyCorrectnessOf(t, tmpDir, "lsif/main.go.json")
+ verifyCorrectnessOf(t, tmpDir, "lsif/morestrings/reverse.go.json")
+}
+
+func verifyCorrectnessOf(t *testing.T, tmpDir, fileName string) {
+ file, err := ioutil.ReadFile(filepath.Join(tmpDir, fileName))
+ require.NoError(t, err)
+
+ var buf bytes.Buffer
+ require.NoError(t, json.Indent(&buf, file, "", " "))
+
+ expected, err := ioutil.ReadFile(filepath.Join("testdata/expected/", fileName))
+ require.NoError(t, err)
+
+ require.Equal(t, string(expected), buf.String())
+}
+
+func createFiles(t *testing.T, filePath, tmpDir string) {
+ t.Helper()
+ file, err := os.Open(filePath)
+ require.NoError(t, err)
+
+ parser, err := NewParser(context.Background(), file, Config{})
+ require.NoError(t, err)
+
+ zipFileName := tmpDir + ".zip"
+ w, err := os.Create(zipFileName)
+ require.NoError(t, err)
+ defer os.RemoveAll(zipFileName)
+
+ _, err = io.Copy(w, parser)
+ require.NoError(t, err)
+ require.NoError(t, parser.Close())
+
+ extractZipFiles(t, tmpDir, zipFileName)
+}
+
+func extractZipFiles(t *testing.T, tmpDir, zipFileName string) {
+ zipReader, err := zip.OpenReader(zipFileName)
+ require.NoError(t, err)
+
+ for _, file := range zipReader.Reader.File {
+ zippedFile, err := file.Open()
+ require.NoError(t, err)
+ defer zippedFile.Close()
+
+ fileDir, fileName := filepath.Split(file.Name)
+ require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, fileDir), os.ModePerm))
+
+ outputFile, err := os.Create(filepath.Join(tmpDir, fileDir, fileName))
+ require.NoError(t, err)
+ defer outputFile.Close()
+
+ _, err = io.Copy(outputFile, zippedFile)
+ require.NoError(t, err)
+ }
+}
diff --git a/workhorse/internal/lsif_transformer/parser/performance_test.go b/workhorse/internal/lsif_transformer/parser/performance_test.go
new file mode 100644
index 00000000000..5a12d90072f
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/performance_test.go
@@ -0,0 +1,47 @@
+package parser
+
+import (
+ "context"
+ "io"
+ "io/ioutil"
+ "os"
+ "runtime"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func BenchmarkGenerate(b *testing.B) {
+ filePath := "testdata/workhorse.lsif.zip"
+ tmpDir := filePath + ".tmp"
+ defer os.RemoveAll(tmpDir)
+
+ var memoryUsage float64
+ for i := 0; i < b.N; i++ {
+ memoryUsage += measureMemory(func() {
+ file, err := os.Open(filePath)
+ require.NoError(b, err)
+
+ parser, err := NewParser(context.Background(), file, Config{})
+ require.NoError(b, err)
+
+ _, err = io.Copy(ioutil.Discard, parser)
+ require.NoError(b, err)
+ require.NoError(b, parser.Close())
+ })
+ }
+
+ b.ReportMetric(memoryUsage/float64(b.N), "MiB/op")
+}
+
+func measureMemory(f func()) float64 {
+ var m, m1 runtime.MemStats
+ runtime.ReadMemStats(&m)
+
+ f()
+
+ runtime.ReadMemStats(&m1)
+ runtime.GC()
+
+ return float64(m1.Alloc-m.Alloc) / 1024 / 1024
+}
diff --git a/workhorse/internal/lsif_transformer/parser/ranges.go b/workhorse/internal/lsif_transformer/parser/ranges.go
new file mode 100644
index 00000000000..a11a66d70ca
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/ranges.go
@@ -0,0 +1,214 @@
+package parser
+
+import (
+ "encoding/json"
+ "errors"
+ "io"
+ "strconv"
+)
+
+const (
+ definitions = "definitions"
+ references = "references"
+)
+
+type Ranges struct {
+ DefRefs map[Id]Item
+ References *References
+ Hovers *Hovers
+ Cache *cache
+}
+
+type RawRange struct {
+ Id Id `json:"id"`
+ Data Range `json:"start"`
+}
+
+type Range struct {
+ Line int32 `json:"line"`
+ Character int32 `json:"character"`
+ RefId Id
+}
+
+type RawItem struct {
+ Property string `json:"property"`
+ RefId Id `json:"outV"`
+ RangeIds []Id `json:"inVs"`
+ DocId Id `json:"document"`
+}
+
+type Item struct {
+ Line int32
+ DocId Id
+}
+
+type SerializedRange struct {
+ StartLine int32 `json:"start_line"`
+ StartChar int32 `json:"start_char"`
+ DefinitionPath string `json:"definition_path,omitempty"`
+ Hover json.RawMessage `json:"hover"`
+ References []SerializedReference `json:"references,omitempty"`
+}
+
+func NewRanges(config Config) (*Ranges, error) {
+ hovers, err := NewHovers(config)
+ if err != nil {
+ return nil, err
+ }
+
+ references, err := NewReferences(config)
+ if err != nil {
+ return nil, err
+ }
+
+ cache, err := newCache(config.TempPath, "ranges", Range{})
+ if err != nil {
+ return nil, err
+ }
+
+ return &Ranges{
+ DefRefs: make(map[Id]Item),
+ References: references,
+ Hovers: hovers,
+ Cache: cache,
+ }, nil
+}
+
+func (r *Ranges) Read(label string, line []byte) error {
+ switch label {
+ case "range":
+ if err := r.addRange(line); err != nil {
+ return err
+ }
+ case "item":
+ if err := r.addItem(line); err != nil {
+ return err
+ }
+ default:
+ return r.Hovers.Read(label, line)
+ }
+
+ return nil
+}
+
+func (r *Ranges) Serialize(f io.Writer, rangeIds []Id, docs map[Id]string) error {
+ encoder := json.NewEncoder(f)
+ n := len(rangeIds)
+
+ if _, err := f.Write([]byte("[")); err != nil {
+ return err
+ }
+
+ for i, rangeId := range rangeIds {
+ entry, err := r.getRange(rangeId)
+ if err != nil {
+ continue
+ }
+
+ serializedRange := SerializedRange{
+ StartLine: entry.Line,
+ StartChar: entry.Character,
+ DefinitionPath: r.definitionPathFor(docs, entry.RefId),
+ Hover: r.Hovers.For(entry.RefId),
+ References: r.References.For(docs, entry.RefId),
+ }
+ if err := encoder.Encode(serializedRange); err != nil {
+ return err
+ }
+ if i+1 < n {
+ if _, err := f.Write([]byte(",")); err != nil {
+ return err
+ }
+ }
+ }
+
+ if _, err := f.Write([]byte("]")); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (r *Ranges) Close() error {
+ return combineErrors(
+ r.Cache.Close(),
+ r.References.Close(),
+ r.Hovers.Close(),
+ )
+}
+
+func (r *Ranges) definitionPathFor(docs map[Id]string, refId Id) string {
+ defRef, ok := r.DefRefs[refId]
+ if !ok {
+ return ""
+ }
+
+ defPath := docs[defRef.DocId] + "#L" + strconv.Itoa(int(defRef.Line))
+
+ return defPath
+}
+
+func (r *Ranges) addRange(line []byte) error {
+ var rg RawRange
+ if err := json.Unmarshal(line, &rg); err != nil {
+ return err
+ }
+
+ return r.Cache.SetEntry(rg.Id, &rg.Data)
+}
+
+func (r *Ranges) addItem(line []byte) error {
+ var rawItem RawItem
+ if err := json.Unmarshal(line, &rawItem); err != nil {
+ return err
+ }
+
+ if rawItem.Property != definitions && rawItem.Property != references {
+ return nil
+ }
+
+ if len(rawItem.RangeIds) == 0 {
+ return errors.New("no range IDs")
+ }
+
+ var references []Item
+
+ for _, rangeId := range rawItem.RangeIds {
+ rg, err := r.getRange(rangeId)
+ if err != nil {
+ return err
+ }
+
+ rg.RefId = rawItem.RefId
+
+ if err := r.Cache.SetEntry(rangeId, rg); err != nil {
+ return err
+ }
+
+ item := Item{
+ Line: rg.Line + 1,
+ DocId: rawItem.DocId,
+ }
+
+ if rawItem.Property == definitions {
+ r.DefRefs[rawItem.RefId] = item
+ } else {
+ references = append(references, item)
+ }
+ }
+
+ if err := r.References.Store(rawItem.RefId, references); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (r *Ranges) getRange(rangeId Id) (*Range, error) {
+ var rg Range
+ if err := r.Cache.Entry(rangeId, &rg); err != nil {
+ return nil, err
+ }
+
+ return &rg, nil
+}
diff --git a/workhorse/internal/lsif_transformer/parser/ranges_test.go b/workhorse/internal/lsif_transformer/parser/ranges_test.go
new file mode 100644
index 00000000000..c1400ba61da
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/ranges_test.go
@@ -0,0 +1,61 @@
+package parser
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRangesRead(t *testing.T) {
+ r, cleanup := setup(t)
+ defer cleanup()
+
+ firstRange := Range{Line: 1, Character: 2, RefId: 4}
+ rg, err := r.getRange(1)
+ require.NoError(t, err)
+ require.Equal(t, &firstRange, rg)
+
+ secondRange := Range{Line: 5, Character: 4, RefId: 4}
+ rg, err = r.getRange(2)
+ require.NoError(t, err)
+ require.Equal(t, &secondRange, rg)
+
+ thirdRange := Range{Line: 7, Character: 4, RefId: 4}
+ rg, err = r.getRange(3)
+ require.NoError(t, err)
+ require.Equal(t, &thirdRange, rg)
+}
+
+func TestSerialize(t *testing.T) {
+ r, cleanup := setup(t)
+ defer cleanup()
+
+ docs := map[Id]string{6: "def-path", 7: "ref-path"}
+
+ var buf bytes.Buffer
+ err := r.Serialize(&buf, []Id{1}, docs)
+ want := `[{"start_line":1,"start_char":2,"definition_path":"def-path#L2","hover":null,"references":[{"path":"ref-path#L6"},{"path":"ref-path#L8"}]}` + "\n]"
+
+ require.NoError(t, err)
+ require.Equal(t, want, buf.String())
+}
+
+func setup(t *testing.T) (*Ranges, func()) {
+ r, err := NewRanges(Config{})
+ require.NoError(t, err)
+
+ require.NoError(t, r.Read("range", []byte(`{"id":1,"label":"range","start":{"line":1,"character":2}}`)))
+ require.NoError(t, r.Read("range", []byte(`{"id":"2","label":"range","start":{"line":5,"character":4}}`)))
+ require.NoError(t, r.Read("range", []byte(`{"id":"3","label":"range","start":{"line":7,"character":4}}`)))
+
+ require.NoError(t, r.Read("item", []byte(`{"id":5,"label":"item","property":"definitions","outV":"4","inVs":[1],"document":"6"}`)))
+ require.NoError(t, r.Read("item", []byte(`{"id":"6","label":"item","property":"references","outV":4,"inVs":["2"],"document":"7"}`)))
+ require.NoError(t, r.Read("item", []byte(`{"id":"7","label":"item","property":"references","outV":4,"inVs":["3"],"document":"7"}`)))
+
+ cleanup := func() {
+ require.NoError(t, r.Close())
+ }
+
+ return r, cleanup
+}
diff --git a/workhorse/internal/lsif_transformer/parser/references.go b/workhorse/internal/lsif_transformer/parser/references.go
new file mode 100644
index 00000000000..58ff9a61c02
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/references.go
@@ -0,0 +1,107 @@
+package parser
+
+import (
+ "strconv"
+)
+
+type ReferencesOffset struct {
+ Id Id
+ Len int32
+}
+
+type References struct {
+ Items *cache
+ Offsets *cache
+ CurrentOffsetId Id
+}
+
+type SerializedReference struct {
+ Path string `json:"path"`
+}
+
+func NewReferences(config Config) (*References, error) {
+ tempPath := config.TempPath
+
+ items, err := newCache(tempPath, "references", Item{})
+ if err != nil {
+ return nil, err
+ }
+
+ offsets, err := newCache(tempPath, "references-offsets", ReferencesOffset{})
+ if err != nil {
+ return nil, err
+ }
+
+ return &References{
+ Items: items,
+ Offsets: offsets,
+ CurrentOffsetId: 0,
+ }, nil
+}
+
+// Store is responsible for keeping track of references that will be used when
+// serializing in `For`.
+//
+// The references are stored in a file to cache them. It is like
+// `map[Id][]Item` (where `Id` is `refId`) but relies on caching the array and
+// its offset in files for storage to reduce RAM usage. The items can be
+// fetched by calling `getItems`.
+func (r *References) Store(refId Id, references []Item) error {
+ size := len(references)
+
+ if size == 0 {
+ return nil
+ }
+
+ items := append(r.getItems(refId), references...)
+ err := r.Items.SetEntry(r.CurrentOffsetId, items)
+ if err != nil {
+ return err
+ }
+
+ size = len(items)
+ r.Offsets.SetEntry(refId, ReferencesOffset{Id: r.CurrentOffsetId, Len: int32(size)})
+ r.CurrentOffsetId += Id(size)
+
+ return nil
+}
+
+func (r *References) For(docs map[Id]string, refId Id) []SerializedReference {
+ references := r.getItems(refId)
+ if references == nil {
+ return nil
+ }
+
+ var serializedReferences []SerializedReference
+
+ for _, reference := range references {
+ serializedReference := SerializedReference{
+ Path: docs[reference.DocId] + "#L" + strconv.Itoa(int(reference.Line)),
+ }
+
+ serializedReferences = append(serializedReferences, serializedReference)
+ }
+
+ return serializedReferences
+}
+
+func (r *References) Close() error {
+ return combineErrors(
+ r.Items.Close(),
+ r.Offsets.Close(),
+ )
+}
+
+func (r *References) getItems(refId Id) []Item {
+ var offset ReferencesOffset
+ if err := r.Offsets.Entry(refId, &offset); err != nil || offset.Len == 0 {
+ return nil
+ }
+
+ items := make([]Item, offset.Len)
+ if err := r.Items.Entry(offset.Id, &items); err != nil {
+ return nil
+ }
+
+ return items
+}
diff --git a/workhorse/internal/lsif_transformer/parser/references_test.go b/workhorse/internal/lsif_transformer/parser/references_test.go
new file mode 100644
index 00000000000..7b47513bc53
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/references_test.go
@@ -0,0 +1,44 @@
+package parser
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestReferencesStore(t *testing.T) {
+ const (
+ docId = 1
+ refId = 3
+ )
+
+ r, err := NewReferences(Config{})
+ require.NoError(t, err)
+
+ err = r.Store(refId, []Item{{Line: 2, DocId: docId}, {Line: 3, DocId: docId}})
+ require.NoError(t, err)
+
+ docs := map[Id]string{docId: "doc.go"}
+ serializedReferences := r.For(docs, refId)
+
+ require.Contains(t, serializedReferences, SerializedReference{Path: "doc.go#L2"})
+ require.Contains(t, serializedReferences, SerializedReference{Path: "doc.go#L3"})
+
+ require.NoError(t, r.Close())
+}
+
+func TestReferencesStoreEmpty(t *testing.T) {
+ const refId = 3
+
+ r, err := NewReferences(Config{})
+ require.NoError(t, err)
+
+ err = r.Store(refId, []Item{})
+ require.NoError(t, err)
+
+ docs := map[Id]string{1: "doc.go"}
+ serializedReferences := r.For(docs, refId)
+
+ require.Nil(t, serializedReferences)
+ require.NoError(t, r.Close())
+}
diff --git a/workhorse/internal/lsif_transformer/parser/testdata/dump.lsif.zip b/workhorse/internal/lsif_transformer/parser/testdata/dump.lsif.zip
new file mode 100644
index 00000000000..e7c9ef2da66
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/testdata/dump.lsif.zip
Binary files differ
diff --git a/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/main.go.json b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/main.go.json
new file mode 100644
index 00000000000..781cb78fc1a
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/main.go.json
@@ -0,0 +1,208 @@
+[
+ {
+ "start_line": 7,
+ "start_char": 1,
+ "definition_path": "main.go#L4",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kn",
+ "value": "package"
+ },
+ {
+ "value": " "
+ },
+ {
+ "class": "s",
+ "value": "\"github.com/user/hello/morestrings\""
+ }
+ ]
+ ],
+ "language": "go"
+ },
+ {
+ "value": "Package morestrings implements additional functions to manipulate UTF-8 encoded strings, beyond what is provided in the standard \"strings\" package. \n\n"
+ }
+ ],
+ "references": [
+ {
+ "path": "main.go#L8"
+ },
+ {
+ "path": "main.go#L9"
+ }
+ ]
+ },
+ {
+ "start_line": 7,
+ "start_char": 13,
+ "definition_path": "morestrings/reverse.go#L12",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "func"
+ },
+ {
+ "value": " Reverse(s "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ },
+ {
+ "value": ") "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ },
+ {
+ "value": "This method reverses a string \n\n"
+ }
+ ],
+ "references": [
+ {
+ "path": "main.go#L8"
+ }
+ ]
+ },
+ {
+ "start_line": 8,
+ "start_char": 1,
+ "definition_path": "main.go#L4",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kn",
+ "value": "package"
+ },
+ {
+ "value": " "
+ },
+ {
+ "class": "s",
+ "value": "\"github.com/user/hello/morestrings\""
+ }
+ ]
+ ],
+ "language": "go"
+ },
+ {
+ "value": "Package morestrings implements additional functions to manipulate UTF-8 encoded strings, beyond what is provided in the standard \"strings\" package. \n\n"
+ }
+ ],
+ "references": [
+ {
+ "path": "main.go#L8"
+ },
+ {
+ "path": "main.go#L9"
+ }
+ ]
+ },
+ {
+ "start_line": 8,
+ "start_char": 13,
+ "definition_path": "morestrings/reverse.go#L5",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "func"
+ },
+ {
+ "value": " Func2(i "
+ },
+ {
+ "class": "kt",
+ "value": "int"
+ },
+ {
+ "value": ") "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ],
+ "references": [
+ {
+ "path": "main.go#L9"
+ }
+ ]
+ },
+ {
+ "start_line": 6,
+ "start_char": 5,
+ "definition_path": "main.go#L7",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "func"
+ },
+ {
+ "value": " main()"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ]
+ },
+ {
+ "start_line": 3,
+ "start_char": 2,
+ "definition_path": "main.go#L4",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kn",
+ "value": "package"
+ },
+ {
+ "value": " "
+ },
+ {
+ "class": "s",
+ "value": "\"github.com/user/hello/morestrings\""
+ }
+ ]
+ ],
+ "language": "go"
+ },
+ {
+ "value": "Package morestrings implements additional functions to manipulate UTF-8 encoded strings, beyond what is provided in the standard \"strings\" package. \n\n"
+ }
+ ],
+ "references": [
+ {
+ "path": "main.go#L8"
+ },
+ {
+ "path": "main.go#L9"
+ }
+ ]
+ }
+] \ No newline at end of file
diff --git a/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/morestrings/reverse.go.json b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/morestrings/reverse.go.json
new file mode 100644
index 00000000000..1d238413d53
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/testdata/expected/lsif/morestrings/reverse.go.json
@@ -0,0 +1,249 @@
+[
+ {
+ "start_line": 11,
+ "start_char": 5,
+ "definition_path": "morestrings/reverse.go#L12",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "func"
+ },
+ {
+ "value": " Reverse(s "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ },
+ {
+ "value": ") "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ },
+ {
+ "value": "This method reverses a string \n\n"
+ }
+ ],
+ "references": [
+ {
+ "path": "main.go#L8"
+ }
+ ]
+ },
+ {
+ "start_line": 4,
+ "start_char": 11,
+ "definition_path": "morestrings/reverse.go#L5",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "var"
+ },
+ {
+ "value": " i "
+ },
+ {
+ "class": "kt",
+ "value": "int"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ]
+ },
+ {
+ "start_line": 11,
+ "start_char": 13,
+ "definition_path": "morestrings/reverse.go#L12",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "var"
+ },
+ {
+ "value": " s "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ]
+ },
+ {
+ "start_line": 12,
+ "start_char": 1,
+ "definition_path": "morestrings/reverse.go#L13",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "var"
+ },
+ {
+ "value": " a "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ],
+ "references": [
+ {
+ "path": "morestrings/reverse.go#L15"
+ }
+ ]
+ },
+ {
+ "start_line": 5,
+ "start_char": 1,
+ "definition_path": "morestrings/reverse.go#L6",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "var"
+ },
+ {
+ "value": " b "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ],
+ "references": [
+ {
+ "path": "morestrings/reverse.go#L8"
+ }
+ ]
+ },
+ {
+ "start_line": 14,
+ "start_char": 8,
+ "definition_path": "morestrings/reverse.go#L13",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "var"
+ },
+ {
+ "value": " a "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ],
+ "references": [
+ {
+ "path": "morestrings/reverse.go#L15"
+ }
+ ]
+ },
+ {
+ "start_line": 7,
+ "start_char": 8,
+ "definition_path": "morestrings/reverse.go#L6",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "var"
+ },
+ {
+ "value": " b "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ],
+ "references": [
+ {
+ "path": "morestrings/reverse.go#L8"
+ }
+ ]
+ },
+ {
+ "start_line": 4,
+ "start_char": 5,
+ "definition_path": "morestrings/reverse.go#L5",
+ "hover": [
+ {
+ "tokens": [
+ [
+ {
+ "class": "kd",
+ "value": "func"
+ },
+ {
+ "value": " Func2(i "
+ },
+ {
+ "class": "kt",
+ "value": "int"
+ },
+ {
+ "value": ") "
+ },
+ {
+ "class": "kt",
+ "value": "string"
+ }
+ ]
+ ],
+ "language": "go"
+ }
+ ],
+ "references": [
+ {
+ "path": "main.go#L9"
+ }
+ ]
+ }
+] \ No newline at end of file
diff --git a/workhorse/internal/lsif_transformer/parser/testdata/workhorse.lsif.zip b/workhorse/internal/lsif_transformer/parser/testdata/workhorse.lsif.zip
new file mode 100644
index 00000000000..76491ed8a93
--- /dev/null
+++ b/workhorse/internal/lsif_transformer/parser/testdata/workhorse.lsif.zip
Binary files differ
diff --git a/workhorse/internal/objectstore/gocloud_object.go b/workhorse/internal/objectstore/gocloud_object.go
new file mode 100644
index 00000000000..38545086994
--- /dev/null
+++ b/workhorse/internal/objectstore/gocloud_object.go
@@ -0,0 +1,100 @@
+package objectstore
+
+import (
+ "context"
+ "io"
+ "time"
+
+ "gitlab.com/gitlab-org/labkit/log"
+ "gocloud.dev/blob"
+ "gocloud.dev/gcerrors"
+)
+
+type GoCloudObject struct {
+ bucket *blob.Bucket
+ mux *blob.URLMux
+ bucketURL string
+ objectName string
+ *uploader
+}
+
+type GoCloudObjectParams struct {
+ Ctx context.Context
+ Mux *blob.URLMux
+ BucketURL string
+ ObjectName string
+}
+
+func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) {
+ bucket, err := p.Mux.OpenBucket(p.Ctx, p.BucketURL)
+ if err != nil {
+ return nil, err
+ }
+
+ o := &GoCloudObject{
+ bucket: bucket,
+ mux: p.Mux,
+ bucketURL: p.BucketURL,
+ objectName: p.ObjectName,
+ }
+
+ o.uploader = newUploader(o)
+ return o, nil
+}
+
+func (o *GoCloudObject) Upload(ctx context.Context, r io.Reader) error {
+ defer o.bucket.Close()
+
+ writer, err := o.bucket.NewWriter(ctx, o.objectName, nil)
+ if err != nil {
+ log.ContextLogger(ctx).WithError(err).Error("error creating GoCloud bucket")
+ return err
+ }
+
+ if _, err = io.Copy(writer, r); err != nil {
+ log.ContextLogger(ctx).WithError(err).Error("error writing to GoCloud bucket")
+ writer.Close()
+ return err
+ }
+
+ if err := writer.Close(); err != nil {
+ log.ContextLogger(ctx).WithError(err).Error("error closing GoCloud bucket")
+ return err
+ }
+
+ return nil
+}
+
+func (o *GoCloudObject) ETag() string {
+ return ""
+}
+
+func (o *GoCloudObject) Abort() {
+ o.Delete()
+}
+
+// Delete will always attempt to delete the temporary file.
+// According to https://github.com/google/go-cloud/blob/7818961b5c9a112f7e092d3a2d8479cbca80d187/blob/azureblob/azureblob.go#L881-L883,
+// if the writer is closed before any Write is called, Close will create an empty file.
+func (o *GoCloudObject) Delete() {
+ if o.bucketURL == "" || o.objectName == "" {
+ return
+ }
+
+ // Note we can't use the request context because in a successful
+ // case, the original request has already completed.
+ deleteCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // lint:allow context.Background
+ defer cancel()
+
+ bucket, err := o.mux.OpenBucket(deleteCtx, o.bucketURL)
+ if err != nil {
+ log.WithError(err).Error("error opening bucket for delete")
+ return
+ }
+
+ if err := bucket.Delete(deleteCtx, o.objectName); err != nil {
+ if gcerrors.Code(err) != gcerrors.NotFound {
+ log.WithError(err).Error("error deleting object")
+ }
+ }
+}
diff --git a/workhorse/internal/objectstore/gocloud_object_test.go b/workhorse/internal/objectstore/gocloud_object_test.go
new file mode 100644
index 00000000000..4dc9d2d75cc
--- /dev/null
+++ b/workhorse/internal/objectstore/gocloud_object_test.go
@@ -0,0 +1,56 @@
+package objectstore_test
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+func TestGoCloudObjectUpload(t *testing.T) {
+ mux, _, cleanup := test.SetupGoCloudFileBucket(t, "azuretest")
+ defer cleanup()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ deadline := time.Now().Add(testTimeout)
+
+ objectName := "test.png"
+ testURL := "azuretest://azure.example.com/test-container"
+ p := &objectstore.GoCloudObjectParams{Ctx: ctx, Mux: mux, BucketURL: testURL, ObjectName: objectName}
+ object, err := objectstore.NewGoCloudObject(p)
+ require.NotNil(t, object)
+ require.NoError(t, err)
+
+ // copy data
+ n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
+ require.NoError(t, err)
+ require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
+
+ bucket, err := mux.OpenBucket(ctx, testURL)
+ require.NoError(t, err)
+
+ // Verify the data was copied correctly.
+ received, err := bucket.ReadAll(ctx, objectName)
+ require.NoError(t, err)
+ require.Equal(t, []byte(test.ObjectContent), received)
+
+ cancel()
+
+ testhelper.Retry(t, 5*time.Second, func() error {
+ exists, err := bucket.Exists(ctx, objectName)
+ require.NoError(t, err)
+
+ if exists {
+ return fmt.Errorf("file %s is still present", objectName)
+ } else {
+ return nil
+ }
+ })
+}
diff --git a/workhorse/internal/objectstore/multipart.go b/workhorse/internal/objectstore/multipart.go
new file mode 100644
index 00000000000..fd1c0ed487d
--- /dev/null
+++ b/workhorse/internal/objectstore/multipart.go
@@ -0,0 +1,188 @@
+package objectstore
+
+import (
+ "bytes"
+ "context"
+ "encoding/xml"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "os"
+
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+)
+
+// ErrNotEnoughParts will be used when writing more than size * len(partURLs)
+var ErrNotEnoughParts = errors.New("not enough Parts")
+
+// Multipart represents a MultipartUpload on a S3 compatible Object Store service.
+// It can be used as io.WriteCloser for uploading an object
+type Multipart struct {
+ PartURLs []string
+ // CompleteURL is a presigned URL for CompleteMultipartUpload
+ CompleteURL string
+ // AbortURL is a presigned URL for AbortMultipartUpload
+ AbortURL string
+ // DeleteURL is a presigned URL for RemoveObject
+ DeleteURL string
+ PutHeaders map[string]string
+ partSize int64
+ etag string
+
+ *uploader
+}
+
+// NewMultipart provides Multipart pointer that can be used for uploading. Data written will be split buffered on disk up to size bytes
+// then uploaded with S3 Upload Part. Once Multipart is Closed a final call to CompleteMultipartUpload will be sent.
+// In case of any error a call to AbortMultipartUpload will be made to cleanup all the resources
+func NewMultipart(partURLs []string, completeURL, abortURL, deleteURL string, putHeaders map[string]string, partSize int64) (*Multipart, error) {
+ m := &Multipart{
+ PartURLs: partURLs,
+ CompleteURL: completeURL,
+ AbortURL: abortURL,
+ DeleteURL: deleteURL,
+ PutHeaders: putHeaders,
+ partSize: partSize,
+ }
+
+ m.uploader = newUploader(m)
+ return m, nil
+}
+
+func (m *Multipart) Upload(ctx context.Context, r io.Reader) error {
+ cmu := &CompleteMultipartUpload{}
+ for i, partURL := range m.PartURLs {
+ src := io.LimitReader(r, m.partSize)
+ part, err := m.readAndUploadOnePart(ctx, partURL, m.PutHeaders, src, i+1)
+ if err != nil {
+ return err
+ }
+ if part == nil {
+ break
+ } else {
+ cmu.Part = append(cmu.Part, part)
+ }
+ }
+
+ n, err := io.Copy(ioutil.Discard, r)
+ if err != nil {
+ return fmt.Errorf("drain pipe: %v", err)
+ }
+ if n > 0 {
+ return ErrNotEnoughParts
+ }
+
+ if err := m.complete(ctx, cmu); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *Multipart) ETag() string {
+ return m.etag
+}
+func (m *Multipart) Abort() {
+ deleteURL(m.AbortURL)
+}
+
+func (m *Multipart) Delete() {
+ deleteURL(m.DeleteURL)
+}
+
+func (m *Multipart) readAndUploadOnePart(ctx context.Context, partURL string, putHeaders map[string]string, src io.Reader, partNumber int) (*completeMultipartUploadPart, error) {
+ file, err := ioutil.TempFile("", "part-buffer")
+ if err != nil {
+ return nil, fmt.Errorf("create temporary buffer file: %v", err)
+ }
+ defer func(path string) {
+ if err := os.Remove(path); err != nil {
+ log.WithError(err).WithField("file", path).Warning("Unable to delete temporary file")
+ }
+ }(file.Name())
+
+ n, err := io.Copy(file, src)
+ if err != nil {
+ return nil, err
+ }
+ if n == 0 {
+ return nil, nil
+ }
+
+ if _, err = file.Seek(0, io.SeekStart); err != nil {
+ return nil, fmt.Errorf("rewind part %d temporary dump : %v", partNumber, err)
+ }
+
+ etag, err := m.uploadPart(ctx, partURL, putHeaders, file, n)
+ if err != nil {
+ return nil, fmt.Errorf("upload part %d: %v", partNumber, err)
+ }
+ return &completeMultipartUploadPart{PartNumber: partNumber, ETag: etag}, nil
+}
+
+func (m *Multipart) uploadPart(ctx context.Context, url string, headers map[string]string, body io.Reader, size int64) (string, error) {
+ deadline, ok := ctx.Deadline()
+ if !ok {
+ return "", fmt.Errorf("missing deadline")
+ }
+
+ part, err := newObject(url, "", headers, size, false)
+ if err != nil {
+ return "", err
+ }
+
+ if n, err := part.Consume(ctx, io.LimitReader(body, size), deadline); err != nil || n < size {
+ if err == nil {
+ err = io.ErrUnexpectedEOF
+ }
+ return "", err
+ }
+
+ return part.ETag(), nil
+}
+
+func (m *Multipart) complete(ctx context.Context, cmu *CompleteMultipartUpload) error {
+ body, err := xml.Marshal(cmu)
+ if err != nil {
+ return fmt.Errorf("marshal CompleteMultipartUpload request: %v", err)
+ }
+
+ req, err := http.NewRequest("POST", m.CompleteURL, bytes.NewReader(body))
+ if err != nil {
+ return fmt.Errorf("create CompleteMultipartUpload request: %v", err)
+ }
+ req.ContentLength = int64(len(body))
+ req.Header.Set("Content-Type", "application/xml")
+ req = req.WithContext(ctx)
+
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("CompleteMultipartUpload request %q: %v", mask.URL(m.CompleteURL), err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("CompleteMultipartUpload request %v returned: %s", mask.URL(m.CompleteURL), resp.Status)
+ }
+
+ result := &compoundCompleteMultipartUploadResult{}
+ decoder := xml.NewDecoder(resp.Body)
+ if err := decoder.Decode(&result); err != nil {
+ return fmt.Errorf("decode CompleteMultipartUpload answer: %v", err)
+ }
+
+ if result.isError() {
+ return result
+ }
+
+ if result.CompleteMultipartUploadResult == nil {
+ return fmt.Errorf("empty CompleteMultipartUploadResult")
+ }
+
+ m.etag = extractETag(result.ETag)
+
+ return nil
+}
diff --git a/workhorse/internal/objectstore/multipart_test.go b/workhorse/internal/objectstore/multipart_test.go
new file mode 100644
index 00000000000..00d6efc0982
--- /dev/null
+++ b/workhorse/internal/objectstore/multipart_test.go
@@ -0,0 +1,64 @@
+package objectstore_test
+
+import (
+ "context"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+)
+
+func TestMultipartUploadWithUpcaseETags(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ var putCnt, postCnt int
+
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+ defer r.Body.Close()
+
+ // Part upload request
+ if r.Method == "PUT" {
+ putCnt++
+
+ w.Header().Set("ETag", strings.ToUpper(test.ObjectMD5))
+ }
+
+ // POST with CompleteMultipartUpload request
+ if r.Method == "POST" {
+ completeBody := `<CompleteMultipartUploadResult>
+ <Bucket>test-bucket</Bucket>
+ <ETag>No Longer Checked</ETag>
+ </CompleteMultipartUploadResult>`
+ postCnt++
+
+ w.Write([]byte(completeBody))
+ }
+ }))
+ defer ts.Close()
+
+ deadline := time.Now().Add(testTimeout)
+
+ m, err := objectstore.NewMultipart(
+ []string{ts.URL}, // a single presigned part URL
+ ts.URL, // the complete multipart upload URL
+ "", // no abort
+ "", // no delete
+ map[string]string{}, // no custom headers
+ test.ObjectSize) // parts size equal to the whole content. Only 1 part
+ require.NoError(t, err)
+
+ _, err = m.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
+ require.NoError(t, err)
+ require.Equal(t, 1, putCnt, "1 part expected")
+ require.Equal(t, 1, postCnt, "1 complete multipart upload expected")
+}
diff --git a/workhorse/internal/objectstore/object.go b/workhorse/internal/objectstore/object.go
new file mode 100644
index 00000000000..eaf3bfb2e36
--- /dev/null
+++ b/workhorse/internal/objectstore/object.go
@@ -0,0 +1,114 @@
+package objectstore
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "time"
+
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/mask"
+ "gitlab.com/gitlab-org/labkit/tracing"
+)
+
+// httpTransport defines a http.Transport with values
+// that are more restrictive than for http.DefaultTransport,
+// they define shorter TLS Handshake, and more aggressive connection closing
+// to prevent the connection hanging and reduce FD usage
+var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 10 * time.Second,
+ }).DialContext,
+ MaxIdleConns: 2,
+ IdleConnTimeout: 30 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 10 * time.Second,
+ ResponseHeaderTimeout: 30 * time.Second,
+}))
+
+var httpClient = &http.Client{
+ Transport: httpTransport,
+}
+
+// Object represents an object on a S3 compatible Object Store service.
+// It can be used as io.WriteCloser for uploading an object
+type Object struct {
+ // putURL is a presigned URL for PutObject
+ putURL string
+ // deleteURL is a presigned URL for RemoveObject
+ deleteURL string
+ putHeaders map[string]string
+ size int64
+ etag string
+ metrics bool
+
+ *uploader
+}
+
+type StatusCodeError error
+
+// NewObject opens an HTTP connection to Object Store and returns an Object pointer that can be used for uploading.
+func NewObject(putURL, deleteURL string, putHeaders map[string]string, size int64) (*Object, error) {
+ return newObject(putURL, deleteURL, putHeaders, size, true)
+}
+
+func newObject(putURL, deleteURL string, putHeaders map[string]string, size int64, metrics bool) (*Object, error) {
+ o := &Object{
+ putURL: putURL,
+ deleteURL: deleteURL,
+ putHeaders: putHeaders,
+ size: size,
+ metrics: metrics,
+ }
+
+ o.uploader = newETagCheckUploader(o, metrics)
+ return o, nil
+}
+
+func (o *Object) Upload(ctx context.Context, r io.Reader) error {
+ // we should prevent pr.Close() otherwise it may shadow error set with pr.CloseWithError(err)
+ req, err := http.NewRequest(http.MethodPut, o.putURL, ioutil.NopCloser(r))
+
+ if err != nil {
+ return fmt.Errorf("PUT %q: %v", mask.URL(o.putURL), err)
+ }
+ req.ContentLength = o.size
+
+ for k, v := range o.putHeaders {
+ req.Header.Set(k, v)
+ }
+
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("PUT request %q: %v", mask.URL(o.putURL), err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ if o.metrics {
+ objectStorageUploadRequestsInvalidStatus.Inc()
+ }
+ return StatusCodeError(fmt.Errorf("PUT request %v returned: %s", mask.URL(o.putURL), resp.Status))
+ }
+
+ o.etag = extractETag(resp.Header.Get("ETag"))
+
+ return nil
+}
+
+func (o *Object) ETag() string {
+ return o.etag
+}
+
+func (o *Object) Abort() {
+ o.Delete()
+}
+
+func (o *Object) Delete() {
+ deleteURL(o.deleteURL)
+}
diff --git a/workhorse/internal/objectstore/object_test.go b/workhorse/internal/objectstore/object_test.go
new file mode 100644
index 00000000000..2ec45520e97
--- /dev/null
+++ b/workhorse/internal/objectstore/object_test.go
@@ -0,0 +1,155 @@
+package objectstore_test
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+)
+
+const testTimeout = 10 * time.Second
+
+type osFactory func() (*test.ObjectstoreStub, *httptest.Server)
+
+func testObjectUploadNoErrors(t *testing.T, startObjectStore osFactory, useDeleteURL bool, contentType string) {
+ osStub, ts := startObjectStore()
+ defer ts.Close()
+
+ objectURL := ts.URL + test.ObjectPath
+ var deleteURL string
+ if useDeleteURL {
+ deleteURL = objectURL
+ }
+
+ putHeaders := map[string]string{"Content-Type": contentType}
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ deadline := time.Now().Add(testTimeout)
+ object, err := objectstore.NewObject(objectURL, deleteURL, putHeaders, test.ObjectSize)
+ require.NoError(t, err)
+
+ // copy data
+ n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
+ require.NoError(t, err)
+ require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
+
+ require.Equal(t, contentType, osStub.GetHeader(test.ObjectPath, "Content-Type"))
+
+ // Checking MD5 extraction
+ require.Equal(t, osStub.GetObjectMD5(test.ObjectPath), object.ETag())
+
+ // Checking cleanup
+ cancel()
+ require.Equal(t, 1, osStub.PutsCnt(), "Object hasn't been uploaded")
+
+ var expectedDeleteCnt int
+ if useDeleteURL {
+ expectedDeleteCnt = 1
+ }
+ // Poll because the object removal is async
+ for i := 0; i < 100; i++ {
+ if osStub.DeletesCnt() == expectedDeleteCnt {
+ break
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ if useDeleteURL {
+ require.Equal(t, 1, osStub.DeletesCnt(), "Object hasn't been deleted")
+ } else {
+ require.Equal(t, 0, osStub.DeletesCnt(), "Object has been deleted")
+ }
+}
+
+func TestObjectUpload(t *testing.T) {
+ t.Run("with delete URL", func(t *testing.T) {
+ testObjectUploadNoErrors(t, test.StartObjectStore, true, "application/octet-stream")
+ })
+ t.Run("without delete URL", func(t *testing.T) {
+ testObjectUploadNoErrors(t, test.StartObjectStore, false, "application/octet-stream")
+ })
+ t.Run("with custom content type", func(t *testing.T) {
+ testObjectUploadNoErrors(t, test.StartObjectStore, false, "image/jpeg")
+ })
+ t.Run("with upcase ETAG", func(t *testing.T) {
+ factory := func() (*test.ObjectstoreStub, *httptest.Server) {
+ md5s := map[string]string{
+ test.ObjectPath: strings.ToUpper(test.ObjectMD5),
+ }
+
+ return test.StartObjectStoreWithCustomMD5(md5s)
+ }
+
+ testObjectUploadNoErrors(t, factory, false, "application/octet-stream")
+ })
+}
+
+func TestObjectUpload404(t *testing.T) {
+ ts := httptest.NewServer(http.NotFoundHandler())
+ defer ts.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ deadline := time.Now().Add(testTimeout)
+ objectURL := ts.URL + test.ObjectPath
+ object, err := objectstore.NewObject(objectURL, "", map[string]string{}, test.ObjectSize)
+ require.NoError(t, err)
+ _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
+
+ require.Error(t, err)
+ _, isStatusCodeError := err.(objectstore.StatusCodeError)
+ require.True(t, isStatusCodeError, "Should fail with StatusCodeError")
+ require.Contains(t, err.Error(), "404")
+}
+
+type endlessReader struct{}
+
+func (e *endlessReader) Read(p []byte) (n int, err error) {
+ for i := 0; i < len(p); i++ {
+ p[i] = '*'
+ }
+
+ return len(p), nil
+}
+
+// TestObjectUploadBrokenConnection purpose is to ensure that errors caused by the upload destination get propagated back correctly.
+// This is important for troubleshooting in production.
+func TestObjectUploadBrokenConnection(t *testing.T) {
+ // This test server closes connection immediately
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ hj, ok := w.(http.Hijacker)
+ if !ok {
+ require.FailNow(t, "webserver doesn't support hijacking")
+ }
+ conn, _, err := hj.Hijack()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ conn.Close()
+ }))
+ defer ts.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ deadline := time.Now().Add(testTimeout)
+ objectURL := ts.URL + test.ObjectPath
+ object, err := objectstore.NewObject(objectURL, "", map[string]string{}, -1)
+ require.NoError(t, err)
+
+ _, copyErr := object.Consume(ctx, &endlessReader{}, deadline)
+ require.Error(t, copyErr)
+ require.NotEqual(t, io.ErrClosedPipe, copyErr, "We are shadowing the real error")
+}
diff --git a/workhorse/internal/objectstore/prometheus.go b/workhorse/internal/objectstore/prometheus.go
new file mode 100644
index 00000000000..20762fb52bc
--- /dev/null
+++ b/workhorse/internal/objectstore/prometheus.go
@@ -0,0 +1,39 @@
+package objectstore
+
+import (
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+var (
+ objectStorageUploadRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_object_storage_upload_requests",
+ Help: "How many object storage requests have been processed",
+ },
+ []string{"status"},
+ )
+ objectStorageUploadsOpen = promauto.NewGauge(
+ prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_object_storage_upload_open",
+ Help: "Describes many object storage requests are open now",
+ },
+ )
+ objectStorageUploadBytes = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_object_storage_upload_bytes",
+ Help: "How many bytes were sent to object storage",
+ },
+ )
+ objectStorageUploadTime = promauto.NewHistogram(
+ prometheus.HistogramOpts{
+ Name: "gitlab_workhorse_object_storage_upload_time",
+ Help: "How long it took to upload objects",
+ Buckets: objectStorageUploadTimeBuckets,
+ })
+
+ objectStorageUploadRequestsRequestFailed = objectStorageUploadRequests.WithLabelValues("request-failed")
+ objectStorageUploadRequestsInvalidStatus = objectStorageUploadRequests.WithLabelValues("invalid-status")
+
+ objectStorageUploadTimeBuckets = []float64{.1, .25, .5, 1, 2.5, 5, 10, 25, 50, 100}
+)
diff --git a/workhorse/internal/objectstore/s3_complete_multipart_api.go b/workhorse/internal/objectstore/s3_complete_multipart_api.go
new file mode 100644
index 00000000000..b84f5757f49
--- /dev/null
+++ b/workhorse/internal/objectstore/s3_complete_multipart_api.go
@@ -0,0 +1,51 @@
+package objectstore
+
+import (
+ "encoding/xml"
+ "fmt"
+)
+
+// CompleteMultipartUpload is the S3 CompleteMultipartUpload body
+type CompleteMultipartUpload struct {
+ Part []*completeMultipartUploadPart
+}
+
+type completeMultipartUploadPart struct {
+ PartNumber int
+ ETag string
+}
+
+// CompleteMultipartUploadResult is the S3 answer to CompleteMultipartUpload request
+type CompleteMultipartUploadResult struct {
+ Location string
+ Bucket string
+ Key string
+ ETag string
+}
+
+// CompleteMultipartUploadError is the in-body error structure
+// https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadComplete.html#mpUploadComplete-examples
+// the answer contains other fields we are not using
+type CompleteMultipartUploadError struct {
+ XMLName xml.Name `xml:"Error"`
+ Code string
+ Message string
+}
+
+func (c *CompleteMultipartUploadError) Error() string {
+ return fmt.Sprintf("CompleteMultipartUpload remote error %q: %s", c.Code, c.Message)
+}
+
+// compoundCompleteMultipartUploadResult holds both CompleteMultipartUploadResult and CompleteMultipartUploadError
+// this allow us to deserialize the response body where the root element can either be Error orCompleteMultipartUploadResult
+type compoundCompleteMultipartUploadResult struct {
+ *CompleteMultipartUploadResult
+ *CompleteMultipartUploadError
+
+ // XMLName this overrides CompleteMultipartUploadError.XMLName tags
+ XMLName xml.Name
+}
+
+func (c *compoundCompleteMultipartUploadResult) isError() bool {
+ return c.CompleteMultipartUploadError != nil
+}
diff --git a/workhorse/internal/objectstore/s3_object.go b/workhorse/internal/objectstore/s3_object.go
new file mode 100644
index 00000000000..1f79f88224f
--- /dev/null
+++ b/workhorse/internal/objectstore/s3_object.go
@@ -0,0 +1,119 @@
+package objectstore
+
+import (
+ "context"
+ "io"
+ "time"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/awserr"
+ "github.com/aws/aws-sdk-go/service/s3"
+ "github.com/aws/aws-sdk-go/service/s3/s3manager"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+)
+
+type S3Object struct {
+ credentials config.S3Credentials
+ config config.S3Config
+ objectName string
+ uploaded bool
+
+ *uploader
+}
+
+func NewS3Object(objectName string, s3Credentials config.S3Credentials, s3Config config.S3Config) (*S3Object, error) {
+ o := &S3Object{
+ credentials: s3Credentials,
+ config: s3Config,
+ objectName: objectName,
+ }
+
+ o.uploader = newUploader(o)
+ return o, nil
+}
+
+func setEncryptionOptions(input *s3manager.UploadInput, s3Config config.S3Config) {
+ if s3Config.ServerSideEncryption != "" {
+ input.ServerSideEncryption = aws.String(s3Config.ServerSideEncryption)
+
+ if s3Config.ServerSideEncryption == s3.ServerSideEncryptionAwsKms && s3Config.SSEKMSKeyID != "" {
+ input.SSEKMSKeyId = aws.String(s3Config.SSEKMSKeyID)
+ }
+ }
+}
+
+func (s *S3Object) Upload(ctx context.Context, r io.Reader) error {
+ sess, err := setupS3Session(s.credentials, s.config)
+ if err != nil {
+ log.WithError(err).Error("error creating S3 session")
+ return err
+ }
+
+ uploader := s3manager.NewUploader(sess)
+
+ input := &s3manager.UploadInput{
+ Bucket: aws.String(s.config.Bucket),
+ Key: aws.String(s.objectName),
+ Body: r,
+ }
+
+ setEncryptionOptions(input, s.config)
+
+ _, err = uploader.UploadWithContext(ctx, input)
+ if err != nil {
+ log.WithError(err).Error("error uploading S3 session")
+ // Get the root cause, such as ErrEntityTooLarge, so we can return the proper HTTP status code
+ return unwrapAWSError(err)
+ }
+
+ s.uploaded = true
+
+ return nil
+}
+
+func (s *S3Object) ETag() string {
+ return ""
+}
+
+func (s *S3Object) Abort() {
+ s.Delete()
+}
+
+func (s *S3Object) Delete() {
+ if !s.uploaded {
+ return
+ }
+
+ session, err := setupS3Session(s.credentials, s.config)
+ if err != nil {
+ log.WithError(err).Error("error setting up S3 session in delete")
+ return
+ }
+
+ svc := s3.New(session)
+ input := &s3.DeleteObjectInput{
+ Bucket: aws.String(s.config.Bucket),
+ Key: aws.String(s.objectName),
+ }
+
+ // Note we can't use the request context because in a successful
+ // case, the original request has already completed.
+ deleteCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // lint:allow context.Background
+ defer cancel()
+
+ _, err = svc.DeleteObjectWithContext(deleteCtx, input)
+ if err != nil {
+ log.WithError(err).Error("error deleting S3 object", err)
+ }
+}
+
+// This is needed until https://github.com/aws/aws-sdk-go/issues/2820 is closed.
+func unwrapAWSError(e error) error {
+ if awsErr, ok := e.(awserr.Error); ok {
+ return unwrapAWSError(awsErr.OrigErr())
+ }
+
+ return e
+}
diff --git a/workhorse/internal/objectstore/s3_object_test.go b/workhorse/internal/objectstore/s3_object_test.go
new file mode 100644
index 00000000000..d9ebbd7f979
--- /dev/null
+++ b/workhorse/internal/objectstore/s3_object_test.go
@@ -0,0 +1,174 @@
+package objectstore_test
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/aws/aws-sdk-go/aws/awserr"
+ "github.com/aws/aws-sdk-go/aws/session"
+ "github.com/aws/aws-sdk-go/service/s3"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+type failedReader struct {
+ io.Reader
+}
+
+func (r *failedReader) Read(p []byte) (int, error) {
+ origErr := fmt.Errorf("entity is too large")
+ return 0, awserr.New("Read", "read failed", origErr)
+}
+
+func TestS3ObjectUpload(t *testing.T) {
+ testCases := []struct {
+ encryption string
+ }{
+ {encryption: ""},
+ {encryption: s3.ServerSideEncryptionAes256},
+ {encryption: s3.ServerSideEncryptionAwsKms},
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("encryption=%v", tc.encryption), func(t *testing.T) {
+ creds, config, sess, ts := test.SetupS3(t, tc.encryption)
+ defer ts.Close()
+
+ deadline := time.Now().Add(testTimeout)
+ tmpDir, err := ioutil.TempDir("", "workhorse-test-")
+ require.NoError(t, err)
+ defer os.Remove(tmpDir)
+
+ objectName := filepath.Join(tmpDir, "s3-test-data")
+ ctx, cancel := context.WithCancel(context.Background())
+
+ object, err := objectstore.NewS3Object(objectName, creds, config)
+ require.NoError(t, err)
+
+ // copy data
+ n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
+ require.NoError(t, err)
+ require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
+
+ test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent)
+ test.CheckS3Metadata(t, sess, config, objectName)
+
+ cancel()
+
+ testhelper.Retry(t, 5*time.Second, func() error {
+ if test.S3ObjectDoesNotExist(t, sess, config, objectName) {
+ return nil
+ }
+
+ return fmt.Errorf("file is still present")
+ })
+ })
+ }
+}
+
+func TestConcurrentS3ObjectUpload(t *testing.T) {
+ creds, uploadsConfig, uploadsSession, uploadServer := test.SetupS3WithBucket(t, "uploads", "")
+ defer uploadServer.Close()
+
+ // This will return a separate S3 endpoint
+ _, artifactsConfig, artifactsSession, artifactsServer := test.SetupS3WithBucket(t, "artifacts", "")
+ defer artifactsServer.Close()
+
+ deadline := time.Now().Add(testTimeout)
+ tmpDir, err := ioutil.TempDir("", "workhorse-test-")
+ require.NoError(t, err)
+ defer os.Remove(tmpDir)
+
+ var wg sync.WaitGroup
+
+ for i := 0; i < 4; i++ {
+ wg.Add(1)
+
+ go func(index int) {
+ var sess *session.Session
+ var config config.S3Config
+
+ if index%2 == 0 {
+ sess = uploadsSession
+ config = uploadsConfig
+ } else {
+ sess = artifactsSession
+ config = artifactsConfig
+ }
+
+ name := fmt.Sprintf("s3-test-data-%d", index)
+ objectName := filepath.Join(tmpDir, name)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ object, err := objectstore.NewS3Object(objectName, creds, config)
+ require.NoError(t, err)
+
+ // copy data
+ n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
+ require.NoError(t, err)
+ require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
+
+ test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent)
+ wg.Done()
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+func TestS3ObjectUploadCancel(t *testing.T) {
+ creds, config, _, ts := test.SetupS3(t, "")
+ defer ts.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ deadline := time.Now().Add(testTimeout)
+ tmpDir, err := ioutil.TempDir("", "workhorse-test-")
+ require.NoError(t, err)
+ defer os.Remove(tmpDir)
+
+ objectName := filepath.Join(tmpDir, "s3-test-data")
+
+ object, err := objectstore.NewS3Object(objectName, creds, config)
+
+ require.NoError(t, err)
+
+ // Cancel the transfer before the data has been copied to ensure
+ // we handle this gracefully.
+ cancel()
+
+ _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
+ require.Error(t, err)
+ require.Equal(t, "context canceled", err.Error())
+}
+
+func TestS3ObjectUploadLimitReached(t *testing.T) {
+ creds, config, _, ts := test.SetupS3(t, "")
+ defer ts.Close()
+
+ deadline := time.Now().Add(testTimeout)
+ tmpDir, err := ioutil.TempDir("", "workhorse-test-")
+ require.NoError(t, err)
+ defer os.Remove(tmpDir)
+
+ objectName := filepath.Join(tmpDir, "s3-test-data")
+ object, err := objectstore.NewS3Object(objectName, creds, config)
+ require.NoError(t, err)
+
+ _, err = object.Consume(context.Background(), &failedReader{}, deadline)
+ require.Error(t, err)
+ require.Equal(t, "entity is too large", err.Error())
+}
diff --git a/workhorse/internal/objectstore/s3_session.go b/workhorse/internal/objectstore/s3_session.go
new file mode 100644
index 00000000000..ebc8daf534c
--- /dev/null
+++ b/workhorse/internal/objectstore/s3_session.go
@@ -0,0 +1,94 @@
+package objectstore
+
+import (
+ "sync"
+ "time"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/credentials"
+ "github.com/aws/aws-sdk-go/aws/session"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+)
+
+type s3Session struct {
+ session *session.Session
+ expiry time.Time
+}
+
+type s3SessionCache struct {
+ // An S3 session is cached by its input configuration (e.g. region,
+ // endpoint, path style, etc.), but the bucket is actually
+ // determined by the type of object to be uploaded (e.g. CI
+ // artifact, LFS, etc.) during runtime. In practice, we should only
+ // need one session per Workhorse process if we only allow one
+ // configuration for many different buckets. However, using a map
+ // indexed by the config avoids potential pitfalls in case the
+ // bucket configuration is supplied at startup or we need to support
+ // multiple S3 endpoints.
+ sessions map[config.S3Config]*s3Session
+ sync.Mutex
+}
+
+func (s *s3Session) isExpired() bool {
+ return time.Now().After(s.expiry)
+}
+
+func newS3SessionCache() *s3SessionCache {
+ return &s3SessionCache{sessions: make(map[config.S3Config]*s3Session)}
+}
+
+var (
+ // By default, it looks like IAM instance profiles may last 6 hours
+ // (via curl http://169.254.169.254/latest/meta-data/iam/security-credentials/<role_name>),
+ // but this may be configurable from anywhere for 15 minutes to 12
+ // hours. To be safe, refresh AWS sessions every 10 minutes.
+ sessionExpiration = time.Duration(10 * time.Minute)
+ sessionCache = newS3SessionCache()
+)
+
+// SetupS3Session initializes a new AWS S3 session and refreshes one if
+// necessary. As recommended in https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/sessions.html,
+// sessions should be cached when possible. Sessions are safe to use
+// concurrently as long as the session isn't modified.
+func setupS3Session(s3Credentials config.S3Credentials, s3Config config.S3Config) (*session.Session, error) {
+ sessionCache.Lock()
+ defer sessionCache.Unlock()
+
+ if s, ok := sessionCache.sessions[s3Config]; ok && !s.isExpired() {
+ return s.session, nil
+ }
+
+ cfg := &aws.Config{
+ Region: aws.String(s3Config.Region),
+ S3ForcePathStyle: aws.Bool(s3Config.PathStyle),
+ }
+
+ // In case IAM profiles aren't being used, use the static credentials
+ if s3Credentials.AwsAccessKeyID != "" && s3Credentials.AwsSecretAccessKey != "" {
+ cfg.Credentials = credentials.NewStaticCredentials(s3Credentials.AwsAccessKeyID, s3Credentials.AwsSecretAccessKey, "")
+ }
+
+ if s3Config.Endpoint != "" {
+ cfg.Endpoint = aws.String(s3Config.Endpoint)
+ }
+
+ sess, err := session.NewSession(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ sessionCache.sessions[s3Config] = &s3Session{
+ expiry: time.Now().Add(sessionExpiration),
+ session: sess,
+ }
+
+ return sess, nil
+}
+
+func ResetS3Session(s3Config config.S3Config) {
+ sessionCache.Lock()
+ defer sessionCache.Unlock()
+
+ delete(sessionCache.sessions, s3Config)
+}
diff --git a/workhorse/internal/objectstore/s3_session_test.go b/workhorse/internal/objectstore/s3_session_test.go
new file mode 100644
index 00000000000..8601f305917
--- /dev/null
+++ b/workhorse/internal/objectstore/s3_session_test.go
@@ -0,0 +1,57 @@
+package objectstore
+
+import (
+ "testing"
+ "time"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+)
+
+func TestS3SessionSetup(t *testing.T) {
+ credentials := config.S3Credentials{}
+ cfg := config.S3Config{Region: "us-west-1", PathStyle: true}
+
+ sess, err := setupS3Session(credentials, cfg)
+ require.NoError(t, err)
+
+ require.Equal(t, aws.StringValue(sess.Config.Region), "us-west-1")
+ require.True(t, aws.BoolValue(sess.Config.S3ForcePathStyle))
+
+ require.Equal(t, len(sessionCache.sessions), 1)
+ anotherConfig := cfg
+ _, err = setupS3Session(credentials, anotherConfig)
+ require.NoError(t, err)
+ require.Equal(t, len(sessionCache.sessions), 1)
+
+ ResetS3Session(cfg)
+}
+
+func TestS3SessionExpiry(t *testing.T) {
+ credentials := config.S3Credentials{}
+ cfg := config.S3Config{Region: "us-west-1", PathStyle: true}
+
+ sess, err := setupS3Session(credentials, cfg)
+ require.NoError(t, err)
+
+ require.Equal(t, aws.StringValue(sess.Config.Region), "us-west-1")
+ require.True(t, aws.BoolValue(sess.Config.S3ForcePathStyle))
+
+ firstSession, ok := sessionCache.sessions[cfg]
+ require.True(t, ok)
+ require.False(t, firstSession.isExpired())
+
+ firstSession.expiry = time.Now().Add(-1 * time.Second)
+ require.True(t, firstSession.isExpired())
+
+ _, err = setupS3Session(credentials, cfg)
+ require.NoError(t, err)
+
+ nextSession, ok := sessionCache.sessions[cfg]
+ require.True(t, ok)
+ require.False(t, nextSession.isExpired())
+
+ ResetS3Session(cfg)
+}
diff --git a/workhorse/internal/objectstore/test/consts.go b/workhorse/internal/objectstore/test/consts.go
new file mode 100644
index 00000000000..7a1bcc28d45
--- /dev/null
+++ b/workhorse/internal/objectstore/test/consts.go
@@ -0,0 +1,19 @@
+package test
+
+// Some useful const for testing purpose
+const (
+ // ObjectContent an example textual content
+ ObjectContent = "TEST OBJECT CONTENT"
+ // ObjectSize is the ObjectContent size
+ ObjectSize = int64(len(ObjectContent))
+ // Objectpath is an example remote object path (including bucket name)
+ ObjectPath = "/bucket/object"
+ // ObjectMD5 is ObjectContent MD5 hash
+ ObjectMD5 = "42d000eea026ee0760677e506189cb33"
+ // ObjectSHA1 is ObjectContent SHA1 hash
+ ObjectSHA1 = "173cfd58c6b60cb910f68a26cbb77e3fc5017a6d"
+ // ObjectSHA256 is ObjectContent SHA256 hash
+ ObjectSHA256 = "b0257e9e657ef19b15eed4fbba975bd5238d651977564035ef91cb45693647aa"
+ // ObjectSHA512 is ObjectContent SHA512 hash
+ ObjectSHA512 = "51af8197db2047f7894652daa7437927bf831d5aa63f1b0b7277c4800b06f5e3057251f0e4c2d344ca8c2daf1ffc08a28dd3b2f5fe0e316d3fd6c3af58c34b97"
+)
diff --git a/workhorse/internal/objectstore/test/gocloud_stub.go b/workhorse/internal/objectstore/test/gocloud_stub.go
new file mode 100644
index 00000000000..cf22075e407
--- /dev/null
+++ b/workhorse/internal/objectstore/test/gocloud_stub.go
@@ -0,0 +1,47 @@
+package test
+
+import (
+ "context"
+ "io/ioutil"
+ "net/url"
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gocloud.dev/blob"
+ "gocloud.dev/blob/fileblob"
+)
+
+type dirOpener struct {
+ tmpDir string
+}
+
+func (o *dirOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
+ return fileblob.OpenBucket(o.tmpDir, nil)
+}
+
+func SetupGoCloudFileBucket(t *testing.T, scheme string) (m *blob.URLMux, bucketDir string, cleanup func()) {
+ tmpDir, err := ioutil.TempDir("", "")
+ require.NoError(t, err)
+
+ mux := new(blob.URLMux)
+ fake := &dirOpener{tmpDir: tmpDir}
+ mux.RegisterBucket(scheme, fake)
+ cleanup = func() {
+ os.RemoveAll(tmpDir)
+ }
+
+ return mux, tmpDir, cleanup
+}
+
+func GoCloudObjectExists(t *testing.T, bucketDir string, objectName string) {
+ bucket, err := fileblob.OpenBucket(bucketDir, nil)
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background()) // lint:allow context.Background
+ defer cancel()
+
+ exists, err := bucket.Exists(ctx, objectName)
+ require.NoError(t, err)
+ require.True(t, exists)
+}
diff --git a/workhorse/internal/objectstore/test/objectstore_stub.go b/workhorse/internal/objectstore/test/objectstore_stub.go
new file mode 100644
index 00000000000..31ef4913305
--- /dev/null
+++ b/workhorse/internal/objectstore/test/objectstore_stub.go
@@ -0,0 +1,278 @@
+package test
+
+import (
+ "crypto/md5"
+ "encoding/hex"
+ "encoding/xml"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "sync"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore"
+)
+
+type partsEtagMap map[int]string
+
+// ObjectstoreStub is a testing implementation of ObjectStore.
+// Instead of storing objects it will just save md5sum.
+type ObjectstoreStub struct {
+ // bucket contains md5sum of uploaded objects
+ bucket map[string]string
+ // overwriteMD5 contains overwrites for md5sum that should be return instead of the regular hash
+ overwriteMD5 map[string]string
+ // multipart is a map of MultipartUploads
+ multipart map[string]partsEtagMap
+ // HTTP header sent along request
+ headers map[string]*http.Header
+
+ puts int
+ deletes int
+
+ m sync.Mutex
+}
+
+// StartObjectStore will start an ObjectStore stub
+func StartObjectStore() (*ObjectstoreStub, *httptest.Server) {
+ return StartObjectStoreWithCustomMD5(make(map[string]string))
+}
+
+// StartObjectStoreWithCustomMD5 will start an ObjectStore stub: md5Hashes contains overwrites for md5sum that should be return on PutObject
+func StartObjectStoreWithCustomMD5(md5Hashes map[string]string) (*ObjectstoreStub, *httptest.Server) {
+ os := &ObjectstoreStub{
+ bucket: make(map[string]string),
+ multipart: make(map[string]partsEtagMap),
+ overwriteMD5: make(map[string]string),
+ headers: make(map[string]*http.Header),
+ }
+
+ for k, v := range md5Hashes {
+ os.overwriteMD5[k] = v
+ }
+
+ return os, httptest.NewServer(os)
+}
+
+// PutsCnt counts PutObject invocations
+func (o *ObjectstoreStub) PutsCnt() int {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ return o.puts
+}
+
+// DeletesCnt counts DeleteObject invocation of a valid object
+func (o *ObjectstoreStub) DeletesCnt() int {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ return o.deletes
+}
+
+// GetObjectMD5 return the calculated MD5 of the object uploaded to path
+// it will return an empty string if no object has been uploaded on such path
+func (o *ObjectstoreStub) GetObjectMD5(path string) string {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ return o.bucket[path]
+}
+
+// GetHeader returns a given HTTP header of the object uploaded to the path
+func (o *ObjectstoreStub) GetHeader(path, key string) string {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ if val, ok := o.headers[path]; ok {
+ return val.Get(key)
+ }
+
+ return ""
+}
+
+// InitiateMultipartUpload prepare the ObjectstoreStob to receive a MultipartUpload on path
+// It will return an error if a MultipartUpload is already in progress on that path
+// InitiateMultipartUpload is only used during test setup.
+// Workhorse's production code does not know how to initiate a multipart upload.
+//
+// Real S3 multipart uploads are more complicated than what we do here,
+// but this is enough to verify that workhorse's production code behaves as intended.
+func (o *ObjectstoreStub) InitiateMultipartUpload(path string) error {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ if o.multipart[path] != nil {
+ return fmt.Errorf("MultipartUpload for %q already in progress", path)
+ }
+
+ o.multipart[path] = make(partsEtagMap)
+ return nil
+}
+
+// IsMultipartUpload check if the given path has a MultipartUpload in progress
+func (o *ObjectstoreStub) IsMultipartUpload(path string) bool {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ return o.isMultipartUpload(path)
+}
+
+// isMultipartUpload is the lock free version of IsMultipartUpload
+func (o *ObjectstoreStub) isMultipartUpload(path string) bool {
+ return o.multipart[path] != nil
+}
+
+func (o *ObjectstoreStub) removeObject(w http.ResponseWriter, r *http.Request) {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ objectPath := r.URL.Path
+ if o.isMultipartUpload(objectPath) {
+ o.deletes++
+ delete(o.multipart, objectPath)
+
+ w.WriteHeader(200)
+ } else if _, ok := o.bucket[objectPath]; ok {
+ o.deletes++
+ delete(o.bucket, objectPath)
+
+ w.WriteHeader(200)
+ } else {
+ w.WriteHeader(404)
+ }
+}
+
+func (o *ObjectstoreStub) putObject(w http.ResponseWriter, r *http.Request) {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ objectPath := r.URL.Path
+
+ etag, overwritten := o.overwriteMD5[objectPath]
+ if !overwritten {
+ hasher := md5.New()
+ io.Copy(hasher, r.Body)
+
+ checksum := hasher.Sum(nil)
+ etag = hex.EncodeToString(checksum)
+ }
+
+ o.headers[objectPath] = &r.Header
+ o.puts++
+ if o.isMultipartUpload(objectPath) {
+ pNumber := r.URL.Query().Get("partNumber")
+ idx, err := strconv.Atoi(pNumber)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("malformed partNumber: %v", err), 400)
+ return
+ }
+
+ o.multipart[objectPath][idx] = etag
+ } else {
+ o.bucket[objectPath] = etag
+ }
+
+ w.Header().Set("ETag", etag)
+ w.WriteHeader(200)
+}
+
+func MultipartUploadInternalError() *objectstore.CompleteMultipartUploadError {
+ return &objectstore.CompleteMultipartUploadError{Code: "InternalError", Message: "malformed object path"}
+}
+
+func (o *ObjectstoreStub) completeMultipartUpload(w http.ResponseWriter, r *http.Request) {
+ o.m.Lock()
+ defer o.m.Unlock()
+
+ objectPath := r.URL.Path
+
+ multipart := o.multipart[objectPath]
+ if multipart == nil {
+ http.Error(w, "Unknown MultipartUpload", 404)
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, err.Error(), 500)
+ return
+ }
+
+ var msg objectstore.CompleteMultipartUpload
+ err = xml.Unmarshal(buf, &msg)
+ if err != nil {
+ http.Error(w, err.Error(), 400)
+ return
+ }
+
+ for _, part := range msg.Part {
+ etag := multipart[part.PartNumber]
+ if etag != part.ETag {
+ msg := fmt.Sprintf("ETag mismatch on part %d. Expected %q got %q", part.PartNumber, etag, part.ETag)
+ http.Error(w, msg, 400)
+ return
+ }
+ }
+
+ etag, overwritten := o.overwriteMD5[objectPath]
+ if !overwritten {
+ etag = "CompleteMultipartUploadETag"
+ }
+
+ o.bucket[objectPath] = etag
+ delete(o.multipart, objectPath)
+
+ w.Header().Set("ETag", etag)
+ split := strings.SplitN(objectPath[1:], "/", 2)
+ if len(split) < 2 {
+ encodeXMLAnswer(w, MultipartUploadInternalError())
+ return
+ }
+
+ bucket := split[0]
+ key := split[1]
+ answer := objectstore.CompleteMultipartUploadResult{
+ Location: r.URL.String(),
+ Bucket: bucket,
+ Key: key,
+ ETag: etag,
+ }
+ encodeXMLAnswer(w, answer)
+}
+
+func encodeXMLAnswer(w http.ResponseWriter, answer interface{}) {
+ w.Header().Set("Content-Type", "text/xml")
+
+ enc := xml.NewEncoder(w)
+ if err := enc.Encode(answer); err != nil {
+ http.Error(w, err.Error(), 500)
+ }
+}
+
+func (o *ObjectstoreStub) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if r.Body != nil {
+ defer r.Body.Close()
+ }
+
+ fmt.Println("ObjectStore Stub:", r.Method, r.URL.String())
+
+ if r.URL.Path == "" {
+ http.Error(w, "No path provided", 404)
+ return
+ }
+
+ switch r.Method {
+ case "DELETE":
+ o.removeObject(w, r)
+ case "PUT":
+ o.putObject(w, r)
+ case "POST":
+ o.completeMultipartUpload(w, r)
+ default:
+ w.WriteHeader(404)
+ }
+}
diff --git a/workhorse/internal/objectstore/test/objectstore_stub_test.go b/workhorse/internal/objectstore/test/objectstore_stub_test.go
new file mode 100644
index 00000000000..8c0d52a2d79
--- /dev/null
+++ b/workhorse/internal/objectstore/test/objectstore_stub_test.go
@@ -0,0 +1,167 @@
+package test
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func doRequest(method, url string, body io.Reader) error {
+ req, err := http.NewRequest(method, url, body)
+ if err != nil {
+ return err
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return err
+ }
+
+ return resp.Body.Close()
+}
+
+func TestObjectStoreStub(t *testing.T) {
+ stub, ts := StartObjectStore()
+ defer ts.Close()
+
+ require.Equal(t, 0, stub.PutsCnt())
+ require.Equal(t, 0, stub.DeletesCnt())
+
+ objectURL := ts.URL + ObjectPath
+
+ require.NoError(t, doRequest(http.MethodPut, objectURL, strings.NewReader(ObjectContent)))
+
+ require.Equal(t, 1, stub.PutsCnt())
+ require.Equal(t, 0, stub.DeletesCnt())
+ require.Equal(t, ObjectMD5, stub.GetObjectMD5(ObjectPath))
+
+ require.NoError(t, doRequest(http.MethodDelete, objectURL, nil))
+
+ require.Equal(t, 1, stub.PutsCnt())
+ require.Equal(t, 1, stub.DeletesCnt())
+}
+
+func TestObjectStoreStubDelete404(t *testing.T) {
+ stub, ts := StartObjectStore()
+ defer ts.Close()
+
+ objectURL := ts.URL + ObjectPath
+
+ req, err := http.NewRequest(http.MethodDelete, objectURL, nil)
+ require.NoError(t, err)
+
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+ require.Equal(t, 404, resp.StatusCode)
+
+ require.Equal(t, 0, stub.DeletesCnt())
+}
+
+func TestObjectStoreInitiateMultipartUpload(t *testing.T) {
+ stub, ts := StartObjectStore()
+ defer ts.Close()
+
+ path := "/my-multipart"
+ err := stub.InitiateMultipartUpload(path)
+ require.NoError(t, err)
+
+ err = stub.InitiateMultipartUpload(path)
+ require.Error(t, err, "second attempt to open the same MultipartUpload")
+}
+
+func TestObjectStoreCompleteMultipartUpload(t *testing.T) {
+ stub, ts := StartObjectStore()
+ defer ts.Close()
+
+ objectURL := ts.URL + ObjectPath
+ parts := []struct {
+ number int
+ content string
+ contentMD5 string
+ }{
+ {
+ number: 1,
+ content: "first part",
+ contentMD5: "550cf6b6e60f65a0e3104a26e70fea42",
+ }, {
+ number: 2,
+ content: "second part",
+ contentMD5: "920b914bca0a70780b40881b8f376135",
+ },
+ }
+
+ stub.InitiateMultipartUpload(ObjectPath)
+
+ require.True(t, stub.IsMultipartUpload(ObjectPath))
+ require.Equal(t, 0, stub.PutsCnt())
+ require.Equal(t, 0, stub.DeletesCnt())
+
+ // Workhorse knows nothing about S3 MultipartUpload, it receives some URLs
+ // from GitLab-rails and PUTs chunk of data to each of them.
+ // Then it completes the upload with a final POST
+ partPutURLs := []string{
+ fmt.Sprintf("%s?partNumber=%d", objectURL, 1),
+ fmt.Sprintf("%s?partNumber=%d", objectURL, 2),
+ }
+ completePostURL := objectURL
+
+ for i, partPutURL := range partPutURLs {
+ part := parts[i]
+
+ require.NoError(t, doRequest(http.MethodPut, partPutURL, strings.NewReader(part.content)))
+
+ require.Equal(t, i+1, stub.PutsCnt())
+ require.Equal(t, 0, stub.DeletesCnt())
+ require.Equal(t, part.contentMD5, stub.multipart[ObjectPath][part.number], "Part %d was not uploaded into ObjectStorage", part.number)
+ require.Empty(t, stub.GetObjectMD5(ObjectPath), "Part %d was mistakenly uploaded as a single object", part.number)
+ require.True(t, stub.IsMultipartUpload(ObjectPath), "MultipartUpload completed or aborted")
+ }
+
+ completeBody := fmt.Sprintf(`<CompleteMultipartUpload>
+ <Part>
+ <PartNumber>1</PartNumber>
+ <ETag>%s</ETag>
+ </Part>
+ <Part>
+ <PartNumber>2</PartNumber>
+ <ETag>%s</ETag>
+ </Part>
+ </CompleteMultipartUpload>`, parts[0].contentMD5, parts[1].contentMD5)
+ require.NoError(t, doRequest(http.MethodPost, completePostURL, strings.NewReader(completeBody)))
+
+ require.Equal(t, len(parts), stub.PutsCnt())
+ require.Equal(t, 0, stub.DeletesCnt())
+ require.False(t, stub.IsMultipartUpload(ObjectPath), "MultipartUpload is still in progress")
+}
+
+func TestObjectStoreAbortMultipartUpload(t *testing.T) {
+ stub, ts := StartObjectStore()
+ defer ts.Close()
+
+ stub.InitiateMultipartUpload(ObjectPath)
+
+ require.True(t, stub.IsMultipartUpload(ObjectPath))
+ require.Equal(t, 0, stub.PutsCnt())
+ require.Equal(t, 0, stub.DeletesCnt())
+
+ objectURL := ts.URL + ObjectPath
+ require.NoError(t, doRequest(http.MethodPut, fmt.Sprintf("%s?partNumber=%d", objectURL, 1), strings.NewReader(ObjectContent)))
+
+ require.Equal(t, 1, stub.PutsCnt())
+ require.Equal(t, 0, stub.DeletesCnt())
+ require.Equal(t, ObjectMD5, stub.multipart[ObjectPath][1], "Part was not uploaded into ObjectStorage")
+ require.Empty(t, stub.GetObjectMD5(ObjectPath), "Part was mistakenly uploaded as a single object")
+ require.True(t, stub.IsMultipartUpload(ObjectPath), "MultipartUpload completed or aborted")
+
+ require.NoError(t, doRequest(http.MethodDelete, objectURL, nil))
+
+ require.Equal(t, 1, stub.PutsCnt())
+ require.Equal(t, 1, stub.DeletesCnt())
+ require.Empty(t, stub.GetObjectMD5(ObjectPath), "MultiUpload has been completed")
+ require.False(t, stub.IsMultipartUpload(ObjectPath), "MultiUpload is still in progress")
+}
diff --git a/workhorse/internal/objectstore/test/s3_stub.go b/workhorse/internal/objectstore/test/s3_stub.go
new file mode 100644
index 00000000000..36514b3b887
--- /dev/null
+++ b/workhorse/internal/objectstore/test/s3_stub.go
@@ -0,0 +1,142 @@
+package test
+
+import (
+ "io/ioutil"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+
+ "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/credentials"
+ "github.com/aws/aws-sdk-go/aws/session"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+
+ "github.com/aws/aws-sdk-go/service/s3"
+ "github.com/aws/aws-sdk-go/service/s3/s3manager"
+
+ "github.com/johannesboyne/gofakes3"
+ "github.com/johannesboyne/gofakes3/backend/s3mem"
+)
+
+func SetupS3(t *testing.T, encryption string) (config.S3Credentials, config.S3Config, *session.Session, *httptest.Server) {
+ return SetupS3WithBucket(t, "test-bucket", encryption)
+}
+
+func SetupS3WithBucket(t *testing.T, bucket string, encryption string) (config.S3Credentials, config.S3Config, *session.Session, *httptest.Server) {
+ backend := s3mem.New()
+ faker := gofakes3.New(backend)
+ ts := httptest.NewServer(faker.Server())
+
+ creds := config.S3Credentials{
+ AwsAccessKeyID: "YOUR-ACCESSKEYID",
+ AwsSecretAccessKey: "YOUR-SECRETACCESSKEY",
+ }
+
+ config := config.S3Config{
+ Bucket: bucket,
+ Endpoint: ts.URL,
+ Region: "eu-central-1",
+ PathStyle: true,
+ }
+
+ if encryption != "" {
+ config.ServerSideEncryption = encryption
+
+ if encryption == s3.ServerSideEncryptionAwsKms {
+ config.SSEKMSKeyID = "arn:aws:1234"
+ }
+ }
+
+ sess, err := session.NewSession(&aws.Config{
+ Credentials: credentials.NewStaticCredentials(creds.AwsAccessKeyID, creds.AwsSecretAccessKey, ""),
+ Endpoint: aws.String(ts.URL),
+ Region: aws.String(config.Region),
+ DisableSSL: aws.Bool(true),
+ S3ForcePathStyle: aws.Bool(true),
+ })
+ require.NoError(t, err)
+
+ // Create S3 service client
+ svc := s3.New(sess)
+
+ _, err = svc.CreateBucket(&s3.CreateBucketInput{
+ Bucket: aws.String(bucket),
+ })
+ require.NoError(t, err)
+
+ return creds, config, sess, ts
+}
+
+// S3ObjectExists will fail the test if the file does not exist.
+func S3ObjectExists(t *testing.T, sess *session.Session, config config.S3Config, objectName string, expectedBytes string) {
+ downloadObject(t, sess, config, objectName, func(tmpfile *os.File, numBytes int64, err error) {
+ require.NoError(t, err)
+ require.Equal(t, int64(len(expectedBytes)), numBytes)
+
+ output, err := ioutil.ReadFile(tmpfile.Name())
+ require.NoError(t, err)
+
+ require.Equal(t, []byte(expectedBytes), output)
+ })
+}
+
+func CheckS3Metadata(t *testing.T, sess *session.Session, config config.S3Config, objectName string) {
+ // In a real S3 provider, s3crypto.NewDecryptionClient should probably be used
+ svc := s3.New(sess)
+ result, err := svc.GetObject(&s3.GetObjectInput{
+ Bucket: aws.String(config.Bucket),
+ Key: aws.String(objectName),
+ })
+ require.NoError(t, err)
+
+ if config.ServerSideEncryption != "" {
+ require.Equal(t, aws.String(config.ServerSideEncryption), result.ServerSideEncryption)
+
+ if config.ServerSideEncryption == s3.ServerSideEncryptionAwsKms {
+ require.Equal(t, aws.String(config.SSEKMSKeyID), result.SSEKMSKeyId)
+ } else {
+ require.Nil(t, result.SSEKMSKeyId)
+ }
+ } else {
+ require.Nil(t, result.ServerSideEncryption)
+ require.Nil(t, result.SSEKMSKeyId)
+ }
+}
+
+// S3ObjectDoesNotExist returns true if the object has been deleted,
+// false otherwise. The return signature is different from
+// S3ObjectExists because deletion may need to be retried since deferred
+// clean up callsinternal/objectstore/test/s3_stub.go may cause the actual deletion to happen after the
+// initial check.
+func S3ObjectDoesNotExist(t *testing.T, sess *session.Session, config config.S3Config, objectName string) bool {
+ deleted := false
+
+ downloadObject(t, sess, config, objectName, func(tmpfile *os.File, numBytes int64, err error) {
+ if err != nil && strings.Contains(err.Error(), "NoSuchKey") {
+ deleted = true
+ }
+ })
+
+ return deleted
+}
+
+func downloadObject(t *testing.T, sess *session.Session, config config.S3Config, objectName string, handler func(tmpfile *os.File, numBytes int64, err error)) {
+ tmpDir, err := ioutil.TempDir("", "workhorse-test-")
+ require.NoError(t, err)
+ defer os.Remove(tmpDir)
+
+ tmpfile, err := ioutil.TempFile(tmpDir, "s3-output")
+ require.NoError(t, err)
+ defer os.Remove(tmpfile.Name())
+
+ downloadSvc := s3manager.NewDownloader(sess)
+ numBytes, err := downloadSvc.Download(tmpfile, &s3.GetObjectInput{
+ Bucket: aws.String(config.Bucket),
+ Key: aws.String(objectName),
+ })
+
+ handler(tmpfile, numBytes, err)
+}
diff --git a/workhorse/internal/objectstore/upload_strategy.go b/workhorse/internal/objectstore/upload_strategy.go
new file mode 100644
index 00000000000..5707ba5f24e
--- /dev/null
+++ b/workhorse/internal/objectstore/upload_strategy.go
@@ -0,0 +1,46 @@
+package objectstore
+
+import (
+ "context"
+ "io"
+ "net/http"
+
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+)
+
+type uploadStrategy interface {
+ Upload(ctx context.Context, r io.Reader) error
+ ETag() string
+ Abort()
+ Delete()
+}
+
+func deleteURL(url string) {
+ if url == "" {
+ return
+ }
+
+ req, err := http.NewRequest("DELETE", url, nil)
+ if err != nil {
+ log.WithError(err).WithField("object", mask.URL(url)).Warning("Delete failed")
+ return
+ }
+ // TODO: consider adding the context to the outgoing request for better instrumentation
+
+ // here we are not using u.ctx because we must perform cleanup regardless of parent context
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ log.WithError(err).WithField("object", mask.URL(url)).Warning("Delete failed")
+ return
+ }
+ resp.Body.Close()
+}
+
+func extractETag(rawETag string) string {
+ if rawETag != "" && rawETag[0] == '"' {
+ rawETag = rawETag[1 : len(rawETag)-1]
+ }
+
+ return rawETag
+}
diff --git a/workhorse/internal/objectstore/uploader.go b/workhorse/internal/objectstore/uploader.go
new file mode 100644
index 00000000000..aedfbe55ead
--- /dev/null
+++ b/workhorse/internal/objectstore/uploader.go
@@ -0,0 +1,115 @@
+package objectstore
+
+import (
+ "context"
+ "crypto/md5"
+ "encoding/hex"
+ "fmt"
+ "hash"
+ "io"
+ "strings"
+ "time"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+// uploader consumes an io.Reader and uploads it using a pluggable uploadStrategy.
+type uploader struct {
+ strategy uploadStrategy
+
+ // In the case of S3 uploads, we have a multipart upload which
+ // instantiates uploads for the individual parts. We don't want to
+ // increment metrics for the individual parts, so that is why we have
+ // this boolean flag.
+ metrics bool
+
+ // With S3 we compare the MD5 of the data we sent with the ETag returned
+ // by the object storage server.
+ checkETag bool
+}
+
+func newUploader(strategy uploadStrategy) *uploader {
+ return &uploader{strategy: strategy, metrics: true}
+}
+
+func newETagCheckUploader(strategy uploadStrategy, metrics bool) *uploader {
+ return &uploader{strategy: strategy, metrics: metrics, checkETag: true}
+}
+
+func hexString(h hash.Hash) string { return hex.EncodeToString(h.Sum(nil)) }
+
+// Consume reads the reader until it reaches EOF or an error. It spawns a
+// goroutine that waits for outerCtx to be done, after which the remote
+// file is deleted. The deadline applies to the upload performed inside
+// Consume, not to outerCtx.
+func (u *uploader) Consume(outerCtx context.Context, reader io.Reader, deadline time.Time) (_ int64, err error) {
+ if u.metrics {
+ objectStorageUploadsOpen.Inc()
+ defer func(started time.Time) {
+ objectStorageUploadsOpen.Dec()
+ objectStorageUploadTime.Observe(time.Since(started).Seconds())
+ if err != nil {
+ objectStorageUploadRequestsRequestFailed.Inc()
+ }
+ }(time.Now())
+ }
+
+ defer func() {
+ // We do this mainly to abort S3 multipart uploads: it is not enough to
+ // "delete" them.
+ if err != nil {
+ u.strategy.Abort()
+ }
+ }()
+
+ go func() {
+ // Once gitlab-rails is done handling the request, we are supposed to
+ // delete the upload from its temporary location.
+ <-outerCtx.Done()
+ u.strategy.Delete()
+ }()
+
+ uploadCtx, cancelFn := context.WithDeadline(outerCtx, deadline)
+ defer cancelFn()
+
+ var hasher hash.Hash
+ if u.checkETag {
+ hasher = md5.New()
+ reader = io.TeeReader(reader, hasher)
+ }
+
+ cr := &countReader{r: reader}
+ if err := u.strategy.Upload(uploadCtx, cr); err != nil {
+ return cr.n, err
+ }
+
+ if u.checkETag {
+ if err := compareMD5(hexString(hasher), u.strategy.ETag()); err != nil {
+ log.ContextLogger(uploadCtx).WithError(err).Error("error comparing MD5 checksum")
+ return cr.n, err
+ }
+ }
+
+ objectStorageUploadBytes.Add(float64(cr.n))
+
+ return cr.n, nil
+}
+
+func compareMD5(local, remote string) error {
+ if !strings.EqualFold(local, remote) {
+ return fmt.Errorf("ETag mismatch. expected %q got %q", local, remote)
+ }
+
+ return nil
+}
+
+type countReader struct {
+ r io.Reader
+ n int64
+}
+
+func (cr *countReader) Read(p []byte) (int, error) {
+ nRead, err := cr.r.Read(p)
+ cr.n += int64(nRead)
+ return nRead, err
+}
diff --git a/workhorse/internal/proxy/proxy.go b/workhorse/internal/proxy/proxy.go
new file mode 100644
index 00000000000..1bc417a841f
--- /dev/null
+++ b/workhorse/internal/proxy/proxy.go
@@ -0,0 +1,62 @@
+package proxy
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httputil"
+ "net/url"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ defaultTarget = helper.URLMustParse("http://localhost")
+)
+
+type Proxy struct {
+ Version string
+ reverseProxy *httputil.ReverseProxy
+ AllowResponseBuffering bool
+}
+
+func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper) *Proxy {
+ p := Proxy{Version: version, AllowResponseBuffering: true}
+
+ if myURL == nil {
+ myURL = defaultTarget
+ }
+
+ u := *myURL // Make a copy of p.URL
+ u.Path = ""
+ p.reverseProxy = httputil.NewSingleHostReverseProxy(&u)
+ p.reverseProxy.Transport = roundTripper
+ return &p
+}
+
+func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Clone request
+ req := *r
+ req.Header = helper.HeaderClone(r.Header)
+
+ // Set Workhorse version
+ req.Header.Set("Gitlab-Workhorse", p.Version)
+ req.Header.Set("Gitlab-Workhorse-Proxy-Start", fmt.Sprintf("%d", time.Now().UnixNano()))
+
+ if p.AllowResponseBuffering {
+ helper.AllowResponseBuffering(w)
+ }
+
+ // If the ultimate client disconnects when the response isn't fully written
+ // to them yet, httputil.ReverseProxy panics with a net/http.ErrAbortHandler
+ // error. We can catch and discard this to keep the error log clean
+ defer func() {
+ if err := recover(); err != nil {
+ if err != http.ErrAbortHandler {
+ panic(err)
+ }
+ }
+ }()
+
+ p.reverseProxy.ServeHTTP(w, &req)
+}
diff --git a/workhorse/internal/queueing/queue.go b/workhorse/internal/queueing/queue.go
new file mode 100644
index 00000000000..db082cf19c6
--- /dev/null
+++ b/workhorse/internal/queueing/queue.go
@@ -0,0 +1,201 @@
+package queueing
+
+import (
+ "errors"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+type errTooManyRequests struct{ error }
+type errQueueingTimedout struct{ error }
+
+var ErrTooManyRequests = &errTooManyRequests{errors.New("too many requests queued")}
+var ErrQueueingTimedout = &errQueueingTimedout{errors.New("queueing timedout")}
+
+type queueMetrics struct {
+ queueingLimit prometheus.Gauge
+ queueingQueueLimit prometheus.Gauge
+ queueingQueueTimeout prometheus.Gauge
+ queueingBusy prometheus.Gauge
+ queueingWaiting prometheus.Gauge
+ queueingWaitingTime prometheus.Histogram
+ queueingErrors *prometheus.CounterVec
+}
+
+// newQueueMetrics prepares Prometheus metrics for queueing mechanism
+// name specifies name of the queue, used to label metrics with ConstLabel `queue_name`
+// Don't call newQueueMetrics twice with the same name argument!
+// timeout specifies the timeout of storing a request in queue - queueMetrics
+// uses it to calculate histogram buckets for gitlab_workhorse_queueing_waiting_time
+// metric
+func newQueueMetrics(name string, timeout time.Duration) *queueMetrics {
+ waitingTimeBuckets := []float64{
+ timeout.Seconds() * 0.01,
+ timeout.Seconds() * 0.05,
+ timeout.Seconds() * 0.10,
+ timeout.Seconds() * 0.25,
+ timeout.Seconds() * 0.50,
+ timeout.Seconds() * 0.75,
+ timeout.Seconds() * 0.90,
+ timeout.Seconds() * 0.95,
+ timeout.Seconds() * 0.99,
+ timeout.Seconds(),
+ }
+
+ metrics := &queueMetrics{
+ queueingLimit: promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_queueing_limit",
+ Help: "Current limit set for the queueing mechanism",
+ ConstLabels: prometheus.Labels{
+ "queue_name": name,
+ },
+ }),
+
+ queueingQueueLimit: promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_queueing_queue_limit",
+ Help: "Current queueLimit set for the queueing mechanism",
+ ConstLabels: prometheus.Labels{
+ "queue_name": name,
+ },
+ }),
+
+ queueingQueueTimeout: promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_queueing_queue_timeout",
+ Help: "Current queueTimeout set for the queueing mechanism",
+ ConstLabels: prometheus.Labels{
+ "queue_name": name,
+ },
+ }),
+
+ queueingBusy: promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_queueing_busy",
+ Help: "How many queued requests are now processed",
+ ConstLabels: prometheus.Labels{
+ "queue_name": name,
+ },
+ }),
+
+ queueingWaiting: promauto.NewGauge(prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_queueing_waiting",
+ Help: "How many requests are now queued",
+ ConstLabels: prometheus.Labels{
+ "queue_name": name,
+ },
+ }),
+
+ queueingWaitingTime: promauto.NewHistogram(prometheus.HistogramOpts{
+ Name: "gitlab_workhorse_queueing_waiting_time",
+ Help: "How many time a request spent in queue",
+ ConstLabels: prometheus.Labels{
+ "queue_name": name,
+ },
+ Buckets: waitingTimeBuckets,
+ }),
+
+ queueingErrors: promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_queueing_errors",
+ Help: "How many times the TooManyRequests or QueueintTimedout errors were returned while queueing, partitioned by error type",
+ ConstLabels: prometheus.Labels{
+ "queue_name": name,
+ },
+ },
+ []string{"type"},
+ ),
+ }
+
+ return metrics
+}
+
+type Queue struct {
+ *queueMetrics
+
+ name string
+ busyCh chan struct{}
+ waitingCh chan time.Time
+ timeout time.Duration
+}
+
+// newQueue creates a new queue
+// name specifies name used to label queue metrics.
+// Don't call newQueue twice with the same name argument!
+// limit specifies number of requests run concurrently
+// queueLimit specifies maximum number of requests that can be queued
+// timeout specifies the time limit of storing the request in the queue
+// if the number of requests is above the limit
+func newQueue(name string, limit, queueLimit uint, timeout time.Duration) *Queue {
+ queue := &Queue{
+ name: name,
+ busyCh: make(chan struct{}, limit),
+ waitingCh: make(chan time.Time, limit+queueLimit),
+ timeout: timeout,
+ }
+
+ queue.queueMetrics = newQueueMetrics(name, timeout)
+ queue.queueingLimit.Set(float64(limit))
+ queue.queueingQueueLimit.Set(float64(queueLimit))
+ queue.queueingQueueTimeout.Set(timeout.Seconds())
+
+ return queue
+}
+
+// Acquire takes one slot from the Queue
+// and returns when a request should be processed
+// it allows up to (limit) of requests running at a time
+// it allows to queue up to (queue-limit) requests
+func (s *Queue) Acquire() (err error) {
+ // push item to a queue to claim your own slot (non-blocking)
+ select {
+ case s.waitingCh <- time.Now():
+ s.queueingWaiting.Inc()
+ break
+ default:
+ s.queueingErrors.WithLabelValues("too_many_requests").Inc()
+ return ErrTooManyRequests
+ }
+
+ defer func() {
+ if err != nil {
+ waitStarted := <-s.waitingCh
+ s.queueingWaiting.Dec()
+ s.queueingWaitingTime.Observe(float64(time.Since(waitStarted).Seconds()))
+ }
+ }()
+
+ // fast path: push item to current processed items (non-blocking)
+ select {
+ case s.busyCh <- struct{}{}:
+ s.queueingBusy.Inc()
+ return nil
+ default:
+ break
+ }
+
+ timer := time.NewTimer(s.timeout)
+ defer timer.Stop()
+
+ // push item to current processed items (blocking)
+ select {
+ case s.busyCh <- struct{}{}:
+ s.queueingBusy.Inc()
+ return nil
+
+ case <-timer.C:
+ s.queueingErrors.WithLabelValues("queueing_timedout").Inc()
+ return ErrQueueingTimedout
+ }
+}
+
+// Release marks the finish of processing of requests
+// It triggers next request to be processed if it's in queue
+func (s *Queue) Release() {
+ // dequeue from queue to allow next request to be processed
+ waitStarted := <-s.waitingCh
+ s.queueingWaiting.Dec()
+ s.queueingWaitingTime.Observe(float64(time.Since(waitStarted).Seconds()))
+
+ <-s.busyCh
+ s.queueingBusy.Dec()
+}
diff --git a/workhorse/internal/queueing/queue_test.go b/workhorse/internal/queueing/queue_test.go
new file mode 100644
index 00000000000..7f5ed9154f4
--- /dev/null
+++ b/workhorse/internal/queueing/queue_test.go
@@ -0,0 +1,62 @@
+package queueing
+
+import (
+ "testing"
+ "time"
+)
+
+func TestNormalQueueing(t *testing.T) {
+ q := newQueue("queue 1", 2, 1, time.Microsecond)
+ err1 := q.Acquire()
+ if err1 != nil {
+ t.Fatal("we should acquire a new slot")
+ }
+
+ err2 := q.Acquire()
+ if err2 != nil {
+ t.Fatal("we should acquire a new slot")
+ }
+
+ err3 := q.Acquire()
+ if err3 != ErrQueueingTimedout {
+ t.Fatal("we should timeout")
+ }
+
+ q.Release()
+
+ err4 := q.Acquire()
+ if err4 != nil {
+ t.Fatal("we should acquire a new slot")
+ }
+}
+
+func TestQueueLimit(t *testing.T) {
+ q := newQueue("queue 2", 1, 0, time.Microsecond)
+ err1 := q.Acquire()
+ if err1 != nil {
+ t.Fatal("we should acquire a new slot")
+ }
+
+ err2 := q.Acquire()
+ if err2 != ErrTooManyRequests {
+ t.Fatal("we should fail because of not enough slots in queue")
+ }
+}
+
+func TestQueueProcessing(t *testing.T) {
+ q := newQueue("queue 3", 1, 1, time.Second)
+ err1 := q.Acquire()
+ if err1 != nil {
+ t.Fatal("we should acquire a new slot")
+ }
+
+ go func() {
+ time.Sleep(50 * time.Microsecond)
+ q.Release()
+ }()
+
+ err2 := q.Acquire()
+ if err2 != nil {
+ t.Fatal("we should acquire slot after the previous one finished")
+ }
+}
diff --git a/workhorse/internal/queueing/requests.go b/workhorse/internal/queueing/requests.go
new file mode 100644
index 00000000000..409a7656fa4
--- /dev/null
+++ b/workhorse/internal/queueing/requests.go
@@ -0,0 +1,51 @@
+package queueing
+
+import (
+ "net/http"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+const (
+ DefaultTimeout = 30 * time.Second
+ httpStatusTooManyRequests = 429
+)
+
+// QueueRequests creates a new request queue
+// name specifies the name of queue, used to label Prometheus metrics
+// Don't call QueueRequests twice with the same name argument!
+// h specifies a http.Handler which will handle the queue requests
+// limit specifies number of requests run concurrently
+// queueLimit specifies maximum number of requests that can be queued
+// queueTimeout specifies the time limit of storing the request in the queue
+func QueueRequests(name string, h http.Handler, limit, queueLimit uint, queueTimeout time.Duration) http.Handler {
+ if limit == 0 {
+ return h
+ }
+ if queueTimeout == 0 {
+ queueTimeout = DefaultTimeout
+ }
+
+ queue := newQueue(name, limit, queueLimit, queueTimeout)
+
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ err := queue.Acquire()
+
+ switch err {
+ case nil:
+ defer queue.Release()
+ h.ServeHTTP(w, r)
+
+ case ErrTooManyRequests:
+ http.Error(w, "Too Many Requests", httpStatusTooManyRequests)
+
+ case ErrQueueingTimedout:
+ http.Error(w, "Service Unavailable", http.StatusServiceUnavailable)
+
+ default:
+ helper.Fail500(w, r, err)
+ }
+
+ })
+}
diff --git a/workhorse/internal/queueing/requests_test.go b/workhorse/internal/queueing/requests_test.go
new file mode 100644
index 00000000000..f1c52e5c6f5
--- /dev/null
+++ b/workhorse/internal/queueing/requests_test.go
@@ -0,0 +1,76 @@
+package queueing
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+)
+
+var httpHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "OK")
+})
+
+func pausedHttpHandler(pauseCh chan struct{}) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ <-pauseCh
+ fmt.Fprintln(w, "OK")
+ })
+}
+
+func TestNormalRequestProcessing(t *testing.T) {
+ w := httptest.NewRecorder()
+ h := QueueRequests("Normal request processing", httpHandler, 1, 1, time.Second)
+ h.ServeHTTP(w, nil)
+ if w.Code != 200 {
+ t.Fatal("QueueRequests should process request")
+ }
+}
+
+// testSlowRequestProcessing creates a new queue,
+// then it runs a number of requests that are going through queue,
+// we return the response of first finished request,
+// where status of request can be 200, 429 or 503
+func testSlowRequestProcessing(name string, count int, limit, queueLimit uint, queueTimeout time.Duration) *httptest.ResponseRecorder {
+ pauseCh := make(chan struct{})
+ defer close(pauseCh)
+
+ handler := QueueRequests("Slow request processing: "+name, pausedHttpHandler(pauseCh), limit, queueLimit, queueTimeout)
+
+ respCh := make(chan *httptest.ResponseRecorder, count)
+
+ // queue requests to use up the queue
+ for i := 0; i < count; i++ {
+ go func() {
+ w := httptest.NewRecorder()
+ handler.ServeHTTP(w, nil)
+ respCh <- w
+ }()
+ }
+
+ // dequeue first request
+ return <-respCh
+}
+
+// TestQueueingTimeout performs 2 requests
+// the queue limit and length is 1,
+// the second request gets timed-out
+func TestQueueingTimeout(t *testing.T) {
+ w := testSlowRequestProcessing("timeout", 2, 1, 1, time.Microsecond)
+
+ if w.Code != 503 {
+ t.Fatal("QueueRequests should timeout queued request")
+ }
+}
+
+// TestQueueingTooManyRequests performs 3 requests
+// the queue limit and length is 1,
+// so the third request has to be rejected with 429
+func TestQueueingTooManyRequests(t *testing.T) {
+ w := testSlowRequestProcessing("too many requests", 3, 1, 1, time.Minute)
+
+ if w.Code != 429 {
+ t.Fatal("QueueRequests should return immediately and return too many requests")
+ }
+}
diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go
new file mode 100644
index 00000000000..96e33a64b85
--- /dev/null
+++ b/workhorse/internal/redis/keywatcher.go
@@ -0,0 +1,198 @@
+package redis
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gomodule/redigo/redis"
+ "github.com/jpillora/backoff"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ keyWatcher = make(map[string][]chan string)
+ keyWatcherMutex sync.Mutex
+ redisReconnectTimeout = backoff.Backoff{
+ //These are the defaults
+ Min: 100 * time.Millisecond,
+ Max: 60 * time.Second,
+ Factor: 2,
+ Jitter: true,
+ }
+ keyWatchers = promauto.NewGauge(
+ prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_keywatcher_keywatchers",
+ Help: "The number of keys that is being watched by gitlab-workhorse",
+ },
+ )
+ totalMessages = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_keywatcher_total_messages",
+ Help: "How many messages gitlab-workhorse has received in total on pubsub.",
+ },
+ )
+)
+
+const (
+ keySubChannel = "workhorse:notifications"
+)
+
+// KeyChan holds a key and a channel
+type KeyChan struct {
+ Key string
+ Chan chan string
+}
+
+func processInner(conn redis.Conn) error {
+ defer conn.Close()
+ psc := redis.PubSubConn{Conn: conn}
+ if err := psc.Subscribe(keySubChannel); err != nil {
+ return err
+ }
+ defer psc.Unsubscribe(keySubChannel)
+
+ for {
+ switch v := psc.Receive().(type) {
+ case redis.Message:
+ totalMessages.Inc()
+ dataStr := string(v.Data)
+ msg := strings.SplitN(dataStr, "=", 2)
+ if len(msg) != 2 {
+ helper.LogError(nil, fmt.Errorf("keywatcher: invalid notification: %q", dataStr))
+ continue
+ }
+ key, value := msg[0], msg[1]
+ notifyChanWatchers(key, value)
+ case error:
+ helper.LogError(nil, fmt.Errorf("keywatcher: pubsub receive: %v", v))
+ // Intermittent error, return nil so that it doesn't wait before reconnect
+ return nil
+ }
+ }
+}
+
+func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) {
+ conn, err := dialer()
+ if err != nil {
+ return nil, err
+ }
+
+ // Make sure Redis is actually connected
+ conn.Do("PING")
+ if err := conn.Err(); err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ return conn, nil
+}
+
+// Process redis subscriptions
+//
+// NOTE: There Can Only Be One!
+func Process() {
+ log.Info("keywatcher: starting process loop")
+ for {
+ conn, err := dialPubSub(workerDialFunc)
+ if err != nil {
+ helper.LogError(nil, fmt.Errorf("keywatcher: %v", err))
+ time.Sleep(redisReconnectTimeout.Duration())
+ continue
+ }
+ redisReconnectTimeout.Reset()
+
+ if err = processInner(conn); err != nil {
+ helper.LogError(nil, fmt.Errorf("keywatcher: process loop: %v", err))
+ }
+ }
+}
+
+func notifyChanWatchers(key, value string) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ if chanList, ok := keyWatcher[key]; ok {
+ for _, c := range chanList {
+ c <- value
+ keyWatchers.Dec()
+ }
+ delete(keyWatcher, key)
+ }
+}
+
+func addKeyChan(kc *KeyChan) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ keyWatcher[kc.Key] = append(keyWatcher[kc.Key], kc.Chan)
+ keyWatchers.Inc()
+}
+
+func delKeyChan(kc *KeyChan) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ if chans, ok := keyWatcher[kc.Key]; ok {
+ for i, c := range chans {
+ if kc.Chan == c {
+ keyWatcher[kc.Key] = append(chans[:i], chans[i+1:]...)
+ keyWatchers.Dec()
+ break
+ }
+ }
+ if len(keyWatcher[kc.Key]) == 0 {
+ delete(keyWatcher, kc.Key)
+ }
+ }
+}
+
+// WatchKeyStatus is used to tell how WatchKey returned
+type WatchKeyStatus int
+
+const (
+ // WatchKeyStatusTimeout is returned when the watch timeout provided by the caller was exceeded
+ WatchKeyStatusTimeout WatchKeyStatus = iota
+ // WatchKeyStatusAlreadyChanged is returned when the value passed by the caller was never observed
+ WatchKeyStatusAlreadyChanged
+ // WatchKeyStatusSeenChange is returned when we have seen the value passed by the caller get changed
+ WatchKeyStatusSeenChange
+ // WatchKeyStatusNoChange is returned when the function had to return before observing a change.
+ // Also returned on errors.
+ WatchKeyStatusNoChange
+)
+
+// WatchKey waits for a key to be updated or expired
+func WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error) {
+ kw := &KeyChan{
+ Key: key,
+ Chan: make(chan string, 1),
+ }
+
+ addKeyChan(kw)
+ defer delKeyChan(kw)
+
+ currentValue, err := GetString(key)
+ if err != nil {
+ return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err)
+ }
+ if currentValue != value {
+ return WatchKeyStatusAlreadyChanged, nil
+ }
+
+ select {
+ case currentValue := <-kw.Chan:
+ if currentValue == "" {
+ return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed")
+ }
+ if currentValue == value {
+ return WatchKeyStatusNoChange, nil
+ }
+ return WatchKeyStatusSeenChange, nil
+
+ case <-time.After(timeout):
+ return WatchKeyStatusTimeout, nil
+ }
+}
diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go
new file mode 100644
index 00000000000..f1ee77e2194
--- /dev/null
+++ b/workhorse/internal/redis/keywatcher_test.go
@@ -0,0 +1,162 @@
+package redis
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/rafaeljusto/redigomock"
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ runnerKey = "runner:build_queue:10"
+)
+
+func createSubscriptionMessage(key, data string) []interface{} {
+ return []interface{}{
+ []byte("message"),
+ []byte(key),
+ []byte(data),
+ }
+}
+
+func createSubscribeMessage(key string) []interface{} {
+ return []interface{}{
+ []byte("subscribe"),
+ []byte(key),
+ []byte("1"),
+ }
+}
+func createUnsubscribeMessage(key string) []interface{} {
+ return []interface{}{
+ []byte("unsubscribe"),
+ []byte(key),
+ []byte("1"),
+ }
+}
+
+func countWatchers(key string) int {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ return len(keyWatcher[key])
+}
+
+func deleteWatchers(key string) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ delete(keyWatcher, key)
+}
+
+// Forces a run of the `Process` loop against a mock PubSubConn.
+func processMessages(numWatchers int, value string) {
+ psc := redigomock.NewConn()
+
+ // Setup the initial subscription message
+ psc.Command("SUBSCRIBE", keySubChannel).Expect(createSubscribeMessage(keySubChannel))
+ psc.Command("UNSUBSCRIBE", keySubChannel).Expect(createUnsubscribeMessage(keySubChannel))
+ psc.AddSubscriptionMessage(createSubscriptionMessage(keySubChannel, runnerKey+"="+value))
+
+ // Wait for all the `WatchKey` calls to be registered
+ for countWatchers(runnerKey) != numWatchers {
+ time.Sleep(time.Millisecond)
+ }
+
+ processInner(psc)
+}
+
+func TestWatchKeySeenChange(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("something")
+
+ wg := &sync.WaitGroup{}
+ wg.Add(1)
+
+ go func() {
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusSeenChange, val, "Expected value to change")
+ wg.Done()
+ }()
+
+ processMessages(1, "somethingelse")
+ wg.Wait()
+}
+
+func TestWatchKeyNoChange(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("something")
+
+ wg := &sync.WaitGroup{}
+ wg.Add(1)
+
+ go func() {
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusNoChange, val, "Expected notification without change to value")
+ wg.Done()
+ }()
+
+ processMessages(1, "something")
+ wg.Wait()
+}
+
+func TestWatchKeyTimeout(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("something")
+
+ val, err := WatchKey(runnerKey, "something", time.Millisecond)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusTimeout, val, "Expected value to not change")
+
+ // Clean up watchers since Process isn't doing that for us (not running)
+ deleteWatchers(runnerKey)
+}
+
+func TestWatchKeyAlreadyChanged(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("somethingelse")
+
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusAlreadyChanged, val, "Expected value to have already changed")
+
+ // Clean up watchers since Process isn't doing that for us (not running)
+ deleteWatchers(runnerKey)
+}
+
+func TestWatchKeyMassivelyParallel(t *testing.T) {
+ runTimes := 100 // 100 parallel watchers
+
+ conn, td := setupMockPool()
+ defer td()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(runTimes)
+
+ getCmd := conn.Command("GET", runnerKey)
+
+ for i := 0; i < runTimes; i++ {
+ getCmd = getCmd.Expect("something")
+ }
+
+ for i := 0; i < runTimes; i++ {
+ go func() {
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusSeenChange, val, "Expected value to change")
+ wg.Done()
+ }()
+ }
+
+ processMessages(runTimes, "somethingelse")
+ wg.Wait()
+}
diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go
new file mode 100644
index 00000000000..0029a2a9e2b
--- /dev/null
+++ b/workhorse/internal/redis/redis.go
@@ -0,0 +1,295 @@
+package redis
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "time"
+
+ "github.com/FZambia/sentinel"
+ "github.com/gomodule/redigo/redis"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ pool *redis.Pool
+ sntnl *sentinel.Sentinel
+)
+
+const (
+ // Max Idle Connections in the pool.
+ defaultMaxIdle = 1
+ // Max Active Connections in the pool.
+ defaultMaxActive = 1
+ // Timeout for Read operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultReadTimeout = 1 * time.Second
+ // Timeout for Write operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultWriteTimeout = 1 * time.Second
+ // Timeout before killing Idle connections in the pool. 3 minutes seemed good.
+ // If you _actually_ hit this timeout often, you should consider turning of
+ // redis-support since it's not necessary at that point...
+ defaultIdleTimeout = 3 * time.Minute
+ // KeepAlivePeriod is to keep a TCP connection open for an extended period of
+ // time without being killed. This is used both in the pool, and in the
+ // worker-connection.
+ // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more
+ // information.
+ defaultKeepAlivePeriod = 5 * time.Minute
+)
+
+var (
+ totalConnections = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_redis_total_connections",
+ Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process",
+ },
+ )
+
+ errorCounter = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_redis_errors",
+ Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)",
+ },
+ []string{"type", "dst"},
+ )
+)
+
+func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel {
+ if len(urls) == 0 {
+ return nil
+ }
+ var addrs []string
+ for _, url := range urls {
+ h := url.URL.String()
+ log.WithFields(log.Fields{
+ "scheme": url.URL.Scheme,
+ "host": url.URL.Host,
+ }).Printf("redis: using sentinel")
+ addrs = append(addrs, h)
+ }
+ return &sentinel.Sentinel{
+ Addrs: addrs,
+ MasterName: master,
+ Dial: func(addr string) (redis.Conn, error) {
+ // This timeout is recommended for Sentinel-support according to the guidelines.
+ // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel
+ // For every address it should try to connect to the Sentinel,
+ // using a short timeout (in the order of a few hundreds of milliseconds).
+ timeout := 500 * time.Millisecond
+ url := helper.URLMustParse(addr)
+
+ var c redis.Conn
+ var err error
+ options := []redis.DialOption{
+ redis.DialConnectTimeout(timeout),
+ redis.DialReadTimeout(timeout),
+ redis.DialWriteTimeout(timeout),
+ }
+
+ if url.Scheme == "redis" || url.Scheme == "rediss" {
+ c, err = redis.DialURL(addr, options...)
+ } else {
+ c, err = redis.Dial("tcp", url.Host, options...)
+ }
+
+ if err != nil {
+ errorCounter.WithLabelValues("dial", "sentinel").Inc()
+ return nil, err
+ }
+ return c, nil
+ },
+ }
+}
+
+var poolDialFunc func() (redis.Conn, error)
+var workerDialFunc func() (redis.Conn, error)
+
+func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption {
+ readTimeout := defaultReadTimeout
+ writeTimeout := defaultWriteTimeout
+
+ if cfg != nil {
+ if cfg.ReadTimeout != nil {
+ readTimeout = cfg.ReadTimeout.Duration
+ }
+
+ if cfg.WriteTimeout != nil {
+ writeTimeout = cfg.WriteTimeout.Duration
+ }
+ }
+ return []redis.DialOption{
+ redis.DialReadTimeout(readTimeout),
+ redis.DialWriteTimeout(writeTimeout),
+ }
+}
+
+func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption {
+ var dopts []redis.DialOption
+ if setTimeouts {
+ dopts = timeoutDialOptions(cfg)
+ }
+ if cfg == nil {
+ return dopts
+ }
+ if cfg.Password != "" {
+ dopts = append(dopts, redis.DialPassword(cfg.Password))
+ }
+ if cfg.DB != nil {
+ dopts = append(dopts, redis.DialDatabase(*cfg.DB))
+ }
+ return dopts
+}
+
+func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
+ return func(network, address string) (net.Conn, error) {
+ addr, err := net.ResolveTCPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ tc, err := net.DialTCP(network, nil, addr)
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.SetKeepAlive(true); err != nil {
+ return nil, err
+ }
+ if err := tc.SetKeepAlivePeriod(timeout); err != nil {
+ return nil, err
+ }
+ return tc, nil
+ }
+}
+
+type redisDialerFunc func() (redis.Conn, error)
+
+func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ address, err := sntnl.MasterAddr()
+ if err != nil {
+ errorCounter.WithLabelValues("master", "sentinel").Inc()
+ return nil, err
+ }
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
+ return redisDial("tcp", address, dopts...)
+ }
+}
+
+func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ if url.Scheme == "unix" {
+ return redisDial(url.Scheme, url.Path, dopts...)
+ }
+
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
+
+ // redis.DialURL only works with redis[s]:// URLs
+ if url.Scheme == "redis" || url.Scheme == "rediss" {
+ return redisURLDial(url, dopts...)
+ }
+
+ return redisDial(url.Scheme, url.Host, dopts...)
+ }
+}
+
+func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) {
+ log.WithFields(log.Fields{
+ "scheme": url.Scheme,
+ "address": url.Host,
+ }).Printf("redis: dialing")
+
+ return redis.DialURL(url.String(), options...)
+}
+
+func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) {
+ log.WithFields(log.Fields{
+ "network": network,
+ "address": address,
+ }).Printf("redis: dialing")
+
+ return redis.Dial(network, address, options...)
+}
+
+func countDialer(dialer redisDialerFunc) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ c, err := dialer()
+ if err != nil {
+ errorCounter.WithLabelValues("dial", "redis").Inc()
+ } else {
+ totalConnections.Inc()
+ }
+ return c, err
+ }
+}
+
+// DefaultDialFunc should always used. Only exception is for unit-tests.
+func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) {
+ keepAlivePeriod := defaultKeepAlivePeriod
+ if cfg.KeepAlivePeriod != nil {
+ keepAlivePeriod = cfg.KeepAlivePeriod.Duration
+ }
+ dopts := dialOptionsBuilder(cfg, setReadTimeout)
+ if sntnl != nil {
+ return countDialer(sentinelDialer(dopts, keepAlivePeriod))
+ }
+ return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL))
+}
+
+// Configure redis-connection
+func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) {
+ if cfg == nil {
+ return
+ }
+ maxIdle := defaultMaxIdle
+ if cfg.MaxIdle != nil {
+ maxIdle = *cfg.MaxIdle
+ }
+ maxActive := defaultMaxActive
+ if cfg.MaxActive != nil {
+ maxActive = *cfg.MaxActive
+ }
+ sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel)
+ workerDialFunc = dialFunc(cfg, false)
+ poolDialFunc = dialFunc(cfg, true)
+ pool = &redis.Pool{
+ MaxIdle: maxIdle, // Keep at most X hot connections
+ MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited
+ IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed
+ Dial: poolDialFunc,
+ Wait: true,
+ }
+ if sntnl != nil {
+ pool.TestOnBorrow = func(c redis.Conn, t time.Time) error {
+ if !sentinel.TestRole(c, "master") {
+ return errors.New("role check failed")
+ }
+ return nil
+ }
+ }
+}
+
+// Get a connection for the Redis-pool
+func Get() redis.Conn {
+ if pool != nil {
+ return pool.Get()
+ }
+ return nil
+}
+
+// GetString fetches the value of a key in Redis as a string
+func GetString(key string) (string, error) {
+ conn := Get()
+ if conn == nil {
+ return "", fmt.Errorf("redis: could not get connection from pool")
+ }
+ defer conn.Close()
+
+ return redis.String(conn.Do("GET", key))
+}
diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go
new file mode 100644
index 00000000000..f4b4120517d
--- /dev/null
+++ b/workhorse/internal/redis/redis_test.go
@@ -0,0 +1,234 @@
+package redis
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/gomodule/redigo/redis"
+ "github.com/rafaeljusto/redigomock"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+func mockRedisServer(t *testing.T, connectReceived *bool) string {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+
+ require.Nil(t, err)
+
+ go func() {
+ defer ln.Close()
+ conn, err := ln.Accept()
+ require.Nil(t, err)
+ *connectReceived = true
+ conn.Write([]byte("OK\n"))
+ }()
+
+ return ln.Addr().String()
+}
+
+// Setup a MockPool for Redis
+//
+// Returns a teardown-function and the mock-connection
+func setupMockPool() (*redigomock.Conn, func()) {
+ conn := redigomock.NewConn()
+ cfg := &config.RedisConfig{URL: config.TomlURL{}}
+ Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) {
+ return func() (redis.Conn, error) {
+ return conn, nil
+ }
+ })
+ return conn, func() {
+ pool = nil
+ }
+}
+
+func TestDefaultDialFunc(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "tcp",
+ },
+ {
+ scheme: "redis",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := false
+ a := mockRedisServer(t, &connectReceived)
+
+ parsedURL := helper.URLMustParse(tc.scheme + "://" + a)
+ cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}}
+
+ dialer := DefaultDialFunc(cfg, true)
+ conn, err := dialer()
+
+ require.Nil(t, err)
+ conn.Receive()
+
+ require.True(t, connectReceived)
+ })
+ }
+}
+
+func TestConfigureNoConfig(t *testing.T) {
+ pool = nil
+ Configure(nil, nil)
+ require.Nil(t, pool, "Pool should be nil")
+}
+
+func TestConfigureMinimalConfig(t *testing.T) {
+ cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""}
+ Configure(cfg, DefaultDialFunc)
+
+ require.NotNil(t, pool, "Pool should not be nil")
+ require.Equal(t, 1, pool.MaxIdle)
+ require.Equal(t, 1, pool.MaxActive)
+ require.Equal(t, 3*time.Minute, pool.IdleTimeout)
+
+ pool = nil
+}
+
+func TestConfigureFullConfig(t *testing.T) {
+ i, a := 4, 10
+ r := config.TomlDuration{Duration: 3}
+ cfg := &config.RedisConfig{
+ URL: config.TomlURL{},
+ Password: "",
+ MaxIdle: &i,
+ MaxActive: &a,
+ ReadTimeout: &r,
+ }
+ Configure(cfg, DefaultDialFunc)
+
+ require.NotNil(t, pool, "Pool should not be nil")
+ require.Equal(t, i, pool.MaxIdle)
+ require.Equal(t, a, pool.MaxActive)
+ require.Equal(t, 3*time.Minute, pool.IdleTimeout)
+
+ pool = nil
+}
+
+func TestGetConnFail(t *testing.T) {
+ conn := Get()
+ require.Nil(t, conn, "Expected `conn` to be nil")
+}
+
+func TestGetConnPass(t *testing.T) {
+ _, teardown := setupMockPool()
+ defer teardown()
+ conn := Get()
+ require.NotNil(t, conn, "Expected `conn` to be non-nil")
+}
+
+func TestGetStringPass(t *testing.T) {
+ conn, teardown := setupMockPool()
+ defer teardown()
+ conn.Command("GET", "foobar").Expect("baz")
+ str, err := GetString("foobar")
+
+ require.NoError(t, err, "Expected `err` to be nil")
+ var value string
+ require.IsType(t, value, str, "Expected value to be a string")
+ require.Equal(t, "baz", str, "Expected it to be equal")
+}
+
+func TestGetStringFail(t *testing.T) {
+ _, err := GetString("foobar")
+ require.Error(t, err, "Expected error when not connected to redis")
+}
+
+func TestSentinelConnNoSentinel(t *testing.T) {
+ s := sentinelConn("", []config.TomlURL{})
+
+ require.Nil(t, s, "Sentinel without urls should return nil")
+}
+
+func TestSentinelConnDialURL(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "tcp",
+ },
+ {
+ scheme: "redis",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := false
+ a := mockRedisServer(t, &connectReceived)
+
+ addrs := []string{tc.scheme + "://" + a}
+ var sentinelUrls []config.TomlURL
+
+ for _, a := range addrs {
+ parsedURL := helper.URLMustParse(a)
+ sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
+ }
+
+ s := sentinelConn("foobar", sentinelUrls)
+ require.Equal(t, len(addrs), len(s.Addrs))
+
+ for i := range addrs {
+ require.Equal(t, addrs[i], s.Addrs[i])
+ }
+
+ conn, err := s.Dial(s.Addrs[0])
+
+ require.Nil(t, err)
+ conn.Receive()
+
+ require.True(t, connectReceived)
+ })
+ }
+}
+
+func TestSentinelConnTwoURLs(t *testing.T) {
+ addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"}
+ var sentinelUrls []config.TomlURL
+
+ for _, a := range addrs {
+ parsedURL := helper.URLMustParse(a)
+ sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
+ }
+
+ s := sentinelConn("foobar", sentinelUrls)
+ require.Equal(t, len(addrs), len(s.Addrs))
+
+ for i := range addrs {
+ require.Equal(t, addrs[i], s.Addrs[i])
+ }
+}
+
+func TestDialOptionsBuildersPassword(t *testing.T) {
+ dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false)
+ require.Equal(t, 1, len(dopts))
+}
+
+func TestDialOptionsBuildersSetTimeouts(t *testing.T) {
+ dopts := dialOptionsBuilder(nil, true)
+ require.Equal(t, 2, len(dopts))
+}
+
+func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) {
+ cfg := &config.RedisConfig{
+ ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
+ WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
+ }
+ dopts := dialOptionsBuilder(cfg, true)
+ require.Equal(t, 2, len(dopts))
+}
+
+func TestDialOptionsBuildersSelectDB(t *testing.T) {
+ db := 3
+ dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false)
+ require.Equal(t, 1, len(dopts))
+}
diff --git a/workhorse/internal/secret/jwt.go b/workhorse/internal/secret/jwt.go
new file mode 100644
index 00000000000..04335e58f76
--- /dev/null
+++ b/workhorse/internal/secret/jwt.go
@@ -0,0 +1,25 @@
+package secret
+
+import (
+ "fmt"
+
+ "github.com/dgrijalva/jwt-go"
+)
+
+var (
+ DefaultClaims = jwt.StandardClaims{Issuer: "gitlab-workhorse"}
+)
+
+func JWTTokenString(claims jwt.Claims) (string, error) {
+ secretBytes, err := Bytes()
+ if err != nil {
+ return "", fmt.Errorf("secret.JWTTokenString: %v", err)
+ }
+
+ tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBytes)
+ if err != nil {
+ return "", fmt.Errorf("secret.JWTTokenString: sign JWT: %v", err)
+ }
+
+ return tokenString, nil
+}
diff --git a/workhorse/internal/secret/roundtripper.go b/workhorse/internal/secret/roundtripper.go
new file mode 100644
index 00000000000..50bf7fff5b8
--- /dev/null
+++ b/workhorse/internal/secret/roundtripper.go
@@ -0,0 +1,35 @@
+package secret
+
+import (
+ "net/http"
+)
+
+const (
+ // This header carries the JWT token for gitlab-rails
+ RequestHeader = "Gitlab-Workhorse-Api-Request"
+)
+
+type roundTripper struct {
+ next http.RoundTripper
+ version string
+}
+
+// NewRoundTripper creates a RoundTripper that adds the JWT token header to a
+// request. This is used to verify that a request came from workhorse
+func NewRoundTripper(next http.RoundTripper, version string) http.RoundTripper {
+ return &roundTripper{next: next, version: version}
+}
+
+func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ tokenString, err := JWTTokenString(DefaultClaims)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set a custom header for the request. This can be used in some
+ // configurations (Passenger) to solve auth request routing problems.
+ req.Header.Set("Gitlab-Workhorse", r.version)
+ req.Header.Set(RequestHeader, tokenString)
+
+ return r.next.RoundTrip(req)
+}
diff --git a/workhorse/internal/secret/secret.go b/workhorse/internal/secret/secret.go
new file mode 100644
index 00000000000..e8c7c25393c
--- /dev/null
+++ b/workhorse/internal/secret/secret.go
@@ -0,0 +1,77 @@
+package secret
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io/ioutil"
+ "sync"
+)
+
+const numSecretBytes = 32
+
+type sec struct {
+ path string
+ bytes []byte
+ sync.RWMutex
+}
+
+var (
+ theSecret = &sec{}
+)
+
+func SetPath(path string) {
+ theSecret.Lock()
+ defer theSecret.Unlock()
+ theSecret.path = path
+ theSecret.bytes = nil
+}
+
+// Lazy access to the HMAC secret key. We must be lazy because if the key
+// is not already there, it will be generated by gitlab-rails, and
+// gitlab-rails is slow.
+func Bytes() ([]byte, error) {
+ if bytes := getBytes(); bytes != nil {
+ return copyBytes(bytes), nil
+ }
+
+ return setBytes()
+}
+
+func getBytes() []byte {
+ theSecret.RLock()
+ defer theSecret.RUnlock()
+ return theSecret.bytes
+}
+
+func copyBytes(bytes []byte) []byte {
+ out := make([]byte, len(bytes))
+ copy(out, bytes)
+ return out
+}
+
+func setBytes() ([]byte, error) {
+ theSecret.Lock()
+ defer theSecret.Unlock()
+
+ if theSecret.bytes != nil {
+ return theSecret.bytes, nil
+ }
+
+ base64Bytes, err := ioutil.ReadFile(theSecret.path)
+ if err != nil {
+ return nil, fmt.Errorf("secret.setBytes: read %q: %v", theSecret.path, err)
+ }
+
+ secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes)))
+ n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("secret.setBytes: decode secret: %v", err)
+ }
+
+ if n != numSecretBytes {
+ return nil, fmt.Errorf("secret.setBytes: expected %d secretBytes in %s, found %d", numSecretBytes, theSecret.path, n)
+ }
+
+ theSecret.bytes = secretBytes
+ return copyBytes(theSecret.bytes), nil
+}
diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor.go b/workhorse/internal/senddata/contentprocessor/contentprocessor.go
new file mode 100644
index 00000000000..a5cc0fee013
--- /dev/null
+++ b/workhorse/internal/senddata/contentprocessor/contentprocessor.go
@@ -0,0 +1,126 @@
+package contentprocessor
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers"
+)
+
+type contentDisposition struct {
+ rw http.ResponseWriter
+ buf *bytes.Buffer
+ wroteHeader bool
+ flushed bool
+ active bool
+ removedResponseHeaders bool
+ status int
+ sentStatus bool
+}
+
+// SetContentHeaders buffers the response if Gitlab-Workhorse-Detect-Content-Type
+// header is found and set the proper content headers based on the current
+// value of content type and disposition
+func SetContentHeaders(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ cd := &contentDisposition{
+ rw: w,
+ buf: &bytes.Buffer{},
+ status: http.StatusOK,
+ }
+
+ defer cd.flush()
+
+ h.ServeHTTP(cd, r)
+ })
+}
+
+func (cd *contentDisposition) Header() http.Header {
+ return cd.rw.Header()
+}
+
+func (cd *contentDisposition) Write(data []byte) (int, error) {
+ // Normal write if we don't need to buffer
+ if cd.isUnbuffered() {
+ cd.WriteHeader(cd.status)
+ return cd.rw.Write(data)
+ }
+
+ // Write the new data into the buffer
+ n, _ := cd.buf.Write(data)
+
+ // If we have enough data to calculate the content headers then flush the Buffer
+ var err error
+ if cd.buf.Len() >= headers.MaxDetectSize {
+ err = cd.flushBuffer()
+ }
+
+ return n, err
+}
+
+func (cd *contentDisposition) flushBuffer() error {
+ if cd.isUnbuffered() {
+ return nil
+ }
+
+ cd.flushed = true
+
+ // If the buffer has any content then we calculate the content headers and
+ // write in the response
+ if cd.buf.Len() > 0 {
+ cd.writeContentHeaders()
+ cd.WriteHeader(cd.status)
+ _, err := io.Copy(cd.rw, cd.buf)
+ return err
+ }
+
+ // If no content is present in the buffer we still need to send the headers
+ cd.WriteHeader(cd.status)
+ return nil
+}
+
+func (cd *contentDisposition) writeContentHeaders() {
+ if cd.wroteHeader {
+ return
+ }
+
+ cd.wroteHeader = true
+ contentType, contentDisposition := headers.SafeContentHeaders(cd.buf.Bytes(), cd.Header().Get(headers.ContentDispositionHeader))
+ cd.Header().Set(headers.ContentTypeHeader, contentType)
+ cd.Header().Set(headers.ContentDispositionHeader, contentDisposition)
+}
+
+func (cd *contentDisposition) WriteHeader(status int) {
+ if cd.sentStatus {
+ return
+ }
+
+ cd.status = status
+
+ if cd.isUnbuffered() {
+ cd.rw.WriteHeader(cd.status)
+ cd.sentStatus = true
+ }
+}
+
+// If we find any response header, then we must calculate the content headers
+// If we don't find any, the data is not buffered and it works as
+// a usual ResponseWriter
+func (cd *contentDisposition) isUnbuffered() bool {
+ if !cd.removedResponseHeaders {
+ if headers.IsDetectContentTypeHeaderPresent(cd.rw) {
+ cd.active = true
+ }
+
+ cd.removedResponseHeaders = true
+ // We ensure to clear any response header from the response
+ headers.RemoveResponseHeaders(cd.rw)
+ }
+
+ return cd.flushed || !cd.active
+}
+
+func (cd *contentDisposition) flush() {
+ cd.flushBuffer()
+}
diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go
new file mode 100644
index 00000000000..5e3a74f04f9
--- /dev/null
+++ b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go
@@ -0,0 +1,293 @@
+package contentprocessor
+
+import (
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestFailSetContentTypeAndDisposition(t *testing.T) {
+ testCaseBody := "Hello world!"
+
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ _, err := io.WriteString(w, testCaseBody)
+ require.NoError(t, err)
+ })
+
+ resp := makeRequest(t, h, testCaseBody, "")
+
+ require.Equal(t, "", resp.Header.Get(headers.ContentDispositionHeader))
+ require.Equal(t, "", resp.Header.Get(headers.ContentTypeHeader))
+}
+
+func TestSuccessSetContentTypeAndDispositionFeatureEnabled(t *testing.T) {
+ testCaseBody := "Hello world!"
+
+ resp := makeRequest(t, nil, testCaseBody, "")
+
+ require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader))
+ require.Equal(t, "text/plain; charset=utf-8", resp.Header.Get(headers.ContentTypeHeader))
+}
+
+func TestSetProperContentTypeAndDisposition(t *testing.T) {
+ testCases := []struct {
+ desc string
+ contentType string
+ contentDisposition string
+ body string
+ }{
+ {
+ desc: "text type",
+ contentType: "text/plain; charset=utf-8",
+ contentDisposition: "inline",
+ body: "Hello world!",
+ },
+ {
+ desc: "HTML type",
+ contentType: "text/plain; charset=utf-8",
+ contentDisposition: "inline",
+ body: "<html><body>Hello world!</body></html>",
+ },
+ {
+ desc: "Javascript type",
+ contentType: "text/plain; charset=utf-8",
+ contentDisposition: "inline",
+ body: "<script>alert(\"foo\")</script>",
+ },
+ {
+ desc: "Image type",
+ contentType: "image/png",
+ contentDisposition: "inline",
+ body: testhelper.LoadFile(t, "testdata/image.png"),
+ },
+ {
+ desc: "SVG type",
+ contentType: "image/svg+xml",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/image.svg"),
+ },
+ {
+ desc: "Partial SVG type",
+ contentType: "image/svg+xml",
+ contentDisposition: "attachment",
+ body: "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" viewBox=\"0 0 330 82\"><title>SVG logo combined with the W3C logo, set horizontally</title><desc>The logo combines three entities displayed horizontall</desc><metadata>",
+ },
+ {
+ desc: "Application type",
+ contentType: "application/pdf",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/file.pdf"),
+ },
+ {
+ desc: "Application type pdf with inline disposition",
+ contentType: "application/pdf",
+ contentDisposition: "inline",
+ body: testhelper.LoadFile(t, "testdata/file.pdf"),
+ },
+ {
+ desc: "Application executable type",
+ contentType: "application/octet-stream",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/file.swf"),
+ },
+ {
+ desc: "Video type",
+ contentType: "video/mp4",
+ contentDisposition: "inline",
+ body: testhelper.LoadFile(t, "testdata/video.mp4"),
+ },
+ {
+ desc: "Audio type",
+ contentType: "audio/mpeg",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/audio.mp3"),
+ },
+ {
+ desc: "JSON type",
+ contentType: "text/plain; charset=utf-8",
+ contentDisposition: "inline",
+ body: "{ \"glossary\": { \"title\": \"example glossary\", \"GlossDiv\": { \"title\": \"S\" } } }",
+ },
+ {
+ desc: "Forged file with png extension but SWF content",
+ contentType: "application/octet-stream",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/forgedfile.png"),
+ },
+ {
+ desc: "BMPR file",
+ contentType: "application/octet-stream",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/file.bmpr"),
+ },
+ {
+ desc: "STL file",
+ contentType: "application/octet-stream",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/file.stl"),
+ },
+ {
+ desc: "RDoc file",
+ contentType: "text/plain; charset=utf-8",
+ contentDisposition: "inline",
+ body: testhelper.LoadFile(t, "testdata/file.rdoc"),
+ },
+ {
+ desc: "IPYNB file",
+ contentType: "text/plain; charset=utf-8",
+ contentDisposition: "inline",
+ body: testhelper.LoadFile(t, "testdata/file.ipynb"),
+ },
+ {
+ desc: "Sketch file",
+ contentType: "application/zip",
+ contentDisposition: "attachment",
+ body: testhelper.LoadFile(t, "testdata/file.sketch"),
+ },
+ {
+ desc: "PDF file with non-ASCII characters in filename",
+ contentType: "application/pdf",
+ contentDisposition: `attachment; filename="file-ä.pdf"; filename*=UTF-8''file-%c3.pdf`,
+ body: testhelper.LoadFile(t, "testdata/file-ä.pdf"),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ resp := makeRequest(t, nil, tc.body, tc.contentDisposition)
+
+ require.Equal(t, tc.contentType, resp.Header.Get(headers.ContentTypeHeader))
+ require.Equal(t, tc.contentDisposition, resp.Header.Get(headers.ContentDispositionHeader))
+ })
+ }
+}
+
+func TestFailOverrideContentType(t *testing.T) {
+ testCase := struct {
+ contentType string
+ body string
+ }{
+ contentType: "text/plain; charset=utf-8",
+ body: "<html><body>Hello world!</body></html>",
+ }
+
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ // We are pretending to be upstream or an inner layer of the ResponseWriter chain
+ w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true")
+ w.Header().Set(headers.ContentTypeHeader, "text/html; charset=utf-8")
+ _, err := io.WriteString(w, testCase.body)
+ require.NoError(t, err)
+ })
+
+ resp := makeRequest(t, h, testCase.body, "")
+
+ require.Equal(t, testCase.contentType, resp.Header.Get(headers.ContentTypeHeader))
+}
+
+func TestSuccessOverrideContentDispositionFromInlineToAttachment(t *testing.T) {
+ testCaseBody := "Hello world!"
+
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ // We are pretending to be upstream or an inner layer of the ResponseWriter chain
+ w.Header().Set(headers.ContentDispositionHeader, "attachment")
+ w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true")
+ _, err := io.WriteString(w, testCaseBody)
+ require.NoError(t, err)
+ })
+
+ resp := makeRequest(t, h, testCaseBody, "")
+
+ require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader))
+}
+
+func TestInlineContentDispositionForPdfFiles(t *testing.T) {
+ testCaseBody := testhelper.LoadFile(t, "testdata/file.pdf")
+
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ // We are pretending to be upstream or an inner layer of the ResponseWriter chain
+ w.Header().Set(headers.ContentDispositionHeader, "inline")
+ w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true")
+ _, err := io.WriteString(w, testCaseBody)
+ require.NoError(t, err)
+ })
+
+ resp := makeRequest(t, h, testCaseBody, "")
+
+ require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader))
+}
+
+func TestFailOverrideContentDispositionFromAttachmentToInline(t *testing.T) {
+ testCaseBody := testhelper.LoadFile(t, "testdata/image.svg")
+
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ // We are pretending to be upstream or an inner layer of the ResponseWriter chain
+ w.Header().Set(headers.ContentDispositionHeader, "inline")
+ w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true")
+ _, err := io.WriteString(w, testCaseBody)
+ require.NoError(t, err)
+ })
+
+ resp := makeRequest(t, h, testCaseBody, "")
+
+ require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader))
+}
+
+func TestHeadersDelete(t *testing.T) {
+ for _, code := range []int{200, 400} {
+ recorder := httptest.NewRecorder()
+ rw := &contentDisposition{rw: recorder}
+ for _, name := range headers.ResponseHeaders {
+ rw.Header().Set(name, "foobar")
+ }
+
+ rw.WriteHeader(code)
+
+ for _, name := range headers.ResponseHeaders {
+ if header := recorder.Header().Get(name); header != "" {
+ t.Fatalf("HTTP %d response: expected header to be empty, found %q", code, name)
+ }
+ }
+ }
+}
+
+func TestWriteHeadersCalledOnce(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ rw := &contentDisposition{rw: recorder}
+ rw.WriteHeader(400)
+ require.Equal(t, 400, rw.status)
+ require.Equal(t, true, rw.sentStatus)
+
+ rw.WriteHeader(200)
+ require.Equal(t, 400, rw.status)
+}
+
+func makeRequest(t *testing.T, handler http.HandlerFunc, body string, disposition string) *http.Response {
+ if handler == nil {
+ handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ // We are pretending to be upstream
+ w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true")
+ w.Header().Set(headers.ContentDispositionHeader, disposition)
+ _, err := io.WriteString(w, body)
+ require.NoError(t, err)
+ })
+ }
+ req, _ := http.NewRequest("GET", "/", nil)
+
+ rw := httptest.NewRecorder()
+ SetContentHeaders(handler).ServeHTTP(rw, req)
+
+ resp := rw.Result()
+ respBody, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ require.Equal(t, body, string(respBody))
+
+ return resp
+}
diff --git a/workhorse/internal/senddata/injecter.go b/workhorse/internal/senddata/injecter.go
new file mode 100644
index 00000000000..d5739d2a053
--- /dev/null
+++ b/workhorse/internal/senddata/injecter.go
@@ -0,0 +1,35 @@
+package senddata
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "net/http"
+ "strings"
+)
+
+type Injecter interface {
+ Match(string) bool
+ Inject(http.ResponseWriter, *http.Request, string)
+ Name() string
+}
+
+type Prefix string
+
+func (p Prefix) Match(s string) bool {
+ return strings.HasPrefix(s, string(p))
+}
+
+func (p Prefix) Unpack(result interface{}, sendData string) error {
+ jsonBytes, err := base64.URLEncoding.DecodeString(strings.TrimPrefix(sendData, string(p)))
+ if err != nil {
+ return err
+ }
+ if err := json.Unmarshal([]byte(jsonBytes), result); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p Prefix) Name() string {
+ return strings.TrimSuffix(string(p), ":")
+}
diff --git a/workhorse/internal/senddata/senddata.go b/workhorse/internal/senddata/senddata.go
new file mode 100644
index 00000000000..c287d2574fa
--- /dev/null
+++ b/workhorse/internal/senddata/senddata.go
@@ -0,0 +1,105 @@
+package senddata
+
+import (
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata/contentprocessor"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+)
+
+var (
+ sendDataResponses = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_senddata_responses",
+ Help: "How many HTTP responses have been hijacked by a workhorse senddata injecter",
+ },
+ []string{"injecter"},
+ )
+ sendDataResponseBytes = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_senddata_response_bytes",
+ Help: "How many bytes have been written by workhorse senddata response injecters",
+ },
+ []string{"injecter"},
+ )
+)
+
+type sendDataResponseWriter struct {
+ rw http.ResponseWriter
+ status int
+ hijacked bool
+ req *http.Request
+ injecters []Injecter
+}
+
+func SendData(h http.Handler, injecters ...Injecter) http.Handler {
+ return contentprocessor.SetContentHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ s := sendDataResponseWriter{
+ rw: w,
+ req: r,
+ injecters: injecters,
+ }
+ defer s.flush()
+ h.ServeHTTP(&s, r)
+ }))
+}
+
+func (s *sendDataResponseWriter) Header() http.Header {
+ return s.rw.Header()
+}
+
+func (s *sendDataResponseWriter) Write(data []byte) (int, error) {
+ if s.status == 0 {
+ s.WriteHeader(http.StatusOK)
+ }
+ if s.hijacked {
+ return len(data), nil
+ }
+ return s.rw.Write(data)
+}
+
+func (s *sendDataResponseWriter) WriteHeader(status int) {
+ if s.status != 0 {
+ return
+ }
+ s.status = status
+
+ if s.status == http.StatusOK && s.tryInject() {
+ return
+ }
+
+ s.rw.WriteHeader(s.status)
+}
+
+func (s *sendDataResponseWriter) tryInject() bool {
+ if s.hijacked {
+ return false
+ }
+
+ header := s.Header().Get(headers.GitlabWorkhorseSendDataHeader)
+ if header == "" {
+ return false
+ }
+
+ for _, injecter := range s.injecters {
+ if injecter.Match(header) {
+ s.hijacked = true
+ helper.DisableResponseBuffering(s.rw)
+ crw := helper.NewCountingResponseWriter(s.rw)
+ injecter.Inject(crw, s.req, header)
+ sendDataResponses.WithLabelValues(injecter.Name()).Inc()
+ sendDataResponseBytes.WithLabelValues(injecter.Name()).Add(float64(crw.Count()))
+ return true
+ }
+ }
+
+ return false
+}
+
+func (s *sendDataResponseWriter) flush() {
+ s.WriteHeader(http.StatusOK)
+}
diff --git a/workhorse/internal/senddata/writer_test.go b/workhorse/internal/senddata/writer_test.go
new file mode 100644
index 00000000000..1262acd5472
--- /dev/null
+++ b/workhorse/internal/senddata/writer_test.go
@@ -0,0 +1,71 @@
+package senddata
+
+import (
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers"
+)
+
+func TestWriter(t *testing.T) {
+ upstreamResponse := "hello world"
+
+ testCases := []struct {
+ desc string
+ headerValue string
+ out string
+ }{
+ {
+ desc: "inject",
+ headerValue: testInjecterName + ":" + testInjecterName,
+ out: testInjecterData,
+ },
+ {
+ desc: "pass",
+ headerValue: "",
+ out: upstreamResponse,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ rw := &sendDataResponseWriter{rw: recorder, injecters: []Injecter{&testInjecter{}}}
+
+ rw.Header().Set(headers.GitlabWorkhorseSendDataHeader, tc.headerValue)
+
+ n, err := rw.Write([]byte(upstreamResponse))
+ require.NoError(t, err)
+ require.Equal(t, len(upstreamResponse), n, "bytes written")
+
+ recorder.Flush()
+
+ body := recorder.Result().Body
+ data, err := ioutil.ReadAll(body)
+ require.NoError(t, err)
+ require.NoError(t, body.Close())
+
+ require.Equal(t, tc.out, string(data))
+ })
+ }
+}
+
+const (
+ testInjecterName = "test-injecter"
+ testInjecterData = "hello this is injected data"
+)
+
+type testInjecter struct{}
+
+func (ti *testInjecter) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ io.WriteString(w, testInjecterData)
+}
+
+func (ti *testInjecter) Match(s string) bool { return strings.HasPrefix(s, testInjecterName+":") }
+func (ti *testInjecter) Name() string { return testInjecterName }
diff --git a/workhorse/internal/sendfile/sendfile.go b/workhorse/internal/sendfile/sendfile.go
new file mode 100644
index 00000000000..d009f216eb9
--- /dev/null
+++ b/workhorse/internal/sendfile/sendfile.go
@@ -0,0 +1,162 @@
+/*
+The xSendFile middleware transparently sends static files in HTTP responses
+via the X-Sendfile mechanism. All that is needed in the Rails code is the
+'send_file' method.
+*/
+
+package sendfile
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "regexp"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ sendFileRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_sendfile_requests",
+ Help: "How many X-Sendfile requests have been processed by gitlab-workhorse, partitioned by sendfile type.",
+ },
+ []string{"type"},
+ )
+
+ sendFileBytes = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_sendfile_bytes",
+ Help: "How many X-Sendfile bytes have been sent by gitlab-workhorse, partitioned by sendfile type.",
+ },
+ []string{"type"},
+ )
+
+ artifactsSendFile = regexp.MustCompile("builds/[0-9]+/artifacts")
+)
+
+type sendFileResponseWriter struct {
+ rw http.ResponseWriter
+ status int
+ hijacked bool
+ req *http.Request
+}
+
+func SendFile(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ s := &sendFileResponseWriter{
+ rw: rw,
+ req: req,
+ }
+ // Advertise to upstream (Rails) that we support X-Sendfile
+ req.Header.Set(headers.XSendFileTypeHeader, headers.XSendFileHeader)
+ defer s.flush()
+ h.ServeHTTP(s, req)
+ })
+}
+
+func (s *sendFileResponseWriter) Header() http.Header {
+ return s.rw.Header()
+}
+
+func (s *sendFileResponseWriter) Write(data []byte) (int, error) {
+ if s.status == 0 {
+ s.WriteHeader(http.StatusOK)
+ }
+ if s.hijacked {
+ return len(data), nil
+ }
+ return s.rw.Write(data)
+}
+
+func (s *sendFileResponseWriter) WriteHeader(status int) {
+ if s.status != 0 {
+ return
+ }
+
+ s.status = status
+ if s.status != http.StatusOK {
+ s.rw.WriteHeader(s.status)
+ return
+ }
+
+ file := s.Header().Get(headers.XSendFileHeader)
+ if file != "" && !s.hijacked {
+ // Mark this connection as hijacked
+ s.hijacked = true
+
+ // Serve the file
+ helper.DisableResponseBuffering(s.rw)
+ sendFileFromDisk(s.rw, s.req, file)
+ return
+ }
+
+ s.rw.WriteHeader(s.status)
+}
+
+func sendFileFromDisk(w http.ResponseWriter, r *http.Request, file string) {
+ log.WithContextFields(r.Context(), log.Fields{
+ "file": file,
+ "method": r.Method,
+ "uri": mask.URL(r.RequestURI),
+ }).Print("Send file")
+
+ contentTypeHeaderPresent := false
+
+ if headers.IsDetectContentTypeHeaderPresent(w) {
+ // Removing the GitlabWorkhorseDetectContentTypeHeader header to
+ // avoid handling the response by the senddata handler
+ w.Header().Del(headers.GitlabWorkhorseDetectContentTypeHeader)
+ contentTypeHeaderPresent = true
+ }
+
+ content, fi, err := helper.OpenFile(file)
+ if err != nil {
+ http.NotFound(w, r)
+ return
+ }
+ defer content.Close()
+
+ countSendFileMetrics(fi.Size(), r)
+
+ if contentTypeHeaderPresent {
+ data, err := ioutil.ReadAll(io.LimitReader(content, headers.MaxDetectSize))
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("content type detection: %v", err))
+ return
+ }
+
+ content.Seek(0, io.SeekStart)
+
+ contentType, contentDisposition := headers.SafeContentHeaders(data, w.Header().Get(headers.ContentDispositionHeader))
+ w.Header().Set(headers.ContentTypeHeader, contentType)
+ w.Header().Set(headers.ContentDispositionHeader, contentDisposition)
+ }
+
+ http.ServeContent(w, r, "", fi.ModTime(), content)
+}
+
+func countSendFileMetrics(size int64, r *http.Request) {
+ var requestType string
+ switch {
+ case artifactsSendFile.MatchString(r.RequestURI):
+ requestType = "artifacts"
+ default:
+ requestType = "other"
+ }
+
+ sendFileRequests.WithLabelValues(requestType).Inc()
+ sendFileBytes.WithLabelValues(requestType).Add(float64(size))
+}
+
+func (s *sendFileResponseWriter) flush() {
+ s.WriteHeader(http.StatusOK)
+}
diff --git a/workhorse/internal/sendfile/sendfile_test.go b/workhorse/internal/sendfile/sendfile_test.go
new file mode 100644
index 00000000000..d424814b5e5
--- /dev/null
+++ b/workhorse/internal/sendfile/sendfile_test.go
@@ -0,0 +1,171 @@
+package sendfile
+
+import (
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers"
+)
+
+func TestResponseWriter(t *testing.T) {
+ upstreamResponse := "hello world"
+
+ fixturePath := "testdata/sent-file.txt"
+ fixtureContent, err := ioutil.ReadFile(fixturePath)
+ require.NoError(t, err)
+
+ testCases := []struct {
+ desc string
+ sendfileHeader string
+ out string
+ }{
+ {
+ desc: "send a file",
+ sendfileHeader: fixturePath,
+ out: string(fixtureContent),
+ },
+ {
+ desc: "pass through unaltered",
+ sendfileHeader: "",
+ out: upstreamResponse,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ r, err := http.NewRequest("GET", "/foo", nil)
+ require.NoError(t, err)
+
+ rw := httptest.NewRecorder()
+ sf := &sendFileResponseWriter{rw: rw, req: r}
+ sf.Header().Set(headers.XSendFileHeader, tc.sendfileHeader)
+
+ upstreamBody := []byte(upstreamResponse)
+ n, err := sf.Write(upstreamBody)
+ require.NoError(t, err)
+ require.Equal(t, len(upstreamBody), n, "bytes written")
+
+ rw.Flush()
+
+ body := rw.Result().Body
+ data, err := ioutil.ReadAll(body)
+ require.NoError(t, err)
+ require.NoError(t, body.Close())
+
+ require.Equal(t, tc.out, string(data))
+ })
+ }
+}
+
+func TestAllowExistentContentHeaders(t *testing.T) {
+ fixturePath := "../../testdata/forgedfile.png"
+
+ httpHeaders := map[string]string{
+ headers.ContentTypeHeader: "image/png",
+ headers.ContentDispositionHeader: "inline",
+ }
+
+ resp := makeRequest(t, fixturePath, httpHeaders)
+ require.Equal(t, "image/png", resp.Header.Get(headers.ContentTypeHeader))
+ require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader))
+}
+
+func TestSuccessOverrideContentHeadersFeatureEnabled(t *testing.T) {
+ fixturePath := "../../testdata/forgedfile.png"
+
+ httpHeaders := make(map[string]string)
+ httpHeaders[headers.ContentTypeHeader] = "image/png"
+ httpHeaders[headers.ContentDispositionHeader] = "inline"
+ httpHeaders["Range"] = "bytes=1-2"
+
+ resp := makeRequest(t, fixturePath, httpHeaders)
+ require.Equal(t, "image/png", resp.Header.Get(headers.ContentTypeHeader))
+ require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader))
+}
+
+func TestSuccessOverrideContentHeadersRangeRequestFeatureEnabled(t *testing.T) {
+ fixturePath := "../../testdata/forgedfile.png"
+
+ fixtureContent, err := ioutil.ReadFile(fixturePath)
+ require.NoError(t, err)
+
+ r, err := http.NewRequest("GET", "/foo", nil)
+ r.Header.Set("Range", "bytes=1-2")
+ require.NoError(t, err)
+
+ rw := httptest.NewRecorder()
+ sf := &sendFileResponseWriter{rw: rw, req: r}
+
+ sf.Header().Set(headers.XSendFileHeader, fixturePath)
+ sf.Header().Set(headers.ContentTypeHeader, "image/png")
+ sf.Header().Set(headers.ContentDispositionHeader, "inline")
+ sf.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true")
+
+ upstreamBody := []byte(fixtureContent)
+ _, err = sf.Write(upstreamBody)
+ require.NoError(t, err)
+
+ rw.Flush()
+
+ resp := rw.Result()
+ body := resp.Body
+ data, err := ioutil.ReadAll(body)
+ require.NoError(t, err)
+ require.NoError(t, body.Close())
+
+ require.Len(t, data, 2)
+
+ require.Equal(t, "application/octet-stream", resp.Header.Get(headers.ContentTypeHeader))
+ require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader))
+}
+
+func TestSuccessInlineWhitelistedTypesFeatureEnabled(t *testing.T) {
+ fixturePath := "../../testdata/image.png"
+
+ httpHeaders := map[string]string{
+ headers.ContentDispositionHeader: "inline",
+ headers.GitlabWorkhorseDetectContentTypeHeader: "true",
+ }
+
+ resp := makeRequest(t, fixturePath, httpHeaders)
+
+ require.Equal(t, "image/png", resp.Header.Get(headers.ContentTypeHeader))
+ require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader))
+}
+
+func makeRequest(t *testing.T, fixturePath string, httpHeaders map[string]string) *http.Response {
+ fixtureContent, err := ioutil.ReadFile(fixturePath)
+ require.NoError(t, err)
+
+ r, err := http.NewRequest("GET", "/foo", nil)
+ require.NoError(t, err)
+
+ rw := httptest.NewRecorder()
+ sf := &sendFileResponseWriter{rw: rw, req: r}
+
+ sf.Header().Set(headers.XSendFileHeader, fixturePath)
+ for name, value := range httpHeaders {
+ sf.Header().Set(name, value)
+ }
+
+ upstreamBody := []byte("hello")
+ n, err := sf.Write(upstreamBody)
+ require.NoError(t, err)
+ require.Equal(t, len(upstreamBody), n, "bytes written")
+
+ rw.Flush()
+
+ resp := rw.Result()
+ body := resp.Body
+ data, err := ioutil.ReadAll(body)
+ require.NoError(t, err)
+ require.NoError(t, body.Close())
+
+ require.Equal(t, fixtureContent, data)
+
+ return resp
+}
diff --git a/workhorse/internal/sendfile/testdata/sent-file.txt b/workhorse/internal/sendfile/testdata/sent-file.txt
new file mode 100644
index 00000000000..40e33f8a628
--- /dev/null
+++ b/workhorse/internal/sendfile/testdata/sent-file.txt
@@ -0,0 +1 @@
+This file is sent with X-SendFile
diff --git a/workhorse/internal/sendurl/sendurl.go b/workhorse/internal/sendurl/sendurl.go
new file mode 100644
index 00000000000..cf3d14a2bf0
--- /dev/null
+++ b/workhorse/internal/sendurl/sendurl.go
@@ -0,0 +1,167 @@
+package sendurl
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+ "gitlab.com/gitlab-org/labkit/tracing"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+)
+
+type entry struct{ senddata.Prefix }
+
+type entryParams struct {
+ URL string
+ AllowRedirects bool
+}
+
+var SendURL = &entry{"send-url:"}
+
+var rangeHeaderKeys = []string{
+ "If-Match",
+ "If-Unmodified-Since",
+ "If-None-Match",
+ "If-Modified-Since",
+ "If-Range",
+ "Range",
+}
+
+// Keep cache headers from the original response, not the proxied response. The
+// original response comes from the Rails application, which should be the
+// source of truth for caching.
+var preserveHeaderKeys = map[string]bool{
+ "Cache-Control": true,
+ "Expires": true,
+ "Date": true, // Support for HTTP 1.0 proxies
+ "Pragma": true, // Support for HTTP 1.0 proxies
+}
+
+// httpTransport defines a http.Transport with values
+// that are more restrictive than for http.DefaultTransport,
+// they define shorter TLS Handshake, and more aggressive connection closing
+// to prevent the connection hanging and reduce FD usage
+var httpTransport = tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 10 * time.Second,
+ }).DialContext,
+ MaxIdleConns: 2,
+ IdleConnTimeout: 30 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 10 * time.Second,
+ ResponseHeaderTimeout: 30 * time.Second,
+}))
+
+var httpClient = &http.Client{
+ Transport: httpTransport,
+}
+
+var (
+ sendURLRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_send_url_requests",
+ Help: "How many send URL requests have been processed",
+ },
+ []string{"status"},
+ )
+ sendURLOpenRequests = promauto.NewGauge(
+ prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_send_url_open_requests",
+ Help: "Describes how many send URL requests are open now",
+ },
+ )
+ sendURLBytes = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_send_url_bytes",
+ Help: "How many bytes were passed with send URL",
+ },
+ )
+
+ sendURLRequestsInvalidData = sendURLRequests.WithLabelValues("invalid-data")
+ sendURLRequestsRequestFailed = sendURLRequests.WithLabelValues("request-failed")
+ sendURLRequestsSucceeded = sendURLRequests.WithLabelValues("succeeded")
+)
+
+func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
+ var params entryParams
+
+ sendURLOpenRequests.Inc()
+ defer sendURLOpenRequests.Dec()
+
+ if err := e.Unpack(&params, sendData); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("SendURL: unpack sendData: %v", err))
+ return
+ }
+
+ log.WithContextFields(r.Context(), log.Fields{
+ "url": mask.URL(params.URL),
+ "path": r.URL.Path,
+ }).Info("SendURL: sending")
+
+ if params.URL == "" {
+ sendURLRequestsInvalidData.Inc()
+ helper.Fail500(w, r, fmt.Errorf("SendURL: URL is empty"))
+ return
+ }
+
+ // create new request and copy range headers
+ newReq, err := http.NewRequest("GET", params.URL, nil)
+ if err != nil {
+ sendURLRequestsInvalidData.Inc()
+ helper.Fail500(w, r, fmt.Errorf("SendURL: NewRequest: %v", err))
+ return
+ }
+ newReq = newReq.WithContext(r.Context())
+
+ for _, header := range rangeHeaderKeys {
+ newReq.Header[header] = r.Header[header]
+ }
+
+ // execute new request
+ var resp *http.Response
+ if params.AllowRedirects {
+ resp, err = httpClient.Do(newReq)
+ } else {
+ resp, err = httpTransport.RoundTrip(newReq)
+ }
+ if err != nil {
+ sendURLRequestsRequestFailed.Inc()
+ helper.Fail500(w, r, fmt.Errorf("SendURL: Do request: %v", err))
+ return
+ }
+
+ // Prevent Go from adding a Content-Length header automatically
+ w.Header().Del("Content-Length")
+
+ // copy response headers and body, except the headers from preserveHeaderKeys
+ for key, value := range resp.Header {
+ if !preserveHeaderKeys[key] {
+ w.Header()[key] = value
+ }
+ }
+ w.WriteHeader(resp.StatusCode)
+
+ defer resp.Body.Close()
+ n, err := io.Copy(w, resp.Body)
+ sendURLBytes.Add(float64(n))
+
+ if err != nil {
+ sendURLRequestsRequestFailed.Inc()
+ helper.LogError(r, fmt.Errorf("SendURL: Copy response: %v", err))
+ return
+ }
+
+ sendURLRequestsSucceeded.Inc()
+}
diff --git a/workhorse/internal/sendurl/sendurl_test.go b/workhorse/internal/sendurl/sendurl_test.go
new file mode 100644
index 00000000000..41e1dbb8e0f
--- /dev/null
+++ b/workhorse/internal/sendurl/sendurl_test.go
@@ -0,0 +1,197 @@
+package sendurl
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+const testData = `123456789012345678901234567890`
+const testDataEtag = `W/"myetag"`
+
+func testEntryServer(t *testing.T, requestURL string, httpHeaders http.Header, allowRedirects bool) *httptest.ResponseRecorder {
+ requestHandler := func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "GET", r.Method)
+
+ url := r.URL.String() + "/file"
+ jsonParams := fmt.Sprintf(`{"URL":%q,"AllowRedirects":%s}`,
+ url, strconv.FormatBool(allowRedirects))
+ data := base64.URLEncoding.EncodeToString([]byte(jsonParams))
+
+ // The server returns a Content-Disposition
+ w.Header().Set("Content-Disposition", "attachment; filename=\"archive.txt\"")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Expires", "")
+ w.Header().Set("Date", "Wed, 21 Oct 2015 05:28:00 GMT")
+ w.Header().Set("Pragma", "no-cache")
+
+ SendURL.Inject(w, r, data)
+ }
+ serveFile := func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "GET", r.Method)
+
+ tempFile, err := ioutil.TempFile("", "download_file")
+ require.NoError(t, err)
+ require.NoError(t, os.Remove(tempFile.Name()))
+ defer tempFile.Close()
+ _, err = tempFile.Write([]byte(testData))
+ require.NoError(t, err)
+
+ w.Header().Set("Etag", testDataEtag)
+ w.Header().Set("Cache-Control", "public")
+ w.Header().Set("Expires", "Wed, 21 Oct 2015 07:28:00 GMT")
+ w.Header().Set("Date", "Wed, 21 Oct 2015 06:28:00 GMT")
+ w.Header().Set("Pragma", "")
+
+ http.ServeContent(w, r, "archive.txt", time.Now(), tempFile)
+ }
+ redirectFile := func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "GET", r.Method)
+ http.Redirect(w, r, r.URL.String()+"/download", http.StatusTemporaryRedirect)
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/get/request", requestHandler)
+ mux.HandleFunc("/get/request/file", serveFile)
+ mux.HandleFunc("/get/redirect", requestHandler)
+ mux.HandleFunc("/get/redirect/file", redirectFile)
+ mux.HandleFunc("/get/redirect/file/download", serveFile)
+ mux.HandleFunc("/get/file-not-existing", requestHandler)
+
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ httpRequest, err := http.NewRequest("GET", server.URL+requestURL, nil)
+ require.NoError(t, err)
+ if httpHeaders != nil {
+ httpRequest.Header = httpHeaders
+ }
+
+ response := httptest.NewRecorder()
+ mux.ServeHTTP(response, httpRequest)
+ return response
+}
+
+func TestDownloadingUsingSendURL(t *testing.T) {
+ response := testEntryServer(t, "/get/request", nil, false)
+ require.Equal(t, http.StatusOK, response.Code)
+
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Type",
+ "text/plain; charset=utf-8")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Disposition",
+ "attachment; filename=\"archive.txt\"")
+
+ testhelper.RequireResponseBody(t, response, testData)
+}
+
+func TestDownloadingAChunkOfDataWithSendURL(t *testing.T) {
+ httpHeaders := http.Header{
+ "Range": []string{
+ "bytes=1-2",
+ },
+ }
+
+ response := testEntryServer(t, "/get/request", httpHeaders, false)
+ require.Equal(t, http.StatusPartialContent, response.Code)
+
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Type",
+ "text/plain; charset=utf-8")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Disposition",
+ "attachment; filename=\"archive.txt\"")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Range",
+ "bytes 1-2/30")
+
+ testhelper.RequireResponseBody(t, response, "23")
+}
+
+func TestAccessingAlreadyDownloadedFileWithSendURL(t *testing.T) {
+ httpHeaders := http.Header{
+ "If-None-Match": []string{testDataEtag},
+ }
+
+ response := testEntryServer(t, "/get/request", httpHeaders, false)
+ require.Equal(t, http.StatusNotModified, response.Code)
+}
+
+func TestAccessingRedirectWithSendURL(t *testing.T) {
+ response := testEntryServer(t, "/get/redirect", nil, false)
+ require.Equal(t, http.StatusTemporaryRedirect, response.Code)
+}
+
+func TestAccessingAllowedRedirectWithSendURL(t *testing.T) {
+ response := testEntryServer(t, "/get/redirect", nil, true)
+ require.Equal(t, http.StatusOK, response.Code)
+
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Type",
+ "text/plain; charset=utf-8")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Disposition",
+ "attachment; filename=\"archive.txt\"")
+}
+
+func TestAccessingAllowedRedirectWithChunkOfDataWithSendURL(t *testing.T) {
+ httpHeaders := http.Header{
+ "Range": []string{
+ "bytes=1-2",
+ },
+ }
+
+ response := testEntryServer(t, "/get/redirect", httpHeaders, true)
+ require.Equal(t, http.StatusPartialContent, response.Code)
+
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Type",
+ "text/plain; charset=utf-8")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Disposition",
+ "attachment; filename=\"archive.txt\"")
+ testhelper.RequireResponseHeader(t, response,
+ "Content-Range",
+ "bytes 1-2/30")
+
+ testhelper.RequireResponseBody(t, response, "23")
+}
+
+func TestOriginalCacheHeadersPreservedWithSendURL(t *testing.T) {
+ response := testEntryServer(t, "/get/redirect", nil, true)
+ require.Equal(t, http.StatusOK, response.Code)
+
+ testhelper.RequireResponseHeader(t, response,
+ "Cache-Control",
+ "no-cache")
+ testhelper.RequireResponseHeader(t, response,
+ "Expires",
+ "")
+ testhelper.RequireResponseHeader(t, response,
+ "Date",
+ "Wed, 21 Oct 2015 05:28:00 GMT")
+ testhelper.RequireResponseHeader(t, response,
+ "Pragma",
+ "no-cache")
+}
+
+func TestDownloadingNonExistingFileUsingSendURL(t *testing.T) {
+ response := testEntryServer(t, "/invalid/path", nil, false)
+ require.Equal(t, http.StatusNotFound, response.Code)
+}
+
+func TestDownloadingNonExistingRemoteFileWithSendURL(t *testing.T) {
+ response := testEntryServer(t, "/get/file-not-existing", nil, false)
+ require.Equal(t, http.StatusNotFound, response.Code)
+}
diff --git a/workhorse/internal/staticpages/deploy_page.go b/workhorse/internal/staticpages/deploy_page.go
new file mode 100644
index 00000000000..d08ed449ae6
--- /dev/null
+++ b/workhorse/internal/staticpages/deploy_page.go
@@ -0,0 +1,26 @@
+package staticpages
+
+import (
+ "io/ioutil"
+ "net/http"
+ "path/filepath"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+func (s *Static) DeployPage(handler http.Handler) http.Handler {
+ deployPage := filepath.Join(s.DocumentRoot, "index.html")
+
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ data, err := ioutil.ReadFile(deployPage)
+ if err != nil {
+ handler.ServeHTTP(w, r)
+ return
+ }
+
+ helper.SetNoCacheHeaders(w.Header())
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+ w.Write(data)
+ })
+}
diff --git a/workhorse/internal/staticpages/deploy_page_test.go b/workhorse/internal/staticpages/deploy_page_test.go
new file mode 100644
index 00000000000..4b081e73a97
--- /dev/null
+++ b/workhorse/internal/staticpages/deploy_page_test.go
@@ -0,0 +1,59 @@
+package staticpages
+
+import (
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIfNoDeployPageExist(t *testing.T) {
+ dir, err := ioutil.TempDir("", "deploy")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ w := httptest.NewRecorder()
+
+ executed := false
+ st := &Static{dir}
+ st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ executed = true
+ })).ServeHTTP(w, nil)
+ if !executed {
+ t.Error("The handler should get executed")
+ }
+}
+
+func TestIfDeployPageExist(t *testing.T) {
+ dir, err := ioutil.TempDir("", "deploy")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ deployPage := "DEPLOY"
+ ioutil.WriteFile(filepath.Join(dir, "index.html"), []byte(deployPage), 0600)
+
+ w := httptest.NewRecorder()
+
+ executed := false
+ st := &Static{dir}
+ st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ executed = true
+ })).ServeHTTP(w, nil)
+ if executed {
+ t.Error("The handler should not get executed")
+ }
+ w.Flush()
+
+ require.Equal(t, 200, w.Code)
+ testhelper.RequireResponseBody(t, w, deployPage)
+}
diff --git a/workhorse/internal/staticpages/error_pages.go b/workhorse/internal/staticpages/error_pages.go
new file mode 100644
index 00000000000..3cc89d9f811
--- /dev/null
+++ b/workhorse/internal/staticpages/error_pages.go
@@ -0,0 +1,138 @@
+package staticpages
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "path/filepath"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ staticErrorResponses = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_static_error_responses",
+ Help: "How many HTTP responses have been changed to a static error page, by HTTP status code.",
+ },
+ []string{"code"},
+ )
+)
+
+type ErrorFormat int
+
+const (
+ ErrorFormatHTML ErrorFormat = iota
+ ErrorFormatJSON
+ ErrorFormatText
+)
+
+type errorPageResponseWriter struct {
+ rw http.ResponseWriter
+ status int
+ hijacked bool
+ path string
+ format ErrorFormat
+}
+
+func (s *errorPageResponseWriter) Header() http.Header {
+ return s.rw.Header()
+}
+
+func (s *errorPageResponseWriter) Write(data []byte) (int, error) {
+ if s.status == 0 {
+ s.WriteHeader(http.StatusOK)
+ }
+ if s.hijacked {
+ return len(data), nil
+ }
+ return s.rw.Write(data)
+}
+
+func (s *errorPageResponseWriter) WriteHeader(status int) {
+ if s.status != 0 {
+ return
+ }
+
+ s.status = status
+
+ if s.status < 400 || s.status > 599 || s.rw.Header().Get("X-GitLab-Custom-Error") != "" {
+ s.rw.WriteHeader(status)
+ return
+ }
+
+ var contentType string
+ var data []byte
+ switch s.format {
+ case ErrorFormatText:
+ contentType, data = s.writeText()
+ case ErrorFormatJSON:
+ contentType, data = s.writeJSON()
+ default:
+ contentType, data = s.writeHTML()
+ }
+
+ if contentType == "" {
+ s.rw.WriteHeader(status)
+ return
+ }
+
+ s.hijacked = true
+ staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc()
+
+ helper.SetNoCacheHeaders(s.rw.Header())
+ s.rw.Header().Set("Content-Type", contentType)
+ s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
+ s.rw.Header().Del("Transfer-Encoding")
+ s.rw.WriteHeader(s.status)
+ s.rw.Write(data)
+}
+
+func (s *errorPageResponseWriter) writeHTML() (string, []byte) {
+ if s.rw.Header().Get("Content-Type") != "application/json" {
+ errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status))
+
+ // check if custom error page exists, serve this page instead
+ if data, err := ioutil.ReadFile(errorPageFile); err == nil {
+ return "text/html; charset=utf-8", data
+ }
+ }
+
+ return "", nil
+}
+
+func (s *errorPageResponseWriter) writeJSON() (string, []byte) {
+ message, err := json.Marshal(map[string]interface{}{"error": http.StatusText(s.status), "status": s.status})
+ if err != nil {
+ return "", nil
+ }
+
+ return "application/json; charset=utf-8", append(message, "\n"...)
+}
+
+func (s *errorPageResponseWriter) writeText() (string, []byte) {
+ return "text/plain; charset=utf-8", []byte(http.StatusText(s.status) + "\n")
+}
+
+func (s *errorPageResponseWriter) flush() {
+ s.WriteHeader(http.StatusOK)
+}
+
+func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler {
+ if disabled {
+ return handler
+ }
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rw := errorPageResponseWriter{
+ rw: w,
+ path: st.DocumentRoot,
+ format: format,
+ }
+ defer rw.flush()
+ handler.ServeHTTP(&rw, r)
+ })
+}
diff --git a/workhorse/internal/staticpages/error_pages_test.go b/workhorse/internal/staticpages/error_pages_test.go
new file mode 100644
index 00000000000..05ec06cd429
--- /dev/null
+++ b/workhorse/internal/staticpages/error_pages_test.go
@@ -0,0 +1,191 @@
+package staticpages
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+func TestIfErrorPageIsPresented(t *testing.T) {
+ dir, err := ioutil.TempDir("", "error_page")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ errorPage := "ERROR"
+ ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600)
+
+ w := httptest.NewRecorder()
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(404)
+ upstreamBody := "Not Found"
+ n, err := fmt.Fprint(w, upstreamBody)
+ require.NoError(t, err)
+ require.Equal(t, len(upstreamBody), n, "bytes written")
+ })
+ st := &Static{dir}
+ st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
+ w.Flush()
+
+ require.Equal(t, 404, w.Code)
+ testhelper.RequireResponseBody(t, w, errorPage)
+ testhelper.RequireResponseHeader(t, w, "Content-Type", "text/html; charset=utf-8")
+}
+
+func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
+ dir, err := ioutil.TempDir("", "error_page")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ w := httptest.NewRecorder()
+ errorResponse := "ERROR"
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(404)
+ fmt.Fprint(w, errorResponse)
+ })
+ st := &Static{dir}
+ st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
+ w.Flush()
+
+ require.Equal(t, 404, w.Code)
+ testhelper.RequireResponseBody(t, w, errorResponse)
+}
+
+func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
+ dir, err := ioutil.TempDir("", "error_page")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ errorPage := "ERROR"
+ ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600)
+
+ w := httptest.NewRecorder()
+ serverError := "Interesting Server Error"
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(500)
+ fmt.Fprint(w, serverError)
+ })
+ st := &Static{dir}
+ st.ErrorPagesUnless(true, ErrorFormatHTML, h).ServeHTTP(w, nil)
+ w.Flush()
+ require.Equal(t, 500, w.Code)
+ testhelper.RequireResponseBody(t, w, serverError)
+}
+
+func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) {
+ dir, err := ioutil.TempDir("", "error_page")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ errorPage := "ERROR"
+ ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600)
+
+ w := httptest.NewRecorder()
+ serverError := "Interesting Server Error"
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Add("X-GitLab-Custom-Error", "1")
+ w.WriteHeader(500)
+ fmt.Fprint(w, serverError)
+ })
+ st := &Static{dir}
+ st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
+ w.Flush()
+ require.Equal(t, 500, w.Code)
+ testhelper.RequireResponseBody(t, w, serverError)
+}
+
+func TestErrorPageInterceptedByContentType(t *testing.T) {
+ testCases := []struct {
+ contentType string
+ intercepted bool
+ }{
+ {contentType: "application/json", intercepted: false},
+ {contentType: "text/plain", intercepted: true},
+ {contentType: "text/html", intercepted: true},
+ {contentType: "", intercepted: true},
+ }
+
+ for _, tc := range testCases {
+ dir, err := ioutil.TempDir("", "error_page")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ errorPage := "ERROR"
+ ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600)
+
+ w := httptest.NewRecorder()
+ serverError := "Interesting Server Error"
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Add("Content-Type", tc.contentType)
+ w.WriteHeader(500)
+ fmt.Fprint(w, serverError)
+ })
+ st := &Static{dir}
+ st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
+ w.Flush()
+ require.Equal(t, 500, w.Code)
+
+ if tc.intercepted {
+ testhelper.RequireResponseBody(t, w, errorPage)
+ } else {
+ testhelper.RequireResponseBody(t, w, serverError)
+ }
+ }
+}
+
+func TestIfErrorPageIsPresentedJSON(t *testing.T) {
+ errorPage := "{\"error\":\"Not Found\",\"status\":404}\n"
+
+ w := httptest.NewRecorder()
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(404)
+ upstreamBody := "This string is ignored"
+ n, err := fmt.Fprint(w, upstreamBody)
+ require.NoError(t, err)
+ require.Equal(t, len(upstreamBody), n, "bytes written")
+ })
+ st := &Static{""}
+ st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil)
+ w.Flush()
+
+ require.Equal(t, 404, w.Code)
+ testhelper.RequireResponseBody(t, w, errorPage)
+ testhelper.RequireResponseHeader(t, w, "Content-Type", "application/json; charset=utf-8")
+}
+
+func TestIfErrorPageIsPresentedText(t *testing.T) {
+ errorPage := "Not Found\n"
+
+ w := httptest.NewRecorder()
+ h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(404)
+ upstreamBody := "This string is ignored"
+ n, err := fmt.Fprint(w, upstreamBody)
+ require.NoError(t, err)
+ require.Equal(t, len(upstreamBody), n, "bytes written")
+ })
+ st := &Static{""}
+ st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil)
+ w.Flush()
+
+ require.Equal(t, 404, w.Code)
+ testhelper.RequireResponseBody(t, w, errorPage)
+ testhelper.RequireResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8")
+}
diff --git a/workhorse/internal/staticpages/servefile.go b/workhorse/internal/staticpages/servefile.go
new file mode 100644
index 00000000000..c98bc030bc2
--- /dev/null
+++ b/workhorse/internal/staticpages/servefile.go
@@ -0,0 +1,84 @@
+package staticpages
+
+import (
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix"
+)
+
+type CacheMode int
+
+const (
+ CacheDisabled CacheMode = iota
+ CacheExpireMax
+)
+
+// BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists,
+// handleServeFile will serve foo/bar instead of passing the request
+// upstream.
+func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ file := filepath.Join(s.DocumentRoot, prefix.Strip(r.URL.Path))
+
+ // The filepath.Join does Clean traversing directories up
+ if !strings.HasPrefix(file, s.DocumentRoot) {
+ helper.Fail500(w, r, &os.PathError{
+ Op: "open",
+ Path: file,
+ Err: os.ErrInvalid,
+ })
+ return
+ }
+
+ var content *os.File
+ var fi os.FileInfo
+ var err error
+
+ // Serve pre-gzipped assets
+ if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
+ content, fi, err = helper.OpenFile(file + ".gz")
+ if err == nil {
+ w.Header().Set("Content-Encoding", "gzip")
+ }
+ }
+
+ // If not found, open the original file
+ if content == nil || err != nil {
+ content, fi, err = helper.OpenFile(file)
+ }
+ if err != nil {
+ if notFoundHandler != nil {
+ notFoundHandler.ServeHTTP(w, r)
+ } else {
+ http.NotFound(w, r)
+ }
+ return
+ }
+ defer content.Close()
+
+ switch cache {
+ case CacheExpireMax:
+ // Cache statically served files for 1 year
+ cacheUntil := time.Now().AddDate(1, 0, 0).Format(http.TimeFormat)
+ w.Header().Set("Cache-Control", "public")
+ w.Header().Set("Expires", cacheUntil)
+ }
+
+ log.WithContextFields(r.Context(), log.Fields{
+ "file": file,
+ "encoding": w.Header().Get("Content-Encoding"),
+ "method": r.Method,
+ "uri": mask.URL(r.RequestURI),
+ }).Info("Send static file")
+
+ http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content)
+ })
+}
diff --git a/workhorse/internal/staticpages/servefile_test.go b/workhorse/internal/staticpages/servefile_test.go
new file mode 100644
index 00000000000..e136b876298
--- /dev/null
+++ b/workhorse/internal/staticpages/servefile_test.go
@@ -0,0 +1,134 @@
+package staticpages
+
+import (
+ "bytes"
+ "compress/gzip"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestServingNonExistingFile(t *testing.T) {
+ dir := "/path/to/non/existing/directory"
+ httpRequest, _ := http.NewRequest("GET", "/file", nil)
+
+ w := httptest.NewRecorder()
+ st := &Static{dir}
+ st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
+ require.Equal(t, 404, w.Code)
+}
+
+func TestServingDirectory(t *testing.T) {
+ dir, err := ioutil.TempDir("", "deploy")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ httpRequest, _ := http.NewRequest("GET", "/file", nil)
+ w := httptest.NewRecorder()
+ st := &Static{dir}
+ st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
+ require.Equal(t, 404, w.Code)
+}
+
+func TestServingMalformedUri(t *testing.T) {
+ dir := "/path/to/non/existing/directory"
+ httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil)
+
+ w := httptest.NewRecorder()
+ st := &Static{dir}
+ st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
+ require.Equal(t, 404, w.Code)
+}
+
+func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
+ dir := "/path/to/non/existing/directory"
+ httpRequest, _ := http.NewRequest("GET", "/file", nil)
+
+ executed := false
+ st := &Static{dir}
+ st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ executed = (r == httpRequest)
+ })).ServeHTTP(nil, httpRequest)
+ if !executed {
+ t.Error("The handler should get executed")
+ }
+}
+
+func TestServingTheActualFile(t *testing.T) {
+ dir, err := ioutil.TempDir("", "deploy")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ httpRequest, _ := http.NewRequest("GET", "/file", nil)
+
+ fileContent := "STATIC"
+ ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
+
+ w := httptest.NewRecorder()
+ st := &Static{dir}
+ st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
+ require.Equal(t, 200, w.Code)
+ if w.Body.String() != fileContent {
+ t.Error("We should serve the file: ", w.Body.String())
+ }
+}
+
+func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
+ dir, err := ioutil.TempDir("", "deploy")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+
+ httpRequest, _ := http.NewRequest("GET", "/file", nil)
+
+ if enableGzip {
+ httpRequest.Header.Set("Accept-Encoding", "gzip, deflate")
+ }
+
+ fileContent := "STATIC"
+
+ var fileGzipContent bytes.Buffer
+ fileGzip := gzip.NewWriter(&fileGzipContent)
+ fileGzip.Write([]byte(fileContent))
+ fileGzip.Close()
+
+ ioutil.WriteFile(filepath.Join(dir, "file.gz"), fileGzipContent.Bytes(), 0600)
+ ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
+
+ w := httptest.NewRecorder()
+ st := &Static{dir}
+ st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
+ require.Equal(t, 200, w.Code)
+ if enableGzip {
+ testhelper.RequireResponseHeader(t, w, "Content-Encoding", "gzip")
+ if !bytes.Equal(w.Body.Bytes(), fileGzipContent.Bytes()) {
+ t.Error("We should serve the pregzipped file")
+ }
+ } else {
+ require.Equal(t, 200, w.Code)
+ testhelper.RequireResponseHeader(t, w, "Content-Encoding")
+ if w.Body.String() != fileContent {
+ t.Error("We should serve the file: ", w.Body.String())
+ }
+ }
+}
+
+func TestServingThePregzippedFile(t *testing.T) {
+ testServingThePregzippedFile(t, true)
+}
+
+func TestServingThePregzippedFileWithoutEncoding(t *testing.T) {
+ testServingThePregzippedFile(t, false)
+}
diff --git a/workhorse/internal/staticpages/static.go b/workhorse/internal/staticpages/static.go
new file mode 100644
index 00000000000..b42351f15f5
--- /dev/null
+++ b/workhorse/internal/staticpages/static.go
@@ -0,0 +1,5 @@
+package staticpages
+
+type Static struct {
+ DocumentRoot string
+}
diff --git a/workhorse/internal/testhelper/gitaly.go b/workhorse/internal/testhelper/gitaly.go
new file mode 100644
index 00000000000..24884505440
--- /dev/null
+++ b/workhorse/internal/testhelper/gitaly.go
@@ -0,0 +1,384 @@
+package testhelper
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "path"
+ "strings"
+ "sync"
+
+ "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274
+ "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab-workhorse/-/issues/274
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/labkit/log"
+ "golang.org/x/net/context"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
+)
+
+type GitalyTestServer struct {
+ finalMessageCode codes.Code
+ sync.WaitGroup
+ LastIncomingMetadata metadata.MD
+ gitalypb.UnimplementedRepositoryServiceServer
+ gitalypb.UnimplementedBlobServiceServer
+ gitalypb.UnimplementedDiffServiceServer
+}
+
+var (
+ GitalyInfoRefsResponseMock = strings.Repeat("Mock Gitaly InfoRefsResponse data", 100000)
+ GitalyGetBlobResponseMock = strings.Repeat("Mock Gitaly GetBlobResponse data", 100000)
+ GitalyGetArchiveResponseMock = strings.Repeat("Mock Gitaly GetArchiveResponse data", 100000)
+ GitalyGetDiffResponseMock = strings.Repeat("Mock Gitaly GetDiffResponse data", 100000)
+ GitalyGetPatchResponseMock = strings.Repeat("Mock Gitaly GetPatchResponse data", 100000)
+
+ GitalyGetSnapshotResponseMock = strings.Repeat("Mock Gitaly GetSnapshotResponse data", 100000)
+
+ GitalyReceivePackResponseMock []byte
+ GitalyUploadPackResponseMock []byte
+)
+
+func init() {
+ var err error
+ if GitalyReceivePackResponseMock, err = ioutil.ReadFile(path.Join(RootDir(), "testdata/receive-pack-fixture.txt")); err != nil {
+ log.WithError(err).Fatal("Unable to read pack response")
+ }
+ if GitalyUploadPackResponseMock, err = ioutil.ReadFile(path.Join(RootDir(), "testdata/upload-pack-fixture.txt")); err != nil {
+ log.WithError(err).Fatal("Unable to read pack response")
+ }
+}
+
+func NewGitalyServer(finalMessageCode codes.Code) *GitalyTestServer {
+ return &GitalyTestServer{finalMessageCode: finalMessageCode}
+}
+
+func (s *GitalyTestServer) InfoRefsUploadPack(in *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsUploadPackServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ if err := validateRepository(in.GetRepository()); err != nil {
+ return err
+ }
+
+ fmt.Printf("Result: %+v\n", in)
+
+ marshaler := &jsonpb.Marshaler{}
+ jsonString, err := marshaler.MarshalToString(in)
+ if err != nil {
+ return err
+ }
+
+ data := []byte(strings.Join([]string{
+ jsonString,
+ "git-upload-pack",
+ GitalyInfoRefsResponseMock,
+ }, "\000"))
+
+ s.LastIncomingMetadata = nil
+ if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
+ s.LastIncomingMetadata = md
+ }
+
+ return s.sendInfoRefs(stream, data)
+}
+
+func (s *GitalyTestServer) InfoRefsReceivePack(in *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsReceivePackServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ if err := validateRepository(in.GetRepository()); err != nil {
+ return err
+ }
+
+ fmt.Printf("Result: %+v\n", in)
+
+ jsonString, err := marshalJSON(in)
+ if err != nil {
+ return err
+ }
+
+ data := []byte(strings.Join([]string{
+ jsonString,
+ "git-receive-pack",
+ GitalyInfoRefsResponseMock,
+ }, "\000"))
+
+ return s.sendInfoRefs(stream, data)
+}
+
+func marshalJSON(msg proto.Message) (string, error) {
+ marshaler := &jsonpb.Marshaler{}
+ return marshaler.MarshalToString(msg)
+}
+
+type infoRefsSender interface {
+ Send(*gitalypb.InfoRefsResponse) error
+}
+
+func (s *GitalyTestServer) sendInfoRefs(stream infoRefsSender, data []byte) error {
+ nSends, err := sendBytes(data, 100, func(p []byte) error {
+ return stream.Send(&gitalypb.InfoRefsResponse{Data: p})
+ })
+ if err != nil {
+ return err
+ }
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+func (s *GitalyTestServer) PostReceivePack(stream gitalypb.SmartHTTPService_PostReceivePackServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ repo := req.GetRepository()
+ if err := validateRepository(repo); err != nil {
+ return err
+ }
+
+ jsonString, err := marshalJSON(req)
+ if err != nil {
+ return err
+ }
+
+ data := []byte(jsonString + "\000")
+
+ // The body of the request starts in the second message
+ for {
+ req, err := stream.Recv()
+ if err != nil {
+ if err != io.EOF {
+ return err
+ }
+ break
+ }
+
+ // We want to echo the request data back
+ data = append(data, req.GetData()...)
+ }
+
+ nSends, _ := sendBytes(data, 100, func(p []byte) error {
+ return stream.Send(&gitalypb.PostReceivePackResponse{Data: p})
+ })
+
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+func (s *GitalyTestServer) PostUploadPack(stream gitalypb.SmartHTTPService_PostUploadPackServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ if err := validateRepository(req.GetRepository()); err != nil {
+ return err
+ }
+
+ jsonString, err := marshalJSON(req)
+ if err != nil {
+ return err
+ }
+
+ if err := stream.Send(&gitalypb.PostUploadPackResponse{
+ Data: []byte(strings.Join([]string{jsonString}, "\000") + "\000"),
+ }); err != nil {
+ return err
+ }
+
+ nSends := 0
+ // The body of the request starts in the second message. Gitaly streams PostUploadPack responses
+ // as soon as possible without reading the request completely first. We stream messages here
+ // directly back to the client to simulate the streaming of the actual implementation.
+ for {
+ req, err := stream.Recv()
+ if err != nil {
+ if err != io.EOF {
+ return err
+ }
+ break
+ }
+
+ if err := stream.Send(&gitalypb.PostUploadPackResponse{Data: req.GetData()}); err != nil {
+ return err
+ }
+
+ nSends++
+ }
+
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+func (s *GitalyTestServer) CommitIsAncestor(ctx context.Context, in *gitalypb.CommitIsAncestorRequest) (*gitalypb.CommitIsAncestorResponse, error) {
+ return nil, nil
+}
+
+func (s *GitalyTestServer) GetBlob(in *gitalypb.GetBlobRequest, stream gitalypb.BlobService_GetBlobServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ if err := validateRepository(in.GetRepository()); err != nil {
+ return err
+ }
+
+ response := &gitalypb.GetBlobResponse{
+ Oid: in.GetOid(),
+ Size: int64(len(GitalyGetBlobResponseMock)),
+ }
+ nSends, err := sendBytes([]byte(GitalyGetBlobResponseMock), 100, func(p []byte) error {
+ response.Data = p
+
+ if err := stream.Send(response); err != nil {
+ return err
+ }
+
+ // Use a new response so we don't send other fields (Size, ...) over and over
+ response = &gitalypb.GetBlobResponse{}
+
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+func (s *GitalyTestServer) GetArchive(in *gitalypb.GetArchiveRequest, stream gitalypb.RepositoryService_GetArchiveServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ if err := validateRepository(in.GetRepository()); err != nil {
+ return err
+ }
+
+ nSends, err := sendBytes([]byte(GitalyGetArchiveResponseMock), 100, func(p []byte) error {
+ return stream.Send(&gitalypb.GetArchiveResponse{Data: p})
+ })
+ if err != nil {
+ return err
+ }
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+func (s *GitalyTestServer) RawDiff(in *gitalypb.RawDiffRequest, stream gitalypb.DiffService_RawDiffServer) error {
+ nSends, err := sendBytes([]byte(GitalyGetDiffResponseMock), 100, func(p []byte) error {
+ return stream.Send(&gitalypb.RawDiffResponse{
+ Data: p,
+ })
+ })
+ if err != nil {
+ return err
+ }
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+func (s *GitalyTestServer) RawPatch(in *gitalypb.RawPatchRequest, stream gitalypb.DiffService_RawPatchServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ if err := validateRepository(in.GetRepository()); err != nil {
+ return err
+ }
+
+ nSends, err := sendBytes([]byte(GitalyGetPatchResponseMock), 100, func(p []byte) error {
+ return stream.Send(&gitalypb.RawPatchResponse{
+ Data: p,
+ })
+ })
+ if err != nil {
+ return err
+ }
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+func (s *GitalyTestServer) GetSnapshot(in *gitalypb.GetSnapshotRequest, stream gitalypb.RepositoryService_GetSnapshotServer) error {
+ s.WaitGroup.Add(1)
+ defer s.WaitGroup.Done()
+
+ if err := validateRepository(in.GetRepository()); err != nil {
+ return err
+ }
+
+ nSends, err := sendBytes([]byte(GitalyGetSnapshotResponseMock), 100, func(p []byte) error {
+ return stream.Send(&gitalypb.GetSnapshotResponse{Data: p})
+ })
+ if err != nil {
+ return err
+ }
+ if nSends <= 1 {
+ panic("should have sent more than one message")
+ }
+
+ return s.finalError()
+}
+
+// sendBytes returns the number of times the 'sender' function was called and an error.
+func sendBytes(data []byte, chunkSize int, sender func([]byte) error) (int, error) {
+ i := 0
+ for ; len(data) > 0; i++ {
+ n := chunkSize
+ if n > len(data) {
+ n = len(data)
+ }
+
+ if err := sender(data[:n]); err != nil {
+ return i, err
+ }
+ data = data[n:]
+ }
+
+ return i, nil
+}
+
+func (s *GitalyTestServer) finalError() error {
+ if code := s.finalMessageCode; code != codes.OK {
+ return status.Errorf(code, "error as specified by test")
+ }
+
+ return nil
+}
+
+func validateRepository(repo *gitalypb.Repository) error {
+ if len(repo.GetStorageName()) == 0 {
+ return fmt.Errorf("missing storage_name: %v", repo)
+ }
+ if len(repo.GetRelativePath()) == 0 {
+ return fmt.Errorf("missing relative_path: %v", repo)
+ }
+ return nil
+}
diff --git a/workhorse/internal/testhelper/testhelper.go b/workhorse/internal/testhelper/testhelper.go
new file mode 100644
index 00000000000..40097bd453a
--- /dev/null
+++ b/workhorse/internal/testhelper/testhelper.go
@@ -0,0 +1,152 @@
+package testhelper
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path"
+ "regexp"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/dgrijalva/jwt-go"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
+)
+
+func ConfigureSecret() {
+ secret.SetPath(path.Join(RootDir(), "testdata/test-secret"))
+}
+
+func RequireResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) {
+ t.Helper()
+ require.Equal(t, expectedBody, response.Body.String(), "response body")
+}
+
+func RequireResponseHeader(t *testing.T, w interface{}, header string, expected ...string) {
+ t.Helper()
+ var actual []string
+
+ header = http.CanonicalHeaderKey(header)
+ type headerer interface{ Header() http.Header }
+
+ switch resp := w.(type) {
+ case *http.Response:
+ actual = resp.Header[header]
+ case headerer:
+ actual = resp.Header()[header]
+ default:
+ t.Fatal("invalid type of w passed RequireResponseHeader")
+ }
+
+ require.Equal(t, expected, actual, "values for HTTP header %s", header)
+}
+
+func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ logEntry := log.WithFields(log.Fields{
+ "method": r.Method,
+ "url": r.URL,
+ "action": "DENY",
+ })
+
+ if url != nil && !url.MatchString(r.URL.Path) {
+ logEntry.Info("UPSTREAM")
+ w.WriteHeader(404)
+ return
+ }
+
+ if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
+ logEntry.Info("UPSTREAM")
+ w.WriteHeader(403)
+ return
+ }
+
+ handler(w, r)
+ }))
+}
+
+var workhorseExecutables = []string{"gitlab-workhorse", "gitlab-zip-cat", "gitlab-zip-metadata", "gitlab-resize-image"}
+
+func BuildExecutables() error {
+ rootDir := RootDir()
+
+ for _, exe := range workhorseExecutables {
+ if _, err := os.Stat(path.Join(rootDir, exe)); os.IsNotExist(err) {
+ return fmt.Errorf("cannot find executable %s. Please run 'make prepare-tests'", exe)
+ }
+ }
+
+ oldPath := os.Getenv("PATH")
+ testPath := fmt.Sprintf("%s:%s", rootDir, oldPath)
+ if err := os.Setenv("PATH", testPath); err != nil {
+ return fmt.Errorf("failed to set PATH to %v", testPath)
+ }
+
+ return nil
+}
+
+func RootDir() string {
+ _, currentFile, _, ok := runtime.Caller(0)
+ if !ok {
+ panic(errors.New("RootDir: calling runtime.Caller failed"))
+ }
+ return path.Join(path.Dir(currentFile), "../..")
+}
+
+func LoadFile(t *testing.T, filePath string) string {
+ t.Helper()
+ content, err := ioutil.ReadFile(path.Join(RootDir(), filePath))
+ require.NoError(t, err)
+ return string(content)
+}
+
+func ReadAll(t *testing.T, r io.Reader) []byte {
+ t.Helper()
+
+ b, err := ioutil.ReadAll(r)
+ require.NoError(t, err)
+ return b
+}
+
+func ParseJWT(token *jwt.Token) (interface{}, error) {
+ // Don't forget to validate the alg is what you expect:
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+ }
+
+ ConfigureSecret()
+ secretBytes, err := secret.Bytes()
+ if err != nil {
+ return nil, fmt.Errorf("read secret from file: %v", err)
+ }
+
+ return secretBytes, nil
+}
+
+// UploadClaims represents the JWT claim for upload parameters
+type UploadClaims struct {
+ Upload map[string]string `json:"upload"`
+ jwt.StandardClaims
+}
+
+func Retry(t testing.TB, timeout time.Duration, fn func() error) {
+ t.Helper()
+ start := time.Now()
+ var err error
+ for ; time.Since(start) < timeout; time.Sleep(time.Millisecond) {
+ err = fn()
+ if err == nil {
+ return
+ }
+ }
+ t.Fatalf("test timeout after %v; last error: %v", timeout, err)
+}
diff --git a/workhorse/internal/upload/accelerate.go b/workhorse/internal/upload/accelerate.go
new file mode 100644
index 00000000000..7d8ea51b14d
--- /dev/null
+++ b/workhorse/internal/upload/accelerate.go
@@ -0,0 +1,32 @@
+package upload
+
+import (
+ "fmt"
+ "net/http"
+
+ "github.com/dgrijalva/jwt-go"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+const RewrittenFieldsHeader = "Gitlab-Workhorse-Multipart-Fields"
+
+type MultipartClaims struct {
+ RewrittenFields map[string]string `json:"rewritten_fields"`
+ jwt.StandardClaims
+}
+
+func Accelerate(rails PreAuthorizer, h http.Handler, p Preparer) http.Handler {
+ return rails.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
+ s := &SavedFileTracker{Request: r}
+
+ opts, _, err := p.Prepare(a)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("Accelerate: error preparing file storage options"))
+ return
+ }
+
+ HandleFileUploads(w, r, h, a, s, opts)
+ }, "/authorize")
+}
diff --git a/workhorse/internal/upload/body_uploader.go b/workhorse/internal/upload/body_uploader.go
new file mode 100644
index 00000000000..2cee90195fb
--- /dev/null
+++ b/workhorse/internal/upload/body_uploader.go
@@ -0,0 +1,90 @@
+package upload
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+type PreAuthorizer interface {
+ PreAuthorizeHandler(next api.HandleFunc, suffix string) http.Handler
+}
+
+// Verifier allows to check an upload before sending it to rails
+type Verifier interface {
+ // Verify can abort the upload returning an error
+ Verify(handler *filestore.FileHandler) error
+}
+
+// Preparer allows to customize BodyUploader configuration
+type Preparer interface {
+ // Prepare converts api.Response into a *SaveFileOpts, it can optionally return an Verifier that will be
+ // invoked after the real upload, before the finalization with rails
+ Prepare(a *api.Response) (*filestore.SaveFileOpts, Verifier, error)
+}
+
+type DefaultPreparer struct{}
+
+func (s *DefaultPreparer) Prepare(a *api.Response) (*filestore.SaveFileOpts, Verifier, error) {
+ opts, err := filestore.GetOpts(a)
+ return opts, nil, err
+}
+
+// BodyUploader is an http.Handler that perform a pre authorization call to rails before hijacking the request body and
+// uploading it.
+// Providing an Preparer allows to customize the upload process
+func BodyUploader(rails PreAuthorizer, h http.Handler, p Preparer) http.Handler {
+ return rails.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
+ opts, verifier, err := p.Prepare(a)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("BodyUploader: preparation failed: %v", err))
+ return
+ }
+
+ fh, err := filestore.SaveFileFromReader(r.Context(), r.Body, r.ContentLength, opts)
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("BodyUploader: upload failed: %v", err))
+ return
+ }
+
+ if verifier != nil {
+ if err := verifier.Verify(fh); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("BodyUploader: verification failed: %v", err))
+ return
+ }
+ }
+
+ data := url.Values{}
+ fields, err := fh.GitLabFinalizeFields("file")
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("BodyUploader: finalize fields failed: %v", err))
+ return
+ }
+
+ for k, v := range fields {
+ data.Set(k, v)
+ }
+
+ // Hijack body
+ body := data.Encode()
+ r.Body = ioutil.NopCloser(strings.NewReader(body))
+ r.ContentLength = int64(len(body))
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ sft := SavedFileTracker{Request: r}
+ sft.Track("file", fh.LocalPath)
+ if err := sft.Finalize(r.Context()); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("BodyUploader: finalize failed: %v", err))
+ return
+ }
+
+ // And proxy the request
+ h.ServeHTTP(w, r)
+ }, "/authorize")
+}
diff --git a/workhorse/internal/upload/body_uploader_test.go b/workhorse/internal/upload/body_uploader_test.go
new file mode 100644
index 00000000000..451d7c97fab
--- /dev/null
+++ b/workhorse/internal/upload/body_uploader_test.go
@@ -0,0 +1,195 @@
+package upload
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/dgrijalva/jwt-go"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+const (
+ fileContent = "A test file content"
+ fileLen = len(fileContent)
+)
+
+func TestBodyUploader(t *testing.T) {
+ testhelper.ConfigureSecret()
+
+ body := strings.NewReader(fileContent)
+
+ resp := testUpload(&rails{}, &alwaysLocalPreparer{}, echoProxy(t, fileLen), body)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ uploadEcho, err := ioutil.ReadAll(resp.Body)
+
+ require.NoError(t, err, "Can't read response body")
+ require.Equal(t, fileContent, string(uploadEcho))
+}
+
+func TestBodyUploaderCustomPreparer(t *testing.T) {
+ body := strings.NewReader(fileContent)
+
+ resp := testUpload(&rails{}, &alwaysLocalPreparer{}, echoProxy(t, fileLen), body)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ uploadEcho, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err, "Can't read response body")
+ require.Equal(t, fileContent, string(uploadEcho))
+}
+
+func TestBodyUploaderCustomVerifier(t *testing.T) {
+ body := strings.NewReader(fileContent)
+ verifier := &mockVerifier{}
+
+ resp := testUpload(&rails{}, &alwaysLocalPreparer{verifier: verifier}, echoProxy(t, fileLen), body)
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ uploadEcho, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err, "Can't read response body")
+ require.Equal(t, fileContent, string(uploadEcho))
+ require.True(t, verifier.invoked, "Verifier.Verify not invoked")
+}
+
+func TestBodyUploaderAuthorizationFailure(t *testing.T) {
+ testNoProxyInvocation(t, http.StatusUnauthorized, &rails{unauthorized: true}, &alwaysLocalPreparer{})
+}
+
+func TestBodyUploaderErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ preparer *alwaysLocalPreparer
+ }{
+ {name: "Prepare failure", preparer: &alwaysLocalPreparer{prepareError: fmt.Errorf("")}},
+ {name: "Verify failure", preparer: &alwaysLocalPreparer{verifier: &alwaysFailsVerifier{}}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testNoProxyInvocation(t, http.StatusInternalServerError, &rails{}, test.preparer)
+ })
+ }
+}
+
+func testNoProxyInvocation(t *testing.T, expectedStatus int, auth PreAuthorizer, preparer Preparer) {
+ proxy := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Fail(t, "request proxied upstream")
+ })
+
+ resp := testUpload(auth, preparer, proxy, nil)
+ require.Equal(t, expectedStatus, resp.StatusCode)
+}
+
+func testUpload(auth PreAuthorizer, preparer Preparer, proxy http.Handler, body io.Reader) *http.Response {
+ req := httptest.NewRequest("POST", "http://example.com/upload", body)
+ w := httptest.NewRecorder()
+
+ BodyUploader(auth, proxy, preparer).ServeHTTP(w, req)
+
+ return w.Result()
+}
+
+func echoProxy(t *testing.T, expectedBodyLength int) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseForm()
+ require.NoError(t, err)
+
+ require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"), "Wrong Content-Type header")
+
+ require.Contains(t, r.PostForm, "file.md5")
+ require.Contains(t, r.PostForm, "file.sha1")
+ require.Contains(t, r.PostForm, "file.sha256")
+ require.Contains(t, r.PostForm, "file.sha512")
+
+ require.Contains(t, r.PostForm, "file.path")
+ require.Contains(t, r.PostForm, "file.size")
+ require.Contains(t, r.PostForm, "file.gitlab-workhorse-upload")
+ require.Equal(t, strconv.Itoa(expectedBodyLength), r.PostFormValue("file.size"))
+
+ token, err := jwt.ParseWithClaims(r.Header.Get(RewrittenFieldsHeader), &MultipartClaims{}, testhelper.ParseJWT)
+ require.NoError(t, err, "Wrong JWT header")
+
+ rewrittenFields := token.Claims.(*MultipartClaims).RewrittenFields
+ if len(rewrittenFields) != 1 || len(rewrittenFields["file"]) == 0 {
+ t.Fatalf("Unexpected rewritten_fields value: %v", rewrittenFields)
+ }
+
+ token, jwtErr := jwt.ParseWithClaims(r.PostFormValue("file.gitlab-workhorse-upload"), &testhelper.UploadClaims{}, testhelper.ParseJWT)
+ require.NoError(t, jwtErr, "Wrong signed upload fields")
+
+ uploadFields := token.Claims.(*testhelper.UploadClaims).Upload
+ require.Contains(t, uploadFields, "name")
+ require.Contains(t, uploadFields, "path")
+ require.Contains(t, uploadFields, "remote_url")
+ require.Contains(t, uploadFields, "remote_id")
+ require.Contains(t, uploadFields, "size")
+ require.Contains(t, uploadFields, "md5")
+ require.Contains(t, uploadFields, "sha1")
+ require.Contains(t, uploadFields, "sha256")
+ require.Contains(t, uploadFields, "sha512")
+
+ path := r.PostFormValue("file.path")
+ uploaded, err := os.Open(path)
+ require.NoError(t, err, "File not uploaded")
+
+ //sending back the file for testing purpose
+ io.Copy(w, uploaded)
+ })
+}
+
+type rails struct {
+ unauthorized bool
+}
+
+func (r *rails) PreAuthorizeHandler(next api.HandleFunc, _ string) http.Handler {
+ if r.unauthorized {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusUnauthorized)
+ })
+ }
+
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ next(w, r, &api.Response{TempPath: os.TempDir()})
+ })
+}
+
+type alwaysLocalPreparer struct {
+ verifier Verifier
+ prepareError error
+}
+
+func (a *alwaysLocalPreparer) Prepare(_ *api.Response) (*filestore.SaveFileOpts, Verifier, error) {
+ opts, err := filestore.GetOpts(&api.Response{TempPath: os.TempDir()})
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return opts, a.verifier, a.prepareError
+}
+
+type alwaysFailsVerifier struct{}
+
+func (alwaysFailsVerifier) Verify(handler *filestore.FileHandler) error {
+ return fmt.Errorf("Verification failed")
+}
+
+type mockVerifier struct {
+ invoked bool
+}
+
+func (m *mockVerifier) Verify(handler *filestore.FileHandler) error {
+ m.invoked = true
+
+ return nil
+}
diff --git a/workhorse/internal/upload/exif/exif.go b/workhorse/internal/upload/exif/exif.go
new file mode 100644
index 00000000000..a9307b1ca90
--- /dev/null
+++ b/workhorse/internal/upload/exif/exif.go
@@ -0,0 +1,107 @@
+package exif
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os/exec"
+ "regexp"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+var ErrRemovingExif = errors.New("error while removing EXIF")
+
+type cleaner struct {
+ ctx context.Context
+ cmd *exec.Cmd
+ stdout io.Reader
+ stderr bytes.Buffer
+ eof bool
+}
+
+func NewCleaner(ctx context.Context, stdin io.Reader) (io.ReadCloser, error) {
+ c := &cleaner{ctx: ctx}
+
+ if err := c.startProcessing(stdin); err != nil {
+ return nil, err
+ }
+
+ return c, nil
+}
+
+func (c *cleaner) Close() error {
+ if c.cmd == nil {
+ return nil
+ }
+
+ return c.cmd.Wait()
+}
+
+func (c *cleaner) Read(p []byte) (int, error) {
+ if c.eof {
+ return 0, io.EOF
+ }
+
+ n, err := c.stdout.Read(p)
+ if err == io.EOF {
+ if waitErr := c.cmd.Wait(); waitErr != nil {
+ log.WithContextFields(c.ctx, log.Fields{
+ "command": c.cmd.Args,
+ "stderr": c.stderr.String(),
+ "error": waitErr.Error(),
+ }).Print("exiftool command failed")
+
+ return n, ErrRemovingExif
+ }
+
+ c.eof = true
+ }
+
+ return n, err
+}
+
+func (c *cleaner) startProcessing(stdin io.Reader) error {
+ var err error
+
+ whitelisted_tags := []string{
+ "-ResolutionUnit",
+ "-XResolution",
+ "-YResolution",
+ "-YCbCrSubSampling",
+ "-YCbCrPositioning",
+ "-BitsPerSample",
+ "-ImageHeight",
+ "-ImageWidth",
+ "-ImageSize",
+ "-Copyright",
+ "-CopyrightNotice",
+ "-Orientation",
+ }
+
+ args := append([]string{"-all=", "--IPTC:all", "--XMP-iptcExt:all", "-tagsFromFile", "@"}, whitelisted_tags...)
+ args = append(args, "-")
+ c.cmd = exec.CommandContext(c.ctx, "exiftool", args...)
+
+ c.cmd.Stderr = &c.stderr
+ c.cmd.Stdin = stdin
+
+ c.stdout, err = c.cmd.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stdout pipe: %v", err)
+ }
+
+ if err = c.cmd.Start(); err != nil {
+ return fmt.Errorf("start %v: %v", c.cmd.Args, err)
+ }
+
+ return nil
+}
+
+func IsExifFile(filename string) bool {
+ filenameMatch := regexp.MustCompile(`(?i)\.(jpg|jpeg|tiff)$`)
+
+ return filenameMatch.MatchString(filename)
+}
diff --git a/workhorse/internal/upload/exif/exif_test.go b/workhorse/internal/upload/exif/exif_test.go
new file mode 100644
index 00000000000..373d97f7fce
--- /dev/null
+++ b/workhorse/internal/upload/exif/exif_test.go
@@ -0,0 +1,95 @@
+package exif
+
+import (
+ "context"
+ "io"
+ "io/ioutil"
+ "os"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsExifFile(t *testing.T) {
+ tests := []struct {
+ name string
+ expected bool
+ }{
+ {
+ name: "/full/path.jpg",
+ expected: true,
+ },
+ {
+ name: "path.jpeg",
+ expected: true,
+ },
+ {
+ name: "path.tiff",
+ expected: true,
+ },
+ {
+ name: "path.JPG",
+ expected: true,
+ },
+ {
+ name: "path.tar",
+ expected: false,
+ },
+ {
+ name: "path",
+ expected: false,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ require.Equal(t, test.expected, IsExifFile(test.name))
+ })
+ }
+}
+
+func TestNewCleanerWithValidFile(t *testing.T) {
+ input, err := os.Open("testdata/sample_exif.jpg")
+ require.NoError(t, err)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ cleaner, err := NewCleaner(ctx, input)
+ require.NoError(t, err, "Expected no error when creating cleaner command")
+
+ size, err := io.Copy(ioutil.Discard, cleaner)
+ require.NoError(t, err, "Expected no error when reading output")
+
+ sizeAfterStrip := int64(25399)
+ require.Equal(t, sizeAfterStrip, size, "Different size of converted image")
+}
+
+func TestNewCleanerWithInvalidFile(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ cleaner, err := NewCleaner(ctx, strings.NewReader("invalid image"))
+ require.NoError(t, err, "Expected no error when creating cleaner command")
+
+ size, err := io.Copy(ioutil.Discard, cleaner)
+ require.Error(t, err, "Expected error when reading output")
+ require.Equal(t, int64(0), size, "Size of invalid image should be 0")
+}
+
+func TestNewCleanerReadingAfterEOF(t *testing.T) {
+ input, err := os.Open("testdata/sample_exif.jpg")
+ require.NoError(t, err)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ cleaner, err := NewCleaner(ctx, input)
+ require.NoError(t, err, "Expected no error when creating cleaner command")
+
+ _, err = io.Copy(ioutil.Discard, cleaner)
+ require.NoError(t, err, "Expected no error when reading output")
+
+ buf := make([]byte, 1)
+ size, err := cleaner.Read(buf)
+ require.Equal(t, 0, size, "The output was already consumed by previous reads")
+ require.Equal(t, io.EOF, err, "We return EOF")
+}
diff --git a/workhorse/internal/upload/exif/testdata/sample_exif.jpg b/workhorse/internal/upload/exif/testdata/sample_exif.jpg
new file mode 100644
index 00000000000..05eda3f7f95
--- /dev/null
+++ b/workhorse/internal/upload/exif/testdata/sample_exif.jpg
Binary files differ
diff --git a/workhorse/internal/upload/object_storage_preparer.go b/workhorse/internal/upload/object_storage_preparer.go
new file mode 100644
index 00000000000..7a113fae80a
--- /dev/null
+++ b/workhorse/internal/upload/object_storage_preparer.go
@@ -0,0 +1,28 @@
+package upload
+
+import (
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+)
+
+type ObjectStoragePreparer struct {
+ config config.ObjectStorageConfig
+ credentials config.ObjectStorageCredentials
+}
+
+func NewObjectStoragePreparer(c config.Config) Preparer {
+ return &ObjectStoragePreparer{credentials: c.ObjectStorageCredentials, config: c.ObjectStorageConfig}
+}
+
+func (p *ObjectStoragePreparer) Prepare(a *api.Response) (*filestore.SaveFileOpts, Verifier, error) {
+ opts, err := filestore.GetOpts(a)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ opts.ObjectStorageConfig.URLMux = p.config.URLMux
+ opts.ObjectStorageConfig.S3Credentials = p.credentials.S3Credentials
+
+ return opts, nil, nil
+}
diff --git a/workhorse/internal/upload/object_storage_preparer_test.go b/workhorse/internal/upload/object_storage_preparer_test.go
new file mode 100644
index 00000000000..613b6071275
--- /dev/null
+++ b/workhorse/internal/upload/object_storage_preparer_test.go
@@ -0,0 +1,62 @@
+package upload_test
+
+import (
+ "testing"
+
+ "gocloud.dev/blob"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestPrepareWithS3Config(t *testing.T) {
+ creds := config.S3Credentials{
+ AwsAccessKeyID: "test-key",
+ AwsSecretAccessKey: "test-secret",
+ }
+
+ c := config.Config{
+ ObjectStorageCredentials: config.ObjectStorageCredentials{
+ Provider: "AWS",
+ S3Credentials: creds,
+ },
+ ObjectStorageConfig: config.ObjectStorageConfig{
+ URLMux: new(blob.URLMux),
+ },
+ }
+
+ r := &api.Response{
+ RemoteObject: api.RemoteObject{
+ ID: "the ID",
+ UseWorkhorseClient: true,
+ ObjectStorage: &api.ObjectStorageParams{
+ Provider: "AWS",
+ },
+ },
+ }
+
+ p := upload.NewObjectStoragePreparer(c)
+ opts, v, err := p.Prepare(r)
+
+ require.NoError(t, err)
+ require.True(t, opts.ObjectStorageConfig.IsAWS())
+ require.True(t, opts.UseWorkhorseClient)
+ require.Equal(t, creds, opts.ObjectStorageConfig.S3Credentials)
+ require.NotNil(t, opts.ObjectStorageConfig.URLMux)
+ require.Equal(t, nil, v)
+}
+
+func TestPrepareWithNoConfig(t *testing.T) {
+ c := config.Config{}
+ r := &api.Response{RemoteObject: api.RemoteObject{ID: "id"}}
+ p := upload.NewObjectStoragePreparer(c)
+ opts, v, err := p.Prepare(r)
+
+ require.NoError(t, err)
+ require.False(t, opts.UseWorkhorseClient)
+ require.Nil(t, v)
+ require.Nil(t, opts.ObjectStorageConfig.URLMux)
+}
diff --git a/workhorse/internal/upload/rewrite.go b/workhorse/internal/upload/rewrite.go
new file mode 100644
index 00000000000..e51604c6ed9
--- /dev/null
+++ b/workhorse/internal/upload/rewrite.go
@@ -0,0 +1,203 @@
+package upload
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "mime/multipart"
+ "net/http"
+ "strings"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/lsif_transformer/parser"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload/exif"
+)
+
+// ErrInjectedClientParam means that the client sent a parameter that overrides one of our own fields
+var ErrInjectedClientParam = errors.New("injected client parameter")
+
+var (
+ multipartUploadRequests = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+
+ Name: "gitlab_workhorse_multipart_upload_requests",
+ Help: "How many multipart upload requests have been processed by gitlab-workhorse. Partitioned by type.",
+ },
+ []string{"type"},
+ )
+
+ multipartFileUploadBytes = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_multipart_upload_bytes",
+ Help: "How many disk bytes of multipart file parts have been successfully written by gitlab-workhorse. Partitioned by type.",
+ },
+ []string{"type"},
+ )
+
+ multipartFiles = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_multipart_upload_files",
+ Help: "How many multipart file parts have been processed by gitlab-workhorse. Partitioned by type.",
+ },
+ []string{"type"},
+ )
+)
+
+type rewriter struct {
+ writer *multipart.Writer
+ preauth *api.Response
+ filter MultipartFormProcessor
+ finalizedFields map[string]bool
+}
+
+func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, preauth *api.Response, filter MultipartFormProcessor, opts *filestore.SaveFileOpts) error {
+ // Create multipart reader
+ reader, err := r.MultipartReader()
+ if err != nil {
+ if err == http.ErrNotMultipart {
+ // We want to be able to recognize http.ErrNotMultipart elsewhere so no fmt.Errorf
+ return http.ErrNotMultipart
+ }
+ return fmt.Errorf("get multipart reader: %v", err)
+ }
+
+ multipartUploadRequests.WithLabelValues(filter.Name()).Inc()
+
+ rew := &rewriter{
+ writer: writer,
+ preauth: preauth,
+ filter: filter,
+ finalizedFields: make(map[string]bool),
+ }
+
+ for {
+ p, err := reader.NextPart()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return err
+ }
+
+ name := p.FormName()
+ if name == "" {
+ continue
+ }
+
+ if rew.finalizedFields[name] {
+ return ErrInjectedClientParam
+ }
+
+ if p.FileName() != "" {
+ err = rew.handleFilePart(r.Context(), name, p, opts)
+ } else {
+ err = rew.copyPart(r.Context(), name, p)
+ }
+
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (rew *rewriter) handleFilePart(ctx context.Context, name string, p *multipart.Part, opts *filestore.SaveFileOpts) error {
+ multipartFiles.WithLabelValues(rew.filter.Name()).Inc()
+
+ filename := p.FileName()
+
+ if strings.Contains(filename, "/") || filename == "." || filename == ".." {
+ return fmt.Errorf("illegal filename: %q", filename)
+ }
+
+ opts.TempFilePrefix = filename
+
+ var inputReader io.ReadCloser
+ var err error
+ switch {
+ case exif.IsExifFile(filename):
+ inputReader, err = handleExifUpload(ctx, p, filename)
+ if err != nil {
+ return err
+ }
+ case rew.preauth.ProcessLsif:
+ inputReader, err = handleLsifUpload(ctx, p, opts.LocalTempPath, filename, rew.preauth)
+ if err != nil {
+ return err
+ }
+ default:
+ inputReader = ioutil.NopCloser(p)
+ }
+
+ defer inputReader.Close()
+
+ fh, err := filestore.SaveFileFromReader(ctx, inputReader, -1, opts)
+ if err != nil {
+ switch err {
+ case filestore.ErrEntityTooLarge, exif.ErrRemovingExif:
+ return err
+ default:
+ return fmt.Errorf("persisting multipart file: %v", err)
+ }
+ }
+
+ fields, err := fh.GitLabFinalizeFields(name)
+ if err != nil {
+ return fmt.Errorf("failed to finalize fields: %v", err)
+ }
+
+ for key, value := range fields {
+ rew.writer.WriteField(key, value)
+ rew.finalizedFields[key] = true
+ }
+
+ multipartFileUploadBytes.WithLabelValues(rew.filter.Name()).Add(float64(fh.Size))
+
+ return rew.filter.ProcessFile(ctx, name, fh, rew.writer)
+}
+
+func handleExifUpload(ctx context.Context, r io.Reader, filename string) (io.ReadCloser, error) {
+ log.WithContextFields(ctx, log.Fields{
+ "filename": filename,
+ }).Print("running exiftool to remove any metadata")
+
+ cleaner, err := exif.NewCleaner(ctx, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return cleaner, nil
+}
+
+func handleLsifUpload(ctx context.Context, reader io.Reader, tempPath, filename string, preauth *api.Response) (io.ReadCloser, error) {
+ parserConfig := parser.Config{
+ TempPath: tempPath,
+ }
+
+ return parser.NewParser(ctx, reader, parserConfig)
+}
+
+func (rew *rewriter) copyPart(ctx context.Context, name string, p *multipart.Part) error {
+ np, err := rew.writer.CreatePart(p.Header)
+ if err != nil {
+ return fmt.Errorf("create multipart field: %v", err)
+ }
+
+ if _, err := io.Copy(np, p); err != nil {
+ return fmt.Errorf("duplicate multipart field: %v", err)
+ }
+
+ if err := rew.filter.ProcessField(ctx, name, rew.writer); err != nil {
+ return fmt.Errorf("process multipart field: %v", err)
+ }
+
+ return nil
+}
diff --git a/workhorse/internal/upload/saved_file_tracker.go b/workhorse/internal/upload/saved_file_tracker.go
new file mode 100644
index 00000000000..7b6cade4faa
--- /dev/null
+++ b/workhorse/internal/upload/saved_file_tracker.go
@@ -0,0 +1,55 @@
+package upload
+
+import (
+ "context"
+ "fmt"
+ "mime/multipart"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
+)
+
+type SavedFileTracker struct {
+ Request *http.Request
+ rewrittenFields map[string]string
+}
+
+func (s *SavedFileTracker) Track(fieldName string, localPath string) {
+ if s.rewrittenFields == nil {
+ s.rewrittenFields = make(map[string]string)
+ }
+ s.rewrittenFields[fieldName] = localPath
+}
+
+func (s *SavedFileTracker) Count() int {
+ return len(s.rewrittenFields)
+}
+
+func (s *SavedFileTracker) ProcessFile(_ context.Context, fieldName string, file *filestore.FileHandler, _ *multipart.Writer) error {
+ s.Track(fieldName, file.LocalPath)
+ return nil
+}
+
+func (s *SavedFileTracker) ProcessField(_ context.Context, _ string, _ *multipart.Writer) error {
+ return nil
+}
+
+func (s *SavedFileTracker) Finalize(_ context.Context) error {
+ if s.rewrittenFields == nil {
+ return nil
+ }
+
+ claims := MultipartClaims{RewrittenFields: s.rewrittenFields, StandardClaims: secret.DefaultClaims}
+ tokenString, err := secret.JWTTokenString(claims)
+ if err != nil {
+ return fmt.Errorf("savedFileTracker.Finalize: %v", err)
+ }
+
+ s.Request.Header.Set(RewrittenFieldsHeader, tokenString)
+ return nil
+}
+
+func (s *SavedFileTracker) Name() string {
+ return "accelerate"
+}
diff --git a/workhorse/internal/upload/saved_file_tracker_test.go b/workhorse/internal/upload/saved_file_tracker_test.go
new file mode 100644
index 00000000000..e5a5e8f23a7
--- /dev/null
+++ b/workhorse/internal/upload/saved_file_tracker_test.go
@@ -0,0 +1,39 @@
+package upload
+
+import (
+ "context"
+
+ "github.com/dgrijalva/jwt-go"
+
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+)
+
+func TestSavedFileTracking(t *testing.T) {
+ testhelper.ConfigureSecret()
+
+ r, err := http.NewRequest("PUT", "/url/path", nil)
+ require.NoError(t, err)
+
+ tracker := SavedFileTracker{Request: r}
+ require.Equal(t, "accelerate", tracker.Name())
+
+ file := &filestore.FileHandler{}
+ ctx := context.Background()
+ tracker.ProcessFile(ctx, "test", file, nil)
+ require.Equal(t, 1, tracker.Count())
+
+ tracker.Finalize(ctx)
+ token, err := jwt.ParseWithClaims(r.Header.Get(RewrittenFieldsHeader), &MultipartClaims{}, testhelper.ParseJWT)
+ require.NoError(t, err)
+
+ rewrittenFields := token.Claims.(*MultipartClaims).RewrittenFields
+ require.Equal(t, 1, len(rewrittenFields))
+
+ require.Contains(t, rewrittenFields, "test")
+}
diff --git a/workhorse/internal/upload/skip_rails_authorizer.go b/workhorse/internal/upload/skip_rails_authorizer.go
new file mode 100644
index 00000000000..716467b8841
--- /dev/null
+++ b/workhorse/internal/upload/skip_rails_authorizer.go
@@ -0,0 +1,22 @@
+package upload
+
+import (
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+)
+
+// SkipRailsAuthorizer implements a fake PreAuthorizer that do not calls rails API and
+// authorize each call as a local only upload to TempPath
+type SkipRailsAuthorizer struct {
+ // TempPath is the temporary path for a local only upload
+ TempPath string
+}
+
+// PreAuthorizeHandler implements PreAuthorizer. It always grant the upload.
+// The fake API response contains only TempPath
+func (l *SkipRailsAuthorizer) PreAuthorizeHandler(next api.HandleFunc, _ string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ next(w, r, &api.Response{TempPath: l.TempPath})
+ })
+}
diff --git a/workhorse/internal/upload/uploads.go b/workhorse/internal/upload/uploads.go
new file mode 100644
index 00000000000..3be39f9518f
--- /dev/null
+++ b/workhorse/internal/upload/uploads.go
@@ -0,0 +1,66 @@
+package upload
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "mime/multipart"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload/exif"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
+)
+
+// These methods are allowed to have thread-unsafe implementations.
+type MultipartFormProcessor interface {
+ ProcessFile(ctx context.Context, formName string, file *filestore.FileHandler, writer *multipart.Writer) error
+ ProcessField(ctx context.Context, formName string, writer *multipart.Writer) error
+ Finalize(ctx context.Context) error
+ Name() string
+}
+
+func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, preauth *api.Response, filter MultipartFormProcessor, opts *filestore.SaveFileOpts) {
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ defer writer.Close()
+
+ // Rewrite multipart form data
+ err := rewriteFormFilesFromMultipart(r, writer, preauth, filter, opts)
+ if err != nil {
+ switch err {
+ case ErrInjectedClientParam:
+ helper.CaptureAndFail(w, r, err, "Bad Request", http.StatusBadRequest)
+ case http.ErrNotMultipart:
+ h.ServeHTTP(w, r)
+ case filestore.ErrEntityTooLarge:
+ helper.RequestEntityTooLarge(w, r, err)
+ case zipartifacts.ErrBadMetadata:
+ helper.RequestEntityTooLarge(w, r, err)
+ case exif.ErrRemovingExif:
+ helper.CaptureAndFail(w, r, err, "Failed to process image", http.StatusUnprocessableEntity)
+ default:
+ helper.Fail500(w, r, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
+ }
+ return
+ }
+
+ // Close writer
+ writer.Close()
+
+ // Hijack the request
+ r.Body = ioutil.NopCloser(&body)
+ r.ContentLength = int64(body.Len())
+ r.Header.Set("Content-Type", writer.FormDataContentType())
+
+ if err := filter.Finalize(r.Context()); err != nil {
+ helper.Fail500(w, r, fmt.Errorf("handleFileUploads: Finalize: %v", err))
+ return
+ }
+
+ // Proxy the request
+ h.ServeHTTP(w, r)
+}
diff --git a/workhorse/internal/upload/uploads_test.go b/workhorse/internal/upload/uploads_test.go
new file mode 100644
index 00000000000..fc1a1ac57ef
--- /dev/null
+++ b/workhorse/internal/upload/uploads_test.go
@@ -0,0 +1,475 @@
+package upload
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/objectstore/test"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream/roundtripper"
+)
+
+var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
+
+type testFormProcessor struct{}
+
+func (a *testFormProcessor) ProcessFile(ctx context.Context, formName string, file *filestore.FileHandler, writer *multipart.Writer) error {
+ return nil
+}
+
+func (a *testFormProcessor) ProcessField(ctx context.Context, formName string, writer *multipart.Writer) error {
+ if formName != "token" && !strings.HasPrefix(formName, "file.") && !strings.HasPrefix(formName, "other.") {
+ return fmt.Errorf("illegal field: %v", formName)
+ }
+ return nil
+}
+
+func (a *testFormProcessor) Finalize(ctx context.Context) error {
+ return nil
+}
+
+func (a *testFormProcessor) Name() string {
+ return ""
+}
+
+func TestUploadTempPathRequirement(t *testing.T) {
+ apiResponse := &api.Response{}
+ preparer := &DefaultPreparer{}
+ _, _, err := preparer.Prepare(apiResponse)
+ require.Error(t, err)
+}
+
+func TestUploadHandlerForwardingRawData(t *testing.T) {
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "PATCH", r.Method, "method")
+
+ body, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+ require.Equal(t, "REQUEST", string(body), "request body")
+
+ w.WriteHeader(202)
+ fmt.Fprint(w, "RESPONSE")
+ })
+ defer ts.Close()
+
+ httpRequest, err := http.NewRequest("PATCH", ts.URL+"/url/path", bytes.NewBufferString("REQUEST"))
+ require.NoError(t, err)
+
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ response := httptest.NewRecorder()
+
+ handler := newProxy(ts.URL)
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, handler, apiResponse, nil, opts)
+
+ require.Equal(t, 202, response.Code)
+ require.Equal(t, "RESPONSE", response.Body.String(), "response body")
+}
+
+func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
+ var filePath string
+
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "PUT", r.Method, "method")
+ require.NoError(t, r.ParseMultipartForm(100000))
+
+ require.Empty(t, r.MultipartForm.File, "Expected to not receive any files")
+ require.Equal(t, "test", r.FormValue("token"), "Expected to receive token")
+ require.Equal(t, "my.file", r.FormValue("file.name"), "Expected to receive a filename")
+
+ filePath = r.FormValue("file.path")
+ require.True(t, strings.HasPrefix(filePath, tempPath), "Expected to the file to be in tempPath")
+
+ require.Empty(t, r.FormValue("file.remote_url"), "Expected to receive empty remote_url")
+ require.Empty(t, r.FormValue("file.remote_id"), "Expected to receive empty remote_id")
+ require.Equal(t, "4", r.FormValue("file.size"), "Expected to receive the file size")
+
+ hashes := map[string]string{
+ "md5": "098f6bcd4621d373cade4e832627b4f6",
+ "sha1": "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3",
+ "sha256": "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08",
+ "sha512": "ee26b0dd4af7e749aa1a8ee3c10ae9923f618980772e473f8819a5d4940e0db27ac185f8a0e1d5f84f88bc887fd67b143732c304cc5fa9ad8e6f57f50028a8ff",
+ }
+
+ for algo, hash := range hashes {
+ require.Equal(t, hash, r.FormValue("file."+algo), "file hash %s", algo)
+ }
+
+ require.Len(t, r.MultipartForm.Value, 11, "multipart form values")
+
+ w.WriteHeader(202)
+ fmt.Fprint(w, "RESPONSE")
+ })
+
+ var buffer bytes.Buffer
+
+ writer := multipart.NewWriter(&buffer)
+ writer.WriteField("token", "test")
+ file, err := writer.CreateFormFile("file", "my.file")
+ require.NoError(t, err)
+ fmt.Fprint(file, "test")
+ writer.Close()
+
+ httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", nil)
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ httpRequest = httpRequest.WithContext(ctx)
+ httpRequest.Body = ioutil.NopCloser(&buffer)
+ httpRequest.ContentLength = int64(buffer.Len())
+ httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
+ response := httptest.NewRecorder()
+
+ handler := newProxy(ts.URL)
+
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts)
+ require.Equal(t, 202, response.Code)
+
+ cancel() // this will trigger an async cleanup
+ waitUntilDeleted(t, filePath)
+}
+
+func TestUploadHandlerDetectingInjectedMultiPartData(t *testing.T) {
+ var filePath string
+
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ tests := []struct {
+ name string
+ field string
+ response int
+ }{
+ {
+ name: "injected file.path",
+ field: "file.path",
+ response: 400,
+ },
+ {
+ name: "injected file.remote_id",
+ field: "file.remote_id",
+ response: 400,
+ },
+ {
+ name: "field with other prefix",
+ field: "other.path",
+ response: 202,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "PUT", r.Method, "method")
+
+ w.WriteHeader(202)
+ fmt.Fprint(w, "RESPONSE")
+ })
+
+ var buffer bytes.Buffer
+
+ writer := multipart.NewWriter(&buffer)
+ file, err := writer.CreateFormFile("file", "my.file")
+ require.NoError(t, err)
+ fmt.Fprint(file, "test")
+
+ writer.WriteField(test.field, "value")
+ writer.Close()
+
+ httpRequest, err := http.NewRequest("PUT", ts.URL+"/url/path", &buffer)
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ httpRequest = httpRequest.WithContext(ctx)
+ httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
+ response := httptest.NewRecorder()
+
+ handler := newProxy(ts.URL)
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts)
+ require.Equal(t, test.response, response.Code)
+
+ cancel() // this will trigger an async cleanup
+ waitUntilDeleted(t, filePath)
+ })
+ }
+}
+
+func TestUploadProcessingField(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ var buffer bytes.Buffer
+
+ writer := multipart.NewWriter(&buffer)
+ writer.WriteField("token2", "test")
+ writer.Close()
+
+ httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
+ require.NoError(t, err)
+ httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
+
+ response := httptest.NewRecorder()
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, nilHandler, apiResponse, &testFormProcessor{}, opts)
+
+ require.Equal(t, 500, response.Code)
+}
+
+func TestUploadProcessingFile(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ _, testServer := test.StartObjectStore()
+ defer testServer.Close()
+
+ storeUrl := testServer.URL + test.ObjectPath
+
+ tests := []struct {
+ name string
+ preauth api.Response
+ }{
+ {
+ name: "FileStore Upload",
+ preauth: api.Response{TempPath: tempPath},
+ },
+ {
+ name: "ObjectStore Upload",
+ preauth: api.Response{RemoteObject: api.RemoteObject{StoreURL: storeUrl}},
+ },
+ {
+ name: "ObjectStore and FileStore Upload",
+ preauth: api.Response{
+ TempPath: tempPath,
+ RemoteObject: api.RemoteObject{StoreURL: storeUrl},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var buffer bytes.Buffer
+ writer := multipart.NewWriter(&buffer)
+ file, err := writer.CreateFormFile("file", "my.file")
+ require.NoError(t, err)
+ fmt.Fprint(file, "test")
+ writer.Close()
+
+ httpRequest, err := http.NewRequest("PUT", "/url/path", &buffer)
+ require.NoError(t, err)
+ httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
+
+ response := httptest.NewRecorder()
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, nilHandler, apiResponse, &testFormProcessor{}, opts)
+
+ require.Equal(t, 200, response.Code)
+ })
+ }
+
+}
+
+func TestInvalidFileNames(t *testing.T) {
+ testhelper.ConfigureSecret()
+
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ for _, testCase := range []struct {
+ filename string
+ code int
+ }{
+ {"foobar", 200}, // sanity check for test setup below
+ {"foo/bar", 500},
+ {"/../../foobar", 500},
+ {".", 500},
+ {"..", 500},
+ } {
+ buffer := &bytes.Buffer{}
+
+ writer := multipart.NewWriter(buffer)
+ file, err := writer.CreateFormFile("file", testCase.filename)
+ require.NoError(t, err)
+ fmt.Fprint(file, "test")
+ writer.Close()
+
+ httpRequest, err := http.NewRequest("POST", "/example", buffer)
+ require.NoError(t, err)
+ httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
+
+ response := httptest.NewRecorder()
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, nilHandler, apiResponse, &SavedFileTracker{Request: httpRequest}, opts)
+ require.Equal(t, testCase.code, response.Code)
+ }
+}
+
+func TestUploadHandlerRemovingExif(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ var buffer bytes.Buffer
+
+ content, err := ioutil.ReadFile("exif/testdata/sample_exif.jpg")
+ require.NoError(t, err)
+
+ writer := multipart.NewWriter(&buffer)
+ file, err := writer.CreateFormFile("file", "test.jpg")
+ require.NoError(t, err)
+
+ _, err = file.Write(content)
+ require.NoError(t, err)
+
+ err = writer.Close()
+ require.NoError(t, err)
+
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseMultipartForm(100000)
+ require.NoError(t, err)
+
+ size, err := strconv.Atoi(r.FormValue("file.size"))
+ require.NoError(t, err)
+ require.True(t, size < len(content), "Expected the file to be smaller after removal of exif")
+ require.True(t, size > 0, "Expected to receive not empty file")
+
+ w.WriteHeader(200)
+ fmt.Fprint(w, "RESPONSE")
+ })
+ defer ts.Close()
+
+ httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", &buffer)
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ httpRequest = httpRequest.WithContext(ctx)
+ httpRequest.ContentLength = int64(buffer.Len())
+ httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
+ response := httptest.NewRecorder()
+
+ handler := newProxy(ts.URL)
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts)
+ require.Equal(t, 200, response.Code)
+}
+
+func TestUploadHandlerRemovingInvalidExif(t *testing.T) {
+ tempPath, err := ioutil.TempDir("", "uploads")
+ require.NoError(t, err)
+ defer os.RemoveAll(tempPath)
+
+ var buffer bytes.Buffer
+
+ writer := multipart.NewWriter(&buffer)
+ file, err := writer.CreateFormFile("file", "test.jpg")
+ require.NoError(t, err)
+
+ fmt.Fprint(file, "this is not valid image data")
+ err = writer.Close()
+ require.NoError(t, err)
+
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseMultipartForm(100000)
+ require.Error(t, err)
+ })
+ defer ts.Close()
+
+ httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", &buffer)
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ httpRequest = httpRequest.WithContext(ctx)
+ httpRequest.ContentLength = int64(buffer.Len())
+ httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
+ response := httptest.NewRecorder()
+
+ handler := newProxy(ts.URL)
+ apiResponse := &api.Response{TempPath: tempPath}
+ preparer := &DefaultPreparer{}
+ opts, _, err := preparer.Prepare(apiResponse)
+ require.NoError(t, err)
+
+ HandleFileUploads(response, httpRequest, handler, apiResponse, &testFormProcessor{}, opts)
+ require.Equal(t, 422, response.Code)
+}
+
+func newProxy(url string) *proxy.Proxy {
+ parsedURL := helper.URLMustParse(url)
+ return proxy.NewProxy(parsedURL, "123", roundtripper.NewTestBackendRoundTripper(parsedURL))
+}
+
+func waitUntilDeleted(t *testing.T, path string) {
+ var err error
+
+ // Poll because the file removal is async
+ for i := 0; i < 100; i++ {
+ _, err = os.Stat(path)
+ if err != nil {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ require.True(t, os.IsNotExist(err), "expected the file to be deleted")
+}
diff --git a/workhorse/internal/upstream/development_test.go b/workhorse/internal/upstream/development_test.go
new file mode 100644
index 00000000000..d2957abb18b
--- /dev/null
+++ b/workhorse/internal/upstream/development_test.go
@@ -0,0 +1,39 @@
+package upstream
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestDevelopmentModeEnabled(t *testing.T) {
+ developmentMode := true
+
+ r, _ := http.NewRequest("GET", "/something", nil)
+ w := httptest.NewRecorder()
+
+ executed := false
+ NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ executed = true
+ })).ServeHTTP(w, r)
+
+ require.True(t, executed, "The handler should get executed")
+}
+
+func TestDevelopmentModeDisabled(t *testing.T) {
+ developmentMode := false
+
+ r, _ := http.NewRequest("GET", "/something", nil)
+ w := httptest.NewRecorder()
+
+ executed := false
+ NotFoundUnless(developmentMode, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ executed = true
+ })).ServeHTTP(w, r)
+
+ require.False(t, executed, "The handler should not get executed")
+
+ require.Equal(t, 404, w.Code)
+}
diff --git a/workhorse/internal/upstream/handlers.go b/workhorse/internal/upstream/handlers.go
new file mode 100644
index 00000000000..a6aa148a4ae
--- /dev/null
+++ b/workhorse/internal/upstream/handlers.go
@@ -0,0 +1,39 @@
+package upstream
+
+import (
+ "compress/gzip"
+ "fmt"
+ "io"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+func contentEncodingHandler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var body io.ReadCloser
+ var err error
+
+ // The client request body may have been gzipped.
+ contentEncoding := r.Header.Get("Content-Encoding")
+ switch contentEncoding {
+ case "":
+ body = r.Body
+ case "gzip":
+ body, err = gzip.NewReader(r.Body)
+ default:
+ err = fmt.Errorf("unsupported content encoding: %s", contentEncoding)
+ }
+
+ if err != nil {
+ helper.Fail500(w, r, fmt.Errorf("contentEncodingHandler: %v", err))
+ return
+ }
+ defer body.Close()
+
+ r.Body = body
+ r.Header.Del("Content-Encoding")
+
+ h.ServeHTTP(w, r)
+ })
+}
diff --git a/workhorse/internal/upstream/handlers_test.go b/workhorse/internal/upstream/handlers_test.go
new file mode 100644
index 00000000000..10c7479f5c5
--- /dev/null
+++ b/workhorse/internal/upstream/handlers_test.go
@@ -0,0 +1,67 @@
+package upstream
+
+import (
+ "bytes"
+ "compress/gzip"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestGzipEncoding(t *testing.T) {
+ resp := httptest.NewRecorder()
+
+ var b bytes.Buffer
+ w := gzip.NewWriter(&b)
+ fmt.Fprint(w, "test")
+ w.Close()
+
+ body := ioutil.NopCloser(&b)
+
+ req, err := http.NewRequest("POST", "http://address/test", body)
+ require.NoError(t, err)
+ req.Header.Set("Content-Encoding", "gzip")
+
+ contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ require.IsType(t, &gzip.Reader{}, r.Body, "body type")
+ require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted")
+ })).ServeHTTP(resp, req)
+
+ require.Equal(t, 200, resp.Code)
+}
+
+func TestNoEncoding(t *testing.T) {
+ resp := httptest.NewRecorder()
+
+ var b bytes.Buffer
+ body := ioutil.NopCloser(&b)
+
+ req, err := http.NewRequest("POST", "http://address/test", body)
+ require.NoError(t, err)
+ req.Header.Set("Content-Encoding", "")
+
+ contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ require.Equal(t, body, r.Body, "Expected the same body")
+ require.Empty(t, r.Header.Get("Content-Encoding"), "Content-Encoding should be deleted")
+ })).ServeHTTP(resp, req)
+
+ require.Equal(t, 200, resp.Code)
+}
+
+func TestInvalidEncoding(t *testing.T) {
+ resp := httptest.NewRecorder()
+
+ req, err := http.NewRequest("POST", "http://address/test", nil)
+ require.NoError(t, err)
+ req.Header.Set("Content-Encoding", "application/unknown")
+
+ contentEncodingHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ t.Fatal("it shouldn't be executed")
+ })).ServeHTTP(resp, req)
+
+ require.Equal(t, 500, resp.Code)
+}
diff --git a/workhorse/internal/upstream/metrics.go b/workhorse/internal/upstream/metrics.go
new file mode 100644
index 00000000000..38528056d43
--- /dev/null
+++ b/workhorse/internal/upstream/metrics.go
@@ -0,0 +1,117 @@
+package upstream
+
+import (
+ "net/http"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+)
+
+const (
+ namespace = "gitlab_workhorse"
+ httpSubsystem = "http"
+)
+
+func secondsDurationBuckets() []float64 {
+ return []float64{
+ 0.005, /* 5ms */
+ 0.025, /* 25ms */
+ 0.1, /* 100ms */
+ 0.5, /* 500ms */
+ 1.0, /* 1s */
+ 10.0, /* 10s */
+ 30.0, /* 30s */
+ 60.0, /* 1m */
+ 300.0, /* 10m */
+ }
+}
+
+func byteSizeBuckets() []float64 {
+ return []float64{
+ 10,
+ 64,
+ 256,
+ 1024, /* 1kB */
+ 64 * 1024, /* 64kB */
+ 256 * 1024, /* 256kB */
+ 1024 * 1024, /* 1mB */
+ 64 * 1024 * 1024, /* 64mB */
+ }
+}
+
+var (
+ httpInFlightRequests = promauto.NewGauge(prometheus.GaugeOpts{
+ Namespace: namespace,
+ Subsystem: httpSubsystem,
+ Name: "in_flight_requests",
+ Help: "A gauge of requests currently being served by workhorse.",
+ })
+
+ httpRequestsTotal = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: httpSubsystem,
+ Name: "requests_total",
+ Help: "A counter for requests to workhorse.",
+ },
+ []string{"code", "method", "route"},
+ )
+
+ httpRequestDurationSeconds = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: httpSubsystem,
+ Name: "request_duration_seconds",
+ Help: "A histogram of latencies for requests to workhorse.",
+ Buckets: secondsDurationBuckets(),
+ },
+ []string{"code", "method", "route"},
+ )
+
+ httpRequestSizeBytes = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: httpSubsystem,
+ Name: "request_size_bytes",
+ Help: "A histogram of sizes of requests to workhorse.",
+ Buckets: byteSizeBuckets(),
+ },
+ []string{"code", "method", "route"},
+ )
+
+ httpResponseSizeBytes = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: httpSubsystem,
+ Name: "response_size_bytes",
+ Help: "A histogram of response sizes for requests to workhorse.",
+ Buckets: byteSizeBuckets(),
+ },
+ []string{"code", "method", "route"},
+ )
+
+ httpTimeToWriteHeaderSeconds = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: httpSubsystem,
+ Name: "time_to_write_header_seconds",
+ Help: "A histogram of request durations until the response headers are written.",
+ Buckets: secondsDurationBuckets(),
+ },
+ []string{"code", "method", "route"},
+ )
+)
+
+func instrumentRoute(next http.Handler, method string, regexpStr string) http.Handler {
+ handler := next
+
+ handler = promhttp.InstrumentHandlerCounter(httpRequestsTotal.MustCurryWith(map[string]string{"route": regexpStr}), handler)
+ handler = promhttp.InstrumentHandlerDuration(httpRequestDurationSeconds.MustCurryWith(map[string]string{"route": regexpStr}), handler)
+ handler = promhttp.InstrumentHandlerInFlight(httpInFlightRequests, handler)
+ handler = promhttp.InstrumentHandlerRequestSize(httpRequestSizeBytes.MustCurryWith(map[string]string{"route": regexpStr}), handler)
+ handler = promhttp.InstrumentHandlerResponseSize(httpResponseSizeBytes.MustCurryWith(map[string]string{"route": regexpStr}), handler)
+ handler = promhttp.InstrumentHandlerTimeToWriteHeader(httpTimeToWriteHeaderSeconds.MustCurryWith(map[string]string{"route": regexpStr}), handler)
+
+ return handler
+}
diff --git a/workhorse/internal/upstream/notfoundunless.go b/workhorse/internal/upstream/notfoundunless.go
new file mode 100644
index 00000000000..3bbe3e873a4
--- /dev/null
+++ b/workhorse/internal/upstream/notfoundunless.go
@@ -0,0 +1,11 @@
+package upstream
+
+import "net/http"
+
+func NotFoundUnless(pass bool, handler http.Handler) http.Handler {
+ if pass {
+ return handler
+ }
+
+ return http.HandlerFunc(http.NotFound)
+}
diff --git a/workhorse/internal/upstream/roundtripper/roundtripper.go b/workhorse/internal/upstream/roundtripper/roundtripper.go
new file mode 100644
index 00000000000..84f1983b471
--- /dev/null
+++ b/workhorse/internal/upstream/roundtripper/roundtripper.go
@@ -0,0 +1,61 @@
+package roundtripper
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "time"
+
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/tracing"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
+)
+
+func mustParseAddress(address, scheme string) string {
+ if scheme == "https" {
+ panic("TLS is not supported for backend connections")
+ }
+
+ for _, suffix := range []string{"", ":" + scheme} {
+ address += suffix
+ if host, port, err := net.SplitHostPort(address); err == nil && host != "" && port != "" {
+ return host + ":" + port
+ }
+ }
+
+ panic(fmt.Errorf("could not parse host:port from address %q and scheme %q", address, scheme))
+}
+
+// NewBackendRoundTripper returns a new RoundTripper instance using the provided values
+func NewBackendRoundTripper(backend *url.URL, socket string, proxyHeadersTimeout time.Duration, developmentMode bool) http.RoundTripper {
+ // Copied from the definition of http.DefaultTransport. We can't literally copy http.DefaultTransport because of its hidden internal state.
+ transport, dialer := newBackendTransport()
+ transport.ResponseHeaderTimeout = proxyHeadersTimeout
+
+ if backend != nil && socket == "" {
+ address := mustParseAddress(backend.Host, backend.Scheme)
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return dialer.DialContext(ctx, "tcp", address)
+ }
+ } else if socket != "" {
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return dialer.DialContext(ctx, "unix", socket)
+ }
+ } else {
+ panic("backend is nil and socket is empty")
+ }
+
+ return tracing.NewRoundTripper(
+ correlation.NewInstrumentedRoundTripper(
+ badgateway.NewRoundTripper(developmentMode, transport),
+ ),
+ )
+}
+
+// NewTestBackendRoundTripper sets up a RoundTripper for testing purposes
+func NewTestBackendRoundTripper(backend *url.URL) http.RoundTripper {
+ return NewBackendRoundTripper(backend, "", 0, true)
+}
diff --git a/workhorse/internal/upstream/roundtripper/roundtripper_test.go b/workhorse/internal/upstream/roundtripper/roundtripper_test.go
new file mode 100644
index 00000000000..79ffa244918
--- /dev/null
+++ b/workhorse/internal/upstream/roundtripper/roundtripper_test.go
@@ -0,0 +1,39 @@
+package roundtripper
+
+import (
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestMustParseAddress(t *testing.T) {
+ successExamples := []struct{ address, scheme, expected string }{
+ {"1.2.3.4:56", "http", "1.2.3.4:56"},
+ {"[::1]:23", "http", "::1:23"},
+ {"4.5.6.7", "http", "4.5.6.7:http"},
+ }
+ for i, example := range successExamples {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ require.Equal(t, example.expected, mustParseAddress(example.address, example.scheme))
+ })
+ }
+}
+
+func TestMustParseAddressPanic(t *testing.T) {
+ panicExamples := []struct{ address, scheme string }{
+ {"1.2.3.4", ""},
+ {"1.2.3.4", "https"},
+ }
+
+ for i, panicExample := range panicExamples {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("expected panic")
+ }
+ }()
+ mustParseAddress(panicExample.address, panicExample.scheme)
+ })
+ }
+}
diff --git a/workhorse/internal/upstream/roundtripper/transport.go b/workhorse/internal/upstream/roundtripper/transport.go
new file mode 100644
index 00000000000..84d9623b129
--- /dev/null
+++ b/workhorse/internal/upstream/roundtripper/transport.go
@@ -0,0 +1,27 @@
+package roundtripper
+
+import (
+ "net"
+ "net/http"
+ "time"
+)
+
+// newBackendTransport setups the default HTTP transport which Workhorse uses
+// to communicate with the upstream
+func newBackendTransport() (*http.Transport, *net.Dialer) {
+ dialler := &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+
+ transport := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: dialler.DialContext,
+ MaxIdleConns: 100,
+ IdleConnTimeout: 90 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+ }
+
+ return transport, dialler
+}
diff --git a/workhorse/internal/upstream/routes.go b/workhorse/internal/upstream/routes.go
new file mode 100644
index 00000000000..5bbd245719b
--- /dev/null
+++ b/workhorse/internal/upstream/routes.go
@@ -0,0 +1,345 @@
+package upstream
+
+import (
+ "net/http"
+ "net/url"
+ "path"
+ "regexp"
+
+ "github.com/gorilla/websocket"
+
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/tracing"
+
+ apipkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/artifacts"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/builds"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/channel"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/git"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/imageresizer"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/lfs"
+ proxypkg "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/redis"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/sendurl"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
+)
+
+type matcherFunc func(*http.Request) bool
+
+type routeEntry struct {
+ method string
+ regex *regexp.Regexp
+ handler http.Handler
+ matchers []matcherFunc
+}
+
+type routeOptions struct {
+ tracing bool
+ matchers []matcherFunc
+}
+
+type uploadPreparers struct {
+ artifacts upload.Preparer
+ lfs upload.Preparer
+ packages upload.Preparer
+ uploads upload.Preparer
+}
+
+const (
+ apiPattern = `^/api/`
+ ciAPIPattern = `^/ci/api/`
+ gitProjectPattern = `^/([^/]+/){1,}[^/]+\.git/`
+ projectPattern = `^/([^/]+/){1,}[^/]+/`
+ snippetUploadPattern = `^/uploads/personal_snippet`
+ userUploadPattern = `^/uploads/user`
+ importPattern = `^/import/`
+)
+
+func compileRegexp(regexpStr string) *regexp.Regexp {
+ if len(regexpStr) == 0 {
+ return nil
+ }
+
+ return regexp.MustCompile(regexpStr)
+}
+
+func withMatcher(f matcherFunc) func(*routeOptions) {
+ return func(options *routeOptions) {
+ options.matchers = append(options.matchers, f)
+ }
+}
+
+func withoutTracing() func(*routeOptions) {
+ return func(options *routeOptions) {
+ options.tracing = false
+ }
+}
+
+func (u *upstream) observabilityMiddlewares(handler http.Handler, method string, regexpStr string) http.Handler {
+ handler = log.AccessLogger(
+ handler,
+ log.WithAccessLogger(u.accessLogger),
+ log.WithExtraFields(func(r *http.Request) log.Fields {
+ return log.Fields{
+ "route": regexpStr, // This field matches the `route` label in Prometheus metrics
+ }
+ }),
+ )
+
+ handler = instrumentRoute(handler, method, regexpStr) // Add prometheus metrics
+ return handler
+}
+
+func (u *upstream) route(method, regexpStr string, handler http.Handler, opts ...func(*routeOptions)) routeEntry {
+ // Instantiate a route with the defaults
+ options := routeOptions{
+ tracing: true,
+ }
+
+ for _, f := range opts {
+ f(&options)
+ }
+
+ handler = u.observabilityMiddlewares(handler, method, regexpStr)
+ handler = denyWebsocket(handler) // Disallow websockets
+ if options.tracing {
+ // Add distributed tracing
+ handler = tracing.Handler(handler, tracing.WithRouteIdentifier(regexpStr))
+ }
+
+ return routeEntry{
+ method: method,
+ regex: compileRegexp(regexpStr),
+ handler: handler,
+ matchers: options.matchers,
+ }
+}
+
+func (u *upstream) wsRoute(regexpStr string, handler http.Handler, matchers ...matcherFunc) routeEntry {
+ method := "GET"
+ handler = u.observabilityMiddlewares(handler, method, regexpStr)
+
+ return routeEntry{
+ method: method,
+ regex: compileRegexp(regexpStr),
+ handler: handler,
+ matchers: append(matchers, websocket.IsWebSocketUpgrade),
+ }
+}
+
+// Creates matcherFuncs for a particular content type.
+func isContentType(contentType string) func(*http.Request) bool {
+ return func(r *http.Request) bool {
+ return helper.IsContentType(contentType, r.Header.Get("Content-Type"))
+ }
+}
+
+func (ro *routeEntry) isMatch(cleanedPath string, req *http.Request) bool {
+ if ro.method != "" && req.Method != ro.method {
+ return false
+ }
+
+ if ro.regex != nil && !ro.regex.MatchString(cleanedPath) {
+ return false
+ }
+
+ ok := true
+ for _, matcher := range ro.matchers {
+ ok = matcher(req)
+ if !ok {
+ break
+ }
+ }
+
+ return ok
+}
+
+func buildProxy(backend *url.URL, version string, rt http.RoundTripper, cfg config.Config) http.Handler {
+ proxier := proxypkg.NewProxy(backend, version, rt)
+
+ return senddata.SendData(
+ sendfile.SendFile(apipkg.Block(proxier)),
+ git.SendArchive,
+ git.SendBlob,
+ git.SendDiff,
+ git.SendPatch,
+ git.SendSnapshot,
+ artifacts.SendEntry,
+ sendurl.SendURL,
+ imageresizer.NewResizer(cfg),
+ )
+}
+
+// Routing table
+// We match against URI not containing the relativeUrlRoot:
+// see upstream.ServeHTTP
+
+func (u *upstream) configureRoutes() {
+ api := apipkg.NewAPI(
+ u.Backend,
+ u.Version,
+ u.RoundTripper,
+ )
+
+ static := &staticpages.Static{DocumentRoot: u.DocumentRoot}
+ proxy := buildProxy(u.Backend, u.Version, u.RoundTripper, u.Config)
+ cableProxy := proxypkg.NewProxy(u.CableBackend, u.Version, u.CableRoundTripper)
+
+ assetsNotFoundHandler := NotFoundUnless(u.DevelopmentMode, proxy)
+ if u.AltDocumentRoot != "" {
+ altStatic := &staticpages.Static{DocumentRoot: u.AltDocumentRoot}
+ assetsNotFoundHandler = altStatic.ServeExisting(
+ u.URLPrefix,
+ staticpages.CacheExpireMax,
+ NotFoundUnless(u.DevelopmentMode, proxy),
+ )
+ }
+
+ signingTripper := secret.NewRoundTripper(u.RoundTripper, u.Version)
+ signingProxy := buildProxy(u.Backend, u.Version, signingTripper, u.Config)
+
+ preparers := createUploadPreparers(u.Config)
+ uploadPath := path.Join(u.DocumentRoot, "uploads/tmp")
+ uploadAccelerateProxy := upload.Accelerate(&upload.SkipRailsAuthorizer{TempPath: uploadPath}, proxy, preparers.uploads)
+ ciAPIProxyQueue := queueing.QueueRequests("ci_api_job_requests", uploadAccelerateProxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout)
+ ciAPILongPolling := builds.RegisterHandler(ciAPIProxyQueue, redis.WatchKey, u.APICILongPollingDuration)
+
+ // Serve static files or forward the requests
+ defaultUpstream := static.ServeExisting(
+ u.URLPrefix,
+ staticpages.CacheDisabled,
+ static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, uploadAccelerateProxy)),
+ )
+ probeUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatJSON, proxy)
+ healthUpstream := static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatText, proxy)
+
+ u.Routes = []routeEntry{
+ // Git Clone
+ u.route("GET", gitProjectPattern+`info/refs\z`, git.GetInfoRefsHandler(api)),
+ u.route("POST", gitProjectPattern+`git-upload-pack\z`, contentEncodingHandler(git.UploadPack(api)), withMatcher(isContentType("application/x-git-upload-pack-request"))),
+ u.route("POST", gitProjectPattern+`git-receive-pack\z`, contentEncodingHandler(git.ReceivePack(api)), withMatcher(isContentType("application/x-git-receive-pack-request"))),
+ u.route("PUT", gitProjectPattern+`gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`, lfs.PutStore(api, signingProxy, preparers.lfs), withMatcher(isContentType("application/octet-stream"))),
+
+ // CI Artifacts
+ u.route("POST", apiPattern+`v4/jobs/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy, preparers.artifacts))),
+ u.route("POST", ciAPIPattern+`v1/builds/[0-9]+/artifacts\z`, contentEncodingHandler(artifacts.UploadArtifacts(api, signingProxy, preparers.artifacts))),
+
+ // ActionCable websocket
+ u.wsRoute(`^/-/cable\z`, cableProxy),
+
+ // Terminal websocket
+ u.wsRoute(projectPattern+`-/environments/[0-9]+/terminal.ws\z`, channel.Handler(api)),
+ u.wsRoute(projectPattern+`-/jobs/[0-9]+/terminal.ws\z`, channel.Handler(api)),
+
+ // Proxy Job Services
+ u.wsRoute(projectPattern+`-/jobs/[0-9]+/proxy.ws\z`, channel.Handler(api)),
+
+ // Long poll and limit capacity given to jobs/request and builds/register.json
+ u.route("", apiPattern+`v4/jobs/request\z`, ciAPILongPolling),
+ u.route("", ciAPIPattern+`v1/builds/register.json\z`, ciAPILongPolling),
+
+ // Maven Artifact Repository
+ u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/maven/`, upload.BodyUploader(api, signingProxy, preparers.packages)),
+
+ // Conan Artifact Repository
+ u.route("PUT", apiPattern+`v4/packages/conan/`, upload.BodyUploader(api, signingProxy, preparers.packages)),
+ u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/conan/`, upload.BodyUploader(api, signingProxy, preparers.packages)),
+
+ // Generic Packages Repository
+ u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/generic/`, upload.BodyUploader(api, signingProxy, preparers.packages)),
+
+ // NuGet Artifact Repository
+ u.route("PUT", apiPattern+`v4/projects/[0-9]+/packages/nuget/`, upload.Accelerate(api, signingProxy, preparers.packages)),
+
+ // PyPI Artifact Repository
+ u.route("POST", apiPattern+`v4/projects/[0-9]+/packages/pypi`, upload.Accelerate(api, signingProxy, preparers.packages)),
+
+ // Debian Artifact Repository
+ u.route("PUT", apiPattern+`v4/projects/[0-9]+/-/packages/debian/incoming/`, upload.BodyUploader(api, signingProxy, preparers.packages)),
+
+ // We are porting API to disk acceleration
+ // we need to declare each routes until we have fixed all the routes on the rails codebase.
+ // Overall status can be seen at https://gitlab.com/groups/gitlab-org/-/epics/1802#current-status
+ u.route("POST", apiPattern+`v4/projects/[0-9]+/wikis/attachments\z`, uploadAccelerateProxy),
+ u.route("POST", apiPattern+`graphql\z`, uploadAccelerateProxy),
+ u.route("POST", apiPattern+`v4/groups/import`, upload.Accelerate(api, signingProxy, preparers.uploads)),
+ u.route("POST", apiPattern+`v4/projects/import`, upload.Accelerate(api, signingProxy, preparers.uploads)),
+
+ // Project Import via UI upload acceleration
+ u.route("POST", importPattern+`gitlab_project`, upload.Accelerate(api, signingProxy, preparers.uploads)),
+ // Group Import via UI upload acceleration
+ u.route("POST", importPattern+`gitlab_group`, upload.Accelerate(api, signingProxy, preparers.uploads)),
+
+ // Metric image upload
+ u.route("POST", apiPattern+`v4/projects/[0-9]+/issues/[0-9]+/metric_images\z`, upload.Accelerate(api, signingProxy, preparers.uploads)),
+
+ // Requirements Import via UI upload acceleration
+ u.route("POST", projectPattern+`requirements_management/requirements/import_csv`, upload.Accelerate(api, signingProxy, preparers.uploads)),
+
+ // Explicitly proxy API requests
+ u.route("", apiPattern, proxy),
+ u.route("", ciAPIPattern, proxy),
+
+ // Serve assets
+ u.route(
+ "", `^/assets/`,
+ static.ServeExisting(
+ u.URLPrefix,
+ staticpages.CacheExpireMax,
+ assetsNotFoundHandler,
+ ),
+ withoutTracing(), // Tracing on assets is very noisy
+ ),
+
+ // Uploads
+ u.route("POST", projectPattern+`uploads\z`, upload.Accelerate(api, signingProxy, preparers.uploads)),
+ u.route("POST", snippetUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)),
+ u.route("POST", userUploadPattern, upload.Accelerate(api, signingProxy, preparers.uploads)),
+
+ // For legacy reasons, user uploads are stored under the document root.
+ // To prevent anybody who knows/guesses the URL of a user-uploaded file
+ // from downloading it we make sure requests to /uploads/ do _not_ pass
+ // through static.ServeExisting.
+ u.route("", `^/uploads/`, static.ErrorPagesUnless(u.DevelopmentMode, staticpages.ErrorFormatHTML, proxy)),
+
+ // health checks don't intercept errors and go straight to rails
+ // TODO: We should probably not return a HTML deploy page?
+ // https://gitlab.com/gitlab-org/gitlab-workhorse/issues/230
+ u.route("", "^/-/(readiness|liveness)$", static.DeployPage(probeUpstream)),
+ u.route("", "^/-/health$", static.DeployPage(healthUpstream)),
+
+ // This route lets us filter out health checks from our metrics.
+ u.route("", "^/-/", defaultUpstream),
+
+ u.route("", "", defaultUpstream),
+ }
+}
+
+func createUploadPreparers(cfg config.Config) uploadPreparers {
+ defaultPreparer := upload.NewObjectStoragePreparer(cfg)
+
+ return uploadPreparers{
+ artifacts: defaultPreparer,
+ lfs: lfs.NewLfsUploadPreparer(cfg, defaultPreparer),
+ packages: defaultPreparer,
+ uploads: defaultPreparer,
+ }
+}
+
+func denyWebsocket(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if websocket.IsWebSocketUpgrade(r) {
+ helper.HTTPError(w, r, "websocket upgrade not allowed", http.StatusBadRequest)
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+}
diff --git a/workhorse/internal/upstream/upstream.go b/workhorse/internal/upstream/upstream.go
new file mode 100644
index 00000000000..fd3f6191a5a
--- /dev/null
+++ b/workhorse/internal/upstream/upstream.go
@@ -0,0 +1,123 @@
+/*
+The upstream type implements http.Handler.
+
+In this file we handle request routing and interaction with the authBackend.
+*/
+
+package upstream
+
+import (
+ "fmt"
+
+ "net/http"
+ "strings"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/labkit/correlation"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream/roundtripper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix"
+)
+
+var (
+ DefaultBackend = helper.URLMustParse("http://localhost:8080")
+ requestHeaderBlacklist = []string{
+ upload.RewrittenFieldsHeader,
+ }
+)
+
+type upstream struct {
+ config.Config
+ URLPrefix urlprefix.Prefix
+ Routes []routeEntry
+ RoundTripper http.RoundTripper
+ CableRoundTripper http.RoundTripper
+ accessLogger *logrus.Logger
+}
+
+func NewUpstream(cfg config.Config, accessLogger *logrus.Logger) http.Handler {
+ up := upstream{
+ Config: cfg,
+ accessLogger: accessLogger,
+ }
+ if up.Backend == nil {
+ up.Backend = DefaultBackend
+ }
+ if up.CableBackend == nil {
+ up.CableBackend = up.Backend
+ }
+ if up.CableSocket == "" {
+ up.CableSocket = up.Socket
+ }
+ up.RoundTripper = roundtripper.NewBackendRoundTripper(up.Backend, up.Socket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
+ up.CableRoundTripper = roundtripper.NewBackendRoundTripper(up.CableBackend, up.CableSocket, up.ProxyHeadersTimeout, cfg.DevelopmentMode)
+ up.configureURLPrefix()
+ up.configureRoutes()
+
+ var correlationOpts []correlation.InboundHandlerOption
+ if cfg.PropagateCorrelationID {
+ correlationOpts = append(correlationOpts, correlation.WithPropagation())
+ }
+
+ handler := correlation.InjectCorrelationID(&up, correlationOpts...)
+ return handler
+}
+
+func (u *upstream) configureURLPrefix() {
+ relativeURLRoot := u.Backend.Path
+ if !strings.HasSuffix(relativeURLRoot, "/") {
+ relativeURLRoot += "/"
+ }
+ u.URLPrefix = urlprefix.Prefix(relativeURLRoot)
+}
+
+func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ helper.FixRemoteAddr(r)
+
+ helper.DisableResponseBuffering(w)
+
+ // Drop RequestURI == "*" (FIXME: why?)
+ if r.RequestURI == "*" {
+ helper.HTTPError(w, r, "Connection upgrade not allowed", http.StatusBadRequest)
+ return
+ }
+
+ // Disallow connect
+ if r.Method == "CONNECT" {
+ helper.HTTPError(w, r, "CONNECT not allowed", http.StatusBadRequest)
+ return
+ }
+
+ // Check URL Root
+ URIPath := urlprefix.CleanURIPath(r.URL.Path)
+ prefix := u.URLPrefix
+ if !prefix.Match(URIPath) {
+ helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
+ return
+ }
+
+ // Look for a matching route
+ var route *routeEntry
+ for _, ro := range u.Routes {
+ if ro.isMatch(prefix.Strip(URIPath), r) {
+ route = &ro
+ break
+ }
+ }
+
+ if route == nil {
+ // The protocol spec in git/Documentation/technical/http-protocol.txt
+ // says we must return 403 if no matching service is found.
+ helper.HTTPError(w, r, "Forbidden", http.StatusForbidden)
+ return
+ }
+
+ for _, h := range requestHeaderBlacklist {
+ r.Header.Del(h)
+ }
+
+ route.handler.ServeHTTP(w, r)
+}
diff --git a/workhorse/internal/urlprefix/urlprefix.go b/workhorse/internal/urlprefix/urlprefix.go
new file mode 100644
index 00000000000..23eefe70c67
--- /dev/null
+++ b/workhorse/internal/urlprefix/urlprefix.go
@@ -0,0 +1,35 @@
+package urlprefix
+
+import (
+ "path"
+ "strings"
+)
+
+type Prefix string
+
+func (p Prefix) Strip(path string) string {
+ return CleanURIPath(strings.TrimPrefix(path, string(p)))
+}
+
+func (p Prefix) Match(path string) bool {
+ pre := string(p)
+ return strings.HasPrefix(path, pre) || path+"/" == pre
+}
+
+// Borrowed from: net/http/server.go
+// Return the canonical path for p, eliminating . and .. elements.
+func CleanURIPath(p string) string {
+ if p == "" {
+ return "/"
+ }
+ if p[0] != '/' {
+ p = "/" + p
+ }
+ np := path.Clean(p)
+ // path.Clean removes trailing slash except for root;
+ // put the trailing slash back if necessary.
+ if p[len(p)-1] == '/' && np != "/" {
+ np += "/"
+ }
+ return np
+}
diff --git a/workhorse/internal/utils/svg/LICENSE b/workhorse/internal/utils/svg/LICENSE
new file mode 100644
index 00000000000..f67807d0070
--- /dev/null
+++ b/workhorse/internal/utils/svg/LICENSE
@@ -0,0 +1,24 @@
+The MIT License
+
+Copyright (c) 2016 Tomas Aparicio
+
+Permission is hereby granted, free of charge, to any person
+obtaining a copy of this software and associated documentation
+files (the "Software"), to deal in the Software without
+restriction, including without limitation the rights to use,
+copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following
+conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+OTHER DEALINGS IN THE SOFTWARE.
diff --git a/workhorse/internal/utils/svg/README.md b/workhorse/internal/utils/svg/README.md
new file mode 100644
index 00000000000..e5531f47473
--- /dev/null
+++ b/workhorse/internal/utils/svg/README.md
@@ -0,0 +1,45 @@
+# go-is-svg
+
+Tiny package to verify if a given file buffer is an SVG image in Go (golang).
+
+## Installation
+
+```bash
+go get -u github.com/h2non/go-is-svg
+```
+
+## Example
+
+```go
+package main
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ svg "github.com/h2non/go-is-svg"
+)
+
+func main() {
+ buf, err := ioutil.ReadFile("_example/example.svg")
+ if err != nil {
+ fmt.Printf("Error: %s\n", err)
+ return
+ }
+
+ if svg.Is(buf) {
+ fmt.Println("File is an SVG")
+ } else {
+ fmt.Println("File is NOT an SVG")
+ }
+}
+```
+
+Run example:
+```bash
+go run _example/example.go
+```
+
+## License
+
+MIT - Tomas Aparicio
diff --git a/workhorse/internal/utils/svg/svg.go b/workhorse/internal/utils/svg/svg.go
new file mode 100644
index 00000000000..b209cb5bf33
--- /dev/null
+++ b/workhorse/internal/utils/svg/svg.go
@@ -0,0 +1,42 @@
+// Copyright (c) 2016 Tomas Aparicio. All rights reserved.
+//
+// Use of this source code is governed by a MIT License
+// license that can be found in the LICENSE file or at
+// https://github.com/h2non/go-is-svg/blob/master/LICENSE.
+
+package svg
+
+import (
+ "regexp"
+ "unicode/utf8"
+)
+
+var (
+ htmlCommentRegex = regexp.MustCompile(`(?i)<!--([\s\S]*?)-->`)
+ svgRegex = regexp.MustCompile(`(?i)^\s*(?:<\?xml[^>]*>\s*)?(?:<!doctype svg[^>]*>\s*)?<svg[^>]*>`)
+)
+
+// isBinary checks if the given buffer is a binary file.
+func isBinary(buf []byte) bool {
+ if len(buf) < 24 {
+ return false
+ }
+ for i := 0; i < 24; i++ {
+ charCode, _ := utf8.DecodeRuneInString(string(buf[i]))
+ if charCode == 65533 || charCode <= 8 {
+ return true
+ }
+ }
+ return false
+}
+
+// Is returns true if the given buffer is a valid SVG image.
+func Is(buf []byte) bool {
+ return !isBinary(buf) && svgRegex.Match(htmlCommentRegex.ReplaceAll(buf, []byte{}))
+}
+
+// IsSVG returns true if the given buffer is a valid SVG image.
+// Alias to: Is()
+func IsSVG(buf []byte) bool {
+ return Is(buf)
+}
diff --git a/workhorse/internal/zipartifacts/.gitignore b/workhorse/internal/zipartifacts/.gitignore
new file mode 100644
index 00000000000..ace1063ab02
--- /dev/null
+++ b/workhorse/internal/zipartifacts/.gitignore
@@ -0,0 +1 @@
+/testdata
diff --git a/workhorse/internal/zipartifacts/entry.go b/workhorse/internal/zipartifacts/entry.go
new file mode 100644
index 00000000000..527387ceaa1
--- /dev/null
+++ b/workhorse/internal/zipartifacts/entry.go
@@ -0,0 +1,13 @@
+package zipartifacts
+
+import (
+ "encoding/base64"
+)
+
+func DecodeFileEntry(entry string) (string, error) {
+ decoded, err := base64.StdEncoding.DecodeString(entry)
+ if err != nil {
+ return "", err
+ }
+ return string(decoded), nil
+}
diff --git a/workhorse/internal/zipartifacts/errors.go b/workhorse/internal/zipartifacts/errors.go
new file mode 100644
index 00000000000..162816618f8
--- /dev/null
+++ b/workhorse/internal/zipartifacts/errors.go
@@ -0,0 +1,57 @@
+package zipartifacts
+
+import (
+ "errors"
+)
+
+// These are exit codes used by subprocesses in cmd/gitlab-zip-xxx. We also use
+// them to map errors and error messages that we use as label in Prometheus.
+const (
+ CodeNotZip = 10 + iota
+ CodeEntryNotFound
+ CodeArchiveNotFound
+ CodeLimitsReached
+ CodeUnknownError
+)
+
+var (
+ ErrorCode = map[int]error{
+ CodeNotZip: errors.New("zip archive format invalid"),
+ CodeEntryNotFound: errors.New("zip entry not found"),
+ CodeArchiveNotFound: errors.New("zip archive not found"),
+ CodeLimitsReached: errors.New("zip processing limits reached"),
+ CodeUnknownError: errors.New("zip processing unknown error"),
+ }
+
+ ErrorLabel = map[int]string{
+ CodeNotZip: "archive_invalid",
+ CodeEntryNotFound: "entry_not_found",
+ CodeArchiveNotFound: "archive_not_found",
+ CodeLimitsReached: "limits_reached",
+ CodeUnknownError: "unknown_error",
+ }
+
+ ErrBadMetadata = errors.New("zip artifacts metadata invalid")
+)
+
+// ExitCodeByError find an os.Exit code for a corresponding error.
+// CodeUnkownError in case it can not be found.
+func ExitCodeByError(err error) int {
+ for c, e := range ErrorCode {
+ if err == e {
+ return c
+ }
+ }
+
+ return CodeUnknownError
+}
+
+// ErrorLabelByCode returns a Prometheus counter label associated with an exit code.
+func ErrorLabelByCode(code int) string {
+ label, ok := ErrorLabel[code]
+ if ok {
+ return label
+ }
+
+ return ErrorLabel[CodeUnknownError]
+}
diff --git a/workhorse/internal/zipartifacts/errors_test.go b/workhorse/internal/zipartifacts/errors_test.go
new file mode 100644
index 00000000000..6fce160b3bc
--- /dev/null
+++ b/workhorse/internal/zipartifacts/errors_test.go
@@ -0,0 +1,32 @@
+package zipartifacts
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestExitCodeByError(t *testing.T) {
+ t.Run("when error has been recognized", func(t *testing.T) {
+ code := ExitCodeByError(ErrorCode[CodeLimitsReached])
+
+ require.Equal(t, code, CodeLimitsReached)
+ require.Greater(t, code, 10)
+ })
+
+ t.Run("when error is an unknown one", func(t *testing.T) {
+ code := ExitCodeByError(errors.New("unknown error"))
+
+ require.Equal(t, code, CodeUnknownError)
+ require.Greater(t, code, 10)
+ })
+}
+
+func TestErrorLabels(t *testing.T) {
+ for code := range ErrorCode {
+ _, ok := ErrorLabel[code]
+
+ require.True(t, ok)
+ }
+}
diff --git a/workhorse/internal/zipartifacts/metadata.go b/workhorse/internal/zipartifacts/metadata.go
new file mode 100644
index 00000000000..1ecf52deafb
--- /dev/null
+++ b/workhorse/internal/zipartifacts/metadata.go
@@ -0,0 +1,117 @@
+package zipartifacts
+
+import (
+ "archive/zip"
+ "compress/gzip"
+ "encoding/binary"
+ "encoding/json"
+ "io"
+ "path"
+ "sort"
+ "strconv"
+)
+
+type metadata struct {
+ Modified int64 `json:"modified,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ CRC uint32 `json:"crc,omitempty"`
+ Size uint64 `json:"size,omitempty"`
+ Zipped uint64 `json:"zipped,omitempty"`
+ Comment string `json:"comment,omitempty"`
+}
+
+const MetadataHeaderPrefix = "\x00\x00\x00&" // length of string below, encoded properly
+const MetadataHeader = "GitLab Build Artifacts Metadata 0.0.2\n"
+
+func newMetadata(file *zip.File) metadata {
+ if file == nil {
+ return metadata{}
+ }
+
+ return metadata{
+ //lint:ignore SA1019 Remove this once the minimum supported version is go 1.10 (go 1.9 and down do not support an alternative)
+ Modified: file.ModTime().Unix(),
+ Mode: strconv.FormatUint(uint64(file.Mode().Perm()), 8),
+ CRC: file.CRC32,
+ Size: file.UncompressedSize64,
+ Zipped: file.CompressedSize64,
+ Comment: file.Comment,
+ }
+}
+
+func (m metadata) writeEncoded(output io.Writer) error {
+ j, err := json.Marshal(m)
+ if err != nil {
+ return err
+ }
+ j = append(j, byte('\n'))
+ return writeBytes(output, j)
+}
+
+func writeZipEntryMetadata(output io.Writer, path string, entry *zip.File) error {
+ if err := writeString(output, path); err != nil {
+ return err
+ }
+
+ if err := newMetadata(entry).writeEncoded(output); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func GenerateZipMetadata(w io.Writer, archive *zip.Reader) error {
+ output := gzip.NewWriter(w)
+ defer output.Close()
+
+ if err := writeString(output, MetadataHeader); err != nil {
+ return err
+ }
+
+ // Write empty error header that we may need in the future
+ if err := writeString(output, "{}"); err != nil {
+ return err
+ }
+
+ // Create map of files in zip archive
+ zipMap := make(map[string]*zip.File, len(archive.File))
+
+ // Add missing entries
+ for _, entry := range archive.File {
+ zipMap[entry.Name] = entry
+
+ for d := path.Dir(entry.Name); d != "." && d != "/"; d = path.Dir(d) {
+ entryDir := d + "/"
+ if _, ok := zipMap[entryDir]; !ok {
+ zipMap[entryDir] = nil
+ }
+ }
+ }
+
+ // Sort paths
+ sortedPaths := make([]string, 0, len(zipMap))
+ for path := range zipMap {
+ sortedPaths = append(sortedPaths, path)
+ }
+ sort.Strings(sortedPaths)
+
+ // Write all files
+ for _, path := range sortedPaths {
+ if err := writeZipEntryMetadata(output, path, zipMap[path]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func writeBytes(output io.Writer, data []byte) error {
+ err := binary.Write(output, binary.BigEndian, uint32(len(data)))
+ if err == nil {
+ _, err = output.Write(data)
+ }
+ return err
+}
+
+func writeString(output io.Writer, str string) error {
+ return writeBytes(output, []byte(str))
+}
diff --git a/workhorse/internal/zipartifacts/metadata_test.go b/workhorse/internal/zipartifacts/metadata_test.go
new file mode 100644
index 00000000000..0f130ab4c15
--- /dev/null
+++ b/workhorse/internal/zipartifacts/metadata_test.go
@@ -0,0 +1,102 @@
+package zipartifacts_test
+
+import (
+ "archive/zip"
+ "bytes"
+ "compress/gzip"
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
+)
+
+func generateTestArchive(w io.Writer) error {
+ archive := zip.NewWriter(w)
+
+ // non-POSIX paths are here just to test if we never enter infinite loop
+ files := []string{"file1", "some/file/dir/", "some/file/dir/file2", "../../test12/test",
+ "/usr/bin/test", `c:\windows\win32.exe`, `c:/windows/win.dll`, "./f/asd", "/"}
+
+ for _, file := range files {
+ archiveFile, err := archive.Create(file)
+ if err != nil {
+ return err
+ }
+
+ fmt.Fprint(archiveFile, file)
+ }
+
+ return archive.Close()
+}
+
+func validateMetadata(r io.Reader) error {
+ gz, err := gzip.NewReader(r)
+ if err != nil {
+ return err
+ }
+
+ meta, err := ioutil.ReadAll(gz)
+ if err != nil {
+ return err
+ }
+
+ paths := []string{"file1", "some/", "some/file/", "some/file/dir/", "some/file/dir/file2"}
+ for _, path := range paths {
+ if !bytes.Contains(meta, []byte(path+"\x00")) {
+ return fmt.Errorf(fmt.Sprintf("zipartifacts: metadata for path %q not found", path))
+ }
+ }
+
+ return nil
+}
+
+func TestGenerateZipMetadataFromFile(t *testing.T) {
+ var metaBuffer bytes.Buffer
+
+ f, err := ioutil.TempFile("", "workhorse-metadata.zip-")
+ if f != nil {
+ defer os.Remove(f.Name())
+ }
+ require.NoError(t, err)
+ defer f.Close()
+
+ err = generateTestArchive(f)
+ require.NoError(t, err)
+ f.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ archive, err := zipartifacts.OpenArchive(ctx, f.Name())
+ require.NoError(t, err, "zipartifacts: OpenArchive failed")
+
+ err = zipartifacts.GenerateZipMetadata(&metaBuffer, archive)
+ require.NoError(t, err, "zipartifacts: GenerateZipMetadata failed")
+
+ err = validateMetadata(&metaBuffer)
+ require.NoError(t, err)
+}
+
+func TestErrNotAZip(t *testing.T) {
+ f, err := ioutil.TempFile("", "workhorse-metadata.zip-")
+ if f != nil {
+ defer os.Remove(f.Name())
+ }
+ require.NoError(t, err)
+ defer f.Close()
+
+ _, err = fmt.Fprint(f, "Not a zip file")
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, err = zipartifacts.OpenArchive(ctx, f.Name())
+ require.Equal(t, zipartifacts.ErrorCode[zipartifacts.CodeNotZip], err, "OpenArchive requires a zip file")
+}
diff --git a/workhorse/internal/zipartifacts/open_archive.go b/workhorse/internal/zipartifacts/open_archive.go
new file mode 100644
index 00000000000..30b86b66c49
--- /dev/null
+++ b/workhorse/internal/zipartifacts/open_archive.go
@@ -0,0 +1,138 @@
+package zipartifacts
+
+import (
+ "archive/zip"
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "strings"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/httprs"
+
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/mask"
+ "gitlab.com/gitlab-org/labkit/tracing"
+)
+
+var httpClient = &http.Client{
+ Transport: tracing.NewRoundTripper(correlation.NewInstrumentedRoundTripper(&http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 10 * time.Second,
+ }).DialContext,
+ IdleConnTimeout: 30 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 10 * time.Second,
+ ResponseHeaderTimeout: 30 * time.Second,
+ DisableCompression: true,
+ })),
+}
+
+type archive struct {
+ reader io.ReaderAt
+ size int64
+}
+
+// OpenArchive will open a zip.Reader from a local path or a remote object store URL
+// in case of remote url it will make use of ranged requestes to support seeking.
+// If the path do not exists error will be ErrArchiveNotFound,
+// if the file isn't a zip archive error will be ErrNotAZip
+func OpenArchive(ctx context.Context, archivePath string) (*zip.Reader, error) {
+ archive, err := openArchiveLocation(ctx, archivePath)
+ if err != nil {
+ return nil, err
+ }
+
+ return openZipReader(archive.reader, archive.size)
+}
+
+// OpenArchiveWithReaderFunc opens a zip.Reader from either local path or a
+// remote object, similarly to OpenArchive function. The difference is that it
+// allows passing a readerFunc that takes a io.ReaderAt that is either going to
+// be os.File or a custom reader we use to read from object storage. The
+// readerFunc can augment the archive reader and return a type that satisfies
+// io.ReaderAt.
+func OpenArchiveWithReaderFunc(ctx context.Context, location string, readerFunc func(io.ReaderAt, int64) io.ReaderAt) (*zip.Reader, error) {
+ archive, err := openArchiveLocation(ctx, location)
+ if err != nil {
+ return nil, err
+ }
+
+ return openZipReader(readerFunc(archive.reader, archive.size), archive.size)
+}
+
+func openArchiveLocation(ctx context.Context, location string) (*archive, error) {
+ if isURL(location) {
+ return openHTTPArchive(ctx, location)
+ }
+
+ return openFileArchive(ctx, location)
+}
+
+func isURL(path string) bool {
+ return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://")
+}
+
+func openHTTPArchive(ctx context.Context, archivePath string) (*archive, error) {
+ scrubbedArchivePath := mask.URL(archivePath)
+ req, err := http.NewRequest(http.MethodGet, archivePath, nil)
+ if err != nil {
+ return nil, fmt.Errorf("can't create HTTP GET %q: %v", scrubbedArchivePath, err)
+ }
+ req = req.WithContext(ctx)
+
+ resp, err := httpClient.Do(req.WithContext(ctx))
+ if err != nil {
+ return nil, fmt.Errorf("HTTP GET %q: %v", scrubbedArchivePath, err)
+ } else if resp.StatusCode == http.StatusNotFound {
+ return nil, ErrorCode[CodeArchiveNotFound]
+ } else if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("HTTP GET %q: %d: %v", scrubbedArchivePath, resp.StatusCode, resp.Status)
+ }
+
+ rs := httprs.NewHttpReadSeeker(resp, httpClient)
+
+ go func() {
+ <-ctx.Done()
+ resp.Body.Close()
+ rs.Close()
+ }()
+
+ return &archive{reader: rs, size: resp.ContentLength}, nil
+}
+
+func openFileArchive(ctx context.Context, archivePath string) (*archive, error) {
+ file, err := os.Open(archivePath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, ErrorCode[CodeArchiveNotFound]
+ }
+ }
+
+ go func() {
+ <-ctx.Done()
+ // We close the archive from this goroutine so that we can safely return a *zip.Reader instead of a *zip.ReadCloser
+ file.Close()
+ }()
+
+ stat, err := file.Stat()
+ if err != nil {
+ return nil, err
+ }
+
+ return &archive{reader: file, size: stat.Size()}, nil
+}
+
+func openZipReader(archive io.ReaderAt, size int64) (*zip.Reader, error) {
+ reader, err := zip.NewReader(archive, size)
+ if err != nil {
+ return nil, ErrorCode[CodeNotZip]
+ }
+
+ return reader, nil
+}
diff --git a/workhorse/internal/zipartifacts/open_archive_test.go b/workhorse/internal/zipartifacts/open_archive_test.go
new file mode 100644
index 00000000000..f7624d053d9
--- /dev/null
+++ b/workhorse/internal/zipartifacts/open_archive_test.go
@@ -0,0 +1,68 @@
+package zipartifacts
+
+import (
+ "archive/zip"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenHTTPArchive(t *testing.T) {
+ const (
+ zipFile = "test.zip"
+ entryName = "hello.txt"
+ contents = "world"
+ testRoot = "testdata/public"
+ )
+
+ require.NoError(t, os.MkdirAll(testRoot, 0755))
+ f, err := os.Create(filepath.Join(testRoot, zipFile))
+ require.NoError(t, err, "create file")
+ defer f.Close()
+
+ zw := zip.NewWriter(f)
+ w, err := zw.Create(entryName)
+ require.NoError(t, err, "create zip entry")
+ _, err = fmt.Fprint(w, contents)
+ require.NoError(t, err, "write zip entry contents")
+ require.NoError(t, zw.Close(), "close zip writer")
+ require.NoError(t, f.Close(), "close file")
+
+ srv := httptest.NewServer(http.FileServer(http.Dir(testRoot)))
+ defer srv.Close()
+
+ zr, err := OpenArchive(context.Background(), srv.URL+"/"+zipFile)
+ require.NoError(t, err, "call OpenArchive")
+ require.Len(t, zr.File, 1)
+
+ zf := zr.File[0]
+ require.Equal(t, entryName, zf.Name, "zip entry name")
+
+ entry, err := zf.Open()
+ require.NoError(t, err, "get zip entry reader")
+ defer entry.Close()
+
+ actualContents, err := ioutil.ReadAll(entry)
+ require.NoError(t, err, "read zip entry contents")
+ require.Equal(t, contents, string(actualContents), "compare zip entry contents")
+}
+
+func TestOpenHTTPArchiveNotSendingAcceptEncodingHeader(t *testing.T) {
+ requestHandler := func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "GET", r.Method)
+ require.Nil(t, r.Header["Accept-Encoding"])
+ w.WriteHeader(http.StatusOK)
+ }
+
+ srv := httptest.NewServer(http.HandlerFunc(requestHandler))
+ defer srv.Close()
+
+ OpenArchive(context.Background(), srv.URL)
+}