summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Benvenuto <mark.benvenuto@mongodb.com>2018-02-27 14:10:06 -0500
committerMark Benvenuto <mark.benvenuto@mongodb.com>2018-02-27 14:10:06 -0500
commitf2a8d9f2350f8cd5122cf3394b6783e85da5c390 (patch)
tree4782eb9f741bebb3954cd429e36b5b66f6f8d747
parentd4ae81a7154ab57a266b38d4fe41dd12a3c4540a (diff)
downloadmongo-f2a8d9f2350f8cd5122cf3394b6783e85da5c390.tar.gz
SERVER-22411 ASIO SChannel stream implementation
-rw-r--r--src/mongo/util/net/ssl/context_schannel.hpp19
-rw-r--r--src/mongo/util/net/ssl/detail/engine_schannel.hpp181
-rw-r--r--src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp150
-rw-r--r--src/mongo/util/net/ssl/detail/impl/schannel.ipp747
-rw-r--r--src/mongo/util/net/ssl/detail/schannel.hpp600
-rw-r--r--src/mongo/util/net/ssl/impl/context_schannel.ipp26
-rw-r--r--src/mongo/util/net/ssl/impl/error.ipp5
-rw-r--r--src/mongo/util/net/ssl/impl/src.hpp1
-rw-r--r--src/mongo/util/net/ssl_manager_windows.cpp654
-rw-r--r--src/mongo/util/net/ssl_stream.cpp2
10 files changed, 2228 insertions, 157 deletions
diff --git a/src/mongo/util/net/ssl/context_schannel.hpp b/src/mongo/util/net/ssl/context_schannel.hpp
index 12293cbc8e9..043d8cfe4ca 100644
--- a/src/mongo/util/net/ssl/context_schannel.hpp
+++ b/src/mongo/util/net/ssl/context_schannel.hpp
@@ -30,20 +30,17 @@
#include "asio/detail/config.hpp"
-#include <string>
#include "asio/buffer.hpp"
#include "asio/io_context.hpp"
#include "mongo/util/net/ssl/context_base.hpp"
+#include <string>
#include "asio/detail/push_options.hpp"
namespace asio {
namespace ssl {
-class context
- : public context_base,
- private noncopyable
-{
+class context : public context_base, private noncopyable {
public:
/// The native handle type of the SSL context.
typedef SCHANNEL_CRED* native_handle_type;
@@ -77,7 +74,7 @@ public:
* @li As a target for move-assignment.
*/
ASIO_DECL context& operator=(context&& other);
-#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION)
+#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION)
/// Destructor.
ASIO_DECL ~context();
@@ -91,6 +88,8 @@ public:
ASIO_DECL native_handle_type native_handle();
private:
+ SCHANNEL_CRED _cred;
+
// The underlying native implementation.
native_handle_type handle_;
};
@@ -98,8 +97,8 @@ private:
#include "asio/detail/pop_options.hpp"
#if defined(ASIO_HEADER_ONLY)
-# include "mongo/util/net/ssl/impl/context_schannel.ipp"
-#endif // defined(ASIO_HEADER_ONLY)
+#include "mongo/util/net/ssl/impl/context_schannel.ipp"
+#endif // defined(ASIO_HEADER_ONLY)
-} // namespace ssl
-} // namespace asio
+} // namespace ssl
+} // namespace asio
diff --git a/src/mongo/util/net/ssl/detail/engine_schannel.hpp b/src/mongo/util/net/ssl/detail/engine_schannel.hpp
index fa3bad1a7af..f56b71ff8c0 100644
--- a/src/mongo/util/net/ssl/detail/engine_schannel.hpp
+++ b/src/mongo/util/net/ssl/detail/engine_schannel.hpp
@@ -32,6 +32,7 @@
#include "asio/buffer.hpp"
#include "asio/detail/static_mutex.hpp"
+#include "mongo/util/net/ssl/detail/schannel.hpp"
#include "mongo/util/net/ssl/stream_base.hpp"
#include "asio/detail/push_options.hpp"
@@ -40,81 +41,125 @@ namespace asio {
namespace ssl {
namespace detail {
-class engine
-{
+class engine {
public:
- enum want
- {
- // Returned by functions to indicate that the engine wants input. The input
- // buffer should be updated to point to the data. The engine then needs to
- // be called again to retry the operation.
- want_input_and_retry = -2,
-
- // Returned by functions to indicate that the engine wants to write output.
- // The output buffer points to the data to be written. The engine then
- // needs to be called again to retry the operation.
- want_output_and_retry = -1,
-
- // Returned by functions to indicate that the engine doesn't need input or
- // output.
- want_nothing = 0,
-
- // Returned by functions to indicate that the engine wants to write output.
- // The output buffer points to the data to be written. After that the
- // operation is complete, and the engine does not need to be called again.
- want_output = 1
- };
-
- // Construct a new engine for the specified context.
- ASIO_DECL explicit engine(SCHANNEL_CRED* context);
-
- // Destructor.
- ASIO_DECL ~engine();
-
- // Get the underlying implementation in the native type.
- ASIO_DECL PCtxtHandle native_handle();
-
- // Perform an SSL handshake using either SSL_connect (client-side) or
- // SSL_accept (server-side).
- ASIO_DECL want handshake(
- stream_base::handshake_type type, asio::error_code& ec);
-
- // Perform a graceful shutdown of the SSL session.
- ASIO_DECL want shutdown(asio::error_code& ec);
-
- // Write bytes to the SSL session.
- ASIO_DECL want write(const asio::const_buffer& data,
- asio::error_code& ec, std::size_t& bytes_transferred);
-
- // Read bytes from the SSL session.
- ASIO_DECL want read(const asio::mutable_buffer& data,
- asio::error_code& ec, std::size_t& bytes_transferred);
-
- // Get output data to be written to the transport.
- ASIO_DECL asio::mutable_buffer get_output(
- const asio::mutable_buffer& data);
-
- // Put input data that was read from the transport.
- ASIO_DECL asio::const_buffer put_input(
- const asio::const_buffer& data);
-
- // Map an error::eof code returned by the underlying transport according to
- // the type and state of the SSL session. Returns a const reference to the
- // error code object, suitable for passing to a completion handler.
- ASIO_DECL const asio::error_code& map_error_code(
- asio::error_code& ec) const;
+ enum want {
+ // Returned by functions to indicate that the engine wants input. The input
+ // buffer should be updated to point to the data. The engine then needs to
+ // be called again to retry the operation.
+ want_input_and_retry = -2,
+
+ // Returned by functions to indicate that the engine wants to write output.
+ // The output buffer points to the data to be written. The engine then
+ // needs to be called again to retry the operation.
+ want_output_and_retry = -1,
+
+ // Returned by functions to indicate that the engine doesn't need input or
+ // output.
+ want_nothing = 0,
+
+ // Returned by functions to indicate that the engine wants to write output.
+ // The output buffer points to the data to be written. After that the
+ // operation is complete, and the engine does not need to be called again.
+ want_output = 1
+ };
+
+ // Construct a new engine for the specified context.
+ ASIO_DECL explicit engine(SCHANNEL_CRED* context);
+
+ // Destructor.
+ ASIO_DECL ~engine();
+
+ // Get the underlying implementation in the native type.
+ ASIO_DECL PCtxtHandle native_handle();
+
+ // Perform an SSL handshake using either SSL_connect (client-side) or
+ // SSL_accept (server-side).
+ ASIO_DECL want handshake(stream_base::handshake_type type, asio::error_code& ec);
+
+ // Perform a graceful shutdown of the SSL session.
+ ASIO_DECL want shutdown(asio::error_code& ec);
+
+ // Write bytes to the SSL session.
+ ASIO_DECL want write(const asio::const_buffer& data,
+ asio::error_code& ec,
+ std::size_t& bytes_transferred);
+
+ // Read bytes from the SSL session.
+ ASIO_DECL want read(const asio::mutable_buffer& data,
+ asio::error_code& ec,
+ std::size_t& bytes_transferred);
+
+ // Get output data to be written to the transport.
+ ASIO_DECL asio::mutable_buffer get_output(const asio::mutable_buffer& data);
+
+ // Put input data that was read from the transport.
+ ASIO_DECL asio::const_buffer put_input(const asio::const_buffer& data);
+
+ // Map an error::eof code returned by the underlying transport according to
+ // the type and state of the SSL session. Returns a const reference to the
+ // error code object, suitable for passing to a completion handler.
+ ASIO_DECL const asio::error_code& map_error_code(asio::error_code& ec) const;
+
+ // MONGODB additions:
+ // Set the Server name for TLS SNI purposes.
+ ASIO_DECL void set_server_name(const std::wstring name);
private:
- // Disallow copying and assignment.
- engine(const engine&);
- engine& operator=(const engine&);
+ // Disallow copying and assignment.
+ engine(const engine&);
+ engine& operator=(const engine&);
private:
+ // SChannel context handle
+ CtxtHandle _hcxt;
+
+ // Credential handle
+ CredHandle _hcred;
+
+ // Credentials for TLS handshake
+ SCHANNEL_CRED* _pCred;
+
+ // TLS SNI server name
+ std::wstring _serverName;
+
+ // Engine State machine
+ //
+ enum class EngineState {
+ // Initial State
+ NeedsHandshake,
+
+ // Normal SSL Conversation in progress
+ InProgress,
+
+ // In SSL shutdown
+ InShutdown,
+ };
+
+ // Engine state
+ EngineState _state{EngineState::NeedsHandshake};
+
+ // Data received from remote side, shared across state machines
+ ReusableBuffer _inBuffer;
+
+ // Data to send to remote side, shared across state machines
+ ReusableBuffer _outBuffer;
+
+ // Extra buffer - for when more then one packet is read from the remote side
+ ReusableBuffer _extraBuffer;
+
+ // Handshake state machine
+ SSLHandshakeManager _handshakeManager;
+
+ // Read state machine
+ SSLReadManager _readManager;
+ // Write state machine
+ SSLWriteManager _writeManager;
};
#include "asio/detail/pop_options.hpp"
-} // namespace detail
-} // namespace ssl
-} // namespace asio
+} // namespace detail
+} // namespace ssl
+} // namespace asio
diff --git a/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp b/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp
index 41aca7f789f..a2d5d8c21f5 100644
--- a/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp
+++ b/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp
@@ -41,62 +41,144 @@ namespace asio {
namespace ssl {
namespace detail {
+
engine::engine(SCHANNEL_CRED* context)
-{
+ : _pCred(context),
+ _hcxt{0, 0},
+ _hcred{0, 0},
+ _inBuffer(kDefaultBufferSize),
+ _outBuffer(kDefaultBufferSize),
+ _extraBuffer(kDefaultBufferSize),
+ _handshakeManager(
+ &_hcxt, &_hcred, _serverName, &_inBuffer, &_outBuffer, &_extraBuffer, _pCred),
+ _readManager(&_hcxt, &_hcred, &_inBuffer, &_extraBuffer),
+ _writeManager(&_hcxt, &_outBuffer) {}
+
+engine::~engine() {
+ DeleteSecurityContext(&_hcxt);
+ FreeCredentialsHandle(&_hcred);
}
-engine::~engine()
-{
+PCtxtHandle engine::native_handle() {
+ return &_hcxt;
}
-PCtxtHandle engine::native_handle()
-{
- return nullptr;
+engine::want ssl_want_to_engine(ssl_want want) {
+ static_assert(static_cast<int>(ssl_want::want_input_and_retry) ==
+ static_cast<int>(engine::want_input_and_retry),
+ "bad");
+ static_assert(static_cast<int>(ssl_want::want_output_and_retry) ==
+ static_cast<int>(engine::want_output_and_retry),
+ "bad");
+ static_assert(
+ static_cast<int>(ssl_want::want_nothing) == static_cast<int>(engine::want_nothing), "bad");
+ static_assert(static_cast<int>(ssl_want::want_output) == static_cast<int>(engine::want_output),
+ "bad");
+
+ return static_cast<engine::want>(want);
}
-engine::want engine::handshake(
- stream_base::handshake_type type, asio::error_code& ec)
-{
- return want::want_nothing;
+engine::want engine::handshake(stream_base::handshake_type type, asio::error_code& ec) {
+ // ASIO will call handshake once more after we send out the last data
+ // so we need to tell them we are done with data to send.
+ if (_state != EngineState::NeedsHandshake) {
+ return want::want_nothing;
+ }
+
+ _handshakeManager.setMode((type == asio::ssl::stream_base::client)
+ ? SSLHandshakeManager::HandshakeMode::Client
+ : SSLHandshakeManager::HandshakeMode::Server);
+ SSLHandshakeManager::HandshakeState state;
+ auto w = _handshakeManager.nextHandshake(ec, &state);
+ if (w == ssl_want::want_nothing || state == SSLHandshakeManager::HandshakeState::Done) {
+ _state = EngineState::InProgress;
+ }
+
+ return ssl_want_to_engine(w);
}
-engine::want engine::shutdown(asio::error_code& ec)
-{
- return want::want_nothing;
+engine::want engine::shutdown(asio::error_code& ec) {
+ return ssl_want_to_engine(_handshakeManager.beginShutdown(ec));
}
engine::want engine::write(const asio::const_buffer& data,
- asio::error_code& ec, std::size_t& bytes_transferred)
-{
- return want::want_nothing;
+ asio::error_code& ec,
+ std::size_t& bytes_transferred) {
+ if (data.size() == 0) {
+ ec = asio::error_code();
+ return engine::want_nothing;
+ }
+
+ if (_state == EngineState::NeedsHandshake || _state == EngineState::InShutdown) {
+ // Why are we trying to write before the handshake is done?
+ ASIO_ASSERT(false);
+ return want::want_nothing;
+ } else {
+ return ssl_want_to_engine(
+ _writeManager.writeUnencryptedData(data.data(), data.size(), bytes_transferred, ec));
+ }
}
engine::want engine::read(const asio::mutable_buffer& data,
- asio::error_code& ec, std::size_t& bytes_transferred)
-{
- return want::want_nothing;
+ asio::error_code& ec,
+ std::size_t& bytes_transferred) {
+ if (data.size() == 0) {
+ ec = asio::error_code();
+ return engine::want_nothing;
+ }
+
+
+ if (_state == EngineState::NeedsHandshake) {
+ // Why are we trying to read before the handshake is done?
+ ASIO_ASSERT(false);
+ return want::want_nothing;
+ } else {
+ SSLReadManager::DecryptState decryptState;
+ auto want = ssl_want_to_engine(_readManager.readDecryptedData(
+ data.data(), data.size(), ec, bytes_transferred, &decryptState));
+ if (ec) {
+ return want;
+ }
+
+ if (decryptState != SSLReadManager::DecryptState::Continue) {
+ if (decryptState == SSLReadManager::DecryptState::Shutdown) {
+ _state = EngineState::InShutdown;
+
+ return ssl_want_to_engine(_handshakeManager.beginShutdown(ec));
+ }
+ }
+
+ return want;
+ }
}
-asio::mutable_buffer engine::get_output(
- const asio::mutable_buffer& data)
-{
- return asio::mutable_buffer(nullptr, 0);
+asio::mutable_buffer engine::get_output(const asio::mutable_buffer& data) {
+ std::size_t length;
+ _outBuffer.readInto(data.data(), data.size(), length);
+
+ return asio::buffer(data, length);
+}
+
+asio::const_buffer engine::put_input(const asio::const_buffer& data) {
+ if (_state == EngineState::NeedsHandshake) {
+ _handshakeManager.writeEncryptedData(data.data(), data.size());
+ } else {
+ _readManager.writeData(data.data(), data.size());
+ }
+
+ return asio::buffer(data + data.size());
}
-asio::const_buffer engine::put_input(
- const asio::const_buffer& data)
-{
- return asio::const_buffer(nullptr, 0);
+void engine::set_server_name(const std::wstring name) {
+ _serverName = name;
}
-const asio::error_code& engine::map_error_code(
- asio::error_code& ec) const
-{
- return ec;
+const asio::error_code& engine::map_error_code(asio::error_code& ec) const {
+ return ec;
}
#include "asio/detail/pop_options.hpp"
-} // namespace detail
-} // namespace ssl
-} // namespace asio
+} // namespace detail
+} // namespace ssl
+} // namespace asio
diff --git a/src/mongo/util/net/ssl/detail/impl/schannel.ipp b/src/mongo/util/net/ssl/detail/impl/schannel.ipp
new file mode 100644
index 00000000000..14da4977b6d
--- /dev/null
+++ b/src/mongo/util/net/ssl/detail/impl/schannel.ipp
@@ -0,0 +1,747 @@
+/**
+ * Copyright (C) 2018 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects
+ * for all of the code used other than as permitted herein. If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so. If you do not
+ * wish to do so, delete this exception statement from your version. If you
+ * delete this exception statement from all source files in the program,
+ * then also delete it in the license file.
+ */
+
+#pragma once
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+
+#include "asio/detail/assert.hpp"
+
+namespace asio {
+namespace ssl {
+namespace detail {
+
+/**
+ * Start or continue SSL handshake.
+ *
+ * Must be called until HandshakeState::Done is returned.
+ *
+ * Return status code to indicate whether it needs more data or if data needs to be sent to the
+ * other side.
+ */
+ssl_want SSLHandshakeManager::nextHandshake(asio::error_code& ec, HandshakeState* pHandshakeState) {
+ ASIO_ASSERT(_mode != HandshakeMode::Unknown);
+ ec = asio::error_code();
+ *pHandshakeState = HandshakeState::Continue;
+
+ if (_state == State::HandshakeStart) {
+ ssl_want want;
+
+ if (_mode == HandshakeMode::Server) {
+ // ASIO will ask for the handshake to start when the input buffer is empty
+ // but we want data first so tell ASIO to give us data
+ if (_pInBuffer->empty()) {
+ return ssl_want::want_input_and_retry;
+ }
+
+ startServerHandshake(ec);
+ if (ec) {
+ return ssl_want::want_nothing;
+ }
+
+ want = doServerHandshake(true, ec, pHandshakeState);
+ if (ec) {
+ return want;
+ }
+
+ } else {
+ startClientHandshake(ec);
+ if (ec) {
+ return ssl_want::want_nothing;
+ }
+
+ want = doClientHandshake(ec);
+ if (ec) {
+ return want;
+ }
+ }
+
+ setState(State::NeedMoreHandshakeData);
+
+ return want;
+ } else if (_state == State::NeedMoreHandshakeData) {
+ return ssl_want::want_input_and_retry;
+ } else {
+ ssl_want want;
+
+ if (_mode == HandshakeMode::Server) {
+ want = doServerHandshake(false, ec, pHandshakeState);
+ } else {
+ want = doClientHandshake(ec);
+ }
+
+ if (ec) {
+ return want;
+ }
+
+ if (want == ssl_want::want_nothing || *pHandshakeState == HandshakeState::Done) {
+ setState(State::Done);
+ } else {
+ setState(State::NeedMoreHandshakeData);
+ }
+
+ return want;
+ }
+}
+
+/**
+ * Begin graceful SSL shutdown. Either:
+ * - respond to already received alert signalling connection shutdown on remote side
+ * - start SSL shutdown by signalling remote side
+ */
+ssl_want SSLHandshakeManager::beginShutdown(asio::error_code& ec) {
+ ASIO_ASSERT(_mode != HandshakeMode::Unknown);
+ _state = State::HandshakeStart;
+
+ return startShutdown(ec);
+}
+
+/*
+ * Injest data from ASIO that has been received.
+ */
+void SSLHandshakeManager::writeEncryptedData(const void* data, std::size_t length) {
+ // We have more data, it may not be enough to decode. We will decide if we have enough on
+ // the next nextHandshake call.
+ if (_state != State::HandshakeStart) {
+ setState(State::HaveEncryptedData);
+ }
+
+ _pInBuffer->append(data, length);
+}
+
+
+void SSLHandshakeManager::startServerHandshake(asio::error_code& ec) {
+ TimeStamp lifetime;
+ SECURITY_STATUS ss = AcquireCredentialsHandleW(NULL,
+ const_cast<LPWSTR>(UNISP_NAME),
+ SECPKG_CRED_INBOUND,
+ NULL,
+ _cred,
+ NULL,
+ NULL,
+ _phcred,
+ &lifetime);
+ if (ss != SEC_E_OK) {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return;
+ }
+}
+
+void SSLHandshakeManager::startClientHandshake(asio::error_code& ec) {
+ TimeStamp lifetime;
+ SECURITY_STATUS ss = AcquireCredentialsHandleW(NULL,
+ const_cast<LPWSTR>(UNISP_NAME),
+ SECPKG_CRED_OUTBOUND,
+ NULL,
+ _cred,
+ NULL,
+ NULL,
+ _phcred,
+ &lifetime);
+
+ if (ss != SEC_E_OK) {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return;
+ }
+}
+
+ssl_want SSLHandshakeManager::startShutdown(asio::error_code& ec) {
+ DWORD shutdownCode = SCHANNEL_SHUTDOWN;
+
+ std::array<SecBuffer, 1> inputBuffers;
+ inputBuffers[0].cbBuffer = sizeof(shutdownCode);
+ inputBuffers[0].BufferType = SECBUFFER_TOKEN;
+ inputBuffers[0].pvBuffer = &shutdownCode;
+
+ SecBufferDesc inputBufferDesc;
+ inputBufferDesc.ulVersion = SECBUFFER_VERSION;
+ inputBufferDesc.cBuffers = inputBuffers.size();
+ inputBufferDesc.pBuffers = inputBuffers.data();
+
+ SECURITY_STATUS ss = ApplyControlToken(_phctxt, &inputBufferDesc);
+
+ if (ss != SEC_E_OK) {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return ssl_want::want_nothing;
+ }
+
+ TimeStamp lifetime;
+
+ std::array<SecBuffer, 1> outputBuffers;
+ outputBuffers[0].cbBuffer = 0;
+ outputBuffers[0].BufferType = SECBUFFER_TOKEN;
+ outputBuffers[0].pvBuffer = NULL;
+ ContextBufferDeleter deleter(&outputBuffers[0].pvBuffer);
+
+ SecBufferDesc outputBufferDesc;
+ outputBufferDesc.ulVersion = SECBUFFER_VERSION;
+ outputBufferDesc.cBuffers = outputBuffers.size();
+ outputBufferDesc.pBuffers = outputBuffers.data();
+
+ if (_mode == HandshakeMode::Server) {
+ ULONG attribs = getServerFlags() | ASC_REQ_ALLOCATE_MEMORY;
+
+ SECURITY_STATUS ss = AcceptSecurityContext(
+ _phcred, _phctxt, NULL, attribs, 0, _phctxt, &outputBufferDesc, &attribs, &lifetime);
+
+ if (ss != SEC_E_OK) {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return ssl_want::want_nothing;
+ }
+
+ _pOutBuffer->reset();
+ _pOutBuffer->append(outputBuffers[0].pvBuffer, outputBuffers[0].cbBuffer);
+
+ if (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0) {
+ ec = asio::error::eof;
+ return ssl_want::want_output;
+ } else {
+ return ssl_want::want_nothing;
+ }
+ } else {
+ ULONG ContextAttributes;
+ DWORD sspiFlags = getClientFlags() | ISC_REQ_ALLOCATE_MEMORY;
+
+ ss = InitializeSecurityContextW(_phcred,
+ _phctxt,
+ const_cast<SEC_WCHAR*>(_serverName.c_str()),
+ sspiFlags,
+ 0,
+ 0,
+ NULL,
+ 0,
+ _phctxt,
+ &outputBufferDesc,
+ &ContextAttributes,
+ &lifetime);
+
+ if (ss != SEC_E_OK) {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return ssl_want::want_nothing;
+ }
+
+ // TODO - I have not found a way to hit this code path
+ ASIO_ASSERT(false);
+ }
+
+ return ssl_want::want_nothing;
+}
+
+ssl_want SSLHandshakeManager::doServerHandshake(bool newConversation,
+ asio::error_code& ec,
+ HandshakeState* pHandshakeState) {
+ TimeStamp lifetime;
+
+ _pOutBuffer->resize(kDefaultBufferSize);
+ _alertBuffer.resize(1024);
+
+ std::array<SecBuffer, 2> outputBuffers;
+ outputBuffers[0].cbBuffer = _pOutBuffer->size();
+ outputBuffers[0].BufferType = SECBUFFER_TOKEN;
+ outputBuffers[0].pvBuffer = _pOutBuffer->data();
+
+ outputBuffers[1].cbBuffer = _alertBuffer.size();
+ outputBuffers[1].BufferType = SECBUFFER_ALERT;
+ outputBuffers[1].pvBuffer = _alertBuffer.data();
+
+ SecBufferDesc outputBufferDesc;
+ outputBufferDesc.ulVersion = SECBUFFER_VERSION;
+ outputBufferDesc.cBuffers = outputBuffers.size();
+ outputBufferDesc.pBuffers = outputBuffers.data();
+
+ std::array<SecBuffer, 2> inputBuffers;
+ inputBuffers[0].cbBuffer = _pInBuffer->size();
+ inputBuffers[0].BufferType = SECBUFFER_TOKEN;
+ inputBuffers[0].pvBuffer = _pInBuffer->data();
+
+ inputBuffers[1].cbBuffer = 0;
+ inputBuffers[1].BufferType = SECBUFFER_EMPTY;
+ inputBuffers[1].pvBuffer = NULL;
+
+ SecBufferDesc inputBufferDesc;
+ inputBufferDesc.ulVersion = SECBUFFER_VERSION;
+ inputBufferDesc.cBuffers = inputBuffers.size();
+ inputBufferDesc.pBuffers = inputBuffers.data();
+
+ ULONG attribs = getServerFlags();
+ ULONG retAttribs = 0;
+
+ SECURITY_STATUS ss = AcceptSecurityContext(_phcred,
+ newConversation ? NULL : _phctxt,
+ &inputBufferDesc,
+ attribs,
+ 0,
+ _phctxt,
+ &outputBufferDesc,
+ &retAttribs,
+ &lifetime);
+
+ if (ss < SEC_E_OK) {
+ if (ss == SEC_E_INCOMPLETE_MESSAGE) {
+ // TODO: consider using SECBUFFER_MISSING and approriate optimizations
+ return ssl_want::want_input_and_retry;
+ }
+
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+
+ if ((retAttribs & ASC_RET_EXTENDED_ERROR) && (outputBuffers[1].cbBuffer > 0)) {
+ _pOutBuffer->resize(outputBuffers[0].cbBuffer);
+
+ // Tell ASIO we have something to send back the last data
+ return ssl_want::want_output;
+ }
+
+ return ssl_want::want_nothing;
+ }
+ invariant(attribs == retAttribs);
+
+ if (inputBuffers[1].BufferType == SECBUFFER_EXTRA) {
+ _pExtraEncryptedBuffer->reset();
+ _pExtraEncryptedBuffer->append(inputBuffers[1].pvBuffer, inputBuffers[1].cbBuffer);
+ }
+
+
+ // Next, figure out if we need to send any data out
+ bool needOutput{false};
+
+ // Did AcceptSecurityContext say we need to continue or is it done but left data in the
+ // output buffer then we need to sent the data out.
+ if (SEC_I_CONTINUE_NEEDED == ss || SEC_I_COMPLETE_AND_CONTINUE == ss ||
+ (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0)) {
+ needOutput = true;
+ }
+
+ // Tell the reusable buffer size of the data written.
+ _pOutBuffer->resize(outputBuffers[0].cbBuffer);
+
+ // Reset the input buffer
+ _pInBuffer->reset();
+
+ // Check if we have any additional encrypted data
+ if (!_pExtraEncryptedBuffer->empty()) {
+ _pInBuffer->swap(*_pExtraEncryptedBuffer);
+ _pExtraEncryptedBuffer->reset();
+
+ setState(State::HaveEncryptedData);
+ }
+
+ if (needOutput) {
+ // If AcceptSecurityContext returns SEC_E_OK, then the handshake is done
+ if (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0) {
+ *pHandshakeState = HandshakeState::Done;
+
+ // We have output, but no need to retry anymore
+ return ssl_want::want_output;
+ }
+
+ return ssl_want::want_output_and_retry;
+ }
+
+ return ssl_want::want_nothing;
+}
+
+ssl_want SSLHandshakeManager::doClientHandshake(asio::error_code& ec) {
+ DWORD sspiFlags = getClientFlags() | ISC_REQ_ALLOCATE_MEMORY;
+
+ std::array<SecBuffer, 3> outputBuffers;
+
+ outputBuffers[0].cbBuffer = 0;
+ outputBuffers[0].BufferType = SECBUFFER_TOKEN;
+ outputBuffers[0].pvBuffer = NULL;
+ ContextBufferDeleter deleter(&outputBuffers[0].pvBuffer);
+
+ outputBuffers[1].cbBuffer = 0;
+ outputBuffers[1].BufferType = SECBUFFER_ALERT;
+ outputBuffers[1].pvBuffer = NULL;
+ ContextBufferDeleter alertDeleter(&outputBuffers[1].pvBuffer);
+
+ outputBuffers[2].cbBuffer = 0;
+ outputBuffers[2].BufferType = SECBUFFER_EMPTY;
+ outputBuffers[2].pvBuffer = NULL;
+
+ SecBufferDesc outputBufferDesc;
+ outputBufferDesc.ulVersion = SECBUFFER_VERSION;
+ outputBufferDesc.cBuffers = outputBuffers.size();
+ outputBufferDesc.pBuffers = outputBuffers.data();
+
+ std::array<SecBuffer, 2> inputBuffers;
+
+ SECURITY_STATUS ss;
+ TimeStamp lifetime;
+ ULONG retAttribs;
+
+ // If the input buffer is empty, this is the start of the client handshake.
+ if (!_pInBuffer->empty()) {
+ inputBuffers[0].cbBuffer = _pInBuffer->size();
+ inputBuffers[0].BufferType = SECBUFFER_TOKEN;
+ inputBuffers[0].pvBuffer = _pInBuffer->data();
+
+ inputBuffers[1].cbBuffer = 0;
+ inputBuffers[1].BufferType = SECBUFFER_EMPTY;
+ inputBuffers[1].pvBuffer = NULL;
+
+ SecBufferDesc inputBufferDesc;
+ inputBufferDesc.ulVersion = SECBUFFER_VERSION;
+ inputBufferDesc.cBuffers = inputBuffers.size();
+ inputBufferDesc.pBuffers = inputBuffers.data();
+
+ ss = InitializeSecurityContextW(_phcred,
+ _phctxt,
+ const_cast<SEC_WCHAR*>(_serverName.c_str()),
+ sspiFlags,
+ 0,
+ 0,
+ &inputBufferDesc,
+ 0,
+ _phctxt,
+ &outputBufferDesc,
+ &retAttribs,
+ &lifetime);
+ } else {
+ ss = InitializeSecurityContextW(_phcred,
+ NULL,
+ const_cast<SEC_WCHAR*>(_serverName.c_str()),
+ sspiFlags,
+ 0,
+ 0,
+ NULL,
+ 0,
+ _phctxt,
+ &outputBufferDesc,
+ &retAttribs,
+ &lifetime);
+ }
+
+ if (ss < SEC_E_OK) {
+ if (ss == SEC_E_INCOMPLETE_MESSAGE) {
+ return ssl_want::want_input_and_retry;
+ }
+
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+
+ if ((retAttribs & ISC_RET_EXTENDED_ERROR) && (outputBuffers[1].cbBuffer > 0)) {
+ _pOutBuffer->reset();
+ _pOutBuffer->append(outputBuffers[0].pvBuffer, outputBuffers[0].cbBuffer);
+
+ // Tell ASIO we have something to send back the last data
+ return ssl_want::want_output;
+ }
+
+ return ssl_want::want_nothing;
+ }
+ invariant(sspiFlags == retAttribs);
+
+ if (_pInBuffer->size()) {
+ // Locate (optional) extra buffer
+ if (inputBuffers[1].BufferType == SECBUFFER_EXTRA) {
+ _pExtraEncryptedBuffer->reset();
+ _pExtraEncryptedBuffer->append(inputBuffers[1].pvBuffer, inputBuffers[1].cbBuffer);
+ }
+ }
+
+ // Next, figure out if we need to send any data out
+ bool needOutput{false};
+
+ // Did AcceptSecurityContext say we need to continue or is it done but left data in the
+ // output buffer then we need to sent the data out.
+ if (SEC_I_CONTINUE_NEEDED == ss || SEC_I_COMPLETE_AND_CONTINUE == ss ||
+ (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0)) {
+ needOutput = true;
+ }
+
+ _pOutBuffer->reset();
+ _pOutBuffer->append(outputBuffers[0].pvBuffer, outputBuffers[0].cbBuffer);
+
+ // Reset the input buffer
+ _pInBuffer->reset();
+
+ // Check if we have any additional encrypted data
+ if (!_pExtraEncryptedBuffer->empty()) {
+ _pInBuffer->swap(*_pExtraEncryptedBuffer);
+ _pExtraEncryptedBuffer->reset();
+
+ setState(State::HaveEncryptedData);
+ }
+
+ if (needOutput) {
+ return ssl_want::want_output_and_retry;
+ }
+
+ return ssl_want::want_nothing;
+}
+
+/**
+ * Read decrypted data if encrypted data was provided via writeData and succesfully decrypted.
+ */
+ssl_want SSLReadManager::readDecryptedData(void* data,
+ std::size_t length,
+ asio::error_code& ec,
+ std::size_t& bytes_transferred,
+ DecryptState* pDecryptState) {
+ bytes_transferred = 0;
+ ec = asio::error_code();
+ *pDecryptState = DecryptState::Continue;
+
+ // Our last state was that we needed more encrypted data, so tell ASIO we still want some
+ if (_state == State::NeedMoreEncryptedData) {
+ return ssl_want::want_input_and_retry;
+ }
+
+ // If we have encrypted data, try to decrypt it
+ if (_state == State::HaveEncryptedData) {
+ ssl_want wantState = decryptBuffer(ec, pDecryptState);
+ if (ec) {
+ return wantState;
+ }
+
+ // If remote side started shutdown, bail
+ if (*pDecryptState != DecryptState::Continue) {
+ return ssl_want::want_nothing;
+ }
+
+ if (wantState == ssl_want::want_input_and_retry) {
+ setState(State::NeedMoreEncryptedData);
+ }
+
+ if (wantState != ssl_want::want_nothing) {
+ return wantState;
+ }
+ }
+
+ // We decrypted data in the past, hand it back to ASIO until we are out of decrypted data
+ ASIO_ASSERT(_state == State::HaveDecryptedData);
+
+ _pInBuffer->readInto(data, length, bytes_transferred);
+
+ // Have we read all the decrypted data?
+ if (_pInBuffer->empty()) {
+ // If we have some extra encrypted data, it needs to be checked if it is at least a
+ // valid SSL packet, so set the state machine to reflect that we have some encrypted
+ // data.
+ if (!_pExtraEncryptedBuffer->empty()) {
+ _pInBuffer->swap(*_pExtraEncryptedBuffer);
+ _pExtraEncryptedBuffer->reset();
+ setState(State::HaveEncryptedData);
+ } else {
+ // We are empty so reset our state to need encrypted data for the next call
+ setState(State::NeedMoreEncryptedData);
+ }
+ }
+
+ return ssl_want::want_nothing;
+}
+
+ssl_want SSLReadManager::decryptBuffer(asio::error_code& ec, DecryptState* pDecryptState) {
+ std::array<SecBuffer, 4> securityBuffers;
+ securityBuffers[0].cbBuffer = _pInBuffer->size();
+ securityBuffers[0].BufferType = SECBUFFER_DATA;
+ securityBuffers[0].pvBuffer = _pInBuffer->data();
+
+ securityBuffers[1].cbBuffer = 0;
+ securityBuffers[1].BufferType = SECBUFFER_EMPTY;
+ securityBuffers[1].pvBuffer = NULL;
+
+ securityBuffers[2].cbBuffer = 0;
+ securityBuffers[2].BufferType = SECBUFFER_EMPTY;
+ securityBuffers[2].pvBuffer = NULL;
+
+ securityBuffers[3].cbBuffer = 0;
+ securityBuffers[3].BufferType = SECBUFFER_EMPTY;
+ securityBuffers[3].pvBuffer = NULL;
+
+ SecBufferDesc bufferDesc;
+ bufferDesc.ulVersion = SECBUFFER_VERSION;
+ bufferDesc.cBuffers = securityBuffers.size();
+ bufferDesc.pBuffers = securityBuffers.data();
+
+ SECURITY_STATUS ss = DecryptMessage(_phctxt, &bufferDesc, 0, NULL);
+
+ if (ss < SEC_E_OK) {
+ if (ss == SEC_E_INCOMPLETE_MESSAGE) {
+ return ssl_want::want_input_and_retry;
+ } else {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return ssl_want::want_nothing;
+ }
+ }
+
+ // Shutdown has been initiated at the client side
+ if (ss == SEC_I_CONTEXT_EXPIRED) {
+ *pDecryptState = DecryptState::Shutdown;
+ } else if (ss == SEC_I_RENEGOTIATE) {
+ *pDecryptState = DecryptState::Renegotiate;
+
+ // Fail the connection on SSL renegotiations
+ ec = asio::ssl::error::stream_truncated;
+ return ssl_want::want_nothing;
+ }
+
+ if (securityBuffers[1].cbBuffer > 0) {
+ _pInBuffer->resetPos(securityBuffers[1].pvBuffer, securityBuffers[1].cbBuffer);
+ }
+
+ // The network layer may have read more then 1 SSL packet so remember the extra data.
+ if (securityBuffers[3].BufferType == SECBUFFER_EXTRA && securityBuffers[3].cbBuffer > 0) {
+ ASIO_ASSERT(_pExtraEncryptedBuffer->empty());
+ _pExtraEncryptedBuffer->append(securityBuffers[3].pvBuffer, securityBuffers[3].cbBuffer);
+ }
+
+ setState(State::HaveDecryptedData);
+
+ return ssl_want::want_nothing;
+}
+
+
+/**
+ * Encrypts data to be sent to the remote side.
+ *
+ * If the message is >= max packet size, it will return want_output_and_retry, and expects
+ * callers to continue to call it with the same parameters until want_output is returned.
+ */
+ssl_want SSLWriteManager::writeUnencryptedData(const void* pMessage,
+ std::size_t messageLength,
+ std::size_t& bytes_transferred,
+ asio::error_code& ec) {
+ ec = asio::error_code();
+
+ if (_securityTrailerLength == ULONG_MAX) {
+ SecPkgContext_StreamSizes secPkgContextStreamSizes;
+
+ SECURITY_STATUS ss =
+ QueryContextAttributes(_phctxt, SECPKG_ATTR_STREAM_SIZES, &secPkgContextStreamSizes);
+
+ if (ss < SEC_E_OK) {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return ssl_want::want_nothing;
+ }
+
+ _securityTrailerLength = secPkgContextStreamSizes.cbTrailer;
+ _securityMaxMessageLength = secPkgContextStreamSizes.cbMaximumMessage;
+ _securityHeaderLength = secPkgContextStreamSizes.cbHeader;
+ }
+
+ // Do we need to fragment the message out?
+ if (messageLength > _securityMaxMessageLength) {
+ // Since the message is too large for SSL, we have to write out fragments. We rely on
+ // the fact that ASIO will keep giving us the same buffer back as long as it is asked to
+ // retry.
+ std::size_t fragmentLength =
+ std::min(_securityMaxMessageLength, messageLength - _lastWriteOffset);
+ ssl_want want = encryptMessage(reinterpret_cast<const char*>(pMessage) + _lastWriteOffset,
+ fragmentLength,
+ bytes_transferred,
+ ec);
+ if (ec) {
+ return want;
+ }
+
+ _lastWriteOffset += fragmentLength;
+
+ // We have more data to give ASIO after this fragment
+ if (_lastWriteOffset < messageLength) {
+ return ssl_want::want_output_and_retry;
+ }
+
+ // We have consumed all the data given to us over multiple consecutive calls, reset
+ // position.
+ _lastWriteOffset = 0;
+
+ // ASIO's buffering of engine calls assumes that bytes_transfered refers to all the
+ // bytes we transfered total when want_output is returned. It ignores bytes_transfered
+ // when want_output_and_retry is returned;
+ bytes_transferred = messageLength;
+
+ return ssl_want::want_output;
+ } else {
+ // Reset fragmentation position
+ _lastWriteOffset = 0;
+
+ // Send message as is without fragmentation
+ return encryptMessage(pMessage, messageLength, bytes_transferred, ec);
+ }
+}
+
+ssl_want SSLWriteManager::encryptMessage(const void* pMessage,
+ std::size_t messageLength,
+ std::size_t& bytes_transferred,
+ asio::error_code& ec) {
+ ASIO_ASSERT(_pOutBuffer->empty());
+ _pOutBuffer->resize(_securityTrailerLength + _securityHeaderLength + messageLength);
+
+ std::array<SecBuffer, 4> securityBuffers;
+
+ securityBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
+ securityBuffers[0].cbBuffer = _securityHeaderLength;
+ securityBuffers[0].pvBuffer = _pOutBuffer->data();
+
+ memcpy_s(_pOutBuffer->data() + _securityHeaderLength,
+ _pOutBuffer->size() - _securityHeaderLength - _securityTrailerLength,
+ pMessage,
+ messageLength);
+
+ securityBuffers[1].BufferType = SECBUFFER_DATA;
+ securityBuffers[1].cbBuffer = messageLength;
+ securityBuffers[1].pvBuffer = _pOutBuffer->data() + _securityHeaderLength;
+
+ securityBuffers[2].cbBuffer = _securityTrailerLength;
+ securityBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
+ securityBuffers[2].pvBuffer = _pOutBuffer->data() + _securityHeaderLength + messageLength;
+
+ securityBuffers[3].cbBuffer = 0;
+ securityBuffers[3].BufferType = SECBUFFER_EMPTY;
+ securityBuffers[3].pvBuffer = 0;
+
+ SecBufferDesc bufferDesc;
+
+ bufferDesc.ulVersion = SECBUFFER_VERSION;
+ bufferDesc.cBuffers = securityBuffers.size();
+ bufferDesc.pBuffers = securityBuffers.data();
+
+ SECURITY_STATUS ss = EncryptMessage(_phctxt, 0, &bufferDesc, 0);
+
+ if (ss < SEC_E_OK) {
+ ec = asio::error_code(ss, asio::error::get_ssl_category());
+ return ssl_want::want_nothing;
+ }
+
+ size_t size =
+ securityBuffers[0].cbBuffer + securityBuffers[1].cbBuffer + securityBuffers[2].cbBuffer;
+
+ _pOutBuffer->resize(size);
+
+ // Tell asio that all the clear text was transfered.
+ bytes_transferred = messageLength;
+
+ return ssl_want::want_output;
+}
+
+} // namespace detail
+} // namespace ssl
+} // namespace asio
diff --git a/src/mongo/util/net/ssl/detail/schannel.hpp b/src/mongo/util/net/ssl/detail/schannel.hpp
new file mode 100644
index 00000000000..4cf00271f0e
--- /dev/null
+++ b/src/mongo/util/net/ssl/detail/schannel.hpp
@@ -0,0 +1,600 @@
+/**
+ * Copyright (C) 2018 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects
+ * for all of the code used other than as permitted herein. If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so. If you do not
+ * wish to do so, delete this exception statement from your version. If you
+ * delete this exception statement from all source files in the program,
+ * then also delete it in the license file.
+ */
+
+#pragma once
+
+#include "asio/detail/config.hpp"
+
+#include <algorithm>
+#include <cstddef>
+#include <memory>
+
+#include "asio/detail/push_options.hpp"
+
+
+#include "asio/detail/assert.hpp"
+
+#define ASSERT_STATE_TRANSITION(orig, dest) ASIO_ASSERT(!(orig) || (dest));
+
+namespace asio {
+namespace ssl {
+namespace detail {
+
+/**
+ * Reusable buffer. Behaves as a sort of producer consumer queue in a sense.
+ *
+ * Data is added to the buffer then removed.
+ *
+ * Typical workflow:
+ * - Write data
+ * - Write more data
+ * - Read some data
+ * - Keeping reading until empty
+ *
+ * Invariants:
+ * - Once reading from a buffer is started, no more writes are permitted until
+ * consumer has read all the entire buffer.
+ */
+class ReusableBuffer {
+public:
+ ReusableBuffer(std::size_t initialSize) {
+ _buffer = std::make_unique<std::uint8_t[]>(initialSize);
+ _capacity = initialSize;
+ }
+
+ /**
+ * Is buffer empty?
+ */
+ bool empty() const {
+ return _size == 0;
+ }
+
+ /**
+ * Get raw pointer to buffer.
+ */
+ std::uint8_t* data() {
+ return _buffer.get();
+ }
+
+ /**
+ * Get current number of elements in buffer.
+ */
+ std::size_t size() const {
+ return _size;
+ }
+
+ /**
+ * Reset to empty state.
+ */
+ void reset() {
+ _bufPos = 0;
+ _size = 0;
+ }
+
+ /**
+ * Add data to empty buffer.
+ */
+ void fill(const std::vector<std::uint8_t>& vec) {
+ ASIO_ASSERT(_size == 0);
+ ASIO_ASSERT(_bufPos == 0);
+ append(vec.data(), vec.size());
+ }
+
+ /**
+ * Reset current position to specified pointer in buffer.
+ */
+ void resetPos(void* pos, std::size_t size) {
+ ASIO_ASSERT(pos >= _buffer.get() && pos < (_buffer.get() + _size));
+ _bufPos = (std::uint8_t*)pos - _buffer.get();
+ resize(_bufPos + size);
+ }
+
+ /**
+ * Append data to buffer.
+ */
+ void append(const void* data, std::size_t length) {
+ ASIO_ASSERT(_bufPos == 0);
+ auto originalSize = _size;
+ resize(_size + length);
+ std::copy(reinterpret_cast<const std::uint8_t*>(data),
+ reinterpret_cast<const std::uint8_t*>(data) + length,
+ _buffer.get() + originalSize);
+ }
+
+ /**
+ * Read data from buffer. Can be a partial read.
+ */
+ void readInto(void* data, std::size_t length, std::size_t& outLength) {
+ if (length >= (size() - _bufPos)) {
+ // We have less then ASIO wants, give them everything we have
+ outLength = size() - _bufPos;
+ memcpy_s(data, length, _buffer.get() + _bufPos, size() - _bufPos);
+
+ // We are empty so reset our state to need encrypted data for the next call
+ _bufPos = 0;
+ _size = 0;
+ } else {
+ // ASIO wants less then we have so give them just what they want
+ outLength = length;
+ memcpy_s(data, length, _buffer.get() + _bufPos, length);
+
+ _bufPos += length;
+ }
+ }
+
+ /**
+ * Realloc buffer preserving existing data.
+ */
+ void resize(std::size_t size) {
+ if (size > _capacity) {
+ auto temp = std::make_unique<unsigned char[]>(size);
+
+ memcpy_s(temp.get(), size, _buffer.get(), _size);
+ _buffer.swap(temp);
+ _capacity = size;
+ }
+ _size = size;
+ }
+
+ /**
+ * Swap current buffer with other buffer.
+ */
+ void swap(ReusableBuffer& other) {
+ std::swap(_buffer, other._buffer);
+ std::swap(_bufPos, other._bufPos);
+ std::swap(_size, other._size);
+ std::swap(_capacity, other._capacity);
+ }
+
+private:
+ // Buffer of data
+ std::unique_ptr<std::uint8_t[]> _buffer;
+
+ // Current read position in buffer
+ std::size_t _bufPos{0};
+
+ // Count of elements in buffer
+ std::size_t _size{0};
+
+ // Capacity of buffer for elements, always >= _size
+ std::size_t _capacity;
+};
+
+// Default buffer size. SSL has a max encapsulated packet size of 16 kb.
+const std::size_t kDefaultBufferSize = 17 * 1024;
+
+// This enum mirrors the engine::want enum. The values must be kept in sync
+// to support a simple conversion from ssl_want to engine::want, see ssl_want_to_engine.
+enum class ssl_want {
+ // Returned by functions to indicate that the engine wants input. The input
+ // buffer should be updated to point to the data. The engine then needs to
+ // be called again to retry the operation.
+ want_input_and_retry = -2,
+
+ // Returned by functions to indicate that the engine wants to write output.
+ // The output buffer points to the data to be written. The engine then
+ // needs to be called again to retry the operation.
+ want_output_and_retry = -1,
+
+ // Returned by functions to indicate that the engine doesn't need input or
+ // output.
+ want_nothing = 0,
+
+ // Returned by functions to indicate that the engine wants to write output.
+ // The output buffer points to the data to be written. After that the
+ // operation is complete, and the engine does not need to be called again.
+ want_output = 1
+};
+
+/**
+ * Manages the SSL handshake and shutdown state machines.
+ *
+ * Handshakes are always the first set of events during SSL connection initiation.
+ * Shutdown can occur anytime after the handshake has succesfully finished
+ * as a result of a read event or explicit shutdown request from the engine.
+ */
+class SSLHandshakeManager {
+public:
+ /**
+ * Handshake Mode indicates whether this a for a client or server side.
+ *
+ * Each given connection can only be a client or server, and it cannot change once set.
+ */
+ enum class HandshakeMode {
+ // Initial state, illegal for clients to set
+ Unknown,
+
+ // Client handshake, connect side
+ Client,
+
+ // Server handshake, accept side
+ Server,
+ };
+
+ /**
+ * Handshake state indicates to the caller if nextHandshake needs to be called next.
+ */
+ enum class HandshakeState {
+ // Caller should continue to call nextHandshake, the handshake is not done.
+ Continue,
+
+ // Caller should not continue to call nextHandshake, the handshake is done.
+ Done
+ };
+
+ SSLHandshakeManager(PCtxtHandle hctxt,
+ PCredHandle phcred,
+ std::wstring& serverName,
+ ReusableBuffer* pInBuffer,
+ ReusableBuffer* pOutBuffer,
+ ReusableBuffer* pExtraBuffer,
+ SCHANNEL_CRED* cred)
+ : _state(State::HandshakeStart),
+ _phctxt(hctxt),
+ _cred(cred),
+ _serverName(serverName),
+ _phcred(phcred),
+ _pInBuffer(pInBuffer),
+ _pOutBuffer(pOutBuffer),
+ _pExtraEncryptedBuffer(pExtraBuffer),
+ _alertBuffer(1024),
+ _mode(HandshakeMode::Unknown) {}
+
+ /**
+ * Set the current handdshake mode as client or server.
+ *
+ * Idempotent if called with same mode otherwise it asserts.
+ */
+ void setMode(HandshakeMode mode) {
+ ASIO_ASSERT(_mode == HandshakeMode::Unknown || _mode == mode);
+ ASIO_ASSERT(mode != HandshakeMode::Unknown);
+ _mode = mode;
+ }
+
+ /**
+ * Start or continue SSL handshake.
+ *
+ * Must be called until HandshakeState::Done is returned.
+ *
+ * Return status code to indicate whether it needs more data or if data needs to be sent to the
+ * other side.
+ */
+ ssl_want nextHandshake(asio::error_code& ec, HandshakeState* pHandshakeState);
+
+ /**
+ * Begin graceful SSL shutdown. Either:
+ * - respond to already received alert signalling connection shutdown on remote side
+ * - start SSL shutdown by signalling remote side
+ */
+ ssl_want beginShutdown(asio::error_code& ec);
+
+ /*
+ * Ingest data from ASIO that has been received.
+ */
+ void writeEncryptedData(const void* data, std::size_t length);
+
+ /**
+ * Returns true if there is data to send over the wire.
+ */
+ bool hasOutputData() {
+ return !_pOutBuffer->empty();
+ }
+
+ /**
+ * Get data to sent over the network.
+ */
+ void readOutputBuffer(void* data, size_t inLength, size_t& outLength) {
+ _pOutBuffer->readInto(data, inLength, outLength);
+ }
+
+private:
+ void startServerHandshake(asio::error_code& ec);
+
+ void startClientHandshake(asio::error_code& ec);
+
+ DWORD getServerFlags() {
+ return ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
+ ASC_REQ_EXTENDED_ERROR | ASC_REQ_STREAM;
+ }
+
+ DWORD getClientFlags() {
+ return ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_EXTENDED_ERROR | ISC_REQ_STREAM | ISC_REQ_USE_SUPPLIED_CREDS |
+ ISC_REQ_MANUAL_CRED_VALIDATION;
+ }
+
+ /**
+ * RAII class to free a buffer allocated by SSPI.
+ */
+ class ContextBufferDeleter {
+ public:
+ ContextBufferDeleter(void** buf) : _buf(buf) {}
+
+ ~ContextBufferDeleter() {
+ if (*_buf != nullptr) {
+ FreeContextBuffer(*_buf);
+ }
+ }
+
+ private:
+ void** _buf;
+ };
+
+ ssl_want startShutdown(asio::error_code& ec);
+
+ ssl_want doServerHandshake(bool newConversation,
+ asio::error_code& ec,
+ HandshakeState* pHandshakeState);
+
+ ssl_want doClientHandshake(asio::error_code& ec);
+
+private:
+ /**
+ * Handshake State machine:
+ * +-----------------------------+
+ * v |
+ * +----------------+ +-----------------------+ +-------------------+ +------+
+ * | HandshakeStart | --> | NeedMoreHandshakeData | --> | HaveEncryptedData | --> | Done |
+ * +----------------+ +-----------------------+ +-------------------+ +------+
+ *
+ * "[ HandshakeStart ] --> [ NeedMoreHandshakeData ] --> [HaveEncryptedData] -> [
+ * NeedMoreHandshakeData], [Done] " | graph-easy
+ */
+ enum class State {
+ // Initial state
+ HandshakeStart,
+
+ // Handshake needs more data before it decode the next message
+ NeedMoreHandshakeData,
+
+ // Handshake just received some data, and can now try to decrypt it
+ HaveEncryptedData,
+
+ // Handshake is done
+ Done,
+ };
+
+ /**
+ * Transition state machine
+ */
+ void setState(State s) {
+ ASSERT_STATE_TRANSITION(_state == State::HandshakeStart, s == State::NeedMoreHandshakeData);
+ ASSERT_STATE_TRANSITION(_state == State::NeedMoreHandshakeData,
+ s == State::HaveEncryptedData);
+ ASSERT_STATE_TRANSITION(_state == State::HaveEncryptedData,
+ s == State::NeedMoreHandshakeData || s == State::Done);
+ _state = s;
+ }
+
+private:
+ // State machine
+ State _state;
+
+ // Handshake mode - client or server
+ HandshakeMode _mode;
+
+ // Server name for TLS SNI purposes
+ std::wstring& _serverName;
+
+ // Buffer of data received from remote side
+ ReusableBuffer* _pInBuffer;
+
+ // Scratch buffer to capture extra handshake data
+ ReusableBuffer* _pExtraEncryptedBuffer;
+
+ // Buffer to data to send to remote side
+ ReusableBuffer* _pOutBuffer;
+
+ // Buffer of data received from remote side
+ ReusableBuffer _alertBuffer;
+
+ // SChannel Credentials
+ SCHANNEL_CRED* _cred;
+
+ // SChannel context
+ PCtxtHandle _phctxt;
+
+ // Credential handle
+ PCredHandle _phcred;
+};
+
+/**
+ * Manages the SSL read state machine.
+ *
+ * Notifies callers of graceful SSL shutdown events.
+ */
+class SSLReadManager {
+public:
+ /**
+ * Indicates whether client should continue to decrypt data or it needs to handle other protocol
+ * signals.
+ */
+ enum class DecryptState {
+ // SSL connection is proceeding normally
+ Continue,
+
+ // Remote side has signaled graceful SSL shutdown
+ Shutdown,
+
+ // Remote side has signaled renegtiation
+ Renegotiate,
+ };
+
+ SSLReadManager(PCtxtHandle hctxt,
+ PCredHandle hcred,
+ ReusableBuffer* pInBuffer,
+ ReusableBuffer* pExtraBuffer)
+ : _state(State::NeedMoreEncryptedData),
+ _phctxt(hctxt),
+ _phcred(hcred),
+ _pInBuffer(pInBuffer),
+ _pExtraEncryptedBuffer(pExtraBuffer) {}
+
+ /**
+ * Read decrypted data if encrypted data was provided via writeData and succesfully decrypted.
+ */
+ ssl_want readDecryptedData(void* data,
+ std::size_t length,
+ asio::error_code& ec,
+ std::size_t& bytes_transferred,
+ DecryptState* pDecryptState);
+
+ /**
+ * Receive more data from ASIO.
+ */
+ void writeData(const void* data, std::size_t length) {
+ ASIO_ASSERT(_pExtraEncryptedBuffer->empty());
+
+ // We have more data, it may not be enough to decode but we will figure that out later.
+ setState(State::HaveEncryptedData);
+
+ _pInBuffer->append(data, length);
+ }
+
+private:
+ ssl_want decryptBuffer(asio::error_code& ec, DecryptState* pDecryptState);
+
+private:
+ /**
+ * Read State machine:
+ *
+ * +------------------------------------------------------------+
+ * | |
+ * | |
+ * | +-----------------------------+ |
+ * | v | |
+ * | +-----------------------+ +-------------------+ +-------------------+
+ * +> | NeedMoreEncryptedData | --> | HaveEncryptedData | --> | HaveDecryptedData |
+ * +-----------------------+ +-------------------+ +-------------------+
+ * ^ | ^ |
+ * +-------------------+ +-------------------------+
+ *
+ * "[ NeedMoreEncryptedData ] --> [ HaveEncryptedData ] --> [HaveDecryptedData] ->
+ * [NeedMoreEncryptedData], [HaveEncryptedData] --> [NeedMoreEncryptedData] " | graph-easy
+ *
+ */
+ enum class State {
+ // Initial state, Need more data from remote side
+ NeedMoreEncryptedData,
+
+ // Have some encrypted data, unknown if it is a complete packet
+ HaveEncryptedData,
+
+ // Was able to decrypt a packet, give decrypted data back to client
+ HaveDecryptedData,
+ };
+
+ /**
+ * Transition state machine
+ */
+ void setState(State s) {
+ ASSERT_STATE_TRANSITION(_state == State::NeedMoreEncryptedData,
+ s == State::HaveEncryptedData);
+ ASSERT_STATE_TRANSITION(
+ _state == State::HaveEncryptedData,
+ (s == State::NeedMoreEncryptedData || s == State::HaveDecryptedData));
+ ASSERT_STATE_TRANSITION(
+ _state == State::HaveDecryptedData,
+ (s == State::NeedMoreEncryptedData || s == State::HaveEncryptedData));
+ _state = s;
+ }
+
+private:
+ // State machine
+ State _state;
+
+ // Scratch buffer to capture extra decryption data
+ ReusableBuffer* _pExtraEncryptedBuffer;
+
+ // Buffer of data from the remote side
+ ReusableBuffer* _pInBuffer;
+
+ // SChannel context handle
+ PCtxtHandle _phctxt;
+
+ // Credential handle
+ PCredHandle _phcred;
+};
+
+/**
+ * Manages the SSL write state machine.
+ */
+class SSLWriteManager {
+public:
+ SSLWriteManager(PCtxtHandle hctxt, ReusableBuffer* pOutBuffer)
+ : _phctxt(hctxt), _pOutBuffer(pOutBuffer) {}
+
+ /**
+ * Encrypts data to be sent to the remote side.
+ *
+ * If the message is >= max packet side, it will return want_output_and_retry, and expects
+ * callers to continue to call it with the same parameters until want_output is returned.
+ */
+ ssl_want writeUnencryptedData(const void* pMessage,
+ std::size_t messageLength,
+ std::size_t& bytes_transferred,
+ asio::error_code& ec);
+
+ /**
+ * Read encrypted data to be sent to the remote side.
+ */
+ void readOutputBuffer(void* data, size_t inLength, size_t& outLength) {
+ _pOutBuffer->readInto(data, inLength, outLength);
+ }
+
+private:
+ ssl_want encryptMessage(const void* pMessage,
+ std::size_t messageLength,
+ std::size_t& bytes_transferred,
+ asio::error_code& ec);
+
+private:
+ // Buffer of data to send to the remote side
+ ReusableBuffer* _pOutBuffer;
+
+ // SChannel context handle
+ PCtxtHandle _phctxt;
+
+ // Position to start encrypting from for messages needing fragmentation
+ std::size_t _lastWriteOffset{0};
+
+ // TLS packet header length
+ std::size_t _securityHeaderLength{ULONG_MAX};
+
+ // TLS max packet size - 16kb typically
+ std::size_t _securityMaxMessageLength{ULONG_MAX};
+
+ // TLS packet trailer length
+ std::size_t _securityTrailerLength{ULONG_MAX};
+};
+
+#include "asio/detail/pop_options.hpp"
+
+} // namespace detail
+} // namespace ssl
+} // namespace asio
diff --git a/src/mongo/util/net/ssl/impl/context_schannel.ipp b/src/mongo/util/net/ssl/impl/context_schannel.ipp
index 99f6541c24d..52bd1b86797 100644
--- a/src/mongo/util/net/ssl/impl/context_schannel.ipp
+++ b/src/mongo/util/net/ssl/impl/context_schannel.ipp
@@ -30,11 +30,11 @@
#include "asio/detail/config.hpp"
-#include <cstring>
#include "asio/detail/throw_error.hpp"
#include "asio/error.hpp"
#include "mongo/util/net/ssl/context.hpp"
#include "mongo/util/net/ssl/error.hpp"
+#include <cstring>
#include "asio/detail/push_options.hpp"
@@ -42,37 +42,31 @@ namespace asio {
namespace ssl {
-context::context(context::method m)
- : handle_(0)
-{
+context::context(context::method m) : handle_(&_cred) {
+ memset(&_cred, 0, sizeof(_cred));
}
#if defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION)
-context::context(context&& other)
-{
+context::context(context&& other) {
handle_ = other.handle_;
other.handle_ = 0;
}
-context& context::operator=(context&& other)
-{
+context& context::operator=(context&& other) {
context tmp(ASIO_MOVE_CAST(context)(*this));
handle_ = other.handle_;
other.handle_ = 0;
return *this;
}
-#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION)
+#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION)
-context::~context()
-{
-}
+context::~context() {}
-context::native_handle_type context::native_handle()
-{
+context::native_handle_type context::native_handle() {
return handle_;
}
-} // namespace ssl
-} // namespace asio
+} // namespace ssl
+} // namespace asio
#include "asio/detail/pop_options.hpp"
diff --git a/src/mongo/util/net/ssl/impl/error.ipp b/src/mongo/util/net/ssl/impl/error.ipp
index 5de8a890454..64d3a28bae7 100644
--- a/src/mongo/util/net/ssl/impl/error.ipp
+++ b/src/mongo/util/net/ssl/impl/error.ipp
@@ -17,6 +17,7 @@
#include "asio/detail/config.hpp"
#include "mongo/util/net/ssl/error.hpp"
+#include "mongo/util/errno_util.h"
#include "asio/detail/push_options.hpp"
@@ -35,9 +36,7 @@ public:
#if MONGO_CONFIG_SSL_PROVIDER == SSL_PROVIDER_WINDOWS
std::string message(int value) const
{
- // TODO: call FormatMessage
- ASIO_ASSERT(false);
- return "asio.ssl error";
+ return mongo::errnoWithDescription(value);
}
#elif MONGO_CONFIG_SSL_PROVIDER == SSL_PROVIDER_OPENSSL
std::string message(int value) const
diff --git a/src/mongo/util/net/ssl/impl/src.hpp b/src/mongo/util/net/ssl/impl/src.hpp
index 81b903d0f5e..fd940e414af 100644
--- a/src/mongo/util/net/ssl/impl/src.hpp
+++ b/src/mongo/util/net/ssl/impl/src.hpp
@@ -24,6 +24,7 @@
#include "mongo/util/net/ssl/impl/context_schannel.ipp"
#include "mongo/util/net/ssl/impl/error.ipp"
#include "mongo/util/net/ssl/detail/impl/engine_schannel.ipp"
+#include "mongo/util/net/ssl/detail/impl/schannel.ipp"
#elif MONGO_CONFIG_SSL_PROVIDER == SSL_PROVIDER_OPENSSL
diff --git a/src/mongo/util/net/ssl_manager_windows.cpp b/src/mongo/util/net/ssl_manager_windows.cpp
index bcac27724a0..10c48d7c325 100644
--- a/src/mongo/util/net/ssl_manager_windows.cpp
+++ b/src/mongo/util/net/ssl_manager_windows.cpp
@@ -46,6 +46,7 @@
#include "mongo/base/initializer_context.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/config.h"
+#include "mongo/db/server_options.h"
#include "mongo/db/server_parameters.h"
#include "mongo/platform/atomic_word.h"
#include "mongo/stdx/memory.h"
@@ -70,8 +71,92 @@ namespace mongo {
namespace {
SimpleMutex sslManagerMtx;
-SSLManagerInterface* theSSLManagerWindows = NULL;
+SSLManagerInterface* theSSLManager = NULL;
+/**
+* Free a Certificate Context.
+*/
+struct CERTFree {
+ void operator()(const CERT_CONTEXT* p) noexcept {
+ if (p) {
+ ::CertFreeCertificateContext(p);
+ }
+ }
+};
+
+using UniqueCertificate = std::unique_ptr<const CERT_CONTEXT, CERTFree>;
+
+/**
+* A simple generic class to manage Windows handle like things. Behaves similiar to std::unique_ptr.
+*
+* Only supports move.
+*/
+template <typename HandleT, class Deleter>
+class AutoHandle {
+public:
+ AutoHandle() : _handle(0) {}
+ AutoHandle(HandleT handle) : _handle(handle) {}
+ AutoHandle(AutoHandle<HandleT, Deleter>&& handle) : _handle(handle._handle) {
+ handle._handle = 0;
+ }
+
+ ~AutoHandle() {
+ if (_handle != 0) {
+ Deleter()(_handle);
+ }
+ }
+
+ AutoHandle(const AutoHandle&) = delete;
+
+ /**
+ * Take ownership of the handle.
+ */
+ AutoHandle& operator=(const HandleT other) {
+ _handle = other;
+ return *this;
+ }
+
+ AutoHandle& operator=(const AutoHandle<HandleT, Deleter>& other) = delete;
+
+ AutoHandle& operator=(AutoHandle<HandleT, Deleter>&& other) {
+ _handle = other._handle;
+ other._handle = 0;
+ return *this;
+ }
+
+ operator HandleT() {
+ return _handle;
+ }
+
+private:
+ HandleT _handle;
+};
+
+/**
+* Free a HCRYPTPROV Handle
+*/
+struct CryptProviderFree {
+ void operator()(HCRYPTPROV const h) noexcept {
+ if (h) {
+ ::CryptReleaseContext(h, 0);
+ }
+ }
+};
+
+using UniqueCryptProvider = AutoHandle<HCRYPTPROV, CryptProviderFree>;
+
+/**
+* Free a HCRYPTKEY Handle
+*/
+struct CryptKeyFree {
+ void operator()(HCRYPTKEY const h) noexcept {
+ if (h) {
+ ::CryptDestroyKey(h);
+ }
+ }
+};
+
+using UniqueCryptKey = AutoHandle<HCRYPTKEY, CryptKeyFree>;
} // namespace
@@ -80,9 +165,20 @@ SSLManagerInterface* theSSLManagerWindows = NULL;
*/
class SSLConnectionWindows : public SSLConnectionInterface {
public:
+ SCHANNEL_CRED* _cred;
+ Socket* socket;
+ asio::ssl::detail::engine _engine;
+
+ std::vector<char> _tempBuffer;
+
+ SSLConnectionWindows(SCHANNEL_CRED* cred, Socket* sock, const char* initialBytes, int len);
+
~SSLConnectionWindows();
- std::string getSNIServerName() const final;
+ std::string getSNIServerName() const final {
+ // TODO
+ return "";
+ };
};
@@ -98,32 +194,44 @@ public:
const SSLParams& params,
ConnectionDirection direction) final;
- virtual SSLConnectionInterface* connect(Socket* socket);
+ SSLConnectionInterface* connect(Socket* socket) final;
- virtual SSLConnectionInterface* accept(Socket* socket, const char* initialBytes, int len);
+ SSLConnectionInterface* accept(Socket* socket, const char* initialBytes, int len) final;
- virtual SSLPeerInfo parseAndValidatePeerCertificateDeprecated(
- const SSLConnectionInterface* conn, const std::string& remoteHost);
+ SSLPeerInfo parseAndValidatePeerCertificateDeprecated(const SSLConnectionInterface* conn,
+ const std::string& remoteHost) final;
StatusWith<boost::optional<SSLPeerInfo>> parseAndValidatePeerCertificate(
PCtxtHandle ssl, const std::string& remoteHost) final;
- virtual const SSLConfiguration& getSSLConfiguration() const {
+ const SSLConfiguration& getSSLConfiguration() const final {
return _sslConfiguration;
}
- virtual int SSL_read(SSLConnectionInterface* conn, void* buf, int num);
+ int SSL_read(SSLConnectionInterface* conn, void* buf, int num) final;
- virtual int SSL_write(SSLConnectionInterface* conn, const void* buf, int num);
+ int SSL_write(SSLConnectionInterface* conn, const void* buf, int num) final;
- virtual int SSL_shutdown(SSLConnectionInterface* conn);
+ int SSL_shutdown(SSLConnectionInterface* conn) final;
private:
bool _weakValidation;
bool _allowInvalidCertificates;
bool _allowInvalidHostnames;
SSLConfiguration _sslConfiguration;
+
+ SCHANNEL_CRED _clientCred;
+ SCHANNEL_CRED _serverCred;
+
+ UniqueCertificate _pemCertificate;
+ UniqueCertificate _clusterPEMCertificate;
+ PCCERT_CONTEXT _clientCertificates[1];
+ PCCERT_CONTEXT _serverCertificates[1];
+
+ Status loadCertificates(const SSLParams& params);
+
+ void handshake(SSLConnectionWindows* conn, bool client);
};
// Global variable indicating if this is a server or a client instance
@@ -132,19 +240,27 @@ bool isSSLServer = false;
MONGO_INITIALIZER(SSLManager)(InitializerContext*) {
stdx::lock_guard<SimpleMutex> lck(sslManagerMtx);
if (!isSSLServer || (sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled)) {
- theSSLManagerWindows = new SSLManagerWindows(sslGlobalParams, isSSLServer);
+ theSSLManager = new SSLManagerWindows(sslGlobalParams, isSSLServer);
}
return Status::OK();
}
-SSLConnectionWindows::~SSLConnectionWindows() {}
+SSLConnectionWindows::SSLConnectionWindows(SCHANNEL_CRED* cred,
+ Socket* sock,
+ const char* initialBytes,
+ int len)
+ : _cred(cred), socket(sock), _engine(_cred) {
-std::string SSLConnectionWindows::getSNIServerName() const {
- invariant(false);
- return "";
+ _tempBuffer.resize(17 * 1024);
+
+ if (len > 0) {
+ _engine.put_input(asio::const_buffer(initialBytes, len));
+ }
}
+SSLConnectionWindows::~SSLConnectionWindows() {}
+
std::unique_ptr<SSLManagerInterface> SSLManagerInterface::create(const SSLParams& params,
bool isServer) {
return stdx::make_unique<SSLManagerWindows>(params, isServer);
@@ -152,24 +268,111 @@ std::unique_ptr<SSLManagerInterface> SSLManagerInterface::create(const SSLParams
SSLManagerInterface* getSSLManager() {
stdx::lock_guard<SimpleMutex> lck(sslManagerMtx);
- if (theSSLManagerWindows)
- return theSSLManagerWindows;
+ if (theSSLManager)
+ return theSSLManager;
return NULL;
}
SSLManagerWindows::SSLManagerWindows(const SSLParams& params, bool isServer)
: _weakValidation(params.sslWeakCertificateValidation),
_allowInvalidCertificates(params.sslAllowInvalidCertificates),
- _allowInvalidHostnames(params.sslAllowInvalidHostnames) {}
+ _allowInvalidHostnames(params.sslAllowInvalidHostnames) {
+
+ uassertStatusOK(loadCertificates(params));
+
+ uassertStatusOK(initSSLContext(&_clientCred, params, ConnectionDirection::kOutgoing));
+
+ // TODO: validate client certificate
+
+ // SSL server specific initialization
+ if (isServer) {
+ uassertStatusOK(initSSLContext(&_serverCred, params, ConnectionDirection::kIncoming));
+
+ // TODO: validate server certificate
+ }
+}
int SSLManagerWindows::SSL_read(SSLConnectionInterface* connInterface, void* buf, int num) {
- invariant(false);
- return 0;
+ SSLConnectionWindows* conn = static_cast<SSLConnectionWindows*>(connInterface);
+
+ while (true) {
+ size_t bytes_transferred;
+ asio::error_code ec;
+ asio::ssl::detail::engine::want want =
+ conn->_engine.read(asio::mutable_buffer(buf, num), ec, bytes_transferred);
+ if (ec) {
+ throwSocketError(SocketErrorKind::RECV_ERROR, ec.message());
+ }
+
+ switch (want) {
+ case asio::ssl::detail::engine::want_input_and_retry: {
+ // ASIO wants more data before it can continue:
+ // 1. fetch some from the network
+ // 2. give it to ASIO
+ // 3. retry
+ int ret =
+ recv(conn->socket->rawFD(), reinterpret_cast<char*>(buf), num, portRecvFlags);
+ if (ret == SOCKET_ERROR) {
+ conn->socket->handleRecvError(ret, num);
+ }
+
+ conn->_engine.put_input(asio::const_buffer(buf, ret));
+
+ continue;
+ }
+ case asio::ssl::detail::engine::want_nothing: {
+ // ASIO wants nothing, return to caller with anything transfered.
+ return bytes_transferred;
+ }
+ default:
+ severe() << "Unexpected ASIO state: " << static_cast<int>(want);
+ MONGO_UNREACHABLE;
+ }
+ }
}
int SSLManagerWindows::SSL_write(SSLConnectionInterface* connInterface, const void* buf, int num) {
- invariant(false);
- return 0;
+ SSLConnectionWindows* conn = static_cast<SSLConnectionWindows*>(connInterface);
+
+ while (true) {
+ size_t bytes_transferred;
+ asio::error_code ec;
+ asio::ssl::detail::engine::want want =
+ conn->_engine.write(asio::const_buffer(buf, num), ec, bytes_transferred);
+ if (ec) {
+ throwSocketError(SocketErrorKind::SEND_ERROR, ec.message());
+ }
+
+ switch (want) {
+ case asio::ssl::detail::engine::want_output:
+ case asio::ssl::detail::engine::want_output_and_retry: {
+ // ASIO wants us to send data out:
+ // 1. get data from ASIO
+ // 2. give it to the network
+ // 3. retry if needed
+
+ asio::mutable_buffer outBuf = conn->_engine.get_output(
+ asio::mutable_buffer(conn->_tempBuffer.data(), conn->_tempBuffer.size()));
+
+ int ret = send(conn->socket->rawFD(),
+ reinterpret_cast<const char*>(outBuf.data()),
+ outBuf.size(),
+ portSendFlags);
+ if (ret == SOCKET_ERROR) {
+ conn->socket->handleSendError(ret, "");
+ }
+
+ if (want == asio::ssl::detail::engine::want_output_and_retry) {
+ continue;
+ }
+
+ return bytes_transferred;
+ }
+ default:
+ severe() << "Unexpected ASIO state: " << static_cast<int>(want);
+ MONGO_UNREACHABLE;
+ }
+ }
}
int SSLManagerWindows::SSL_shutdown(SSLConnectionInterface* conn) {
@@ -177,26 +380,426 @@ int SSLManagerWindows::SSL_shutdown(SSLConnectionInterface* conn) {
return 0;
}
+StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData password) {
+
+ std::ifstream pemFile(fileName.toString(), std::ios::binary);
+ if (!pemFile.is_open()) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "Failed to open PEM file: " << fileName);
+ }
+
+ std::string buf((std::istreambuf_iterator<char>(pemFile)), std::istreambuf_iterator<char>());
+
+ pemFile.close();
+
+ // Search the buffer for the various strings that make up a PEM file
+ size_t publicKey = buf.find("-----BEGIN CERTIFICATE-----");
+ if (publicKey == std::string::npos) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "Failed to find Certifiate in: " << fileName);
+ }
+
+ // TODO: decode encrypted pem
+ // StringData encryptedPrivateKey = buf.find("-----BEGIN ENCRYPTED PRIVATE KEY-----");
+
+ // TODO: check if we need both
+ size_t privateKey = buf.find("-----BEGIN RSA PRIVATE KEY-----");
+ if (privateKey == std::string::npos) {
+ privateKey = buf.find("-----BEGIN PRIVATE KEY-----");
+ }
+
+ if (privateKey == std::string::npos) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "Failed to find privateKey in: " << fileName);
+ }
+
+ CERT_BLOB certBlob;
+ certBlob.cbData = buf.size() - publicKey;
+ certBlob.pbData = reinterpret_cast<BYTE*>(const_cast<char*>(buf.data() + publicKey));
+
+ PCCERT_CONTEXT cert;
+ BOOL ret = CryptQueryObject(CERT_QUERY_OBJECT_BLOB,
+ &certBlob,
+ CERT_QUERY_CONTENT_FLAG_ALL,
+ CERT_QUERY_FORMAT_FLAG_ALL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ reinterpret_cast<const void**>(&cert));
+ if (!ret) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptQueryObject failed to get cert: "
+ << errnoWithDescription(gle));
+ }
+
+ UniqueCertificate certHolder(cert);
+ DWORD privateKeyLen{0};
+
+ ret = CryptStringToBinaryA(buf.c_str() + privateKey,
+ 0, // null terminated string
+ CRYPT_STRING_BASE64HEADER | CRYPT_STRING_STRICT,
+ NULL,
+ &privateKeyLen,
+ NULL,
+ NULL);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ if (gle != ERROR_MORE_DATA) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptStringToBinary failed to get size of key: "
+ << errnoWithDescription(gle));
+ }
+ }
+
+ std::unique_ptr<BYTE[]> privateKeyBuf = std::make_unique<BYTE[]>(privateKeyLen);
+ ret = CryptStringToBinaryA(buf.c_str() + privateKey,
+ 0, // null terminated string
+ CRYPT_STRING_BASE64HEADER | CRYPT_STRING_STRICT,
+ privateKeyBuf.get(),
+ &privateKeyLen,
+ NULL,
+ NULL);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptStringToBinary failed to read key: "
+ << errnoWithDescription(gle));
+ }
+
+
+ DWORD privateBlobLen{0};
+
+ ret = CryptDecodeObjectEx(X509_ASN_ENCODING,
+ PKCS_RSA_PRIVATE_KEY,
+ privateKeyBuf.get(),
+ privateKeyLen,
+ CRYPT_DECODE_SHARE_OID_STRING_FLAG,
+ NULL,
+ NULL,
+ &privateBlobLen);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ if (gle != ERROR_MORE_DATA) {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptDecodeObjectEx failed to get size of key: "
+ << errnoWithDescription(gle));
+ }
+ }
+
+ std::unique_ptr<BYTE[]> privateBlobBuf = std::make_unique<BYTE[]>(privateBlobLen);
+
+ ret = CryptDecodeObjectEx(X509_ASN_ENCODING,
+ PKCS_RSA_PRIVATE_KEY,
+ privateKeyBuf.get(),
+ privateKeyLen,
+ CRYPT_DECODE_SHARE_OID_STRING_FLAG,
+ NULL,
+ privateBlobBuf.get(),
+ &privateBlobLen);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptDecodeObjectEx failed to read key: "
+ << errnoWithDescription(gle));
+ }
+
+ HCRYPTPROV hProv;
+ std::wstring wstr;
+
+ // Create the right Crypto context depending on whether we running in a server or outside.
+ // See https://msdn.microsoft.com/en-us/library/windows/desktop/aa375195(v=vs.85).aspx
+ if (isSSLServer) {
+ // Generate a unique name for our key container
+ // Use the the log file if possible
+ if (!serverGlobalParams.logpath.empty()) {
+ wstr = toNativeString(serverGlobalParams.logpath.c_str());
+ } else {
+ auto us = UUID::gen().toString();
+ wstr = toNativeString(us.c_str());
+ }
+
+ // Use a new key container for the key. We cannot use the default container since the
+ // default
+ // container is shared across processes owned by the same user.
+ // Note: Server side Schannel requires CRYPT_VERIFYCONTEXT off
+ ret = CryptAcquireContextW(
+ &hProv, wstr.c_str(), MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_NEWKEYSET | CRYPT_SILENT);
+ if (!ret) {
+ DWORD gle = GetLastError();
+
+ if (gle == NTE_EXISTS) {
+
+ ret = CryptAcquireContextW(
+ &hProv, wstr.c_str(), MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_SILENT);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptAcquireContextW failed "
+ << errnoWithDescription(gle));
+ }
+
+ } else {
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptAcquireContextW failed "
+ << errnoWithDescription(gle));
+ }
+ }
+ } else {
+ // Use a transient key container for the key
+ ret = CryptAcquireContextW(
+ &hProv, NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT | CRYPT_SILENT);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptAcquireContextW failed "
+ << errnoWithDescription(gle));
+ }
+ }
+ UniqueCryptProvider cryptProvider(hProv);
+
+ HCRYPTKEY hkey;
+ ret = CryptImportKey(hProv, privateBlobBuf.get(), privateBlobLen, 0, 0, &hkey);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CryptImportKey failed " << errnoWithDescription(gle));
+ }
+ UniqueCryptKey(hKey);
+
+ if (isSSLServer) {
+ // Server-side SChannel requires a different way of attaching the private key to the
+ // certificate
+ CRYPT_KEY_PROV_INFO keyProvInfo;
+ memset(&keyProvInfo, 0, sizeof(keyProvInfo));
+ keyProvInfo.pwszContainerName = const_cast<wchar_t*>(wstr.c_str());
+ keyProvInfo.pwszProvName = const_cast<wchar_t*>(MS_ENHANCED_PROV);
+ keyProvInfo.dwFlags = CERT_SET_KEY_PROV_HANDLE_PROP_ID | CERT_SET_KEY_CONTEXT_PROP_ID;
+ keyProvInfo.dwProvType = PROV_RSA_FULL;
+ keyProvInfo.dwKeySpec = AT_KEYEXCHANGE;
+
+ if (!CertSetCertificateContextProperty(
+ certHolder.get(), CERT_KEY_PROV_INFO_PROP_ID, 0, &keyProvInfo)) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CertSetCertificateContextProperty Failed "
+ << errnoWithDescription(gle));
+ }
+ }
+
+ // NOTE: This is used to set the certificate for client side SChannel
+ ret = CertSetCertificateContextProperty(
+ cert, CERT_KEY_PROV_HANDLE_PROP_ID, 0, (const void*)hProv);
+ if (!ret) {
+ DWORD gle = GetLastError();
+ return Status(ErrorCodes::InvalidSSLConfiguration,
+ str::stream() << "CertSetCertificateContextProperty failed "
+ << errnoWithDescription(gle));
+ }
+
+ return std::move(certHolder);
+}
+
+Status SSLManagerWindows::loadCertificates(const SSLParams& params) {
+ _clientCertificates[0] = nullptr;
+ _serverCertificates[0] = nullptr;
+
+ // Load the normal PEM file
+ if (!params.sslPEMKeyFile.empty()) {
+ auto swCertificate = readPEMFile(params.sslPEMKeyFile, params.sslPEMKeyPassword);
+ if (!swCertificate.isOK()) {
+ return swCertificate.getStatus();
+ }
+
+ _pemCertificate = std::move(swCertificate.getValue());
+ }
+
+ // Load the cluster PEM file, only applies to server side code
+ if (!params.sslClusterFile.empty()) {
+ auto swCertificate = readPEMFile(params.sslClusterFile, params.sslClusterPassword);
+ if (!swCertificate.isOK()) {
+ return swCertificate.getStatus();
+ }
+
+ _clusterPEMCertificate = std::move(swCertificate.getValue());
+ }
+
+ if (_pemCertificate) {
+ _clientCertificates[0] = _pemCertificate.get();
+ _serverCertificates[0] = _pemCertificate.get();
+ }
+
+ if (_clusterPEMCertificate) {
+ _clientCertificates[0] = _clusterPEMCertificate.get();
+ }
+
+ return Status::OK();
+}
+
Status SSLManagerWindows::initSSLContext(SCHANNEL_CRED* cred,
const SSLParams& params,
ConnectionDirection direction) {
+ memset(cred, 0, sizeof(*cred));
+ cred->dwVersion = SCHANNEL_CRED_VERSION;
+ cred->dwFlags = SCH_USE_STRONG_CRYPTO; // Use strong crypto;
+
+ uint32_t supportedProtocols = 0;
+
+ if (direction == ConnectionDirection::kIncoming) {
+ supportedProtocols = SP_PROT_TLS1_SERVER | SP_PROT_TLS1_0_SERVER | SP_PROT_TLS1_1_SERVER |
+ SP_PROT_TLS1_2_SERVER;
+
+ cred->dwFlags = cred->dwFlags // Flags
+ | SCH_CRED_REVOCATION_CHECK_CHAIN // Check certificate revocation
+ | SCH_CRED_SNI_CREDENTIAL // Pass along SNI creds
+ | SCH_CRED_SNI_ENABLE_OCSP // Enable OCSP
+ | SCH_CRED_NO_SYSTEM_MAPPER // Do not map certificate to user account
+ | SCH_CRED_DISABLE_RECONNECTS; // Do not support reconnects
+ } else {
+ supportedProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_0_CLIENT | SP_PROT_TLS1_1_CLIENT |
+ SP_PROT_TLS1_2_CLIENT;
+
+ cred->dwFlags = cred->dwFlags // Flags
+ | SCH_CRED_REVOCATION_CHECK_CHAIN // Check certificate revocation
+ | SCH_CRED_NO_SERVERNAME_CHECK // Do not validate server name against cert
+ | SCH_CRED_NO_DEFAULT_CREDS // No Default Certificate
+ | SCH_CRED_MANUAL_CRED_VALIDATION; // Validate Certificate Manually
+ }
+
+ // Set the supported TLS protocols. Allow --sslDisabledProtocols to disable selected ciphers.
+ for (const SSLParams::Protocols& protocol : params.sslDisabledProtocols) {
+ if (protocol == SSLParams::Protocols::TLS1_0) {
+ supportedProtocols &= ~(SP_PROT_TLS1_0_CLIENT | SP_PROT_TLS1_0_SERVER);
+ } else if (protocol == SSLParams::Protocols::TLS1_1) {
+ supportedProtocols &= ~(SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_1_SERVER);
+ } else if (protocol == SSLParams::Protocols::TLS1_2) {
+ supportedProtocols &= ~(SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_2_SERVER);
+ }
+ }
+
+ cred->grbitEnabledProtocols = supportedProtocols;
+
+ if (!params.sslCipherConfig.empty()) {
+ warning()
+ << "sslCipherConfig parameter is not supported with Windows SChannel and is ignored.";
+ }
+
+ if (direction == ConnectionDirection::kOutgoing) {
+ if (_clientCertificates[0]) {
+ cred->cCreds = 1;
+ cred->paCred = _clientCertificates;
+ }
+ } else {
+ cred->cCreds = 1;
+ cred->paCred = _serverCertificates;
+ }
+
return Status::OK();
}
SSLConnectionInterface* SSLManagerWindows::connect(Socket* socket) {
- return nullptr;
+ std::unique_ptr<SSLConnectionWindows> sslConn =
+ stdx::make_unique<SSLConnectionWindows>(&_clientCred, socket, nullptr, 0);
+
+ handshake(sslConn.get(), true);
+ return sslConn.release();
}
SSLConnectionInterface* SSLManagerWindows::accept(Socket* socket,
const char* initialBytes,
int len) {
- return nullptr;
+ std::unique_ptr<SSLConnectionWindows> sslConn =
+ stdx::make_unique<SSLConnectionWindows>(&_serverCred, socket, initialBytes, len);
+
+ handshake(sslConn.get(), false);
+
+ return sslConn.release();
+}
+
+void SSLManagerWindows::handshake(SSLConnectionWindows* conn, bool client) {
+ initSSLContext(conn->_cred,
+ getSSLGlobalParams(),
+ client ? SSLManagerInterface::ConnectionDirection::kOutgoing
+ : SSLManagerInterface::ConnectionDirection::kIncoming);
+
+ while (true) {
+ asio::error_code ec;
+ asio::ssl::detail::engine::want want =
+ conn->_engine.handshake(client ? asio::ssl::stream_base::handshake_type::client
+ : asio::ssl::stream_base::handshake_type::server,
+ ec);
+ if (ec) {
+ throwSocketError(SocketErrorKind::RECV_ERROR, ec.message());
+ }
+
+ switch (want) {
+ case asio::ssl::detail::engine::want_input_and_retry: {
+ // ASIO wants more data before it can continue,
+ // 1. fetch some from the network
+ // 2. give it to ASIO
+ // 3. retry
+ int ret = recv(conn->socket->rawFD(),
+ conn->_tempBuffer.data(),
+ conn->_tempBuffer.size(),
+ portRecvFlags);
+ if (ret == SOCKET_ERROR) {
+ conn->socket->handleRecvError(ret, conn->_tempBuffer.size());
+ }
+
+ conn->_engine.put_input(asio::const_buffer(conn->_tempBuffer.data(), ret));
+
+ continue;
+ }
+ case asio::ssl::detail::engine::want_output:
+ case asio::ssl::detail::engine::want_output_and_retry: {
+ // ASIO wants us to send data out
+ // 1. get data from ASIO
+ // 2. give it to the network
+ // 3. retry if needed
+ asio::mutable_buffer outBuf = conn->_engine.get_output(
+ asio::mutable_buffer(conn->_tempBuffer.data(), conn->_tempBuffer.size()));
+
+ int ret = send(conn->socket->rawFD(),
+ reinterpret_cast<const char*>(outBuf.data()),
+ outBuf.size(),
+ portSendFlags);
+ if (ret == SOCKET_ERROR) {
+ conn->socket->handleSendError(ret, "");
+ }
+
+ if (want == asio::ssl::detail::engine::want_output_and_retry) {
+ continue;
+ }
+
+ // ASIO wants nothing, return to caller since we are done with handshake
+ return;
+ }
+ case asio::ssl::detail::engine::want_nothing: {
+ // ASIO wants nothing, return to caller since we are done with handshake
+ return;
+ }
+ default:
+ MONGO_UNREACHABLE;
+ }
+ }
}
SSLPeerInfo SSLManagerWindows::parseAndValidatePeerCertificateDeprecated(
const SSLConnectionInterface* conn, const std::string& remoteHost) {
- return SSLPeerInfo();
+ auto swPeerSubjectName = parseAndValidatePeerCertificate(
+ const_cast<SSLConnectionWindows*>(static_cast<const SSLConnectionWindows*>(conn))
+ ->_engine.native_handle(),
+ remoteHost);
+ // We can't use uassertStatusOK here because we need to throw a SocketException.
+ if (!swPeerSubjectName.isOK()) {
+ throwSocketError(SocketErrorKind::CONNECT_ERROR, swPeerSubjectName.getStatus().reason());
+ }
+
+ return swPeerSubjectName.getValue().get_value_or(SSLPeerInfo());
}
StatusWith<boost::optional<SSLPeerInfo>> SSLManagerWindows::parseAndValidatePeerCertificate(
@@ -205,5 +808,4 @@ StatusWith<boost::optional<SSLPeerInfo>> SSLManagerWindows::parseAndValidatePeer
return {boost::none};
}
-
} // namespace mongo
diff --git a/src/mongo/util/net/ssl_stream.cpp b/src/mongo/util/net/ssl_stream.cpp
index b7022df9532..a99723d4a32 100644
--- a/src/mongo/util/net/ssl_stream.cpp
+++ b/src/mongo/util/net/ssl_stream.cpp
@@ -7,6 +7,8 @@
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//
+#include "mongo/platform/basic.h"
+
#include "mongo/config.h"
#ifdef MONGO_CONFIG_SSL