summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Thomas <nick@gitlab.com>2021-07-15 12:44:48 +0000
committerNick Thomas <nick@gitlab.com>2021-07-15 12:44:48 +0000
commit60a91e977cb76e2cc49faeeac134640dbf2e2e6c (patch)
treeb693a2244f3d715d48df83eac70bdf6630e51d3a
parentd3711d8d7e781dbff01d8ae5c7a1d5b800c5c8a2 (diff)
parent569a0197cacc75270776217c27e9d709907a9dfa (diff)
downloadgitlab-shell-60a91e977cb76e2cc49faeeac134640dbf2e2e6c.tar.gz
Merge branch 'id-cancelable-sshd' into 'main'
Shutdown sshd gracefully See merge request gitlab-org/gitlab-shell!484
-rw-r--r--cmd/gitlab-sshd/main.go28
-rw-r--r--internal/config/config.go7
-rw-r--r--internal/sshd/sshd.go98
-rw-r--r--internal/sshd/sshd_test.go49
4 files changed, 159 insertions, 23 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go
index 866bc8d..7cecbf5 100644
--- a/cmd/gitlab-sshd/main.go
+++ b/cmd/gitlab-sshd/main.go
@@ -3,6 +3,10 @@ package main
import (
"flag"
"os"
+ "os/signal"
+ "context"
+ "syscall"
+ "time"
log "github.com/sirupsen/logrus"
@@ -63,6 +67,8 @@ func main() {
ctx, finished := command.Setup("gitlab-sshd", cfg)
defer finished()
+ server := sshd.Server{Config: cfg}
+
// Startup monitoring endpoint.
if cfg.Server.WebListen != "" {
go func() {
@@ -75,7 +81,27 @@ func main() {
}()
}
- if err := sshd.Run(ctx, cfg); err != nil {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ done := make(chan os.Signal, 1)
+ signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
+
+ go func() {
+ sig := <-done
+ signal.Reset(syscall.SIGINT, syscall.SIGTERM)
+
+ log.WithFields(log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Infof("Shutdown initiated")
+
+ server.Shutdown()
+
+ <-time.After(cfg.Server.GracePeriod())
+
+ cancel()
+
+ }()
+
+ if err := server.ListenAndServe(ctx); err != nil {
log.Fatalf("Failed to start GitLab built-in sshd: %v", err)
}
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 23044cd..c58ea7d 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -8,6 +8,7 @@ import (
"path"
"path/filepath"
"sync"
+ "time"
"gitlab.com/gitlab-org/labkit/tracing"
yaml "gopkg.in/yaml.v2"
@@ -25,6 +26,7 @@ type ServerConfig struct {
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
WebListen string `yaml:"web_listen,omitempty"`
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
+ GracePeriodSeconds uint64 `yaml:"grace_period"`
HostKeyFiles []string `yaml:"host_key_files,omitempty"`
}
@@ -69,6 +71,7 @@ var (
Listen: "[::]:22",
WebListen: "localhost:9122",
ConcurrentSessionsLimit: 10,
+ GracePeriodSeconds: 10,
HostKeyFiles: []string{
"/run/secrets/ssh-hostkeys/ssh_host_rsa_key",
"/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key",
@@ -77,6 +80,10 @@ var (
}
)
+func (sc *ServerConfig) GracePeriod() time.Duration {
+ return time.Duration(sc.GracePeriodSeconds) * time.Second
+}
+
func (c *Config) ApplyGlobalState() {
if c.SslCertDir != "" {
os.Setenv("SSL_CERT_DIR", c.SslCertDir)
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index b04366e..ef401dc 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -9,6 +9,7 @@ import (
"net"
"strconv"
"time"
+ "sync"
log "github.com/sirupsen/logrus"
@@ -20,28 +21,87 @@ import (
"gitlab.com/gitlab-org/labkit/correlation"
)
-func Run(ctx context.Context, cfg *config.Config) error {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
- if err != nil {
- return fmt.Errorf("failed to initialize GitLab client: %w", err)
+type Server struct {
+ Config *config.Config
+
+ onShutdown bool
+ wg sync.WaitGroup
+ listener net.Listener
+}
+
+func (s *Server) ListenAndServe(ctx context.Context) error {
+ if err := s.listen(); err != nil {
+ return err
}
+ defer s.listener.Close()
+
+ return s.serve(ctx)
+}
+
+func (s *Server) Shutdown() error {
+ if s.listener == nil {
+ return nil
+ }
+
+ s.onShutdown = true
+
+ return s.listener.Close()
+}
- sshListener, err := net.Listen("tcp", cfg.Server.Listen)
+func (s *Server) listen() error {
+ sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
if err != nil {
return fmt.Errorf("failed to listen for connection: %w", err)
}
- if cfg.Server.ProxyProtocol {
+
+ if s.Config.Server.ProxyProtocol {
sshListener = &proxyproto.Listener{Listener: sshListener}
log.Info("Proxy protocol is enabled")
}
- defer sshListener.Close()
log.Infof("Listening on %v", sshListener.Addr().String())
+ s.listener = sshListener
+
+ return nil
+}
+
+func (s *Server) serve(ctx context.Context) error {
+ sshCfg, err := s.initConfig(ctx)
+ if err != nil {
+ return err
+ }
+
+ for {
+ nconn, err := s.listener.Accept()
+ if err != nil {
+ if s.onShutdown {
+ break
+ }
+
+ log.Warnf("Failed to accept connection: %v\n", err)
+ continue
+ }
+
+ s.wg.Add(1)
+ go s.handleConn(ctx, sshCfg, nconn)
+ }
+
+ s.wg.Wait()
+
+ return nil
+}
+
+func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(s.Config)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != cfg.User {
+ if conn.User() != s.Config.User {
return nil, errors.New("unknown user")
}
if key.Type() == ssh.KeyAlgoDSA {
@@ -64,7 +124,7 @@ func Run(ctx context.Context, cfg *config.Config) error {
}
var loadedHostKeys uint
- for _, filename := range cfg.Server.HostKeyFiles {
+ for _, filename := range s.Config.Server.HostKeyFiles {
keyRaw, err := ioutil.ReadFile(filename)
if err != nil {
log.Warnf("Failed to read host key %v: %v", filename, err)
@@ -79,23 +139,17 @@ func Run(ctx context.Context, cfg *config.Config) error {
sshCfg.AddHostKey(key)
}
if loadedHostKeys == 0 {
- return fmt.Errorf("No host keys could be loaded, aborting")
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
}
- for {
- nconn, err := sshListener.Accept()
- if err != nil {
- log.Warnf("Failed to accept connection: %v\n", err)
- continue
- }
-
- go handleConn(ctx, cfg, sshCfg, nconn)
- }
+ return sshCfg, nil
}
-func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
+
+func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
+ defer s.wg.Done()
defer nconn.Close()
// Prevent a panic in a single connection from taking out the whole server
@@ -116,10 +170,10 @@ func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfi
go ssh.DiscardRequests(reqs)
- conn := newConnection(cfg.Server.ConcurrentSessionsLimit, remoteAddr)
+ conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
session := &session{
- cfg: cfg,
+ cfg: s.Config,
channel: channel,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
remoteAddr: remoteAddr,
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
new file mode 100644
index 0000000..d1891ec
--- /dev/null
+++ b/internal/sshd/sshd_test.go
@@ -0,0 +1,49 @@
+package sshd
+
+import (
+ "testing"
+ "context"
+ "path"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+const serverUrl = "127.0.0.1:50000"
+
+func TestShutdown(t *testing.T) {
+ s := setupServer(t)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ done := make(chan bool, 1)
+ go func() {
+ require.NoError(t, s.serve(ctx))
+ done <- true
+ }()
+
+ require.NoError(t, s.Shutdown())
+
+ require.True(t, <-done, "the accepting loop must be interrupted")
+}
+
+func setupServer(t *testing.T) *Server {
+ testhelper.PrepareTestRootDir(t)
+
+ url := testserver.StartSocketHttpServer(t, []testserver.TestRequestHandler{})
+ srvCfg := config.ServerConfig{
+ Listen: serverUrl,
+ HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")},
+ }
+
+ cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg}
+
+ s := &Server{Config: cfg}
+ require.NoError(t, s.listen())
+
+ return s
+}