diff options
Diffstat (limited to 'workhorse/internal/channel')
-rw-r--r-- | workhorse/internal/channel/auth_checker.go | 69 | ||||
-rw-r--r-- | workhorse/internal/channel/auth_checker_test.go | 53 | ||||
-rw-r--r-- | workhorse/internal/channel/channel.go | 132 | ||||
-rw-r--r-- | workhorse/internal/channel/proxy.go | 56 | ||||
-rw-r--r-- | workhorse/internal/channel/wrappers.go | 134 | ||||
-rw-r--r-- | workhorse/internal/channel/wrappers_test.go | 155 |
6 files changed, 599 insertions, 0 deletions
diff --git a/workhorse/internal/channel/auth_checker.go b/workhorse/internal/channel/auth_checker.go new file mode 100644 index 00000000000..f44850e0861 --- /dev/null +++ b/workhorse/internal/channel/auth_checker.go @@ -0,0 +1,69 @@ +package channel + +import ( + "errors" + "net/http" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" +) + +type AuthCheckerFunc func() *api.ChannelSettings + +// Regularly checks that authorization is still valid for a channel, outputting +// to the stopper when it isn't +type AuthChecker struct { + Checker AuthCheckerFunc + Template *api.ChannelSettings + StopCh chan error + Done chan struct{} + Count int64 +} + +var ErrAuthChanged = errors.New("connection closed: authentication changed or endpoint unavailable") + +func NewAuthChecker(f AuthCheckerFunc, template *api.ChannelSettings, stopCh chan error) *AuthChecker { + return &AuthChecker{ + Checker: f, + Template: template, + StopCh: stopCh, + Done: make(chan struct{}), + } +} +func (c *AuthChecker) Loop(interval time.Duration) { + for { + select { + case <-time.After(interval): + settings := c.Checker() + if !c.Template.IsEqual(settings) { + c.StopCh <- ErrAuthChanged + return + } + c.Count = c.Count + 1 + case <-c.Done: + return + } + } +} + +func (c *AuthChecker) Close() error { + close(c.Done) + return nil +} + +// Generates a CheckerFunc from an *api.API + request needing authorization +func authCheckFunc(myAPI *api.API, r *http.Request, suffix string) AuthCheckerFunc { + return func() *api.ChannelSettings { + httpResponse, authResponse, err := myAPI.PreAuthorize(suffix, r) + if err != nil { + return nil + } + defer httpResponse.Body.Close() + + if httpResponse.StatusCode != http.StatusOK || authResponse == nil { + return nil + } + + return authResponse.Channel + } +} diff --git a/workhorse/internal/channel/auth_checker_test.go b/workhorse/internal/channel/auth_checker_test.go new file mode 100644 index 00000000000..18beb45cf3a --- /dev/null +++ b/workhorse/internal/channel/auth_checker_test.go @@ -0,0 +1,53 @@ +package channel + +import ( + "testing" + "time" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" +) + +func checkerSeries(values ...*api.ChannelSettings) AuthCheckerFunc { + return func() *api.ChannelSettings { + if len(values) == 0 { + return nil + } + out := values[0] + values = values[1:] + return out + } +} + +func TestAuthCheckerStopsWhenAuthFails(t *testing.T) { + template := &api.ChannelSettings{Url: "ws://example.com"} + stopCh := make(chan error) + series := checkerSeries(template, template, template) + ac := NewAuthChecker(series, template, stopCh) + + go ac.Loop(1 * time.Millisecond) + if err := <-stopCh; err != ErrAuthChanged { + t.Fatalf("Expected ErrAuthChanged, got %v", err) + } + + if ac.Count != 3 { + t.Fatalf("Expected 3 successful checks, got %v", ac.Count) + } +} + +func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) { + template := &api.ChannelSettings{Url: "ws://example.com"} + changed := template.Clone() + changed.Url = "wss://example.com" + stopCh := make(chan error) + series := checkerSeries(template, changed, template) + ac := NewAuthChecker(series, template, stopCh) + + go ac.Loop(1 * time.Millisecond) + if err := <-stopCh; err != ErrAuthChanged { + t.Fatalf("Expected ErrAuthChanged, got %v", err) + } + + if ac.Count != 1 { + t.Fatalf("Expected 1 successful check, got %v", ac.Count) + } +} diff --git a/workhorse/internal/channel/channel.go b/workhorse/internal/channel/channel.go new file mode 100644 index 00000000000..381ce95df82 --- /dev/null +++ b/workhorse/internal/channel/channel.go @@ -0,0 +1,132 @@ +package channel + +import ( + "fmt" + "net/http" + "time" + + "github.com/gorilla/websocket" + + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" +) + +var ( + // See doc/channel.md for documentation of this subprotocol + subprotocols = []string{"terminal.gitlab.com", "base64.terminal.gitlab.com"} + upgrader = &websocket.Upgrader{Subprotocols: subprotocols} + ReauthenticationInterval = 5 * time.Minute + BrowserPingInterval = 30 * time.Second +) + +func Handler(myAPI *api.API) http.Handler { + return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) { + if err := a.Channel.Validate(); err != nil { + helper.Fail500(w, r, err) + return + } + + proxy := NewProxy(2) // two stoppers: auth checker, max time + checker := NewAuthChecker( + authCheckFunc(myAPI, r, "authorize"), + a.Channel, + proxy.StopCh, + ) + defer checker.Close() + go checker.Loop(ReauthenticationInterval) + go closeAfterMaxTime(proxy, a.Channel.MaxSessionTime) + + ProxyChannel(w, r, a.Channel, proxy) + }, "authorize") +} + +func ProxyChannel(w http.ResponseWriter, r *http.Request, settings *api.ChannelSettings, proxy *Proxy) { + server, err := connectToServer(settings, r) + if err != nil { + helper.Fail500(w, r, err) + log.ContextLogger(r.Context()).WithError(err).Print("Channel: connecting to server failed") + return + } + defer server.UnderlyingConn().Close() + serverAddr := server.UnderlyingConn().RemoteAddr().String() + + client, err := upgradeClient(w, r) + if err != nil { + log.ContextLogger(r.Context()).WithError(err).Print("Channel: upgrading client to websocket failed") + return + } + + // Regularly send ping messages to the browser to keep the websocket from + // being timed out by intervening proxies. + go pingLoop(client) + + defer client.UnderlyingConn().Close() + clientAddr := getClientAddr(r) // We can't know the port with confidence + + logEntry := log.WithContextFields(r.Context(), log.Fields{ + "clientAddr": clientAddr, + "serverAddr": serverAddr, + }) + + logEntry.Print("Channel: started proxying") + + defer logEntry.Print("Channel: finished proxying") + + if err := proxy.Serve(server, client, serverAddr, clientAddr); err != nil { + logEntry.WithError(err).Print("Channel: error proxying") + } +} + +// In the future, we might want to look at X-Client-Ip or X-Forwarded-For +func getClientAddr(r *http.Request) string { + return r.RemoteAddr +} + +func upgradeClient(w http.ResponseWriter, r *http.Request) (Connection, error) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, err + } + + return Wrap(conn, conn.Subprotocol()), nil +} + +func pingLoop(conn Connection) { + for { + time.Sleep(BrowserPingInterval) + deadline := time.Now().Add(5 * time.Second) + if err := conn.WriteControl(websocket.PingMessage, nil, deadline); err != nil { + // Either the connection was already closed so no further pings are + // needed, or this connection is now dead and no further pings can + // be sent. + break + } + } +} + +func connectToServer(settings *api.ChannelSettings, r *http.Request) (Connection, error) { + settings = settings.Clone() + + helper.SetForwardedFor(&settings.Header, r) + + conn, _, err := settings.Dial() + if err != nil { + return nil, err + } + + return Wrap(conn, conn.Subprotocol()), nil +} + +func closeAfterMaxTime(proxy *Proxy, maxSessionTime int) { + if maxSessionTime == 0 { + return + } + + <-time.After(time.Duration(maxSessionTime) * time.Second) + proxy.StopCh <- fmt.Errorf( + "connection closed: session time greater than maximum time allowed - %v seconds", + maxSessionTime, + ) +} diff --git a/workhorse/internal/channel/proxy.go b/workhorse/internal/channel/proxy.go new file mode 100644 index 00000000000..71f58092276 --- /dev/null +++ b/workhorse/internal/channel/proxy.go @@ -0,0 +1,56 @@ +package channel + +import ( + "fmt" + "net" + "time" + + "github.com/gorilla/websocket" +) + +// ANSI "end of channel" code +var eot = []byte{0x04} + +// An abstraction of gorilla's *websocket.Conn +type Connection interface { + UnderlyingConn() net.Conn + ReadMessage() (int, []byte, error) + WriteMessage(int, []byte) error + WriteControl(int, []byte, time.Time) error +} + +type Proxy struct { + StopCh chan error +} + +// stoppers is the number of goroutines that may attempt to call Stop() +func NewProxy(stoppers int) *Proxy { + return &Proxy{ + StopCh: make(chan error, stoppers+2), // each proxy() call is a stopper + } +} + +func (p *Proxy) Serve(upstream, downstream Connection, upstreamAddr, downstreamAddr string) error { + // This signals the upstream channel to kill the exec'd process + defer upstream.WriteMessage(websocket.BinaryMessage, eot) + + go p.proxy(upstream, downstream, upstreamAddr, downstreamAddr) + go p.proxy(downstream, upstream, downstreamAddr, upstreamAddr) + + return <-p.StopCh +} + +func (p *Proxy) proxy(to, from Connection, toAddr, fromAddr string) { + for { + messageType, data, err := from.ReadMessage() + if err != nil { + p.StopCh <- fmt.Errorf("reading from %s: %s", fromAddr, err) + break + } + + if err := to.WriteMessage(messageType, data); err != nil { + p.StopCh <- fmt.Errorf("writing to %s: %s", toAddr, err) + break + } + } +} diff --git a/workhorse/internal/channel/wrappers.go b/workhorse/internal/channel/wrappers.go new file mode 100644 index 00000000000..6fd955bedc7 --- /dev/null +++ b/workhorse/internal/channel/wrappers.go @@ -0,0 +1,134 @@ +package channel + +import ( + "encoding/base64" + "net" + "time" + + "github.com/gorilla/websocket" +) + +func Wrap(conn Connection, subprotocol string) Connection { + switch subprotocol { + case "channel.k8s.io": + return &kubeWrapper{base64: false, conn: conn} + case "base64.channel.k8s.io": + return &kubeWrapper{base64: true, conn: conn} + case "terminal.gitlab.com": + return &gitlabWrapper{base64: false, conn: conn} + case "base64.terminal.gitlab.com": + return &gitlabWrapper{base64: true, conn: conn} + } + + return conn +} + +type kubeWrapper struct { + base64 bool + conn Connection +} + +type gitlabWrapper struct { + base64 bool + conn Connection +} + +func (w *gitlabWrapper) ReadMessage() (int, []byte, error) { + mt, data, err := w.conn.ReadMessage() + if err != nil { + return mt, data, err + } + + if isData(mt) { + mt = websocket.BinaryMessage + if w.base64 { + data, err = decodeBase64(data) + } + } + + return mt, data, err +} + +func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error { + if isData(mt) { + if w.base64 { + mt = websocket.TextMessage + data = encodeBase64(data) + } else { + mt = websocket.BinaryMessage + } + } + + return w.conn.WriteMessage(mt, data) +} + +func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error { + return w.conn.WriteControl(mt, data, deadline) +} + +func (w *gitlabWrapper) UnderlyingConn() net.Conn { + return w.conn.UnderlyingConn() +} + +// Coalesces all wsstreams into a single stream. In practice, we should only +// receive data on stream 1. +func (w *kubeWrapper) ReadMessage() (int, []byte, error) { + mt, data, err := w.conn.ReadMessage() + if err != nil { + return mt, data, err + } + + if isData(mt) { + mt = websocket.BinaryMessage + + // Remove the WSStream channel number, decode to raw + if len(data) > 0 { + data = data[1:] + if w.base64 { + data, err = decodeBase64(data) + } + } + } + + return mt, data, err +} + +// Always sends to wsstream 0 +func (w *kubeWrapper) WriteMessage(mt int, data []byte) error { + if isData(mt) { + if w.base64 { + mt = websocket.TextMessage + data = append([]byte{'0'}, encodeBase64(data)...) + } else { + mt = websocket.BinaryMessage + data = append([]byte{0}, data...) + } + } + + return w.conn.WriteMessage(mt, data) +} + +func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error { + return w.conn.WriteControl(mt, data, deadline) +} + +func (w *kubeWrapper) UnderlyingConn() net.Conn { + return w.conn.UnderlyingConn() +} + +func isData(mt int) bool { + return mt == websocket.BinaryMessage || mt == websocket.TextMessage +} + +func encodeBase64(data []byte) []byte { + buf := make([]byte, base64.StdEncoding.EncodedLen(len(data))) + base64.StdEncoding.Encode(buf, data) + + return buf +} + +func decodeBase64(data []byte) ([]byte, error) { + buf := make([]byte, base64.StdEncoding.DecodedLen(len(data))) + n, err := base64.StdEncoding.Decode(buf, data) + return buf[:n], err +} diff --git a/workhorse/internal/channel/wrappers_test.go b/workhorse/internal/channel/wrappers_test.go new file mode 100644 index 00000000000..1e0226f85d8 --- /dev/null +++ b/workhorse/internal/channel/wrappers_test.go @@ -0,0 +1,155 @@ +package channel + +import ( + "bytes" + "errors" + "net" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +type testcase struct { + input *fakeConn + expected *fakeConn +} + +type fakeConn struct { + // WebSocket message type + mt int + data []byte + err error +} + +func (f *fakeConn) ReadMessage() (int, []byte, error) { + return f.mt, f.data, f.err +} + +func (f *fakeConn) WriteMessage(mt int, data []byte) error { + f.mt = mt + f.data = data + return f.err +} + +func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error { + f.mt = mt + f.data = data + return f.err +} + +func (f *fakeConn) UnderlyingConn() net.Conn { + return nil +} + +func fake(mt int, data []byte, err error) *fakeConn { + return &fakeConn{mt: mt, data: []byte(data), err: err} +} + +var ( + msg = []byte("foo bar") + msgBase64 = []byte("Zm9vIGJhcg==") + kubeMsg = append([]byte{0}, msg...) + kubeMsgBase64 = append([]byte{'0'}, msgBase64...) + + errFake = errors.New("fake error") + + text = websocket.TextMessage + binary = websocket.BinaryMessage + other = 999 + + fakeOther = fake(other, []byte("foo"), nil) +) + +func requireEqualConn(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) { + if expected.mt != actual.mt { + t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt) + t.Fatalf(msg, args...) + } + + if !bytes.Equal(expected.data, actual.data) { + t.Logf("data expected to be %q but was %q: ", expected.data, actual.data) + t.Fatalf(msg, args...) + } + + if expected.err != actual.err { + t.Logf("error expected to be %v but was %v", expected.err, actual.err) + t.Fatalf(msg, args...) + } +} + +func TestReadMessage(t *testing.T) { + testCases := map[string][]testcase{ + "channel.k8s.io": { + {fake(binary, kubeMsg, errFake), fake(binary, kubeMsg, errFake)}, + {fake(binary, kubeMsg, nil), fake(binary, msg, nil)}, + {fake(text, kubeMsg, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.channel.k8s.io": { + {fake(text, kubeMsgBase64, errFake), fake(text, kubeMsgBase64, errFake)}, + {fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)}, + {fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "terminal.gitlab.com": { + {fake(binary, msg, errFake), fake(binary, msg, errFake)}, + {fake(binary, msg, nil), fake(binary, msg, nil)}, + {fake(text, msg, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.terminal.gitlab.com": { + {fake(text, msgBase64, errFake), fake(text, msgBase64, errFake)}, + {fake(text, msgBase64, nil), fake(binary, msg, nil)}, + {fake(binary, msgBase64, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + } + + for subprotocol, cases := range testCases { + for i, tc := range cases { + conn := Wrap(tc.input, subprotocol) + mt, data, err := conn.ReadMessage() + actual := fake(mt, data, err) + requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i) + } + } +} + +func TestWriteMessage(t *testing.T) { + testCases := map[string][]testcase{ + "channel.k8s.io": { + {fake(binary, msg, errFake), fake(binary, kubeMsg, errFake)}, + {fake(binary, msg, nil), fake(binary, kubeMsg, nil)}, + {fake(text, msg, nil), fake(binary, kubeMsg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.channel.k8s.io": { + {fake(binary, msg, errFake), fake(text, kubeMsgBase64, errFake)}, + {fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)}, + {fake(text, msg, nil), fake(text, kubeMsgBase64, nil)}, + {fakeOther, fakeOther}, + }, + "terminal.gitlab.com": { + {fake(binary, msg, errFake), fake(binary, msg, errFake)}, + {fake(binary, msg, nil), fake(binary, msg, nil)}, + {fake(text, msg, nil), fake(binary, msg, nil)}, + {fakeOther, fakeOther}, + }, + "base64.terminal.gitlab.com": { + {fake(binary, msg, errFake), fake(text, msgBase64, errFake)}, + {fake(binary, msg, nil), fake(text, msgBase64, nil)}, + {fake(text, msg, nil), fake(text, msgBase64, nil)}, + {fakeOther, fakeOther}, + }, + } + + for subprotocol, cases := range testCases { + for i, tc := range cases { + actual := fake(0, nil, tc.input.err) + conn := Wrap(actual, subprotocol) + actual.err = conn.WriteMessage(tc.input.mt, tc.input.data) + requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i) + } + } +} |