summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
authorJohn Cai <jcai@gitlab.com>2020-05-04 10:59:49 -0700
committerJohn Cai <jcai@gitlab.com>2020-05-04 14:19:47 -0700
commit91f45342c4ff29a24c61812d539ac745dbb1570a (patch)
treeaec9280d8c4e1c0d02515f1ca9d46a65182cfb14 /client
parentf62a4b2fb89754372a346f24659212eb8da13601 (diff)
downloadgitlab-shell-jc-refactor-gitlabnet-client.tar.gz
Move gitlabnet client to client packagejc-refactor-gitlabnet-client
Diffstat (limited to 'client')
-rw-r--r--client/client_test.go240
-rw-r--r--client/gitlabnet.go140
-rw-r--r--client/httpclient.go113
-rw-r--r--client/httpclient_test.go105
-rw-r--r--client/httpsclient_test.go115
-rw-r--r--client/testserver/gitalyserver.go85
-rw-r--r--client/testserver/testserver.go81
7 files changed, 879 insertions, 0 deletions
diff --git a/client/client_test.go b/client/client_test.go
new file mode 100644
index 0000000..dfb1ca3
--- /dev/null
+++ b/client/client_test.go
@@ -0,0 +1,240 @@
+package client
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "path"
+ "strings"
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestClients(t *testing.T) {
+ testDirCleanup, err := testhelper.PrepareTestRootDir()
+ require.NoError(t, err)
+ defer testDirCleanup()
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/hello",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
+ fmt.Fprint(w, "Hello")
+ },
+ },
+ {
+ Path: "/api/v4/internal/post_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ fmt.Fprint(w, "Echo: "+string(b))
+ },
+ },
+ {
+ Path: "/api/v4/internal/auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, r.Header.Get(secretHeaderName))
+ },
+ },
+ {
+ Path: "/api/v4/internal/error",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ body := map[string]string{
+ "message": "Don't do that",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ {
+ Path: "/api/v4/internal/broken",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ panic("Broken")
+ },
+ },
+ }
+
+ testCases := []struct {
+ desc string
+ caFile string
+ server func(*testing.T, []testserver.TestRequestHandler) (string, func())
+ }{
+ {
+ desc: "Socket client",
+ server: testserver.StartSocketHttpServer,
+ },
+ {
+ desc: "Http client",
+ server: testserver.StartHttpServer,
+ },
+ {
+ desc: "Https client",
+ caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
+ server: testserver.StartHttpsServer,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ url, cleanup := tc.server(t, requests)
+ defer cleanup()
+
+ secret := "sssh, it's a secret"
+
+ httpClient := NewHTTPClient(url, tc.caFile, "", false, 1)
+
+ client, err := NewGitlabNetClient("", "", secret, httpClient)
+ require.NoError(t, err)
+
+ testBrokenRequest(t, client)
+ testSuccessfulGet(t, client)
+ testSuccessfulPost(t, client)
+ testMissing(t, client)
+ testErrorMessage(t, client)
+ testAuthenticationHeader(t, client)
+ })
+ }
+}
+
+func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
+ t.Run("Successful get", func(t *testing.T) {
+ hook := testhelper.SetupLogger()
+ response, err := client.Get("/hello")
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, string(responseBody), "Hello")
+
+ assert.Equal(t, 1, len(hook.Entries))
+ assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level)
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "method=GET"))
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "Finished HTTP request"))
+ })
+}
+
+func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
+ t.Run("Successful Post", func(t *testing.T) {
+ hook := testhelper.SetupLogger()
+ data := map[string]string{"key": "value"}
+
+ response, err := client.Post("/post_endpoint", data)
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody))
+
+ assert.Equal(t, 1, len(hook.Entries))
+ assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level)
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "method=POST"))
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "Finished HTTP request"))
+ })
+}
+
+func testMissing(t *testing.T, client *GitlabNetClient) {
+ t.Run("Missing error for GET", func(t *testing.T) {
+ hook := testhelper.SetupLogger()
+ response, err := client.Get("/missing")
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+
+ assert.Equal(t, 1, len(hook.Entries))
+ assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level)
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "method=GET"))
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "Internal API error"))
+ })
+
+ t.Run("Missing error for POST", func(t *testing.T) {
+ hook := testhelper.SetupLogger()
+ response, err := client.Post("/missing", map[string]string{})
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+
+ assert.Equal(t, 1, len(hook.Entries))
+ assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level)
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "method=POST"))
+ assert.True(t, strings.Contains(hook.LastEntry().Message, "Internal API error"))
+ })
+}
+
+func testErrorMessage(t *testing.T, client *GitlabNetClient) {
+ t.Run("Error with message for GET", func(t *testing.T) {
+ response, err := client.Get("/error")
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
+
+ t.Run("Error with message for POST", func(t *testing.T) {
+ response, err := client.Post("/error", map[string]string{})
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
+}
+
+func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
+ t.Run("Broken request for GET", func(t *testing.T) {
+ response, err := client.Get("/broken")
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
+
+ t.Run("Broken request for POST", func(t *testing.T) {
+ response, err := client.Post("/broken", map[string]string{})
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
+}
+
+func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
+ t.Run("Authentication headers for GET", func(t *testing.T) {
+ response, err := client.Get("/auth")
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
+
+ t.Run("Authentication headers for POST", func(t *testing.T) {
+ response, err := client.Post("/auth", map[string]string{})
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
+}
diff --git a/client/gitlabnet.go b/client/gitlabnet.go
new file mode 100644
index 0000000..67c48c7
--- /dev/null
+++ b/client/gitlabnet.go
@@ -0,0 +1,140 @@
+package client
+
+import (
+ "bytes"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ internalApiPath = "/api/v4/internal"
+ secretHeaderName = "Gitlab-Shared-Secret"
+)
+
+type ErrorResponse struct {
+ Message string `json:"message"`
+}
+
+type GitlabNetClient struct {
+ httpClient *HttpClient
+ user, password, secret string
+}
+
+func NewGitlabNetClient(
+ user,
+ password,
+ secret string,
+ httpClient *HttpClient,
+) (*GitlabNetClient, error) {
+
+ if httpClient == nil {
+ return nil, fmt.Errorf("Unsupported protocol")
+ }
+
+ return &GitlabNetClient{
+ httpClient: httpClient,
+ user: user,
+ password: password,
+ secret: secret,
+ }, nil
+}
+
+func normalizePath(path string) string {
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+
+ if !strings.HasPrefix(path, internalApiPath) {
+ path = internalApiPath + path
+ }
+ return path
+}
+
+func newRequest(method, host, path string, data interface{}) (*http.Request, error) {
+ var jsonReader io.Reader
+ if data != nil {
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+
+ jsonReader = bytes.NewReader(jsonData)
+ }
+
+ request, err := http.NewRequest(method, host+path, jsonReader)
+ if err != nil {
+ return nil, err
+ }
+
+ return request, nil
+}
+
+func parseError(resp *http.Response) error {
+ if resp.StatusCode >= 200 && resp.StatusCode <= 399 {
+ return nil
+ }
+ defer resp.Body.Close()
+ parsedResponse := &ErrorResponse{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
+ } else {
+ return fmt.Errorf(parsedResponse.Message)
+ }
+
+}
+
+func (c *GitlabNetClient) Get(path string) (*http.Response, error) {
+ return c.DoRequest(http.MethodGet, normalizePath(path), nil)
+}
+
+func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) {
+ return c.DoRequest(http.MethodPost, normalizePath(path), data)
+}
+
+func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
+ request, err := newRequest(method, c.httpClient.Host, path, data)
+ if err != nil {
+ return nil, err
+ }
+
+ user, password := c.user, c.password
+ if user != "" && password != "" {
+ request.SetBasicAuth(user, password)
+ }
+
+ encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.secret))
+ request.Header.Set(secretHeaderName, encodedSecret)
+
+ request.Header.Add("Content-Type", "application/json")
+ request.Close = true
+
+ start := time.Now()
+ response, err := c.httpClient.Do(request)
+ fields := log.Fields{
+ "method": method,
+ "url": request.URL.String(),
+ "duration_ms": time.Since(start) / time.Millisecond,
+ }
+
+ if err != nil {
+ log.WithError(err).WithFields(fields).Error("Internal API unreachable")
+ return nil, fmt.Errorf("Internal API unreachable")
+ }
+
+ if err := parseError(response); err != nil {
+ log.WithError(err).WithFields(fields).Error("Internal API error")
+ return nil, err
+ }
+
+ log.WithFields(fields).Info("Finished HTTP request")
+
+ return response, nil
+}
diff --git a/client/httpclient.go b/client/httpclient.go
new file mode 100644
index 0000000..ff0cc25
--- /dev/null
+++ b/client/httpclient.go
@@ -0,0 +1,113 @@
+package client
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+const (
+ socketBaseUrl = "http://unix"
+ unixSocketProtocol = "http+unix://"
+ httpProtocol = "http://"
+ httpsProtocol = "https://"
+ defaultReadTimeoutSeconds = 300
+)
+
+type HttpClient struct {
+ *http.Client
+ Host string
+}
+
+func NewHTTPClient(gitlabURL, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64) *HttpClient {
+
+ var transport *http.Transport
+ var host string
+ if strings.HasPrefix(gitlabURL, unixSocketProtocol) {
+ transport, host = buildSocketTransport(gitlabURL)
+ } else if strings.HasPrefix(gitlabURL, httpProtocol) {
+ transport, host = buildHttpTransport(gitlabURL)
+ } else if strings.HasPrefix(gitlabURL, httpsProtocol) {
+ transport, host = buildHttpsTransport(caFile, caPath, selfSignedCert, gitlabURL)
+ } else {
+ return nil
+ }
+
+ c := &http.Client{
+ Transport: transport,
+ Timeout: readTimeout(readTimeoutSeconds),
+ }
+
+ client := &HttpClient{Client: c, Host: host}
+
+ return client
+}
+
+func buildSocketTransport(gitlabURL string) (*http.Transport, string) {
+ socketPath := strings.TrimPrefix(gitlabURL, unixSocketProtocol)
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ dialer := net.Dialer{}
+ return dialer.DialContext(ctx, "unix", socketPath)
+ },
+ }
+
+ return transport, socketBaseUrl
+}
+
+func buildHttpsTransport(caFile, caPath string, selfSignedCert bool, gitlabURL string) (*http.Transport, string) {
+ certPool, err := x509.SystemCertPool()
+
+ if err != nil {
+ certPool = x509.NewCertPool()
+ }
+
+ if caFile != "" {
+ addCertToPool(certPool, caFile)
+ }
+
+ if caPath != "" {
+ fis, _ := ioutil.ReadDir(caPath)
+ for _, fi := range fis {
+ if fi.IsDir() {
+ continue
+ }
+
+ addCertToPool(certPool, filepath.Join(caPath, fi.Name()))
+ }
+ }
+
+ transport := &http.Transport{
+ TLSClientConfig: &tls.Config{
+ RootCAs: certPool,
+ InsecureSkipVerify: selfSignedCert,
+ },
+ }
+
+ return transport, gitlabURL
+}
+
+func addCertToPool(certPool *x509.CertPool, fileName string) {
+ cert, err := ioutil.ReadFile(fileName)
+ if err == nil {
+ certPool.AppendCertsFromPEM(cert)
+ }
+}
+
+func buildHttpTransport(gitlabURL string) (*http.Transport, string) {
+ return &http.Transport{}, gitlabURL
+}
+
+func readTimeout(timeoutSeconds uint64) time.Duration {
+ if timeoutSeconds == 0 {
+ timeoutSeconds = defaultReadTimeoutSeconds
+ }
+
+ return time.Duration(timeoutSeconds) * time.Second
+}
diff --git a/client/httpclient_test.go b/client/httpclient_test.go
new file mode 100644
index 0000000..1f0a4ed
--- /dev/null
+++ b/client/httpclient_test.go
@@ -0,0 +1,105 @@
+package client
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+)
+
+func TestReadTimeout(t *testing.T) {
+ expectedSeconds := uint64(300)
+
+ client := NewHTTPClient("http://localhost:3000", "", "", false, expectedSeconds)
+
+ require.NotNil(t, client)
+ assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.Client.Timeout)
+}
+
+const (
+ username = "basic_auth_user"
+ password = "basic_auth_password"
+)
+
+func TestBasicAuthSettings(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/get_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
+ fmt.Fprint(w, r.Header.Get("Authorization"))
+ },
+ },
+ {
+ Path: "/api/v4/internal/post_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+
+ fmt.Fprint(w, r.Header.Get("Authorization"))
+ },
+ },
+ }
+
+ client, cleanup := setup(t, username, password, requests)
+ defer cleanup()
+
+ response, err := client.Get("/get_endpoint")
+ require.NoError(t, err)
+ testBasicAuthHeaders(t, response)
+
+ response, err = client.Post("/post_endpoint", nil)
+ require.NoError(t, err)
+ testBasicAuthHeaders(t, response)
+}
+
+func testBasicAuthHeaders(t *testing.T, response *http.Response) {
+ defer response.Body.Close()
+
+ require.NotNil(t, response)
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+
+ headerParts := strings.Split(string(responseBody), " ")
+ assert.Equal(t, "Basic", headerParts[0])
+
+ credentials, err := base64.StdEncoding.DecodeString(headerParts[1])
+ require.NoError(t, err)
+
+ assert.Equal(t, username+":"+password, string(credentials))
+}
+
+func TestEmptyBasicAuthSettings(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/empty_basic_auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "", r.Header.Get("Authorization"))
+ },
+ },
+ }
+
+ client, cleanup := setup(t, "", "", requests)
+ defer cleanup()
+
+ _, err := client.Get("/empty_basic_auth")
+ require.NoError(t, err)
+}
+
+func setup(t *testing.T, username, password string, requests []testserver.TestRequestHandler) (*GitlabNetClient, func()) {
+ url, cleanup := testserver.StartHttpServer(t, requests)
+
+ httpClient := NewHTTPClient(url, "", "", false, 1)
+
+ client, err := NewGitlabNetClient(username, password, "", httpClient)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
new file mode 100644
index 0000000..6c3ae08
--- /dev/null
+++ b/client/httpsclient_test.go
@@ -0,0 +1,115 @@
+package client
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestSuccessfulRequests(t *testing.T) {
+ testCases := []struct {
+ desc string
+ caFile, caPath string
+ selfSigned bool
+ }{
+ {
+ desc: "Valid CaFile",
+ caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
+ },
+ {
+ desc: "Valid CaPath",
+ caPath: path.Join(testhelper.TestRoot, "certs/valid"),
+ },
+ {
+ desc: "Self signed cert option enabled",
+ selfSigned: true,
+ },
+ {
+ desc: "Invalid cert with self signed cert option enabled",
+ caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"),
+ selfSigned: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned)
+ defer cleanup()
+
+ response, err := client.Get("/hello")
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, string(responseBody), "Hello")
+ })
+ }
+}
+
+func TestFailedRequests(t *testing.T) {
+ testCases := []struct {
+ desc string
+ caFile string
+ caPath string
+ }{
+ {
+ desc: "Invalid CaFile",
+ caFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
+ },
+ {
+ desc: "Invalid CaPath",
+ caPath: path.Join(testhelper.TestRoot, "certs/invalid"),
+ },
+ {
+ desc: "Empty config",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false)
+ defer cleanup()
+
+ _, err := client.Get("/hello")
+ require.Error(t, err)
+
+ assert.Equal(t, err.Error(), "Internal API unreachable")
+ })
+ }
+}
+
+func setupWithRequests(t *testing.T, caFile, caPath string, selfSigned bool) (*GitlabNetClient, func()) {
+ testDirCleanup, err := testhelper.PrepareTestRootDir()
+ require.NoError(t, err)
+ defer testDirCleanup()
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/hello",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
+ fmt.Fprint(w, "Hello")
+ },
+ },
+ }
+
+ url, cleanup := testserver.StartHttpsServer(t, requests)
+
+ httpClient := NewHTTPClient(url, caFile, caPath, selfSigned, 1)
+
+ client, err := NewGitlabNetClient("", "", "", httpClient)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/client/testserver/gitalyserver.go b/client/testserver/gitalyserver.go
new file mode 100644
index 0000000..4bf14f3
--- /dev/null
+++ b/client/testserver/gitalyserver.go
@@ -0,0 +1,85 @@
+package testserver
+
+import (
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+type TestGitalyServer struct{ ReceivedMD metadata.MD }
+
+func (s *TestGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
+
+ response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHReceivePackResponse{Stdout: response})
+
+ return nil
+}
+
+func (s *TestGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
+
+ response := []byte("UploadPack: " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHUploadPackResponse{Stdout: response})
+
+ return nil
+}
+
+func (s *TestGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
+
+ response := []byte("UploadArchive: " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response})
+
+ return nil
+}
+
+func StartGitalyServer(t *testing.T) (string, *TestGitalyServer, func()) {
+ tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api")
+ gitalySocketPath := path.Join(tempDir, "gitaly.sock")
+
+ err := os.MkdirAll(filepath.Dir(gitalySocketPath), 0700)
+ require.NoError(t, err)
+
+ server := grpc.NewServer()
+
+ listener, err := net.Listen("unix", gitalySocketPath)
+ require.NoError(t, err)
+
+ testServer := TestGitalyServer{}
+ pb.RegisterSSHServiceServer(server, &testServer)
+
+ go server.Serve(listener)
+
+ gitalySocketUrl := "unix:" + gitalySocketPath
+ cleanup := func() {
+ server.Stop()
+ os.RemoveAll(tempDir)
+ }
+
+ return gitalySocketUrl, &testServer, cleanup
+}
diff --git a/client/testserver/testserver.go b/client/testserver/testserver.go
new file mode 100644
index 0000000..377e331
--- /dev/null
+++ b/client/testserver/testserver.go
@@ -0,0 +1,81 @@
+package testserver
+
+import (
+ "crypto/tls"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+var (
+ tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
+ testSocket = path.Join(tempDir, "internal.sock")
+)
+
+type TestRequestHandler struct {
+ Path string
+ Handler func(w http.ResponseWriter, r *http.Request)
+}
+
+func StartSocketHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ err := os.MkdirAll(filepath.Dir(testSocket), 0700)
+ require.NoError(t, err)
+
+ socketListener, err := net.Listen("unix", testSocket)
+ require.NoError(t, err)
+
+ server := http.Server{
+ Handler: buildHandler(handlers),
+ // We'll put this server through some nasty stuff we don't want
+ // in our test output
+ ErrorLog: log.New(ioutil.Discard, "", 0),
+ }
+ go server.Serve(socketListener)
+
+ url := "http+unix://" + testSocket
+
+ return url, cleanupSocket
+}
+
+func StartHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ server := httptest.NewServer(buildHandler(handlers))
+
+ return server.URL, server.Close
+}
+
+func StartHttpsServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt")
+ key := path.Join(testhelper.TestRoot, "certs/valid/server.key")
+
+ server := httptest.NewUnstartedServer(buildHandler(handlers))
+ cer, err := tls.LoadX509KeyPair(crt, key)
+ require.NoError(t, err)
+
+ server.TLS = &tls.Config{Certificates: []tls.Certificate{cer}}
+ server.StartTLS()
+
+ return server.URL, server.Close
+}
+
+func cleanupSocket() {
+ os.RemoveAll(tempDir)
+}
+
+func buildHandler(handlers []TestRequestHandler) http.Handler {
+ h := http.NewServeMux()
+
+ for _, handler := range handlers {
+ h.HandleFunc(handler.Path, handler.Handler)
+ }
+
+ return h
+}