summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Thomas <nick@gitlab.com>2021-09-15 16:51:41 +0000
committerNick Thomas <nick@gitlab.com>2021-09-15 16:51:41 +0000
commit3f640bdafe429501897541cadf2c268b13c4cf9f (patch)
tree5713e182841a41acc9e2adfd48697cbfe9eb8e67
parent7d60d7a09658041c959c92a7776feceb64b735f4 (diff)
parente96e13301904bfa6eb514667df9a7803828a7da9 (diff)
downloadgitlab-shell-3f640bdafe429501897541cadf2c268b13c4cf9f.tar.gz
Merge branch 'id-sshd-test-3' into 'main'
Extract server config related code out of sshd.go Closes #523 See merge request gitlab-org/gitlab-shell!523
-rw-r--r--internal/sshd/server_config.go94
-rw-r--r--internal/sshd/server_config_test.go105
-rw-r--r--internal/sshd/sshd.go75
-rw-r--r--internal/sshd/sshd_test.go16
4 files changed, 217 insertions, 73 deletions
diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go
new file mode 100644
index 0000000..7306944
--- /dev/null
+++ b/internal/sshd/server_config.go
@@ -0,0 +1,94 @@
+package sshd
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "os"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+type serverConfig struct {
+ cfg *config.Config
+ hostKeys []ssh.Signer
+ authorizedKeysClient *authorizedkeys.Client
+}
+
+func newServerConfig(cfg *config.Config) (*serverConfig, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
+ var hostKeys []ssh.Signer
+ for _, filename := range cfg.Server.HostKeyFiles {
+ keyRaw, err := os.ReadFile(filename)
+ if err != nil {
+ log.WithError(err).Warnf("Failed to read host key %v", filename)
+ continue
+ }
+ key, err := ssh.ParsePrivateKey(keyRaw)
+ if err != nil {
+ log.WithError(err).Warnf("Failed to parse host key %v", filename)
+ continue
+ }
+
+ hostKeys = append(hostKeys, key)
+ }
+ if len(hostKeys) == 0 {
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
+ }
+
+ return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+}
+
+func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) {
+ if user != s.cfg.User {
+ return nil, fmt.Errorf("unknown user")
+ }
+ if key.Type() == ssh.KeyAlgoDSA {
+ return nil, fmt.Errorf("DSA is prohibited")
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig {
+ sshCfg := &ssh.ServerConfig{
+ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+ res, err := s.getAuthKey(ctx, conn.User(), key)
+ if err != nil {
+ return nil, err
+ }
+
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "key-id": strconv.FormatInt(res.Id, 10),
+ },
+ }, nil
+ },
+ }
+
+ for _, key := range s.hostKeys {
+ sshCfg.AddHostKey(key)
+ }
+
+ return sshCfg
+}
diff --git a/internal/sshd/server_config_test.go b/internal/sshd/server_config_test.go
new file mode 100644
index 0000000..58bd3e1
--- /dev/null
+++ b/internal/sshd/server_config_test.go
@@ -0,0 +1,105 @@
+package sshd
+
+import (
+ "context"
+ "crypto/dsa"
+ "crypto/rand"
+ "crypto/rsa"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestNewServerConfigWithoutHosts(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "http://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "No host keys could be loaded, aborting", err.Error())
+}
+
+func TestFailedAuthorizedKeysClient(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "ftp://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "failed to initialize GitLab client: Error creating http client: unknown GitLab URL prefix", err.Error())
+}
+
+func TestFailedGetAuthKey(t *testing.T) {
+ testhelper.PrepareTestRootDir(t)
+
+ srvCfg := config.ServerConfig{
+ Listen: "127.0.0.1",
+ ConcurrentSessionsLimit: 1,
+ HostKeyFiles: []string{
+ path.Join(testhelper.TestRoot, "certs/valid/server.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid-path.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
+ },
+ }
+
+ cfg, err := newServerConfig(
+ &config.Config{GitlabUrl: "http://localhost", User: "user", Server: srvCfg},
+ )
+ require.NoError(t, err)
+
+ testCases := []struct {
+ desc string
+ user string
+ key ssh.PublicKey
+ expectedError string
+ }{
+ {
+ desc: "wrong user",
+ user: "wrong-user",
+ key: rsaPublicKey(t),
+ expectedError: "unknown user",
+ }, {
+ desc: "prohibited dsa key",
+ user: "user",
+ key: dsaPublicKey(t),
+ expectedError: "DSA is prohibited",
+ }, {
+ desc: "API error",
+ user: "user",
+ key: rsaPublicKey(t),
+ expectedError: "Internal API unreachable",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err = cfg.getAuthKey(context.Background(), tc.user, tc.key)
+ require.Error(t, err)
+ require.Equal(t, tc.expectedError, err.Error())
+ })
+ }
+}
+
+func rsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
+
+func dsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey := new(dsa.PrivateKey)
+ params := new(dsa.Parameters)
+ require.NoError(t, dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160))
+
+ privateKey.PublicKey.Parameters = *params
+ require.NoError(t, dsa.GenerateKey(privateKey, rand.Reader))
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index de5fbd4..ff9e765 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -2,13 +2,9 @@ package sshd
import (
"context"
- "encoding/base64"
- "errors"
"fmt"
"net"
"net/http"
- "os"
- "strconv"
"sync"
"time"
@@ -16,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
@@ -35,40 +30,20 @@ const (
type Server struct {
Config *config.Config
- status status
- statusMu sync.Mutex
- wg sync.WaitGroup
- listener net.Listener
- hostKeys []ssh.Signer
- authorizedKeysClient *authorizedkeys.Client
+ status status
+ statusMu sync.Mutex
+ wg sync.WaitGroup
+ listener net.Listener
+ serverConfig *serverConfig
}
func NewServer(cfg *config.Config) (*Server, error) {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ serverConfig, err := newServerConfig(cfg)
if err != nil {
- return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ return nil, err
}
- var hostKeys []ssh.Signer
- for _, filename := range cfg.Server.HostKeyFiles {
- keyRaw, err := os.ReadFile(filename)
- if err != nil {
- log.WithError(err).Warnf("Failed to read host key %v", filename)
- continue
- }
- key, err := ssh.ParsePrivateKey(keyRaw)
- if err != nil {
- log.WithError(err).Warnf("Failed to parse host key %v", filename)
- continue
- }
-
- hostKeys = append(hostKeys, key)
- }
- if len(hostKeys) == 0 {
- return nil, fmt.Errorf("No host keys could be loaded, aborting")
- }
-
- return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+ return &Server{Config: cfg, serverConfig: serverConfig}, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
@@ -168,38 +143,6 @@ func (s *Server) getStatus() status {
return s.status
}
-func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig {
- sshCfg := &ssh.ServerConfig{
- PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != s.Config.User {
- return nil, errors.New("unknown user")
- }
- if key.Type() == ssh.KeyAlgoDSA {
- return nil, errors.New("DSA is prohibited")
- }
- ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
- res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
- if err != nil {
- return nil, err
- }
-
- return &ssh.Permissions{
- // Record the public key used for authentication.
- Extensions: map[string]string{
- "key-id": strconv.FormatInt(res.Id, 10),
- },
- }, nil
- },
- }
-
- for _, key := range s.hostKeys {
- sshCfg.AddHostKey(key)
- }
-
- return sshCfg
-}
-
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
@@ -216,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx))
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
if err != nil {
log.WithError(err).Info("Failed to initialize SSH connection")
return
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index cba1c3f..71f7733 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -104,13 +104,6 @@ func TestLivenessProbe(t *testing.T) {
require.Equal(t, 200, r.Result().StatusCode)
}
-func TestNewServerWithoutHosts(t *testing.T) {
- _, err := NewServer(&config.Config{GitlabUrl: "http://localhost"})
-
- require.Error(t, err)
- require.Equal(t, "No host keys could be loaded, aborting", err.Error())
-}
-
func TestInvalidClientConfig(t *testing.T) {
setupServer(t)
@@ -120,6 +113,15 @@ func TestInvalidClientConfig(t *testing.T) {
require.Error(t, err)
}
+func TestInvalidServerConfig(t *testing.T) {
+ s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}}
+ err := s.ListenAndServe(context.Background())
+
+ require.Error(t, err)
+ require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error())
+ require.Nil(t, s.Shutdown())
+}
+
func setupServer(t *testing.T) *Server {
t.Helper()