summaryrefslogtreecommitdiff
path: root/internal/sshd/session_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/session_test.go')
-rw-r--r--internal/sshd/session_test.go87
1 files changed, 54 insertions, 33 deletions
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go
index d0cc8d4..5bc8e7c 100644
--- a/internal/sshd/session_test.go
+++ b/internal/sshd/session_test.go
@@ -3,6 +3,7 @@ package sshd
import (
"bytes"
"context"
+ "errors"
"io"
"net/http"
"testing"
@@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedProtocolVersion string
expectedResult bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"),
expectedProtocolVersion: "1",
expectedResult: false,
}, {
desc: "valid payload",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "2",
expectedResult: true,
}, {
desc: "valid payload with forbidden env var",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
+ expectedErr: nil,
expectedProtocolVersion: "1",
expectedResult: true,
},
@@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) {
s := &session{gitProtocolVersion: "1"}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
- require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
+ shouldContinue, err := s.handleEnv(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, tc.expectedResult, shouldContinue)
+ require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion)
})
}
}
@@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) {
testCases := []struct {
desc string
payload []byte
+ expectedErr error
expectedExecCmd string
sentRequestName string
sentRequestPayload []byte
- success bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
+ expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"),
expectedExecCmd: "",
sentRequestName: "",
}, {
desc: "valid payload",
payload: ssh.Marshal(execRequest{Command: "discover"}),
+ expectedErr: nil,
expectedExecCmd: "discover",
sentRequestName: "exit-status",
sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
- success: true,
},
}
@@ -129,47 +138,53 @@ func TestHandleExec(t *testing.T) {
}
r := &ssh.Request{Payload: tc.payload}
- require.Equal(t, false, s.handleExec(context.Background(), r))
+ shouldContinue, err := s.handleExec(context.Background(), r)
+
+ require.Equal(t, tc.expectedErr, err)
+ require.Equal(t, false, shouldContinue)
require.Equal(t, tc.sentRequestName, f.sentRequestName)
require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
- require.Equal(t, tc.success, s.success)
})
}
}
func TestHandleShell(t *testing.T) {
testCases := []struct {
- desc string
- cmd string
- errMsg string
- gitlabKeyId string
- expectedExitCode uint32
- success bool
+ desc string
+ cmd string
+ errMsg string
+ gitlabKeyId string
+ expectedErrString string
+ expectedExitCode uint32
}{
{
- desc: "fails to parse command",
- cmd: `\`,
- errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
- gitlabKeyId: "root",
- expectedExitCode: 128,
+ desc: "fails to parse command",
+ cmd: `\`,
+ errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
+ gitlabKeyId: "root",
+ expectedErrString: "Invalid SSH command: invalid command line string",
+ expectedExitCode: 128,
}, {
- desc: "specified command is unknown",
- cmd: "unknown-command",
- errMsg: "Unknown command: unknown-command\n",
- gitlabKeyId: "root",
- expectedExitCode: 128,
+ desc: "specified command is unknown",
+ cmd: "unknown-command",
+ errMsg: "Unknown command: unknown-command\n",
+ gitlabKeyId: "root",
+ expectedErrString: "Disallowed command",
+ expectedExitCode: 128,
}, {
- desc: "fails to parse command",
- cmd: "discover",
- gitlabKeyId: "",
- errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
- expectedExitCode: 1,
+ desc: "fails to parse command",
+ cmd: "discover",
+ gitlabKeyId: "",
+ errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
+ expectedErrString: "Failed to get username: who='' is invalid",
+ expectedExitCode: 1,
}, {
- desc: "fails to parse command",
- cmd: "discover",
- errMsg: "",
- gitlabKeyId: "root",
- expectedExitCode: 0,
+ desc: "fails to parse command",
+ cmd: "discover",
+ errMsg: "",
+ gitlabKeyId: "root",
+ expectedErrString: "",
+ expectedExitCode: 0,
},
}
@@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) {
}
r := &ssh.Request{}
- require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
+ exitCode, err := s.handleShell(context.Background(), r)
+
+ if tc.expectedErrString != "" {
+ require.Equal(t, tc.expectedErrString, err.Error())
+ }
+
+ require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.errMsg, out.String())
})
}