summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStan Hu <stanhu@gmail.com>2020-09-19 03:34:49 -0700
committerStan Hu <stanhu@gmail.com>2020-09-20 21:40:40 -0700
commita487572a904cc149840488eefdfe121173d8bcb5 (patch)
treed17cf7bff45492a587027851bb6e0bcb493cff58
parentf100e7e83943b3bb5db232f5bf79a616fdba88f1 (diff)
downloadgitlab-shell-a487572a904cc149840488eefdfe121173d8bcb5.tar.gz
Make it possible to propagate correlation ID across processes
Previously, gitlab-shell did not pass a context through the application. Correlation IDs were generated down the call stack instead of passed around from the start execution. This has several potential downsides: 1. It's easier for programming mistakes to be made in future that lead to multiple correlation IDs being generated for a single request. 2. Correlation IDs cannot be passed in from upstream requests 3. Other advantages of context passing, such as distributed tracing is not possible. This commit changes the behavior: 1. Extract the correlation ID from the environment at the start of the application. 2. If no correlation ID exists, generate a random one. 3. Pass the correlation ID to the GitLabNet API requests. This change also enables other clients of GitLabNet (e.g. Gitaly) to pass along the correlation ID in the internal API requests (https://gitlab.com/gitlab-org/gitaly/-/issues/2725). Fixes https://gitlab.com/gitlab-org/gitlab-shell/-/issues/474
-rw-r--r--client/client_test.go21
-rw-r--r--client/gitlabnet.go28
-rw-r--r--client/httpclient_test.go7
-rw-r--r--client/httpsclient_test.go5
-rw-r--r--cmd/check/main.go5
-rw-r--r--cmd/gitlab-shell-authorized-keys-check/main.go5
-rw-r--r--cmd/gitlab-shell-authorized-principals-check/main.go5
-rw-r--r--cmd/gitlab-shell/main.go5
-rw-r--r--go.sum1
-rw-r--r--internal/command/authorizedkeys/authorized_keys.go13
-rw-r--r--internal/command/authorizedkeys/authorized_keys_test.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals_test.go3
-rw-r--r--internal/command/command.go29
-rw-r--r--internal/command/command_test.go66
-rw-r--r--internal/command/discover/discover.go9
-rw-r--r--internal/command/discover/discover_test.go5
-rw-r--r--internal/command/healthcheck/healthcheck.go9
-rw-r--r--internal/command/healthcheck/healthcheck_test.go7
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate.go15
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate_test.go5
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken.go9
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken_test.go3
-rw-r--r--internal/command/receivepack/gitalycall_test.go3
-rw-r--r--internal/command/receivepack/receivepack.go12
-rw-r--r--internal/command/receivepack/receivepack_test.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier_test.go5
-rw-r--r--internal/command/shared/customaction/customaction.go13
-rw-r--r--internal/command/shared/customaction/customaction_test.go5
-rw-r--r--internal/command/twofactorrecover/twofactorrecover.go13
-rw-r--r--internal/command/twofactorrecover/twofactorrecover_test.go3
-rw-r--r--internal/command/uploadarchive/gitalycall_test.go3
-rw-r--r--internal/command/uploadarchive/uploadarchive.go10
-rw-r--r--internal/command/uploadarchive/uploadarchive_test.go3
-rw-r--r--internal/command/uploadpack/gitalycall_test.go3
-rw-r--r--internal/command/uploadpack/uploadpack.go12
-rw-r--r--internal/command/uploadpack/uploadpack_test.go3
-rw-r--r--internal/config/config.go2
-rw-r--r--internal/gitlabnet/accessverifier/client.go5
-rw-r--r--internal/gitlabnet/accessverifier/client_test.go9
-rw-r--r--internal/gitlabnet/authorizedkeys/client.go5
-rw-r--r--internal/gitlabnet/authorizedkeys/client_test.go5
-rw-r--r--internal/gitlabnet/discover/client.go9
-rw-r--r--internal/gitlabnet/discover/client_test.go9
-rw-r--r--internal/gitlabnet/healthcheck/client.go5
-rw-r--r--internal/gitlabnet/healthcheck/client_test.go3
-rw-r--r--internal/gitlabnet/lfsauthenticate/client.go5
-rw-r--r--internal/gitlabnet/lfsauthenticate/client_test.go5
-rw-r--r--internal/gitlabnet/personalaccesstoken/client.go11
-rw-r--r--internal/gitlabnet/personalaccesstoken/client_test.go9
-rw-r--r--internal/gitlabnet/twofactorrecover/client.go11
-rw-r--r--internal/gitlabnet/twofactorrecover/client_test.go9
53 files changed, 304 insertions, 157 deletions
diff --git a/client/client_test.go b/client/client_test.go
index e92093a..e0650b2 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -78,7 +79,7 @@ func TestClients(t *testing.T) {
func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
t.Run("Successful get", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
hook := testhelper.SetupLogger()
data := map[string]string{"key": "value"}
- response, err := client.Post("/post_endpoint", data)
+ response, err := client.Post(context.Background(), "/post_endpoint", data)
require.NoError(t, err)
require.NotNil(t, response)
@@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/missing")
+ response, err := client.Get(context.Background(), "/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/missing", map[string]string{})
+ response, err := client.Post(context.Background(), "/missing", map[string]string{})
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
func testErrorMessage(t *testing.T, client *GitlabNetClient) {
t.Run("Error with message for GET", func(t *testing.T) {
- response, err := client.Get("/error")
+ response, err := client.Get(context.Background(), "/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
t.Run("Error with message for POST", func(t *testing.T) {
- response, err := client.Post("/error", map[string]string{})
+ response, err := client.Post(context.Background(), "/error", map[string]string{})
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
@@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/broken")
+ response, err := client.Get(context.Background(), "/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/broken", map[string]string{})
+ response, err := client.Post(context.Background(), "/broken", map[string]string{})
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
- response, err := client.Get("/auth")
+ response, err := client.Get(context.Background(), "/auth")
require.NoError(t, err)
require.NotNil(t, response)
@@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
})
t.Run("Authentication headers for POST", func(t *testing.T) {
- response, err := client.Post("/auth", map[string]string{})
+ response, err := client.Post(context.Background(), "/auth", map[string]string{})
require.NoError(t, err)
require.NotNil(t, response)
diff --git a/client/gitlabnet.go b/client/gitlabnet.go
index 0657ca0..b908d04 100644
--- a/client/gitlabnet.go
+++ b/client/gitlabnet.go
@@ -11,8 +11,9 @@ import (
"strings"
"time"
- log "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/correlation"
+
+ log "github.com/sirupsen/logrus"
)
const (
@@ -59,7 +60,7 @@ func normalizePath(path string) string {
return path
}
-func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) {
+func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) {
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
@@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str
jsonReader = bytes.NewReader(jsonData)
}
- correlationID, err := correlation.RandomID()
- ctx := context.Background()
-
- if err != nil {
- log.WithError(err).Warn("unable to generate correlation ID")
- } else {
- ctx = correlation.ContextWithCorrelation(ctx, correlationID)
- }
-
request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader)
if err != nil {
return nil, "", err
}
+ correlationID := correlation.ExtractFromContext(ctx)
+
return request, correlationID, nil
}
@@ -102,16 +96,16 @@ func parseError(resp *http.Response) error {
}
-func (c *GitlabNetClient) Get(path string) (*http.Response, error) {
- return c.DoRequest(http.MethodGet, normalizePath(path), nil)
+func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil)
}
-func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) {
- return c.DoRequest(http.MethodPost, normalizePath(path), data)
+func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data)
}
-func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
- request, correlationID, err := newRequest(method, c.httpClient.Host, path, data)
+func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) {
+ request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data)
if err != nil {
return nil, err
}
diff --git a/client/httpclient_test.go b/client/httpclient_test.go
index fce0cd5..97e1384 100644
--- a/client/httpclient_test.go
+++ b/client/httpclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"fmt"
"io/ioutil"
@@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, username, password, requests)
defer cleanup()
- response, err := client.Get("/get_endpoint")
+ response, err := client.Get(context.Background(), "/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
- response, err = client.Post("/post_endpoint", nil)
+ response, err = client.Post(context.Background(), "/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
@@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, "", "", requests)
defer cleanup()
- _, err := client.Get("/empty_basic_auth")
+ _, err := client.Get(context.Background(), "/empty_basic_auth")
require.NoError(t, err)
}
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
index 1c7435f..0cf77e3 100644
--- a/client/httpsclient_test.go
+++ b/client/httpsclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"fmt"
"io/ioutil"
"net/http"
@@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned)
defer cleanup()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false)
defer cleanup()
- _, err := client.Get("/hello")
+ _, err := client.Get(context.Background(), "/hello")
require.Error(t, err)
assert.Equal(t, err.Error(), "Internal API unreachable")
diff --git a/cmd/check/main.go b/cmd/check/main.go
index e88b9fe..28634f4 100644
--- a/cmd/check/main.go
+++ b/cmd/check/main.go
@@ -38,7 +38,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
}
diff --git a/cmd/gitlab-shell-authorized-keys-check/main.go b/cmd/gitlab-shell-authorized-keys-check/main.go
index 4b3949c..3a7dcbb 100644
--- a/cmd/gitlab-shell-authorized-keys-check/main.go
+++ b/cmd/gitlab-shell-authorized-keys-check/main.go
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
diff --git a/cmd/gitlab-shell-authorized-principals-check/main.go b/cmd/gitlab-shell-authorized-principals-check/main.go
index fc46180..ea8d140 100644
--- a/cmd/gitlab-shell-authorized-principals-check/main.go
+++ b/cmd/gitlab-shell-authorized-principals-check/main.go
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
diff --git a/cmd/gitlab-shell/main.go b/cmd/gitlab-shell/main.go
index 8df781c..763aa5e 100644
--- a/cmd/gitlab-shell/main.go
+++ b/cmd/gitlab-shell/main.go
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
- if err = cmd.Execute(); err != nil {
+ ctx, finished := command.ContextWithCorrelationID()
+ defer finished()
+
+ if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
diff --git a/go.sum b/go.sum
index a90a46a..c4a5ed4 100644
--- a/go.sum
+++ b/go.sum
@@ -248,6 +248,7 @@ github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb6
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go
index 7554761..736aeed 100644
--- a/internal/command/authorizedkeys/authorized_keys.go
+++ b/internal/command/authorizedkeys/authorized_keys.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"fmt"
"strconv"
@@ -17,7 +18,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
// Do and return nothing when the expected and actual user don't match.
// This can happen when the user in sshd_config doesn't match the user
// trying to login. When nothing is printed, the user will be denied access.
@@ -27,15 +28,15 @@ func (c *Command) Execute() error {
return nil
}
- if err := c.printKeyLine(); err != nil {
+ if err := c.printKeyLine(ctx); err != nil {
return err
}
return nil
}
-func (c *Command) printKeyLine() error {
- response, err := c.getAuthorizedKey()
+func (c *Command) printKeyLine(ctx context.Context) error {
+ response, err := c.getAuthorizedKey(ctx)
if err != nil {
fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key))
return nil
@@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error {
return nil
}
-func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) {
+func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) {
client, err := authorizedkeys.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByKey(c.Args.Key)
+ return client.GetByKey(ctx, c.Args.Key)
}
diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go
index e12f4fa..f15c34d 100644
--- a/internal/command/authorizedkeys/authorized_keys_test.go
+++ b/internal/command/authorizedkeys/authorized_keys_test.go
@@ -2,6 +2,7 @@ package authorizedkeys
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -97,7 +98,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go
index ab5f2f8..44f6c47 100644
--- a/internal/command/authorizedprincipals/authorized_principals.go
+++ b/internal/command/authorizedprincipals/authorized_principals.go
@@ -1,6 +1,7 @@
package authorizedprincipals
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,7 +16,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if err := c.printPrincipalLines(); err != nil {
return err
}
diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go
index f11dd0f..ec97b65 100644
--- a/internal/command/authorizedprincipals/authorized_principals_test.go
+++ b/internal/command/authorizedprincipals/authorized_principals_test.go
@@ -2,6 +2,7 @@ package authorizedprincipals
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -54,7 +55,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/command.go b/internal/command/command.go
index 283b4a1..c69219b 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -1,6 +1,8 @@
package command
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -16,10 +18,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/tracing"
)
type Command interface {
- Execute() error
+ Execute(ctx context.Context) error
}
func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
@@ -35,6 +40,28 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re
return nil, disallowedcommand.Error
}
+// ContextWithCorrelationID() will always return a background Context
+// with a correlation ID. It will first attempt to extract the ID from
+// an environment variable. If is not available, a random one will be
+// generated.
+func ContextWithCorrelationID() (context.Context, func()) {
+ ctx, finished := tracing.ExtractFromEnv(context.Background())
+ defer finished()
+
+ correlationID := correlation.ExtractFromContext(ctx)
+ if correlationID == "" {
+ correlationID, err := correlation.RandomID()
+ if err != nil {
+ log.WithError(err).Warn("unable to generate correlation ID")
+ } else {
+ log.Info("generated random correlation ID")
+ ctx = correlation.ContextWithCorrelation(ctx, correlationID)
+ }
+ }
+
+ return ctx, finished
+}
+
func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
switch e.Name {
case executable.GitlabShell:
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index db55e7d..9160abf 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -2,6 +2,7 @@ package command
import (
"errors"
+ "os"
"testing"
"github.com/stretchr/testify/require"
@@ -20,6 +21,7 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+ "gitlab.com/gitlab-org/labkit/correlation"
)
var (
@@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) {
})
}
}
+
+func TestContextWithCorrelationID(t *testing.T) {
+ testCases := []struct {
+ name string
+ additionalEnv map[string]string
+ expectedCorrelationID string
+ }{
+ {
+ name: "no CORRELATION_ID in environment",
+ },
+ {
+ name: "CORRELATION_ID in environment",
+ additionalEnv: map[string]string{
+ "CORRELATION_ID": "abc123",
+ },
+ expectedCorrelationID: "abc123",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ resetEnvironment := addAdditionalEnv(tc.additionalEnv)
+ defer resetEnvironment()
+
+ ctx, finished := ContextWithCorrelationID()
+ require.NotNil(t, ctx, "ctx is nil")
+ require.NotNil(t, finished, "finished is nil")
+ correlationID := correlation.ExtractFromContext(ctx)
+ require.NotEmpty(t, correlationID)
+
+ if tc.expectedCorrelationID != "" {
+ require.Equal(t, tc.expectedCorrelationID, correlationID)
+ }
+ defer finished()
+ })
+ }
+}
+
+// addAdditionalEnv will configure additional environment values
+// and return a deferrable function to reset the environment to
+// it's original state after the test
+func addAdditionalEnv(envMap map[string]string) func() {
+ prevValues := map[string]string{}
+ unsetValues := []string{}
+ for k, v := range envMap {
+ value, exists := os.LookupEnv(k)
+ if exists {
+ prevValues[k] = value
+ } else {
+ unsetValues = append(unsetValues, k)
+ }
+ os.Setenv(k, v)
+ }
+
+ return func() {
+ for k, v := range prevValues {
+ os.Setenv(k, v)
+ }
+
+ for _, k := range unsetValues {
+ os.Unsetenv(k)
+ }
+
+ }
+}
diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go
index 3aa7456..822be32 100644
--- a/internal/command/discover/discover.go
+++ b/internal/command/discover/discover.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,8 +16,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.getUserInfo()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.getUserInfo(ctx)
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
@@ -30,11 +31,11 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) getUserInfo() (*discover.Response, error) {
+func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByCommandArgs(c.Args)
+ return client.GetByCommandArgs(ctx, c.Args)
}
diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go
index 8edbcb9..5431410 100644
--- a/internal/command/discover/discover_test.go
+++ b/internal/command/discover/discover_test.go
@@ -2,6 +2,7 @@ package discover
import (
"bytes"
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -83,7 +84,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
@@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, tc.expectedError)
diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go
index bbc73dc..b04eb0d 100644
--- a/internal/command/healthcheck/healthcheck.go
+++ b/internal/command/healthcheck/healthcheck.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
@@ -18,8 +19,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.runCheck()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.runCheck(ctx)
if err != nil {
return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
}
@@ -34,13 +35,13 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) runCheck() (*healthcheck.Response, error) {
+func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) {
client, err := healthcheck.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Check()
+ response, err := client.Check(ctx)
if err != nil {
return nil, err
}
diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go
index 7479bcb..d05e563 100644
--- a/internal/command/healthcheck/healthcheck_test.go
+++ b/internal/command/healthcheck/healthcheck_test.go
@@ -2,6 +2,7 @@ package healthcheck
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -53,7 +54,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
@@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err, "Redis available via internal API: FAILED")
require.Equal(t, "Internal API available: OK\n", buffer.String())
}
@@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)")
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go
index 2aaac2a..dab69ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -34,7 +35,7 @@ type Payload struct {
ExpiresIn int `json:"expires_in,omitempty"`
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) < 3 {
return disallowedcommand.Error
@@ -49,12 +50,12 @@ func (c *Command) Execute() error {
return err
}
- accessResponse, err := c.verifyAccess(action, repo)
+ accessResponse, err := c.verifyAccess(ctx, action, repo)
if err != nil {
return err
}
- payload, err := c.authenticate(operation, repo, accessResponse.UserId)
+ payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
return nil
@@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) {
return action, nil
}
-func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(action, repo)
+ return cmd.Verify(ctx, action, repo)
}
-func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) {
+func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) {
client, err := lfsauthenticate.NewClient(c.Config, c.Args)
if err != nil {
return nil, err
}
- response, err := client.Authenticate(operation, repo, userId)
+ response, err := client.Authenticate(ctx, operation, repo, userId)
if err != nil {
return nil, err
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go
index a1c7aec..55998ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate_test.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go
@@ -2,6 +2,7 @@ package lfsauthenticate
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
@@ -146,7 +147,7 @@ func TestLfsAuthenticateRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go
index b283890..6f3d03e 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"errors"
"fmt"
"strconv"
@@ -31,13 +32,13 @@ type tokenArgs struct {
ExpiresDate string // Calculated, a TTL is passed from command-line.
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
err := c.parseTokenArgs()
if err != nil {
return err
}
- response, err := c.getPersonalAccessToken()
+ response, err := c.getPersonalAccessToken(ctx)
if err != nil {
return err
}
@@ -76,11 +77,11 @@ func (c *Command) parseTokenArgs() error {
return nil
}
-func (c *Command) getPersonalAccessToken() (*personalaccesstoken.Response, error) {
+func (c *Command) getPersonalAccessToken(ctx context.Context) (*personalaccesstoken.Response, error) {
client, err := personalaccesstoken.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetPersonalAccessToken(c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
+ return client.GetPersonalAccessToken(ctx, c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
}
diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go
index bc748ab..5970142 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken_test.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go
@@ -2,6 +2,7 @@ package personalaccesstoken
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -170,7 +171,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
if tc.expectedError == "" {
assert.NoError(t, err)
diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go
index 8bee484..2a0c146 100644
--- a/internal/command/receivepack/gitalycall_test.go
+++ b/internal/command/receivepack/gitalycall_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -42,7 +43,7 @@ func TestReceivePack(t *testing.T) {
hook := testhelper.SetupLogger()
- err = cmd.Execute()
+ err = cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String())
diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go
index 7271264..4d5c686 100644
--- a/internal/command/receivepack/receivepack.go
+++ b/internal/command/receivepack/receivepack.go
@@ -1,6 +1,8 @@
package receivepack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: true,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go
index a4632b4..44cb680 100644
--- a/internal/command/receivepack/receivepack_test.go
+++ b/internal/command/receivepack/receivepack_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -18,7 +19,7 @@ func TestForbiddenAccess(t *testing.T) {
cmd, _, cleanup := setup(t, "disallowed", requests)
defer cleanup()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
@@ -26,7 +27,7 @@ func TestCustomReceivePack(t *testing.T) {
cmd, output, cleanup := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t))
defer cleanup()
- require.NoError(t, cmd.Execute())
+ require.NoError(t, cmd.Execute(context.Background()))
require.Equal(t, "customoutput", output.String())
}
diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go
index 5d2d709..9fcdde4 100644
--- a/internal/command/shared/accessverifier/accessverifier.go
+++ b/internal/command/shared/accessverifier/accessverifier.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -18,13 +19,13 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) {
+func (c *Command) Verify(ctx context.Context, action commandargs.CommandType, repo string) (*Response, error) {
client, err := accessverifier.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Verify(c.Args, action, repo)
+ response, err := client.Verify(ctx, c.Args, action, repo)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go
index 998e622..8ad87b8 100644
--- a/internal/command/shared/accessverifier/accessverifier_test.go
+++ b/internal/command/shared/accessverifier/accessverifier_test.go
@@ -2,6 +2,7 @@ package accessverifier
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -65,7 +66,7 @@ func TestMissingUser(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "2"}
- _, err := cmd.Verify(action, repo)
+ _, err := cmd.Verify(context.Background(), action, repo)
require.Equal(t, "missing user", err.Error())
}
@@ -75,7 +76,7 @@ func TestConsoleMessages(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "1"}
- cmd.Verify(action, repo)
+ cmd.Verify(context.Background(), action, repo)
require.Equal(t, "remote: \nremote: console\nremote: message\nremote: \n", errBuf.String())
require.Empty(t, outBuf.String())
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go
index 2ba1091..0675d36 100644
--- a/internal/command/shared/customaction/customaction.go
+++ b/internal/command/shared/customaction/customaction.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/client"
@@ -34,7 +35,7 @@ type Command struct {
EOFSent bool
}
-func (c *Command) Execute(response *accessverifier.Response) error {
+func (c *Command) Execute(ctx context.Context, response *accessverifier.Response) error {
data := response.Payload.Data
apiEndpoints := data.ApiEndpoints
@@ -42,10 +43,10 @@ func (c *Command) Execute(response *accessverifier.Response) error {
return errors.New("Custom action error: Empty API endpoints")
}
- return c.processApiEndpoints(response)
+ return c.processApiEndpoints(ctx, response)
}
-func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
+func (c *Command) processApiEndpoints(ctx context.Context, response *accessverifier.Response) error {
client, err := gitlabnet.GetClient(c.Config)
if err != nil {
@@ -64,7 +65,7 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
log.WithFields(fields).Info("Performing custom action")
- response, err := c.performRequest(client, endpoint, request)
+ response, err := c.performRequest(ctx, client, endpoint, request)
if err != nil {
return err
}
@@ -95,8 +96,8 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
return nil
}
-func (c *Command) performRequest(client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
- response, err := client.DoRequest(http.MethodPost, endpoint, request)
+func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
+ response, err := client.DoRequest(ctx, http.MethodPost, endpoint, request)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go
index 46c5f32..119da5b 100644
--- a/internal/command/shared/customaction/customaction_test.go
+++ b/internal/command/shared/customaction/customaction_test.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -78,7 +79,7 @@ func TestExecuteEOFSent(t *testing.T) {
EOFSent: true,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
@@ -148,7 +149,7 @@ func TestExecuteNoEOFSent(t *testing.T) {
EOFSent: false,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go
index 2f13cc5..f0a9e7b 100644
--- a/internal/command/twofactorrecover/twofactorrecover.go
+++ b/internal/command/twofactorrecover/twofactorrecover.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"fmt"
"strings"
@@ -16,9 +17,9 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if c.canContinue() {
- c.displayRecoveryCodes()
+ c.displayRecoveryCodes(ctx)
} else {
fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
}
@@ -38,8 +39,8 @@ func (c *Command) canContinue() bool {
return answer == "yes"
}
-func (c *Command) displayRecoveryCodes() {
- codes, err := c.getRecoveryCodes()
+func (c *Command) displayRecoveryCodes(ctx context.Context) {
+ codes, err := c.getRecoveryCodes(ctx)
if err == nil {
messageWithCodes :=
@@ -54,12 +55,12 @@ func (c *Command) displayRecoveryCodes() {
}
}
-func (c *Command) getRecoveryCodes() ([]string, error) {
+func (c *Command) getRecoveryCodes(ctx context.Context) ([]string, error) {
client, err := twofactorrecover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetRecoveryCodes(c.Args)
+ return client.GetRecoveryCodes(ctx, c.Args)
}
diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go
index d2f931b..ea6abd6 100644
--- a/internal/command/twofactorrecover/twofactorrecover_test.go
+++ b/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -2,6 +2,7 @@ package twofactorrecover
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -127,7 +128,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go
index eaeb2b7..f74093a 100644
--- a/internal/command/uploadarchive/gitalycall_test.go
+++ b/internal/command/uploadarchive/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -38,7 +39,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadArchive: "+repo, output.String())
diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go
index 9d4fbe0..178b42b 100644
--- a/internal/command/uploadarchive/uploadarchive.go
+++ b/internal/command/uploadarchive/uploadarchive.go
@@ -1,6 +1,8 @@
package uploadarchive
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -14,14 +16,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -29,8 +31,8 @@ func (c *Command) Execute() error {
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go
index 7b03009..5426569 100644
--- a/internal/command/uploadarchive/uploadarchive_test.go
+++ b/internal/command/uploadarchive/uploadarchive_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go
index d6762a2..22189b8 100644
--- a/internal/command/uploadpack/gitalycall_test.go
+++ b/internal/command/uploadpack/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -37,7 +38,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadPack: "+repo, output.String())
diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go
index 56814d7..fca3823 100644
--- a/internal/command/uploadpack/uploadpack.go
+++ b/internal/command/uploadpack/uploadpack.go
@@ -1,6 +1,8 @@
package uploadpack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: false,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go
index 7ea8e5d..20edb57 100644
--- a/internal/command/uploadpack/uploadpack_test.go
+++ b/internal/command/uploadpack/uploadpack_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
diff --git a/internal/config/config.go b/internal/config/config.go
index e7abd59..79c2a36 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -37,7 +37,7 @@ type Config struct {
Secret string `yaml:"secret"`
SslCertDir string `yaml:"ssl_cert_dir"`
HttpSettings HttpSettingsConfig `yaml:"http_settings"`
- HttpClient *client.HttpClient
+ HttpClient *client.HttpClient `-`
}
func (c *Config) GetHttpClient() *client.HttpClient {
diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go
index 00b9d76..7e120e0 100644
--- a/internal/gitlabnet/accessverifier/client.go
+++ b/internal/gitlabnet/accessverifier/client.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"fmt"
"net/http"
@@ -77,7 +78,7 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{client: client}, nil
}
-func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) {
+func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) {
request := &Request{Action: action, Repo: repo, Protocol: protocol, Changes: anyChanges}
if args.GitlabUsername != "" {
@@ -88,7 +89,7 @@ func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType,
request.CheckIp = sshenv.LocalAddr()
- response, err := c.client.Post("/allowed", request)
+ response, err := c.client.Post(ctx, "/allowed", request)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/accessverifier/client_test.go b/internal/gitlabnet/accessverifier/client_test.go
index 7ddbb5e..3681968 100644
--- a/internal/gitlabnet/accessverifier/client_test.go
+++ b/internal/gitlabnet/accessverifier/client_test.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -73,7 +74,7 @@ func TestSuccessfulResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- result, err := client.Verify(tc.args, receivePackAction, repo)
+ result, err := client.Verify(context.Background(), tc.args, receivePackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse(tc.who)
@@ -87,7 +88,7 @@ func TestGeoPushGetCustomAction(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "custom"}
- result, err := client.Verify(args, receivePackAction, repo)
+ result, err := client.Verify(context.Background(), args, receivePackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse("user-1")
@@ -110,7 +111,7 @@ func TestGeoPullGetCustomAction(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "custom"}
- result, err := client.Verify(args, uploadPackAction, repo)
+ result, err := client.Verify(context.Background(), args, uploadPackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse("user-1")
@@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
- resp, err := client.Verify(args, receivePackAction, repo)
+ resp, err := client.Verify(context.Background(), args, receivePackAction, repo)
require.EqualError(t, err, tc.expectedError)
require.Nil(t, resp)
diff --git a/internal/gitlabnet/authorizedkeys/client.go b/internal/gitlabnet/authorizedkeys/client.go
index e4fec28..0a00034 100644
--- a/internal/gitlabnet/authorizedkeys/client.go
+++ b/internal/gitlabnet/authorizedkeys/client.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"fmt"
"net/url"
@@ -32,13 +33,13 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetByKey(key string) (*Response, error) {
+func (c *Client) GetByKey(ctx context.Context, key string) (*Response, error) {
path, err := pathWithKey(key)
if err != nil {
return nil, err
}
- response, err := c.client.Get(path)
+ response, err := c.client.Get(ctx, path)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/authorizedkeys/client_test.go b/internal/gitlabnet/authorizedkeys/client_test.go
index c9c76a1..e72840c 100644
--- a/internal/gitlabnet/authorizedkeys/client_test.go
+++ b/internal/gitlabnet/authorizedkeys/client_test.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -48,7 +49,7 @@ func TestGetByKey(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
- result, err := client.GetByKey("key")
+ result, err := client.GetByKey(context.Background(), "key")
require.NoError(t, err)
require.Equal(t, &Response{Id: 1, Key: "public-key"}, result)
}
@@ -86,7 +87,7 @@ func TestGetByKeyErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- resp, err := client.GetByKey(tc.key)
+ resp, err := client.GetByKey(context.Background(), tc.key)
require.EqualError(t, err, tc.expectedError)
require.Nil(t, resp)
diff --git a/internal/gitlabnet/discover/client.go b/internal/gitlabnet/discover/client.go
index d1e1906..cc7f516 100644
--- a/internal/gitlabnet/discover/client.go
+++ b/internal/gitlabnet/discover/client.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"fmt"
"net/http"
"net/url"
@@ -31,7 +32,7 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) {
+func (c *Client) GetByCommandArgs(ctx context.Context, args *commandargs.Shell) (*Response, error) {
params := url.Values{}
if args.GitlabUsername != "" {
params.Add("username", args.GitlabUsername)
@@ -43,13 +44,13 @@ func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) {
return nil, fmt.Errorf("who='' is invalid")
}
- return c.getResponse(params)
+ return c.getResponse(ctx, params)
}
-func (c *Client) getResponse(params url.Values) (*Response, error) {
+func (c *Client) getResponse(ctx context.Context, params url.Values) (*Response, error) {
path := "/discover?" + params.Encode()
- response, err := c.client.Get(path)
+ response, err := c.client.Get(ctx, path)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/discover/client_test.go b/internal/gitlabnet/discover/client_test.go
index 96b3162..cb46dd7 100644
--- a/internal/gitlabnet/discover/client_test.go
+++ b/internal/gitlabnet/discover/client_test.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -62,7 +63,7 @@ func TestGetByKeyId(t *testing.T) {
params := url.Values{}
params.Add("key_id", "1")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
}
@@ -73,7 +74,7 @@ func TestGetByUsername(t *testing.T) {
params := url.Values{}
params.Add("username", "jane-doe")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
}
@@ -84,7 +85,7 @@ func TestMissingUser(t *testing.T) {
params := url.Values{}
params.Add("username", "missing")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.True(t, result.IsAnonymous())
}
@@ -119,7 +120,7 @@ func TestErrorResponses(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
params := url.Values{}
params.Add("username", tc.fakeUsername)
- resp, err := client.getResponse(params)
+ resp, err := client.getResponse(context.Background(), params)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)
diff --git a/internal/gitlabnet/healthcheck/client.go b/internal/gitlabnet/healthcheck/client.go
index 09b45af..f148504 100644
--- a/internal/gitlabnet/healthcheck/client.go
+++ b/internal/gitlabnet/healthcheck/client.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"fmt"
"net/http"
@@ -34,8 +35,8 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) Check() (*Response, error) {
- resp, err := c.client.Get(checkPath)
+func (c *Client) Check(ctx context.Context) (*Response, error) {
+ resp, err := c.client.Get(ctx, checkPath)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/healthcheck/client_test.go b/internal/gitlabnet/healthcheck/client_test.go
index c66ddbd..81ae209 100644
--- a/internal/gitlabnet/healthcheck/client_test.go
+++ b/internal/gitlabnet/healthcheck/client_test.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -33,7 +34,7 @@ func TestCheck(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
- result, err := client.Check()
+ result, err := client.Check(context.Background())
require.NoError(t, err)
require.Equal(t, testResponse, result)
}
diff --git a/internal/gitlabnet/lfsauthenticate/client.go b/internal/gitlabnet/lfsauthenticate/client.go
index fffc225..834cbe1 100644
--- a/internal/gitlabnet/lfsauthenticate/client.go
+++ b/internal/gitlabnet/lfsauthenticate/client.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"fmt"
"net/http"
"strings"
@@ -40,7 +41,7 @@ func NewClient(config *config.Config, args *commandargs.Shell) (*Client, error)
return &Client{config: config, client: client, args: args}, nil
}
-func (c *Client) Authenticate(operation, repo, userId string) (*Response, error) {
+func (c *Client) Authenticate(ctx context.Context, operation, repo, userId string) (*Response, error) {
request := &Request{Operation: operation, Repo: repo}
if c.args.GitlabKeyId != "" {
request.KeyId = c.args.GitlabKeyId
@@ -48,7 +49,7 @@ func (c *Client) Authenticate(operation, repo, userId string) (*Response, error)
request.UserId = strings.TrimPrefix(userId, "user-")
}
- response, err := c.client.Post("/lfs_authenticate", request)
+ response, err := c.client.Post(ctx, "/lfs_authenticate", request)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/lfsauthenticate/client_test.go b/internal/gitlabnet/lfsauthenticate/client_test.go
index 82e364b..2bd0451 100644
--- a/internal/gitlabnet/lfsauthenticate/client_test.go
+++ b/internal/gitlabnet/lfsauthenticate/client_test.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -85,7 +86,7 @@ func TestFailedRequests(t *testing.T) {
operation := tc.args.SshArgs[2]
- _, err = client.Authenticate(operation, repo, "")
+ _, err = client.Authenticate(context.Background(), operation, repo, "")
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
@@ -119,7 +120,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, err := NewClient(&config.Config{GitlabUrl: url}, args)
require.NoError(t, err)
- response, err := client.Authenticate(operation, repo, "")
+ response, err := client.Authenticate(context.Background(), operation, repo, "")
require.NoError(t, err)
expectedResponse := &Response{
diff --git a/internal/gitlabnet/personalaccesstoken/client.go b/internal/gitlabnet/personalaccesstoken/client.go
index 588bead..abbd395 100644
--- a/internal/gitlabnet/personalaccesstoken/client.go
+++ b/internal/gitlabnet/personalaccesstoken/client.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"errors"
"fmt"
"net/http"
@@ -42,13 +43,13 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetPersonalAccessToken(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) {
- requestBody, err := c.getRequestBody(args, name, scopes, expiresAt)
+func (c *Client) GetPersonalAccessToken(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) {
+ requestBody, err := c.getRequestBody(ctx, args, name, scopes, expiresAt)
if err != nil {
return nil, err
}
- response, err := c.client.Post("/personal_access_token", requestBody)
+ response, err := c.client.Post(ctx, "/personal_access_token", requestBody)
if err != nil {
return nil, err
}
@@ -70,7 +71,7 @@ func parse(hr *http.Response) (*Response, error) {
return response, nil
}
-func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) {
+func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) {
client, err := discover.NewClient(c.config)
if err != nil {
return nil, err
@@ -83,7 +84,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[]
return requestBody, nil
}
- userInfo, err := client.GetByCommandArgs(args)
+ userInfo, err := client.GetByCommandArgs(ctx, args)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/personalaccesstoken/client_test.go b/internal/gitlabnet/personalaccesstoken/client_test.go
index de45975..140a7b2 100644
--- a/internal/gitlabnet/personalaccesstoken/client_test.go
+++ b/internal/gitlabnet/personalaccesstoken/client_test.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -90,7 +91,7 @@ func TestGetPersonalAccessTokenByKeyId(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: "0"}
result, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"read_api", "read_repository"}, "",
+ context.Background(), args, "newtoken", &[]string{"read_api", "read_repository"}, "",
)
assert.NoError(t, err)
response := &Response{
@@ -109,7 +110,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) {
args := &commandargs.Shell{GitlabUsername: "jane-doe"}
result, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.NoError(t, err)
response := &Response{true, "YXuxvUgCEmeePY3G1YAa", []string{"api"}, "", ""}
@@ -122,7 +123,7 @@ func TestMissingUser(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: "1"}
_, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.Equal(t, "missing user", err.Error())
}
@@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
resp, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.EqualError(t, err, tc.expectedError)
diff --git a/internal/gitlabnet/twofactorrecover/client.go b/internal/gitlabnet/twofactorrecover/client.go
index d22daca..456f892 100644
--- a/internal/gitlabnet/twofactorrecover/client.go
+++ b/internal/gitlabnet/twofactorrecover/client.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"errors"
"fmt"
"net/http"
@@ -37,14 +38,14 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetRecoveryCodes(args *commandargs.Shell) ([]string, error) {
- requestBody, err := c.getRequestBody(args)
+func (c *Client) GetRecoveryCodes(ctx context.Context, args *commandargs.Shell) ([]string, error) {
+ requestBody, err := c.getRequestBody(ctx, args)
if err != nil {
return nil, err
}
- response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
+ response, err := c.client.Post(ctx, "/two_factor_recovery_codes", requestBody)
if err != nil {
return nil, err
}
@@ -66,7 +67,7 @@ func parse(hr *http.Response) ([]string, error) {
return response.RecoveryCodes, nil
}
-func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) {
+func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell) (*RequestBody, error) {
client, err := discover.NewClient(c.config)
if err != nil {
@@ -77,7 +78,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) {
if args.GitlabKeyId != "" {
requestBody = &RequestBody{KeyId: args.GitlabKeyId}
} else {
- userInfo, err := client.GetByCommandArgs(args)
+ userInfo, err := client.GetByCommandArgs(ctx, args)
if err != nil {
return nil, err
diff --git a/internal/gitlabnet/twofactorrecover/client_test.go b/internal/gitlabnet/twofactorrecover/client_test.go
index 372afec..46291aa 100644
--- a/internal/gitlabnet/twofactorrecover/client_test.go
+++ b/internal/gitlabnet/twofactorrecover/client_test.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -85,7 +86,7 @@ func TestGetRecoveryCodesByKeyId(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabKeyId: "0"}
- result, err := client.GetRecoveryCodes(args)
+ result, err := client.GetRecoveryCodes(context.Background(), args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 1", "codes 1"}, result)
}
@@ -95,7 +96,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "jane-doe"}
- result, err := client.GetRecoveryCodes(args)
+ result, err := client.GetRecoveryCodes(context.Background(), args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 2", "codes 2"}, result)
}
@@ -105,7 +106,7 @@ func TestMissingUser(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabKeyId: "1"}
- _, err := client.GetRecoveryCodes(args)
+ _, err := client.GetRecoveryCodes(context.Background(), args)
assert.Equal(t, "missing user", err.Error())
}
@@ -138,7 +139,7 @@ func TestErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
- resp, err := client.GetRecoveryCodes(args)
+ resp, err := client.GetRecoveryCodes(context.Background(), args)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)