diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2021-07-09 14:41:41 +0300 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2021-07-15 14:39:33 +0300 |
commit | 569a0197cacc75270776217c27e9d709907a9dfa (patch) | |
tree | b693a2244f3d715d48df83eac70bdf6630e51d3a | |
parent | d3711d8d7e781dbff01d8ae5c7a1d5b800c5c8a2 (diff) | |
download | gitlab-shell-569a0197cacc75270776217c27e9d709907a9dfa.tar.gz |
Shutdown sshd gracefully
When interruption signal is sent, we are closing ssh listener to
prevent it from accepting new connections
Then after configured grace period, we cancel the context to
cancel all ongoing operations
-rw-r--r-- | cmd/gitlab-sshd/main.go | 28 | ||||
-rw-r--r-- | internal/config/config.go | 7 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 98 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 49 |
4 files changed, 159 insertions, 23 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go index 866bc8d..7cecbf5 100644 --- a/cmd/gitlab-sshd/main.go +++ b/cmd/gitlab-sshd/main.go @@ -3,6 +3,10 @@ package main import ( "flag" "os" + "os/signal" + "context" + "syscall" + "time" log "github.com/sirupsen/logrus" @@ -63,6 +67,8 @@ func main() { ctx, finished := command.Setup("gitlab-sshd", cfg) defer finished() + server := sshd.Server{Config: cfg} + // Startup monitoring endpoint. if cfg.Server.WebListen != "" { go func() { @@ -75,7 +81,27 @@ func main() { }() } - if err := sshd.Run(ctx, cfg); err != nil { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-done + signal.Reset(syscall.SIGINT, syscall.SIGTERM) + + log.WithFields(log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Infof("Shutdown initiated") + + server.Shutdown() + + <-time.After(cfg.Server.GracePeriod()) + + cancel() + + }() + + if err := server.ListenAndServe(ctx); err != nil { log.Fatalf("Failed to start GitLab built-in sshd: %v", err) } } diff --git a/internal/config/config.go b/internal/config/config.go index 23044cd..c58ea7d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,6 +8,7 @@ import ( "path" "path/filepath" "sync" + "time" "gitlab.com/gitlab-org/labkit/tracing" yaml "gopkg.in/yaml.v2" @@ -25,6 +26,7 @@ type ServerConfig struct { ProxyProtocol bool `yaml:"proxy_protocol,omitempty"` WebListen string `yaml:"web_listen,omitempty"` ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"` + GracePeriodSeconds uint64 `yaml:"grace_period"` HostKeyFiles []string `yaml:"host_key_files,omitempty"` } @@ -69,6 +71,7 @@ var ( Listen: "[::]:22", WebListen: "localhost:9122", ConcurrentSessionsLimit: 10, + GracePeriodSeconds: 10, HostKeyFiles: []string{ "/run/secrets/ssh-hostkeys/ssh_host_rsa_key", "/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key", @@ -77,6 +80,10 @@ var ( } ) +func (sc *ServerConfig) GracePeriod() time.Duration { + return time.Duration(sc.GracePeriodSeconds) * time.Second +} + func (c *Config) ApplyGlobalState() { if c.SslCertDir != "" { os.Setenv("SSL_CERT_DIR", c.SslCertDir) diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index b04366e..ef401dc 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -9,6 +9,7 @@ import ( "net" "strconv" "time" + "sync" log "github.com/sirupsen/logrus" @@ -20,28 +21,87 @@ import ( "gitlab.com/gitlab-org/labkit/correlation" ) -func Run(ctx context.Context, cfg *config.Config) error { - authorizedKeysClient, err := authorizedkeys.NewClient(cfg) - if err != nil { - return fmt.Errorf("failed to initialize GitLab client: %w", err) +type Server struct { + Config *config.Config + + onShutdown bool + wg sync.WaitGroup + listener net.Listener +} + +func (s *Server) ListenAndServe(ctx context.Context) error { + if err := s.listen(); err != nil { + return err } + defer s.listener.Close() + + return s.serve(ctx) +} + +func (s *Server) Shutdown() error { + if s.listener == nil { + return nil + } + + s.onShutdown = true + + return s.listener.Close() +} - sshListener, err := net.Listen("tcp", cfg.Server.Listen) +func (s *Server) listen() error { + sshListener, err := net.Listen("tcp", s.Config.Server.Listen) if err != nil { return fmt.Errorf("failed to listen for connection: %w", err) } - if cfg.Server.ProxyProtocol { + + if s.Config.Server.ProxyProtocol { sshListener = &proxyproto.Listener{Listener: sshListener} log.Info("Proxy protocol is enabled") } - defer sshListener.Close() log.Infof("Listening on %v", sshListener.Addr().String()) + s.listener = sshListener + + return nil +} + +func (s *Server) serve(ctx context.Context) error { + sshCfg, err := s.initConfig(ctx) + if err != nil { + return err + } + + for { + nconn, err := s.listener.Accept() + if err != nil { + if s.onShutdown { + break + } + + log.Warnf("Failed to accept connection: %v\n", err) + continue + } + + s.wg.Add(1) + go s.handleConn(ctx, sshCfg, nconn) + } + + s.wg.Wait() + + return nil +} + +func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(s.Config) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + sshCfg := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if conn.User() != cfg.User { + if conn.User() != s.Config.User { return nil, errors.New("unknown user") } if key.Type() == ssh.KeyAlgoDSA { @@ -64,7 +124,7 @@ func Run(ctx context.Context, cfg *config.Config) error { } var loadedHostKeys uint - for _, filename := range cfg.Server.HostKeyFiles { + for _, filename := range s.Config.Server.HostKeyFiles { keyRaw, err := ioutil.ReadFile(filename) if err != nil { log.Warnf("Failed to read host key %v: %v", filename, err) @@ -79,23 +139,17 @@ func Run(ctx context.Context, cfg *config.Config) error { sshCfg.AddHostKey(key) } if loadedHostKeys == 0 { - return fmt.Errorf("No host keys could be loaded, aborting") + return nil, fmt.Errorf("No host keys could be loaded, aborting") } - for { - nconn, err := sshListener.Accept() - if err != nil { - log.Warnf("Failed to accept connection: %v\n", err) - continue - } - - go handleConn(ctx, cfg, sshCfg, nconn) - } + return sshCfg, nil } -func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) { + +func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() + defer s.wg.Done() defer nconn.Close() // Prevent a panic in a single connection from taking out the whole server @@ -116,10 +170,10 @@ func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfi go ssh.DiscardRequests(reqs) - conn := newConnection(cfg.Server.ConcurrentSessionsLimit, remoteAddr) + conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr) conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { session := &session{ - cfg: cfg, + cfg: s.Config, channel: channel, gitlabKeyId: sconn.Permissions.Extensions["key-id"], remoteAddr: remoteAddr, diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go new file mode 100644 index 0000000..d1891ec --- /dev/null +++ b/internal/sshd/sshd_test.go @@ -0,0 +1,49 @@ +package sshd + +import ( + "testing" + "context" + "path" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +const serverUrl = "127.0.0.1:50000" + +func TestShutdown(t *testing.T) { + s := setupServer(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan bool, 1) + go func() { + require.NoError(t, s.serve(ctx)) + done <- true + }() + + require.NoError(t, s.Shutdown()) + + require.True(t, <-done, "the accepting loop must be interrupted") +} + +func setupServer(t *testing.T) *Server { + testhelper.PrepareTestRootDir(t) + + url := testserver.StartSocketHttpServer(t, []testserver.TestRequestHandler{}) + srvCfg := config.ServerConfig{ + Listen: serverUrl, + HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}, + } + + cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg} + + s := &Server{Config: cfg} + require.NoError(t, s.listen()) + + return s +} |