diff options
-rw-r--r-- | go/internal/config/config.go | 18 | ||||
-rw-r--r-- | go/internal/config/httpclient.go | 74 | ||||
-rw-r--r-- | go/internal/config/httpclient_test.go | 22 | ||||
-rw-r--r-- | go/internal/gitlabnet/client.go | 50 | ||||
-rw-r--r-- | go/internal/gitlabnet/client_test.go | 12 | ||||
-rw-r--r-- | go/internal/gitlabnet/discover/client.go | 2 | ||||
-rw-r--r-- | go/internal/gitlabnet/httpclient.go | 69 | ||||
-rw-r--r-- | go/internal/gitlabnet/httpclient_test.go | 15 | ||||
-rw-r--r-- | go/internal/gitlabnet/twofactorrecover/client.go | 2 |
9 files changed, 143 insertions, 121 deletions
diff --git a/go/internal/config/config.go b/go/internal/config/config.go index 09fb62a..6085493 100644 --- a/go/internal/config/config.go +++ b/go/internal/config/config.go @@ -7,16 +7,14 @@ import ( "path" "path/filepath" "strings" - "time" yaml "gopkg.in/yaml.v2" ) const ( - configFile = "config.yml" - logFile = "gitlab-shell.log" - defaultSecretFileName = ".gitlab_shell_secret" - defaultReadTimeoutSeconds = 300 + configFile = "config.yml" + logFile = "gitlab-shell.log" + defaultSecretFileName = ".gitlab_shell_secret" ) type MigrationConfig struct { @@ -40,6 +38,7 @@ type Config struct { SecretFilePath string `yaml:"secret_file"` Secret string `yaml:"secret"` HttpSettings HttpSettingsConfig `yaml:"http_settings"` + HttpClient *HttpClient } func New() (*Config, error) { @@ -73,15 +72,6 @@ func (c *Config) FeatureEnabled(featureName string) bool { return false } -func (c *HttpSettingsConfig) ReadTimeout() time.Duration { - timeoutSeconds := c.ReadTimeoutSeconds - if c.ReadTimeoutSeconds == 0 { - timeoutSeconds = defaultReadTimeoutSeconds - } - - return time.Duration(timeoutSeconds) * time.Second -} - func newFromFile(filename string) (*Config, error) { cfg := &Config{RootDir: path.Dir(filename)} diff --git a/go/internal/config/httpclient.go b/go/internal/config/httpclient.go new file mode 100644 index 0000000..82807a6 --- /dev/null +++ b/go/internal/config/httpclient.go @@ -0,0 +1,74 @@ +package config + +import ( + "context" + "net" + "net/http" + "strings" + "time" +) + +const ( + socketBaseUrl = "http://unix" + UnixSocketProtocol = "http+unix://" + HttpProtocol = "http://" + defaultReadTimeoutSeconds = 300 +) + +type HttpClient struct { + HttpClient *http.Client + Host string +} + +func (c *Config) GetHttpClient() *HttpClient { + if c.HttpClient != nil { + return c.HttpClient + } + + var transport *http.Transport + var host string + if strings.HasPrefix(c.GitlabUrl, UnixSocketProtocol) { + transport, host = c.buildSocketTransport() + } else if strings.HasPrefix(c.GitlabUrl, HttpProtocol) { + transport, host = c.buildHttpTransport() + } else { + return nil + } + + httpClient := &http.Client{ + Transport: transport, + Timeout: c.readTimeout(), + } + + client := &HttpClient{HttpClient: httpClient, Host: host} + + c.HttpClient = client + + return client +} + +func (c *Config) buildSocketTransport() (*http.Transport, string) { + socketPath := strings.TrimPrefix(c.GitlabUrl, UnixSocketProtocol) + transport := &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", socketPath) + }, + } + + return transport, socketBaseUrl +} + +func (c *Config) buildHttpTransport() (*http.Transport, string) { + return &http.Transport{}, c.GitlabUrl +} + +func (c *Config) readTimeout() time.Duration { + timeoutSeconds := c.HttpSettings.ReadTimeoutSeconds + + if timeoutSeconds == 0 { + timeoutSeconds = defaultReadTimeoutSeconds + } + + return time.Duration(timeoutSeconds) * time.Second +} diff --git a/go/internal/config/httpclient_test.go b/go/internal/config/httpclient_test.go new file mode 100644 index 0000000..474deba --- /dev/null +++ b/go/internal/config/httpclient_test.go @@ -0,0 +1,22 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadTimeout(t *testing.T) { + expectedSeconds := uint64(300) + + config := &Config{ + GitlabUrl: "http://localhost:3000", + HttpSettings: HttpSettingsConfig{ReadTimeoutSeconds: expectedSeconds}, + } + client := config.GetHttpClient() + + require.NotNil(t, client) + assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.HttpClient.Timeout) +} diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go index 839f206..c0f7f97 100644 --- a/go/internal/gitlabnet/client.go +++ b/go/internal/gitlabnet/client.go @@ -17,25 +17,24 @@ const ( secretHeaderName = "Gitlab-Shared-Secret" ) -type GitlabClient interface { - Get(path string) (*http.Response, error) - Post(path string, data interface{}) (*http.Response, error) -} - type ErrorResponse struct { Message string `json:"message"` } -func GetClient(config *config.Config) (GitlabClient, error) { - url := config.GitlabUrl - if strings.HasPrefix(url, UnixSocketProtocol) { - return buildSocketClient(config), nil - } - if strings.HasPrefix(url, HttpProtocol) { - return buildHttpClient(config), nil +type GitlabClient struct { + httpClient *http.Client + config *config.Config + host string +} + +func GetClient(config *config.Config) (*GitlabClient, error) { + client := config.GetHttpClient() + + if client == nil { + return nil, fmt.Errorf("Unsupported protocol") } - return nil, fmt.Errorf("Unsupported protocol") + return &GitlabClient{httpClient: client.HttpClient, config: config, host: client.Host}, nil } func normalizePath(path string) string { @@ -85,13 +84,32 @@ func parseError(resp *http.Response) error { } -func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) { - encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret)) +func (c *GitlabClient) Get(path string) (*http.Response, error) { + return c.doRequest("GET", path, nil) +} + +func (c *GitlabClient) Post(path string, data interface{}) (*http.Response, error) { + return c.doRequest("POST", path, data) +} + +func (c *GitlabClient) doRequest(method, path string, data interface{}) (*http.Response, error) { + request, err := newRequest(method, c.host, path, data) + if err != nil { + return nil, err + } + + user, password := c.config.HttpSettings.User, c.config.HttpSettings.Password + if user != "" && password != "" { + request.SetBasicAuth(user, password) + } + + encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.config.Secret)) request.Header.Set(secretHeaderName, encodedSecret) request.Header.Add("Content-Type", "application/json") + request.Close = true - response, err := client.Do(request) + response, err := c.httpClient.Do(request) if err != nil { return nil, fmt.Errorf("Internal API unreachable") } diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go index 5925273..f9aa289 100644 --- a/go/internal/gitlabnet/client_test.go +++ b/go/internal/gitlabnet/client_test.go @@ -98,7 +98,7 @@ func TestClients(t *testing.T) { } } -func testSuccessfulGet(t *testing.T, client GitlabClient) { +func testSuccessfulGet(t *testing.T, client *GitlabClient) { t.Run("Successful get", func(t *testing.T) { response, err := client.Get("/hello") defer response.Body.Close() @@ -112,7 +112,7 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) { }) } -func testSuccessfulPost(t *testing.T, client GitlabClient) { +func testSuccessfulPost(t *testing.T, client *GitlabClient) { t.Run("Successful Post", func(t *testing.T) { data := map[string]string{"key": "value"} @@ -128,7 +128,7 @@ func testSuccessfulPost(t *testing.T, client GitlabClient) { }) } -func testMissing(t *testing.T, client GitlabClient) { +func testMissing(t *testing.T, client *GitlabClient) { t.Run("Missing error for GET", func(t *testing.T) { response, err := client.Get("/missing") assert.EqualError(t, err, "Internal API error (404)") @@ -142,7 +142,7 @@ func testMissing(t *testing.T, client GitlabClient) { }) } -func testErrorMessage(t *testing.T, client GitlabClient) { +func testErrorMessage(t *testing.T, client *GitlabClient) { t.Run("Error with message for GET", func(t *testing.T) { response, err := client.Get("/error") assert.EqualError(t, err, "Don't do that") @@ -156,7 +156,7 @@ func testErrorMessage(t *testing.T, client GitlabClient) { }) } -func testBrokenRequest(t *testing.T, client GitlabClient) { +func testBrokenRequest(t *testing.T, client *GitlabClient) { t.Run("Broken request for GET", func(t *testing.T) { response, err := client.Get("/broken") assert.EqualError(t, err, "Internal API unreachable") @@ -170,7 +170,7 @@ func testBrokenRequest(t *testing.T, client GitlabClient) { }) } -func testAuthenticationHeader(t *testing.T, client GitlabClient) { +func testAuthenticationHeader(t *testing.T, client *GitlabClient) { t.Run("Authentication headers for GET", func(t *testing.T) { response, err := client.Get("/auth") defer response.Body.Close() diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go index e84b1b4..1266379 100644 --- a/go/internal/gitlabnet/discover/client.go +++ b/go/internal/gitlabnet/discover/client.go @@ -13,7 +13,7 @@ import ( type Client struct { config *config.Config - client gitlabnet.GitlabClient + client *gitlabnet.GitlabClient } type Response struct { diff --git a/go/internal/gitlabnet/httpclient.go b/go/internal/gitlabnet/httpclient.go deleted file mode 100644 index de40051..0000000 --- a/go/internal/gitlabnet/httpclient.go +++ /dev/null @@ -1,69 +0,0 @@ -package gitlabnet - -import ( - "context" - "net" - "net/http" - "strings" - - "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" -) - -const ( - socketBaseUrl = "http://unix" - UnixSocketProtocol = "http+unix://" - HttpProtocol = "http://" -) - -type GitlabHttpClient struct { - httpClient *http.Client - config *config.Config - host string -} - -func buildSocketClient(config *config.Config) *GitlabHttpClient { - path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol) - transport := &http.Transport{ - DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - dialer := net.Dialer{} - return dialer.DialContext(ctx, "unix", path) - }, - } - - return buildClient(config, transport, socketBaseUrl) -} - -func buildHttpClient(config *config.Config) *GitlabHttpClient { - return buildClient(config, &http.Transport{}, config.GitlabUrl) -} - -func buildClient(config *config.Config, transport *http.Transport, host string) *GitlabHttpClient { - httpClient := &http.Client{ - Transport: transport, - Timeout: config.HttpSettings.ReadTimeout(), - } - - return &GitlabHttpClient{httpClient: httpClient, config: config, host: host} -} - -func (c *GitlabHttpClient) Get(path string) (*http.Response, error) { - return c.doRequest("GET", path, nil) -} - -func (c *GitlabHttpClient) Post(path string, data interface{}) (*http.Response, error) { - return c.doRequest("POST", path, data) -} - -func (c *GitlabHttpClient) doRequest(method, path string, data interface{}) (*http.Response, error) { - request, err := newRequest(method, c.host, path, data) - if err != nil { - return nil, err - } - - user, password := c.config.HttpSettings.User, c.config.HttpSettings.Password - if user != "" && password != "" { - request.SetBasicAuth(user, password) - } - - return doRequest(c.httpClient, c.config, request) -} diff --git a/go/internal/gitlabnet/httpclient_test.go b/go/internal/gitlabnet/httpclient_test.go index 5b01b24..885a6d1 100644 --- a/go/internal/gitlabnet/httpclient_test.go +++ b/go/internal/gitlabnet/httpclient_test.go @@ -7,7 +7,6 @@ import ( "net/http" "strings" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -86,19 +85,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) { require.NoError(t, err) } -func TestReadTimeoutSetting(t *testing.T) { - expectedTimeout := 500 - - config := &config.Config{HttpSettings: config.HttpSettingsConfig{ReadTimeoutSeconds: uint64(expectedTimeout)}} - - client := buildHttpClient(config) - assert.Equal(t, time.Duration(expectedTimeout)*time.Second, client.httpClient.Timeout) - - socketClient := buildSocketClient(config) - assert.Equal(t, time.Duration(expectedTimeout)*time.Second, socketClient.httpClient.Timeout) -} - -func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (GitlabClient, func()) { +func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (*GitlabClient, func()) { cleanup, url, err := testserver.StartHttpServer(requests) require.NoError(t, err) diff --git a/go/internal/gitlabnet/twofactorrecover/client.go b/go/internal/gitlabnet/twofactorrecover/client.go index 2e47c64..d26b141 100644 --- a/go/internal/gitlabnet/twofactorrecover/client.go +++ b/go/internal/gitlabnet/twofactorrecover/client.go @@ -15,7 +15,7 @@ import ( type Client struct { config *config.Config - client gitlabnet.GitlabClient + client *gitlabnet.GitlabClient } type Response struct { |