diff options
Diffstat (limited to 'workhorse/internal/staticpages')
-rw-r--r-- | workhorse/internal/staticpages/deploy_page.go | 26 | ||||
-rw-r--r-- | workhorse/internal/staticpages/deploy_page_test.go | 59 | ||||
-rw-r--r-- | workhorse/internal/staticpages/error_pages.go | 138 | ||||
-rw-r--r-- | workhorse/internal/staticpages/error_pages_test.go | 191 | ||||
-rw-r--r-- | workhorse/internal/staticpages/servefile.go | 84 | ||||
-rw-r--r-- | workhorse/internal/staticpages/servefile_test.go | 134 | ||||
-rw-r--r-- | workhorse/internal/staticpages/static.go | 5 |
7 files changed, 637 insertions, 0 deletions
diff --git a/workhorse/internal/staticpages/deploy_page.go b/workhorse/internal/staticpages/deploy_page.go new file mode 100644 index 00000000000..d08ed449ae6 --- /dev/null +++ b/workhorse/internal/staticpages/deploy_page.go @@ -0,0 +1,26 @@ +package staticpages + +import ( + "io/ioutil" + "net/http" + "path/filepath" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +func (s *Static) DeployPage(handler http.Handler) http.Handler { + deployPage := filepath.Join(s.DocumentRoot, "index.html") + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, err := ioutil.ReadFile(deployPage) + if err != nil { + handler.ServeHTTP(w, r) + return + } + + helper.SetNoCacheHeaders(w.Header()) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(data) + }) +} diff --git a/workhorse/internal/staticpages/deploy_page_test.go b/workhorse/internal/staticpages/deploy_page_test.go new file mode 100644 index 00000000000..4b081e73a97 --- /dev/null +++ b/workhorse/internal/staticpages/deploy_page_test.go @@ -0,0 +1,59 @@ +package staticpages + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestIfNoDeployPageExist(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + w := httptest.NewRecorder() + + executed := false + st := &Static{dir} + st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + executed = true + })).ServeHTTP(w, nil) + if !executed { + t.Error("The handler should get executed") + } +} + +func TestIfDeployPageExist(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + deployPage := "DEPLOY" + ioutil.WriteFile(filepath.Join(dir, "index.html"), []byte(deployPage), 0600) + + w := httptest.NewRecorder() + + executed := false + st := &Static{dir} + st.DeployPage(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + executed = true + })).ServeHTTP(w, nil) + if executed { + t.Error("The handler should not get executed") + } + w.Flush() + + require.Equal(t, 200, w.Code) + testhelper.RequireResponseBody(t, w, deployPage) +} diff --git a/workhorse/internal/staticpages/error_pages.go b/workhorse/internal/staticpages/error_pages.go new file mode 100644 index 00000000000..3cc89d9f811 --- /dev/null +++ b/workhorse/internal/staticpages/error_pages.go @@ -0,0 +1,138 @@ +package staticpages + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "path/filepath" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + staticErrorResponses = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_static_error_responses", + Help: "How many HTTP responses have been changed to a static error page, by HTTP status code.", + }, + []string{"code"}, + ) +) + +type ErrorFormat int + +const ( + ErrorFormatHTML ErrorFormat = iota + ErrorFormatJSON + ErrorFormatText +) + +type errorPageResponseWriter struct { + rw http.ResponseWriter + status int + hijacked bool + path string + format ErrorFormat +} + +func (s *errorPageResponseWriter) Header() http.Header { + return s.rw.Header() +} + +func (s *errorPageResponseWriter) Write(data []byte) (int, error) { + if s.status == 0 { + s.WriteHeader(http.StatusOK) + } + if s.hijacked { + return len(data), nil + } + return s.rw.Write(data) +} + +func (s *errorPageResponseWriter) WriteHeader(status int) { + if s.status != 0 { + return + } + + s.status = status + + if s.status < 400 || s.status > 599 || s.rw.Header().Get("X-GitLab-Custom-Error") != "" { + s.rw.WriteHeader(status) + return + } + + var contentType string + var data []byte + switch s.format { + case ErrorFormatText: + contentType, data = s.writeText() + case ErrorFormatJSON: + contentType, data = s.writeJSON() + default: + contentType, data = s.writeHTML() + } + + if contentType == "" { + s.rw.WriteHeader(status) + return + } + + s.hijacked = true + staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc() + + helper.SetNoCacheHeaders(s.rw.Header()) + s.rw.Header().Set("Content-Type", contentType) + s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + s.rw.Header().Del("Transfer-Encoding") + s.rw.WriteHeader(s.status) + s.rw.Write(data) +} + +func (s *errorPageResponseWriter) writeHTML() (string, []byte) { + if s.rw.Header().Get("Content-Type") != "application/json" { + errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status)) + + // check if custom error page exists, serve this page instead + if data, err := ioutil.ReadFile(errorPageFile); err == nil { + return "text/html; charset=utf-8", data + } + } + + return "", nil +} + +func (s *errorPageResponseWriter) writeJSON() (string, []byte) { + message, err := json.Marshal(map[string]interface{}{"error": http.StatusText(s.status), "status": s.status}) + if err != nil { + return "", nil + } + + return "application/json; charset=utf-8", append(message, "\n"...) +} + +func (s *errorPageResponseWriter) writeText() (string, []byte) { + return "text/plain; charset=utf-8", []byte(http.StatusText(s.status) + "\n") +} + +func (s *errorPageResponseWriter) flush() { + s.WriteHeader(http.StatusOK) +} + +func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler { + if disabled { + return handler + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rw := errorPageResponseWriter{ + rw: w, + path: st.DocumentRoot, + format: format, + } + defer rw.flush() + handler.ServeHTTP(&rw, r) + }) +} diff --git a/workhorse/internal/staticpages/error_pages_test.go b/workhorse/internal/staticpages/error_pages_test.go new file mode 100644 index 00000000000..05ec06cd429 --- /dev/null +++ b/workhorse/internal/staticpages/error_pages_test.go @@ -0,0 +1,191 @@ +package staticpages + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" +) + +func TestIfErrorPageIsPresented(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + upstreamBody := "Not Found" + n, err := fmt.Fprint(w, upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorPage) + testhelper.RequireResponseHeader(t, w, "Content-Type", "text/html; charset=utf-8") +} + +func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + w := httptest.NewRecorder() + errorResponse := "ERROR" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + fmt.Fprint(w, errorResponse) + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorResponse) +} + +func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + serverError := "Interesting Server Error" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(500) + fmt.Fprint(w, serverError) + }) + st := &Static{dir} + st.ErrorPagesUnless(true, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + require.Equal(t, 500, w.Code) + testhelper.RequireResponseBody(t, w, serverError) +} + +func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + serverError := "Interesting Server Error" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("X-GitLab-Custom-Error", "1") + w.WriteHeader(500) + fmt.Fprint(w, serverError) + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + require.Equal(t, 500, w.Code) + testhelper.RequireResponseBody(t, w, serverError) +} + +func TestErrorPageInterceptedByContentType(t *testing.T) { + testCases := []struct { + contentType string + intercepted bool + }{ + {contentType: "application/json", intercepted: false}, + {contentType: "text/plain", intercepted: true}, + {contentType: "text/html", intercepted: true}, + {contentType: "", intercepted: true}, + } + + for _, tc := range testCases { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + serverError := "Interesting Server Error" + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("Content-Type", tc.contentType) + w.WriteHeader(500) + fmt.Fprint(w, serverError) + }) + st := &Static{dir} + st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil) + w.Flush() + require.Equal(t, 500, w.Code) + + if tc.intercepted { + testhelper.RequireResponseBody(t, w, errorPage) + } else { + testhelper.RequireResponseBody(t, w, serverError) + } + } +} + +func TestIfErrorPageIsPresentedJSON(t *testing.T) { + errorPage := "{\"error\":\"Not Found\",\"status\":404}\n" + + w := httptest.NewRecorder() + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + upstreamBody := "This string is ignored" + n, err := fmt.Fprint(w, upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + }) + st := &Static{""} + st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorPage) + testhelper.RequireResponseHeader(t, w, "Content-Type", "application/json; charset=utf-8") +} + +func TestIfErrorPageIsPresentedText(t *testing.T) { + errorPage := "Not Found\n" + + w := httptest.NewRecorder() + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(404) + upstreamBody := "This string is ignored" + n, err := fmt.Fprint(w, upstreamBody) + require.NoError(t, err) + require.Equal(t, len(upstreamBody), n, "bytes written") + }) + st := &Static{""} + st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil) + w.Flush() + + require.Equal(t, 404, w.Code) + testhelper.RequireResponseBody(t, w, errorPage) + testhelper.RequireResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8") +} diff --git a/workhorse/internal/staticpages/servefile.go b/workhorse/internal/staticpages/servefile.go new file mode 100644 index 00000000000..c98bc030bc2 --- /dev/null +++ b/workhorse/internal/staticpages/servefile.go @@ -0,0 +1,84 @@ +package staticpages + +import ( + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix" +) + +type CacheMode int + +const ( + CacheDisabled CacheMode = iota + CacheExpireMax +) + +// BUG/QUIRK: If a client requests 'foo%2Fbar' and 'foo/bar' exists, +// handleServeFile will serve foo/bar instead of passing the request +// upstream. +func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + file := filepath.Join(s.DocumentRoot, prefix.Strip(r.URL.Path)) + + // The filepath.Join does Clean traversing directories up + if !strings.HasPrefix(file, s.DocumentRoot) { + helper.Fail500(w, r, &os.PathError{ + Op: "open", + Path: file, + Err: os.ErrInvalid, + }) + return + } + + var content *os.File + var fi os.FileInfo + var err error + + // Serve pre-gzipped assets + if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") { + content, fi, err = helper.OpenFile(file + ".gz") + if err == nil { + w.Header().Set("Content-Encoding", "gzip") + } + } + + // If not found, open the original file + if content == nil || err != nil { + content, fi, err = helper.OpenFile(file) + } + if err != nil { + if notFoundHandler != nil { + notFoundHandler.ServeHTTP(w, r) + } else { + http.NotFound(w, r) + } + return + } + defer content.Close() + + switch cache { + case CacheExpireMax: + // Cache statically served files for 1 year + cacheUntil := time.Now().AddDate(1, 0, 0).Format(http.TimeFormat) + w.Header().Set("Cache-Control", "public") + w.Header().Set("Expires", cacheUntil) + } + + log.WithContextFields(r.Context(), log.Fields{ + "file": file, + "encoding": w.Header().Get("Content-Encoding"), + "method": r.Method, + "uri": mask.URL(r.RequestURI), + }).Info("Send static file") + + http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content) + }) +} diff --git a/workhorse/internal/staticpages/servefile_test.go b/workhorse/internal/staticpages/servefile_test.go new file mode 100644 index 00000000000..e136b876298 --- /dev/null +++ b/workhorse/internal/staticpages/servefile_test.go @@ -0,0 +1,134 @@ +package staticpages + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestServingNonExistingFile(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 404, w.Code) +} + +func TestServingDirectory(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 404, w.Code) +} + +func TestServingMalformedUri(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 404, w.Code) +} + +func TestExecutingHandlerWhenNoFileFound(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + executed := false + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + executed = (r == httpRequest) + })).ServeHTTP(nil, httpRequest) + if !executed { + t.Error("The handler should get executed") + } +} + +func TestServingTheActualFile(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + fileContent := "STATIC" + ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 200, w.Code) + if w.Body.String() != fileContent { + t.Error("We should serve the file: ", w.Body.String()) + } +} + +func testServingThePregzippedFile(t *testing.T, enableGzip bool) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + + if enableGzip { + httpRequest.Header.Set("Accept-Encoding", "gzip, deflate") + } + + fileContent := "STATIC" + + var fileGzipContent bytes.Buffer + fileGzip := gzip.NewWriter(&fileGzipContent) + fileGzip.Write([]byte(fileContent)) + fileGzip.Close() + + ioutil.WriteFile(filepath.Join(dir, "file.gz"), fileGzipContent.Bytes(), 0600) + ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) + + w := httptest.NewRecorder() + st := &Static{dir} + st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest) + require.Equal(t, 200, w.Code) + if enableGzip { + testhelper.RequireResponseHeader(t, w, "Content-Encoding", "gzip") + if !bytes.Equal(w.Body.Bytes(), fileGzipContent.Bytes()) { + t.Error("We should serve the pregzipped file") + } + } else { + require.Equal(t, 200, w.Code) + testhelper.RequireResponseHeader(t, w, "Content-Encoding") + if w.Body.String() != fileContent { + t.Error("We should serve the file: ", w.Body.String()) + } + } +} + +func TestServingThePregzippedFile(t *testing.T) { + testServingThePregzippedFile(t, true) +} + +func TestServingThePregzippedFileWithoutEncoding(t *testing.T) { + testServingThePregzippedFile(t, false) +} diff --git a/workhorse/internal/staticpages/static.go b/workhorse/internal/staticpages/static.go new file mode 100644 index 00000000000..b42351f15f5 --- /dev/null +++ b/workhorse/internal/staticpages/static.go @@ -0,0 +1,5 @@ +package staticpages + +type Static struct { + DocumentRoot string +} |