diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2022-07-01 11:02:59 +0000 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2022-07-01 11:02:59 +0000 |
commit | 0d7ef238cb8c05eabaec85e62bec70a40147d1df (patch) | |
tree | 9179705f9e8b6ee309d456323fbaedaa70141c7e | |
parent | 01f4e022c04b29b896eb383e6e6a33f96a6beeb1 (diff) | |
parent | 9b60ce49460876d0e599f2fec65f02856930dbcd (diff) | |
download | gitlab-shell-0d7ef238cb8c05eabaec85e62bec70a40147d1df.tar.gz |
Merge branch 'sshd-forwarded-for' into 'main'
Pass original IP from PROXY requests to internal API calls
See merge request gitlab-org/gitlab-shell!665
-rw-r--r-- | client/client_test.go | 22 | ||||
-rw-r--r-- | client/gitlabnet.go | 8 | ||||
-rw-r--r-- | internal/gitlabnet/accessverifier/client.go | 18 | ||||
-rw-r--r-- | internal/gitlabnet/client.go | 16 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 17 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 15 |
6 files changed, 76 insertions, 20 deletions
diff --git a/client/client_test.go b/client/client_test.go index 66ce2d8..06036b6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -76,6 +76,7 @@ func TestClients(t *testing.T) { testErrorMessage(t, client) testAuthenticationHeader(t, client) testJWTAuthenticationHeader(t, client) + testXForwardedForHeader(t, client) }) } } @@ -221,6 +222,21 @@ func testJWTAuthenticationHeader(t *testing.T, client *GitlabNetClient) { }) } +func testXForwardedForHeader(t *testing.T, client *GitlabNetClient) { + t.Run("X-Forwarded-For Header inserted if original address in context", func(t *testing.T) { + ctx := context.WithValue(context.Background(), OriginalRemoteIPContextKey{}, "196.7.0.238") + response, err := client.Get(ctx, "/x_forwarded_for") + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := io.ReadAll(response.Body) + require.NoError(t, err) + require.Equal(t, "196.7.0.238", string(responseBody)) + }) +} + func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestRequestHandler { requests := []testserver.TestRequestHandler{ { @@ -257,6 +273,12 @@ func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestReques }, }, { + Path: "/api/v4/internal/x_forwarded_for", + Handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, r.Header.Get("X-Forwarded-For")) + }, + }, + { Path: "/api/v4/internal/error", Handler: func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/client/gitlabnet.go b/client/gitlabnet.go index 31131d2..c34f148 100644 --- a/client/gitlabnet.go +++ b/client/gitlabnet.go @@ -41,6 +41,9 @@ type ApiError struct { Msg string } +// To use as the key in a Context to set an X-Forwarded-For header in a request +type OriginalRemoteIPContextKey struct{} + func (e *ApiError) Error() string { return e.Msg } @@ -150,6 +153,11 @@ func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, da } request.Header.Set(apiSecretHeaderName, tokenString) + originalRemoteIP, ok := ctx.Value(OriginalRemoteIPContextKey{}).(string) + if ok { + request.Header.Add("X-Forwarded-For", originalRemoteIP) + } + request.Header.Add("Content-Type", "application/json") request.Header.Add("User-Agent", c.userAgent) request.Close = true diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go index c46a16f..adeccd6 100644 --- a/internal/gitlabnet/accessverifier/client.go +++ b/internal/gitlabnet/accessverifier/client.go @@ -3,7 +3,6 @@ package accessverifier import ( "context" "fmt" - "net" "net/http" pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" @@ -86,7 +85,7 @@ func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action com request.KeyId = args.GitlabKeyId } - request.CheckIp = parseIP(args.Env.RemoteAddr) + request.CheckIp = gitlabnet.ParseIP(args.Env.RemoteAddr) response, err := c.client.Post(ctx, "/allowed", request) if err != nil { @@ -117,18 +116,3 @@ func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) { func (r *Response) IsCustomAction() bool { return r.StatusCode == http.StatusMultipleChoices } - -func parseIP(remoteAddr string) string { - // The remoteAddr field can be filled by: - // 1. An IP address via the SSH_CONNECTION environment variable - // 2. A host:port combination via the PROXY protocol - ip, _, err := net.SplitHostPort(remoteAddr) - - // If we don't have a port or can't parse this address for some reason, - // just return the original string. - if err != nil { - return remoteAddr - } - - return ip -} diff --git a/internal/gitlabnet/client.go b/internal/gitlabnet/client.go index 39c3320..9bcf6db 100644 --- a/internal/gitlabnet/client.go +++ b/internal/gitlabnet/client.go @@ -3,6 +3,7 @@ package gitlabnet import ( "encoding/json" "fmt" + "net" "net/http" "gitlab.com/gitlab-org/gitlab-shell/client" @@ -34,3 +35,18 @@ func ParseJSON(hr *http.Response, response interface{}) error { return nil } + +func ParseIP(remoteAddr string) string { + // The remoteAddr field can be filled by: + // 1. An IP address via the SSH_CONNECTION environment variable + // 2. A host:port combination via the PROXY protocol + ip, _, err := net.SplitHostPort(remoteAddr) + + // If we don't have a port or can't parse this address for some reason, + // just return the original string. + if err != nil { + return remoteAddr + } + + return ip +} diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 43c4d7b..d275193 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -12,7 +12,9 @@ import ( "github.com/pires/go-proxyproto" "golang.org/x/crypto/ssh" + "gitlab.com/gitlab-org/gitlab-shell/client" "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" "gitlab.com/gitlab-org/gitlab-shell/internal/metrics" "gitlab.com/gitlab-org/labkit/correlation" @@ -145,13 +147,26 @@ func (s *Server) getStatus() status { return s.status } +func contextWithValues(parent context.Context, nconn net.Conn) context.Context { + ctx := correlation.ContextWithCorrelation(parent, correlation.SafeRandomID()) + + // If we're dealing with a PROXY connection, register the original requester's IP + mconn, ok := nconn.(*proxyproto.Conn) + if ok { + ip := gitlabnet.ParseIP(mconn.Raw().RemoteAddr().String()) + ctx = context.WithValue(ctx, client.OriginalRemoteIPContextKey{}, ip) + } + + return ctx +} + func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { defer s.wg.Done() metrics.SshdConnectionsInFlight.Inc() defer metrics.SshdConnectionsInFlight.Dec() - ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) + ctx, cancel := context.WithCancel(contextWithValues(ctx, nconn)) defer cancel() go func() { <-ctx.Done() diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index 8f52125..e3fbeeb 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -27,6 +27,7 @@ const ( var ( correlationId = "" + xForwardedFor = "" ) func TestListenAndServe(t *testing.T) { @@ -63,6 +64,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin }, DestinationAddr: target, } + xForwardedFor = "127.0.0.1" + defer func() { + xForwardedFor = "" // Cleanup for other test cases + }() testCases := []struct { desc string @@ -132,9 +137,9 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin require.NoError(t, err) } - sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t)) + sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t)) if sshConn != nil { - sshConn.Close() + defer sshConn.Close() } if tc.isRejected { @@ -142,6 +147,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin require.Regexp(t, "ssh: handshake failed", err.Error()) } else { require.NoError(t, err) + client := ssh.NewClient(sshConn, sshChans, sshRequs) + defer client.Close() + + holdSession(t, client) } }) } @@ -306,6 +315,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex correlationId = r.Header.Get("X-Request-Id") require.NotEmpty(t, correlationId) + require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For")) fmt.Fprint(w, `{"id": 1000, "key": "key"}`) }, @@ -313,6 +323,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex Path: "/api/v4/internal/discover", Handler: func(w http.ResponseWriter, r *http.Request) { require.Equal(t, correlationId, r.Header.Get("X-Request-Id")) + require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For")) fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`) }, |