diff options
author | Mark Benvenuto <mark.benvenuto@mongodb.com> | 2018-02-27 14:10:06 -0500 |
---|---|---|
committer | Mark Benvenuto <mark.benvenuto@mongodb.com> | 2018-02-27 14:10:06 -0500 |
commit | f2a8d9f2350f8cd5122cf3394b6783e85da5c390 (patch) | |
tree | 4782eb9f741bebb3954cd429e36b5b66f6f8d747 /src/mongo/util/net | |
parent | d4ae81a7154ab57a266b38d4fe41dd12a3c4540a (diff) | |
download | mongo-f2a8d9f2350f8cd5122cf3394b6783e85da5c390.tar.gz |
SERVER-22411 ASIO SChannel stream implementation
Diffstat (limited to 'src/mongo/util/net')
-rw-r--r-- | src/mongo/util/net/ssl/context_schannel.hpp | 19 | ||||
-rw-r--r-- | src/mongo/util/net/ssl/detail/engine_schannel.hpp | 181 | ||||
-rw-r--r-- | src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp | 150 | ||||
-rw-r--r-- | src/mongo/util/net/ssl/detail/impl/schannel.ipp | 747 | ||||
-rw-r--r-- | src/mongo/util/net/ssl/detail/schannel.hpp | 600 | ||||
-rw-r--r-- | src/mongo/util/net/ssl/impl/context_schannel.ipp | 26 | ||||
-rw-r--r-- | src/mongo/util/net/ssl/impl/error.ipp | 5 | ||||
-rw-r--r-- | src/mongo/util/net/ssl/impl/src.hpp | 1 | ||||
-rw-r--r-- | src/mongo/util/net/ssl_manager_windows.cpp | 654 | ||||
-rw-r--r-- | src/mongo/util/net/ssl_stream.cpp | 2 |
10 files changed, 2228 insertions, 157 deletions
diff --git a/src/mongo/util/net/ssl/context_schannel.hpp b/src/mongo/util/net/ssl/context_schannel.hpp index 12293cbc8e9..043d8cfe4ca 100644 --- a/src/mongo/util/net/ssl/context_schannel.hpp +++ b/src/mongo/util/net/ssl/context_schannel.hpp @@ -30,20 +30,17 @@ #include "asio/detail/config.hpp" -#include <string> #include "asio/buffer.hpp" #include "asio/io_context.hpp" #include "mongo/util/net/ssl/context_base.hpp" +#include <string> #include "asio/detail/push_options.hpp" namespace asio { namespace ssl { -class context - : public context_base, - private noncopyable -{ +class context : public context_base, private noncopyable { public: /// The native handle type of the SSL context. typedef SCHANNEL_CRED* native_handle_type; @@ -77,7 +74,7 @@ public: * @li As a target for move-assignment. */ ASIO_DECL context& operator=(context&& other); -#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION) +#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION) /// Destructor. ASIO_DECL ~context(); @@ -91,6 +88,8 @@ public: ASIO_DECL native_handle_type native_handle(); private: + SCHANNEL_CRED _cred; + // The underlying native implementation. native_handle_type handle_; }; @@ -98,8 +97,8 @@ private: #include "asio/detail/pop_options.hpp" #if defined(ASIO_HEADER_ONLY) -# include "mongo/util/net/ssl/impl/context_schannel.ipp" -#endif // defined(ASIO_HEADER_ONLY) +#include "mongo/util/net/ssl/impl/context_schannel.ipp" +#endif // defined(ASIO_HEADER_ONLY) -} // namespace ssl -} // namespace asio +} // namespace ssl +} // namespace asio diff --git a/src/mongo/util/net/ssl/detail/engine_schannel.hpp b/src/mongo/util/net/ssl/detail/engine_schannel.hpp index fa3bad1a7af..f56b71ff8c0 100644 --- a/src/mongo/util/net/ssl/detail/engine_schannel.hpp +++ b/src/mongo/util/net/ssl/detail/engine_schannel.hpp @@ -32,6 +32,7 @@ #include "asio/buffer.hpp" #include "asio/detail/static_mutex.hpp" +#include "mongo/util/net/ssl/detail/schannel.hpp" #include "mongo/util/net/ssl/stream_base.hpp" #include "asio/detail/push_options.hpp" @@ -40,81 +41,125 @@ namespace asio { namespace ssl { namespace detail { -class engine -{ +class engine { public: - enum want - { - // Returned by functions to indicate that the engine wants input. The input - // buffer should be updated to point to the data. The engine then needs to - // be called again to retry the operation. - want_input_and_retry = -2, - - // Returned by functions to indicate that the engine wants to write output. - // The output buffer points to the data to be written. The engine then - // needs to be called again to retry the operation. - want_output_and_retry = -1, - - // Returned by functions to indicate that the engine doesn't need input or - // output. - want_nothing = 0, - - // Returned by functions to indicate that the engine wants to write output. - // The output buffer points to the data to be written. After that the - // operation is complete, and the engine does not need to be called again. - want_output = 1 - }; - - // Construct a new engine for the specified context. - ASIO_DECL explicit engine(SCHANNEL_CRED* context); - - // Destructor. - ASIO_DECL ~engine(); - - // Get the underlying implementation in the native type. - ASIO_DECL PCtxtHandle native_handle(); - - // Perform an SSL handshake using either SSL_connect (client-side) or - // SSL_accept (server-side). - ASIO_DECL want handshake( - stream_base::handshake_type type, asio::error_code& ec); - - // Perform a graceful shutdown of the SSL session. - ASIO_DECL want shutdown(asio::error_code& ec); - - // Write bytes to the SSL session. - ASIO_DECL want write(const asio::const_buffer& data, - asio::error_code& ec, std::size_t& bytes_transferred); - - // Read bytes from the SSL session. - ASIO_DECL want read(const asio::mutable_buffer& data, - asio::error_code& ec, std::size_t& bytes_transferred); - - // Get output data to be written to the transport. - ASIO_DECL asio::mutable_buffer get_output( - const asio::mutable_buffer& data); - - // Put input data that was read from the transport. - ASIO_DECL asio::const_buffer put_input( - const asio::const_buffer& data); - - // Map an error::eof code returned by the underlying transport according to - // the type and state of the SSL session. Returns a const reference to the - // error code object, suitable for passing to a completion handler. - ASIO_DECL const asio::error_code& map_error_code( - asio::error_code& ec) const; + enum want { + // Returned by functions to indicate that the engine wants input. The input + // buffer should be updated to point to the data. The engine then needs to + // be called again to retry the operation. + want_input_and_retry = -2, + + // Returned by functions to indicate that the engine wants to write output. + // The output buffer points to the data to be written. The engine then + // needs to be called again to retry the operation. + want_output_and_retry = -1, + + // Returned by functions to indicate that the engine doesn't need input or + // output. + want_nothing = 0, + + // Returned by functions to indicate that the engine wants to write output. + // The output buffer points to the data to be written. After that the + // operation is complete, and the engine does not need to be called again. + want_output = 1 + }; + + // Construct a new engine for the specified context. + ASIO_DECL explicit engine(SCHANNEL_CRED* context); + + // Destructor. + ASIO_DECL ~engine(); + + // Get the underlying implementation in the native type. + ASIO_DECL PCtxtHandle native_handle(); + + // Perform an SSL handshake using either SSL_connect (client-side) or + // SSL_accept (server-side). + ASIO_DECL want handshake(stream_base::handshake_type type, asio::error_code& ec); + + // Perform a graceful shutdown of the SSL session. + ASIO_DECL want shutdown(asio::error_code& ec); + + // Write bytes to the SSL session. + ASIO_DECL want write(const asio::const_buffer& data, + asio::error_code& ec, + std::size_t& bytes_transferred); + + // Read bytes from the SSL session. + ASIO_DECL want read(const asio::mutable_buffer& data, + asio::error_code& ec, + std::size_t& bytes_transferred); + + // Get output data to be written to the transport. + ASIO_DECL asio::mutable_buffer get_output(const asio::mutable_buffer& data); + + // Put input data that was read from the transport. + ASIO_DECL asio::const_buffer put_input(const asio::const_buffer& data); + + // Map an error::eof code returned by the underlying transport according to + // the type and state of the SSL session. Returns a const reference to the + // error code object, suitable for passing to a completion handler. + ASIO_DECL const asio::error_code& map_error_code(asio::error_code& ec) const; + + // MONGODB additions: + // Set the Server name for TLS SNI purposes. + ASIO_DECL void set_server_name(const std::wstring name); private: - // Disallow copying and assignment. - engine(const engine&); - engine& operator=(const engine&); + // Disallow copying and assignment. + engine(const engine&); + engine& operator=(const engine&); private: + // SChannel context handle + CtxtHandle _hcxt; + + // Credential handle + CredHandle _hcred; + + // Credentials for TLS handshake + SCHANNEL_CRED* _pCred; + + // TLS SNI server name + std::wstring _serverName; + + // Engine State machine + // + enum class EngineState { + // Initial State + NeedsHandshake, + + // Normal SSL Conversation in progress + InProgress, + + // In SSL shutdown + InShutdown, + }; + + // Engine state + EngineState _state{EngineState::NeedsHandshake}; + + // Data received from remote side, shared across state machines + ReusableBuffer _inBuffer; + + // Data to send to remote side, shared across state machines + ReusableBuffer _outBuffer; + + // Extra buffer - for when more then one packet is read from the remote side + ReusableBuffer _extraBuffer; + + // Handshake state machine + SSLHandshakeManager _handshakeManager; + + // Read state machine + SSLReadManager _readManager; + // Write state machine + SSLWriteManager _writeManager; }; #include "asio/detail/pop_options.hpp" -} // namespace detail -} // namespace ssl -} // namespace asio +} // namespace detail +} // namespace ssl +} // namespace asio diff --git a/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp b/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp index 41aca7f789f..a2d5d8c21f5 100644 --- a/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp +++ b/src/mongo/util/net/ssl/detail/impl/engine_schannel.ipp @@ -41,62 +41,144 @@ namespace asio { namespace ssl { namespace detail { + engine::engine(SCHANNEL_CRED* context) -{ + : _pCred(context), + _hcxt{0, 0}, + _hcred{0, 0}, + _inBuffer(kDefaultBufferSize), + _outBuffer(kDefaultBufferSize), + _extraBuffer(kDefaultBufferSize), + _handshakeManager( + &_hcxt, &_hcred, _serverName, &_inBuffer, &_outBuffer, &_extraBuffer, _pCred), + _readManager(&_hcxt, &_hcred, &_inBuffer, &_extraBuffer), + _writeManager(&_hcxt, &_outBuffer) {} + +engine::~engine() { + DeleteSecurityContext(&_hcxt); + FreeCredentialsHandle(&_hcred); } -engine::~engine() -{ +PCtxtHandle engine::native_handle() { + return &_hcxt; } -PCtxtHandle engine::native_handle() -{ - return nullptr; +engine::want ssl_want_to_engine(ssl_want want) { + static_assert(static_cast<int>(ssl_want::want_input_and_retry) == + static_cast<int>(engine::want_input_and_retry), + "bad"); + static_assert(static_cast<int>(ssl_want::want_output_and_retry) == + static_cast<int>(engine::want_output_and_retry), + "bad"); + static_assert( + static_cast<int>(ssl_want::want_nothing) == static_cast<int>(engine::want_nothing), "bad"); + static_assert(static_cast<int>(ssl_want::want_output) == static_cast<int>(engine::want_output), + "bad"); + + return static_cast<engine::want>(want); } -engine::want engine::handshake( - stream_base::handshake_type type, asio::error_code& ec) -{ - return want::want_nothing; +engine::want engine::handshake(stream_base::handshake_type type, asio::error_code& ec) { + // ASIO will call handshake once more after we send out the last data + // so we need to tell them we are done with data to send. + if (_state != EngineState::NeedsHandshake) { + return want::want_nothing; + } + + _handshakeManager.setMode((type == asio::ssl::stream_base::client) + ? SSLHandshakeManager::HandshakeMode::Client + : SSLHandshakeManager::HandshakeMode::Server); + SSLHandshakeManager::HandshakeState state; + auto w = _handshakeManager.nextHandshake(ec, &state); + if (w == ssl_want::want_nothing || state == SSLHandshakeManager::HandshakeState::Done) { + _state = EngineState::InProgress; + } + + return ssl_want_to_engine(w); } -engine::want engine::shutdown(asio::error_code& ec) -{ - return want::want_nothing; +engine::want engine::shutdown(asio::error_code& ec) { + return ssl_want_to_engine(_handshakeManager.beginShutdown(ec)); } engine::want engine::write(const asio::const_buffer& data, - asio::error_code& ec, std::size_t& bytes_transferred) -{ - return want::want_nothing; + asio::error_code& ec, + std::size_t& bytes_transferred) { + if (data.size() == 0) { + ec = asio::error_code(); + return engine::want_nothing; + } + + if (_state == EngineState::NeedsHandshake || _state == EngineState::InShutdown) { + // Why are we trying to write before the handshake is done? + ASIO_ASSERT(false); + return want::want_nothing; + } else { + return ssl_want_to_engine( + _writeManager.writeUnencryptedData(data.data(), data.size(), bytes_transferred, ec)); + } } engine::want engine::read(const asio::mutable_buffer& data, - asio::error_code& ec, std::size_t& bytes_transferred) -{ - return want::want_nothing; + asio::error_code& ec, + std::size_t& bytes_transferred) { + if (data.size() == 0) { + ec = asio::error_code(); + return engine::want_nothing; + } + + + if (_state == EngineState::NeedsHandshake) { + // Why are we trying to read before the handshake is done? + ASIO_ASSERT(false); + return want::want_nothing; + } else { + SSLReadManager::DecryptState decryptState; + auto want = ssl_want_to_engine(_readManager.readDecryptedData( + data.data(), data.size(), ec, bytes_transferred, &decryptState)); + if (ec) { + return want; + } + + if (decryptState != SSLReadManager::DecryptState::Continue) { + if (decryptState == SSLReadManager::DecryptState::Shutdown) { + _state = EngineState::InShutdown; + + return ssl_want_to_engine(_handshakeManager.beginShutdown(ec)); + } + } + + return want; + } } -asio::mutable_buffer engine::get_output( - const asio::mutable_buffer& data) -{ - return asio::mutable_buffer(nullptr, 0); +asio::mutable_buffer engine::get_output(const asio::mutable_buffer& data) { + std::size_t length; + _outBuffer.readInto(data.data(), data.size(), length); + + return asio::buffer(data, length); +} + +asio::const_buffer engine::put_input(const asio::const_buffer& data) { + if (_state == EngineState::NeedsHandshake) { + _handshakeManager.writeEncryptedData(data.data(), data.size()); + } else { + _readManager.writeData(data.data(), data.size()); + } + + return asio::buffer(data + data.size()); } -asio::const_buffer engine::put_input( - const asio::const_buffer& data) -{ - return asio::const_buffer(nullptr, 0); +void engine::set_server_name(const std::wstring name) { + _serverName = name; } -const asio::error_code& engine::map_error_code( - asio::error_code& ec) const -{ - return ec; +const asio::error_code& engine::map_error_code(asio::error_code& ec) const { + return ec; } #include "asio/detail/pop_options.hpp" -} // namespace detail -} // namespace ssl -} // namespace asio +} // namespace detail +} // namespace ssl +} // namespace asio diff --git a/src/mongo/util/net/ssl/detail/impl/schannel.ipp b/src/mongo/util/net/ssl/detail/impl/schannel.ipp new file mode 100644 index 00000000000..14da4977b6d --- /dev/null +++ b/src/mongo/util/net/ssl/detail/impl/schannel.ipp @@ -0,0 +1,747 @@ +/** + * Copyright (C) 2018 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects + * for all of the code used other than as permitted herein. If you modify + * file(s) with this exception, you may extend this exception to your + * version of the file(s), but you are not obligated to do so. If you do not + * wish to do so, delete this exception statement from your version. If you + * delete this exception statement from all source files in the program, + * then also delete it in the license file. + */ + +#pragma once + +#include <algorithm> +#include <cstddef> +#include <memory> + +#include "asio/detail/assert.hpp" + +namespace asio { +namespace ssl { +namespace detail { + +/** + * Start or continue SSL handshake. + * + * Must be called until HandshakeState::Done is returned. + * + * Return status code to indicate whether it needs more data or if data needs to be sent to the + * other side. + */ +ssl_want SSLHandshakeManager::nextHandshake(asio::error_code& ec, HandshakeState* pHandshakeState) { + ASIO_ASSERT(_mode != HandshakeMode::Unknown); + ec = asio::error_code(); + *pHandshakeState = HandshakeState::Continue; + + if (_state == State::HandshakeStart) { + ssl_want want; + + if (_mode == HandshakeMode::Server) { + // ASIO will ask for the handshake to start when the input buffer is empty + // but we want data first so tell ASIO to give us data + if (_pInBuffer->empty()) { + return ssl_want::want_input_and_retry; + } + + startServerHandshake(ec); + if (ec) { + return ssl_want::want_nothing; + } + + want = doServerHandshake(true, ec, pHandshakeState); + if (ec) { + return want; + } + + } else { + startClientHandshake(ec); + if (ec) { + return ssl_want::want_nothing; + } + + want = doClientHandshake(ec); + if (ec) { + return want; + } + } + + setState(State::NeedMoreHandshakeData); + + return want; + } else if (_state == State::NeedMoreHandshakeData) { + return ssl_want::want_input_and_retry; + } else { + ssl_want want; + + if (_mode == HandshakeMode::Server) { + want = doServerHandshake(false, ec, pHandshakeState); + } else { + want = doClientHandshake(ec); + } + + if (ec) { + return want; + } + + if (want == ssl_want::want_nothing || *pHandshakeState == HandshakeState::Done) { + setState(State::Done); + } else { + setState(State::NeedMoreHandshakeData); + } + + return want; + } +} + +/** + * Begin graceful SSL shutdown. Either: + * - respond to already received alert signalling connection shutdown on remote side + * - start SSL shutdown by signalling remote side + */ +ssl_want SSLHandshakeManager::beginShutdown(asio::error_code& ec) { + ASIO_ASSERT(_mode != HandshakeMode::Unknown); + _state = State::HandshakeStart; + + return startShutdown(ec); +} + +/* + * Injest data from ASIO that has been received. + */ +void SSLHandshakeManager::writeEncryptedData(const void* data, std::size_t length) { + // We have more data, it may not be enough to decode. We will decide if we have enough on + // the next nextHandshake call. + if (_state != State::HandshakeStart) { + setState(State::HaveEncryptedData); + } + + _pInBuffer->append(data, length); +} + + +void SSLHandshakeManager::startServerHandshake(asio::error_code& ec) { + TimeStamp lifetime; + SECURITY_STATUS ss = AcquireCredentialsHandleW(NULL, + const_cast<LPWSTR>(UNISP_NAME), + SECPKG_CRED_INBOUND, + NULL, + _cred, + NULL, + NULL, + _phcred, + &lifetime); + if (ss != SEC_E_OK) { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return; + } +} + +void SSLHandshakeManager::startClientHandshake(asio::error_code& ec) { + TimeStamp lifetime; + SECURITY_STATUS ss = AcquireCredentialsHandleW(NULL, + const_cast<LPWSTR>(UNISP_NAME), + SECPKG_CRED_OUTBOUND, + NULL, + _cred, + NULL, + NULL, + _phcred, + &lifetime); + + if (ss != SEC_E_OK) { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return; + } +} + +ssl_want SSLHandshakeManager::startShutdown(asio::error_code& ec) { + DWORD shutdownCode = SCHANNEL_SHUTDOWN; + + std::array<SecBuffer, 1> inputBuffers; + inputBuffers[0].cbBuffer = sizeof(shutdownCode); + inputBuffers[0].BufferType = SECBUFFER_TOKEN; + inputBuffers[0].pvBuffer = &shutdownCode; + + SecBufferDesc inputBufferDesc; + inputBufferDesc.ulVersion = SECBUFFER_VERSION; + inputBufferDesc.cBuffers = inputBuffers.size(); + inputBufferDesc.pBuffers = inputBuffers.data(); + + SECURITY_STATUS ss = ApplyControlToken(_phctxt, &inputBufferDesc); + + if (ss != SEC_E_OK) { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return ssl_want::want_nothing; + } + + TimeStamp lifetime; + + std::array<SecBuffer, 1> outputBuffers; + outputBuffers[0].cbBuffer = 0; + outputBuffers[0].BufferType = SECBUFFER_TOKEN; + outputBuffers[0].pvBuffer = NULL; + ContextBufferDeleter deleter(&outputBuffers[0].pvBuffer); + + SecBufferDesc outputBufferDesc; + outputBufferDesc.ulVersion = SECBUFFER_VERSION; + outputBufferDesc.cBuffers = outputBuffers.size(); + outputBufferDesc.pBuffers = outputBuffers.data(); + + if (_mode == HandshakeMode::Server) { + ULONG attribs = getServerFlags() | ASC_REQ_ALLOCATE_MEMORY; + + SECURITY_STATUS ss = AcceptSecurityContext( + _phcred, _phctxt, NULL, attribs, 0, _phctxt, &outputBufferDesc, &attribs, &lifetime); + + if (ss != SEC_E_OK) { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return ssl_want::want_nothing; + } + + _pOutBuffer->reset(); + _pOutBuffer->append(outputBuffers[0].pvBuffer, outputBuffers[0].cbBuffer); + + if (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0) { + ec = asio::error::eof; + return ssl_want::want_output; + } else { + return ssl_want::want_nothing; + } + } else { + ULONG ContextAttributes; + DWORD sspiFlags = getClientFlags() | ISC_REQ_ALLOCATE_MEMORY; + + ss = InitializeSecurityContextW(_phcred, + _phctxt, + const_cast<SEC_WCHAR*>(_serverName.c_str()), + sspiFlags, + 0, + 0, + NULL, + 0, + _phctxt, + &outputBufferDesc, + &ContextAttributes, + &lifetime); + + if (ss != SEC_E_OK) { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return ssl_want::want_nothing; + } + + // TODO - I have not found a way to hit this code path + ASIO_ASSERT(false); + } + + return ssl_want::want_nothing; +} + +ssl_want SSLHandshakeManager::doServerHandshake(bool newConversation, + asio::error_code& ec, + HandshakeState* pHandshakeState) { + TimeStamp lifetime; + + _pOutBuffer->resize(kDefaultBufferSize); + _alertBuffer.resize(1024); + + std::array<SecBuffer, 2> outputBuffers; + outputBuffers[0].cbBuffer = _pOutBuffer->size(); + outputBuffers[0].BufferType = SECBUFFER_TOKEN; + outputBuffers[0].pvBuffer = _pOutBuffer->data(); + + outputBuffers[1].cbBuffer = _alertBuffer.size(); + outputBuffers[1].BufferType = SECBUFFER_ALERT; + outputBuffers[1].pvBuffer = _alertBuffer.data(); + + SecBufferDesc outputBufferDesc; + outputBufferDesc.ulVersion = SECBUFFER_VERSION; + outputBufferDesc.cBuffers = outputBuffers.size(); + outputBufferDesc.pBuffers = outputBuffers.data(); + + std::array<SecBuffer, 2> inputBuffers; + inputBuffers[0].cbBuffer = _pInBuffer->size(); + inputBuffers[0].BufferType = SECBUFFER_TOKEN; + inputBuffers[0].pvBuffer = _pInBuffer->data(); + + inputBuffers[1].cbBuffer = 0; + inputBuffers[1].BufferType = SECBUFFER_EMPTY; + inputBuffers[1].pvBuffer = NULL; + + SecBufferDesc inputBufferDesc; + inputBufferDesc.ulVersion = SECBUFFER_VERSION; + inputBufferDesc.cBuffers = inputBuffers.size(); + inputBufferDesc.pBuffers = inputBuffers.data(); + + ULONG attribs = getServerFlags(); + ULONG retAttribs = 0; + + SECURITY_STATUS ss = AcceptSecurityContext(_phcred, + newConversation ? NULL : _phctxt, + &inputBufferDesc, + attribs, + 0, + _phctxt, + &outputBufferDesc, + &retAttribs, + &lifetime); + + if (ss < SEC_E_OK) { + if (ss == SEC_E_INCOMPLETE_MESSAGE) { + // TODO: consider using SECBUFFER_MISSING and approriate optimizations + return ssl_want::want_input_and_retry; + } + + ec = asio::error_code(ss, asio::error::get_ssl_category()); + + if ((retAttribs & ASC_RET_EXTENDED_ERROR) && (outputBuffers[1].cbBuffer > 0)) { + _pOutBuffer->resize(outputBuffers[0].cbBuffer); + + // Tell ASIO we have something to send back the last data + return ssl_want::want_output; + } + + return ssl_want::want_nothing; + } + invariant(attribs == retAttribs); + + if (inputBuffers[1].BufferType == SECBUFFER_EXTRA) { + _pExtraEncryptedBuffer->reset(); + _pExtraEncryptedBuffer->append(inputBuffers[1].pvBuffer, inputBuffers[1].cbBuffer); + } + + + // Next, figure out if we need to send any data out + bool needOutput{false}; + + // Did AcceptSecurityContext say we need to continue or is it done but left data in the + // output buffer then we need to sent the data out. + if (SEC_I_CONTINUE_NEEDED == ss || SEC_I_COMPLETE_AND_CONTINUE == ss || + (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0)) { + needOutput = true; + } + + // Tell the reusable buffer size of the data written. + _pOutBuffer->resize(outputBuffers[0].cbBuffer); + + // Reset the input buffer + _pInBuffer->reset(); + + // Check if we have any additional encrypted data + if (!_pExtraEncryptedBuffer->empty()) { + _pInBuffer->swap(*_pExtraEncryptedBuffer); + _pExtraEncryptedBuffer->reset(); + + setState(State::HaveEncryptedData); + } + + if (needOutput) { + // If AcceptSecurityContext returns SEC_E_OK, then the handshake is done + if (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0) { + *pHandshakeState = HandshakeState::Done; + + // We have output, but no need to retry anymore + return ssl_want::want_output; + } + + return ssl_want::want_output_and_retry; + } + + return ssl_want::want_nothing; +} + +ssl_want SSLHandshakeManager::doClientHandshake(asio::error_code& ec) { + DWORD sspiFlags = getClientFlags() | ISC_REQ_ALLOCATE_MEMORY; + + std::array<SecBuffer, 3> outputBuffers; + + outputBuffers[0].cbBuffer = 0; + outputBuffers[0].BufferType = SECBUFFER_TOKEN; + outputBuffers[0].pvBuffer = NULL; + ContextBufferDeleter deleter(&outputBuffers[0].pvBuffer); + + outputBuffers[1].cbBuffer = 0; + outputBuffers[1].BufferType = SECBUFFER_ALERT; + outputBuffers[1].pvBuffer = NULL; + ContextBufferDeleter alertDeleter(&outputBuffers[1].pvBuffer); + + outputBuffers[2].cbBuffer = 0; + outputBuffers[2].BufferType = SECBUFFER_EMPTY; + outputBuffers[2].pvBuffer = NULL; + + SecBufferDesc outputBufferDesc; + outputBufferDesc.ulVersion = SECBUFFER_VERSION; + outputBufferDesc.cBuffers = outputBuffers.size(); + outputBufferDesc.pBuffers = outputBuffers.data(); + + std::array<SecBuffer, 2> inputBuffers; + + SECURITY_STATUS ss; + TimeStamp lifetime; + ULONG retAttribs; + + // If the input buffer is empty, this is the start of the client handshake. + if (!_pInBuffer->empty()) { + inputBuffers[0].cbBuffer = _pInBuffer->size(); + inputBuffers[0].BufferType = SECBUFFER_TOKEN; + inputBuffers[0].pvBuffer = _pInBuffer->data(); + + inputBuffers[1].cbBuffer = 0; + inputBuffers[1].BufferType = SECBUFFER_EMPTY; + inputBuffers[1].pvBuffer = NULL; + + SecBufferDesc inputBufferDesc; + inputBufferDesc.ulVersion = SECBUFFER_VERSION; + inputBufferDesc.cBuffers = inputBuffers.size(); + inputBufferDesc.pBuffers = inputBuffers.data(); + + ss = InitializeSecurityContextW(_phcred, + _phctxt, + const_cast<SEC_WCHAR*>(_serverName.c_str()), + sspiFlags, + 0, + 0, + &inputBufferDesc, + 0, + _phctxt, + &outputBufferDesc, + &retAttribs, + &lifetime); + } else { + ss = InitializeSecurityContextW(_phcred, + NULL, + const_cast<SEC_WCHAR*>(_serverName.c_str()), + sspiFlags, + 0, + 0, + NULL, + 0, + _phctxt, + &outputBufferDesc, + &retAttribs, + &lifetime); + } + + if (ss < SEC_E_OK) { + if (ss == SEC_E_INCOMPLETE_MESSAGE) { + return ssl_want::want_input_and_retry; + } + + ec = asio::error_code(ss, asio::error::get_ssl_category()); + + if ((retAttribs & ISC_RET_EXTENDED_ERROR) && (outputBuffers[1].cbBuffer > 0)) { + _pOutBuffer->reset(); + _pOutBuffer->append(outputBuffers[0].pvBuffer, outputBuffers[0].cbBuffer); + + // Tell ASIO we have something to send back the last data + return ssl_want::want_output; + } + + return ssl_want::want_nothing; + } + invariant(sspiFlags == retAttribs); + + if (_pInBuffer->size()) { + // Locate (optional) extra buffer + if (inputBuffers[1].BufferType == SECBUFFER_EXTRA) { + _pExtraEncryptedBuffer->reset(); + _pExtraEncryptedBuffer->append(inputBuffers[1].pvBuffer, inputBuffers[1].cbBuffer); + } + } + + // Next, figure out if we need to send any data out + bool needOutput{false}; + + // Did AcceptSecurityContext say we need to continue or is it done but left data in the + // output buffer then we need to sent the data out. + if (SEC_I_CONTINUE_NEEDED == ss || SEC_I_COMPLETE_AND_CONTINUE == ss || + (SEC_E_OK == ss && outputBuffers[0].cbBuffer != 0)) { + needOutput = true; + } + + _pOutBuffer->reset(); + _pOutBuffer->append(outputBuffers[0].pvBuffer, outputBuffers[0].cbBuffer); + + // Reset the input buffer + _pInBuffer->reset(); + + // Check if we have any additional encrypted data + if (!_pExtraEncryptedBuffer->empty()) { + _pInBuffer->swap(*_pExtraEncryptedBuffer); + _pExtraEncryptedBuffer->reset(); + + setState(State::HaveEncryptedData); + } + + if (needOutput) { + return ssl_want::want_output_and_retry; + } + + return ssl_want::want_nothing; +} + +/** + * Read decrypted data if encrypted data was provided via writeData and succesfully decrypted. + */ +ssl_want SSLReadManager::readDecryptedData(void* data, + std::size_t length, + asio::error_code& ec, + std::size_t& bytes_transferred, + DecryptState* pDecryptState) { + bytes_transferred = 0; + ec = asio::error_code(); + *pDecryptState = DecryptState::Continue; + + // Our last state was that we needed more encrypted data, so tell ASIO we still want some + if (_state == State::NeedMoreEncryptedData) { + return ssl_want::want_input_and_retry; + } + + // If we have encrypted data, try to decrypt it + if (_state == State::HaveEncryptedData) { + ssl_want wantState = decryptBuffer(ec, pDecryptState); + if (ec) { + return wantState; + } + + // If remote side started shutdown, bail + if (*pDecryptState != DecryptState::Continue) { + return ssl_want::want_nothing; + } + + if (wantState == ssl_want::want_input_and_retry) { + setState(State::NeedMoreEncryptedData); + } + + if (wantState != ssl_want::want_nothing) { + return wantState; + } + } + + // We decrypted data in the past, hand it back to ASIO until we are out of decrypted data + ASIO_ASSERT(_state == State::HaveDecryptedData); + + _pInBuffer->readInto(data, length, bytes_transferred); + + // Have we read all the decrypted data? + if (_pInBuffer->empty()) { + // If we have some extra encrypted data, it needs to be checked if it is at least a + // valid SSL packet, so set the state machine to reflect that we have some encrypted + // data. + if (!_pExtraEncryptedBuffer->empty()) { + _pInBuffer->swap(*_pExtraEncryptedBuffer); + _pExtraEncryptedBuffer->reset(); + setState(State::HaveEncryptedData); + } else { + // We are empty so reset our state to need encrypted data for the next call + setState(State::NeedMoreEncryptedData); + } + } + + return ssl_want::want_nothing; +} + +ssl_want SSLReadManager::decryptBuffer(asio::error_code& ec, DecryptState* pDecryptState) { + std::array<SecBuffer, 4> securityBuffers; + securityBuffers[0].cbBuffer = _pInBuffer->size(); + securityBuffers[0].BufferType = SECBUFFER_DATA; + securityBuffers[0].pvBuffer = _pInBuffer->data(); + + securityBuffers[1].cbBuffer = 0; + securityBuffers[1].BufferType = SECBUFFER_EMPTY; + securityBuffers[1].pvBuffer = NULL; + + securityBuffers[2].cbBuffer = 0; + securityBuffers[2].BufferType = SECBUFFER_EMPTY; + securityBuffers[2].pvBuffer = NULL; + + securityBuffers[3].cbBuffer = 0; + securityBuffers[3].BufferType = SECBUFFER_EMPTY; + securityBuffers[3].pvBuffer = NULL; + + SecBufferDesc bufferDesc; + bufferDesc.ulVersion = SECBUFFER_VERSION; + bufferDesc.cBuffers = securityBuffers.size(); + bufferDesc.pBuffers = securityBuffers.data(); + + SECURITY_STATUS ss = DecryptMessage(_phctxt, &bufferDesc, 0, NULL); + + if (ss < SEC_E_OK) { + if (ss == SEC_E_INCOMPLETE_MESSAGE) { + return ssl_want::want_input_and_retry; + } else { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return ssl_want::want_nothing; + } + } + + // Shutdown has been initiated at the client side + if (ss == SEC_I_CONTEXT_EXPIRED) { + *pDecryptState = DecryptState::Shutdown; + } else if (ss == SEC_I_RENEGOTIATE) { + *pDecryptState = DecryptState::Renegotiate; + + // Fail the connection on SSL renegotiations + ec = asio::ssl::error::stream_truncated; + return ssl_want::want_nothing; + } + + if (securityBuffers[1].cbBuffer > 0) { + _pInBuffer->resetPos(securityBuffers[1].pvBuffer, securityBuffers[1].cbBuffer); + } + + // The network layer may have read more then 1 SSL packet so remember the extra data. + if (securityBuffers[3].BufferType == SECBUFFER_EXTRA && securityBuffers[3].cbBuffer > 0) { + ASIO_ASSERT(_pExtraEncryptedBuffer->empty()); + _pExtraEncryptedBuffer->append(securityBuffers[3].pvBuffer, securityBuffers[3].cbBuffer); + } + + setState(State::HaveDecryptedData); + + return ssl_want::want_nothing; +} + + +/** + * Encrypts data to be sent to the remote side. + * + * If the message is >= max packet size, it will return want_output_and_retry, and expects + * callers to continue to call it with the same parameters until want_output is returned. + */ +ssl_want SSLWriteManager::writeUnencryptedData(const void* pMessage, + std::size_t messageLength, + std::size_t& bytes_transferred, + asio::error_code& ec) { + ec = asio::error_code(); + + if (_securityTrailerLength == ULONG_MAX) { + SecPkgContext_StreamSizes secPkgContextStreamSizes; + + SECURITY_STATUS ss = + QueryContextAttributes(_phctxt, SECPKG_ATTR_STREAM_SIZES, &secPkgContextStreamSizes); + + if (ss < SEC_E_OK) { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return ssl_want::want_nothing; + } + + _securityTrailerLength = secPkgContextStreamSizes.cbTrailer; + _securityMaxMessageLength = secPkgContextStreamSizes.cbMaximumMessage; + _securityHeaderLength = secPkgContextStreamSizes.cbHeader; + } + + // Do we need to fragment the message out? + if (messageLength > _securityMaxMessageLength) { + // Since the message is too large for SSL, we have to write out fragments. We rely on + // the fact that ASIO will keep giving us the same buffer back as long as it is asked to + // retry. + std::size_t fragmentLength = + std::min(_securityMaxMessageLength, messageLength - _lastWriteOffset); + ssl_want want = encryptMessage(reinterpret_cast<const char*>(pMessage) + _lastWriteOffset, + fragmentLength, + bytes_transferred, + ec); + if (ec) { + return want; + } + + _lastWriteOffset += fragmentLength; + + // We have more data to give ASIO after this fragment + if (_lastWriteOffset < messageLength) { + return ssl_want::want_output_and_retry; + } + + // We have consumed all the data given to us over multiple consecutive calls, reset + // position. + _lastWriteOffset = 0; + + // ASIO's buffering of engine calls assumes that bytes_transfered refers to all the + // bytes we transfered total when want_output is returned. It ignores bytes_transfered + // when want_output_and_retry is returned; + bytes_transferred = messageLength; + + return ssl_want::want_output; + } else { + // Reset fragmentation position + _lastWriteOffset = 0; + + // Send message as is without fragmentation + return encryptMessage(pMessage, messageLength, bytes_transferred, ec); + } +} + +ssl_want SSLWriteManager::encryptMessage(const void* pMessage, + std::size_t messageLength, + std::size_t& bytes_transferred, + asio::error_code& ec) { + ASIO_ASSERT(_pOutBuffer->empty()); + _pOutBuffer->resize(_securityTrailerLength + _securityHeaderLength + messageLength); + + std::array<SecBuffer, 4> securityBuffers; + + securityBuffers[0].BufferType = SECBUFFER_STREAM_HEADER; + securityBuffers[0].cbBuffer = _securityHeaderLength; + securityBuffers[0].pvBuffer = _pOutBuffer->data(); + + memcpy_s(_pOutBuffer->data() + _securityHeaderLength, + _pOutBuffer->size() - _securityHeaderLength - _securityTrailerLength, + pMessage, + messageLength); + + securityBuffers[1].BufferType = SECBUFFER_DATA; + securityBuffers[1].cbBuffer = messageLength; + securityBuffers[1].pvBuffer = _pOutBuffer->data() + _securityHeaderLength; + + securityBuffers[2].cbBuffer = _securityTrailerLength; + securityBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER; + securityBuffers[2].pvBuffer = _pOutBuffer->data() + _securityHeaderLength + messageLength; + + securityBuffers[3].cbBuffer = 0; + securityBuffers[3].BufferType = SECBUFFER_EMPTY; + securityBuffers[3].pvBuffer = 0; + + SecBufferDesc bufferDesc; + + bufferDesc.ulVersion = SECBUFFER_VERSION; + bufferDesc.cBuffers = securityBuffers.size(); + bufferDesc.pBuffers = securityBuffers.data(); + + SECURITY_STATUS ss = EncryptMessage(_phctxt, 0, &bufferDesc, 0); + + if (ss < SEC_E_OK) { + ec = asio::error_code(ss, asio::error::get_ssl_category()); + return ssl_want::want_nothing; + } + + size_t size = + securityBuffers[0].cbBuffer + securityBuffers[1].cbBuffer + securityBuffers[2].cbBuffer; + + _pOutBuffer->resize(size); + + // Tell asio that all the clear text was transfered. + bytes_transferred = messageLength; + + return ssl_want::want_output; +} + +} // namespace detail +} // namespace ssl +} // namespace asio diff --git a/src/mongo/util/net/ssl/detail/schannel.hpp b/src/mongo/util/net/ssl/detail/schannel.hpp new file mode 100644 index 00000000000..4cf00271f0e --- /dev/null +++ b/src/mongo/util/net/ssl/detail/schannel.hpp @@ -0,0 +1,600 @@ +/** + * Copyright (C) 2018 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects + * for all of the code used other than as permitted herein. If you modify + * file(s) with this exception, you may extend this exception to your + * version of the file(s), but you are not obligated to do so. If you do not + * wish to do so, delete this exception statement from your version. If you + * delete this exception statement from all source files in the program, + * then also delete it in the license file. + */ + +#pragma once + +#include "asio/detail/config.hpp" + +#include <algorithm> +#include <cstddef> +#include <memory> + +#include "asio/detail/push_options.hpp" + + +#include "asio/detail/assert.hpp" + +#define ASSERT_STATE_TRANSITION(orig, dest) ASIO_ASSERT(!(orig) || (dest)); + +namespace asio { +namespace ssl { +namespace detail { + +/** + * Reusable buffer. Behaves as a sort of producer consumer queue in a sense. + * + * Data is added to the buffer then removed. + * + * Typical workflow: + * - Write data + * - Write more data + * - Read some data + * - Keeping reading until empty + * + * Invariants: + * - Once reading from a buffer is started, no more writes are permitted until + * consumer has read all the entire buffer. + */ +class ReusableBuffer { +public: + ReusableBuffer(std::size_t initialSize) { + _buffer = std::make_unique<std::uint8_t[]>(initialSize); + _capacity = initialSize; + } + + /** + * Is buffer empty? + */ + bool empty() const { + return _size == 0; + } + + /** + * Get raw pointer to buffer. + */ + std::uint8_t* data() { + return _buffer.get(); + } + + /** + * Get current number of elements in buffer. + */ + std::size_t size() const { + return _size; + } + + /** + * Reset to empty state. + */ + void reset() { + _bufPos = 0; + _size = 0; + } + + /** + * Add data to empty buffer. + */ + void fill(const std::vector<std::uint8_t>& vec) { + ASIO_ASSERT(_size == 0); + ASIO_ASSERT(_bufPos == 0); + append(vec.data(), vec.size()); + } + + /** + * Reset current position to specified pointer in buffer. + */ + void resetPos(void* pos, std::size_t size) { + ASIO_ASSERT(pos >= _buffer.get() && pos < (_buffer.get() + _size)); + _bufPos = (std::uint8_t*)pos - _buffer.get(); + resize(_bufPos + size); + } + + /** + * Append data to buffer. + */ + void append(const void* data, std::size_t length) { + ASIO_ASSERT(_bufPos == 0); + auto originalSize = _size; + resize(_size + length); + std::copy(reinterpret_cast<const std::uint8_t*>(data), + reinterpret_cast<const std::uint8_t*>(data) + length, + _buffer.get() + originalSize); + } + + /** + * Read data from buffer. Can be a partial read. + */ + void readInto(void* data, std::size_t length, std::size_t& outLength) { + if (length >= (size() - _bufPos)) { + // We have less then ASIO wants, give them everything we have + outLength = size() - _bufPos; + memcpy_s(data, length, _buffer.get() + _bufPos, size() - _bufPos); + + // We are empty so reset our state to need encrypted data for the next call + _bufPos = 0; + _size = 0; + } else { + // ASIO wants less then we have so give them just what they want + outLength = length; + memcpy_s(data, length, _buffer.get() + _bufPos, length); + + _bufPos += length; + } + } + + /** + * Realloc buffer preserving existing data. + */ + void resize(std::size_t size) { + if (size > _capacity) { + auto temp = std::make_unique<unsigned char[]>(size); + + memcpy_s(temp.get(), size, _buffer.get(), _size); + _buffer.swap(temp); + _capacity = size; + } + _size = size; + } + + /** + * Swap current buffer with other buffer. + */ + void swap(ReusableBuffer& other) { + std::swap(_buffer, other._buffer); + std::swap(_bufPos, other._bufPos); + std::swap(_size, other._size); + std::swap(_capacity, other._capacity); + } + +private: + // Buffer of data + std::unique_ptr<std::uint8_t[]> _buffer; + + // Current read position in buffer + std::size_t _bufPos{0}; + + // Count of elements in buffer + std::size_t _size{0}; + + // Capacity of buffer for elements, always >= _size + std::size_t _capacity; +}; + +// Default buffer size. SSL has a max encapsulated packet size of 16 kb. +const std::size_t kDefaultBufferSize = 17 * 1024; + +// This enum mirrors the engine::want enum. The values must be kept in sync +// to support a simple conversion from ssl_want to engine::want, see ssl_want_to_engine. +enum class ssl_want { + // Returned by functions to indicate that the engine wants input. The input + // buffer should be updated to point to the data. The engine then needs to + // be called again to retry the operation. + want_input_and_retry = -2, + + // Returned by functions to indicate that the engine wants to write output. + // The output buffer points to the data to be written. The engine then + // needs to be called again to retry the operation. + want_output_and_retry = -1, + + // Returned by functions to indicate that the engine doesn't need input or + // output. + want_nothing = 0, + + // Returned by functions to indicate that the engine wants to write output. + // The output buffer points to the data to be written. After that the + // operation is complete, and the engine does not need to be called again. + want_output = 1 +}; + +/** + * Manages the SSL handshake and shutdown state machines. + * + * Handshakes are always the first set of events during SSL connection initiation. + * Shutdown can occur anytime after the handshake has succesfully finished + * as a result of a read event or explicit shutdown request from the engine. + */ +class SSLHandshakeManager { +public: + /** + * Handshake Mode indicates whether this a for a client or server side. + * + * Each given connection can only be a client or server, and it cannot change once set. + */ + enum class HandshakeMode { + // Initial state, illegal for clients to set + Unknown, + + // Client handshake, connect side + Client, + + // Server handshake, accept side + Server, + }; + + /** + * Handshake state indicates to the caller if nextHandshake needs to be called next. + */ + enum class HandshakeState { + // Caller should continue to call nextHandshake, the handshake is not done. + Continue, + + // Caller should not continue to call nextHandshake, the handshake is done. + Done + }; + + SSLHandshakeManager(PCtxtHandle hctxt, + PCredHandle phcred, + std::wstring& serverName, + ReusableBuffer* pInBuffer, + ReusableBuffer* pOutBuffer, + ReusableBuffer* pExtraBuffer, + SCHANNEL_CRED* cred) + : _state(State::HandshakeStart), + _phctxt(hctxt), + _cred(cred), + _serverName(serverName), + _phcred(phcred), + _pInBuffer(pInBuffer), + _pOutBuffer(pOutBuffer), + _pExtraEncryptedBuffer(pExtraBuffer), + _alertBuffer(1024), + _mode(HandshakeMode::Unknown) {} + + /** + * Set the current handdshake mode as client or server. + * + * Idempotent if called with same mode otherwise it asserts. + */ + void setMode(HandshakeMode mode) { + ASIO_ASSERT(_mode == HandshakeMode::Unknown || _mode == mode); + ASIO_ASSERT(mode != HandshakeMode::Unknown); + _mode = mode; + } + + /** + * Start or continue SSL handshake. + * + * Must be called until HandshakeState::Done is returned. + * + * Return status code to indicate whether it needs more data or if data needs to be sent to the + * other side. + */ + ssl_want nextHandshake(asio::error_code& ec, HandshakeState* pHandshakeState); + + /** + * Begin graceful SSL shutdown. Either: + * - respond to already received alert signalling connection shutdown on remote side + * - start SSL shutdown by signalling remote side + */ + ssl_want beginShutdown(asio::error_code& ec); + + /* + * Ingest data from ASIO that has been received. + */ + void writeEncryptedData(const void* data, std::size_t length); + + /** + * Returns true if there is data to send over the wire. + */ + bool hasOutputData() { + return !_pOutBuffer->empty(); + } + + /** + * Get data to sent over the network. + */ + void readOutputBuffer(void* data, size_t inLength, size_t& outLength) { + _pOutBuffer->readInto(data, inLength, outLength); + } + +private: + void startServerHandshake(asio::error_code& ec); + + void startClientHandshake(asio::error_code& ec); + + DWORD getServerFlags() { + return ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY | + ASC_REQ_EXTENDED_ERROR | ASC_REQ_STREAM; + } + + DWORD getClientFlags() { + return ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | + ISC_REQ_EXTENDED_ERROR | ISC_REQ_STREAM | ISC_REQ_USE_SUPPLIED_CREDS | + ISC_REQ_MANUAL_CRED_VALIDATION; + } + + /** + * RAII class to free a buffer allocated by SSPI. + */ + class ContextBufferDeleter { + public: + ContextBufferDeleter(void** buf) : _buf(buf) {} + + ~ContextBufferDeleter() { + if (*_buf != nullptr) { + FreeContextBuffer(*_buf); + } + } + + private: + void** _buf; + }; + + ssl_want startShutdown(asio::error_code& ec); + + ssl_want doServerHandshake(bool newConversation, + asio::error_code& ec, + HandshakeState* pHandshakeState); + + ssl_want doClientHandshake(asio::error_code& ec); + +private: + /** + * Handshake State machine: + * +-----------------------------+ + * v | + * +----------------+ +-----------------------+ +-------------------+ +------+ + * | HandshakeStart | --> | NeedMoreHandshakeData | --> | HaveEncryptedData | --> | Done | + * +----------------+ +-----------------------+ +-------------------+ +------+ + * + * "[ HandshakeStart ] --> [ NeedMoreHandshakeData ] --> [HaveEncryptedData] -> [ + * NeedMoreHandshakeData], [Done] " | graph-easy + */ + enum class State { + // Initial state + HandshakeStart, + + // Handshake needs more data before it decode the next message + NeedMoreHandshakeData, + + // Handshake just received some data, and can now try to decrypt it + HaveEncryptedData, + + // Handshake is done + Done, + }; + + /** + * Transition state machine + */ + void setState(State s) { + ASSERT_STATE_TRANSITION(_state == State::HandshakeStart, s == State::NeedMoreHandshakeData); + ASSERT_STATE_TRANSITION(_state == State::NeedMoreHandshakeData, + s == State::HaveEncryptedData); + ASSERT_STATE_TRANSITION(_state == State::HaveEncryptedData, + s == State::NeedMoreHandshakeData || s == State::Done); + _state = s; + } + +private: + // State machine + State _state; + + // Handshake mode - client or server + HandshakeMode _mode; + + // Server name for TLS SNI purposes + std::wstring& _serverName; + + // Buffer of data received from remote side + ReusableBuffer* _pInBuffer; + + // Scratch buffer to capture extra handshake data + ReusableBuffer* _pExtraEncryptedBuffer; + + // Buffer to data to send to remote side + ReusableBuffer* _pOutBuffer; + + // Buffer of data received from remote side + ReusableBuffer _alertBuffer; + + // SChannel Credentials + SCHANNEL_CRED* _cred; + + // SChannel context + PCtxtHandle _phctxt; + + // Credential handle + PCredHandle _phcred; +}; + +/** + * Manages the SSL read state machine. + * + * Notifies callers of graceful SSL shutdown events. + */ +class SSLReadManager { +public: + /** + * Indicates whether client should continue to decrypt data or it needs to handle other protocol + * signals. + */ + enum class DecryptState { + // SSL connection is proceeding normally + Continue, + + // Remote side has signaled graceful SSL shutdown + Shutdown, + + // Remote side has signaled renegtiation + Renegotiate, + }; + + SSLReadManager(PCtxtHandle hctxt, + PCredHandle hcred, + ReusableBuffer* pInBuffer, + ReusableBuffer* pExtraBuffer) + : _state(State::NeedMoreEncryptedData), + _phctxt(hctxt), + _phcred(hcred), + _pInBuffer(pInBuffer), + _pExtraEncryptedBuffer(pExtraBuffer) {} + + /** + * Read decrypted data if encrypted data was provided via writeData and succesfully decrypted. + */ + ssl_want readDecryptedData(void* data, + std::size_t length, + asio::error_code& ec, + std::size_t& bytes_transferred, + DecryptState* pDecryptState); + + /** + * Receive more data from ASIO. + */ + void writeData(const void* data, std::size_t length) { + ASIO_ASSERT(_pExtraEncryptedBuffer->empty()); + + // We have more data, it may not be enough to decode but we will figure that out later. + setState(State::HaveEncryptedData); + + _pInBuffer->append(data, length); + } + +private: + ssl_want decryptBuffer(asio::error_code& ec, DecryptState* pDecryptState); + +private: + /** + * Read State machine: + * + * +------------------------------------------------------------+ + * | | + * | | + * | +-----------------------------+ | + * | v | | + * | +-----------------------+ +-------------------+ +-------------------+ + * +> | NeedMoreEncryptedData | --> | HaveEncryptedData | --> | HaveDecryptedData | + * +-----------------------+ +-------------------+ +-------------------+ + * ^ | ^ | + * +-------------------+ +-------------------------+ + * + * "[ NeedMoreEncryptedData ] --> [ HaveEncryptedData ] --> [HaveDecryptedData] -> + * [NeedMoreEncryptedData], [HaveEncryptedData] --> [NeedMoreEncryptedData] " | graph-easy + * + */ + enum class State { + // Initial state, Need more data from remote side + NeedMoreEncryptedData, + + // Have some encrypted data, unknown if it is a complete packet + HaveEncryptedData, + + // Was able to decrypt a packet, give decrypted data back to client + HaveDecryptedData, + }; + + /** + * Transition state machine + */ + void setState(State s) { + ASSERT_STATE_TRANSITION(_state == State::NeedMoreEncryptedData, + s == State::HaveEncryptedData); + ASSERT_STATE_TRANSITION( + _state == State::HaveEncryptedData, + (s == State::NeedMoreEncryptedData || s == State::HaveDecryptedData)); + ASSERT_STATE_TRANSITION( + _state == State::HaveDecryptedData, + (s == State::NeedMoreEncryptedData || s == State::HaveEncryptedData)); + _state = s; + } + +private: + // State machine + State _state; + + // Scratch buffer to capture extra decryption data + ReusableBuffer* _pExtraEncryptedBuffer; + + // Buffer of data from the remote side + ReusableBuffer* _pInBuffer; + + // SChannel context handle + PCtxtHandle _phctxt; + + // Credential handle + PCredHandle _phcred; +}; + +/** + * Manages the SSL write state machine. + */ +class SSLWriteManager { +public: + SSLWriteManager(PCtxtHandle hctxt, ReusableBuffer* pOutBuffer) + : _phctxt(hctxt), _pOutBuffer(pOutBuffer) {} + + /** + * Encrypts data to be sent to the remote side. + * + * If the message is >= max packet side, it will return want_output_and_retry, and expects + * callers to continue to call it with the same parameters until want_output is returned. + */ + ssl_want writeUnencryptedData(const void* pMessage, + std::size_t messageLength, + std::size_t& bytes_transferred, + asio::error_code& ec); + + /** + * Read encrypted data to be sent to the remote side. + */ + void readOutputBuffer(void* data, size_t inLength, size_t& outLength) { + _pOutBuffer->readInto(data, inLength, outLength); + } + +private: + ssl_want encryptMessage(const void* pMessage, + std::size_t messageLength, + std::size_t& bytes_transferred, + asio::error_code& ec); + +private: + // Buffer of data to send to the remote side + ReusableBuffer* _pOutBuffer; + + // SChannel context handle + PCtxtHandle _phctxt; + + // Position to start encrypting from for messages needing fragmentation + std::size_t _lastWriteOffset{0}; + + // TLS packet header length + std::size_t _securityHeaderLength{ULONG_MAX}; + + // TLS max packet size - 16kb typically + std::size_t _securityMaxMessageLength{ULONG_MAX}; + + // TLS packet trailer length + std::size_t _securityTrailerLength{ULONG_MAX}; +}; + +#include "asio/detail/pop_options.hpp" + +} // namespace detail +} // namespace ssl +} // namespace asio diff --git a/src/mongo/util/net/ssl/impl/context_schannel.ipp b/src/mongo/util/net/ssl/impl/context_schannel.ipp index 99f6541c24d..52bd1b86797 100644 --- a/src/mongo/util/net/ssl/impl/context_schannel.ipp +++ b/src/mongo/util/net/ssl/impl/context_schannel.ipp @@ -30,11 +30,11 @@ #include "asio/detail/config.hpp" -#include <cstring> #include "asio/detail/throw_error.hpp" #include "asio/error.hpp" #include "mongo/util/net/ssl/context.hpp" #include "mongo/util/net/ssl/error.hpp" +#include <cstring> #include "asio/detail/push_options.hpp" @@ -42,37 +42,31 @@ namespace asio { namespace ssl { -context::context(context::method m) - : handle_(0) -{ +context::context(context::method m) : handle_(&_cred) { + memset(&_cred, 0, sizeof(_cred)); } #if defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION) -context::context(context&& other) -{ +context::context(context&& other) { handle_ = other.handle_; other.handle_ = 0; } -context& context::operator=(context&& other) -{ +context& context::operator=(context&& other) { context tmp(ASIO_MOVE_CAST(context)(*this)); handle_ = other.handle_; other.handle_ = 0; return *this; } -#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION) +#endif // defined(ASIO_HAS_MOVE) || defined(GENERATING_DOCUMENTATION) -context::~context() -{ -} +context::~context() {} -context::native_handle_type context::native_handle() -{ +context::native_handle_type context::native_handle() { return handle_; } -} // namespace ssl -} // namespace asio +} // namespace ssl +} // namespace asio #include "asio/detail/pop_options.hpp" diff --git a/src/mongo/util/net/ssl/impl/error.ipp b/src/mongo/util/net/ssl/impl/error.ipp index 5de8a890454..64d3a28bae7 100644 --- a/src/mongo/util/net/ssl/impl/error.ipp +++ b/src/mongo/util/net/ssl/impl/error.ipp @@ -17,6 +17,7 @@ #include "asio/detail/config.hpp" #include "mongo/util/net/ssl/error.hpp" +#include "mongo/util/errno_util.h" #include "asio/detail/push_options.hpp" @@ -35,9 +36,7 @@ public: #if MONGO_CONFIG_SSL_PROVIDER == SSL_PROVIDER_WINDOWS std::string message(int value) const { - // TODO: call FormatMessage - ASIO_ASSERT(false); - return "asio.ssl error"; + return mongo::errnoWithDescription(value); } #elif MONGO_CONFIG_SSL_PROVIDER == SSL_PROVIDER_OPENSSL std::string message(int value) const diff --git a/src/mongo/util/net/ssl/impl/src.hpp b/src/mongo/util/net/ssl/impl/src.hpp index 81b903d0f5e..fd940e414af 100644 --- a/src/mongo/util/net/ssl/impl/src.hpp +++ b/src/mongo/util/net/ssl/impl/src.hpp @@ -24,6 +24,7 @@ #include "mongo/util/net/ssl/impl/context_schannel.ipp" #include "mongo/util/net/ssl/impl/error.ipp" #include "mongo/util/net/ssl/detail/impl/engine_schannel.ipp" +#include "mongo/util/net/ssl/detail/impl/schannel.ipp" #elif MONGO_CONFIG_SSL_PROVIDER == SSL_PROVIDER_OPENSSL diff --git a/src/mongo/util/net/ssl_manager_windows.cpp b/src/mongo/util/net/ssl_manager_windows.cpp index bcac27724a0..10c48d7c325 100644 --- a/src/mongo/util/net/ssl_manager_windows.cpp +++ b/src/mongo/util/net/ssl_manager_windows.cpp @@ -46,6 +46,7 @@ #include "mongo/base/initializer_context.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/config.h" +#include "mongo/db/server_options.h" #include "mongo/db/server_parameters.h" #include "mongo/platform/atomic_word.h" #include "mongo/stdx/memory.h" @@ -70,8 +71,92 @@ namespace mongo { namespace { SimpleMutex sslManagerMtx; -SSLManagerInterface* theSSLManagerWindows = NULL; +SSLManagerInterface* theSSLManager = NULL; +/** +* Free a Certificate Context. +*/ +struct CERTFree { + void operator()(const CERT_CONTEXT* p) noexcept { + if (p) { + ::CertFreeCertificateContext(p); + } + } +}; + +using UniqueCertificate = std::unique_ptr<const CERT_CONTEXT, CERTFree>; + +/** +* A simple generic class to manage Windows handle like things. Behaves similiar to std::unique_ptr. +* +* Only supports move. +*/ +template <typename HandleT, class Deleter> +class AutoHandle { +public: + AutoHandle() : _handle(0) {} + AutoHandle(HandleT handle) : _handle(handle) {} + AutoHandle(AutoHandle<HandleT, Deleter>&& handle) : _handle(handle._handle) { + handle._handle = 0; + } + + ~AutoHandle() { + if (_handle != 0) { + Deleter()(_handle); + } + } + + AutoHandle(const AutoHandle&) = delete; + + /** + * Take ownership of the handle. + */ + AutoHandle& operator=(const HandleT other) { + _handle = other; + return *this; + } + + AutoHandle& operator=(const AutoHandle<HandleT, Deleter>& other) = delete; + + AutoHandle& operator=(AutoHandle<HandleT, Deleter>&& other) { + _handle = other._handle; + other._handle = 0; + return *this; + } + + operator HandleT() { + return _handle; + } + +private: + HandleT _handle; +}; + +/** +* Free a HCRYPTPROV Handle +*/ +struct CryptProviderFree { + void operator()(HCRYPTPROV const h) noexcept { + if (h) { + ::CryptReleaseContext(h, 0); + } + } +}; + +using UniqueCryptProvider = AutoHandle<HCRYPTPROV, CryptProviderFree>; + +/** +* Free a HCRYPTKEY Handle +*/ +struct CryptKeyFree { + void operator()(HCRYPTKEY const h) noexcept { + if (h) { + ::CryptDestroyKey(h); + } + } +}; + +using UniqueCryptKey = AutoHandle<HCRYPTKEY, CryptKeyFree>; } // namespace @@ -80,9 +165,20 @@ SSLManagerInterface* theSSLManagerWindows = NULL; */ class SSLConnectionWindows : public SSLConnectionInterface { public: + SCHANNEL_CRED* _cred; + Socket* socket; + asio::ssl::detail::engine _engine; + + std::vector<char> _tempBuffer; + + SSLConnectionWindows(SCHANNEL_CRED* cred, Socket* sock, const char* initialBytes, int len); + ~SSLConnectionWindows(); - std::string getSNIServerName() const final; + std::string getSNIServerName() const final { + // TODO + return ""; + }; }; @@ -98,32 +194,44 @@ public: const SSLParams& params, ConnectionDirection direction) final; - virtual SSLConnectionInterface* connect(Socket* socket); + SSLConnectionInterface* connect(Socket* socket) final; - virtual SSLConnectionInterface* accept(Socket* socket, const char* initialBytes, int len); + SSLConnectionInterface* accept(Socket* socket, const char* initialBytes, int len) final; - virtual SSLPeerInfo parseAndValidatePeerCertificateDeprecated( - const SSLConnectionInterface* conn, const std::string& remoteHost); + SSLPeerInfo parseAndValidatePeerCertificateDeprecated(const SSLConnectionInterface* conn, + const std::string& remoteHost) final; StatusWith<boost::optional<SSLPeerInfo>> parseAndValidatePeerCertificate( PCtxtHandle ssl, const std::string& remoteHost) final; - virtual const SSLConfiguration& getSSLConfiguration() const { + const SSLConfiguration& getSSLConfiguration() const final { return _sslConfiguration; } - virtual int SSL_read(SSLConnectionInterface* conn, void* buf, int num); + int SSL_read(SSLConnectionInterface* conn, void* buf, int num) final; - virtual int SSL_write(SSLConnectionInterface* conn, const void* buf, int num); + int SSL_write(SSLConnectionInterface* conn, const void* buf, int num) final; - virtual int SSL_shutdown(SSLConnectionInterface* conn); + int SSL_shutdown(SSLConnectionInterface* conn) final; private: bool _weakValidation; bool _allowInvalidCertificates; bool _allowInvalidHostnames; SSLConfiguration _sslConfiguration; + + SCHANNEL_CRED _clientCred; + SCHANNEL_CRED _serverCred; + + UniqueCertificate _pemCertificate; + UniqueCertificate _clusterPEMCertificate; + PCCERT_CONTEXT _clientCertificates[1]; + PCCERT_CONTEXT _serverCertificates[1]; + + Status loadCertificates(const SSLParams& params); + + void handshake(SSLConnectionWindows* conn, bool client); }; // Global variable indicating if this is a server or a client instance @@ -132,19 +240,27 @@ bool isSSLServer = false; MONGO_INITIALIZER(SSLManager)(InitializerContext*) { stdx::lock_guard<SimpleMutex> lck(sslManagerMtx); if (!isSSLServer || (sslGlobalParams.sslMode.load() != SSLParams::SSLMode_disabled)) { - theSSLManagerWindows = new SSLManagerWindows(sslGlobalParams, isSSLServer); + theSSLManager = new SSLManagerWindows(sslGlobalParams, isSSLServer); } return Status::OK(); } -SSLConnectionWindows::~SSLConnectionWindows() {} +SSLConnectionWindows::SSLConnectionWindows(SCHANNEL_CRED* cred, + Socket* sock, + const char* initialBytes, + int len) + : _cred(cred), socket(sock), _engine(_cred) { -std::string SSLConnectionWindows::getSNIServerName() const { - invariant(false); - return ""; + _tempBuffer.resize(17 * 1024); + + if (len > 0) { + _engine.put_input(asio::const_buffer(initialBytes, len)); + } } +SSLConnectionWindows::~SSLConnectionWindows() {} + std::unique_ptr<SSLManagerInterface> SSLManagerInterface::create(const SSLParams& params, bool isServer) { return stdx::make_unique<SSLManagerWindows>(params, isServer); @@ -152,24 +268,111 @@ std::unique_ptr<SSLManagerInterface> SSLManagerInterface::create(const SSLParams SSLManagerInterface* getSSLManager() { stdx::lock_guard<SimpleMutex> lck(sslManagerMtx); - if (theSSLManagerWindows) - return theSSLManagerWindows; + if (theSSLManager) + return theSSLManager; return NULL; } SSLManagerWindows::SSLManagerWindows(const SSLParams& params, bool isServer) : _weakValidation(params.sslWeakCertificateValidation), _allowInvalidCertificates(params.sslAllowInvalidCertificates), - _allowInvalidHostnames(params.sslAllowInvalidHostnames) {} + _allowInvalidHostnames(params.sslAllowInvalidHostnames) { + + uassertStatusOK(loadCertificates(params)); + + uassertStatusOK(initSSLContext(&_clientCred, params, ConnectionDirection::kOutgoing)); + + // TODO: validate client certificate + + // SSL server specific initialization + if (isServer) { + uassertStatusOK(initSSLContext(&_serverCred, params, ConnectionDirection::kIncoming)); + + // TODO: validate server certificate + } +} int SSLManagerWindows::SSL_read(SSLConnectionInterface* connInterface, void* buf, int num) { - invariant(false); - return 0; + SSLConnectionWindows* conn = static_cast<SSLConnectionWindows*>(connInterface); + + while (true) { + size_t bytes_transferred; + asio::error_code ec; + asio::ssl::detail::engine::want want = + conn->_engine.read(asio::mutable_buffer(buf, num), ec, bytes_transferred); + if (ec) { + throwSocketError(SocketErrorKind::RECV_ERROR, ec.message()); + } + + switch (want) { + case asio::ssl::detail::engine::want_input_and_retry: { + // ASIO wants more data before it can continue: + // 1. fetch some from the network + // 2. give it to ASIO + // 3. retry + int ret = + recv(conn->socket->rawFD(), reinterpret_cast<char*>(buf), num, portRecvFlags); + if (ret == SOCKET_ERROR) { + conn->socket->handleRecvError(ret, num); + } + + conn->_engine.put_input(asio::const_buffer(buf, ret)); + + continue; + } + case asio::ssl::detail::engine::want_nothing: { + // ASIO wants nothing, return to caller with anything transfered. + return bytes_transferred; + } + default: + severe() << "Unexpected ASIO state: " << static_cast<int>(want); + MONGO_UNREACHABLE; + } + } } int SSLManagerWindows::SSL_write(SSLConnectionInterface* connInterface, const void* buf, int num) { - invariant(false); - return 0; + SSLConnectionWindows* conn = static_cast<SSLConnectionWindows*>(connInterface); + + while (true) { + size_t bytes_transferred; + asio::error_code ec; + asio::ssl::detail::engine::want want = + conn->_engine.write(asio::const_buffer(buf, num), ec, bytes_transferred); + if (ec) { + throwSocketError(SocketErrorKind::SEND_ERROR, ec.message()); + } + + switch (want) { + case asio::ssl::detail::engine::want_output: + case asio::ssl::detail::engine::want_output_and_retry: { + // ASIO wants us to send data out: + // 1. get data from ASIO + // 2. give it to the network + // 3. retry if needed + + asio::mutable_buffer outBuf = conn->_engine.get_output( + asio::mutable_buffer(conn->_tempBuffer.data(), conn->_tempBuffer.size())); + + int ret = send(conn->socket->rawFD(), + reinterpret_cast<const char*>(outBuf.data()), + outBuf.size(), + portSendFlags); + if (ret == SOCKET_ERROR) { + conn->socket->handleSendError(ret, ""); + } + + if (want == asio::ssl::detail::engine::want_output_and_retry) { + continue; + } + + return bytes_transferred; + } + default: + severe() << "Unexpected ASIO state: " << static_cast<int>(want); + MONGO_UNREACHABLE; + } + } } int SSLManagerWindows::SSL_shutdown(SSLConnectionInterface* conn) { @@ -177,26 +380,426 @@ int SSLManagerWindows::SSL_shutdown(SSLConnectionInterface* conn) { return 0; } +StatusWith<UniqueCertificate> readPEMFile(StringData fileName, StringData password) { + + std::ifstream pemFile(fileName.toString(), std::ios::binary); + if (!pemFile.is_open()) { + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "Failed to open PEM file: " << fileName); + } + + std::string buf((std::istreambuf_iterator<char>(pemFile)), std::istreambuf_iterator<char>()); + + pemFile.close(); + + // Search the buffer for the various strings that make up a PEM file + size_t publicKey = buf.find("-----BEGIN CERTIFICATE-----"); + if (publicKey == std::string::npos) { + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "Failed to find Certifiate in: " << fileName); + } + + // TODO: decode encrypted pem + // StringData encryptedPrivateKey = buf.find("-----BEGIN ENCRYPTED PRIVATE KEY-----"); + + // TODO: check if we need both + size_t privateKey = buf.find("-----BEGIN RSA PRIVATE KEY-----"); + if (privateKey == std::string::npos) { + privateKey = buf.find("-----BEGIN PRIVATE KEY-----"); + } + + if (privateKey == std::string::npos) { + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "Failed to find privateKey in: " << fileName); + } + + CERT_BLOB certBlob; + certBlob.cbData = buf.size() - publicKey; + certBlob.pbData = reinterpret_cast<BYTE*>(const_cast<char*>(buf.data() + publicKey)); + + PCCERT_CONTEXT cert; + BOOL ret = CryptQueryObject(CERT_QUERY_OBJECT_BLOB, + &certBlob, + CERT_QUERY_CONTENT_FLAG_ALL, + CERT_QUERY_FORMAT_FLAG_ALL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + reinterpret_cast<const void**>(&cert)); + if (!ret) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptQueryObject failed to get cert: " + << errnoWithDescription(gle)); + } + + UniqueCertificate certHolder(cert); + DWORD privateKeyLen{0}; + + ret = CryptStringToBinaryA(buf.c_str() + privateKey, + 0, // null terminated string + CRYPT_STRING_BASE64HEADER | CRYPT_STRING_STRICT, + NULL, + &privateKeyLen, + NULL, + NULL); + if (!ret) { + DWORD gle = GetLastError(); + if (gle != ERROR_MORE_DATA) { + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptStringToBinary failed to get size of key: " + << errnoWithDescription(gle)); + } + } + + std::unique_ptr<BYTE[]> privateKeyBuf = std::make_unique<BYTE[]>(privateKeyLen); + ret = CryptStringToBinaryA(buf.c_str() + privateKey, + 0, // null terminated string + CRYPT_STRING_BASE64HEADER | CRYPT_STRING_STRICT, + privateKeyBuf.get(), + &privateKeyLen, + NULL, + NULL); + if (!ret) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptStringToBinary failed to read key: " + << errnoWithDescription(gle)); + } + + + DWORD privateBlobLen{0}; + + ret = CryptDecodeObjectEx(X509_ASN_ENCODING, + PKCS_RSA_PRIVATE_KEY, + privateKeyBuf.get(), + privateKeyLen, + CRYPT_DECODE_SHARE_OID_STRING_FLAG, + NULL, + NULL, + &privateBlobLen); + if (!ret) { + DWORD gle = GetLastError(); + if (gle != ERROR_MORE_DATA) { + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptDecodeObjectEx failed to get size of key: " + << errnoWithDescription(gle)); + } + } + + std::unique_ptr<BYTE[]> privateBlobBuf = std::make_unique<BYTE[]>(privateBlobLen); + + ret = CryptDecodeObjectEx(X509_ASN_ENCODING, + PKCS_RSA_PRIVATE_KEY, + privateKeyBuf.get(), + privateKeyLen, + CRYPT_DECODE_SHARE_OID_STRING_FLAG, + NULL, + privateBlobBuf.get(), + &privateBlobLen); + if (!ret) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptDecodeObjectEx failed to read key: " + << errnoWithDescription(gle)); + } + + HCRYPTPROV hProv; + std::wstring wstr; + + // Create the right Crypto context depending on whether we running in a server or outside. + // See https://msdn.microsoft.com/en-us/library/windows/desktop/aa375195(v=vs.85).aspx + if (isSSLServer) { + // Generate a unique name for our key container + // Use the the log file if possible + if (!serverGlobalParams.logpath.empty()) { + wstr = toNativeString(serverGlobalParams.logpath.c_str()); + } else { + auto us = UUID::gen().toString(); + wstr = toNativeString(us.c_str()); + } + + // Use a new key container for the key. We cannot use the default container since the + // default + // container is shared across processes owned by the same user. + // Note: Server side Schannel requires CRYPT_VERIFYCONTEXT off + ret = CryptAcquireContextW( + &hProv, wstr.c_str(), MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_NEWKEYSET | CRYPT_SILENT); + if (!ret) { + DWORD gle = GetLastError(); + + if (gle == NTE_EXISTS) { + + ret = CryptAcquireContextW( + &hProv, wstr.c_str(), MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_SILENT); + if (!ret) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptAcquireContextW failed " + << errnoWithDescription(gle)); + } + + } else { + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptAcquireContextW failed " + << errnoWithDescription(gle)); + } + } + } else { + // Use a transient key container for the key + ret = CryptAcquireContextW( + &hProv, NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT | CRYPT_SILENT); + if (!ret) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptAcquireContextW failed " + << errnoWithDescription(gle)); + } + } + UniqueCryptProvider cryptProvider(hProv); + + HCRYPTKEY hkey; + ret = CryptImportKey(hProv, privateBlobBuf.get(), privateBlobLen, 0, 0, &hkey); + if (!ret) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CryptImportKey failed " << errnoWithDescription(gle)); + } + UniqueCryptKey(hKey); + + if (isSSLServer) { + // Server-side SChannel requires a different way of attaching the private key to the + // certificate + CRYPT_KEY_PROV_INFO keyProvInfo; + memset(&keyProvInfo, 0, sizeof(keyProvInfo)); + keyProvInfo.pwszContainerName = const_cast<wchar_t*>(wstr.c_str()); + keyProvInfo.pwszProvName = const_cast<wchar_t*>(MS_ENHANCED_PROV); + keyProvInfo.dwFlags = CERT_SET_KEY_PROV_HANDLE_PROP_ID | CERT_SET_KEY_CONTEXT_PROP_ID; + keyProvInfo.dwProvType = PROV_RSA_FULL; + keyProvInfo.dwKeySpec = AT_KEYEXCHANGE; + + if (!CertSetCertificateContextProperty( + certHolder.get(), CERT_KEY_PROV_INFO_PROP_ID, 0, &keyProvInfo)) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CertSetCertificateContextProperty Failed " + << errnoWithDescription(gle)); + } + } + + // NOTE: This is used to set the certificate for client side SChannel + ret = CertSetCertificateContextProperty( + cert, CERT_KEY_PROV_HANDLE_PROP_ID, 0, (const void*)hProv); + if (!ret) { + DWORD gle = GetLastError(); + return Status(ErrorCodes::InvalidSSLConfiguration, + str::stream() << "CertSetCertificateContextProperty failed " + << errnoWithDescription(gle)); + } + + return std::move(certHolder); +} + +Status SSLManagerWindows::loadCertificates(const SSLParams& params) { + _clientCertificates[0] = nullptr; + _serverCertificates[0] = nullptr; + + // Load the normal PEM file + if (!params.sslPEMKeyFile.empty()) { + auto swCertificate = readPEMFile(params.sslPEMKeyFile, params.sslPEMKeyPassword); + if (!swCertificate.isOK()) { + return swCertificate.getStatus(); + } + + _pemCertificate = std::move(swCertificate.getValue()); + } + + // Load the cluster PEM file, only applies to server side code + if (!params.sslClusterFile.empty()) { + auto swCertificate = readPEMFile(params.sslClusterFile, params.sslClusterPassword); + if (!swCertificate.isOK()) { + return swCertificate.getStatus(); + } + + _clusterPEMCertificate = std::move(swCertificate.getValue()); + } + + if (_pemCertificate) { + _clientCertificates[0] = _pemCertificate.get(); + _serverCertificates[0] = _pemCertificate.get(); + } + + if (_clusterPEMCertificate) { + _clientCertificates[0] = _clusterPEMCertificate.get(); + } + + return Status::OK(); +} + Status SSLManagerWindows::initSSLContext(SCHANNEL_CRED* cred, const SSLParams& params, ConnectionDirection direction) { + memset(cred, 0, sizeof(*cred)); + cred->dwVersion = SCHANNEL_CRED_VERSION; + cred->dwFlags = SCH_USE_STRONG_CRYPTO; // Use strong crypto; + + uint32_t supportedProtocols = 0; + + if (direction == ConnectionDirection::kIncoming) { + supportedProtocols = SP_PROT_TLS1_SERVER | SP_PROT_TLS1_0_SERVER | SP_PROT_TLS1_1_SERVER | + SP_PROT_TLS1_2_SERVER; + + cred->dwFlags = cred->dwFlags // Flags + | SCH_CRED_REVOCATION_CHECK_CHAIN // Check certificate revocation + | SCH_CRED_SNI_CREDENTIAL // Pass along SNI creds + | SCH_CRED_SNI_ENABLE_OCSP // Enable OCSP + | SCH_CRED_NO_SYSTEM_MAPPER // Do not map certificate to user account + | SCH_CRED_DISABLE_RECONNECTS; // Do not support reconnects + } else { + supportedProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_0_CLIENT | SP_PROT_TLS1_1_CLIENT | + SP_PROT_TLS1_2_CLIENT; + + cred->dwFlags = cred->dwFlags // Flags + | SCH_CRED_REVOCATION_CHECK_CHAIN // Check certificate revocation + | SCH_CRED_NO_SERVERNAME_CHECK // Do not validate server name against cert + | SCH_CRED_NO_DEFAULT_CREDS // No Default Certificate + | SCH_CRED_MANUAL_CRED_VALIDATION; // Validate Certificate Manually + } + + // Set the supported TLS protocols. Allow --sslDisabledProtocols to disable selected ciphers. + for (const SSLParams::Protocols& protocol : params.sslDisabledProtocols) { + if (protocol == SSLParams::Protocols::TLS1_0) { + supportedProtocols &= ~(SP_PROT_TLS1_0_CLIENT | SP_PROT_TLS1_0_SERVER); + } else if (protocol == SSLParams::Protocols::TLS1_1) { + supportedProtocols &= ~(SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_1_SERVER); + } else if (protocol == SSLParams::Protocols::TLS1_2) { + supportedProtocols &= ~(SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_2_SERVER); + } + } + + cred->grbitEnabledProtocols = supportedProtocols; + + if (!params.sslCipherConfig.empty()) { + warning() + << "sslCipherConfig parameter is not supported with Windows SChannel and is ignored."; + } + + if (direction == ConnectionDirection::kOutgoing) { + if (_clientCertificates[0]) { + cred->cCreds = 1; + cred->paCred = _clientCertificates; + } + } else { + cred->cCreds = 1; + cred->paCred = _serverCertificates; + } + return Status::OK(); } SSLConnectionInterface* SSLManagerWindows::connect(Socket* socket) { - return nullptr; + std::unique_ptr<SSLConnectionWindows> sslConn = + stdx::make_unique<SSLConnectionWindows>(&_clientCred, socket, nullptr, 0); + + handshake(sslConn.get(), true); + return sslConn.release(); } SSLConnectionInterface* SSLManagerWindows::accept(Socket* socket, const char* initialBytes, int len) { - return nullptr; + std::unique_ptr<SSLConnectionWindows> sslConn = + stdx::make_unique<SSLConnectionWindows>(&_serverCred, socket, initialBytes, len); + + handshake(sslConn.get(), false); + + return sslConn.release(); +} + +void SSLManagerWindows::handshake(SSLConnectionWindows* conn, bool client) { + initSSLContext(conn->_cred, + getSSLGlobalParams(), + client ? SSLManagerInterface::ConnectionDirection::kOutgoing + : SSLManagerInterface::ConnectionDirection::kIncoming); + + while (true) { + asio::error_code ec; + asio::ssl::detail::engine::want want = + conn->_engine.handshake(client ? asio::ssl::stream_base::handshake_type::client + : asio::ssl::stream_base::handshake_type::server, + ec); + if (ec) { + throwSocketError(SocketErrorKind::RECV_ERROR, ec.message()); + } + + switch (want) { + case asio::ssl::detail::engine::want_input_and_retry: { + // ASIO wants more data before it can continue, + // 1. fetch some from the network + // 2. give it to ASIO + // 3. retry + int ret = recv(conn->socket->rawFD(), + conn->_tempBuffer.data(), + conn->_tempBuffer.size(), + portRecvFlags); + if (ret == SOCKET_ERROR) { + conn->socket->handleRecvError(ret, conn->_tempBuffer.size()); + } + + conn->_engine.put_input(asio::const_buffer(conn->_tempBuffer.data(), ret)); + + continue; + } + case asio::ssl::detail::engine::want_output: + case asio::ssl::detail::engine::want_output_and_retry: { + // ASIO wants us to send data out + // 1. get data from ASIO + // 2. give it to the network + // 3. retry if needed + asio::mutable_buffer outBuf = conn->_engine.get_output( + asio::mutable_buffer(conn->_tempBuffer.data(), conn->_tempBuffer.size())); + + int ret = send(conn->socket->rawFD(), + reinterpret_cast<const char*>(outBuf.data()), + outBuf.size(), + portSendFlags); + if (ret == SOCKET_ERROR) { + conn->socket->handleSendError(ret, ""); + } + + if (want == asio::ssl::detail::engine::want_output_and_retry) { + continue; + } + + // ASIO wants nothing, return to caller since we are done with handshake + return; + } + case asio::ssl::detail::engine::want_nothing: { + // ASIO wants nothing, return to caller since we are done with handshake + return; + } + default: + MONGO_UNREACHABLE; + } + } } SSLPeerInfo SSLManagerWindows::parseAndValidatePeerCertificateDeprecated( const SSLConnectionInterface* conn, const std::string& remoteHost) { - return SSLPeerInfo(); + auto swPeerSubjectName = parseAndValidatePeerCertificate( + const_cast<SSLConnectionWindows*>(static_cast<const SSLConnectionWindows*>(conn)) + ->_engine.native_handle(), + remoteHost); + // We can't use uassertStatusOK here because we need to throw a SocketException. + if (!swPeerSubjectName.isOK()) { + throwSocketError(SocketErrorKind::CONNECT_ERROR, swPeerSubjectName.getStatus().reason()); + } + + return swPeerSubjectName.getValue().get_value_or(SSLPeerInfo()); } StatusWith<boost::optional<SSLPeerInfo>> SSLManagerWindows::parseAndValidatePeerCertificate( @@ -205,5 +808,4 @@ StatusWith<boost::optional<SSLPeerInfo>> SSLManagerWindows::parseAndValidatePeer return {boost::none}; } - } // namespace mongo diff --git a/src/mongo/util/net/ssl_stream.cpp b/src/mongo/util/net/ssl_stream.cpp index b7022df9532..a99723d4a32 100644 --- a/src/mongo/util/net/ssl_stream.cpp +++ b/src/mongo/util/net/ssl_stream.cpp @@ -7,6 +7,8 @@ // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // +#include "mongo/platform/basic.h" + #include "mongo/config.h" #ifdef MONGO_CONFIG_SSL |