summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2022-05-10 23:16:22 +0400
committerIgor Drozdov <idrozdov@gitlab.com>2022-05-10 23:23:53 +0400
commit709c5dd75a7c1a2a0f3296d76ddc654191841213 (patch)
treed80a8b1ed3d340116770122b99b56bf43d2bad88 /internal
parent733845f9abec43b6573ba3a1167cc27ff2bfc199 (diff)
downloadgitlab-shell-709c5dd75a7c1a2a0f3296d76ddc654191841213.tar.gz
Make PROXY policy configurable
It would give us more flexibility when we decide to enable PROXY protocol
Diffstat (limited to 'internal')
-rw-r--r--internal/config/config.go1
-rw-r--r--internal/sshd/sshd.go18
-rw-r--r--internal/sshd/sshd_test.go106
3 files changed, 110 insertions, 15 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index ff0c79a..ab88d72 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -24,6 +24,7 @@ const (
type ServerConfig struct {
Listen string `yaml:"listen,omitempty"`
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
+ ProxyPolicy string `yaml:"proxy_policy,omitempty"`
WebListen string `yaml:"web_listen,omitempty"`
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
GracePeriodSeconds uint64 `yaml:"grace_period"`
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 49b8ab9..c2758f0 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
+ "strings"
"sync"
"time"
@@ -95,7 +96,7 @@ func (s *Server) listen(ctx context.Context) error {
if s.Config.Server.ProxyProtocol {
sshListener = &proxyproto.Listener{
Listener: sshListener,
- Policy: unconditionalRequirePolicy,
+ Policy: s.requirePolicy,
ReadHeaderTimeout: ProxyHeaderTimeout,
}
@@ -210,6 +211,17 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
}).Info("server: handleConn: done")
}
-func unconditionalRequirePolicy(_ net.Addr) (proxyproto.Policy, error) {
- return proxyproto.REQUIRE, nil
+func (s *Server) requirePolicy(_ net.Addr) (proxyproto.Policy, error) {
+ // Set the Policy value based on config
+ // Values are taken from https://github.com/pires/go-proxyproto/blob/195fedcfbfc1be163f3a0d507fac1709e9d81fed/policy.go#L20
+ switch strings.ToLower(s.Config.Server.ProxyPolicy) {
+ case "require":
+ return proxyproto.REQUIRE, nil
+ case "ignore":
+ return proxyproto.IGNORE, nil
+ case "reject":
+ return proxyproto.REJECT, nil
+ default:
+ return proxyproto.USE, nil
+ }
}
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index 0c6a8ec..d725add 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -3,6 +3,7 @@ package sshd
import (
"context"
"fmt"
+ "net"
"net/http"
"net/http/httptest"
"os"
@@ -10,6 +11,7 @@ import (
"testing"
"time"
+ "github.com/pires/go-proxyproto"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
@@ -48,15 +50,101 @@ func TestListenAndServe(t *testing.T) {
}
func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testing.T) {
- setupServerWithProxyProtocolEnabled(t)
+ target, err := net.ResolveTCPAddr("tcp", serverUrl)
+ require.NoError(t, err)
- client, err := ssh.Dial("tcp", serverUrl, clientConfig(t))
- if client != nil {
- client.Close()
+ header := &proxyproto.Header{
+ Version: 2,
+ Command: proxyproto.PROXY,
+ TransportProtocol: proxyproto.TCPv4,
+ SourceAddr: &net.TCPAddr{
+ IP: net.ParseIP("10.1.1.1"),
+ Port: 1000,
+ },
+ DestinationAddr: target,
+ }
+
+ testCases := []struct {
+ desc string
+ proxyPolicy string
+ header *proxyproto.Header
+ isRejected bool
+ }{
+ {
+ desc: "USE (default) without a header",
+ proxyPolicy: "",
+ header: nil,
+ isRejected: false,
+ },
+ {
+ desc: "USE (default) with a header",
+ proxyPolicy: "",
+ header: header,
+ isRejected: false,
+ },
+ {
+ desc: "REQUIRE without a header",
+ proxyPolicy: "require",
+ header: nil,
+ isRejected: true,
+ },
+ {
+ desc: "REQUIRE with a header",
+ proxyPolicy: "require",
+ header: header,
+ isRejected: false,
+ },
+ {
+ desc: "REJECT without a header",
+ proxyPolicy: "reject",
+ header: nil,
+ isRejected: false,
+ },
+ {
+ desc: "REJECT with a header",
+ proxyPolicy: "reject",
+ header: header,
+ isRejected: true,
+ },
+ {
+ desc: "IGNORE without a header",
+ proxyPolicy: "ignore",
+ header: nil,
+ isRejected: false,
+ },
+ {
+ desc: "IGNORE with a header",
+ proxyPolicy: "ignore",
+ header: header,
+ isRejected: false,
+ },
}
- require.Error(t, err, "Expected plain SSH request to be failed")
- require.Regexp(t, "ssh: handshake failed", err.Error())
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true, ProxyPolicy: tc.proxyPolicy}})
+
+ conn, err := net.DialTCP("tcp", nil, target)
+ require.NoError(t, err)
+
+ if tc.header != nil {
+ _, err := header.WriteTo(conn)
+ require.NoError(t, err)
+ }
+
+ sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
+ if sshConn != nil {
+ sshConn.Close()
+ }
+
+ if tc.isRejected {
+ require.Error(t, err, "Expected plain SSH request to be failed")
+ require.Regexp(t, "ssh: handshake failed", err.Error())
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
}
func TestCorrelationId(t *testing.T) {
@@ -140,12 +228,6 @@ func setupServer(t *testing.T) *Server {
return setupServerWithConfig(t, nil)
}
-func setupServerWithProxyProtocolEnabled(t *testing.T) *Server {
- t.Helper()
-
- return setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true}})
-}
-
func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
t.Helper()