diff options
-rw-r--r-- | internal/sshd/session_test.go | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go index 152506e..e4e312c 100644 --- a/internal/sshd/session_test.go +++ b/internal/sshd/session_test.go @@ -1,12 +1,61 @@ package sshd import ( + "bytes" + "context" + "io" + "net/http" "testing" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" ) +type fakeChannel struct { + stdErr io.ReadWriter + sentRequestName string + sentRequestPayload []byte +} + +func (f *fakeChannel) Read(data []byte) (int, error) { + return 0, nil +} + +func (f *fakeChannel) Write(data []byte) (int, error) { + return 0, nil +} + +func (f *fakeChannel) Close() error { + return nil +} + +func (f *fakeChannel) CloseWrite() error { + return nil +} + +func (f *fakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + f.sentRequestName = name + f.sentRequestPayload = payload + + return true, nil +} + +func (f *fakeChannel) Stderr() io.ReadWriter { + return f.stdErr +} + +var requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"id": 1000, "name": "Test User", "username": "test-user"}`)) + }, + }, +} + func TestHandleEnv(t *testing.T) { testCases := []struct { desc string @@ -42,3 +91,99 @@ func TestHandleEnv(t *testing.T) { }) } } + +func TestHandleExec(t *testing.T) { + testCases := []struct { + desc string + payload []byte + expectedExecCmd string + sentRequestName string + sentRequestPayload []byte + }{ + { + desc: "invalid payload", + payload: []byte("invalid"), + expectedExecCmd: "", + sentRequestName: "", + }, { + desc: "valid payload", + payload: ssh.Marshal(execRequest{Command: "discover"}), + expectedExecCmd: "discover", + sentRequestName: "exit-status", + sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}), + }, + } + + url := testserver.StartHttpServer(t, requests) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := &bytes.Buffer{} + f := &fakeChannel{stdErr: out} + s := &session{ + gitlabKeyId: "root", + channel: f, + cfg: &config.Config{GitlabUrl: url}, + } + r := &ssh.Request{Payload: tc.payload} + + require.Equal(t, false, s.handleExec(context.Background(), r)) + require.Equal(t, tc.sentRequestName, f.sentRequestName) + require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload) + }) + } +} + +func TestHandleShell(t *testing.T) { + testCases := []struct { + desc string + cmd string + errMsg string + gitlabKeyId string + expectedExitCode uint32 + }{ + { + desc: "fails to parse command", + cmd: `\`, + errMsg: "Failed to parse command: invalid command line string\n", + gitlabKeyId: "root", + expectedExitCode: 128, + }, { + desc: "specified command is unknown", + cmd: "unknown-command", + errMsg: "Unknown command: unknown-command\n", + gitlabKeyId: "root", + expectedExitCode: 128, + }, { + desc: "fails to parse command", + cmd: "discover", + gitlabKeyId: "", + errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", + expectedExitCode: 1, + }, { + desc: "fails to parse command", + cmd: "discover", + errMsg: "", + gitlabKeyId: "root", + expectedExitCode: 0, + }, + } + + url := testserver.StartHttpServer(t, requests) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := &bytes.Buffer{} + s := &session{ + gitlabKeyId: tc.gitlabKeyId, + execCmd: tc.cmd, + channel: &fakeChannel{stdErr: out}, + cfg: &config.Config{GitlabUrl: url}, + } + r := &ssh.Request{} + + require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r)) + require.Equal(t, tc.errMsg, out.String()) + }) + } +} |