summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Thomas <nick@gitlab.com>2021-07-21 14:32:33 +0100
committerNick Thomas <nick@gitlab.com>2021-07-22 10:48:31 +0100
commitdba8402032824a215f1e1e80c6e285e0ef66f646 (patch)
tree806cb9df9fd5456bcea69a66887b5757894f9bef
parenta8b2088d6d40e365445fcf4bea5183f83e31cc51 (diff)
downloadgitlab-shell-dba8402032824a215f1e1e80c6e285e0ef66f646.tar.gz
Unit tests for internal/sshd/connection.go
-rw-r--r--internal/sshd/connection_test.go78
1 files changed, 73 insertions, 5 deletions
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
index f48750e..d6bd3c0 100644
--- a/internal/sshd/connection_test.go
+++ b/internal/sshd/connection_test.go
@@ -2,6 +2,7 @@ package sshd
import (
"context"
+ "errors"
"testing"
"github.com/stretchr/testify/require"
@@ -9,22 +10,31 @@ import (
)
type rejectCall struct {
- reason ssh.RejectionReason
+ reason ssh.RejectionReason
message string
}
type fakeNewChannel struct {
channelType string
extraData []byte
+ acceptErr error
+
+ acceptCh chan struct{}
rejectCh chan rejectCall
}
func (f *fakeNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
- return nil, nil, nil
+ if f.acceptCh != nil {
+ f.acceptCh <- struct{}{}
+ }
+
+ return nil, nil, f.acceptErr
}
func (f *fakeNewChannel) Reject(reason ssh.RejectionReason, message string) error {
- f.rejectCh <- rejectCall{reason: reason, message: message}
+ if f.rejectCh != nil {
+ f.rejectCh <- rejectCall{reason: reason, message: message}
+ }
return nil
}
@@ -63,7 +73,9 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
}
func TestUnknownChannelType(t *testing.T) {
- rejectCh := make(chan rejectCall, 1)
+ rejectCh := make(chan rejectCall)
+ defer close(rejectCh)
+
newChannel := &fakeNewChannel{channelType: "unknown session", rejectCh: rejectCh}
conn, chans := setup(1, newChannel)
@@ -72,8 +84,64 @@ func TestUnknownChannelType(t *testing.T) {
}()
rejectionData := <-rejectCh
- close(rejectCh)
expectedRejection := rejectCall{reason: ssh.UnknownChannelType, message: "unknown channel type"}
require.Equal(t, expectedRejection, rejectionData)
}
+
+func TestTooManySessions(t *testing.T) {
+ rejectCh := make(chan rejectCall)
+ defer close(rejectCh)
+
+ newChannel := &fakeNewChannel{channelType: "session", rejectCh: rejectCh}
+ conn, chans := setup(1, newChannel)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ go func() {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ <-ctx.Done() // Keep the accepted channel open until the end of the test
+ })
+ }()
+
+ chans <- newChannel
+ require.Equal(t, <-rejectCh, rejectCall{reason: ssh.ResourceShortage, message: "too many concurrent sessions"})
+}
+
+func TestAcceptSessionSucceeds(t *testing.T) {
+ newChannel := &fakeNewChannel{channelType: "session"}
+ conn, chans := setup(1, newChannel)
+
+ channelHandled := false
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ channelHandled = true
+ close(chans)
+ })
+
+ require.True(t, channelHandled)
+}
+
+func TestAcceptSessionFails(t *testing.T) {
+ acceptCh := make(chan struct{})
+ defer close(acceptCh)
+
+ acceptErr := errors.New("some failure")
+ newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr}
+ conn, chans := setup(1, newChannel)
+
+ channelHandled := false
+ go func() {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ channelHandled = true
+ })
+ }()
+
+ require.Equal(t, <-acceptCh, struct{}{})
+
+ // Waits until the number of sessions is back to 0, since we can only have 1
+ conn.concurrentSessions.Acquire(context.Background(), 1)
+ defer conn.concurrentSessions.Release(1)
+
+ require.False(t, channelHandled)
+}