summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStan Hu <stanhu@gmail.com>2022-05-23 17:00:22 +0000
committerStan Hu <stanhu@gmail.com>2022-05-23 17:00:22 +0000
commitc40ad688ed72357a58ba8481ba9382cabfc59375 (patch)
tree51d53adfefe6bb22d0741dabac8fc5c87f6f7d4e
parent4d2459f3b5af7de6684fdf1ea012b386ed17f424 (diff)
parent0110b9ea4b49d9236e537fd984d3db7f7b7a2702 (diff)
downloadgitlab-shell-c40ad688ed72357a58ba8481ba9382cabfc59375.tar.gz
Merge branch 'id-login-grace-time' into 'main'
Close the connection when context is canceled See merge request gitlab-org/gitlab-shell!646
-rw-r--r--internal/sshd/connection.go82
-rw-r--r--internal/sshd/connection_test.go26
-rw-r--r--internal/sshd/sshd.go51
-rw-r--r--internal/sshd/sshd_test.go37
4 files changed, 121 insertions, 75 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 61234a3..eaae5ca 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -3,6 +3,8 @@ package sshd
import (
"context"
"errors"
+ "net"
+ "strings"
"time"
"golang.org/x/crypto/ssh"
@@ -22,52 +24,91 @@ const KeepAliveMsg = "keepalive@openssh.com"
var EOFTimeout = 10 * time.Second
type connection struct {
- cfg *config.Config
- concurrentSessions *semaphore.Weighted
- remoteAddr string
- sconn *ssh.ServerConn
- maxSessions int64
+ cfg *config.Config
+ concurrentSessions *semaphore.Weighted
+ nconn net.Conn
+ maxSessions int64
+ remoteAddr string
+ started time.Time
+ establishSessionDuration float64
}
-type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error
+type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
-func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
+func newConnection(cfg *config.Config, nconn net.Conn) *connection {
maxSessions := cfg.Server.ConcurrentSessionsLimit
return &connection{
cfg: cfg,
maxSessions: maxSessions,
concurrentSessions: semaphore.NewWeighted(maxSessions),
- remoteAddr: remoteAddr,
- sconn: sconn,
+ nconn: nconn,
+ remoteAddr: nconn.RemoteAddr().String(),
+ started: time.Now(),
}
}
-func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
- ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) {
+ sconn, chans, err := c.initServerConn(ctx, srvCfg)
+ if err != nil {
+ return
+ }
if c.cfg.Server.ClientAliveInterval > 0 {
ticker := time.NewTicker(time.Duration(c.cfg.Server.ClientAliveInterval))
defer ticker.Stop()
- go c.sendKeepAliveMsg(ctx, ticker)
+ go c.sendKeepAliveMsg(ctx, sconn, ticker)
+ }
+
+ c.handleRequests(ctx, sconn, chans, handler)
+
+ reason := sconn.Wait()
+ log.WithContextFields(ctx, log.Fields{
+ "duration_s": time.Since(c.started).Seconds(),
+ "establish_session_duration_s": c.establishSessionDuration,
+ "reason": reason,
+ }).Info("server: handleConn: done")
+}
+
+func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) {
+ sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, srvCfg)
+ if err != nil {
+ msg := "connection: initServerConn: failed to initialize SSH connection"
+
+ logger := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}).WithError(err)
+
+ if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" {
+ logger.Debug(msg)
+ } else {
+ logger.Warn(msg)
+ }
+
+ return nil, nil, err
}
+ go ssh.DiscardRequests(reqs)
+
+ return sconn, chans, err
+}
+
+func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
- ctxlog.Info("connection: handle: unknown channel type")
+ ctxlog.Info("connection: handleRequests: unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !c.concurrentSessions.TryAcquire(1) {
- ctxlog.Info("connection: handle: too many concurrent sessions")
+ ctxlog.Info("connection: handleRequests: too many concurrent sessions")
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
metrics.SshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
- ctxlog.WithError(err).Error("connection: handle: accepting channel failed")
+ ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed")
c.concurrentSessions.Release(1)
continue
}
@@ -76,6 +117,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
defer func(started time.Time) {
metrics.SshdSessionDuration.Observe(time.Since(started).Seconds())
}(time.Now())
+ c.establishSessionDuration = time.Since(c.started).Seconds()
defer c.concurrentSessions.Release(1)
@@ -87,12 +129,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
}()
metrics.SliSshdSessionsTotal.Inc()
- err := handler(ctx, channel, requests)
+ err := handler(sconn, channel, requests)
if err != nil {
c.trackError(err)
}
- ctxlog.Info("connection: handle: done")
+ ctxlog.Info("connection: handleRequests: done")
}()
}
@@ -105,7 +147,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
c.concurrentSessions.Acquire(ctx, c.maxSessions)
}
-func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
+func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for {
@@ -113,9 +155,9 @@ func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker)
case <-ctx.Done():
return
case <-ticker.C:
- ctxlog.Debug("session: handleShell: send keepalive message to a client")
+ ctxlog.Debug("connection: sendKeepAliveMsg: send keepalive message to a client")
- c.sconn.SendRequest(KeepAliveMsg, true, nil)
+ sconn.SendRequest(KeepAliveMsg, true, nil)
}
}
}
diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go
index a6dad8d..a5225b2 100644
--- a/internal/sshd/connection_test.go
+++ b/internal/sshd/connection_test.go
@@ -10,6 +10,7 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
+ "golang.org/x/sync/semaphore"
grpccodes "google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
@@ -81,7 +82,7 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}}
- conn := newConnection(cfg, "127.0.0.1:50000", &ssh.ServerConn{&fakeConn{}, nil})
+ conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)}
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
@@ -95,7 +96,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
@@ -113,7 +114,7 @@ func TestUnknownChannelType(t *testing.T) {
conn, chans := setup(1, newChannel)
go func() {
- conn.handle(context.Background(), chans, nil)
+ conn.handleRequests(context.Background(), nil, chans, nil)
}()
rejectionData := <-rejectCh
@@ -133,7 +134,7 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return nil
})
@@ -148,7 +149,7 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
return nil
@@ -167,7 +168,7 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
return nil
})
@@ -185,12 +186,11 @@ func TestAcceptSessionFails(t *testing.T) {
func TestClientAliveInterval(t *testing.T) {
f := &fakeConn{}
- conn := newConnection(&config.Config{}, "127.0.0.1:50000", &ssh.ServerConn{f, nil})
-
ticker := time.NewTicker(time.Millisecond)
defer ticker.Stop()
- go conn.sendKeepAliveMsg(context.Background(), ticker)
+ conn := &connection{}
+ go conn.sendKeepAliveMsg(context.Background(), &ssh.ServerConn{f, nil}, ticker)
require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond)
}
@@ -204,7 +204,7 @@ func TestSessionsMetrics(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return errors.New("custom error")
})
@@ -213,7 +213,7 @@ func TestSessionsMetrics(t *testing.T) {
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
conn, chans = setup(1, newChannel)
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return grpcstatus.Error(grpccodes.Canceled, "canceled")
})
@@ -222,7 +222,7 @@ func TestSessionsMetrics(t *testing.T) {
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
conn, chans = setup(1, newChannel)
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return &client.ApiError{"api error"}
})
@@ -231,7 +231,7 @@ func TestSessionsMetrics(t *testing.T) {
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
conn, chans = setup(1, newChannel)
- conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
+ conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return grpcstatus.Error(grpccodes.Unavailable, "unavailable")
})
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index dbb8709..d927268 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -10,7 +10,6 @@ import (
"time"
"github.com/pires/go-proxyproto"
- "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
@@ -39,18 +38,6 @@ type Server struct {
serverConfig *serverConfig
}
-func logSSHInitError(ctxlog *logrus.Entry, err error) {
- msg := "server: handleConn: failed to initialize SSH connection"
-
- logger := ctxlog.WithError(err)
-
- if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" {
- logger.Debug(msg)
- } else {
- logger.Warn(msg)
- }
-}
-
func NewServer(cfg *config.Config) (*Server, error) {
serverConfig, err := newServerConfig(cfg)
if err != nil {
@@ -159,18 +146,21 @@ func (s *Server) getStatus() status {
}
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
+ defer s.wg.Done()
+
metrics.SshdConnectionsInFlight.Inc()
defer metrics.SshdConnectionsInFlight.Dec()
- remoteAddr := nconn.RemoteAddr().String()
-
- defer s.wg.Done()
- defer nconn.Close()
-
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
+ go func() {
+ <-ctx.Done()
+ nconn.Close() // Close the connection when context is cancelled
+ }()
+ remoteAddr := nconn.RemoteAddr().String()
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
+ ctxlog.Debug("server: handleConn: start")
// Prevent a panic in a single connection from taking out the whole server
defer func() {
@@ -181,22 +171,8 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
}
}()
- ctxlog.Debug("server: handleConn: start")
-
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
- if err != nil {
- logSSHInitError(ctxlog, err)
- return
- }
- go ssh.DiscardRequests(reqs)
-
- started := time.Now()
- var establishSessionDuration float64
- conn := newConnection(s.Config, remoteAddr, sconn)
- conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error {
- establishSessionDuration = time.Since(started).Seconds()
- metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
-
+ conn := newConnection(s.Config, nconn)
+ conn.handle(ctx, s.serverConfig.get(ctx), func(sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
session := &session{
cfg: s.Config,
channel: channel,
@@ -206,13 +182,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
return session.handle(ctx, requests)
})
-
- reason := sconn.Wait()
- ctxlog.WithFields(log.Fields{
- "duration_s": time.Since(started).Seconds(),
- "establish_session_duration_s": establishSessionDuration,
- "reason": reason,
- }).Info("server: handleConn: done")
}
func (s *Server) requirePolicy(_ net.Addr) (proxyproto.Policy, error) {
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index d725add..36adc57 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -222,6 +222,35 @@ func TestInvalidServerConfig(t *testing.T) {
require.Nil(t, s.Shutdown())
}
+func TestClosingHangedConnections(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ s := setupServerWithContext(t, nil, ctx)
+
+ unauthenticatedRequestStatus := make(chan string)
+ completed := make(chan bool)
+
+ clientCfg := clientConfig(t)
+ clientCfg.HostKeyCallback = func(_ string, _ net.Addr, _ ssh.PublicKey) error {
+ unauthenticatedRequestStatus <- "authentication-started"
+ <-completed // Wait infinitely
+
+ return nil
+ }
+
+ go func() {
+ // Start an SSH connection that never ends
+ ssh.Dial("tcp", serverUrl, clientCfg)
+ }()
+
+ require.Equal(t, "authentication-started", <-unauthenticatedRequestStatus)
+
+ require.NoError(t, s.Shutdown())
+ cancel()
+ verifyStatus(t, s, StatusClosed)
+}
+
func setupServer(t *testing.T) *Server {
t.Helper()
@@ -231,6 +260,12 @@ func setupServer(t *testing.T) *Server {
func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
t.Helper()
+ return setupServerWithContext(t, cfg, context.Background())
+}
+
+func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Context) *Server {
+ t.Helper()
+
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/authorized_keys",
@@ -270,7 +305,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
s, err := NewServer(cfg)
require.NoError(t, err)
- go func() { require.NoError(t, s.ListenAndServe(context.Background())) }()
+ go func() { require.NoError(t, s.ListenAndServe(ctx)) }()
t.Cleanup(func() { s.Shutdown() })
verifyStatus(t, s, StatusReady)