/**
* Copyright (C) 2015 MongoDB Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, version 3,
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the GNU Affero General Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#include "mongo/platform/basic.h"
#include "mongo/s/sharding_test_fixture.h"
#include
#include
#include "mongo/base/status_with.h"
#include "mongo/bson/simple_bsonobj_comparator.h"
#include "mongo/client/remote_command_targeter_factory_mock.h"
#include "mongo/client/remote_command_targeter_mock.h"
#include "mongo/db/client.h"
#include "mongo/db/commands.h"
#include "mongo/db/namespace_string.h"
#include "mongo/db/query/collation/collator_factory_mock.h"
#include "mongo/db/query/query_request.h"
#include "mongo/db/repl/read_concern_args.h"
#include "mongo/db/s/sharding_task_executor.h"
#include "mongo/db/service_context_noop.h"
#include "mongo/executor/network_interface_mock.h"
#include "mongo/executor/task_executor_pool.h"
#include "mongo/executor/thread_pool_task_executor_test_fixture.h"
#include "mongo/rpc/metadata/repl_set_metadata.h"
#include "mongo/rpc/metadata/tracking_metadata.h"
#include "mongo/s/balancer_configuration.h"
#include "mongo/s/catalog/dist_lock_manager_mock.h"
#include "mongo/s/catalog/sharding_catalog_client_impl.h"
#include "mongo/s/catalog/sharding_catalog_manager.h"
#include "mongo/s/catalog/type_changelog.h"
#include "mongo/s/catalog/type_collection.h"
#include "mongo/s/catalog/type_shard.h"
#include "mongo/s/catalog_cache.h"
#include "mongo/s/client/shard_factory.h"
#include "mongo/s/client/shard_registry.h"
#include "mongo/s/client/shard_remote.h"
#include "mongo/s/grid.h"
#include "mongo/s/query/cluster_cursor_manager.h"
#include "mongo/s/set_shard_version_request.h"
#include "mongo/s/sharding_egress_metadata_hook_for_mongos.h"
#include "mongo/s/write_ops/batched_command_request.h"
#include "mongo/s/write_ops/batched_command_response.h"
#include "mongo/stdx/memory.h"
#include "mongo/transport/mock_session.h"
#include "mongo/transport/transport_layer.h"
#include "mongo/transport/transport_layer_mock.h"
#include "mongo/util/clock_source_mock.h"
#include "mongo/util/tick_source_mock.h"
namespace mongo {
using executor::NetworkInterfaceMock;
using executor::NetworkTestEnv;
using executor::RemoteCommandRequest;
using executor::RemoteCommandResponse;
using executor::ShardingTaskExecutor;
using rpc::ShardingEgressMetadataHookForMongos;
using unittest::assertGet;
using std::string;
using std::vector;
using unittest::assertGet;
namespace {
std::unique_ptr makeShardingTestExecutor(
std::unique_ptr net) {
auto testExecutor = makeThreadPoolTestExecutor(std::move(net));
return stdx::make_unique(std::move(testExecutor));
}
}
ShardingTestFixture::ShardingTestFixture() = default;
ShardingTestFixture::~ShardingTestFixture() = default;
const Seconds ShardingTestFixture::kFutureTimeout{5};
void ShardingTestFixture::setUp() {
{
auto service = stdx::make_unique();
service->setFastClockSource(stdx::make_unique());
service->setPreciseClockSource(stdx::make_unique());
service->setTickSource(stdx::make_unique());
auto tlMock = stdx::make_unique();
_transportLayer = tlMock.get();
service->setTransportLayer(std::move(tlMock));
_transportLayer->start().transitional_ignore();
// Set the newly created service context to be the current global context so that tests,
// which invoke code still referencing getGlobalServiceContext will work properly.
setGlobalServiceContext(std::move(service));
}
CollatorFactoryInterface::set(serviceContext(), stdx::make_unique());
_transportSession = transport::MockSession::create(_transportLayer);
_client = serviceContext()->makeClient("ShardingTestFixture", _transportSession);
_opCtx = _client->makeOperationContext();
// Set up executor pool used for most operations.
auto fixedNet = stdx::make_unique();
fixedNet->setEgressMetadataHook(stdx::make_unique());
_mockNetwork = fixedNet.get();
auto fixedExec = makeShardingTestExecutor(std::move(fixedNet));
_networkTestEnv = stdx::make_unique(fixedExec.get(), _mockNetwork);
_executor = fixedExec.get();
auto netForPool = stdx::make_unique();
netForPool->setEgressMetadataHook(stdx::make_unique());
auto _mockNetworkForPool = netForPool.get();
auto execForPool = makeShardingTestExecutor(std::move(netForPool));
_networkTestEnvForPool =
stdx::make_unique(execForPool.get(), _mockNetworkForPool);
std::vector> executorsForPool;
executorsForPool.emplace_back(std::move(execForPool));
auto executorPool = stdx::make_unique();
executorPool->addExecutors(std::move(executorsForPool), std::move(fixedExec));
auto uniqueDistLockManager = stdx::make_unique(nullptr);
_distLockManager = uniqueDistLockManager.get();
std::unique_ptr catalogClient(
stdx::make_unique(std::move(uniqueDistLockManager)));
_catalogClient = catalogClient.get();
catalogClient->startup();
ConnectionString configCS = ConnectionString::forReplicaSet(
"configRS", {HostAndPort{"TestHost1"}, HostAndPort{"TestHost2"}});
auto targeterFactory(stdx::make_unique());
auto targeterFactoryPtr = targeterFactory.get();
_targeterFactory = targeterFactoryPtr;
auto configTargeter(stdx::make_unique());
_configTargeter = configTargeter.get();
_targeterFactory->addTargeterToReturn(configCS, std::move(configTargeter));
ShardFactory::BuilderCallable setBuilder =
[targeterFactoryPtr](const ShardId& shardId, const ConnectionString& connStr) {
return stdx::make_unique(
shardId, connStr, targeterFactoryPtr->create(connStr));
};
ShardFactory::BuilderCallable masterBuilder =
[targeterFactoryPtr](const ShardId& shardId, const ConnectionString& connStr) {
return stdx::make_unique(
shardId, connStr, targeterFactoryPtr->create(connStr));
};
ShardFactory::BuildersMap buildersMap{
{ConnectionString::SET, std::move(setBuilder)},
{ConnectionString::MASTER, std::move(masterBuilder)},
};
auto shardFactory =
stdx::make_unique(std::move(buildersMap), std::move(targeterFactory));
auto shardRegistry(stdx::make_unique(std::move(shardFactory), configCS));
executorPool->startup();
// For now initialize the global grid object. All sharding objects will be accessible from there
// until we get rid of it.
Grid::get(operationContext())
->init(std::move(catalogClient),
nullptr,
stdx::make_unique(),
std::move(shardRegistry),
stdx::make_unique(serviceContext()->getPreciseClockSource()),
stdx::make_unique(),
std::move(executorPool),
_mockNetwork);
}
void ShardingTestFixture::tearDown() {
Grid::get(operationContext())->getExecutorPool()->shutdownAndJoin();
Grid::get(operationContext())->catalogClient()->shutDown(_opCtx.get());
Grid::get(operationContext())->clearForUnitTests();
_transportSession.reset();
_opCtx.reset();
_client.reset();
}
void ShardingTestFixture::shutdownExecutor() {
if (_executor)
_executor->shutdown();
}
ShardingCatalogClient* ShardingTestFixture::catalogClient() const {
return Grid::get(operationContext())->catalogClient();
}
ShardingCatalogClientImpl* ShardingTestFixture::getCatalogClient() const {
return _catalogClient;
}
ShardRegistry* ShardingTestFixture::shardRegistry() const {
return Grid::get(operationContext())->shardRegistry();
}
RemoteCommandTargeterFactoryMock* ShardingTestFixture::targeterFactory() const {
invariant(_targeterFactory);
return _targeterFactory;
}
RemoteCommandTargeterMock* ShardingTestFixture::configTargeter() const {
invariant(_configTargeter);
return _configTargeter;
}
executor::NetworkInterfaceMock* ShardingTestFixture::network() const {
invariant(_mockNetwork);
return _mockNetwork;
}
executor::TaskExecutor* ShardingTestFixture::executor() const {
invariant(_executor);
return _executor;
}
DistLockManagerMock* ShardingTestFixture::distLock() const {
invariant(_distLockManager);
return _distLockManager;
}
ServiceContext* ShardingTestFixture::serviceContext() const {
return getGlobalServiceContext();
}
OperationContext* ShardingTestFixture::operationContext() const {
invariant(_opCtx);
return _opCtx.get();
}
void ShardingTestFixture::onCommand(NetworkTestEnv::OnCommandFunction func) {
_networkTestEnv->onCommand(func);
}
void ShardingTestFixture::onCommandWithMetadata(
NetworkTestEnv::OnCommandWithMetadataFunction func) {
_networkTestEnv->onCommandWithMetadata(func);
}
void ShardingTestFixture::onFindCommand(NetworkTestEnv::OnFindCommandFunction func) {
_networkTestEnv->onFindCommand(func);
}
void ShardingTestFixture::onFindWithMetadataCommand(
NetworkTestEnv::OnFindCommandWithMetadataFunction func) {
_networkTestEnv->onFindWithMetadataCommand(func);
}
void ShardingTestFixture::onCommandForPoolExecutor(NetworkTestEnv::OnCommandFunction func) {
_networkTestEnvForPool->onCommand(func);
}
void ShardingTestFixture::setupShards(const std::vector& shards) {
auto future = launchAsync([this] { shardRegistry()->reload(operationContext()); });
expectGetShards(shards);
future.timed_get(kFutureTimeout);
}
void ShardingTestFixture::expectGetShards(const std::vector& shards) {
onFindCommand([this, &shards](const RemoteCommandRequest& request) {
const NamespaceString nss(request.dbname, request.cmdObj.firstElement().String());
ASSERT_EQ(nss.toString(), ShardType::ConfigNS);
auto queryResult = QueryRequest::makeFromFindCommand(nss, request.cmdObj, false);
ASSERT_OK(queryResult.getStatus());
const auto& query = queryResult.getValue();
ASSERT_EQ(query->ns(), ShardType::ConfigNS);
ASSERT_BSONOBJ_EQ(query->getFilter(), BSONObj());
ASSERT_BSONOBJ_EQ(query->getSort(), BSONObj());
ASSERT_FALSE(query->getLimit().is_initialized());
checkReadConcern(request.cmdObj, Timestamp(0, 0), repl::OpTime::kUninitializedTerm);
vector shardsToReturn;
std::transform(shards.begin(),
shards.end(),
std::back_inserter(shardsToReturn),
[](const ShardType& shard) { return shard.toBSON(); });
return shardsToReturn;
});
}
void ShardingTestFixture::expectInserts(const NamespaceString& nss,
const std::vector& expected) {
onCommand([&nss, &expected](const RemoteCommandRequest& request) {
ASSERT_EQUALS(nss.db(), request.dbname);
BatchedInsertRequest actualBatchedInsert;
actualBatchedInsert.parseRequest(
OpMsgRequest::fromDBAndBody(request.dbname, request.cmdObj));
ASSERT_EQUALS(nss.toString(), actualBatchedInsert.getNS().toString());
auto inserted = actualBatchedInsert.getDocuments();
ASSERT_EQUALS(expected.size(), inserted.size());
auto itInserted = inserted.begin();
auto itExpected = expected.begin();
for (; itInserted != inserted.end(); itInserted++, itExpected++) {
ASSERT_BSONOBJ_EQ(*itExpected, *itInserted);
}
BatchedCommandResponse response;
response.setOk(true);
return response.toBSON();
});
}
void ShardingTestFixture::expectConfigCollectionCreate(const HostAndPort& configHost,
StringData collName,
int cappedSize,
const BSONObj& response) {
onCommand([&](const RemoteCommandRequest& request) {
ASSERT_EQUALS(configHost, request.target);
ASSERT_EQUALS("config", request.dbname);
BSONObj expectedCreateCmd =
BSON("create" << collName << "capped" << true << "size" << cappedSize << "writeConcern"
<< BSON("w"
<< "majority"
<< "wtimeout"
<< 15000)
<< "maxTimeMS"
<< 30000);
ASSERT_BSONOBJ_EQ(expectedCreateCmd, request.cmdObj);
return response;
});
}
void ShardingTestFixture::expectConfigCollectionInsert(const HostAndPort& configHost,
StringData collName,
Date_t timestamp,
const std::string& what,
const std::string& ns,
const BSONObj& detail) {
onCommand([&](const RemoteCommandRequest& request) {
ASSERT_EQUALS(configHost, request.target);
ASSERT_EQUALS("config", request.dbname);
BatchedInsertRequest actualBatchedInsert;
actualBatchedInsert.parseRequest(
OpMsgRequest::fromDBAndBody(request.dbname, request.cmdObj));
ASSERT_EQ("config", actualBatchedInsert.getNS().db());
ASSERT_EQ(collName, actualBatchedInsert.getNS().coll());
auto inserts = actualBatchedInsert.getDocuments();
ASSERT_EQUALS(1U, inserts.size());
const ChangeLogType& actualChangeLog = assertGet(ChangeLogType::fromBSON(inserts.front()));
ASSERT_EQUALS(operationContext()->getClient()->clientAddress(true),
actualChangeLog.getClientAddr());
ASSERT_BSONOBJ_EQ(detail, actualChangeLog.getDetails());
ASSERT_EQUALS(ns, actualChangeLog.getNS());
ASSERT_EQUALS(network()->getHostName(), actualChangeLog.getServer());
ASSERT_EQUALS(timestamp, actualChangeLog.getTime());
ASSERT_EQUALS(what, actualChangeLog.getWhat());
// Handle changeId specially because there's no way to know what OID was generated
std::string changeId = actualChangeLog.getChangeId();
size_t firstDash = changeId.find("-");
size_t lastDash = changeId.rfind("-");
const std::string serverPiece = changeId.substr(0, firstDash);
const std::string timePiece = changeId.substr(firstDash + 1, lastDash - firstDash - 1);
const std::string oidPiece = changeId.substr(lastDash + 1);
ASSERT_EQUALS(Grid::get(operationContext())->getNetwork()->getHostName(), serverPiece);
ASSERT_EQUALS(timestamp.toString(), timePiece);
OID generatedOID;
// Just make sure this doesn't throws and assume the OID is valid
generatedOID.init(oidPiece);
BatchedCommandResponse response;
response.setOk(true);
return response.toBSON();
});
}
void ShardingTestFixture::expectChangeLogCreate(const HostAndPort& configHost,
const BSONObj& response) {
expectConfigCollectionCreate(configHost, "changelog", 10 * 1024 * 1024, response);
}
void ShardingTestFixture::expectChangeLogInsert(const HostAndPort& configHost,
Date_t timestamp,
const std::string& what,
const std::string& ns,
const BSONObj& detail) {
expectConfigCollectionInsert(configHost, "changelog", timestamp, what, ns, detail);
}
void ShardingTestFixture::expectUpdateCollection(const HostAndPort& expectedHost,
const CollectionType& coll,
bool expectUpsert) {
onCommand([&](const RemoteCommandRequest& request) {
ASSERT_EQUALS(expectedHost, request.target);
ASSERT_BSONOBJ_EQ(BSON(rpc::kReplSetMetadataFieldName << 1),
rpc::TrackingMetadata::removeTrackingData(request.metadata));
ASSERT_EQUALS("config", request.dbname);
BatchedUpdateRequest actualBatchedUpdate;
actualBatchedUpdate.parseRequest(
OpMsgRequest::fromDBAndBody(request.dbname, request.cmdObj));
ASSERT_EQUALS(CollectionType::ConfigNS, actualBatchedUpdate.getNS().ns());
auto updates = actualBatchedUpdate.getUpdates();
ASSERT_EQUALS(1U, updates.size());
auto update = updates.front();
ASSERT_EQ(expectUpsert, update->getUpsert());
ASSERT_FALSE(update->getMulti());
ASSERT_BSONOBJ_EQ(update->getQuery(),
BSON(CollectionType::fullNs(coll.getNs().toString())));
ASSERT_BSONOBJ_EQ(update->getUpdateExpr(), coll.toBSON());
BatchedCommandResponse response;
response.setOk(true);
response.setNModified(1);
return response.toBSON();
});
}
void ShardingTestFixture::expectSetShardVersion(const HostAndPort& expectedHost,
const ShardType& expectedShard,
const NamespaceString& expectedNs,
const ChunkVersion& expectedChunkVersion) {
onCommand([&](const RemoteCommandRequest& request) {
ASSERT_EQ(expectedHost, request.target);
ASSERT_BSONOBJ_EQ(rpc::makeEmptyMetadata(),
rpc::TrackingMetadata::removeTrackingData(request.metadata));
SetShardVersionRequest ssv =
assertGet(SetShardVersionRequest::parseFromBSON(request.cmdObj));
ASSERT(!ssv.isInit());
ASSERT(ssv.isAuthoritative());
ASSERT_EQ(expectedShard.getHost(), ssv.getShardConnectionString().toString());
ASSERT_EQ(expectedNs.toString(), ssv.getNS().ns());
ASSERT_EQ(expectedChunkVersion.toString(), ssv.getNSVersion().toString());
return BSON("ok" << true);
});
}
void ShardingTestFixture::expectCount(const HostAndPort& configHost,
const NamespaceString& expectedNs,
const BSONObj& expectedQuery,
const StatusWith& response) {
onCommand([&](const RemoteCommandRequest& request) {
ASSERT_EQUALS(configHost, request.target);
string cmdName = request.cmdObj.firstElement().fieldName();
ASSERT_EQUALS("count", cmdName);
const NamespaceString nss(request.dbname, request.cmdObj.firstElement().String());
ASSERT_EQUALS(expectedNs.toString(), nss.toString());
if (expectedQuery.isEmpty()) {
auto queryElem = request.cmdObj["query"];
ASSERT_TRUE(queryElem.eoo() || queryElem.Obj().isEmpty());
} else {
ASSERT_BSONOBJ_EQ(expectedQuery, request.cmdObj["query"].Obj());
}
if (response.isOK()) {
return BSON("ok" << 1 << "n" << response.getValue());
}
checkReadConcern(request.cmdObj, Timestamp(0, 0), repl::OpTime::kUninitializedTerm);
BSONObjBuilder responseBuilder;
Command::appendCommandStatus(responseBuilder, response.getStatus());
return responseBuilder.obj();
});
}
void ShardingTestFixture::setRemote(const HostAndPort& remote) {
_transportSession = transport::MockSession::create(remote, HostAndPort{}, _transportLayer);
}
void ShardingTestFixture::checkReadConcern(const BSONObj& cmdObj,
const Timestamp& expectedTS,
long long expectedTerm) const {
auto readConcernElem = cmdObj[repl::ReadConcernArgs::kReadConcernFieldName];
ASSERT_EQ(Object, readConcernElem.type());
auto readConcernObj = readConcernElem.Obj();
ASSERT_EQ("majority", readConcernObj[repl::ReadConcernArgs::kLevelFieldName].str());
auto afterElem = readConcernObj[repl::ReadConcernArgs::kAfterOpTimeFieldName];
ASSERT_EQ(Object, afterElem.type());
auto afterObj = afterElem.Obj();
ASSERT_TRUE(afterObj.hasField(repl::OpTime::kTimestampFieldName));
ASSERT_EQ(expectedTS, afterObj[repl::OpTime::kTimestampFieldName].timestamp());
ASSERT_TRUE(afterObj.hasField(repl::OpTime::kTermFieldName));
ASSERT_EQ(expectedTerm, afterObj[repl::OpTime::kTermFieldName].numberLong());
}
} // namespace mongo