diff options
Diffstat (limited to 'internal/sshd/connection.go')
-rw-r--r-- | internal/sshd/connection.go | 72 |
1 files changed, 66 insertions, 6 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 1312833..e878e9c 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -1,7 +1,9 @@ package sshd import ( + "net" "context" + "time" "golang.org/x/crypto/ssh" "golang.org/x/sync/semaphore" @@ -13,19 +15,71 @@ import ( type connection struct { concurrentSessions *semaphore.Weighted - remoteAddr string + nconn net.Conn + remoteAddr string + started time.Time } -type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) +type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error -func newConnection(maxSessions int64, remoteAddr string) *connection { +func newConnection(maxSessions int64, nconn net.Conn) *connection { return &connection{ concurrentSessions: semaphore.NewWeighted(maxSessions), - remoteAddr: remoteAddr, + nconn: nconn, + remoteAddr: nconn.RemoteAddr().String(), + started: time.Now(), } } -func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { +func (c *connection) handle(ctx context.Context, cfg *ssh.ServerConfig, handler channelHandler) { + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) + + // Prevent a panic in a single connection from taking out the whole server + defer func() { + if err := recover(); err != nil { + ctxlog.Warn("panic handling session") + } + + metrics.SliSshdSessionsErrorsTotal.Inc() + }() + + ctxlog.Info("server: handleConn: start") + + metrics.SshdConnectionsInFlight.Inc() + defer func() { + metrics.SshdConnectionsInFlight.Dec() + metrics.SshdSessionDuration.Observe(time.Since(c.started).Seconds()) + }() + + // Initialize the connection with server + sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, cfg) + + // Track the time required to establish a session + establishSessionDuration := time.Since(c.started).Seconds() + metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) + + // Most of the times a connection failes due to the client's misconfiguration or when + // a client cancels a request, so we shouldn't treat them as an error + // Warnings will helps us to track the errors whether they happend on the server side + if err != nil { + ctxlog.WithError(err).WithFields(log.Fields{ + "establish_session_duration_s": establishSessionDuration, + }).Warn("conn: init: failed to initialize SSH connection") + + return + } + go ssh.DiscardRequests(reqs) + + // Handle incoming requests + c.handleRequests(ctx, sconn, chans, handler) + + ctxlog.WithFields(log.Fields{ + "duration_s": time.Since(c.started).Seconds(), + "establish_session_duration_s": establishSessionDuration, + }).Info("server: handleConn: done") +} + +func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) { ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { @@ -55,10 +109,16 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha defer func() { if err := recover(); err != nil { ctxlog.WithField("recovered_error", err).Warn("panic handling session") + + metrics.SliSshdSessionsErrorsTotal.Inc() } }() - handler(ctx, channel, requests) + err := handler(ctx, sconn, channel, requests) + if err != nil { + metrics.SliSshdSessionsErrorsTotal.Inc() + } + ctxlog.Info("connection: handle: done") }() } |