summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/sshd/connection.go12
-rw-r--r--internal/sshd/connection_test.go49
-rw-r--r--internal/sshd/sshd.go13
3 files changed, 71 insertions, 3 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index c8d1456..a9b9e97 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -50,14 +50,16 @@ var (
type connection struct {
begin time.Time
concurrentSessions *semaphore.Weighted
+ remoteAddr string
}
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
-func newConnection(maxSessions int64) *connection {
+func newConnection(maxSessions int64, remoteAddr string) *connection {
return &connection{
begin: time.Now(),
concurrentSessions: semaphore.NewWeighted(maxSessions),
+ remoteAddr: remoteAddr,
}
}
@@ -83,6 +85,14 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
go func() {
defer c.concurrentSessions.Release(1)
+
+ // Prevent a panic in a single session from taking out the whole server
+ defer func() {
+ if err := recover(); err != nil {
+ log.Warnf("panic handling session from %s: recovered: %#+v", c.remoteAddr, err)
+ }
+ }()
+
handler(ctx, channel, requests)
}()
}
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
new file mode 100644
index 0000000..03e9209
--- /dev/null
+++ b/internal/sshd/connection_test.go
@@ -0,0 +1,49 @@
+package sshd
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+)
+
+type fakeNewChannel struct {
+ channelType string
+ extraData []byte
+}
+
+func (f *fakeNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
+ return nil, nil, nil
+}
+
+func (f *fakeNewChannel) Reject(reason ssh.RejectionReason, message string) error {
+ return nil
+}
+
+func (f *fakeNewChannel) ChannelType() string {
+ return f.channelType
+}
+
+func (f *fakeNewChannel) ExtraData() []byte {
+ return f.extraData
+}
+
+func TestPanicDuringSessionIsRecovered(t *testing.T) {
+ numSessions := 0
+ conn := newConnection(1, "127.0.0.1:50000")
+
+ newChannel := &fakeNewChannel{channelType: "session"}
+ chans := make(chan ssh.NewChannel, 1)
+ chans <- newChannel
+
+ require.NotPanics(t, func() {
+ conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
+ numSessions += 1
+ close(chans)
+ panic("This is a panic")
+ })
+ })
+
+ require.Equal(t, numSessions, 1)
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index f046d60..a9e797b 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -93,8 +93,17 @@ func Run(cfg *config.Config) error {
}
func handleConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
+ remoteAddr := nconn.RemoteAddr().String()
+
defer nconn.Close()
+ // Prevent a panic in a single connection from taking out the whole server
+ defer func() {
+ if err := recover(); err != nil {
+ log.Warnf("panic handling connection from %s: recovered: %#+v", remoteAddr, err)
+ }
+ }()
+
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -106,13 +115,13 @@ func handleConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
go ssh.DiscardRequests(reqs)
- conn := newConnection(cfg.Server.ConcurrentSessionsLimit)
+ conn := newConnection(cfg.Server.ConcurrentSessionsLimit, remoteAddr)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
session := &session{
cfg: cfg,
channel: channel,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
- remoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(),
+ remoteAddr: remoteAddr,
}
session.handle(ctx, requests)