diff options
author | Igor <idrozdov@gitlab.com> | 2019-10-21 16:25:53 +0000 |
---|---|---|
committer | Igor <idrozdov@gitlab.com> | 2019-10-21 16:25:53 +0000 |
commit | 629e3bf9c31687f7b824cf29ba07ad2ce402e280 (patch) | |
tree | 0f80f7394231d39970f23a08ba9ba2ce7051e22c /internal | |
parent | 7d5229db263a62661653431881bef8b46984d0de (diff) | |
parent | ede41ee451dd0aa6d0ecd958c7fadbdb3b63f3e4 (diff) | |
download | gitlab-shell-629e3bf9c31687f7b824cf29ba07ad2ce402e280.tar.gz |
Merge branch '173-move-go-code-up-one-level' into 'master'
Move Go code up one level
See merge request gitlab-org/gitlab-shell!350
Diffstat (limited to 'internal')
81 files changed, 5705 insertions, 0 deletions
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go new file mode 100644 index 0000000..f1cab45 --- /dev/null +++ b/internal/command/authorizedkeys/authorized_keys.go @@ -0,0 +1,61 @@ +package authorizedkeys + +import ( + "fmt" + "strconv" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" + "gitlab.com/gitlab-org/gitlab-shell/internal/keyline" +) + +type Command struct { + Config *config.Config + Args *commandargs.AuthorizedKeys + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + // Do and return nothing when the expected and actual user don't match. + // This can happen when the user in sshd_config doesn't match the user + // trying to login. When nothing is printed, the user will be denied access. + if c.Args.ExpectedUser != c.Args.ActualUser { + // TODO: Log this event once we have a consistent way to log in Go. + // See https://gitlab.com/gitlab-org/gitlab-shell/issues/192 for more info. + return nil + } + + if err := c.printKeyLine(); err != nil { + return err + } + + return nil +} + +func (c *Command) printKeyLine() error { + response, err := c.getAuthorizedKey() + if err != nil { + fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key)) + return nil + } + + keyLine, err := keyline.NewPublicKeyLine(strconv.FormatInt(response.Id, 10), response.Key, c.Config.RootDir) + if err != nil { + return err + } + + fmt.Fprintln(c.ReadWriter.Out, keyLine.ToString()) + + return nil +} + +func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) { + client, err := authorizedkeys.NewClient(c.Config) + if err != nil { + return nil, err + } + + return client.GetByKey(c.Args.Key) +} diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go new file mode 100644 index 0000000..3bf4153 --- /dev/null +++ b/internal/command/authorizedkeys/authorized_keys_test.go @@ -0,0 +1,90 @@ +package authorizedkeys + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +var ( + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/authorized_keys", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key") == "key" { + body := map[string]interface{}{ + "id": 1, + "key": "public-key", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("key") == "broken-message" { + body := map[string]string{ + "message": "Forbidden!", + } + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("key") == "broken" { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(http.StatusNotFound) + } + }, + }, + } +) + +func TestExecute(t *testing.T) { + url, cleanup := testserver.StartSocketHttpServer(t, requests) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.AuthorizedKeys + expectedOutput string + }{ + { + desc: "With matching username and key", + arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "key"}, + expectedOutput: "command=\"/tmp/bin/gitlab-shell key-1\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty public-key\n", + }, + { + desc: "When key doesn't match any existing key", + arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "not-found"}, + expectedOutput: "# No key was found for not-found\n", + }, + { + desc: "When the API returns an error", + arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "broken-message"}, + expectedOutput: "# No key was found for broken-message\n", + }, + { + desc: "When the API fails", + arguments: &commandargs.AuthorizedKeys{ExpectedUser: "user", ActualUser: "user", Key: "broken"}, + expectedOutput: "# No key was found for broken\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{RootDir: "/tmp", GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } + + err := cmd.Execute() + + require.NoError(t, err) + require.Equal(t, tc.expectedOutput, buffer.String()) + }) + } +} diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go new file mode 100644 index 0000000..10ae70e --- /dev/null +++ b/internal/command/authorizedprincipals/authorized_principals.go @@ -0,0 +1,47 @@ +package authorizedprincipals + +import ( + "fmt" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/keyline" +) + +type Command struct { + Config *config.Config + Args *commandargs.AuthorizedPrincipals + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + if err := c.printPrincipalLines(); err != nil { + return err + } + + return nil +} + +func (c *Command) printPrincipalLines() error { + principals := c.Args.Principals + + for _, principal := range principals { + if err := c.printPrincipalLine(principal); err != nil { + return err + } + } + + return nil +} + +func (c *Command) printPrincipalLine(principal string) error { + principalKeyLine, err := keyline.NewPrincipalKeyLine(c.Args.KeyId, principal, c.Config.RootDir) + if err != nil { + return err + } + + fmt.Fprintln(c.ReadWriter.Out, principalKeyLine.ToString()) + + return nil +} diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go new file mode 100644 index 0000000..f0334e5 --- /dev/null +++ b/internal/command/authorizedprincipals/authorized_principals_test.go @@ -0,0 +1,47 @@ +package authorizedprincipals + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +func TestExecute(t *testing.T) { + testCases := []struct { + desc string + arguments *commandargs.AuthorizedPrincipals + expectedOutput string + }{ + { + desc: "With single principal", + arguments: &commandargs.AuthorizedPrincipals{KeyId: "key", Principals: []string{"principal"}}, + expectedOutput: "command=\"/tmp/bin/gitlab-shell username-key\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty principal\n", + }, + { + desc: "With multiple principals", + arguments: &commandargs.AuthorizedPrincipals{KeyId: "key", Principals: []string{"principal-1", "principal-2"}}, + expectedOutput: "command=\"/tmp/bin/gitlab-shell username-key\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty principal-1\ncommand=\"/tmp/bin/gitlab-shell username-key\",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty principal-2\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{RootDir: "/tmp"}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } + + err := cmd.Execute() + + require.NoError(t, err) + require.Equal(t, tc.expectedOutput, buffer.String()) + }) + } +} diff --git a/internal/command/command.go b/internal/command/command.go new file mode 100644 index 0000000..af63862 --- /dev/null +++ b/internal/command/command.go @@ -0,0 +1,81 @@ +package command + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" +) + +type Command interface { + Execute() error +} + +func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) { + args, err := commandargs.Parse(e, arguments) + if err != nil { + return nil, err + } + + if cmd := buildCommand(e, args, config, readWriter); cmd != nil { + return cmd, nil + } + + return nil, disallowedcommand.Error +} + +func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command { + switch e.Name { + case executable.GitlabShell: + return buildShellCommand(args.(*commandargs.Shell), config, readWriter) + case executable.AuthorizedKeysCheck: + return buildAuthorizedKeysCommand(args.(*commandargs.AuthorizedKeys), config, readWriter) + case executable.AuthorizedPrincipalsCheck: + return buildAuthorizedPrincipalsCommand(args.(*commandargs.AuthorizedPrincipals), config, readWriter) + case executable.Healthcheck: + return buildHealthcheckCommand(config, readWriter) + } + + return nil +} + +func buildShellCommand(args *commandargs.Shell, config *config.Config, readWriter *readwriter.ReadWriter) Command { + switch args.CommandType { + case commandargs.Discover: + return &discover.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.TwoFactorRecover: + return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.LfsAuthenticate: + return &lfsauthenticate.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.ReceivePack: + return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.UploadPack: + return &uploadpack.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.UploadArchive: + return &uploadarchive.Command{Config: config, Args: args, ReadWriter: readWriter} + } + + return nil +} + +func buildAuthorizedKeysCommand(args *commandargs.AuthorizedKeys, config *config.Config, readWriter *readwriter.ReadWriter) Command { + return &authorizedkeys.Command{Config: config, Args: args, ReadWriter: readWriter} +} + +func buildAuthorizedPrincipalsCommand(args *commandargs.AuthorizedPrincipals, config *config.Config, readWriter *readwriter.ReadWriter) Command { + return &authorizedprincipals.Command{Config: config, Args: args, ReadWriter: readWriter} +} + +func buildHealthcheckCommand(config *config.Config, readWriter *readwriter.ReadWriter) Command { + return &healthcheck.Command{Config: config, ReadWriter: readWriter} +} diff --git a/internal/command/command_test.go b/internal/command/command_test.go new file mode 100644 index 0000000..2ca319e --- /dev/null +++ b/internal/command/command_test.go @@ -0,0 +1,146 @@ +package command + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +var ( + authorizedKeysExec = &executable.Executable{Name: executable.AuthorizedKeysCheck} + authorizedPrincipalsExec = &executable.Executable{Name: executable.AuthorizedPrincipalsCheck} + checkExec = &executable.Executable{Name: executable.Healthcheck} + gitlabShellExec = &executable.Executable{Name: executable.GitlabShell} + + basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"} +) + +func buildEnv(command string) map[string]string { + return map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": command, + } +} + +func TestNew(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + environment map[string]string + arguments []string + expectedType interface{} + }{ + { + desc: "it returns a Discover command", + executable: gitlabShellExec, + environment: buildEnv(""), + expectedType: &discover.Command{}, + }, + { + desc: "it returns a TwoFactorRecover command", + executable: gitlabShellExec, + environment: buildEnv("2fa_recovery_codes"), + expectedType: &twofactorrecover.Command{}, + }, + { + desc: "it returns an LfsAuthenticate command", + executable: gitlabShellExec, + environment: buildEnv("git-lfs-authenticate"), + expectedType: &lfsauthenticate.Command{}, + }, + { + desc: "it returns a ReceivePack command", + executable: gitlabShellExec, + environment: buildEnv("git-receive-pack"), + expectedType: &receivepack.Command{}, + }, + { + desc: "it returns an UploadPack command", + executable: gitlabShellExec, + environment: buildEnv("git-upload-pack"), + expectedType: &uploadpack.Command{}, + }, + { + desc: "it returns an UploadArchive command", + executable: gitlabShellExec, + environment: buildEnv("git-upload-archive"), + expectedType: &uploadarchive.Command{}, + }, + { + desc: "it returns a Healthcheck command", + executable: checkExec, + expectedType: &healthcheck.Command{}, + }, + { + desc: "it returns a AuthorizedKeys command", + executable: authorizedKeysExec, + arguments: []string{"git", "git", "key"}, + expectedType: &authorizedkeys.Command{}, + }, + { + desc: "it returns a AuthorizedPrincipals command", + executable: authorizedPrincipalsExec, + arguments: []string{"key", "principal"}, + expectedType: &authorizedprincipals.Command{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + restoreEnv := testhelper.TempEnv(tc.environment) + defer restoreEnv() + + command, err := New(tc.executable, tc.arguments, basicConfig, nil) + + require.NoError(t, err) + require.IsType(t, tc.expectedType, command) + }) + } +} + +func TestFailingNew(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + environment map[string]string + expectedError error + }{ + { + desc: "Parsing environment failed", + executable: gitlabShellExec, + expectedError: errors.New("Only SSH allowed"), + }, + { + desc: "Unknown command given", + executable: gitlabShellExec, + environment: buildEnv("unknown"), + expectedError: disallowedcommand.Error, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + restoreEnv := testhelper.TempEnv(tc.environment) + defer restoreEnv() + + command, err := New(tc.executable, []string{}, basicConfig, nil) + require.Nil(t, command) + require.Equal(t, tc.expectedError, err) + }) + } +} diff --git a/internal/command/commandargs/authorized_keys.go b/internal/command/commandargs/authorized_keys.go new file mode 100644 index 0000000..2733954 --- /dev/null +++ b/internal/command/commandargs/authorized_keys.go @@ -0,0 +1,51 @@ +package commandargs + +import ( + "errors" + "fmt" +) + +type AuthorizedKeys struct { + Arguments []string + ExpectedUser string + ActualUser string + Key string +} + +func (ak *AuthorizedKeys) Parse() error { + if err := ak.validate(); err != nil { + return err + } + + ak.ExpectedUser = ak.Arguments[0] + ak.ActualUser = ak.Arguments[1] + ak.Key = ak.Arguments[2] + + return nil +} + +func (ak *AuthorizedKeys) GetArguments() []string { + return ak.Arguments +} + +func (ak *AuthorizedKeys) validate() error { + argsSize := len(ak.Arguments) + + if argsSize != 3 { + return errors.New(fmt.Sprintf("# Insufficient arguments. %d. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", argsSize)) + } + + expectedUsername := ak.Arguments[0] + actualUsername := ak.Arguments[1] + key := ak.Arguments[2] + + if expectedUsername == "" || actualUsername == "" { + return errors.New("# No username provided") + } + + if key == "" { + return errors.New("# No key provided") + } + + return nil +} diff --git a/internal/command/commandargs/authorized_principals.go b/internal/command/commandargs/authorized_principals.go new file mode 100644 index 0000000..746ae3f --- /dev/null +++ b/internal/command/commandargs/authorized_principals.go @@ -0,0 +1,50 @@ +package commandargs + +import ( + "errors" + "fmt" +) + +type AuthorizedPrincipals struct { + Arguments []string + KeyId string + Principals []string +} + +func (ap *AuthorizedPrincipals) Parse() error { + if err := ap.validate(); err != nil { + return err + } + + ap.KeyId = ap.Arguments[0] + ap.Principals = ap.Arguments[1:] + + return nil +} + +func (ap *AuthorizedPrincipals) GetArguments() []string { + return ap.Arguments +} + +func (ap *AuthorizedPrincipals) validate() error { + argsSize := len(ap.Arguments) + + if argsSize < 2 { + return errors.New(fmt.Sprintf("# Insufficient arguments. %d. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]", argsSize)) + } + + keyId := ap.Arguments[0] + principals := ap.Arguments[1:] + + if keyId == "" { + return errors.New("# No key_id provided") + } + + for _, principal := range principals { + if principal == "" { + return errors.New("# An invalid principal was provided") + } + } + + return nil +} diff --git a/internal/command/commandargs/command_args.go b/internal/command/commandargs/command_args.go new file mode 100644 index 0000000..b4bf334 --- /dev/null +++ b/internal/command/commandargs/command_args.go @@ -0,0 +1,31 @@ +package commandargs + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" +) + +type CommandType string + +type CommandArgs interface { + Parse() error + GetArguments() []string +} + +func Parse(e *executable.Executable, arguments []string) (CommandArgs, error) { + var args CommandArgs = &GenericArgs{Arguments: arguments} + + switch e.Name { + case executable.GitlabShell: + args = &Shell{Arguments: arguments} + case executable.AuthorizedKeysCheck: + args = &AuthorizedKeys{Arguments: arguments} + case executable.AuthorizedPrincipalsCheck: + args = &AuthorizedPrincipals{Arguments: arguments} + } + + if err := args.Parse(); err != nil { + return nil, err + } + + return args, nil +} diff --git a/internal/command/commandargs/command_args_test.go b/internal/command/commandargs/command_args_test.go new file mode 100644 index 0000000..aa74237 --- /dev/null +++ b/internal/command/commandargs/command_args_test.go @@ -0,0 +1,231 @@ +package commandargs + +import ( + "testing" + + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestParseSuccess(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + environment map[string]string + arguments []string + expectedArgs CommandArgs + }{ + // Setting the used env variables for every case to ensure we're + // not using anything set in the original env. + { + desc: "It sets discover as the command when the command string was empty", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "", + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{}, CommandType: Discover}, + }, + { + desc: "It finds the key id in any passed arguments", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "", + }, + arguments: []string{"hello", "key-123"}, + expectedArgs: &Shell{Arguments: []string{"hello", "key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123"}, + }, { + desc: "It finds the username in any passed arguments", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "", + }, + arguments: []string{"hello", "username-jane-doe"}, + expectedArgs: &Shell{Arguments: []string{"hello", "username-jane-doe"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe"}, + }, { + desc: "It parses 2fa_recovery_codes command", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes", + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover}, + }, { + desc: "It parses git-receive-pack command", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git-receive-pack group/repo", + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: "It parses git-receive-pack command and a project with single quotes", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git receive-pack 'group/repo'", + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: `It parses "git receive-pack" command`, + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git receive-pack "group/repo"`, + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: `It parses a command followed by control characters`, + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git-receive-pack group/repo; any command`, + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: "It parses git-upload-pack command", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git upload-pack "group/repo"`, + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-pack", "group/repo"}, CommandType: UploadPack}, + }, { + desc: "It parses git-upload-archive command", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git-upload-archive 'group/repo'", + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-archive", "group/repo"}, CommandType: UploadArchive}, + }, { + desc: "It parses git-lfs-authenticate command", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git-lfs-authenticate 'group/repo' download", + }, + arguments: []string{}, + expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}, CommandType: LfsAuthenticate}, + }, { + desc: "It parses authorized-keys command", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"git", "git", "key"}, + expectedArgs: &AuthorizedKeys{Arguments: []string{"git", "git", "key"}, ExpectedUser: "git", ActualUser: "git", Key: "key"}, + }, { + desc: "It parses authorized-principals command", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"key", "principal-1", "principal-2"}, + expectedArgs: &AuthorizedPrincipals{Arguments: []string{"key", "principal-1", "principal-2"}, KeyId: "key", Principals: []string{"principal-1", "principal-2"}}, + }, { + desc: "Unknown executable", + executable: &executable.Executable{Name: "unknown"}, + arguments: []string{}, + expectedArgs: &GenericArgs{Arguments: []string{}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + restoreEnv := testhelper.TempEnv(tc.environment) + defer restoreEnv() + + result, err := Parse(tc.executable, tc.arguments) + + require.NoError(t, err) + require.Equal(t, tc.expectedArgs, result) + }) + } +} + +func TestParseFailure(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + environment map[string]string + arguments []string + expectedError string + }{ + { + desc: "It fails if SSH connection is not set", + executable: &executable.Executable{Name: executable.GitlabShell}, + arguments: []string{}, + expectedError: "Only SSH allowed", + }, + { + desc: "It fails if SSH command is invalid", + executable: &executable.Executable{Name: executable.GitlabShell}, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git receive-pack "`, + }, + arguments: []string{}, + expectedError: "Invalid SSH command", + }, + { + desc: "With not enough arguments for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user"}, + expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", + }, + { + desc: "With too many arguments for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user", "user", "key", "something-else"}, + expectedError: "# Insufficient arguments. 4. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", + }, + { + desc: "With missing username for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user", "", "key"}, + expectedError: "# No username provided", + }, + { + desc: "With missing key for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user", "user", ""}, + expectedError: "# No key provided", + }, + { + desc: "With not enough arguments for the AuthorizedPrincipalsCheck", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"key"}, + expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]", + }, + { + desc: "With missing key_id for the AuthorizedPrincipalsCheck", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"", "principal"}, + expectedError: "# No key_id provided", + }, + { + desc: "With blank principal for the AuthorizedPrincipalsCheck", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"key", "principal", ""}, + expectedError: "# An invalid principal was provided", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + restoreEnv := testhelper.TempEnv(tc.environment) + defer restoreEnv() + + _, err := Parse(tc.executable, tc.arguments) + + require.EqualError(t, err, tc.expectedError) + }) + } +} diff --git a/internal/command/commandargs/generic_args.go b/internal/command/commandargs/generic_args.go new file mode 100644 index 0000000..96bed99 --- /dev/null +++ b/internal/command/commandargs/generic_args.go @@ -0,0 +1,14 @@ +package commandargs + +type GenericArgs struct { + Arguments []string +} + +func (b *GenericArgs) Parse() error { + // Do nothing + return nil +} + +func (b *GenericArgs) GetArguments() []string { + return b.Arguments +} diff --git a/internal/command/commandargs/shell.go b/internal/command/commandargs/shell.go new file mode 100644 index 0000000..7e2b72e --- /dev/null +++ b/internal/command/commandargs/shell.go @@ -0,0 +1,131 @@ +package commandargs + +import ( + "errors" + "os" + "regexp" + + "github.com/mattn/go-shellwords" +) + +const ( + Discover CommandType = "discover" + TwoFactorRecover CommandType = "2fa_recovery_codes" + LfsAuthenticate CommandType = "git-lfs-authenticate" + ReceivePack CommandType = "git-receive-pack" + UploadPack CommandType = "git-upload-pack" + UploadArchive CommandType = "git-upload-archive" +) + +var ( + whoKeyRegex = regexp.MustCompile(`\bkey-(?P<keyid>\d+)\b`) + whoUsernameRegex = regexp.MustCompile(`\busername-(?P<username>\S+)\b`) +) + +type Shell struct { + Arguments []string + GitlabUsername string + GitlabKeyId string + SshArgs []string + CommandType CommandType +} + +func (s *Shell) Parse() error { + if err := s.validate(); err != nil { + return err + } + + s.parseWho() + s.defineCommandType() + + return nil +} + +func (s *Shell) GetArguments() []string { + return s.Arguments +} + +func (s *Shell) validate() error { + if !s.isSshConnection() { + return errors.New("Only SSH allowed") + } + + if !s.isValidSshCommand() { + return errors.New("Invalid SSH command") + } + + return nil +} + +func (s *Shell) isSshConnection() bool { + ok := os.Getenv("SSH_CONNECTION") + return ok != "" +} + +func (s *Shell) isValidSshCommand() bool { + err := s.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND")) + return err == nil +} + +func (s *Shell) parseWho() { + for _, argument := range s.Arguments { + if keyId := tryParseKeyId(argument); keyId != "" { + s.GitlabKeyId = keyId + break + } + + if username := tryParseUsername(argument); username != "" { + s.GitlabUsername = username + break + } + } +} + +func tryParseKeyId(argument string) string { + matchInfo := whoKeyRegex.FindStringSubmatch(argument) + if len(matchInfo) == 2 { + // The first element is the full matched string + // The second element is the named `keyid` + return matchInfo[1] + } + + return "" +} + +func tryParseUsername(argument string) string { + matchInfo := whoUsernameRegex.FindStringSubmatch(argument) + if len(matchInfo) == 2 { + // The first element is the full matched string + // The second element is the named `username` + return matchInfo[1] + } + + return "" +} + +func (s *Shell) parseCommand(commandString string) error { + args, err := shellwords.Parse(commandString) + if err != nil { + return err + } + + // Handle Git for Windows 2.14 using "git upload-pack" instead of git-upload-pack + if len(args) > 1 && args[0] == "git" { + command := args[0] + "-" + args[1] + commandArgs := args[2:] + + args = append([]string{command}, commandArgs...) + } + + s.SshArgs = args + + return nil +} + +func (s *Shell) defineCommandType() { + if len(s.SshArgs) == 0 { + s.CommandType = Discover + } else { + s.CommandType = CommandType(s.SshArgs[0]) + } +} diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go new file mode 100644 index 0000000..3aa7456 --- /dev/null +++ b/internal/command/discover/discover.go @@ -0,0 +1,40 @@ +package discover + +import ( + "fmt" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/discover" +) + +type Command struct { + Config *config.Config + Args *commandargs.Shell + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + response, err := c.getUserInfo() + if err != nil { + return fmt.Errorf("Failed to get username: %v", err) + } + + if response.IsAnonymous() { + fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n") + } else { + fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) + } + + return nil +} + +func (c *Command) getUserInfo() (*discover.Response, error) { + client, err := discover.NewClient(c.Config) + if err != nil { + return nil, err + } + + return client.GetByCommandArgs(c.Args) +} diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go new file mode 100644 index 0000000..3878286 --- /dev/null +++ b/internal/command/discover/discover_test.go @@ -0,0 +1,135 @@ +package discover + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +var ( + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" { + body := map[string]interface{}{ + "id": 2, + "username": "alex-doe", + "name": "Alex Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_message" { + body := map[string]string{ + "message": "Forbidden!", + } + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken" { + w.WriteHeader(http.StatusInternalServerError) + } else { + fmt.Fprint(w, "null") + } + }, + }, + } +) + +func TestExecute(t *testing.T) { + url, cleanup := testserver.StartSocketHttpServer(t, requests) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.Shell + expectedOutput string + }{ + { + desc: "With a known username", + arguments: &commandargs.Shell{GitlabUsername: "alex-doe"}, + expectedOutput: "Welcome to GitLab, @alex-doe!\n", + }, + { + desc: "With a known key id", + arguments: &commandargs.Shell{GitlabKeyId: "1"}, + expectedOutput: "Welcome to GitLab, @alex-doe!\n", + }, + { + desc: "With an unknown key", + arguments: &commandargs.Shell{GitlabKeyId: "-1"}, + expectedOutput: "Welcome to GitLab, Anonymous!\n", + }, + { + desc: "With an unknown username", + arguments: &commandargs.Shell{GitlabUsername: "unknown"}, + expectedOutput: "Welcome to GitLab, Anonymous!\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } + + err := cmd.Execute() + + require.NoError(t, err) + require.Equal(t, tc.expectedOutput, buffer.String()) + }) + } +} + +func TestFailingExecute(t *testing.T) { + url, cleanup := testserver.StartSocketHttpServer(t, requests) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.Shell + expectedError string + }{ + { + desc: "With missing arguments", + arguments: &commandargs.Shell{}, + expectedError: "Failed to get username: who='' is invalid", + }, + { + desc: "When the API returns an error", + arguments: &commandargs.Shell{GitlabUsername: "broken_message"}, + expectedError: "Failed to get username: Forbidden!", + }, + { + desc: "When the API fails", + arguments: &commandargs.Shell{GitlabUsername: "broken"}, + expectedError: "Failed to get username: Internal API error (500)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } + + err := cmd.Execute() + + require.Empty(t, buffer.String()) + require.EqualError(t, err, tc.expectedError) + }) + } +} diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go new file mode 100644 index 0000000..bbc73dc --- /dev/null +++ b/internal/command/healthcheck/healthcheck.go @@ -0,0 +1,49 @@ +package healthcheck + +import ( + "fmt" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/healthcheck" +) + +var ( + apiMessage = "Internal API available" + redisMessage = "Redis available via internal API" +) + +type Command struct { + Config *config.Config + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + response, err := c.runCheck() + if err != nil { + return fmt.Errorf("%v: FAILED - %v", apiMessage, err) + } + + fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", apiMessage) + + if !response.Redis { + return fmt.Errorf("%v: FAILED", redisMessage) + } + + fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", redisMessage) + return nil +} + +func (c *Command) runCheck() (*healthcheck.Response, error) { + client, err := healthcheck.NewClient(c.Config) + if err != nil { + return nil, err + } + + response, err := client.Check() + if err != nil { + return nil, err + } + + return response, nil +} diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go new file mode 100644 index 0000000..e59c5a2 --- /dev/null +++ b/internal/command/healthcheck/healthcheck_test.go @@ -0,0 +1,90 @@ +package healthcheck + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/healthcheck" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +var ( + okResponse = &healthcheck.Response{ + APIVersion: "v4", + GitlabVersion: "v12.0.0-ee", + GitlabRevision: "3b13818e8330f68625d80d9bf5d8049c41fbe197", + Redis: true, + } + + badRedisResponse = &healthcheck.Response{Redis: false} + + okHandlers = buildTestHandlers(200, okResponse) + badRedisHandlers = buildTestHandlers(200, badRedisResponse) + brokenHandlers = buildTestHandlers(500, nil) +) + +func buildTestHandlers(code int, rsp *healthcheck.Response) []testserver.TestRequestHandler { + return []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/check", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + if rsp != nil { + json.NewEncoder(w).Encode(rsp) + } + }, + }, + } +} + +func TestExecute(t *testing.T) { + url, cleanup := testserver.StartSocketHttpServer(t, okHandlers) + defer cleanup() + + buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } + + err := cmd.Execute() + + require.NoError(t, err) + require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String()) +} + +func TestFailingRedisExecute(t *testing.T) { + url, cleanup := testserver.StartSocketHttpServer(t, badRedisHandlers) + defer cleanup() + + buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } + + err := cmd.Execute() + require.Error(t, err, "Redis available via internal API: FAILED") + require.Equal(t, "Internal API available: OK\n", buffer.String()) +} + +func TestFailingAPIExecute(t *testing.T) { + url, cleanup := testserver.StartSocketHttpServer(t, brokenHandlers) + defer cleanup() + + buffer := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + ReadWriter: &readwriter.ReadWriter{Out: buffer}, + } + + err := cmd.Execute() + require.Empty(t, buffer.String()) + require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)") +} diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go new file mode 100644 index 0000000..1b2a742 --- /dev/null +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -0,0 +1,104 @@ +package lfsauthenticate + +import ( + "encoding/base64" + "encoding/json" + "fmt" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/lfsauthenticate" +) + +const ( + downloadAction = "download" + uploadAction = "upload" +) + +type Command struct { + Config *config.Config + Args *commandargs.Shell + ReadWriter *readwriter.ReadWriter +} + +type PayloadHeader struct { + Auth string `json:"Authorization"` +} + +type Payload struct { + Header PayloadHeader `json:"header"` + Href string `json:"href"` + ExpiresIn int `json:"expires_in,omitempty"` +} + +func (c *Command) Execute() error { + args := c.Args.SshArgs + if len(args) < 3 { + return disallowedcommand.Error + } + + repo := args[1] + action, err := actionToCommandType(args[2]) + if err != nil { + return err + } + + accessResponse, err := c.verifyAccess(action, repo) + if err != nil { + return err + } + + payload, err := c.authenticate(action, repo, accessResponse.UserId) + if err != nil { + // return nothing just like Ruby's GitlabShell#lfs_authenticate does + return nil + } + + fmt.Fprintf(c.ReadWriter.Out, "%s\n", payload) + + return nil +} + +func actionToCommandType(action string) (commandargs.CommandType, error) { + var accessAction commandargs.CommandType + switch action { + case downloadAction: + accessAction = commandargs.UploadPack + case uploadAction: + accessAction = commandargs.ReceivePack + default: + return "", disallowedcommand.Error + } + + return accessAction, nil +} + +func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) { + cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} + + return cmd.Verify(action, repo) +} + +func (c *Command) authenticate(action commandargs.CommandType, repo, userId string) ([]byte, error) { + client, err := lfsauthenticate.NewClient(c.Config, c.Args) + if err != nil { + return nil, err + } + + response, err := client.Authenticate(action, repo, userId) + if err != nil { + return nil, err + } + + basicAuth := base64.StdEncoding.EncodeToString([]byte(response.Username + ":" + response.LfsToken)) + payload := &Payload{ + Header: PayloadHeader{Auth: "Basic " + basicAuth}, + Href: response.RepoPath + "/info/lfs", + ExpiresIn: response.ExpiresIn, + } + + return json.Marshal(payload) +} diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go new file mode 100644 index 0000000..f2ccc20 --- /dev/null +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -0,0 +1,153 @@ +package lfsauthenticate + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/lfsauthenticate" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers" +) + +func TestFailedRequests(t *testing.T) { + requests := requesthandlers.BuildDisallowedByApiHandlers(t) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.Shell + expectedOutput string + }{ + { + desc: "With missing arguments", + arguments: &commandargs.Shell{}, + expectedOutput: "> GitLab: Disallowed command", + }, + { + desc: "With disallowed command", + arguments: &commandargs.Shell{GitlabKeyId: "1", SshArgs: []string{"git-lfs-authenticate", "group/repo", "unknown"}}, + expectedOutput: "> GitLab: Disallowed command", + }, + { + desc: "With disallowed user", + arguments: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}}, + expectedOutput: "Disallowed by API call", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + output := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + } + + err := cmd.Execute() + require.Error(t, err) + + require.Equal(t, tc.expectedOutput, err.Error()) + }) + } +} + +func TestLfsAuthenticateRequests(t *testing.T) { + userId := "123" + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/lfs_authenticate", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + require.NoError(t, err) + + var request *lfsauthenticate.Request + require.NoError(t, json.Unmarshal(b, &request)) + + if request.UserId == userId { + body := map[string]interface{}{ + "username": "john", + "lfs_token": "sometoken", + "repository_http_path": "https://gitlab.com/repo/path", + "expires_in": 1800, + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + } else { + w.WriteHeader(http.StatusForbidden) + } + }, + }, + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + require.NoError(t, err) + + var request *accessverifier.Request + require.NoError(t, json.Unmarshal(b, &request)) + + var glId string + if request.Username == "somename" { + glId = userId + } else { + glId = "100" + } + + body := map[string]interface{}{ + "gl_id": glId, + "status": true, + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + }, + }, + } + + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + testCases := []struct { + desc string + username string + expectedOutput string + }{ + { + desc: "With successful response from API", + username: "somename", + expectedOutput: "{\"header\":{\"Authorization\":\"Basic am9objpzb21ldG9rZW4=\"},\"href\":\"https://gitlab.com/repo/path/info/lfs\",\"expires_in\":1800}\n", + }, + { + desc: "With forbidden response from API", + username: "anothername", + expectedOutput: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + output := &bytes.Buffer{} + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabUsername: tc.username, SshArgs: []string{"git-lfs-authenticate", "group/repo", "upload"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + } + + err := cmd.Execute() + require.NoError(t, err) + + require.Equal(t, tc.expectedOutput, output.String()) + }) + } +} diff --git a/internal/command/readwriter/readwriter.go b/internal/command/readwriter/readwriter.go new file mode 100644 index 0000000..da18d30 --- /dev/null +++ b/internal/command/readwriter/readwriter.go @@ -0,0 +1,9 @@ +package readwriter + +import "io" + +type ReadWriter struct { + Out io.Writer + In io.Reader + ErrOut io.Writer +} diff --git a/internal/command/receivepack/customaction.go b/internal/command/receivepack/customaction.go new file mode 100644 index 0000000..c94ae4c --- /dev/null +++ b/internal/command/receivepack/customaction.go @@ -0,0 +1,99 @@ +package receivepack + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" +) + +type Request struct { + SecretToken []byte `json:"secret_token"` + Data accessverifier.CustomPayloadData `json:"data"` + Output []byte `json:"output"` +} + +type Response struct { + Result []byte `json:"result"` + Message string `json:"message"` +} + +func (c *Command) processCustomAction(response *accessverifier.Response) error { + data := response.Payload.Data + apiEndpoints := data.ApiEndpoints + + if len(apiEndpoints) == 0 { + return errors.New("Custom action error: Empty API endpoints") + } + + c.displayInfoMessage(data.InfoMessage) + + return c.processApiEndpoints(response) +} + +func (c *Command) displayInfoMessage(infoMessage string) { + messages := strings.Split(infoMessage, "\n") + + for _, msg := range messages { + fmt.Fprintf(c.ReadWriter.ErrOut, "> GitLab: %v\n", msg) + } +} + +func (c *Command) processApiEndpoints(response *accessverifier.Response) error { + client, err := gitlabnet.GetClient(c.Config) + + if err != nil { + return err + } + + data := response.Payload.Data + request := &Request{Data: data} + request.Data.UserId = response.Who + + for _, endpoint := range data.ApiEndpoints { + response, err := c.performRequest(client, endpoint, request) + if err != nil { + return err + } + + if err = c.displayResult(response.Result); err != nil { + return err + } + + // In the context of the git push sequence of events, it's necessary to read + // stdin in order to capture output to pass onto subsequent commands + output, err := ioutil.ReadAll(c.ReadWriter.In) + if err != nil { + return err + } + request.Output = output + } + + return nil +} + +func (c *Command) performRequest(client *gitlabnet.GitlabClient, endpoint string, request *Request) (*Response, error) { + response, err := client.DoRequest(http.MethodPost, endpoint, request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + cr := &Response{} + if err := gitlabnet.ParseJSON(response, cr); err != nil { + return nil, err + } + + return cr, nil +} + +func (c *Command) displayResult(result []byte) error { + _, err := io.Copy(c.ReadWriter.Out, bytes.NewReader(result)) + return err +} diff --git a/internal/command/receivepack/customaction_test.go b/internal/command/receivepack/customaction_test.go new file mode 100644 index 0000000..2a4a718 --- /dev/null +++ b/internal/command/receivepack/customaction_test.go @@ -0,0 +1,105 @@ +package receivepack + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +func TestCustomReceivePack(t *testing.T) { + repo := "group/repo" + keyId := "1" + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var request *accessverifier.Request + require.NoError(t, json.Unmarshal(b, &request)) + + require.Equal(t, "1", request.KeyId) + + body := map[string]interface{}{ + "status": true, + "gl_id": "1", + "payload": map[string]interface{}{ + "action": "geo_proxy_to_primary", + "data": map[string]interface{}{ + "api_endpoints": []string{"/geo/proxy_git_push_ssh/info_refs", "/geo/proxy_git_push_ssh/push"}, + "gl_username": "custom", + "primary_repo": "https://repo/path", + "info_message": "info_message\none more message", + }, + }, + } + w.WriteHeader(http.StatusMultipleChoices) + require.NoError(t, json.NewEncoder(w).Encode(body)) + }, + }, + { + Path: "/geo/proxy_git_push_ssh/info_refs", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var request *Request + require.NoError(t, json.Unmarshal(b, &request)) + + require.Equal(t, request.Data.UserId, "key-"+keyId) + require.Empty(t, request.Output) + + err = json.NewEncoder(w).Encode(Response{Result: []byte("custom")}) + require.NoError(t, err) + }, + }, + { + Path: "/geo/proxy_git_push_ssh/push", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var request *Request + require.NoError(t, json.Unmarshal(b, &request)) + + require.Equal(t, request.Data.UserId, "key-"+keyId) + require.Equal(t, "input", string(request.Output)) + + err = json.NewEncoder(w).Encode(Response{Result: []byte("output")}) + require.NoError(t, err) + }, + }, + } + + url, cleanup := testserver.StartSocketHttpServer(t, requests) + defer cleanup() + + outBuf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + input := bytes.NewBufferString("input") + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabKeyId: keyId, CommandType: commandargs.ReceivePack, SshArgs: []string{"git-receive-pack", repo}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: errBuf, Out: outBuf, In: input}, + } + + require.NoError(t, cmd.Execute()) + + // expect printing of info message, "custom" string from the first request + // and "output" string from the second request + require.Equal(t, "> GitLab: info_message\n> GitLab: one more message\n", errBuf.String()) + require.Equal(t, "customoutput", outBuf.String()) +} diff --git a/internal/command/receivepack/gitalycall.go b/internal/command/receivepack/gitalycall.go new file mode 100644 index 0000000..f440672 --- /dev/null +++ b/internal/command/receivepack/gitalycall.go @@ -0,0 +1,39 @@ +package receivepack + +import ( + "context" + + "google.golang.org/grpc" + + "gitlab.com/gitlab-org/gitaly/client" + pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/handler" +) + +func (c *Command) performGitalyCall(response *accessverifier.Response) error { + gc := &handler.GitalyCommand{ + Config: c.Config, + ServiceName: string(commandargs.ReceivePack), + Address: response.Gitaly.Address, + Token: response.Gitaly.Token, + } + + request := &pb.SSHReceivePackRequest{ + Repository: &response.Gitaly.Repo, + GlId: response.UserId, + GlRepository: response.Repo, + GlUsername: response.Username, + GitProtocol: response.GitProtocol, + GitConfigOptions: response.GitConfigOptions, + } + + return gc.RunGitalyCommand(func(ctx context.Context, conn *grpc.ClientConn) (int32, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + rw := c.ReadWriter + return client.ReceivePack(ctx, conn, rw.In, rw.Out, rw.ErrOut, request) + }) +} diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go new file mode 100644 index 0000000..361596a --- /dev/null +++ b/internal/command/receivepack/gitalycall_test.go @@ -0,0 +1,40 @@ +package receivepack + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers" +) + +func TestReceivePack(t *testing.T) { + gitalyAddress, cleanup := testserver.StartGitalyServer(t) + defer cleanup() + + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + input := &bytes.Buffer{} + + userId := "1" + repo := "group/repo" + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabKeyId: userId, CommandType: commandargs.ReceivePack, SshArgs: []string{"git-receive-pack", repo}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, + } + + err := cmd.Execute() + require.NoError(t, err) + + require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String()) +} diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go new file mode 100644 index 0000000..aaaf7b0 --- /dev/null +++ b/internal/command/receivepack/receivepack.go @@ -0,0 +1,40 @@ +package receivepack + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +type Command struct { + Config *config.Config + Args *commandargs.Shell + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + args := c.Args.SshArgs + if len(args) != 2 { + return disallowedcommand.Error + } + + repo := args[1] + response, err := c.verifyAccess(repo) + if err != nil { + return err + } + + if response.IsCustomAction() { + return c.processCustomAction(response) + } + + return c.performGitalyCall(response) +} + +func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { + cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} + + return cmd.Verify(c.Args.CommandType, repo) +} diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go new file mode 100644 index 0000000..1d7bd21 --- /dev/null +++ b/internal/command/receivepack/receivepack_test.go @@ -0,0 +1,32 @@ +package receivepack + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers" +) + +func TestForbiddenAccess(t *testing.T) { + requests := requesthandlers.BuildDisallowedByApiHandlers(t) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + input := bytes.NewBufferString("input") + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-receive-pack", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, + } + + err := cmd.Execute() + require.Equal(t, "Disallowed by API call", err.Error()) +} diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go new file mode 100644 index 0000000..3aaf98d --- /dev/null +++ b/internal/command/shared/accessverifier/accessverifier.go @@ -0,0 +1,45 @@ +package accessverifier + +import ( + "errors" + "fmt" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" +) + +type Response = accessverifier.Response + +type Command struct { + Config *config.Config + Args *commandargs.Shell + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) { + client, err := accessverifier.NewClient(c.Config) + if err != nil { + return nil, err + } + + response, err := client.Verify(c.Args, action, repo) + if err != nil { + return nil, err + } + + c.displayConsoleMessages(response.ConsoleMessages) + + if !response.Success { + return nil, errors.New(response.Message) + } + + return response, nil +} + +func (c *Command) displayConsoleMessages(messages []string) { + for _, msg := range messages { + fmt.Fprintf(c.ReadWriter.ErrOut, "> GitLab: %v\n", msg) + } +} diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go new file mode 100644 index 0000000..39c2a66 --- /dev/null +++ b/internal/command/shared/accessverifier/accessverifier_test.go @@ -0,0 +1,82 @@ +package accessverifier + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +var ( + repo = "group/repo" + action = commandargs.ReceivePack +) + +func setup(t *testing.T) (*Command, *bytes.Buffer, *bytes.Buffer, func()) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var requestBody *accessverifier.Request + err = json.Unmarshal(b, &requestBody) + require.NoError(t, err) + + if requestBody.KeyId == "1" { + body := map[string]interface{}{ + "gl_console_messages": []string{"console", "message"}, + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + } else { + body := map[string]interface{}{ + "status": false, + "message": "missing user", + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + } + }, + }, + } + + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + errBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + + readWriter := &readwriter.ReadWriter{Out: outBuf, ErrOut: errBuf} + cmd := &Command{Config: &config.Config{GitlabUrl: url}, ReadWriter: readWriter} + + return cmd, errBuf, outBuf, cleanup +} + +func TestMissingUser(t *testing.T) { + cmd, _, _, cleanup := setup(t) + defer cleanup() + + cmd.Args = &commandargs.Shell{GitlabKeyId: "2"} + _, err := cmd.Verify(action, repo) + + require.Equal(t, "missing user", err.Error()) +} + +func TestConsoleMessages(t *testing.T) { + cmd, errBuf, outBuf, cleanup := setup(t) + defer cleanup() + + cmd.Args = &commandargs.Shell{GitlabKeyId: "1"} + cmd.Verify(action, repo) + + require.Equal(t, "> GitLab: console\n> GitLab: message\n", errBuf.String()) + require.Empty(t, outBuf.String()) +} diff --git a/internal/command/shared/disallowedcommand/disallowedcommand.go b/internal/command/shared/disallowedcommand/disallowedcommand.go new file mode 100644 index 0000000..3c98bcc --- /dev/null +++ b/internal/command/shared/disallowedcommand/disallowedcommand.go @@ -0,0 +1,7 @@ +package disallowedcommand + +import "errors" + +var ( + Error = errors.New("> GitLab: Disallowed command") +) diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go new file mode 100644 index 0000000..2f13cc5 --- /dev/null +++ b/internal/command/twofactorrecover/twofactorrecover.go @@ -0,0 +1,65 @@ +package twofactorrecover + +import ( + "fmt" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/twofactorrecover" +) + +type Command struct { + Config *config.Config + Args *commandargs.Shell + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + if c.canContinue() { + c.displayRecoveryCodes() + } else { + fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") + } + + return nil +} + +func (c *Command) canContinue() bool { + question := + "Are you sure you want to generate new two-factor recovery codes?\n" + + "Any existing recovery codes you saved will be invalidated. (yes/no)" + fmt.Fprintln(c.ReadWriter.Out, question) + + var answer string + fmt.Fscanln(c.ReadWriter.In, &answer) + + return answer == "yes" +} + +func (c *Command) displayRecoveryCodes() { + codes, err := c.getRecoveryCodes() + + if err == nil { + messageWithCodes := + "\nYour two-factor authentication recovery codes are:\n\n" + + strings.Join(codes, "\n") + + "\n\nDuring sign in, use one of the codes above when prompted for\n" + + "your two-factor code. Then, visit your Profile Settings and add\n" + + "a new device so you do not lose access to your account again.\n" + fmt.Fprint(c.ReadWriter.Out, messageWithCodes) + } else { + fmt.Fprintf(c.ReadWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err) + } +} + +func (c *Command) getRecoveryCodes() ([]string, error) { + client, err := twofactorrecover.NewClient(c.Config) + + if err != nil { + return nil, err + } + + return client.GetRecoveryCodes(c.Args) +} diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go new file mode 100644 index 0000000..283c45a --- /dev/null +++ b/internal/command/twofactorrecover/twofactorrecover_test.go @@ -0,0 +1,136 @@ +package twofactorrecover + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/twofactorrecover" +) + +var ( + requests []testserver.TestRequestHandler +) + +func setup(t *testing.T) { + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/two_factor_recovery_codes", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + var requestBody *twofactorrecover.RequestBody + json.Unmarshal(b, &requestBody) + + switch requestBody.KeyId { + case "1": + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery", "codes"}, + } + json.NewEncoder(w).Encode(body) + case "forbidden": + body := map[string]interface{}{ + "success": false, + "message": "Forbidden!", + } + json.NewEncoder(w).Encode(body) + case "broken": + w.WriteHeader(http.StatusInternalServerError) + } + }, + }, + } +} + +const ( + question = "Are you sure you want to generate new two-factor recovery codes?\n" + + "Any existing recovery codes you saved will be invalidated. (yes/no)\n\n" + errorHeader = "An error occurred while trying to generate new recovery codes.\n" +) + +func TestExecute(t *testing.T) { + setup(t) + + url, cleanup := testserver.StartSocketHttpServer(t, requests) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.Shell + answer string + expectedOutput string + }{ + { + desc: "With a known key id", + arguments: &commandargs.Shell{GitlabKeyId: "1"}, + answer: "yes\n", + expectedOutput: question + + "Your two-factor authentication recovery codes are:\n\nrecovery\ncodes\n\n" + + "During sign in, use one of the codes above when prompted for\n" + + "your two-factor code. Then, visit your Profile Settings and add\n" + + "a new device so you do not lose access to your account again.\n", + }, + { + desc: "With bad response", + arguments: &commandargs.Shell{GitlabKeyId: "-1"}, + answer: "yes\n", + expectedOutput: question + errorHeader + "Parsing failed\n", + }, + { + desc: "With API returns an error", + arguments: &commandargs.Shell{GitlabKeyId: "forbidden"}, + answer: "yes\n", + expectedOutput: question + errorHeader + "Forbidden!\n", + }, + { + desc: "With API fails", + arguments: &commandargs.Shell{GitlabKeyId: "broken"}, + answer: "yes\n", + expectedOutput: question + errorHeader + "Internal API error (500)\n", + }, + { + desc: "With missing arguments", + arguments: &commandargs.Shell{}, + answer: "yes\n", + expectedOutput: question + errorHeader + "who='' is invalid\n", + }, + { + desc: "With negative answer", + arguments: &commandargs.Shell{}, + answer: "no\n", + expectedOutput: question + + "New recovery codes have *not* been generated. Existing codes will remain valid.\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + output := &bytes.Buffer{} + input := bytes.NewBufferString(tc.answer) + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: tc.arguments, + ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, + } + + err := cmd.Execute() + + assert.NoError(t, err) + assert.Equal(t, tc.expectedOutput, output.String()) + }) + } +} diff --git a/internal/command/uploadarchive/gitalycall.go b/internal/command/uploadarchive/gitalycall.go new file mode 100644 index 0000000..1dfc864 --- /dev/null +++ b/internal/command/uploadarchive/gitalycall.go @@ -0,0 +1,32 @@ +package uploadarchive + +import ( + "context" + + "google.golang.org/grpc" + + "gitlab.com/gitlab-org/gitaly/client" + pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/handler" +) + +func (c *Command) performGitalyCall(response *accessverifier.Response) error { + gc := &handler.GitalyCommand{ + Config: c.Config, + ServiceName: string(commandargs.UploadArchive), + Address: response.Gitaly.Address, + Token: response.Gitaly.Token, + } + + request := &pb.SSHUploadArchiveRequest{Repository: &response.Gitaly.Repo} + + return gc.RunGitalyCommand(func(ctx context.Context, conn *grpc.ClientConn) (int32, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + rw := c.ReadWriter + return client.UploadArchive(ctx, conn, rw.In, rw.Out, rw.ErrOut, request) + }) +} diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go new file mode 100644 index 0000000..5c5353f --- /dev/null +++ b/internal/command/uploadarchive/gitalycall_test.go @@ -0,0 +1,40 @@ +package uploadarchive + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers" +) + +func TestUploadPack(t *testing.T) { + gitalyAddress, cleanup := testserver.StartGitalyServer(t) + defer cleanup() + + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + input := &bytes.Buffer{} + + userId := "1" + repo := "group/repo" + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabKeyId: userId, CommandType: commandargs.UploadArchive, SshArgs: []string{"git-upload-archive", repo}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, + } + + err := cmd.Execute() + require.NoError(t, err) + + require.Equal(t, "UploadArchive: "+repo, output.String()) +} diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go new file mode 100644 index 0000000..9d4fbe0 --- /dev/null +++ b/internal/command/uploadarchive/uploadarchive.go @@ -0,0 +1,36 @@ +package uploadarchive + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +type Command struct { + Config *config.Config + Args *commandargs.Shell + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + args := c.Args.SshArgs + if len(args) != 2 { + return disallowedcommand.Error + } + + repo := args[1] + response, err := c.verifyAccess(repo) + if err != nil { + return err + } + + return c.performGitalyCall(response) +} + +func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { + cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} + + return cmd.Verify(c.Args.CommandType, repo) +} diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go new file mode 100644 index 0000000..50f3f7e --- /dev/null +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -0,0 +1,31 @@ +package uploadarchive + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers" +) + +func TestForbiddenAccess(t *testing.T) { + requests := requesthandlers.BuildDisallowedByApiHandlers(t) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-archive", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + } + + err := cmd.Execute() + require.Equal(t, "Disallowed by API call", err.Error()) +} diff --git a/internal/command/uploadpack/gitalycall.go b/internal/command/uploadpack/gitalycall.go new file mode 100644 index 0000000..8b97dee --- /dev/null +++ b/internal/command/uploadpack/gitalycall.go @@ -0,0 +1,36 @@ +package uploadpack + +import ( + "context" + + "google.golang.org/grpc" + + "gitlab.com/gitlab-org/gitaly/client" + pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/handler" +) + +func (c *Command) performGitalyCall(response *accessverifier.Response) error { + gc := &handler.GitalyCommand{ + Config: c.Config, + ServiceName: string(commandargs.UploadPack), + Address: response.Gitaly.Address, + Token: response.Gitaly.Token, + } + + request := &pb.SSHUploadPackRequest{ + Repository: &response.Gitaly.Repo, + GitProtocol: response.GitProtocol, + GitConfigOptions: response.GitConfigOptions, + } + + return gc.RunGitalyCommand(func(ctx context.Context, conn *grpc.ClientConn) (int32, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + rw := c.ReadWriter + return client.UploadPack(ctx, conn, rw.In, rw.Out, rw.ErrOut, request) + }) +} diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go new file mode 100644 index 0000000..71a253b --- /dev/null +++ b/internal/command/uploadpack/gitalycall_test.go @@ -0,0 +1,40 @@ +package uploadpack + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers" +) + +func TestUploadPack(t *testing.T) { + gitalyAddress, cleanup := testserver.StartGitalyServer(t) + defer cleanup() + + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + input := &bytes.Buffer{} + + userId := "1" + repo := "group/repo" + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabKeyId: userId, CommandType: commandargs.UploadPack, SshArgs: []string{"git-upload-pack", repo}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, + } + + err := cmd.Execute() + require.NoError(t, err) + + require.Equal(t, "UploadPack: "+repo, output.String()) +} diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go new file mode 100644 index 0000000..a5c71b2 --- /dev/null +++ b/internal/command/uploadpack/uploadpack.go @@ -0,0 +1,36 @@ +package uploadpack + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +type Command struct { + Config *config.Config + Args *commandargs.Shell + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + args := c.Args.SshArgs + if len(args) != 2 { + return disallowedcommand.Error + } + + repo := args[1] + response, err := c.verifyAccess(repo) + if err != nil { + return err + } + + return c.performGitalyCall(response) +} + +func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { + cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} + + return cmd.Verify(c.Args.CommandType, repo) +} diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go new file mode 100644 index 0000000..04fe2ba --- /dev/null +++ b/internal/command/uploadpack/uploadpack_test.go @@ -0,0 +1,31 @@ +package uploadpack + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper/requesthandlers" +) + +func TestForbiddenAccess(t *testing.T) { + requests := requesthandlers.BuildDisallowedByApiHandlers(t) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-pack", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + } + + err := cmd.Execute() + require.Equal(t, "Disallowed by API call", err.Error()) +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..2231851 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,123 @@ +package config + +import ( + "io/ioutil" + "net/url" + "os" + "path" + "path/filepath" + + yaml "gopkg.in/yaml.v2" +) + +const ( + configFile = "config.yml" + logFile = "gitlab-shell.log" + defaultSecretFileName = ".gitlab_shell_secret" +) + +type HttpSettingsConfig struct { + User string `yaml:"user"` + Password string `yaml:"password"` + ReadTimeoutSeconds uint64 `yaml:"read_timeout"` + CaFile string `yaml:"ca_file"` + CaPath string `yaml:"ca_path"` + SelfSignedCert bool `yaml:"self_signed_cert"` +} + +type Config struct { + RootDir string + LogFile string `yaml:"log_file"` + LogFormat string `yaml:"log_format"` + GitlabUrl string `yaml:"gitlab_url"` + GitlabTracing string `yaml:"gitlab_tracing"` + SecretFilePath string `yaml:"secret_file"` + Secret string `yaml:"secret"` + HttpSettings HttpSettingsConfig `yaml:"http_settings"` + HttpClient *HttpClient +} + +func New() (*Config, error) { + dir, err := os.Getwd() + if err != nil { + return nil, err + } + + return NewFromDir(dir) +} + +func NewFromDir(dir string) (*Config, error) { + return newFromFile(path.Join(dir, configFile)) +} + +func newFromFile(filename string) (*Config, error) { + cfg := &Config{RootDir: path.Dir(filename)} + + configBytes, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + + if err := parseConfig(configBytes, cfg); err != nil { + return nil, err + } + + return cfg, nil +} + +// parseConfig expects YAML data in configBytes and a Config instance with RootDir set. +func parseConfig(configBytes []byte, cfg *Config) error { + if err := yaml.Unmarshal(configBytes, cfg); err != nil { + return err + } + + if cfg.LogFile == "" { + cfg.LogFile = logFile + } + + if len(cfg.LogFile) > 0 && cfg.LogFile[0] != '/' { + cfg.LogFile = path.Join(cfg.RootDir, cfg.LogFile) + } + + if cfg.LogFormat == "" { + cfg.LogFormat = "text" + } + + if cfg.GitlabUrl != "" { + unescapedUrl, err := url.PathUnescape(cfg.GitlabUrl) + if err != nil { + return err + } + + cfg.GitlabUrl = unescapedUrl + } + + if err := parseSecret(cfg); err != nil { + return err + } + + return nil +} + +func parseSecret(cfg *Config) error { + // The secret was parsed from yaml no need to read another file + if cfg.Secret != "" { + return nil + } + + if cfg.SecretFilePath == "" { + cfg.SecretFilePath = defaultSecretFileName + } + + if !filepath.IsAbs(cfg.SecretFilePath) { + cfg.SecretFilePath = path.Join(cfg.RootDir, cfg.SecretFilePath) + } + + secretFileContent, err := ioutil.ReadFile(cfg.SecretFilePath) + if err != nil { + return err + } + cfg.Secret = string(secretFileContent) + + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..202db6d --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,112 @@ +package config + +import ( + "fmt" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +const ( + customSecret = "custom/my-contents-is-secret" +) + +var ( + testRoot = testhelper.TestRoot +) + +func TestParseConfig(t *testing.T) { + cleanup, err := testhelper.PrepareTestRootDir() + require.NoError(t, err) + defer cleanup() + + testCases := []struct { + yaml string + path string + format string + gitlabUrl string + secret string + httpSettings HttpSettingsConfig + }{ + { + path: path.Join(testRoot, "gitlab-shell.log"), + format: "text", + secret: "default-secret-content", + }, + { + yaml: "log_file: my-log.log", + path: path.Join(testRoot, "my-log.log"), + format: "text", + secret: "default-secret-content", + }, + { + yaml: "log_file: /qux/my-log.log", + path: "/qux/my-log.log", + format: "text", + secret: "default-secret-content", + }, + { + yaml: "log_format: json", + path: path.Join(testRoot, "gitlab-shell.log"), + format: "json", + secret: "default-secret-content", + }, + { + yaml: "gitlab_url: http+unix://%2Fpath%2Fto%2Fgitlab%2Fgitlab.socket", + path: path.Join(testRoot, "gitlab-shell.log"), + format: "text", + gitlabUrl: "http+unix:///path/to/gitlab/gitlab.socket", + secret: "default-secret-content", + }, + { + yaml: fmt.Sprintf("secret_file: %s", customSecret), + path: path.Join(testRoot, "gitlab-shell.log"), + format: "text", + secret: "custom-secret-content", + }, + { + yaml: fmt.Sprintf("secret_file: %s", path.Join(testRoot, customSecret)), + path: path.Join(testRoot, "gitlab-shell.log"), + format: "text", + secret: "custom-secret-content", + }, + { + yaml: "secret: an inline secret", + path: path.Join(testRoot, "gitlab-shell.log"), + format: "text", + secret: "an inline secret", + }, + { + yaml: "http_settings:\n user: user_basic_auth\n password: password_basic_auth\n read_timeout: 500", + path: path.Join(testRoot, "gitlab-shell.log"), + format: "text", + secret: "default-secret-content", + httpSettings: HttpSettingsConfig{User: "user_basic_auth", Password: "password_basic_auth", ReadTimeoutSeconds: 500}, + }, + { + yaml: "http_settings:\n ca_file: /etc/ssl/cert.pem\n ca_path: /etc/pki/tls/certs\n self_signed_cert: true", + path: path.Join(testRoot, "gitlab-shell.log"), + format: "text", + secret: "default-secret-content", + httpSettings: HttpSettingsConfig{CaFile: "/etc/ssl/cert.pem", CaPath: "/etc/pki/tls/certs", SelfSignedCert: true}, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("yaml input: %q", tc.yaml), func(t *testing.T) { + cfg := Config{RootDir: testRoot} + + err := parseConfig([]byte(tc.yaml), &cfg) + require.NoError(t, err) + + assert.Equal(t, tc.path, cfg.LogFile) + assert.Equal(t, tc.format, cfg.LogFormat) + assert.Equal(t, tc.gitlabUrl, cfg.GitlabUrl) + assert.Equal(t, tc.secret, cfg.Secret) + assert.Equal(t, tc.httpSettings, cfg.HttpSettings) + }) + } +} diff --git a/internal/config/httpclient.go b/internal/config/httpclient.go new file mode 100644 index 0000000..c71efad --- /dev/null +++ b/internal/config/httpclient.go @@ -0,0 +1,122 @@ +package config + +import ( + "context" + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "net/http" + "path/filepath" + "strings" + "time" +) + +const ( + socketBaseUrl = "http://unix" + unixSocketProtocol = "http+unix://" + httpProtocol = "http://" + httpsProtocol = "https://" + defaultReadTimeoutSeconds = 300 +) + +type HttpClient struct { + HttpClient *http.Client + Host string +} + +func (c *Config) GetHttpClient() *HttpClient { + if c.HttpClient != nil { + return c.HttpClient + } + + var transport *http.Transport + var host string + if strings.HasPrefix(c.GitlabUrl, unixSocketProtocol) { + transport, host = c.buildSocketTransport() + } else if strings.HasPrefix(c.GitlabUrl, httpProtocol) { + transport, host = c.buildHttpTransport() + } else if strings.HasPrefix(c.GitlabUrl, httpsProtocol) { + transport, host = c.buildHttpsTransport() + } else { + return nil + } + + httpClient := &http.Client{ + Transport: transport, + Timeout: c.readTimeout(), + } + + client := &HttpClient{HttpClient: httpClient, Host: host} + + c.HttpClient = client + + return client +} + +func (c *Config) buildSocketTransport() (*http.Transport, string) { + socketPath := strings.TrimPrefix(c.GitlabUrl, unixSocketProtocol) + transport := &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", socketPath) + }, + } + + return transport, socketBaseUrl +} + +func (c *Config) buildHttpsTransport() (*http.Transport, string) { + certPool, err := x509.SystemCertPool() + + if err != nil { + certPool = x509.NewCertPool() + } + + caFile := c.HttpSettings.CaFile + if caFile != "" { + addCertToPool(certPool, caFile) + } + + caPath := c.HttpSettings.CaPath + if caPath != "" { + fis, _ := ioutil.ReadDir(caPath) + for _, fi := range fis { + if fi.IsDir() { + continue + } + + addCertToPool(certPool, filepath.Join(caPath, fi.Name())) + } + } + + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + InsecureSkipVerify: c.HttpSettings.SelfSignedCert, + }, + } + + return transport, c.GitlabUrl +} + +func addCertToPool(certPool *x509.CertPool, fileName string) { + cert, err := ioutil.ReadFile(fileName) + if err == nil { + certPool.AppendCertsFromPEM(cert) + } +} + +func (c *Config) buildHttpTransport() (*http.Transport, string) { + return &http.Transport{}, c.GitlabUrl +} + +func (c *Config) readTimeout() time.Duration { + timeoutSeconds := c.HttpSettings.ReadTimeoutSeconds + + if timeoutSeconds == 0 { + timeoutSeconds = defaultReadTimeoutSeconds + } + + return time.Duration(timeoutSeconds) * time.Second +} diff --git a/internal/config/httpclient_test.go b/internal/config/httpclient_test.go new file mode 100644 index 0000000..474deba --- /dev/null +++ b/internal/config/httpclient_test.go @@ -0,0 +1,22 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadTimeout(t *testing.T) { + expectedSeconds := uint64(300) + + config := &Config{ + GitlabUrl: "http://localhost:3000", + HttpSettings: HttpSettingsConfig{ReadTimeoutSeconds: expectedSeconds}, + } + client := config.GetHttpClient() + + require.NotNil(t, client) + assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.HttpClient.Timeout) +} diff --git a/internal/executable/executable.go b/internal/executable/executable.go new file mode 100644 index 0000000..c6355b9 --- /dev/null +++ b/internal/executable/executable.go @@ -0,0 +1,60 @@ +package executable + +import ( + "os" + "path/filepath" +) + +const ( + BinDir = "bin" + Healthcheck = "check" + GitlabShell = "gitlab-shell" + AuthorizedKeysCheck = "gitlab-shell-authorized-keys-check" + AuthorizedPrincipalsCheck = "gitlab-shell-authorized-principals-check" +) + +type Executable struct { + Name string + RootDir string +} + +var ( + // osExecutable is overridden in tests + osExecutable = os.Executable +) + +func New(name string) (*Executable, error) { + path, err := osExecutable() + if err != nil { + return nil, err + } + + rootDir, err := findRootDir(path) + if err != nil { + return nil, err + } + + executable := &Executable{ + Name: name, + RootDir: rootDir, + } + + return executable, nil +} + +func findRootDir(path string) (string, error) { + // Start: /opt/.../gitlab-shell/bin/gitlab-shell + // Ends: /opt/.../gitlab-shell + rootDir := filepath.Dir(filepath.Dir(path)) + pathFromEnv := os.Getenv("GITLAB_SHELL_DIR") + + if pathFromEnv != "" { + if _, err := os.Stat(pathFromEnv); os.IsNotExist(err) { + return "", err + } + + rootDir = pathFromEnv + } + + return rootDir, nil +} diff --git a/internal/executable/executable_test.go b/internal/executable/executable_test.go new file mode 100644 index 0000000..3915f1a --- /dev/null +++ b/internal/executable/executable_test.go @@ -0,0 +1,104 @@ +package executable + +import ( + "errors" + "testing" + + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +type fakeOs struct { + OldExecutable func() (string, error) + Path string + Error error +} + +func (f *fakeOs) Executable() (string, error) { + return f.Path, f.Error +} + +func (f *fakeOs) Setup() { + f.OldExecutable = osExecutable + osExecutable = f.Executable +} + +func (f *fakeOs) Cleanup() { + osExecutable = f.OldExecutable +} + +func TestNewSuccess(t *testing.T) { + testCases := []struct { + desc string + fakeOs *fakeOs + environment map[string]string + expectedRootDir string + }{ + { + desc: "GITLAB_SHELL_DIR env var is not defined", + fakeOs: &fakeOs{Path: "/tmp/bin/gitlab-shell"}, + expectedRootDir: "/tmp", + }, + { + desc: "GITLAB_SHELL_DIR env var is defined", + fakeOs: &fakeOs{Path: "/opt/bin/gitlab-shell"}, + environment: map[string]string{ + "GITLAB_SHELL_DIR": "/tmp", + }, + expectedRootDir: "/tmp", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + restoreEnv := testhelper.TempEnv(tc.environment) + defer restoreEnv() + + fake := tc.fakeOs + fake.Setup() + defer fake.Cleanup() + + result, err := New("gitlab-shell") + + require.NoError(t, err) + require.Equal(t, result.Name, "gitlab-shell") + require.Equal(t, result.RootDir, tc.expectedRootDir) + }) + } +} + +func TestNewFailure(t *testing.T) { + testCases := []struct { + desc string + fakeOs *fakeOs + environment map[string]string + }{ + { + desc: "failed to determine executable", + fakeOs: &fakeOs{Path: "", Error: errors.New("error")}, + }, + { + desc: "GITLAB_SHELL_DIR doesn't exist", + fakeOs: &fakeOs{Path: "/tmp/bin/gitlab-shell"}, + environment: map[string]string{ + "GITLAB_SHELL_DIR": "/tmp/non/existing/directory", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + restoreEnv := testhelper.TempEnv(tc.environment) + defer restoreEnv() + + fake := tc.fakeOs + fake.Setup() + defer fake.Cleanup() + + _, err := New("gitlab-shell") + + require.Error(t, err) + }) + } +} diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go new file mode 100644 index 0000000..217dcdb --- /dev/null +++ b/internal/gitlabnet/accessverifier/client.go @@ -0,0 +1,115 @@ +package accessverifier + +import ( + "fmt" + "net/http" + + pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" +) + +const ( + protocol = "ssh" + anyChanges = "_any" +) + +type Client struct { + client *gitlabnet.GitlabClient +} + +type Request struct { + Action commandargs.CommandType `json:"action"` + Repo string `json:"project"` + Changes string `json:"changes"` + Protocol string `json:"protocol"` + KeyId string `json:"key_id,omitempty"` + Username string `json:"username,omitempty"` + CheckIp string `json:"check_ip,omitempty"` +} + +type Gitaly struct { + Repo pb.Repository `json:"repository"` + Address string `json:"address"` + Token string `json:"token"` +} + +type CustomPayloadData struct { + ApiEndpoints []string `json:"api_endpoints"` + Username string `json:"gl_username"` + PrimaryRepo string `json:"primary_repo"` + InfoMessage string `json:"info_message"` + UserId string `json:"gl_id,omitempty"` +} + +type CustomPayload struct { + Action string `json:"action"` + Data CustomPayloadData `json:"data"` +} + +type Response struct { + Success bool `json:"status"` + Message string `json:"message"` + Repo string `json:"gl_repository"` + UserId string `json:"gl_id"` + Username string `json:"gl_username"` + GitConfigOptions []string `json:"git_config_options"` + Gitaly Gitaly `json:"gitaly"` + GitProtocol string `json:"git_protocol"` + Payload CustomPayload `json:"payload"` + ConsoleMessages []string `json:"gl_console_messages"` + Who string + StatusCode int +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{client: client}, nil +} + +func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) { + request := &Request{Action: action, Repo: repo, Protocol: protocol, Changes: anyChanges} + + if args.GitlabUsername != "" { + request.Username = args.GitlabUsername + } else { + request.KeyId = args.GitlabKeyId + } + + request.CheckIp = sshenv.LocalAddr() + + response, err := c.client.Post("/allowed", request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + return parse(response, args) +} + +func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) { + response := &Response{} + if err := gitlabnet.ParseJSON(hr, response); err != nil { + return nil, err + } + + if args.GitlabKeyId != "" { + response.Who = "key-" + args.GitlabKeyId + } else { + response.Who = response.UserId + } + + response.StatusCode = hr.StatusCode + + return response, nil +} + +func (r *Response) IsCustomAction() bool { + return r.StatusCode == http.StatusMultipleChoices +} diff --git a/internal/gitlabnet/accessverifier/client_test.go b/internal/gitlabnet/accessverifier/client_test.go new file mode 100644 index 0000000..96c80a7 --- /dev/null +++ b/internal/gitlabnet/accessverifier/client_test.go @@ -0,0 +1,209 @@ +package accessverifier + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "path" + "testing" + + "github.com/stretchr/testify/require" + + pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +var ( + repo = "group/private" + action = commandargs.ReceivePack +) + +func buildExpectedResponse(who string) *Response { + response := &Response{ + Success: true, + UserId: "user-1", + Repo: "project-26", + Username: "root", + GitConfigOptions: []string{"option"}, + Gitaly: Gitaly{ + Repo: pb.Repository{ + StorageName: "default", + RelativePath: "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git", + GitObjectDirectory: "path/to/git_object_directory", + GitAlternateObjectDirectories: []string{"path/to/git_alternate_object_directory"}, + GlRepository: "project-26", + GlProjectPath: repo, + }, + Address: "unix:gitaly.socket", + Token: "token", + }, + GitProtocol: "protocol", + Payload: CustomPayload{}, + ConsoleMessages: []string{"console", "message"}, + Who: who, + StatusCode: 200, + } + + return response +} + +func TestSuccessfulResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + args *commandargs.Shell + who string + }{ + { + desc: "Provide key id within the request", + args: &commandargs.Shell{GitlabKeyId: "1"}, + who: "key-1", + }, { + desc: "Provide username within the request", + args: &commandargs.Shell{GitlabUsername: "first"}, + who: "user-1", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := client.Verify(tc.args, action, repo) + require.NoError(t, err) + + response := buildExpectedResponse(tc.who) + require.Equal(t, response, result) + }) + } +} + +func TestGetCustomAction(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.Shell{GitlabUsername: "custom"} + result, err := client.Verify(args, action, repo) + require.NoError(t, err) + + response := buildExpectedResponse("user-1") + response.Payload = CustomPayload{ + Action: "geo_proxy_to_primary", + Data: CustomPayloadData{ + ApiEndpoints: []string{"geo/proxy_git_push_ssh/info_refs", "geo/proxy_git_push_ssh/push"}, + Username: "custom", + PrimaryRepo: "https://repo/path", + InfoMessage: "message", + }, + } + response.StatusCode = 300 + + require.True(t, response.IsCustomAction()) + require.Equal(t, response, result) +} + +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeId string + expectedError string + }{ + { + desc: "A response with an error message", + fakeId: "2", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeId: "3", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeId: "4", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + args := &commandargs.Shell{GitlabKeyId: tc.fakeId} + resp, err := client.Verify(args, action, repo) + + require.EqualError(t, err, tc.expectedError) + require.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + testDirCleanup, err := testhelper.PrepareTestRootDir() + require.NoError(t, err) + defer testDirCleanup() + + body, err := ioutil.ReadFile(path.Join(testhelper.TestRoot, "responses/allowed.json")) + require.NoError(t, err) + + allowedWithPayloadPath := path.Join(testhelper.TestRoot, "responses/allowed_with_payload.json") + bodyWithPayload, err := ioutil.ReadFile(allowedWithPayloadPath) + require.NoError(t, err) + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var requestBody *Request + require.NoError(t, json.Unmarshal(b, &requestBody)) + + switch requestBody.Username { + case "first": + _, err = w.Write(body) + require.NoError(t, err) + case "second": + errBody := map[string]interface{}{ + "status": false, + "message": "missing user", + } + require.NoError(t, json.NewEncoder(w).Encode(errBody)) + case "custom": + w.WriteHeader(http.StatusMultipleChoices) + _, err = w.Write(bodyWithPayload) + require.NoError(t, err) + } + + switch requestBody.KeyId { + case "1": + _, err = w.Write(body) + require.NoError(t, err) + case "2": + w.WriteHeader(http.StatusForbidden) + errBody := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + require.NoError(t, json.NewEncoder(w).Encode(errBody)) + case "3": + w.Write([]byte("{ \"message\": \"broken json!\"")) + case "4": + w.WriteHeader(http.StatusForbidden) + } + }, + }, + } + + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + client, err := NewClient(&config.Config{GitlabUrl: url}) + require.NoError(t, err) + + return client, cleanup +} diff --git a/internal/gitlabnet/authorizedkeys/client.go b/internal/gitlabnet/authorizedkeys/client.go new file mode 100644 index 0000000..ac23a96 --- /dev/null +++ b/internal/gitlabnet/authorizedkeys/client.go @@ -0,0 +1,65 @@ +package authorizedkeys + +import ( + "fmt" + "net/url" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" +) + +const ( + AuthorizedKeysPath = "/authorized_keys" +) + +type Client struct { + config *config.Config + client *gitlabnet.GitlabClient +} + +type Response struct { + Id int64 `json:"id"` + Key string `json:"key"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) GetByKey(key string) (*Response, error) { + path, err := pathWithKey(key) + if err != nil { + return nil, err + } + + response, err := c.client.Get(path) + if err != nil { + return nil, err + } + defer response.Body.Close() + + parsedResponse := &Response{} + if err := gitlabnet.ParseJSON(response, parsedResponse); err != nil { + return nil, err + } + + return parsedResponse, nil +} + +func pathWithKey(key string) (string, error) { + u, err := url.Parse(AuthorizedKeysPath) + if err != nil { + return "", err + } + + params := u.Query() + params.Set("key", key) + u.RawQuery = params.Encode() + + return u.String(), nil +} diff --git a/internal/gitlabnet/authorizedkeys/client_test.go b/internal/gitlabnet/authorizedkeys/client_test.go new file mode 100644 index 0000000..965025f --- /dev/null +++ b/internal/gitlabnet/authorizedkeys/client_test.go @@ -0,0 +1,105 @@ +package authorizedkeys + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +var ( + requests []testserver.TestRequestHandler +) + +func init() { + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/authorized_keys", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key") == "key" { + body := &Response{ + Id: 1, + Key: "public-key", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("key") == "broken-message" { + w.WriteHeader(http.StatusForbidden) + body := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("key") == "broken-json" { + w.Write([]byte("{ \"message\": \"broken json!\"")) + } else if r.URL.Query().Get("key") == "broken-empty" { + w.WriteHeader(http.StatusForbidden) + } else { + w.WriteHeader(http.StatusNotFound) + } + }, + }, + } +} + +func TestGetByKey(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.GetByKey("key") + require.NoError(t, err) + require.Equal(t, &Response{Id: 1, Key: "public-key"}, result) +} + +func TestGetByKeyErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + key string + expectedError string + }{ + { + desc: "A response with an error message", + key: "broken-message", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + key: "broken-json", + expectedError: "Parsing failed", + }, + { + desc: "A forbidden (403) response without message", + key: "broken-empty", + expectedError: "Internal API error (403)", + }, + { + desc: "A not found (404) response without message", + key: "not-found", + expectedError: "Internal API error (404)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + resp, err := client.GetByKey(tc.key) + + require.EqualError(t, err, tc.expectedError) + require.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + client, err := NewClient(&config.Config{GitlabUrl: url}) + require.NoError(t, err) + + return client, cleanup +} diff --git a/internal/gitlabnet/client.go b/internal/gitlabnet/client.go new file mode 100644 index 0000000..bb8655a --- /dev/null +++ b/internal/gitlabnet/client.go @@ -0,0 +1,132 @@ +package gitlabnet + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +const ( + internalApiPath = "/api/v4/internal" + secretHeaderName = "Gitlab-Shared-Secret" +) + +var ( + ParsingError = fmt.Errorf("Parsing failed") +) + +type ErrorResponse struct { + Message string `json:"message"` +} + +type GitlabClient struct { + httpClient *http.Client + config *config.Config + host string +} + +func GetClient(config *config.Config) (*GitlabClient, error) { + client := config.GetHttpClient() + + if client == nil { + return nil, fmt.Errorf("Unsupported protocol") + } + + return &GitlabClient{httpClient: client.HttpClient, config: config, host: client.Host}, nil +} + +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + if !strings.HasPrefix(path, internalApiPath) { + path = internalApiPath + path + } + return path +} + +func newRequest(method, host, path string, data interface{}) (*http.Request, error) { + var jsonReader io.Reader + if data != nil { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + jsonReader = bytes.NewReader(jsonData) + } + + request, err := http.NewRequest(method, host+path, jsonReader) + if err != nil { + return nil, err + } + + return request, nil +} + +func parseError(resp *http.Response) error { + if resp.StatusCode >= 200 && resp.StatusCode <= 399 { + return nil + } + defer resp.Body.Close() + parsedResponse := &ErrorResponse{} + + if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { + return fmt.Errorf("Internal API error (%v)", resp.StatusCode) + } else { + return fmt.Errorf(parsedResponse.Message) + } + +} + +func (c *GitlabClient) Get(path string) (*http.Response, error) { + return c.DoRequest(http.MethodGet, normalizePath(path), nil) +} + +func (c *GitlabClient) Post(path string, data interface{}) (*http.Response, error) { + return c.DoRequest(http.MethodPost, normalizePath(path), data) +} + +func (c *GitlabClient) DoRequest(method, path string, data interface{}) (*http.Response, error) { + request, err := newRequest(method, c.host, path, data) + if err != nil { + return nil, err + } + + user, password := c.config.HttpSettings.User, c.config.HttpSettings.Password + if user != "" && password != "" { + request.SetBasicAuth(user, password) + } + + encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.config.Secret)) + request.Header.Set(secretHeaderName, encodedSecret) + + request.Header.Add("Content-Type", "application/json") + request.Close = true + + response, err := c.httpClient.Do(request) + if err != nil { + return nil, fmt.Errorf("Internal API unreachable") + } + + if err := parseError(response); err != nil { + return nil, err + } + + return response, nil +} + +func ParseJSON(hr *http.Response, response interface{}) error { + if err := json.NewDecoder(hr.Body).Decode(response); err != nil { + return ParsingError + } + + return nil +} diff --git a/internal/gitlabnet/client_test.go b/internal/gitlabnet/client_test.go new file mode 100644 index 0000000..3f96b41 --- /dev/null +++ b/internal/gitlabnet/client_test.go @@ -0,0 +1,219 @@ +package gitlabnet + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestClients(t *testing.T) { + testDirCleanup, err := testhelper.PrepareTestRootDir() + require.NoError(t, err) + defer testDirCleanup() + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/hello", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + + fmt.Fprint(w, "Hello") + }, + }, + { + Path: "/api/v4/internal/post_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + fmt.Fprint(w, "Echo: "+string(b)) + }, + }, + { + Path: "/api/v4/internal/auth", + Handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, r.Header.Get(secretHeaderName)) + }, + }, + { + Path: "/api/v4/internal/error", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + body := map[string]string{ + "message": "Don't do that", + } + json.NewEncoder(w).Encode(body) + }, + }, + { + Path: "/api/v4/internal/broken", + Handler: func(w http.ResponseWriter, r *http.Request) { + panic("Broken") + }, + }, + } + + testCases := []struct { + desc string + config *config.Config + server func(*testing.T, []testserver.TestRequestHandler) (string, func()) + }{ + { + desc: "Socket client", + config: &config.Config{}, + server: testserver.StartSocketHttpServer, + }, + { + desc: "Http client", + config: &config.Config{}, + server: testserver.StartHttpServer, + }, + { + desc: "Https client", + config: &config.Config{ + HttpSettings: config.HttpSettingsConfig{CaFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt")}, + }, + server: testserver.StartHttpsServer, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + url, cleanup := tc.server(t, requests) + defer cleanup() + + tc.config.GitlabUrl = url + tc.config.Secret = "sssh, it's a secret" + + client, err := GetClient(tc.config) + require.NoError(t, err) + + testBrokenRequest(t, client) + testSuccessfulGet(t, client) + testSuccessfulPost(t, client) + testMissing(t, client) + testErrorMessage(t, client) + testAuthenticationHeader(t, client) + }) + } +} + +func testSuccessfulGet(t *testing.T, client *GitlabClient) { + t.Run("Successful get", func(t *testing.T) { + response, err := client.Get("/hello") + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, string(responseBody), "Hello") + }) +} + +func testSuccessfulPost(t *testing.T, client *GitlabClient) { + t.Run("Successful Post", func(t *testing.T) { + data := map[string]string{"key": "value"} + + response, err := client.Post("/post_endpoint", data) + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody)) + }) +} + +func testMissing(t *testing.T, client *GitlabClient) { + t.Run("Missing error for GET", func(t *testing.T) { + response, err := client.Get("/missing") + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + }) + + t.Run("Missing error for POST", func(t *testing.T) { + response, err := client.Post("/missing", map[string]string{}) + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + }) +} + +func testErrorMessage(t *testing.T, client *GitlabClient) { + t.Run("Error with message for GET", func(t *testing.T) { + response, err := client.Get("/error") + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) + + t.Run("Error with message for POST", func(t *testing.T) { + response, err := client.Post("/error", map[string]string{}) + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) +} + +func testBrokenRequest(t *testing.T, client *GitlabClient) { + t.Run("Broken request for GET", func(t *testing.T) { + response, err := client.Get("/broken") + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) + + t.Run("Broken request for POST", func(t *testing.T) { + response, err := client.Post("/broken", map[string]string{}) + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) +} + +func testAuthenticationHeader(t *testing.T, client *GitlabClient) { + t.Run("Authentication headers for GET", func(t *testing.T) { + response, err := client.Get("/auth") + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) + + t.Run("Authentication headers for POST", func(t *testing.T) { + response, err := client.Post("/auth", map[string]string{}) + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) +} diff --git a/internal/gitlabnet/discover/client.go b/internal/gitlabnet/discover/client.go new file mode 100644 index 0000000..3faef53 --- /dev/null +++ b/internal/gitlabnet/discover/client.go @@ -0,0 +1,71 @@ +package discover + +import ( + "fmt" + "net/http" + "net/url" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" +) + +type Client struct { + config *config.Config + client *gitlabnet.GitlabClient +} + +type Response struct { + UserId int64 `json:"id"` + Name string `json:"name"` + Username string `json:"username"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) { + params := url.Values{} + if args.GitlabUsername != "" { + params.Add("username", args.GitlabUsername) + } else if args.GitlabKeyId != "" { + params.Add("key_id", args.GitlabKeyId) + } else { + // There was no 'who' information, this matches the ruby error + // message. + return nil, fmt.Errorf("who='' is invalid") + } + + return c.getResponse(params) +} + +func (c *Client) getResponse(params url.Values) (*Response, error) { + path := "/discover?" + params.Encode() + + response, err := c.client.Get(path) + if err != nil { + return nil, err + } + defer response.Body.Close() + + return parse(response) +} + +func parse(hr *http.Response) (*Response, error) { + response := &Response{} + if err := gitlabnet.ParseJSON(hr, response); err != nil { + return nil, err + } + + return response, nil +} + +func (r *Response) IsAnonymous() bool { + return r.UserId < 1 +} diff --git a/internal/gitlabnet/discover/client_test.go b/internal/gitlabnet/discover/client_test.go new file mode 100644 index 0000000..66e234b --- /dev/null +++ b/internal/gitlabnet/discover/client_test.go @@ -0,0 +1,137 @@ +package discover + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "testing" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + requests []testserver.TestRequestHandler +) + +func init() { + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key_id") == "1" { + body := &Response{ + UserId: 2, + Username: "alex-doe", + Name: "Alex Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "jane-doe" { + body := &Response{ + UserId: 1, + Username: "jane-doe", + Name: "Jane Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_message" { + w.WriteHeader(http.StatusForbidden) + body := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_json" { + w.Write([]byte("{ \"message\": \"broken json!\"")) + } else if r.URL.Query().Get("username") == "broken_empty" { + w.WriteHeader(http.StatusForbidden) + } else { + fmt.Fprint(w, "null") + } + }, + }, + } +} + +func TestGetByKeyId(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + params := url.Values{} + params.Add("key_id", "1") + result, err := client.getResponse(params) + assert.NoError(t, err) + assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result) +} + +func TestGetByUsername(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + params := url.Values{} + params.Add("username", "jane-doe") + result, err := client.getResponse(params) + assert.NoError(t, err) + assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result) +} + +func TestMissingUser(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + params := url.Values{} + params.Add("username", "missing") + result, err := client.getResponse(params) + assert.NoError(t, err) + assert.True(t, result.IsAnonymous()) +} + +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeUsername string + expectedError string + }{ + { + desc: "A response with an error message", + fakeUsername: "broken_message", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeUsername: "broken_json", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeUsername: "broken_empty", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + params := url.Values{} + params.Add("username", tc.fakeUsername) + resp, err := client.getResponse(params) + + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + client, err := NewClient(&config.Config{GitlabUrl: url}) + require.NoError(t, err) + + return client, cleanup +} diff --git a/internal/gitlabnet/healthcheck/client.go b/internal/gitlabnet/healthcheck/client.go new file mode 100644 index 0000000..7db682a --- /dev/null +++ b/internal/gitlabnet/healthcheck/client.go @@ -0,0 +1,54 @@ +package healthcheck + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" +) + +const ( + checkPath = "/check" +) + +type Client struct { + config *config.Config + client *gitlabnet.GitlabClient +} + +type Response struct { + APIVersion string `json:"api_version"` + GitlabVersion string `json:"gitlab_version"` + GitlabRevision string `json:"gitlab_rev"` + Redis bool `json:"redis"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) Check() (*Response, error) { + resp, err := c.client.Get(checkPath) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + return parse(resp) +} + +func parse(hr *http.Response) (*Response, error) { + response := &Response{} + if err := gitlabnet.ParseJSON(hr, response); err != nil { + return nil, err + } + + return response, nil +} diff --git a/internal/gitlabnet/healthcheck/client_test.go b/internal/gitlabnet/healthcheck/client_test.go new file mode 100644 index 0000000..d7212b0 --- /dev/null +++ b/internal/gitlabnet/healthcheck/client_test.go @@ -0,0 +1,48 @@ +package healthcheck + +import ( + "encoding/json" + "net/http" + "testing" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + + "github.com/stretchr/testify/require" +) + +var ( + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/check", + Handler: func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(testResponse) + }, + }, + } + + testResponse = &Response{ + APIVersion: "v4", + GitlabVersion: "v12.0.0-ee", + GitlabRevision: "3b13818e8330f68625d80d9bf5d8049c41fbe197", + Redis: true, + } +) + +func TestCheck(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.Check() + require.NoError(t, err) + require.Equal(t, testResponse, result) +} + +func setup(t *testing.T) (*Client, func()) { + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + client, err := NewClient(&config.Config{GitlabUrl: url}) + require.NoError(t, err) + + return client, cleanup +} diff --git a/internal/gitlabnet/httpclient_test.go b/internal/gitlabnet/httpclient_test.go new file mode 100644 index 0000000..a40ab6d --- /dev/null +++ b/internal/gitlabnet/httpclient_test.go @@ -0,0 +1,96 @@ +package gitlabnet + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +const ( + username = "basic_auth_user" + password = "basic_auth_password" +) + +func TestBasicAuthSettings(t *testing.T) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/get_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + + fmt.Fprint(w, r.Header.Get("Authorization")) + }, + }, + { + Path: "/api/v4/internal/post_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + fmt.Fprint(w, r.Header.Get("Authorization")) + }, + }, + } + config := &config.Config{HttpSettings: config.HttpSettingsConfig{User: username, Password: password}} + + client, cleanup := setup(t, config, requests) + defer cleanup() + + response, err := client.Get("/get_endpoint") + require.NoError(t, err) + testBasicAuthHeaders(t, response) + + response, err = client.Post("/post_endpoint", nil) + require.NoError(t, err) + testBasicAuthHeaders(t, response) +} + +func testBasicAuthHeaders(t *testing.T, response *http.Response) { + defer response.Body.Close() + + require.NotNil(t, response) + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + + headerParts := strings.Split(string(responseBody), " ") + assert.Equal(t, "Basic", headerParts[0]) + + credentials, err := base64.StdEncoding.DecodeString(headerParts[1]) + require.NoError(t, err) + + assert.Equal(t, username+":"+password, string(credentials)) +} + +func TestEmptyBasicAuthSettings(t *testing.T) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/empty_basic_auth", + Handler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "", r.Header.Get("Authorization")) + }, + }, + } + + client, cleanup := setup(t, &config.Config{}, requests) + defer cleanup() + + _, err := client.Get("/empty_basic_auth") + require.NoError(t, err) +} + +func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (*GitlabClient, func()) { + url, cleanup := testserver.StartHttpServer(t, requests) + + config.GitlabUrl = url + client, err := GetClient(config) + require.NoError(t, err) + + return client, cleanup +} diff --git a/internal/gitlabnet/httpsclient_test.go b/internal/gitlabnet/httpsclient_test.go new file mode 100644 index 0000000..0acd425 --- /dev/null +++ b/internal/gitlabnet/httpsclient_test.go @@ -0,0 +1,125 @@ +package gitlabnet + +import ( + "fmt" + "io/ioutil" + "net/http" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestSuccessfulRequests(t *testing.T) { + testCases := []struct { + desc string + config *config.Config + }{ + { + desc: "Valid CaFile", + config: &config.Config{ + HttpSettings: config.HttpSettingsConfig{CaFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt")}, + }, + }, + { + desc: "Valid CaPath", + config: &config.Config{ + HttpSettings: config.HttpSettingsConfig{CaPath: path.Join(testhelper.TestRoot, "certs/valid")}, + }, + }, + { + desc: "Self signed cert option enabled", + config: &config.Config{ + HttpSettings: config.HttpSettingsConfig{SelfSignedCert: true}, + }, + }, + { + desc: "Invalid cert with self signed cert option enabled", + config: &config.Config{ + HttpSettings: config.HttpSettingsConfig{SelfSignedCert: true, CaFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, cleanup := setupWithRequests(t, tc.config) + defer cleanup() + + response, err := client.Get("/hello") + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, string(responseBody), "Hello") + }) + } +} + +func TestFailedRequests(t *testing.T) { + testCases := []struct { + desc string + config *config.Config + }{ + { + desc: "Invalid CaFile", + config: &config.Config{ + HttpSettings: config.HttpSettingsConfig{CaFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt")}, + }, + }, + { + desc: "Invalid CaPath", + config: &config.Config{ + HttpSettings: config.HttpSettingsConfig{CaPath: path.Join(testhelper.TestRoot, "certs/invalid")}, + }, + }, + { + desc: "Empty config", + config: &config.Config{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, cleanup := setupWithRequests(t, tc.config) + defer cleanup() + + _, err := client.Get("/hello") + require.Error(t, err) + + assert.Equal(t, err.Error(), "Internal API unreachable") + }) + } +} + +func setupWithRequests(t *testing.T, config *config.Config) (*GitlabClient, func()) { + testDirCleanup, err := testhelper.PrepareTestRootDir() + require.NoError(t, err) + defer testDirCleanup() + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/hello", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + + fmt.Fprint(w, "Hello") + }, + }, + } + + url, cleanup := testserver.StartHttpsServer(t, requests) + + config.GitlabUrl = url + client, err := GetClient(config) + require.NoError(t, err) + + return client, cleanup +} diff --git a/internal/gitlabnet/lfsauthenticate/client.go b/internal/gitlabnet/lfsauthenticate/client.go new file mode 100644 index 0000000..d7469dd --- /dev/null +++ b/internal/gitlabnet/lfsauthenticate/client.go @@ -0,0 +1,66 @@ +package lfsauthenticate + +import ( + "fmt" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" +) + +type Client struct { + config *config.Config + client *gitlabnet.GitlabClient + args *commandargs.Shell +} + +type Request struct { + Action commandargs.CommandType `json:"operation"` + Repo string `json:"project"` + KeyId string `json:"key_id,omitempty"` + UserId string `json:"user_id,omitempty"` +} + +type Response struct { + Username string `json:"username"` + LfsToken string `json:"lfs_token"` + RepoPath string `json:"repository_http_path"` + ExpiresIn int `json:"expires_in"` +} + +func NewClient(config *config.Config, args *commandargs.Shell) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client, args: args}, nil +} + +func (c *Client) Authenticate(action commandargs.CommandType, repo, userId string) (*Response, error) { + request := &Request{Action: action, Repo: repo} + if c.args.GitlabKeyId != "" { + request.KeyId = c.args.GitlabKeyId + } else { + request.UserId = strings.TrimPrefix(userId, "user-") + } + + response, err := c.client.Post("/lfs_authenticate", request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + return parse(response) +} + +func parse(hr *http.Response) (*Response, error) { + response := &Response{} + if err := gitlabnet.ParseJSON(hr, response); err != nil { + return nil, err + } + + return response, nil +} diff --git a/internal/gitlabnet/lfsauthenticate/client_test.go b/internal/gitlabnet/lfsauthenticate/client_test.go new file mode 100644 index 0000000..6faaa63 --- /dev/null +++ b/internal/gitlabnet/lfsauthenticate/client_test.go @@ -0,0 +1,117 @@ +package lfsauthenticate + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +const ( + keyId = "123" + repo = "group/repo" + action = commandargs.UploadPack +) + +func setup(t *testing.T) []testserver.TestRequestHandler { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/lfs_authenticate", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + require.NoError(t, err) + + var request *Request + require.NoError(t, json.Unmarshal(b, &request)) + + switch request.KeyId { + case keyId: + body := map[string]interface{}{ + "username": "john", + "lfs_token": "sometoken", + "repository_http_path": "https://gitlab.com/repo/path", + "expires_in": 1800, + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + case "forbidden": + w.WriteHeader(http.StatusForbidden) + case "broken": + w.WriteHeader(http.StatusInternalServerError) + } + }, + }, + } + + return requests +} + +func TestFailedRequests(t *testing.T) { + requests := setup(t) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + testCases := []struct { + desc string + args *commandargs.Shell + expectedOutput string + }{ + { + desc: "With bad response", + args: &commandargs.Shell{GitlabKeyId: "-1", CommandType: commandargs.UploadPack}, + expectedOutput: "Parsing failed", + }, + { + desc: "With API returns an error", + args: &commandargs.Shell{GitlabKeyId: "forbidden", CommandType: commandargs.UploadPack}, + expectedOutput: "Internal API error (403)", + }, + { + desc: "With API fails", + args: &commandargs.Shell{GitlabKeyId: "broken", CommandType: commandargs.UploadPack}, + expectedOutput: "Internal API error (500)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, err := NewClient(&config.Config{GitlabUrl: url}, tc.args) + require.NoError(t, err) + + repo := "group/repo" + + _, err = client.Authenticate(tc.args.CommandType, repo, "") + require.Error(t, err) + + require.Equal(t, tc.expectedOutput, err.Error()) + }) + } +} + +func TestSuccessfulRequests(t *testing.T) { + requests := setup(t) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + args := &commandargs.Shell{GitlabKeyId: keyId, CommandType: commandargs.LfsAuthenticate} + client, err := NewClient(&config.Config{GitlabUrl: url}, args) + require.NoError(t, err) + + response, err := client.Authenticate(action, repo, "") + require.NoError(t, err) + + expectedResponse := &Response{ + Username: "john", + LfsToken: "sometoken", + RepoPath: "https://gitlab.com/repo/path", + ExpiresIn: 1800, + } + + require.Equal(t, expectedResponse, response) +} diff --git a/internal/gitlabnet/testserver/gitalyserver.go b/internal/gitlabnet/testserver/gitalyserver.go new file mode 100644 index 0000000..694fd41 --- /dev/null +++ b/internal/gitlabnet/testserver/gitalyserver.go @@ -0,0 +1,78 @@ +package testserver + +import ( + "io/ioutil" + "net" + "os" + "path" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" +) + +type testGitalyServer struct{} + +func (s *testGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + + response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository) + stream.Send(&pb.SSHReceivePackResponse{Stdout: response}) + + return nil +} + +func (s *testGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + + response := []byte("UploadPack: " + req.Repository.GlRepository) + stream.Send(&pb.SSHUploadPackResponse{Stdout: response}) + + return nil +} + +func (s *testGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + + response := []byte("UploadArchive: " + req.Repository.GlRepository) + stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response}) + + return nil +} + +func StartGitalyServer(t *testing.T) (string, func()) { + tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api") + gitalySocketPath := path.Join(tempDir, "gitaly.sock") + + err := os.MkdirAll(filepath.Dir(gitalySocketPath), 0700) + require.NoError(t, err) + + server := grpc.NewServer() + + listener, err := net.Listen("unix", gitalySocketPath) + require.NoError(t, err) + + pb.RegisterSSHServiceServer(server, &testGitalyServer{}) + + go server.Serve(listener) + + gitalySocketUrl := "unix:" + gitalySocketPath + cleanup := func() { + server.Stop() + os.RemoveAll(tempDir) + } + + return gitalySocketUrl, cleanup +} diff --git a/internal/gitlabnet/testserver/testserver.go b/internal/gitlabnet/testserver/testserver.go new file mode 100644 index 0000000..f3b7b71 --- /dev/null +++ b/internal/gitlabnet/testserver/testserver.go @@ -0,0 +1,82 @@ +package testserver + +import ( + "crypto/tls" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "os" + "path" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +var ( + tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api") + testSocket = path.Join(tempDir, "internal.sock") +) + +type TestRequestHandler struct { + Path string + Handler func(w http.ResponseWriter, r *http.Request) +} + +func StartSocketHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { + err := os.MkdirAll(filepath.Dir(testSocket), 0700) + require.NoError(t, err) + + socketListener, err := net.Listen("unix", testSocket) + require.NoError(t, err) + + server := http.Server{ + Handler: buildHandler(handlers), + // We'll put this server through some nasty stuff we don't want + // in our test output + ErrorLog: log.New(ioutil.Discard, "", 0), + } + go server.Serve(socketListener) + + url := "http+unix://" + testSocket + + return url, cleanupSocket +} + +func StartHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { + server := httptest.NewServer(buildHandler(handlers)) + + return server.URL, server.Close +} + +func StartHttpsServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { + crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt") + key := path.Join(testhelper.TestRoot, "certs/valid/server.key") + + server := httptest.NewUnstartedServer(buildHandler(handlers)) + cer, err := tls.LoadX509KeyPair(crt, key) + require.NoError(t, err) + + server.TLS = &tls.Config{Certificates: []tls.Certificate{cer}} + server.StartTLS() + + return server.URL, server.Close +} + +func cleanupSocket() { + os.RemoveAll(tempDir) +} + +func buildHandler(handlers []TestRequestHandler) http.Handler { + h := http.NewServeMux() + + for _, handler := range handlers { + h.HandleFunc(handler.Path, handler.Handler) + } + + return h +} diff --git a/internal/gitlabnet/twofactorrecover/client.go b/internal/gitlabnet/twofactorrecover/client.go new file mode 100644 index 0000000..a3052f8 --- /dev/null +++ b/internal/gitlabnet/twofactorrecover/client.go @@ -0,0 +1,89 @@ +package twofactorrecover + +import ( + "errors" + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/discover" +) + +type Client struct { + config *config.Config + client *gitlabnet.GitlabClient +} + +type Response struct { + Success bool `json:"success"` + RecoveryCodes []string `json:"recovery_codes"` + Message string `json:"message"` +} + +type RequestBody struct { + KeyId string `json:"key_id,omitempty"` + UserId int64 `json:"user_id,omitempty"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) GetRecoveryCodes(args *commandargs.Shell) ([]string, error) { + requestBody, err := c.getRequestBody(args) + + if err != nil { + return nil, err + } + + response, err := c.client.Post("/two_factor_recovery_codes", requestBody) + if err != nil { + return nil, err + } + defer response.Body.Close() + + return parse(response) +} + +func parse(hr *http.Response) ([]string, error) { + response := &Response{} + if err := gitlabnet.ParseJSON(hr, response); err != nil { + return nil, err + } + + if !response.Success { + return nil, errors.New(response.Message) + } + + return response.RecoveryCodes, nil +} + +func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) { + client, err := discover.NewClient(c.config) + + if err != nil { + return nil, err + } + + var requestBody *RequestBody + if args.GitlabKeyId != "" { + requestBody = &RequestBody{KeyId: args.GitlabKeyId} + } else { + userInfo, err := client.GetByCommandArgs(args) + + if err != nil { + return nil, err + } + + requestBody = &RequestBody{UserId: userInfo.UserId} + } + + return requestBody, nil +} diff --git a/internal/gitlabnet/twofactorrecover/client_test.go b/internal/gitlabnet/twofactorrecover/client_test.go new file mode 100644 index 0000000..d5073e3 --- /dev/null +++ b/internal/gitlabnet/twofactorrecover/client_test.go @@ -0,0 +1,158 @@ +package twofactorrecover + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/discover" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +var ( + requests []testserver.TestRequestHandler +) + +func initialize(t *testing.T) { + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/two_factor_recovery_codes", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + var requestBody *RequestBody + json.Unmarshal(b, &requestBody) + + switch requestBody.KeyId { + case "0": + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery 1", "codes 1"}, + } + json.NewEncoder(w).Encode(body) + case "1": + body := map[string]interface{}{ + "success": false, + "message": "missing user", + } + json.NewEncoder(w).Encode(body) + case "2": + w.WriteHeader(http.StatusForbidden) + body := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + json.NewEncoder(w).Encode(body) + case "3": + w.Write([]byte("{ \"message\": \"broken json!\"")) + case "4": + w.WriteHeader(http.StatusForbidden) + } + + if requestBody.UserId == 1 { + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery 2", "codes 2"}, + } + json.NewEncoder(w).Encode(body) + } + }, + }, + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + body := &discover.Response{ + UserId: 1, + Username: "jane-doe", + Name: "Jane Doe", + } + json.NewEncoder(w).Encode(body) + }, + }, + } +} + +func TestGetRecoveryCodesByKeyId(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.Shell{GitlabKeyId: "0"} + result, err := client.GetRecoveryCodes(args) + assert.NoError(t, err) + assert.Equal(t, []string{"recovery 1", "codes 1"}, result) +} + +func TestGetRecoveryCodesByUsername(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.Shell{GitlabUsername: "jane-doe"} + result, err := client.GetRecoveryCodes(args) + assert.NoError(t, err) + assert.Equal(t, []string{"recovery 2", "codes 2"}, result) +} + +func TestMissingUser(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.Shell{GitlabKeyId: "1"} + _, err := client.GetRecoveryCodes(args) + assert.Equal(t, "missing user", err.Error()) +} + +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeId string + expectedError string + }{ + { + desc: "A response with an error message", + fakeId: "2", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeId: "3", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeId: "4", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + args := &commandargs.Shell{GitlabKeyId: tc.fakeId} + resp, err := client.GetRecoveryCodes(args) + + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + initialize(t) + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + client, err := NewClient(&config.Config{GitlabUrl: url}) + require.NoError(t, err) + + return client, cleanup +} diff --git a/internal/handler/exec.go b/internal/handler/exec.go new file mode 100644 index 0000000..ba9a4ff --- /dev/null +++ b/internal/handler/exec.go @@ -0,0 +1,96 @@ +package handler + +import ( + "context" + "fmt" + "os" + + "gitlab.com/gitlab-org/gitaly/auth" + "gitlab.com/gitlab-org/gitaly/client" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/labkit/tracing" + "google.golang.org/grpc" +) + +// GitalyHandlerFunc implementations are responsible for making +// an appropriate Gitaly call using the provided client and context +// and returning an error from the Gitaly call. +type GitalyHandlerFunc func(ctx context.Context, client *grpc.ClientConn) (int32, error) + +type GitalyConn struct { + ctx context.Context + conn *grpc.ClientConn + close func() +} + +type GitalyCommand struct { + Config *config.Config + ServiceName string + Address string + Token string +} + +// RunGitalyCommand provides a bootstrap for Gitaly commands executed +// through GitLab-Shell. It ensures that logging, tracing and other +// common concerns are configured before executing the `handler`. +func (gc *GitalyCommand) RunGitalyCommand(handler GitalyHandlerFunc) error { + gitalyConn, err := getConn(gc) + + if err != nil { + return err + } + + _, err = handler(gitalyConn.ctx, gitalyConn.conn) + + gitalyConn.close() + + return err +} + +func getConn(gc *GitalyCommand) (*GitalyConn, error) { + if gc.Address == "" { + return nil, fmt.Errorf("no gitaly_address given") + } + + connOpts := client.DefaultDialOpts + if gc.Token != "" { + connOpts = append(client.DefaultDialOpts, grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(gc.Token))) + } + + // Use a working directory that won't get removed or unmounted. + if err := os.Chdir("/"); err != nil { + return nil, err + } + + // Configure distributed tracing + serviceName := fmt.Sprintf("gitlab-shell-%v", gc.ServiceName) + closer := tracing.Initialize( + tracing.WithServiceName(serviceName), + + // For GitLab-Shell, we explicitly initialize tracing from a config file + // instead of the default environment variable (using GITLAB_TRACING) + // This decision was made owing to the difficulty in passing environment + // variables into GitLab-Shell processes. + // + // Processes are spawned as children of the SSH daemon, which tightly + // controls environment variables; doing this means we don't have to + // enable PermitUserEnvironment + tracing.WithConnectionString(gc.Config.GitlabTracing), + ) + + ctx, finished := tracing.ExtractFromEnv(context.Background()) + + conn, err := client.Dial(gc.Address, connOpts) + if err != nil { + return nil, err + } + + finish := func() { + finished() + closer.Close() + conn.Close() + } + + return &GitalyConn{ctx: ctx, conn: conn, close: finish}, nil +} diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go new file mode 100644 index 0000000..6c7d3f5 --- /dev/null +++ b/internal/handler/exec_test.go @@ -0,0 +1,42 @@ +package handler + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +func makeHandler(t *testing.T, err error) func(context.Context, *grpc.ClientConn) (int32, error) { + return func(ctx context.Context, client *grpc.ClientConn) (int32, error) { + require.NotNil(t, ctx) + require.NotNil(t, client) + + return 0, err + } +} + +func TestRunGitalyCommand(t *testing.T) { + cmd := GitalyCommand{ + Config: &config.Config{}, + Address: "tcp://localhost:9999", + } + + err := cmd.RunGitalyCommand(makeHandler(t, nil)) + require.NoError(t, err) + + expectedErr := errors.New("error") + err = cmd.RunGitalyCommand(makeHandler(t, expectedErr)) + require.Equal(t, err, expectedErr) +} + +func TestMissingGitalyAddress(t *testing.T) { + cmd := GitalyCommand{Config: &config.Config{}} + + err := cmd.RunGitalyCommand(makeHandler(t, nil)) + require.EqualError(t, err, "no gitaly_address given") +} diff --git a/internal/keyline/key_line.go b/internal/keyline/key_line.go new file mode 100644 index 0000000..c29a320 --- /dev/null +++ b/internal/keyline/key_line.go @@ -0,0 +1,62 @@ +package keyline + +import ( + "errors" + "fmt" + "path" + "regexp" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" +) + +var ( + keyRegex = regexp.MustCompile(`\A[a-z0-9-]+\z`) +) + +const ( + PublicKeyPrefix = "key" + PrincipalPrefix = "username" + SshOptions = "no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty" +) + +type KeyLine struct { + Id string // This can be either an ID of a Key or username + Value string // This can be either a public key or a principal name + Prefix string + RootDir string +} + +func NewPublicKeyLine(id string, publicKey string, rootDir string) (*KeyLine, error) { + return newKeyLine(id, publicKey, PublicKeyPrefix, rootDir) +} + +func NewPrincipalKeyLine(keyId string, principal string, rootDir string) (*KeyLine, error) { + return newKeyLine(keyId, principal, PrincipalPrefix, rootDir) +} + +func (k *KeyLine) ToString() string { + command := fmt.Sprintf("%s %s-%s", path.Join(k.RootDir, executable.BinDir, executable.GitlabShell), k.Prefix, k.Id) + + return fmt.Sprintf(`command="%s",%s %s`, command, SshOptions, k.Value) +} + +func newKeyLine(id string, value string, prefix string, rootDir string) (*KeyLine, error) { + if err := validate(id, value); err != nil { + return nil, err + } + + return &KeyLine{Id: id, Value: value, Prefix: prefix, RootDir: rootDir}, nil +} + +func validate(id string, value string) error { + if !keyRegex.MatchString(id) { + return errors.New(fmt.Sprintf("Invalid key_id: %s", id)) + } + + if strings.Contains(value, "\n") { + return errors.New(fmt.Sprintf("Invalid value: %s", value)) + } + + return nil +} diff --git a/internal/keyline/key_line_test.go b/internal/keyline/key_line_test.go new file mode 100644 index 0000000..c6883c0 --- /dev/null +++ b/internal/keyline/key_line_test.go @@ -0,0 +1,82 @@ +package keyline + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFailingNewPublicKeyLine(t *testing.T) { + testCases := []struct { + desc string + id string + publicKey string + expectedError string + }{ + { + desc: "When Id has non-alphanumeric and non-dash characters in it", + id: "key\n1", + publicKey: "public-key", + expectedError: "Invalid key_id: key\n1", + }, + { + desc: "When public key has newline in it", + id: "key", + publicKey: "public\nkey", + expectedError: "Invalid value: public\nkey", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := NewPublicKeyLine(tc.id, tc.publicKey, "root-dir") + + require.Empty(t, result) + require.EqualError(t, err, tc.expectedError) + }) + } +} + +func TestFailingNewPrincipalKeyLine(t *testing.T) { + testCases := []struct { + desc string + keyId string + principal string + expectedError string + }{ + { + desc: "When username has non-alphanumeric and non-dash characters in it", + keyId: "username\n1", + principal: "principal", + expectedError: "Invalid key_id: username\n1", + }, + { + desc: "When principal has newline in it", + keyId: "username", + principal: "principal\n1", + expectedError: "Invalid value: principal\n1", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := NewPrincipalKeyLine(tc.keyId, tc.principal, "root-dir") + + require.Empty(t, result) + require.EqualError(t, err, tc.expectedError) + }) + } +} + +func TestToString(t *testing.T) { + keyLine := &KeyLine{ + Id: "1", + Value: "public-key", + Prefix: "key", + RootDir: "/tmp", + } + + result := keyLine.ToString() + + require.Equal(t, `command="/tmp/bin/gitlab-shell key-1",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty public-key`, result) +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..bbc6a51 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,82 @@ +package logger + +import ( + "fmt" + "io" + golog "log" + "log/syslog" + "os" + "sync" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + + log "github.com/sirupsen/logrus" +) + +var ( + logWriter io.Writer + bootstrapLogger *golog.Logger + pid int + mutex sync.Mutex + ProgName string +) + +func Configure(cfg *config.Config) error { + mutex.Lock() + defer mutex.Unlock() + + pid = os.Getpid() + + var err error + logWriter, err = os.OpenFile(cfg.LogFile, os.O_WRONLY|os.O_APPEND, 0) + if err != nil { + return err + } + + log.SetOutput(logWriter) + if cfg.LogFormat == "json" { + log.SetFormatter(&log.JSONFormatter{}) + } + + return nil +} + +func logPrint(msg string, err error) { + mutex.Lock() + defer mutex.Unlock() + + if logWriter == nil { + bootstrapLogPrint(msg, err) + return + } + + log.WithError(err).WithFields(log.Fields{ + "pid": pid, + }).Error(msg) +} + +func Fatal(msg string, err error) { + logPrint(msg, err) + // We don't show the error to the end user because it can leak + // information that is private to the GitLab server. + fmt.Fprintf(os.Stderr, "%s: fatal: %s\n", ProgName, msg) + os.Exit(1) +} + +// If our log file is not available we want to log somewhere else, but +// not to standard error because that leaks information to the user. This +// function attemps to log to syslog. +// +// We assume the logging mutex is already locked. +func bootstrapLogPrint(msg string, err error) { + if bootstrapLogger == nil { + var err error + bootstrapLogger, err = syslog.NewLogger(syslog.LOG_ERR|syslog.LOG_USER, 0) + if err != nil { + // The message will not be logged. + return + } + } + + bootstrapLogger.Print(ProgName+":", msg+":", err) +} diff --git a/internal/sshenv/sshenv.go b/internal/sshenv/sshenv.go new file mode 100644 index 0000000..387feb2 --- /dev/null +++ b/internal/sshenv/sshenv.go @@ -0,0 +1,15 @@ +package sshenv + +import ( + "os" + "strings" +) + +func LocalAddr() string { + address := os.Getenv("SSH_CONNECTION") + + if address != "" { + return strings.Fields(address)[0] + } + return "" +} diff --git a/internal/sshenv/sshenv_test.go b/internal/sshenv/sshenv_test.go new file mode 100644 index 0000000..e4a640e --- /dev/null +++ b/internal/sshenv/sshenv_test.go @@ -0,0 +1,20 @@ +package sshenv + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestLocalAddr(t *testing.T) { + cleanup, err := testhelper.Setenv("SSH_CONNECTION", "127.0.0.1 0") + require.NoError(t, err) + defer cleanup() + + require.Equal(t, LocalAddr(), "127.0.0.1") +} + +func TestEmptyLocalAddr(t *testing.T) { + require.Equal(t, LocalAddr(), "") +} diff --git a/internal/testhelper/requesthandlers/requesthandlers.go b/internal/testhelper/requesthandlers/requesthandlers.go new file mode 100644 index 0000000..11817e8 --- /dev/null +++ b/internal/testhelper/requesthandlers/requesthandlers.go @@ -0,0 +1,58 @@ +package requesthandlers + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/testserver" +) + +func BuildDisallowedByApiHandlers(t *testing.T) []testserver.TestRequestHandler { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + body := map[string]interface{}{ + "status": false, + "message": "Disallowed by API call", + } + w.WriteHeader(http.StatusForbidden) + require.NoError(t, json.NewEncoder(w).Encode(body)) + }, + }, + } + + return requests +} + +func BuildAllowedWithGitalyHandlers(t *testing.T, gitalyAddress string) []testserver.TestRequestHandler { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + body := map[string]interface{}{ + "status": true, + "gl_id": "1", + "gitaly": map[string]interface{}{ + "repository": map[string]interface{}{ + "storage_name": "storage_name", + "relative_path": "relative_path", + "git_object_directory": "path/to/git_object_directory", + "git_alternate_object_directories": []string{"path/to/git_alternate_object_directory"}, + "gl_repository": "group/repo", + "gl_project_path": "group/project-path", + }, + "address": gitalyAddress, + "token": "token", + }, + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + }, + }, + } + + return requests +} diff --git a/internal/testhelper/testdata/testroot/.gitlab_shell_secret b/internal/testhelper/testdata/testroot/.gitlab_shell_secret new file mode 100644 index 0000000..9bd459d --- /dev/null +++ b/internal/testhelper/testdata/testroot/.gitlab_shell_secret @@ -0,0 +1 @@ +default-secret-content
\ No newline at end of file diff --git a/internal/testhelper/testdata/testroot/certs/invalid/server.crt b/internal/testhelper/testdata/testroot/certs/invalid/server.crt new file mode 100644 index 0000000..f8a42c1 --- /dev/null +++ b/internal/testhelper/testdata/testroot/certs/invalid/server.crt @@ -0,0 +1,10 @@ +-----BEGIN CERTIFICATE----- +MinvalidcertAOvHjs6cs1R9MAoGCCqGSM49BAMCMBQxEjAQBgNVBAMMCWxvY2Fs +ainvalidcertOTA0MjQxNjM4NTBaFw0yOTA0MjExNjM4NTBaMBQxEjAQBgNVBAMM +CinvalidcertdDB2MBAGByqGSM49AgEGBSuBBAAiA2IABJ5m7oW9OuL7aTAC04sL +3invalidcertdB2L0GsVCImav4PEpx6UAjkoiNGW9j0zPdNgxTYDjiCaGmr1aY2X +kinvalidcert7MNq7H8v7Ce/vrKkcDMOX8Gd/ddT3dEVqzAKBggqhkjOPQQDAgNp +AinvalidcertswcyjiB+A+ZjMSfaOsA2hAP0I3fkTcry386DePViMfnaIjm7rcuu +Jinvalidcert5V5CHypOxio1tOtGjaDkSH2FCdoatMyIe02+F6TIo44i4J/zjN52 +Jinvalidcert +-----END CERTIFICATE----- diff --git a/internal/testhelper/testdata/testroot/certs/valid/dir/.gitkeep b/internal/testhelper/testdata/testroot/certs/valid/dir/.gitkeep new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/internal/testhelper/testdata/testroot/certs/valid/dir/.gitkeep diff --git a/internal/testhelper/testdata/testroot/certs/valid/server.crt b/internal/testhelper/testdata/testroot/certs/valid/server.crt new file mode 100644 index 0000000..11f1da7 --- /dev/null +++ b/internal/testhelper/testdata/testroot/certs/valid/server.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDrjCCApagAwIBAgIUHVNTmyz3p+7xSEMkSfhPz4BZfqwwDQYJKoZIhvcNAQEL +BQAwTjELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEjAQBgNVBAcM +CVRoZSBDbG91ZDEWMBQGA1UECgwNTXkgQ29tcGFueSBDQTAeFw0xOTA5MjAxMDQ3 +NTlaFw0yOTA5MTcxMDQ3NTlaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRIwEAYDVQQHDAlUaGUgQ2xvdWQxDTALBgNVBAoMBERlbW8xFzAVBgNV +BAMMDk15IENlcnRpZmljYXRlMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAmte3G/eD+quamwyFl+2jEo8ngSAT0FWeY5ZAwRvdF4FgtTBLvbAdTnyi7pHM +esCSUkyxXHHPazM4SDV6uiu5LNKF0iz/NY76rLtFoqSGUgygTZHVbZ6NRXCNUZ0P +slD95wOCWvS9t9xgNXry66k8+mfZNhE+cFQfrO/pN5WpNuGyWTfKlUQw5NVL3mob +j3tSjI+wzSpbPMvbTQoBiZ/VHkyyc15YdrbePwFB2dJbxE/Xgsyk/TwWSUFnAs6i +1x2t+423NIm9rIDTdW2YYJJXv3MUcdDIxJnY0beGePMIymn9ZIRUJtK/ZXmwMb52 +v70+YTcsG67uSm31CR8jNt8qpQIDAQABo3QwcjAJBgNVHRMEAjAAMB0GA1UdDgQW +BBTxZ9SORmIwDs90TW8UXIVhDst4kjALBgNVHQ8EBAMCBaAwHQYDVR0lBBYwFAYI +KwYBBQUHAwIGCCsGAQUFBwMBMBoGA1UdEQQTMBGHBH8AAAGCCWxvY2FsaG9zdDAN +BgkqhkiG9w0BAQsFAAOCAQEAf4Iq94Su9TlkReMS4x2N5xZru9YoKQtrrxqWSRbp +oh5Lwtk9rJPy6q4IEPXzDsRI1YWCZe1Fw7zdiNfmoFRxjs59MBJ9YVrcFeyeAILg +LiAiRcGth2THpikCnLxmniGHUUX1WfjmcDEYMIs6BZ98N64VWwtuZqcJnJPmQs64 +lDrgW9oz6/8hPMeW58ok8PjkiG+E+srBaURoKwNe7vfPRVyq45N67/juH+4o6QBd +WP6ACjDM3RnxyWyW0S+sl3i3EAGgtwM6RIDhOG238HOIiA/I/+CCmITsvujz6jMN +bLdoPfnatZ7f5m9DuoOsGlYAZbLfOl2NywgO0jAlnHJGEQ== +-----END CERTIFICATE----- diff --git a/internal/testhelper/testdata/testroot/certs/valid/server.key b/internal/testhelper/testdata/testroot/certs/valid/server.key new file mode 100644 index 0000000..acec0fb --- /dev/null +++ b/internal/testhelper/testdata/testroot/certs/valid/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAmte3G/eD+quamwyFl+2jEo8ngSAT0FWeY5ZAwRvdF4FgtTBL +vbAdTnyi7pHMesCSUkyxXHHPazM4SDV6uiu5LNKF0iz/NY76rLtFoqSGUgygTZHV +bZ6NRXCNUZ0PslD95wOCWvS9t9xgNXry66k8+mfZNhE+cFQfrO/pN5WpNuGyWTfK +lUQw5NVL3mobj3tSjI+wzSpbPMvbTQoBiZ/VHkyyc15YdrbePwFB2dJbxE/Xgsyk +/TwWSUFnAs6i1x2t+423NIm9rIDTdW2YYJJXv3MUcdDIxJnY0beGePMIymn9ZIRU +JtK/ZXmwMb52v70+YTcsG67uSm31CR8jNt8qpQIDAQABAoIBAEJQyNdtdlTRUfG9 +tymOWR0FuoGO322GfcNhAnKyIEqE2oo/GPEwkByhPJa4Ur7v4rrkpcFV7OOYmC40 +2U8KktAjibSuGM8zYSDBQ92YYP6a8bzHDIVaNl7bCWs+vQ49qcBavGWAFBC+jWXa +Nle/r6H/AAQr9nXdUYObbGKl8kbSUBNAqQHILsNyxQsAo12oqRnUWhIbfzUFBr1m +us93OsvpOYWgkbaBWk0brjp2X0eNGHctTboFxRknJcU6MQVL5degbgXhnCm4ir4O +E2KMubEwxePr5fPotWNQXCVin85OQv1eb70anfwoA2b5/ykb57jo5EDoiUoFsjLz +KLAaRQECgYEAzZNP/CpwCh5s31SDr7ajYfNIu8ie370g2Qbf4jrqVrOJ8Sj1LRYB +lS5+QbSRu4W6Ani3AQwZA09lS608G8w5rD7YGRVDCFuwJt+Yz5GcsSkso9B8DR4h +vCe2WuDutz7M5ikP1DAc/9x5HIzjQijxM1JJCNU2nR6QoFvV6wpVcpECgYEAwNK9 +oTqyb7UjNinAo9PFrFpnbX+DoGokGPsRyUwi9UkyRR0Uf7Kxjoq2C8zsCvnGdrE7 +kwUiWjyfAgMDF8+iWHYO1vD7m6NL31h/AAmo0NEQIBs0LFj0lF0xORzvXdTjhvuG +LxXhm927z4WBOCLTn8FAsBUjVBpmB6ffyZCVWNUCgYA3P4j2fz0/KvAdkSwW9CGy +uFxqwz8XaE/Eo9lVhnnmNTg0TMqfhFOGkUkzRWEJIaZc9a5RJLwwLI1Pqk4GNnul +c/pFu3YZb/LGb780wbB32FX77JL6P4fXdmDGyb6+Fq2giZaMcyXICauu5ZpJ9JDm +Nw4TxqF31ngN8MBr+4n9UQKBgAkxAoEQ/zh79fW6/8fPbHjOxmdd0LRw2s+mCC8E +RhZTKuZIgJWluvkEe7EMT6QmS+OUhzZ25DBQ+3NpGVilOSPmXMa6LgQ5QIChA0zJ +KRbrIE2nflEu3FnGJ3aFfpOGdmIU00yjSmHXrAA0aPh4EIZo++Bo4Yo8x+hNhElj +bvsRAoGADYZTUchbiVndk5QtnbwlDjrF5PmgjoDboBfv9/6FU+DzQRyOpl3kr0hs +OcZGE6xPZJidv1Bcv60L1VzTMj7spvMRTeumn2zEQGjkl6i/fSZzawjmKaKXKNkC +YfoV0RepB4TlNYGICaTcV+aKRIXivcpBGfduZEb39iUKCjh9Afg= +-----END RSA PRIVATE KEY----- diff --git a/internal/testhelper/testdata/testroot/config.yml b/internal/testhelper/testdata/testroot/config.yml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/internal/testhelper/testdata/testroot/config.yml diff --git a/internal/testhelper/testdata/testroot/custom/my-contents-is-secret b/internal/testhelper/testdata/testroot/custom/my-contents-is-secret new file mode 100644 index 0000000..645b575 --- /dev/null +++ b/internal/testhelper/testdata/testroot/custom/my-contents-is-secret @@ -0,0 +1 @@ +custom-secret-content
\ No newline at end of file diff --git a/internal/testhelper/testdata/testroot/gitlab-shell.log b/internal/testhelper/testdata/testroot/gitlab-shell.log new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/internal/testhelper/testdata/testroot/gitlab-shell.log diff --git a/internal/testhelper/testdata/testroot/responses/allowed.json b/internal/testhelper/testdata/testroot/responses/allowed.json new file mode 100644 index 0000000..d0403d9 --- /dev/null +++ b/internal/testhelper/testdata/testroot/responses/allowed.json @@ -0,0 +1,22 @@ +{ + "status": true, + "gl_repository": "project-26", + "gl_project_path": "group/private", + "gl_id": "user-1", + "gl_username": "root", + "git_config_options": ["option"], + "gitaly": { + "repository": { + "storage_name": "default", + "relative_path": "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git", + "git_object_directory": "path/to/git_object_directory", + "git_alternate_object_directories": ["path/to/git_alternate_object_directory"], + "gl_repository": "project-26", + "gl_project_path": "group/private" + }, + "address": "unix:gitaly.socket", + "token": "token" + }, + "git_protocol": "protocol", + "gl_console_messages": ["console", "message"] +} diff --git a/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json b/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json new file mode 100644 index 0000000..331c3a9 --- /dev/null +++ b/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json @@ -0,0 +1,31 @@ +{ + "status": true, + "gl_repository": "project-26", + "gl_project_path": "group/private", + "gl_id": "user-1", + "gl_username": "root", + "git_config_options": ["option"], + "gitaly": { + "repository": { + "storage_name": "default", + "relative_path": "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git", + "git_object_directory": "path/to/git_object_directory", + "git_alternate_object_directories": ["path/to/git_alternate_object_directory"], + "gl_repository": "project-26", + "gl_project_path": "group/private" + }, + "address": "unix:gitaly.socket", + "token": "token" + }, + "payload" : { + "action": "geo_proxy_to_primary", + "data": { + "api_endpoints": ["geo/proxy_git_push_ssh/info_refs", "geo/proxy_git_push_ssh/push"], + "gl_username": "custom", + "primary_repo": "https://repo/path", + "info_message": "message" + } + }, + "git_protocol": "protocol", + "gl_console_messages": ["console", "message"] +} diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go new file mode 100644 index 0000000..a925c79 --- /dev/null +++ b/internal/testhelper/testhelper.go @@ -0,0 +1,93 @@ +package testhelper + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "runtime" + + "github.com/otiai10/copy" +) + +var ( + TestRoot, _ = ioutil.TempDir("", "test-gitlab-shell") +) + +func TempEnv(env map[string]string) func() { + var original = make(map[string]string) + for key, value := range env { + original[key] = os.Getenv(key) + os.Setenv(key, value) + } + + return func() { + for key, originalValue := range original { + os.Setenv(key, originalValue) + } + } +} + +func PrepareTestRootDir() (func(), error) { + if err := os.MkdirAll(TestRoot, 0700); err != nil { + return nil, err + } + + var oldWd string + cleanup := func() { + if oldWd != "" { + err := os.Chdir(oldWd) + if err != nil { + panic(err) + } + } + + if err := os.RemoveAll(TestRoot); err != nil { + panic(err) + } + } + + if err := copyTestData(); err != nil { + cleanup() + return nil, err + } + + oldWd, err := os.Getwd() + if err != nil { + cleanup() + return nil, err + } + + if err := os.Chdir(TestRoot); err != nil { + cleanup() + return nil, err + } + + return cleanup, nil +} + +func copyTestData() error { + testDataDir, err := getTestDataDir() + if err != nil { + return err + } + + testdata := path.Join(testDataDir, "testroot") + + return copy.Copy(testdata, TestRoot) +} + +func getTestDataDir() (string, error) { + _, currentFile, _, ok := runtime.Caller(0) + if !ok { + return "", fmt.Errorf("Could not get caller info") + } + + return path.Join(path.Dir(currentFile), "testdata"), nil +} + +func Setenv(key, value string) (func(), error) { + oldValue := os.Getenv(key) + err := os.Setenv(key, value) + return func() { os.Setenv(key, oldValue) }, err +} |