package sshd import ( "context" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/sync/semaphore" ) const ( namespace = "gitlab_shell" sshdSubsystem = "sshd" ) var ( sshdConnectionDuration = promauto.NewHistogram( prometheus.HistogramOpts{ Namespace: namespace, Subsystem: sshdSubsystem, Name: "connection_duration_seconds", Help: "A histogram of latencies for connections to gitlab-shell sshd.", Buckets: []float64{ 0.005, /* 5ms */ 0.025, /* 25ms */ 0.1, /* 100ms */ 0.5, /* 500ms */ 1.0, /* 1s */ 10.0, /* 10s */ 30.0, /* 30s */ 60.0, /* 1m */ 300.0, /* 5m */ }, }, ) sshdHitMaxSessions = promauto.NewCounter( prometheus.CounterOpts{ Namespace: namespace, Subsystem: sshdSubsystem, Name: "concurrent_limited_sessions_total", Help: "The number of times the concurrent sessions limit was hit in gitlab-shell sshd.", }, ) ) type connection struct { begin time.Time concurrentSessions *semaphore.Weighted remoteAddr string } type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) func newConnection(maxSessions int64, remoteAddr string) *connection { return &connection{ begin: time.Now(), concurrentSessions: semaphore.NewWeighted(maxSessions), remoteAddr: remoteAddr, } } func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { defer sshdConnectionDuration.Observe(time.Since(c.begin).Seconds()) for newChannel := range chans { if newChannel.ChannelType() != "session" { newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } if !c.concurrentSessions.TryAcquire(1) { newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") sshdHitMaxSessions.Inc() continue } channel, requests, err := newChannel.Accept() if err != nil { log.Infof("Could not accept channel: %v", err) c.concurrentSessions.Release(1) continue } go func() { defer c.concurrentSessions.Release(1) // Prevent a panic in a single session from taking out the whole server defer func() { if err := recover(); err != nil { log.Warnf("panic handling session from %s: recovered: %#+v", c.remoteAddr, err) } }() handler(ctx, channel, requests) }() } }