diff options
author | Sara Golemon <sara.golemon@mongodb.com> | 2021-01-08 22:49:02 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-01-12 20:30:56 +0000 |
commit | d9dbede70d8f429cb46876ad0e2462d3d489eccc (patch) | |
tree | f5c4d4fb2c6a0641622451c95c41781650ef6732 /src/mongo | |
parent | a67689bc26cad79cc008833458aaad496fece9c4 (diff) | |
download | mongo-d9dbede70d8f429cb46876ad0e2462d3d489eccc.tar.gz |
SERVER-53154 Convert saslStart/saslContinue to TypedCommand
Diffstat (limited to 'src/mongo')
-rw-r--r-- | src/mongo/db/auth/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/db/auth/sasl_commands.cpp | 359 | ||||
-rw-r--r-- | src/mongo/db/auth/sasl_commands.idl | 101 | ||||
-rw-r--r-- | src/mongo/db/auth/sasl_payload.cpp | 80 | ||||
-rw-r--r-- | src/mongo/db/auth/sasl_payload.h | 76 |
5 files changed, 392 insertions, 226 deletions
diff --git a/src/mongo/db/auth/SConscript b/src/mongo/db/auth/SConscript index 8a5498f04ba..865b4d8c474 100644 --- a/src/mongo/db/auth/SConscript +++ b/src/mongo/db/auth/SConscript @@ -224,6 +224,8 @@ env.Library( source=[ 'authz_session_external_state_server_common.cpp', 'sasl_commands.cpp', + 'sasl_commands.idl', + 'sasl_payload.cpp', 'enable_localhost_auth_bypass_parameter.idl', ], LIBDEPS=[ diff --git a/src/mongo/db/auth/sasl_commands.cpp b/src/mongo/db/auth/sasl_commands.cpp index e5e00a6b160..72108f01a0e 100644 --- a/src/mongo/db/auth/sasl_commands.cpp +++ b/src/mongo/db/auth/sasl_commands.cpp @@ -47,6 +47,7 @@ #include "mongo/db/auth/authz_manager_external_state_mock.h" #include "mongo/db/auth/authz_session_external_state_mock.h" #include "mongo/db/auth/sasl_command_constants.h" +#include "mongo/db/auth/sasl_commands_gen.h" #include "mongo/db/auth/sasl_options.h" #include "mongo/db/client.h" #include "mongo/db/commands.h" @@ -59,143 +60,90 @@ #include "mongo/util/str.h" namespace mongo { +namespace auth { namespace { using std::stringstream; -const bool autoAuthorizeDefault = true; - -class CmdSaslStart : public BasicCommand { +class CmdSaslStart : public SaslStartCmdVersion1Gen<CmdSaslStart> { public: - static constexpr StringData kPayloadField = "payload"_sd; + std::set<StringData> sensitiveFieldNames() const final { + return {SaslStartCommand::kPayloadFieldName}; + } - CmdSaslStart(); - virtual ~CmdSaslStart(); + class Invocation final : public InvocationBaseGen { + public: + using InvocationBaseGen::InvocationBaseGen; - const std::set<std::string>& apiVersions() const { - return kApiVersions1; - } + bool supportsWriteConcern() const final { + return false; + } - virtual void addRequiredPrivileges(const std::string&, - const BSONObj&, - std::vector<Privilege>*) const {} + NamespaceString ns() const final { + return NamespaceString(request().getDbName()); + } - std::set<StringData> sensitiveFieldNames() const final { - return {kPayloadField}; - } + void doCheckAuthorization(OperationContext*) const final {} - virtual bool run(OperationContext* opCtx, - const std::string& db, - const BSONObj& cmdObj, - BSONObjBuilder& result); + Reply typedRun(OperationContext* opCtx); + }; - virtual std::string help() const override; - virtual bool supportsWriteConcern(const BSONObj& cmd) const override { - return false; + std::string help() const final { + return "First step in a SASL authentication conversation."; } - AllowedOnSecondary secondaryAllowed(ServiceContext*) const override { + + AllowedOnSecondary secondaryAllowed(ServiceContext*) const final { return AllowedOnSecondary::kAlways; } + bool requiresAuth() const override { return false; } -}; +} cmdSaslStart; -class CmdSaslContinue : public BasicCommand { +class CmdSaslContinue : public SaslContinueCmdVersion1Gen<CmdSaslContinue> { public: - static constexpr StringData kPayloadField = "payload"_sd; - - CmdSaslContinue(); - virtual ~CmdSaslContinue(); - - const std::set<std::string>& apiVersions() const { - return kApiVersions1; - } - virtual void addRequiredPrivileges(const std::string&, - const BSONObj&, - std::vector<Privilege>*) const {} - std::set<StringData> sensitiveFieldNames() const final { - return {kPayloadField}; - } - - virtual bool run(OperationContext* opCtx, - const std::string& db, - const BSONObj& cmdObj, - BSONObjBuilder& result); - - std::string help() const override; - virtual bool supportsWriteConcern(const BSONObj& cmd) const override { - return false; - } - AllowedOnSecondary secondaryAllowed(ServiceContext*) const override { - return AllowedOnSecondary::kAlways; - } - bool requiresAuth() const override { - return false; + return {SaslContinueCommand::kPayloadFieldName}; } -}; -CmdSaslStart cmdSaslStart; -CmdSaslContinue cmdSaslContinue; + class Invocation final : public InvocationBaseGen { + public: + using InvocationBaseGen::InvocationBaseGen; -Status buildResponse(const AuthenticationSession* session, - const std::string& responsePayload, - BSONType responsePayloadType, - BSONObjBuilder* result) { - result->appendIntOrLL(saslCommandConversationIdFieldName, 1); - result->appendBool(saslCommandDoneFieldName, session->getMechanism().isSuccess()); + bool supportsWriteConcern() const final { + return false; + } - if (responsePayload.size() > size_t(std::numeric_limits<int>::max())) { - return Status(ErrorCodes::InvalidLength, "Response payload too long"); - } - if (responsePayloadType == BinData) { - result->appendBinData(saslCommandPayloadFieldName, - int(responsePayload.size()), - BinDataGeneral, - responsePayload.data()); - } else if (responsePayloadType == String) { - result->append(saslCommandPayloadFieldName, base64::encode(responsePayload)); - } else { - fassertFailed(4003); - } + NamespaceString ns() const final { + return NamespaceString(request().getDbName()); + } - return Status::OK(); -} + void doCheckAuthorization(OperationContext*) const final {} -Status extractConversationId(const BSONObj& cmdObj, int64_t* conversationId) { - BSONElement element; - Status status = bsonExtractField(cmdObj, saslCommandConversationIdFieldName, &element); - if (!status.isOK()) - return status; + Reply typedRun(OperationContext* opCtx); + }; - if (!element.isNumber()) { - return Status(ErrorCodes::TypeMismatch, - str::stream() << "Wrong type for field; expected number for " << element); + std::string help() const final { + return "Subsequent steps in a SASL authentication conversation."; } - *conversationId = element.numberLong(); - return Status::OK(); -} -Status extractMechanism(const BSONObj& cmdObj, std::string* mechanism) { - return bsonExtractStringField(cmdObj, saslCommandMechanismFieldName, mechanism); -} + AllowedOnSecondary secondaryAllowed(ServiceContext*) const final { + return AllowedOnSecondary::kAlways; + } -Status doSaslStep(OperationContext* opCtx, - AuthenticationSession* session, - const BSONObj& cmdObj, - BSONObjBuilder* result) { - std::string payload; - BSONType type = EOO; - Status status = saslExtractPayload(cmdObj, &payload, &type); - if (!status.isOK()) { - return status; + bool requiresAuth() const final { + return false; } +} cmdSaslContinue; +StatusWith<SaslReply> doSaslStep(OperationContext* opCtx, + const SaslPayload& payload, + AuthenticationSession* session) try { auto& mechanism = session->getMechanism(); // Passing in a payload and extracting a responsePayload - StatusWith<std::string> swResponse = mechanism.step(opCtx, payload); + StatusWith<std::string> swResponse = mechanism.step(opCtx, payload.get()); if (!swResponse.isOK()) { int64_t dLevel = 0; @@ -221,18 +169,10 @@ Status doSaslStep(OperationContext* opCtx, return AuthorizationManager::authenticationFailedStatus; } - status = buildResponse(session, swResponse.getValue(), type, result); - if (!status.isOK()) { - return status; - } - if (mechanism.isSuccess()) { UserName userName(mechanism.getPrincipalName(), mechanism.getAuthenticationDatabase()); - status = - AuthorizationSession::get(opCtx->getClient())->addAndAuthorizeUser(opCtx, userName); - if (!status.isOK()) { - return status; - } + uassertStatusOK( + AuthorizationSession::get(opCtx->getClient())->addAndAuthorizeUser(opCtx, userName)); if (!serverGlobalParams.quiet.load()) { LOGV2(20250, @@ -245,98 +185,63 @@ Status doSaslStep(OperationContext* opCtx, "remote"_attr = opCtx->getClient()->session()->remote()); } if (session->isSpeculative()) { - status = authCounter.incSpeculativeAuthenticateSuccessful( - mechanism.mechanismName().toString()); + uassertStatusOK(authCounter.incSpeculativeAuthenticateSuccessful( + mechanism.mechanismName().toString())); } } - return status; -} -StatusWith<std::unique_ptr<AuthenticationSession>> doSaslStart(OperationContext* opCtx, - const std::string& db, - const BSONObj& cmdObj, - BSONObjBuilder* result, - std::string* principalName, - bool speculative) { - bool autoAuthorize = false; - Status status = bsonExtractBooleanFieldWithDefault( - cmdObj, saslCommandAutoAuthorizeFieldName, autoAuthorizeDefault, &autoAuthorize); - if (!status.isOK()) - return status; - - std::string mechanismName; - status = extractMechanism(cmdObj, &mechanismName); - if (!status.isOK()) - return status; - - StatusWith<std::unique_ptr<ServerMechanismBase>> swMech = - SASLServerMechanismRegistry::get(opCtx->getServiceContext()) - .getServerMechanism(mechanismName, db); + SaslReply reply; + reply.setConversationId(1); + reply.setDone(mechanism.isSuccess()); - if (!swMech.isOK()) { - return swMech.getStatus(); - } + SaslPayload replyPayload(swResponse.getValue()); + replyPayload.serializeAsBase64(payload.getSerializeAsBase64()); + reply.setPayload(std::move(replyPayload)); - auto session = - std::make_unique<AuthenticationSession>(std::move(swMech.getValue()), speculative); + return reply; +} catch (const DBException& ex) { + return ex.toStatus(); +} - if (speculative && - !session->getMechanism().properties().hasAllProperties( - SecurityPropertySet({SecurityProperty::kNoPlainText}))) { - return {ErrorCodes::BadValue, - "Plaintext mechanisms may not be used with speculativeSaslStart"}; - } +SaslReply doSaslStart(OperationContext* opCtx, + const SaslStartCommand& request, + bool speculative, + std::string* principalName, + std::unique_ptr<AuthenticationSession>* session) { + auto mechanism = uassertStatusOK( + SASLServerMechanismRegistry::get(opCtx->getServiceContext()) + .getServerMechanism(request.getMechanism(), request.getDbName().toString())); - auto options = cmdObj["options"]; - if (!options.eoo()) { - if (options.type() != Object) { - return {ErrorCodes::BadValue, "saslStart.options must be an object"}; - } - status = session->setOptions(options.Obj()); - if (!status.isOK()) { - return status; - } - } + uassert(ErrorCodes::BadValue, + "Plaintext mechanisms may not be used with speculativeSaslStart", + !speculative || + mechanism->properties().hasAllProperties( + SecurityPropertySet({SecurityProperty::kNoPlainText}))); - Status statusStep = doSaslStep(opCtx, session.get(), cmdObj, result); + auto newSession = std::make_unique<AuthenticationSession>(std::move(mechanism), speculative); - if (!statusStep.isOK() || session->getMechanism().isSuccess()) { - // Only attempt to populate principal name if we're done (successfully or not). - *principalName = session->getMechanism().getPrincipalName().toString(); + if (auto options = request.getOptions()) { + uassertStatusOK(newSession->setOptions(options->getOwned())); } - if (!statusStep.isOK()) { - return statusStep; + auto swReply = doSaslStep(opCtx, request.getPayload(), newSession.get()); + if (!swReply.isOK() || newSession->getMechanism().isSuccess()) { + // Only attempt to populate principal name if we're done (successfully or not). + *principalName = newSession->getMechanism().getPrincipalName().toString(); } - return std::move(session); -} - -Status doSaslContinue(OperationContext* opCtx, - AuthenticationSession* session, - const BSONObj& cmdObj, - BSONObjBuilder* result) { - int64_t conversationId = 0; - Status status = extractConversationId(cmdObj, &conversationId); - if (!status.isOK()) - return status; - if (conversationId != 1) - return Status(ErrorCodes::ProtocolError, "sasl: Mismatched conversation id"); - - return doSaslStep(opCtx, session, cmdObj, result); + auto reply = uassertStatusOK(swReply); + session->reset(newSession.release()); + return reply; } -bool runSaslStart(OperationContext* opCtx, - const std::string& db, - const BSONObj& cmdObj, - BSONObjBuilder& result, - bool speculative) { +SaslReply runSaslStart(OperationContext* opCtx, const SaslStartCommand& request, bool speculative) { opCtx->markKillOnClientDisconnect(); - Client* client = opCtx->getClient(); + auto client = opCtx->getClient(); AuthenticationSession::set(client, std::unique_ptr<AuthenticationSession>()); - std::string mechanismName; - uassertStatusOK(extractMechanism(cmdObj, &mechanismName)); + auto db = request.getDbName(); + auto mechanismName = request.getMechanism().toString(); auto status = authCounter.incAuthenticateReceived(mechanismName); if (!status.isOK()) { @@ -345,10 +250,12 @@ bool runSaslStart(OperationContext* opCtx, MONGO_UNREACHABLE; } + SaslReply reply; std::string principalName; try { - auto session = - uassertStatusOK(doSaslStart(opCtx, db, cmdObj, &result, &principalName, speculative)); + std::unique_ptr<AuthenticationSession> session; + reply = doSaslStart(opCtx, request, speculative, &principalName, &session); + const bool isClusterMember = session->getMechanism().isClusterMember(); if (isClusterMember) { uassertStatusOK(authCounter.incClusterAuthenticateReceived(mechanismName)); @@ -359,7 +266,7 @@ bool runSaslStart(OperationContext* opCtx, uassertStatusOK(authCounter.incClusterAuthenticateSuccessful(mechanismName)); } audit::logAuthentication( - client, mechanismName, UserName(principalName, db), Status::OK().code()); + client, mechanismName, UserName(principalName, db), ErrorCodes::OK); } else { AuthenticationSession::swap(client, session); } @@ -368,36 +275,18 @@ bool runSaslStart(OperationContext* opCtx, throw; } - return true; -} - -CmdSaslStart::CmdSaslStart() : BasicCommand(saslStartCommandName) {} -CmdSaslStart::~CmdSaslStart() {} - -std::string CmdSaslStart::help() const { - return "First step in a SASL authentication conversation."; + return reply; } -bool CmdSaslStart::run(OperationContext* opCtx, - const std::string& db, - const BSONObj& cmdObj, - BSONObjBuilder& result) { - return runSaslStart(opCtx, db, cmdObj, result, false); +SaslReply CmdSaslStart::Invocation::typedRun(OperationContext* opCtx) { + return runSaslStart(opCtx, request(), false); } -CmdSaslContinue::CmdSaslContinue() : BasicCommand(saslContinueCommandName) {} -CmdSaslContinue::~CmdSaslContinue() {} - -std::string CmdSaslContinue::help() const { - return "Subsequent steps in a SASL authentication conversation."; -} +SaslReply CmdSaslContinue::Invocation::typedRun(OperationContext* opCtx) { + auto cmd = request(); -bool CmdSaslContinue::run(OperationContext* opCtx, - const std::string& db, - const BSONObj& cmdObj, - BSONObjBuilder& result) { opCtx->markKillOnClientDisconnect(); - Client* client = Client::getCurrent(); + auto* client = Client::getCurrent(); std::unique_ptr<AuthenticationSession> sessionGuard; AuthenticationSession::swap(client, sessionGuard); @@ -405,25 +294,28 @@ bool CmdSaslContinue::run(OperationContext* opCtx, uasserted(ErrorCodes::ProtocolError, "No SASL session state found"); } - AuthenticationSession* session = static_cast<AuthenticationSession*>(sessionGuard.get()); + auto* session = static_cast<AuthenticationSession*>(sessionGuard.get()); auto& mechanism = session->getMechanism(); // Authenticating the __system@local user to the admin database on mongos is required // by the auth passthrough test suite. - if (mechanism.getAuthenticationDatabase() != db && !getTestCommandsEnabled()) { + if (mechanism.getAuthenticationDatabase() != cmd.getDbName() && !getTestCommandsEnabled()) { uasserted(ErrorCodes::ProtocolError, "Attempt to switch database target during SASL authentication."); } - Status status = doSaslContinue(opCtx, session, cmdObj, &result); - CommandHelpers::appendCommandStatusNoThrow(result, status); + uassert(ErrorCodes::ProtocolError, + "sasl: Mismatched conversation id", + cmd.getConversationId() == 1); + + auto swReply = doSaslStep(opCtx, cmd.getPayload(), session); - if (mechanism.isSuccess() || !status.isOK()) { + if (mechanism.isSuccess() || !swReply.isOK()) { audit::logAuthentication( client, mechanism.mechanismName(), UserName(mechanism.getPrincipalName(), mechanism.getAuthenticationDatabase()), - status.code()); + swReply.getStatus().code()); if (mechanism.isSuccess()) { uassertStatusOK( authCounter.incAuthenticateSuccessful(mechanism.mechanismName().toString())); @@ -436,7 +328,7 @@ bool CmdSaslContinue::run(OperationContext* opCtx, AuthenticationSession::swap(client, sessionGuard); } - return status.isOK(); + return uassertStatusOK(swReply); } // The CyrusSaslCommands Enterprise initializer is dependent on PreSaslCommands @@ -446,7 +338,9 @@ MONGO_INITIALIZER(PreSaslCommands) disableAuthMechanism(kX509AuthMechanism); } +constexpr auto kDBFieldName = "db"_sd; } // namespace +} // namespace auth void doSpeculativeSaslStart(OperationContext* opCtx, BSONObj cmdObj, BSONObjBuilder* result) try { auto mechElem = cmdObj["mechanism"]; @@ -457,15 +351,28 @@ void doSpeculativeSaslStart(OperationContext* opCtx, BSONObj cmdObj, BSONObjBuil // Run will make sure an audit entry happens. Let it reach that point. authCounter.incSpeculativeAuthenticateReceived(mechElem.String()).ignore(); - auto dbElement = cmdObj["db"]; - if (dbElement.type() != String) { + // TypedCommands expect DB overrides in the "$db" field, + // but saslStart coming from the Hello command has it in the "db" field. + // Rewrite it for handling here. + BSONObjBuilder cmd; + bool hasDBField = false; + for (const auto& elem : cmdObj) { + if (elem.fieldName() == auth::kDBFieldName) { + cmd.appendAs(elem, auth::SaslStartCommand::kDbNameFieldName); + hasDBField = true; + } else { + cmd.append(elem); + } + } + if (!hasDBField) { return; } - BSONObjBuilder saslStartResult; - if (runSaslStart(opCtx, dbElement.String(), cmdObj, saslStartResult, true)) { - result->append(auth::kSpeculativeAuthenticate, saslStartResult.obj()); - } + auto reply = auth::runSaslStart( + opCtx, + auth::SaslStartCommand::parse(IDLParserErrorContext("speculative saslStart"), cmd.obj()), + true); + result->append(auth::kSpeculativeAuthenticate, reply.toBSON()); } catch (...) { // Treat failure like we never even got a speculative start. } diff --git a/src/mongo/db/auth/sasl_commands.idl b/src/mongo/db/auth/sasl_commands.idl new file mode 100644 index 00000000000..c6ef7cad125 --- /dev/null +++ b/src/mongo/db/auth/sasl_commands.idl @@ -0,0 +1,101 @@ +# Copyright (C) 2020-present MongoDB, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the Server Side Public License, version 1, +# as published by MongoDB, Inc. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Server Side Public License for more details. +# +# You should have received a copy of the Server Side Public License +# along with this program. If not, see +# <http://www.mongodb.com/licensing/server-side-public-license>. +# +# As a special exception, the copyright holders give permission to link the +# code of portions of this program with the OpenSSL library under certain +# conditions as described in each individual source file and distribute +# linked combinations including the program with the OpenSSL library. You +# must comply with the Server Side Public License in all respects for +# all of the code used other than as permitted herein. If you modify file(s) +# with this exception, you may extend this exception to your version of the +# file(s), but you are not obligated to do so. If you do not wish to do so, +# delete this exception statement from your version. If you delete this +# exception statement from all source files in the program, then also delete +# it in the license file. +# +global: + cpp_namespace: "mongo::auth" + cpp_includes: + - "mongo/db/auth/sasl_payload.h" + +imports: + - "mongo/idl/basic_types.idl" + +types: + SaslPayload: + description: "A base64 string or BinData value" + bson_serialization_type: any + cpp_type: SaslPayload + deserializer: "mongo::auth::SaslPayload::parseFromBSON" + serializer: "mongo::auth::SaslPayload::serializeToBSON" + +structs: + SaslReply: + description: "Response for saslStart and saslContinue commands" + strict: false + fields: + conversationId: + # In practice, this field is always populated as 1. + description: "Unique identifier for this SASL authentication session" + type: int + done: + description: "Whether or not the authentication has completed" + type: bool + payload: + description: "SASL payload" + type: SaslPayload + +commands: + saslStart: + description: "Begin a SASL based authentication session" + api_version: "1" + command_name: saslStart + namespace: ignored + cpp_name: SaslStartCommand + reply_type: SaslReply + strict: true + fields: + mechanism: + description: "SASL mechanism used for authentication" + type: string + autoAuthorize: + # This field is ignored and assumed to always be true. + description: "Automatically authorized user once authenticated" + type: safeBool + default: true + options: + description: "SASL mechanism specific options" + type: object + optional: true + payload: + description: "Initial client message for SASL exchange" + type: SaslPayload + + saslContinue: + description: "Continue a SASL based authentication session" + api_version: "1" + command_name: saslContinue + namespace: ignored + cpp_name: SaslContinueCommand + reply_type: SaslReply + strict: true + fields: + conversationId: + # This field is expected to be 1, any other value generates an error. + description: "Unique identifier for this SASL authentication session" + type: int + payload: + description: "SASL payload" + type: SaslPayload diff --git a/src/mongo/db/auth/sasl_payload.cpp b/src/mongo/db/auth/sasl_payload.cpp new file mode 100644 index 00000000000..5061c96ba36 --- /dev/null +++ b/src/mongo/db/auth/sasl_payload.cpp @@ -0,0 +1,80 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/auth/sasl_payload.h" + +#include "mongo/util/base64.h" + +namespace mongo { +namespace auth { + +SaslPayload SaslPayload::parseFromBSON(const BSONElement& elem) { + if (elem.type() == String) { + try { + SaslPayload ret(base64::decode(elem.valueStringDataSafe())); + ret.serializeAsBase64(true); + return ret; + } catch (...) { + auto status = exceptionToStatus(); + uasserted(status.code(), + str::stream() << "Failed decoding SASL payload: " << status.reason()); + } + } else if (elem.type() == BinData) { + uassert(ErrorCodes::BadValue, + str::stream() << "Invalid SASLPayload subtype. Expected BinDataGeneral, got: " + << typeName(elem.binDataType()), + elem.binDataType() == BinDataGeneral); + int len = 0; + const char* data = elem.binData(len); + return SaslPayload(std::string(data, len)); + } else { + uasserted(ErrorCodes::BadValue, + str::stream() << "Invalid SASLPayload type. Expected Base64 or BinData, got: " + << typeName(elem.type())); + } +} + +void SaslPayload::serializeToBSON(StringData fieldName, BSONObjBuilder* bob) const { + if (_serializeAsBase64) { + bob->append(fieldName, base64::encode(_payload)); + } else { + bob->appendBinData(fieldName, int(_payload.size()), BinDataGeneral, _payload.c_str()); + } +} + +void SaslPayload::serializeToBSON(BSONArrayBuilder* bob) const { + if (_serializeAsBase64) { + bob->append(base64::encode(_payload)); + } else { + bob->appendBinData(int(_payload.size()), BinDataGeneral, _payload.c_str()); + } +} + +} // namespace auth +} // namespace mongo diff --git a/src/mongo/db/auth/sasl_payload.h b/src/mongo/db/auth/sasl_payload.h new file mode 100644 index 00000000000..671c75fbce8 --- /dev/null +++ b/src/mongo/db/auth/sasl_payload.h @@ -0,0 +1,76 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <string> + +#include "mongo/base/string_data.h" +#include "mongo/bson/bsonelement.h" +#include "mongo/bson/bsonobjbuilder.h" + +namespace mongo { +namespace auth { + +/** + * IDL support class for proxying different ways of passing/returning + * a SASL payload. + * + * Payload may be either a base64 encoded string, + * or a BinDataGeneral containing the raw payload bytes. + */ +class SaslPayload { +public: + SaslPayload() = default; + + explicit SaslPayload(std::string data) : _payload(std::move(data)) {} + + bool getSerializeAsBase64() const { + return _serializeAsBase64; + } + + void serializeAsBase64(bool opt) { + _serializeAsBase64 = opt; + } + + const std::string& get() const { + return _payload; + } + + static SaslPayload parseFromBSON(const BSONElement& elem); + void serializeToBSON(StringData fieldName, BSONObjBuilder* bob) const; + void serializeToBSON(BSONArrayBuilder* bob) const; + +private: + bool _serializeAsBase64 = false; + std::string _payload; +}; + +} // namespace auth +} // namespace mongo |