From 026975dc7fae48b8ba5d86f6c10b1112cf775d57 Mon Sep 17 00:00:00 2001 From: James Fargher Date: Tue, 8 Nov 2022 09:24:30 +1300 Subject: sshd: Add ProxyAllowed setting to limit PROXY protocol IP addresses Changelog: added --- internal/config/config.go | 1 + internal/sshd/sshd.go | 4 +++ internal/sshd/sshd_test.go | 67 +++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index debab3f..35d8e74 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -27,6 +27,7 @@ type ServerConfig struct { Listen string `yaml:"listen,omitempty"` ProxyProtocol bool `yaml:"proxy_protocol,omitempty"` ProxyPolicy string `yaml:"proxy_policy,omitempty"` + ProxyAllowed []string `yaml:"proxy_allowed,omitempty"` WebListen string `yaml:"web_listen,omitempty"` ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"` ClientAliveInterval YamlDuration `yaml:"client_alive_interval,omitempty"` diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 19dc96a..c61d527 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -201,6 +201,10 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { } func (s *Server) requirePolicy() proxyproto.PolicyFunc { + if len(s.Config.Server.ProxyAllowed) > 0 { + return proxyproto.MustStrictWhiteListPolicy(s.Config.Server.ProxyAllowed) + } + // Set the Policy value based on config // Values are taken from https://github.com/pires/go-proxyproto/blob/195fedcfbfc1be163f3a0d507fac1709e9d81fed/policy.go#L20 switch strings.ToLower(s.Config.Server.ProxyPolicy) { diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index a197430..c14a9f5 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -50,7 +50,7 @@ func TestListenAndServe(t *testing.T) { verifyStatus(t, s, StatusClosed) } -func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testing.T) { +func TestListenAndServe_proxyProtocolEnabled(t *testing.T) { target, err := net.ResolveTCPAddr("tcp", serverUrl) require.NoError(t, err) @@ -70,10 +70,11 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin }() testCases := []struct { - desc string - proxyPolicy string - header *proxyproto.Header - isRejected bool + desc string + proxyPolicy string + proxyAllowed []string + header *proxyproto.Header + isRejected bool }{ { desc: "USE (default) without a header", @@ -123,11 +124,65 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin header: header, isRejected: false, }, + { + desc: "Allow-listed IP with a header", + proxyAllowed: []string{"127.0.0.1"}, + header: header, + isRejected: false, + }, + { + desc: "Allow-listed IP without a header", + proxyAllowed: []string{"127.0.0.1"}, + header: nil, + isRejected: false, + }, + { + desc: "Allow-listed range with a header", + proxyAllowed: []string{"127.0.0.0/24"}, + header: header, + isRejected: false, + }, + { + desc: "Allow-listed range without a header", + proxyAllowed: []string{"127.0.0.0/24"}, + header: nil, + isRejected: false, + }, + { + desc: "Not allow-listed IP with a header", + proxyAllowed: []string{"192.168.1.1"}, + header: header, + isRejected: true, + }, + { + desc: "Not allow-listed IP without a header", + proxyAllowed: []string{"192.168.1.1"}, + header: nil, + isRejected: false, + }, + { + desc: "Not allow-listed range with a header", + proxyAllowed: []string{"192.168.1.0/24"}, + header: header, + isRejected: true, + }, + { + desc: "Not allow-listed range without a header", + proxyAllowed: []string{"192.168.1.0/24"}, + header: nil, + isRejected: false, + }, } for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true, ProxyPolicy: tc.proxyPolicy}}) + setupServerWithConfig(t, &config.Config{ + Server: config.ServerConfig{ + ProxyProtocol: true, + ProxyPolicy: tc.proxyPolicy, + ProxyAllowed: tc.proxyAllowed, + }, + }) conn, err := net.DialTCP("tcp", nil, target) require.NoError(t, err) -- cgit v1.2.1