diff options
author | Cheahuychou Mao <mao.cheahuychou@gmail.com> | 2020-12-16 17:09:45 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-01-05 05:58:56 +0000 |
commit | dc3ef13edd2ec8054f97fd160e72dae5edec3061 (patch) | |
tree | b8198ba35ab8715f53df1b4ead6e493f034e2f1a /src/mongo | |
parent | 1dfe8355a2b034ded045191f4e3d4be827365621 (diff) | |
download | mongo-dc3ef13edd2ec8054f97fd160e72dae5edec3061.tar.gz |
SERVER-52707 Make tenant migration recipient use x509 certificate to connect to donor
Diffstat (limited to 'src/mongo')
18 files changed, 159 insertions, 87 deletions
diff --git a/src/mongo/client/connection_string.h b/src/mongo/client/connection_string.h index 77f5f73a277..d505e1e5d24 100644 --- a/src/mongo/client/connection_string.h +++ b/src/mongo/client/connection_string.h @@ -46,6 +46,7 @@ namespace mongo { class ClientAPIVersionParameters; class DBClientBase; class MongoURI; +struct TransientSSLParams; /** * ConnectionString handles parsing different ways to connect to mongo and determining method @@ -150,7 +151,8 @@ public: std::string& errmsg, double socketTimeout = 0, const MongoURI* uri = nullptr, - const ClientAPIVersionParameters* apiParameters = nullptr) const; + const ClientAPIVersionParameters* apiParameters = nullptr, + const TransientSSLParams* transientSSLParams = nullptr) const; static StatusWith<ConnectionString> parse(const std::string& url); diff --git a/src/mongo/client/connection_string_connect.cpp b/src/mongo/client/connection_string_connect.cpp index ae4bf55d6fd..28cc448639c 100644 --- a/src/mongo/client/connection_string_connect.cpp +++ b/src/mongo/client/connection_string_connect.cpp @@ -51,7 +51,8 @@ std::unique_ptr<DBClientBase> ConnectionString::connect( std::string& errmsg, double socketTimeout, const MongoURI* uri, - const ClientAPIVersionParameters* apiParameters) const { + const ClientAPIVersionParameters* apiParameters, + const TransientSSLParams* transientSSLParams) const { MongoURI newURI{}; if (uri) { newURI = *uri; @@ -69,7 +70,11 @@ std::unique_ptr<DBClientBase> ConnectionString::connect( "Creating new connection to: {hostAndPort}", "Creating new connection", "hostAndPort"_attr = server); - if (!c->connect(server, applicationName, errmsg)) { + if (!c->connect(server, + applicationName, + errmsg, + transientSSLParams ? boost::make_optional(*transientSSLParams) + : boost::none)) { continue; } LOGV2_DEBUG(20110, 1, "Connected connection!"); diff --git a/src/mongo/client/dbclient_connection.cpp b/src/mongo/client/dbclient_connection.cpp index 49b581d5d2d..383be9ee5d7 100644 --- a/src/mongo/client/dbclient_connection.cpp +++ b/src/mongo/client/dbclient_connection.cpp @@ -275,8 +275,9 @@ Status DBClientConnection::authenticateInternalUser(auth::StepDownBehavior stepD bool DBClientConnection::connect(const HostAndPort& server, StringData applicationName, - std::string& errmsg) { - auto connectStatus = connect(server, applicationName); + std::string& errmsg, + boost::optional<TransientSSLParams> transientSSLParams) { + auto connectStatus = connect(server, applicationName, transientSSLParams); if (!connectStatus.isOK()) { errmsg = connectStatus.reason(); return false; @@ -284,8 +285,10 @@ bool DBClientConnection::connect(const HostAndPort& server, return true; } -Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData applicationName) { - auto connectStatus = connectSocketOnly(serverAddress); +Status DBClientConnection::connect(const HostAndPort& serverAddress, + StringData applicationName, + boost::optional<TransientSSLParams> transientSSLParams) { + auto connectStatus = connectSocketOnly(serverAddress, transientSSLParams); if (!connectStatus.isOK()) { return connectStatus; } @@ -391,7 +394,8 @@ Status DBClientConnection::connect(const HostAndPort& serverAddress, StringData return Status::OK(); } -Status DBClientConnection::connectSocketOnly(const HostAndPort& serverAddress) { +Status DBClientConnection::connectSocketOnly( + const HostAndPort& serverAddress, boost::optional<TransientSSLParams> transientSSLParams) { _serverAddress = serverAddress; _markFailed(kReleaseSession); @@ -415,7 +419,10 @@ Status DBClientConnection::connectSocketOnly(const HostAndPort& serverAddress) { } auto sws = getGlobalServiceContext()->getTransportLayer()->connect( - serverAddress, _uri.getSSLMode(), _socketTimeout.value_or(Milliseconds{5000})); + serverAddress, + transientSSLParams ? transport::kEnableSSL : _uri.getSSLMode(), + _socketTimeout.value_or(Milliseconds{5000}), + transientSSLParams); if (!sws.isOK()) { return Status(ErrorCodes::HostUnreachable, str::stream() << "couldn't connect to server " << _serverAddress.toString() diff --git a/src/mongo/client/dbclient_connection.h b/src/mongo/client/dbclient_connection.h index 8532ddf231f..845af96d900 100644 --- a/src/mongo/client/dbclient_connection.h +++ b/src/mongo/client/dbclient_connection.h @@ -112,7 +112,10 @@ public: * @param errmsg any relevant error message will appended to the string * @return false if fails to connect. */ - bool connect(const HostAndPort& server, StringData applicationName, std::string& errmsg); + bool connect(const HostAndPort& server, + StringData applicationName, + std::string& errmsg, + boost::optional<TransientSSLParams> transientSSLParams = boost::none); /** * Semantically equivalent to the previous connect method, but returns a Status @@ -120,7 +123,9 @@ public: * * @param server The server to connect to. */ - virtual Status connect(const HostAndPort& server, StringData applicationName); + virtual Status connect(const HostAndPort& server, + StringData applicationName, + boost::optional<TransientSSLParams> transientSSLParams = boost::none); /** * This version of connect does not run 'isMaster' after creating a TCP connection to the @@ -129,7 +134,8 @@ public: * * @param server The server to connect to. */ - Status connectSocketOnly(const HostAndPort& server); + Status connectSocketOnly(const HostAndPort& server, + boost::optional<TransientSSLParams> transientSSLParams = boost::none); /** * Logs out the connection for the given database. diff --git a/src/mongo/db/repl/tenant_migration_donor_service.cpp b/src/mongo/db/repl/tenant_migration_donor_service.cpp index 827b2c16a88..50c5dcab032 100644 --- a/src/mongo/db/repl/tenant_migration_donor_service.cpp +++ b/src/mongo/db/repl/tenant_migration_donor_service.cpp @@ -84,7 +84,7 @@ bool shouldStopUpdatingDonorStateDoc(Status status, const CancelationToken& toke } bool shouldStopSendingRecipientCommand(Status status, const CancelationToken& token) { - return status.isOK() || token.isCanceled(); + return status.isOK() || !ErrorCodes::isRetriableError(status) || token.isCanceled(); } } // namespace @@ -486,7 +486,10 @@ ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendCommandToRecipi if (!response.isOK()) { return response.status; } - return getStatusFromCommandResult(response.data); + auto commandStatus = getStatusFromCommandResult(response.data); + commandStatus.addContext( + "Tenant migration recipient command failed"); + return commandStatus; }); }); }) diff --git a/src/mongo/db/repl/tenant_migration_recipient_service.cpp b/src/mongo/db/repl/tenant_migration_recipient_service.cpp index 2443da81f08..e2cf7f52888 100644 --- a/src/mongo/db/repl/tenant_migration_recipient_service.cpp +++ b/src/mongo/db/repl/tenant_migration_recipient_service.cpp @@ -67,6 +67,7 @@ constexpr StringData kOplogBufferPrefix = "repl.migration.oplog_"_sd; // A convenient place to set test-specific parameters. MONGO_FAIL_POINT_DEFINE(pauseBeforeRunTenantMigrationRecipientInstance); MONGO_FAIL_POINT_DEFINE(pauseAfterRunTenantMigrationRecipientInstance); +MONGO_FAIL_POINT_DEFINE(skipTenantMigrationRecipientAuth); MONGO_FAIL_POINT_DEFINE(autoRecipientForgetMigration); // Fails before waiting for the state doc to be majority replicated. @@ -267,9 +268,17 @@ OpTime TenantMigrationRecipientService::Instance::waitUntilTimestampIsMajorityCo } std::unique_ptr<DBClientConnection> TenantMigrationRecipientService::Instance::_connectAndAuth( - const HostAndPort& serverAddress, StringData applicationName, BSONObj authParams) { + const HostAndPort& serverAddress, + StringData applicationName, + const TransientSSLParams* transientSSLParams) { std::string errMsg; - auto clientBase = ConnectionString(serverAddress).connect(applicationName, errMsg); + auto clientBase = ConnectionString(serverAddress) + .connect(applicationName, + errMsg, + 0 /* socketTimeout */, + nullptr /* uri */, + nullptr /* apiParameters */, + transientSSLParams); if (!clientBase) { LOGV2_ERROR(4880400, "Failed to connect to migration donor", @@ -278,24 +287,21 @@ std::unique_ptr<DBClientConnection> TenantMigrationRecipientService::Instance::_ "serverAddress"_attr = serverAddress, "applicationName"_attr = applicationName, "error"_attr = errMsg); - uasserted(ErrorCodes::HostNotFound, errMsg); + // TODO (SERVER-53423): Make ConnectString::connect return a status instead of setting error + // message + uasserted(errMsg.find("InvalidSSLConfiguration") != std::string::npos + ? ErrorCodes::InvalidSSLConfiguration + : ErrorCodes::HostUnreachable, + errMsg); } - // Authenticate connection to the donor. - uassertStatusOK(replAuthenticate(clientBase.get()) - .withContext(str::stream() - << "TenantMigrationRecipientService failed to authenticate to " - << serverAddress)); - // ConnectionString::connect() always returns a DBClientConnection in a unique_ptr of // DBClientBase type. std::unique_ptr<DBClientConnection> client( checked_cast<DBClientConnection*>(clientBase.release())); - if (!authParams.isEmpty()) { - client->auth(authParams); - } else { - // Tenant migration in production should always require auth. - uassert(4880405, "No auth data provided to tenant migration", getTestCommandsEnabled()); + + if (MONGO_likely(!skipTenantMigrationRecipientAuth.shouldFail())) { + client->auth(auth::createInternalX509AuthDocument()); } return client; @@ -309,8 +315,7 @@ TenantMigrationRecipientService::Instance::_createAndConnectClients() { "tenantId"_attr = getTenantId(), "migrationId"_attr = getMigrationUUID(), "connectionString"_attr = _donorConnectionString, - "readPreference"_attr = _readPreference, - "authParams"_attr = redact(_authParams)); + "readPreference"_attr = _readPreference); auto connectionStringWithStatus = ConnectionString::parse(_donorConnectionString); if (!connectionStringWithStatus.isOK()) { LOGV2_ERROR(4880403, @@ -322,11 +327,11 @@ TenantMigrationRecipientService::Instance::_createAndConnectClients() { return SemiFuture<ConnectionPair>::makeReady(connectionStringWithStatus.getStatus()); } - auto connectionString = std::move(connectionStringWithStatus.getValue()); - const auto& servers = connectionString.getServers(); + auto donorConnectionString = std::move(connectionStringWithStatus.getValue()); + const auto& servers = donorConnectionString.getServers(); stdx::lock_guard lk(_mutex); _donorReplicaSetMonitor = ReplicaSetMonitor::createIfNeeded( - connectionString.getSetName(), std::set<HostAndPort>(servers.begin(), servers.end())); + donorConnectionString.getSetName(), std::set<HostAndPort>(servers.begin(), servers.end())); // Only ever used to cancel when the setTenantMigrationRecipientInstanceHostTimeout failpoint is // set. @@ -342,7 +347,8 @@ TenantMigrationRecipientService::Instance::_createAndConnectClients() { return _donorReplicaSetMonitor->getHostOrRefresh(_readPreference, getHostCancelSource.token()) .thenRunOn(**_scopedExecutor) - .then([this, self = shared_from_this()](const HostAndPort& serverAddress) { + .then([this, self = shared_from_this(), donorConnectionString]( + const HostAndPort& serverAddress) { // Application name is constructed such that it doesn't exceeds // kMaxApplicationNameByteLength (128 bytes). // "TenantMigration_" (16 bytes) + <tenantId> (61 bytes) + "_" (1 byte) + @@ -352,14 +358,22 @@ TenantMigrationRecipientService::Instance::_createAndConnectClients() { // character long, the maximum length of tenantId can only be 61 bytes. auto applicationName = "TenantMigration_" + getTenantId() + "_" + getMigrationUUID().toString(); - auto client = _connectAndAuth(serverAddress, applicationName, _authParams); + + auto recipientCertificate = _stateDoc.getRecipientCertificateForDonor(); + auto recipientSSLClusterPEMPayload = recipientCertificate.getCertificate().toString() + + "\n" + recipientCertificate.getPrivateKey().toString(); + const TransientSSLParams transientSSLParams{donorConnectionString, + std::move(recipientSSLClusterPEMPayload)}; + + auto client = _connectAndAuth(serverAddress, applicationName, &transientSSLParams); // Application name is constructed such that it doesn't exceeds // kMaxApplicationNameByteLength (128 bytes). // "TenantMigration_" (16 bytes) + <tenantId> (61 bytes) + "_" (1 byte) + // <migrationUuid> (36 bytes) + _oplogFetcher" (13 bytes) = 127 bytes length. applicationName += "_oplogFetcher"; - auto oplogFetcherClient = _connectAndAuth(serverAddress, applicationName, _authParams); + auto oplogFetcherClient = + _connectAndAuth(serverAddress, applicationName, &transientSSLParams); return ConnectionPair(std::move(client), std::move(oplogFetcherClient)); }) .onError( diff --git a/src/mongo/db/repl/tenant_migration_recipient_service.h b/src/mongo/db/repl/tenant_migration_recipient_service.h index 320eb9b5366..d0b4f13a4f6 100644 --- a/src/mongo/db/repl/tenant_migration_recipient_service.h +++ b/src/mongo/db/repl/tenant_migration_recipient_service.h @@ -286,9 +286,10 @@ public: * non-empty. Throws a user assertion on failure. * */ - std::unique_ptr<DBClientConnection> _connectAndAuth(const HostAndPort& serverAddress, - StringData applicationName, - BSONObj authParams); + std::unique_ptr<DBClientConnection> _connectAndAuth( + const HostAndPort& serverAddress, + StringData applicationName, + const TransientSSLParams* transientSSLParams); /** * Creates and connects both the oplog fetcher client and the client used for other @@ -387,8 +388,6 @@ public: const UUID _migrationUuid; // (R) const std::string _donorConnectionString; // (R) const ReadPreferenceSetting _readPreference; // (R) - // TODO(SERVER-50670): Populate authParams - const BSONObj _authParams; // (M) std::shared_ptr<ReplicaSetMonitor> _donorReplicaSetMonitor; // (M) diff --git a/src/mongo/db/repl/tenant_migration_recipient_service_test.cpp b/src/mongo/db/repl/tenant_migration_recipient_service_test.cpp index 92d620d41d3..fed2bcff73e 100644 --- a/src/mongo/db/repl/tenant_migration_recipient_service_test.cpp +++ b/src/mongo/db/repl/tenant_migration_recipient_service_test.cpp @@ -187,9 +187,16 @@ public: _service = _registry->lookupServiceByName( TenantMigrationRecipientService::kTenantMigrationRecipientServiceName); ASSERT(_service); + + // MockReplicaSet uses custom connection string which does not support auth. + auto authFp = globalFailPointRegistry().find("skipTenantMigrationRecipientAuth"); + authFp->setMode(FailPoint::alwaysOn); } void tearDown() override { + auto authFp = globalFailPointRegistry().find("skipTenantMigrationRecipientAuth"); + authFp->setMode(FailPoint::off); + WaitForMajorityService::get(getServiceContext()).shutDown(); _registry->onShutdown(); diff --git a/src/mongo/dbtests/mock/mock_dbclient_connection.h b/src/mongo/dbtests/mock/mock_dbclient_connection.h index 37190d624cc..0d4180ed952 100644 --- a/src/mongo/dbtests/mock/mock_dbclient_connection.h +++ b/src/mongo/dbtests/mock/mock_dbclient_connection.h @@ -107,7 +107,9 @@ public: bool connect(const char* hostName, StringData applicationName, std::string& errmsg); - Status connect(const HostAndPort& host, StringData applicationName) override { + Status connect(const HostAndPort& host, + StringData applicationName, + boost::optional<TransientSSLParams> transientSSLParams = boost::none) override { std::string errmsg; if (!connect(host.toString().c_str(), applicationName, errmsg)) { return {ErrorCodes::HostUnreachable, errmsg}; diff --git a/src/mongo/executor/network_interface_tl.cpp b/src/mongo/executor/network_interface_tl.cpp index d5389deb331..14460069e5e 100644 --- a/src/mongo/executor/network_interface_tl.cpp +++ b/src/mongo/executor/network_interface_tl.cpp @@ -130,7 +130,7 @@ NetworkInterfaceTL::NetworkInterfaceTL(std::string instanceName, #ifdef MONGO_CONFIG_SSL if (_connPoolOpts.transientSSLParams) { auto statusOrContext = - _tl->createTransientSSLContext(_connPoolOpts.transientSSLParams.get(), nullptr); + _tl->createTransientSSLContext(_connPoolOpts.transientSSLParams.get()); uassertStatusOK(statusOrContext.getStatus()); transientSSLContext = std::move(statusOrContext.getValue()); } diff --git a/src/mongo/transport/transport_layer.h b/src/mongo/transport/transport_layer.h index 10745f16198..a3f4cc326a8 100644 --- a/src/mongo/transport/transport_layer.h +++ b/src/mongo/transport/transport_layer.h @@ -87,9 +87,11 @@ public: virtual ~TransportLayer() = default; - virtual StatusWith<SessionHandle> connect(HostAndPort peer, - ConnectSSLMode sslMode, - Milliseconds timeout) = 0; + virtual StatusWith<SessionHandle> connect( + HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + boost::optional<TransientSSLParams> transientSSLParams = boost::none) = 0; virtual Future<SessionHandle> asyncConnect( HostAndPort peer, @@ -142,8 +144,7 @@ public: * used. */ virtual StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> - createTransientSSLContext(const TransientSSLParams& transientSSLParams, - const SSLManagerInterface* optionalManager) = 0; + createTransientSSLContext(const TransientSSLParams& transientSSLParams) = 0; #endif private: diff --git a/src/mongo/transport/transport_layer_asio.cpp b/src/mongo/transport/transport_layer_asio.cpp index 02d7e7b6655..60d66c0c04b 100644 --- a/src/mongo/transport/transport_layer_asio.cpp +++ b/src/mongo/transport/transport_layer_asio.cpp @@ -454,9 +454,19 @@ Status makeConnectError(Status status, const HostAndPort& peer, const WrappedEnd } -StatusWith<SessionHandle> TransportLayerASIO::connect(HostAndPort peer, - ConnectSSLMode sslMode, - Milliseconds timeout) { +StatusWith<SessionHandle> TransportLayerASIO::connect( + HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + boost::optional<TransientSSLParams> transientSSLParams) { + if (transientSSLParams) { + uassert(ErrorCodes::InvalidSSLConfiguration, + "Specified transient SSL params but connection SSL mode is not set", + sslMode == kEnableSSL); + LOGV2_DEBUG( + 5270701, 2, "Connecting to peer using transient SSL connection", "peer"_attr = peer); + } + std::error_code ec; GenericSocket sock(*_egressReactor); WrappedResolver resolver(*_egressReactor); @@ -473,7 +483,7 @@ StatusWith<SessionHandle> TransportLayerASIO::connect(HostAndPort peer, } auto endpoints = std::move(swEndpoints.getValue()); - auto sws = _doSyncConnect(endpoints.front(), peer, timeout); + auto sws = _doSyncConnect(endpoints.front(), peer, timeout, transientSSLParams); if (!sws.isOK()) { return sws.getStatus(); } @@ -515,7 +525,10 @@ StatusWith<SessionHandle> TransportLayerASIO::connect(HostAndPort peer, template <typename Endpoint> StatusWith<TransportLayerASIO::ASIOSessionHandle> TransportLayerASIO::_doSyncConnect( - Endpoint endpoint, const HostAndPort& peer, const Milliseconds& timeout) { + Endpoint endpoint, + const HostAndPort& peer, + const Milliseconds& timeout, + boost::optional<TransientSSLParams> transientSSLParams) { GenericSocket sock(*_egressReactor); std::error_code ec; @@ -563,7 +576,16 @@ StatusWith<TransportLayerASIO::ASIOSessionHandle> TransportLayerASIO::_doSyncCon sock.non_blocking(false); try { - return std::make_shared<ASIOSession>(this, std::move(sock), false, *endpoint); + std::shared_ptr<const transport::SSLConnectionContext> transientSSLContext; +#ifdef MONGO_CONFIG_SSL + if (transientSSLParams) { + auto statusOrContext = createTransientSSLContext(transientSSLParams.get()); + uassertStatusOK(statusOrContext.getStatus()); + transientSSLContext = std::move(statusOrContext.getValue()); + } +#endif + return std::make_shared<ASIOSession>( + this, std::move(sock), false, *endpoint, transientSSLContext); } catch (const DBException& e) { return e.toStatus(); } @@ -1277,9 +1299,7 @@ TransportLayerASIO::_createSSLContext(std::shared_ptr<SSLManagerInterface>& mana } StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> -TransportLayerASIO::createTransientSSLContext(const TransientSSLParams& transientSSLParams, - const SSLManagerInterface* optionalManager) { - +TransportLayerASIO::createTransientSSLContext(const TransientSSLParams& transientSSLParams) { auto manager = getSSLManager(); if (!manager) { return Status(ErrorCodes::InvalidSSLConfiguration, "TransportLayerASIO has no SSL manager"); diff --git a/src/mongo/transport/transport_layer_asio.h b/src/mongo/transport/transport_layer_asio.h index 6eac75fae2f..04d2d136427 100644 --- a/src/mongo/transport/transport_layer_asio.h +++ b/src/mongo/transport/transport_layer_asio.h @@ -122,7 +122,8 @@ public: StatusWith<SessionHandle> connect(HostAndPort peer, ConnectSSLMode sslMode, - Milliseconds timeout) final; + Milliseconds timeout, + boost::optional<TransientSSLParams> transientSSLParams) final; Future<SessionHandle> asyncConnect( HostAndPort peer, @@ -166,8 +167,7 @@ public: * used. */ StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> createTransientSSLContext( - const TransientSSLParams& transientSSLParams, - const SSLManagerInterface* optionalManager) override; + const TransientSSLParams& transientSSLParams) override; #endif private: @@ -182,9 +182,11 @@ private: void _acceptConnection(GenericAcceptor& acceptor); template <typename Endpoint> - StatusWith<ASIOSessionHandle> _doSyncConnect(Endpoint endpoint, - const HostAndPort& peer, - const Milliseconds& timeout); + StatusWith<ASIOSessionHandle> _doSyncConnect( + Endpoint endpoint, + const HostAndPort& peer, + const Milliseconds& timeout, + boost::optional<TransientSSLParams> transientSSLParams); StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> _createSSLContext( std::shared_ptr<SSLManagerInterface>& manager, diff --git a/src/mongo/transport/transport_layer_manager.cpp b/src/mongo/transport/transport_layer_manager.cpp index b3a183cebd6..536c7b39148 100644 --- a/src/mongo/transport/transport_layer_manager.cpp +++ b/src/mongo/transport/transport_layer_manager.cpp @@ -61,10 +61,12 @@ void TransportLayerManager::_foreach(Callable&& cb) const { } } -StatusWith<SessionHandle> TransportLayerManager::connect(HostAndPort peer, - ConnectSSLMode sslMode, - Milliseconds timeout) { - return _tls.front()->connect(peer, sslMode, timeout); +StatusWith<SessionHandle> TransportLayerManager::connect( + HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + boost::optional<TransientSSLParams> transientSSLParams) { + return _tls.front()->connect(peer, sslMode, timeout, transientSSLParams); } Future<SessionHandle> TransportLayerManager::asyncConnect( @@ -156,13 +158,12 @@ Status TransportLayerManager::rotateCertificates(std::shared_ptr<SSLManagerInter } StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> -TransportLayerManager::createTransientSSLContext(const TransientSSLParams& transientSSLParams, - const SSLManagerInterface* optionalManager) { +TransportLayerManager::createTransientSSLContext(const TransientSSLParams& transientSSLParams) { Status firstError(ErrorCodes::InvalidSSLConfiguration, "Failure creating transient SSL context"); for (auto&& tl : _tls) { - auto statusOrContext = tl->createTransientSSLContext(transientSSLParams, optionalManager); + auto statusOrContext = tl->createTransientSSLContext(transientSSLParams); if (statusOrContext.isOK()) { return std::move(statusOrContext.getValue()); } diff --git a/src/mongo/transport/transport_layer_manager.h b/src/mongo/transport/transport_layer_manager.h index b561a67f591..3cc6538a6c5 100644 --- a/src/mongo/transport/transport_layer_manager.h +++ b/src/mongo/transport/transport_layer_manager.h @@ -63,9 +63,11 @@ public: explicit TransportLayerManager(const WireSpec& wireSpec = WireSpec::instance()) : TransportLayer(wireSpec) {} - StatusWith<SessionHandle> connect(HostAndPort peer, - ConnectSSLMode sslMode, - Milliseconds timeout) override; + StatusWith<SessionHandle> connect( + HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + boost::optional<TransientSSLParams> transientSSLParams) override; Future<SessionHandle> asyncConnect( HostAndPort peer, ConnectSSLMode sslMode, @@ -109,8 +111,7 @@ public: bool asyncOCSPStaple) override; StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> createTransientSSLContext( - const TransientSSLParams& transientSSLParams, - const SSLManagerInterface* optionalManager) override; + const TransientSSLParams& transientSSLParams) override; #endif private: template <typename Callable> diff --git a/src/mongo/transport/transport_layer_mock.cpp b/src/mongo/transport/transport_layer_mock.cpp index af00a2bca22..a640b2d49ad 100644 --- a/src/mongo/transport/transport_layer_mock.cpp +++ b/src/mongo/transport/transport_layer_mock.cpp @@ -62,9 +62,11 @@ bool TransportLayerMock::owns(Session::Id id) { return _sessions.count(id) > 0; } -StatusWith<SessionHandle> TransportLayerMock::connect(HostAndPort peer, - ConnectSSLMode sslMode, - Milliseconds timeout) { +StatusWith<SessionHandle> TransportLayerMock::connect( + HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + boost::optional<TransientSSLParams> transientSSLParams) { MONGO_UNREACHABLE; } @@ -106,8 +108,7 @@ TransportLayerMock::~TransportLayerMock() { #ifdef MONGO_CONFIG_SSL StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> -TransportLayerMock::createTransientSSLContext(const TransientSSLParams& transientSSLParams, - const SSLManagerInterface* optionalManager) { +TransportLayerMock::createTransientSSLContext(const TransientSSLParams& transientSSLParams) { return Status(ErrorCodes::InvalidSSLConfiguration, "Failure creating transient SSL context"); } diff --git a/src/mongo/transport/transport_layer_mock.h b/src/mongo/transport/transport_layer_mock.h index a376f66fc95..cd07c75a68c 100644 --- a/src/mongo/transport/transport_layer_mock.h +++ b/src/mongo/transport/transport_layer_mock.h @@ -57,9 +57,11 @@ public: SessionHandle get(Session::Id id); bool owns(Session::Id id); - StatusWith<SessionHandle> connect(HostAndPort peer, - ConnectSSLMode sslMode, - Milliseconds timeout) override; + StatusWith<SessionHandle> connect( + HostAndPort peer, + ConnectSSLMode sslMode, + Milliseconds timeout, + boost::optional<TransientSSLParams> transientSSLParams) override; Future<SessionHandle> asyncConnect( HostAndPort peer, ConnectSSLMode sslMode, @@ -85,8 +87,7 @@ public: } StatusWith<std::shared_ptr<const transport::SSLConnectionContext>> createTransientSSLContext( - const TransientSSLParams& transientSSLParams, - const SSLManagerInterface* optionalManager) override; + const TransientSSLParams& transientSSLParams) override; #endif private: diff --git a/src/mongo/util/net/ssl_manager_test.cpp b/src/mongo/util/net/ssl_manager_test.cpp index 60349b4b2c2..a435dc05d2c 100644 --- a/src/mongo/util/net/ssl_manager_test.cpp +++ b/src/mongo/util/net/ssl_manager_test.cpp @@ -625,7 +625,7 @@ TEST(SSLManager, TransientSSLParams) { transientSSLParams.sslClusterPEMPayload = loadFile("jstests/libs/client.pem"); transientSSLParams.targetedClusterConnectionString = ConnectionString::forLocal(); - auto result = tla.createTransientSSLContext(transientSSLParams, manager.get()); + auto result = tla.createTransientSSLContext(transientSSLParams); // This will fail because we need to rotate certificates first to // initialize the default SSL context inside TransportLayerASIO. @@ -634,7 +634,7 @@ TEST(SSLManager, TransientSSLParams) { // Init the transport properly. uassertStatusOK(tla.rotateCertificates(manager, false /* asyncOCSPStaple */)); - result = tla.createTransientSSLContext(transientSSLParams, manager.get()); + result = tla.createTransientSSLContext(transientSSLParams); uassertStatusOK(result.getStatus()); } |