summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2021-07-09 14:41:41 +0300
committerIgor Drozdov <idrozdov@gitlab.com>2021-07-15 14:39:33 +0300
commit569a0197cacc75270776217c27e9d709907a9dfa (patch)
treeb693a2244f3d715d48df83eac70bdf6630e51d3a
parentd3711d8d7e781dbff01d8ae5c7a1d5b800c5c8a2 (diff)
downloadgitlab-shell-569a0197cacc75270776217c27e9d709907a9dfa.tar.gz
Shutdown sshd gracefully
When interruption signal is sent, we are closing ssh listener to prevent it from accepting new connections Then after configured grace period, we cancel the context to cancel all ongoing operations
-rw-r--r--cmd/gitlab-sshd/main.go28
-rw-r--r--internal/config/config.go7
-rw-r--r--internal/sshd/sshd.go98
-rw-r--r--internal/sshd/sshd_test.go49
4 files changed, 159 insertions, 23 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go
index 866bc8d..7cecbf5 100644
--- a/cmd/gitlab-sshd/main.go
+++ b/cmd/gitlab-sshd/main.go
@@ -3,6 +3,10 @@ package main
import (
"flag"
"os"
+ "os/signal"
+ "context"
+ "syscall"
+ "time"
log "github.com/sirupsen/logrus"
@@ -63,6 +67,8 @@ func main() {
ctx, finished := command.Setup("gitlab-sshd", cfg)
defer finished()
+ server := sshd.Server{Config: cfg}
+
// Startup monitoring endpoint.
if cfg.Server.WebListen != "" {
go func() {
@@ -75,7 +81,27 @@ func main() {
}()
}
- if err := sshd.Run(ctx, cfg); err != nil {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ done := make(chan os.Signal, 1)
+ signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
+
+ go func() {
+ sig := <-done
+ signal.Reset(syscall.SIGINT, syscall.SIGTERM)
+
+ log.WithFields(log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Infof("Shutdown initiated")
+
+ server.Shutdown()
+
+ <-time.After(cfg.Server.GracePeriod())
+
+ cancel()
+
+ }()
+
+ if err := server.ListenAndServe(ctx); err != nil {
log.Fatalf("Failed to start GitLab built-in sshd: %v", err)
}
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 23044cd..c58ea7d 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -8,6 +8,7 @@ import (
"path"
"path/filepath"
"sync"
+ "time"
"gitlab.com/gitlab-org/labkit/tracing"
yaml "gopkg.in/yaml.v2"
@@ -25,6 +26,7 @@ type ServerConfig struct {
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
WebListen string `yaml:"web_listen,omitempty"`
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
+ GracePeriodSeconds uint64 `yaml:"grace_period"`
HostKeyFiles []string `yaml:"host_key_files,omitempty"`
}
@@ -69,6 +71,7 @@ var (
Listen: "[::]:22",
WebListen: "localhost:9122",
ConcurrentSessionsLimit: 10,
+ GracePeriodSeconds: 10,
HostKeyFiles: []string{
"/run/secrets/ssh-hostkeys/ssh_host_rsa_key",
"/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key",
@@ -77,6 +80,10 @@ var (
}
)
+func (sc *ServerConfig) GracePeriod() time.Duration {
+ return time.Duration(sc.GracePeriodSeconds) * time.Second
+}
+
func (c *Config) ApplyGlobalState() {
if c.SslCertDir != "" {
os.Setenv("SSL_CERT_DIR", c.SslCertDir)
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index b04366e..ef401dc 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -9,6 +9,7 @@ import (
"net"
"strconv"
"time"
+ "sync"
log "github.com/sirupsen/logrus"
@@ -20,28 +21,87 @@ import (
"gitlab.com/gitlab-org/labkit/correlation"
)
-func Run(ctx context.Context, cfg *config.Config) error {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
- if err != nil {
- return fmt.Errorf("failed to initialize GitLab client: %w", err)
+type Server struct {
+ Config *config.Config
+
+ onShutdown bool
+ wg sync.WaitGroup
+ listener net.Listener
+}
+
+func (s *Server) ListenAndServe(ctx context.Context) error {
+ if err := s.listen(); err != nil {
+ return err
}
+ defer s.listener.Close()
+
+ return s.serve(ctx)
+}
+
+func (s *Server) Shutdown() error {
+ if s.listener == nil {
+ return nil
+ }
+
+ s.onShutdown = true
+
+ return s.listener.Close()
+}
- sshListener, err := net.Listen("tcp", cfg.Server.Listen)
+func (s *Server) listen() error {
+ sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
if err != nil {
return fmt.Errorf("failed to listen for connection: %w", err)
}
- if cfg.Server.ProxyProtocol {
+
+ if s.Config.Server.ProxyProtocol {
sshListener = &proxyproto.Listener{Listener: sshListener}
log.Info("Proxy protocol is enabled")
}
- defer sshListener.Close()
log.Infof("Listening on %v", sshListener.Addr().String())
+ s.listener = sshListener
+
+ return nil
+}
+
+func (s *Server) serve(ctx context.Context) error {
+ sshCfg, err := s.initConfig(ctx)
+ if err != nil {
+ return err
+ }
+
+ for {
+ nconn, err := s.listener.Accept()
+ if err != nil {
+ if s.onShutdown {
+ break
+ }
+
+ log.Warnf("Failed to accept connection: %v\n", err)
+ continue
+ }
+
+ s.wg.Add(1)
+ go s.handleConn(ctx, sshCfg, nconn)
+ }
+
+ s.wg.Wait()
+
+ return nil
+}
+
+func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(s.Config)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != cfg.User {
+ if conn.User() != s.Config.User {
return nil, errors.New("unknown user")
}
if key.Type() == ssh.KeyAlgoDSA {
@@ -64,7 +124,7 @@ func Run(ctx context.Context, cfg *config.Config) error {
}
var loadedHostKeys uint
- for _, filename := range cfg.Server.HostKeyFiles {
+ for _, filename := range s.Config.Server.HostKeyFiles {
keyRaw, err := ioutil.ReadFile(filename)
if err != nil {
log.Warnf("Failed to read host key %v: %v", filename, err)
@@ -79,23 +139,17 @@ func Run(ctx context.Context, cfg *config.Config) error {
sshCfg.AddHostKey(key)
}
if loadedHostKeys == 0 {
- return fmt.Errorf("No host keys could be loaded, aborting")
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
}
- for {
- nconn, err := sshListener.Accept()
- if err != nil {
- log.Warnf("Failed to accept connection: %v\n", err)
- continue
- }
-
- go handleConn(ctx, cfg, sshCfg, nconn)
- }
+ return sshCfg, nil
}
-func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
+
+func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
+ defer s.wg.Done()
defer nconn.Close()
// Prevent a panic in a single connection from taking out the whole server
@@ -116,10 +170,10 @@ func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfi
go ssh.DiscardRequests(reqs)
- conn := newConnection(cfg.Server.ConcurrentSessionsLimit, remoteAddr)
+ conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
session := &session{
- cfg: cfg,
+ cfg: s.Config,
channel: channel,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
remoteAddr: remoteAddr,
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
new file mode 100644
index 0000000..d1891ec
--- /dev/null
+++ b/internal/sshd/sshd_test.go
@@ -0,0 +1,49 @@
+package sshd
+
+import (
+ "testing"
+ "context"
+ "path"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+const serverUrl = "127.0.0.1:50000"
+
+func TestShutdown(t *testing.T) {
+ s := setupServer(t)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ done := make(chan bool, 1)
+ go func() {
+ require.NoError(t, s.serve(ctx))
+ done <- true
+ }()
+
+ require.NoError(t, s.Shutdown())
+
+ require.True(t, <-done, "the accepting loop must be interrupted")
+}
+
+func setupServer(t *testing.T) *Server {
+ testhelper.PrepareTestRootDir(t)
+
+ url := testserver.StartSocketHttpServer(t, []testserver.TestRequestHandler{})
+ srvCfg := config.ServerConfig{
+ Listen: serverUrl,
+ HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")},
+ }
+
+ cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg}
+
+ s := &Server{Config: cfg}
+ require.NoError(t, s.listen())
+
+ return s
+}