From bc99911f0bbe0d0c18d46bd0ad44c0b136a162ff Mon Sep 17 00:00:00 2001 From: Spencer Jackson Date: Thu, 15 Feb 2018 15:30:46 -0500 Subject: SERVER-33329: Make server and shell emit TLS protocol_version alerts (cherry picked from commit 51af489a86f1862de87b51f26a9e818ec3b5df04) (cherry picked from commit 56e653fdd204e1ad091e0736454aefc005b5ce3f) --- jstests/ssl/ssl_alert_reporting.js | 48 +++++++++++ src/mongo/util/net/message_port.cpp | 12 +++ src/mongo/util/net/ssl_manager.cpp | 158 ++++++++++++++++++++++++++++++++++-- src/mongo/util/net/ssl_manager.h | 6 ++ 4 files changed, 216 insertions(+), 8 deletions(-) create mode 100644 jstests/ssl/ssl_alert_reporting.js diff --git a/jstests/ssl/ssl_alert_reporting.js b/jstests/ssl/ssl_alert_reporting.js new file mode 100644 index 00000000000..f5ca5650896 --- /dev/null +++ b/jstests/ssl/ssl_alert_reporting.js @@ -0,0 +1,48 @@ +// Ensure that TLS version alerts are correctly propagated + +load('jstests/ssl/libs/ssl_helpers.js'); + +(function() { + 'use strict'; + + const clientOptions = [ + "--ssl", + "--sslPEMKeyFile", + "jstests/libs/client.pem", + "--sslCAFile", + "jstests/libs/ca.pem", + "--eval", + ";" + ]; + + function runTest(serverDisabledProtos, clientDisabledProtos) { + let expectedRegex = /tlsv1 alert protocol version/; + + var md = MongoRunner.runMongod({ + nopreallocj: "", + sslMode: "requireSSL", + sslCAFile: "jstests/libs/ca.pem", + sslPEMKeyFile: "jstests/libs/server.pem", + sslDisabledProtocols: serverDisabledProtos, + waitForConnect: false, + }); + + assert.soon(function() { + clearRawMongoProgramOutput(); + let shell = runMongoProgram("mongo", + "--port", + md.port, + ...clientOptions, + "--sslDisabledProtocols", + clientDisabledProtos); + let mongoOutput = rawMongoProgramOutput(); + return mongoOutput.match(expectedRegex); + }); + + MongoRunner.stopMongod(md); + } + + // Client recieves and reports a protocol version alert if it advertises a protocol older than + // the server's oldest supported protocol + runTest("TLS1_0", "TLS1_1,TLS1_2"); +}()); diff --git a/src/mongo/util/net/message_port.cpp b/src/mongo/util/net/message_port.cpp index 0f4f543087e..28c81da19ff 100644 --- a/src/mongo/util/net/message_port.cpp +++ b/src/mongo/util/net/message_port.cpp @@ -126,6 +126,18 @@ bool MessagingPort::recv(Message& m) { uassert(17132, "SSL handshake received but server is started without SSL support", sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled); + + auto tlsAlert = checkTLSRequest( + ConstDataRange(reinterpret_cast(&header), sizeof(header))); + + if (tlsAlert) { + _psock->send(reinterpret_cast(tlsAlert->data()), + tlsAlert->size(), + "tls protocol mismatch"); + log() << "SSL handshake failed, as client requested disabled protocol"; + return false; + } + setX509PeerInfo( _psock->doSSLHandshake(reinterpret_cast(&header), sizeof(header))); LOG(1) << "new ssl connection, SNI server name [" << _psock->getSNIServerName() diff --git a/src/mongo/util/net/ssl_manager.cpp b/src/mongo/util/net/ssl_manager.cpp index 4bb0befb4b0..025f4932867 100644 --- a/src/mongo/util/net/ssl_manager.cpp +++ b/src/mongo/util/net/ssl_manager.cpp @@ -53,6 +53,7 @@ #include "mongo/util/exit.h" #include "mongo/util/log.h" #include "mongo/util/mongoutils/str.h" +#include "mongo/util/net/message.h" #include "mongo/util/net/sock.h" #include "mongo/util/net/socket_exception.h" #include "mongo/util/net/ssl_expiration.h" @@ -327,7 +328,7 @@ private: * Given an error code from an SSL-type IO function, logs an * appropriate message and throws a SocketException */ - MONGO_COMPILER_NORETURN void _handleSSLError(int code, int ret); + MONGO_COMPILER_NORETURN void _handleSSLError(SSLConnection* conn, int ret); /* * Init the SSL context using parameters provided in params. This SSL context will @@ -625,7 +626,7 @@ int SSLManager::SSL_read(SSLConnection* conn, void* buf, int num) { } while (!_doneWithSSLOp(conn, status)); if (status <= 0) - _handleSSLError(SSL_get_error(conn, status), status); + _handleSSLError(conn, status); return status; } @@ -636,7 +637,7 @@ int SSLManager::SSL_write(SSLConnection* conn, const void* buf, int num) { } while (!_doneWithSSLOp(conn, status)); if (status <= 0) - _handleSSLError(SSL_get_error(conn, status), status); + _handleSSLError(conn, status); return status; } @@ -659,7 +660,7 @@ int SSLManager::SSL_shutdown(SSLConnection* conn) { } while (!_doneWithSSLOp(conn, status)); if (status < 0) - _handleSSLError(SSL_get_error(conn, status), status); + _handleSSLError(conn, status); return status; } @@ -1166,14 +1167,14 @@ SSLConnection* SSLManager::connect(Socket* socket) { int ret = ::SSL_set_tlsext_host_name(sslConn->ssl, socket->remoteAddr().hostOrIp().c_str()); if (ret != 1) - _handleSSLError(SSL_get_error(sslConn.get(), ret), ret); + _handleSSLError(sslConn.get(), ret); do { ret = ::SSL_connect(sslConn->ssl); } while (!_doneWithSSLOp(sslConn.get(), ret)); if (ret != 1) - _handleSSLError(SSL_get_error(sslConn.get(), ret), ret); + _handleSSLError(sslConn.get(), ret); return sslConn.release(); } @@ -1188,7 +1189,7 @@ SSLConnection* SSLManager::accept(Socket* socket, const char* initialBytes, int } while (!_doneWithSSLOp(sslConn.get(), ret)); if (ret != 1) - _handleSSLError(SSL_get_error(sslConn.get(), ret), ret); + _handleSSLError(sslConn.get(), ret); return sslConn.release(); } @@ -1459,7 +1460,8 @@ std::string SSLManagerInterface::getSSLErrorMessage(int code) { return msg; } -void SSLManager::_handleSSLError(int code, int ret) { +void SSLManager::_handleSSLError(SSLConnection* conn, int ret) { + int code = SSL_get_error(conn, ret); int err = ERR_get_error(); switch (code) { @@ -1496,8 +1498,148 @@ void SSLManager::_handleSSLError(int code, int ret) { error() << "unrecognized SSL error"; break; } + _flushNetworkBIO(conn); throw SocketException(SocketException::CONNECT_ERROR, ""); } + +boost::optional> checkTLSRequest(ConstDataRange dataRange) { + // This method's caller should have read in at least one MSGHEADER::Value's worth of data. + // The fragment we are about to examine must be strictly smaller. + static const size_t sizeOfTLSFragmentToRead = 11; + invariant(dataRange.length() >= sizeOfTLSFragmentToRead); + + static_assert(sizeOfTLSFragmentToRead < sizeof(MSGHEADER::Value), + "checkTLSRequest's caller read a MSGHEADER::Value, which must be larger than " + "message containing the TLS version"); + + ConstDataRangeCursor cdr(dataRange); + + /** + * The fragment we are to examine is a record, containing a handshake, containing a + * ClientHello. We wish to examine the advertised protocol version in the ClientHello. + * The following roughly describes the contents of these structures. Note that we do not + * need, or wish to, examine the entire ClientHello, we're looking exclusively for the + * client_version. + * + * Below is a rough description of the payload we will be examining. We shall perform some + * basic checks to ensure the payload matches these expectations. If it does not, we should + * bail out, and not emit protocol version alerts. + * + * enum {alert(21), handshake(22)} ContentType; + * TLSPlaintext { + * ContentType type = handshake(22), + * ProtocolVersion version; // Irrelevant. Clients send the real version in ClientHello. + * uint16 length; + * fragment, see Handshake stuct for contents + * ... + * } + * + * enum {client_hello(1)} HandshakeType; + * Handshake { + * HandshakeType msg_type = client_hello(1); + * uint24_t length; + * ClientHello body; + * } + * + * ClientHello { + * ProtocolVersion client_version; // <- This is the value we want to extract. + * } + */ + + static const std::uint8_t ContentType_handshake = 22; + static const std::uint8_t HandshakeType_client_hello = 1; + + using ProtocolVersion = std::array; + static const ProtocolVersion tls10VersionBytes{3, 1}; + static const ProtocolVersion tls11VersionBytes{3, 2}; + + // Parse the record header. + // Extract the ContentType from the header, and ensure it is a handshake. + StatusWith record_ContentType = cdr.readAndAdvance(); + if (!record_ContentType.isOK() || record_ContentType.getValue() != ContentType_handshake) { + return boost::none; + } + // Skip the record's ProtocolVersion. Clients tend to send TLS 1.0 in + // the record, but then their real protocol version in the enclosed ClientHello. + StatusWith record_protocol_version = cdr.readAndAdvance(); + if (!record_protocol_version.isOK()) { + return boost::none; + } + // Parse the record length. It should be be larger than the remaining expected payload. + auto record_length = cdr.readAndAdvance>(); + if (!record_length.isOK() || record_length.getValue() < cdr.length()) { + return boost::none; + } + + // Parse the handshake header. + // Extract the HandshakeType, and ensure it is a ClientHello. + StatusWith handshake_type = cdr.readAndAdvance(); + if (!handshake_type.isOK() || handshake_type.getValue() != HandshakeType_client_hello) { + return boost::none; + } + // Extract the handshake length, and ensure it is larger than the remaining expected + // payload. This requires a little work because the packet represents it with a uint24_t. + StatusWith> handshake_length_bytes = + cdr.readAndAdvance>(); + if (!handshake_length_bytes.isOK()) { + return boost::none; + } + std::uint32_t handshake_length = 0; + for (std::uint8_t handshake_byte : handshake_length_bytes.getValue()) { + handshake_length <<= 8; + handshake_length |= handshake_byte; + } + if (handshake_length < cdr.length()) { + return boost::none; + } + StatusWith client_version = cdr.readAndAdvance(); + if (!client_version.isOK()) { + return boost::none; + } + + // Invariant: We read exactly as much data as expected. + invariant(cdr.data() - dataRange.data() == sizeOfTLSFragmentToRead); + + auto isProtocolDisabled = [](SSLParams::Protocols protocol) { + const auto& params = getSSLGlobalParams(); + return std::find(params.sslDisabledProtocols.begin(), + params.sslDisabledProtocols.end(), + protocol) != params.sslDisabledProtocols.end(); + }; + + auto makeTLSProtocolVersionAlert = + [](const std::array& versionBytes) -> std::array { + /** + * The structure for this alert packet is as follows: + * TLSPlaintext { + * ContentType type = alert(21); + * ProtocolVersion = versionBytes; + * uint16_t length = 2 + * fragment = AlertDescription { + * AlertLevel level = fatal(2); + * AlertDescription = protocol_version(70); + * } + * + */ + return std::array{ + 0x15, versionBytes[0], versionBytes[1], 0x00, 0x02, 0x02, 0x46}; + }; + + ProtocolVersion version = client_version.getValue(); + if (version == tls10VersionBytes && isProtocolDisabled(SSLParams::Protocols::TLS1_0)) { + return makeTLSProtocolVersionAlert(version); + } else if (client_version == tls11VersionBytes && + isProtocolDisabled(SSLParams::Protocols::TLS1_1)) { + return makeTLSProtocolVersionAlert(version); + } + // TLS1.2 cannot be distinguished from TLS1.3, just by looking at the ProtocolVersion bytes. + // TLS 1.3 compatible clients advertise a "supported_versions" extension, which we would + // have to extract here. + // Hopefully by the time this matters, OpenSSL will properly emit protocol_version alerts. + + return boost::none; +} + #else MONGO_INITIALIZER(SSLManager)(InitializerContext*) { diff --git a/src/mongo/util/net/ssl_manager.h b/src/mongo/util/net/ssl_manager.h index ef7ad5c403b..bbdafdabff1 100644 --- a/src/mongo/util/net/ssl_manager.h +++ b/src/mongo/util/net/ssl_manager.h @@ -199,5 +199,11 @@ extern bool isSSLServer; * "EndStartupOptionStorage" as a prerequisite. */ const SSLParams& getSSLGlobalParams(); + +/** + * Peeks at a fragment of a client issued TLS handshake packet. Returns a TLS alert + * packet if the client has selected a protocol which has been disabled by the server. + */ +boost::optional> checkTLSRequest(ConstDataRange cdr); } #endif // #ifdef MONGO_CONFIG_SSL -- cgit v1.2.1