summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
authorMartin Neupauer <martin.neupauer@mongodb.com>2018-08-09 17:09:08 -0400
committerMartin Neupauer <martin.neupauer@mongodb.com>2018-08-30 14:17:06 -0400
commit47306b9f203abee01f6fc54aa8d7ab8f8e25c8c9 (patch)
tree9d1734d0958b5f07afd6dad4adede420696fba3a /src/mongo
parentb46de3f6c06fab5cf9b7ea0f4176b32ff544a4bf (diff)
downloadmongo-47306b9f203abee01f6fc54aa8d7ab8f8e25c8c9.tar.gz
SERVER-35905 Plug pieces together to perform a distributed when applicable
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/db/commands/run_aggregate.cpp66
-rw-r--r--src/mongo/db/pipeline/SConscript1
-rw-r--r--src/mongo/db/pipeline/aggregation_request.cpp11
-rw-r--r--src/mongo/db/pipeline/aggregation_request.h15
-rw-r--r--src/mongo/db/pipeline/aggregation_request_test.cpp16
-rw-r--r--src/mongo/db/pipeline/document_source_exchange.cpp88
-rw-r--r--src/mongo/db/pipeline/document_source_exchange.h46
-rw-r--r--src/mongo/db/pipeline/document_source_exchange_test.cpp247
-rw-r--r--src/mongo/db/query/cursor_response.cpp26
-rw-r--r--src/mongo/db/query/cursor_response.h6
-rw-r--r--src/mongo/s/query/cluster_aggregate.cpp107
-rw-r--r--src/mongo/s/query/establish_cursors.cpp24
12 files changed, 436 insertions, 217 deletions
diff --git a/src/mongo/db/commands/run_aggregate.cpp b/src/mongo/db/commands/run_aggregate.cpp
index 034f2da1826..958c33ffbfc 100644
--- a/src/mongo/db/commands/run_aggregate.cpp
+++ b/src/mongo/db/commands/run_aggregate.cpp
@@ -318,6 +318,26 @@ std::unique_ptr<CollatorInterface> resolveCollator(OperationContext* opCtx,
? collection->getDefaultCollator()->clone()
: nullptr);
}
+
+boost::intrusive_ptr<ExpressionContext> makeExpressionContext(
+ OperationContext* opCtx,
+ const AggregationRequest& request,
+ std::unique_ptr<CollatorInterface> collator,
+ boost::optional<UUID> uuid) {
+ boost::intrusive_ptr<ExpressionContext> expCtx =
+ new ExpressionContext(opCtx,
+ request,
+ std::move(collator),
+ MongoDInterface::create(opCtx),
+ uassertStatusOK(resolveInvolvedNamespaces(opCtx, request)),
+ uuid);
+ expCtx->tempDir = storageGlobalParams.dbpath + "/_tmp";
+ auto txnParticipant = TransactionParticipant::get(opCtx);
+ expCtx->inMultiDocumentTransaction =
+ txnParticipant && txnParticipant->inMultiDocumentTransaction();
+
+ return expCtx;
+}
} // namespace
Status runAggregate(OperationContext* opCtx,
@@ -462,17 +482,7 @@ Status runAggregate(OperationContext* opCtx,
}
invariant(collatorToUse);
- expCtx.reset(
- new ExpressionContext(opCtx,
- request,
- std::move(*collatorToUse),
- MongoDInterface::create(opCtx),
- uassertStatusOK(resolveInvolvedNamespaces(opCtx, request)),
- uuid));
- expCtx->tempDir = storageGlobalParams.dbpath + "/_tmp";
- auto txnParticipant = TransactionParticipant::get(opCtx);
- expCtx->inMultiDocumentTransaction =
- txnParticipant && txnParticipant->inMultiDocumentTransaction();
+ expCtx = makeExpressionContext(opCtx, request, std::move(*collatorToUse), uuid);
auto pipeline = uassertStatusOK(Pipeline::parse(request.getPipeline(), expCtx));
@@ -506,21 +516,31 @@ Status runAggregate(OperationContext* opCtx,
std::vector<std::unique_ptr<Pipeline, PipelineDeleter>> pipelines;
- pipelines.emplace_back(std::move(pipeline));
-
- auto exchange =
- dynamic_cast<DocumentSourceExchange*>(pipelines[0]->getSources().back().get());
- if (exchange) {
- for (size_t idx = 1; idx < exchange->getConsumers(); ++idx) {
- auto sources = pipelines[0]->getSources();
- sources.back() = new DocumentSourceExchange(expCtx, exchange->getExchange(), idx);
- pipelines.emplace_back(
- uassertStatusOK(Pipeline::create(std::move(sources), expCtx)));
+ if (request.getExchangeSpec()) {
+ boost::intrusive_ptr<Exchange> exchange =
+ new Exchange(request.getExchangeSpec().get(), std::move(pipeline));
+
+ for (size_t idx = 0; idx < exchange->getConsumers(); ++idx) {
+ // For every new pipeline we have create a new ExpressionContext as the context
+ // cannot be shared between threads. There is no synchronization for pieces of the
+ // execution machinery above the Exchange, so nothing above the Exchange can be
+ // shared between different exchange-producer cursors.
+ expCtx = makeExpressionContext(
+ opCtx,
+ request,
+ expCtx->getCollator() ? expCtx->getCollator()->clone() : nullptr,
+ uuid);
+
+ // Create a new pipeline for the consumer consisting of a single
+ // DocumentSourceExchange.
+ boost::intrusive_ptr<DocumentSource> consumer =
+ new DocumentSourceExchange(expCtx, exchange, idx);
+ pipelines.emplace_back(uassertStatusOK(Pipeline::create({consumer}, expCtx)));
}
+ } else {
+ pipelines.emplace_back(std::move(pipeline));
}
- // TODO we will revisit the current vector of pipelines design when we will implement
- // plan summaries, explains, etc.
for (size_t idx = 0; idx < pipelines.size(); ++idx) {
// Transfer ownership of the Pipeline to the PipelineProxyStage.
auto ws = make_unique<WorkingSet>();
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index d26d429a29a..5999ec5e3ac 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -92,6 +92,7 @@ env.Library(
'$BUILD_DIR/mongo/db/repl/read_concern_args',
'$BUILD_DIR/mongo/db/storage/storage_options',
'$BUILD_DIR/mongo/db/write_concern_options',
+ 'document_sources_idl',
'document_value',
]
)
diff --git a/src/mongo/db/pipeline/aggregation_request.cpp b/src/mongo/db/pipeline/aggregation_request.cpp
index 4f86b1c2a81..c93ab25f2f1 100644
--- a/src/mongo/db/pipeline/aggregation_request.cpp
+++ b/src/mongo/db/pipeline/aggregation_request.cpp
@@ -59,6 +59,7 @@ constexpr StringData AggregationRequest::kExplainName;
constexpr StringData AggregationRequest::kAllowDiskUseName;
constexpr StringData AggregationRequest::kHintName;
constexpr StringData AggregationRequest::kCommentName;
+constexpr StringData AggregationRequest::kExchangeName;
constexpr long long AggregationRequest::kDefaultBatchSize;
@@ -213,6 +214,13 @@ StatusWith<AggregationRequest> AggregationRequest::parseFromBSON(
<< typeName(elem.type())};
}
request.setAllowDiskUse(elem.Bool());
+ } else if (kExchangeName == fieldName) {
+ try {
+ IDLParserErrorContext ctx("internalExchange");
+ request.setExchangeSpec(ExchangeSpec::parse(ctx, elem.Obj()));
+ } catch (const DBException& ex) {
+ return ex.toStatus();
+ }
} else if (bypassDocumentValidationCommandOption() == fieldName) {
request.setBypassDocumentValidation(elem.trueValue());
} else if (!isGenericArgument(fieldName)) {
@@ -313,7 +321,8 @@ Document AggregationRequest::serializeToCommandObj() const {
_unwrappedReadPref.isEmpty() ? Value() : Value(_unwrappedReadPref)},
// Only serialize maxTimeMs if specified.
{QueryRequest::cmdOptionMaxTimeMS,
- _maxTimeMS == 0 ? Value() : Value(static_cast<int>(_maxTimeMS))}};
+ _maxTimeMS == 0 ? Value() : Value(static_cast<int>(_maxTimeMS))},
+ {kExchangeName, _exchangeSpec ? Value(_exchangeSpec->toBSON()) : Value()}};
}
} // namespace mongo
diff --git a/src/mongo/db/pipeline/aggregation_request.h b/src/mongo/db/pipeline/aggregation_request.h
index 3c0700a8375..53513f22bb1 100644
--- a/src/mongo/db/pipeline/aggregation_request.h
+++ b/src/mongo/db/pipeline/aggregation_request.h
@@ -34,6 +34,7 @@
#include "mongo/bson/bsonelement.h"
#include "mongo/bson/bsonobj.h"
#include "mongo/db/namespace_string.h"
+#include "mongo/db/pipeline/document_source_exchange_gen.h"
#include "mongo/db/query/explain_options.h"
namespace mongo {
@@ -58,6 +59,7 @@ public:
static constexpr StringData kAllowDiskUseName = "allowDiskUse"_sd;
static constexpr StringData kHintName = "hint"_sd;
static constexpr StringData kCommentName = "comment"_sd;
+ static constexpr StringData kExchangeName = "exchange"_sd;
static constexpr long long kDefaultBatchSize = 101;
@@ -186,6 +188,10 @@ public:
return _unwrappedReadPref;
}
+ const auto& getExchangeSpec() const {
+ return _exchangeSpec;
+ }
+
//
// Setters for optional fields.
//
@@ -242,6 +248,10 @@ public:
_unwrappedReadPref = unwrappedReadPref.getOwned();
}
+ void setExchangeSpec(ExchangeSpec spec) {
+ _exchangeSpec = std::move(spec);
+ }
+
private:
// Required fields.
const NamespaceString _nss;
@@ -282,5 +292,10 @@ private:
// A user-specified maxTimeMS limit, or a value of '0' if not specified.
unsigned int _maxTimeMS = 0;
+
+ // An optional exchange specification for this request. If set it means that the request
+ // represents a producer running as a part of the exchange machinery.
+ // This is an internal option; we do not expect it to be set on requests from users or drivers.
+ boost::optional<ExchangeSpec> _exchangeSpec;
};
} // namespace mongo
diff --git a/src/mongo/db/pipeline/aggregation_request_test.cpp b/src/mongo/db/pipeline/aggregation_request_test.cpp
index a5cbd390e80..441d5a19cee 100644
--- a/src/mongo/db/pipeline/aggregation_request_test.cpp
+++ b/src/mongo/db/pipeline/aggregation_request_test.cpp
@@ -59,7 +59,8 @@ TEST(AggregationRequestTest, ShouldParseAllKnownOptions) {
"{pipeline: [{$match: {a: 'abc'}}], explain: false, allowDiskUse: true, fromMongos: true, "
"needsMerge: true, bypassDocumentValidation: true, collation: {locale: 'en_US'}, cursor: "
"{batchSize: 10}, hint: {a: 1}, maxTimeMS: 100, readConcern: {level: 'linearizable'}, "
- "$queryOptions: {$readPreference: 'nearest'}, comment: 'agg_comment'}}");
+ "$queryOptions: {$readPreference: 'nearest'}, comment: 'agg_comment', exchange: {policy: "
+ "'roundrobin', consumers:NumberInt(2)}}");
auto request = unittest::assertGet(AggregationRequest::parseFromBSON(nss, inputBson));
ASSERT_FALSE(request.getExplain());
ASSERT_TRUE(request.shouldAllowDiskUse());
@@ -79,6 +80,7 @@ TEST(AggregationRequestTest, ShouldParseAllKnownOptions) {
ASSERT_BSONOBJ_EQ(request.getUnwrappedReadPref(),
BSON("$readPreference"
<< "nearest"));
+ ASSERT_TRUE(request.getExchangeSpec().is_initialized());
}
TEST(AggregationRequestTest, ShouldParseExplicitExplainTrue) {
@@ -480,6 +482,18 @@ TEST(AggregationRequestTest, ParseFromBSONOverloadsShouldProduceIdenticalRequest
ASSERT_DOCUMENT_EQ(aggReqDBName.serializeToCommandObj(), aggReqNSS.serializeToCommandObj());
}
+TEST(AggregationRequestTest, ShouldRejectExchangeNotObject) {
+ NamespaceString nss("a.collection");
+ const BSONObj inputBson = fromjson("{pipeline: [], exchage: '42'}");
+ ASSERT_NOT_OK(AggregationRequest::parseFromBSON(nss, inputBson).getStatus());
+}
+
+TEST(AggregationRequestTest, ShouldRejectExchangeInvalidSpec) {
+ NamespaceString nss("a.collection");
+ const BSONObj inputBson = fromjson("{pipeline: [], exchage: {}}");
+ ASSERT_NOT_OK(AggregationRequest::parseFromBSON(nss, inputBson).getStatus());
+}
+
//
// Ignore fields parsed elsewhere.
//
diff --git a/src/mongo/db/pipeline/document_source_exchange.cpp b/src/mongo/db/pipeline/document_source_exchange.cpp
index dddc3d74f0c..49de073af73 100644
--- a/src/mongo/db/pipeline/document_source_exchange.cpp
+++ b/src/mongo/db/pipeline/document_source_exchange.cpp
@@ -41,27 +41,8 @@
namespace mongo {
-REGISTER_DOCUMENT_SOURCE(exchange,
- LiteParsedDocumentSourceDefault::parse,
- DocumentSourceExchange::createFromBson);
-
const char* DocumentSourceExchange::getSourceName() const {
- return "$exchange";
-}
-
-boost::intrusive_ptr<DocumentSource> DocumentSourceExchange::createFromBson(
- BSONElement spec, const boost::intrusive_ptr<ExpressionContext>& pExpCtx) {
- uassert(ErrorCodes::FailedToParse,
- str::stream() << "$exchange options must be specified in an object, but found: "
- << typeName(spec.type()),
- spec.type() == BSONType::Object);
-
- IDLParserErrorContext ctx("$exchange");
- auto parsed = ExchangeSpec::parse(ctx, spec.embeddedObject());
-
- boost::intrusive_ptr<Exchange> exchange = new Exchange(parsed);
-
- return new DocumentSourceExchange(pExpCtx, exchange, 0);
+ return "$_internalExchange";
}
Value DocumentSourceExchange::serialize(boost::optional<ExplainOptions::Verbosity> explain) const {
@@ -75,31 +56,36 @@ DocumentSourceExchange::DocumentSourceExchange(
: DocumentSource(expCtx), _exchange(exchange), _consumerId(consumerId) {}
DocumentSource::GetNextResult DocumentSourceExchange::getNext() {
- return _exchange->getNext(_consumerId);
+ return _exchange->getNext(pExpCtx->opCtx, _consumerId);
}
-Exchange::Exchange(const ExchangeSpec& spec)
- : _spec(spec),
- _keyPattern(spec.getKey().getOwned()),
+Exchange::Exchange(ExchangeSpec spec, std::unique_ptr<Pipeline, PipelineDeleter> pipeline)
+ : _spec(std::move(spec)),
+ _keyPattern(_spec.getKey().getOwned()),
_ordering(extractOrdering(_keyPattern)),
- _boundaries(extractBoundaries(spec.getBoundaries())),
- _consumerIds(extractConsumerIds(spec.getConsumerids(), spec.getConsumers())),
- _policy(spec.getPolicy()),
- _orderPreserving(spec.getOrderPreserving()),
- _maxBufferSize(spec.getBufferSize()) {
- uassert(50901, "$exchange must have at least one consumer", spec.getConsumers() > 0);
-
- for (int idx = 0; idx < spec.getConsumers(); ++idx) {
+ _boundaries(extractBoundaries(_spec.getBoundaries())),
+ _consumerIds(extractConsumerIds(_spec.getConsumerids(), _spec.getConsumers())),
+ _policy(_spec.getPolicy()),
+ _orderPreserving(_spec.getOrderPreserving()),
+ _maxBufferSize(_spec.getBufferSize()),
+ _pipeline(std::move(pipeline)) {
+ uassert(50901, "Exchange must have at least one consumer", _spec.getConsumers() > 0);
+
+ for (int idx = 0; idx < _spec.getConsumers(); ++idx) {
_consumers.emplace_back(std::make_unique<ExchangeBuffer>());
}
if (_policy == ExchangePolicyEnum::kRange || _policy == ExchangePolicyEnum::kHash) {
uassert(50900,
- "$exchange boundaries do not match number of consumers.",
+ "Exchange boundaries do not match number of consumers.",
_boundaries.size() == _consumerIds.size() + 1);
} else {
- uassert(50899, "$exchange boundaries must not be specified.", _boundaries.empty());
+ uassert(50899, "Exchange boundaries must not be specified.", _boundaries.empty());
}
+
+ // We will manually detach and reattach when iterating '_pipeline', we expect it to start in the
+ // detached state.
+ _pipeline->detachFromOperationContext();
}
std::vector<std::string> Exchange::extractBoundaries(
@@ -125,7 +111,7 @@ std::vector<std::string> Exchange::extractBoundaries(
for (size_t idx = 1; idx < ret.size(); ++idx) {
uassert(50893,
- str::stream() << "$exchange range boundaries are not in ascending order.",
+ str::stream() << "Exchange range boundaries are not in ascending order.",
ret[idx - 1] < ret[idx]);
}
return ret;
@@ -151,7 +137,7 @@ std::vector<size_t> Exchange::extractConsumerIds(
}
uassert(50894,
- str::stream() << "$exchange consumers ids are invalid.",
+ str::stream() << "Exchange consumers ids are invalid.",
nConsumers > 0 && validation.size() == nConsumers && *validation.begin() == 0 &&
*validation.rbegin() == nConsumers - 1);
}
@@ -165,29 +151,29 @@ Ordering Exchange::extractOrdering(const BSONObj& obj) {
for (const auto& element : obj) {
if (element.type() == BSONType::String) {
uassert(50895,
- str::stream() << "$exchange key description is invalid: " << element,
+ str::stream() << "Exchange key description is invalid: " << element,
element.valueStringData() == "hashed"_sd);
hasHashKey = true;
} else if (element.isNumber()) {
auto num = element.number();
if (!(num == 1 || num == -1)) {
uasserted(50896,
- str::stream() << "$exchange key description is invalid: " << element);
+ str::stream() << "Exchange key description is invalid: " << element);
}
hasOrderKey = true;
} else {
- uasserted(50897, str::stream() << "$exchange key description is invalid: " << element);
+ uasserted(50897, str::stream() << "Exchange key description is invalid: " << element);
}
}
uassert(50898,
- str::stream() << "$exchange hash and order keys cannot be mixed together: " << obj,
+ str::stream() << "Exchange hash and order keys cannot be mixed together: " << obj,
!(hasHashKey && hasOrderKey));
return hasHashKey ? Ordering::make(BSONObj()) : Ordering::make(obj);
}
-DocumentSource::GetNextResult Exchange::getNext(size_t consumerId) {
+DocumentSource::GetNextResult Exchange::getNext(OperationContext* opCtx, size_t consumerId) {
// Grab a lock.
stdx::unique_lock<stdx::mutex> lk(_mutex);
@@ -212,11 +198,15 @@ DocumentSource::GetNextResult Exchange::getNext(size_t consumerId) {
// This consumer won the race and will fill the buffers.
_loadingThreadId = consumerId;
+ _pipeline->reattachToOperationContext(opCtx);
+
// This will return when some exchange buffer is full and we cannot make any forward
// progress anymore.
// The return value is an index of a full consumer buffer.
size_t fullConsumerId = loadNextBatch();
+ _pipeline->detachFromOperationContext();
+
// The loading cannot continue until the consumer with the full buffer consumes some
// documents.
_loadingThreadId = fullConsumerId;
@@ -232,9 +222,9 @@ DocumentSource::GetNextResult Exchange::getNext(size_t consumerId) {
}
size_t Exchange::loadNextBatch() {
- auto input = pSource->getNext();
+ auto input = _pipeline->getSources().back()->getNext();
- for (; input.isAdvanced(); input = pSource->getNext()) {
+ for (; input.isAdvanced(); input = _pipeline->getSources().back()->getNext()) {
// We have a document and we will deliver it to a consumer(s) based on the policy.
switch (_policy) {
case ExchangePolicyEnum::kBroadcast: {
@@ -317,6 +307,18 @@ size_t Exchange::getTargetConsumer(const Document& input) {
return cid;
}
+void Exchange::dispose(OperationContext* opCtx) {
+ stdx::lock_guard<stdx::mutex> lk(_mutex);
+
+ invariant(_disposeRunDown < getConsumers());
+
+ ++_disposeRunDown;
+
+ if (_disposeRunDown == getConsumers()) {
+ _pipeline->dispose(opCtx);
+ }
+}
+
DocumentSource::GetNextResult Exchange::ExchangeBuffer::getNext() {
invariant(!_buffer.empty());
diff --git a/src/mongo/db/pipeline/document_source_exchange.h b/src/mongo/db/pipeline/document_source_exchange.h
index ba4cb56ecdb..46277075bec 100644
--- a/src/mongo/db/pipeline/document_source_exchange.h
+++ b/src/mongo/db/pipeline/document_source_exchange.h
@@ -62,21 +62,19 @@ class Exchange : public RefCountable {
static Ordering extractOrdering(const BSONObj& obj);
public:
- explicit Exchange(const ExchangeSpec& spec);
- DocumentSource::GetNextResult getNext(size_t consumerId);
+ Exchange(ExchangeSpec spec, std::unique_ptr<Pipeline, PipelineDeleter> pipeline);
+ DocumentSource::GetNextResult getNext(OperationContext* opCtx, size_t consumerId);
size_t getConsumers() const {
return _consumers.size();
}
- void setSource(DocumentSource* source) {
- pSource = source;
- }
-
- const auto& getSpec() const {
+ auto& getSpec() const {
return _spec;
}
+ void dispose(OperationContext* opCtx);
+
private:
size_t loadNextBatch();
@@ -126,7 +124,7 @@ private:
const size_t _maxBufferSize;
// An input to the exchange operator
- DocumentSource* pSource;
+ std::unique_ptr<Pipeline, PipelineDeleter> _pipeline;
// Synchronization.
stdx::mutex _mutex;
@@ -137,14 +135,15 @@ private:
size_t _roundRobinCounter{0};
+ // A rundown counter of consumers disposing of the pipelines. Only the last consumer will
+ // dispose of the 'inner' exchange pipeline.
+ size_t _disposeRunDown{0};
+
std::vector<std::unique_ptr<ExchangeBuffer>> _consumers;
};
-class DocumentSourceExchange final : public DocumentSource, public NeedsMergerDocumentSource {
+class DocumentSourceExchange final : public DocumentSource {
public:
- static boost::intrusive_ptr<DocumentSource> createFromBson(
- BSONElement spec, const boost::intrusive_ptr<ExpressionContext>& pExpCtx);
-
DocumentSourceExchange(const boost::intrusive_ptr<ExpressionContext>& expCtx,
const boost::intrusive_ptr<Exchange> exchange,
size_t consumerId);
@@ -164,21 +163,12 @@ public:
Value serialize(boost::optional<ExplainOptions::Verbosity> explain = boost::none) const final;
- boost::intrusive_ptr<DocumentSource> getShardSource() final {
- return this;
- }
- MergingLogic mergingLogic() final {
- // TODO SERVER-35974 we have to revisit this when we implement consumers.
- return {this};
- }
-
/**
- * Set the underlying source this source should use to get Documents from. Must not throw
- * exceptions.
+ * DocumentSourceExchange does not have a direct source (it is reading through the shared
+ * Exchange pipeline).
*/
void setSource(DocumentSource* source) final {
- DocumentSource::setSource(source);
- _exchange->setSource(source);
+ invariant(!source);
}
GetNextResult getNext(size_t consumerId);
@@ -191,6 +181,14 @@ public:
return _exchange;
}
+ void doDispose() final {
+ _exchange->dispose(pExpCtx->opCtx);
+ }
+
+ auto getConsumerId() const {
+ return _consumerId;
+ }
+
private:
boost::intrusive_ptr<Exchange> _exchange;
diff --git a/src/mongo/db/pipeline/document_source_exchange_test.cpp b/src/mongo/db/pipeline/document_source_exchange_test.cpp
index de8e973f892..7906dedd8c3 100644
--- a/src/mongo/db/pipeline/document_source_exchange_test.cpp
+++ b/src/mongo/db/pipeline/document_source_exchange_test.cpp
@@ -44,10 +44,23 @@
namespace mongo {
+/**
+ * An implementation of the MongoProcessInterface that is okay with changing the OperationContext,
+ * but has no other parts of the interface implemented.
+ */
+class StubMongoProcessOkWithOpCtxChanges : public StubMongoProcessInterface {
+public:
+ void setOperationContext(OperationContext* opCtx) final {
+ return;
+ }
+};
+
class DocumentSourceExchangeTest : public AggregationContextFixture {
protected:
std::unique_ptr<executor::TaskExecutor> _executor;
virtual void setUp() override {
+ getExpCtx()->mongoProcessInterface = std::make_shared<StubMongoProcessOkWithOpCtxChanges>();
+
auto net = executor::makeNetworkInterface("ExchangeTest");
ThreadPool::Options options;
@@ -65,6 +78,7 @@ protected:
auto getMockSource(int cnt) {
auto source = DocumentSourceMock::create();
+
for (int i = 0; i < cnt; ++i)
source->queue.emplace_back(Document{{"a", i}, {"b", "aaaaaaaaaaaaaaaaaaaaaaaaaaa"_sd}});
@@ -82,12 +96,18 @@ protected:
PseudoRandom prng(seed);
auto source = DocumentSourceMock::create();
+
for (size_t i = 0; i < cnt; ++i)
source->queue.emplace_back(Document{{"a", static_cast<int>(prng.nextInt32() % cnt)},
{"b", "aaaaaaaaaaaaaaaaaaaaaaaaaaa"_sd}});
return source;
}
+
+ auto parseSpec(const BSONObj& spec) {
+ IDLParserErrorContext ctx("internalExchange");
+ return ExchangeSpec::parse(ctx, spec);
+ }
};
TEST_F(DocumentSourceExchangeTest, SimpleExchange1Consumer) {
@@ -100,14 +120,13 @@ TEST_F(DocumentSourceExchangeTest, SimpleExchange1Consumer) {
spec.setConsumers(1);
spec.setBufferSize(1024);
- boost::intrusive_ptr<Exchange> ex = new Exchange(spec);
-
- ex->setSource(source.get());
+ boost::intrusive_ptr<Exchange> ex =
+ new Exchange(spec, unittest::assertGet(Pipeline::create({source}, getExpCtx())));
- auto input = ex->getNext(0);
+ auto input = ex->getNext(getExpCtx()->opCtx, 0);
size_t docs = 0;
- for (; input.isAdvanced(); input = ex->getNext(0)) {
+ for (; input.isAdvanced(); input = ex->getNext(getExpCtx()->opCtx, 0)) {
++docs;
}
@@ -127,13 +146,13 @@ TEST_F(DocumentSourceExchangeTest, SimpleExchangeNConsumer) {
spec.setConsumers(nConsumers);
spec.setBufferSize(1024);
- boost::intrusive_ptr<Exchange> ex = new Exchange(spec);
+ boost::intrusive_ptr<Exchange> ex =
+ new Exchange(spec, unittest::assertGet(Pipeline::create({source}, getExpCtx())));
std::vector<boost::intrusive_ptr<DocumentSourceExchange>> prods;
for (size_t idx = 0; idx < nConsumers; ++idx) {
prods.push_back(new DocumentSourceExchange(getExpCtx(), ex, idx));
- prods.back()->setSource(source.get());
}
std::vector<executor::TaskExecutor::CallbackHandle> handles;
@@ -172,13 +191,13 @@ TEST_F(DocumentSourceExchangeTest, BroadcastExchangeNConsumer) {
spec.setConsumers(nConsumers);
spec.setBufferSize(1024);
- boost::intrusive_ptr<Exchange> ex = new Exchange(spec);
+ boost::intrusive_ptr<Exchange> ex =
+ new Exchange(spec, unittest::assertGet(Pipeline::create({source}, getExpCtx())));
std::vector<boost::intrusive_ptr<DocumentSourceExchange>> prods;
for (size_t idx = 0; idx < nConsumers; ++idx) {
prods.push_back(new DocumentSourceExchange(getExpCtx(), ex, idx));
- prods.back()->setSource(source.get());
}
std::vector<executor::TaskExecutor::CallbackHandle> handles;
@@ -223,13 +242,13 @@ TEST_F(DocumentSourceExchangeTest, RangeExchangeNConsumer) {
spec.setConsumers(nConsumers);
spec.setBufferSize(1024);
- boost::intrusive_ptr<Exchange> ex = new Exchange(spec);
+ boost::intrusive_ptr<Exchange> ex =
+ new Exchange(std::move(spec), unittest::assertGet(Pipeline::create({source}, getExpCtx())));
std::vector<boost::intrusive_ptr<DocumentSourceExchange>> prods;
for (size_t idx = 0; idx < nConsumers; ++idx) {
prods.push_back(new DocumentSourceExchange(getExpCtx(), ex, idx));
- prods.back()->setSource(source.get());
}
std::vector<executor::TaskExecutor::CallbackHandle> handles;
@@ -289,13 +308,13 @@ TEST_F(DocumentSourceExchangeTest, RangeShardingExchangeNConsumer) {
spec.setConsumers(nConsumers);
spec.setBufferSize(1024);
- boost::intrusive_ptr<Exchange> ex = new Exchange(spec);
+ boost::intrusive_ptr<Exchange> ex =
+ new Exchange(std::move(spec), unittest::assertGet(Pipeline::create({source}, getExpCtx())));
std::vector<boost::intrusive_ptr<DocumentSourceExchange>> prods;
for (size_t idx = 0; idx < nConsumers; ++idx) {
prods.push_back(new DocumentSourceExchange(getExpCtx(), ex, idx));
- prods.back()->setSource(source.get());
}
std::vector<executor::TaskExecutor::CallbackHandle> handles;
@@ -346,13 +365,13 @@ TEST_F(DocumentSourceExchangeTest, RangeRandomExchangeNConsumer) {
spec.setConsumers(nConsumers);
spec.setBufferSize(1024);
- boost::intrusive_ptr<Exchange> ex = new Exchange(spec);
+ boost::intrusive_ptr<Exchange> ex =
+ new Exchange(std::move(spec), unittest::assertGet(Pipeline::create({source}, getExpCtx())));
std::vector<boost::intrusive_ptr<DocumentSourceExchange>> prods;
for (size_t idx = 0; idx < nConsumers; ++idx) {
prods.push_back(new DocumentSourceExchange(getExpCtx(), ex, idx));
- prods.back()->setSource(source.get());
}
std::vector<executor::TaskExecutor::CallbackHandle> handles;
@@ -414,13 +433,13 @@ TEST_F(DocumentSourceExchangeTest, RangeRandomHashExchangeNConsumer) {
spec.setConsumers(nConsumers);
spec.setBufferSize(1024);
- boost::intrusive_ptr<Exchange> ex = new Exchange(spec);
+ boost::intrusive_ptr<Exchange> ex =
+ new Exchange(std::move(spec), unittest::assertGet(Pipeline::create({source}, getExpCtx())));
std::vector<boost::intrusive_ptr<DocumentSourceExchange>> prods;
for (size_t idx = 0; idx < nConsumers; ++idx) {
prods.push_back(new DocumentSourceExchange(getExpCtx(), ex, idx));
- prods.back()->setSource(source.get());
}
std::vector<executor::TaskExecutor::CallbackHandle> handles;
@@ -456,126 +475,122 @@ TEST_F(DocumentSourceExchangeTest, RangeRandomHashExchangeNConsumer) {
}
TEST_F(DocumentSourceExchangeTest, RejectNoConsumers) {
- BSONObj spec = BSON("$exchange" << BSON("policy"
- << "broadcast"
- << "consumers"
- << 0));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50901);
+ BSONObj spec = BSON("policy"
+ << "broadcast"
+ << "consumers"
+ << 0);
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50901);
}
TEST_F(DocumentSourceExchangeTest, RejectInvalidKey) {
- BSONObj spec = BSON("$exchange" << BSON("policy"
- << "broadcast"
- << "consumers"
- << 1
- << "key"
- << BSON("a" << 2)));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50896);
+ BSONObj spec = BSON("policy"
+ << "broadcast"
+ << "consumers"
+ << 1
+ << "key"
+ << BSON("a" << 2));
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50896);
}
TEST_F(DocumentSourceExchangeTest, RejectInvalidKeyHashExpected) {
- BSONObj spec = BSON("$exchange" << BSON("policy"
- << "broadcast"
- << "consumers"
- << 1
- << "key"
- << BSON("a"
- << "nothash")));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50895);
+ BSONObj spec = BSON("policy"
+ << "broadcast"
+ << "consumers"
+ << 1
+ << "key"
+ << BSON("a"
+ << "nothash"));
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50895);
}
TEST_F(DocumentSourceExchangeTest, RejectInvalidKeyWrongType) {
- BSONObj spec = BSON("$exchange" << BSON("policy"
- << "broadcast"
- << "consumers"
- << 1
- << "key"
- << BSON("a" << true)));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50897);
+ BSONObj spec = BSON("policy"
+ << "broadcast"
+ << "consumers"
+ << 1
+ << "key"
+ << BSON("a" << true));
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50897);
}
TEST_F(DocumentSourceExchangeTest, RejectInvalidBoundaries) {
- BSONObj spec =
- BSON("$exchange" << BSON("policy"
- << "range"
- << "consumers"
- << 1
- << "key"
- << BSON("a" << 1)
- << "boundaries"
- << BSON_ARRAY(BSON("a" << MAXKEY) << BSON("a" << MINKEY))
- << "consumerids"
- << BSON_ARRAY(0)));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50893);
+ BSONObj spec = BSON("policy"
+ << "range"
+ << "consumers"
+ << 1
+ << "key"
+ << BSON("a" << 1)
+ << "boundaries"
+ << BSON_ARRAY(BSON("a" << MAXKEY) << BSON("a" << MINKEY))
+ << "consumerids"
+ << BSON_ARRAY(0));
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50893);
}
TEST_F(DocumentSourceExchangeTest, RejectInvalidBoundariesAndConsumerIds) {
- BSONObj spec =
- BSON("$exchange" << BSON("policy"
- << "range"
- << "consumers"
- << 2
- << "key"
- << BSON("a" << 1)
- << "boundaries"
- << BSON_ARRAY(BSON("a" << MINKEY) << BSON("a" << MAXKEY))
- << "consumerids"
- << BSON_ARRAY(0 << 1)));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50900);
+ BSONObj spec = BSON("policy"
+ << "range"
+ << "consumers"
+ << 2
+ << "key"
+ << BSON("a" << 1)
+ << "boundaries"
+ << BSON_ARRAY(BSON("a" << MINKEY) << BSON("a" << MAXKEY))
+ << "consumerids"
+ << BSON_ARRAY(0 << 1));
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50900);
}
TEST_F(DocumentSourceExchangeTest, RejectInvalidPolicyBoundaries) {
- BSONObj spec =
- BSON("$exchange" << BSON("policy"
- << "roundrobin"
- << "consumers"
- << 1
- << "key"
- << BSON("a" << 1)
- << "boundaries"
- << BSON_ARRAY(BSON("a" << MINKEY) << BSON("a" << MAXKEY))
- << "consumerids"
- << BSON_ARRAY(0)));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50899);
+ BSONObj spec = BSON("policy"
+ << "roundrobin"
+ << "consumers"
+ << 1
+ << "key"
+ << BSON("a" << 1)
+ << "boundaries"
+ << BSON_ARRAY(BSON("a" << MINKEY) << BSON("a" << MAXKEY))
+ << "consumerids"
+ << BSON_ARRAY(0));
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50899);
}
TEST_F(DocumentSourceExchangeTest, RejectInvalidConsumerIds) {
- BSONObj spec =
- BSON("$exchange" << BSON("policy"
- << "range"
- << "consumers"
- << 1
- << "key"
- << BSON("a" << 1)
- << "boundaries"
- << BSON_ARRAY(BSON("a" << MINKEY) << BSON("a" << MAXKEY))
- << "consumerids"
- << BSON_ARRAY(1)));
- BSONElement specElement = spec.firstElement();
- ASSERT_THROWS_CODE(DocumentSourceExchange::createFromBson(specElement, getExpCtx()),
- AssertionException,
- 50894);
+ BSONObj spec = BSON("policy"
+ << "range"
+ << "consumers"
+ << 1
+ << "key"
+ << BSON("a" << 1)
+ << "boundaries"
+ << BSON_ARRAY(BSON("a" << MINKEY) << BSON("a" << MAXKEY))
+ << "consumerids"
+ << BSON_ARRAY(1));
+ ASSERT_THROWS_CODE(
+ Exchange(parseSpec(spec), unittest::assertGet(Pipeline::create({}, getExpCtx()))),
+ AssertionException,
+ 50894);
}
} // namespace mongo
diff --git a/src/mongo/db/query/cursor_response.cpp b/src/mongo/db/query/cursor_response.cpp
index 14b8af5e8f6..da7ef1127bc 100644
--- a/src/mongo/db/query/cursor_response.cpp
+++ b/src/mongo/db/query/cursor_response.cpp
@@ -39,6 +39,7 @@ namespace mongo {
namespace {
+const char kCursorsField[] = "cursors";
const char kCursorField[] = "cursor";
const char kIdField[] = "id";
const char kNsField[] = "ns";
@@ -129,6 +130,31 @@ CursorResponse::CursorResponse(NamespaceString nss,
_latestOplogTimestamp(latestOplogTimestamp),
_writeConcernError(std::move(writeConcernError)) {}
+std::vector<StatusWith<CursorResponse>> CursorResponse::parseFromBSONMany(
+ const BSONObj& cmdResponse) {
+ std::vector<StatusWith<CursorResponse>> cursors;
+ BSONElement cursorsElt = cmdResponse[kCursorsField];
+
+ // If there is not "cursors" array then treat it as a single cursor response
+ if (cursorsElt.type() != BSONType::Array) {
+ cursors.push_back(parseFromBSON(cmdResponse));
+ } else {
+ BSONObj cursorsObj = cursorsElt.embeddedObject();
+ for (BSONElement elt : cursorsObj) {
+ if (elt.type() != BSONType::Object) {
+ cursors.push_back({ErrorCodes::BadValue,
+ str::stream()
+ << "Cursors array element contains non-object element: "
+ << elt});
+ } else {
+ cursors.push_back(parseFromBSON(elt.Obj()));
+ }
+ }
+ }
+
+ return cursors;
+}
+
StatusWith<CursorResponse> CursorResponse::parseFromBSON(const BSONObj& cmdResponse) {
Status cmdStatus = getStatusFromCommandResult(cmdResponse);
if (!cmdStatus.isOK()) {
diff --git a/src/mongo/db/query/cursor_response.h b/src/mongo/db/query/cursor_response.h
index 529654118df..52f4d30ed6a 100644
--- a/src/mongo/db/query/cursor_response.h
+++ b/src/mongo/db/query/cursor_response.h
@@ -166,6 +166,12 @@ public:
};
/**
+ * Constructs a vector of CursorResponses from a command BSON response that represents one or
+ * more cursors.
+ */
+ static std::vector<StatusWith<CursorResponse>> parseFromBSONMany(const BSONObj& cmdResponse);
+
+ /**
* Constructs a CursorResponse from the command BSON response.
*/
static StatusWith<CursorResponse> parseFromBSON(const BSONObj& cmdResponse);
diff --git a/src/mongo/s/query/cluster_aggregate.cpp b/src/mongo/s/query/cluster_aggregate.cpp
index 8293aa4b65b..5d600054c93 100644
--- a/src/mongo/s/query/cluster_aggregate.cpp
+++ b/src/mongo/s/query/cluster_aggregate.cpp
@@ -211,7 +211,8 @@ BSONObj createPassthroughCommandForShard(OperationContext* opCtx,
BSONObj createCommandForTargetedShards(OperationContext* opCtx,
const AggregationRequest& request,
const SplitPipeline& splitPipeline,
- const BSONObj collationObj) {
+ const BSONObj collationObj,
+ const boost::optional<ExchangeSpec> exchangeSpec) {
// Create the command for the shards.
MutableDocument targetedCmd(request.serializeToCommandObj());
@@ -225,6 +226,9 @@ BSONObj createCommandForTargetedShards(OperationContext* opCtx,
targetedCmd[AggregationRequest::kCursorName] =
Value(DOC(AggregationRequest::kBatchSizeName << 0));
+ targetedCmd[AggregationRequest::kExchangeName] =
+ exchangeSpec ? Value(exchangeSpec->toBSON()) : Value();
+
return genericTransformForShards(std::move(targetedCmd), opCtx, request, collationObj);
}
@@ -328,6 +332,12 @@ struct DispatchShardPipelineResults {
// The command object to send to the targeted shards.
BSONObj commandForTargetedShards;
+
+ // How many producers are running the shard part of splitPipeline.
+ size_t numProducers;
+
+ // How many consumers are running the merging.
+ size_t numConsumers;
};
/**
@@ -399,7 +409,8 @@ DispatchShardPipelineResults dispatchShardPipeline(
// Generate the command object for the targeted shards.
BSONObj targetedCommand = splitPipeline
- ? createCommandForTargetedShards(opCtx, aggRequest, *splitPipeline, collationObj)
+ ? createCommandForTargetedShards(
+ opCtx, aggRequest, *splitPipeline, collationObj, aggRequest.getExchangeSpec())
: createPassthroughCommandForShard(
opCtx, aggRequest, pipeline.get(), originalCmdObj, collationObj);
@@ -413,6 +424,7 @@ DispatchShardPipelineResults dispatchShardPipeline(
}
}
+ size_t consumers = 1;
// Explain does not produce a cursor, so instead we scatter-gather commands to the shards.
if (expCtx->explain) {
if (mustRunOnAll) {
@@ -448,6 +460,12 @@ DispatchShardPipelineResults dispatchShardPipeline(
ReadPreferenceSetting::get(opCtx),
shardQuery,
aggRequest.getCollation());
+ invariant(cursors.size() % shardIds.size() == 0,
+ str::stream() << "Number of cursors (" << cursors.size()
+ << ") is not a multiple of producers ("
+ << shardIds.size()
+ << ")");
+ consumers = cursors.size() / shardIds.size();
}
// Record the number of shards involved in the aggregation. If we are required to merge on
@@ -462,7 +480,79 @@ DispatchShardPipelineResults dispatchShardPipeline(
std::move(shardResults),
std::move(splitPipeline),
std::move(pipeline),
- targetedCommand};
+ targetedCommand,
+ shardIds.size(),
+ consumers};
+}
+
+DispatchShardPipelineResults dispatchExchangeConsumerPipeline(
+ const boost::intrusive_ptr<ExpressionContext>& expCtx,
+ const NamespaceString& executionNss,
+ BSONObj originalCmdObj,
+ const AggregationRequest& aggRequest,
+ const LiteParsedPipeline& liteParsedPipeline,
+ BSONObj collationObj,
+ DispatchShardPipelineResults* shardDispatchResults) {
+ invariant(!liteParsedPipeline.hasChangeStream());
+ auto opCtx = expCtx->opCtx;
+
+ // TODO SERVER-35905 - we will use ShardDistributionInfo to determine shards that will run the
+ // consumers. For now we simply distribute to all shards.
+ std::vector<ShardId> shardIds;
+ Grid::get(opCtx)->shardRegistry()->getAllShardIds(opCtx, &shardIds);
+
+ // For all consumers construct a request with appropriate cursor ids and send to shards.
+ std::vector<std::pair<ShardId, BSONObj>> requests;
+ for (size_t idx = 0; idx < shardDispatchResults->numConsumers; ++idx) {
+
+ // Pick this consumer's cursors from producers.
+ std::vector<RemoteCursor> producers;
+ for (size_t p = 0; p < shardDispatchResults->numProducers; ++p) {
+ producers.emplace_back(std::move(
+ shardDispatchResults->remoteCursors[p * shardDispatchResults->numConsumers + idx]));
+ }
+
+ // Create a pipeline for a consumer and add the merging stage.
+ auto consumerPipeline = uassertStatusOK(Pipeline::create(
+ shardDispatchResults->splitPipeline->mergePipeline->getSources(), expCtx));
+
+ cluster_aggregation_planner::addMergeCursorsSource(
+ consumerPipeline.get(),
+ liteParsedPipeline,
+ BSONObj(),
+ std::move(producers),
+ {},
+ shardDispatchResults->splitPipeline->shardCursorsSortSpec,
+ Grid::get(opCtx)->getExecutorPool()->getArbitraryExecutor());
+
+ SplitPipeline pipeline(std::move(consumerPipeline), nullptr, boost::none);
+
+ auto consumerCmdObj =
+ createCommandForTargetedShards(opCtx, aggRequest, pipeline, collationObj, boost::none);
+
+ requests.emplace_back(shardIds[idx % shardIds.size()], consumerCmdObj);
+ }
+ auto cursors = establishCursors(opCtx,
+ Grid::get(opCtx)->getExecutorPool()->getArbitraryExecutor(),
+ executionNss,
+ ReadPreferenceSetting::get(opCtx),
+ requests,
+ false /* do not allow partial results */);
+
+ // The merging pipeline is just a union of the results from each of the shards involved on the
+ // consumer side of the exchange.
+ auto mergePipeline = uassertStatusOK(Pipeline::create({}, expCtx));
+ mergePipeline->setSplitState(Pipeline::SplitState::kSplitForMerge);
+
+ SplitPipeline splitPipeline{nullptr, std::move(mergePipeline), boost::none};
+ return DispatchShardPipelineResults{false,
+ std::move(cursors),
+ {} /*TODO SERVER-36279*/,
+ std::move(splitPipeline),
+ nullptr,
+ BSONObj(),
+ shardDispatchResults->numConsumers,
+ 1};
}
Status appendExplainResults(DispatchShardPipelineResults&& dispatchResults,
@@ -991,6 +1081,17 @@ Status ClusterAggregate::runAggregate(OperationContext* opCtx,
remoteCursor.getShardId().toString(), reply, result);
}
+ // If we have more than 1 consumer (i.e. the exchange operator is in use) then dispatch all
+ // consumers.
+ if (shardDispatchResults.numConsumers > 1) {
+ shardDispatchResults = dispatchExchangeConsumerPipeline(expCtx,
+ namespaces.executionNss,
+ cmdObj,
+ request,
+ litePipe,
+ collationObj,
+ &shardDispatchResults);
+ }
// If we reach here, we have a merge pipeline to dispatch.
return dispatchMergingPipeline(expCtx,
namespaces,
diff --git a/src/mongo/s/query/establish_cursors.cpp b/src/mongo/s/query/establish_cursors.cpp
index 08ce5a2cb5b..8cdb30d1c0b 100644
--- a/src/mongo/s/query/establish_cursors.cpp
+++ b/src/mongo/s/query/establish_cursors.cpp
@@ -78,12 +78,24 @@ std::vector<RemoteCursor> establishCursors(OperationContext* opCtx,
// Additionally, be careful not to push into 'remoteCursors' until we are sure we
// have a valid cursor, since the error handling path will attempt to clean up
// anything in 'remoteCursors'
- RemoteCursor cursor;
- cursor.setCursorResponse(CursorResponse::parseFromBSONThrowing(
- uassertStatusOK(std::move(response.swResponse)).data));
- cursor.setShardId(std::move(response.shardId));
- cursor.setHostAndPort(*response.shardHostAndPort);
- remoteCursors.push_back(std::move(cursor));
+ auto cursors = CursorResponse::parseFromBSONMany(
+ uassertStatusOK(std::move(response.swResponse)).data);
+
+ for (auto& cursor : cursors) {
+ if (cursor.isOK()) {
+ RemoteCursor remoteCursor;
+ remoteCursor.setCursorResponse(std::move(cursor.getValue()));
+ remoteCursor.setShardId(std::move(response.shardId));
+ remoteCursor.setHostAndPort(*response.shardHostAndPort);
+ remoteCursors.push_back(std::move(remoteCursor));
+ }
+ }
+
+ // Throw if there is any error and then the catch block below will do the cleanup.
+ for (auto& cursor : cursors) {
+ uassertStatusOK(cursor.getStatus());
+ }
+
} catch (const DBException& ex) {
// Retriable errors are swallowed if 'allowPartialResults' is true.
if (allowPartialResults &&