diff options
Diffstat (limited to 'internal/sshd/server_config.go')
-rw-r--r-- | internal/sshd/server_config.go | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go new file mode 100644 index 0000000..7306944 --- /dev/null +++ b/internal/sshd/server_config.go @@ -0,0 +1,94 @@ +package sshd + +import ( + "context" + "encoding/base64" + "fmt" + "os" + "strconv" + "time" + + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" + + "gitlab.com/gitlab-org/labkit/log" +) + +type serverConfig struct { + cfg *config.Config + hostKeys []ssh.Signer + authorizedKeysClient *authorizedkeys.Client +} + +func newServerConfig(cfg *config.Config) (*serverConfig, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + + var hostKeys []ssh.Signer + for _, filename := range cfg.Server.HostKeyFiles { + keyRaw, err := os.ReadFile(filename) + if err != nil { + log.WithError(err).Warnf("Failed to read host key %v", filename) + continue + } + key, err := ssh.ParsePrivateKey(keyRaw) + if err != nil { + log.WithError(err).Warnf("Failed to parse host key %v", filename) + continue + } + + hostKeys = append(hostKeys, key) + } + if len(hostKeys) == 0 { + return nil, fmt.Errorf("No host keys could be loaded, aborting") + } + + return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil +} + +func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) { + if user != s.cfg.User { + return nil, fmt.Errorf("unknown user") + } + if key.Type() == ssh.KeyAlgoDSA { + return nil, fmt.Errorf("DSA is prohibited") + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) + if err != nil { + return nil, err + } + + return res, nil +} + +func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig { + sshCfg := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + res, err := s.getAuthKey(ctx, conn.User(), key) + if err != nil { + return nil, err + } + + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "key-id": strconv.FormatInt(res.Id, 10), + }, + }, nil + }, + } + + for _, key := range s.hostKeys { + sshCfg.AddHostKey(key) + } + + return sshCfg +} |