summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor <idrozdov@gitlab.com>2019-04-24 11:12:28 +0000
committerNick Thomas <nick@gitlab.com>2019-04-24 11:12:28 +0000
commit9d9e1617ab7173ed245860612280284c7c904b58 (patch)
tree0c7eae65818430d632f60433195aa08ce122aa8f
parent6e9b4dec537171c643e67fc6a73b79dd29afd068 (diff)
downloadgitlab-shell-9d9e1617ab7173ed245860612280284c7c904b58.tar.gz
Support calling internal api using HTTP
-rw-r--r--go/internal/command/discover/discover_test.go11
-rw-r--r--go/internal/command/twofactorrecover/twofactorrecover_test.go8
-rw-r--r--go/internal/config/config.go24
-rw-r--r--go/internal/config/config_test.go30
-rw-r--r--go/internal/config/httpclient.go74
-rw-r--r--go/internal/config/httpclient_test.go22
-rw-r--r--go/internal/gitlabnet/client.go72
-rw-r--r--go/internal/gitlabnet/client_test.go41
-rw-r--r--go/internal/gitlabnet/discover/client.go2
-rw-r--r--go/internal/gitlabnet/discover/client_test.go8
-rw-r--r--go/internal/gitlabnet/httpclient_test.go97
-rw-r--r--go/internal/gitlabnet/socketclient.go66
-rw-r--r--go/internal/gitlabnet/testserver/testserver.go23
-rw-r--r--go/internal/gitlabnet/twofactorrecover/client.go2
-rw-r--r--go/internal/gitlabnet/twofactorrecover/client_test.go8
15 files changed, 348 insertions, 140 deletions
diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go
index a57f07e..f85add8 100644
--- a/go/internal/command/discover/discover_test.go
+++ b/go/internal/command/discover/discover_test.go
@@ -17,8 +17,7 @@ import (
)
var (
- testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
- requests = []testserver.TestRequestHandler{
+ requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
@@ -46,7 +45,7 @@ var (
)
func TestExecute(t *testing.T) {
- cleanup, err := testserver.StartSocketHttpServer(requests)
+ cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
@@ -79,7 +78,7 @@ func TestExecute(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- cmd := &Command{Config: testConfig, Args: tc.arguments}
+ cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
@@ -91,7 +90,7 @@ func TestExecute(t *testing.T) {
}
func TestFailingExecute(t *testing.T) {
- cleanup, err := testserver.StartSocketHttpServer(requests)
+ cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
@@ -119,7 +118,7 @@ func TestFailingExecute(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- cmd := &Command{Config: testConfig, Args: tc.arguments}
+ cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
diff --git a/go/internal/command/twofactorrecover/twofactorrecover_test.go b/go/internal/command/twofactorrecover/twofactorrecover_test.go
index 908ee13..be76520 100644
--- a/go/internal/command/twofactorrecover/twofactorrecover_test.go
+++ b/go/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -18,12 +18,10 @@ import (
)
var (
- testConfig *config.Config
- requests []testserver.TestRequestHandler
+ requests []testserver.TestRequestHandler
)
func setup(t *testing.T) {
- testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/two_factor_recovery_codes",
@@ -66,7 +64,7 @@ const (
func TestExecute(t *testing.T) {
setup(t)
- cleanup, err := testserver.StartSocketHttpServer(requests)
+ cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
@@ -124,7 +122,7 @@ func TestExecute(t *testing.T) {
output := &bytes.Buffer{}
input := bytes.NewBufferString(tc.answer)
- cmd := &Command{Config: testConfig, Args: tc.arguments}
+ cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input})
diff --git a/go/internal/config/config.go b/go/internal/config/config.go
index d2afcdc..6085493 100644
--- a/go/internal/config/config.go
+++ b/go/internal/config/config.go
@@ -22,15 +22,23 @@ type MigrationConfig struct {
Features []string `yaml:"features"`
}
+type HttpSettingsConfig struct {
+ User string `yaml:"user"`
+ Password string `yaml:"password"`
+ ReadTimeoutSeconds uint64 `yaml:"read_timeout"`
+}
+
type Config struct {
RootDir string
- LogFile string `yaml:"log_file"`
- LogFormat string `yaml:"log_format"`
- Migration MigrationConfig `yaml:"migration"`
- GitlabUrl string `yaml:"gitlab_url"`
- GitlabTracing string `yaml:"gitlab_tracing"`
- SecretFilePath string `yaml:"secret_file"`
- Secret string `yaml:"secret"`
+ LogFile string `yaml:"log_file"`
+ LogFormat string `yaml:"log_format"`
+ Migration MigrationConfig `yaml:"migration"`
+ GitlabUrl string `yaml:"gitlab_url"`
+ GitlabTracing string `yaml:"gitlab_tracing"`
+ SecretFilePath string `yaml:"secret_file"`
+ Secret string `yaml:"secret"`
+ HttpSettings HttpSettingsConfig `yaml:"http_settings"`
+ HttpClient *HttpClient
}
func New() (*Config, error) {
@@ -51,7 +59,7 @@ func (c *Config) FeatureEnabled(featureName string) bool {
return false
}
- if !strings.HasPrefix(c.GitlabUrl, "http+unix://") {
+ if !strings.HasPrefix(c.GitlabUrl, "http+unix://") && !strings.HasPrefix(c.GitlabUrl, "http://") {
return false
}
diff --git a/go/internal/config/config_test.go b/go/internal/config/config_test.go
index e1e49d7..d48d3db 100644
--- a/go/internal/config/config_test.go
+++ b/go/internal/config/config_test.go
@@ -24,12 +24,13 @@ func TestParseConfig(t *testing.T) {
defer cleanup()
testCases := []struct {
- yaml string
- path string
- format string
- gitlabUrl string
- migration MigrationConfig
- secret string
+ yaml string
+ path string
+ format string
+ gitlabUrl string
+ migration MigrationConfig
+ secret string
+ httpSettings HttpSettingsConfig
}{
{
path: path.Join(testRoot, "gitlab-shell.log"),
@@ -86,6 +87,13 @@ func TestParseConfig(t *testing.T) {
format: "text",
secret: "an inline secret",
},
+ {
+ yaml: "http_settings:\n user: user_basic_auth\n password: password_basic_auth\n read_timeout: 500",
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ secret: "default-secret-content",
+ httpSettings: HttpSettingsConfig{User: "user_basic_auth", Password: "password_basic_auth", ReadTimeoutSeconds: 500},
+ },
}
for _, tc := range testCases {
@@ -101,6 +109,7 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, tc.format, cfg.LogFormat)
assert.Equal(t, tc.gitlabUrl, cfg.GitlabUrl)
assert.Equal(t, tc.secret, cfg.Secret)
+ assert.Equal(t, tc.httpSettings, cfg.HttpSettings)
})
}
}
@@ -140,6 +149,15 @@ func TestFeatureEnabled(t *testing.T) {
expectEnabled: false,
},
{
+ desc: "When the protocol is http and the feature enabled",
+ config: &Config{
+ GitlabUrl: "http://localhost:3000",
+ Migration: MigrationConfig{Enabled: true, Features: []string{"discover"}},
+ },
+ feature: "discover",
+ expectEnabled: true,
+ },
+ {
desc: "When the protocol is not supported",
config: &Config{
GitlabUrl: "https://localhost:3000",
diff --git a/go/internal/config/httpclient.go b/go/internal/config/httpclient.go
new file mode 100644
index 0000000..82807a6
--- /dev/null
+++ b/go/internal/config/httpclient.go
@@ -0,0 +1,74 @@
+package config
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "strings"
+ "time"
+)
+
+const (
+ socketBaseUrl = "http://unix"
+ UnixSocketProtocol = "http+unix://"
+ HttpProtocol = "http://"
+ defaultReadTimeoutSeconds = 300
+)
+
+type HttpClient struct {
+ HttpClient *http.Client
+ Host string
+}
+
+func (c *Config) GetHttpClient() *HttpClient {
+ if c.HttpClient != nil {
+ return c.HttpClient
+ }
+
+ var transport *http.Transport
+ var host string
+ if strings.HasPrefix(c.GitlabUrl, UnixSocketProtocol) {
+ transport, host = c.buildSocketTransport()
+ } else if strings.HasPrefix(c.GitlabUrl, HttpProtocol) {
+ transport, host = c.buildHttpTransport()
+ } else {
+ return nil
+ }
+
+ httpClient := &http.Client{
+ Transport: transport,
+ Timeout: c.readTimeout(),
+ }
+
+ client := &HttpClient{HttpClient: httpClient, Host: host}
+
+ c.HttpClient = client
+
+ return client
+}
+
+func (c *Config) buildSocketTransport() (*http.Transport, string) {
+ socketPath := strings.TrimPrefix(c.GitlabUrl, UnixSocketProtocol)
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ dialer := net.Dialer{}
+ return dialer.DialContext(ctx, "unix", socketPath)
+ },
+ }
+
+ return transport, socketBaseUrl
+}
+
+func (c *Config) buildHttpTransport() (*http.Transport, string) {
+ return &http.Transport{}, c.GitlabUrl
+}
+
+func (c *Config) readTimeout() time.Duration {
+ timeoutSeconds := c.HttpSettings.ReadTimeoutSeconds
+
+ if timeoutSeconds == 0 {
+ timeoutSeconds = defaultReadTimeoutSeconds
+ }
+
+ return time.Duration(timeoutSeconds) * time.Second
+}
diff --git a/go/internal/config/httpclient_test.go b/go/internal/config/httpclient_test.go
new file mode 100644
index 0000000..474deba
--- /dev/null
+++ b/go/internal/config/httpclient_test.go
@@ -0,0 +1,22 @@
+package config
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestReadTimeout(t *testing.T) {
+ expectedSeconds := uint64(300)
+
+ config := &Config{
+ GitlabUrl: "http://localhost:3000",
+ HttpSettings: HttpSettingsConfig{ReadTimeoutSeconds: expectedSeconds},
+ }
+ client := config.GetHttpClient()
+
+ require.NotNil(t, client)
+ assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.HttpClient.Timeout)
+}
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go
index c2453e5..c0f7f97 100644
--- a/go/internal/gitlabnet/client.go
+++ b/go/internal/gitlabnet/client.go
@@ -1,9 +1,11 @@
package gitlabnet
import (
+ "bytes"
"encoding/base64"
"encoding/json"
"fmt"
+ "io"
"net/http"
"strings"
@@ -15,22 +17,24 @@ const (
secretHeaderName = "Gitlab-Shared-Secret"
)
-type GitlabClient interface {
- Get(path string) (*http.Response, error)
- Post(path string, data interface{}) (*http.Response, error)
-}
-
type ErrorResponse struct {
Message string `json:"message"`
}
-func GetClient(config *config.Config) (GitlabClient, error) {
- url := config.GitlabUrl
- if strings.HasPrefix(url, UnixSocketProtocol) {
- return buildSocketClient(config), nil
+type GitlabClient struct {
+ httpClient *http.Client
+ config *config.Config
+ host string
+}
+
+func GetClient(config *config.Config) (*GitlabClient, error) {
+ client := config.GetHttpClient()
+
+ if client == nil {
+ return nil, fmt.Errorf("Unsupported protocol")
}
- return nil, fmt.Errorf("Unsupported protocol")
+ return &GitlabClient{httpClient: client.HttpClient, config: config, host: client.Host}, nil
}
func normalizePath(path string) string {
@@ -44,6 +48,27 @@ func normalizePath(path string) string {
return path
}
+func newRequest(method, host, path string, data interface{}) (*http.Request, error) {
+ path = normalizePath(path)
+
+ var jsonReader io.Reader
+ if data != nil {
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+
+ jsonReader = bytes.NewReader(jsonData)
+ }
+
+ request, err := http.NewRequest(method, host+path, jsonReader)
+ if err != nil {
+ return nil, err
+ }
+
+ return request, nil
+}
+
func parseError(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return nil
@@ -59,11 +84,32 @@ func parseError(resp *http.Response) error {
}
-func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) {
- encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret))
+func (c *GitlabClient) Get(path string) (*http.Response, error) {
+ return c.doRequest("GET", path, nil)
+}
+
+func (c *GitlabClient) Post(path string, data interface{}) (*http.Response, error) {
+ return c.doRequest("POST", path, data)
+}
+
+func (c *GitlabClient) doRequest(method, path string, data interface{}) (*http.Response, error) {
+ request, err := newRequest(method, c.host, path, data)
+ if err != nil {
+ return nil, err
+ }
+
+ user, password := c.config.HttpSettings.User, c.config.HttpSettings.Password
+ if user != "" && password != "" {
+ request.SetBasicAuth(user, password)
+ }
+
+ encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.config.Secret))
request.Header.Set(secretHeaderName, encodedSecret)
- response, err := client.Do(request)
+ request.Header.Add("Content-Type", "application/json")
+ request.Close = true
+
+ response, err := c.httpClient.Do(request)
if err != nil {
return nil, fmt.Errorf("Internal API unreachable")
}
diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go
index c1d08a1..f9aa289 100644
--- a/go/internal/gitlabnet/client_test.go
+++ b/go/internal/gitlabnet/client_test.go
@@ -61,37 +61,44 @@ func TestClients(t *testing.T) {
},
},
}
- testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
testCases := []struct {
desc string
- client GitlabClient
- server func([]testserver.TestRequestHandler) (func(), error)
+ secret string
+ server func([]testserver.TestRequestHandler) (func(), string, error)
}{
{
desc: "Socket client",
- client: buildSocketClient(testConfig),
+ secret: "sssh, it's a secret",
server: testserver.StartSocketHttpServer,
},
+ {
+ desc: "Http client",
+ secret: "sssh, it's a secret",
+ server: testserver.StartHttpServer,
+ },
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- cleanup, err := tc.server(requests)
+ cleanup, url, err := tc.server(requests)
defer cleanup()
require.NoError(t, err)
- testBrokenRequest(t, tc.client)
- testSuccessfulGet(t, tc.client)
- testSuccessfulPost(t, tc.client)
- testMissing(t, tc.client)
- testErrorMessage(t, tc.client)
- testAuthenticationHeader(t, tc.client)
+ client, err := GetClient(&config.Config{GitlabUrl: url, Secret: tc.secret})
+ require.NoError(t, err)
+
+ testBrokenRequest(t, client)
+ testSuccessfulGet(t, client)
+ testSuccessfulPost(t, client)
+ testMissing(t, client)
+ testErrorMessage(t, client)
+ testAuthenticationHeader(t, client)
})
}
}
-func testSuccessfulGet(t *testing.T, client GitlabClient) {
+func testSuccessfulGet(t *testing.T, client *GitlabClient) {
t.Run("Successful get", func(t *testing.T) {
response, err := client.Get("/hello")
defer response.Body.Close()
@@ -105,7 +112,7 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) {
})
}
-func testSuccessfulPost(t *testing.T, client GitlabClient) {
+func testSuccessfulPost(t *testing.T, client *GitlabClient) {
t.Run("Successful Post", func(t *testing.T) {
data := map[string]string{"key": "value"}
@@ -121,7 +128,7 @@ func testSuccessfulPost(t *testing.T, client GitlabClient) {
})
}
-func testMissing(t *testing.T, client GitlabClient) {
+func testMissing(t *testing.T, client *GitlabClient) {
t.Run("Missing error for GET", func(t *testing.T) {
response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)")
@@ -135,7 +142,7 @@ func testMissing(t *testing.T, client GitlabClient) {
})
}
-func testErrorMessage(t *testing.T, client GitlabClient) {
+func testErrorMessage(t *testing.T, client *GitlabClient) {
t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that")
@@ -149,7 +156,7 @@ func testErrorMessage(t *testing.T, client GitlabClient) {
})
}
-func testBrokenRequest(t *testing.T, client GitlabClient) {
+func testBrokenRequest(t *testing.T, client *GitlabClient) {
t.Run("Broken request for GET", func(t *testing.T) {
response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable")
@@ -163,7 +170,7 @@ func testBrokenRequest(t *testing.T, client GitlabClient) {
})
}
-func testAuthenticationHeader(t *testing.T, client GitlabClient) {
+func testAuthenticationHeader(t *testing.T, client *GitlabClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth")
defer response.Body.Close()
diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go
index e84b1b4..1266379 100644
--- a/go/internal/gitlabnet/discover/client.go
+++ b/go/internal/gitlabnet/discover/client.go
@@ -13,7 +13,7 @@ import (
type Client struct {
config *config.Config
- client gitlabnet.GitlabClient
+ client *gitlabnet.GitlabClient
}
type Response struct {
diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go
index e88cedd..006568a 100644
--- a/go/internal/gitlabnet/discover/client_test.go
+++ b/go/internal/gitlabnet/discover/client_test.go
@@ -15,12 +15,10 @@ import (
)
var (
- testConfig *config.Config
- requests []testserver.TestRequestHandler
+ requests []testserver.TestRequestHandler
)
func init() {
- testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
@@ -121,10 +119,10 @@ func TestErrorResponses(t *testing.T) {
}
func setup(t *testing.T) (*Client, func()) {
- cleanup, err := testserver.StartSocketHttpServer(requests)
+ cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
- client, err := NewClient(testConfig)
+ client, err := NewClient(&config.Config{GitlabUrl: url})
require.NoError(t, err)
return client, cleanup
diff --git a/go/internal/gitlabnet/httpclient_test.go b/go/internal/gitlabnet/httpclient_test.go
new file mode 100644
index 0000000..885a6d1
--- /dev/null
+++ b/go/internal/gitlabnet/httpclient_test.go
@@ -0,0 +1,97 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+const (
+ username = "basic_auth_user"
+ password = "basic_auth_password"
+)
+
+func TestBasicAuthSettings(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/get_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
+ fmt.Fprint(w, r.Header.Get("Authorization"))
+ },
+ },
+ {
+ Path: "/api/v4/internal/post_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+
+ fmt.Fprint(w, r.Header.Get("Authorization"))
+ },
+ },
+ }
+ config := &config.Config{HttpSettings: config.HttpSettingsConfig{User: username, Password: password}}
+
+ client, cleanup := setup(t, config, requests)
+ defer cleanup()
+
+ response, err := client.Get("/get_endpoint")
+ require.NoError(t, err)
+ testBasicAuthHeaders(t, response)
+
+ response, err = client.Post("/post_endpoint", nil)
+ require.NoError(t, err)
+ testBasicAuthHeaders(t, response)
+}
+
+func testBasicAuthHeaders(t *testing.T, response *http.Response) {
+ defer response.Body.Close()
+
+ require.NotNil(t, response)
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+
+ headerParts := strings.Split(string(responseBody), " ")
+ assert.Equal(t, "Basic", headerParts[0])
+
+ credentials, err := base64.StdEncoding.DecodeString(headerParts[1])
+ require.NoError(t, err)
+
+ assert.Equal(t, username+":"+password, string(credentials))
+}
+
+func TestEmptyBasicAuthSettings(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/empty_basic_auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "", r.Header.Get("Authorization"))
+ },
+ },
+ }
+
+ client, cleanup := setup(t, &config.Config{}, requests)
+ defer cleanup()
+
+ _, err := client.Get("/empty_basic_auth")
+ require.NoError(t, err)
+}
+
+func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (*GitlabClient, func()) {
+ cleanup, url, err := testserver.StartHttpServer(requests)
+ require.NoError(t, err)
+
+ config.GitlabUrl = url
+ client, err := GetClient(config)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go
deleted file mode 100644
index fd97535..0000000
--- a/go/internal/gitlabnet/socketclient.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package gitlabnet
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "net"
- "net/http"
- "strings"
-
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
-)
-
-const (
- // We need to set the base URL to something starting with HTTP, the host
- // itself is ignored as we're talking over a socket.
- socketBaseUrl = "http://unix"
- UnixSocketProtocol = "http+unix://"
-)
-
-type GitlabSocketClient struct {
- httpClient *http.Client
- config *config.Config
-}
-
-func buildSocketClient(config *config.Config) *GitlabSocketClient {
- path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
- httpClient := &http.Client{
- Transport: &http.Transport{
- DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
- return net.Dial("unix", path)
- },
- },
- }
-
- return &GitlabSocketClient{httpClient: httpClient, config: config}
-}
-
-func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
- path = normalizePath(path)
-
- request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
- if err != nil {
- return nil, err
- }
-
- return doRequest(c.httpClient, c.config, request)
-}
-
-func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) {
- path = normalizePath(path)
-
- jsonData, err := json.Marshal(data)
- if err != nil {
- return nil, err
- }
-
- request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData))
- request.Header.Add("Content-Type", "application/json")
-
- if err != nil {
- return nil, err
- }
-
- return doRequest(c.httpClient, c.config, request)
-}
diff --git a/go/internal/gitlabnet/testserver/testserver.go b/go/internal/gitlabnet/testserver/testserver.go
index 9640fd7..3e6499d 100644
--- a/go/internal/gitlabnet/testserver/testserver.go
+++ b/go/internal/gitlabnet/testserver/testserver.go
@@ -5,6 +5,7 @@ import (
"log"
"net"
"net/http"
+ "net/http/httptest"
"os"
"path"
"path/filepath"
@@ -12,7 +13,7 @@ import (
var (
tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
- TestSocket = path.Join(tempDir, "internal.sock")
+ testSocket = path.Join(tempDir, "internal.sock")
)
type TestRequestHandler struct {
@@ -20,14 +21,14 @@ type TestRequestHandler struct {
Handler func(w http.ResponseWriter, r *http.Request)
}
-func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
- if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil {
- return nil, err
+func StartSocketHttpServer(handlers []TestRequestHandler) (func(), string, error) {
+ if err := os.MkdirAll(filepath.Dir(testSocket), 0700); err != nil {
+ return nil, "", err
}
- socketListener, err := net.Listen("unix", TestSocket)
+ socketListener, err := net.Listen("unix", testSocket)
if err != nil {
- return nil, err
+ return nil, "", err
}
server := http.Server{
@@ -38,7 +39,15 @@ func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
}
go server.Serve(socketListener)
- return cleanupSocket, nil
+ url := "http+unix://" + testSocket
+
+ return cleanupSocket, url, nil
+}
+
+func StartHttpServer(handlers []TestRequestHandler) (func(), string, error) {
+ server := httptest.NewServer(buildHandler(handlers))
+
+ return server.Close, server.URL, nil
}
func cleanupSocket() {
diff --git a/go/internal/gitlabnet/twofactorrecover/client.go b/go/internal/gitlabnet/twofactorrecover/client.go
index 2e47c64..d26b141 100644
--- a/go/internal/gitlabnet/twofactorrecover/client.go
+++ b/go/internal/gitlabnet/twofactorrecover/client.go
@@ -15,7 +15,7 @@ import (
type Client struct {
config *config.Config
- client gitlabnet.GitlabClient
+ client *gitlabnet.GitlabClient
}
type Response struct {
diff --git a/go/internal/gitlabnet/twofactorrecover/client_test.go b/go/internal/gitlabnet/twofactorrecover/client_test.go
index 5cbc011..56f7958 100644
--- a/go/internal/gitlabnet/twofactorrecover/client_test.go
+++ b/go/internal/gitlabnet/twofactorrecover/client_test.go
@@ -17,12 +17,10 @@ import (
)
var (
- testConfig *config.Config
- requests []testserver.TestRequestHandler
+ requests []testserver.TestRequestHandler
)
func initialize(t *testing.T) {
- testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/two_factor_recovery_codes",
@@ -151,10 +149,10 @@ func TestErrorResponses(t *testing.T) {
func setup(t *testing.T) (*Client, func()) {
initialize(t)
- cleanup, err := testserver.StartSocketHttpServer(requests)
+ cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
- client, err := NewClient(testConfig)
+ client, err := NewClient(&config.Config{GitlabUrl: url})
require.NoError(t, err)
return client, cleanup