diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2021-04-23 14:15:01 +0000 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2021-04-23 14:15:01 +0000 |
commit | 39792693a2a2d06669103714e7fa9da83b0e9b12 (patch) | |
tree | c06a1d0b6b5a1497866ff22425f29976f530865c | |
parent | 4ce33557c9364bb1771abaa6a475b1dba9a1ad3e (diff) | |
parent | 339fa88a70a4450c75965bbf7ba6b64780c2d92d (diff) | |
download | gitlab-shell-39792693a2a2d06669103714e7fa9da83b0e9b12.tar.gz |
Merge branch '511-extract-session' into 'main'
Extract sshd connections and sessions into their own files and structs
See merge request gitlab-org/gitlab-shell!463
-rw-r--r-- | internal/sshd/connection.go | 89 | ||||
-rw-r--r-- | internal/sshd/session.go | 151 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 192 |
3 files changed, 257 insertions, 175 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go new file mode 100644 index 0000000..c8d1456 --- /dev/null +++ b/internal/sshd/connection.go @@ -0,0 +1,89 @@ +package sshd + +import ( + "context" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/sync/semaphore" +) + +const ( + namespace = "gitlab_shell" + sshdSubsystem = "sshd" +) + +var ( + sshdConnectionDuration = promauto.NewHistogram( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: sshdSubsystem, + Name: "connection_duration_seconds", + Help: "A histogram of latencies for connections to gitlab-shell sshd.", + Buckets: []float64{ + 0.005, /* 5ms */ + 0.025, /* 25ms */ + 0.1, /* 100ms */ + 0.5, /* 500ms */ + 1.0, /* 1s */ + 10.0, /* 10s */ + 30.0, /* 30s */ + 60.0, /* 1m */ + 300.0, /* 5m */ + }, + }, + ) + + sshdHitMaxSessions = promauto.NewCounter( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: sshdSubsystem, + Name: "concurrent_limited_sessions_total", + Help: "The number of times the concurrent sessions limit was hit in gitlab-shell sshd.", + }, + ) +) + +type connection struct { + begin time.Time + concurrentSessions *semaphore.Weighted +} + +type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) + +func newConnection(maxSessions int64) *connection { + return &connection{ + begin: time.Now(), + concurrentSessions: semaphore.NewWeighted(maxSessions), + } +} + +func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { + defer sshdConnectionDuration.Observe(time.Since(c.begin).Seconds()) + + for newChannel := range chans { + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + if !c.concurrentSessions.TryAcquire(1) { + newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") + sshdHitMaxSessions.Inc() + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Infof("Could not accept channel: %v", err) + c.concurrentSessions.Release(1) + continue + } + + go func() { + defer c.concurrentSessions.Release(1) + handler(ctx, channel, requests) + }() + } +} diff --git a/internal/sshd/session.go b/internal/sshd/session.go new file mode 100644 index 0000000..22cb715 --- /dev/null +++ b/internal/sshd/session.go @@ -0,0 +1,151 @@ +package sshd + +import ( + "context" + "fmt" + + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" +) + +type session struct { + // State set up by the connection + cfg *config.Config + channel ssh.Channel + gitlabKeyId string + remoteAddr string + + // State managed by the session + execCmd string + gitProtocolVersion string +} + +type execRequest struct { + Command string +} + +type envRequest struct { + Name string + Value string +} + +type exitStatusReq struct { + ExitStatus uint32 +} + +func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { + for req := range requests { + var shouldContinue bool + switch req.Type { + case "env": + shouldContinue = s.handleEnv(req) + case "exec": + shouldContinue = s.handleExec(ctx, req) + case "shell": + shouldContinue = false + s.exit(s.handleShell(ctx, req)) + default: + // Ignore unknown requests but don't terminate the session + shouldContinue = true + if req.WantReply { + req.Reply(false, []byte{}) + } + } + + if !shouldContinue { + s.channel.Close() + break + } + } +} + +func (s *session) handleEnv(req *ssh.Request) bool { + var accepted bool + var envRequest envRequest + + if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { + return false + } + + switch envRequest.Name { + case sshenv.GitProtocolEnv: + s.gitProtocolVersion = envRequest.Value + accepted = true + default: + // Client requested a forbidden envvar, nothing to do + } + + if req.WantReply { + req.Reply(accepted, []byte{}) + } + + return true +} + +func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool { + var execRequest execRequest + if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { + return false + } + + s.execCmd = execRequest.Command + + s.exit(s.handleShell(ctx, req)) + return false +} + +func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { + if req.WantReply { + req.Reply(true, []byte{}) + } + + args := &commandargs.Shell{ + GitlabKeyId: s.gitlabKeyId, + Env: sshenv.Env{ + IsSSHConnection: true, + OriginalCommand: s.execCmd, + GitProtocolVersion: s.gitProtocolVersion, + RemoteAddr: s.remoteAddr, + }, + } + + if err := args.ParseCommand(s.execCmd); err != nil { + s.toStderr("Failed to parse command: %v\n", err.Error()) + return 128 + } + + rw := &readwriter.ReadWriter{ + Out: s.channel, + In: s.channel, + ErrOut: s.channel.Stderr(), + } + + cmd := command.BuildShellCommand(args, s.cfg, rw) + if cmd == nil { + s.toStderr("Unknown command: %v\n", args.CommandType) + return 128 + } + + if err := cmd.Execute(ctx); err != nil { + s.toStderr("remote: ERROR: %v\n", err.Error()) + return 1 + } + + return 0 +} + +func (s *session) toStderr(format string, args ...interface{}) { + fmt.Fprintf(s.channel.Stderr(), format, args...) +} + +func (s *session) exit(status uint32) { + req := exitStatusReq{ExitStatus: status} + + s.channel.CloseWrite() + s.channel.SendRequest("exit-status", false, ssh.Marshal(req)) +} diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 7bd81ff..f046d60 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -13,57 +13,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/pires/go-proxyproto" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "golang.org/x/crypto/ssh" - "golang.org/x/sync/semaphore" - "gitlab.com/gitlab-org/gitlab-shell/internal/command" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" -) - -const ( - namespace = "gitlab_shell" - sshdSubsystem = "sshd" -) - -func secondsDurationBuckets() []float64 { - return []float64{ - 0.005, /* 5ms */ - 0.025, /* 25ms */ - 0.1, /* 100ms */ - 0.5, /* 500ms */ - 1.0, /* 1s */ - 10.0, /* 10s */ - 30.0, /* 30s */ - 60.0, /* 1m */ - 300.0, /* 10m */ - } -} - -var ( - sshdConnectionDuration = promauto.NewHistogram( - prometheus.HistogramOpts{ - Namespace: namespace, - Subsystem: sshdSubsystem, - Name: "connection_duration_seconds", - Help: "A histogram of latencies for connections to gitlab-shell sshd.", - Buckets: secondsDurationBuckets(), - }, - ) - - sshdHitMaxSessions = promauto.NewCounter( - prometheus.CounterOpts{ - Namespace: namespace, - Subsystem: sshdSubsystem, - Name: "concurrent_limited_sessions_total", - Help: "The number of times the concurrent sessions limit was hit in gitlab-shell sshd.", - }, - ) ) func Run(cfg *config.Config) error { @@ -85,7 +38,7 @@ func Run(cfg *config.Config) error { log.Infof("Listening on %v", sshListener.Addr().String()) - config := &ssh.ServerConfig{ + sshCfg := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if conn.User() != cfg.User { return nil, errors.New("unknown user") @@ -122,7 +75,7 @@ func Run(cfg *config.Config) error { continue } loadedHostKeys++ - config.AddHostKey(key) + sshCfg.AddHostKey(key) } if loadedHostKeys == 0 { return fmt.Errorf("No host keys could be loaded, aborting") @@ -135,144 +88,33 @@ func Run(cfg *config.Config) error { continue } - go handleConn(nconn, config, cfg) - } -} - -type execRequest struct { - Command string -} - -type exitStatusReq struct { - ExitStatus uint32 -} - -type envRequest struct { - Name string - Value string -} - -func exitSession(ch ssh.Channel, exitStatus uint32) { - exitStatusReq := exitStatusReq{ - ExitStatus: exitStatus, + go handleConn(cfg, sshCfg, nconn) } - ch.CloseWrite() - ch.SendRequest("exit-status", false, ssh.Marshal(exitStatusReq)) - ch.Close() } -func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) { - begin := time.Now() - defer func() { - sshdConnectionDuration.Observe(time.Since(begin).Seconds()) - }() +func handleConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) { + defer nconn.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - defer nconn.Close() - conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) + + sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) if err != nil { log.Infof("Failed to initialize SSH connection: %v", err) return } - concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit) - go ssh.DiscardRequests(reqs) - for newChannel := range chans { - if newChannel.ChannelType() != "session" { - newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") - continue - } - if !concurrentSessions.TryAcquire(1) { - newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") - sshdHitMaxSessions.Inc() - continue - } - ch, requests, err := newChannel.Accept() - if err != nil { - log.Infof("Could not accept channel: %v", err) - concurrentSessions.Release(1) - continue - } - - go handleSession(ctx, concurrentSessions, ch, requests, conn, nconn, cfg) - } -} - -func handleSession(ctx context.Context, concurrentSessions *semaphore.Weighted, ch ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn, nconn net.Conn, cfg *config.Config) { - defer concurrentSessions.Release(1) - - rw := &readwriter.ReadWriter{ - Out: ch, - In: ch, - ErrOut: ch.Stderr(), - } - var gitProtocolVersion string - - for req := range requests { - var execCmd string - switch req.Type { - case "env": - var envRequest envRequest - if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { - ch.Close() - return - } - var accepted bool - if envRequest.Name == sshenv.GitProtocolEnv { - gitProtocolVersion = envRequest.Value - accepted = true - } - if req.WantReply { - req.Reply(accepted, []byte{}) - } - case "exec": - var execRequest execRequest - if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { - ch.Close() - return - } - execCmd = execRequest.Command - fallthrough - case "shell": - if req.WantReply { - req.Reply(true, []byte{}) - } - args := &commandargs.Shell{ - GitlabKeyId: conn.Permissions.Extensions["key-id"], - Env: sshenv.Env{ - IsSSHConnection: true, - OriginalCommand: execCmd, - GitProtocolVersion: gitProtocolVersion, - RemoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(), - }, - } - - if err := args.ParseCommand(execCmd); err != nil { - fmt.Fprintf(ch.Stderr(), "Failed to parse command: %v\n", err.Error()) - exitSession(ch, 128) - return - } - - cmd := command.BuildShellCommand(args, cfg, rw) - if cmd == nil { - fmt.Fprintf(ch.Stderr(), "Unknown command: %v\n", args.CommandType) - exitSession(ch, 128) - return - } - if err := cmd.Execute(ctx); err != nil { - fmt.Fprintf(ch.Stderr(), "remote: ERROR: %v\n", err.Error()) - exitSession(ch, 1) - return - } - exitSession(ch, 0) - return - default: - if req.WantReply { - req.Reply(false, []byte{}) - } + conn := newConnection(cfg.Server.ConcurrentSessionsLimit) + conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { + session := &session{ + cfg: cfg, + channel: channel, + gitlabKeyId: sconn.Permissions.Extensions["key-id"], + remoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(), } - } + + session.handle(ctx, requests) + }) } |