From 25d521ca3283fcb21e125b30403e0036d05a8338 Mon Sep 17 00:00:00 2001 From: Spencer Jackson Date: Fri, 12 Jan 2018 19:11:14 -0500 Subject: SERVER-32966: Add SASL server mechanism registry --- .../db/auth/sasl_authentication_session_test.cpp | 89 +++++++++++++--------- 1 file changed, 52 insertions(+), 37 deletions(-) (limited to 'src/mongo/db/auth/sasl_authentication_session_test.cpp') diff --git a/src/mongo/db/auth/sasl_authentication_session_test.cpp b/src/mongo/db/auth/sasl_authentication_session_test.cpp index 7226395ece6..4babe8c816b 100644 --- a/src/mongo/db/auth/sasl_authentication_session_test.cpp +++ b/src/mongo/db/auth/sasl_authentication_session_test.cpp @@ -16,10 +16,13 @@ #include "mongo/db/auth/authorization_session.h" #include "mongo/db/auth/authz_manager_external_state_mock.h" #include "mongo/db/auth/authz_session_external_state_mock.h" -#include "mongo/db/auth/sasl_authentication_session.h" +#include "mongo/db/auth/sasl_mechanism_registry.h" #include "mongo/db/auth/sasl_options.h" +#include "mongo/db/auth/sasl_plain_server_conversation.h" +#include "mongo/db/auth/sasl_scram_server_conversation.h" #include "mongo/db/jsobj.h" #include "mongo/db/operation_context_noop.h" +#include "mongo/db/service_context_noop.h" #include "mongo/unittest/unittest.h" #include "mongo/util/log.h" #include "mongo/util/password_digest.h" @@ -38,12 +41,16 @@ public: void testWrongClientMechanism(); void testWrongServerMechanism(); + ServiceContextNoop serviceContext; + ServiceContext::UniqueClient opClient; + ServiceContext::UniqueOperationContext opCtx; AuthzManagerExternalStateMock* authManagerExternalState; - AuthorizationManager authManager; + AuthorizationManager* authManager; std::unique_ptr authSession; + SASLServerMechanismRegistry registry; std::string mechanism; std::unique_ptr client; - std::unique_ptr server; + std::unique_ptr server; private: void assertConversationFailure(); @@ -58,17 +65,27 @@ const std::string mockServiceName = "mocksvc"; const std::string mockHostName = "host.mockery.com"; SaslConversation::SaslConversation(std::string mech) - : authManagerExternalState(new AuthzManagerExternalStateMock), - authManager(std::unique_ptr(authManagerExternalState)), - authSession(authManager.makeAuthorizationSession()), + : opClient(serviceContext.makeClient("saslTest")), + opCtx(serviceContext.makeOperationContext(opClient.get())), + authManagerExternalState(new AuthzManagerExternalStateMock), + authManager(new AuthorizationManager( + std::unique_ptr(authManagerExternalState))), + authSession(authManager->makeAuthorizationSession()), mechanism(mech) { - OperationContextNoop opCtx; + + AuthorizationManager::set(&serviceContext, std::unique_ptr(authManager)); client.reset(SaslClientSession::create(mechanism)); - server.reset(SaslAuthenticationSession::create(authSession.get(), "test", mechanism)); + + registry.registerFactory( + SASLServerMechanismRegistry::kNoValidateGlobalMechanisms); + registry.registerFactory( + SASLServerMechanismRegistry::kNoValidateGlobalMechanisms); + registry.registerFactory( + SASLServerMechanismRegistry::kNoValidateGlobalMechanisms); ASSERT_OK(authManagerExternalState->updateOne( - &opCtx, + opCtx.get(), AuthorizationManager::versionCollectionNamespace, AuthorizationManager::versionDocumentQuery, BSON("$set" << BSON(AuthorizationManager::schemaVersionFieldName @@ -88,7 +105,7 @@ SaslConversation::SaslConversation(std::string mech) << scram::Secrets::generateCredentials( "frim", saslGlobalParams.scramSHA256IterationCount.load())); - ASSERT_OK(authManagerExternalState->insert(&opCtx, + ASSERT_OK(authManagerExternalState->insert(opCtx.get(), NamespaceString("admin.system.users"), BSON("_id" << "test.andy" @@ -107,16 +124,16 @@ void SaslConversation::assertConversationFailure() { std::string clientMessage; std::string serverMessage; Status clientStatus(ErrorCodes::InternalError, ""); - Status serverStatus(ErrorCodes::InternalError, ""); + StatusWith serverResponse(""); do { - clientStatus = client->step(serverMessage, &clientMessage); + clientStatus = client->step(serverResponse.getValue(), &clientMessage); if (!clientStatus.isOK()) break; - serverStatus = server->step(clientMessage, &serverMessage); - if (!serverStatus.isOK()) + serverResponse = server->step(opCtx.get(), clientMessage); + if (!serverResponse.isOK()) break; } while (!client->isDone()); - ASSERT_FALSE(serverStatus.isOK() && clientStatus.isOK() && client->isDone() && + ASSERT_FALSE(serverResponse.isOK() && clientStatus.isOK() && client->isDone() && server->isDone()); } @@ -128,13 +145,12 @@ void SaslConversation::testSuccessfulAuthentication() { client->setParameter(SaslClientSession::parameterPassword, "frim"); ASSERT_OK(client->initialize()); - ASSERT_OK(server->start("test", mechanism, mockServiceName, mockHostName, 1, true)); - std::string clientMessage; - std::string serverMessage; + StatusWith serverResponse(""); do { - ASSERT_OK(client->step(serverMessage, &clientMessage)); - ASSERT_OK(server->step(clientMessage, &serverMessage)); + ASSERT_OK(client->step(serverResponse.getValue(), &clientMessage)); + serverResponse = server->step(opCtx.get(), clientMessage); + ASSERT_OK(serverResponse.getStatus()); } while (!client->isDone()); ASSERT_TRUE(server->isDone()); } @@ -147,8 +163,6 @@ void SaslConversation::testNoSuchUser() { client->setParameter(SaslClientSession::parameterPassword, "frim"); ASSERT_OK(client->initialize()); - ASSERT_OK(server->start("test", mechanism, mockServiceName, mockHostName, 1, true)); - assertConversationFailure(); } @@ -160,8 +174,6 @@ void SaslConversation::testBadPassword() { client->setParameter(SaslClientSession::parameterPassword, "WRONG"); ASSERT_OK(client->initialize()); - ASSERT_OK(server->start("test", mechanism, mockServiceName, mockHostName, 1, true)); - assertConversationFailure(); } @@ -175,8 +187,6 @@ void SaslConversation::testWrongClientMechanism() { client->setParameter(SaslClientSession::parameterPassword, "frim"); ASSERT_OK(client->initialize()); - ASSERT_OK(server->start("test", mechanism, mockServiceName, mockHostName, 1, true)); - assertConversationFailure(); } @@ -188,19 +198,22 @@ void SaslConversation::testWrongServerMechanism() { client->setParameter(SaslClientSession::parameterPassword, "frim"); ASSERT_OK(client->initialize()); - ASSERT_OK(server->start("test", - mechanism != "SCRAM-SHA-1" ? "SCRAM-SHA-1" : "PLAIN", - mockServiceName, - mockHostName, - 1, - true)); + auto swServer = + registry.getServerMechanism(mechanism != "SCRAM-SHA-1" ? "SCRAM-SHA-1" : "PLAIN", "test"); + ASSERT_OK(swServer.getStatus()); + server = std::move(swServer.getValue()); + assertConversationFailure(); } -#define DEFINE_MECHANISM_FIXTURE(CLASS_SUFFIX, MECH_NAME) \ - class SaslConversation##CLASS_SUFFIX : public SaslConversation { \ - public: \ - SaslConversation##CLASS_SUFFIX() : SaslConversation(MECH_NAME) {} \ +#define DEFINE_MECHANISM_FIXTURE(CLASS_SUFFIX, MECH_NAME) \ + class SaslConversation##CLASS_SUFFIX : public SaslConversation { \ + public: \ + SaslConversation##CLASS_SUFFIX() : SaslConversation(MECH_NAME) { \ + auto swServer = registry.getServerMechanism(MECH_NAME, "test"); \ + ASSERT_OK(swServer.getStatus()); \ + server = std::move(swServer.getValue()); \ + } \ } #define DEFINE_MECHANISM_TEST(FIXTURE_NAME, TEST_NAME) \ @@ -236,7 +249,9 @@ TEST_F(SaslIllegalConversation, IllegalClientMechanism) { } TEST_F(SaslIllegalConversation, IllegalServerMechanism) { - ASSERT_NOT_OK(server->start("test", "FAKE", mockServiceName, mockHostName, 1, true)); + SASLServerMechanismRegistry registry; + auto swServer = registry.getServerMechanism("FAKE", "test"); + ASSERT_NOT_OK(swServer.getStatus()); } } // namespace -- cgit v1.2.1