summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSpencer Jackson <spencer.jackson@mongodb.com>2018-02-15 15:30:46 -0500
committerSpencer Jackson <spencer.jackson@mongodb.com>2018-05-04 17:45:29 -0400
commitbc99911f0bbe0d0c18d46bd0ad44c0b136a162ff (patch)
tree9e28da8adbdb7f8357080f374618abea934c34c1
parenta5923c25181622e8374c6891770267c9735bc3f1 (diff)
downloadmongo-bc99911f0bbe0d0c18d46bd0ad44c0b136a162ff.tar.gz
SERVER-33329: Make server and shell emit TLS protocol_version alerts
(cherry picked from commit 51af489a86f1862de87b51f26a9e818ec3b5df04) (cherry picked from commit 56e653fdd204e1ad091e0736454aefc005b5ce3f)
-rw-r--r--jstests/ssl/ssl_alert_reporting.js48
-rw-r--r--src/mongo/util/net/message_port.cpp12
-rw-r--r--src/mongo/util/net/ssl_manager.cpp158
-rw-r--r--src/mongo/util/net/ssl_manager.h6
4 files changed, 216 insertions, 8 deletions
diff --git a/jstests/ssl/ssl_alert_reporting.js b/jstests/ssl/ssl_alert_reporting.js
new file mode 100644
index 00000000000..f5ca5650896
--- /dev/null
+++ b/jstests/ssl/ssl_alert_reporting.js
@@ -0,0 +1,48 @@
+// Ensure that TLS version alerts are correctly propagated
+
+load('jstests/ssl/libs/ssl_helpers.js');
+
+(function() {
+ 'use strict';
+
+ const clientOptions = [
+ "--ssl",
+ "--sslPEMKeyFile",
+ "jstests/libs/client.pem",
+ "--sslCAFile",
+ "jstests/libs/ca.pem",
+ "--eval",
+ ";"
+ ];
+
+ function runTest(serverDisabledProtos, clientDisabledProtos) {
+ let expectedRegex = /tlsv1 alert protocol version/;
+
+ var md = MongoRunner.runMongod({
+ nopreallocj: "",
+ sslMode: "requireSSL",
+ sslCAFile: "jstests/libs/ca.pem",
+ sslPEMKeyFile: "jstests/libs/server.pem",
+ sslDisabledProtocols: serverDisabledProtos,
+ waitForConnect: false,
+ });
+
+ assert.soon(function() {
+ clearRawMongoProgramOutput();
+ let shell = runMongoProgram("mongo",
+ "--port",
+ md.port,
+ ...clientOptions,
+ "--sslDisabledProtocols",
+ clientDisabledProtos);
+ let mongoOutput = rawMongoProgramOutput();
+ return mongoOutput.match(expectedRegex);
+ });
+
+ MongoRunner.stopMongod(md);
+ }
+
+ // Client recieves and reports a protocol version alert if it advertises a protocol older than
+ // the server's oldest supported protocol
+ runTest("TLS1_0", "TLS1_1,TLS1_2");
+}());
diff --git a/src/mongo/util/net/message_port.cpp b/src/mongo/util/net/message_port.cpp
index 0f4f543087e..28c81da19ff 100644
--- a/src/mongo/util/net/message_port.cpp
+++ b/src/mongo/util/net/message_port.cpp
@@ -126,6 +126,18 @@ bool MessagingPort::recv(Message& m) {
uassert(17132,
"SSL handshake received but server is started without SSL support",
sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled);
+
+ auto tlsAlert = checkTLSRequest(
+ ConstDataRange(reinterpret_cast<const char*>(&header), sizeof(header)));
+
+ if (tlsAlert) {
+ _psock->send(reinterpret_cast<const char*>(tlsAlert->data()),
+ tlsAlert->size(),
+ "tls protocol mismatch");
+ log() << "SSL handshake failed, as client requested disabled protocol";
+ return false;
+ }
+
setX509PeerInfo(
_psock->doSSLHandshake(reinterpret_cast<const char*>(&header), sizeof(header)));
LOG(1) << "new ssl connection, SNI server name [" << _psock->getSNIServerName()
diff --git a/src/mongo/util/net/ssl_manager.cpp b/src/mongo/util/net/ssl_manager.cpp
index 4bb0befb4b0..025f4932867 100644
--- a/src/mongo/util/net/ssl_manager.cpp
+++ b/src/mongo/util/net/ssl_manager.cpp
@@ -53,6 +53,7 @@
#include "mongo/util/exit.h"
#include "mongo/util/log.h"
#include "mongo/util/mongoutils/str.h"
+#include "mongo/util/net/message.h"
#include "mongo/util/net/sock.h"
#include "mongo/util/net/socket_exception.h"
#include "mongo/util/net/ssl_expiration.h"
@@ -327,7 +328,7 @@ private:
* Given an error code from an SSL-type IO function, logs an
* appropriate message and throws a SocketException
*/
- MONGO_COMPILER_NORETURN void _handleSSLError(int code, int ret);
+ MONGO_COMPILER_NORETURN void _handleSSLError(SSLConnection* conn, int ret);
/*
* Init the SSL context using parameters provided in params. This SSL context will
@@ -625,7 +626,7 @@ int SSLManager::SSL_read(SSLConnection* conn, void* buf, int num) {
} while (!_doneWithSSLOp(conn, status));
if (status <= 0)
- _handleSSLError(SSL_get_error(conn, status), status);
+ _handleSSLError(conn, status);
return status;
}
@@ -636,7 +637,7 @@ int SSLManager::SSL_write(SSLConnection* conn, const void* buf, int num) {
} while (!_doneWithSSLOp(conn, status));
if (status <= 0)
- _handleSSLError(SSL_get_error(conn, status), status);
+ _handleSSLError(conn, status);
return status;
}
@@ -659,7 +660,7 @@ int SSLManager::SSL_shutdown(SSLConnection* conn) {
} while (!_doneWithSSLOp(conn, status));
if (status < 0)
- _handleSSLError(SSL_get_error(conn, status), status);
+ _handleSSLError(conn, status);
return status;
}
@@ -1166,14 +1167,14 @@ SSLConnection* SSLManager::connect(Socket* socket) {
int ret = ::SSL_set_tlsext_host_name(sslConn->ssl, socket->remoteAddr().hostOrIp().c_str());
if (ret != 1)
- _handleSSLError(SSL_get_error(sslConn.get(), ret), ret);
+ _handleSSLError(sslConn.get(), ret);
do {
ret = ::SSL_connect(sslConn->ssl);
} while (!_doneWithSSLOp(sslConn.get(), ret));
if (ret != 1)
- _handleSSLError(SSL_get_error(sslConn.get(), ret), ret);
+ _handleSSLError(sslConn.get(), ret);
return sslConn.release();
}
@@ -1188,7 +1189,7 @@ SSLConnection* SSLManager::accept(Socket* socket, const char* initialBytes, int
} while (!_doneWithSSLOp(sslConn.get(), ret));
if (ret != 1)
- _handleSSLError(SSL_get_error(sslConn.get(), ret), ret);
+ _handleSSLError(sslConn.get(), ret);
return sslConn.release();
}
@@ -1459,7 +1460,8 @@ std::string SSLManagerInterface::getSSLErrorMessage(int code) {
return msg;
}
-void SSLManager::_handleSSLError(int code, int ret) {
+void SSLManager::_handleSSLError(SSLConnection* conn, int ret) {
+ int code = SSL_get_error(conn, ret);
int err = ERR_get_error();
switch (code) {
@@ -1496,8 +1498,148 @@ void SSLManager::_handleSSLError(int code, int ret) {
error() << "unrecognized SSL error";
break;
}
+ _flushNetworkBIO(conn);
throw SocketException(SocketException::CONNECT_ERROR, "");
}
+
+boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(ConstDataRange dataRange) {
+ // This method's caller should have read in at least one MSGHEADER::Value's worth of data.
+ // The fragment we are about to examine must be strictly smaller.
+ static const size_t sizeOfTLSFragmentToRead = 11;
+ invariant(dataRange.length() >= sizeOfTLSFragmentToRead);
+
+ static_assert(sizeOfTLSFragmentToRead < sizeof(MSGHEADER::Value),
+ "checkTLSRequest's caller read a MSGHEADER::Value, which must be larger than "
+ "message containing the TLS version");
+
+ ConstDataRangeCursor cdr(dataRange);
+
+ /**
+ * The fragment we are to examine is a record, containing a handshake, containing a
+ * ClientHello. We wish to examine the advertised protocol version in the ClientHello.
+ * The following roughly describes the contents of these structures. Note that we do not
+ * need, or wish to, examine the entire ClientHello, we're looking exclusively for the
+ * client_version.
+ *
+ * Below is a rough description of the payload we will be examining. We shall perform some
+ * basic checks to ensure the payload matches these expectations. If it does not, we should
+ * bail out, and not emit protocol version alerts.
+ *
+ * enum {alert(21), handshake(22)} ContentType;
+ * TLSPlaintext {
+ * ContentType type = handshake(22),
+ * ProtocolVersion version; // Irrelevant. Clients send the real version in ClientHello.
+ * uint16 length;
+ * fragment, see Handshake stuct for contents
+ * ...
+ * }
+ *
+ * enum {client_hello(1)} HandshakeType;
+ * Handshake {
+ * HandshakeType msg_type = client_hello(1);
+ * uint24_t length;
+ * ClientHello body;
+ * }
+ *
+ * ClientHello {
+ * ProtocolVersion client_version; // <- This is the value we want to extract.
+ * }
+ */
+
+ static const std::uint8_t ContentType_handshake = 22;
+ static const std::uint8_t HandshakeType_client_hello = 1;
+
+ using ProtocolVersion = std::array<std::uint8_t, 2>;
+ static const ProtocolVersion tls10VersionBytes{3, 1};
+ static const ProtocolVersion tls11VersionBytes{3, 2};
+
+ // Parse the record header.
+ // Extract the ContentType from the header, and ensure it is a handshake.
+ StatusWith<std::uint8_t> record_ContentType = cdr.readAndAdvance<std::uint8_t>();
+ if (!record_ContentType.isOK() || record_ContentType.getValue() != ContentType_handshake) {
+ return boost::none;
+ }
+ // Skip the record's ProtocolVersion. Clients tend to send TLS 1.0 in
+ // the record, but then their real protocol version in the enclosed ClientHello.
+ StatusWith<ProtocolVersion> record_protocol_version = cdr.readAndAdvance<ProtocolVersion>();
+ if (!record_protocol_version.isOK()) {
+ return boost::none;
+ }
+ // Parse the record length. It should be be larger than the remaining expected payload.
+ auto record_length = cdr.readAndAdvance<BigEndian<std::uint16_t>>();
+ if (!record_length.isOK() || record_length.getValue() < cdr.length()) {
+ return boost::none;
+ }
+
+ // Parse the handshake header.
+ // Extract the HandshakeType, and ensure it is a ClientHello.
+ StatusWith<std::uint8_t> handshake_type = cdr.readAndAdvance<std::uint8_t>();
+ if (!handshake_type.isOK() || handshake_type.getValue() != HandshakeType_client_hello) {
+ return boost::none;
+ }
+ // Extract the handshake length, and ensure it is larger than the remaining expected
+ // payload. This requires a little work because the packet represents it with a uint24_t.
+ StatusWith<std::array<std::uint8_t, 3>> handshake_length_bytes =
+ cdr.readAndAdvance<std::array<std::uint8_t, 3>>();
+ if (!handshake_length_bytes.isOK()) {
+ return boost::none;
+ }
+ std::uint32_t handshake_length = 0;
+ for (std::uint8_t handshake_byte : handshake_length_bytes.getValue()) {
+ handshake_length <<= 8;
+ handshake_length |= handshake_byte;
+ }
+ if (handshake_length < cdr.length()) {
+ return boost::none;
+ }
+ StatusWith<ProtocolVersion> client_version = cdr.readAndAdvance<ProtocolVersion>();
+ if (!client_version.isOK()) {
+ return boost::none;
+ }
+
+ // Invariant: We read exactly as much data as expected.
+ invariant(cdr.data() - dataRange.data() == sizeOfTLSFragmentToRead);
+
+ auto isProtocolDisabled = [](SSLParams::Protocols protocol) {
+ const auto& params = getSSLGlobalParams();
+ return std::find(params.sslDisabledProtocols.begin(),
+ params.sslDisabledProtocols.end(),
+ protocol) != params.sslDisabledProtocols.end();
+ };
+
+ auto makeTLSProtocolVersionAlert =
+ [](const std::array<std::uint8_t, 2>& versionBytes) -> std::array<std::uint8_t, 7> {
+ /**
+ * The structure for this alert packet is as follows:
+ * TLSPlaintext {
+ * ContentType type = alert(21);
+ * ProtocolVersion = versionBytes;
+ * uint16_t length = 2
+ * fragment = AlertDescription {
+ * AlertLevel level = fatal(2);
+ * AlertDescription = protocol_version(70);
+ * }
+ *
+ */
+ return std::array<std::uint8_t, 7>{
+ 0x15, versionBytes[0], versionBytes[1], 0x00, 0x02, 0x02, 0x46};
+ };
+
+ ProtocolVersion version = client_version.getValue();
+ if (version == tls10VersionBytes && isProtocolDisabled(SSLParams::Protocols::TLS1_0)) {
+ return makeTLSProtocolVersionAlert(version);
+ } else if (client_version == tls11VersionBytes &&
+ isProtocolDisabled(SSLParams::Protocols::TLS1_1)) {
+ return makeTLSProtocolVersionAlert(version);
+ }
+ // TLS1.2 cannot be distinguished from TLS1.3, just by looking at the ProtocolVersion bytes.
+ // TLS 1.3 compatible clients advertise a "supported_versions" extension, which we would
+ // have to extract here.
+ // Hopefully by the time this matters, OpenSSL will properly emit protocol_version alerts.
+
+ return boost::none;
+}
+
#else
MONGO_INITIALIZER(SSLManager)(InitializerContext*) {
diff --git a/src/mongo/util/net/ssl_manager.h b/src/mongo/util/net/ssl_manager.h
index ef7ad5c403b..bbdafdabff1 100644
--- a/src/mongo/util/net/ssl_manager.h
+++ b/src/mongo/util/net/ssl_manager.h
@@ -199,5 +199,11 @@ extern bool isSSLServer;
* "EndStartupOptionStorage" as a prerequisite.
*/
const SSLParams& getSSLGlobalParams();
+
+/**
+ * Peeks at a fragment of a client issued TLS handshake packet. Returns a TLS alert
+ * packet if the client has selected a protocol which has been disabled by the server.
+ */
+boost::optional<std::array<std::uint8_t, 7>> checkTLSRequest(ConstDataRange cdr);
}
#endif // #ifdef MONGO_CONFIG_SSL