summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2022-05-06 11:33:02 +0400
committerIgor Drozdov <idrozdov@gitlab.com>2022-05-06 11:47:52 +0400
commit8e0c2360a98c59b65be02de7629f007e362aa849 (patch)
tree0ac0acf71f211b4e2b39624206d7167cb891d679
parentc8ba21bd0c40e29cd40c7ed513c5d8a4e308a6dd (diff)
downloadgitlab-shell-id-improve-metrics.tar.gz
Refactor sshd.go and move the connection logic to connection.goid-improve-metrics
-rw-r--r--internal/sshd/connection.go72
-rw-r--r--internal/sshd/connection_test.go20
-rw-r--r--internal/sshd/session.go39
-rw-r--r--internal/sshd/session_test.go37
-rw-r--r--internal/sshd/sshd.go54
5 files changed, 134 insertions, 88 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 1312833..e878e9c 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -1,7 +1,9 @@
package sshd
import (
+ "net"
"context"
+ "time"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
@@ -13,19 +15,71 @@ import (
type connection struct {
concurrentSessions *semaphore.Weighted
- remoteAddr string
+ nconn net.Conn
+ remoteAddr string
+ started time.Time
}
-type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
+type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
-func newConnection(maxSessions int64, remoteAddr string) *connection {
+func newConnection(maxSessions int64, nconn net.Conn) *connection {
return &connection{
concurrentSessions: semaphore.NewWeighted(maxSessions),
- remoteAddr: remoteAddr,
+ nconn: nconn,
+ remoteAddr: nconn.RemoteAddr().String(),
+ started: time.Now(),
}
}
-func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
+func (c *connection) handle(ctx context.Context, cfg *ssh.ServerConfig, handler channelHandler) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+
+ // Prevent a panic in a single connection from taking out the whole server
+ defer func() {
+ if err := recover(); err != nil {
+ ctxlog.Warn("panic handling session")
+ }
+
+ metrics.SliSshdSessionsErrorsTotal.Inc()
+ }()
+
+ ctxlog.Info("server: handleConn: start")
+
+ metrics.SshdConnectionsInFlight.Inc()
+ defer func() {
+ metrics.SshdConnectionsInFlight.Dec()
+ metrics.SshdSessionDuration.Observe(time.Since(c.started).Seconds())
+ }()
+
+ // Initialize the connection with server
+ sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, cfg)
+
+ // Track the time required to establish a session
+ establishSessionDuration := time.Since(c.started).Seconds()
+ metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
+
+ // Most of the times a connection failes due to the client's misconfiguration or when
+ // a client cancels a request, so we shouldn't treat them as an error
+ // Warnings will helps us to track the errors whether they happend on the server side
+ if err != nil {
+ ctxlog.WithError(err).WithFields(log.Fields{
+ "establish_session_duration_s": establishSessionDuration,
+ }).Warn("conn: init: failed to initialize SSH connection")
+
+ return
+ }
+ go ssh.DiscardRequests(reqs)
+
+ // Handle incoming requests
+ c.handleRequests(ctx, sconn, chans, handler)
+
+ ctxlog.WithFields(log.Fields{
+ "duration_s": time.Since(c.started).Seconds(),
+ "establish_session_duration_s": establishSessionDuration,
+ }).Info("server: handleConn: done")
+}
+
+func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
@@ -55,10 +109,16 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
defer func() {
if err := recover(); err != nil {
ctxlog.WithField("recovered_error", err).Warn("panic handling session")
+
+ metrics.SliSshdSessionsErrorsTotal.Inc()
}
}()
- handler(ctx, channel, requests)
+ err := handler(ctx, sconn, channel, requests)
+ if err != nil {
+ metrics.SliSshdSessionsErrorsTotal.Inc()
+ }
+
ctxlog.Info("connection: handle: done")
}()
}
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
index d6bd3c0..9b5e158 100644
--- a/internal/sshd/connection_test.go
+++ b/internal/sshd/connection_test.go
@@ -5,6 +5,8 @@ import (
"errors"
"testing"
+ "golang.org/x/sync/semaphore"
+
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
@@ -48,7 +50,9 @@ func (f *fakeNewChannel) ExtraData() []byte {
}
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
- conn := newConnection(sessionsNum, "127.0.0.1:50000")
+ conn := &connection{
+ concurrentSessions: semaphore.NewWeighted(sessionsNum),
+ }
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
@@ -62,10 +66,11 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
+ return nil
})
})
@@ -80,7 +85,7 @@ func TestUnknownChannelType(t *testing.T) {
conn, chans := setup(1, newChannel)
go func() {
- conn.handle(context.Background(), chans, nil)
+ conn.handleRequests(context.Background(), nil, chans, nil)
}()
rejectionData := <-rejectCh
@@ -100,8 +105,9 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
+ return nil
})
}()
@@ -114,9 +120,10 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
+ return nil
})
require.True(t, channelHandled)
@@ -132,8 +139,9 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
+ return nil
})
}()
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index beb529e..831beb8 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -22,7 +22,6 @@ type session struct {
channel ssh.Channel
gitlabKeyId string
remoteAddr string
- success bool
// State managed by the session
execCmd string
@@ -42,11 +41,12 @@ type exitStatusReq struct {
ExitStatus uint32
}
-func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
+func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("session: handle: entering request loop")
+ var err error
for req := range requests {
sessionLog := ctxlog.WithFields(log.Fields{
"bytesize": len(req.Payload),
@@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
var shouldContinue bool
switch req.Type {
case "env":
- shouldContinue = s.handleEnv(ctx, req)
+ shouldContinue, err = s.handleEnv(ctx, req)
case "exec":
- shouldContinue = s.handleExec(ctx, req)
+ shouldContinue, err = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
- s.exit(ctx, s.handleShell(ctx, req))
+ var status uint32
+ status, err = s.handleShell(ctx, req)
+ s.exit(ctx, status)
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
@@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
ctxlog.Debug("session: handle: exiting request loop")
+
+ return err
}
-func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
+func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
- return false
+ return false, err
}
switch envRequest.Name {
@@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
).Debug("session: handleEnv: processed")
- return true
+ return true, nil
}
-func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
+func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) {
var execRequest execRequest
if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
- return false
+ return false, err
}
s.execCmd = execRequest.Command
- s.exit(ctx, s.handleShell(ctx, req))
+ status, err := s.handleShell(ctx, req)
+ s.exit(ctx, status)
- return false
+ return false, err
}
-func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
+func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) {
ctxlog := log.ContextLogger(ctx)
if req.WantReply {
@@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
}
s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
- return 128
+ return 128, err
}
cmdName := reflect.TypeOf(cmd).String()
@@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
if err := cmd.Execute(ctx); err != nil {
s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
- return 1
+ return 1, err
}
ctxlog.Info("session: handleShell: command executed successfully")
- return 0
+ return 0, nil
}
func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
@@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) {
log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
- s.success = status == 0
-
s.channel.CloseWrite()
s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go
index d0cc8d4..7a01eb2 100644
--- a/internal/sshd/session_test.go
+++ b/internal/sshd/session_test.go
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"testing"
+ "errors"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
@@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedProtocolVersion string
expectedResult bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"),
expectedProtocolVersion: "1",
expectedResult: false,
}, {
desc: "valid payload",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "2",
expectedResult: true,
}, {
desc: "valid payload with forbidden env var",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "1",
expectedResult: true,
},
@@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) {
s := &session{gitProtocolVersion: "1"}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
- require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
+ shouldContinue, err := s.handleEnv(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, tc.expectedResult, shouldContinue)
+ require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion)
})
}
}
@@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedExecCmd string
sentRequestName string
sentRequestPayload []byte
- success bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"),
expectedExecCmd: "",
sentRequestName: "",
}, {
desc: "valid payload",
payload: ssh.Marshal(execRequest{Command: "discover"}),
+ expectedErr: nil,
expectedExecCmd: "discover",
sentRequestName: "exit-status",
sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
- success: true,
},
}
@@ -129,10 +138,12 @@ func TestHandleExec(t *testing.T) {
}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, false, s.handleExec(context.Background(), r))
+ shouldContinue, err := s.handleExec(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, false, shouldContinue)
require.Equal(t, tc.sentRequestName, f.sentRequestName)
require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
- require.Equal(t, tc.success, s.success)
})
}
}
@@ -143,32 +154,36 @@ func TestHandleShell(t *testing.T) {
cmd string
errMsg string
gitlabKeyId string
+ expectedErrString string
expectedExitCode uint32
- success bool
}{
{
desc: "fails to parse command",
cmd: `\`,
errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
gitlabKeyId: "root",
+ expectedErrString: "Invalid SSH command: invalid command line string",
expectedExitCode: 128,
}, {
desc: "specified command is unknown",
cmd: "unknown-command",
errMsg: "Unknown command: unknown-command\n",
gitlabKeyId: "root",
+ expectedErrString: "Disallowed command",
expectedExitCode: 128,
}, {
desc: "fails to parse command",
cmd: "discover",
gitlabKeyId: "",
errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
+ expectedErrString: "Failed to get username: who='' is invalid",
expectedExitCode: 1,
}, {
desc: "fails to parse command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
+ expectedErrString: "",
expectedExitCode: 0,
},
}
@@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) {
}
r := &ssh.Request{}
- require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
+ exitCode, err := s.handleShell(context.Background(), r)
+
+ if tc.expectedErrString != "" {
+ require.Equal(t, tc.expectedErrString, err.Error())
+ }
+
+ require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.errMsg, out.String())
})
}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 49b8ab9..ebd7022 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -12,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
@@ -146,68 +145,23 @@ func (s *Server) getStatus() status {
}
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
- success := false
-
- metrics.SshdConnectionsInFlight.Inc()
- started := time.Now()
- defer func() {
- metrics.SshdConnectionsInFlight.Dec()
- metrics.SshdSessionDuration.Observe(time.Since(started).Seconds())
-
- metrics.SliSshdSessionsTotal.Inc()
- if !success {
- metrics.SliSshdSessionsErrorsTotal.Inc()
- }
- }()
-
- remoteAddr := nconn.RemoteAddr().String()
-
defer s.wg.Done()
defer nconn.Close()
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
- ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
-
- // Prevent a panic in a single connection from taking out the whole server
- defer func() {
- if err := recover(); err != nil {
- ctxlog.Warn("panic handling session")
- }
- }()
-
- ctxlog.Info("server: handleConn: start")
-
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
- if err != nil {
- ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection")
- return
- }
- go ssh.DiscardRequests(reqs)
-
- var establishSessionDuration float64
- conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
- conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
- establishSessionDuration = time.Since(started).Seconds()
- metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
-
+ conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, nconn)
+ conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
session := &session{
cfg: s.Config,
channel: channel,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
- remoteAddr: remoteAddr,
+ remoteAddr: nconn.RemoteAddr().String(),
}
- session.handle(ctx, requests)
-
- success = session.success
+ return session.handle(ctx, requests)
})
-
- ctxlog.WithFields(log.Fields{
- "duration_s": time.Since(started).Seconds(),
- "establish_session_duration_s": establishSessionDuration,
- }).Info("server: handleConn: done")
}
func unconditionalRequirePolicy(_ net.Addr) (proxyproto.Policy, error) {