summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Thomas <nick@gitlab.com>2019-03-21 11:53:09 +0000
committerNick Thomas <nick@gitlab.com>2019-03-21 11:53:09 +0000
commitc0e60a07b4fa169132942e0e6389decfea214041 (patch)
treea3fdc408786fd0342bd3eb28ad841e70d3d7ac6e
parent81bed658f083a165e65b16f7ef86c18938349e33 (diff)
parent98dbdfb758703428626d54b2a257565a44509a55 (diff)
downloadgitlab-shell-c0e60a07b4fa169132942e0e6389decfea214041.tar.gz
Merge branch 'id-go-recovery-codes' into 'master'
Provide go implementation for 2fa_recovery_codes command See merge request gitlab-org/gitlab-shell!285
-rw-r--r--go/cmd/gitlab-shell/main.go22
-rw-r--r--go/internal/command/command.go7
-rw-r--r--go/internal/command/command_test.go14
-rw-r--r--go/internal/command/commandargs/command_args.go7
-rw-r--r--go/internal/command/discover/discover.go18
-rw-r--r--go/internal/command/discover/discover_test.go6
-rw-r--r--go/internal/command/fallback/fallback.go4
-rw-r--r--go/internal/command/readwriter/readwriter.go9
-rw-r--r--go/internal/command/reporting/reporter.go8
-rw-r--r--go/internal/command/twofactorrecover/twofactorrecover.go64
-rw-r--r--go/internal/command/twofactorrecover/twofactorrecover_test.go135
-rw-r--r--go/internal/gitlabnet/client.go3
-rw-r--r--go/internal/gitlabnet/client_test.go73
-rw-r--r--go/internal/gitlabnet/discover/client.go13
-rw-r--r--go/internal/gitlabnet/socketclient.go20
-rw-r--r--go/internal/gitlabnet/twofactorrecover/client.go104
-rw-r--r--go/internal/gitlabnet/twofactorrecover/client_test.go161
-rw-r--r--spec/gitlab_shell_authorized_keys_check_spec.rb45
-rw-r--r--spec/gitlab_shell_gitlab_shell_spec.rb66
-rw-r--r--spec/gitlab_shell_two_factor_recovery_spec.rb128
-rw-r--r--spec/support/gitlab_shell_setup.rb58
21 files changed, 825 insertions, 140 deletions
diff --git a/go/cmd/gitlab-shell/main.go b/go/cmd/gitlab-shell/main.go
index 2ed319d..51b5210 100644
--- a/go/cmd/gitlab-shell/main.go
+++ b/go/cmd/gitlab-shell/main.go
@@ -7,28 +7,28 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
var (
- binDir string
- rootDir string
- reporter *reporting.Reporter
+ binDir string
+ rootDir string
+ readWriter *readwriter.ReadWriter
)
func init() {
binDir = filepath.Dir(os.Args[0])
rootDir = filepath.Dir(binDir)
- reporter = &reporting.Reporter{Out: os.Stdout, ErrOut: os.Stderr}
+ readWriter = &readwriter.ReadWriter{Out: os.Stdout, In: os.Stdin, ErrOut: os.Stderr}
}
// rubyExec will never return. It either replaces the current process with a
// Ruby interpreter, or outputs an error and kills the process.
func execRuby() {
cmd := &fallback.Command{}
- if err := cmd.Execute(reporter); err != nil {
- fmt.Fprintf(reporter.ErrOut, "Failed to exec: %v\n", err)
+ if err := cmd.Execute(readWriter); err != nil {
+ fmt.Fprintf(readWriter.ErrOut, "Failed to exec: %v\n", err)
os.Exit(1)
}
}
@@ -38,7 +38,7 @@ func main() {
// warning as this isn't something we can sustain indefinitely
config, err := config.NewFromDir(rootDir)
if err != nil {
- fmt.Fprintln(reporter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby")
+ fmt.Fprintln(readWriter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby")
execRuby()
}
@@ -46,14 +46,14 @@ func main() {
if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on
// the environment
- fmt.Fprintf(reporter.ErrOut, "%v\n", err)
+ fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
}
// The command will write to STDOUT on execution or replace the current
// process in case of the `fallback.Command`
- if err = cmd.Execute(reporter); err != nil {
- fmt.Fprintf(reporter.ErrOut, "%v\n", err)
+ if err = cmd.Execute(readWriter); err != nil {
+ fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
}
}
diff --git a/go/internal/command/command.go b/go/internal/command/command.go
index d4649de..b3bdcba 100644
--- a/go/internal/command/command.go
+++ b/go/internal/command/command.go
@@ -4,12 +4,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
type Command interface {
- Execute(*reporting.Reporter) error
+ Execute(*readwriter.ReadWriter) error
}
func New(arguments []string, config *config.Config) (Command, error) {
@@ -30,6 +31,8 @@ func buildCommand(args *commandargs.CommandArgs, config *config.Config) Command
switch args.CommandType {
case commandargs.Discover:
return &discover.Command{Config: config, Args: args}
+ case commandargs.TwoFactorRecover:
+ return &twofactorrecover.Command{Config: config, Args: args}
}
return nil
diff --git a/go/internal/command/command_test.go b/go/internal/command/command_test.go
index 02fc0d0..42c5112 100644
--- a/go/internal/command/command_test.go
+++ b/go/internal/command/command_test.go
@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper"
)
@@ -44,6 +45,19 @@ func TestNew(t *testing.T) {
},
expectedType: &fallback.Command{},
},
+ {
+ desc: "it returns a TwoFactorRecover command if the feature is enabled",
+ arguments: []string{},
+ config: &config.Config{
+ GitlabUrl: "http+unix://gitlab.socket",
+ Migration: config.MigrationConfig{Enabled: true, Features: []string{"2fa_recovery_codes"}},
+ },
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes",
+ },
+ expectedType: &twofactorrecover.Command{},
+ },
}
for _, tc := range testCases {
diff --git a/go/internal/command/commandargs/command_args.go b/go/internal/command/commandargs/command_args.go
index 9e679d3..e801889 100644
--- a/go/internal/command/commandargs/command_args.go
+++ b/go/internal/command/commandargs/command_args.go
@@ -9,7 +9,8 @@ import (
type CommandType string
const (
- Discover CommandType = "discover"
+ Discover CommandType = "discover"
+ TwoFactorRecover CommandType = "2fa_recovery_codes"
)
var (
@@ -79,4 +80,8 @@ func (c *CommandArgs) parseCommand(commandString string) {
if commandString == "" {
c.CommandType = Discover
}
+
+ if CommandType(commandString) == TwoFactorRecover {
+ c.CommandType = TwoFactorRecover
+ }
}
diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go
index 8ad2868..9bb442f 100644
--- a/go/internal/command/discover/discover.go
+++ b/go/internal/command/discover/discover.go
@@ -4,7 +4,7 @@ import (
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
)
@@ -14,16 +14,16 @@ type Command struct {
Args *commandargs.CommandArgs
}
-func (c *Command) Execute(reporter *reporting.Reporter) error {
+func (c *Command) Execute(readWriter *readwriter.ReadWriter) error {
response, err := c.getUserInfo()
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
if response.IsAnonymous() {
- fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n")
+ fmt.Fprintf(readWriter.Out, "Welcome to GitLab, Anonymous!\n")
} else {
- fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username)
+ fmt.Fprintf(readWriter.Out, "Welcome to GitLab, @%s!\n", response.Username)
}
return nil
@@ -35,13 +35,5 @@ func (c *Command) getUserInfo() (*discover.Response, error) {
return nil, err
}
- if c.Args.GitlabKeyId != "" {
- return client.GetByKeyId(c.Args.GitlabKeyId)
- } else if c.Args.GitlabUsername != "" {
- return client.GetByUsername(c.Args.GitlabUsername)
- } else {
- // There was no 'who' information, this matches the ruby error
- // message.
- return nil, fmt.Errorf("who='' is invalid")
- }
+ return client.GetByCommandArgs(c.Args)
}
diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go
index ec6f931..a57f07e 100644
--- a/go/internal/command/discover/discover_test.go
+++ b/go/internal/command/discover/discover_test.go
@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
@@ -82,7 +82,7 @@ func TestExecute(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{}
- err := cmd.Execute(&reporting.Reporter{Out: buffer})
+ err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, buffer.String())
@@ -122,7 +122,7 @@ func TestFailingExecute(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{}
- err := cmd.Execute(&reporting.Reporter{Out: buffer})
+ err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
assert.Empty(t, buffer.String())
assert.EqualError(t, err, tc.expectedError)
diff --git a/go/internal/command/fallback/fallback.go b/go/internal/command/fallback/fallback.go
index a2c73ed..6e6d526 100644
--- a/go/internal/command/fallback/fallback.go
+++ b/go/internal/command/fallback/fallback.go
@@ -5,7 +5,7 @@ import (
"path/filepath"
"syscall"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
)
type Command struct{}
@@ -14,7 +14,7 @@ var (
binDir = filepath.Dir(os.Args[0])
)
-func (c *Command) Execute(_ *reporting.Reporter) error {
+func (c *Command) Execute(_ *readwriter.ReadWriter) error {
rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby")
execErr := syscall.Exec(rubyCmd, os.Args, os.Environ())
return execErr
diff --git a/go/internal/command/readwriter/readwriter.go b/go/internal/command/readwriter/readwriter.go
new file mode 100644
index 0000000..da18d30
--- /dev/null
+++ b/go/internal/command/readwriter/readwriter.go
@@ -0,0 +1,9 @@
+package readwriter
+
+import "io"
+
+type ReadWriter struct {
+ Out io.Writer
+ In io.Reader
+ ErrOut io.Writer
+}
diff --git a/go/internal/command/reporting/reporter.go b/go/internal/command/reporting/reporter.go
deleted file mode 100644
index 74bca59..0000000
--- a/go/internal/command/reporting/reporter.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package reporting
-
-import "io"
-
-type Reporter struct {
- Out io.Writer
- ErrOut io.Writer
-}
diff --git a/go/internal/command/twofactorrecover/twofactorrecover.go b/go/internal/command/twofactorrecover/twofactorrecover.go
new file mode 100644
index 0000000..e77a334
--- /dev/null
+++ b/go/internal/command/twofactorrecover/twofactorrecover.go
@@ -0,0 +1,64 @@
+package twofactorrecover
+
+import (
+ "fmt"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.CommandArgs
+}
+
+func (c *Command) Execute(readWriter *readwriter.ReadWriter) error {
+ if c.canContinue(readWriter) {
+ c.displayRecoveryCodes(readWriter)
+ } else {
+ fmt.Fprintln(readWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
+ }
+
+ return nil
+}
+
+func (c *Command) canContinue(readWriter *readwriter.ReadWriter) bool {
+ question :=
+ "Are you sure you want to generate new two-factor recovery codes?\n" +
+ "Any existing recovery codes you saved will be invalidated. (yes/no)"
+ fmt.Fprintln(readWriter.Out, question)
+
+ var answer string
+ fmt.Fscanln(readWriter.In, &answer)
+
+ return answer == "yes"
+}
+
+func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) {
+ codes, err := c.getRecoveryCodes()
+
+ if err == nil {
+ messageWithCodes :=
+ "\nYour two-factor authentication recovery codes are:\n\n" +
+ strings.Join(codes, "\n") +
+ "\n\nDuring sign in, use one of the codes above when prompted for\n" +
+ "your two-factor code. Then, visit your Profile Settings and add\n" +
+ "a new device so you do not lose access to your account again.\n"
+ fmt.Fprint(readWriter.Out, messageWithCodes)
+ } else {
+ fmt.Fprintf(readWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err)
+ }
+}
+
+func (c *Command) getRecoveryCodes() ([]string, error) {
+ client, err := twofactorrecover.NewClient(c.Config)
+
+ if err != nil {
+ return nil, err
+ }
+
+ return client.GetRecoveryCodes(c.Args)
+}
diff --git a/go/internal/command/twofactorrecover/twofactorrecover_test.go b/go/internal/command/twofactorrecover/twofactorrecover_test.go
new file mode 100644
index 0000000..908ee13
--- /dev/null
+++ b/go/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -0,0 +1,135 @@
+package twofactorrecover
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover"
+)
+
+var (
+ testConfig *config.Config
+ requests []testserver.TestRequestHandler
+)
+
+func setup(t *testing.T) {
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/two_factor_recovery_codes",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ var requestBody *twofactorrecover.RequestBody
+ json.Unmarshal(b, &requestBody)
+
+ switch requestBody.KeyId {
+ case "1":
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery", "codes"},
+ }
+ json.NewEncoder(w).Encode(body)
+ case "forbidden":
+ body := map[string]interface{}{
+ "success": false,
+ "message": "Forbidden!",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "broken":
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ },
+ },
+ }
+}
+
+const (
+ question = "Are you sure you want to generate new two-factor recovery codes?\n" +
+ "Any existing recovery codes you saved will be invalidated. (yes/no)\n\n"
+ errorHeader = "An error occurred while trying to generate new recovery codes.\n"
+)
+
+func TestExecute(t *testing.T) {
+ setup(t)
+
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.CommandArgs
+ answer string
+ expectedOutput string
+ }{
+ {
+ desc: "With a known key id",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "1"},
+ answer: "yes\n",
+ expectedOutput: question +
+ "Your two-factor authentication recovery codes are:\n\nrecovery\ncodes\n\n" +
+ "During sign in, use one of the codes above when prompted for\n" +
+ "your two-factor code. Then, visit your Profile Settings and add\n" +
+ "a new device so you do not lose access to your account again.\n",
+ },
+ {
+ desc: "With bad response",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Parsing failed\n",
+ },
+ {
+ desc: "With API returns an error",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "forbidden"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Forbidden!\n",
+ },
+ {
+ desc: "With API fails",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "broken"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Internal API error (500)\n",
+ },
+ {
+ desc: "With missing arguments",
+ arguments: &commandargs.CommandArgs{},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "who='' is invalid\n",
+ },
+ {
+ desc: "With negative answer",
+ arguments: &commandargs.CommandArgs{},
+ answer: "no\n",
+ expectedOutput: question +
+ "New recovery codes have *not* been generated. Existing codes will remain valid.\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ output := &bytes.Buffer{}
+ input := bytes.NewBufferString(tc.answer)
+
+ cmd := &Command{Config: testConfig, Args: tc.arguments}
+
+ err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input})
+
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expectedOutput, output.String())
+ })
+ }
+}
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go
index abc218f..c2453e5 100644
--- a/go/internal/gitlabnet/client.go
+++ b/go/internal/gitlabnet/client.go
@@ -17,8 +17,7 @@ const (
type GitlabClient interface {
Get(path string) (*http.Response, error)
- // TODO: implement posts
- // Post(path string) (http.Response, error)
+ Post(path string, data interface{}) (*http.Response, error)
}
type ErrorResponse struct {
diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go
index f69f284..c1d08a1 100644
--- a/go/internal/gitlabnet/client_test.go
+++ b/go/internal/gitlabnet/client_test.go
@@ -19,10 +19,25 @@ func TestClients(t *testing.T) {
{
Path: "/api/v4/internal/hello",
Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
fmt.Fprint(w, "Hello")
},
},
{
+ Path: "/api/v4/internal/post_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ fmt.Fprint(w, "Echo: "+string(b))
+ },
+ },
+ {
Path: "/api/v4/internal/auth",
Handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, r.Header.Get(secretHeaderName))
@@ -68,6 +83,7 @@ func TestClients(t *testing.T) {
testBrokenRequest(t, tc.client)
testSuccessfulGet(t, tc.client)
+ testSuccessfulPost(t, tc.client)
testMissing(t, tc.client)
testErrorMessage(t, tc.client)
testAuthenticationHeader(t, tc.client)
@@ -89,32 +105,66 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) {
})
}
+func testSuccessfulPost(t *testing.T, client GitlabClient) {
+ t.Run("Successful Post", func(t *testing.T) {
+ data := map[string]string{"key": "value"}
+
+ response, err := client.Post("/post_endpoint", data)
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody))
+ })
+}
+
func testMissing(t *testing.T, client GitlabClient) {
- t.Run("Missing error", func(t *testing.T) {
+ t.Run("Missing error for GET", func(t *testing.T) {
response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
})
+
+ t.Run("Missing error for POST", func(t *testing.T) {
+ response, err := client.Post("/missing", map[string]string{})
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+ })
}
func testErrorMessage(t *testing.T, client GitlabClient) {
- t.Run("Error with message", func(t *testing.T) {
+ t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
+
+ t.Run("Error with message for POST", func(t *testing.T) {
+ response, err := client.Post("/error", map[string]string{})
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
}
func testBrokenRequest(t *testing.T, client GitlabClient) {
- t.Run("Broken request", func(t *testing.T) {
+ t.Run("Broken request for GET", func(t *testing.T) {
response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
})
+
+ t.Run("Broken request for POST", func(t *testing.T) {
+ response, err := client.Post("/broken", map[string]string{})
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
}
func testAuthenticationHeader(t *testing.T, client GitlabClient) {
- t.Run("Authentication headers", func(t *testing.T) {
+ t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth")
defer response.Body.Close()
@@ -128,4 +178,19 @@ func testAuthenticationHeader(t *testing.T, client GitlabClient) {
require.NoError(t, err)
assert.Equal(t, "sssh, it's a secret", string(header))
})
+
+ t.Run("Authentication headers for POST", func(t *testing.T) {
+ response, err := client.Post("/auth", map[string]string{})
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
}
diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go
index 8df78fb..e84b1b4 100644
--- a/go/internal/gitlabnet/discover/client.go
+++ b/go/internal/gitlabnet/discover/client.go
@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
)
@@ -30,6 +31,18 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
+func (c *Client) GetByCommandArgs(args *commandargs.CommandArgs) (*Response, error) {
+ if args.GitlabKeyId != "" {
+ return c.GetByKeyId(args.GitlabKeyId)
+ } else if args.GitlabUsername != "" {
+ return c.GetByUsername(args.GitlabUsername)
+ } else {
+ // There was no 'who' information, this matches the ruby error
+ // message.
+ return nil, fmt.Errorf("who='' is invalid")
+ }
+}
+
func (c *Client) GetByKeyId(keyId string) (*Response, error) {
params := url.Values{}
params.Add("key_id", keyId)
diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go
index 3bd7c70..fd97535 100644
--- a/go/internal/gitlabnet/socketclient.go
+++ b/go/internal/gitlabnet/socketclient.go
@@ -1,7 +1,9 @@
package gitlabnet
import (
+ "bytes"
"context"
+ "encoding/json"
"net"
"net/http"
"strings"
@@ -44,3 +46,21 @@ func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
return doRequest(c.httpClient, c.config, request)
}
+
+func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) {
+ path = normalizePath(path)
+
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+
+ request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData))
+ request.Header.Add("Content-Type", "application/json")
+
+ if err != nil {
+ return nil, err
+ }
+
+ return doRequest(c.httpClient, c.config, request)
+}
diff --git a/go/internal/gitlabnet/twofactorrecover/client.go b/go/internal/gitlabnet/twofactorrecover/client.go
new file mode 100644
index 0000000..2e47c64
--- /dev/null
+++ b/go/internal/gitlabnet/twofactorrecover/client.go
@@ -0,0 +1,104 @@
+package twofactorrecover
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
+)
+
+type Client struct {
+ config *config.Config
+ client gitlabnet.GitlabClient
+}
+
+type Response struct {
+ Success bool `json:"success"`
+ RecoveryCodes []string `json:"recovery_codes"`
+ Message string `json:"message"`
+}
+
+type RequestBody struct {
+ KeyId string `json:"key_id,omitempty"`
+ UserId int64 `json:"user_id,omitempty"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetRecoveryCodes(args *commandargs.CommandArgs) ([]string, error) {
+ requestBody, err := c.getRequestBody(args)
+
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
+
+ if err != nil {
+ return nil, err
+ }
+
+ defer response.Body.Close()
+ parsedResponse, err := c.parseResponse(response)
+
+ if err != nil {
+ return nil, fmt.Errorf("Parsing failed")
+ }
+
+ if parsedResponse.Success {
+ return parsedResponse.RecoveryCodes, nil
+ } else {
+ return nil, errors.New(parsedResponse.Message)
+ }
+}
+
+func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
+ parsedResponse := &Response{}
+ body, err := ioutil.ReadAll(resp.Body)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if err := json.Unmarshal(body, parsedResponse); err != nil {
+ return nil, err
+ } else {
+ return parsedResponse, nil
+ }
+}
+
+func (c *Client) getRequestBody(args *commandargs.CommandArgs) (*RequestBody, error) {
+ client, err := discover.NewClient(c.config)
+
+ if err != nil {
+ return nil, err
+ }
+
+ var requestBody *RequestBody
+ if args.GitlabKeyId != "" {
+ requestBody = &RequestBody{KeyId: args.GitlabKeyId}
+ } else {
+ userInfo, err := client.GetByCommandArgs(args)
+
+ if err != nil {
+ return nil, err
+ }
+
+ requestBody = &RequestBody{UserId: userInfo.UserId}
+ }
+
+ return requestBody, nil
+}
diff --git a/go/internal/gitlabnet/twofactorrecover/client_test.go b/go/internal/gitlabnet/twofactorrecover/client_test.go
new file mode 100644
index 0000000..5cbc011
--- /dev/null
+++ b/go/internal/gitlabnet/twofactorrecover/client_test.go
@@ -0,0 +1,161 @@
+package twofactorrecover
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+var (
+ testConfig *config.Config
+ requests []testserver.TestRequestHandler
+)
+
+func initialize(t *testing.T) {
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/two_factor_recovery_codes",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ var requestBody *RequestBody
+ json.Unmarshal(b, &requestBody)
+
+ switch requestBody.KeyId {
+ case "0":
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery 1", "codes 1"},
+ }
+ json.NewEncoder(w).Encode(body)
+ case "1":
+ body := map[string]interface{}{
+ "success": false,
+ "message": "missing user",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "2":
+ w.WriteHeader(http.StatusForbidden)
+ body := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "3":
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ case "4":
+ w.WriteHeader(http.StatusForbidden)
+ }
+
+ if requestBody.UserId == 1 {
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery 2", "codes 2"},
+ }
+ json.NewEncoder(w).Encode(body)
+ }
+ },
+ },
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ body := &discover.Response{
+ UserId: 1,
+ Username: "jane-doe",
+ Name: "Jane Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ }
+}
+
+func TestGetRecoveryCodesByKeyId(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.CommandArgs{GitlabKeyId: "0"}
+ result, err := client.GetRecoveryCodes(args)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"recovery 1", "codes 1"}, result)
+}
+
+func TestGetRecoveryCodesByUsername(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.CommandArgs{GitlabUsername: "jane-doe"}
+ result, err := client.GetRecoveryCodes(args)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"recovery 2", "codes 2"}, result)
+}
+
+func TestMissingUser(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.CommandArgs{GitlabKeyId: "1"}
+ _, err := client.GetRecoveryCodes(args)
+ assert.Equal(t, "missing user", err.Error())
+}
+
+func TestErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ fakeId string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ fakeId: "2",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ fakeId: "3",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "An error response without message",
+ fakeId: "4",
+ expectedError: "Internal API error (403)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ args := &commandargs.CommandArgs{GitlabKeyId: tc.fakeId}
+ resp, err := client.GetRecoveryCodes(args)
+
+ assert.EqualError(t, err, tc.expectedError)
+ assert.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ initialize(t)
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+
+ client, err := NewClient(testConfig)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/spec/gitlab_shell_authorized_keys_check_spec.rb b/spec/gitlab_shell_authorized_keys_check_spec.rb
index baaa560..7050604 100644
--- a/spec/gitlab_shell_authorized_keys_check_spec.rb
+++ b/spec/gitlab_shell_authorized_keys_check_spec.rb
@@ -1,20 +1,7 @@
require_relative 'spec_helper'
describe 'bin/gitlab-shell-authorized-keys-check' do
- def original_root_path
- ROOT_PATH
- end
-
- # All this test boilerplate is mostly copy/pasted between
- # gitlab_shell_gitlab_shell_spec.rb and
- # gitlab_shell_authorized_keys_check_spec.rb
- def tmp_root_path
- @tmp_root_path ||= File.realpath(Dir.mktmpdir)
- end
-
- def config_path
- File.join(tmp_root_path, 'config.yml')
- end
+ include_context 'gitlab shell'
def tmp_socket_path
# This has to be a relative path shorter than 100 bytes due to
@@ -22,12 +9,8 @@ describe 'bin/gitlab-shell-authorized-keys-check' do
'tmp/gitlab-shell-authorized-keys-check-socket'
end
- before(:all) do
- FileUtils.mkdir_p(File.dirname(tmp_socket_path))
- FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret'))
-
- @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
- @server.mount_proc('/api/v4/internal/authorized_keys') do |req, res|
+ def mock_server(server)
+ server.mount_proc('/api/v4/internal/authorized_keys') do |req, res|
if req.query['key'] == 'known-rsa-key'
res.status = 200
res.content_type = 'application/json'
@@ -36,28 +19,14 @@ describe 'bin/gitlab-shell-authorized-keys-check' do
res.status = 404
end
end
-
- @webrick_thread = Thread.new { @server.start }
-
- sleep(0.1) while @webrick_thread.alive? && @server.status != :Running
- raise "Couldn't start stub GitlabNet server" unless @server.status == :Running
-
- File.open(config_path, 'w') do |f|
- f.write("---\ngitlab_url: http+unix://#{CGI.escape(tmp_socket_path)}\n")
- end
-
- copy_dirs = ['bin', 'lib']
- FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) })
- FileUtils.cp_r(copy_dirs, tmp_root_path)
end
- after(:all) do
- @server.shutdown if @server
- @webrick_thread.join if @webrick_thread
- FileUtils.rm_rf(tmp_root_path)
+ before(:all) do
+ write_config(
+ "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}",
+ )
end
- let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') }
let(:authorized_keys_check_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell-authorized-keys-check') }
it 'succeeds when a valid key is given' do
diff --git a/spec/gitlab_shell_gitlab_shell_spec.rb b/spec/gitlab_shell_gitlab_shell_spec.rb
index cb3fd9c..6d6e172 100644
--- a/spec/gitlab_shell_gitlab_shell_spec.rb
+++ b/spec/gitlab_shell_gitlab_shell_spec.rb
@@ -3,33 +3,10 @@ require_relative 'spec_helper'
require 'open3'
describe 'bin/gitlab-shell' do
- def original_root_path
- ROOT_PATH
- end
-
- # All this test boilerplate is mostly copy/pasted between
- # gitlab_shell_gitlab_shell_spec.rb and
- # gitlab_shell_authorized_keys_check_spec.rb
- def tmp_root_path
- @tmp_root_path ||= File.realpath(Dir.mktmpdir)
- end
-
- def config_path
- File.join(tmp_root_path, 'config.yml')
- end
-
- def tmp_socket_path
- # This has to be a relative path shorter than 100 bytes due to
- # limitations in how Unix sockets work.
- 'tmp/gitlab-shell-socket'
- end
-
- before(:all) do
- FileUtils.mkdir_p(File.dirname(tmp_socket_path))
- FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret'))
+ include_context 'gitlab shell'
- @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
- @server.mount_proc('/api/v4/internal/discover') do |req, res|
+ def mock_server(server)
+ server.mount_proc('/api/v4/internal/discover') do |req, res|
identifier = req.query['key_id'] || req.query['username'] || req.query['user_id']
known_identifiers = %w(10 someuser 100)
if known_identifiers.include?(identifier)
@@ -47,24 +24,16 @@ describe 'bin/gitlab-shell' do
res.status = 500
end
end
-
- @webrick_thread = Thread.new { @server.start }
-
- sleep(0.1) while @webrick_thread.alive? && @server.status != :Running
- raise "Couldn't start stub GitlabNet server" unless @server.status == :Running
- system(original_root_path, 'bin/compile')
- copy_dirs = ['bin', 'lib']
- FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) })
- FileUtils.cp_r(copy_dirs, tmp_root_path)
end
- after(:all) do
- @server.shutdown if @server
- @webrick_thread.join if @webrick_thread
- FileUtils.rm_rf(tmp_root_path)
- end
+ def run!(args, env: {'SSH_CONNECTION' => 'fake'})
+ cmd = [
+ gitlab_shell_path,
+ args
+ ].flatten.compact.join(' ')
- let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') }
+ Open3.capture3(env, cmd)
+ end
shared_examples 'results with keys' do
# Basic valid input
@@ -175,19 +144,4 @@ describe 'bin/gitlab-shell' do
expect(status).not_to be_success
end
end
-
- def run!(args, env: {'SSH_CONNECTION' => 'fake'})
- cmd = [
- gitlab_shell_path,
- args
- ].flatten.compact.join(' ')
-
- Open3.capture3(env, cmd)
- end
-
- def write_config(config)
- File.open(config_path, 'w') do |f|
- f.write(config.to_yaml)
- end
- end
end
diff --git a/spec/gitlab_shell_two_factor_recovery_spec.rb b/spec/gitlab_shell_two_factor_recovery_spec.rb
new file mode 100644
index 0000000..19999e5
--- /dev/null
+++ b/spec/gitlab_shell_two_factor_recovery_spec.rb
@@ -0,0 +1,128 @@
+require_relative 'spec_helper'
+
+require 'open3'
+
+describe 'bin/gitlab-shell 2fa_recovery_codes' do
+ include_context 'gitlab shell'
+
+ def mock_server(server)
+ server.mount_proc('/api/v4/internal/two_factor_recovery_codes') do |req, res|
+ res.content_type = 'application/json'
+ res.status = 200
+
+ key_id = req.query['key_id'] || req.query['user_id']
+
+ unless key_id
+ body = JSON.parse(req.body)
+ key_id = body['key_id'] || body['user_id'].to_s
+ end
+
+ if key_id == '100'
+ res.body = '{"success":true, "recovery_codes": ["1", "2"]}'
+ else
+ res.body = '{"success":false, "message": "Forbidden!"}'
+ end
+ end
+
+ server.mount_proc('/api/v4/internal/discover') do |req, res|
+ res.status = 200
+ res.content_type = 'application/json'
+ res.body = '{"id":100, "name": "Some User", "username": "someuser"}'
+ end
+ end
+
+ shared_examples 'dialog for regenerating recovery keys' do
+ context 'when the user agrees to regenerate keys' do
+ def verify_successful_regeneration!(cmd)
+ Open3.popen2(env, cmd) do |stdin, stdout|
+ expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n")
+ expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n")
+
+ stdin.puts('yes')
+
+ expect(stdout.flush.read).to eq(
+ "\nYour two-factor authentication recovery codes are:\n\n" \
+ "1\n2\n\n" \
+ "During sign in, use one of the codes above when prompted for\n" \
+ "your two-factor code. Then, visit your Profile Settings and add\n" \
+ "a new device so you do not lose access to your account again.\n"
+ )
+ end
+ end
+
+ context 'when key is provided' do
+ let(:cmd) { "#{gitlab_shell_path} key-100" }
+
+ it 'the recovery keys are regenerated' do
+ verify_successful_regeneration!(cmd)
+ end
+ end
+
+ context 'when username is provided' do
+ let(:cmd) { "#{gitlab_shell_path} username-someone" }
+
+ it 'the recovery keys are regenerated' do
+ verify_successful_regeneration!(cmd)
+ end
+ end
+ end
+
+ context 'when the user disagrees to regenerate keys' do
+ let(:cmd) { "#{gitlab_shell_path} key-100" }
+
+ it 'the recovery keys are not regenerated' do
+ Open3.popen2(env, cmd) do |stdin, stdout|
+ expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n")
+ expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n")
+
+ stdin.puts('no')
+
+ expect(stdout.flush.read).to eq(
+ "\nNew recovery codes have *not* been generated. Existing codes will remain valid.\n"
+ )
+ end
+ end
+ end
+
+ context 'when API error occurs' do
+ let(:cmd) { "#{gitlab_shell_path} key-101" }
+
+ context 'when the user agrees to regenerate keys' do
+ it 'the recovery keys are regenerated' do
+ Open3.popen2(env, cmd) do |stdin, stdout|
+ expect(stdout.gets).to eq("Are you sure you want to generate new two-factor recovery codes?\n")
+ expect(stdout.gets).to eq("Any existing recovery codes you saved will be invalidated. (yes/no)\n")
+
+ stdin.puts('yes')
+
+ expect(stdout.flush.read).to eq("\nAn error occurred while trying to generate new recovery codes.\nForbidden!\n")
+ end
+ end
+ end
+ end
+ end
+
+ let(:env) { {'SSH_CONNECTION' => 'fake', 'SSH_ORIGINAL_COMMAND' => '2fa_recovery_codes' } }
+
+ describe 'without go features' do
+ before(:context) do
+ write_config(
+ "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}",
+ )
+ end
+
+ it_behaves_like 'dialog for regenerating recovery keys'
+ end
+
+ describe 'with go features' do
+ before(:context) do
+ write_config(
+ "gitlab_url" => "http+unix://#{CGI.escape(tmp_socket_path)}",
+ "migration" => { "enabled" => true,
+ "features" => ["2fa_recovery_codes"] }
+ )
+ end
+
+ it_behaves_like 'dialog for regenerating recovery keys'
+ end
+end
diff --git a/spec/support/gitlab_shell_setup.rb b/spec/support/gitlab_shell_setup.rb
new file mode 100644
index 0000000..eddd2d1
--- /dev/null
+++ b/spec/support/gitlab_shell_setup.rb
@@ -0,0 +1,58 @@
+RSpec.shared_context 'gitlab shell', shared_context: :metadata do
+ def original_root_path
+ ROOT_PATH
+ end
+
+ def config_path
+ File.join(tmp_root_path, 'config.yml')
+ end
+
+ def write_config(config)
+ File.open(config_path, 'w') do |f|
+ f.write(config.to_yaml)
+ end
+ end
+
+ def tmp_root_path
+ @tmp_root_path ||= File.realpath(Dir.mktmpdir)
+ end
+
+ def mock_server(server)
+ raise NotImplementedError.new(
+ 'mock_server method must be implemented in order to include gitlab shell context'
+ )
+ end
+
+ # This has to be a relative path shorter than 100 bytes due to
+ # limitations in how Unix sockets work.
+ def tmp_socket_path
+ 'tmp/gitlab-shell-socket'
+ end
+
+ let(:gitlab_shell_path) { File.join(tmp_root_path, 'bin', 'gitlab-shell') }
+
+ before(:all) do
+ FileUtils.mkdir_p(File.dirname(tmp_socket_path))
+ FileUtils.touch(File.join(tmp_root_path, '.gitlab_shell_secret'))
+
+ @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
+
+ mock_server(@server)
+
+ @webrick_thread = Thread.new { @server.start }
+
+ sleep(0.1) while @webrick_thread.alive? && @server.status != :Running
+ raise "Couldn't start stub GitlabNet server" unless @server.status == :Running
+ system(original_root_path, 'bin/compile')
+
+ copy_dirs = ['bin', 'lib']
+ FileUtils.rm_rf(copy_dirs.map { |d| File.join(tmp_root_path, d) })
+ FileUtils.cp_r(copy_dirs, tmp_root_path)
+ end
+
+ after(:all) do
+ @server.shutdown if @server
+ @webrick_thread.join if @webrick_thread
+ FileUtils.rm_rf(tmp_root_path)
+ end
+end