summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2022-05-11 20:25:57 +0400
committerIgor Drozdov <idrozdov@gitlab.com>2022-05-12 09:53:48 +0400
commita16dcb3e6ca3361ba23fabb369dc6566e693ba9d (patch)
tree4f861f7a8cc32105660feddefc8e9623ebd11028
parent42cf058b7292527b250d48167b2db5ec85736f37 (diff)
downloadgitlab-shell-a16dcb3e6ca3361ba23fabb369dc6566e693ba9d.tar.gz
Implement ClientKeepAlive option
Git clients sometimes open a connection and leave it idling, like when compressing objects. Settings like timeout client in HAProxy might cause these idle connections to be terminated. Let's send the keepalive message in order to prevent a client from closing
-rw-r--r--config.yml.example2
-rw-r--r--internal/config/config.go36
-rw-r--r--internal/sshd/connection.go32
-rw-r--r--internal/sshd/connection_test.go43
-rw-r--r--internal/sshd/sshd.go2
-rw-r--r--internal/sshd/sshd_test.go1
6 files changed, 97 insertions, 19 deletions
diff --git a/config.yml.example b/config.yml.example
index f23c5d1..a453c0a 100644
--- a/config.yml.example
+++ b/config.yml.example
@@ -76,6 +76,8 @@ sshd:
web_listen: "localhost:9122"
# Maximum number of concurrent sessions allowed on a single SSH connection. Defaults to 10.
concurrent_sessions_limit: 10
+ # Sets an interval after which server will send keepalive message to a client
+ client_alive_interval: 15
# The server waits for this time (in seconds) for the ongoing connections to complete before shutting down. Defaults to 10.
grace_period: 10
# The endpoint that returns 200 OK if the server is ready to receive incoming connections; otherwise, it returns 503 Service Unavailable. Defaults to "/start".
diff --git a/internal/config/config.go b/internal/config/config.go
index ab88d72..b45082e 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -22,15 +22,16 @@ const (
)
type ServerConfig struct {
- Listen string `yaml:"listen,omitempty"`
- ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
- ProxyPolicy string `yaml:"proxy_policy,omitempty"`
- WebListen string `yaml:"web_listen,omitempty"`
- ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
- GracePeriodSeconds uint64 `yaml:"grace_period"`
- ReadinessProbe string `yaml:"readiness_probe"`
- LivenessProbe string `yaml:"liveness_probe"`
- HostKeyFiles []string `yaml:"host_key_files,omitempty"`
+ Listen string `yaml:"listen,omitempty"`
+ ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
+ ProxyPolicy string `yaml:"proxy_policy,omitempty"`
+ WebListen string `yaml:"web_listen,omitempty"`
+ ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
+ ClientAliveIntervalSeconds int64 `yaml:"client_alive_interval,omitempty"`
+ GracePeriodSeconds uint64 `yaml:"grace_period"`
+ ReadinessProbe string `yaml:"readiness_probe"`
+ LivenessProbe string `yaml:"liveness_probe"`
+ HostKeyFiles []string `yaml:"host_key_files,omitempty"`
}
type HttpSettingsConfig struct {
@@ -75,12 +76,13 @@ var (
}
DefaultServerConfig = ServerConfig{
- Listen: "[::]:22",
- WebListen: "localhost:9122",
- ConcurrentSessionsLimit: 10,
- GracePeriodSeconds: 10,
- ReadinessProbe: "/start",
- LivenessProbe: "/health",
+ Listen: "[::]:22",
+ WebListen: "localhost:9122",
+ ConcurrentSessionsLimit: 10,
+ GracePeriodSeconds: 10,
+ ClientAliveIntervalSeconds: 15,
+ ReadinessProbe: "/start",
+ LivenessProbe: "/health",
HostKeyFiles: []string{
"/run/secrets/ssh-hostkeys/ssh_host_rsa_key",
"/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key",
@@ -89,6 +91,10 @@ var (
}
)
+func (sc *ServerConfig) ClientAliveInterval() time.Duration {
+ return time.Duration(sc.ClientAliveIntervalSeconds) * time.Second
+}
+
func (sc *ServerConfig) GracePeriod() time.Duration {
return time.Duration(sc.GracePeriodSeconds) * time.Second
}
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 278bb37..5b1232d 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -7,28 +7,41 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
"gitlab.com/gitlab-org/labkit/log"
)
+const KeepAliveMsg = "keepalive@openssh.com"
+
type connection struct {
+ cfg *config.Config
concurrentSessions *semaphore.Weighted
remoteAddr string
+ sconn *ssh.ServerConn
}
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
-func newConnection(maxSessions int64, remoteAddr string) *connection {
+func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
return &connection{
- concurrentSessions: semaphore.NewWeighted(maxSessions),
+ cfg: cfg,
+ concurrentSessions: semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit),
remoteAddr: remoteAddr,
+ sconn: sconn,
}
}
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+ if c.cfg.Server.ClientAliveIntervalSeconds > 0 {
+ ticker := time.NewTicker(c.cfg.Server.ClientAliveInterval())
+ defer ticker.Stop()
+ go c.sendKeepAliveMsg(ctx, ticker)
+ }
+
for newChannel := range chans {
ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
@@ -68,3 +81,18 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
}()
}
}
+
+func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ ctxlog.Debug("session: handleShell: send keepalive message to a client")
+
+ c.sconn.SendRequest(KeepAliveMsg, true, nil)
+ }
+ }
+}
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
index d6bd3c0..3bd9bf8 100644
--- a/internal/sshd/connection_test.go
+++ b/internal/sshd/connection_test.go
@@ -3,10 +3,14 @@ package sshd
import (
"context"
"errors"
+ "sync"
"testing"
+ "time"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
)
type rejectCall struct {
@@ -47,8 +51,32 @@ func (f *fakeNewChannel) ExtraData() []byte {
return f.extraData
}
+type fakeConn struct {
+ ssh.Conn
+
+ sentRequestName string
+ mu sync.Mutex
+}
+
+func (f *fakeConn) SentRequestName() string {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ return f.sentRequestName
+}
+
+func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.sentRequestName = name
+
+ return true, nil, nil
+}
+
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
- conn := newConnection(sessionsNum, "127.0.0.1:50000")
+ cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum, ClientAliveIntervalSeconds: 1}}
+ conn := newConnection(cfg, "127.0.0.1:50000", &ssh.ServerConn{&fakeConn{}, nil})
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
@@ -145,3 +173,16 @@ func TestAcceptSessionFails(t *testing.T) {
require.False(t, channelHandled)
}
+
+func TestClientAliveInterval(t *testing.T) {
+ f := &fakeConn{}
+
+ conn := newConnection(&config.Config{}, "127.0.0.1:50000", &ssh.ServerConn{f, nil})
+
+ ticker := time.NewTicker(time.Millisecond)
+ defer ticker.Stop()
+
+ go conn.sendKeepAliveMsg(context.Background(), ticker)
+
+ require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond)
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 99fa3c9..242e4f2 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -180,7 +180,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started := time.Now()
var establishSessionDuration float64
- conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
+ conn := newConnection(s.Config, remoteAddr, sconn)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
establishSessionDuration = time.Since(started).Seconds()
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index d725add..80495f6 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -265,6 +265,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
cfg.User = user
cfg.Server.Listen = serverUrl
cfg.Server.ConcurrentSessionsLimit = 1
+ cfg.Server.ClientAliveIntervalSeconds = 15
cfg.Server.HostKeyFiles = []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}
s, err := NewServer(cfg)