summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2021-04-23 14:15:01 +0000
committerIgor Drozdov <idrozdov@gitlab.com>2021-04-23 14:15:01 +0000
commit39792693a2a2d06669103714e7fa9da83b0e9b12 (patch)
treec06a1d0b6b5a1497866ff22425f29976f530865c
parent4ce33557c9364bb1771abaa6a475b1dba9a1ad3e (diff)
parent339fa88a70a4450c75965bbf7ba6b64780c2d92d (diff)
downloadgitlab-shell-39792693a2a2d06669103714e7fa9da83b0e9b12.tar.gz
Merge branch '511-extract-session' into 'main'
Extract sshd connections and sessions into their own files and structs See merge request gitlab-org/gitlab-shell!463
-rw-r--r--internal/sshd/connection.go89
-rw-r--r--internal/sshd/session.go151
-rw-r--r--internal/sshd/sshd.go192
3 files changed, 257 insertions, 175 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
new file mode 100644
index 0000000..c8d1456
--- /dev/null
+++ b/internal/sshd/connection.go
@@ -0,0 +1,89 @@
+package sshd
+
+import (
+ "context"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/sync/semaphore"
+)
+
+const (
+ namespace = "gitlab_shell"
+ sshdSubsystem = "sshd"
+)
+
+var (
+ sshdConnectionDuration = promauto.NewHistogram(
+ prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: sshdSubsystem,
+ Name: "connection_duration_seconds",
+ Help: "A histogram of latencies for connections to gitlab-shell sshd.",
+ Buckets: []float64{
+ 0.005, /* 5ms */
+ 0.025, /* 25ms */
+ 0.1, /* 100ms */
+ 0.5, /* 500ms */
+ 1.0, /* 1s */
+ 10.0, /* 10s */
+ 30.0, /* 30s */
+ 60.0, /* 1m */
+ 300.0, /* 5m */
+ },
+ },
+ )
+
+ sshdHitMaxSessions = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: sshdSubsystem,
+ Name: "concurrent_limited_sessions_total",
+ Help: "The number of times the concurrent sessions limit was hit in gitlab-shell sshd.",
+ },
+ )
+)
+
+type connection struct {
+ begin time.Time
+ concurrentSessions *semaphore.Weighted
+}
+
+type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
+
+func newConnection(maxSessions int64) *connection {
+ return &connection{
+ begin: time.Now(),
+ concurrentSessions: semaphore.NewWeighted(maxSessions),
+ }
+}
+
+func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
+ defer sshdConnectionDuration.Observe(time.Since(c.begin).Seconds())
+
+ for newChannel := range chans {
+ if newChannel.ChannelType() != "session" {
+ newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
+ continue
+ }
+ if !c.concurrentSessions.TryAcquire(1) {
+ newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
+ sshdHitMaxSessions.Inc()
+ continue
+ }
+ channel, requests, err := newChannel.Accept()
+ if err != nil {
+ log.Infof("Could not accept channel: %v", err)
+ c.concurrentSessions.Release(1)
+ continue
+ }
+
+ go func() {
+ defer c.concurrentSessions.Release(1)
+ handler(ctx, channel, requests)
+ }()
+ }
+}
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
new file mode 100644
index 0000000..22cb715
--- /dev/null
+++ b/internal/sshd/session.go
@@ -0,0 +1,151 @@
+package sshd
+
+import (
+ "context"
+ "fmt"
+
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
+)
+
+type session struct {
+ // State set up by the connection
+ cfg *config.Config
+ channel ssh.Channel
+ gitlabKeyId string
+ remoteAddr string
+
+ // State managed by the session
+ execCmd string
+ gitProtocolVersion string
+}
+
+type execRequest struct {
+ Command string
+}
+
+type envRequest struct {
+ Name string
+ Value string
+}
+
+type exitStatusReq struct {
+ ExitStatus uint32
+}
+
+func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
+ for req := range requests {
+ var shouldContinue bool
+ switch req.Type {
+ case "env":
+ shouldContinue = s.handleEnv(req)
+ case "exec":
+ shouldContinue = s.handleExec(ctx, req)
+ case "shell":
+ shouldContinue = false
+ s.exit(s.handleShell(ctx, req))
+ default:
+ // Ignore unknown requests but don't terminate the session
+ shouldContinue = true
+ if req.WantReply {
+ req.Reply(false, []byte{})
+ }
+ }
+
+ if !shouldContinue {
+ s.channel.Close()
+ break
+ }
+ }
+}
+
+func (s *session) handleEnv(req *ssh.Request) bool {
+ var accepted bool
+ var envRequest envRequest
+
+ if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
+ return false
+ }
+
+ switch envRequest.Name {
+ case sshenv.GitProtocolEnv:
+ s.gitProtocolVersion = envRequest.Value
+ accepted = true
+ default:
+ // Client requested a forbidden envvar, nothing to do
+ }
+
+ if req.WantReply {
+ req.Reply(accepted, []byte{})
+ }
+
+ return true
+}
+
+func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
+ var execRequest execRequest
+ if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
+ return false
+ }
+
+ s.execCmd = execRequest.Command
+
+ s.exit(s.handleShell(ctx, req))
+ return false
+}
+
+func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
+ if req.WantReply {
+ req.Reply(true, []byte{})
+ }
+
+ args := &commandargs.Shell{
+ GitlabKeyId: s.gitlabKeyId,
+ Env: sshenv.Env{
+ IsSSHConnection: true,
+ OriginalCommand: s.execCmd,
+ GitProtocolVersion: s.gitProtocolVersion,
+ RemoteAddr: s.remoteAddr,
+ },
+ }
+
+ if err := args.ParseCommand(s.execCmd); err != nil {
+ s.toStderr("Failed to parse command: %v\n", err.Error())
+ return 128
+ }
+
+ rw := &readwriter.ReadWriter{
+ Out: s.channel,
+ In: s.channel,
+ ErrOut: s.channel.Stderr(),
+ }
+
+ cmd := command.BuildShellCommand(args, s.cfg, rw)
+ if cmd == nil {
+ s.toStderr("Unknown command: %v\n", args.CommandType)
+ return 128
+ }
+
+ if err := cmd.Execute(ctx); err != nil {
+ s.toStderr("remote: ERROR: %v\n", err.Error())
+ return 1
+ }
+
+ return 0
+}
+
+func (s *session) toStderr(format string, args ...interface{}) {
+ fmt.Fprintf(s.channel.Stderr(), format, args...)
+}
+
+func (s *session) exit(status uint32) {
+ req := exitStatusReq{ExitStatus: status}
+
+ s.channel.CloseWrite()
+ s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 7bd81ff..f046d60 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -13,57 +13,10 @@ import (
log "github.com/sirupsen/logrus"
"github.com/pires/go-proxyproto"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/crypto/ssh"
- "golang.org/x/sync/semaphore"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
-)
-
-const (
- namespace = "gitlab_shell"
- sshdSubsystem = "sshd"
-)
-
-func secondsDurationBuckets() []float64 {
- return []float64{
- 0.005, /* 5ms */
- 0.025, /* 25ms */
- 0.1, /* 100ms */
- 0.5, /* 500ms */
- 1.0, /* 1s */
- 10.0, /* 10s */
- 30.0, /* 30s */
- 60.0, /* 1m */
- 300.0, /* 10m */
- }
-}
-
-var (
- sshdConnectionDuration = promauto.NewHistogram(
- prometheus.HistogramOpts{
- Namespace: namespace,
- Subsystem: sshdSubsystem,
- Name: "connection_duration_seconds",
- Help: "A histogram of latencies for connections to gitlab-shell sshd.",
- Buckets: secondsDurationBuckets(),
- },
- )
-
- sshdHitMaxSessions = promauto.NewCounter(
- prometheus.CounterOpts{
- Namespace: namespace,
- Subsystem: sshdSubsystem,
- Name: "concurrent_limited_sessions_total",
- Help: "The number of times the concurrent sessions limit was hit in gitlab-shell sshd.",
- },
- )
)
func Run(cfg *config.Config) error {
@@ -85,7 +38,7 @@ func Run(cfg *config.Config) error {
log.Infof("Listening on %v", sshListener.Addr().String())
- config := &ssh.ServerConfig{
+ sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if conn.User() != cfg.User {
return nil, errors.New("unknown user")
@@ -122,7 +75,7 @@ func Run(cfg *config.Config) error {
continue
}
loadedHostKeys++
- config.AddHostKey(key)
+ sshCfg.AddHostKey(key)
}
if loadedHostKeys == 0 {
return fmt.Errorf("No host keys could be loaded, aborting")
@@ -135,144 +88,33 @@ func Run(cfg *config.Config) error {
continue
}
- go handleConn(nconn, config, cfg)
- }
-}
-
-type execRequest struct {
- Command string
-}
-
-type exitStatusReq struct {
- ExitStatus uint32
-}
-
-type envRequest struct {
- Name string
- Value string
-}
-
-func exitSession(ch ssh.Channel, exitStatus uint32) {
- exitStatusReq := exitStatusReq{
- ExitStatus: exitStatus,
+ go handleConn(cfg, sshCfg, nconn)
}
- ch.CloseWrite()
- ch.SendRequest("exit-status", false, ssh.Marshal(exitStatusReq))
- ch.Close()
}
-func handleConn(nconn net.Conn, sshCfg *ssh.ServerConfig, cfg *config.Config) {
- begin := time.Now()
- defer func() {
- sshdConnectionDuration.Observe(time.Since(begin).Seconds())
- }()
+func handleConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
+ defer nconn.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- defer nconn.Close()
- conn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
+
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
if err != nil {
log.Infof("Failed to initialize SSH connection: %v", err)
return
}
- concurrentSessions := semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit)
-
go ssh.DiscardRequests(reqs)
- for newChannel := range chans {
- if newChannel.ChannelType() != "session" {
- newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
- continue
- }
- if !concurrentSessions.TryAcquire(1) {
- newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
- sshdHitMaxSessions.Inc()
- continue
- }
- ch, requests, err := newChannel.Accept()
- if err != nil {
- log.Infof("Could not accept channel: %v", err)
- concurrentSessions.Release(1)
- continue
- }
-
- go handleSession(ctx, concurrentSessions, ch, requests, conn, nconn, cfg)
- }
-}
-
-func handleSession(ctx context.Context, concurrentSessions *semaphore.Weighted, ch ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn, nconn net.Conn, cfg *config.Config) {
- defer concurrentSessions.Release(1)
-
- rw := &readwriter.ReadWriter{
- Out: ch,
- In: ch,
- ErrOut: ch.Stderr(),
- }
- var gitProtocolVersion string
-
- for req := range requests {
- var execCmd string
- switch req.Type {
- case "env":
- var envRequest envRequest
- if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
- ch.Close()
- return
- }
- var accepted bool
- if envRequest.Name == sshenv.GitProtocolEnv {
- gitProtocolVersion = envRequest.Value
- accepted = true
- }
- if req.WantReply {
- req.Reply(accepted, []byte{})
- }
- case "exec":
- var execRequest execRequest
- if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
- ch.Close()
- return
- }
- execCmd = execRequest.Command
- fallthrough
- case "shell":
- if req.WantReply {
- req.Reply(true, []byte{})
- }
- args := &commandargs.Shell{
- GitlabKeyId: conn.Permissions.Extensions["key-id"],
- Env: sshenv.Env{
- IsSSHConnection: true,
- OriginalCommand: execCmd,
- GitProtocolVersion: gitProtocolVersion,
- RemoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(),
- },
- }
-
- if err := args.ParseCommand(execCmd); err != nil {
- fmt.Fprintf(ch.Stderr(), "Failed to parse command: %v\n", err.Error())
- exitSession(ch, 128)
- return
- }
-
- cmd := command.BuildShellCommand(args, cfg, rw)
- if cmd == nil {
- fmt.Fprintf(ch.Stderr(), "Unknown command: %v\n", args.CommandType)
- exitSession(ch, 128)
- return
- }
- if err := cmd.Execute(ctx); err != nil {
- fmt.Fprintf(ch.Stderr(), "remote: ERROR: %v\n", err.Error())
- exitSession(ch, 1)
- return
- }
- exitSession(ch, 0)
- return
- default:
- if req.WantReply {
- req.Reply(false, []byte{})
- }
+ conn := newConnection(cfg.Server.ConcurrentSessionsLimit)
+ conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
+ session := &session{
+ cfg: cfg,
+ channel: channel,
+ gitlabKeyId: sconn.Permissions.Extensions["key-id"],
+ remoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(),
}
- }
+
+ session.handle(ctx, requests)
+ })
}