diff options
author | Nick Thomas <nick@gitlab.com> | 2021-04-22 17:33:49 +0100 |
---|---|---|
committer | Nick Thomas <nick@gitlab.com> | 2021-04-23 14:17:12 +0100 |
commit | 5f4e800f9959504c99df35d52773af4a6de6bdfd (patch) | |
tree | bac8c090f14de4cf7eb09c1b29793b0b8cf98f10 | |
parent | 31920be44344e55252505cf8c1574655197bb31b (diff) | |
download | gitlab-shell-5f4e800f9959504c99df35d52773af4a6de6bdfd.tar.gz |
sshd: Extract connections into their own file
-rw-r--r-- | internal/sshd/connection.go | 89 | ||||
-rw-r--r-- | internal/sshd/session.go | 15 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 99 |
3 files changed, 113 insertions, 90 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go new file mode 100644 index 0000000..a4c6b36 --- /dev/null +++ b/internal/sshd/connection.go @@ -0,0 +1,89 @@ +package sshd + +import ( + "context" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/sync/semaphore" +) + +const ( + namespace = "gitlab_shell" + sshdSubsystem = "sshd" +) + +var ( + sshdConnectionDuration = promauto.NewHistogram( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: sshdSubsystem, + Name: "connection_duration_seconds", + Help: "A histogram of latencies for connections to gitlab-shell sshd.", + Buckets: []float64{ + 0.005, /* 5ms */ + 0.025, /* 25ms */ + 0.1, /* 100ms */ + 0.5, /* 500ms */ + 1.0, /* 1s */ + 10.0, /* 10s */ + 30.0, /* 30s */ + 60.0, /* 1m */ + 300.0, /* 10m */ + }, + }, + ) + + sshdHitMaxSessions = promauto.NewCounter( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: sshdSubsystem, + Name: "concurrent_limited_sessions_total", + Help: "The number of times the concurrent sessions limit was hit in gitlab-shell sshd.", + }, + ) +) + +type connection struct { + begin time.Time + concurrentSessions *semaphore.Weighted +} + +type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) + +func newConnection(maxSessions int64) *connection { + return &connection{ + begin: time.Now(), + concurrentSessions: semaphore.NewWeighted(maxSessions), + } +} + +func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { + defer sshdConnectionDuration.Observe(time.Since(c.begin).Seconds()) + + for newChannel := range chans { + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + if !c.concurrentSessions.TryAcquire(1) { + newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") + sshdHitMaxSessions.Inc() + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Infof("Could not accept channel: %v", err) + c.concurrentSessions.Release(1) + continue + } + + go func() { + defer c.concurrentSessions.Release(1) + handler(ctx, channel, requests) + }() + } +} diff --git a/internal/sshd/session.go b/internal/sshd/session.go index e178fe8..22cb715 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -3,7 +3,6 @@ package sshd import ( "context" "fmt" - "net" "golang.org/x/crypto/ssh" @@ -15,11 +14,11 @@ import ( ) type session struct { - // State set up by handleConn - cfg *config.Config - channel ssh.Channel - sconn *ssh.ServerConn - nconn net.Conn + // State set up by the connection + cfg *config.Config + channel ssh.Channel + gitlabKeyId string + remoteAddr string // State managed by the session execCmd string @@ -106,12 +105,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { } args := &commandargs.Shell{ - GitlabKeyId: s.sconn.Permissions.Extensions["key-id"], + GitlabKeyId: s.gitlabKeyId, Env: sshenv.Env{ IsSSHConnection: true, OriginalCommand: s.execCmd, GitProtocolVersion: s.gitProtocolVersion, - RemoteAddr: s.nconn.RemoteAddr().(*net.TCPAddr).String(), + RemoteAddr: s.remoteAddr, }, } diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 7906f0d..f046d60 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -13,55 +13,12 @@ import ( log "github.com/sirupsen/logrus" "github.com/pires/go-proxyproto" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "golang.org/x/crypto/ssh" - "golang.org/x/sync/semaphore" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" ) -const ( - namespace = "gitlab_shell" - sshdSubsystem = "sshd" -) - -func secondsDurationBuckets() []float64 { - return []float64{ - 0.005, /* 5ms */ - 0.025, /* 25ms */ - 0.1, /* 100ms */ - 0.5, /* 500ms */ - 1.0, /* 1s */ - 10.0, /* 10s */ - 30.0, /* 30s */ - 60.0, /* 1m */ - 300.0, /* 10m */ - } -} - -var ( - sshdConnectionDuration = promauto.NewHistogram( - prometheus.HistogramOpts{ - Namespace: namespace, - Subsystem: sshdSubsystem, - Name: "connection_duration_seconds", - Help: "A histogram of latencies for connections to gitlab-shell sshd.", - Buckets: secondsDurationBuckets(), - }, - ) - - sshdHitMaxSessions = promauto.NewCounter( - prometheus.CounterOpts{ - Namespace: namespace, - Subsystem: sshdSubsystem, - Name: "concurrent_limited_sessions_total", - Help: "The number of times the concurrent sessions limit was hit in gitlab-shell sshd.", - }, - ) -) - func Run(cfg *config.Config) error { authorizedKeysClient, err := authorizedkeys.NewClient(cfg) if err != nil { @@ -81,7 +38,7 @@ func Run(cfg *config.Config) error { log.Infof("Listening on %v", sshListener.Addr().String()) - config := &ssh.ServerConfig{ + sshCfg := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if conn.User() != cfg.User { return nil, errors.New("unknown user") @@ -118,7 +75,7 @@ func Run(cfg *config.Config) error { continue } loadedHostKeys++ - config.AddHostKey(key) + sshCfg.AddHostKey(key) } if loadedHostKeys == 0 { return fmt.Errorf("No host keys could be loaded, aborting") @@ -131,55 +88,33 @@ func Run(cfg *config.Config) error { continue } - go handleConn(nconn, config, cfg) + go handleConn(cfg, sshCfg, nconn) } } -func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) { - begin := time.Now() - defer func() { - sshdConnectionDuration.Observe(time.Since(begin).Seconds()) - }() +func handleConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) { + defer nconn.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - defer nconn.Close() - conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) + + sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) if err != nil { log.Infof("Failed to initialize SSH connection: %v", err) return } - concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit) - go ssh.DiscardRequests(reqs) - for newChannel := range chans { - if newChannel.ChannelType() != "session" { - newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") - continue - } - if !concurrentSessions.TryAcquire(1) { - newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") - sshdHitMaxSessions.Inc() - continue - } - ch, requests, err := newChannel.Accept() - if err != nil { - log.Infof("Could not accept channel: %v", err) - concurrentSessions.Release(1) - continue - } - go func() { - defer concurrentSessions.Release(1) - session := &session{ - cfg: cfg, - channel: ch, - sconn: conn, - nconn: nconn, - } + conn := newConnection(cfg.Server.ConcurrentSessionsLimit) + conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { + session := &session{ + cfg: cfg, + channel: channel, + gitlabKeyId: sconn.Permissions.Extensions["key-id"], + remoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(), + } - session.handle(ctx, requests) - }() - } + session.handle(ctx, requests) + }) } |