summaryrefslogtreecommitdiff
path: root/workhorse/internal/api/channel_settings_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/api/channel_settings_test.go')
-rw-r--r--workhorse/internal/api/channel_settings_test.go154
1 files changed, 154 insertions, 0 deletions
diff --git a/workhorse/internal/api/channel_settings_test.go b/workhorse/internal/api/channel_settings_test.go
new file mode 100644
index 00000000000..4aa2c835579
--- /dev/null
+++ b/workhorse/internal/api/channel_settings_test.go
@@ -0,0 +1,154 @@
+package api
+
+import (
+ "net/http"
+ "testing"
+)
+
+func channel(url string, subprotocols ...string) *ChannelSettings {
+ return &ChannelSettings{
+ Url: url,
+ Subprotocols: subprotocols,
+ MaxSessionTime: 0,
+ }
+}
+
+func ca(channel *ChannelSettings) *ChannelSettings {
+ channel = channel.Clone()
+ channel.CAPem = "Valid CA data"
+
+ return channel
+}
+
+func timeout(channel *ChannelSettings) *ChannelSettings {
+ channel = channel.Clone()
+ channel.MaxSessionTime = 600
+
+ return channel
+}
+
+func header(channel *ChannelSettings, values ...string) *ChannelSettings {
+ if len(values) == 0 {
+ values = []string{"Dummy Value"}
+ }
+
+ channel = channel.Clone()
+ channel.Header = http.Header{
+ "Header": values,
+ }
+
+ return channel
+}
+
+func TestClone(t *testing.T) {
+ a := ca(header(channel("ws:", "", "")))
+ b := a.Clone()
+
+ if a == b {
+ t.Fatalf("Address of cloned channel didn't change")
+ }
+
+ if &a.Subprotocols == &b.Subprotocols {
+ t.Fatalf("Address of cloned subprotocols didn't change")
+ }
+
+ if &a.Header == &b.Header {
+ t.Fatalf("Address of cloned header didn't change")
+ }
+}
+
+func TestValidate(t *testing.T) {
+ for i, tc := range []struct {
+ channel *ChannelSettings
+ valid bool
+ msg string
+ }{
+ {nil, false, "nil channel"},
+ {channel("", ""), false, "empty URL"},
+ {channel("ws:"), false, "empty subprotocols"},
+ {channel("ws:", "foo"), true, "any subprotocol"},
+ {channel("ws:", "foo", "bar"), true, "multiple subprotocols"},
+ {channel("ws:", ""), true, "websocket URL"},
+ {channel("wss:", ""), true, "secure websocket URL"},
+ {channel("http:", ""), false, "HTTP URL"},
+ {channel("https:", ""), false, " HTTPS URL"},
+ {ca(channel("ws:", "")), true, "any CA pem"},
+ {header(channel("ws:", "")), true, "any headers"},
+ {ca(header(channel("ws:", ""))), true, "PEM and headers"},
+ } {
+ if err := tc.channel.Validate(); (err != nil) == tc.valid {
+ t.Fatalf("test case %d: "+tc.msg+": valid=%v: %s: %+v", i, tc.valid, err, tc.channel)
+ }
+ }
+}
+
+func TestDialer(t *testing.T) {
+ channel := channel("ws:", "foo")
+ dialer := channel.Dialer()
+
+ if len(dialer.Subprotocols) != len(channel.Subprotocols) {
+ t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols)
+ }
+
+ for i, subprotocol := range channel.Subprotocols {
+ if dialer.Subprotocols[i] != subprotocol {
+ t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols)
+ }
+ }
+
+ if dialer.TLSClientConfig != nil {
+ t.Fatalf("Unexpected TLSClientConfig: %+v", dialer)
+ }
+
+ channel = ca(channel)
+ dialer = channel.Dialer()
+
+ if dialer.TLSClientConfig == nil || dialer.TLSClientConfig.RootCAs == nil {
+ t.Fatalf("Custom CA certificates not recognised!")
+ }
+}
+
+func TestIsEqual(t *testing.T) {
+ chann := channel("ws:", "foo")
+
+ chann_header2 := header(chann, "extra")
+ chann_header3 := header(chann)
+ chann_header3.Header.Add("Extra", "extra")
+
+ chann_ca2 := ca(chann)
+ chann_ca2.CAPem = "other value"
+
+ for i, tc := range []struct {
+ channelA *ChannelSettings
+ channelB *ChannelSettings
+ expected bool
+ }{
+ {nil, nil, true},
+ {chann, nil, false},
+ {nil, chann, false},
+ {chann, chann, true},
+ {chann.Clone(), chann.Clone(), true},
+ {chann, channel("foo:"), false},
+ {chann, channel(chann.Url), false},
+ {header(chann), header(chann), true},
+ {chann_header2, chann_header2, true},
+ {chann_header3, chann_header3, true},
+ {header(chann), chann_header2, false},
+ {header(chann), chann_header3, false},
+ {header(chann), chann, false},
+ {chann, header(chann), false},
+ {ca(chann), ca(chann), true},
+ {ca(chann), chann, false},
+ {chann, ca(chann), false},
+ {ca(header(chann)), ca(header(chann)), true},
+ {chann_ca2, ca(chann), false},
+ {chann, timeout(chann), false},
+ } {
+ if actual := tc.channelA.IsEqual(tc.channelB); tc.expected != actual {
+ t.Fatalf(
+ "test case %d: Comparison:\n-%+v\n+%+v\nexpected=%v: actual=%v",
+ i, tc.channelA, tc.channelB, tc.expected, actual,
+ )
+ }
+ }
+}