diff options
-rw-r--r-- | cmd/gitlab-sshd/acceptance_test.go | 2 | ||||
-rw-r--r-- | config.yml.example | 4 | ||||
-rw-r--r-- | internal/config/config.go | 1 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 29 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 67 |
5 files changed, 90 insertions, 13 deletions
diff --git a/cmd/gitlab-sshd/acceptance_test.go b/cmd/gitlab-sshd/acceptance_test.go index e36c629..af263c8 100644 --- a/cmd/gitlab-sshd/acceptance_test.go +++ b/cmd/gitlab-sshd/acceptance_test.go @@ -477,6 +477,7 @@ func TestGitUploadArchiveSuccess(t *testing.T) { require.NoError(t, err) _, err = fmt.Fprintln(stdin, "0012argument HEAD\n0000") + require.NoError(t, err) line, err := reader.ReadString('\n') require.Equal(t, "0008ACK\n", line) @@ -489,5 +490,6 @@ func TestGitUploadArchiveSuccess(t *testing.T) { output, err := io.ReadAll(stdout) require.NoError(t, err) + t.Logf("output: %q", output) require.Equal(t, []byte("0000"), output[len(output)-4:]) } diff --git a/config.yml.example b/config.yml.example index 501b61b..0154723 100644 --- a/config.yml.example +++ b/config.yml.example @@ -72,6 +72,10 @@ sshd: # Proxy protocol policy ("use", "require", "reject", "ignore"), "use" is the default value # Values: https://github.com/pires/go-proxyproto/blob/195fedcfbfc1be163f3a0d507fac1709e9d81fed/policy.go#L20 proxy_policy: "use" + # Proxy allowed IP addresses. Takes precedent over proxy_policy. Disabled by default. + # proxy_allowed: + # - "192.168.0.1" + # - "192.168.1.0/24" # Address which the server listens on HTTP for monitoring/health checks. Defaults to localhost:9122. web_listen: "localhost:9122" # Maximum number of concurrent sessions allowed on a single SSH connection. Defaults to 10. diff --git a/internal/config/config.go b/internal/config/config.go index debab3f..35d8e74 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -27,6 +27,7 @@ type ServerConfig struct { Listen string `yaml:"listen,omitempty"` ProxyProtocol bool `yaml:"proxy_protocol,omitempty"` ProxyPolicy string `yaml:"proxy_policy,omitempty"` + ProxyAllowed []string `yaml:"proxy_allowed,omitempty"` WebListen string `yaml:"web_listen,omitempty"` ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"` ClientAliveInterval YamlDuration `yaml:"client_alive_interval,omitempty"` diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index b08b386..d20286a 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/pires/go-proxyproto" + proxyproto "github.com/pires/go-proxyproto" "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/v14/client" @@ -95,9 +95,14 @@ func (s *Server) listen(ctx context.Context) error { } if s.Config.Server.ProxyProtocol { + policy, err := s.proxyPolicy() + if err != nil { + return fmt.Errorf("invalid policy configuration: %w", err) + } + sshListener = &proxyproto.Listener{ Listener: sshListener, - Policy: s.requirePolicy, + Policy: policy, ReadHeaderTimeout: time.Duration(s.Config.Server.ProxyHeaderTimeout), } @@ -200,17 +205,27 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { }) } -func (s *Server) requirePolicy(_ net.Addr) (proxyproto.Policy, error) { +func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { + if len(s.Config.Server.ProxyAllowed) > 0 { + return proxyproto.StrictWhiteListPolicy(s.Config.Server.ProxyAllowed) + } + // Set the Policy value based on config // Values are taken from https://github.com/pires/go-proxyproto/blob/195fedcfbfc1be163f3a0d507fac1709e9d81fed/policy.go#L20 switch strings.ToLower(s.Config.Server.ProxyPolicy) { case "require": - return proxyproto.REQUIRE, nil + return staticProxyPolicy(proxyproto.REQUIRE), nil case "ignore": - return proxyproto.IGNORE, nil + return staticProxyPolicy(proxyproto.IGNORE), nil case "reject": - return proxyproto.REJECT, nil + return staticProxyPolicy(proxyproto.REJECT), nil default: - return proxyproto.USE, nil + return staticProxyPolicy(proxyproto.USE), nil + } +} + +func staticProxyPolicy(policy proxyproto.Policy) proxyproto.PolicyFunc { + return func(_ net.Addr) (proxyproto.Policy, error) { + return policy, nil } } diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index a197430..c14a9f5 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -50,7 +50,7 @@ func TestListenAndServe(t *testing.T) { verifyStatus(t, s, StatusClosed) } -func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testing.T) { +func TestListenAndServe_proxyProtocolEnabled(t *testing.T) { target, err := net.ResolveTCPAddr("tcp", serverUrl) require.NoError(t, err) @@ -70,10 +70,11 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin }() testCases := []struct { - desc string - proxyPolicy string - header *proxyproto.Header - isRejected bool + desc string + proxyPolicy string + proxyAllowed []string + header *proxyproto.Header + isRejected bool }{ { desc: "USE (default) without a header", @@ -123,11 +124,65 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin header: header, isRejected: false, }, + { + desc: "Allow-listed IP with a header", + proxyAllowed: []string{"127.0.0.1"}, + header: header, + isRejected: false, + }, + { + desc: "Allow-listed IP without a header", + proxyAllowed: []string{"127.0.0.1"}, + header: nil, + isRejected: false, + }, + { + desc: "Allow-listed range with a header", + proxyAllowed: []string{"127.0.0.0/24"}, + header: header, + isRejected: false, + }, + { + desc: "Allow-listed range without a header", + proxyAllowed: []string{"127.0.0.0/24"}, + header: nil, + isRejected: false, + }, + { + desc: "Not allow-listed IP with a header", + proxyAllowed: []string{"192.168.1.1"}, + header: header, + isRejected: true, + }, + { + desc: "Not allow-listed IP without a header", + proxyAllowed: []string{"192.168.1.1"}, + header: nil, + isRejected: false, + }, + { + desc: "Not allow-listed range with a header", + proxyAllowed: []string{"192.168.1.0/24"}, + header: header, + isRejected: true, + }, + { + desc: "Not allow-listed range without a header", + proxyAllowed: []string{"192.168.1.0/24"}, + header: nil, + isRejected: false, + }, } for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true, ProxyPolicy: tc.proxyPolicy}}) + setupServerWithConfig(t, &config.Config{ + Server: config.ServerConfig{ + ProxyProtocol: true, + ProxyPolicy: tc.proxyPolicy, + ProxyAllowed: tc.proxyAllowed, + }, + }) conn, err := net.DialTCP("tcp", nil, target) require.NoError(t, err) |