From 98dbdfb758703428626d54b2a257565a44509a55 Mon Sep 17 00:00:00 2001 From: Igor Date: Thu, 21 Mar 2019 11:53:09 +0000 Subject: Provide go implementation for 2fa_recovery_codes command --- go/cmd/gitlab-shell/main.go | 22 +-- go/internal/command/command.go | 7 +- go/internal/command/command_test.go | 14 ++ go/internal/command/commandargs/command_args.go | 7 +- go/internal/command/discover/discover.go | 18 +-- go/internal/command/discover/discover_test.go | 6 +- go/internal/command/fallback/fallback.go | 4 +- go/internal/command/readwriter/readwriter.go | 9 ++ go/internal/command/reporting/reporter.go | 8 - .../command/twofactorrecover/twofactorrecover.go | 64 ++++++++ .../twofactorrecover/twofactorrecover_test.go | 135 +++++++++++++++++ go/internal/gitlabnet/client.go | 3 +- go/internal/gitlabnet/client_test.go | 73 +++++++++- go/internal/gitlabnet/discover/client.go | 13 ++ go/internal/gitlabnet/socketclient.go | 20 +++ go/internal/gitlabnet/twofactorrecover/client.go | 104 +++++++++++++ .../gitlabnet/twofactorrecover/client_test.go | 161 +++++++++++++++++++++ spec/gitlab_shell_authorized_keys_check_spec.rb | 45 +----- spec/gitlab_shell_gitlab_shell_spec.rb | 66 ++------- spec/gitlab_shell_two_factor_recovery_spec.rb | 128 ++++++++++++++++ spec/support/gitlab_shell_setup.rb | 58 ++++++++ 21 files changed, 825 insertions(+), 140 deletions(-) create mode 100644 go/internal/command/readwriter/readwriter.go delete mode 100644 go/internal/command/reporting/reporter.go create mode 100644 go/internal/command/twofactorrecover/twofactorrecover.go create mode 100644 go/internal/command/twofactorrecover/twofactorrecover_test.go create mode 100644 go/internal/gitlabnet/twofactorrecover/client.go create mode 100644 go/internal/gitlabnet/twofactorrecover/client_test.go create mode 100644 spec/gitlab_shell_two_factor_recovery_spec.rb create mode 100644 spec/support/gitlab_shell_setup.rb diff --git a/go/cmd/gitlab-shell/main.go b/go/cmd/gitlab-shell/main.go index 2ed319d..51b5210 100644 --- a/go/cmd/gitlab-shell/main.go +++ b/go/cmd/gitlab-shell/main.go @@ -7,28 +7,28 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/go/internal/command" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" - "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" ) var ( - binDir string - rootDir string - reporter *reporting.Reporter + binDir string + rootDir string + readWriter *readwriter.ReadWriter ) func init() { binDir = filepath.Dir(os.Args[0]) rootDir = filepath.Dir(binDir) - reporter = &reporting.Reporter{Out: os.Stdout, ErrOut: os.Stderr} + readWriter = &readwriter.ReadWriter{Out: os.Stdout, In: os.Stdin, ErrOut: os.Stderr} } // rubyExec will never return. It either replaces the current process with a // Ruby interpreter, or outputs an error and kills the process. func execRuby() { cmd := &fallback.Command{} - if err := cmd.Execute(reporter); err != nil { - fmt.Fprintf(reporter.ErrOut, "Failed to exec: %v\n", err) + if err := cmd.Execute(readWriter); err != nil { + fmt.Fprintf(readWriter.ErrOut, "Failed to exec: %v\n", err) os.Exit(1) } } @@ -38,7 +38,7 @@ func main() { // warning as this isn't something we can sustain indefinitely config, err := config.NewFromDir(rootDir) if err != nil { - fmt.Fprintln(reporter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby") + fmt.Fprintln(readWriter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby") execRuby() } @@ -46,14 +46,14 @@ func main() { if err != nil { // For now this could happen if `SSH_CONNECTION` is not set on // the environment - fmt.Fprintf(reporter.ErrOut, "%v\n", err) + fmt.Fprintf(readWriter.ErrOut, "%v\n", err) os.Exit(1) } // The command will write to STDOUT on execution or replace the current // process in case of the `fallback.Command` - if err = cmd.Execute(reporter); err != nil { - fmt.Fprintf(reporter.ErrOut, "%v\n", err) + if err = cmd.Execute(readWriter); err != nil { + fmt.Fprintf(readWriter.ErrOut, "%v\n", err) os.Exit(1) } } diff --git a/go/internal/command/command.go b/go/internal/command/command.go index d4649de..b3bdcba 100644 --- a/go/internal/command/command.go +++ b/go/internal/command/command.go @@ -4,12 +4,13 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" - "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" ) type Command interface { - Execute(*reporting.Reporter) error + Execute(*readwriter.ReadWriter) error } func New(arguments []string, config *config.Config) (Command, error) { @@ -30,6 +31,8 @@ func buildCommand(args *commandargs.CommandArgs, config *config.Config) Command switch args.CommandType { case commandargs.Discover: return &discover.Command{Config: config, Args: args} + case commandargs.TwoFactorRecover: + return &twofactorrecover.Command{Config: config, Args: args} } return nil diff --git a/go/internal/command/command_test.go b/go/internal/command/command_test.go index 02fc0d0..42c5112 100644 --- a/go/internal/command/command_test.go +++ b/go/internal/command/command_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" ) @@ -44,6 +45,19 @@ func TestNew(t *testing.T) { }, expectedType: &fallback.Command{}, }, + { + desc: "it returns a TwoFactorRecover command if the feature is enabled", + arguments: []string{}, + config: &config.Config{ + GitlabUrl: "http+unix://gitlab.socket", + Migration: config.MigrationConfig{Enabled: true, Features: []string{"2fa_recovery_codes"}}, + }, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes", + }, + expectedType: &twofactorrecover.Command{}, + }, } for _, tc := range testCases { diff --git a/go/internal/command/commandargs/command_args.go b/go/internal/command/commandargs/command_args.go index 9e679d3..e801889 100644 --- a/go/internal/command/commandargs/command_args.go +++ b/go/internal/command/commandargs/command_args.go @@ -9,7 +9,8 @@ import ( type CommandType string const ( - Discover CommandType = "discover" + Discover CommandType = "discover" + TwoFactorRecover CommandType = "2fa_recovery_codes" ) var ( @@ -79,4 +80,8 @@ func (c *CommandArgs) parseCommand(commandString string) { if commandString == "" { c.CommandType = Discover } + + if CommandType(commandString) == TwoFactorRecover { + c.CommandType = TwoFactorRecover + } } diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go index 8ad2868..9bb442f 100644 --- a/go/internal/command/discover/discover.go +++ b/go/internal/command/discover/discover.go @@ -4,7 +4,7 @@ import ( "fmt" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" - "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" ) @@ -14,16 +14,16 @@ type Command struct { Args *commandargs.CommandArgs } -func (c *Command) Execute(reporter *reporting.Reporter) error { +func (c *Command) Execute(readWriter *readwriter.ReadWriter) error { response, err := c.getUserInfo() if err != nil { return fmt.Errorf("Failed to get username: %v", err) } if response.IsAnonymous() { - fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n") + fmt.Fprintf(readWriter.Out, "Welcome to GitLab, Anonymous!\n") } else { - fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username) + fmt.Fprintf(readWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) } return nil @@ -35,13 +35,5 @@ func (c *Command) getUserInfo() (*discover.Response, error) { return nil, err } - if c.Args.GitlabKeyId != "" { - return client.GetByKeyId(c.Args.GitlabKeyId) - } else if c.Args.GitlabUsername != "" { - return client.GetByUsername(c.Args.GitlabUsername) - } else { - // There was no 'who' information, this matches the ruby error - // message. - return nil, fmt.Errorf("who='' is invalid") - } + return client.GetByCommandArgs(c.Args) } diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go index ec6f931..a57f07e 100644 --- a/go/internal/command/discover/discover_test.go +++ b/go/internal/command/discover/discover_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" - "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" ) @@ -82,7 +82,7 @@ func TestExecute(t *testing.T) { cmd := &Command{Config: testConfig, Args: tc.arguments} buffer := &bytes.Buffer{} - err := cmd.Execute(&reporting.Reporter{Out: buffer}) + err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, buffer.String()) @@ -122,7 +122,7 @@ func TestFailingExecute(t *testing.T) { cmd := &Command{Config: testConfig, Args: tc.arguments} buffer := &bytes.Buffer{} - err := cmd.Execute(&reporting.Reporter{Out: buffer}) + err := cmd.Execute(&readwriter.ReadWriter{Out: buffer}) assert.Empty(t, buffer.String()) assert.EqualError(t, err, tc.expectedError) diff --git a/go/internal/command/fallback/fallback.go b/go/internal/command/fallback/fallback.go index a2c73ed..6e6d526 100644 --- a/go/internal/command/fallback/fallback.go +++ b/go/internal/command/fallback/fallback.go @@ -5,7 +5,7 @@ import ( "path/filepath" "syscall" - "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" ) type Command struct{} @@ -14,7 +14,7 @@ var ( binDir = filepath.Dir(os.Args[0]) ) -func (c *Command) Execute(_ *reporting.Reporter) error { +func (c *Command) Execute(_ *readwriter.ReadWriter) error { rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby") execErr := syscall.Exec(rubyCmd, os.Args, os.Environ()) return execErr diff --git a/go/internal/command/readwriter/readwriter.go b/go/internal/command/readwriter/readwriter.go new file mode 100644 index 0000000..da18d30 --- /dev/null +++ b/go/internal/command/readwriter/readwriter.go @@ -0,0 +1,9 @@ +package readwriter + +import "io" + +type ReadWriter struct { + Out io.Writer + In io.Reader + ErrOut io.Writer +} diff --git a/go/internal/command/reporting/reporter.go b/go/internal/command/reporting/reporter.go deleted file mode 100644 index 74bca59..0000000 --- a/go/internal/command/reporting/reporter.go +++ /dev/null @@ -1,8 +0,0 @@ -package reporting - -import "io" - -type Reporter struct { - Out io.Writer - ErrOut io.Writer -} diff --git a/go/internal/command/twofactorrecover/twofactorrecover.go b/go/internal/command/twofactorrecover/twofactorrecover.go new file mode 100644 index 0000000..e77a334 --- /dev/null +++ b/go/internal/command/twofactorrecover/twofactorrecover.go @@ -0,0 +1,64 @@ +package twofactorrecover + +import ( + "fmt" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover" +) + +type Command struct { + Config *config.Config + Args *commandargs.CommandArgs +} + +func (c *Command) Execute(readWriter *readwriter.ReadWriter) error { + if c.canContinue(readWriter) { + c.displayRecoveryCodes(readWriter) + } else { + fmt.Fprintln(readWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") + } + + return nil +} + +func (c *Command) canContinue(readWriter *readwriter.ReadWriter) bool { + question := + "Are you sure you want to generate new two-factor recovery codes?\n" + + "Any existing recovery codes you saved will be invalidated. (yes/no)" + fmt.Fprintln(readWriter.Out, question) + + var answer string + fmt.Fscanln(readWriter.In, &answer) + + return answer == "yes" +} + +func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) { + codes, err := c.getRecoveryCodes() + + if err == nil { + messageWithCodes := + "\nYour two-factor authentication recovery codes are:\n\n" + + strings.Join(codes, "\n") + + "\n\nDuring sign in, use one of the codes above when prompted for\n" + + "your two-factor code. Then, visit your Profile Settings and add\n" + + "a new device so you do not lose access to your account again.\n" + fmt.Fprint(readWriter.Out, messageWithCodes) + } else { + fmt.Fprintf(readWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err) + } +} + +func (c *Command) getRecoveryCodes() ([]string, error) { + client, err := twofactorrecover.NewClient(c.Config) + + if err != nil { + return nil, err + } + + return client.GetRecoveryCodes(c.Args) +} diff --git a/go/internal/command/twofactorrecover/twofactorrecover_test.go b/go/internal/command/twofactorrecover/twofactorrecover_test.go new file mode 100644 index 0000000..908ee13 --- /dev/null +++ b/go/internal/command/twofactorrecover/twofactorrecover_test.go @@ -0,0 +1,135 @@ +package twofactorrecover + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover" +) + +var ( + testConfig *config.Config + requests []testserver.TestRequestHandler +) + +func setup(t *testing.T) { + testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket} + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/two_factor_recovery_codes", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + var requestBody *twofactorrecover.RequestBody + json.Unmarshal(b, &requestBody) + + switch requestBody.KeyId { + case "1": + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery", "codes"}, + } + json.NewEncoder(w).Encode(body) + case "forbidden": + body := map[string]interface{}{ + "success": false, + "message": "Forbidden!", + } + json.NewEncoder(w).Encode(body) + case "broken": + w.WriteHeader(http.StatusInternalServerError) + } + }, + }, + } +} + +const ( + question = "Are you sure you want to generate new two-factor recovery codes?\n" + + "Any existing recovery codes you saved will be invalidated. (yes/no)\n\n" + errorHeader = "An error occurred while trying to generate new recovery codes.\n" +) + +func TestExecute(t *testing.T) { + setup(t) + + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.CommandArgs + answer string + expectedOutput string + }{ + { + desc: "With a known key id", + arguments: &commandargs.CommandArgs{GitlabKeyId: "1"}, + answer: "yes\n", + expectedOutput: question + + "Your two-factor authentication recovery codes are:\n\nrecovery\ncodes\n\n" + + "During sign in, use one of the codes above when prompted for\n" + + "your two-factor code. Then, visit your Profile Settings and add\n" + + "a new device so you do not lose access to your account again.\n", + }, + { + desc: "With bad response", + arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"}, + answer: "yes\n", + expectedOutput: question + errorHeader + "Parsing failed\n", + }, + { + desc: "With API returns an error", + arguments: &commandargs.CommandArgs{GitlabKeyId: "forbidden"}, + answer: "yes\n", + expectedOutput: question + errorHeader + "Forbidden!\n", + }, + { + desc: "With API fails", + arguments: &commandargs.CommandArgs{GitlabKeyId: "broken"}, + answer: "yes\n", + expectedOutput: question + errorHeader + "Internal API error (500)\n", + }, + { + desc: "With missing arguments", + arguments: &commandargs.CommandArgs{}, + answer: "yes\n", + expectedOutput: question + errorHeader + "who='' is invalid\n", + }, + { + desc: "With negative answer", + arguments: &commandargs.CommandArgs{}, + answer: "no\n", + expectedOutput: question + + "New recovery codes have *not* been generated. Existing codes will remain valid.\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + output := &bytes.Buffer{} + input := bytes.NewBufferString(tc.answer) + + cmd := &Command{Config: testConfig, Args: tc.arguments} + + err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input}) + + assert.NoError(t, err) + assert.Equal(t, tc.expectedOutput, output.String()) + }) + } +} diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go index abc218f..c2453e5 100644 --- a/go/internal/gitlabnet/client.go +++ b/go/internal/gitlabnet/client.go @@ -17,8 +17,7 @@ const ( type GitlabClient interface { Get(path string) (*http.Response, error) - // TODO: implement posts - // Post(path string) (http.Response, error) + Post(path string, data interface{}) (*http.Response, error) } type ErrorResponse struct { diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go index f69f284..c1d08a1 100644 --- a/go/internal/gitlabnet/client_test.go +++ b/go/internal/gitlabnet/client_test.go @@ -19,9 +19,24 @@ func TestClients(t *testing.T) { { Path: "/api/v4/internal/hello", Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + fmt.Fprint(w, "Hello") }, }, + { + Path: "/api/v4/internal/post_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + fmt.Fprint(w, "Echo: "+string(b)) + }, + }, { Path: "/api/v4/internal/auth", Handler: func(w http.ResponseWriter, r *http.Request) { @@ -68,6 +83,7 @@ func TestClients(t *testing.T) { testBrokenRequest(t, tc.client) testSuccessfulGet(t, tc.client) + testSuccessfulPost(t, tc.client) testMissing(t, tc.client) testErrorMessage(t, tc.client) testAuthenticationHeader(t, tc.client) @@ -89,32 +105,66 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) { }) } +func testSuccessfulPost(t *testing.T, client GitlabClient) { + t.Run("Successful Post", func(t *testing.T) { + data := map[string]string{"key": "value"} + + response, err := client.Post("/post_endpoint", data) + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody)) + }) +} + func testMissing(t *testing.T, client GitlabClient) { - t.Run("Missing error", func(t *testing.T) { + t.Run("Missing error for GET", func(t *testing.T) { response, err := client.Get("/missing") assert.EqualError(t, err, "Internal API error (404)") assert.Nil(t, response) }) + + t.Run("Missing error for POST", func(t *testing.T) { + response, err := client.Post("/missing", map[string]string{}) + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + }) } func testErrorMessage(t *testing.T, client GitlabClient) { - t.Run("Error with message", func(t *testing.T) { + t.Run("Error with message for GET", func(t *testing.T) { response, err := client.Get("/error") assert.EqualError(t, err, "Don't do that") assert.Nil(t, response) }) + + t.Run("Error with message for POST", func(t *testing.T) { + response, err := client.Post("/error", map[string]string{}) + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) } func testBrokenRequest(t *testing.T, client GitlabClient) { - t.Run("Broken request", func(t *testing.T) { + t.Run("Broken request for GET", func(t *testing.T) { response, err := client.Get("/broken") assert.EqualError(t, err, "Internal API unreachable") assert.Nil(t, response) }) + + t.Run("Broken request for POST", func(t *testing.T) { + response, err := client.Post("/broken", map[string]string{}) + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) } func testAuthenticationHeader(t *testing.T, client GitlabClient) { - t.Run("Authentication headers", func(t *testing.T) { + t.Run("Authentication headers for GET", func(t *testing.T) { response, err := client.Get("/auth") defer response.Body.Close() @@ -128,4 +178,19 @@ func testAuthenticationHeader(t *testing.T, client GitlabClient) { require.NoError(t, err) assert.Equal(t, "sssh, it's a secret", string(header)) }) + + t.Run("Authentication headers for POST", func(t *testing.T) { + response, err := client.Post("/auth", map[string]string{}) + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) } diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go index 8df78fb..e84b1b4 100644 --- a/go/internal/gitlabnet/discover/client.go +++ b/go/internal/gitlabnet/discover/client.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" ) @@ -30,6 +31,18 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{config: config, client: client}, nil } +func (c *Client) GetByCommandArgs(args *commandargs.CommandArgs) (*Response, error) { + if args.GitlabKeyId != "" { + return c.GetByKeyId(args.GitlabKeyId) + } else if args.GitlabUsername != "" { + return c.GetByUsername(args.GitlabUsername) + } else { + // There was no 'who' information, this matches the ruby error + // message. + return nil, fmt.Errorf("who='' is invalid") + } +} + func (c *Client) GetByKeyId(keyId string) (*Response, error) { params := url.Values{} params.Add("key_id", keyId) diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go index 3bd7c70..fd97535 100644 --- a/go/internal/gitlabnet/socketclient.go +++ b/go/internal/gitlabnet/socketclient.go @@ -1,7 +1,9 @@ package gitlabnet import ( + "bytes" "context" + "encoding/json" "net" "net/http" "strings" @@ -44,3 +46,21 @@ func (c *GitlabSocketClient) Get(path string) (*http.Response, error) { return doRequest(c.httpClient, c.config, request) } + +func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) { + path = normalizePath(path) + + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData)) + request.Header.Add("Content-Type", "application/json") + + if err != nil { + return nil, err + } + + return doRequest(c.httpClient, c.config, request) +} diff --git a/go/internal/gitlabnet/twofactorrecover/client.go b/go/internal/gitlabnet/twofactorrecover/client.go new file mode 100644 index 0000000..2e47c64 --- /dev/null +++ b/go/internal/gitlabnet/twofactorrecover/client.go @@ -0,0 +1,104 @@ +package twofactorrecover + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" +) + +type Client struct { + config *config.Config + client gitlabnet.GitlabClient +} + +type Response struct { + Success bool `json:"success"` + RecoveryCodes []string `json:"recovery_codes"` + Message string `json:"message"` +} + +type RequestBody struct { + KeyId string `json:"key_id,omitempty"` + UserId int64 `json:"user_id,omitempty"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) GetRecoveryCodes(args *commandargs.CommandArgs) ([]string, error) { + requestBody, err := c.getRequestBody(args) + + if err != nil { + return nil, err + } + + response, err := c.client.Post("/two_factor_recovery_codes", requestBody) + + if err != nil { + return nil, err + } + + defer response.Body.Close() + parsedResponse, err := c.parseResponse(response) + + if err != nil { + return nil, fmt.Errorf("Parsing failed") + } + + if parsedResponse.Success { + return parsedResponse.RecoveryCodes, nil + } else { + return nil, errors.New(parsedResponse.Message) + } +} + +func (c *Client) parseResponse(resp *http.Response) (*Response, error) { + parsedResponse := &Response{} + body, err := ioutil.ReadAll(resp.Body) + + if err != nil { + return nil, err + } + + if err := json.Unmarshal(body, parsedResponse); err != nil { + return nil, err + } else { + return parsedResponse, nil + } +} + +func (c *Client) getRequestBody(args *commandargs.CommandArgs) (*RequestBody, error) { + client, err := discover.NewClient(c.config) + + if err != nil { + return nil, err + } + + var requestBody *RequestBody + if args.GitlabKeyId != "" { + requestBody = &RequestBody{KeyId: args.GitlabKeyId} + } else { + userInfo, err := client.GetByCommandArgs(args) + + if err != nil { + return nil, err + } + + requestBody = &RequestBody{UserId: userInfo.UserId} + } + + return requestBody, nil +} diff --git a/go/internal/gitlabnet/twofactorrecover/client_test.go b/go/internal/gitlabnet/twofactorrecover/client_test.go new file mode 100644 index 0000000..5cbc011 --- /dev/null +++ b/go/internal/gitlabnet/twofactorrecover/client_test.go @@ -0,0 +1,161 @@ +package twofactorrecover + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +var ( + testConfig *config.Config + requests []testserver.TestRequestHandler +) + +func initialize(t *testing.T) { + testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket} + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/two_factor_recovery_codes", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + var requestBody *RequestBody + json.Unmarshal(b, &requestBody) + + switch requestBody.KeyId { + case "0": + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery 1", "codes 1"}, + } + json.NewEncoder(w).Encode(body) + case "1": + body := map[string]interface{}{ + "success": false, + "message": "missing user", + } + json.NewEncoder(w).Encode(body) + case "2": + w.WriteHeader(http.StatusForbidden) + body := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + json.NewEncoder(w).Encode(body) + case "3": + w.Write([]byte("{ \"message\": \"broken json!\"")) + case "4": + w.WriteHeader(http.StatusForbidden) + } + + if requestBody.UserId == 1 { + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery 2", "codes 2"}, + } + json.NewEncoder(w).Encode(body) + } + }, + }, + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + body := &discover.Response{ + UserId: 1, + Username: "jane-doe", + Name: "Jane Doe", + } + json.NewEncoder(w).Encode(body) + }, + }, + } +} + +func TestGetRecoveryCodesByKeyId(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.CommandArgs{GitlabKeyId: "0"} + result, err := client.GetRecoveryCodes(args) + assert.NoError(t, err) + assert.Equal(t, []string{"recovery 1", "codes 1"}, result) +} + +func TestGetRecoveryCodesByUsername(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.CommandArgs{GitlabUsername: "jane-doe"} + result, err := client.GetRecoveryCodes(args) + assert.NoError(t, err) + assert.Equal(t, []string{"recovery 2", "codes 2"}, result) +} + +func TestMissingUser(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.CommandArgs{GitlabKeyId: "1"} + _, err := client.GetRecoveryCodes(args) + assert.Equal(t, "missing user", err.Error()) +} + +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeId string + expectedError string + }{ + { + desc: "A response with an error message", + fakeId: "2", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeId: "3", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeId: "4", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + args := &commandargs.CommandArgs{GitlabKeyId: tc.fakeId} + resp, err := client.GetRecoveryCodes(args) + + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + initialize(t) + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + + client, err := NewClient(testConfig) + require.NoError(t, err) + + return client, cleanup +} diff --git a/spec/gitlab_shell_authorized_keys_check_spec.rb b/spec/gitlab_shell_authorized_keys_check_spec.rb index baaa560..7050604 100644 --- a/spec/gitlab_shell_authorized_keys_check_spec.rb +++ b/spec/gitlab_shell_authorized_keys_check_spec.rb @@ -1,20 +1,7 @@ require_relative 'spec_helper' describe 'bin/gitlab-shell-authorized-keys-check' do - def original_root_path - ROOT_PATH - end - - # All this test boilerplate is mostly copy/pasted between - # gitlab_shell_gitlab_shell_spec.rb and - # gitlab_shell_authorized_keys_check_spec.rb - def tmp_root_path - @tmp_root_path ||= File.realpath(Dir.mktmpdir) - end - - def config_path - File.join(tmp_root_path, 'config.yml') - end + include_context 'gitlab shell' def tmp_socket_path # This has to be a relative path shorter than 100 bytes due to @@ -22,12 +9,8 @@ describe 'bin/gitlab-shell-authorized-keys-check' do 'tmp/gitlab-shell-authorized-keys-check-socket' end - before(:all) do - FileUtils.mkdir_p(File.dirname(tmp_socket_path)) - FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret')) - - @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path) - @server.mount_proc('/api/v4/internal/authorized_keys') do |req, res| + def mock_server(server) + server.mount_proc('/api/v4/internal/authorized_keys') do |req, res| if req.query['key'] == 'known-rsa-key' res.status = 200 res.content_type = 'application/json' @@ -36,28 +19,14 @@ describe 'bin/gitlab-shell-authorized-keys-check' do res.status = 404 end end - - @webrick_thread = Thread.new { @server.start } - - sleep(0.1) while @webrick_thread.alive? && @server.status != :Running - raise "Couldn't start stub GitlabNet server" unless @server.status == :Running - - File.open(config_path, 'w') do |f| - f.write("---\ngitlab_url: http+unix://#{CGI.escape(tmp_socket_path)}\n") - end - - copy_dirs = ['bin', 'lib'] - FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) }) - FileUtils.cp_r(copy_dirs, tmp_root_path) end - after(:all) do - @server.shutdown if @server - @webrick_thread.join if @webrick_thread - FileUtils.rm_rf(tmp_root_path) + before(:all) do + write_config( + "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}", + ) end - let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') } let(:authorized_keys_check_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell-authorized-keys-check') } it 'succeeds when a valid key is given' do diff --git a/spec/gitlab_shell_gitlab_shell_spec.rb b/spec/gitlab_shell_gitlab_shell_spec.rb index cb3fd9c..6d6e172 100644 --- a/spec/gitlab_shell_gitlab_shell_spec.rb +++ b/spec/gitlab_shell_gitlab_shell_spec.rb @@ -3,33 +3,10 @@ require_relative 'spec_helper' require 'open3' describe 'bin/gitlab-shell' do - def original_root_path - ROOT_PATH - end - - # All this test boilerplate is mostly copy/pasted between - # gitlab_shell_gitlab_shell_spec.rb and - # gitlab_shell_authorized_keys_check_spec.rb - def tmp_root_path - @tmp_root_path ||= File.realpath(Dir.mktmpdir) - end - - def config_path - File.join(tmp_root_path, 'config.yml') - end - - def tmp_socket_path - # This has to be a relative path shorter than 100 bytes due to - # limitations in how Unix sockets work. - 'tmp/gitlab-shell-socket' - end - - before(:all) do - FileUtils.mkdir_p(File.dirname(tmp_socket_path)) - FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret')) + include_context 'gitlab shell' - @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path) - @server.mount_proc('/api/v4/internal/discover') do |req, res| + def mock_server(server) + server.mount_proc('/api/v4/internal/discover') do |req, res| identifier = req.query['key_id'] || req.query['username'] || req.query['user_id'] known_identifiers = %w(10 someuser 100) if known_identifiers.include?(identifier) @@ -47,24 +24,16 @@ describe 'bin/gitlab-shell' do res.status = 500 end end - - @webrick_thread = Thread.new { @server.start } - - sleep(0.1) while @webrick_thread.alive? && @server.status != :Running - raise "Couldn't start stub GitlabNet server" unless @server.status == :Running - system(original_root_path, 'bin/compile') - copy_dirs = ['bin', 'lib'] - FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) }) - FileUtils.cp_r(copy_dirs, tmp_root_path) end - after(:all) do - @server.shutdown if @server - @webrick_thread.join if @webrick_thread - FileUtils.rm_rf(tmp_root_path) - end + def run!(args, env: {'SSH_CONNECTION' => 'fake'}) + cmd = [ + gitlab_shell_path, + args + ].flatten.compact.join(' ') - let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') } + Open3.capture3(env, cmd) + end shared_examples 'results with keys' do # Basic valid input @@ -175,19 +144,4 @@ describe 'bin/gitlab-shell' do expect(status).not_to be_success end end - - def run!(args, env: {'SSH_CONNECTION' => 'fake'}) - cmd = [ - gitlab_shell_path, - args - ].flatten.compact.join(' ') - - Open3.capture3(env, cmd) - end - - def write_config(config) - File.open(config_path, 'w') do |f| - f.write(config.to_yaml) - end - end end diff --git a/spec/gitlab_shell_two_factor_recovery_spec.rb b/spec/gitlab_shell_two_factor_recovery_spec.rb new file mode 100644 index 0000000..19999e5 --- /dev/null +++ b/spec/gitlab_shell_two_factor_recovery_spec.rb @@ -0,0 +1,128 @@ +require_relative 'spec_helper' + +require 'open3' + +describe 'bin/gitlab-shell 2fa_recovery_codes' do + include_context 'gitlab shell' + + def mock_server(server) + server.mount_proc('/api/v4/internal/two_factor_recovery_codes') do |req, res| + res.content_type = 'application/json' + res.status = 200 + + key_id = req.query['key_id'] || req.query['user_id'] + + unless key_id + body = JSON.parse(req.body) + key_id = body['key_id'] || body['user_id'].to_s + end + + if key_id == '100' + res.body = '{"success":true, "recovery_codes": ["1", "2"]}' + else + res.body = '{"success":false, "message": "Forbidden!"}' + end + end + + server.mount_proc('/api/v4/internal/discover') do |req, res| + res.status = 200 + res.content_type = 'application/json' + res.body = '{"id":100, "name": "Some User", "username": "someuser"}' + end + end + + shared_examples 'dialog for regenerating recovery keys' do + context 'when the user agrees to regenerate keys' do + def verify_successful_regeneration!(cmd) + Open3.popen2(env, cmd) do |stdin, stdout| + expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n") + expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n") + + stdin.puts('yes') + + expect(stdout.flush.read).to eq( + "\nYour two-factor authentication recovery codes are:\n\n" \ + "1\n2\n\n" \ + "During sign in, use one of the codes above when prompted for\n" \ + "your two-factor code. Then, visit your Profile Settings and add\n" \ + "a new device so you do not lose access to your account again.\n" + ) + end + end + + context 'when key is provided' do + let(:cmd) { "#{gitlab_shell_path} key-100" } + + it 'the recovery keys are regenerated' do + verify_successful_regeneration!(cmd) + end + end + + context 'when username is provided' do + let(:cmd) { "#{gitlab_shell_path} username-someone" } + + it 'the recovery keys are regenerated' do + verify_successful_regeneration!(cmd) + end + end + end + + context 'when the user disagrees to regenerate keys' do + let(:cmd) { "#{gitlab_shell_path} key-100" } + + it 'the recovery keys are not regenerated' do + Open3.popen2(env, cmd) do |stdin, stdout| + expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n") + expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n") + + stdin.puts('no') + + expect(stdout.flush.read).to eq( + "\nNew recovery codes have *not* been generated. Existing codes will remain valid.\n" + ) + end + end + end + + context 'when API error occurs' do + let(:cmd) { "#{gitlab_shell_path} key-101" } + + context 'when the user agrees to regenerate keys' do + it 'the recovery keys are regenerated' do + Open3.popen2(env, cmd) do |stdin, stdout| + expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n") + expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n") + + stdin.puts('yes') + + expect(stdout.flush.read).to eq("\nAn error occurred while trying to generate new recovery codes.\nForbidden!\n") + end + end + end + end + end + + let(:env) { {'SSH_CONNECTION' => 'fake', 'SSH_ORIGINAL_COMMAND' => '2fa_recovery_codes' } } + + describe 'without go features' do + before(:context) do + write_config( + "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}", + ) + end + + it_behaves_like 'dialog for regenerating recovery keys' + end + + describe 'with go features' do + before(:context) do + write_config( + "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}", + "migration" => { "enabled" => true, + "features" => ["2fa_recovery_codes"] } + ) + end + + it_behaves_like 'dialog for regenerating recovery keys' + end +end diff --git a/spec/support/gitlab_shell_setup.rb b/spec/support/gitlab_shell_setup.rb new file mode 100644 index 0000000..eddd2d1 --- /dev/null +++ b/spec/support/gitlab_shell_setup.rb @@ -0,0 +1,58 @@ +RSpec.shared_context 'gitlab shell', shared_context: :metadata do + def original_root_path + ROOT_PATH + end + + def config_path + File.join(tmp_root_path, 'config.yml') + end + + def write_config(config) + File.open(config_path, 'w') do |f| + f.write(config.to_yaml) + end + end + + def tmp_root_path + @tmp_root_path ||= File.realpath(Dir.mktmpdir) + end + + def mock_server(server) + raise NotImplementedError.new( + 'mock_server method must be implemented in order to include gitlab shell context' + ) + end + + # This has to be a relative path shorter than 100 bytes due to + # limitations in how Unix sockets work. + def tmp_socket_path + 'tmp/gitlab-shell-socket' + end + + let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') } + + before(:all) do + FileUtils.mkdir_p(File.dirname(tmp_socket_path)) + FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret')) + + @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path) + + mock_server(@server) + + @webrick_thread = Thread.new { @server.start } + + sleep(0.1) while @webrick_thread.alive? && @server.status != :Running + raise "Couldn't start stub GitlabNet server" unless @server.status == :Running + system(original_root_path, 'bin/compile') + + copy_dirs = ['bin', 'lib'] + FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) }) + FileUtils.cp_r(copy_dirs, tmp_root_path) + end + + after(:all) do + @server.shutdown if @server + @webrick_thread.join if @webrick_thread + FileUtils.rm_rf(tmp_root_path) + end +end -- cgit v1.2.1