diff options
author | Nick Thomas <nick@gitlab.com> | 2021-07-15 12:44:48 +0000 |
---|---|---|
committer | Nick Thomas <nick@gitlab.com> | 2021-07-15 12:44:48 +0000 |
commit | 60a91e977cb76e2cc49faeeac134640dbf2e2e6c (patch) | |
tree | b693a2244f3d715d48df83eac70bdf6630e51d3a | |
parent | d3711d8d7e781dbff01d8ae5c7a1d5b800c5c8a2 (diff) | |
parent | 569a0197cacc75270776217c27e9d709907a9dfa (diff) | |
download | gitlab-shell-60a91e977cb76e2cc49faeeac134640dbf2e2e6c.tar.gz |
Merge branch 'id-cancelable-sshd' into 'main'
Shutdown sshd gracefully
See merge request gitlab-org/gitlab-shell!484
-rw-r--r-- | cmd/gitlab-sshd/main.go | 28 | ||||
-rw-r--r-- | internal/config/config.go | 7 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 98 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 49 |
4 files changed, 159 insertions, 23 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go index 866bc8d..7cecbf5 100644 --- a/cmd/gitlab-sshd/main.go +++ b/cmd/gitlab-sshd/main.go @@ -3,6 +3,10 @@ package main import ( "flag" "os" + "os/signal" + "context" + "syscall" + "time" log "github.com/sirupsen/logrus" @@ -63,6 +67,8 @@ func main() { ctx, finished := command.Setup("gitlab-sshd", cfg) defer finished() + server := sshd.Server{Config: cfg} + // Startup monitoring endpoint. if cfg.Server.WebListen != "" { go func() { @@ -75,7 +81,27 @@ func main() { }() } - if err := sshd.Run(ctx, cfg); err != nil { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-done + signal.Reset(syscall.SIGINT, syscall.SIGTERM) + + log.WithFields(log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Infof("Shutdown initiated") + + server.Shutdown() + + <-time.After(cfg.Server.GracePeriod()) + + cancel() + + }() + + if err := server.ListenAndServe(ctx); err != nil { log.Fatalf("Failed to start GitLab built-in sshd: %v", err) } } diff --git a/internal/config/config.go b/internal/config/config.go index 23044cd..c58ea7d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,6 +8,7 @@ import ( "path" "path/filepath" "sync" + "time" "gitlab.com/gitlab-org/labkit/tracing" yaml "gopkg.in/yaml.v2" @@ -25,6 +26,7 @@ type ServerConfig struct { ProxyProtocol bool `yaml:"proxy_protocol,omitempty"` WebListen string `yaml:"web_listen,omitempty"` ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"` + GracePeriodSeconds uint64 `yaml:"grace_period"` HostKeyFiles []string `yaml:"host_key_files,omitempty"` } @@ -69,6 +71,7 @@ var ( Listen: "[::]:22", WebListen: "localhost:9122", ConcurrentSessionsLimit: 10, + GracePeriodSeconds: 10, HostKeyFiles: []string{ "/run/secrets/ssh-hostkeys/ssh_host_rsa_key", "/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key", @@ -77,6 +80,10 @@ var ( } ) +func (sc *ServerConfig) GracePeriod() time.Duration { + return time.Duration(sc.GracePeriodSeconds) * time.Second +} + func (c *Config) ApplyGlobalState() { if c.SslCertDir != "" { os.Setenv("SSL_CERT_DIR", c.SslCertDir) diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index b04366e..ef401dc 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -9,6 +9,7 @@ import ( "net" "strconv" "time" + "sync" log "github.com/sirupsen/logrus" @@ -20,28 +21,87 @@ import ( "gitlab.com/gitlab-org/labkit/correlation" ) -func Run(ctx context.Context, cfg *config.Config) error { - authorizedKeysClient, err := authorizedkeys.NewClient(cfg) - if err != nil { - return fmt.Errorf("failed to initialize GitLab client: %w", err) +type Server struct { + Config *config.Config + + onShutdown bool + wg sync.WaitGroup + listener net.Listener +} + +func (s *Server) ListenAndServe(ctx context.Context) error { + if err := s.listen(); err != nil { + return err } + defer s.listener.Close() + + return s.serve(ctx) +} + +func (s *Server) Shutdown() error { + if s.listener == nil { + return nil + } + + s.onShutdown = true + + return s.listener.Close() +} - sshListener, err := net.Listen("tcp", cfg.Server.Listen) +func (s *Server) listen() error { + sshListener, err := net.Listen("tcp", s.Config.Server.Listen) if err != nil { return fmt.Errorf("failed to listen for connection: %w", err) } - if cfg.Server.ProxyProtocol { + + if s.Config.Server.ProxyProtocol { sshListener = &proxyproto.Listener{Listener: sshListener} log.Info("Proxy protocol is enabled") } - defer sshListener.Close() log.Infof("Listening on %v", sshListener.Addr().String()) + s.listener = sshListener + + return nil +} + +func (s *Server) serve(ctx context.Context) error { + sshCfg, err := s.initConfig(ctx) + if err != nil { + return err + } + + for { + nconn, err := s.listener.Accept() + if err != nil { + if s.onShutdown { + break + } + + log.Warnf("Failed to accept connection: %v\n", err) + continue + } + + s.wg.Add(1) + go s.handleConn(ctx, sshCfg, nconn) + } + + s.wg.Wait() + + return nil +} + +func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(s.Config) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + sshCfg := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if conn.User() != cfg.User { + if conn.User() != s.Config.User { return nil, errors.New("unknown user") } if key.Type() == ssh.KeyAlgoDSA { @@ -64,7 +124,7 @@ func Run(ctx context.Context, cfg *config.Config) error { } var loadedHostKeys uint - for _, filename := range cfg.Server.HostKeyFiles { + for _, filename := range s.Config.Server.HostKeyFiles { keyRaw, err := ioutil.ReadFile(filename) if err != nil { log.Warnf("Failed to read host key %v: %v", filename, err) @@ -79,23 +139,17 @@ func Run(ctx context.Context, cfg *config.Config) error { sshCfg.AddHostKey(key) } if loadedHostKeys == 0 { - return fmt.Errorf("No host keys could be loaded, aborting") + return nil, fmt.Errorf("No host keys could be loaded, aborting") } - for { - nconn, err := sshListener.Accept() - if err != nil { - log.Warnf("Failed to accept connection: %v\n", err) - continue - } - - go handleConn(ctx, cfg, sshCfg, nconn) - } + return sshCfg, nil } -func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) { + +func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() + defer s.wg.Done() defer nconn.Close() // Prevent a panic in a single connection from taking out the whole server @@ -116,10 +170,10 @@ func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfi go ssh.DiscardRequests(reqs) - conn := newConnection(cfg.Server.ConcurrentSessionsLimit, remoteAddr) + conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr) conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { session := &session{ - cfg: cfg, + cfg: s.Config, channel: channel, gitlabKeyId: sconn.Permissions.Extensions["key-id"], remoteAddr: remoteAddr, diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go new file mode 100644 index 0000000..d1891ec --- /dev/null +++ b/internal/sshd/sshd_test.go @@ -0,0 +1,49 @@ +package sshd + +import ( + "testing" + "context" + "path" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +const serverUrl = "127.0.0.1:50000" + +func TestShutdown(t *testing.T) { + s := setupServer(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan bool, 1) + go func() { + require.NoError(t, s.serve(ctx)) + done <- true + }() + + require.NoError(t, s.Shutdown()) + + require.True(t, <-done, "the accepting loop must be interrupted") +} + +func setupServer(t *testing.T) *Server { + testhelper.PrepareTestRootDir(t) + + url := testserver.StartSocketHttpServer(t, []testserver.TestRequestHandler{}) + srvCfg := config.ServerConfig{ + Listen: serverUrl, + HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}, + } + + cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg} + + s := &Server{Config: cfg} + require.NoError(t, s.listen()) + + return s +} |