1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
|
package channel
import (
"testing"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
)
func checkerSeries(values ...*api.ChannelSettings) AuthCheckerFunc {
return func() *api.ChannelSettings {
if len(values) == 0 {
return nil
}
out := values[0]
values = values[1:]
return out
}
}
func TestAuthCheckerStopsWhenAuthFails(t *testing.T) {
template := &api.ChannelSettings{Url: "ws://example.com"}
stopCh := make(chan error)
series := checkerSeries(template, template, template)
ac := NewAuthChecker(series, template, stopCh)
go ac.Loop(1 * time.Millisecond)
if err := <-stopCh; err != ErrAuthChanged {
t.Fatalf("Expected ErrAuthChanged, got %v", err)
}
if ac.Count != 3 {
t.Fatalf("Expected 3 successful checks, got %v", ac.Count)
}
}
func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) {
template := &api.ChannelSettings{Url: "ws://example.com"}
changed := template.Clone()
changed.Url = "wss://example.com"
stopCh := make(chan error)
series := checkerSeries(template, changed, template)
ac := NewAuthChecker(series, template, stopCh)
go ac.Loop(1 * time.Millisecond)
if err := <-stopCh; err != ErrAuthChanged {
t.Fatalf("Expected ErrAuthChanged, got %v", err)
}
if ac.Count != 1 {
t.Fatalf("Expected 1 successful check, got %v", ac.Count)
}
}
|