summaryrefslogtreecommitdiff
path: root/workhorse/internal/helper/helpers_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/helper/helpers_test.go')
-rw-r--r--workhorse/internal/helper/helpers_test.go258
1 files changed, 258 insertions, 0 deletions
diff --git a/workhorse/internal/helper/helpers_test.go b/workhorse/internal/helper/helpers_test.go
new file mode 100644
index 00000000000..6a895aded03
--- /dev/null
+++ b/workhorse/internal/helper/helpers_test.go
@@ -0,0 +1,258 @@
+package helper
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+)
+
+func TestFixRemoteAddr(t *testing.T) {
+ testCases := []struct {
+ initial string
+ forwarded string
+ expected string
+ }{
+ {initial: "@", forwarded: "", expected: "127.0.0.1:0"},
+ {initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
+ {initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"},
+ {initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"},
+ {initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"},
+ {initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
+ }
+
+ for _, tc := range testCases {
+ req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil)
+ require.NoError(t, err)
+
+ req.RemoteAddr = tc.initial
+
+ if tc.forwarded != "" {
+ req.Header.Add("X-Forwarded-For", tc.forwarded)
+ }
+
+ FixRemoteAddr(req)
+
+ require.Equal(t, tc.expected, req.RemoteAddr)
+ }
+}
+
+func TestSetForwardedForGeneratesHeader(t *testing.T) {
+ testCases := []struct {
+ remoteAddr string
+ previousForwardedFor []string
+ expected string
+ }{
+ {
+ "8.8.8.8:3000",
+ nil,
+ "8.8.8.8",
+ },
+ {
+ "8.8.8.8:3000",
+ []string{"138.124.33.63, 151.146.211.237"},
+ "138.124.33.63, 151.146.211.237, 8.8.8.8",
+ },
+ {
+ "8.8.8.8:3000",
+ []string{"8.154.76.107", "115.206.118.179"},
+ "8.154.76.107, 115.206.118.179, 8.8.8.8",
+ },
+ }
+ for _, tc := range testCases {
+ headers := http.Header{}
+ originalRequest := http.Request{
+ RemoteAddr: tc.remoteAddr,
+ }
+
+ if tc.previousForwardedFor != nil {
+ originalRequest.Header = http.Header{
+ "X-Forwarded-For": tc.previousForwardedFor,
+ }
+ }
+
+ SetForwardedFor(&headers, &originalRequest)
+
+ result := headers.Get("X-Forwarded-For")
+ if result != tc.expected {
+ t.Fatalf("Expected %v, got %v", tc.expected, result)
+ }
+ }
+}
+
+func TestReadRequestBody(t *testing.T) {
+ data := []byte("123456")
+ rw := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
+
+ result, err := ReadRequestBody(rw, req, 1000)
+ require.NoError(t, err)
+ require.Equal(t, data, result)
+}
+
+func TestReadRequestBodyLimit(t *testing.T) {
+ data := []byte("123456")
+ rw := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
+
+ _, err := ReadRequestBody(rw, req, 2)
+ require.Error(t, err)
+}
+
+func TestCloneRequestWithBody(t *testing.T) {
+ input := []byte("test")
+ newInput := []byte("new body")
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input))
+ newReq := CloneRequestWithNewBody(req, newInput)
+
+ require.NotEqual(t, req, newReq)
+ require.NotEqual(t, req.Body, newReq.Body)
+ require.NotEqual(t, len(newInput), newReq.ContentLength)
+
+ var buffer bytes.Buffer
+ io.Copy(&buffer, newReq.Body)
+ require.Equal(t, newInput, buffer.Bytes())
+}
+
+func TestApplicationJson(t *testing.T) {
+ req, _ := http.NewRequest("POST", "/test", nil)
+ req.Header.Set("Content-Type", "application/json")
+
+ require.True(t, IsApplicationJson(req), "expected to match 'application/json' as 'application/json'")
+
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+ require.True(t, IsApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'")
+
+ req.Header.Set("Content-Type", "text/plain")
+ require.False(t, IsApplicationJson(req), "expected not to match 'text/plain' as 'application/json'")
+}
+
+func TestFail500WorksWithNils(t *testing.T) {
+ body := bytes.NewBuffer(nil)
+ w := httptest.NewRecorder()
+ w.Body = body
+
+ Fail500(w, nil, nil)
+
+ require.Equal(t, http.StatusInternalServerError, w.Code)
+ require.Equal(t, "Internal server error\n", body.String())
+}
+
+func TestLogError(t *testing.T) {
+ tests := []struct {
+ name string
+ method string
+ uri string
+ err error
+ logMatchers []string
+ }{
+ {
+ name: "nil_request",
+ err: fmt.Errorf("crash"),
+ logMatchers: []string{
+ `level=error msg="unknown error" error=crash`,
+ },
+ },
+ {
+ name: "nil_request_nil_error",
+ err: nil,
+ logMatchers: []string{
+ `level=error msg="unknown error" error="<nil>"`,
+ },
+ },
+ {
+ name: "basic_url",
+ method: "GET",
+ uri: "http://localhost:3000/",
+ err: fmt.Errorf("error"),
+ logMatchers: []string{
+ `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/"`,
+ },
+ },
+ {
+ name: "secret_url",
+ method: "GET",
+ uri: "http://localhost:3000/path?certificate=123&sharedSecret=123&import_url=the_url&my_password_string=password",
+ err: fmt.Errorf("error"),
+ logMatchers: []string{
+ `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/path\?certificate=\[FILTERED\]&sharedSecret=\[FILTERED\]&import_url=\[FILTERED\]&my_password_string=\[FILTERED\]"`,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := &bytes.Buffer{}
+
+ oldOut := logrus.StandardLogger().Out
+ logrus.StandardLogger().Out = buf
+ defer func() {
+ logrus.StandardLogger().Out = oldOut
+ }()
+
+ var r *http.Request
+ if tt.uri != "" {
+ r = httptest.NewRequest(tt.method, tt.uri, nil)
+ }
+
+ LogError(r, tt.err)
+
+ logString := buf.String()
+ for _, v := range tt.logMatchers {
+ require.Regexp(t, v, logString)
+ }
+ })
+ }
+}
+
+func TestLogErrorWithFields(t *testing.T) {
+ tests := []struct {
+ name string
+ request *http.Request
+ err error
+ fields map[string]interface{}
+ logMatcher string
+ }{
+ {
+ name: "nil_request",
+ err: fmt.Errorf("crash"),
+ fields: map[string]interface{}{"extra_one": 123},
+ logMatcher: `level=error msg="unknown error" error=crash extra_one=123`,
+ },
+ {
+ name: "nil_request_nil_error",
+ err: nil,
+ fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"},
+ logMatcher: `level=error msg="unknown error" error="<nil>" extra_one=123 extra_two=test`,
+ },
+ {
+ name: "basic_url",
+ request: httptest.NewRequest("GET", "http://localhost:3000/", nil),
+ err: fmt.Errorf("error"),
+ fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"},
+ logMatcher: `level=error msg=error correlation_id= error=error extra_one=123 extra_two=test method=GET uri="http://localhost:3000/`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := &bytes.Buffer{}
+
+ oldOut := logrus.StandardLogger().Out
+ logrus.StandardLogger().Out = buf
+ defer func() {
+ logrus.StandardLogger().Out = oldOut
+ }()
+
+ LogErrorWithFields(tt.request, tt.err, tt.fields)
+
+ logString := buf.String()
+ require.Contains(t, logString, tt.logMatcher)
+ })
+ }
+}