summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
authorSara Golemon <sara.golemon@mongodb.com>2021-01-08 22:49:02 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-01-12 20:30:56 +0000
commitd9dbede70d8f429cb46876ad0e2462d3d489eccc (patch)
treef5c4d4fb2c6a0641622451c95c41781650ef6732 /src/mongo
parenta67689bc26cad79cc008833458aaad496fece9c4 (diff)
downloadmongo-d9dbede70d8f429cb46876ad0e2462d3d489eccc.tar.gz
SERVER-53154 Convert saslStart/saslContinue to TypedCommand
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/db/auth/SConscript2
-rw-r--r--src/mongo/db/auth/sasl_commands.cpp359
-rw-r--r--src/mongo/db/auth/sasl_commands.idl101
-rw-r--r--src/mongo/db/auth/sasl_payload.cpp80
-rw-r--r--src/mongo/db/auth/sasl_payload.h76
5 files changed, 392 insertions, 226 deletions
diff --git a/src/mongo/db/auth/SConscript b/src/mongo/db/auth/SConscript
index 8a5498f04ba..865b4d8c474 100644
--- a/src/mongo/db/auth/SConscript
+++ b/src/mongo/db/auth/SConscript
@@ -224,6 +224,8 @@ env.Library(
source=[
'authz_session_external_state_server_common.cpp',
'sasl_commands.cpp',
+ 'sasl_commands.idl',
+ 'sasl_payload.cpp',
'enable_localhost_auth_bypass_parameter.idl',
],
LIBDEPS=[
diff --git a/src/mongo/db/auth/sasl_commands.cpp b/src/mongo/db/auth/sasl_commands.cpp
index e5e00a6b160..72108f01a0e 100644
--- a/src/mongo/db/auth/sasl_commands.cpp
+++ b/src/mongo/db/auth/sasl_commands.cpp
@@ -47,6 +47,7 @@
#include "mongo/db/auth/authz_manager_external_state_mock.h"
#include "mongo/db/auth/authz_session_external_state_mock.h"
#include "mongo/db/auth/sasl_command_constants.h"
+#include "mongo/db/auth/sasl_commands_gen.h"
#include "mongo/db/auth/sasl_options.h"
#include "mongo/db/client.h"
#include "mongo/db/commands.h"
@@ -59,143 +60,90 @@
#include "mongo/util/str.h"
namespace mongo {
+namespace auth {
namespace {
using std::stringstream;
-const bool autoAuthorizeDefault = true;
-
-class CmdSaslStart : public BasicCommand {
+class CmdSaslStart : public SaslStartCmdVersion1Gen<CmdSaslStart> {
public:
- static constexpr StringData kPayloadField = "payload"_sd;
+ std::set<StringData> sensitiveFieldNames() const final {
+ return {SaslStartCommand::kPayloadFieldName};
+ }
- CmdSaslStart();
- virtual ~CmdSaslStart();
+ class Invocation final : public InvocationBaseGen {
+ public:
+ using InvocationBaseGen::InvocationBaseGen;
- const std::set<std::string>& apiVersions() const {
- return kApiVersions1;
- }
+ bool supportsWriteConcern() const final {
+ return false;
+ }
- virtual void addRequiredPrivileges(const std::string&,
- const BSONObj&,
- std::vector<Privilege>*) const {}
+ NamespaceString ns() const final {
+ return NamespaceString(request().getDbName());
+ }
- std::set<StringData> sensitiveFieldNames() const final {
- return {kPayloadField};
- }
+ void doCheckAuthorization(OperationContext*) const final {}
- virtual bool run(OperationContext* opCtx,
- const std::string& db,
- const BSONObj& cmdObj,
- BSONObjBuilder& result);
+ Reply typedRun(OperationContext* opCtx);
+ };
- virtual std::string help() const override;
- virtual bool supportsWriteConcern(const BSONObj& cmd) const override {
- return false;
+ std::string help() const final {
+ return "First step in a SASL authentication conversation.";
}
- AllowedOnSecondary secondaryAllowed(ServiceContext*) const override {
+
+ AllowedOnSecondary secondaryAllowed(ServiceContext*) const final {
return AllowedOnSecondary::kAlways;
}
+
bool requiresAuth() const override {
return false;
}
-};
+} cmdSaslStart;
-class CmdSaslContinue : public BasicCommand {
+class CmdSaslContinue : public SaslContinueCmdVersion1Gen<CmdSaslContinue> {
public:
- static constexpr StringData kPayloadField = "payload"_sd;
-
- CmdSaslContinue();
- virtual ~CmdSaslContinue();
-
- const std::set<std::string>& apiVersions() const {
- return kApiVersions1;
- }
- virtual void addRequiredPrivileges(const std::string&,
- const BSONObj&,
- std::vector<Privilege>*) const {}
-
std::set<StringData> sensitiveFieldNames() const final {
- return {kPayloadField};
- }
-
- virtual bool run(OperationContext* opCtx,
- const std::string& db,
- const BSONObj& cmdObj,
- BSONObjBuilder& result);
-
- std::string help() const override;
- virtual bool supportsWriteConcern(const BSONObj& cmd) const override {
- return false;
- }
- AllowedOnSecondary secondaryAllowed(ServiceContext*) const override {
- return AllowedOnSecondary::kAlways;
- }
- bool requiresAuth() const override {
- return false;
+ return {SaslContinueCommand::kPayloadFieldName};
}
-};
-CmdSaslStart cmdSaslStart;
-CmdSaslContinue cmdSaslContinue;
+ class Invocation final : public InvocationBaseGen {
+ public:
+ using InvocationBaseGen::InvocationBaseGen;
-Status buildResponse(const AuthenticationSession* session,
- const std::string& responsePayload,
- BSONType responsePayloadType,
- BSONObjBuilder* result) {
- result->appendIntOrLL(saslCommandConversationIdFieldName, 1);
- result->appendBool(saslCommandDoneFieldName, session->getMechanism().isSuccess());
+ bool supportsWriteConcern() const final {
+ return false;
+ }
- if (responsePayload.size() > size_t(std::numeric_limits<int>::max())) {
- return Status(ErrorCodes::InvalidLength, "Response payload too long");
- }
- if (responsePayloadType == BinData) {
- result->appendBinData(saslCommandPayloadFieldName,
- int(responsePayload.size()),
- BinDataGeneral,
- responsePayload.data());
- } else if (responsePayloadType == String) {
- result->append(saslCommandPayloadFieldName, base64::encode(responsePayload));
- } else {
- fassertFailed(4003);
- }
+ NamespaceString ns() const final {
+ return NamespaceString(request().getDbName());
+ }
- return Status::OK();
-}
+ void doCheckAuthorization(OperationContext*) const final {}
-Status extractConversationId(const BSONObj& cmdObj, int64_t* conversationId) {
- BSONElement element;
- Status status = bsonExtractField(cmdObj, saslCommandConversationIdFieldName, &element);
- if (!status.isOK())
- return status;
+ Reply typedRun(OperationContext* opCtx);
+ };
- if (!element.isNumber()) {
- return Status(ErrorCodes::TypeMismatch,
- str::stream() << "Wrong type for field; expected number for " << element);
+ std::string help() const final {
+ return "Subsequent steps in a SASL authentication conversation.";
}
- *conversationId = element.numberLong();
- return Status::OK();
-}
-Status extractMechanism(const BSONObj& cmdObj, std::string* mechanism) {
- return bsonExtractStringField(cmdObj, saslCommandMechanismFieldName, mechanism);
-}
+ AllowedOnSecondary secondaryAllowed(ServiceContext*) const final {
+ return AllowedOnSecondary::kAlways;
+ }
-Status doSaslStep(OperationContext* opCtx,
- AuthenticationSession* session,
- const BSONObj& cmdObj,
- BSONObjBuilder* result) {
- std::string payload;
- BSONType type = EOO;
- Status status = saslExtractPayload(cmdObj, &payload, &type);
- if (!status.isOK()) {
- return status;
+ bool requiresAuth() const final {
+ return false;
}
+} cmdSaslContinue;
+StatusWith<SaslReply> doSaslStep(OperationContext* opCtx,
+ const SaslPayload& payload,
+ AuthenticationSession* session) try {
auto& mechanism = session->getMechanism();
// Passing in a payload and extracting a responsePayload
- StatusWith<std::string> swResponse = mechanism.step(opCtx, payload);
+ StatusWith<std::string> swResponse = mechanism.step(opCtx, payload.get());
if (!swResponse.isOK()) {
int64_t dLevel = 0;
@@ -221,18 +169,10 @@ Status doSaslStep(OperationContext* opCtx,
return AuthorizationManager::authenticationFailedStatus;
}
- status = buildResponse(session, swResponse.getValue(), type, result);
- if (!status.isOK()) {
- return status;
- }
-
if (mechanism.isSuccess()) {
UserName userName(mechanism.getPrincipalName(), mechanism.getAuthenticationDatabase());
- status =
- AuthorizationSession::get(opCtx->getClient())->addAndAuthorizeUser(opCtx, userName);
- if (!status.isOK()) {
- return status;
- }
+ uassertStatusOK(
+ AuthorizationSession::get(opCtx->getClient())->addAndAuthorizeUser(opCtx, userName));
if (!serverGlobalParams.quiet.load()) {
LOGV2(20250,
@@ -245,98 +185,63 @@ Status doSaslStep(OperationContext* opCtx,
"remote"_attr = opCtx->getClient()->session()->remote());
}
if (session->isSpeculative()) {
- status = authCounter.incSpeculativeAuthenticateSuccessful(
- mechanism.mechanismName().toString());
+ uassertStatusOK(authCounter.incSpeculativeAuthenticateSuccessful(
+ mechanism.mechanismName().toString()));
}
}
- return status;
-}
-StatusWith<std::unique_ptr<AuthenticationSession>> doSaslStart(OperationContext* opCtx,
- const std::string& db,
- const BSONObj& cmdObj,
- BSONObjBuilder* result,
- std::string* principalName,
- bool speculative) {
- bool autoAuthorize = false;
- Status status = bsonExtractBooleanFieldWithDefault(
- cmdObj, saslCommandAutoAuthorizeFieldName, autoAuthorizeDefault, &autoAuthorize);
- if (!status.isOK())
- return status;
-
- std::string mechanismName;
- status = extractMechanism(cmdObj, &mechanismName);
- if (!status.isOK())
- return status;
-
- StatusWith<std::unique_ptr<ServerMechanismBase>> swMech =
- SASLServerMechanismRegistry::get(opCtx->getServiceContext())
- .getServerMechanism(mechanismName, db);
+ SaslReply reply;
+ reply.setConversationId(1);
+ reply.setDone(mechanism.isSuccess());
- if (!swMech.isOK()) {
- return swMech.getStatus();
- }
+ SaslPayload replyPayload(swResponse.getValue());
+ replyPayload.serializeAsBase64(payload.getSerializeAsBase64());
+ reply.setPayload(std::move(replyPayload));
- auto session =
- std::make_unique<AuthenticationSession>(std::move(swMech.getValue()), speculative);
+ return reply;
+} catch (const DBException& ex) {
+ return ex.toStatus();
+}
- if (speculative &&
- !session->getMechanism().properties().hasAllProperties(
- SecurityPropertySet({SecurityProperty::kNoPlainText}))) {
- return {ErrorCodes::BadValue,
- "Plaintext mechanisms may not be used with speculativeSaslStart"};
- }
+SaslReply doSaslStart(OperationContext* opCtx,
+ const SaslStartCommand& request,
+ bool speculative,
+ std::string* principalName,
+ std::unique_ptr<AuthenticationSession>* session) {
+ auto mechanism = uassertStatusOK(
+ SASLServerMechanismRegistry::get(opCtx->getServiceContext())
+ .getServerMechanism(request.getMechanism(), request.getDbName().toString()));
- auto options = cmdObj["options"];
- if (!options.eoo()) {
- if (options.type() != Object) {
- return {ErrorCodes::BadValue, "saslStart.options must be an object"};
- }
- status = session->setOptions(options.Obj());
- if (!status.isOK()) {
- return status;
- }
- }
+ uassert(ErrorCodes::BadValue,
+ "Plaintext mechanisms may not be used with speculativeSaslStart",
+ !speculative ||
+ mechanism->properties().hasAllProperties(
+ SecurityPropertySet({SecurityProperty::kNoPlainText})));
- Status statusStep = doSaslStep(opCtx, session.get(), cmdObj, result);
+ auto newSession = std::make_unique<AuthenticationSession>(std::move(mechanism), speculative);
- if (!statusStep.isOK() || session->getMechanism().isSuccess()) {
- // Only attempt to populate principal name if we're done (successfully or not).
- *principalName = session->getMechanism().getPrincipalName().toString();
+ if (auto options = request.getOptions()) {
+ uassertStatusOK(newSession->setOptions(options->getOwned()));
}
- if (!statusStep.isOK()) {
- return statusStep;
+ auto swReply = doSaslStep(opCtx, request.getPayload(), newSession.get());
+ if (!swReply.isOK() || newSession->getMechanism().isSuccess()) {
+ // Only attempt to populate principal name if we're done (successfully or not).
+ *principalName = newSession->getMechanism().getPrincipalName().toString();
}
- return std::move(session);
-}
-
-Status doSaslContinue(OperationContext* opCtx,
- AuthenticationSession* session,
- const BSONObj& cmdObj,
- BSONObjBuilder* result) {
- int64_t conversationId = 0;
- Status status = extractConversationId(cmdObj, &conversationId);
- if (!status.isOK())
- return status;
- if (conversationId != 1)
- return Status(ErrorCodes::ProtocolError, "sasl: Mismatched conversation id");
-
- return doSaslStep(opCtx, session, cmdObj, result);
+ auto reply = uassertStatusOK(swReply);
+ session->reset(newSession.release());
+ return reply;
}
-bool runSaslStart(OperationContext* opCtx,
- const std::string& db,
- const BSONObj& cmdObj,
- BSONObjBuilder& result,
- bool speculative) {
+SaslReply runSaslStart(OperationContext* opCtx, const SaslStartCommand& request, bool speculative) {
opCtx->markKillOnClientDisconnect();
- Client* client = opCtx->getClient();
+ auto client = opCtx->getClient();
AuthenticationSession::set(client, std::unique_ptr<AuthenticationSession>());
- std::string mechanismName;
- uassertStatusOK(extractMechanism(cmdObj, &mechanismName));
+ auto db = request.getDbName();
+ auto mechanismName = request.getMechanism().toString();
auto status = authCounter.incAuthenticateReceived(mechanismName);
if (!status.isOK()) {
@@ -345,10 +250,12 @@ bool runSaslStart(OperationContext* opCtx,
MONGO_UNREACHABLE;
}
+ SaslReply reply;
std::string principalName;
try {
- auto session =
- uassertStatusOK(doSaslStart(opCtx, db, cmdObj, &result, &principalName, speculative));
+ std::unique_ptr<AuthenticationSession> session;
+ reply = doSaslStart(opCtx, request, speculative, &principalName, &session);
+
const bool isClusterMember = session->getMechanism().isClusterMember();
if (isClusterMember) {
uassertStatusOK(authCounter.incClusterAuthenticateReceived(mechanismName));
@@ -359,7 +266,7 @@ bool runSaslStart(OperationContext* opCtx,
uassertStatusOK(authCounter.incClusterAuthenticateSuccessful(mechanismName));
}
audit::logAuthentication(
- client, mechanismName, UserName(principalName, db), Status::OK().code());
+ client, mechanismName, UserName(principalName, db), ErrorCodes::OK);
} else {
AuthenticationSession::swap(client, session);
}
@@ -368,36 +275,18 @@ bool runSaslStart(OperationContext* opCtx,
throw;
}
- return true;
-}
-
-CmdSaslStart::CmdSaslStart() : BasicCommand(saslStartCommandName) {}
-CmdSaslStart::~CmdSaslStart() {}
-
-std::string CmdSaslStart::help() const {
- return "First step in a SASL authentication conversation.";
+ return reply;
}
-bool CmdSaslStart::run(OperationContext* opCtx,
- const std::string& db,
- const BSONObj& cmdObj,
- BSONObjBuilder& result) {
- return runSaslStart(opCtx, db, cmdObj, result, false);
+SaslReply CmdSaslStart::Invocation::typedRun(OperationContext* opCtx) {
+ return runSaslStart(opCtx, request(), false);
}
-CmdSaslContinue::CmdSaslContinue() : BasicCommand(saslContinueCommandName) {}
-CmdSaslContinue::~CmdSaslContinue() {}
-
-std::string CmdSaslContinue::help() const {
- return "Subsequent steps in a SASL authentication conversation.";
-}
+SaslReply CmdSaslContinue::Invocation::typedRun(OperationContext* opCtx) {
+ auto cmd = request();
-bool CmdSaslContinue::run(OperationContext* opCtx,
- const std::string& db,
- const BSONObj& cmdObj,
- BSONObjBuilder& result) {
opCtx->markKillOnClientDisconnect();
- Client* client = Client::getCurrent();
+ auto* client = Client::getCurrent();
std::unique_ptr<AuthenticationSession> sessionGuard;
AuthenticationSession::swap(client, sessionGuard);
@@ -405,25 +294,28 @@ bool CmdSaslContinue::run(OperationContext* opCtx,
uasserted(ErrorCodes::ProtocolError, "No SASL session state found");
}
- AuthenticationSession* session = static_cast<AuthenticationSession*>(sessionGuard.get());
+ auto* session = static_cast<AuthenticationSession*>(sessionGuard.get());
auto& mechanism = session->getMechanism();
// Authenticating the __system@local user to the admin database on mongos is required
// by the auth passthrough test suite.
- if (mechanism.getAuthenticationDatabase() != db && !getTestCommandsEnabled()) {
+ if (mechanism.getAuthenticationDatabase() != cmd.getDbName() && !getTestCommandsEnabled()) {
uasserted(ErrorCodes::ProtocolError,
"Attempt to switch database target during SASL authentication.");
}
- Status status = doSaslContinue(opCtx, session, cmdObj, &result);
- CommandHelpers::appendCommandStatusNoThrow(result, status);
+ uassert(ErrorCodes::ProtocolError,
+ "sasl: Mismatched conversation id",
+ cmd.getConversationId() == 1);
+
+ auto swReply = doSaslStep(opCtx, cmd.getPayload(), session);
- if (mechanism.isSuccess() || !status.isOK()) {
+ if (mechanism.isSuccess() || !swReply.isOK()) {
audit::logAuthentication(
client,
mechanism.mechanismName(),
UserName(mechanism.getPrincipalName(), mechanism.getAuthenticationDatabase()),
- status.code());
+ swReply.getStatus().code());
if (mechanism.isSuccess()) {
uassertStatusOK(
authCounter.incAuthenticateSuccessful(mechanism.mechanismName().toString()));
@@ -436,7 +328,7 @@ bool CmdSaslContinue::run(OperationContext* opCtx,
AuthenticationSession::swap(client, sessionGuard);
}
- return status.isOK();
+ return uassertStatusOK(swReply);
}
// The CyrusSaslCommands Enterprise initializer is dependent on PreSaslCommands
@@ -446,7 +338,9 @@ MONGO_INITIALIZER(PreSaslCommands)
disableAuthMechanism(kX509AuthMechanism);
}
+constexpr auto kDBFieldName = "db"_sd;
} // namespace
+} // namespace auth
void doSpeculativeSaslStart(OperationContext* opCtx, BSONObj cmdObj, BSONObjBuilder* result) try {
auto mechElem = cmdObj["mechanism"];
@@ -457,15 +351,28 @@ void doSpeculativeSaslStart(OperationContext* opCtx, BSONObj cmdObj, BSONObjBuil
// Run will make sure an audit entry happens. Let it reach that point.
authCounter.incSpeculativeAuthenticateReceived(mechElem.String()).ignore();
- auto dbElement = cmdObj["db"];
- if (dbElement.type() != String) {
+ // TypedCommands expect DB overrides in the "$db" field,
+ // but saslStart coming from the Hello command has it in the "db" field.
+ // Rewrite it for handling here.
+ BSONObjBuilder cmd;
+ bool hasDBField = false;
+ for (const auto& elem : cmdObj) {
+ if (elem.fieldName() == auth::kDBFieldName) {
+ cmd.appendAs(elem, auth::SaslStartCommand::kDbNameFieldName);
+ hasDBField = true;
+ } else {
+ cmd.append(elem);
+ }
+ }
+ if (!hasDBField) {
return;
}
- BSONObjBuilder saslStartResult;
- if (runSaslStart(opCtx, dbElement.String(), cmdObj, saslStartResult, true)) {
- result->append(auth::kSpeculativeAuthenticate, saslStartResult.obj());
- }
+ auto reply = auth::runSaslStart(
+ opCtx,
+ auth::SaslStartCommand::parse(IDLParserErrorContext("speculative saslStart"), cmd.obj()),
+ true);
+ result->append(auth::kSpeculativeAuthenticate, reply.toBSON());
} catch (...) {
// Treat failure like we never even got a speculative start.
}
diff --git a/src/mongo/db/auth/sasl_commands.idl b/src/mongo/db/auth/sasl_commands.idl
new file mode 100644
index 00000000000..c6ef7cad125
--- /dev/null
+++ b/src/mongo/db/auth/sasl_commands.idl
@@ -0,0 +1,101 @@
+# Copyright (C) 2020-present MongoDB, Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the Server Side Public License, version 1,
+# as published by MongoDB, Inc.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# Server Side Public License for more details.
+#
+# You should have received a copy of the Server Side Public License
+# along with this program. If not, see
+# <http://www.mongodb.com/licensing/server-side-public-license>.
+#
+# As a special exception, the copyright holders give permission to link the
+# code of portions of this program with the OpenSSL library under certain
+# conditions as described in each individual source file and distribute
+# linked combinations including the program with the OpenSSL library. You
+# must comply with the Server Side Public License in all respects for
+# all of the code used other than as permitted herein. If you modify file(s)
+# with this exception, you may extend this exception to your version of the
+# file(s), but you are not obligated to do so. If you do not wish to do so,
+# delete this exception statement from your version. If you delete this
+# exception statement from all source files in the program, then also delete
+# it in the license file.
+#
+global:
+ cpp_namespace: "mongo::auth"
+ cpp_includes:
+ - "mongo/db/auth/sasl_payload.h"
+
+imports:
+ - "mongo/idl/basic_types.idl"
+
+types:
+ SaslPayload:
+ description: "A base64 string or BinData value"
+ bson_serialization_type: any
+ cpp_type: SaslPayload
+ deserializer: "mongo::auth::SaslPayload::parseFromBSON"
+ serializer: "mongo::auth::SaslPayload::serializeToBSON"
+
+structs:
+ SaslReply:
+ description: "Response for saslStart and saslContinue commands"
+ strict: false
+ fields:
+ conversationId:
+ # In practice, this field is always populated as 1.
+ description: "Unique identifier for this SASL authentication session"
+ type: int
+ done:
+ description: "Whether or not the authentication has completed"
+ type: bool
+ payload:
+ description: "SASL payload"
+ type: SaslPayload
+
+commands:
+ saslStart:
+ description: "Begin a SASL based authentication session"
+ api_version: "1"
+ command_name: saslStart
+ namespace: ignored
+ cpp_name: SaslStartCommand
+ reply_type: SaslReply
+ strict: true
+ fields:
+ mechanism:
+ description: "SASL mechanism used for authentication"
+ type: string
+ autoAuthorize:
+ # This field is ignored and assumed to always be true.
+ description: "Automatically authorized user once authenticated"
+ type: safeBool
+ default: true
+ options:
+ description: "SASL mechanism specific options"
+ type: object
+ optional: true
+ payload:
+ description: "Initial client message for SASL exchange"
+ type: SaslPayload
+
+ saslContinue:
+ description: "Continue a SASL based authentication session"
+ api_version: "1"
+ command_name: saslContinue
+ namespace: ignored
+ cpp_name: SaslContinueCommand
+ reply_type: SaslReply
+ strict: true
+ fields:
+ conversationId:
+ # This field is expected to be 1, any other value generates an error.
+ description: "Unique identifier for this SASL authentication session"
+ type: int
+ payload:
+ description: "SASL payload"
+ type: SaslPayload
diff --git a/src/mongo/db/auth/sasl_payload.cpp b/src/mongo/db/auth/sasl_payload.cpp
new file mode 100644
index 00000000000..5061c96ba36
--- /dev/null
+++ b/src/mongo/db/auth/sasl_payload.cpp
@@ -0,0 +1,80 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/auth/sasl_payload.h"
+
+#include "mongo/util/base64.h"
+
+namespace mongo {
+namespace auth {
+
+SaslPayload SaslPayload::parseFromBSON(const BSONElement& elem) {
+ if (elem.type() == String) {
+ try {
+ SaslPayload ret(base64::decode(elem.valueStringDataSafe()));
+ ret.serializeAsBase64(true);
+ return ret;
+ } catch (...) {
+ auto status = exceptionToStatus();
+ uasserted(status.code(),
+ str::stream() << "Failed decoding SASL payload: " << status.reason());
+ }
+ } else if (elem.type() == BinData) {
+ uassert(ErrorCodes::BadValue,
+ str::stream() << "Invalid SASLPayload subtype. Expected BinDataGeneral, got: "
+ << typeName(elem.binDataType()),
+ elem.binDataType() == BinDataGeneral);
+ int len = 0;
+ const char* data = elem.binData(len);
+ return SaslPayload(std::string(data, len));
+ } else {
+ uasserted(ErrorCodes::BadValue,
+ str::stream() << "Invalid SASLPayload type. Expected Base64 or BinData, got: "
+ << typeName(elem.type()));
+ }
+}
+
+void SaslPayload::serializeToBSON(StringData fieldName, BSONObjBuilder* bob) const {
+ if (_serializeAsBase64) {
+ bob->append(fieldName, base64::encode(_payload));
+ } else {
+ bob->appendBinData(fieldName, int(_payload.size()), BinDataGeneral, _payload.c_str());
+ }
+}
+
+void SaslPayload::serializeToBSON(BSONArrayBuilder* bob) const {
+ if (_serializeAsBase64) {
+ bob->append(base64::encode(_payload));
+ } else {
+ bob->appendBinData(int(_payload.size()), BinDataGeneral, _payload.c_str());
+ }
+}
+
+} // namespace auth
+} // namespace mongo
diff --git a/src/mongo/db/auth/sasl_payload.h b/src/mongo/db/auth/sasl_payload.h
new file mode 100644
index 00000000000..671c75fbce8
--- /dev/null
+++ b/src/mongo/db/auth/sasl_payload.h
@@ -0,0 +1,76 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <string>
+
+#include "mongo/base/string_data.h"
+#include "mongo/bson/bsonelement.h"
+#include "mongo/bson/bsonobjbuilder.h"
+
+namespace mongo {
+namespace auth {
+
+/**
+ * IDL support class for proxying different ways of passing/returning
+ * a SASL payload.
+ *
+ * Payload may be either a base64 encoded string,
+ * or a BinDataGeneral containing the raw payload bytes.
+ */
+class SaslPayload {
+public:
+ SaslPayload() = default;
+
+ explicit SaslPayload(std::string data) : _payload(std::move(data)) {}
+
+ bool getSerializeAsBase64() const {
+ return _serializeAsBase64;
+ }
+
+ void serializeAsBase64(bool opt) {
+ _serializeAsBase64 = opt;
+ }
+
+ const std::string& get() const {
+ return _payload;
+ }
+
+ static SaslPayload parseFromBSON(const BSONElement& elem);
+ void serializeToBSON(StringData fieldName, BSONObjBuilder* bob) const;
+ void serializeToBSON(BSONArrayBuilder* bob) const;
+
+private:
+ bool _serializeAsBase64 = false;
+ std::string _payload;
+};
+
+} // namespace auth
+} // namespace mongo