summaryrefslogtreecommitdiff
path: root/workhorse/internal/channel/auth_checker_test.go
blob: 18beb45cf3a969cff85cdfd47b3b42a9d56aae36 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package channel

import (
	"testing"
	"time"

	"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)

func checkerSeries(values ...*api.ChannelSettings) AuthCheckerFunc {
	return func() *api.ChannelSettings {
		if len(values) == 0 {
			return nil
		}
		out := values[0]
		values = values[1:]
		return out
	}
}

func TestAuthCheckerStopsWhenAuthFails(t *testing.T) {
	template := &api.ChannelSettings{Url: "ws://example.com"}
	stopCh := make(chan error)
	series := checkerSeries(template, template, template)
	ac := NewAuthChecker(series, template, stopCh)

	go ac.Loop(1 * time.Millisecond)
	if err := <-stopCh; err != ErrAuthChanged {
		t.Fatalf("Expected ErrAuthChanged, got %v", err)
	}

	if ac.Count != 3 {
		t.Fatalf("Expected 3 successful checks, got %v", ac.Count)
	}
}

func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) {
	template := &api.ChannelSettings{Url: "ws://example.com"}
	changed := template.Clone()
	changed.Url = "wss://example.com"
	stopCh := make(chan error)
	series := checkerSeries(template, changed, template)
	ac := NewAuthChecker(series, template, stopCh)

	go ac.Loop(1 * time.Millisecond)
	if err := <-stopCh; err != ErrAuthChanged {
		t.Fatalf("Expected ErrAuthChanged, got %v", err)
	}

	if ac.Count != 1 {
		t.Fatalf("Expected 1 successful check, got %v", ac.Count)
	}
}