summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/sshd/sshd.go5
-rw-r--r--internal/sshd/sshd_test.go42
2 files changed, 42 insertions, 5 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 19fa661..d765faf 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -94,6 +94,7 @@ func (s *Server) listen(ctx context.Context) error {
if s.Config.Server.ProxyProtocol {
sshListener = &proxyproto.Listener{
Listener: sshListener,
+ Policy: unconditionalRequirePolicy,
ReadHeaderTimeout: ProxyHeaderTimeout,
}
@@ -185,3 +186,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
ctxlog.Info("server: handleConn: done")
}
+
+func unconditionalRequirePolicy(_ net.Addr) (proxyproto.Policy, error) {
+ return proxyproto.REQUIRE, nil
+}
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index 71f7733..455a830 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -47,6 +47,19 @@ func TestListenAndServe(t *testing.T) {
verifyStatus(t, s, StatusClosed)
}
+func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testing.T) {
+ s := setupServerWithProxyProtocolEnabled(t)
+ defer s.Shutdown()
+
+ client, err := ssh.Dial("tcp", serverUrl, clientConfig(t))
+ if client != nil {
+ client.Close()
+ }
+
+ require.Error(t, err, "Expected plain SSH request to be failed")
+ require.Equal(t, err.Error(), "ssh: handshake failed: EOF")
+}
+
func TestCorrelationId(t *testing.T) {
setupServer(t)
@@ -125,6 +138,18 @@ func TestInvalidServerConfig(t *testing.T) {
func setupServer(t *testing.T) *Server {
t.Helper()
+ return setupServerWithConfig(t, nil)
+}
+
+func setupServerWithProxyProtocolEnabled(t *testing.T) *Server {
+ t.Helper()
+
+ return setupServerWithConfig(t, &config.Config{Server: config.ServerConfig{ProxyProtocol: true}})
+}
+
+func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
+ t.Helper()
+
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/authorized_keys",
@@ -148,13 +173,20 @@ func setupServer(t *testing.T) *Server {
testhelper.PrepareTestRootDir(t)
url := testserver.StartSocketHttpServer(t, requests)
- srvCfg := config.ServerConfig{
- Listen: serverUrl,
- ConcurrentSessionsLimit: 1,
- HostKeyFiles: []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")},
+
+ if cfg == nil {
+ cfg = &config.Config{}
}
- s, err := NewServer(&config.Config{User: user, RootDir: "/tmp", GitlabUrl: url, Server: srvCfg})
+ // All things that don't need to be configurable in tests yet
+ cfg.GitlabUrl = url
+ cfg.RootDir = "/tmp"
+ cfg.User = user
+ cfg.Server.Listen = serverUrl
+ cfg.Server.ConcurrentSessionsLimit = 1
+ cfg.Server.HostKeyFiles = []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}
+
+ s, err := NewServer(cfg)
require.NoError(t, err)
go func() { require.NoError(t, s.ListenAndServe(context.Background())) }()