diff options
Diffstat (limited to 'workhorse/internal/upstream/roundtripper/roundtripper_test.go')
-rw-r--r-- | workhorse/internal/upstream/roundtripper/roundtripper_test.go | 56 |
1 files changed, 55 insertions, 1 deletions
diff --git a/workhorse/internal/upstream/roundtripper/roundtripper_test.go b/workhorse/internal/upstream/roundtripper/roundtripper_test.go index 79ffa244918..eed71cc5bae 100644 --- a/workhorse/internal/upstream/roundtripper/roundtripper_test.go +++ b/workhorse/internal/upstream/roundtripper/roundtripper_test.go @@ -1,6 +1,13 @@ package roundtripper import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" "strconv" "testing" @@ -12,6 +19,7 @@ func TestMustParseAddress(t *testing.T) { {"1.2.3.4:56", "http", "1.2.3.4:56"}, {"[::1]:23", "http", "::1:23"}, {"4.5.6.7", "http", "4.5.6.7:http"}, + {"4.5.6.7", "https", "4.5.6.7:https"}, } for i, example := range successExamples { t.Run(strconv.Itoa(i), func(t *testing.T) { @@ -23,7 +31,6 @@ func TestMustParseAddress(t *testing.T) { func TestMustParseAddressPanic(t *testing.T) { panicExamples := []struct{ address, scheme string }{ {"1.2.3.4", ""}, - {"1.2.3.4", "https"}, } for i, panicExample := range panicExamples { @@ -37,3 +44,50 @@ func TestMustParseAddressPanic(t *testing.T) { }) } } + +func TestSupportsHTTPBackend(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + fmt.Fprint(w, "successful response") + })) + defer ts.Close() + + testNewBackendRoundTripper(t, ts, nil, "successful response") +} + +func TestSupportsHTTPSBackend(t *testing.T) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + fmt.Fprint(w, "successful response") + })) + defer ts.Close() + + certpool := x509.NewCertPool() + certpool.AddCert(ts.Certificate()) + tlsClientConfig := &tls.Config{ + RootCAs: certpool, + } + + testNewBackendRoundTripper(t, ts, tlsClientConfig, "successful response") +} + +func testNewBackendRoundTripper(t *testing.T, ts *httptest.Server, tlsClientConfig *tls.Config, expectedResponseBody string) { + t.Helper() + + backend, err := url.Parse(ts.URL) + require.NoError(t, err, "parse url") + + rt := newBackendRoundTripper(backend, "", 0, true, tlsClientConfig) + + req, err := http.NewRequest("GET", ts.URL+"/", nil) + require.NoError(t, err, "build request") + + response, err := rt.RoundTrip(req) + require.NoError(t, err, "perform roundtrip") + defer response.Body.Close() + + body, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + require.Equal(t, expectedResponseBody, string(body)) +} |