summaryrefslogtreecommitdiff
path: root/workhorse/internal/channel/wrappers.go
blob: 6fd955bedc7366a2e23eb8a4deacbe2a7adb67b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package channel

import (
	"encoding/base64"
	"net"
	"time"

	"github.com/gorilla/websocket"
)

func Wrap(conn Connection, subprotocol string) Connection {
	switch subprotocol {
	case "channel.k8s.io":
		return &kubeWrapper{base64: false, conn: conn}
	case "base64.channel.k8s.io":
		return &kubeWrapper{base64: true, conn: conn}
	case "terminal.gitlab.com":
		return &gitlabWrapper{base64: false, conn: conn}
	case "base64.terminal.gitlab.com":
		return &gitlabWrapper{base64: true, conn: conn}
	}

	return conn
}

type kubeWrapper struct {
	base64 bool
	conn   Connection
}

type gitlabWrapper struct {
	base64 bool
	conn   Connection
}

func (w *gitlabWrapper) ReadMessage() (int, []byte, error) {
	mt, data, err := w.conn.ReadMessage()
	if err != nil {
		return mt, data, err
	}

	if isData(mt) {
		mt = websocket.BinaryMessage
		if w.base64 {
			data, err = decodeBase64(data)
		}
	}

	return mt, data, err
}

func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error {
	if isData(mt) {
		if w.base64 {
			mt = websocket.TextMessage
			data = encodeBase64(data)
		} else {
			mt = websocket.BinaryMessage
		}
	}

	return w.conn.WriteMessage(mt, data)
}

func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
	return w.conn.WriteControl(mt, data, deadline)
}

func (w *gitlabWrapper) UnderlyingConn() net.Conn {
	return w.conn.UnderlyingConn()
}

// Coalesces all wsstreams into a single stream. In practice, we should only
// receive data on stream 1.
func (w *kubeWrapper) ReadMessage() (int, []byte, error) {
	mt, data, err := w.conn.ReadMessage()
	if err != nil {
		return mt, data, err
	}

	if isData(mt) {
		mt = websocket.BinaryMessage

		// Remove the WSStream channel number, decode to raw
		if len(data) > 0 {
			data = data[1:]
			if w.base64 {
				data, err = decodeBase64(data)
			}
		}
	}

	return mt, data, err
}

// Always sends to wsstream 0
func (w *kubeWrapper) WriteMessage(mt int, data []byte) error {
	if isData(mt) {
		if w.base64 {
			mt = websocket.TextMessage
			data = append([]byte{'0'}, encodeBase64(data)...)
		} else {
			mt = websocket.BinaryMessage
			data = append([]byte{0}, data...)
		}
	}

	return w.conn.WriteMessage(mt, data)
}

func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
	return w.conn.WriteControl(mt, data, deadline)
}

func (w *kubeWrapper) UnderlyingConn() net.Conn {
	return w.conn.UnderlyingConn()
}

func isData(mt int) bool {
	return mt == websocket.BinaryMessage || mt == websocket.TextMessage
}

func encodeBase64(data []byte) []byte {
	buf := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
	base64.StdEncoding.Encode(buf, data)

	return buf
}

func decodeBase64(data []byte) ([]byte, error) {
	buf := make([]byte, base64.StdEncoding.DecodedLen(len(data)))
	n, err := base64.StdEncoding.Decode(buf, data)
	return buf[:n], err
}