summaryrefslogtreecommitdiff
path: root/lib/d
diff options
context:
space:
mode:
authorJames Lacey <jamlacey@gmail.com>2020-04-06 09:17:59 -0700
committerJens Geyer <jensg@apache.org>2020-04-09 21:52:46 +0200
commit6bbdb1a46ce6ba0ac4e27e29b2c9c9eef107186c (patch)
tree7c0d2a28567427313cb857e607460c87afba52cd /lib/d
parentf44b6ee8db9342d804c6ba01da9953e791021bfe (diff)
downloadthrift-6bbdb1a46ce6ba0ac4e27e29b2c9c9eef107186c.tar.gz
THRIFT-5166: Add support for using WebSockets as a server transport.
Client: d Patch: James Lacey This closes #2087
Diffstat (limited to 'lib/d')
-rw-r--r--lib/d/src/thrift/transport/websocket.d388
1 files changed, 388 insertions, 0 deletions
diff --git a/lib/d/src/thrift/transport/websocket.d b/lib/d/src/thrift/transport/websocket.d
new file mode 100644
index 000000000..25d1f7d01
--- /dev/null
+++ b/lib/d/src/thrift/transport/websocket.d
@@ -0,0 +1,388 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+module thrift.transport.websocket;
+
+import std.algorithm;
+import std.algorithm.searching;
+import std.base64;
+import std.bitmanip;
+import std.conv;
+import std.digest.sha;
+import std.stdio;
+import std.string;
+import std.uni;
+import thrift.base : VERSION;
+import thrift.transport.base;
+import thrift.transport.http;
+
+/**
+ * WebSocket server transport.
+ */
+final class TServerWebSocketTransport(bool binary) : THttpTransport {
+ /**
+ * Constructs a new instance.
+ *
+ * Param:
+ * transport = The underlying transport used for the actual I/O.
+ */
+ this(TTransport transport) {
+ super(transport);
+ transport_ = transport;
+ }
+
+ override size_t read(ubyte[] buf) {
+ // If we do not have a good handshake, the client will attempt one.
+ if (!handshakeComplete) {
+ resetHandshake();
+ super.read(buf);
+ // If we did not get everything we expected, the handshake failed
+ // and we need to send a 400 response back.
+ if (!handshakeComplete) {
+ sendBadRequest();
+ return 0;
+ }
+ // Otherwise, send back the 101 response.
+ super.flush();
+ }
+
+ // If the buffer is empty, read a new frame off the wire.
+ if (readBuffer_.empty) {
+ if (!readFrame()) {
+ return 0;
+ }
+ }
+
+ auto size = min(readBuffer_.length, buf.length);
+ buf[0..size] = readBuffer_[0..size];
+ readBuffer_ = readBuffer_[size..$];
+ return size;
+ }
+
+ override void write(in ubyte[] buf) {
+ writeBuffer_ ~= buf;
+ }
+
+ override void flush() {
+ if (writeBuffer_.empty) {
+ return;
+ }
+
+ // Properly reset the write buffer even some of the protocol operations go
+ // wrong.
+ scope (exit) {
+ writeBuffer_.length = 0;
+ writeBuffer_.assumeSafeAppend();
+ }
+
+ writeFrameHeader();
+ transport_.write(writeBuffer_);
+ transport_.flush();
+ }
+
+protected:
+ override string getHeader(size_t dataLength) {
+ return "HTTP/1.1 101 Switching Protocols\r\n" ~
+ "Server: Thrift/" ~ VERSION ~ "\r\n" ~
+ "Upgrade: websocket\r\n" ~
+ "Connection: Upgrade\r\n" ~
+ "Sec-WebSocket-Accept: " ~ acceptKey_ ~ "\r\n" ~
+ "\r\n";
+ }
+
+ override void parseHeader(const(ubyte)[] header) {
+ auto split = findSplit(header, [':']);
+ if (split[1].empty) {
+ // No colon found.
+ return;
+ }
+
+ static bool compToLower(ubyte a, ubyte b) {
+ return toLower(a) == toLower(b);
+ }
+
+ if (startsWith!compToLower(split[0], cast(ubyte[])"upgrade")) {
+ auto upgrade = stripLeft(cast(const(char)[])split[2]);
+ upgrade_ = sicmp(upgrade, "websocket") == 0;
+ } else if (startsWith!compToLower(split[0], cast(ubyte[])"connection")) {
+ auto connection = stripLeft(cast(const(char)[])split[2]);
+ connection_ = sicmp(connection, "upgrade") == 0;
+ } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-key")) {
+ auto secWebSocketKey = stripLeft(cast(const(char)[])split[2]);
+ auto hash = sha1Of(secWebSocketKey ~ WEBSOCKET_GUID);
+ acceptKey_ = Base64.encode(hash);
+ secWebSocketKey_ = true;
+ } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-version")) {
+ auto secWebSocketVersion = stripLeft(cast(const(char)[])split[2]);
+ secWebSocketVersion_ = sicmp(secWebSocketVersion, "13") == 0;
+ }
+ }
+
+ override bool parseStatusLine(const(ubyte)[] status) {
+ // Method SP Request-URI SP HTTP-Version CRLF.
+ auto split = findSplit(status, [' ']);
+ if (split[1].empty) {
+ throw new TTransportException("Bad status: " ~ to!string(status),
+ TTransportException.Type.CORRUPTED_DATA);
+ }
+
+ auto uriVersion = split[2][countUntil!"a != b"(split[2], ' ') .. $];
+ if (!canFind(uriVersion, ' ')) {
+ throw new TTransportException("Bad status: " ~ to!string(status),
+ TTransportException.Type.CORRUPTED_DATA);
+ }
+
+ if (split[0] == "GET") {
+ // GET method ok, looking for content.
+ return true;
+ }
+
+ throw new TTransportException("Bad status (unsupported method): " ~
+ to!string(status), TTransportException.Type.CORRUPTED_DATA);
+ }
+
+private:
+ @property bool handshakeComplete() {
+ return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_;
+ }
+
+ void failConnection(CloseCode reason) {
+ writeFrameHeader(Opcode.Close);
+ transport_.write(nativeToBigEndian!ushort(reason));
+ transport_.flush();
+ transport_.close();
+ }
+
+ void pong() {
+ writeFrameHeader(Opcode.Pong);
+ transport_.write(readBuffer_);
+ transport_.flush();
+ }
+
+ bool readFrame() {
+ ubyte[8] headerBuffer;
+
+ auto read = transport_.read(headerBuffer[0..2]);
+ if (read < 2) {
+ return false;
+ }
+ // Since Thrift has its own message end marker and we read frame by frame,
+ // it doesn't really matter if the frame is marked as FIN.
+ // Capture it only for debugging only.
+ debug auto fin = (headerBuffer[0] & 0x80) != 0;
+
+ // RSV1, RSV2, RSV3
+ if ((headerBuffer[0] & 0x70) != 0) {
+ failConnection(CloseCode.ProtocolError);
+ throw new TTransportException("Reserved bits must be zeroes", TTransportException.Type.CORRUPTED_DATA);
+ }
+
+ Opcode opcode;
+ try {
+ opcode = to!Opcode(headerBuffer[0] & 0x0F);
+ } catch (ConvException) {
+ failConnection(CloseCode.ProtocolError);
+ throw new TTransportException("Unknown opcode", TTransportException.Type.CORRUPTED_DATA);
+ }
+
+ // Mask
+ if ((headerBuffer[1] & 0x80) == 0) {
+ failConnection(CloseCode.ProtocolError);
+ throw new TTransportException("Messages from the client must be masked", TTransportException.Type.CORRUPTED_DATA);
+ }
+
+ // Read the length
+ ulong payloadLength = headerBuffer[1] & 0x7F;
+ if (payloadLength == 126) {
+ read = transport_.read(headerBuffer[0..2]);
+ if (read < 2) {
+ return false;
+ }
+ payloadLength = bigEndianToNative!ushort(headerBuffer[0..2]);
+ } else if (payloadLength == 127) {
+ read = transport_.read(headerBuffer);
+ if (read < headerBuffer.length) {
+ return false;
+ }
+ payloadLength = bigEndianToNative!ulong(headerBuffer);
+ if ((payloadLength & 0x8000000000000000) != 0) {
+ failConnection(CloseCode.ProtocolError);
+ throw new TTransportException("The most significant bit of the payload length must be zero",
+ TTransportException.Type.CORRUPTED_DATA);
+ }
+ }
+
+ // size_t is smaller than a ulong on a 32-bit system
+ static if (size_t.max < ulong.max) {
+ if(payloadLength > size_t.max) {
+ failConnection(CloseCode.MessageTooBig);
+ return false;
+ }
+ }
+
+ auto length = cast(size_t)payloadLength;
+
+ if (length > 0) {
+ // Read the masking key
+ read = transport_.read(headerBuffer[0..4]);
+ if (read < 4) {
+ return false;
+ }
+
+ readBuffer_ = new ubyte[](length);
+ read = transport_.read(readBuffer_);
+ if (read < length) {
+ return false;
+ }
+
+ // Unmask the data
+ for (size_t i = 0; i < length; i++) {
+ readBuffer_[i] ^= headerBuffer[i % 4];
+ }
+
+ debug writef("FIN=%d, Opcode=%X, length=%d, payload=%s\n",
+ fin,
+ opcode,
+ length,
+ binary ? readBuffer_.toHexString() : cast(string)readBuffer_);
+ }
+
+ switch (opcode) {
+ case Opcode.Close:
+ debug {
+ if (length >= 2) {
+ CloseCode closeCode;
+ try {
+ closeCode = to!CloseCode(bigEndianToNative!ushort(readBuffer_[0..2]));
+ } catch (ConvException) {
+ closeCode = CloseCode.NoStatusCode;
+ }
+
+ string closeReason;
+ if (length == 2) {
+ closeReason = to!string(cast(CloseCode)closeCode);
+ } else {
+ closeReason = cast(string)readBuffer_[2..$];
+ }
+
+ writef("Connection closed: %d %s\n", closeCode, closeReason);
+ }
+ }
+ transport_.close();
+ return false;
+ case Opcode.Ping:
+ pong();
+ return readFrame();
+ default:
+ return true;
+ }
+ }
+
+ void resetHandshake() {
+ connection_ = false;
+ secWebSocketKey_ = false;
+ secWebSocketVersion_ = false;
+ upgrade_ = false;
+ }
+
+ void sendBadRequest() {
+ auto header = "HTTP/1.1 400 Bad Request\r\n" ~
+ "Server: Thrift/" ~ VERSION ~ "\r\n" ~
+ "\r\n";
+ transport_.write(cast(const(ubyte[]))header);
+ transport_.flush();
+ transport_.close();
+ }
+
+ void writeFrameHeader(Opcode opcode = Opcode.Continuation) {
+ size_t headerSize = 1;
+ if (writeBuffer_.length < 126) {
+ ++headerSize;
+ } else if (writeBuffer_.length < 65536) {
+ headerSize += 3;
+ } else {
+ headerSize += 9;
+ }
+ // The server does not mask the response
+
+ ubyte[] header = new ubyte[headerSize];
+ if (opcode == Opcode.Continuation) {
+ header[0] = binary ? Opcode.Binary : Opcode.Text;
+ }
+ else {
+ header[0] = opcode;
+ }
+ header[0] |= 0x80;
+ if (writeBuffer_.length < 126) {
+ header[1] = cast(ubyte)writeBuffer_.length;
+ } else if (writeBuffer_.length < 65536) {
+ header[1] = 126;
+ header[2..4] = nativeToBigEndian(cast(ushort)writeBuffer_.length);
+ } else {
+ header[1] = 127;
+ header[2..10] = nativeToBigEndian(cast(ulong)writeBuffer_.length);
+ }
+
+ transport_.write(header);
+ }
+
+ enum WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+ TTransport transport_;
+
+ string acceptKey_;
+ bool connection_;
+ bool secWebSocketKey_;
+ bool secWebSocketVersion_;
+ bool upgrade_;
+ ubyte[] readBuffer_;
+ ubyte[] writeBuffer_;
+}
+
+class TServerWebSocketTransportFactory(bool binary) : TTransportFactory {
+ override TTransport getTransport(TTransport trans) {
+ return new TServerWebSocketTransport!binary(trans);
+ }
+}
+
+alias TServerBinaryWebSocketTransportFactory = TServerWebSocketTransportFactory!true;
+alias TServerTextWebSocketTransportFactory = TServerWebSocketTransportFactory!false;
+
+private enum CloseCode : ushort {
+ NormalClosure = 1000,
+ GoingAway = 1001,
+ ProtocolError = 1002,
+ UnsupportedDataType = 1003,
+ NoStatusCode = 1005,
+ AbnormalClosure = 1006,
+ InvalidData = 1007,
+ PolicyViolation = 1008,
+ MessageTooBig = 1009,
+ ExtensionExpected = 1010,
+ UnexpectedError = 1011,
+ NotSecure = 1015
+}
+
+private enum Opcode : ubyte {
+ Continuation = 0x0,
+ Text = 0x1,
+ Binary = 0x2,
+ Close = 0x8,
+ Ping = 0x9,
+ Pong = 0xA
+}