summaryrefslogtreecommitdiff
path: root/src/mongo/util/net/ssl_manager_windows.cpp
diff options
context:
space:
mode:
authorMark Benvenuto <mark.benvenuto@mongodb.com>2018-03-08 13:52:35 -0500
committerMark Benvenuto <mark.benvenuto@mongodb.com>2018-03-08 13:52:35 -0500
commite7dedcc3b42e3a601bd7080743d1e3c1d10e3cfa (patch)
tree52d41f3019672765b24152d87897c539f697af9f /src/mongo/util/net/ssl_manager_windows.cpp
parent8522318f0a92ce96b8d90e5bd5d52771b2d63c04 (diff)
downloadmongo-e7dedcc3b42e3a601bd7080743d1e3c1d10e3cfa.tar.gz
SERVER-22411 Refactor Windows Certificate PEM file loading
Diffstat (limited to 'src/mongo/util/net/ssl_manager_windows.cpp')
-rw-r--r--src/mongo/util/net/ssl_manager_windows.cpp284
1 files changed, 176 insertions, 108 deletions
diff --git a/src/mongo/util/net/ssl_manager_windows.cpp b/src/mongo/util/net/ssl_manager_windows.cpp
index 10c48d7c325..1f123cef0c7 100644
--- a/src/mongo/util/net/ssl_manager_windows.cpp
+++ b/src/mongo/util/net/ssl_manager_windows.cpp
@@ -33,13 +33,10 @@
#include "mongo/util/net/ssl_manager.h"
#include <asio.hpp>
-#include <boost/algorithm/string.hpp>
-#include <boost/date_time/posix_time/posix_time.hpp>
+#include <boost/algorithm/string/replace.hpp>
#include <fstream>
-#include <iostream>
-#include <sstream>
-#include <stack>
#include <string>
+#include <tuple>
#include <vector>
#include "mongo/base/init.h"
@@ -48,9 +45,7 @@
#include "mongo/config.h"
#include "mongo/db/server_options.h"
#include "mongo/db/server_parameters.h"
-#include "mongo/platform/atomic_word.h"
#include "mongo/stdx/memory.h"
-#include "mongo/transport/session.h"
#include "mongo/util/concurrency/mutex.h"
#include "mongo/util/debug_util.h"
#include "mongo/util/exit.h"
@@ -62,7 +57,6 @@
#include "mongo/util/net/ssl.hpp"
#include "mongo/util/net/ssl_options.h"
#include "mongo/util/net/ssl_types.h"
-#include "mongo/util/scopeguard.h"
#include "mongo/util/text.h"
#include "mongo/util/uuid.h"
@@ -158,6 +152,13 @@ struct CryptKeyFree {
using UniqueCryptKey = AutoHandle<HCRYPTKEY, CryptKeyFree>;
+/**
+ * The lifetime of a private key of a certificate loaded from a PEM is bound to the CryptContext's
+ * lifetime
+ * so we treat the certificate and cryptcontext as a pair.
+ */
+using UniqueCertificateWithPrivateKey = std::tuple<UniqueCertificate, UniqueCryptProvider>;
+
} // namespace
/**
@@ -216,6 +217,11 @@ public:
int SSL_shutdown(SSLConnectionInterface* conn) final;
private:
+ Status _loadCertificates(const SSLParams& params);
+
+ void _handshake(SSLConnectionWindows* conn, bool client);
+
+private:
bool _weakValidation;
bool _allowInvalidCertificates;
bool _allowInvalidHostnames;
@@ -224,14 +230,10 @@ private:
SCHANNEL_CRED _clientCred;
SCHANNEL_CRED _serverCred;
- UniqueCertificate _pemCertificate;
- UniqueCertificate _clusterPEMCertificate;
- PCCERT_CONTEXT _clientCertificates[1];
- PCCERT_CONTEXT _serverCertificates[1];
-
- Status loadCertificates(const SSLParams& params);
-
- void handshake(SSLConnectionWindows* conn, bool client);
+ UniqueCertificateWithPrivateKey _pemCertificate;
+ UniqueCertificateWithPrivateKey _clusterPEMCertificate;
+ std::array<PCCERT_CONTEXT, 1> _clientCertificates;
+ std::array<PCCERT_CONTEXT, 1> _serverCertificates;
};
// Global variable indicating if this is a server or a client instance
@@ -278,7 +280,7 @@ SSLManagerWindows::SSLManagerWindows(const SSLParams& params, bool isServer)
_allowInvalidCertificates(params.sslAllowInvalidCertificates),
_allowInvalidHostnames(params.sslAllowInvalidHostnames) {
- uassertStatusOK(loadCertificates(params));
+ uassertStatusOK(_loadCertificates(params));
uassertStatusOK(initSSLContext(&_clientCred, params, ConnectionDirection::kOutgoing));
@@ -380,8 +382,7 @@ int SSLManagerWindows::SSL_shutdown(SSLConnectionInterface* conn) {
return 0;
}
-StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData password) {
-
+StatusWith<std::string> readFile(StringData fileName) {
std::ifstream pemFile(fileName.toString(), std::ios::binary);
if (!pemFile.is_open()) {
return Status(ErrorCodes::InvalidSSLConfiguration,
@@ -392,60 +393,37 @@ StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData passwo
pemFile.close();
- // Search the buffer for the various strings that make up a PEM file
- size_t publicKey = buf.find("-----BEGIN CERTIFICATE-----");
- if (publicKey == std::string::npos) {
- return Status(ErrorCodes::InvalidSSLConfiguration,
- str::stream() << "Failed to find Certifiate in: " << fileName);
- }
+ return buf;
+}
- // TODO: decode encrypted pem
- // StringData encryptedPrivateKey = buf.find("-----BEGIN ENCRYPTED PRIVATE KEY-----");
+// Find a specific kind of PEM blob marked by BEGIN and END in a string
+StatusWith<StringData> findPEMBlob(StringData blob, StringData type, size_t position = 0) {
+ std::string header = str::stream() << "-----BEGIN " << type << "-----";
+ std::string trailer = str::stream() << "-----END " << type << "-----";
- // TODO: check if we need both
- size_t privateKey = buf.find("-----BEGIN RSA PRIVATE KEY-----");
- if (privateKey == std::string::npos) {
- privateKey = buf.find("-----BEGIN PRIVATE KEY-----");
+ size_t headerPosition = blob.find(header, position);
+ if (headerPosition == std::string::npos) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "Failed to find PEM blob header: " << header);
}
- if (privateKey == std::string::npos) {
+ size_t trailerPosition = blob.find(trailer, headerPosition);
+ if (trailerPosition == std::string::npos) {
return Status(ErrorCodes::InvalidSSLConfiguration,
- str::stream() << "Failed to find privateKey in: " << fileName);
- }
-
- CERT_BLOB certBlob;
- certBlob.cbData = buf.size() - publicKey;
- certBlob.pbData = reinterpret_cast<BYTE*>(const_cast<char*>(buf.data() + publicKey));
-
- PCCERT_CONTEXT cert;
- BOOL ret = CryptQueryObject(CERT_QUERY_OBJECT_BLOB,
- &certBlob,
- CERT_QUERY_CONTENT_FLAG_ALL,
- CERT_QUERY_FORMAT_FLAG_ALL,
- NULL,
- NULL,
- NULL,
- NULL,
- NULL,
- NULL,
- reinterpret_cast<const void**>(&cert));
- if (!ret) {
- DWORD gle = GetLastError();
- return Status(ErrorCodes::InvalidSSLConfiguration,
- str::stream() << "CryptQueryObject failed to get cert: "
- << errnoWithDescription(gle));
+ str::stream() << "Failed to find PEM blob trailer: " << trailer);
}
- UniqueCertificate certHolder(cert);
- DWORD privateKeyLen{0};
+ trailerPosition += trailer.size();
- ret = CryptStringToBinaryA(buf.c_str() + privateKey,
- 0, // null terminated string
- CRYPT_STRING_BASE64HEADER | CRYPT_STRING_STRICT,
- NULL,
- &privateKeyLen,
- NULL,
- NULL);
+ return StringData(blob.rawData() + headerPosition, trailerPosition - headerPosition);
+}
+
+// Decode a base-64 PEM blob with headers into a binary blob
+StatusWith<std::vector<BYTE>> decodePEMBlob(StringData blob) {
+ DWORD decodeLen{0};
+
+ BOOL ret = CryptStringToBinaryA(
+ blob.rawData(), blob.size(), CRYPT_STRING_BASE64HEADER, NULL, &decodeLen, NULL, NULL);
if (!ret) {
DWORD gle = GetLastError();
if (gle != ERROR_MORE_DATA) {
@@ -455,12 +433,14 @@ StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData passwo
}
}
- std::unique_ptr<BYTE[]> privateKeyBuf = std::make_unique<BYTE[]>(privateKeyLen);
- ret = CryptStringToBinaryA(buf.c_str() + privateKey,
- 0, // null terminated string
- CRYPT_STRING_BASE64HEADER | CRYPT_STRING_STRICT,
- privateKeyBuf.get(),
- &privateKeyLen,
+ std::vector<BYTE> binaryBlobBuf;
+ binaryBlobBuf.resize(decodeLen);
+
+ ret = CryptStringToBinaryA(blob.rawData(),
+ blob.size(),
+ CRYPT_STRING_BASE64HEADER,
+ binaryBlobBuf.data(),
+ &decodeLen,
NULL,
NULL);
if (!ret) {
@@ -470,45 +450,131 @@ StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData passwo
<< errnoWithDescription(gle));
}
+ return std::move(binaryBlobBuf);
+}
- DWORD privateBlobLen{0};
+StatusWith<std::vector<BYTE>> decodeObject(const char* structType,
+ const BYTE* data,
+ size_t length) {
+ DWORD decodeLen{0};
- ret = CryptDecodeObjectEx(X509_ASN_ENCODING,
- PKCS_RSA_PRIVATE_KEY,
- privateKeyBuf.get(),
- privateKeyLen,
- CRYPT_DECODE_SHARE_OID_STRING_FLAG,
- NULL,
- NULL,
- &privateBlobLen);
+ BOOL ret =
+ CryptDecodeObjectEx(X509_ASN_ENCODING, structType, data, length, 0, NULL, NULL, &decodeLen);
if (!ret) {
DWORD gle = GetLastError();
if (gle != ERROR_MORE_DATA) {
return Status(ErrorCodes::InvalidSSLConfiguration,
- str::stream() << "CryptDecodeObjectEx failed to get size of key: "
+ str::stream() << "CryptDecodeObjectEx failed to get size of object: "
<< errnoWithDescription(gle));
}
}
- std::unique_ptr<BYTE[]> privateBlobBuf = std::make_unique<BYTE[]>(privateBlobLen);
+ std::vector<BYTE> binaryBlobBuf;
+ binaryBlobBuf.resize(decodeLen);
- ret = CryptDecodeObjectEx(X509_ASN_ENCODING,
- PKCS_RSA_PRIVATE_KEY,
- privateKeyBuf.get(),
- privateKeyLen,
- CRYPT_DECODE_SHARE_OID_STRING_FLAG,
- NULL,
- privateBlobBuf.get(),
- &privateBlobLen);
+ ret = CryptDecodeObjectEx(
+ X509_ASN_ENCODING, structType, data, length, 0, NULL, binaryBlobBuf.data(), &decodeLen);
if (!ret) {
DWORD gle = GetLastError();
return Status(ErrorCodes::InvalidSSLConfiguration,
- str::stream() << "CryptDecodeObjectEx failed to read key: "
+ str::stream() << "CryptDecodeObjectEx failed to read object: "
+ << errnoWithDescription(gle));
+ }
+
+ return std::move(binaryBlobBuf);
+}
+
+// Read a Certificate PEM file with a private key from disk
+StatusWith<UniqueCertificateWithPrivateKey> readCertPEMFile(StringData fileName,
+ StringData password) {
+ auto swBuf = readFile(fileName);
+ if (!swBuf.isOK()) {
+ return swBuf.getStatus();
+ }
+
+ std::string buf = std::move(swBuf.getValue());
+
+ size_t encryptedPrivateKey = buf.find("-----BEGIN ENCRYPTED PRIVATE KEY-----");
+ if (encryptedPrivateKey != std::string::npos) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "Encrypted private keys are not supported, use the Windows "
+ "certificate store instead: "
+ << fileName);
+ }
+
+ // Search the buffer for the various strings that make up a PEM file
+ auto swPublicKeyBlob = findPEMBlob(buf, "CERTIFICATE"_sd);
+ if (!swPublicKeyBlob.isOK()) {
+ return swPublicKeyBlob.getStatus();
+ }
+
+ auto publicKeyBlob = swPublicKeyBlob.getValue();
+
+ // Multiple certificates in a PEM file are not supported since these certs need to be in the ca
+ // file.
+ auto secondPublicKeyBlobPosition =
+ buf.find("CERTIFICATE", (publicKeyBlob.rawData() + publicKeyBlob.size()) - buf.data());
+ if (secondPublicKeyBlobPosition != std::string::npos) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "Certificate PEM files should only have one certificate, "
+ "intermediate CA certificates belong in the CA file.");
+ }
+
+ // PEM files can have either private key format
+ // Also the private key can either come before or after the certificate
+ auto swPrivateKeyBlob = findPEMBlob(buf, "RSA PRIVATE KEY"_sd);
+ // We expect to find at least one certificate
+ if (!swPrivateKeyBlob.isOK()) {
+ // A "PRIVATE KEY" is actually a PKCS #8 PrivateKeyInfo ASN.1 type. We do not support it for
+ // now so tell the user how to fix it.
+ // Warn user rsa -in roles.key -out roles2.key
+ swPrivateKeyBlob = findPEMBlob(buf, "PRIVATE KEY"_sd);
+ if (!swPrivateKeyBlob.isOK()) {
+ return swPrivateKeyBlob.getStatus();
+ } else {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "Expected to find 'RSA PRIVATE KEY' in PEM file, found "
+ "'PRIVATE KEY' instead.");
+ }
+ }
+
+ auto privateKeyBlob = swPrivateKeyBlob.getValue();
+
+ auto swCert = decodePEMBlob(publicKeyBlob);
+ if (!swCert.isOK()) {
+ return swCert.getStatus();
+ }
+
+ auto certBuf = swCert.getValue();
+
+ PCCERT_CONTEXT cert =
+ CertCreateCertificateContext(X509_ASN_ENCODING, certBuf.data(), certBuf.size());
+
+ if (cert == NULL) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CertCreateCertificateContext failed to decode cert: "
<< errnoWithDescription(gle));
}
+ UniqueCertificate certHolder(cert);
+
+ auto swPrivateKeyBuf = decodePEMBlob(privateKeyBlob);
+ if (!swPrivateKeyBuf.isOK()) {
+ return swPrivateKeyBuf.getStatus();
+ }
+
+ auto privateKeyBuf = swPrivateKeyBuf.getValue();
+
+ auto swPrivateKey =
+ decodeObject(PKCS_RSA_PRIVATE_KEY, privateKeyBuf.data(), privateKeyBuf.size());
+ if (!swPrivateKey.isOK()) {
+ return swPrivateKey.getStatus();
+ }
+
HCRYPTPROV hProv;
std::wstring wstr;
+ BOOL ret;
// Create the right Crypto context depending on whether we running in a server or outside.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/aa375195(v=vs.85).aspx
@@ -562,13 +628,14 @@ StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData passwo
UniqueCryptProvider cryptProvider(hProv);
HCRYPTKEY hkey;
- ret = CryptImportKey(hProv, privateBlobBuf.get(), privateBlobLen, 0, 0, &hkey);
+ ret = CryptImportKey(
+ hProv, swPrivateKey.getValue().data(), swPrivateKey.getValue().size(), 0, 0, &hkey);
if (!ret) {
DWORD gle = GetLastError();
return Status(ErrorCodes::InvalidSSLConfiguration,
str::stream() << "CryptImportKey failed " << errnoWithDescription(gle));
}
- UniqueCryptKey(hKey);
+ UniqueCryptKey keyHolder(hkey);
if (isSSLServer) {
// Server-side SChannel requires a different way of attaching the private key to the
@@ -600,16 +667,17 @@ StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData passwo
<< errnoWithDescription(gle));
}
- return std::move(certHolder);
+ return std::move(
+ UniqueCertificateWithPrivateKey(std::move(certHolder), std::move(cryptProvider)));
}
-Status SSLManagerWindows::loadCertificates(const SSLParams& params) {
+Status SSLManagerWindows::_loadCertificates(const SSLParams& params) {
_clientCertificates[0] = nullptr;
_serverCertificates[0] = nullptr;
// Load the normal PEM file
if (!params.sslPEMKeyFile.empty()) {
- auto swCertificate = readPEMFile(params.sslPEMKeyFile, params.sslPEMKeyPassword);
+ auto swCertificate = readCertPEMFile(params.sslPEMKeyFile, params.sslPEMKeyPassword);
if (!swCertificate.isOK()) {
return swCertificate.getStatus();
}
@@ -619,7 +687,7 @@ Status SSLManagerWindows::loadCertificates(const SSLParams& params) {
// Load the cluster PEM file, only applies to server side code
if (!params.sslClusterFile.empty()) {
- auto swCertificate = readPEMFile(params.sslClusterFile, params.sslClusterPassword);
+ auto swCertificate = readCertPEMFile(params.sslClusterFile, params.sslClusterPassword);
if (!swCertificate.isOK()) {
return swCertificate.getStatus();
}
@@ -627,13 +695,13 @@ Status SSLManagerWindows::loadCertificates(const SSLParams& params) {
_clusterPEMCertificate = std::move(swCertificate.getValue());
}
- if (_pemCertificate) {
- _clientCertificates[0] = _pemCertificate.get();
- _serverCertificates[0] = _pemCertificate.get();
+ if (std::get<0>(_pemCertificate)) {
+ _clientCertificates[0] = std::get<0>(_pemCertificate).get();
+ _serverCertificates[0] = std::get<0>(_pemCertificate).get();
}
- if (_clusterPEMCertificate) {
- _clientCertificates[0] = _clusterPEMCertificate.get();
+ if (std::get<0>(_clusterPEMCertificate)) {
+ _clientCertificates[0] = std::get<0>(_clusterPEMCertificate).get();
}
return Status::OK();
@@ -653,7 +721,7 @@ Status SSLManagerWindows::initSSLContext(SCHANNEL_CRED* cred,
supportedProtocols = SP_PROT_TLS1_SERVER | SP_PROT_TLS1_0_SERVER | SP_PROT_TLS1_1_SERVER |
SP_PROT_TLS1_2_SERVER;
- cred->dwFlags = cred->dwFlags // Flags
+ cred->dwFlags = cred->dwFlags // flags
| SCH_CRED_REVOCATION_CHECK_CHAIN // Check certificate revocation
| SCH_CRED_SNI_CREDENTIAL // Pass along SNI creds
| SCH_CRED_SNI_ENABLE_OCSP // Enable OCSP
@@ -691,11 +759,11 @@ Status SSLManagerWindows::initSSLContext(SCHANNEL_CRED* cred,
if (direction == ConnectionDirection::kOutgoing) {
if (_clientCertificates[0]) {
cred->cCreds = 1;
- cred->paCred = _clientCertificates;
+ cred->paCred = _clientCertificates.data();
}
} else {
cred->cCreds = 1;
- cred->paCred = _serverCertificates;
+ cred->paCred = _serverCertificates.data();
}
return Status::OK();
@@ -705,7 +773,7 @@ SSLConnectionInterface* SSLManagerWindows::connect(Socket* socket) {
std::unique_ptr<SSLConnectionWindows> sslConn =
stdx::make_unique<SSLConnectionWindows>(&_clientCred, socket, nullptr, 0);
- handshake(sslConn.get(), true);
+ _handshake(sslConn.get(), true);
return sslConn.release();
}
@@ -715,12 +783,12 @@ SSLConnectionInterface* SSLManagerWindows::accept(Socket* socket,
std::unique_ptr<SSLConnectionWindows> sslConn =
stdx::make_unique<SSLConnectionWindows>(&_serverCred, socket, initialBytes, len);
- handshake(sslConn.get(), false);
+ _handshake(sslConn.get(), false);
return sslConn.release();
}
-void SSLManagerWindows::handshake(SSLConnectionWindows* conn, bool client) {
+void SSLManagerWindows::_handshake(SSLConnectionWindows* conn, bool client) {
initSSLContext(conn->_cred,
getSSLGlobalParams(),
client ? SSLManagerInterface::ConnectionDirection::kOutgoing