summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2022-05-16 01:17:52 +0400
committerIgor Drozdov <idrozdov@gitlab.com>2022-05-16 12:05:32 +0400
commita77babe96fac9c880061fa63fffabfc8406f11bf (patch)
tree8ee023b0bc368fec094e57301875535b14ac20ec
parent7cde0770f2a29010181f95eef4c1744e16f5e0d8 (diff)
downloadgitlab-shell-a77babe96fac9c880061fa63fffabfc8406f11bf.tar.gz
Return error from session handler
-rw-r--r--internal/sshd/connection.go9
-rw-r--r--internal/sshd/connection_test.go16
-rw-r--r--internal/sshd/session.go39
-rw-r--r--internal/sshd/session_test.go87
-rw-r--r--internal/sshd/sshd.go8
5 files changed, 96 insertions, 63 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 5b1232d..060156d 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -22,7 +22,7 @@ type connection struct {
sconn *ssh.ServerConn
}
-type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
+type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error
func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
return &connection{
@@ -76,7 +76,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
}
}()
- handler(ctx, channel, requests)
+ metrics.SliSshdSessionsTotal.Inc()
+ err := handler(ctx, channel, requests)
+ if err != nil {
+ metrics.SliSshdSessionsErrorsTotal.Inc()
+ }
+
ctxlog.Info("connection: handle: done")
}()
}
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
index 3bd9bf8..0a06255 100644
--- a/internal/sshd/connection_test.go
+++ b/internal/sshd/connection_test.go
@@ -55,9 +55,14 @@ type fakeConn struct {
ssh.Conn
sentRequestName string
+ waitErr error
mu sync.Mutex
}
+func (f *fakeConn) Wait() error {
+ return f.waitErr
+}
+
func (f *fakeConn) SentRequestName() string {
f.mu.Lock()
defer f.mu.Unlock()
@@ -90,7 +95,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
@@ -128,8 +133,9 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
+ return nil
})
}()
@@ -142,9 +148,10 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
+ return nil
})
require.True(t, channelHandled)
@@ -160,8 +167,9 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
+ return nil
})
}()
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index beb529e..831beb8 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -22,7 +22,6 @@ type session struct {
channel ssh.Channel
gitlabKeyId string
remoteAddr string
- success bool
// State managed by the session
execCmd string
@@ -42,11 +41,12 @@ type exitStatusReq struct {
ExitStatus uint32
}
-func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
+func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("session: handle: entering request loop")
+ var err error
for req := range requests {
sessionLog := ctxlog.WithFields(log.Fields{
"bytesize": len(req.Payload),
@@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
var shouldContinue bool
switch req.Type {
case "env":
- shouldContinue = s.handleEnv(ctx, req)
+ shouldContinue, err = s.handleEnv(ctx, req)
case "exec":
- shouldContinue = s.handleExec(ctx, req)
+ shouldContinue, err = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
- s.exit(ctx, s.handleShell(ctx, req))
+ var status uint32
+ status, err = s.handleShell(ctx, req)
+ s.exit(ctx, status)
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
@@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
ctxlog.Debug("session: handle: exiting request loop")
+
+ return err
}
-func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
+func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
- return false
+ return false, err
}
switch envRequest.Name {
@@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
).Debug("session: handleEnv: processed")
- return true
+ return true, nil
}
-func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
+func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) {
var execRequest execRequest
if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
- return false
+ return false, err
}
s.execCmd = execRequest.Command
- s.exit(ctx, s.handleShell(ctx, req))
+ status, err := s.handleShell(ctx, req)
+ s.exit(ctx, status)
- return false
+ return false, err
}
-func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
+func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) {
ctxlog := log.ContextLogger(ctx)
if req.WantReply {
@@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
}
s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
- return 128
+ return 128, err
}
cmdName := reflect.TypeOf(cmd).String()
@@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
if err := cmd.Execute(ctx); err != nil {
s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
- return 1
+ return 1, err
}
ctxlog.Info("session: handleShell: command executed successfully")
- return 0
+ return 0, nil
}
func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
@@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) {
log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
- s.success = status == 0
-
s.channel.CloseWrite()
s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go
index d0cc8d4..5bc8e7c 100644
--- a/internal/sshd/session_test.go
+++ b/internal/sshd/session_test.go
@@ -3,6 +3,7 @@ package sshd
import (
"bytes"
"context"
+ "errors"
"io"
"net/http"
"testing"
@@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedProtocolVersion string
expectedResult bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"),
expectedProtocolVersion: "1",
expectedResult: false,
}, {
desc: "valid payload",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "2",
expectedResult: true,
}, {
desc: "valid payload with forbidden env var",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "1",
expectedResult: true,
},
@@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) {
s := &session{gitProtocolVersion: "1"}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
- require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
+ shouldContinue, err := s.handleEnv(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, tc.expectedResult, shouldContinue)
+ require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion)
})
}
}
@@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedExecCmd string
sentRequestName string
sentRequestPayload []byte
- success bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"),
expectedExecCmd: "",
sentRequestName: "",
}, {
desc: "valid payload",
payload: ssh.Marshal(execRequest{Command: "discover"}),
+ expectedErr: nil,
expectedExecCmd: "discover",
sentRequestName: "exit-status",
sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
- success: true,
},
}
@@ -129,47 +138,53 @@ func TestHandleExec(t *testing.T) {
}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, false, s.handleExec(context.Background(), r))
+ shouldContinue, err := s.handleExec(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, false, shouldContinue)
require.Equal(t, tc.sentRequestName, f.sentRequestName)
require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
- require.Equal(t, tc.success, s.success)
})
}
}
func TestHandleShell(t *testing.T) {
testCases := []struct {
- desc string
- cmd string
- errMsg string
- gitlabKeyId string
- expectedExitCode uint32
- success bool
+ desc string
+ cmd string
+ errMsg string
+ gitlabKeyId string
+ expectedErrString string
+ expectedExitCode uint32
}{
{
- desc: "fails to parse command",
- cmd: `\`,
- errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
- gitlabKeyId: "root",
- expectedExitCode: 128,
+ desc: "fails to parse command",
+ cmd: `\`,
+ errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
+ gitlabKeyId: "root",
+ expectedErrString: "Invalid SSH command: invalid command line string",
+ expectedExitCode: 128,
}, {
- desc: "specified command is unknown",
- cmd: "unknown-command",
- errMsg: "Unknown command: unknown-command\n",
- gitlabKeyId: "root",
- expectedExitCode: 128,
+ desc: "specified command is unknown",
+ cmd: "unknown-command",
+ errMsg: "Unknown command: unknown-command\n",
+ gitlabKeyId: "root",
+ expectedErrString: "Disallowed command",
+ expectedExitCode: 128,
}, {
- desc: "fails to parse command",
- cmd: "discover",
- gitlabKeyId: "",
- errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
- expectedExitCode: 1,
+ desc: "fails to parse command",
+ cmd: "discover",
+ gitlabKeyId: "",
+ errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
+ expectedErrString: "Failed to get username: who='' is invalid",
+ expectedExitCode: 1,
}, {
- desc: "fails to parse command",
- cmd: "discover",
- errMsg: "",
- gitlabKeyId: "root",
- expectedExitCode: 0,
+ desc: "fails to parse command",
+ cmd: "discover",
+ errMsg: "",
+ gitlabKeyId: "root",
+ expectedErrString: "",
+ expectedExitCode: 0,
},
}
@@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) {
}
r := &ssh.Request{}
- require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
+ exitCode, err := s.handleShell(context.Background(), r)
+
+ if tc.expectedErrString != "" {
+ require.Equal(t, tc.expectedErrString, err.Error())
+ }
+
+ require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.errMsg, out.String())
})
}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 242e4f2..a9cd302 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -181,7 +181,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started := time.Now()
var establishSessionDuration float64
conn := newConnection(s.Config, remoteAddr, sconn)
- conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
+ conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error {
establishSessionDuration = time.Since(started).Seconds()
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
@@ -192,11 +192,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr: remoteAddr,
}
- metrics.SliSshdSessionsTotal.Inc()
- session.handle(ctx, requests)
- if !session.success {
- metrics.SliSshdSessionsErrorsTotal.Inc()
- }
+ return session.handle(ctx, requests)
})
reason := sconn.Wait()