diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-06 11:33:02 +0400 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2022-05-06 11:47:52 +0400 |
commit | 8e0c2360a98c59b65be02de7629f007e362aa849 (patch) | |
tree | 0ac0acf71f211b4e2b39624206d7167cb891d679 | |
parent | c8ba21bd0c40e29cd40c7ed513c5d8a4e308a6dd (diff) | |
download | gitlab-shell-id-improve-metrics.tar.gz |
Refactor sshd.go and move the connection logic to connection.goid-improve-metrics
-rw-r--r-- | internal/sshd/connection.go | 72 | ||||
-rw-r--r-- | internal/sshd/connection_test.go | 20 | ||||
-rw-r--r-- | internal/sshd/session.go | 39 | ||||
-rw-r--r-- | internal/sshd/session_test.go | 37 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 54 |
5 files changed, 134 insertions, 88 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 1312833..e878e9c 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -1,7 +1,9 @@ package sshd import ( + "net" "context" + "time" "golang.org/x/crypto/ssh" "golang.org/x/sync/semaphore" @@ -13,19 +15,71 @@ import ( type connection struct { concurrentSessions *semaphore.Weighted - remoteAddr string + nconn net.Conn + remoteAddr string + started time.Time } -type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) +type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error -func newConnection(maxSessions int64, remoteAddr string) *connection { +func newConnection(maxSessions int64, nconn net.Conn) *connection { return &connection{ concurrentSessions: semaphore.NewWeighted(maxSessions), - remoteAddr: remoteAddr, + nconn: nconn, + remoteAddr: nconn.RemoteAddr().String(), + started: time.Now(), } } -func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { +func (c *connection) handle(ctx context.Context, cfg *ssh.ServerConfig, handler channelHandler) { + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) + + // Prevent a panic in a single connection from taking out the whole server + defer func() { + if err := recover(); err != nil { + ctxlog.Warn("panic handling session") + } + + metrics.SliSshdSessionsErrorsTotal.Inc() + }() + + ctxlog.Info("server: handleConn: start") + + metrics.SshdConnectionsInFlight.Inc() + defer func() { + metrics.SshdConnectionsInFlight.Dec() + metrics.SshdSessionDuration.Observe(time.Since(c.started).Seconds()) + }() + + // Initialize the connection with server + sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, cfg) + + // Track the time required to establish a session + establishSessionDuration := time.Since(c.started).Seconds() + metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) + + // Most of the times a connection failes due to the client's misconfiguration or when + // a client cancels a request, so we shouldn't treat them as an error + // Warnings will helps us to track the errors whether they happend on the server side + if err != nil { + ctxlog.WithError(err).WithFields(log.Fields{ + "establish_session_duration_s": establishSessionDuration, + }).Warn("conn: init: failed to initialize SSH connection") + + return + } + go ssh.DiscardRequests(reqs) + + // Handle incoming requests + c.handleRequests(ctx, sconn, chans, handler) + + ctxlog.WithFields(log.Fields{ + "duration_s": time.Since(c.started).Seconds(), + "establish_session_duration_s": establishSessionDuration, + }).Info("server: handleConn: done") +} + +func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) { ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { @@ -55,10 +109,16 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha defer func() { if err := recover(); err != nil { ctxlog.WithField("recovered_error", err).Warn("panic handling session") + + metrics.SliSshdSessionsErrorsTotal.Inc() } }() - handler(ctx, channel, requests) + err := handler(ctx, sconn, channel, requests) + if err != nil { + metrics.SliSshdSessionsErrorsTotal.Inc() + } + ctxlog.Info("connection: handle: done") }() } diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index d6bd3c0..9b5e158 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "golang.org/x/sync/semaphore" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) @@ -48,7 +50,9 @@ func (f *fakeNewChannel) ExtraData() []byte { } func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) { - conn := newConnection(sessionsNum, "127.0.0.1:50000") + conn := &connection{ + concurrentSessions: semaphore.NewWeighted(sessionsNum), + } chans := make(chan ssh.NewChannel, 1) chans <- newChannel @@ -62,10 +66,11 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) { numSessions := 0 require.NotPanics(t, func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { numSessions += 1 close(chans) panic("This is a panic") + return nil }) }) @@ -80,7 +85,7 @@ func TestUnknownChannelType(t *testing.T) { conn, chans := setup(1, newChannel) go func() { - conn.handle(context.Background(), chans, nil) + conn.handleRequests(context.Background(), nil, chans, nil) }() rejectionData := <-rejectCh @@ -100,8 +105,9 @@ func TestTooManySessions(t *testing.T) { defer cancel() go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { <-ctx.Done() // Keep the accepted channel open until the end of the test + return nil }) }() @@ -114,9 +120,10 @@ func TestAcceptSessionSucceeds(t *testing.T) { conn, chans := setup(1, newChannel) channelHandled := false - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true close(chans) + return nil }) require.True(t, channelHandled) @@ -132,8 +139,9 @@ func TestAcceptSessionFails(t *testing.T) { channelHandled := false go func() { - conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { channelHandled = true + return nil }) }() diff --git a/internal/sshd/session.go b/internal/sshd/session.go index beb529e..831beb8 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -22,7 +22,6 @@ type session struct { channel ssh.Channel gitlabKeyId string remoteAddr string - success bool // State managed by the session execCmd string @@ -42,11 +41,12 @@ type exitStatusReq struct { ExitStatus uint32 } -func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { +func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error { ctxlog := log.ContextLogger(ctx) ctxlog.Debug("session: handle: entering request loop") + var err error for req := range requests { sessionLog := ctxlog.WithFields(log.Fields{ "bytesize": len(req.Payload), @@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { var shouldContinue bool switch req.Type { case "env": - shouldContinue = s.handleEnv(ctx, req) + shouldContinue, err = s.handleEnv(ctx, req) case "exec": - shouldContinue = s.handleExec(ctx, req) + shouldContinue, err = s.handleExec(ctx, req) case "shell": shouldContinue = false - s.exit(ctx, s.handleShell(ctx, req)) + var status uint32 + status, err = s.handleShell(ctx, req) + s.exit(ctx, status) default: // Ignore unknown requests but don't terminate the session shouldContinue = true @@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { } ctxlog.Debug("session: handle: exiting request loop") + + return err } -func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { +func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) { var accepted bool var envRequest envRequest if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request") - return false + return false, err } switch envRequest.Name { @@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { ctx, log.Fields{"accepted": accepted, "env_request": envRequest}, ).Debug("session: handleEnv: processed") - return true + return true, nil } -func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool { +func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) { var execRequest execRequest if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { - return false + return false, err } s.execCmd = execRequest.Command - s.exit(ctx, s.handleShell(ctx, req)) + status, err := s.handleShell(ctx, req) + s.exit(ctx, status) - return false + return false, err } -func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { +func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) { ctxlog := log.ContextLogger(ctx) if req.WantReply { @@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { s.toStderr(ctx, "Failed to parse command: %v\n", err.Error()) } s.toStderr(ctx, "Unknown command: %v\n", s.execCmd) - return 128 + return 128, err } cmdName := reflect.TypeOf(cmd).String() @@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { if err := cmd.Execute(ctx); err != nil { s.toStderr(ctx, "remote: ERROR: %v\n", err.Error()) - return 1 + return 1, err } ctxlog.Info("session: handleShell: command executed successfully") - return 0 + return 0, nil } func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { @@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) { log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting") req := exitStatusReq{ExitStatus: status} - s.success = status == 0 - s.channel.CloseWrite() s.channel.SendRequest("exit-status", false, ssh.Marshal(req)) } diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go index d0cc8d4..7a01eb2 100644 --- a/internal/sshd/session_test.go +++ b/internal/sshd/session_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "testing" + "errors" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) { testCases := []struct { desc string payload []byte + expectedErr error expectedProtocolVersion string expectedResult bool }{ { desc: "invalid payload", payload: []byte("invalid"), + expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"), expectedProtocolVersion: "1", expectedResult: false, }, { desc: "valid payload", payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}), + expectedErr: nil, expectedProtocolVersion: "2", expectedResult: true, }, { desc: "valid payload with forbidden env var", payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}), + expectedErr: nil, expectedProtocolVersion: "1", expectedResult: true, }, @@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) { s := &session{gitProtocolVersion: "1"} r := &ssh.Request{Payload: tc.payload} - require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult) - require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion) + shouldContinue, err := s.handleEnv(context.Background(), r) + + require.Equal(t, tc.expectedErr, err) + require.Equal(t, tc.expectedResult, shouldContinue) + require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion) }) } } @@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) { testCases := []struct { desc string payload []byte + expectedErr error expectedExecCmd string sentRequestName string sentRequestPayload []byte - success bool }{ { desc: "invalid payload", payload: []byte("invalid"), + expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"), expectedExecCmd: "", sentRequestName: "", }, { desc: "valid payload", payload: ssh.Marshal(execRequest{Command: "discover"}), + expectedErr: nil, expectedExecCmd: "discover", sentRequestName: "exit-status", sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}), - success: true, }, } @@ -129,10 +138,12 @@ func TestHandleExec(t *testing.T) { } r := &ssh.Request{Payload: tc.payload} - require.Equal(t, false, s.handleExec(context.Background(), r)) + shouldContinue, err := s.handleExec(context.Background(), r) + + require.Equal(t, tc.expectedErr, err) + require.Equal(t, false, shouldContinue) require.Equal(t, tc.sentRequestName, f.sentRequestName) require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload) - require.Equal(t, tc.success, s.success) }) } } @@ -143,32 +154,36 @@ func TestHandleShell(t *testing.T) { cmd string errMsg string gitlabKeyId string + expectedErrString string expectedExitCode uint32 - success bool }{ { desc: "fails to parse command", cmd: `\`, errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n", gitlabKeyId: "root", + expectedErrString: "Invalid SSH command: invalid command line string", expectedExitCode: 128, }, { desc: "specified command is unknown", cmd: "unknown-command", errMsg: "Unknown command: unknown-command\n", gitlabKeyId: "root", + expectedErrString: "Disallowed command", expectedExitCode: 128, }, { desc: "fails to parse command", cmd: "discover", gitlabKeyId: "", errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", + expectedErrString: "Failed to get username: who='' is invalid", expectedExitCode: 1, }, { desc: "fails to parse command", cmd: "discover", errMsg: "", gitlabKeyId: "root", + expectedErrString: "", expectedExitCode: 0, }, } @@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) { } r := &ssh.Request{} - require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r)) + exitCode, err := s.handleShell(context.Background(), r) + + if tc.expectedErrString != "" { + require.Equal(t, tc.expectedErrString, err.Error()) + } + + require.Equal(t, tc.expectedExitCode, exitCode) require.Equal(t, tc.errMsg, out.String()) }) } diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 49b8ab9..ebd7022 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -12,7 +12,6 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/metrics" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" @@ -146,68 +145,23 @@ func (s *Server) getStatus() status { } func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { - success := false - - metrics.SshdConnectionsInFlight.Inc() - started := time.Now() - defer func() { - metrics.SshdConnectionsInFlight.Dec() - metrics.SshdSessionDuration.Observe(time.Since(started).Seconds()) - - metrics.SliSshdSessionsTotal.Inc() - if !success { - metrics.SliSshdSessionsErrorsTotal.Inc() - } - }() - - remoteAddr := nconn.RemoteAddr().String() - defer s.wg.Done() defer nconn.Close() ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) defer cancel() - ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr}) - - // Prevent a panic in a single connection from taking out the whole server - defer func() { - if err := recover(); err != nil { - ctxlog.Warn("panic handling session") - } - }() - - ctxlog.Info("server: handleConn: start") - - sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx)) - if err != nil { - ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection") - return - } - go ssh.DiscardRequests(reqs) - - var establishSessionDuration float64 - conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr) - conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { - establishSessionDuration = time.Since(started).Seconds() - metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) - + conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, nconn) + conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error { session := &session{ cfg: s.Config, channel: channel, gitlabKeyId: sconn.Permissions.Extensions["key-id"], - remoteAddr: remoteAddr, + remoteAddr: nconn.RemoteAddr().String(), } - session.handle(ctx, requests) - - success = session.success + return session.handle(ctx, requests) }) - - ctxlog.WithFields(log.Fields{ - "duration_s": time.Since(started).Seconds(), - "establish_session_duration_s": establishSessionDuration, - }).Info("server: handleConn: done") } func unconditionalRequirePolicy(_ net.Addr) (proxyproto.Policy, error) { |