summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Thomas <nick@gitlab.com>2021-07-27 15:41:54 +0000
committerNick Thomas <nick@gitlab.com>2021-07-27 15:41:54 +0000
commitb7edd7dd9f957c6b14d3bfa4407aca9ddfbe4f52 (patch)
tree034e4ae73aa5522a73db0506e67f567648bc507f
parentf9e7ffda68192d24ff26f0d5ff7fe70e376c32f2 (diff)
parentf6baecaa794ef85b144fa9cd05940e3f020b4a0e (diff)
downloadgitlab-shell-b7edd7dd9f957c6b14d3bfa4407aca9ddfbe4f52.tar.gz
Merge branch 'id-ctx-for-auth-check' into 'main'
Log same correlation_id on auth keys check of ssh connections See merge request gitlab-org/gitlab-shell!501
-rw-r--r--cmd/gitlab-sshd/main.go7
-rw-r--r--internal/sshd/sshd.go85
-rw-r--r--internal/sshd/sshd_test.go120
-rw-r--r--internal/testhelper/testdata/testroot/certs/valid/server_authorized_key1
4 files changed, 157 insertions, 56 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go
index d1cc84e..78690b0 100644
--- a/cmd/gitlab-sshd/main.go
+++ b/cmd/gitlab-sshd/main.go
@@ -68,7 +68,10 @@ func main() {
ctx, finished := command.Setup("gitlab-sshd", cfg)
defer finished()
- server := sshd.Server{Config: cfg}
+ server, err := sshd.NewServer(cfg)
+ if err != nil {
+ log.WithError(err).Fatal("Failed to start GitLab built-in sshd")
+ }
// Startup monitoring endpoint.
if cfg.Server.WebListen != "" {
@@ -104,6 +107,6 @@ func main() {
}()
if err := server.ListenAndServe(ctx); err != nil {
- log.WithError(err).Fatal("Failed to start GitLab built-in sshd")
+ log.WithError(err).Fatal("GitLab built-in sshd failed to listen for new connections")
}
}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 8b49712..b918109 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -35,10 +35,40 @@ const (
type Server struct {
Config *config.Config
- status status
- statusMu sync.Mutex
- wg sync.WaitGroup
- listener net.Listener
+ status status
+ statusMu sync.Mutex
+ wg sync.WaitGroup
+ listener net.Listener
+ hostKeys []ssh.Signer
+ authorizedKeysClient *authorizedkeys.Client
+}
+
+func NewServer(cfg *config.Config) (*Server, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
+ var hostKeys []ssh.Signer
+ for _, filename := range cfg.Server.HostKeyFiles {
+ keyRaw, err := ioutil.ReadFile(filename)
+ if err != nil {
+ log.WithError(err).Warnf("Failed to read host key %v", filename)
+ continue
+ }
+ key, err := ssh.ParsePrivateKey(keyRaw)
+ if err != nil {
+ log.WithError(err).Warnf("Failed to parse host key %v", filename)
+ continue
+ }
+
+ hostKeys = append(hostKeys, key)
+ }
+ if len(hostKeys) == 0 {
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
+ }
+
+ return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
@@ -47,7 +77,9 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
}
defer s.listener.Close()
- return s.serve(ctx)
+ s.serve(ctx)
+
+ return nil
}
func (s *Server) Shutdown() error {
@@ -100,12 +132,7 @@ func (s *Server) listen() error {
return nil
}
-func (s *Server) serve(ctx context.Context) error {
- sshCfg, err := s.initConfig(ctx)
- if err != nil {
- return err
- }
-
+func (s *Server) serve(ctx context.Context) {
s.changeStatus(StatusReady)
for {
@@ -120,14 +147,12 @@ func (s *Server) serve(ctx context.Context) error {
}
s.wg.Add(1)
- go s.handleConn(ctx, sshCfg, nconn)
+ go s.handleConn(ctx, nconn)
}
s.wg.Wait()
s.changeStatus(StatusClosed)
-
- return nil
}
func (s *Server) changeStatus(st status) {
@@ -143,12 +168,7 @@ func (s *Server) getStatus() status {
return s.status
}
-func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
- authorizedKeysClient, err := authorizedkeys.NewClient(s.Config)
- if err != nil {
- return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
- }
-
+func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig {
sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if conn.User() != s.Config.User {
@@ -159,7 +179,7 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
- res, err := authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
+ res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
if err != nil {
return nil, err
}
@@ -173,29 +193,14 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
},
}
- var loadedHostKeys uint
- for _, filename := range s.Config.Server.HostKeyFiles {
- keyRaw, err := ioutil.ReadFile(filename)
- if err != nil {
- log.WithError(err).Warnf("Failed to read host key %v", filename)
- continue
- }
- key, err := ssh.ParsePrivateKey(keyRaw)
- if err != nil {
- log.WithError(err).Warnf("Failed to parse host key %v", filename)
- continue
- }
- loadedHostKeys++
+ for _, key := range s.hostKeys {
sshCfg.AddHostKey(key)
}
- if loadedHostKeys == 0 {
- return nil, fmt.Errorf("No host keys could be loaded, aborting")
- }
- return sshCfg, nil
+ return sshCfg
}
-func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) {
+func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
defer s.wg.Done()
@@ -211,7 +216,7 @@ func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx))
if err != nil {
log.WithError(err).Info("Failed to initialize SSH connection")
return
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index e5f6111..2923737 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -2,37 +2,71 @@ package sshd
import (
"context"
+ "fmt"
+ "io/ioutil"
+ "net/http"
"net/http/httptest"
"path"
"testing"
"time"
"github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/client/testserver"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
)
-const serverUrl = "127.0.0.1:50000"
-
-func TestShutdown(t *testing.T) {
- s := setupServer(t)
+const (
+ serverUrl = "127.0.0.1:50000"
+ user = "git"
+)
- go func() { require.NoError(t, s.ListenAndServe(context.Background())) }()
+var (
+ correlationId = ""
+)
- verifyStatus(t, s, StatusReady)
+func TestListenAndServe(t *testing.T) {
+ s := setupServer(t)
- s.wg.Add(1)
+ client, err := ssh.Dial("tcp", serverUrl, clientConfig(t))
+ require.NoError(t, err)
+ defer client.Close()
require.NoError(t, s.Shutdown())
verifyStatus(t, s, StatusOnShutdown)
- s.wg.Done()
+ holdSession(t, client)
+
+ _, err = ssh.Dial("tcp", serverUrl, clientConfig(t))
+ require.Equal(t, err.Error(), "dial tcp 127.0.0.1:50000: connect: connection refused")
+
+ client.Close()
verifyStatus(t, s, StatusClosed)
}
+func TestCorrelationId(t *testing.T) {
+ setupServer(t)
+
+ client, err := ssh.Dial("tcp", serverUrl, clientConfig(t))
+ require.NoError(t, err)
+ defer client.Close()
+
+ holdSession(t, client)
+
+ previousCorrelationId := correlationId
+
+ client, err = ssh.Dial("tcp", serverUrl, clientConfig(t))
+ require.NoError(t, err)
+ defer client.Close()
+
+ holdSession(t, client)
+
+ require.NotEqual(t, previousCorrelationId, correlationId)
+}
+
func TestReadinessProbe(t *testing.T) {
s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}}
@@ -71,17 +105,75 @@ func TestLivenessProbe(t *testing.T) {
}
func setupServer(t *testing.T) *Server {
+ t.Helper()
+
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/authorized_keys",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ correlationId = r.Header.Get("X-Request-Id")
+
+ require.NotEmpty(t, correlationId)
+
+ fmt.Fprint(w, `{"id": 1000, "key": "key"}`)
+ },
+ }, {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, correlationId, r.Header.Get("X-Request-Id"))
+
+ fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`)
+ },
+ },
+ }
+
testhelper.PrepareTestRootDir(t)
- url := testserver.StartSocketHttpServer(t, []testserver.TestRequestHandler{})
+ url := testserver.StartSocketHttpServer(t, requests)
srvCfg := config.ServerConfig{
- Listen: serverUrl,
- HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")},
+ Listen: serverUrl,
+ ConcurrentSessionsLimit: 1,
+ HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")},
+ }
+
+ s, err := NewServer(&config.Config{User: user, RootDir: "/tmp", GitlabUrl: url, Server: srvCfg})
+ require.NoError(t, err)
+
+ go func() { require.NoError(t, s.ListenAndServe(context.Background())) }()
+ t.Cleanup(func() { s.Shutdown() })
+
+ verifyStatus(t, s, StatusReady)
+
+ return s
+}
+
+func clientConfig(t *testing.T) *ssh.ClientConfig {
+ keyRaw, err := ioutil.ReadFile(path.Join(testhelper.TestRoot, "certs/valid/server_authorized_key"))
+ pKey, _, _, _, err := ssh.ParseAuthorizedKey(keyRaw)
+ require.NoError(t, err)
+
+ key, err := ioutil.ReadFile(path.Join(testhelper.TestRoot, "certs/client/key.pem"))
+ require.NoError(t, err)
+ signer, err := ssh.ParsePrivateKey(key)
+ require.NoError(t, err)
+
+ return &ssh.ClientConfig{
+ User: user,
+ Auth: []ssh.AuthMethod{
+ ssh.PublicKeys(signer),
+ },
+ HostKeyCallback: ssh.FixedHostKey(pKey),
}
+}
- cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg}
+func holdSession(t *testing.T, c *ssh.Client) {
+ session, err := c.NewSession()
+ require.NoError(t, err)
+ defer session.Close()
- return &Server{Config: cfg}
+ output, err := session.Output("discover")
+ require.NoError(t, err)
+ require.Equal(t, "Welcome to GitLab, @test-user!\n", string(output))
}
func verifyStatus(t *testing.T, s *Server, st status) {
@@ -94,5 +186,5 @@ func verifyStatus(t *testing.T, s *Server, st status) {
time.Sleep(time.Duration(i) * time.Millisecond)
}
- require.Equal(t, s.getStatus(), st)
+ require.Equal(t, st, s.getStatus())
}
diff --git a/internal/testhelper/testdata/testroot/certs/valid/server_authorized_key b/internal/testhelper/testdata/testroot/certs/valid/server_authorized_key
new file mode 100644
index 0000000..784d80c
--- /dev/null
+++ b/internal/testhelper/testdata/testroot/certs/valid/server_authorized_key
@@ -0,0 +1 @@
+ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCa17cb94P6q5qbDIWX7aMSjyeBIBPQVZ5jlkDBG90XgWC1MEu9sB1OfKLukcx6wJJSTLFccc9rMzhINXq6K7ks0oXSLP81jvqsu0WipIZSDKBNkdVtno1FcI1RnQ+yUP3nA4Ja9L233GA1evLrqTz6Z9k2ET5wVB+s7+k3lak24bJZN8qVRDDk1UveahuPe1KMj7DNKls8y9tNCgGJn9UeTLJzXlh2tt4/AUHZ0lvET9eCzKT9PBZJQWcCzqLXHa37jbc0ib2sgNN1bZhgkle/cxRx0MjEmdjRt4Z48wjKaf1khFQm0r9lebAxvna/vT5hNywbru5KbfUJHyM23yql