summaryrefslogtreecommitdiff
path: root/internal/sshd/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/connection.go')
-rw-r--r--internal/sshd/connection.go72
1 files changed, 66 insertions, 6 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 1312833..e878e9c 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -1,7 +1,9 @@
package sshd
import (
+ "net"
"context"
+ "time"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
@@ -13,19 +15,71 @@ import (
type connection struct {
concurrentSessions *semaphore.Weighted
- remoteAddr string
+ nconn net.Conn
+ remoteAddr string
+ started time.Time
}
-type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
+type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
-func newConnection(maxSessions int64, remoteAddr string) *connection {
+func newConnection(maxSessions int64, nconn net.Conn) *connection {
return &connection{
concurrentSessions: semaphore.NewWeighted(maxSessions),
- remoteAddr: remoteAddr,
+ nconn: nconn,
+ remoteAddr: nconn.RemoteAddr().String(),
+ started: time.Now(),
}
}
-func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
+func (c *connection) handle(ctx context.Context, cfg *ssh.ServerConfig, handler channelHandler) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+
+ // Prevent a panic in a single connection from taking out the whole server
+ defer func() {
+ if err := recover(); err != nil {
+ ctxlog.Warn("panic handling session")
+ }
+
+ metrics.SliSshdSessionsErrorsTotal.Inc()
+ }()
+
+ ctxlog.Info("server: handleConn: start")
+
+ metrics.SshdConnectionsInFlight.Inc()
+ defer func() {
+ metrics.SshdConnectionsInFlight.Dec()
+ metrics.SshdSessionDuration.Observe(time.Since(c.started).Seconds())
+ }()
+
+ // Initialize the connection with server
+ sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, cfg)
+
+ // Track the time required to establish a session
+ establishSessionDuration := time.Since(c.started).Seconds()
+ metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
+
+ // Most of the times a connection failes due to the client's misconfiguration or when
+ // a client cancels a request, so we shouldn't treat them as an error
+ // Warnings will helps us to track the errors whether they happend on the server side
+ if err != nil {
+ ctxlog.WithError(err).WithFields(log.Fields{
+ "establish_session_duration_s": establishSessionDuration,
+ }).Warn("conn: init: failed to initialize SSH connection")
+
+ return
+ }
+ go ssh.DiscardRequests(reqs)
+
+ // Handle incoming requests
+ c.handleRequests(ctx, sconn, chans, handler)
+
+ ctxlog.WithFields(log.Fields{
+ "duration_s": time.Since(c.started).Seconds(),
+ "establish_session_duration_s": establishSessionDuration,
+ }).Info("server: handleConn: done")
+}
+
+func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
@@ -55,10 +109,16 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
defer func() {
if err := recover(); err != nil {
ctxlog.WithField("recovered_error", err).Warn("panic handling session")
+
+ metrics.SliSshdSessionsErrorsTotal.Inc()
}
}()
- handler(ctx, channel, requests)
+ err := handler(ctx, sconn, channel, requests)
+ if err != nil {
+ metrics.SliSshdSessionsErrorsTotal.Inc()
+ }
+
ctxlog.Info("connection: handle: done")
}()
}