diff options
Diffstat (limited to 'workhorse/internal/api')
-rw-r--r-- | workhorse/internal/api/api.go | 345 | ||||
-rw-r--r-- | workhorse/internal/api/block.go | 61 | ||||
-rw-r--r-- | workhorse/internal/api/block_test.go | 56 | ||||
-rw-r--r-- | workhorse/internal/api/channel_settings.go | 122 | ||||
-rw-r--r-- | workhorse/internal/api/channel_settings_test.go | 154 |
5 files changed, 738 insertions, 0 deletions
diff --git a/workhorse/internal/api/api.go b/workhorse/internal/api/api.go new file mode 100644 index 00000000000..17fea398029 --- /dev/null +++ b/workhorse/internal/api/api.go @@ -0,0 +1,345 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" +) + +const ( + // Custom content type for API responses, to catch routing / programming mistakes + ResponseContentType = "application/vnd.gitlab-workhorse+json" + + failureResponseLimit = 32768 +) + +type API struct { + Client *http.Client + URL *url.URL + Version string +} + +var ( + requestsCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_internal_api_requests", + Help: "How many internal API requests have been completed by gitlab-workhorse, partitioned by status code and HTTP method.", + }, + []string{"code", "method"}, + ) + bytesTotal = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_internal_api_failure_response_bytes", + Help: "How many bytes have been returned by upstream GitLab in API failure/rejection response bodies.", + }, + ) +) + +func NewAPI(myURL *url.URL, version string, roundTripper http.RoundTripper) *API { + return &API{ + Client: &http.Client{Transport: roundTripper}, + URL: myURL, + Version: version, + } +} + +type HandleFunc func(http.ResponseWriter, *http.Request, *Response) + +type MultipartUploadParams struct { + // PartSize is the exact size of each uploaded part. Only the last one can be smaller + PartSize int64 + // PartURLs contains the presigned URLs for each part + PartURLs []string + // CompleteURL is a presigned URL for CompleteMulipartUpload + CompleteURL string + // AbortURL is a presigned URL for AbortMultipartUpload + AbortURL string +} + +type ObjectStorageParams struct { + Provider string + S3Config config.S3Config + GoCloudConfig config.GoCloudConfig +} + +type RemoteObject struct { + // GetURL is an S3 GetObject URL + GetURL string + // DeleteURL is a presigned S3 RemoveObject URL + DeleteURL string + // StoreURL is the temporary presigned S3 PutObject URL to which upload the first found file + StoreURL string + // Boolean to indicate whether to use headers included in PutHeaders + CustomPutHeaders bool + // PutHeaders are HTTP headers (e.g. Content-Type) to be sent with StoreURL + PutHeaders map[string]string + // Whether to ignore Rails pre-signed URLs and have Workhorse directly access object storage provider + UseWorkhorseClient bool + // Remote, temporary object name where Rails will move to the final destination + RemoteTempObjectID string + // ID is a unique identifier of object storage upload + ID string + // Timeout is a number that represents timeout in seconds for sending data to StoreURL + Timeout int + // MultipartUpload contains presigned URLs for S3 MultipartUpload + MultipartUpload *MultipartUploadParams + // Object storage config for Workhorse client + ObjectStorage *ObjectStorageParams +} + +type Response struct { + // GL_ID is an environment variable used by gitlab-shell hooks during 'git + // push' and 'git pull' + GL_ID string + + // GL_USERNAME holds gitlab username of the user who is taking the action causing hooks to be invoked + GL_USERNAME string + + // GL_REPOSITORY is an environment variable used by gitlab-shell hooks during + // 'git push' and 'git pull' + GL_REPOSITORY string + // GitConfigOptions holds the custom options that we want to pass to the git command + GitConfigOptions []string + // StoreLFSPath is provided by the GitLab Rails application to mark where the tmp file should be placed. + // This field is deprecated. GitLab will use TempPath instead + StoreLFSPath string + // LFS object id + LfsOid string + // LFS object size + LfsSize int64 + // TmpPath is the path where we should store temporary files + // This is set by authorization middleware + TempPath string + // RemoteObject is provided by the GitLab Rails application + // and defines a way to store object on remote storage + RemoteObject RemoteObject + // Archive is the path where the artifacts archive is stored + Archive string `json:"archive"` + // Entry is a filename inside the archive point to file that needs to be extracted + Entry string `json:"entry"` + // Used to communicate channel session details + Channel *ChannelSettings + // GitalyServer specifies an address and authentication token for a gitaly server we should connect to. + GitalyServer gitaly.Server + // Repository object for making gRPC requests to Gitaly. + Repository gitalypb.Repository + // For git-http, does the requestor have the right to view all refs? + ShowAllRefs bool + // Detects whether an artifact is used for code intelligence + ProcessLsif bool + // Detects whether LSIF artifact will be parsed with references + ProcessLsifReferences bool + // The maximum accepted size in bytes of the upload + MaximumSize int64 +} + +// singleJoiningSlash is taken from reverseproxy.go:NewSingleHostReverseProxy +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// rebaseUrl is taken from reverseproxy.go:NewSingleHostReverseProxy +func rebaseUrl(url *url.URL, onto *url.URL, suffix string) *url.URL { + newUrl := *url + newUrl.Scheme = onto.Scheme + newUrl.Host = onto.Host + if suffix != "" { + newUrl.Path = singleJoiningSlash(url.Path, suffix) + } + if onto.RawQuery == "" || newUrl.RawQuery == "" { + newUrl.RawQuery = onto.RawQuery + newUrl.RawQuery + } else { + newUrl.RawQuery = onto.RawQuery + "&" + newUrl.RawQuery + } + return &newUrl +} + +func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error) { + authReq := &http.Request{ + Method: r.Method, + URL: rebaseUrl(r.URL, api.URL, suffix), + Header: helper.HeaderClone(r.Header), + } + + authReq = authReq.WithContext(r.Context()) + + // Clean some headers when issuing a new request without body + authReq.Header.Del("Content-Type") + authReq.Header.Del("Content-Encoding") + authReq.Header.Del("Content-Length") + authReq.Header.Del("Content-Disposition") + authReq.Header.Del("Accept-Encoding") + + // Hop-by-hop headers. These are removed when sent to the backend. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html + authReq.Header.Del("Transfer-Encoding") + authReq.Header.Del("Connection") + authReq.Header.Del("Keep-Alive") + authReq.Header.Del("Proxy-Authenticate") + authReq.Header.Del("Proxy-Authorization") + authReq.Header.Del("Te") + authReq.Header.Del("Trailers") + authReq.Header.Del("Upgrade") + + // Also forward the Host header, which is excluded from the Header map by the http library. + // This allows the Host header received by the backend to be consistent with other + // requests not going through gitlab-workhorse. + authReq.Host = r.Host + + return authReq, nil +} + +// PreAuthorize performs a pre-authorization check against the API for the given HTTP request +// +// If `outErr` is set, the other fields will be nil and it should be treated as +// a 500 error. +// +// If httpResponse is present, the caller is responsible for closing its body +// +// authResponse will only be present if the authorization check was successful +func (api *API) PreAuthorize(suffix string, r *http.Request) (httpResponse *http.Response, authResponse *Response, outErr error) { + authReq, err := api.newRequest(r, suffix) + if err != nil { + return nil, nil, fmt.Errorf("preAuthorizeHandler newUpstreamRequest: %v", err) + } + + httpResponse, err = api.doRequestWithoutRedirects(authReq) + if err != nil { + return nil, nil, fmt.Errorf("preAuthorizeHandler: do request: %v", err) + } + defer func() { + if outErr != nil { + httpResponse.Body.Close() + httpResponse = nil + } + }() + requestsCounter.WithLabelValues(strconv.Itoa(httpResponse.StatusCode), authReq.Method).Inc() + + // This may be a false positive, e.g. for .../info/refs, rather than a + // failure, so pass the response back + if httpResponse.StatusCode != http.StatusOK || !validResponseContentType(httpResponse) { + return httpResponse, nil, nil + } + + authResponse = &Response{} + // The auth backend validated the client request and told us additional + // request metadata. We must extract this information from the auth + // response body. + if err := json.NewDecoder(httpResponse.Body).Decode(authResponse); err != nil { + return httpResponse, nil, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err) + } + + return httpResponse, authResponse, nil +} + +func (api *API) PreAuthorizeHandler(next HandleFunc, suffix string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpResponse, authResponse, err := api.PreAuthorize(suffix, r) + if httpResponse != nil { + defer httpResponse.Body.Close() + } + + if err != nil { + helper.Fail500(w, r, err) + return + } + + // The response couldn't be interpreted as a valid auth response, so + // pass it back (mostly) unmodified + if httpResponse != nil && authResponse == nil { + passResponseBack(httpResponse, w, r) + return + } + + httpResponse.Body.Close() // Free up the Unicorn worker + + copyAuthHeader(httpResponse, w) + + next(w, r, authResponse) + }) +} + +func (api *API) doRequestWithoutRedirects(authReq *http.Request) (*http.Response, error) { + signingTripper := secret.NewRoundTripper(api.Client.Transport, api.Version) + + return signingTripper.RoundTrip(authReq) +} + +func copyAuthHeader(httpResponse *http.Response, w http.ResponseWriter) { + // Negotiate authentication (Kerberos) may need to return a WWW-Authenticate + // header to the client even in case of success as per RFC4559. + for k, v := range httpResponse.Header { + // Case-insensitive comparison as per RFC7230 + if strings.EqualFold(k, "WWW-Authenticate") { + w.Header()[k] = v + } + } +} + +func passResponseBack(httpResponse *http.Response, w http.ResponseWriter, r *http.Request) { + // NGINX response buffering is disabled on this path (with + // X-Accel-Buffering: no) but we still want to free up the Unicorn worker + // that generated httpResponse as fast as possible. To do this we buffer + // the entire response body in memory before sending it on. + responseBody, err := bufferResponse(httpResponse.Body) + if err != nil { + helper.Fail500(w, r, err) + return + } + httpResponse.Body.Close() // Free up the Unicorn worker + bytesTotal.Add(float64(responseBody.Len())) + + for k, v := range httpResponse.Header { + // Accommodate broken clients that do case-sensitive header lookup + if k == "Www-Authenticate" { + w.Header()["WWW-Authenticate"] = v + } else { + w.Header()[k] = v + } + } + w.WriteHeader(httpResponse.StatusCode) + if _, err := io.Copy(w, responseBody); err != nil { + helper.LogError(r, err) + } +} + +func bufferResponse(r io.Reader) (*bytes.Buffer, error) { + responseBody := &bytes.Buffer{} + n, err := io.Copy(responseBody, io.LimitReader(r, failureResponseLimit)) + if err != nil { + return nil, err + } + + if n == failureResponseLimit { + return nil, fmt.Errorf("response body exceeded maximum buffer size (%d bytes)", failureResponseLimit) + } + + return responseBody, nil +} + +func validResponseContentType(resp *http.Response) bool { + return helper.IsContentType(ResponseContentType, resp.Header.Get("Content-Type")) +} diff --git a/workhorse/internal/api/block.go b/workhorse/internal/api/block.go new file mode 100644 index 00000000000..92322906c03 --- /dev/null +++ b/workhorse/internal/api/block.go @@ -0,0 +1,61 @@ +package api + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +// Prevent internal API responses intended for gitlab-workhorse from +// leaking to the end user +func Block(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rw := &blocker{rw: w, r: r} + defer rw.flush() + h.ServeHTTP(rw, r) + }) +} + +type blocker struct { + rw http.ResponseWriter + r *http.Request + hijacked bool + status int +} + +func (b *blocker) Header() http.Header { + return b.rw.Header() +} + +func (b *blocker) Write(data []byte) (int, error) { + if b.status == 0 { + b.WriteHeader(http.StatusOK) + } + if b.hijacked { + return len(data), nil + } + + return b.rw.Write(data) +} + +func (b *blocker) WriteHeader(status int) { + if b.status != 0 { + return + } + + if helper.IsContentType(ResponseContentType, b.Header().Get("Content-Type")) { + b.status = 500 + b.Header().Del("Content-Length") + b.hijacked = true + helper.Fail500(b.rw, b.r, fmt.Errorf("api.blocker: forbidden content-type: %q", ResponseContentType)) + return + } + + b.status = status + b.rw.WriteHeader(b.status) +} + +func (b *blocker) flush() { + b.WriteHeader(http.StatusOK) +} diff --git a/workhorse/internal/api/block_test.go b/workhorse/internal/api/block_test.go new file mode 100644 index 00000000000..85ad54f3cfd --- /dev/null +++ b/workhorse/internal/api/block_test.go @@ -0,0 +1,56 @@ +package api + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBlocker(t *testing.T) { + upstreamResponse := "hello world" + + testCases := []struct { + desc string + contentType string + out string + }{ + { + desc: "blocked", + contentType: ResponseContentType, + out: "Internal server error\n", + }, + { + desc: "pass", + contentType: "text/plain", + out: upstreamResponse, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + r, err := http.NewRequest("GET", "/foo", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + bl := &blocker{rw: rw, r: r} + bl.Header().Set("Content-Type", tc.contentType) + + upstreamBody := []byte(upstreamResponse) + n, err := bl.Write(upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + + rw.Flush() + + body := rw.Result().Body + data, err := ioutil.ReadAll(body) + require.NoError(t, err) + require.NoError(t, body.Close()) + + require.Equal(t, tc.out, string(data)) + }) + } +} diff --git a/workhorse/internal/api/channel_settings.go b/workhorse/internal/api/channel_settings.go new file mode 100644 index 00000000000..bf3094c9c91 --- /dev/null +++ b/workhorse/internal/api/channel_settings.go @@ -0,0 +1,122 @@ +package api + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "net/url" + + "github.com/gorilla/websocket" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +type ChannelSettings struct { + // The channel provider may require use of a particular subprotocol. If so, + // it must be specified here, and Workhorse must have a matching codec. + Subprotocols []string + + // The websocket URL to connect to. + Url string + + // Any headers (e.g., Authorization) to send with the websocket request + Header http.Header + + // The CA roots to validate the remote endpoint with, for wss:// URLs. The + // system-provided CA pool will be used if this is blank. PEM-encoded data. + CAPem string + + // The value is specified in seconds. It is converted to time.Duration + // later. + MaxSessionTime int +} + +func (t *ChannelSettings) URL() (*url.URL, error) { + return url.Parse(t.Url) +} + +func (t *ChannelSettings) Dialer() *websocket.Dialer { + dialer := &websocket.Dialer{ + Subprotocols: t.Subprotocols, + } + + if len(t.CAPem) > 0 { + pool := x509.NewCertPool() + pool.AppendCertsFromPEM([]byte(t.CAPem)) + dialer.TLSClientConfig = &tls.Config{RootCAs: pool} + } + + return dialer +} + +func (t *ChannelSettings) Clone() *ChannelSettings { + // Doesn't clone the strings, but that's OK as strings are immutable in go + cloned := *t + cloned.Header = helper.HeaderClone(t.Header) + return &cloned +} + +func (t *ChannelSettings) Dial() (*websocket.Conn, *http.Response, error) { + return t.Dialer().Dial(t.Url, t.Header) +} + +func (t *ChannelSettings) Validate() error { + if t == nil { + return fmt.Errorf("channel details not specified") + } + + if len(t.Subprotocols) == 0 { + return fmt.Errorf("no subprotocol specified") + } + + parsedURL, err := t.URL() + if err != nil { + return fmt.Errorf("invalid URL") + } + + if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" { + return fmt.Errorf("invalid websocket scheme: %q", parsedURL.Scheme) + } + + return nil +} + +func (t *ChannelSettings) IsEqual(other *ChannelSettings) bool { + if t == nil && other == nil { + return true + } + + if t == nil || other == nil { + return false + } + + if len(t.Subprotocols) != len(other.Subprotocols) { + return false + } + + for i, subprotocol := range t.Subprotocols { + if other.Subprotocols[i] != subprotocol { + return false + } + } + + if len(t.Header) != len(other.Header) { + return false + } + + for header, values := range t.Header { + if len(values) != len(other.Header[header]) { + return false + } + for i, value := range values { + if other.Header[header][i] != value { + return false + } + } + } + + return t.Url == other.Url && + t.CAPem == other.CAPem && + t.MaxSessionTime == other.MaxSessionTime +} diff --git a/workhorse/internal/api/channel_settings_test.go b/workhorse/internal/api/channel_settings_test.go new file mode 100644 index 00000000000..4aa2c835579 --- /dev/null +++ b/workhorse/internal/api/channel_settings_test.go @@ -0,0 +1,154 @@ +package api + +import ( + "net/http" + "testing" +) + +func channel(url string, subprotocols ...string) *ChannelSettings { + return &ChannelSettings{ + Url: url, + Subprotocols: subprotocols, + MaxSessionTime: 0, + } +} + +func ca(channel *ChannelSettings) *ChannelSettings { + channel = channel.Clone() + channel.CAPem = "Valid CA data" + + return channel +} + +func timeout(channel *ChannelSettings) *ChannelSettings { + channel = channel.Clone() + channel.MaxSessionTime = 600 + + return channel +} + +func header(channel *ChannelSettings, values ...string) *ChannelSettings { + if len(values) == 0 { + values = []string{"Dummy Value"} + } + + channel = channel.Clone() + channel.Header = http.Header{ + "Header": values, + } + + return channel +} + +func TestClone(t *testing.T) { + a := ca(header(channel("ws:", "", ""))) + b := a.Clone() + + if a == b { + t.Fatalf("Address of cloned channel didn't change") + } + + if &a.Subprotocols == &b.Subprotocols { + t.Fatalf("Address of cloned subprotocols didn't change") + } + + if &a.Header == &b.Header { + t.Fatalf("Address of cloned header didn't change") + } +} + +func TestValidate(t *testing.T) { + for i, tc := range []struct { + channel *ChannelSettings + valid bool + msg string + }{ + {nil, false, "nil channel"}, + {channel("", ""), false, "empty URL"}, + {channel("ws:"), false, "empty subprotocols"}, + {channel("ws:", "foo"), true, "any subprotocol"}, + {channel("ws:", "foo", "bar"), true, "multiple subprotocols"}, + {channel("ws:", ""), true, "websocket URL"}, + {channel("wss:", ""), true, "secure websocket URL"}, + {channel("http:", ""), false, "HTTP URL"}, + {channel("https:", ""), false, " HTTPS URL"}, + {ca(channel("ws:", "")), true, "any CA pem"}, + {header(channel("ws:", "")), true, "any headers"}, + {ca(header(channel("ws:", ""))), true, "PEM and headers"}, + } { + if err := tc.channel.Validate(); (err != nil) == tc.valid { + t.Fatalf("test case %d: "+tc.msg+": valid=%v: %s: %+v", i, tc.valid, err, tc.channel) + } + } +} + +func TestDialer(t *testing.T) { + channel := channel("ws:", "foo") + dialer := channel.Dialer() + + if len(dialer.Subprotocols) != len(channel.Subprotocols) { + t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols) + } + + for i, subprotocol := range channel.Subprotocols { + if dialer.Subprotocols[i] != subprotocol { + t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols) + } + } + + if dialer.TLSClientConfig != nil { + t.Fatalf("Unexpected TLSClientConfig: %+v", dialer) + } + + channel = ca(channel) + dialer = channel.Dialer() + + if dialer.TLSClientConfig == nil || dialer.TLSClientConfig.RootCAs == nil { + t.Fatalf("Custom CA certificates not recognised!") + } +} + +func TestIsEqual(t *testing.T) { + chann := channel("ws:", "foo") + + chann_header2 := header(chann, "extra") + chann_header3 := header(chann) + chann_header3.Header.Add("Extra", "extra") + + chann_ca2 := ca(chann) + chann_ca2.CAPem = "other value" + + for i, tc := range []struct { + channelA *ChannelSettings + channelB *ChannelSettings + expected bool + }{ + {nil, nil, true}, + {chann, nil, false}, + {nil, chann, false}, + {chann, chann, true}, + {chann.Clone(), chann.Clone(), true}, + {chann, channel("foo:"), false}, + {chann, channel(chann.Url), false}, + {header(chann), header(chann), true}, + {chann_header2, chann_header2, true}, + {chann_header3, chann_header3, true}, + {header(chann), chann_header2, false}, + {header(chann), chann_header3, false}, + {header(chann), chann, false}, + {chann, header(chann), false}, + {ca(chann), ca(chann), true}, + {ca(chann), chann, false}, + {chann, ca(chann), false}, + {ca(header(chann)), ca(header(chann)), true}, + {chann_ca2, ca(chann), false}, + {chann, timeout(chann), false}, + } { + if actual := tc.channelA.IsEqual(tc.channelB); tc.expected != actual { + t.Fatalf( + "test case %d: Comparison:\n-%+v\n+%+v\nexpected=%v: actual=%v", + i, tc.channelA, tc.channelB, tc.expected, actual, + ) + } + } +} |