summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorIgor <idrozdov@gitlab.com>2019-10-21 16:25:53 +0000
committerIgor <idrozdov@gitlab.com>2019-10-21 16:25:53 +0000
commit629e3bf9c31687f7b824cf29ba07ad2ce402e280 (patch)
tree0f80f7394231d39970f23a08ba9ba2ce7051e22c /internal
parent7d5229db263a62661653431881bef8b46984d0de (diff)
parentede41ee451dd0aa6d0ecd958c7fadbdb3b63f3e4 (diff)
downloadgitlab-shell-629e3bf9c31687f7b824cf29ba07ad2ce402e280.tar.gz
Merge branch '173-move-go-code-up-one-level' into 'master'
Move Go code up one level See merge request gitlab-org/gitlab-shell!350
Diffstat (limited to 'internal')
-rw-r--r--internal/command/authorizedkeys/authorized_keys.go61
-rw-r--r--internal/command/authorizedkeys/authorized_keys_test.go90
-rw-r--r--internal/command/authorizedprincipals/authorized_principals.go47
-rw-r--r--internal/command/authorizedprincipals/authorized_principals_test.go47
-rw-r--r--internal/command/command.go81
-rw-r--r--internal/command/command_test.go146
-rw-r--r--internal/command/commandargs/authorized_keys.go51
-rw-r--r--internal/command/commandargs/authorized_principals.go50
-rw-r--r--internal/command/commandargs/command_args.go31
-rw-r--r--internal/command/commandargs/command_args_test.go231
-rw-r--r--internal/command/commandargs/generic_args.go14
-rw-r--r--internal/command/commandargs/shell.go131
-rw-r--r--internal/command/discover/discover.go40
-rw-r--r--internal/command/discover/discover_test.go135
-rw-r--r--internal/command/healthcheck/healthcheck.go49
-rw-r--r--internal/command/healthcheck/healthcheck_test.go90
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate.go104
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate_test.go153
-rw-r--r--internal/command/readwriter/readwriter.go9
-rw-r--r--internal/command/receivepack/customaction.go99
-rw-r--r--internal/command/receivepack/customaction_test.go105
-rw-r--r--internal/command/receivepack/gitalycall.go39
-rw-r--r--internal/command/receivepack/gitalycall_test.go40
-rw-r--r--internal/command/receivepack/receivepack.go40
-rw-r--r--internal/command/receivepack/receivepack_test.go32
-rw-r--r--internal/command/shared/accessverifier/accessverifier.go45
-rw-r--r--internal/command/shared/accessverifier/accessverifier_test.go82
-rw-r--r--internal/command/shared/disallowedcommand/disallowedcommand.go7
-rw-r--r--internal/command/twofactorrecover/twofactorrecover.go65
-rw-r--r--internal/command/twofactorrecover/twofactorrecover_test.go136
-rw-r--r--internal/command/uploadarchive/gitalycall.go32
-rw-r--r--internal/command/uploadarchive/gitalycall_test.go40
-rw-r--r--internal/command/uploadarchive/uploadarchive.go36
-rw-r--r--internal/command/uploadarchive/uploadarchive_test.go31
-rw-r--r--internal/command/uploadpack/gitalycall.go36
-rw-r--r--internal/command/uploadpack/gitalycall_test.go40
-rw-r--r--internal/command/uploadpack/uploadpack.go36
-rw-r--r--internal/command/uploadpack/uploadpack_test.go31
-rw-r--r--internal/config/config.go123
-rw-r--r--internal/config/config_test.go112
-rw-r--r--internal/config/httpclient.go122
-rw-r--r--internal/config/httpclient_test.go22
-rw-r--r--internal/executable/executable.go60
-rw-r--r--internal/executable/executable_test.go104
-rw-r--r--internal/gitlabnet/accessverifier/client.go115
-rw-r--r--internal/gitlabnet/accessverifier/client_test.go209
-rw-r--r--internal/gitlabnet/authorizedkeys/client.go65
-rw-r--r--internal/gitlabnet/authorizedkeys/client_test.go105
-rw-r--r--internal/gitlabnet/client.go132
-rw-r--r--internal/gitlabnet/client_test.go219
-rw-r--r--internal/gitlabnet/discover/client.go71
-rw-r--r--internal/gitlabnet/discover/client_test.go137
-rw-r--r--internal/gitlabnet/healthcheck/client.go54
-rw-r--r--internal/gitlabnet/healthcheck/client_test.go48
-rw-r--r--internal/gitlabnet/httpclient_test.go96
-rw-r--r--internal/gitlabnet/httpsclient_test.go125
-rw-r--r--internal/gitlabnet/lfsauthenticate/client.go66
-rw-r--r--internal/gitlabnet/lfsauthenticate/client_test.go117
-rw-r--r--internal/gitlabnet/testserver/gitalyserver.go78
-rw-r--r--internal/gitlabnet/testserver/testserver.go82
-rw-r--r--internal/gitlabnet/twofactorrecover/client.go89
-rw-r--r--internal/gitlabnet/twofactorrecover/client_test.go158
-rw-r--r--internal/handler/exec.go96
-rw-r--r--internal/handler/exec_test.go42
-rw-r--r--internal/keyline/key_line.go62
-rw-r--r--internal/keyline/key_line_test.go82
-rw-r--r--internal/logger/logger.go82
-rw-r--r--internal/sshenv/sshenv.go15
-rw-r--r--internal/sshenv/sshenv_test.go20
-rw-r--r--internal/testhelper/requesthandlers/requesthandlers.go58
-rw-r--r--internal/testhelper/testdata/testroot/.gitlab_shell_secret1
-rw-r--r--internal/testhelper/testdata/testroot/certs/invalid/server.crt10
-rw-r--r--internal/testhelper/testdata/testroot/certs/valid/dir/.gitkeep0
-rw-r--r--internal/testhelper/testdata/testroot/certs/valid/server.crt22
-rw-r--r--internal/testhelper/testdata/testroot/certs/valid/server.key27
-rw-r--r--internal/testhelper/testdata/testroot/config.yml0
-rw-r--r--internal/testhelper/testdata/testroot/custom/my-contents-is-secret1
-rw-r--r--internal/testhelper/testdata/testroot/gitlab-shell.log0
-rw-r--r--internal/testhelper/testdata/testroot/responses/allowed.json22
-rw-r--r--internal/testhelper/testdata/testroot/responses/allowed_with_payload.json31
-rw-r--r--internal/testhelper/testhelper.go93
81 files changed, 5705 insertions, 0 deletions
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go
new file mode 100644
index 0000000..f1cab45
--- /dev/null
+++ b/internal/command/authorizedkeys/authorized_keys.go
@@ -0,0 +1,61 @@
+package authorizedkeys
+
+import (
+ "fmt"
+ "strconv"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/keyline"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.AuthorizedKeys
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ // Do and return nothing when the expected and actual user don't match.
+ // This can happen when the user in sshd_config doesn't match the user
+ // trying to login. When nothing is printed, the user will be denied access.
+ if c.Args.ExpectedUser != c.Args.ActualUser {
+ // TODO: Log this event once we have a consistent way to log in Go.
+ // See https://gitlab.com/gitlab-org/gitlab-shell/issues/192 for more info.
+ return nil
+ }
+
+ if err := c.printKeyLine(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (c *Command) printKeyLine() error {
+ response, err := c.getAuthorizedKey()
+ if err != nil {
+ fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key))
+ return nil
+ }
+
+ keyLine, err := keyline.NewPublicKeyLine(strconv.FormatInt(response.Id, 10), response.Key, c.Config.RootDir)
+ if err != nil {
+ return err
+ }
+
+ fmt.Fprintln(c.ReadWriter.Out, keyLine.ToString())
+
+ return nil
+}
+
+func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) {
+ client, err := authorizedkeys.NewClient(c.Config)
+ if err != nil {
+ return nil, err
+ }
+
+ return client.GetByKey(c.Args.Key)
+}
diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go
new file mode 100644
index 0000000..3bf4153
--- /dev/null
+++ b/internal/command/authorizedkeys/authorized_keys_test.go
@@ -0,0 +1,90 @@
+package authorizedkeys
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+var (
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/authorized_keys",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key") == "key" {
+ body := map[string]interface{}{
+ "id": 1,
+ "key": "public-key",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("key") == "broken-message" {
+ body := map[string]string{
+ "message": "Forbidden!",
+ }
+ w.WriteHeader(http.StatusForbidden)
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("key") == "broken" {
+ w.WriteHeader(http.StatusInternalServerError)
+ } else {
+ w.WriteHeader(http.StatusNotFound)
+ }
+ },
+ },
+ }
+)
+
+func TestExecute(t *testing.T) {
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.AuthorizedKeys
+ expectedOutput string
+ }{
+ {
+ desc: "With matching username and key",
+ arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "key"},
+ expectedOutput: "command=\"/tmp/bin/gitlab-shell key-1\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty public-key\n",
+ },
+ {
+ desc: "When key doesn't match any existing key",
+ arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "not-found"},
+ expectedOutput: "# No key was found for not-found\n",
+ },
+ {
+ desc: "When the API returns an error",
+ arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "broken-message"},
+ expectedOutput: "# No key was found for broken-message\n",
+ },
+ {
+ desc: "When the API fails",
+ arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "broken"},
+ expectedOutput: "# No key was found for broken\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ buffer := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{RootDir: "/tmp", GitlabUrl: url},
+ Args: tc.arguments,
+ ReadWriter: &readwriter.ReadWriter{Out: buffer},
+ }
+
+ err := cmd.Execute()
+
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedOutput, buffer.String())
+ })
+ }
+}
diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go
new file mode 100644
index 0000000..10ae70e
--- /dev/null
+++ b/internal/command/authorizedprincipals/authorized_principals.go
@@ -0,0 +1,47 @@
+package authorizedprincipals
+
+import (
+ "fmt"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/keyline"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.AuthorizedPrincipals
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ if err := c.printPrincipalLines(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (c *Command) printPrincipalLines() error {
+ principals := c.Args.Principals
+
+ for _, principal := range principals {
+ if err := c.printPrincipalLine(principal); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (c *Command) printPrincipalLine(principal string) error {
+ principalKeyLine, err := keyline.NewPrincipalKeyLine(c.Args.KeyId, principal, c.Config.RootDir)
+ if err != nil {
+ return err
+ }
+
+ fmt.Fprintln(c.ReadWriter.Out, principalKeyLine.ToString())
+
+ return nil
+}
diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go
new file mode 100644
index 0000000..f0334e5
--- /dev/null
+++ b/internal/command/authorizedprincipals/authorized_principals_test.go
@@ -0,0 +1,47 @@
+package authorizedprincipals
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+func TestExecute(t *testing.T) {
+ testCases := []struct {
+ desc string
+ arguments *commandargs.AuthorizedPrincipals
+ expectedOutput string
+ }{
+ {
+ desc: "With single principal",
+ arguments: &commandargs.AuthorizedPrincipals{KeyId: "key", Principals: []string{"principal"}},
+ expectedOutput: "command=\"/tmp/bin/gitlab-shell username-key\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty principal\n",
+ },
+ {
+ desc: "With multiple principals",
+ arguments: &commandargs.AuthorizedPrincipals{KeyId: "key", Principals: []string{"principal-1", "principal-2"}},
+ expectedOutput: "command=\"/tmp/bin/gitlab-shell username-key\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty principal-1\ncommand=\"/tmp/bin/gitlab-shell username-key\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty principal-2\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ buffer := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{RootDir: "/tmp"},
+ Args: tc.arguments,
+ ReadWriter: &readwriter.ReadWriter{Out: buffer},
+ }
+
+ err := cmd.Execute()
+
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedOutput, buffer.String())
+ })
+ }
+}
diff --git a/internal/command/command.go b/internal/command/command.go
new file mode 100644
index 0000000..af63862
--- /dev/null
+++ b/internal/command/command.go
@@ -0,0 +1,81 @@
+package command
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+)
+
+type Command interface {
+ Execute() error
+}
+
+func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
+ args, err := commandargs.Parse(e, arguments)
+ if err != nil {
+ return nil, err
+ }
+
+ if cmd := buildCommand(e, args, config, readWriter); cmd != nil {
+ return cmd, nil
+ }
+
+ return nil, disallowedcommand.Error
+}
+
+func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
+ switch e.Name {
+ case executable.GitlabShell:
+ return buildShellCommand(args.(*commandargs.Shell), config, readWriter)
+ case executable.AuthorizedKeysCheck:
+ return buildAuthorizedKeysCommand(args.(*commandargs.AuthorizedKeys), config, readWriter)
+ case executable.AuthorizedPrincipalsCheck:
+ return buildAuthorizedPrincipalsCommand(args.(*commandargs.AuthorizedPrincipals), config, readWriter)
+ case executable.Healthcheck:
+ return buildHealthcheckCommand(config, readWriter)
+ }
+
+ return nil
+}
+
+func buildShellCommand(args *commandargs.Shell, config *config.Config, readWriter *readwriter.ReadWriter) Command {
+ switch args.CommandType {
+ case commandargs.Discover:
+ return &discover.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.TwoFactorRecover:
+ return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.LfsAuthenticate:
+ return &lfsauthenticate.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.ReceivePack:
+ return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.UploadPack:
+ return &uploadpack.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.UploadArchive:
+ return &uploadarchive.Command{Config: config, Args: args, ReadWriter: readWriter}
+ }
+
+ return nil
+}
+
+func buildAuthorizedKeysCommand(args *commandargs.AuthorizedKeys, config *config.Config, readWriter *readwriter.ReadWriter) Command {
+ return &authorizedkeys.Command{Config: config, Args: args, ReadWriter: readWriter}
+}
+
+func buildAuthorizedPrincipalsCommand(args *commandargs.AuthorizedPrincipals, config *config.Config, readWriter *readwriter.ReadWriter) Command {
+ return &authorizedprincipals.Command{Config: config, Args: args, ReadWriter: readWriter}
+}
+
+func buildHealthcheckCommand(config *config.Config, readWriter *readwriter.ReadWriter) Command {
+ return &healthcheck.Command{Config: config, ReadWriter: readWriter}
+}
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
new file mode 100644
index 0000000..2ca319e
--- /dev/null
+++ b/internal/command/command_test.go
@@ -0,0 +1,146 @@
+package command
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+var (
+ authorizedKeysExec = &executable.Executable{Name: executable.AuthorizedKeysCheck}
+ authorizedPrincipalsExec = &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}
+ checkExec = &executable.Executable{Name: executable.Healthcheck}
+ gitlabShellExec = &executable.Executable{Name: executable.GitlabShell}
+
+ basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"}
+)
+
+func buildEnv(command string) map[string]string {
+ return map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": command,
+ }
+}
+
+func TestNew(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ environment map[string]string
+ arguments []string
+ expectedType interface{}
+ }{
+ {
+ desc: "it returns a Discover command",
+ executable: gitlabShellExec,
+ environment: buildEnv(""),
+ expectedType: &discover.Command{},
+ },
+ {
+ desc: "it returns a TwoFactorRecover command",
+ executable: gitlabShellExec,
+ environment: buildEnv("2fa_recovery_codes"),
+ expectedType: &twofactorrecover.Command{},
+ },
+ {
+ desc: "it returns an LfsAuthenticate command",
+ executable: gitlabShellExec,
+ environment: buildEnv("git-lfs-authenticate"),
+ expectedType: &lfsauthenticate.Command{},
+ },
+ {
+ desc: "it returns a ReceivePack command",
+ executable: gitlabShellExec,
+ environment: buildEnv("git-receive-pack"),
+ expectedType: &receivepack.Command{},
+ },
+ {
+ desc: "it returns an UploadPack command",
+ executable: gitlabShellExec,
+ environment: buildEnv("git-upload-pack"),
+ expectedType: &uploadpack.Command{},
+ },
+ {
+ desc: "it returns an UploadArchive command",
+ executable: gitlabShellExec,
+ environment: buildEnv("git-upload-archive"),
+ expectedType: &uploadarchive.Command{},
+ },
+ {
+ desc: "it returns a Healthcheck command",
+ executable: checkExec,
+ expectedType: &healthcheck.Command{},
+ },
+ {
+ desc: "it returns a AuthorizedKeys command",
+ executable: authorizedKeysExec,
+ arguments: []string{"git", "git", "key"},
+ expectedType: &authorizedkeys.Command{},
+ },
+ {
+ desc: "it returns a AuthorizedPrincipals command",
+ executable: authorizedPrincipalsExec,
+ arguments: []string{"key", "principal"},
+ expectedType: &authorizedprincipals.Command{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ restoreEnv := testhelper.TempEnv(tc.environment)
+ defer restoreEnv()
+
+ command, err := New(tc.executable, tc.arguments, basicConfig, nil)
+
+ require.NoError(t, err)
+ require.IsType(t, tc.expectedType, command)
+ })
+ }
+}
+
+func TestFailingNew(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ environment map[string]string
+ expectedError error
+ }{
+ {
+ desc: "Parsing environment failed",
+ executable: gitlabShellExec,
+ expectedError: errors.New("Only SSH allowed"),
+ },
+ {
+ desc: "Unknown command given",
+ executable: gitlabShellExec,
+ environment: buildEnv("unknown"),
+ expectedError: disallowedcommand.Error,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ restoreEnv := testhelper.TempEnv(tc.environment)
+ defer restoreEnv()
+
+ command, err := New(tc.executable, []string{}, basicConfig, nil)
+ require.Nil(t, command)
+ require.Equal(t, tc.expectedError, err)
+ })
+ }
+}
diff --git a/internal/command/commandargs/authorized_keys.go b/internal/command/commandargs/authorized_keys.go
new file mode 100644
index 0000000..2733954
--- /dev/null
+++ b/internal/command/commandargs/authorized_keys.go
@@ -0,0 +1,51 @@
+package commandargs
+
+import (
+ "errors"
+ "fmt"
+)
+
+type AuthorizedKeys struct {
+ Arguments []string
+ ExpectedUser string
+ ActualUser string
+ Key string
+}
+
+func (ak *AuthorizedKeys) Parse() error {
+ if err := ak.validate(); err != nil {
+ return err
+ }
+
+ ak.ExpectedUser = ak.Arguments[0]
+ ak.ActualUser = ak.Arguments[1]
+ ak.Key = ak.Arguments[2]
+
+ return nil
+}
+
+func (ak *AuthorizedKeys) GetArguments() []string {
+ return ak.Arguments
+}
+
+func (ak *AuthorizedKeys) validate() error {
+ argsSize := len(ak.Arguments)
+
+ if argsSize != 3 {
+ return errors.New(fmt.Sprintf("# Insufficient arguments. %d. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", argsSize))
+ }
+
+ expectedUsername := ak.Arguments[0]
+ actualUsername := ak.Arguments[1]
+ key := ak.Arguments[2]
+
+ if expectedUsername == "" || actualUsername == "" {
+ return errors.New("# No username provided")
+ }
+
+ if key == "" {
+ return errors.New("# No key provided")
+ }
+
+ return nil
+}
diff --git a/internal/command/commandargs/authorized_principals.go b/internal/command/commandargs/authorized_principals.go
new file mode 100644
index 0000000..746ae3f
--- /dev/null
+++ b/internal/command/commandargs/authorized_principals.go
@@ -0,0 +1,50 @@
+package commandargs
+
+import (
+ "errors"
+ "fmt"
+)
+
+type AuthorizedPrincipals struct {
+ Arguments []string
+ KeyId string
+ Principals []string
+}
+
+func (ap *AuthorizedPrincipals) Parse() error {
+ if err := ap.validate(); err != nil {
+ return err
+ }
+
+ ap.KeyId = ap.Arguments[0]
+ ap.Principals = ap.Arguments[1:]
+
+ return nil
+}
+
+func (ap *AuthorizedPrincipals) GetArguments() []string {
+ return ap.Arguments
+}
+
+func (ap *AuthorizedPrincipals) validate() error {
+ argsSize := len(ap.Arguments)
+
+ if argsSize < 2 {
+ return errors.New(fmt.Sprintf("# Insufficient arguments. %d. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]", argsSize))
+ }
+
+ keyId := ap.Arguments[0]
+ principals := ap.Arguments[1:]
+
+ if keyId == "" {
+ return errors.New("# No key_id provided")
+ }
+
+ for _, principal := range principals {
+ if principal == "" {
+ return errors.New("# An invalid principal was provided")
+ }
+ }
+
+ return nil
+}
diff --git a/internal/command/commandargs/command_args.go b/internal/command/commandargs/command_args.go
new file mode 100644
index 0000000..b4bf334
--- /dev/null
+++ b/internal/command/commandargs/command_args.go
@@ -0,0 +1,31 @@
+package commandargs
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+)
+
+type CommandType string
+
+type CommandArgs interface {
+ Parse() error
+ GetArguments() []string
+}
+
+func Parse(e *executable.Executable, arguments []string) (CommandArgs, error) {
+ var args CommandArgs = &GenericArgs{Arguments: arguments}
+
+ switch e.Name {
+ case executable.GitlabShell:
+ args = &Shell{Arguments: arguments}
+ case executable.AuthorizedKeysCheck:
+ args = &AuthorizedKeys{Arguments: arguments}
+ case executable.AuthorizedPrincipalsCheck:
+ args = &AuthorizedPrincipals{Arguments: arguments}
+ }
+
+ if err := args.Parse(); err != nil {
+ return nil, err
+ }
+
+ return args, nil
+}
diff --git a/internal/command/commandargs/command_args_test.go b/internal/command/commandargs/command_args_test.go
new file mode 100644
index 0000000..aa74237
--- /dev/null
+++ b/internal/command/commandargs/command_args_test.go
@@ -0,0 +1,231 @@
+package commandargs
+
+import (
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseSuccess(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ environment map[string]string
+ arguments []string
+ expectedArgs CommandArgs
+ }{
+ // Setting the used env variables for every case to ensure we're
+ // not using anything set in the original env.
+ {
+ desc: "It sets discover as the command when the command string was empty",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "",
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{}, CommandType: Discover},
+ },
+ {
+ desc: "It finds the key id in any passed arguments",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "",
+ },
+ arguments: []string{"hello", "key-123"},
+ expectedArgs: &Shell{Arguments: []string{"hello", "key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123"},
+ }, {
+ desc: "It finds the username in any passed arguments",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "",
+ },
+ arguments: []string{"hello", "username-jane-doe"},
+ expectedArgs: &Shell{Arguments: []string{"hello", "username-jane-doe"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe"},
+ }, {
+ desc: "It parses 2fa_recovery_codes command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes",
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover},
+ }, {
+ desc: "It parses git-receive-pack command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "git-receive-pack group/repo",
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
+ }, {
+ desc: "It parses git-receive-pack command and a project with single quotes",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "git receive-pack 'group/repo'",
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
+ }, {
+ desc: `It parses "git receive-pack" command`,
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": `git receive-pack "group/repo"`,
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
+ }, {
+ desc: `It parses a command followed by control characters`,
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": `git-receive-pack group/repo; any command`,
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
+ }, {
+ desc: "It parses git-upload-pack command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": `git upload-pack "group/repo"`,
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-pack", "group/repo"}, CommandType: UploadPack},
+ }, {
+ desc: "It parses git-upload-archive command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "git-upload-archive 'group/repo'",
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-archive", "group/repo"}, CommandType: UploadArchive},
+ }, {
+ desc: "It parses git-lfs-authenticate command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "git-lfs-authenticate 'group/repo' download",
+ },
+ arguments: []string{},
+ expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}, CommandType: LfsAuthenticate},
+ }, {
+ desc: "It parses authorized-keys command",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"git", "git", "key"},
+ expectedArgs: &AuthorizedKeys{Arguments: []string{"git", "git", "key"}, ExpectedUser: "git", ActualUser: "git", Key: "key"},
+ }, {
+ desc: "It parses authorized-principals command",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"key", "principal-1", "principal-2"},
+ expectedArgs: &AuthorizedPrincipals{Arguments: []string{"key", "principal-1", "principal-2"}, KeyId: "key", Principals: []string{"principal-1", "principal-2"}},
+ }, {
+ desc: "Unknown executable",
+ executable: &executable.Executable{Name: "unknown"},
+ arguments: []string{},
+ expectedArgs: &GenericArgs{Arguments: []string{}},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ restoreEnv := testhelper.TempEnv(tc.environment)
+ defer restoreEnv()
+
+ result, err := Parse(tc.executable, tc.arguments)
+
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedArgs, result)
+ })
+ }
+}
+
+func TestParseFailure(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ environment map[string]string
+ arguments []string
+ expectedError string
+ }{
+ {
+ desc: "It fails if SSH connection is not set",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ arguments: []string{},
+ expectedError: "Only SSH allowed",
+ },
+ {
+ desc: "It fails if SSH command is invalid",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": `git receive-pack "`,
+ },
+ arguments: []string{},
+ expectedError: "Invalid SSH command",
+ },
+ {
+ desc: "With not enough arguments for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user"},
+ expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
+ },
+ {
+ desc: "With too many arguments for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user", "user", "key", "something-else"},
+ expectedError: "# Insufficient arguments. 4. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
+ },
+ {
+ desc: "With missing username for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user", "", "key"},
+ expectedError: "# No username provided",
+ },
+ {
+ desc: "With missing key for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user", "user", ""},
+ expectedError: "# No key provided",
+ },
+ {
+ desc: "With not enough arguments for the AuthorizedPrincipalsCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"key"},
+ expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]",
+ },
+ {
+ desc: "With missing key_id for the AuthorizedPrincipalsCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"", "principal"},
+ expectedError: "# No key_id provided",
+ },
+ {
+ desc: "With blank principal for the AuthorizedPrincipalsCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"key", "principal", ""},
+ expectedError: "# An invalid principal was provided",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ restoreEnv := testhelper.TempEnv(tc.environment)
+ defer restoreEnv()
+
+ _, err := Parse(tc.executable, tc.arguments)
+
+ require.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
diff --git a/internal/command/commandargs/generic_args.go b/internal/command/commandargs/generic_args.go
new file mode 100644
index 0000000..96bed99
--- /dev/null
+++ b/internal/command/commandargs/generic_args.go
@@ -0,0 +1,14 @@
+package commandargs
+
+type GenericArgs struct {
+ Arguments []string
+}
+
+func (b *GenericArgs) Parse() error {
+ // Do nothing
+ return nil
+}
+
+func (b *GenericArgs) GetArguments() []string {
+ return b.Arguments
+}
diff --git a/internal/command/commandargs/shell.go b/internal/command/commandargs/shell.go
new file mode 100644
index 0000000..7e2b72e
--- /dev/null
+++ b/internal/command/commandargs/shell.go
@@ -0,0 +1,131 @@
+package commandargs
+
+import (
+ "errors"
+ "os"
+ "regexp"
+
+ "github.com/mattn/go-shellwords"
+)
+
+const (
+ Discover CommandType = "discover"
+ TwoFactorRecover CommandType = "2fa_recovery_codes"
+ LfsAuthenticate CommandType = "git-lfs-authenticate"
+ ReceivePack CommandType = "git-receive-pack"
+ UploadPack CommandType = "git-upload-pack"
+ UploadArchive CommandType = "git-upload-archive"
+)
+
+var (
+ whoKeyRegex = regexp.MustCompile(`\bkey-(?P<keyid>\d+)\b`)
+ whoUsernameRegex = regexp.MustCompile(`\busername-(?P<username>\S+)\b`)
+)
+
+type Shell struct {
+ Arguments []string
+ GitlabUsername string
+ GitlabKeyId string
+ SshArgs []string
+ CommandType CommandType
+}
+
+func (s *Shell) Parse() error {
+ if err := s.validate(); err != nil {
+ return err
+ }
+
+ s.parseWho()
+ s.defineCommandType()
+
+ return nil
+}
+
+func (s *Shell) GetArguments() []string {
+ return s.Arguments
+}
+
+func (s *Shell) validate() error {
+ if !s.isSshConnection() {
+ return errors.New("Only SSH allowed")
+ }
+
+ if !s.isValidSshCommand() {
+ return errors.New("Invalid SSH command")
+ }
+
+ return nil
+}
+
+func (s *Shell) isSshConnection() bool {
+ ok := os.Getenv("SSH_CONNECTION")
+ return ok != ""
+}
+
+func (s *Shell) isValidSshCommand() bool {
+ err := s.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND"))
+ return err == nil
+}
+
+func (s *Shell) parseWho() {
+ for _, argument := range s.Arguments {
+ if keyId := tryParseKeyId(argument); keyId != "" {
+ s.GitlabKeyId = keyId
+ break
+ }
+
+ if username := tryParseUsername(argument); username != "" {
+ s.GitlabUsername = username
+ break
+ }
+ }
+}
+
+func tryParseKeyId(argument string) string {
+ matchInfo := whoKeyRegex.FindStringSubmatch(argument)
+ if len(matchInfo) == 2 {
+ // The first element is the full matched string
+ // The second element is the named `keyid`
+ return matchInfo[1]
+ }
+
+ return ""
+}
+
+func tryParseUsername(argument string) string {
+ matchInfo := whoUsernameRegex.FindStringSubmatch(argument)
+ if len(matchInfo) == 2 {
+ // The first element is the full matched string
+ // The second element is the named `username`
+ return matchInfo[1]
+ }
+
+ return ""
+}
+
+func (s *Shell) parseCommand(commandString string) error {
+ args, err := shellwords.Parse(commandString)
+ if err != nil {
+ return err
+ }
+
+ // Handle Git for Windows 2.14 using "git upload-pack" instead of git-upload-pack
+ if len(args) > 1 && args[0] == "git" {
+ command := args[0] + "-" + args[1]
+ commandArgs := args[2:]
+
+ args = append([]string{command}, commandArgs...)
+ }
+
+ s.SshArgs = args
+
+ return nil
+}
+
+func (s *Shell) defineCommandType() {
+ if len(s.SshArgs) == 0 {
+ s.CommandType = Discover
+ } else {
+ s.CommandType = CommandType(s.SshArgs[0])
+ }
+}
diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go
new file mode 100644
index 0000000..3aa7456
--- /dev/null
+++ b/internal/command/discover/discover.go
@@ -0,0 +1,40 @@
+package discover
+
+import (
+ "fmt"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/discover"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.Shell
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ response, err := c.getUserInfo()
+ if err != nil {
+ return fmt.Errorf("Failed to get username: %v", err)
+ }
+
+ if response.IsAnonymous() {
+ fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n")
+ } else {
+ fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username)
+ }
+
+ return nil
+}
+
+func (c *Command) getUserInfo() (*discover.Response, error) {
+ client, err := discover.NewClient(c.Config)
+ if err != nil {
+ return nil, err
+ }
+
+ return client.GetByCommandArgs(c.Args)
+}
diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go
new file mode 100644
index 0000000..3878286
--- /dev/null
+++ b/internal/command/discover/discover_test.go
@@ -0,0 +1,135 @@
+package discover
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+var (
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" {
+ body := map[string]interface{}{
+ "id": 2,
+ "username": "alex-doe",
+ "name": "Alex Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_message" {
+ body := map[string]string{
+ "message": "Forbidden!",
+ }
+ w.WriteHeader(http.StatusForbidden)
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken" {
+ w.WriteHeader(http.StatusInternalServerError)
+ } else {
+ fmt.Fprint(w, "null")
+ }
+ },
+ },
+ }
+)
+
+func TestExecute(t *testing.T) {
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.Shell
+ expectedOutput string
+ }{
+ {
+ desc: "With a known username",
+ arguments: &commandargs.Shell{GitlabUsername: "alex-doe"},
+ expectedOutput: "Welcome to GitLab, @alex-doe!\n",
+ },
+ {
+ desc: "With a known key id",
+ arguments: &commandargs.Shell{GitlabKeyId: "1"},
+ expectedOutput: "Welcome to GitLab, @alex-doe!\n",
+ },
+ {
+ desc: "With an unknown key",
+ arguments: &commandargs.Shell{GitlabKeyId: "-1"},
+ expectedOutput: "Welcome to GitLab, Anonymous!\n",
+ },
+ {
+ desc: "With an unknown username",
+ arguments: &commandargs.Shell{GitlabUsername: "unknown"},
+ expectedOutput: "Welcome to GitLab, Anonymous!\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ buffer := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: tc.arguments,
+ ReadWriter: &readwriter.ReadWriter{Out: buffer},
+ }
+
+ err := cmd.Execute()
+
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedOutput, buffer.String())
+ })
+ }
+}
+
+func TestFailingExecute(t *testing.T) {
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.Shell
+ expectedError string
+ }{
+ {
+ desc: "With missing arguments",
+ arguments: &commandargs.Shell{},
+ expectedError: "Failed to get username: who='' is invalid",
+ },
+ {
+ desc: "When the API returns an error",
+ arguments: &commandargs.Shell{GitlabUsername: "broken_message"},
+ expectedError: "Failed to get username: Forbidden!",
+ },
+ {
+ desc: "When the API fails",
+ arguments: &commandargs.Shell{GitlabUsername: "broken"},
+ expectedError: "Failed to get username: Internal API error (500)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ buffer := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: tc.arguments,
+ ReadWriter: &readwriter.ReadWriter{Out: buffer},
+ }
+
+ err := cmd.Execute()
+
+ require.Empty(t, buffer.String())
+ require.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go
new file mode 100644
index 0000000..bbc73dc
--- /dev/null
+++ b/internal/command/healthcheck/healthcheck.go
@@ -0,0 +1,49 @@
+package healthcheck
+
+import (
+ "fmt"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/healthcheck"
+)
+
+var (
+ apiMessage = "Internal API available"
+ redisMessage = "Redis available via internal API"
+)
+
+type Command struct {
+ Config *config.Config
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ response, err := c.runCheck()
+ if err != nil {
+ return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
+ }
+
+ fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", apiMessage)
+
+ if !response.Redis {
+ return fmt.Errorf("%v: FAILED", redisMessage)
+ }
+
+ fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", redisMessage)
+ return nil
+}
+
+func (c *Command) runCheck() (*healthcheck.Response, error) {
+ client, err := healthcheck.NewClient(c.Config)
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := client.Check()
+ if err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go
new file mode 100644
index 0000000..e59c5a2
--- /dev/null
+++ b/internal/command/healthcheck/healthcheck_test.go
@@ -0,0 +1,90 @@
+package healthcheck
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/healthcheck"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+var (
+ okResponse = &healthcheck.Response{
+ APIVersion: "v4",
+ GitlabVersion: "v12.0.0-ee",
+ GitlabRevision: "3b13818e8330f68625d80d9bf5d8049c41fbe197",
+ Redis: true,
+ }
+
+ badRedisResponse = &healthcheck.Response{Redis: false}
+
+ okHandlers = buildTestHandlers(200, okResponse)
+ badRedisHandlers = buildTestHandlers(200, badRedisResponse)
+ brokenHandlers = buildTestHandlers(500, nil)
+)
+
+func buildTestHandlers(code int, rsp *healthcheck.Response) []testserver.TestRequestHandler {
+ return []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/check",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(code)
+ if rsp != nil {
+ json.NewEncoder(w).Encode(rsp)
+ }
+ },
+ },
+ }
+}
+
+func TestExecute(t *testing.T) {
+ url, cleanup := testserver.StartSocketHttpServer(t, okHandlers)
+ defer cleanup()
+
+ buffer := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ ReadWriter: &readwriter.ReadWriter{Out: buffer},
+ }
+
+ err := cmd.Execute()
+
+ require.NoError(t, err)
+ require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
+}
+
+func TestFailingRedisExecute(t *testing.T) {
+ url, cleanup := testserver.StartSocketHttpServer(t, badRedisHandlers)
+ defer cleanup()
+
+ buffer := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ ReadWriter: &readwriter.ReadWriter{Out: buffer},
+ }
+
+ err := cmd.Execute()
+ require.Error(t, err, "Redis available via internal API: FAILED")
+ require.Equal(t, "Internal API available: OK\n", buffer.String())
+}
+
+func TestFailingAPIExecute(t *testing.T) {
+ url, cleanup := testserver.StartSocketHttpServer(t, brokenHandlers)
+ defer cleanup()
+
+ buffer := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ ReadWriter: &readwriter.ReadWriter{Out: buffer},
+ }
+
+ err := cmd.Execute()
+ require.Empty(t, buffer.String())
+ require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)")
+}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go
new file mode 100644
index 0000000..1b2a742
--- /dev/null
+++ b/internal/command/lfsauthenticate/lfsauthenticate.go
@@ -0,0 +1,104 @@
+package lfsauthenticate
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/lfsauthenticate"
+)
+
+const (
+ downloadAction = "download"
+ uploadAction = "upload"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.Shell
+ ReadWriter *readwriter.ReadWriter
+}
+
+type PayloadHeader struct {
+ Auth string `json:"Authorization"`
+}
+
+type Payload struct {
+ Header PayloadHeader `json:"header"`
+ Href string `json:"href"`
+ ExpiresIn int `json:"expires_in,omitempty"`
+}
+
+func (c *Command) Execute() error {
+ args := c.Args.SshArgs
+ if len(args) < 3 {
+ return disallowedcommand.Error
+ }
+
+ repo := args[1]
+ action, err := actionToCommandType(args[2])
+ if err != nil {
+ return err
+ }
+
+ accessResponse, err := c.verifyAccess(action, repo)
+ if err != nil {
+ return err
+ }
+
+ payload, err := c.authenticate(action, repo, accessResponse.UserId)
+ if err != nil {
+ // return nothing just like Ruby's GitlabShell#lfs_authenticate does
+ return nil
+ }
+
+ fmt.Fprintf(c.ReadWriter.Out, "%s\n", payload)
+
+ return nil
+}
+
+func actionToCommandType(action string) (commandargs.CommandType, error) {
+ var accessAction commandargs.CommandType
+ switch action {
+ case downloadAction:
+ accessAction = commandargs.UploadPack
+ case uploadAction:
+ accessAction = commandargs.ReceivePack
+ default:
+ return "", disallowedcommand.Error
+ }
+
+ return accessAction, nil
+}
+
+func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
+ cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
+
+ return cmd.Verify(action, repo)
+}
+
+func (c *Command) authenticate(action commandargs.CommandType, repo, userId string) ([]byte, error) {
+ client, err := lfsauthenticate.NewClient(c.Config, c.Args)
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := client.Authenticate(action, repo, userId)
+ if err != nil {
+ return nil, err
+ }
+
+ basicAuth := base64.StdEncoding.EncodeToString([]byte(response.Username + ":" + response.LfsToken))
+ payload := &Payload{
+ Header: PayloadHeader{Auth: "Basic " + basicAuth},
+ Href: response.RepoPath + "/info/lfs",
+ ExpiresIn: response.ExpiresIn,
+ }
+
+ return json.Marshal(payload)
+}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go
new file mode 100644
index 0000000..f2ccc20
--- /dev/null
+++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go
@@ -0,0 +1,153 @@
+package lfsauthenticate
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/lfsauthenticate"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers"
+)
+
+func TestFailedRequests(t *testing.T) {
+ requests := requesthandlers.BuildDisallowedByApiHandlers(t)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.Shell
+ expectedOutput string
+ }{
+ {
+ desc: "With missing arguments",
+ arguments: &commandargs.Shell{},
+ expectedOutput: "> GitLab: Disallowed command",
+ },
+ {
+ desc: "With disallowed command",
+ arguments: &commandargs.Shell{GitlabKeyId: "1", SshArgs: []string{"git-lfs-authenticate", "group/repo", "unknown"}},
+ expectedOutput: "> GitLab: Disallowed command",
+ },
+ {
+ desc: "With disallowed user",
+ arguments: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}},
+ expectedOutput: "Disallowed by API call",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ output := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: tc.arguments,
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
+ }
+
+ err := cmd.Execute()
+ require.Error(t, err)
+
+ require.Equal(t, tc.expectedOutput, err.Error())
+ })
+ }
+}
+
+func TestLfsAuthenticateRequests(t *testing.T) {
+ userId := "123"
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/lfs_authenticate",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+ require.NoError(t, err)
+
+ var request *lfsauthenticate.Request
+ require.NoError(t, json.Unmarshal(b, &request))
+
+ if request.UserId == userId {
+ body := map[string]interface{}{
+ "username": "john",
+ "lfs_token": "sometoken",
+ "repository_http_path": "https://gitlab.com/repo/path",
+ "expires_in": 1800,
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ } else {
+ w.WriteHeader(http.StatusForbidden)
+ }
+ },
+ },
+ {
+ Path: "/api/v4/internal/allowed",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+ require.NoError(t, err)
+
+ var request *accessverifier.Request
+ require.NoError(t, json.Unmarshal(b, &request))
+
+ var glId string
+ if request.Username == "somename" {
+ glId = userId
+ } else {
+ glId = "100"
+ }
+
+ body := map[string]interface{}{
+ "gl_id": glId,
+ "status": true,
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ },
+ },
+ }
+
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ username string
+ expectedOutput string
+ }{
+ {
+ desc: "With successful response from API",
+ username: "somename",
+ expectedOutput: "{\"header\":{\"Authorization\":\"Basic am9objpzb21ldG9rZW4=\"},\"href\":\"https://gitlab.com/repo/path/info/lfs\",\"expires_in\":1800}\n",
+ },
+ {
+ desc: "With forbidden response from API",
+ username: "anothername",
+ expectedOutput: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ output := &bytes.Buffer{}
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabUsername: tc.username, SshArgs: []string{"git-lfs-authenticate", "group/repo", "upload"}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
+ }
+
+ err := cmd.Execute()
+ require.NoError(t, err)
+
+ require.Equal(t, tc.expectedOutput, output.String())
+ })
+ }
+}
diff --git a/internal/command/readwriter/readwriter.go b/internal/command/readwriter/readwriter.go
new file mode 100644
index 0000000..da18d30
--- /dev/null
+++ b/internal/command/readwriter/readwriter.go
@@ -0,0 +1,9 @@
+package readwriter
+
+import "io"
+
+type ReadWriter struct {
+ Out io.Writer
+ In io.Reader
+ ErrOut io.Writer
+}
diff --git a/internal/command/receivepack/customaction.go b/internal/command/receivepack/customaction.go
new file mode 100644
index 0000000..c94ae4c
--- /dev/null
+++ b/internal/command/receivepack/customaction.go
@@ -0,0 +1,99 @@
+package receivepack
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
+)
+
+type Request struct {
+ SecretToken []byte `json:"secret_token"`
+ Data accessverifier.CustomPayloadData `json:"data"`
+ Output []byte `json:"output"`
+}
+
+type Response struct {
+ Result []byte `json:"result"`
+ Message string `json:"message"`
+}
+
+func (c *Command) processCustomAction(response *accessverifier.Response) error {
+ data := response.Payload.Data
+ apiEndpoints := data.ApiEndpoints
+
+ if len(apiEndpoints) == 0 {
+ return errors.New("Custom action error: Empty API endpoints")
+ }
+
+ c.displayInfoMessage(data.InfoMessage)
+
+ return c.processApiEndpoints(response)
+}
+
+func (c *Command) displayInfoMessage(infoMessage string) {
+ messages := strings.Split(infoMessage, "\n")
+
+ for _, msg := range messages {
+ fmt.Fprintf(c.ReadWriter.ErrOut, "> GitLab: %v\n", msg)
+ }
+}
+
+func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
+ client, err := gitlabnet.GetClient(c.Config)
+
+ if err != nil {
+ return err
+ }
+
+ data := response.Payload.Data
+ request := &Request{Data: data}
+ request.Data.UserId = response.Who
+
+ for _, endpoint := range data.ApiEndpoints {
+ response, err := c.performRequest(client, endpoint, request)
+ if err != nil {
+ return err
+ }
+
+ if err = c.displayResult(response.Result); err != nil {
+ return err
+ }
+
+ // In the context of the git push sequence of events, it's necessary to read
+ // stdin in order to capture output to pass onto subsequent commands
+ output, err := ioutil.ReadAll(c.ReadWriter.In)
+ if err != nil {
+ return err
+ }
+ request.Output = output
+ }
+
+ return nil
+}
+
+func (c *Command) performRequest(client *gitlabnet.GitlabClient, endpoint string, request *Request) (*Response, error) {
+ response, err := client.DoRequest(http.MethodPost, endpoint, request)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+
+ cr := &Response{}
+ if err := gitlabnet.ParseJSON(response, cr); err != nil {
+ return nil, err
+ }
+
+ return cr, nil
+}
+
+func (c *Command) displayResult(result []byte) error {
+ _, err := io.Copy(c.ReadWriter.Out, bytes.NewReader(result))
+ return err
+}
diff --git a/internal/command/receivepack/customaction_test.go b/internal/command/receivepack/customaction_test.go
new file mode 100644
index 0000000..2a4a718
--- /dev/null
+++ b/internal/command/receivepack/customaction_test.go
@@ -0,0 +1,105 @@
+package receivepack
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+func TestCustomReceivePack(t *testing.T) {
+ repo := "group/repo"
+ keyId := "1"
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/allowed",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+
+ var request *accessverifier.Request
+ require.NoError(t, json.Unmarshal(b, &request))
+
+ require.Equal(t, "1", request.KeyId)
+
+ body := map[string]interface{}{
+ "status": true,
+ "gl_id": "1",
+ "payload": map[string]interface{}{
+ "action": "geo_proxy_to_primary",
+ "data": map[string]interface{}{
+ "api_endpoints": []string{"/geo/proxy_git_push_ssh/info_refs", "/geo/proxy_git_push_ssh/push"},
+ "gl_username": "custom",
+ "primary_repo": "https://repo/path",
+ "info_message": "info_message\none more message",
+ },
+ },
+ }
+ w.WriteHeader(http.StatusMultipleChoices)
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ },
+ },
+ {
+ Path: "/geo/proxy_git_push_ssh/info_refs",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+
+ var request *Request
+ require.NoError(t, json.Unmarshal(b, &request))
+
+ require.Equal(t, request.Data.UserId, "key-"+keyId)
+ require.Empty(t, request.Output)
+
+ err = json.NewEncoder(w).Encode(Response{Result: []byte("custom")})
+ require.NoError(t, err)
+ },
+ },
+ {
+ Path: "/geo/proxy_git_push_ssh/push",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+
+ var request *Request
+ require.NoError(t, json.Unmarshal(b, &request))
+
+ require.Equal(t, request.Data.UserId, "key-"+keyId)
+ require.Equal(t, "input", string(request.Output))
+
+ err = json.NewEncoder(w).Encode(Response{Result: []byte("output")})
+ require.NoError(t, err)
+ },
+ },
+ }
+
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+ defer cleanup()
+
+ outBuf := &bytes.Buffer{}
+ errBuf := &bytes.Buffer{}
+ input := bytes.NewBufferString("input")
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabKeyId: keyId, CommandType: commandargs.ReceivePack, SshArgs: []string{"git-receive-pack", repo}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: errBuf, Out: outBuf, In: input},
+ }
+
+ require.NoError(t, cmd.Execute())
+
+ // expect printing of info message, "custom" string from the first request
+ // and "output" string from the second request
+ require.Equal(t, "> GitLab: info_message\n> GitLab: one more message\n", errBuf.String())
+ require.Equal(t, "customoutput", outBuf.String())
+}
diff --git a/internal/command/receivepack/gitalycall.go b/internal/command/receivepack/gitalycall.go
new file mode 100644
index 0000000..f440672
--- /dev/null
+++ b/internal/command/receivepack/gitalycall.go
@@ -0,0 +1,39 @@
+package receivepack
+
+import (
+ "context"
+
+ "google.golang.org/grpc"
+
+ "gitlab.com/gitlab-org/gitaly/client"
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/handler"
+)
+
+func (c *Command) performGitalyCall(response *accessverifier.Response) error {
+ gc := &handler.GitalyCommand{
+ Config: c.Config,
+ ServiceName: string(commandargs.ReceivePack),
+ Address: response.Gitaly.Address,
+ Token: response.Gitaly.Token,
+ }
+
+ request := &pb.SSHReceivePackRequest{
+ Repository: &response.Gitaly.Repo,
+ GlId: response.UserId,
+ GlRepository: response.Repo,
+ GlUsername: response.Username,
+ GitProtocol: response.GitProtocol,
+ GitConfigOptions: response.GitConfigOptions,
+ }
+
+ return gc.RunGitalyCommand(func(ctx context.Context, conn *grpc.ClientConn) (int32, error) {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ rw := c.ReadWriter
+ return client.ReceivePack(ctx, conn, rw.In, rw.Out, rw.ErrOut, request)
+ })
+}
diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go
new file mode 100644
index 0000000..361596a
--- /dev/null
+++ b/internal/command/receivepack/gitalycall_test.go
@@ -0,0 +1,40 @@
+package receivepack
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers"
+)
+
+func TestReceivePack(t *testing.T) {
+ gitalyAddress, cleanup := testserver.StartGitalyServer(t)
+ defer cleanup()
+
+ requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ output := &bytes.Buffer{}
+ input := &bytes.Buffer{}
+
+ userId := "1"
+ repo := "group/repo"
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabKeyId: userId, CommandType: commandargs.ReceivePack, SshArgs: []string{"git-receive-pack", repo}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input},
+ }
+
+ err := cmd.Execute()
+ require.NoError(t, err)
+
+ require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String())
+}
diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go
new file mode 100644
index 0000000..aaaf7b0
--- /dev/null
+++ b/internal/command/receivepack/receivepack.go
@@ -0,0 +1,40 @@
+package receivepack
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.Shell
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ args := c.Args.SshArgs
+ if len(args) != 2 {
+ return disallowedcommand.Error
+ }
+
+ repo := args[1]
+ response, err := c.verifyAccess(repo)
+ if err != nil {
+ return err
+ }
+
+ if response.IsCustomAction() {
+ return c.processCustomAction(response)
+ }
+
+ return c.performGitalyCall(response)
+}
+
+func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+ cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
+
+ return cmd.Verify(c.Args.CommandType, repo)
+}
diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go
new file mode 100644
index 0000000..1d7bd21
--- /dev/null
+++ b/internal/command/receivepack/receivepack_test.go
@@ -0,0 +1,32 @@
+package receivepack
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers"
+)
+
+func TestForbiddenAccess(t *testing.T) {
+ requests := requesthandlers.BuildDisallowedByApiHandlers(t)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ output := &bytes.Buffer{}
+ input := bytes.NewBufferString("input")
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-receive-pack", "group/repo"}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input},
+ }
+
+ err := cmd.Execute()
+ require.Equal(t, "Disallowed by API call", err.Error())
+}
diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go
new file mode 100644
index 0000000..3aaf98d
--- /dev/null
+++ b/internal/command/shared/accessverifier/accessverifier.go
@@ -0,0 +1,45 @@
+package accessverifier
+
+import (
+ "errors"
+ "fmt"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
+)
+
+type Response = accessverifier.Response
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.Shell
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) {
+ client, err := accessverifier.NewClient(c.Config)
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := client.Verify(c.Args, action, repo)
+ if err != nil {
+ return nil, err
+ }
+
+ c.displayConsoleMessages(response.ConsoleMessages)
+
+ if !response.Success {
+ return nil, errors.New(response.Message)
+ }
+
+ return response, nil
+}
+
+func (c *Command) displayConsoleMessages(messages []string) {
+ for _, msg := range messages {
+ fmt.Fprintf(c.ReadWriter.ErrOut, "> GitLab: %v\n", msg)
+ }
+}
diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go
new file mode 100644
index 0000000..39c2a66
--- /dev/null
+++ b/internal/command/shared/accessverifier/accessverifier_test.go
@@ -0,0 +1,82 @@
+package accessverifier
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+var (
+ repo = "group/repo"
+ action = commandargs.ReceivePack
+)
+
+func setup(t *testing.T) (*Command, *bytes.Buffer, *bytes.Buffer, func()) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/allowed",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+
+ var requestBody *accessverifier.Request
+ err = json.Unmarshal(b, &requestBody)
+ require.NoError(t, err)
+
+ if requestBody.KeyId == "1" {
+ body := map[string]interface{}{
+ "gl_console_messages": []string{"console", "message"},
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ } else {
+ body := map[string]interface{}{
+ "status": false,
+ "message": "missing user",
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ }
+ },
+ },
+ }
+
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+
+ errBuf := &bytes.Buffer{}
+ outBuf := &bytes.Buffer{}
+
+ readWriter := &readwriter.ReadWriter{Out: outBuf, ErrOut: errBuf}
+ cmd := &Command{Config: &config.Config{GitlabUrl: url}, ReadWriter: readWriter}
+
+ return cmd, errBuf, outBuf, cleanup
+}
+
+func TestMissingUser(t *testing.T) {
+ cmd, _, _, cleanup := setup(t)
+ defer cleanup()
+
+ cmd.Args = &commandargs.Shell{GitlabKeyId: "2"}
+ _, err := cmd.Verify(action, repo)
+
+ require.Equal(t, "missing user", err.Error())
+}
+
+func TestConsoleMessages(t *testing.T) {
+ cmd, errBuf, outBuf, cleanup := setup(t)
+ defer cleanup()
+
+ cmd.Args = &commandargs.Shell{GitlabKeyId: "1"}
+ cmd.Verify(action, repo)
+
+ require.Equal(t, "> GitLab: console\n> GitLab: message\n", errBuf.String())
+ require.Empty(t, outBuf.String())
+}
diff --git a/internal/command/shared/disallowedcommand/disallowedcommand.go b/internal/command/shared/disallowedcommand/disallowedcommand.go
new file mode 100644
index 0000000..3c98bcc
--- /dev/null
+++ b/internal/command/shared/disallowedcommand/disallowedcommand.go
@@ -0,0 +1,7 @@
+package disallowedcommand
+
+import "errors"
+
+var (
+ Error = errors.New("> GitLab: Disallowed command")
+)
diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go
new file mode 100644
index 0000000..2f13cc5
--- /dev/null
+++ b/internal/command/twofactorrecover/twofactorrecover.go
@@ -0,0 +1,65 @@
+package twofactorrecover
+
+import (
+ "fmt"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/twofactorrecover"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.Shell
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ if c.canContinue() {
+ c.displayRecoveryCodes()
+ } else {
+ fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
+ }
+
+ return nil
+}
+
+func (c *Command) canContinue() bool {
+ question :=
+ "Are you sure you want to generate new two-factor recovery codes?\n" +
+ "Any existing recovery codes you saved will be invalidated. (yes/no)"
+ fmt.Fprintln(c.ReadWriter.Out, question)
+
+ var answer string
+ fmt.Fscanln(c.ReadWriter.In, &answer)
+
+ return answer == "yes"
+}
+
+func (c *Command) displayRecoveryCodes() {
+ codes, err := c.getRecoveryCodes()
+
+ if err == nil {
+ messageWithCodes :=
+ "\nYour two-factor authentication recovery codes are:\n\n" +
+ strings.Join(codes, "\n") +
+ "\n\nDuring sign in, use one of the codes above when prompted for\n" +
+ "your two-factor code. Then, visit your Profile Settings and add\n" +
+ "a new device so you do not lose access to your account again.\n"
+ fmt.Fprint(c.ReadWriter.Out, messageWithCodes)
+ } else {
+ fmt.Fprintf(c.ReadWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err)
+ }
+}
+
+func (c *Command) getRecoveryCodes() ([]string, error) {
+ client, err := twofactorrecover.NewClient(c.Config)
+
+ if err != nil {
+ return nil, err
+ }
+
+ return client.GetRecoveryCodes(c.Args)
+}
diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go
new file mode 100644
index 0000000..283c45a
--- /dev/null
+++ b/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -0,0 +1,136 @@
+package twofactorrecover
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/twofactorrecover"
+)
+
+var (
+ requests []testserver.TestRequestHandler
+)
+
+func setup(t *testing.T) {
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/two_factor_recovery_codes",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ var requestBody *twofactorrecover.RequestBody
+ json.Unmarshal(b, &requestBody)
+
+ switch requestBody.KeyId {
+ case "1":
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery", "codes"},
+ }
+ json.NewEncoder(w).Encode(body)
+ case "forbidden":
+ body := map[string]interface{}{
+ "success": false,
+ "message": "Forbidden!",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "broken":
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ },
+ },
+ }
+}
+
+const (
+ question = "Are you sure you want to generate new two-factor recovery codes?\n" +
+ "Any existing recovery codes you saved will be invalidated. (yes/no)\n\n"
+ errorHeader = "An error occurred while trying to generate new recovery codes.\n"
+)
+
+func TestExecute(t *testing.T) {
+ setup(t)
+
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.Shell
+ answer string
+ expectedOutput string
+ }{
+ {
+ desc: "With a known key id",
+ arguments: &commandargs.Shell{GitlabKeyId: "1"},
+ answer: "yes\n",
+ expectedOutput: question +
+ "Your two-factor authentication recovery codes are:\n\nrecovery\ncodes\n\n" +
+ "During sign in, use one of the codes above when prompted for\n" +
+ "your two-factor code. Then, visit your Profile Settings and add\n" +
+ "a new device so you do not lose access to your account again.\n",
+ },
+ {
+ desc: "With bad response",
+ arguments: &commandargs.Shell{GitlabKeyId: "-1"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Parsing failed\n",
+ },
+ {
+ desc: "With API returns an error",
+ arguments: &commandargs.Shell{GitlabKeyId: "forbidden"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Forbidden!\n",
+ },
+ {
+ desc: "With API fails",
+ arguments: &commandargs.Shell{GitlabKeyId: "broken"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Internal API error (500)\n",
+ },
+ {
+ desc: "With missing arguments",
+ arguments: &commandargs.Shell{},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "who='' is invalid\n",
+ },
+ {
+ desc: "With negative answer",
+ arguments: &commandargs.Shell{},
+ answer: "no\n",
+ expectedOutput: question +
+ "New recovery codes have *not* been generated. Existing codes will remain valid.\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ output := &bytes.Buffer{}
+ input := bytes.NewBufferString(tc.answer)
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: tc.arguments,
+ ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
+ }
+
+ err := cmd.Execute()
+
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expectedOutput, output.String())
+ })
+ }
+}
diff --git a/internal/command/uploadarchive/gitalycall.go b/internal/command/uploadarchive/gitalycall.go
new file mode 100644
index 0000000..1dfc864
--- /dev/null
+++ b/internal/command/uploadarchive/gitalycall.go
@@ -0,0 +1,32 @@
+package uploadarchive
+
+import (
+ "context"
+
+ "google.golang.org/grpc"
+
+ "gitlab.com/gitlab-org/gitaly/client"
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/handler"
+)
+
+func (c *Command) performGitalyCall(response *accessverifier.Response) error {
+ gc := &handler.GitalyCommand{
+ Config: c.Config,
+ ServiceName: string(commandargs.UploadArchive),
+ Address: response.Gitaly.Address,
+ Token: response.Gitaly.Token,
+ }
+
+ request := &pb.SSHUploadArchiveRequest{Repository: &response.Gitaly.Repo}
+
+ return gc.RunGitalyCommand(func(ctx context.Context, conn *grpc.ClientConn) (int32, error) {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ rw := c.ReadWriter
+ return client.UploadArchive(ctx, conn, rw.In, rw.Out, rw.ErrOut, request)
+ })
+}
diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go
new file mode 100644
index 0000000..5c5353f
--- /dev/null
+++ b/internal/command/uploadarchive/gitalycall_test.go
@@ -0,0 +1,40 @@
+package uploadarchive
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers"
+)
+
+func TestUploadPack(t *testing.T) {
+ gitalyAddress, cleanup := testserver.StartGitalyServer(t)
+ defer cleanup()
+
+ requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ output := &bytes.Buffer{}
+ input := &bytes.Buffer{}
+
+ userId := "1"
+ repo := "group/repo"
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabKeyId: userId, CommandType: commandargs.UploadArchive, SshArgs: []string{"git-upload-archive", repo}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input},
+ }
+
+ err := cmd.Execute()
+ require.NoError(t, err)
+
+ require.Equal(t, "UploadArchive: "+repo, output.String())
+}
diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go
new file mode 100644
index 0000000..9d4fbe0
--- /dev/null
+++ b/internal/command/uploadarchive/uploadarchive.go
@@ -0,0 +1,36 @@
+package uploadarchive
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.Shell
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ args := c.Args.SshArgs
+ if len(args) != 2 {
+ return disallowedcommand.Error
+ }
+
+ repo := args[1]
+ response, err := c.verifyAccess(repo)
+ if err != nil {
+ return err
+ }
+
+ return c.performGitalyCall(response)
+}
+
+func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+ cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
+
+ return cmd.Verify(c.Args.CommandType, repo)
+}
diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go
new file mode 100644
index 0000000..50f3f7e
--- /dev/null
+++ b/internal/command/uploadarchive/uploadarchive_test.go
@@ -0,0 +1,31 @@
+package uploadarchive
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers"
+)
+
+func TestForbiddenAccess(t *testing.T) {
+ requests := requesthandlers.BuildDisallowedByApiHandlers(t)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ output := &bytes.Buffer{}
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-archive", "group/repo"}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
+ }
+
+ err := cmd.Execute()
+ require.Equal(t, "Disallowed by API call", err.Error())
+}
diff --git a/internal/command/uploadpack/gitalycall.go b/internal/command/uploadpack/gitalycall.go
new file mode 100644
index 0000000..8b97dee
--- /dev/null
+++ b/internal/command/uploadpack/gitalycall.go
@@ -0,0 +1,36 @@
+package uploadpack
+
+import (
+ "context"
+
+ "google.golang.org/grpc"
+
+ "gitlab.com/gitlab-org/gitaly/client"
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/handler"
+)
+
+func (c *Command) performGitalyCall(response *accessverifier.Response) error {
+ gc := &handler.GitalyCommand{
+ Config: c.Config,
+ ServiceName: string(commandargs.UploadPack),
+ Address: response.Gitaly.Address,
+ Token: response.Gitaly.Token,
+ }
+
+ request := &pb.SSHUploadPackRequest{
+ Repository: &response.Gitaly.Repo,
+ GitProtocol: response.GitProtocol,
+ GitConfigOptions: response.GitConfigOptions,
+ }
+
+ return gc.RunGitalyCommand(func(ctx context.Context, conn *grpc.ClientConn) (int32, error) {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ rw := c.ReadWriter
+ return client.UploadPack(ctx, conn, rw.In, rw.Out, rw.ErrOut, request)
+ })
+}
diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go
new file mode 100644
index 0000000..71a253b
--- /dev/null
+++ b/internal/command/uploadpack/gitalycall_test.go
@@ -0,0 +1,40 @@
+package uploadpack
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers"
+)
+
+func TestUploadPack(t *testing.T) {
+ gitalyAddress, cleanup := testserver.StartGitalyServer(t)
+ defer cleanup()
+
+ requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ output := &bytes.Buffer{}
+ input := &bytes.Buffer{}
+
+ userId := "1"
+ repo := "group/repo"
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabKeyId: userId, CommandType: commandargs.UploadPack, SshArgs: []string{"git-upload-pack", repo}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input},
+ }
+
+ err := cmd.Execute()
+ require.NoError(t, err)
+
+ require.Equal(t, "UploadPack: "+repo, output.String())
+}
diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go
new file mode 100644
index 0000000..a5c71b2
--- /dev/null
+++ b/internal/command/uploadpack/uploadpack.go
@@ -0,0 +1,36 @@
+package uploadpack
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.Shell
+ ReadWriter *readwriter.ReadWriter
+}
+
+func (c *Command) Execute() error {
+ args := c.Args.SshArgs
+ if len(args) != 2 {
+ return disallowedcommand.Error
+ }
+
+ repo := args[1]
+ response, err := c.verifyAccess(repo)
+ if err != nil {
+ return err
+ }
+
+ return c.performGitalyCall(response)
+}
+
+func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+ cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
+
+ return cmd.Verify(c.Args.CommandType, repo)
+}
diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go
new file mode 100644
index 0000000..04fe2ba
--- /dev/null
+++ b/internal/command/uploadpack/uploadpack_test.go
@@ -0,0 +1,31 @@
+package uploadpack
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers"
+)
+
+func TestForbiddenAccess(t *testing.T) {
+ requests := requesthandlers.BuildDisallowedByApiHandlers(t)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ output := &bytes.Buffer{}
+
+ cmd := &Command{
+ Config: &config.Config{GitlabUrl: url},
+ Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-pack", "group/repo"}},
+ ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
+ }
+
+ err := cmd.Execute()
+ require.Equal(t, "Disallowed by API call", err.Error())
+}
diff --git a/internal/config/config.go b/internal/config/config.go
new file mode 100644
index 0000000..2231851
--- /dev/null
+++ b/internal/config/config.go
@@ -0,0 +1,123 @@
+package config
+
+import (
+ "io/ioutil"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+
+ yaml "gopkg.in/yaml.v2"
+)
+
+const (
+ configFile = "config.yml"
+ logFile = "gitlab-shell.log"
+ defaultSecretFileName = ".gitlab_shell_secret"
+)
+
+type HttpSettingsConfig struct {
+ User string `yaml:"user"`
+ Password string `yaml:"password"`
+ ReadTimeoutSeconds uint64 `yaml:"read_timeout"`
+ CaFile string `yaml:"ca_file"`
+ CaPath string `yaml:"ca_path"`
+ SelfSignedCert bool `yaml:"self_signed_cert"`
+}
+
+type Config struct {
+ RootDir string
+ LogFile string `yaml:"log_file"`
+ LogFormat string `yaml:"log_format"`
+ GitlabUrl string `yaml:"gitlab_url"`
+ GitlabTracing string `yaml:"gitlab_tracing"`
+ SecretFilePath string `yaml:"secret_file"`
+ Secret string `yaml:"secret"`
+ HttpSettings HttpSettingsConfig `yaml:"http_settings"`
+ HttpClient *HttpClient
+}
+
+func New() (*Config, error) {
+ dir, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+
+ return NewFromDir(dir)
+}
+
+func NewFromDir(dir string) (*Config, error) {
+ return newFromFile(path.Join(dir, configFile))
+}
+
+func newFromFile(filename string) (*Config, error) {
+ cfg := &Config{RootDir: path.Dir(filename)}
+
+ configBytes, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := parseConfig(configBytes, cfg); err != nil {
+ return nil, err
+ }
+
+ return cfg, nil
+}
+
+// parseConfig expects YAML data in configBytes and a Config instance with RootDir set.
+func parseConfig(configBytes []byte, cfg *Config) error {
+ if err := yaml.Unmarshal(configBytes, cfg); err != nil {
+ return err
+ }
+
+ if cfg.LogFile == "" {
+ cfg.LogFile = logFile
+ }
+
+ if len(cfg.LogFile) > 0 && cfg.LogFile[0] != '/' {
+ cfg.LogFile = path.Join(cfg.RootDir, cfg.LogFile)
+ }
+
+ if cfg.LogFormat == "" {
+ cfg.LogFormat = "text"
+ }
+
+ if cfg.GitlabUrl != "" {
+ unescapedUrl, err := url.PathUnescape(cfg.GitlabUrl)
+ if err != nil {
+ return err
+ }
+
+ cfg.GitlabUrl = unescapedUrl
+ }
+
+ if err := parseSecret(cfg); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func parseSecret(cfg *Config) error {
+ // The secret was parsed from yaml no need to read another file
+ if cfg.Secret != "" {
+ return nil
+ }
+
+ if cfg.SecretFilePath == "" {
+ cfg.SecretFilePath = defaultSecretFileName
+ }
+
+ if !filepath.IsAbs(cfg.SecretFilePath) {
+ cfg.SecretFilePath = path.Join(cfg.RootDir, cfg.SecretFilePath)
+ }
+
+ secretFileContent, err := ioutil.ReadFile(cfg.SecretFilePath)
+ if err != nil {
+ return err
+ }
+ cfg.Secret = string(secretFileContent)
+
+ return nil
+}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 0000000..202db6d
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,112 @@
+package config
+
+import (
+ "fmt"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+const (
+ customSecret = "custom/my-contents-is-secret"
+)
+
+var (
+ testRoot = testhelper.TestRoot
+)
+
+func TestParseConfig(t *testing.T) {
+ cleanup, err := testhelper.PrepareTestRootDir()
+ require.NoError(t, err)
+ defer cleanup()
+
+ testCases := []struct {
+ yaml string
+ path string
+ format string
+ gitlabUrl string
+ secret string
+ httpSettings HttpSettingsConfig
+ }{
+ {
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ secret: "default-secret-content",
+ },
+ {
+ yaml: "log_file: my-log.log",
+ path: path.Join(testRoot, "my-log.log"),
+ format: "text",
+ secret: "default-secret-content",
+ },
+ {
+ yaml: "log_file: /qux/my-log.log",
+ path: "/qux/my-log.log",
+ format: "text",
+ secret: "default-secret-content",
+ },
+ {
+ yaml: "log_format: json",
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "json",
+ secret: "default-secret-content",
+ },
+ {
+ yaml: "gitlab_url: http+unix://%2Fpath%2Fto%2Fgitlab%2Fgitlab.socket",
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ gitlabUrl: "http+unix:///path/to/gitlab/gitlab.socket",
+ secret: "default-secret-content",
+ },
+ {
+ yaml: fmt.Sprintf("secret_file: %s", customSecret),
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ secret: "custom-secret-content",
+ },
+ {
+ yaml: fmt.Sprintf("secret_file: %s", path.Join(testRoot, customSecret)),
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ secret: "custom-secret-content",
+ },
+ {
+ yaml: "secret: an inline secret",
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ secret: "an inline secret",
+ },
+ {
+ yaml: "http_settings:\n user: user_basic_auth\n password: password_basic_auth\n read_timeout: 500",
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ secret: "default-secret-content",
+ httpSettings: HttpSettingsConfig{User: "user_basic_auth", Password: "password_basic_auth", ReadTimeoutSeconds: 500},
+ },
+ {
+ yaml: "http_settings:\n ca_file: /etc/ssl/cert.pem\n ca_path: /etc/pki/tls/certs\n self_signed_cert: true",
+ path: path.Join(testRoot, "gitlab-shell.log"),
+ format: "text",
+ secret: "default-secret-content",
+ httpSettings: HttpSettingsConfig{CaFile: "/etc/ssl/cert.pem", CaPath: "/etc/pki/tls/certs", SelfSignedCert: true},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("yaml input: %q", tc.yaml), func(t *testing.T) {
+ cfg := Config{RootDir: testRoot}
+
+ err := parseConfig([]byte(tc.yaml), &cfg)
+ require.NoError(t, err)
+
+ assert.Equal(t, tc.path, cfg.LogFile)
+ assert.Equal(t, tc.format, cfg.LogFormat)
+ assert.Equal(t, tc.gitlabUrl, cfg.GitlabUrl)
+ assert.Equal(t, tc.secret, cfg.Secret)
+ assert.Equal(t, tc.httpSettings, cfg.HttpSettings)
+ })
+ }
+}
diff --git a/internal/config/httpclient.go b/internal/config/httpclient.go
new file mode 100644
index 0000000..c71efad
--- /dev/null
+++ b/internal/config/httpclient.go
@@ -0,0 +1,122 @@
+package config
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+const (
+ socketBaseUrl = "http://unix"
+ unixSocketProtocol = "http+unix://"
+ httpProtocol = "http://"
+ httpsProtocol = "https://"
+ defaultReadTimeoutSeconds = 300
+)
+
+type HttpClient struct {
+ HttpClient *http.Client
+ Host string
+}
+
+func (c *Config) GetHttpClient() *HttpClient {
+ if c.HttpClient != nil {
+ return c.HttpClient
+ }
+
+ var transport *http.Transport
+ var host string
+ if strings.HasPrefix(c.GitlabUrl, unixSocketProtocol) {
+ transport, host = c.buildSocketTransport()
+ } else if strings.HasPrefix(c.GitlabUrl, httpProtocol) {
+ transport, host = c.buildHttpTransport()
+ } else if strings.HasPrefix(c.GitlabUrl, httpsProtocol) {
+ transport, host = c.buildHttpsTransport()
+ } else {
+ return nil
+ }
+
+ httpClient := &http.Client{
+ Transport: transport,
+ Timeout: c.readTimeout(),
+ }
+
+ client := &HttpClient{HttpClient: httpClient, Host: host}
+
+ c.HttpClient = client
+
+ return client
+}
+
+func (c *Config) buildSocketTransport() (*http.Transport, string) {
+ socketPath := strings.TrimPrefix(c.GitlabUrl, unixSocketProtocol)
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ dialer := net.Dialer{}
+ return dialer.DialContext(ctx, "unix", socketPath)
+ },
+ }
+
+ return transport, socketBaseUrl
+}
+
+func (c *Config) buildHttpsTransport() (*http.Transport, string) {
+ certPool, err := x509.SystemCertPool()
+
+ if err != nil {
+ certPool = x509.NewCertPool()
+ }
+
+ caFile := c.HttpSettings.CaFile
+ if caFile != "" {
+ addCertToPool(certPool, caFile)
+ }
+
+ caPath := c.HttpSettings.CaPath
+ if caPath != "" {
+ fis, _ := ioutil.ReadDir(caPath)
+ for _, fi := range fis {
+ if fi.IsDir() {
+ continue
+ }
+
+ addCertToPool(certPool, filepath.Join(caPath, fi.Name()))
+ }
+ }
+
+ transport := &http.Transport{
+ TLSClientConfig: &tls.Config{
+ RootCAs: certPool,
+ InsecureSkipVerify: c.HttpSettings.SelfSignedCert,
+ },
+ }
+
+ return transport, c.GitlabUrl
+}
+
+func addCertToPool(certPool *x509.CertPool, fileName string) {
+ cert, err := ioutil.ReadFile(fileName)
+ if err == nil {
+ certPool.AppendCertsFromPEM(cert)
+ }
+}
+
+func (c *Config) buildHttpTransport() (*http.Transport, string) {
+ return &http.Transport{}, c.GitlabUrl
+}
+
+func (c *Config) readTimeout() time.Duration {
+ timeoutSeconds := c.HttpSettings.ReadTimeoutSeconds
+
+ if timeoutSeconds == 0 {
+ timeoutSeconds = defaultReadTimeoutSeconds
+ }
+
+ return time.Duration(timeoutSeconds) * time.Second
+}
diff --git a/internal/config/httpclient_test.go b/internal/config/httpclient_test.go
new file mode 100644
index 0000000..474deba
--- /dev/null
+++ b/internal/config/httpclient_test.go
@@ -0,0 +1,22 @@
+package config
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestReadTimeout(t *testing.T) {
+ expectedSeconds := uint64(300)
+
+ config := &Config{
+ GitlabUrl: "http://localhost:3000",
+ HttpSettings: HttpSettingsConfig{ReadTimeoutSeconds: expectedSeconds},
+ }
+ client := config.GetHttpClient()
+
+ require.NotNil(t, client)
+ assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.HttpClient.Timeout)
+}
diff --git a/internal/executable/executable.go b/internal/executable/executable.go
new file mode 100644
index 0000000..c6355b9
--- /dev/null
+++ b/internal/executable/executable.go
@@ -0,0 +1,60 @@
+package executable
+
+import (
+ "os"
+ "path/filepath"
+)
+
+const (
+ BinDir = "bin"
+ Healthcheck = "check"
+ GitlabShell = "gitlab-shell"
+ AuthorizedKeysCheck = "gitlab-shell-authorized-keys-check"
+ AuthorizedPrincipalsCheck = "gitlab-shell-authorized-principals-check"
+)
+
+type Executable struct {
+ Name string
+ RootDir string
+}
+
+var (
+ // osExecutable is overridden in tests
+ osExecutable = os.Executable
+)
+
+func New(name string) (*Executable, error) {
+ path, err := osExecutable()
+ if err != nil {
+ return nil, err
+ }
+
+ rootDir, err := findRootDir(path)
+ if err != nil {
+ return nil, err
+ }
+
+ executable := &Executable{
+ Name: name,
+ RootDir: rootDir,
+ }
+
+ return executable, nil
+}
+
+func findRootDir(path string) (string, error) {
+ // Start: /opt/.../gitlab-shell/bin/gitlab-shell
+ // Ends: /opt/.../gitlab-shell
+ rootDir := filepath.Dir(filepath.Dir(path))
+ pathFromEnv := os.Getenv("GITLAB_SHELL_DIR")
+
+ if pathFromEnv != "" {
+ if _, err := os.Stat(pathFromEnv); os.IsNotExist(err) {
+ return "", err
+ }
+
+ rootDir = pathFromEnv
+ }
+
+ return rootDir, nil
+}
diff --git a/internal/executable/executable_test.go b/internal/executable/executable_test.go
new file mode 100644
index 0000000..3915f1a
--- /dev/null
+++ b/internal/executable/executable_test.go
@@ -0,0 +1,104 @@
+package executable
+
+import (
+ "errors"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+
+ "github.com/stretchr/testify/require"
+)
+
+type fakeOs struct {
+ OldExecutable func() (string, error)
+ Path string
+ Error error
+}
+
+func (f *fakeOs) Executable() (string, error) {
+ return f.Path, f.Error
+}
+
+func (f *fakeOs) Setup() {
+ f.OldExecutable = osExecutable
+ osExecutable = f.Executable
+}
+
+func (f *fakeOs) Cleanup() {
+ osExecutable = f.OldExecutable
+}
+
+func TestNewSuccess(t *testing.T) {
+ testCases := []struct {
+ desc string
+ fakeOs *fakeOs
+ environment map[string]string
+ expectedRootDir string
+ }{
+ {
+ desc: "GITLAB_SHELL_DIR env var is not defined",
+ fakeOs: &fakeOs{Path: "/tmp/bin/gitlab-shell"},
+ expectedRootDir: "/tmp",
+ },
+ {
+ desc: "GITLAB_SHELL_DIR env var is defined",
+ fakeOs: &fakeOs{Path: "/opt/bin/gitlab-shell"},
+ environment: map[string]string{
+ "GITLAB_SHELL_DIR": "/tmp",
+ },
+ expectedRootDir: "/tmp",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ restoreEnv := testhelper.TempEnv(tc.environment)
+ defer restoreEnv()
+
+ fake := tc.fakeOs
+ fake.Setup()
+ defer fake.Cleanup()
+
+ result, err := New("gitlab-shell")
+
+ require.NoError(t, err)
+ require.Equal(t, result.Name, "gitlab-shell")
+ require.Equal(t, result.RootDir, tc.expectedRootDir)
+ })
+ }
+}
+
+func TestNewFailure(t *testing.T) {
+ testCases := []struct {
+ desc string
+ fakeOs *fakeOs
+ environment map[string]string
+ }{
+ {
+ desc: "failed to determine executable",
+ fakeOs: &fakeOs{Path: "", Error: errors.New("error")},
+ },
+ {
+ desc: "GITLAB_SHELL_DIR doesn't exist",
+ fakeOs: &fakeOs{Path: "/tmp/bin/gitlab-shell"},
+ environment: map[string]string{
+ "GITLAB_SHELL_DIR": "/tmp/non/existing/directory",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ restoreEnv := testhelper.TempEnv(tc.environment)
+ defer restoreEnv()
+
+ fake := tc.fakeOs
+ fake.Setup()
+ defer fake.Cleanup()
+
+ _, err := New("gitlab-shell")
+
+ require.Error(t, err)
+ })
+ }
+}
diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go
new file mode 100644
index 0000000..217dcdb
--- /dev/null
+++ b/internal/gitlabnet/accessverifier/client.go
@@ -0,0 +1,115 @@
+package accessverifier
+
+import (
+ "fmt"
+ "net/http"
+
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
+)
+
+const (
+ protocol = "ssh"
+ anyChanges = "_any"
+)
+
+type Client struct {
+ client *gitlabnet.GitlabClient
+}
+
+type Request struct {
+ Action commandargs.CommandType `json:"action"`
+ Repo string `json:"project"`
+ Changes string `json:"changes"`
+ Protocol string `json:"protocol"`
+ KeyId string `json:"key_id,omitempty"`
+ Username string `json:"username,omitempty"`
+ CheckIp string `json:"check_ip,omitempty"`
+}
+
+type Gitaly struct {
+ Repo pb.Repository `json:"repository"`
+ Address string `json:"address"`
+ Token string `json:"token"`
+}
+
+type CustomPayloadData struct {
+ ApiEndpoints []string `json:"api_endpoints"`
+ Username string `json:"gl_username"`
+ PrimaryRepo string `json:"primary_repo"`
+ InfoMessage string `json:"info_message"`
+ UserId string `json:"gl_id,omitempty"`
+}
+
+type CustomPayload struct {
+ Action string `json:"action"`
+ Data CustomPayloadData `json:"data"`
+}
+
+type Response struct {
+ Success bool `json:"status"`
+ Message string `json:"message"`
+ Repo string `json:"gl_repository"`
+ UserId string `json:"gl_id"`
+ Username string `json:"gl_username"`
+ GitConfigOptions []string `json:"git_config_options"`
+ Gitaly Gitaly `json:"gitaly"`
+ GitProtocol string `json:"git_protocol"`
+ Payload CustomPayload `json:"payload"`
+ ConsoleMessages []string `json:"gl_console_messages"`
+ Who string
+ StatusCode int
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{client: client}, nil
+}
+
+func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) {
+ request := &Request{Action: action, Repo: repo, Protocol: protocol, Changes: anyChanges}
+
+ if args.GitlabUsername != "" {
+ request.Username = args.GitlabUsername
+ } else {
+ request.KeyId = args.GitlabKeyId
+ }
+
+ request.CheckIp = sshenv.LocalAddr()
+
+ response, err := c.client.Post("/allowed", request)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+
+ return parse(response, args)
+}
+
+func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) {
+ response := &Response{}
+ if err := gitlabnet.ParseJSON(hr, response); err != nil {
+ return nil, err
+ }
+
+ if args.GitlabKeyId != "" {
+ response.Who = "key-" + args.GitlabKeyId
+ } else {
+ response.Who = response.UserId
+ }
+
+ response.StatusCode = hr.StatusCode
+
+ return response, nil
+}
+
+func (r *Response) IsCustomAction() bool {
+ return r.StatusCode == http.StatusMultipleChoices
+}
diff --git a/internal/gitlabnet/accessverifier/client_test.go b/internal/gitlabnet/accessverifier/client_test.go
new file mode 100644
index 0000000..96c80a7
--- /dev/null
+++ b/internal/gitlabnet/accessverifier/client_test.go
@@ -0,0 +1,209 @@
+package accessverifier
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+var (
+ repo = "group/private"
+ action = commandargs.ReceivePack
+)
+
+func buildExpectedResponse(who string) *Response {
+ response := &Response{
+ Success: true,
+ UserId: "user-1",
+ Repo: "project-26",
+ Username: "root",
+ GitConfigOptions: []string{"option"},
+ Gitaly: Gitaly{
+ Repo: pb.Repository{
+ StorageName: "default",
+ RelativePath: "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git",
+ GitObjectDirectory: "path/to/git_object_directory",
+ GitAlternateObjectDirectories: []string{"path/to/git_alternate_object_directory"},
+ GlRepository: "project-26",
+ GlProjectPath: repo,
+ },
+ Address: "unix:gitaly.socket",
+ Token: "token",
+ },
+ GitProtocol: "protocol",
+ Payload: CustomPayload{},
+ ConsoleMessages: []string{"console", "message"},
+ Who: who,
+ StatusCode: 200,
+ }
+
+ return response
+}
+
+func TestSuccessfulResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ args *commandargs.Shell
+ who string
+ }{
+ {
+ desc: "Provide key id within the request",
+ args: &commandargs.Shell{GitlabKeyId: "1"},
+ who: "key-1",
+ }, {
+ desc: "Provide username within the request",
+ args: &commandargs.Shell{GitlabUsername: "first"},
+ who: "user-1",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ result, err := client.Verify(tc.args, action, repo)
+ require.NoError(t, err)
+
+ response := buildExpectedResponse(tc.who)
+ require.Equal(t, response, result)
+ })
+ }
+}
+
+func TestGetCustomAction(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.Shell{GitlabUsername: "custom"}
+ result, err := client.Verify(args, action, repo)
+ require.NoError(t, err)
+
+ response := buildExpectedResponse("user-1")
+ response.Payload = CustomPayload{
+ Action: "geo_proxy_to_primary",
+ Data: CustomPayloadData{
+ ApiEndpoints: []string{"geo/proxy_git_push_ssh/info_refs", "geo/proxy_git_push_ssh/push"},
+ Username: "custom",
+ PrimaryRepo: "https://repo/path",
+ InfoMessage: "message",
+ },
+ }
+ response.StatusCode = 300
+
+ require.True(t, response.IsCustomAction())
+ require.Equal(t, response, result)
+}
+
+func TestErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ fakeId string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ fakeId: "2",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ fakeId: "3",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "An error response without message",
+ fakeId: "4",
+ expectedError: "Internal API error (403)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
+ resp, err := client.Verify(args, action, repo)
+
+ require.EqualError(t, err, tc.expectedError)
+ require.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ testDirCleanup, err := testhelper.PrepareTestRootDir()
+ require.NoError(t, err)
+ defer testDirCleanup()
+
+ body, err := ioutil.ReadFile(path.Join(testhelper.TestRoot, "responses/allowed.json"))
+ require.NoError(t, err)
+
+ allowedWithPayloadPath := path.Join(testhelper.TestRoot, "responses/allowed_with_payload.json")
+ bodyWithPayload, err := ioutil.ReadFile(allowedWithPayloadPath)
+ require.NoError(t, err)
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/allowed",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ require.NoError(t, err)
+
+ var requestBody *Request
+ require.NoError(t, json.Unmarshal(b, &requestBody))
+
+ switch requestBody.Username {
+ case "first":
+ _, err = w.Write(body)
+ require.NoError(t, err)
+ case "second":
+ errBody := map[string]interface{}{
+ "status": false,
+ "message": "missing user",
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(errBody))
+ case "custom":
+ w.WriteHeader(http.StatusMultipleChoices)
+ _, err = w.Write(bodyWithPayload)
+ require.NoError(t, err)
+ }
+
+ switch requestBody.KeyId {
+ case "1":
+ _, err = w.Write(body)
+ require.NoError(t, err)
+ case "2":
+ w.WriteHeader(http.StatusForbidden)
+ errBody := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(errBody))
+ case "3":
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ case "4":
+ w.WriteHeader(http.StatusForbidden)
+ }
+ },
+ },
+ }
+
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+
+ client, err := NewClient(&config.Config{GitlabUrl: url})
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/internal/gitlabnet/authorizedkeys/client.go b/internal/gitlabnet/authorizedkeys/client.go
new file mode 100644
index 0000000..ac23a96
--- /dev/null
+++ b/internal/gitlabnet/authorizedkeys/client.go
@@ -0,0 +1,65 @@
+package authorizedkeys
+
+import (
+ "fmt"
+ "net/url"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+)
+
+const (
+ AuthorizedKeysPath = "/authorized_keys"
+)
+
+type Client struct {
+ config *config.Config
+ client *gitlabnet.GitlabClient
+}
+
+type Response struct {
+ Id int64 `json:"id"`
+ Key string `json:"key"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetByKey(key string) (*Response, error) {
+ path, err := pathWithKey(key)
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := c.client.Get(path)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+
+ parsedResponse := &Response{}
+ if err := gitlabnet.ParseJSON(response, parsedResponse); err != nil {
+ return nil, err
+ }
+
+ return parsedResponse, nil
+}
+
+func pathWithKey(key string) (string, error) {
+ u, err := url.Parse(AuthorizedKeysPath)
+ if err != nil {
+ return "", err
+ }
+
+ params := u.Query()
+ params.Set("key", key)
+ u.RawQuery = params.Encode()
+
+ return u.String(), nil
+}
diff --git a/internal/gitlabnet/authorizedkeys/client_test.go b/internal/gitlabnet/authorizedkeys/client_test.go
new file mode 100644
index 0000000..965025f
--- /dev/null
+++ b/internal/gitlabnet/authorizedkeys/client_test.go
@@ -0,0 +1,105 @@
+package authorizedkeys
+
+import (
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+var (
+ requests []testserver.TestRequestHandler
+)
+
+func init() {
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/authorized_keys",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key") == "key" {
+ body := &Response{
+ Id: 1,
+ Key: "public-key",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("key") == "broken-message" {
+ w.WriteHeader(http.StatusForbidden)
+ body := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("key") == "broken-json" {
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ } else if r.URL.Query().Get("key") == "broken-empty" {
+ w.WriteHeader(http.StatusForbidden)
+ } else {
+ w.WriteHeader(http.StatusNotFound)
+ }
+ },
+ },
+ }
+}
+
+func TestGetByKey(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByKey("key")
+ require.NoError(t, err)
+ require.Equal(t, &Response{Id: 1, Key: "public-key"}, result)
+}
+
+func TestGetByKeyErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ key string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ key: "broken-message",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ key: "broken-json",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "A forbidden (403) response without message",
+ key: "broken-empty",
+ expectedError: "Internal API error (403)",
+ },
+ {
+ desc: "A not found (404) response without message",
+ key: "not-found",
+ expectedError: "Internal API error (404)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ resp, err := client.GetByKey(tc.key)
+
+ require.EqualError(t, err, tc.expectedError)
+ require.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+
+ client, err := NewClient(&config.Config{GitlabUrl: url})
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/internal/gitlabnet/client.go b/internal/gitlabnet/client.go
new file mode 100644
index 0000000..bb8655a
--- /dev/null
+++ b/internal/gitlabnet/client.go
@@ -0,0 +1,132 @@
+package gitlabnet
+
+import (
+ "bytes"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+const (
+ internalApiPath = "/api/v4/internal"
+ secretHeaderName = "Gitlab-Shared-Secret"
+)
+
+var (
+ ParsingError = fmt.Errorf("Parsing failed")
+)
+
+type ErrorResponse struct {
+ Message string `json:"message"`
+}
+
+type GitlabClient struct {
+ httpClient *http.Client
+ config *config.Config
+ host string
+}
+
+func GetClient(config *config.Config) (*GitlabClient, error) {
+ client := config.GetHttpClient()
+
+ if client == nil {
+ return nil, fmt.Errorf("Unsupported protocol")
+ }
+
+ return &GitlabClient{httpClient: client.HttpClient, config: config, host: client.Host}, nil
+}
+
+func normalizePath(path string) string {
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+
+ if !strings.HasPrefix(path, internalApiPath) {
+ path = internalApiPath + path
+ }
+ return path
+}
+
+func newRequest(method, host, path string, data interface{}) (*http.Request, error) {
+ var jsonReader io.Reader
+ if data != nil {
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+
+ jsonReader = bytes.NewReader(jsonData)
+ }
+
+ request, err := http.NewRequest(method, host+path, jsonReader)
+ if err != nil {
+ return nil, err
+ }
+
+ return request, nil
+}
+
+func parseError(resp *http.Response) error {
+ if resp.StatusCode >= 200 && resp.StatusCode <= 399 {
+ return nil
+ }
+ defer resp.Body.Close()
+ parsedResponse := &ErrorResponse{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
+ } else {
+ return fmt.Errorf(parsedResponse.Message)
+ }
+
+}
+
+func (c *GitlabClient) Get(path string) (*http.Response, error) {
+ return c.DoRequest(http.MethodGet, normalizePath(path), nil)
+}
+
+func (c *GitlabClient) Post(path string, data interface{}) (*http.Response, error) {
+ return c.DoRequest(http.MethodPost, normalizePath(path), data)
+}
+
+func (c *GitlabClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
+ request, err := newRequest(method, c.host, path, data)
+ if err != nil {
+ return nil, err
+ }
+
+ user, password := c.config.HttpSettings.User, c.config.HttpSettings.Password
+ if user != "" && password != "" {
+ request.SetBasicAuth(user, password)
+ }
+
+ encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.config.Secret))
+ request.Header.Set(secretHeaderName, encodedSecret)
+
+ request.Header.Add("Content-Type", "application/json")
+ request.Close = true
+
+ response, err := c.httpClient.Do(request)
+ if err != nil {
+ return nil, fmt.Errorf("Internal API unreachable")
+ }
+
+ if err := parseError(response); err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
+
+func ParseJSON(hr *http.Response, response interface{}) error {
+ if err := json.NewDecoder(hr.Body).Decode(response); err != nil {
+ return ParsingError
+ }
+
+ return nil
+}
diff --git a/internal/gitlabnet/client_test.go b/internal/gitlabnet/client_test.go
new file mode 100644
index 0000000..3f96b41
--- /dev/null
+++ b/internal/gitlabnet/client_test.go
@@ -0,0 +1,219 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestClients(t *testing.T) {
+ testDirCleanup, err := testhelper.PrepareTestRootDir()
+ require.NoError(t, err)
+ defer testDirCleanup()
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/hello",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
+ fmt.Fprint(w, "Hello")
+ },
+ },
+ {
+ Path: "/api/v4/internal/post_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ fmt.Fprint(w, "Echo: "+string(b))
+ },
+ },
+ {
+ Path: "/api/v4/internal/auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, r.Header.Get(secretHeaderName))
+ },
+ },
+ {
+ Path: "/api/v4/internal/error",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ body := map[string]string{
+ "message": "Don't do that",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ {
+ Path: "/api/v4/internal/broken",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ panic("Broken")
+ },
+ },
+ }
+
+ testCases := []struct {
+ desc string
+ config *config.Config
+ server func(*testing.T, []testserver.TestRequestHandler) (string, func())
+ }{
+ {
+ desc: "Socket client",
+ config: &config.Config{},
+ server: testserver.StartSocketHttpServer,
+ },
+ {
+ desc: "Http client",
+ config: &config.Config{},
+ server: testserver.StartHttpServer,
+ },
+ {
+ desc: "Https client",
+ config: &config.Config{
+ HttpSettings: config.HttpSettingsConfig{CaFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt")},
+ },
+ server: testserver.StartHttpsServer,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ url, cleanup := tc.server(t, requests)
+ defer cleanup()
+
+ tc.config.GitlabUrl = url
+ tc.config.Secret = "sssh, it's a secret"
+
+ client, err := GetClient(tc.config)
+ require.NoError(t, err)
+
+ testBrokenRequest(t, client)
+ testSuccessfulGet(t, client)
+ testSuccessfulPost(t, client)
+ testMissing(t, client)
+ testErrorMessage(t, client)
+ testAuthenticationHeader(t, client)
+ })
+ }
+}
+
+func testSuccessfulGet(t *testing.T, client *GitlabClient) {
+ t.Run("Successful get", func(t *testing.T) {
+ response, err := client.Get("/hello")
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, string(responseBody), "Hello")
+ })
+}
+
+func testSuccessfulPost(t *testing.T, client *GitlabClient) {
+ t.Run("Successful Post", func(t *testing.T) {
+ data := map[string]string{"key": "value"}
+
+ response, err := client.Post("/post_endpoint", data)
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody))
+ })
+}
+
+func testMissing(t *testing.T, client *GitlabClient) {
+ t.Run("Missing error for GET", func(t *testing.T) {
+ response, err := client.Get("/missing")
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+ })
+
+ t.Run("Missing error for POST", func(t *testing.T) {
+ response, err := client.Post("/missing", map[string]string{})
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+ })
+}
+
+func testErrorMessage(t *testing.T, client *GitlabClient) {
+ t.Run("Error with message for GET", func(t *testing.T) {
+ response, err := client.Get("/error")
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
+
+ t.Run("Error with message for POST", func(t *testing.T) {
+ response, err := client.Post("/error", map[string]string{})
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
+}
+
+func testBrokenRequest(t *testing.T, client *GitlabClient) {
+ t.Run("Broken request for GET", func(t *testing.T) {
+ response, err := client.Get("/broken")
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
+
+ t.Run("Broken request for POST", func(t *testing.T) {
+ response, err := client.Post("/broken", map[string]string{})
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
+}
+
+func testAuthenticationHeader(t *testing.T, client *GitlabClient) {
+ t.Run("Authentication headers for GET", func(t *testing.T) {
+ response, err := client.Get("/auth")
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
+
+ t.Run("Authentication headers for POST", func(t *testing.T) {
+ response, err := client.Post("/auth", map[string]string{})
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
+}
diff --git a/internal/gitlabnet/discover/client.go b/internal/gitlabnet/discover/client.go
new file mode 100644
index 0000000..3faef53
--- /dev/null
+++ b/internal/gitlabnet/discover/client.go
@@ -0,0 +1,71 @@
+package discover
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+)
+
+type Client struct {
+ config *config.Config
+ client *gitlabnet.GitlabClient
+}
+
+type Response struct {
+ UserId int64 `json:"id"`
+ Name string `json:"name"`
+ Username string `json:"username"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) {
+ params := url.Values{}
+ if args.GitlabUsername != "" {
+ params.Add("username", args.GitlabUsername)
+ } else if args.GitlabKeyId != "" {
+ params.Add("key_id", args.GitlabKeyId)
+ } else {
+ // There was no 'who' information, this matches the ruby error
+ // message.
+ return nil, fmt.Errorf("who='' is invalid")
+ }
+
+ return c.getResponse(params)
+}
+
+func (c *Client) getResponse(params url.Values) (*Response, error) {
+ path := "/discover?" + params.Encode()
+
+ response, err := c.client.Get(path)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+
+ return parse(response)
+}
+
+func parse(hr *http.Response) (*Response, error) {
+ response := &Response{}
+ if err := gitlabnet.ParseJSON(hr, response); err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
+
+func (r *Response) IsAnonymous() bool {
+ return r.UserId < 1
+}
diff --git a/internal/gitlabnet/discover/client_test.go b/internal/gitlabnet/discover/client_test.go
new file mode 100644
index 0000000..66e234b
--- /dev/null
+++ b/internal/gitlabnet/discover/client_test.go
@@ -0,0 +1,137 @@
+package discover
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ requests []testserver.TestRequestHandler
+)
+
+func init() {
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key_id") == "1" {
+ body := &Response{
+ UserId: 2,
+ Username: "alex-doe",
+ Name: "Alex Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "jane-doe" {
+ body := &Response{
+ UserId: 1,
+ Username: "jane-doe",
+ Name: "Jane Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_message" {
+ w.WriteHeader(http.StatusForbidden)
+ body := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_json" {
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ } else if r.URL.Query().Get("username") == "broken_empty" {
+ w.WriteHeader(http.StatusForbidden)
+ } else {
+ fmt.Fprint(w, "null")
+ }
+ },
+ },
+ }
+}
+
+func TestGetByKeyId(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ params := url.Values{}
+ params.Add("key_id", "1")
+ result, err := client.getResponse(params)
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
+}
+
+func TestGetByUsername(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ params := url.Values{}
+ params.Add("username", "jane-doe")
+ result, err := client.getResponse(params)
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
+}
+
+func TestMissingUser(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ params := url.Values{}
+ params.Add("username", "missing")
+ result, err := client.getResponse(params)
+ assert.NoError(t, err)
+ assert.True(t, result.IsAnonymous())
+}
+
+func TestErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ fakeUsername string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ fakeUsername: "broken_message",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ fakeUsername: "broken_json",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "An error response without message",
+ fakeUsername: "broken_empty",
+ expectedError: "Internal API error (403)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ params := url.Values{}
+ params.Add("username", tc.fakeUsername)
+ resp, err := client.getResponse(params)
+
+ assert.EqualError(t, err, tc.expectedError)
+ assert.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+
+ client, err := NewClient(&config.Config{GitlabUrl: url})
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/internal/gitlabnet/healthcheck/client.go b/internal/gitlabnet/healthcheck/client.go
new file mode 100644
index 0000000..7db682a
--- /dev/null
+++ b/internal/gitlabnet/healthcheck/client.go
@@ -0,0 +1,54 @@
+package healthcheck
+
+import (
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+)
+
+const (
+ checkPath = "/check"
+)
+
+type Client struct {
+ config *config.Config
+ client *gitlabnet.GitlabClient
+}
+
+type Response struct {
+ APIVersion string `json:"api_version"`
+ GitlabVersion string `json:"gitlab_version"`
+ GitlabRevision string `json:"gitlab_rev"`
+ Redis bool `json:"redis"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) Check() (*Response, error) {
+ resp, err := c.client.Get(checkPath)
+ if err != nil {
+ return nil, err
+ }
+
+ defer resp.Body.Close()
+
+ return parse(resp)
+}
+
+func parse(hr *http.Response) (*Response, error) {
+ response := &Response{}
+ if err := gitlabnet.ParseJSON(hr, response); err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
diff --git a/internal/gitlabnet/healthcheck/client_test.go b/internal/gitlabnet/healthcheck/client_test.go
new file mode 100644
index 0000000..d7212b0
--- /dev/null
+++ b/internal/gitlabnet/healthcheck/client_test.go
@@ -0,0 +1,48 @@
+package healthcheck
+
+import (
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/check",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ json.NewEncoder(w).Encode(testResponse)
+ },
+ },
+ }
+
+ testResponse = &Response{
+ APIVersion: "v4",
+ GitlabVersion: "v12.0.0-ee",
+ GitlabRevision: "3b13818e8330f68625d80d9bf5d8049c41fbe197",
+ Redis: true,
+ }
+)
+
+func TestCheck(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.Check()
+ require.NoError(t, err)
+ require.Equal(t, testResponse, result)
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+
+ client, err := NewClient(&config.Config{GitlabUrl: url})
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/internal/gitlabnet/httpclient_test.go b/internal/gitlabnet/httpclient_test.go
new file mode 100644
index 0000000..a40ab6d
--- /dev/null
+++ b/internal/gitlabnet/httpclient_test.go
@@ -0,0 +1,96 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+const (
+ username = "basic_auth_user"
+ password = "basic_auth_password"
+)
+
+func TestBasicAuthSettings(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/get_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
+ fmt.Fprint(w, r.Header.Get("Authorization"))
+ },
+ },
+ {
+ Path: "/api/v4/internal/post_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+
+ fmt.Fprint(w, r.Header.Get("Authorization"))
+ },
+ },
+ }
+ config := &config.Config{HttpSettings: config.HttpSettingsConfig{User: username, Password: password}}
+
+ client, cleanup := setup(t, config, requests)
+ defer cleanup()
+
+ response, err := client.Get("/get_endpoint")
+ require.NoError(t, err)
+ testBasicAuthHeaders(t, response)
+
+ response, err = client.Post("/post_endpoint", nil)
+ require.NoError(t, err)
+ testBasicAuthHeaders(t, response)
+}
+
+func testBasicAuthHeaders(t *testing.T, response *http.Response) {
+ defer response.Body.Close()
+
+ require.NotNil(t, response)
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+
+ headerParts := strings.Split(string(responseBody), " ")
+ assert.Equal(t, "Basic", headerParts[0])
+
+ credentials, err := base64.StdEncoding.DecodeString(headerParts[1])
+ require.NoError(t, err)
+
+ assert.Equal(t, username+":"+password, string(credentials))
+}
+
+func TestEmptyBasicAuthSettings(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/empty_basic_auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "", r.Header.Get("Authorization"))
+ },
+ },
+ }
+
+ client, cleanup := setup(t, &config.Config{}, requests)
+ defer cleanup()
+
+ _, err := client.Get("/empty_basic_auth")
+ require.NoError(t, err)
+}
+
+func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (*GitlabClient, func()) {
+ url, cleanup := testserver.StartHttpServer(t, requests)
+
+ config.GitlabUrl = url
+ client, err := GetClient(config)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/internal/gitlabnet/httpsclient_test.go b/internal/gitlabnet/httpsclient_test.go
new file mode 100644
index 0000000..0acd425
--- /dev/null
+++ b/internal/gitlabnet/httpsclient_test.go
@@ -0,0 +1,125 @@
+package gitlabnet
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestSuccessfulRequests(t *testing.T) {
+ testCases := []struct {
+ desc string
+ config *config.Config
+ }{
+ {
+ desc: "Valid CaFile",
+ config: &config.Config{
+ HttpSettings: config.HttpSettingsConfig{CaFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt")},
+ },
+ },
+ {
+ desc: "Valid CaPath",
+ config: &config.Config{
+ HttpSettings: config.HttpSettingsConfig{CaPath: path.Join(testhelper.TestRoot, "certs/valid")},
+ },
+ },
+ {
+ desc: "Self signed cert option enabled",
+ config: &config.Config{
+ HttpSettings: config.HttpSettingsConfig{SelfSignedCert: true},
+ },
+ },
+ {
+ desc: "Invalid cert with self signed cert option enabled",
+ config: &config.Config{
+ HttpSettings: config.HttpSettingsConfig{SelfSignedCert: true, CaFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt")},
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, cleanup := setupWithRequests(t, tc.config)
+ defer cleanup()
+
+ response, err := client.Get("/hello")
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ defer response.Body.Close()
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, string(responseBody), "Hello")
+ })
+ }
+}
+
+func TestFailedRequests(t *testing.T) {
+ testCases := []struct {
+ desc string
+ config *config.Config
+ }{
+ {
+ desc: "Invalid CaFile",
+ config: &config.Config{
+ HttpSettings: config.HttpSettingsConfig{CaFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt")},
+ },
+ },
+ {
+ desc: "Invalid CaPath",
+ config: &config.Config{
+ HttpSettings: config.HttpSettingsConfig{CaPath: path.Join(testhelper.TestRoot, "certs/invalid")},
+ },
+ },
+ {
+ desc: "Empty config",
+ config: &config.Config{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, cleanup := setupWithRequests(t, tc.config)
+ defer cleanup()
+
+ _, err := client.Get("/hello")
+ require.Error(t, err)
+
+ assert.Equal(t, err.Error(), "Internal API unreachable")
+ })
+ }
+}
+
+func setupWithRequests(t *testing.T, config *config.Config) (*GitlabClient, func()) {
+ testDirCleanup, err := testhelper.PrepareTestRootDir()
+ require.NoError(t, err)
+ defer testDirCleanup()
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/hello",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
+ fmt.Fprint(w, "Hello")
+ },
+ },
+ }
+
+ url, cleanup := testserver.StartHttpsServer(t, requests)
+
+ config.GitlabUrl = url
+ client, err := GetClient(config)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/internal/gitlabnet/lfsauthenticate/client.go b/internal/gitlabnet/lfsauthenticate/client.go
new file mode 100644
index 0000000..d7469dd
--- /dev/null
+++ b/internal/gitlabnet/lfsauthenticate/client.go
@@ -0,0 +1,66 @@
+package lfsauthenticate
+
+import (
+ "fmt"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+)
+
+type Client struct {
+ config *config.Config
+ client *gitlabnet.GitlabClient
+ args *commandargs.Shell
+}
+
+type Request struct {
+ Action commandargs.CommandType `json:"operation"`
+ Repo string `json:"project"`
+ KeyId string `json:"key_id,omitempty"`
+ UserId string `json:"user_id,omitempty"`
+}
+
+type Response struct {
+ Username string `json:"username"`
+ LfsToken string `json:"lfs_token"`
+ RepoPath string `json:"repository_http_path"`
+ ExpiresIn int `json:"expires_in"`
+}
+
+func NewClient(config *config.Config, args *commandargs.Shell) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client, args: args}, nil
+}
+
+func (c *Client) Authenticate(action commandargs.CommandType, repo, userId string) (*Response, error) {
+ request := &Request{Action: action, Repo: repo}
+ if c.args.GitlabKeyId != "" {
+ request.KeyId = c.args.GitlabKeyId
+ } else {
+ request.UserId = strings.TrimPrefix(userId, "user-")
+ }
+
+ response, err := c.client.Post("/lfs_authenticate", request)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+
+ return parse(response)
+}
+
+func parse(hr *http.Response) (*Response, error) {
+ response := &Response{}
+ if err := gitlabnet.ParseJSON(hr, response); err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
diff --git a/internal/gitlabnet/lfsauthenticate/client_test.go b/internal/gitlabnet/lfsauthenticate/client_test.go
new file mode 100644
index 0000000..6faaa63
--- /dev/null
+++ b/internal/gitlabnet/lfsauthenticate/client_test.go
@@ -0,0 +1,117 @@
+package lfsauthenticate
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+const (
+ keyId = "123"
+ repo = "group/repo"
+ action = commandargs.UploadPack
+)
+
+func setup(t *testing.T) []testserver.TestRequestHandler {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/lfs_authenticate",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+ require.NoError(t, err)
+
+ var request *Request
+ require.NoError(t, json.Unmarshal(b, &request))
+
+ switch request.KeyId {
+ case keyId:
+ body := map[string]interface{}{
+ "username": "john",
+ "lfs_token": "sometoken",
+ "repository_http_path": "https://gitlab.com/repo/path",
+ "expires_in": 1800,
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ case "forbidden":
+ w.WriteHeader(http.StatusForbidden)
+ case "broken":
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ },
+ },
+ }
+
+ return requests
+}
+
+func TestFailedRequests(t *testing.T) {
+ requests := setup(t)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ args *commandargs.Shell
+ expectedOutput string
+ }{
+ {
+ desc: "With bad response",
+ args: &commandargs.Shell{GitlabKeyId: "-1", CommandType: commandargs.UploadPack},
+ expectedOutput: "Parsing failed",
+ },
+ {
+ desc: "With API returns an error",
+ args: &commandargs.Shell{GitlabKeyId: "forbidden", CommandType: commandargs.UploadPack},
+ expectedOutput: "Internal API error (403)",
+ },
+ {
+ desc: "With API fails",
+ args: &commandargs.Shell{GitlabKeyId: "broken", CommandType: commandargs.UploadPack},
+ expectedOutput: "Internal API error (500)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, err := NewClient(&config.Config{GitlabUrl: url}, tc.args)
+ require.NoError(t, err)
+
+ repo := "group/repo"
+
+ _, err = client.Authenticate(tc.args.CommandType, repo, "")
+ require.Error(t, err)
+
+ require.Equal(t, tc.expectedOutput, err.Error())
+ })
+ }
+}
+
+func TestSuccessfulRequests(t *testing.T) {
+ requests := setup(t)
+ url, cleanup := testserver.StartHttpServer(t, requests)
+ defer cleanup()
+
+ args := &commandargs.Shell{GitlabKeyId: keyId, CommandType: commandargs.LfsAuthenticate}
+ client, err := NewClient(&config.Config{GitlabUrl: url}, args)
+ require.NoError(t, err)
+
+ response, err := client.Authenticate(action, repo, "")
+ require.NoError(t, err)
+
+ expectedResponse := &Response{
+ Username: "john",
+ LfsToken: "sometoken",
+ RepoPath: "https://gitlab.com/repo/path",
+ ExpiresIn: 1800,
+ }
+
+ require.Equal(t, expectedResponse, response)
+}
diff --git a/internal/gitlabnet/testserver/gitalyserver.go b/internal/gitlabnet/testserver/gitalyserver.go
new file mode 100644
index 0000000..694fd41
--- /dev/null
+++ b/internal/gitlabnet/testserver/gitalyserver.go
@@ -0,0 +1,78 @@
+package testserver
+
+import (
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+)
+
+type testGitalyServer struct{}
+
+func (s *testGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHReceivePackResponse{Stdout: response})
+
+ return nil
+}
+
+func (s *testGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ response := []byte("UploadPack: " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHUploadPackResponse{Stdout: response})
+
+ return nil
+}
+
+func (s *testGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ response := []byte("UploadArchive: " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response})
+
+ return nil
+}
+
+func StartGitalyServer(t *testing.T) (string, func()) {
+ tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api")
+ gitalySocketPath := path.Join(tempDir, "gitaly.sock")
+
+ err := os.MkdirAll(filepath.Dir(gitalySocketPath), 0700)
+ require.NoError(t, err)
+
+ server := grpc.NewServer()
+
+ listener, err := net.Listen("unix", gitalySocketPath)
+ require.NoError(t, err)
+
+ pb.RegisterSSHServiceServer(server, &testGitalyServer{})
+
+ go server.Serve(listener)
+
+ gitalySocketUrl := "unix:" + gitalySocketPath
+ cleanup := func() {
+ server.Stop()
+ os.RemoveAll(tempDir)
+ }
+
+ return gitalySocketUrl, cleanup
+}
diff --git a/internal/gitlabnet/testserver/testserver.go b/internal/gitlabnet/testserver/testserver.go
new file mode 100644
index 0000000..f3b7b71
--- /dev/null
+++ b/internal/gitlabnet/testserver/testserver.go
@@ -0,0 +1,82 @@
+package testserver
+
+import (
+ "crypto/tls"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+var (
+ tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
+ testSocket = path.Join(tempDir, "internal.sock")
+)
+
+type TestRequestHandler struct {
+ Path string
+ Handler func(w http.ResponseWriter, r *http.Request)
+}
+
+func StartSocketHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ err := os.MkdirAll(filepath.Dir(testSocket), 0700)
+ require.NoError(t, err)
+
+ socketListener, err := net.Listen("unix", testSocket)
+ require.NoError(t, err)
+
+ server := http.Server{
+ Handler: buildHandler(handlers),
+ // We'll put this server through some nasty stuff we don't want
+ // in our test output
+ ErrorLog: log.New(ioutil.Discard, "", 0),
+ }
+ go server.Serve(socketListener)
+
+ url := "http+unix://" + testSocket
+
+ return url, cleanupSocket
+}
+
+func StartHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ server := httptest.NewServer(buildHandler(handlers))
+
+ return server.URL, server.Close
+}
+
+func StartHttpsServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt")
+ key := path.Join(testhelper.TestRoot, "certs/valid/server.key")
+
+ server := httptest.NewUnstartedServer(buildHandler(handlers))
+ cer, err := tls.LoadX509KeyPair(crt, key)
+ require.NoError(t, err)
+
+ server.TLS = &tls.Config{Certificates: []tls.Certificate{cer}}
+ server.StartTLS()
+
+ return server.URL, server.Close
+}
+
+func cleanupSocket() {
+ os.RemoveAll(tempDir)
+}
+
+func buildHandler(handlers []TestRequestHandler) http.Handler {
+ h := http.NewServeMux()
+
+ for _, handler := range handlers {
+ h.HandleFunc(handler.Path, handler.Handler)
+ }
+
+ return h
+}
diff --git a/internal/gitlabnet/twofactorrecover/client.go b/internal/gitlabnet/twofactorrecover/client.go
new file mode 100644
index 0000000..a3052f8
--- /dev/null
+++ b/internal/gitlabnet/twofactorrecover/client.go
@@ -0,0 +1,89 @@
+package twofactorrecover
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/discover"
+)
+
+type Client struct {
+ config *config.Config
+ client *gitlabnet.GitlabClient
+}
+
+type Response struct {
+ Success bool `json:"success"`
+ RecoveryCodes []string `json:"recovery_codes"`
+ Message string `json:"message"`
+}
+
+type RequestBody struct {
+ KeyId string `json:"key_id,omitempty"`
+ UserId int64 `json:"user_id,omitempty"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetRecoveryCodes(args *commandargs.Shell) ([]string, error) {
+ requestBody, err := c.getRequestBody(args)
+
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+
+ return parse(response)
+}
+
+func parse(hr *http.Response) ([]string, error) {
+ response := &Response{}
+ if err := gitlabnet.ParseJSON(hr, response); err != nil {
+ return nil, err
+ }
+
+ if !response.Success {
+ return nil, errors.New(response.Message)
+ }
+
+ return response.RecoveryCodes, nil
+}
+
+func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) {
+ client, err := discover.NewClient(c.config)
+
+ if err != nil {
+ return nil, err
+ }
+
+ var requestBody *RequestBody
+ if args.GitlabKeyId != "" {
+ requestBody = &RequestBody{KeyId: args.GitlabKeyId}
+ } else {
+ userInfo, err := client.GetByCommandArgs(args)
+
+ if err != nil {
+ return nil, err
+ }
+
+ requestBody = &RequestBody{UserId: userInfo.UserId}
+ }
+
+ return requestBody, nil
+}
diff --git a/internal/gitlabnet/twofactorrecover/client_test.go b/internal/gitlabnet/twofactorrecover/client_test.go
new file mode 100644
index 0000000..d5073e3
--- /dev/null
+++ b/internal/gitlabnet/twofactorrecover/client_test.go
@@ -0,0 +1,158 @@
+package twofactorrecover
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/discover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+var (
+ requests []testserver.TestRequestHandler
+)
+
+func initialize(t *testing.T) {
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/two_factor_recovery_codes",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ var requestBody *RequestBody
+ json.Unmarshal(b, &requestBody)
+
+ switch requestBody.KeyId {
+ case "0":
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery 1", "codes 1"},
+ }
+ json.NewEncoder(w).Encode(body)
+ case "1":
+ body := map[string]interface{}{
+ "success": false,
+ "message": "missing user",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "2":
+ w.WriteHeader(http.StatusForbidden)
+ body := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "3":
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ case "4":
+ w.WriteHeader(http.StatusForbidden)
+ }
+
+ if requestBody.UserId == 1 {
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery 2", "codes 2"},
+ }
+ json.NewEncoder(w).Encode(body)
+ }
+ },
+ },
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ body := &discover.Response{
+ UserId: 1,
+ Username: "jane-doe",
+ Name: "Jane Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ }
+}
+
+func TestGetRecoveryCodesByKeyId(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.Shell{GitlabKeyId: "0"}
+ result, err := client.GetRecoveryCodes(args)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"recovery 1", "codes 1"}, result)
+}
+
+func TestGetRecoveryCodesByUsername(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.Shell{GitlabUsername: "jane-doe"}
+ result, err := client.GetRecoveryCodes(args)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"recovery 2", "codes 2"}, result)
+}
+
+func TestMissingUser(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.Shell{GitlabKeyId: "1"}
+ _, err := client.GetRecoveryCodes(args)
+ assert.Equal(t, "missing user", err.Error())
+}
+
+func TestErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ fakeId string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ fakeId: "2",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ fakeId: "3",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "An error response without message",
+ fakeId: "4",
+ expectedError: "Internal API error (403)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
+ resp, err := client.GetRecoveryCodes(args)
+
+ assert.EqualError(t, err, tc.expectedError)
+ assert.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ initialize(t)
+ url, cleanup := testserver.StartSocketHttpServer(t, requests)
+
+ client, err := NewClient(&config.Config{GitlabUrl: url})
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/internal/handler/exec.go b/internal/handler/exec.go
new file mode 100644
index 0000000..ba9a4ff
--- /dev/null
+++ b/internal/handler/exec.go
@@ -0,0 +1,96 @@
+package handler
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "gitlab.com/gitlab-org/gitaly/auth"
+ "gitlab.com/gitlab-org/gitaly/client"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/labkit/tracing"
+ "google.golang.org/grpc"
+)
+
+// GitalyHandlerFunc implementations are responsible for making
+// an appropriate Gitaly call using the provided client and context
+// and returning an error from the Gitaly call.
+type GitalyHandlerFunc func(ctx context.Context, client *grpc.ClientConn) (int32, error)
+
+type GitalyConn struct {
+ ctx context.Context
+ conn *grpc.ClientConn
+ close func()
+}
+
+type GitalyCommand struct {
+ Config *config.Config
+ ServiceName string
+ Address string
+ Token string
+}
+
+// RunGitalyCommand provides a bootstrap for Gitaly commands executed
+// through GitLab-Shell. It ensures that logging, tracing and other
+// common concerns are configured before executing the `handler`.
+func (gc *GitalyCommand) RunGitalyCommand(handler GitalyHandlerFunc) error {
+ gitalyConn, err := getConn(gc)
+
+ if err != nil {
+ return err
+ }
+
+ _, err = handler(gitalyConn.ctx, gitalyConn.conn)
+
+ gitalyConn.close()
+
+ return err
+}
+
+func getConn(gc *GitalyCommand) (*GitalyConn, error) {
+ if gc.Address == "" {
+ return nil, fmt.Errorf("no gitaly_address given")
+ }
+
+ connOpts := client.DefaultDialOpts
+ if gc.Token != "" {
+ connOpts = append(client.DefaultDialOpts, grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(gc.Token)))
+ }
+
+ // Use a working directory that won't get removed or unmounted.
+ if err := os.Chdir("/"); err != nil {
+ return nil, err
+ }
+
+ // Configure distributed tracing
+ serviceName := fmt.Sprintf("gitlab-shell-%v", gc.ServiceName)
+ closer := tracing.Initialize(
+ tracing.WithServiceName(serviceName),
+
+ // For GitLab-Shell, we explicitly initialize tracing from a config file
+ // instead of the default environment variable (using GITLAB_TRACING)
+ // This decision was made owing to the difficulty in passing environment
+ // variables into GitLab-Shell processes.
+ //
+ // Processes are spawned as children of the SSH daemon, which tightly
+ // controls environment variables; doing this means we don't have to
+ // enable PermitUserEnvironment
+ tracing.WithConnectionString(gc.Config.GitlabTracing),
+ )
+
+ ctx, finished := tracing.ExtractFromEnv(context.Background())
+
+ conn, err := client.Dial(gc.Address, connOpts)
+ if err != nil {
+ return nil, err
+ }
+
+ finish := func() {
+ finished()
+ closer.Close()
+ conn.Close()
+ }
+
+ return &GitalyConn{ctx: ctx, conn: conn, close: finish}, nil
+}
diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go
new file mode 100644
index 0000000..6c7d3f5
--- /dev/null
+++ b/internal/handler/exec_test.go
@@ -0,0 +1,42 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+func makeHandler(t *testing.T, err error) func(context.Context, *grpc.ClientConn) (int32, error) {
+ return func(ctx context.Context, client *grpc.ClientConn) (int32, error) {
+ require.NotNil(t, ctx)
+ require.NotNil(t, client)
+
+ return 0, err
+ }
+}
+
+func TestRunGitalyCommand(t *testing.T) {
+ cmd := GitalyCommand{
+ Config: &config.Config{},
+ Address: "tcp://localhost:9999",
+ }
+
+ err := cmd.RunGitalyCommand(makeHandler(t, nil))
+ require.NoError(t, err)
+
+ expectedErr := errors.New("error")
+ err = cmd.RunGitalyCommand(makeHandler(t, expectedErr))
+ require.Equal(t, err, expectedErr)
+}
+
+func TestMissingGitalyAddress(t *testing.T) {
+ cmd := GitalyCommand{Config: &config.Config{}}
+
+ err := cmd.RunGitalyCommand(makeHandler(t, nil))
+ require.EqualError(t, err, "no gitaly_address given")
+}
diff --git a/internal/keyline/key_line.go b/internal/keyline/key_line.go
new file mode 100644
index 0000000..c29a320
--- /dev/null
+++ b/internal/keyline/key_line.go
@@ -0,0 +1,62 @@
+package keyline
+
+import (
+ "errors"
+ "fmt"
+ "path"
+ "regexp"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+)
+
+var (
+ keyRegex = regexp.MustCompile(`\A[a-z0-9-]+\z`)
+)
+
+const (
+ PublicKeyPrefix = "key"
+ PrincipalPrefix = "username"
+ SshOptions = "no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty"
+)
+
+type KeyLine struct {
+ Id string // This can be either an ID of a Key or username
+ Value string // This can be either a public key or a principal name
+ Prefix string
+ RootDir string
+}
+
+func NewPublicKeyLine(id string, publicKey string, rootDir string) (*KeyLine, error) {
+ return newKeyLine(id, publicKey, PublicKeyPrefix, rootDir)
+}
+
+func NewPrincipalKeyLine(keyId string, principal string, rootDir string) (*KeyLine, error) {
+ return newKeyLine(keyId, principal, PrincipalPrefix, rootDir)
+}
+
+func (k *KeyLine) ToString() string {
+ command := fmt.Sprintf("%s %s-%s", path.Join(k.RootDir, executable.BinDir, executable.GitlabShell), k.Prefix, k.Id)
+
+ return fmt.Sprintf(`command="%s",%s %s`, command, SshOptions, k.Value)
+}
+
+func newKeyLine(id string, value string, prefix string, rootDir string) (*KeyLine, error) {
+ if err := validate(id, value); err != nil {
+ return nil, err
+ }
+
+ return &KeyLine{Id: id, Value: value, Prefix: prefix, RootDir: rootDir}, nil
+}
+
+func validate(id string, value string) error {
+ if !keyRegex.MatchString(id) {
+ return errors.New(fmt.Sprintf("Invalid key_id: %s", id))
+ }
+
+ if strings.Contains(value, "\n") {
+ return errors.New(fmt.Sprintf("Invalid value: %s", value))
+ }
+
+ return nil
+}
diff --git a/internal/keyline/key_line_test.go b/internal/keyline/key_line_test.go
new file mode 100644
index 0000000..c6883c0
--- /dev/null
+++ b/internal/keyline/key_line_test.go
@@ -0,0 +1,82 @@
+package keyline
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestFailingNewPublicKeyLine(t *testing.T) {
+ testCases := []struct {
+ desc string
+ id string
+ publicKey string
+ expectedError string
+ }{
+ {
+ desc: "When Id has non-alphanumeric and non-dash characters in it",
+ id: "key\n1",
+ publicKey: "public-key",
+ expectedError: "Invalid key_id: key\n1",
+ },
+ {
+ desc: "When public key has newline in it",
+ id: "key",
+ publicKey: "public\nkey",
+ expectedError: "Invalid value: public\nkey",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ result, err := NewPublicKeyLine(tc.id, tc.publicKey, "root-dir")
+
+ require.Empty(t, result)
+ require.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
+
+func TestFailingNewPrincipalKeyLine(t *testing.T) {
+ testCases := []struct {
+ desc string
+ keyId string
+ principal string
+ expectedError string
+ }{
+ {
+ desc: "When username has non-alphanumeric and non-dash characters in it",
+ keyId: "username\n1",
+ principal: "principal",
+ expectedError: "Invalid key_id: username\n1",
+ },
+ {
+ desc: "When principal has newline in it",
+ keyId: "username",
+ principal: "principal\n1",
+ expectedError: "Invalid value: principal\n1",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ result, err := NewPrincipalKeyLine(tc.keyId, tc.principal, "root-dir")
+
+ require.Empty(t, result)
+ require.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
+
+func TestToString(t *testing.T) {
+ keyLine := &KeyLine{
+ Id: "1",
+ Value: "public-key",
+ Prefix: "key",
+ RootDir: "/tmp",
+ }
+
+ result := keyLine.ToString()
+
+ require.Equal(t, `command="/tmp/bin/gitlab-shell key-1",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty public-key`, result)
+}
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
new file mode 100644
index 0000000..bbc6a51
--- /dev/null
+++ b/internal/logger/logger.go
@@ -0,0 +1,82 @@
+package logger
+
+import (
+ "fmt"
+ "io"
+ golog "log"
+ "log/syslog"
+ "os"
+ "sync"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ logWriter io.Writer
+ bootstrapLogger *golog.Logger
+ pid int
+ mutex sync.Mutex
+ ProgName string
+)
+
+func Configure(cfg *config.Config) error {
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ pid = os.Getpid()
+
+ var err error
+ logWriter, err = os.OpenFile(cfg.LogFile, os.O_WRONLY|os.O_APPEND, 0)
+ if err != nil {
+ return err
+ }
+
+ log.SetOutput(logWriter)
+ if cfg.LogFormat == "json" {
+ log.SetFormatter(&log.JSONFormatter{})
+ }
+
+ return nil
+}
+
+func logPrint(msg string, err error) {
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ if logWriter == nil {
+ bootstrapLogPrint(msg, err)
+ return
+ }
+
+ log.WithError(err).WithFields(log.Fields{
+ "pid": pid,
+ }).Error(msg)
+}
+
+func Fatal(msg string, err error) {
+ logPrint(msg, err)
+ // We don't show the error to the end user because it can leak
+ // information that is private to the GitLab server.
+ fmt.Fprintf(os.Stderr, "%s: fatal: %s\n", ProgName, msg)
+ os.Exit(1)
+}
+
+// If our log file is not available we want to log somewhere else, but
+// not to standard error because that leaks information to the user. This
+// function attemps to log to syslog.
+//
+// We assume the logging mutex is already locked.
+func bootstrapLogPrint(msg string, err error) {
+ if bootstrapLogger == nil {
+ var err error
+ bootstrapLogger, err = syslog.NewLogger(syslog.LOG_ERR|syslog.LOG_USER, 0)
+ if err != nil {
+ // The message will not be logged.
+ return
+ }
+ }
+
+ bootstrapLogger.Print(ProgName+":", msg+":", err)
+}
diff --git a/internal/sshenv/sshenv.go b/internal/sshenv/sshenv.go
new file mode 100644
index 0000000..387feb2
--- /dev/null
+++ b/internal/sshenv/sshenv.go
@@ -0,0 +1,15 @@
+package sshenv
+
+import (
+ "os"
+ "strings"
+)
+
+func LocalAddr() string {
+ address := os.Getenv("SSH_CONNECTION")
+
+ if address != "" {
+ return strings.Fields(address)[0]
+ }
+ return ""
+}
diff --git a/internal/sshenv/sshenv_test.go b/internal/sshenv/sshenv_test.go
new file mode 100644
index 0000000..e4a640e
--- /dev/null
+++ b/internal/sshenv/sshenv_test.go
@@ -0,0 +1,20 @@
+package sshenv
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestLocalAddr(t *testing.T) {
+ cleanup, err := testhelper.Setenv("SSH_CONNECTION", "127.0.0.1 0")
+ require.NoError(t, err)
+ defer cleanup()
+
+ require.Equal(t, LocalAddr(), "127.0.0.1")
+}
+
+func TestEmptyLocalAddr(t *testing.T) {
+ require.Equal(t, LocalAddr(), "")
+}
diff --git a/internal/testhelper/requesthandlers/requesthandlers.go b/internal/testhelper/requesthandlers/requesthandlers.go
new file mode 100644
index 0000000..11817e8
--- /dev/null
+++ b/internal/testhelper/requesthandlers/requesthandlers.go
@@ -0,0 +1,58 @@
+package requesthandlers
+
+import (
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver"
+)
+
+func BuildDisallowedByApiHandlers(t *testing.T) []testserver.TestRequestHandler {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/allowed",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ body := map[string]interface{}{
+ "status": false,
+ "message": "Disallowed by API call",
+ }
+ w.WriteHeader(http.StatusForbidden)
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ },
+ },
+ }
+
+ return requests
+}
+
+func BuildAllowedWithGitalyHandlers(t *testing.T, gitalyAddress string) []testserver.TestRequestHandler {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/allowed",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ body := map[string]interface{}{
+ "status": true,
+ "gl_id": "1",
+ "gitaly": map[string]interface{}{
+ "repository": map[string]interface{}{
+ "storage_name": "storage_name",
+ "relative_path": "relative_path",
+ "git_object_directory": "path/to/git_object_directory",
+ "git_alternate_object_directories": []string{"path/to/git_alternate_object_directory"},
+ "gl_repository": "group/repo",
+ "gl_project_path": "group/project-path",
+ },
+ "address": gitalyAddress,
+ "token": "token",
+ },
+ }
+ require.NoError(t, json.NewEncoder(w).Encode(body))
+ },
+ },
+ }
+
+ return requests
+}
diff --git a/internal/testhelper/testdata/testroot/.gitlab_shell_secret b/internal/testhelper/testdata/testroot/.gitlab_shell_secret
new file mode 100644
index 0000000..9bd459d
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/.gitlab_shell_secret
@@ -0,0 +1 @@
+default-secret-content \ No newline at end of file
diff --git a/internal/testhelper/testdata/testroot/certs/invalid/server.crt b/internal/testhelper/testdata/testroot/certs/invalid/server.crt
new file mode 100644
index 0000000..f8a42c1
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/certs/invalid/server.crt
@@ -0,0 +1,10 @@
+-----BEGIN CERTIFICATE-----
+MinvalidcertAOvHjs6cs1R9MAoGCCqGSM49BAMCMBQxEjAQBgNVBAMMCWxvY2Fs
+ainvalidcertOTA0MjQxNjM4NTBaFw0yOTA0MjExNjM4NTBaMBQxEjAQBgNVBAMM
+CinvalidcertdDB2MBAGByqGSM49AgEGBSuBBAAiA2IABJ5m7oW9OuL7aTAC04sL
+3invalidcertdB2L0GsVCImav4PEpx6UAjkoiNGW9j0zPdNgxTYDjiCaGmr1aY2X
+kinvalidcert7MNq7H8v7Ce/vrKkcDMOX8Gd/ddT3dEVqzAKBggqhkjOPQQDAgNp
+AinvalidcertswcyjiB+A+ZjMSfaOsA2hAP0I3fkTcry386DePViMfnaIjm7rcuu
+Jinvalidcert5V5CHypOxio1tOtGjaDkSH2FCdoatMyIe02+F6TIo44i4J/zjN52
+Jinvalidcert
+-----END CERTIFICATE-----
diff --git a/internal/testhelper/testdata/testroot/certs/valid/dir/.gitkeep b/internal/testhelper/testdata/testroot/certs/valid/dir/.gitkeep
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/certs/valid/dir/.gitkeep
diff --git a/internal/testhelper/testdata/testroot/certs/valid/server.crt b/internal/testhelper/testdata/testroot/certs/valid/server.crt
new file mode 100644
index 0000000..11f1da7
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/certs/valid/server.crt
@@ -0,0 +1,22 @@
+-----BEGIN CERTIFICATE-----
+MIIDrjCCApagAwIBAgIUHVNTmyz3p+7xSEMkSfhPz4BZfqwwDQYJKoZIhvcNAQEL
+BQAwTjELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEjAQBgNVBAcM
+CVRoZSBDbG91ZDEWMBQGA1UECgwNTXkgQ29tcGFueSBDQTAeFw0xOTA5MjAxMDQ3
+NTlaFw0yOTA5MTcxMDQ3NTlaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp
+Zm9ybmlhMRIwEAYDVQQHDAlUaGUgQ2xvdWQxDTALBgNVBAoMBERlbW8xFzAVBgNV
+BAMMDk15IENlcnRpZmljYXRlMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
+AQEAmte3G/eD+quamwyFl+2jEo8ngSAT0FWeY5ZAwRvdF4FgtTBLvbAdTnyi7pHM
+esCSUkyxXHHPazM4SDV6uiu5LNKF0iz/NY76rLtFoqSGUgygTZHVbZ6NRXCNUZ0P
+slD95wOCWvS9t9xgNXry66k8+mfZNhE+cFQfrO/pN5WpNuGyWTfKlUQw5NVL3mob
+j3tSjI+wzSpbPMvbTQoBiZ/VHkyyc15YdrbePwFB2dJbxE/Xgsyk/TwWSUFnAs6i
+1x2t+423NIm9rIDTdW2YYJJXv3MUcdDIxJnY0beGePMIymn9ZIRUJtK/ZXmwMb52
+v70+YTcsG67uSm31CR8jNt8qpQIDAQABo3QwcjAJBgNVHRMEAjAAMB0GA1UdDgQW
+BBTxZ9SORmIwDs90TW8UXIVhDst4kjALBgNVHQ8EBAMCBaAwHQYDVR0lBBYwFAYI
+KwYBBQUHAwIGCCsGAQUFBwMBMBoGA1UdEQQTMBGHBH8AAAGCCWxvY2FsaG9zdDAN
+BgkqhkiG9w0BAQsFAAOCAQEAf4Iq94Su9TlkReMS4x2N5xZru9YoKQtrrxqWSRbp
+oh5Lwtk9rJPy6q4IEPXzDsRI1YWCZe1Fw7zdiNfmoFRxjs59MBJ9YVrcFeyeAILg
+LiAiRcGth2THpikCnLxmniGHUUX1WfjmcDEYMIs6BZ98N64VWwtuZqcJnJPmQs64
+lDrgW9oz6/8hPMeW58ok8PjkiG+E+srBaURoKwNe7vfPRVyq45N67/juH+4o6QBd
+WP6ACjDM3RnxyWyW0S+sl3i3EAGgtwM6RIDhOG238HOIiA/I/+CCmITsvujz6jMN
+bLdoPfnatZ7f5m9DuoOsGlYAZbLfOl2NywgO0jAlnHJGEQ==
+-----END CERTIFICATE-----
diff --git a/internal/testhelper/testdata/testroot/certs/valid/server.key b/internal/testhelper/testdata/testroot/certs/valid/server.key
new file mode 100644
index 0000000..acec0fb
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/certs/valid/server.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEogIBAAKCAQEAmte3G/eD+quamwyFl+2jEo8ngSAT0FWeY5ZAwRvdF4FgtTBL
+vbAdTnyi7pHMesCSUkyxXHHPazM4SDV6uiu5LNKF0iz/NY76rLtFoqSGUgygTZHV
+bZ6NRXCNUZ0PslD95wOCWvS9t9xgNXry66k8+mfZNhE+cFQfrO/pN5WpNuGyWTfK
+lUQw5NVL3mobj3tSjI+wzSpbPMvbTQoBiZ/VHkyyc15YdrbePwFB2dJbxE/Xgsyk
+/TwWSUFnAs6i1x2t+423NIm9rIDTdW2YYJJXv3MUcdDIxJnY0beGePMIymn9ZIRU
+JtK/ZXmwMb52v70+YTcsG67uSm31CR8jNt8qpQIDAQABAoIBAEJQyNdtdlTRUfG9
+tymOWR0FuoGO322GfcNhAnKyIEqE2oo/GPEwkByhPJa4Ur7v4rrkpcFV7OOYmC40
+2U8KktAjibSuGM8zYSDBQ92YYP6a8bzHDIVaNl7bCWs+vQ49qcBavGWAFBC+jWXa
+Nle/r6H/AAQr9nXdUYObbGKl8kbSUBNAqQHILsNyxQsAo12oqRnUWhIbfzUFBr1m
+us93OsvpOYWgkbaBWk0brjp2X0eNGHctTboFxRknJcU6MQVL5degbgXhnCm4ir4O
+E2KMubEwxePr5fPotWNQXCVin85OQv1eb70anfwoA2b5/ykb57jo5EDoiUoFsjLz
+KLAaRQECgYEAzZNP/CpwCh5s31SDr7ajYfNIu8ie370g2Qbf4jrqVrOJ8Sj1LRYB
+lS5+QbSRu4W6Ani3AQwZA09lS608G8w5rD7YGRVDCFuwJt+Yz5GcsSkso9B8DR4h
+vCe2WuDutz7M5ikP1DAc/9x5HIzjQijxM1JJCNU2nR6QoFvV6wpVcpECgYEAwNK9
+oTqyb7UjNinAo9PFrFpnbX+DoGokGPsRyUwi9UkyRR0Uf7Kxjoq2C8zsCvnGdrE7
+kwUiWjyfAgMDF8+iWHYO1vD7m6NL31h/AAmo0NEQIBs0LFj0lF0xORzvXdTjhvuG
+LxXhm927z4WBOCLTn8FAsBUjVBpmB6ffyZCVWNUCgYA3P4j2fz0/KvAdkSwW9CGy
+uFxqwz8XaE/Eo9lVhnnmNTg0TMqfhFOGkUkzRWEJIaZc9a5RJLwwLI1Pqk4GNnul
+c/pFu3YZb/LGb780wbB32FX77JL6P4fXdmDGyb6+Fq2giZaMcyXICauu5ZpJ9JDm
+Nw4TxqF31ngN8MBr+4n9UQKBgAkxAoEQ/zh79fW6/8fPbHjOxmdd0LRw2s+mCC8E
+RhZTKuZIgJWluvkEe7EMT6QmS+OUhzZ25DBQ+3NpGVilOSPmXMa6LgQ5QIChA0zJ
+KRbrIE2nflEu3FnGJ3aFfpOGdmIU00yjSmHXrAA0aPh4EIZo++Bo4Yo8x+hNhElj
+bvsRAoGADYZTUchbiVndk5QtnbwlDjrF5PmgjoDboBfv9/6FU+DzQRyOpl3kr0hs
+OcZGE6xPZJidv1Bcv60L1VzTMj7spvMRTeumn2zEQGjkl6i/fSZzawjmKaKXKNkC
+YfoV0RepB4TlNYGICaTcV+aKRIXivcpBGfduZEb39iUKCjh9Afg=
+-----END RSA PRIVATE KEY-----
diff --git a/internal/testhelper/testdata/testroot/config.yml b/internal/testhelper/testdata/testroot/config.yml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/config.yml
diff --git a/internal/testhelper/testdata/testroot/custom/my-contents-is-secret b/internal/testhelper/testdata/testroot/custom/my-contents-is-secret
new file mode 100644
index 0000000..645b575
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/custom/my-contents-is-secret
@@ -0,0 +1 @@
+custom-secret-content \ No newline at end of file
diff --git a/internal/testhelper/testdata/testroot/gitlab-shell.log b/internal/testhelper/testdata/testroot/gitlab-shell.log
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/gitlab-shell.log
diff --git a/internal/testhelper/testdata/testroot/responses/allowed.json b/internal/testhelper/testdata/testroot/responses/allowed.json
new file mode 100644
index 0000000..d0403d9
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/responses/allowed.json
@@ -0,0 +1,22 @@
+{
+ "status": true,
+ "gl_repository": "project-26",
+ "gl_project_path": "group/private",
+ "gl_id": "user-1",
+ "gl_username": "root",
+ "git_config_options": ["option"],
+ "gitaly": {
+ "repository": {
+ "storage_name": "default",
+ "relative_path": "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git",
+ "git_object_directory": "path/to/git_object_directory",
+ "git_alternate_object_directories": ["path/to/git_alternate_object_directory"],
+ "gl_repository": "project-26",
+ "gl_project_path": "group/private"
+ },
+ "address": "unix:gitaly.socket",
+ "token": "token"
+ },
+ "git_protocol": "protocol",
+ "gl_console_messages": ["console", "message"]
+}
diff --git a/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json b/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json
new file mode 100644
index 0000000..331c3a9
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json
@@ -0,0 +1,31 @@
+{
+ "status": true,
+ "gl_repository": "project-26",
+ "gl_project_path": "group/private",
+ "gl_id": "user-1",
+ "gl_username": "root",
+ "git_config_options": ["option"],
+ "gitaly": {
+ "repository": {
+ "storage_name": "default",
+ "relative_path": "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git",
+ "git_object_directory": "path/to/git_object_directory",
+ "git_alternate_object_directories": ["path/to/git_alternate_object_directory"],
+ "gl_repository": "project-26",
+ "gl_project_path": "group/private"
+ },
+ "address": "unix:gitaly.socket",
+ "token": "token"
+ },
+ "payload" : {
+ "action": "geo_proxy_to_primary",
+ "data": {
+ "api_endpoints": ["geo/proxy_git_push_ssh/info_refs", "geo/proxy_git_push_ssh/push"],
+ "gl_username": "custom",
+ "primary_repo": "https://repo/path",
+ "info_message": "message"
+ }
+ },
+ "git_protocol": "protocol",
+ "gl_console_messages": ["console", "message"]
+}
diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go
new file mode 100644
index 0000000..a925c79
--- /dev/null
+++ b/internal/testhelper/testhelper.go
@@ -0,0 +1,93 @@
+package testhelper
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path"
+ "runtime"
+
+ "github.com/otiai10/copy"
+)
+
+var (
+ TestRoot, _ = ioutil.TempDir("", "test-gitlab-shell")
+)
+
+func TempEnv(env map[string]string) func() {
+ var original = make(map[string]string)
+ for key, value := range env {
+ original[key] = os.Getenv(key)
+ os.Setenv(key, value)
+ }
+
+ return func() {
+ for key, originalValue := range original {
+ os.Setenv(key, originalValue)
+ }
+ }
+}
+
+func PrepareTestRootDir() (func(), error) {
+ if err := os.MkdirAll(TestRoot, 0700); err != nil {
+ return nil, err
+ }
+
+ var oldWd string
+ cleanup := func() {
+ if oldWd != "" {
+ err := os.Chdir(oldWd)
+ if err != nil {
+ panic(err)
+ }
+ }
+
+ if err := os.RemoveAll(TestRoot); err != nil {
+ panic(err)
+ }
+ }
+
+ if err := copyTestData(); err != nil {
+ cleanup()
+ return nil, err
+ }
+
+ oldWd, err := os.Getwd()
+ if err != nil {
+ cleanup()
+ return nil, err
+ }
+
+ if err := os.Chdir(TestRoot); err != nil {
+ cleanup()
+ return nil, err
+ }
+
+ return cleanup, nil
+}
+
+func copyTestData() error {
+ testDataDir, err := getTestDataDir()
+ if err != nil {
+ return err
+ }
+
+ testdata := path.Join(testDataDir, "testroot")
+
+ return copy.Copy(testdata, TestRoot)
+}
+
+func getTestDataDir() (string, error) {
+ _, currentFile, _, ok := runtime.Caller(0)
+ if !ok {
+ return "", fmt.Errorf("Could not get caller info")
+ }
+
+ return path.Join(path.Dir(currentFile), "testdata"), nil
+}
+
+func Setenv(key, value string) (func(), error) {
+ oldValue := os.Getenv(key)
+ err := os.Setenv(key, value)
+ return func() { os.Setenv(key, oldValue) }, err
+}