summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfeistel <6742251-feistel@users.noreply.gitlab.com>2021-08-11 19:02:04 +0000
committerfeistel <6742251-feistel@users.noreply.gitlab.com>2021-08-11 19:02:04 +0000
commit883615685b54c5a15856b2bbb469c9875bdcbd68 (patch)
treee643d41ef3020f829ad788d45c9e1760a4871355
parent7b44ce1d4a0716d27acabb4f826eb5613dade082 (diff)
downloadgitlab-shell-883615685b54c5a15856b2bbb469c9875bdcbd68.tar.gz
fix: validate client cert paths exist on disk before proceeding
-rw-r--r--client/httpclient.go28
-rw-r--r--client/httpsclient_test.go45
2 files changed, 44 insertions, 29 deletions
diff --git a/client/httpclient.go b/client/httpclient.go
index f2e82e5..15bae25 100644
--- a/client/httpclient.go
+++ b/client/httpclient.go
@@ -5,9 +5,11 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
+ "fmt"
"io/ioutil"
"net"
"net/http"
+ "os"
"path/filepath"
"strings"
"time"
@@ -25,6 +27,10 @@ const (
defaultReadTimeoutSeconds = 300
)
+var (
+ ErrCafileNotFound = errors.New("cafile not found")
+)
+
type HttpClient struct {
*http.Client
Host string
@@ -60,15 +66,6 @@ func NewHTTPClient(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, self
// NewHTTPClientWithOpts builds an HTTP client using the provided options
func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64, opts []HTTPClientOpt) (*HttpClient, error) {
- hcc := &httpClientCfg{
- caFile: caFile,
- caPath: caPath,
- }
-
- for _, opt := range opts {
- opt(hcc)
- }
-
var transport *http.Transport
var host string
var err error
@@ -77,6 +74,19 @@ func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath stri
} else if strings.HasPrefix(gitlabURL, httpProtocol) {
transport, host = buildHttpTransport(gitlabURL)
} else if strings.HasPrefix(gitlabURL, httpsProtocol) {
+ hcc := &httpClientCfg{
+ caFile: caFile,
+ caPath: caPath,
+ }
+
+ for _, opt := range opts {
+ opt(hcc)
+ }
+
+ if _, err := os.Stat(caFile); err != nil {
+ return nil, fmt.Errorf("cannot find cafile '%s': %w", caFile, ErrCafileNotFound)
+ }
+
transport, host, err = buildHttpsTransport(*hcc, selfSignedCert, gitlabURL)
if err != nil {
return nil, err
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
index 48efa91..be1d49c 100644
--- a/client/httpsclient_test.go
+++ b/client/httpsclient_test.go
@@ -28,10 +28,7 @@ func TestSuccessfulRequests(t *testing.T) {
{
desc: "Valid CaPath",
caPath: path.Join(testhelper.TestRoot, "certs/valid"),
- },
- {
- desc: "Self signed cert option enabled",
- selfSigned: true,
+ caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
},
{
desc: "Invalid cert with self signed cert option enabled",
@@ -51,7 +48,8 @@ func TestSuccessfulRequests(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- client := setupWithRequests(t, tc.caFile, tc.caPath, tc.clientCAPath, tc.clientCertPath, tc.clientKeyPath, tc.selfSigned)
+ client, err := setupWithRequests(t, tc.caFile, tc.caPath, tc.clientCAPath, tc.clientCertPath, tc.clientKeyPath, tc.selfSigned)
+ require.NoError(t, err)
response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
@@ -68,13 +66,15 @@ func TestSuccessfulRequests(t *testing.T) {
func TestFailedRequests(t *testing.T) {
testCases := []struct {
- desc string
- caFile string
- caPath string
+ desc string
+ caFile string
+ caPath string
+ expectedError string
}{
{
- desc: "Invalid CaFile",
- caFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
+ desc: "Invalid CaFile",
+ caFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
+ expectedError: "Internal API unreachable",
},
{
desc: "Invalid CaPath",
@@ -87,17 +87,21 @@ func TestFailedRequests(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- client := setupWithRequests(t, tc.caFile, tc.caPath, "", "", "", false)
-
- _, err := client.Get(context.Background(), "/hello")
- require.Error(t, err)
-
- require.Equal(t, err.Error(), "Internal API unreachable")
+ client, err := setupWithRequests(t, tc.caFile, tc.caPath, "", "", "", false)
+ if tc.caFile == "" {
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrCafileNotFound)
+ } else {
+ _, err = client.Get(context.Background(), "/hello")
+ require.Error(t, err)
+
+ require.Equal(t, err.Error(), tc.expectedError)
+ }
})
}
}
-func setupWithRequests(t *testing.T, caFile, caPath, clientCAPath, clientCertPath, clientKeyPath string, selfSigned bool) *GitlabNetClient {
+func setupWithRequests(t *testing.T, caFile, caPath, clientCAPath, clientCertPath, clientKeyPath string, selfSigned bool) (*GitlabNetClient, error) {
testhelper.PrepareTestRootDir(t)
requests := []testserver.TestRequestHandler{
@@ -119,10 +123,11 @@ func setupWithRequests(t *testing.T, caFile, caPath, clientCAPath, clientCertPat
}
httpClient, err := NewHTTPClientWithOpts(url, "", caFile, caPath, selfSigned, 1, opts)
- require.NoError(t, err)
+ if err != nil {
+ return nil, err
+ }
client, err := NewGitlabNetClient("", "", "", httpClient)
- require.NoError(t, err)
- return client
+ return client, err
}