From a8bfcc13011c5e859a10e56ce882a0d53a0a2031 Mon Sep 17 00:00:00 2001 From: Jonathan Reams Date: Mon, 5 Nov 2018 18:22:59 -0500 Subject: SERVER-32978 Advertise SCRAM-SHA-256 authentication for the internal user --- src/mongo/client/sasl_client_authenticate_impl.cpp | 81 +++++++++------------- 1 file changed, 34 insertions(+), 47 deletions(-) (limited to 'src/mongo/client/sasl_client_authenticate_impl.cpp') diff --git a/src/mongo/client/sasl_client_authenticate_impl.cpp b/src/mongo/client/sasl_client_authenticate_impl.cpp index 6623966f612..2940e0eb809 100644 --- a/src/mongo/client/sasl_client_authenticate_impl.cpp +++ b/src/mongo/client/sasl_client_authenticate_impl.cpp @@ -167,19 +167,18 @@ Status configureSession(SaslClientSession* session, return session->initialize(); } -void asyncSaslConversation(auth::RunCommandHook runCommand, - const std::shared_ptr& session, - const BSONObj& saslCommandPrefix, - const BSONObj& inputObj, - std::string targetDatabase, - int saslLogLevel, - auth::AuthCompletionHandler handler) { +Future asyncSaslConversation(auth::RunCommandHook runCommand, + const std::shared_ptr& session, + const BSONObj& saslCommandPrefix, + const BSONObj& inputObj, + std::string targetDatabase, + int saslLogLevel) { // Extract payload from previous step std::string payload; BSONType type; auto status = saslExtractPayload(inputObj, &payload, &type); if (!status.isOK()) - return handler(std::move(status)); + return status; LOG(saslLogLevel) << "sasl client input: " << base64::encode(payload) << endl; @@ -187,7 +186,7 @@ void asyncSaslConversation(auth::RunCommandHook runCommand, std::string responsePayload; status = session->step(payload, &responsePayload); if (!status.isOK()) - return handler(std::move(status)); + return status; LOG(saslLogLevel) << "sasl client output: " << base64::encode(responsePayload) << endl; @@ -202,41 +201,31 @@ void asyncSaslConversation(auth::RunCommandHook runCommand, if (!conversationId.eoo()) commandBuilder.append(conversationId); - auto request = RemoteCommandRequest(); - request.dbname = targetDatabase; - request.cmdObj = commandBuilder.obj(); - // Asynchronously continue the conversation - runCommand( - request, - [runCommand, session, targetDatabase, saslLogLevel, handler](auth::AuthResponse response) { - if (!response.isOK()) { - return handler(std::move(response)); - } - - auto serverResponse = response.data.getOwned(); + return runCommand(OpMsgRequest::fromDBAndBody(targetDatabase, commandBuilder.obj())) + .then([runCommand, session, targetDatabase, saslLogLevel]( + BSONObj serverResponse) -> Future { auto status = getStatusFromCommandResult(serverResponse); if (!status.isOK()) { - return handler(status); + return status; } // Exit if we have finished if (session->isDone()) { bool isServerDone = serverResponse[saslCommandDoneFieldName].trueValue(); if (!isServerDone) { - return handler({ErrorCodes::ProtocolError, "Client finished before server."}); + return Status(ErrorCodes::ProtocolError, "Client finished before server."); } - return handler(std::move(response)); + return Status::OK(); } - BSONObj saslFollowupCommandPrefix = BSON(saslContinueCommandName << 1); - asyncSaslConversation(runCommand, - session, - std::move(saslFollowupCommandPrefix), - std::move(serverResponse), - std::move(targetDatabase), - saslLogLevel, - handler); + static const BSONObj saslFollowupCommandPrefix = BSON(saslContinueCommandName << 1); + return asyncSaslConversation(runCommand, + session, + std::move(saslFollowupCommandPrefix), + std::move(serverResponse), + std::move(targetDatabase), + saslLogLevel); }); } @@ -244,26 +233,25 @@ void asyncSaslConversation(auth::RunCommandHook runCommand, * Driver for the client side of a sasl authentication session, conducted synchronously over * "client". */ -void saslClientAuthenticateImpl(auth::RunCommandHook runCommand, - const HostAndPort& hostname, - const BSONObj& saslParameters, - auth::AuthCompletionHandler handler) { +Future saslClientAuthenticateImpl(auth::RunCommandHook runCommand, + const HostAndPort& hostname, + const BSONObj& saslParameters) { int saslLogLevel = getSaslClientLogLevel(saslParameters); std::string targetDatabase; try { Status status = bsonExtractStringFieldWithDefault( saslParameters, saslCommandUserDBFieldName, saslDefaultDBName, &targetDatabase); if (!status.isOK()) - return handler(std::move(status)); + return status; } catch (const DBException& ex) { - return handler(ex.toStatus()); + return ex.toStatus(); } std::string mechanism; Status status = bsonExtractStringField(saslParameters, saslCommandMechanismFieldName, &mechanism); if (!status.isOK()) { - return handler(std::move(status)); + return status; } // NOTE: this must be a shared_ptr so that we can capture it in a lambda later on. @@ -272,19 +260,18 @@ void saslClientAuthenticateImpl(auth::RunCommandHook runCommand, status = configureSession(session.get(), hostname, targetDatabase, saslParameters); if (!status.isOK()) - return handler(std::move(status)); + return status; BSONObj saslFirstCommandPrefix = BSON(saslStartCommandName << 1 << saslCommandMechanismFieldName << session->getParameter(SaslClientSession::parameterMechanism)); BSONObj inputObj = BSON(saslCommandPayloadFieldName << ""); - asyncSaslConversation(runCommand, - session, - std::move(saslFirstCommandPrefix), - std::move(inputObj), - targetDatabase, - saslLogLevel, - handler); + return asyncSaslConversation(runCommand, + session, + std::move(saslFirstCommandPrefix), + std::move(inputObj), + targetDatabase, + saslLogLevel); } MONGO_INITIALIZER(SaslClientAuthenticateFunction)(InitializerContext* context) { -- cgit v1.2.1