summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsh McKenzie <amckenzie@gitlab.com>2020-09-21 04:47:01 +0000
committerAsh McKenzie <amckenzie@gitlab.com>2020-09-21 04:47:01 +0000
commit1a2bfecd2f0ebb8e31f9833e0522c4643797041b (patch)
treed17cf7bff45492a587027851bb6e0bcb493cff58
parentf100e7e83943b3bb5db232f5bf79a616fdba88f1 (diff)
parenta487572a904cc149840488eefdfe121173d8bcb5 (diff)
downloadgitlab-shell-1a2bfecd2f0ebb8e31f9833e0522c4643797041b.tar.gz
Merge branch 'sh-extract-context-from-env' into 'master'
Make it possible to propagate correlation ID across processes Closes #474 See merge request gitlab-org/gitlab-shell!413
-rw-r--r--client/client_test.go21
-rw-r--r--client/gitlabnet.go28
-rw-r--r--client/httpclient_test.go7
-rw-r--r--client/httpsclient_test.go5
-rw-r--r--cmd/check/main.go5
-rw-r--r--cmd/gitlab-shell-authorized-keys-check/main.go5
-rw-r--r--cmd/gitlab-shell-authorized-principals-check/main.go5
-rw-r--r--cmd/gitlab-shell/main.go5
-rw-r--r--go.sum1
-rw-r--r--internal/command/authorizedkeys/authorized_keys.go13
-rw-r--r--internal/command/authorizedkeys/authorized_keys_test.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals_test.go3
-rw-r--r--internal/command/command.go29
-rw-r--r--internal/command/command_test.go66
-rw-r--r--internal/command/discover/discover.go9
-rw-r--r--internal/command/discover/discover_test.go5
-rw-r--r--internal/command/healthcheck/healthcheck.go9
-rw-r--r--internal/command/healthcheck/healthcheck_test.go7
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate.go15
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate_test.go5
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken.go9
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken_test.go3
-rw-r--r--internal/command/receivepack/gitalycall_test.go3
-rw-r--r--internal/command/receivepack/receivepack.go12
-rw-r--r--internal/command/receivepack/receivepack_test.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier_test.go5
-rw-r--r--internal/command/shared/customaction/customaction.go13
-rw-r--r--internal/command/shared/customaction/customaction_test.go5
-rw-r--r--internal/command/twofactorrecover/twofactorrecover.go13
-rw-r--r--internal/command/twofactorrecover/twofactorrecover_test.go3
-rw-r--r--internal/command/uploadarchive/gitalycall_test.go3
-rw-r--r--internal/command/uploadarchive/uploadarchive.go10
-rw-r--r--internal/command/uploadarchive/uploadarchive_test.go3
-rw-r--r--internal/command/uploadpack/gitalycall_test.go3
-rw-r--r--internal/command/uploadpack/uploadpack.go12
-rw-r--r--internal/command/uploadpack/uploadpack_test.go3
-rw-r--r--internal/config/config.go2
-rw-r--r--internal/gitlabnet/accessverifier/client.go5
-rw-r--r--internal/gitlabnet/accessverifier/client_test.go9
-rw-r--r--internal/gitlabnet/authorizedkeys/client.go5
-rw-r--r--internal/gitlabnet/authorizedkeys/client_test.go5
-rw-r--r--internal/gitlabnet/discover/client.go9
-rw-r--r--internal/gitlabnet/discover/client_test.go9
-rw-r--r--internal/gitlabnet/healthcheck/client.go5
-rw-r--r--internal/gitlabnet/healthcheck/client_test.go3
-rw-r--r--internal/gitlabnet/lfsauthenticate/client.go5
-rw-r--r--internal/gitlabnet/lfsauthenticate/client_test.go5
-rw-r--r--internal/gitlabnet/personalaccesstoken/client.go11
-rw-r--r--internal/gitlabnet/personalaccesstoken/client_test.go9
-rw-r--r--internal/gitlabnet/twofactorrecover/client.go11
-rw-r--r--internal/gitlabnet/twofactorrecover/client_test.go9
53 files changed, 304 insertions, 157 deletions
diff --git a/client/client_test.go b/client/client_test.go
index e92093a..e0650b2 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -78,7 +79,7 @@ func TestClients(t *testing.T) {
func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
t.Run("Successful get", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
hook := testhelper.SetupLogger()
data := map[string]string{"key": "value"}
- response, err := client.Post("/post_endpoint", data)
+ response, err := client.Post(context.Background(), "/post_endpoint", data)
require.NoError(t, err)
require.NotNil(t, response)
@@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/missing")
+ response, err := client.Get(context.Background(), "/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/missing", map[string]string{})
+ response, err := client.Post(context.Background(), "/missing", map[string]string{})
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
func testErrorMessage(t *testing.T, client *GitlabNetClient) {
t.Run("Error with message for GET", func(t *testing.T) {
- response, err := client.Get("/error")
+ response, err := client.Get(context.Background(), "/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
t.Run("Error with message for POST", func(t *testing.T) {
- response, err := client.Post("/error", map[string]string{})
+ response, err := client.Post(context.Background(), "/error", map[string]string{})
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
@@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/broken")
+ response, err := client.Get(context.Background(), "/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/broken", map[string]string{})
+ response, err := client.Post(context.Background(), "/broken", map[string]string{})
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
- response, err := client.Get("/auth")
+ response, err := client.Get(context.Background(), "/auth")
require.NoError(t, err)
require.NotNil(t, response)
@@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
})
t.Run("Authentication headers for POST", func(t *testing.T) {
- response, err := client.Post("/auth", map[string]string{})
+ response, err := client.Post(context.Background(), "/auth", map[string]string{})
require.NoError(t, err)
require.NotNil(t, response)
diff --git a/client/gitlabnet.go b/client/gitlabnet.go
index 0657ca0..b908d04 100644
--- a/client/gitlabnet.go
+++ b/client/gitlabnet.go
@@ -11,8 +11,9 @@ import (
"strings"
"time"
- log "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/correlation"
+
+ log "github.com/sirupsen/logrus"
)
const (
@@ -59,7 +60,7 @@ func normalizePath(path string) string {
return path
}
-func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) {
+func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) {
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
@@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str
jsonReader = bytes.NewReader(jsonData)
}
- correlationID, err := correlation.RandomID()
- ctx := context.Background()
-
- if err != nil {
- log.WithError(err).Warn("unable to generate correlation ID")
- } else {
- ctx = correlation.ContextWithCorrelation(ctx, correlationID)
- }
-
request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader)
if err != nil {
return nil, "", err
}
+ correlationID := correlation.ExtractFromContext(ctx)
+
return request, correlationID, nil
}
@@ -102,16 +96,16 @@ func parseError(resp *http.Response) error {
}
-func (c *GitlabNetClient) Get(path string) (*http.Response, error) {
- return c.DoRequest(http.MethodGet, normalizePath(path), nil)
+func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil)
}
-func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) {
- return c.DoRequest(http.MethodPost, normalizePath(path), data)
+func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data)
}
-func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
- request, correlationID, err := newRequest(method, c.httpClient.Host, path, data)
+func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) {
+ request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data)
if err != nil {
return nil, err
}
diff --git a/client/httpclient_test.go b/client/httpclient_test.go
index fce0cd5..97e1384 100644
--- a/client/httpclient_test.go
+++ b/client/httpclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"fmt"
"io/ioutil"
@@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, username, password, requests)
defer cleanup()
- response, err := client.Get("/get_endpoint")
+ response, err := client.Get(context.Background(), "/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
- response, err = client.Post("/post_endpoint", nil)
+ response, err = client.Post(context.Background(), "/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
@@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, "", "", requests)
defer cleanup()
- _, err := client.Get("/empty_basic_auth")
+ _, err := client.Get(context.Background(), "/empty_basic_auth")
require.NoError(t, err)
}
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
index 1c7435f..0cf77e3 100644
--- a/client/httpsclient_test.go
+++ b/client/httpsclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"fmt"
"io/ioutil"
"net/http"
@@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned)
defer cleanup()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false)
defer cleanup()
- _, err := client.Get("/hello")
+ _, err := client.Get(context.Background(), "/hello")
require.Error(t, err)
assert.Equal(t, err.Error(), "Internal API unreachable")
diff --git a/cmd/check/main.go b/cmd/check/main.go
index e88b9fe..28634f4 100644
--- a/cmd/check/main.go
+++ b/cmd/check/main.go
@@ -38,7 +38,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
}
diff --git a/cmd/gitlab-shell-authorized-keys-check/main.go b/cmd/gitlab-shell-authorized-keys-check/main.go
index 4b3949c..3a7dcbb 100644
--- a/cmd/gitlab-shell-authorized-keys-check/main.go
+++ b/cmd/gitlab-shell-authorized-keys-check/main.go
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
diff --git a/cmd/gitlab-shell-authorized-principals-check/main.go b/cmd/gitlab-shell-authorized-principals-check/main.go
index fc46180..ea8d140 100644
--- a/cmd/gitlab-shell-authorized-principals-check/main.go
+++ b/cmd/gitlab-shell-authorized-principals-check/main.go
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
diff --git a/cmd/gitlab-shell/main.go b/cmd/gitlab-shell/main.go
index 8df781c..763aa5e 100644
--- a/cmd/gitlab-shell/main.go
+++ b/cmd/gitlab-shell/main.go
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
diff --git a/go.sum b/go.sum
index a90a46a..c4a5ed4 100644
--- a/go.sum
+++ b/go.sum
@@ -248,6 +248,7 @@ github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb6
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go
index 7554761..736aeed 100644
--- a/internal/command/authorizedkeys/authorized_keys.go
+++ b/internal/command/authorizedkeys/authorized_keys.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"fmt"
"strconv"
@@ -17,7 +18,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
// Do and return nothing when the expected and actual user don't match.
// This can happen when the user in sshd_config doesn't match the user
// trying to login. When nothing is printed, the user will be denied access.
@@ -27,15 +28,15 @@ func (c *Command) Execute() error {
return nil
}
- if err := c.printKeyLine(); err != nil {
+ if err := c.printKeyLine(ctx); err != nil {
return err
}
return nil
}
-func (c *Command) printKeyLine() error {
- response, err := c.getAuthorizedKey()
+func (c *Command) printKeyLine(ctx context.Context) error {
+ response, err := c.getAuthorizedKey(ctx)
if err != nil {
fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key))
return nil
@@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error {
return nil
}
-func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) {
+func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) {
client, err := authorizedkeys.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByKey(c.Args.Key)
+ return client.GetByKey(ctx, c.Args.Key)
}
diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go
index e12f4fa..f15c34d 100644
--- a/internal/command/authorizedkeys/authorized_keys_test.go
+++ b/internal/command/authorizedkeys/authorized_keys_test.go
@@ -2,6 +2,7 @@ package authorizedkeys
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -97,7 +98,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go
index ab5f2f8..44f6c47 100644
--- a/internal/command/authorizedprincipals/authorized_principals.go
+++ b/internal/command/authorizedprincipals/authorized_principals.go
@@ -1,6 +1,7 @@
package authorizedprincipals
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,7 +16,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if err := c.printPrincipalLines(); err != nil {
return err
}
diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go
index f11dd0f..ec97b65 100644
--- a/internal/command/authorizedprincipals/authorized_principals_test.go
+++ b/internal/command/authorizedprincipals/authorized_principals_test.go
@@ -2,6 +2,7 @@ package authorizedprincipals
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -54,7 +55,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/command.go b/internal/command/command.go
index 283b4a1..c69219b 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -1,6 +1,8 @@
package command
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -16,10 +18,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/tracing"
)
type Command interface {
- Execute() error
+ Execute(ctx context.Context) error
}
func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
@@ -35,6 +40,28 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re
return nil, disallowedcommand.Error
}
+// ContextWithCorrelationID() will always return a background Context
+// with a correlation ID. It will first attempt to extract the ID from
+// an environment variable. If is not available, a random one will be
+// generated.
+func ContextWithCorrelationID() (context.Context, func()) {
+ ctx, finished := tracing.ExtractFromEnv(context.Background())
+ defer finished()
+
+ correlationID := correlation.ExtractFromContext(ctx)
+ if correlationID == "" {
+ correlationID, err := correlation.RandomID()
+ if err != nil {
+ log.WithError(err).Warn("unable to generate correlation ID")
+ } else {
+ log.Info("generated random correlation ID")
+ ctx = correlation.ContextWithCorrelation(ctx, correlationID)
+ }
+ }
+
+ return ctx, finished
+}
+
func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
switch e.Name {
case executable.GitlabShell:
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index db55e7d..9160abf 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -2,6 +2,7 @@ package command
import (
"errors"
+ "os"
"testing"
"github.com/stretchr/testify/require"
@@ -20,6 +21,7 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+ "gitlab.com/gitlab-org/labkit/correlation"
)
var (
@@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) {
})
}
}
+
+func TestContextWithCorrelationID(t *testing.T) {
+ testCases := []struct {
+ name string
+ additionalEnv map[string]string
+ expectedCorrelationID string
+ }{
+ {
+ name: "no CORRELATION_ID in environment",
+ },
+ {
+ name: "CORRELATION_ID in environment",
+ additionalEnv: map[string]string{
+ "CORRELATION_ID": "abc123",
+ },
+ expectedCorrelationID: "abc123",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ resetEnvironment := addAdditionalEnv(tc.additionalEnv)
+ defer resetEnvironment()
+
+ ctx, finished := ContextWithCorrelationID()
+ require.NotNil(t, ctx, "ctx is nil")
+ require.NotNil(t, finished, "finished is nil")
+ correlationID := correlation.ExtractFromContext(ctx)
+ require.NotEmpty(t, correlationID)
+
+ if tc.expectedCorrelationID != "" {
+ require.Equal(t, tc.expectedCorrelationID, correlationID)
+ }
+ defer finished()
+ })
+ }
+}
+
+// addAdditionalEnv will configure additional environment values
+// and return a deferrable function to reset the environment to
+// it's original state after the test
+func addAdditionalEnv(envMap map[string]string) func() {
+ prevValues := map[string]string{}
+ unsetValues := []string{}
+ for k, v := range envMap {
+ value, exists := os.LookupEnv(k)
+ if exists {
+ prevValues[k] = value
+ } else {
+ unsetValues = append(unsetValues, k)
+ }
+ os.Setenv(k, v)
+ }
+
+ return func() {
+ for k, v := range prevValues {
+ os.Setenv(k, v)
+ }
+
+ for _, k := range unsetValues {
+ os.Unsetenv(k)
+ }
+
+ }
+}
diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go
index 3aa7456..822be32 100644
--- a/internal/command/discover/discover.go
+++ b/internal/command/discover/discover.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,8 +16,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.getUserInfo()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.getUserInfo(ctx)
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
@@ -30,11 +31,11 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) getUserInfo() (*discover.Response, error) {
+func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByCommandArgs(c.Args)
+ return client.GetByCommandArgs(ctx, c.Args)
}
diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go
index 8edbcb9..5431410 100644
--- a/internal/command/discover/discover_test.go
+++ b/internal/command/discover/discover_test.go
@@ -2,6 +2,7 @@ package discover
import (
"bytes"
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -83,7 +84,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
@@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, tc.expectedError)
diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go
index bbc73dc..b04eb0d 100644
--- a/internal/command/healthcheck/healthcheck.go
+++ b/internal/command/healthcheck/healthcheck.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
@@ -18,8 +19,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.runCheck()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.runCheck(ctx)
if err != nil {
return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
}
@@ -34,13 +35,13 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) runCheck() (*healthcheck.Response, error) {
+func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) {
client, err := healthcheck.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Check()
+ response, err := client.Check(ctx)
if err != nil {
return nil, err
}
diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go
index 7479bcb..d05e563 100644
--- a/internal/command/healthcheck/healthcheck_test.go
+++ b/internal/command/healthcheck/healthcheck_test.go
@@ -2,6 +2,7 @@ package healthcheck
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -53,7 +54,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
@@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err, "Redis available via internal API: FAILED")
require.Equal(t, "Internal API available: OK\n", buffer.String())
}
@@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)")
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go
index 2aaac2a..dab69ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -34,7 +35,7 @@ type Payload struct {
ExpiresIn int `json:"expires_in,omitempty"`
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) < 3 {
return disallowedcommand.Error
@@ -49,12 +50,12 @@ func (c *Command) Execute() error {
return err
}
- accessResponse, err := c.verifyAccess(action, repo)
+ accessResponse, err := c.verifyAccess(ctx, action, repo)
if err != nil {
return err
}
- payload, err := c.authenticate(operation, repo, accessResponse.UserId)
+ payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
return nil
@@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) {
return action, nil
}
-func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(action, repo)
+ return cmd.Verify(ctx, action, repo)
}
-func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) {
+func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) {
client, err := lfsauthenticate.NewClient(c.Config, c.Args)
if err != nil {
return nil, err
}
- response, err := client.Authenticate(operation, repo, userId)
+ response, err := client.Authenticate(ctx, operation, repo, userId)
if err != nil {
return nil, err
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go
index a1c7aec..55998ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate_test.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go
@@ -2,6 +2,7 @@ package lfsauthenticate
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
@@ -146,7 +147,7 @@ func TestLfsAuthenticateRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go
index b283890..6f3d03e 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"errors"
"fmt"
"strconv"
@@ -31,13 +32,13 @@ type tokenArgs struct {
ExpiresDate string // Calculated, a TTL is passed from command-line.
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
err := c.parseTokenArgs()
if err != nil {
return err
}
- response, err := c.getPersonalAccessToken()
+ response, err := c.getPersonalAccessToken(ctx)
if err != nil {
return err
}
@@ -76,11 +77,11 @@ func (c *Command) parseTokenArgs() error {
return nil
}
-func (c *Command) getPersonalAccessToken() (*personalaccesstoken.Response, error) {
+func (c *Command) getPersonalAccessToken(ctx context.Context) (*personalaccesstoken.Response, error) {
client, err := personalaccesstoken.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetPersonalAccessToken(c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
+ return client.GetPersonalAccessToken(ctx, c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
}
diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go
index bc748ab..5970142 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken_test.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go
@@ -2,6 +2,7 @@ package personalaccesstoken
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -170,7 +171,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
if tc.expectedError == "" {
assert.NoError(t, err)
diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go
index 8bee484..2a0c146 100644
--- a/internal/command/receivepack/gitalycall_test.go
+++ b/internal/command/receivepack/gitalycall_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -42,7 +43,7 @@ func TestReceivePack(t *testing.T) {
hook := testhelper.SetupLogger()
- err = cmd.Execute()
+ err = cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String())
diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go
index 7271264..4d5c686 100644
--- a/internal/command/receivepack/receivepack.go
+++ b/internal/command/receivepack/receivepack.go
@@ -1,6 +1,8 @@
package receivepack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: true,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go
index a4632b4..44cb680 100644
--- a/internal/command/receivepack/receivepack_test.go
+++ b/internal/command/receivepack/receivepack_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -18,7 +19,7 @@ func TestForbiddenAccess(t *testing.T) {
cmd, _, cleanup := setup(t, "disallowed", requests)
defer cleanup()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
@@ -26,7 +27,7 @@ func TestCustomReceivePack(t *testing.T) {
cmd, output, cleanup := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t))
defer cleanup()
- require.NoError(t, cmd.Execute())
+ require.NoError(t, cmd.Execute(context.Background()))
require.Equal(t, "customoutput", output.String())
}
diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go
index 5d2d709..9fcdde4 100644
--- a/internal/command/shared/accessverifier/accessverifier.go
+++ b/internal/command/shared/accessverifier/accessverifier.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -18,13 +19,13 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) {
+func (c *Command) Verify(ctx context.Context, action commandargs.CommandType, repo string) (*Response, error) {
client, err := accessverifier.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Verify(c.Args, action, repo)
+ response, err := client.Verify(ctx, c.Args, action, repo)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go
index 998e622..8ad87b8 100644
--- a/internal/command/shared/accessverifier/accessverifier_test.go
+++ b/internal/command/shared/accessverifier/accessverifier_test.go
@@ -2,6 +2,7 @@ package accessverifier
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -65,7 +66,7 @@ func TestMissingUser(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "2"}
- _, err := cmd.Verify(action, repo)
+ _, err := cmd.Verify(context.Background(), action, repo)
require.Equal(t, "missing user", err.Error())
}
@@ -75,7 +76,7 @@ func TestConsoleMessages(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "1"}
- cmd.Verify(action, repo)
+ cmd.Verify(context.Background(), action, repo)
require.Equal(t, "remote: \nremote: console\nremote: message\nremote: \n", errBuf.String())
require.Empty(t, outBuf.String())
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go
index 2ba1091..0675d36 100644
--- a/internal/command/shared/customaction/customaction.go
+++ b/internal/command/shared/customaction/customaction.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/client"
@@ -34,7 +35,7 @@ type Command struct {
EOFSent bool
}
-func (c *Command) Execute(response *accessverifier.Response) error {
+func (c *Command) Execute(ctx context.Context, response *accessverifier.Response) error {
data := response.Payload.Data
apiEndpoints := data.ApiEndpoints
@@ -42,10 +43,10 @@ func (c *Command) Execute(response *accessverifier.Response) error {
return errors.New("Custom action error: Empty API endpoints")
}
- return c.processApiEndpoints(response)
+ return c.processApiEndpoints(ctx, response)
}
-func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
+func (c *Command) processApiEndpoints(ctx context.Context, response *accessverifier.Response) error {
client, err := gitlabnet.GetClient(c.Config)
if err != nil {
@@ -64,7 +65,7 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
log.WithFields(fields).Info("Performing custom action")
- response, err := c.performRequest(client, endpoint, request)
+ response, err := c.performRequest(ctx, client, endpoint, request)
if err != nil {
return err
}
@@ -95,8 +96,8 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
return nil
}
-func (c *Command) performRequest(client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
- response, err := client.DoRequest(http.MethodPost, endpoint, request)
+func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
+ response, err := client.DoRequest(ctx, http.MethodPost, endpoint, request)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go
index 46c5f32..119da5b 100644
--- a/internal/command/shared/customaction/customaction_test.go
+++ b/internal/command/shared/customaction/customaction_test.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -78,7 +79,7 @@ func TestExecuteEOFSent(t *testing.T) {
EOFSent: true,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
@@ -148,7 +149,7 @@ func TestExecuteNoEOFSent(t *testing.T) {
EOFSent: false,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go
index 2f13cc5..f0a9e7b 100644
--- a/internal/command/twofactorrecover/twofactorrecover.go
+++ b/internal/command/twofactorrecover/twofactorrecover.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"fmt"
"strings"
@@ -16,9 +17,9 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if c.canContinue() {
- c.displayRecoveryCodes()
+ c.displayRecoveryCodes(ctx)
} else {
fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
}
@@ -38,8 +39,8 @@ func (c *Command) canContinue() bool {
return answer == "yes"
}
-func (c *Command) displayRecoveryCodes() {
- codes, err := c.getRecoveryCodes()
+func (c *Command) displayRecoveryCodes(ctx context.Context) {
+ codes, err := c.getRecoveryCodes(ctx)
if err == nil {
messageWithCodes :=
@@ -54,12 +55,12 @@ func (c *Command) displayRecoveryCodes() {
}
}
-func (c *Command) getRecoveryCodes() ([]string, error) {
+func (c *Command) getRecoveryCodes(ctx context.Context) ([]string, error) {
client, err := twofactorrecover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetRecoveryCodes(c.Args)
+ return client.GetRecoveryCodes(ctx, c.Args)
}
diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go
index d2f931b..ea6abd6 100644
--- a/internal/command/twofactorrecover/twofactorrecover_test.go
+++ b/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -2,6 +2,7 @@ package twofactorrecover
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -127,7 +128,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go
index eaeb2b7..f74093a 100644
--- a/internal/command/uploadarchive/gitalycall_test.go
+++ b/internal/command/uploadarchive/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -38,7 +39,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadArchive: "+repo, output.String())
diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go
index 9d4fbe0..178b42b 100644
--- a/internal/command/uploadarchive/uploadarchive.go
+++ b/internal/command/uploadarchive/uploadarchive.go
@@ -1,6 +1,8 @@
package uploadarchive
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -14,14 +16,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -29,8 +31,8 @@ func (c *Command) Execute() error {
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go
index 7b03009..5426569 100644
--- a/internal/command/uploadarchive/uploadarchive_test.go
+++ b/internal/command/uploadarchive/uploadarchive_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go
index d6762a2..22189b8 100644
--- a/internal/command/uploadpack/gitalycall_test.go
+++ b/internal/command/uploadpack/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -37,7 +38,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadPack: "+repo, output.String())
diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go
index 56814d7..fca3823 100644
--- a/internal/command/uploadpack/uploadpack.go
+++ b/internal/command/uploadpack/uploadpack.go
@@ -1,6 +1,8 @@
package uploadpack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: false,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go
index 7ea8e5d..20edb57 100644
--- a/internal/command/uploadpack/uploadpack_test.go
+++ b/internal/command/uploadpack/uploadpack_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
diff --git a/internal/config/config.go b/internal/config/config.go
index e7abd59..79c2a36 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -37,7 +37,7 @@ type Config struct {
Secret string `yaml:"secret"`
SslCertDir string `yaml:"ssl_cert_dir"`
HttpSettings HttpSettingsConfig `yaml:"http_settings"`
- HttpClient *client.HttpClient
+ HttpClient *client.HttpClient `-`
}
func (c *Config) GetHttpClient() *client.HttpClient {
diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go
index 00b9d76..7e120e0 100644
--- a/internal/gitlabnet/accessverifier/client.go
+++ b/internal/gitlabnet/accessverifier/client.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"fmt"
"net/http"
@@ -77,7 +78,7 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{client: client}, nil
}
-func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) {
+func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) {
request := &Request{Action: action, Repo: repo, Protocol: protocol, Changes: anyChanges}
if args.GitlabUsername != "" {
@@ -88,7 +89,7 @@ func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType,
request.CheckIp = sshenv.LocalAddr()
- response, err := c.client.Post("/allowed", request)
+ response, err := c.client.Post(ctx, "/allowed", request)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/accessverifier/client_test.go b/internal/gitlabnet/accessverifier/client_test.go
index 7ddbb5e..3681968 100644
--- a/internal/gitlabnet/accessverifier/client_test.go
+++ b/internal/gitlabnet/accessverifier/client_test.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -73,7 +74,7 @@ func TestSuccessfulResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- result, err := client.Verify(tc.args, receivePackAction, repo)
+ result, err := client.Verify(context.Background(), tc.args, receivePackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse(tc.who)
@@ -87,7 +88,7 @@ func TestGeoPushGetCustomAction(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "custom"}
- result, err := client.Verify(args, receivePackAction, repo)
+ result, err := client.Verify(context.Background(), args, receivePackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse("user-1")
@@ -110,7 +111,7 @@ func TestGeoPullGetCustomAction(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "custom"}
- result, err := client.Verify(args, uploadPackAction, repo)
+ result, err := client.Verify(context.Background(), args, uploadPackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse("user-1")
@@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
- resp, err := client.Verify(args, receivePackAction, repo)
+ resp, err := client.Verify(context.Background(), args, receivePackAction, repo)
require.EqualError(t, err, tc.expectedError)
require.Nil(t, resp)
diff --git a/internal/gitlabnet/authorizedkeys/client.go b/internal/gitlabnet/authorizedkeys/client.go
index e4fec28..0a00034 100644
--- a/internal/gitlabnet/authorizedkeys/client.go
+++ b/internal/gitlabnet/authorizedkeys/client.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"fmt"
"net/url"
@@ -32,13 +33,13 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetByKey(key string) (*Response, error) {
+func (c *Client) GetByKey(ctx context.Context, key string) (*Response, error) {
path, err := pathWithKey(key)
if err != nil {
return nil, err
}
- response, err := c.client.Get(path)
+ response, err := c.client.Get(ctx, path)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/authorizedkeys/client_test.go b/internal/gitlabnet/authorizedkeys/client_test.go
index c9c76a1..e72840c 100644
--- a/internal/gitlabnet/authorizedkeys/client_test.go
+++ b/internal/gitlabnet/authorizedkeys/client_test.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -48,7 +49,7 @@ func TestGetByKey(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
- result, err := client.GetByKey("key")
+ result, err := client.GetByKey(context.Background(), "key")
require.NoError(t, err)
require.Equal(t, &Response{Id: 1, Key: "public-key"}, result)
}
@@ -86,7 +87,7 @@ func TestGetByKeyErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- resp, err := client.GetByKey(tc.key)
+ resp, err := client.GetByKey(context.Background(), tc.key)
require.EqualError(t, err, tc.expectedError)
require.Nil(t, resp)
diff --git a/internal/gitlabnet/discover/client.go b/internal/gitlabnet/discover/client.go
index d1e1906..cc7f516 100644
--- a/internal/gitlabnet/discover/client.go
+++ b/internal/gitlabnet/discover/client.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"fmt"
"net/http"
"net/url"
@@ -31,7 +32,7 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) {
+func (c *Client) GetByCommandArgs(ctx context.Context, args *commandargs.Shell) (*Response, error) {
params := url.Values{}
if args.GitlabUsername != "" {
params.Add("username", args.GitlabUsername)
@@ -43,13 +44,13 @@ func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) {
return nil, fmt.Errorf("who='' is invalid")
}
- return c.getResponse(params)
+ return c.getResponse(ctx, params)
}
-func (c *Client) getResponse(params url.Values) (*Response, error) {
+func (c *Client) getResponse(ctx context.Context, params url.Values) (*Response, error) {
path := "/discover?" + params.Encode()
- response, err := c.client.Get(path)
+ response, err := c.client.Get(ctx, path)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/discover/client_test.go b/internal/gitlabnet/discover/client_test.go
index 96b3162..cb46dd7 100644
--- a/internal/gitlabnet/discover/client_test.go
+++ b/internal/gitlabnet/discover/client_test.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -62,7 +63,7 @@ func TestGetByKeyId(t *testing.T) {
params := url.Values{}
params.Add("key_id", "1")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
}
@@ -73,7 +74,7 @@ func TestGetByUsername(t *testing.T) {
params := url.Values{}
params.Add("username", "jane-doe")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
}
@@ -84,7 +85,7 @@ func TestMissingUser(t *testing.T) {
params := url.Values{}
params.Add("username", "missing")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.True(t, result.IsAnonymous())
}
@@ -119,7 +120,7 @@ func TestErrorResponses(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
params := url.Values{}
params.Add("username", tc.fakeUsername)
- resp, err := client.getResponse(params)
+ resp, err := client.getResponse(context.Background(), params)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)
diff --git a/internal/gitlabnet/healthcheck/client.go b/internal/gitlabnet/healthcheck/client.go
index 09b45af..f148504 100644
--- a/internal/gitlabnet/healthcheck/client.go
+++ b/internal/gitlabnet/healthcheck/client.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"fmt"
"net/http"
@@ -34,8 +35,8 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) Check() (*Response, error) {
- resp, err := c.client.Get(checkPath)
+func (c *Client) Check(ctx context.Context) (*Response, error) {
+ resp, err := c.client.Get(ctx, checkPath)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/healthcheck/client_test.go b/internal/gitlabnet/healthcheck/client_test.go
index c66ddbd..81ae209 100644
--- a/internal/gitlabnet/healthcheck/client_test.go
+++ b/internal/gitlabnet/healthcheck/client_test.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -33,7 +34,7 @@ func TestCheck(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
- result, err := client.Check()
+ result, err := client.Check(context.Background())
require.NoError(t, err)
require.Equal(t, testResponse, result)
}
diff --git a/internal/gitlabnet/lfsauthenticate/client.go b/internal/gitlabnet/lfsauthenticate/client.go
index fffc225..834cbe1 100644
--- a/internal/gitlabnet/lfsauthenticate/client.go
+++ b/internal/gitlabnet/lfsauthenticate/client.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"fmt"
"net/http"
"strings"
@@ -40,7 +41,7 @@ func NewClient(config *config.Config, args *commandargs.Shell) (*Client, error)
return &Client{config: config, client: client, args: args}, nil
}
-func (c *Client) Authenticate(operation, repo, userId string) (*Response, error) {
+func (c *Client) Authenticate(ctx context.Context, operation, repo, userId string) (*Response, error) {
request := &Request{Operation: operation, Repo: repo}
if c.args.GitlabKeyId != "" {
request.KeyId = c.args.GitlabKeyId
@@ -48,7 +49,7 @@ func (c *Client) Authenticate(operation, repo, userId string) (*Response, error)
request.UserId = strings.TrimPrefix(userId, "user-")
}
- response, err := c.client.Post("/lfs_authenticate", request)
+ response, err := c.client.Post(ctx, "/lfs_authenticate", request)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/lfsauthenticate/client_test.go b/internal/gitlabnet/lfsauthenticate/client_test.go
index 82e364b..2bd0451 100644
--- a/internal/gitlabnet/lfsauthenticate/client_test.go
+++ b/internal/gitlabnet/lfsauthenticate/client_test.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -85,7 +86,7 @@ func TestFailedRequests(t *testing.T) {
operation := tc.args.SshArgs[2]
- _, err = client.Authenticate(operation, repo, "")
+ _, err = client.Authenticate(context.Background(), operation, repo, "")
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
@@ -119,7 +120,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, err := NewClient(&config.Config{GitlabUrl: url}, args)
require.NoError(t, err)
- response, err := client.Authenticate(operation, repo, "")
+ response, err := client.Authenticate(context.Background(), operation, repo, "")
require.NoError(t, err)
expectedResponse := &Response{
diff --git a/internal/gitlabnet/personalaccesstoken/client.go b/internal/gitlabnet/personalaccesstoken/client.go
index 588bead..abbd395 100644
--- a/internal/gitlabnet/personalaccesstoken/client.go
+++ b/internal/gitlabnet/personalaccesstoken/client.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"errors"
"fmt"
"net/http"
@@ -42,13 +43,13 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetPersonalAccessToken(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) {
- requestBody, err := c.getRequestBody(args, name, scopes, expiresAt)
+func (c *Client) GetPersonalAccessToken(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) {
+ requestBody, err := c.getRequestBody(ctx, args, name, scopes, expiresAt)
if err != nil {
return nil, err
}
- response, err := c.client.Post("/personal_access_token", requestBody)
+ response, err := c.client.Post(ctx, "/personal_access_token", requestBody)
if err != nil {
return nil, err
}
@@ -70,7 +71,7 @@ func parse(hr *http.Response) (*Response, error) {
return response, nil
}
-func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) {
+func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) {
client, err := discover.NewClient(c.config)
if err != nil {
return nil, err
@@ -83,7 +84,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[]
return requestBody, nil
}
- userInfo, err := client.GetByCommandArgs(args)
+ userInfo, err := client.GetByCommandArgs(ctx, args)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/personalaccesstoken/client_test.go b/internal/gitlabnet/personalaccesstoken/client_test.go
index de45975..140a7b2 100644
--- a/internal/gitlabnet/personalaccesstoken/client_test.go
+++ b/internal/gitlabnet/personalaccesstoken/client_test.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -90,7 +91,7 @@ func TestGetPersonalAccessTokenByKeyId(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: "0"}
result, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"read_api", "read_repository"}, "",
+ context.Background(), args, "newtoken", &[]string{"read_api", "read_repository"}, "",
)
assert.NoError(t, err)
response := &Response{
@@ -109,7 +110,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) {
args := &commandargs.Shell{GitlabUsername: "jane-doe"}
result, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.NoError(t, err)
response := &Response{true, "YXuxvUgCEmeePY3G1YAa", []string{"api"}, "", ""}
@@ -122,7 +123,7 @@ func TestMissingUser(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: "1"}
_, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.Equal(t, "missing user", err.Error())
}
@@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
resp, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.EqualError(t, err, tc.expectedError)
diff --git a/internal/gitlabnet/twofactorrecover/client.go b/internal/gitlabnet/twofactorrecover/client.go
index d22daca..456f892 100644
--- a/internal/gitlabnet/twofactorrecover/client.go
+++ b/internal/gitlabnet/twofactorrecover/client.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"errors"
"fmt"
"net/http"
@@ -37,14 +38,14 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetRecoveryCodes(args *commandargs.Shell) ([]string, error) {
- requestBody, err := c.getRequestBody(args)
+func (c *Client) GetRecoveryCodes(ctx context.Context, args *commandargs.Shell) ([]string, error) {
+ requestBody, err := c.getRequestBody(ctx, args)
if err != nil {
return nil, err
}
- response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
+ response, err := c.client.Post(ctx, "/two_factor_recovery_codes", requestBody)
if err != nil {
return nil, err
}
@@ -66,7 +67,7 @@ func parse(hr *http.Response) ([]string, error) {
return response.RecoveryCodes, nil
}
-func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) {
+func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell) (*RequestBody, error) {
client, err := discover.NewClient(c.config)
if err != nil {
@@ -77,7 +78,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) {
if args.GitlabKeyId != "" {
requestBody = &RequestBody{KeyId: args.GitlabKeyId}
} else {
- userInfo, err := client.GetByCommandArgs(args)
+ userInfo, err := client.GetByCommandArgs(ctx, args)
if err != nil {
return nil, err
diff --git a/internal/gitlabnet/twofactorrecover/client_test.go b/internal/gitlabnet/twofactorrecover/client_test.go
index 372afec..46291aa 100644
--- a/internal/gitlabnet/twofactorrecover/client_test.go
+++ b/internal/gitlabnet/twofactorrecover/client_test.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -85,7 +86,7 @@ func TestGetRecoveryCodesByKeyId(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabKeyId: "0"}
- result, err := client.GetRecoveryCodes(args)
+ result, err := client.GetRecoveryCodes(context.Background(), args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 1", "codes 1"}, result)
}
@@ -95,7 +96,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "jane-doe"}
- result, err := client.GetRecoveryCodes(args)
+ result, err := client.GetRecoveryCodes(context.Background(), args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 2", "codes 2"}, result)
}
@@ -105,7 +106,7 @@ func TestMissingUser(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabKeyId: "1"}
- _, err := client.GetRecoveryCodes(args)
+ _, err := client.GetRecoveryCodes(context.Background(), args)
assert.Equal(t, "missing user", err.Error())
}
@@ -138,7 +139,7 @@ func TestErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
- resp, err := client.GetRecoveryCodes(args)
+ resp, err := client.GetRecoveryCodes(context.Background(), args)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)