diff options
author | James Lacey <jamlacey@gmail.com> | 2020-04-06 09:17:59 -0700 |
---|---|---|
committer | Jens Geyer <jensg@apache.org> | 2020-04-09 21:52:46 +0200 |
commit | 6bbdb1a46ce6ba0ac4e27e29b2c9c9eef107186c (patch) | |
tree | 7c0d2a28567427313cb857e607460c87afba52cd /lib/d | |
parent | f44b6ee8db9342d804c6ba01da9953e791021bfe (diff) | |
download | thrift-6bbdb1a46ce6ba0ac4e27e29b2c9c9eef107186c.tar.gz |
THRIFT-5166: Add support for using WebSockets as a server transport.
Client: d
Patch: James Lacey
This closes #2087
Diffstat (limited to 'lib/d')
-rw-r--r-- | lib/d/src/thrift/transport/websocket.d | 388 |
1 files changed, 388 insertions, 0 deletions
diff --git a/lib/d/src/thrift/transport/websocket.d b/lib/d/src/thrift/transport/websocket.d new file mode 100644 index 000000000..25d1f7d01 --- /dev/null +++ b/lib/d/src/thrift/transport/websocket.d @@ -0,0 +1,388 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +module thrift.transport.websocket; + +import std.algorithm; +import std.algorithm.searching; +import std.base64; +import std.bitmanip; +import std.conv; +import std.digest.sha; +import std.stdio; +import std.string; +import std.uni; +import thrift.base : VERSION; +import thrift.transport.base; +import thrift.transport.http; + +/** + * WebSocket server transport. + */ +final class TServerWebSocketTransport(bool binary) : THttpTransport { + /** + * Constructs a new instance. + * + * Param: + * transport = The underlying transport used for the actual I/O. + */ + this(TTransport transport) { + super(transport); + transport_ = transport; + } + + override size_t read(ubyte[] buf) { + // If we do not have a good handshake, the client will attempt one. + if (!handshakeComplete) { + resetHandshake(); + super.read(buf); + // If we did not get everything we expected, the handshake failed + // and we need to send a 400 response back. + if (!handshakeComplete) { + sendBadRequest(); + return 0; + } + // Otherwise, send back the 101 response. + super.flush(); + } + + // If the buffer is empty, read a new frame off the wire. + if (readBuffer_.empty) { + if (!readFrame()) { + return 0; + } + } + + auto size = min(readBuffer_.length, buf.length); + buf[0..size] = readBuffer_[0..size]; + readBuffer_ = readBuffer_[size..$]; + return size; + } + + override void write(in ubyte[] buf) { + writeBuffer_ ~= buf; + } + + override void flush() { + if (writeBuffer_.empty) { + return; + } + + // Properly reset the write buffer even some of the protocol operations go + // wrong. + scope (exit) { + writeBuffer_.length = 0; + writeBuffer_.assumeSafeAppend(); + } + + writeFrameHeader(); + transport_.write(writeBuffer_); + transport_.flush(); + } + +protected: + override string getHeader(size_t dataLength) { + return "HTTP/1.1 101 Switching Protocols\r\n" ~ + "Server: Thrift/" ~ VERSION ~ "\r\n" ~ + "Upgrade: websocket\r\n" ~ + "Connection: Upgrade\r\n" ~ + "Sec-WebSocket-Accept: " ~ acceptKey_ ~ "\r\n" ~ + "\r\n"; + } + + override void parseHeader(const(ubyte)[] header) { + auto split = findSplit(header, [':']); + if (split[1].empty) { + // No colon found. + return; + } + + static bool compToLower(ubyte a, ubyte b) { + return toLower(a) == toLower(b); + } + + if (startsWith!compToLower(split[0], cast(ubyte[])"upgrade")) { + auto upgrade = stripLeft(cast(const(char)[])split[2]); + upgrade_ = sicmp(upgrade, "websocket") == 0; + } else if (startsWith!compToLower(split[0], cast(ubyte[])"connection")) { + auto connection = stripLeft(cast(const(char)[])split[2]); + connection_ = sicmp(connection, "upgrade") == 0; + } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-key")) { + auto secWebSocketKey = stripLeft(cast(const(char)[])split[2]); + auto hash = sha1Of(secWebSocketKey ~ WEBSOCKET_GUID); + acceptKey_ = Base64.encode(hash); + secWebSocketKey_ = true; + } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-version")) { + auto secWebSocketVersion = stripLeft(cast(const(char)[])split[2]); + secWebSocketVersion_ = sicmp(secWebSocketVersion, "13") == 0; + } + } + + override bool parseStatusLine(const(ubyte)[] status) { + // Method SP Request-URI SP HTTP-Version CRLF. + auto split = findSplit(status, [' ']); + if (split[1].empty) { + throw new TTransportException("Bad status: " ~ to!string(status), + TTransportException.Type.CORRUPTED_DATA); + } + + auto uriVersion = split[2][countUntil!"a != b"(split[2], ' ') .. $]; + if (!canFind(uriVersion, ' ')) { + throw new TTransportException("Bad status: " ~ to!string(status), + TTransportException.Type.CORRUPTED_DATA); + } + + if (split[0] == "GET") { + // GET method ok, looking for content. + return true; + } + + throw new TTransportException("Bad status (unsupported method): " ~ + to!string(status), TTransportException.Type.CORRUPTED_DATA); + } + +private: + @property bool handshakeComplete() { + return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_; + } + + void failConnection(CloseCode reason) { + writeFrameHeader(Opcode.Close); + transport_.write(nativeToBigEndian!ushort(reason)); + transport_.flush(); + transport_.close(); + } + + void pong() { + writeFrameHeader(Opcode.Pong); + transport_.write(readBuffer_); + transport_.flush(); + } + + bool readFrame() { + ubyte[8] headerBuffer; + + auto read = transport_.read(headerBuffer[0..2]); + if (read < 2) { + return false; + } + // Since Thrift has its own message end marker and we read frame by frame, + // it doesn't really matter if the frame is marked as FIN. + // Capture it only for debugging only. + debug auto fin = (headerBuffer[0] & 0x80) != 0; + + // RSV1, RSV2, RSV3 + if ((headerBuffer[0] & 0x70) != 0) { + failConnection(CloseCode.ProtocolError); + throw new TTransportException("Reserved bits must be zeroes", TTransportException.Type.CORRUPTED_DATA); + } + + Opcode opcode; + try { + opcode = to!Opcode(headerBuffer[0] & 0x0F); + } catch (ConvException) { + failConnection(CloseCode.ProtocolError); + throw new TTransportException("Unknown opcode", TTransportException.Type.CORRUPTED_DATA); + } + + // Mask + if ((headerBuffer[1] & 0x80) == 0) { + failConnection(CloseCode.ProtocolError); + throw new TTransportException("Messages from the client must be masked", TTransportException.Type.CORRUPTED_DATA); + } + + // Read the length + ulong payloadLength = headerBuffer[1] & 0x7F; + if (payloadLength == 126) { + read = transport_.read(headerBuffer[0..2]); + if (read < 2) { + return false; + } + payloadLength = bigEndianToNative!ushort(headerBuffer[0..2]); + } else if (payloadLength == 127) { + read = transport_.read(headerBuffer); + if (read < headerBuffer.length) { + return false; + } + payloadLength = bigEndianToNative!ulong(headerBuffer); + if ((payloadLength & 0x8000000000000000) != 0) { + failConnection(CloseCode.ProtocolError); + throw new TTransportException("The most significant bit of the payload length must be zero", + TTransportException.Type.CORRUPTED_DATA); + } + } + + // size_t is smaller than a ulong on a 32-bit system + static if (size_t.max < ulong.max) { + if(payloadLength > size_t.max) { + failConnection(CloseCode.MessageTooBig); + return false; + } + } + + auto length = cast(size_t)payloadLength; + + if (length > 0) { + // Read the masking key + read = transport_.read(headerBuffer[0..4]); + if (read < 4) { + return false; + } + + readBuffer_ = new ubyte[](length); + read = transport_.read(readBuffer_); + if (read < length) { + return false; + } + + // Unmask the data + for (size_t i = 0; i < length; i++) { + readBuffer_[i] ^= headerBuffer[i % 4]; + } + + debug writef("FIN=%d, Opcode=%X, length=%d, payload=%s\n", + fin, + opcode, + length, + binary ? readBuffer_.toHexString() : cast(string)readBuffer_); + } + + switch (opcode) { + case Opcode.Close: + debug { + if (length >= 2) { + CloseCode closeCode; + try { + closeCode = to!CloseCode(bigEndianToNative!ushort(readBuffer_[0..2])); + } catch (ConvException) { + closeCode = CloseCode.NoStatusCode; + } + + string closeReason; + if (length == 2) { + closeReason = to!string(cast(CloseCode)closeCode); + } else { + closeReason = cast(string)readBuffer_[2..$]; + } + + writef("Connection closed: %d %s\n", closeCode, closeReason); + } + } + transport_.close(); + return false; + case Opcode.Ping: + pong(); + return readFrame(); + default: + return true; + } + } + + void resetHandshake() { + connection_ = false; + secWebSocketKey_ = false; + secWebSocketVersion_ = false; + upgrade_ = false; + } + + void sendBadRequest() { + auto header = "HTTP/1.1 400 Bad Request\r\n" ~ + "Server: Thrift/" ~ VERSION ~ "\r\n" ~ + "\r\n"; + transport_.write(cast(const(ubyte[]))header); + transport_.flush(); + transport_.close(); + } + + void writeFrameHeader(Opcode opcode = Opcode.Continuation) { + size_t headerSize = 1; + if (writeBuffer_.length < 126) { + ++headerSize; + } else if (writeBuffer_.length < 65536) { + headerSize += 3; + } else { + headerSize += 9; + } + // The server does not mask the response + + ubyte[] header = new ubyte[headerSize]; + if (opcode == Opcode.Continuation) { + header[0] = binary ? Opcode.Binary : Opcode.Text; + } + else { + header[0] = opcode; + } + header[0] |= 0x80; + if (writeBuffer_.length < 126) { + header[1] = cast(ubyte)writeBuffer_.length; + } else if (writeBuffer_.length < 65536) { + header[1] = 126; + header[2..4] = nativeToBigEndian(cast(ushort)writeBuffer_.length); + } else { + header[1] = 127; + header[2..10] = nativeToBigEndian(cast(ulong)writeBuffer_.length); + } + + transport_.write(header); + } + + enum WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + TTransport transport_; + + string acceptKey_; + bool connection_; + bool secWebSocketKey_; + bool secWebSocketVersion_; + bool upgrade_; + ubyte[] readBuffer_; + ubyte[] writeBuffer_; +} + +class TServerWebSocketTransportFactory(bool binary) : TTransportFactory { + override TTransport getTransport(TTransport trans) { + return new TServerWebSocketTransport!binary(trans); + } +} + +alias TServerBinaryWebSocketTransportFactory = TServerWebSocketTransportFactory!true; +alias TServerTextWebSocketTransportFactory = TServerWebSocketTransportFactory!false; + +private enum CloseCode : ushort { + NormalClosure = 1000, + GoingAway = 1001, + ProtocolError = 1002, + UnsupportedDataType = 1003, + NoStatusCode = 1005, + AbnormalClosure = 1006, + InvalidData = 1007, + PolicyViolation = 1008, + MessageTooBig = 1009, + ExtensionExpected = 1010, + UnexpectedError = 1011, + NotSecure = 1015 +} + +private enum Opcode : ubyte { + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xA +} |