summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Thomas <nick@gitlab.com>2019-03-15 17:16:17 +0000
committerNick Thomas <nick@gitlab.com>2019-03-15 17:16:17 +0000
commitf237aba6df1c1873f1f9d5ba18c3b8924d85cb51 (patch)
tree22d69b9450693bb153e58dbe8b7cd6feb3f8e1e0
parent049beb74303a03d9fa598d23b150e0ccea3cd60d (diff)
parent83c0f18e1de04b3bad9c424084e738e911c47336 (diff)
downloadgitlab-shell-f237aba6df1c1873f1f9d5ba18c3b8924d85cb51.tar.gz
Merge branch 'bvl-discover-command' into 'master'
Call gitlab "/internal/discover" from go Closes #175 See merge request gitlab-org/gitlab-shell!283
-rw-r--r--go/cmd/gitlab-shell/main.go19
-rw-r--r--go/internal/command/command.go3
-rw-r--r--go/internal/command/discover/discover.go34
-rw-r--r--go/internal/command/discover/discover_test.go131
-rw-r--r--go/internal/command/fallback/fallback.go4
-rw-r--r--go/internal/command/reporting/reporter.go8
-rw-r--r--go/internal/gitlabnet/client.go77
-rw-r--r--go/internal/gitlabnet/client_test.go131
-rw-r--r--go/internal/gitlabnet/discover/client.go76
-rw-r--r--go/internal/gitlabnet/discover/client_test.go131
-rw-r--r--go/internal/gitlabnet/socketclient.go46
-rw-r--r--go/internal/gitlabnet/testserver/testserver.go56
-rw-r--r--spec/gitlab_shell_gitlab_shell_spec.rb33
13 files changed, 729 insertions, 20 deletions
diff --git a/go/cmd/gitlab-shell/main.go b/go/cmd/gitlab-shell/main.go
index 07623b4..2ed319d 100644
--- a/go/cmd/gitlab-shell/main.go
+++ b/go/cmd/gitlab-shell/main.go
@@ -7,25 +7,28 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
var (
- binDir string
- rootDir string
+ binDir string
+ rootDir string
+ reporter *reporting.Reporter
)
func init() {
binDir = filepath.Dir(os.Args[0])
rootDir = filepath.Dir(binDir)
+ reporter = &reporting.Reporter{Out: os.Stdout, ErrOut: os.Stderr}
}
// rubyExec will never return. It either replaces the current process with a
// Ruby interpreter, or outputs an error and kills the process.
func execRuby() {
cmd := &fallback.Command{}
- if err := cmd.Execute(); err != nil {
- fmt.Fprintf(os.Stderr, "Failed to exec: %v\n", err)
+ if err := cmd.Execute(reporter); err != nil {
+ fmt.Fprintf(reporter.ErrOut, "Failed to exec: %v\n", err)
os.Exit(1)
}
}
@@ -35,7 +38,7 @@ func main() {
// warning as this isn't something we can sustain indefinitely
config, err := config.NewFromDir(rootDir)
if err != nil {
- fmt.Fprintln(os.Stderr, "Failed to read config, falling back to gitlab-shell-ruby")
+ fmt.Fprintln(reporter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby")
execRuby()
}
@@ -43,14 +46,14 @@ func main() {
if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on
// the environment
- fmt.Fprintf(os.Stderr, "%v\n", err)
+ fmt.Fprintf(reporter.ErrOut, "%v\n", err)
os.Exit(1)
}
// The command will write to STDOUT on execution or replace the current
// process in case of the `fallback.Command`
- if err = cmd.Execute(); err != nil {
- fmt.Fprintf(os.Stderr, "%v\n", err)
+ if err = cmd.Execute(reporter); err != nil {
+ fmt.Fprintf(reporter.ErrOut, "%v\n", err)
os.Exit(1)
}
}
diff --git a/go/internal/command/command.go b/go/internal/command/command.go
index cb2acdc..d4649de 100644
--- a/go/internal/command/command.go
+++ b/go/internal/command/command.go
@@ -4,11 +4,12 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
type Command interface {
- Execute() error
+ Execute(*reporting.Reporter) error
}
func New(arguments []string, config *config.Config) (Command, error) {
diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go
index 63a7a32..8ad2868 100644
--- a/go/internal/command/discover/discover.go
+++ b/go/internal/command/discover/discover.go
@@ -4,7 +4,9 @@ import (
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
)
type Command struct {
@@ -12,6 +14,34 @@ type Command struct {
Args *commandargs.CommandArgs
}
-func (c *Command) Execute() error {
- return fmt.Errorf("No feature is implemented yet")
+func (c *Command) Execute(reporter *reporting.Reporter) error {
+ response, err := c.getUserInfo()
+ if err != nil {
+ return fmt.Errorf("Failed to get username: %v", err)
+ }
+
+ if response.IsAnonymous() {
+ fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n")
+ } else {
+ fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username)
+ }
+
+ return nil
+}
+
+func (c *Command) getUserInfo() (*discover.Response, error) {
+ client, err := discover.NewClient(c.Config)
+ if err != nil {
+ return nil, err
+ }
+
+ if c.Args.GitlabKeyId != "" {
+ return client.GetByKeyId(c.Args.GitlabKeyId)
+ } else if c.Args.GitlabUsername != "" {
+ return client.GetByUsername(c.Args.GitlabUsername)
+ } else {
+ // There was no 'who' information, this matches the ruby error
+ // message.
+ return nil, fmt.Errorf("who='' is invalid")
+ }
}
diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go
new file mode 100644
index 0000000..ec6f931
--- /dev/null
+++ b/go/internal/command/discover/discover_test.go
@@ -0,0 +1,131 @@
+package discover
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+var (
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" {
+ body := map[string]interface{}{
+ "id": 2,
+ "username": "alex-doe",
+ "name": "Alex Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_message" {
+ body := map[string]string{
+ "message": "Forbidden!",
+ }
+ w.WriteHeader(http.StatusForbidden)
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken" {
+ w.WriteHeader(http.StatusInternalServerError)
+ } else {
+ fmt.Fprint(w, "null")
+ }
+ },
+ },
+ }
+)
+
+func TestExecute(t *testing.T) {
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.CommandArgs
+ expectedOutput string
+ }{
+ {
+ desc: "With a known username",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "alex-doe"},
+ expectedOutput: "Welcome to GitLab, @alex-doe!\n",
+ },
+ {
+ desc: "With a known key id",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "1"},
+ expectedOutput: "Welcome to GitLab, @alex-doe!\n",
+ },
+ {
+ desc: "With an unknown key",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"},
+ expectedOutput: "Welcome to GitLab, Anonymous!\n",
+ },
+ {
+ desc: "With an unknown username",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "unknown"},
+ expectedOutput: "Welcome to GitLab, Anonymous!\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ cmd := &Command{Config: testConfig, Args: tc.arguments}
+ buffer := &bytes.Buffer{}
+
+ err := cmd.Execute(&reporting.Reporter{Out: buffer})
+
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expectedOutput, buffer.String())
+ })
+ }
+}
+
+func TestFailingExecute(t *testing.T) {
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.CommandArgs
+ expectedError string
+ }{
+ {
+ desc: "With missing arguments",
+ arguments: &commandargs.CommandArgs{},
+ expectedError: "Failed to get username: who='' is invalid",
+ },
+ {
+ desc: "When the API returns an error",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "broken_message"},
+ expectedError: "Failed to get username: Forbidden!",
+ },
+ {
+ desc: "When the API fails",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "broken"},
+ expectedError: "Failed to get username: Internal API error (500)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ cmd := &Command{Config: testConfig, Args: tc.arguments}
+ buffer := &bytes.Buffer{}
+
+ err := cmd.Execute(&reporting.Reporter{Out: buffer})
+
+ assert.Empty(t, buffer.String())
+ assert.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
diff --git a/go/internal/command/fallback/fallback.go b/go/internal/command/fallback/fallback.go
index a136657..a2c73ed 100644
--- a/go/internal/command/fallback/fallback.go
+++ b/go/internal/command/fallback/fallback.go
@@ -4,6 +4,8 @@ import (
"os"
"path/filepath"
"syscall"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
)
type Command struct{}
@@ -12,7 +14,7 @@ var (
binDir = filepath.Dir(os.Args[0])
)
-func (c *Command) Execute() error {
+func (c *Command) Execute(_ *reporting.Reporter) error {
rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby")
execErr := syscall.Exec(rubyCmd, os.Args, os.Environ())
return execErr
diff --git a/go/internal/command/reporting/reporter.go b/go/internal/command/reporting/reporter.go
new file mode 100644
index 0000000..74bca59
--- /dev/null
+++ b/go/internal/command/reporting/reporter.go
@@ -0,0 +1,8 @@
+package reporting
+
+import "io"
+
+type Reporter struct {
+ Out io.Writer
+ ErrOut io.Writer
+}
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go
new file mode 100644
index 0000000..abc218f
--- /dev/null
+++ b/go/internal/gitlabnet/client.go
@@ -0,0 +1,77 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+)
+
+const (
+ internalApiPath = "/api/v4/internal"
+ secretHeaderName = "Gitlab-Shared-Secret"
+)
+
+type GitlabClient interface {
+ Get(path string) (*http.Response, error)
+ // TODO: implement posts
+ // Post(path string) (http.Response, error)
+}
+
+type ErrorResponse struct {
+ Message string `json:"message"`
+}
+
+func GetClient(config *config.Config) (GitlabClient, error) {
+ url := config.GitlabUrl
+ if strings.HasPrefix(url, UnixSocketProtocol) {
+ return buildSocketClient(config), nil
+ }
+
+ return nil, fmt.Errorf("Unsupported protocol")
+}
+
+func normalizePath(path string) string {
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+
+ if !strings.HasPrefix(path, internalApiPath) {
+ path = internalApiPath + path
+ }
+ return path
+}
+
+func parseError(resp *http.Response) error {
+ if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
+ return nil
+ }
+ defer resp.Body.Close()
+ parsedResponse := &ErrorResponse{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
+ } else {
+ return fmt.Errorf(parsedResponse.Message)
+ }
+
+}
+
+func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) {
+ encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret))
+ request.Header.Set(secretHeaderName, encodedSecret)
+
+ response, err := client.Do(request)
+ if err != nil {
+ return nil, fmt.Errorf("Internal API unreachable")
+ }
+
+ if err := parseError(response); err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go
new file mode 100644
index 0000000..f69f284
--- /dev/null
+++ b/go/internal/gitlabnet/client_test.go
@@ -0,0 +1,131 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+func TestClients(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/hello",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "Hello")
+ },
+ },
+ {
+ Path: "/api/v4/internal/auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, r.Header.Get(secretHeaderName))
+ },
+ },
+ {
+ Path: "/api/v4/internal/error",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ body := map[string]string{
+ "message": "Don't do that",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ {
+ Path: "/api/v4/internal/broken",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ panic("Broken")
+ },
+ },
+ }
+ testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
+
+ testCases := []struct {
+ desc string
+ client GitlabClient
+ server func([]testserver.TestRequestHandler) (func(), error)
+ }{
+ {
+ desc: "Socket client",
+ client: buildSocketClient(testConfig),
+ server: testserver.StartSocketHttpServer,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ cleanup, err := tc.server(requests)
+ defer cleanup()
+ require.NoError(t, err)
+
+ testBrokenRequest(t, tc.client)
+ testSuccessfulGet(t, tc.client)
+ testMissing(t, tc.client)
+ testErrorMessage(t, tc.client)
+ testAuthenticationHeader(t, tc.client)
+ })
+ }
+}
+
+func testSuccessfulGet(t *testing.T, client GitlabClient) {
+ t.Run("Successful get", func(t *testing.T) {
+ response, err := client.Get("/hello")
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, string(responseBody), "Hello")
+ })
+}
+
+func testMissing(t *testing.T, client GitlabClient) {
+ t.Run("Missing error", func(t *testing.T) {
+ response, err := client.Get("/missing")
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+ })
+}
+
+func testErrorMessage(t *testing.T, client GitlabClient) {
+ t.Run("Error with message", func(t *testing.T) {
+ response, err := client.Get("/error")
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
+}
+
+func testBrokenRequest(t *testing.T, client GitlabClient) {
+ t.Run("Broken request", func(t *testing.T) {
+ response, err := client.Get("/broken")
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
+}
+
+func testAuthenticationHeader(t *testing.T, client GitlabClient) {
+ t.Run("Authentication headers", func(t *testing.T) {
+ response, err := client.Get("/auth")
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
+}
diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go
new file mode 100644
index 0000000..8df78fb
--- /dev/null
+++ b/go/internal/gitlabnet/discover/client.go
@@ -0,0 +1,76 @@
+package discover
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+)
+
+type Client struct {
+ config *config.Config
+ client gitlabnet.GitlabClient
+}
+
+type Response struct {
+ UserId int64 `json:"id"`
+ Name string `json:"name"`
+ Username string `json:"username"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetByKeyId(keyId string) (*Response, error) {
+ params := url.Values{}
+ params.Add("key_id", keyId)
+
+ return c.getResponse(params)
+}
+
+func (c *Client) GetByUsername(username string) (*Response, error) {
+ params := url.Values{}
+ params.Add("username", username)
+
+ return c.getResponse(params)
+}
+
+func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
+ parsedResponse := &Response{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return nil, err
+ } else {
+ return parsedResponse, nil
+ }
+}
+
+func (c *Client) getResponse(params url.Values) (*Response, error) {
+ path := "/discover?" + params.Encode()
+ response, err := c.client.Get(path)
+
+ if err != nil {
+ return nil, err
+ }
+
+ defer response.Body.Close()
+ parsedResponse, err := c.parseResponse(response)
+ if err != nil {
+ return nil, fmt.Errorf("Parsing failed")
+ }
+
+ return parsedResponse, nil
+}
+
+func (r *Response) IsAnonymous() bool {
+ return r.UserId < 1
+}
diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go
new file mode 100644
index 0000000..e88cedd
--- /dev/null
+++ b/go/internal/gitlabnet/discover/client_test.go
@@ -0,0 +1,131 @@
+package discover
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ testConfig *config.Config
+ requests []testserver.TestRequestHandler
+)
+
+func init() {
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key_id") == "1" {
+ body := &Response{
+ UserId: 2,
+ Username: "alex-doe",
+ Name: "Alex Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "jane-doe" {
+ body := &Response{
+ UserId: 1,
+ Username: "jane-doe",
+ Name: "Jane Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_message" {
+ w.WriteHeader(http.StatusForbidden)
+ body := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_json" {
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ } else if r.URL.Query().Get("username") == "broken_empty" {
+ w.WriteHeader(http.StatusForbidden)
+ } else {
+ fmt.Fprint(w, "null")
+ }
+ },
+ },
+ }
+}
+
+func TestGetByKeyId(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByKeyId("1")
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
+}
+
+func TestGetByUsername(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByUsername("jane-doe")
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
+}
+
+func TestMissingUser(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByUsername("missing")
+ assert.NoError(t, err)
+ assert.True(t, result.IsAnonymous())
+}
+
+func TestErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ fakeUsername string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ fakeUsername: "broken_message",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ fakeUsername: "broken_json",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "An error response without message",
+ fakeUsername: "broken_empty",
+ expectedError: "Internal API error (403)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ resp, err := client.GetByUsername(tc.fakeUsername)
+
+ assert.EqualError(t, err, tc.expectedError)
+ assert.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+
+ client, err := NewClient(testConfig)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go
new file mode 100644
index 0000000..3bd7c70
--- /dev/null
+++ b/go/internal/gitlabnet/socketclient.go
@@ -0,0 +1,46 @@
+package gitlabnet
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+)
+
+const (
+ // We need to set the base URL to something starting with HTTP, the host
+ // itself is ignored as we're talking over a socket.
+ socketBaseUrl = "http://unix"
+ UnixSocketProtocol = "http+unix://"
+)
+
+type GitlabSocketClient struct {
+ httpClient *http.Client
+ config *config.Config
+}
+
+func buildSocketClient(config *config.Config) *GitlabSocketClient {
+ path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
+ return net.Dial("unix", path)
+ },
+ },
+ }
+
+ return &GitlabSocketClient{httpClient: httpClient, config: config}
+}
+
+func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
+ path = normalizePath(path)
+
+ request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return doRequest(c.httpClient, c.config, request)
+}
diff --git a/go/internal/gitlabnet/testserver/testserver.go b/go/internal/gitlabnet/testserver/testserver.go
new file mode 100644
index 0000000..9640fd7
--- /dev/null
+++ b/go/internal/gitlabnet/testserver/testserver.go
@@ -0,0 +1,56 @@
+package testserver
+
+import (
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+)
+
+var (
+ tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
+ TestSocket = path.Join(tempDir, "internal.sock")
+)
+
+type TestRequestHandler struct {
+ Path string
+ Handler func(w http.ResponseWriter, r *http.Request)
+}
+
+func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
+ if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil {
+ return nil, err
+ }
+
+ socketListener, err := net.Listen("unix", TestSocket)
+ if err != nil {
+ return nil, err
+ }
+
+ server := http.Server{
+ Handler: buildHandler(handlers),
+ // We'll put this server through some nasty stuff we don't want
+ // in our test output
+ ErrorLog: log.New(ioutil.Discard, "", 0),
+ }
+ go server.Serve(socketListener)
+
+ return cleanupSocket, nil
+}
+
+func cleanupSocket() {
+ os.RemoveAll(tempDir)
+}
+
+func buildHandler(handlers []TestRequestHandler) http.Handler {
+ h := http.NewServeMux()
+
+ for _, handler := range handlers {
+ h.HandleFunc(handler.Path, handler.Handler)
+ }
+
+ return h
+}
diff --git a/spec/gitlab_shell_gitlab_shell_spec.rb b/spec/gitlab_shell_gitlab_shell_spec.rb
index 11692d3..cb3fd9c 100644
--- a/spec/gitlab_shell_gitlab_shell_spec.rb
+++ b/spec/gitlab_shell_gitlab_shell_spec.rb
@@ -30,12 +30,19 @@ describe 'bin/gitlab-shell' do
@server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
@server.mount_proc('/api/v4/internal/discover') do |req, res|
- if req.query['key_id'] == '100' ||
- req.query['user_id'] == '10' ||
- req.query['username'] == 'someuser'
+ identifier = req.query['key_id'] || req.query['username'] || req.query['user_id']
+ known_identifiers = %w(10 someuser 100)
+ if known_identifiers.include?(identifier)
res.status = 200
res.content_type = 'application/json'
res.body = '{"id":1, "name": "Some User", "username": "someuser"}'
+ elsif identifier == 'broken_message'
+ res.status = 401
+ res.body = '{"message": "Forbidden!"}'
+ elsif identifier && identifier != 'broken'
+ res.status = 200
+ res.content_type = 'application/json'
+ res.body = 'null'
else
res.status = 500
end
@@ -145,11 +152,7 @@ describe 'bin/gitlab-shell' do
)
end
- it_behaves_like 'results with keys' do
- before do
- pending
- end
- end
+ it_behaves_like 'results with keys'
it 'outputs "Only ssh allowed"' do
_, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-someuser"], env: {})
@@ -157,6 +160,20 @@ describe 'bin/gitlab-shell' do
expect(stderr).to eq("Only ssh allowed\n")
expect(status).not_to be_success
end
+
+ it 'returns an error message when the API call fails with a message' do
+ _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken_message"])
+
+ expect(stderr).to match(/Failed to get username: Forbidden!/)
+ expect(status).not_to be_success
+ end
+
+ it 'returns an error message when the API call fails without a message' do
+ _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken"])
+
+ expect(stderr).to match(/Failed to get username: Internal API error \(500\)/)
+ expect(status).not_to be_success
+ end
end
def run!(args, env: {'SSH_CONNECTION' => 'fake'})