summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Cox <eric.cox@mongodb.com>2021-12-07 21:39:55 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-12-07 22:28:04 +0000
commit280b35dbc2f71e692df3e02e399e997575e00bf0 (patch)
tree56eea75215b744829f103b5b6518a76d0d7efe5b
parent2c0ae37a1a318a19b6d666a911738746601de908 (diff)
downloadmongo-280b35dbc2f71e692df3e02e399e997575e00bf0.tar.gz
SERVER-58436 Implement spilling HashAgg
-rw-r--r--jstests/aggregation/spill_to_disk.js12
-rw-r--r--jstests/libs/sbe_assert_error_override.js5
-rw-r--r--jstests/noPassthroughWithMongod/group_pushdown.js18
-rw-r--r--src/mongo/db/exec/sbe/SConscript1
-rw-r--r--src/mongo/db/exec/sbe/parser/parser.cpp1
-rw-r--r--src/mongo/db/exec/sbe/parser/sbe_parser_test.cpp2
-rw-r--r--src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp448
-rw-r--r--src/mongo/db/exec/sbe/sbe_plan_size_test.cpp1
-rw-r--r--src/mongo/db/exec/sbe/sbe_plan_stage_test.cpp6
-rw-r--r--src/mongo/db/exec/sbe/sbe_plan_stage_test.h12
-rw-r--r--src/mongo/db/exec/sbe/sbe_trial_run_tracker_test.cpp3
-rw-r--r--src/mongo/db/exec/sbe/stages/hash_agg.cpp309
-rw-r--r--src/mongo/db/exec/sbe/stages/hash_agg.h49
-rw-r--r--src/mongo/db/pipeline/pipeline_d.cpp2
-rw-r--r--src/mongo/db/query/sbe_stage_builder.cpp4
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp1
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.cpp2
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.h11
18 files changed, 803 insertions, 84 deletions
diff --git a/jstests/aggregation/spill_to_disk.js b/jstests/aggregation/spill_to_disk.js
index 0865586a1c4..be9cfa2bc20 100644
--- a/jstests/aggregation/spill_to_disk.js
+++ b/jstests/aggregation/spill_to_disk.js
@@ -60,6 +60,13 @@ function test({pipeline, expectedCodes, canSpillToDisk}) {
}
}
+assert.commandWorked(db.adminCommand({
+ setParameter: 1,
+ internalQuerySlotBasedExecutionHashAggApproxMemoryUseInBytesBeforeSpill: 1024
+}));
+assert.commandWorked(db.adminCommand(
+ {setParameter: 1, internalQuerySlotBasedExecutionHashAggMemoryUseSampleRate: 1.0}));
+
test({
pipeline: [{$group: {_id: '$_id', bigStr: {$min: '$bigStr'}}}],
expectedCodes: ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed,
@@ -73,6 +80,7 @@ test({
expectedCodes: ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed,
canSpillToDisk: true
});
+
test({
pipeline: [{$sort: {bigStr: 1}}], // big key and value
expectedCodes: ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed,
@@ -92,11 +100,13 @@ test({
expectedCodes: ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed,
canSpillToDisk: true
});
+
test({
pipeline: [{$group: {_id: '$_id', bigStr: {$min: '$bigStr'}}}, {$sort: {_id: -1}}],
expectedCodes: ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed,
canSpillToDisk: true
});
+
test({
pipeline: [{$group: {_id: '$_id', bigStr: {$min: '$bigStr'}}}, {$sort: {random: 1}}],
expectedCodes: ErrorCodes.QueryExceededMemoryLimitNoDiskUseAllowed,
@@ -158,5 +168,5 @@ for (const op of ['$firstN', '$lastN', '$minN', '$maxN', '$topN', '$bottomN']) {
}
// don't leave large collection laying around
-assert(coll.drop());
+coll.drop();
})();
diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js
index 2f4d9a1a37f..74ec08ef9bc 100644
--- a/jstests/libs/sbe_assert_error_override.js
+++ b/jstests/libs/sbe_assert_error_override.js
@@ -114,8 +114,9 @@ const equivalentErrorCodesList = [
[5338802, 5439016],
[5687301, 5687400],
[5687302, 5687401],
- [292, 5859000],
- [6045000, 5166606]
+ [292, 5859000, 5843600, 5843601],
+ [6045000, 5166606],
+ [146, 13548]
];
// This map is generated based on the contents of 'equivalentErrorCodesList'. This map should _not_
diff --git a/jstests/noPassthroughWithMongod/group_pushdown.js b/jstests/noPassthroughWithMongod/group_pushdown.js
index 78b761fa649..b71069c6ccc 100644
--- a/jstests/noPassthroughWithMongod/group_pushdown.js
+++ b/jstests/noPassthroughWithMongod/group_pushdown.js
@@ -23,8 +23,9 @@ assert.commandWorked(coll.insert([
{"_id": 5, "item": "c", "price": 5, "quantity": 10, "date": ISODate("2014-02-15T09:05:00Z")},
]));
-let assertGroupPushdown = function(coll, pipeline, expectedResults, expectedGroupCountInExplain) {
- const explain = coll.explain().aggregate(pipeline);
+let assertGroupPushdown = function(
+ coll, pipeline, expectedResults, expectedGroupCountInExplain, options = {}) {
+ const explain = coll.explain().aggregate(pipeline, options);
// When $group is pushed down it will never be present as a stage in the 'winningPlan' of
// $cursor.
if (expectedGroupCountInExplain > 1) {
@@ -33,7 +34,7 @@ let assertGroupPushdown = function(coll, pipeline, expectedResults, expectedGrou
assert.neq(null, getAggPlanStage(explain, "GROUP"), explain);
}
- let results = coll.aggregate(pipeline).toArray();
+ let results = coll.aggregate(pipeline, options).toArray();
assert.sameMembers(results, expectedResults);
};
@@ -253,11 +254,12 @@ assertNoGroupPushdown(coll,
[{$group: {_id: {"i": "$item"}, s: {$sum: "$price"}}}],
[{_id: {i: "a"}, s: 15}, {_id: {i: "b"}, s: 30}, {_id: {i: "c"}, s: 5}]);
-// Spilling isn't supported yet so $group with 'allowDiskUse' true won't get pushed down.
-assertNoGroupPushdown(coll,
- [{$group: {_id: "$item", s: {$sum: "$price"}}}],
- [{"_id": "b", "s": 30}, {"_id": "a", "s": 15}, {"_id": "c", "s": 5}],
- {allowDiskUse: true, cursor: {batchSize: 1}});
+// Run a group with spilling on and check that $group is pushed down.
+assertGroupPushdown(coll,
+ [{$group: {_id: "$item", s: {$sum: "$price"}}}],
+ [{"_id": "b", "s": 30}, {"_id": "a", "s": 15}, {"_id": "c", "s": 5}],
+ 1,
+ {allowDiskUse: true, cursor: {batchSize: 1}});
// Run a pipeline with match, sort, group to check if the whole pipeline gets pushed down.
assertGroupPushdown(coll,
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript
index 175642a5121..6f26193db8f 100644
--- a/src/mongo/db/exec/sbe/SConscript
+++ b/src/mongo/db/exec/sbe/SConscript
@@ -64,6 +64,7 @@ sbeEnv.Library(
],
LIBDEPS=[
'$BUILD_DIR/mongo/base',
+ '$BUILD_DIR/mongo/db/concurrency/lock_manager',
'$BUILD_DIR/mongo/db/concurrency/write_conflict_exception',
'$BUILD_DIR/mongo/db/exec/js_function',
'$BUILD_DIR/mongo/db/exec/scoped_timer',
diff --git a/src/mongo/db/exec/sbe/parser/parser.cpp b/src/mongo/db/exec/sbe/parser/parser.cpp
index a169d712ff8..2217992cec7 100644
--- a/src/mongo/db/exec/sbe/parser/parser.cpp
+++ b/src/mongo/db/exec/sbe/parser/parser.cpp
@@ -1151,6 +1151,7 @@ void Parser::walkGroup(AstQuery& ast) {
true,
collatorSlotPos ? lookupSlot(std::move(ast.nodes[collatorSlotPos]->identifier))
: boost::none,
+ true, // allowDiskUse
getCurrentPlanNodeId());
}
diff --git a/src/mongo/db/exec/sbe/parser/sbe_parser_test.cpp b/src/mongo/db/exec/sbe/parser/sbe_parser_test.cpp
index fb1ba49d9b5..bf21402d279 100644
--- a/src/mongo/db/exec/sbe/parser/sbe_parser_test.cpp
+++ b/src/mongo/db/exec/sbe/parser/sbe_parser_test.cpp
@@ -360,6 +360,7 @@ protected:
sbe::makeSV(),
true,
boost::none, /* optional collator slot */
+ true, /* allowDiskUse */
planNodeId),
// GROUP with a collator slot.
sbe::makeS<sbe::HashAggStage>(
@@ -374,6 +375,7 @@ protected:
sbe::makeSV(),
true,
sbe::value::SlotId{4}, /* optional collator slot */
+ true, /* allowDiskUse */
planNodeId),
// LIMIT
sbe::makeS<sbe::LimitSkipStage>(
diff --git a/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp b/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp
index b7ce2dd2ac3..58ecd40b98e 100644
--- a/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp
+++ b/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp
@@ -66,7 +66,7 @@ void HashAggStageTest::performHashAggWithSpillChecking(
auto collatorSlot = generateSlotId();
auto shouldUseCollator = optionalCollator.get() != nullptr;
- auto makeStageFn = [this, collatorSlot, shouldUseCollator](
+ auto makeStageFn = [this, collatorSlot, shouldUseCollator, shouldSpill](
value::SlotId scanSlot, std::unique_ptr<PlanStage> scanStage) {
auto countsSlot = generateSlotId();
@@ -80,6 +80,7 @@ void HashAggStageTest::performHashAggWithSpillChecking(
makeSV(),
true,
boost::optional<value::SlotId>{shouldUseCollator, collatorSlot},
+ shouldSpill,
kEmptyPlanNodeId);
return std::make_pair(countsSlot, std::move(hashAggStage));
@@ -98,17 +99,6 @@ void HashAggStageTest::performHashAggWithSpillChecking(
inputGuard.reset();
auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
- // Prepare the tree and get the 'SlotAccessor' for the output slot.
- if (shouldSpill) {
- auto hashAggStage = makeStageFn(scanSlot, std::move(scanStage));
- // 'prepareTree()' also opens the tree after preparing it thus the spilling error should
- // occur in 'prepareTree()'.
- ASSERT_THROWS_CODE(prepareTree(ctx.get(), hashAggStage.second.get(), hashAggStage.first),
- DBException,
- 5859000);
- return;
- }
-
auto [outputSlot, stage] = makeStageFn(scanSlot, std::move(scanStage));
auto resultAccessor = prepareTree(ctx.get(), stage.get(), outputSlot);
@@ -189,6 +179,7 @@ TEST_F(HashAggStageTest, HashAggMinMaxTest) {
makeSV(),
true,
boost::none,
+ false /* allowDiskUse */,
kEmptyPlanNodeId);
auto outSlot = generateSlotId();
@@ -245,6 +236,7 @@ TEST_F(HashAggStageTest, HashAggAddToSetTest) {
makeSV(),
true,
boost::none,
+ false /* allowDiskUse */,
kEmptyPlanNodeId);
return std::make_pair(hashAggSlot, std::move(hashAggStage));
@@ -342,6 +334,7 @@ TEST_F(HashAggStageTest, HashAggSeekKeysTest) {
makeSV(seekSlot),
true,
boost::none,
+ false /* allowDiskUse */,
kEmptyPlanNodeId);
return std::make_pair(countsSlot, std::move(hashAggStage));
@@ -381,6 +374,277 @@ TEST_F(HashAggStageTest, HashAggSeekKeysTest) {
stage->close();
}
+TEST_F(HashAggStageTest, HashAggBasicCountSpill) {
+ // Changing the query knobs to always re-estimate the hash table size in HashAgg and spill when
+ // estimated size is >= 4 * 8.
+ auto defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill =
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.load();
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(4 * 8);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(
+ defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill);
+ });
+ auto defaultInternalQuerySBEAggMemoryUseSampleRate =
+ internalQuerySBEAggMemoryUseSampleRate.load();
+ internalQuerySBEAggMemoryUseSampleRate.store(1.0);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggMemoryUseSampleRate.store(defaultInternalQuerySBEAggMemoryUseSampleRate);
+ });
+
+ auto ctx = makeCompileCtx();
+
+ // Build a scan of the [5,6,7,5,6,7,6,7,7] input array.
+ auto [inputTag, inputVal] =
+ stage_builder::makeValue(BSON_ARRAY(5 << 6 << 7 << 5 << 6 << 7 << 6 << 7 << 7));
+ auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
+
+ // Build a HashAggStage, group by the scanSlot and compute a simple count.
+ auto countsSlot = generateSlotId();
+ auto stage = makeS<HashAggStage>(
+ std::move(scanStage),
+ makeSV(scanSlot),
+ makeEM(countsSlot,
+ stage_builder::makeFunction(
+ "sum",
+ makeE<EConstant>(value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(1)))),
+ makeSV(), // Seek slot
+ true,
+ boost::none,
+ true /* allowDiskUse */,
+ kEmptyPlanNodeId);
+
+ // Prepare the tree and get the 'SlotAccessor' for the output slot.
+ auto resultAccessor = prepareTree(ctx.get(), stage.get(), countsSlot);
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res1Tag, res1Val] = resultAccessor->getViewOfValue();
+ // There are '2' occurences of '5' in the input.
+ assertValuesEqual(res1Tag, res1Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(2));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res2Tag, res2Val] = resultAccessor->getViewOfValue();
+ // There are '3' occurences of '6' in the input.
+ assertValuesEqual(res2Tag, res2Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(3));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res3Tag, res3Val] = resultAccessor->getViewOfValue();
+ // There are '4' occurences of '7' in the input.
+ assertValuesEqual(res3Tag, res3Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(4));
+ ASSERT_TRUE(stage->getNext() == PlanState::IS_EOF);
+
+ stage->close();
+}
+
+TEST_F(HashAggStageTest, HashAggBasicCountSpillDouble) {
+ // Changing the query knobs to always re-estimate the hash table size in HashAgg and spill when
+ // estimated size is >= 4 * 8.
+ auto defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill =
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.load();
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(4 * 8);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(
+ defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill);
+ });
+ auto defaultInternalQuerySBEAggMemoryUseSampleRate =
+ internalQuerySBEAggMemoryUseSampleRate.load();
+ internalQuerySBEAggMemoryUseSampleRate.store(1.0);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggMemoryUseSampleRate.store(defaultInternalQuerySBEAggMemoryUseSampleRate);
+ });
+
+ auto ctx = makeCompileCtx();
+
+ // Build a scan of the [5,6,7,5,6,7,6,7,7] input array.
+ auto [inputTag, inputVal] = stage_builder::makeValue(
+ BSON_ARRAY(5.0 << 6.0 << 7.0 << 5.0 << 6.0 << 7.0 << 6.0 << 7.0 << 7.0));
+ auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
+
+ // Build a HashAggStage, group by the scanSlot and compute a simple count.
+ auto countsSlot = generateSlotId();
+ auto stage = makeS<HashAggStage>(
+ std::move(scanStage),
+ makeSV(scanSlot),
+ makeEM(countsSlot,
+ stage_builder::makeFunction(
+ "sum",
+ makeE<EConstant>(value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(1)))),
+ makeSV(), // Seek slot
+ true,
+ boost::none,
+ true /* allowDiskUse */,
+ kEmptyPlanNodeId);
+
+ // Prepare the tree and get the 'SlotAccessor' for the output slot.
+ auto resultAccessor = prepareTree(ctx.get(), stage.get(), countsSlot);
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res1Tag, res1Val] = resultAccessor->getViewOfValue();
+ // There are '2' occurences of '5' in the input.
+ assertValuesEqual(res1Tag, res1Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(2));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res2Tag, res2Val] = resultAccessor->getViewOfValue();
+ // There are '3' occurences of '6' in the input.
+ assertValuesEqual(res2Tag, res2Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(3));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res3Tag, res3Val] = resultAccessor->getViewOfValue();
+ // There are '4' occurences of '7' in the input.
+ assertValuesEqual(res3Tag, res3Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(4));
+ ASSERT_TRUE(stage->getNext() == PlanState::IS_EOF);
+
+ stage->close();
+}
+
+
+TEST_F(HashAggStageTest, HashAggMultipleAccSpill) {
+ // Changing the query knobs to always re-estimate the hash table size in HashAgg and spill when
+ // estimated size is >= 2 * 8.
+ auto defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill =
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.load();
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(2 * 8);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(
+ defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill);
+ });
+ auto defaultInternalQuerySBEAggMemoryUseSampleRate =
+ internalQuerySBEAggMemoryUseSampleRate.load();
+ internalQuerySBEAggMemoryUseSampleRate.store(1.0);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggMemoryUseSampleRate.store(defaultInternalQuerySBEAggMemoryUseSampleRate);
+ });
+
+ auto ctx = makeCompileCtx();
+
+ // Build a scan of the [5,6,7,5,6,7,6,7,7] input array.
+ auto [inputTag, inputVal] =
+ stage_builder::makeValue(BSON_ARRAY(5 << 6 << 7 << 5 << 6 << 7 << 6 << 7 << 7));
+ auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
+
+ // Build a HashAggStage, group by the scanSlot and compute a simple count.
+ auto countsSlot = generateSlotId();
+ auto sumsSlot = generateSlotId();
+ auto stage = makeS<HashAggStage>(
+ std::move(scanStage),
+ makeSV(scanSlot),
+ makeEM(countsSlot,
+ stage_builder::makeFunction(
+ "sum",
+ makeE<EConstant>(value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(1))),
+ sumsSlot,
+ stage_builder::makeFunction("sum", makeE<EVariable>(scanSlot))),
+ makeSV(), // Seek slot
+ true,
+ boost::none,
+ true /* allowDiskUse */,
+ kEmptyPlanNodeId);
+
+ // Prepare the tree and get the 'SlotAccessor' for the output slot.
+ auto resultAccessors = prepareTree(ctx.get(), stage.get(), makeSV(countsSlot, sumsSlot));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res1Tag, res1Val] = resultAccessors[0]->getViewOfValue();
+ auto [res1TagSum, res1ValSum] = resultAccessors[1]->getViewOfValue();
+
+ // There are '2' occurences of '5' in the input.
+ assertValuesEqual(res1Tag, res1Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(2));
+ assertValuesEqual(
+ res1TagSum, res1ValSum, value::TypeTags::NumberInt32, value::bitcastFrom<int>(10));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res2Tag, res2Val] = resultAccessors[0]->getViewOfValue();
+ auto [res2TagSum, res2ValSum] = resultAccessors[1]->getViewOfValue();
+ // There are '3' occurences of '6' in the input.
+ assertValuesEqual(res2Tag, res2Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(3));
+ assertValuesEqual(
+ res2TagSum, res2ValSum, value::TypeTags::NumberInt32, value::bitcastFrom<int>(18));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res3Tag, res3Val] = resultAccessors[0]->getViewOfValue();
+ auto [res3TagSum, res3ValSum] = resultAccessors[1]->getViewOfValue();
+ // There are '4' occurences of '7' in the input.
+ assertValuesEqual(res3Tag, res3Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(4));
+ assertValuesEqual(
+ res3TagSum, res3ValSum, value::TypeTags::NumberInt32, value::bitcastFrom<int>(28));
+ ASSERT_TRUE(stage->getNext() == PlanState::IS_EOF);
+
+ stage->close();
+}
+
+TEST_F(HashAggStageTest, HashAggMultipleAccSpillAllToDisk) {
+ // Changing the query knobs to always re-estimate the hash table size in HashAgg and spill when
+ // estimated size is >= 0. This sill spill everything to the RecordStore.
+ auto defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill =
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.load();
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(0);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(
+ defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill);
+ });
+ auto defaultInternalQuerySBEAggMemoryUseSampleRate =
+ internalQuerySBEAggMemoryUseSampleRate.load();
+ internalQuerySBEAggMemoryUseSampleRate.store(1.0);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggMemoryUseSampleRate.store(defaultInternalQuerySBEAggMemoryUseSampleRate);
+ });
+
+ auto ctx = makeCompileCtx();
+
+ // Build a scan of the [5,6,7,5,6,7,6,7,7] input array.
+ auto [inputTag, inputVal] =
+ stage_builder::makeValue(BSON_ARRAY(5 << 6 << 7 << 5 << 6 << 7 << 6 << 7 << 7));
+ auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
+
+ // Build a HashAggStage, group by the scanSlot and compute a simple count.
+ auto countsSlot = generateSlotId();
+ auto sumsSlot = generateSlotId();
+ auto stage = makeS<HashAggStage>(
+ std::move(scanStage),
+ makeSV(scanSlot),
+ makeEM(countsSlot,
+ stage_builder::makeFunction(
+ "sum",
+ makeE<EConstant>(value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(1))),
+ sumsSlot,
+ stage_builder::makeFunction("sum", makeE<EVariable>(scanSlot))),
+ makeSV(), // Seek slot
+ true,
+ boost::none,
+ true, // allowDiskUse=true
+ kEmptyPlanNodeId);
+
+ // Prepare the tree and get the 'SlotAccessor' for the output slot.
+ auto resultAccessors = prepareTree(ctx.get(), stage.get(), makeSV(countsSlot, sumsSlot));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res1Tag, res1Val] = resultAccessors[0]->getViewOfValue();
+ auto [res1TagSum, res1ValSum] = resultAccessors[1]->getViewOfValue();
+
+ // There are '2' occurences of '5' in the input.
+ assertValuesEqual(res1Tag, res1Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(2));
+ assertValuesEqual(
+ res1TagSum, res1ValSum, value::TypeTags::NumberInt32, value::bitcastFrom<int>(10));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res2Tag, res2Val] = resultAccessors[0]->getViewOfValue();
+ auto [res2TagSum, res2ValSum] = resultAccessors[1]->getViewOfValue();
+ // There are '3' occurences of '6' in the input.
+ assertValuesEqual(res2Tag, res2Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(3));
+ assertValuesEqual(
+ res2TagSum, res2ValSum, value::TypeTags::NumberInt32, value::bitcastFrom<int>(18));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res3Tag, res3Val] = resultAccessors[0]->getViewOfValue();
+ auto [res3TagSum, res3ValSum] = resultAccessors[1]->getViewOfValue();
+ // There are '4' occurences of '7' in the input.
+ assertValuesEqual(res3Tag, res3Val, value::TypeTags::NumberInt32, value::bitcastFrom<int>(4));
+ assertValuesEqual(
+ res3TagSum, res3ValSum, value::TypeTags::NumberInt32, value::bitcastFrom<int>(28));
+ ASSERT_TRUE(stage->getNext() == PlanState::IS_EOF);
+
+ stage->close();
+}
+
TEST_F(HashAggStageTest, HashAggMemUsageTest) {
// Changing the query knobs to always re-estimate the hash table size in HashAgg and spill when
// estimated size is >= 128 * 5.
@@ -420,4 +684,164 @@ TEST_F(HashAggStageTest, HashAggMemUsageTest) {
performHashAggWithSpillChecking(spillInputArr, expectedOutputArr, true);
}
+TEST_F(HashAggStageTest, HashAggSum10Groups) {
+ // Changing the query knobs to always re-estimate the hash table size in HashAgg and spill when
+ // estimated size is >= 128. This should spilt the number of ints between the hash table and
+ // the record store somewhat evenly.
+ const auto memLimit = 128;
+ auto defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill =
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.load();
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(memLimit);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(
+ defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill);
+ });
+ auto defaultInternalQuerySBEAggMemoryUseSampleRate =
+ internalQuerySBEAggMemoryUseSampleRate.load();
+ internalQuerySBEAggMemoryUseSampleRate.store(1.0);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggMemoryUseSampleRate.store(defaultInternalQuerySBEAggMemoryUseSampleRate);
+ });
+
+ auto ctx = makeCompileCtx();
+
+ // Build an array with sums over 100 congruence groups.
+ BSONArrayBuilder builder;
+ stdx::unordered_map<int, int> sums;
+ for (int i = 0; i < 10 * memLimit; ++i) {
+ auto val = i % 10;
+ auto [it, inserted] = sums.try_emplace(val, val);
+ if (!inserted) {
+ it->second += val;
+ }
+ builder.append(val);
+ }
+
+ auto [inputTag, inputVal] = stage_builder::makeValue(BSONArray(builder.done()));
+ auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
+
+ // Build a HashAggStage, group by the scanSlot and compute a sum for each group.
+ auto sumsSlot = generateSlotId();
+ auto stage = makeS<HashAggStage>(
+ std::move(scanStage),
+ makeSV(scanSlot),
+ makeEM(sumsSlot, stage_builder::makeFunction("sum", makeE<EVariable>(scanSlot))),
+ makeSV(), // Seek slot
+ true,
+ boost::none,
+ true, // allowDiskUse=true
+ kEmptyPlanNodeId);
+
+ // Prepare the tree and get the 'SlotAccessor' for the output slot.
+ auto resultAccessors = prepareTree(ctx.get(), stage.get(), makeSV(scanSlot, sumsSlot));
+
+ while (stage->getNext() == PlanState::ADVANCED) {
+ auto [resGroupByTag, resGroupByVal] = resultAccessors[0]->getViewOfValue();
+ auto [resSumTag, resSumVal] = resultAccessors[1]->getViewOfValue();
+ auto it = sums.find(value::bitcastTo<int>(resGroupByVal));
+ ASSERT_TRUE(it != sums.end());
+ assertValuesEqual(resSumTag,
+ resSumVal,
+ value::TypeTags::NumberInt32,
+ value::bitcastFrom<int>(it->second));
+ }
+ stage->close();
+}
+
+TEST_F(HashAggStageTest, HashAggBasicCountWithRecordIds) {
+ // Changing the query knobs to always re-estimate the hash table size in HashAgg and spill when
+ // estimated size is >= 4 * 8.
+ auto defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill =
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.load();
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(4 * 8);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggApproxMemoryUseInBytesBeforeSpill.store(
+ defaultInternalQuerySBEAggApproxMemoryUseInBytesBeforeSpill);
+ });
+ auto defaultInternalQuerySBEAggMemoryUseSampleRate =
+ internalQuerySBEAggMemoryUseSampleRate.load();
+ internalQuerySBEAggMemoryUseSampleRate.store(1.0);
+ ON_BLOCK_EXIT([&] {
+ internalQuerySBEAggMemoryUseSampleRate.store(defaultInternalQuerySBEAggMemoryUseSampleRate);
+ });
+
+ auto ctx = makeCompileCtx();
+
+ // Build a scan of record ids [1,10,999,10,1,999,8589869056,999,10,8589869056] input array.
+ auto [inputTag, inputVal] = sbe::value::makeNewArray();
+ auto testData = sbe::value::getArrayView(inputVal);
+
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(1));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(10));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(999));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(10));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(999));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(1));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(999));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(8589869056));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(999));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(10));
+ testData->push_back(value::TypeTags::RecordId, value::bitcastFrom<int64_t>(8589869056));
+
+ auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
+
+ // Build a HashAggStage, group by the scanSlot and compute a simple count.
+ auto countsSlot = generateSlotId();
+ auto stage = makeS<HashAggStage>(
+ std::move(scanStage),
+ makeSV(scanSlot),
+ makeEM(countsSlot,
+ stage_builder::makeFunction(
+ "sum",
+ makeE<EConstant>(value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(1)))),
+ makeSV(), // Seek slot
+ true,
+ boost::none,
+ true, // allowDiskUse=true
+ kEmptyPlanNodeId);
+
+ // Prepare the tree and get the 'SlotAccessor' for the output slot.
+ auto resultAccessors = prepareTree(ctx.get(), stage.get(), makeSV(scanSlot, countsSlot));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res1ScanTag, res1ScanVal] = resultAccessors[0]->getViewOfValue();
+ auto [res1Tag, res1Val] = resultAccessors[1]->getViewOfValue();
+ // There are '2' occurences of '1' in the input.
+ assertValuesEqual(
+ res1ScanTag, res1ScanVal, value::TypeTags::RecordId, value::bitcastFrom<int64_t>(1));
+ assertValuesEqual(
+ res1Tag, res1Val, value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(2));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res2ScanTag, res2ScanVal] = resultAccessors[0]->getViewOfValue();
+ auto [res2Tag, res2Val] = resultAccessors[1]->getViewOfValue();
+ // There are '2' occurences of '8589869056' in the input.
+ assertValuesEqual(res2ScanTag,
+ res2ScanVal,
+ value::TypeTags::RecordId,
+ value::bitcastFrom<int64_t>(8589869056));
+ assertValuesEqual(
+ res2Tag, res2Val, value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(2));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res3ScanTag, res3ScanVal] = resultAccessors[0]->getViewOfValue();
+ auto [res3Tag, res3Val] = resultAccessors[1]->getViewOfValue();
+ // There are '3' occurences of '10' in the input.
+ assertValuesEqual(
+ res3ScanTag, res3ScanVal, value::TypeTags::RecordId, value::bitcastFrom<int64_t>(10));
+ assertValuesEqual(
+ res3Tag, res3Val, value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(3));
+
+ ASSERT_TRUE(stage->getNext() == PlanState::ADVANCED);
+ auto [res4ScanTag, res4ScanVal] = resultAccessors[0]->getViewOfValue();
+ auto [res4Tag, res4Val] = resultAccessors[1]->getViewOfValue();
+ // There are '4' occurences of '999' in the input.
+ assertValuesEqual(
+ res4ScanTag, res4ScanVal, value::TypeTags::RecordId, value::bitcastFrom<int64_t>(999));
+ assertValuesEqual(
+ res4Tag, res4Val, value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(4));
+ ASSERT_TRUE(stage->getNext() == PlanState::IS_EOF);
+
+ stage->close();
+}
} // namespace mongo::sbe
diff --git a/src/mongo/db/exec/sbe/sbe_plan_size_test.cpp b/src/mongo/db/exec/sbe/sbe_plan_size_test.cpp
index 91d97e3e498..76bd240d1a2 100644
--- a/src/mongo/db/exec/sbe/sbe_plan_size_test.cpp
+++ b/src/mongo/db/exec/sbe/sbe_plan_size_test.cpp
@@ -143,6 +143,7 @@ TEST_F(PlanSizeTest, HashAgg) {
makeSV(),
true,
generateSlotId(),
+ false,
kEmptyPlanNodeId);
assertPlanSize(*stage);
}
diff --git a/src/mongo/db/exec/sbe/sbe_plan_stage_test.cpp b/src/mongo/db/exec/sbe/sbe_plan_stage_test.cpp
index 4c1209f49c0..a57812a3c1e 100644
--- a/src/mongo/db/exec/sbe/sbe_plan_stage_test.cpp
+++ b/src/mongo/db/exec/sbe/sbe_plan_stage_test.cpp
@@ -41,12 +41,6 @@
namespace mongo::sbe {
-PlanStageTestFixture::PlanStageTestFixture() {
- auto service = getServiceContext();
- service->registerClientObserver(
- std::make_unique<LockerNoopClientObserverWithReplacementPolicy>());
-}
-
void PlanStageTestFixture::assertValuesEqual(value::TypeTags lhsTag,
value::Value lhsVal,
value::TypeTags rhsTag,
diff --git a/src/mongo/db/exec/sbe/sbe_plan_stage_test.h b/src/mongo/db/exec/sbe/sbe_plan_stage_test.h
index 05174c73747..febf4d841db 100644
--- a/src/mongo/db/exec/sbe/sbe_plan_stage_test.h
+++ b/src/mongo/db/exec/sbe/sbe_plan_stage_test.h
@@ -42,7 +42,7 @@
#include "mongo/db/exec/sbe/values/value.h"
#include "mongo/db/query/sbe_stage_builder.h"
#include "mongo/db/query/sbe_stage_builder_helpers.h"
-#include "mongo/db/service_context_test_fixture.h"
+#include "mongo/db/service_context_d_test_fixture.h"
#include "mongo/unittest/unittest.h"
namespace mongo::sbe {
@@ -80,20 +80,20 @@ using MakeStageFn = std::function<std::pair<T, std::unique_ptr<PlanStage>>(
* observe 1 output slot, use runTest(). For unittests where the PlanStage has multiple input slots
* and/or where the test needs to observe multiple output slots, use runTestMulti().
*/
-class PlanStageTestFixture : public ServiceContextTest {
+class PlanStageTestFixture : public ServiceContextMongoDTest {
public:
- PlanStageTestFixture();
+ PlanStageTestFixture() = default;
void setUp() override {
- ServiceContextTest::setUp();
- _opCtx = makeOperationContext();
+ ServiceContextMongoDTest::setUp();
+ _opCtx = cc().makeOperationContext();
_slotIdGenerator.reset(new value::SlotIdGenerator());
}
void tearDown() override {
_slotIdGenerator.reset();
_opCtx.reset();
- ServiceContextTest::tearDown();
+ ServiceContextMongoDTest::tearDown();
}
OperationContext* opCtx() {
diff --git a/src/mongo/db/exec/sbe/sbe_trial_run_tracker_test.cpp b/src/mongo/db/exec/sbe/sbe_trial_run_tracker_test.cpp
index 1e0644c35ff..b3aa236dd8d 100644
--- a/src/mongo/db/exec/sbe/sbe_trial_run_tracker_test.cpp
+++ b/src/mongo/db/exec/sbe/sbe_trial_run_tracker_test.cpp
@@ -148,6 +148,7 @@ TEST_F(TrialRunTrackerTest, TrialEndsDuringOpenPhaseOfBlockingStage) {
makeSV(), // Seek slot
true,
boost::none,
+ false /* allowDiskUse */,
kEmptyPlanNodeId);
auto tracker = std::make_unique<TrialRunTracker>(numResultsLimit, size_t{0});
@@ -216,6 +217,7 @@ TEST_F(TrialRunTrackerTest, OnlyDeepestNestedBlockingStageHasTrialRunTracker) {
makeSV(), // Seek slot
true,
boost::none,
+ false /* allowDiskUse */,
kEmptyPlanNodeId);
hashAggStage->prepare(*ctx);
@@ -282,6 +284,7 @@ TEST_F(TrialRunTrackerTest, SiblingBlockingStagesBothGetTrialRunTracker) {
makeSV(), // Seek slot
true,
boost::none,
+ false /* allowDiskUse */,
kEmptyPlanNodeId);
return std::make_pair(countsSlot, std::move(hashAggStage));
diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.cpp b/src/mongo/db/exec/sbe/stages/hash_agg.cpp
index 625d4c62585..0179036395e 100644
--- a/src/mongo/db/exec/sbe/stages/hash_agg.cpp
+++ b/src/mongo/db/exec/sbe/stages/hash_agg.cpp
@@ -29,8 +29,11 @@
#include "mongo/platform/basic.h"
+#include "mongo/db/concurrency/d_concurrency.h"
+#include "mongo/db/concurrency/write_conflict_exception.h"
#include "mongo/db/exec/sbe/stages/hash_agg.h"
-
+#include "mongo/db/storage/kv/kv_engine.h"
+#include "mongo/db/storage/storage_engine.h"
#include "mongo/util/str.h"
#include "mongo/db/exec/sbe/size_estimator.h"
@@ -43,11 +46,13 @@ HashAggStage::HashAggStage(std::unique_ptr<PlanStage> input,
value::SlotVector seekKeysSlots,
bool optimizedClose,
boost::optional<value::SlotId> collatorSlot,
+ bool allowDiskUse,
PlanNodeId planNodeId)
: PlanStage("group"_sd, planNodeId),
_gbs(std::move(gbs)),
_aggs(std::move(aggs)),
_collatorSlot(collatorSlot),
+ _allowDiskUse(allowDiskUse),
_seekKeysSlots(std::move(seekKeysSlots)),
_optimizedClose(optimizedClose) {
_children.emplace_back(std::move(input));
@@ -68,9 +73,40 @@ std::unique_ptr<PlanStage> HashAggStage::clone() const {
_seekKeysSlots,
_optimizedClose,
_collatorSlot,
+ _allowDiskUse,
_commonStats.nodeId);
}
+void HashAggStage::doSaveState(bool relinquishCursor) {
+ if (relinquishCursor) {
+ if (_rsCursor) {
+ _rsCursor->save();
+ }
+ }
+ if (_rsCursor) {
+ _rsCursor->setSaveStorageCursorOnDetachFromOperationContext(!relinquishCursor);
+ }
+}
+
+void HashAggStage::doRestoreState(bool relinquishCursor) {
+ invariant(_opCtx);
+ if (_rsCursor && relinquishCursor) {
+ _rsCursor->restore();
+ }
+}
+
+void HashAggStage::doDetachFromOperationContext() {
+ if (_rsCursor) {
+ _rsCursor->detachFromOperationContext();
+ }
+}
+
+void HashAggStage::doAttachToOperationContext(OperationContext* opCtx) {
+ if (_rsCursor) {
+ _rsCursor->reattachToOperationContext(opCtx);
+ }
+}
+
void HashAggStage::prepare(CompileCtx& ctx) {
_children[0]->prepare(ctx);
@@ -89,7 +125,26 @@ void HashAggStage::prepare(CompileCtx& ctx) {
uassert(4822827, str::stream() << "duplicate field: " << slot, inserted);
_inKeyAccessors.emplace_back(_children[0]->getAccessor(ctx, slot));
- _outKeyAccessors.emplace_back(std::make_unique<HashKeyAccessor>(_htIt, counter++));
+
+ // Construct accessors for the key to be processed from either the '_ht' or the
+ // '_recordStore'. Before the memory limit is reached the '_outHashKeyAccessors' will carry
+ // the group-by keys, otherwise the '_outRecordStoreKeyAccessors' will carry the group-by
+ // keys.
+ _outHashKeyAccessors.emplace_back(std::make_unique<HashKeyAccessor>(_htIt, counter));
+ _outRecordStoreKeyAccessors.emplace_back(
+ std::make_unique<value::MaterializedSingleRowAccessor>(_aggKeyRecordStore, counter));
+
+ counter++;
+
+ // A SwitchAccessor is used to point the '_outKeyAccessors' to the key coming from the '_ht'
+ // or the '_recordStore' when draining the HashAgg stage in getNext. The group-by key will
+ // either be in the '_ht' or the '_recordStore' if the key lives in memory, or if the key
+ // has been spilled to disk, respectively. The SwitchAccessor allows toggling between the
+ // two so the parent stage can read it through the '_outAccessors'.
+ _outKeyAccessors.emplace_back(
+ std::make_unique<value::SwitchAccessor>(std::vector<value::SlotAccessor*>{
+ _outHashKeyAccessors.back().get(), _outRecordStoreKeyAccessors.back().get()}));
+
_outAccessors[slot] = _outKeyAccessors.back().get();
}
@@ -108,7 +163,23 @@ void HashAggStage::prepare(CompileCtx& ctx) {
const auto slotId = slot;
uassert(4822828, str::stream() << "duplicate field: " << slotId, inserted);
- _outAggAccessors.emplace_back(std::make_unique<HashAggAccessor>(_htIt, counter++));
+ // Construct accessors for the agg state to be processed from either the '_ht' or the
+ // '_recordStore' by the SwitchAccessor owned by '_outAggAccessors' below.
+ _outRecordStoreAggAccessors.emplace_back(
+ std::make_unique<value::MaterializedSingleRowAccessor>(_aggValueRecordStore, counter));
+ _outHashAggAccessors.emplace_back(std::make_unique<HashAggAccessor>(_htIt, counter));
+ counter++;
+
+ // A SwitchAccessor is used to toggle the '_outAggAccessors' between the '_ht' and the
+ // '_recordStore' when updating the agg state via the bytecode. By compiling the agg
+ // EExpressions with a SwitchAccessor we can load the agg value into the of memory
+ // '_aggValueRecordStore' if the value comes from the '_recordStore' or we can use the
+ // agg value referenced through '_htIt' and run the bytecode to mutate the value through the
+ // SwitchAccessor.
+ _outAggAccessors.emplace_back(
+ std::make_unique<value::SwitchAccessor>(std::vector<value::SlotAccessor*>{
+ _outHashAggAccessors.back().get(), _outRecordStoreAggAccessors.back().get()}));
+
_outAccessors[slot] = _outAggAccessors.back().get();
ctx.root = this;
@@ -134,15 +205,29 @@ value::SlotAccessor* HashAggStage::getAccessor(CompileCtx& ctx, value::SlotId sl
}
namespace {
-// This check makes sure we are safe to spill to disk without the need to abandon the current
-// snapshot.
-void assertIgnoreConflictsWriteBehavior(OperationContext* opCtx) {
+// Proactively assert that this operation can safely write before hitting an assertion in the
+// storage engine. We can safely write if we are enforcing prepare conflicts by blocking or if we
+// are ignoring prepare conflicts and explicitly allowing writes. Ignoring prepare conflicts
+// without allowing writes will cause this operation to fail in the storage engine.
+void assertIgnorePrepareConflictsBehavior(OperationContext* opCtx) {
tassert(5907502,
"The operation must be ignoring conflicts and allowing writes or enforcing prepare "
"conflicts entirely",
opCtx->recoveryUnit()->getPrepareConflictBehavior() !=
PrepareConflictBehavior::kIgnoreConflicts);
}
+
+/**
+ * This helper takes the 'rid' RecordId (the group-by key) and rehydrates it into a KeyString::Value
+ * from the typeBits.
+ */
+KeyString::Value rehydrateKey(const RecordId& rid, KeyString::TypeBits typeBits) {
+ auto rawKey = rid.getStr();
+ KeyString::Builder kb{KeyString::Version::kLatestVersion};
+ kb.resetFromBuffer(rawKey.rawData(), rawKey.size());
+ kb.setTypeBits(typeBits);
+ return kb.getValueCopy();
+}
} // namespace
void HashAggStage::makeTemporaryRecordStore() {
@@ -153,11 +238,58 @@ void HashAggStage::makeTemporaryRecordStore() {
tassert(5907501,
"No storage engine so HashAggStage cannot spill to disk",
_opCtx->getServiceContext()->getStorageEngine());
- assertIgnoreConflictsWriteBehavior(_opCtx);
+ assertIgnorePrepareConflictsBehavior(_opCtx);
_recordStore = _opCtx->getServiceContext()->getStorageEngine()->makeTemporaryRecordStore(
_opCtx, KeyFormat::String);
}
+void HashAggStage::spillValueToDisk(const RecordId& key,
+ const value::MaterializedRow& val,
+ const KeyString::TypeBits& typeBits,
+ bool update) {
+ BufBuilder bufValue;
+
+ val.serializeForSorter(bufValue);
+
+ // Append the 'typeBits' to the end of the val's buffer so the 'key' can be reconstructed when
+ // draining HashAgg.
+ bufValue.appendBuf(typeBits.getBuffer(), typeBits.getSize());
+
+ assertIgnorePrepareConflictsBehavior(_opCtx);
+
+ // Take a dummy lock to avoid tripping invariants in the storage layer. This is a noop because
+ // we aren't writing to a collection, just a temporary record store that only HashAgg will
+ // touch.
+ Lock::GlobalLock lk(_opCtx, MODE_IX);
+ WriteUnitOfWork wuow(_opCtx);
+ auto result = [&]() {
+ if (update) {
+ auto status =
+ _recordStore->rs()->updateRecord(_opCtx, key, bufValue.buf(), bufValue.len());
+ return status;
+ } else {
+ auto status = _recordStore->rs()->insertRecord(
+ _opCtx, key, bufValue.buf(), bufValue.len(), Timestamp{});
+ return status.getStatus();
+ }
+ }();
+ wuow.commit();
+ tassert(5843600,
+ str::stream() << "Failed to write to disk because " << result.reason(),
+ result.isOK());
+}
+
+boost::optional<value::MaterializedRow> HashAggStage::getFromRecordStore(const RecordId& rid) {
+ Lock::GlobalLock lk(_opCtx, MODE_IS);
+ RecordData record;
+ if (_recordStore->rs()->findRecord(_opCtx, rid, &record)) {
+ auto valueReader = BufReader(record.data(), record.size());
+ return value::MaterializedRow::deserializeForSorter(valueReader, {});
+ } else {
+ return boost::none;
+ }
+}
+
void HashAggStage::open(bool reOpen) {
auto optTimer(getOptTimer(_opCtx));
@@ -166,7 +298,6 @@ void HashAggStage::open(bool reOpen) {
if (!reOpen || _seekKeysAccessors.empty()) {
_children[0]->open(_childOpened);
_childOpened = true;
-
if (_collatorAccessor) {
auto [tag, collatorVal] = _collatorAccessor->getViewOfValue();
uassert(
@@ -184,6 +315,9 @@ void HashAggStage::open(bool reOpen) {
// A counter to check memory usage periodically.
auto memoryUseCheckCounter = 0;
+ // A default value for spilling a key to the record store.
+ value::MaterializedRow defaultVal{_outAggAccessors.size()};
+ bool updateAggStateHt = false;
while (_children[0]->getNext() == PlanState::ADVANCED) {
value::MaterializedRow key{_inKeyAccessors.size()};
// Copy keys in order to do the lookup.
@@ -193,34 +327,80 @@ void HashAggStage::open(bool reOpen) {
key.reset(idx++, false, tag, val);
}
- auto [it, inserted] = _ht->try_emplace(std::move(key), value::MaterializedRow{0});
- if (inserted) {
- // Copy keys.
- const_cast<value::MaterializedRow&>(it->first).makeOwned();
- // Initialize accumulators.
- it->second.resize(_outAggAccessors.size());
+ if (!_recordStore) {
+ // The memory limit hasn't been reached yet, accumulate state in '_ht'.
+ auto [it, inserted] = _ht->try_emplace(std::move(key), value::MaterializedRow{0});
+ if (inserted) {
+ // Copy keys.
+ const_cast<value::MaterializedRow&>(it->first).makeOwned();
+ // Initialize accumulators.
+ it->second.resize(_outAggAccessors.size());
+ }
+ // Always update the state in the '_ht' for the branch when data hasn't been
+ // spilled to disk.
+ _htIt = it;
+ updateAggStateHt = true;
+ } else {
+ // The memory limit has been reached, accumulate state in '_ht' only if we
+ // find the key in '_ht'.
+ auto it = _ht->find(key);
+ _htIt = it;
+ updateAggStateHt = _htIt != _ht->end();
}
- // Accumulate.
- _htIt = it;
- for (size_t idx = 0; idx < _outAggAccessors.size(); ++idx) {
- auto [owned, tag, val] = _bytecode.run(_aggCodes[idx].get());
- _outAggAccessors[idx]->reset(owned, tag, val);
+ if (updateAggStateHt) {
+ // Accumulate state in '_ht' by pointing the '_outAggAccessors' the
+ // '_outHashAggAccessors'.
+ for (size_t idx = 0; idx < _outAggAccessors.size(); ++idx) {
+ _outAggAccessors[idx]->setIndex(0);
+ auto [owned, tag, val] = _bytecode.run(_aggCodes[idx].get());
+ _outHashAggAccessors[idx]->reset(owned, tag, val);
+ }
+ } else {
+ // The memory limit has been reached and the key wasn't in the '_ht' so we need
+ // to spill it to the '_recordStore'.
+ KeyString::Builder kb{KeyString::Version::kLatestVersion};
+
+ // It's safe to ignore the use-after-move warning since it's logically impossible to
+ // enter this block after the move occurs.
+ key.serializeIntoKeyString(kb); // NOLINT(bugprone-use-after-move)
+ auto typeBits = kb.getTypeBits();
+
+ auto rid = RecordId(kb.getBuffer(), kb.getSize());
+
+ auto valFromRs = getFromRecordStore(rid);
+ if (!valFromRs) {
+ _aggValueRecordStore = defaultVal;
+ } else {
+ _aggValueRecordStore = *valFromRs;
+ }
+
+ for (size_t idx = 0; idx < _outAggAccessors.size(); ++idx) {
+ _outAggAccessors[idx]->setIndex(1);
+ auto [owned, tag, val] = _bytecode.run(_aggCodes[idx].get());
+ _aggValueRecordStore.reset(idx, owned, tag, val);
+ }
+ spillValueToDisk(rid, _aggValueRecordStore, typeBits, valFromRs ? true : false);
}
- // Track memory usage.
- auto shouldCalculateEstimatedSize =
- _pseudoRandom.nextCanonicalDouble() < _memoryUseSampleRate;
- if (shouldCalculateEstimatedSize || ++memoryUseCheckCounter % 100 == 0) {
- memoryUseCheckCounter = 0;
- long estimatedSizeForOneRow =
- it->first.memUsageForSorter() + it->second.memUsageForSorter();
- long long estimatedTotalSize = _ht->size() * estimatedSizeForOneRow;
-
- if (estimatedTotalSize >= _approxMemoryUseInBytesBeforeSpill) {
- // TODO SERVER-58436: Remove this uassert when spilling is implemented.
- uasserted(5859000, "Need to spill to disk");
- makeTemporaryRecordStore();
+ // Track memory usage only when we haven't started spilling to the '_recordStore'.
+ if (!_recordStore) {
+ auto shouldCalculateEstimatedSize =
+ _pseudoRandom.nextCanonicalDouble() < _memoryUseSampleRate;
+ if (shouldCalculateEstimatedSize || ++memoryUseCheckCounter % 100 == 0) {
+ memoryUseCheckCounter = 0;
+ long estimatedSizeForOneRow =
+ _htIt->first.memUsageForSorter() + _htIt->second.memUsageForSorter();
+ long long estimatedTotalSize = _ht->size() * estimatedSizeForOneRow;
+
+ if (estimatedTotalSize >= _approxMemoryUseInBytesBeforeSpill) {
+ uassert(
+ 5843601,
+ "Exceeded memory limit for $group, but didn't allow external spilling."
+ " Pass allowDiskUse:true to opt in.",
+ _allowDiskUse);
+ makeTemporaryRecordStore();
+ }
}
}
@@ -241,6 +421,7 @@ void HashAggStage::open(bool reOpen) {
}
}
+
if (!_seekKeysAccessors.empty()) {
// Copy keys in order to do the lookup.
size_t idx = 0;
@@ -251,13 +432,20 @@ void HashAggStage::open(bool reOpen) {
}
_htIt = _ht->end();
+
+ // Set the SwitchAccessors to point to the '_ht' so we can drain it first before draining the
+ // '_recordStore' in getNext().
+ for (size_t idx = 0; idx < _outAggAccessors.size(); ++idx) {
+ _outAggAccessors[idx]->setIndex(0);
+ }
+ _drainingRecordStore = false;
}
PlanState HashAggStage::getNext() {
auto optTimer(getOptTimer(_opCtx));
- if (_htIt == _ht->end()) {
- // First invocation of getNext() after open().
+ if (_htIt == _ht->end() && !_drainingRecordStore) {
+ // First invocation of getNext() after open() when not draining the '_recordStore'.
if (!_seekKeysAccessors.empty()) {
_htIt = _ht->find(_seekKeys);
} else {
@@ -266,20 +454,56 @@ PlanState HashAggStage::getNext() {
} else if (!_seekKeysAccessors.empty()) {
// Subsequent invocation with seek keys. Return only 1 single row (if any).
_htIt = _ht->end();
- } else {
- // Returning the results of the entire hash table.
+ } else if (!_drainingRecordStore) {
+ // Returning the results of the entire hash table first before draining the '_recordStore'.
++_htIt;
}
- if (_htIt == _ht->end()) {
- if (_recordStore && _seekKeysAccessors.empty()) {
- // A record store was created to spill to disk. Clean it up.
+ if (_htIt == _ht->end() && !_recordStore) {
+ // The hash table has been drained and nothing was spilled to disk.
+ return trackPlanState(PlanState::IS_EOF);
+ } else if (_htIt != _ht->end()) {
+ // Drain the '_ht' on the next 'getNext()' call.
+ return trackPlanState(PlanState::ADVANCED);
+ } else if (_seekKeysAccessors.empty()) {
+ // A record store was created to spill to disk. Drain it then clean it up.
+ if (!_rsCursor) {
+ _rsCursor = _recordStore->rs()->getCursor(_opCtx);
+ }
+ auto nextRecord = _rsCursor->next();
+ if (nextRecord) {
+ // Point the out accessors to the recordStore accessors to allow parent stages to read
+ // the agg state from the '_recordStore'.
+ if (!_drainingRecordStore) {
+ for (size_t i = 0; i < _outKeyAccessors.size(); ++i) {
+ _outKeyAccessors[i]->setIndex(1);
+ }
+ for (size_t i = 0; i < _outAggAccessors.size(); ++i) {
+ _outAggAccessors[i]->setIndex(1);
+ }
+ }
+ _drainingRecordStore = true;
+
+ // Read the agg state value from the '_recordStore' and Reconstruct the key from the
+ // typeBits stored along side of the value.
+ BufReader valReader(nextRecord->data.data(), nextRecord->data.size());
+ auto val = value::MaterializedRow::deserializeForSorter(valReader, {});
+ auto typeBits =
+ KeyString::TypeBits::fromBuffer(KeyString::Version::kLatestVersion, &valReader);
+ _aggValueRecordStore = val;
+
+ BufBuilder buf;
+ _aggKeyRecordStore = value::MaterializedRow::deserializeFromKeyString(
+ rehydrateKey(nextRecord->id, typeBits), &buf);
+ return trackPlanState(PlanState::ADVANCED);
+ } else {
+ _rsCursor.reset();
_recordStore.reset();
+ return trackPlanState(PlanState::IS_EOF);
}
- return trackPlanState(PlanState::IS_EOF);
+ } else {
+ return trackPlanState(PlanState::ADVANCED);
}
-
- return trackPlanState(PlanState::ADVANCED);
}
std::unique_ptr<PlanStageStats> HashAggStage::getStats(bool includeDebugInfo) const {
@@ -314,6 +538,7 @@ void HashAggStage::close() {
if (_recordStore) {
// A record store was created to spill to disk. Clean it up.
_recordStore.reset();
+ _drainingRecordStore = false;
}
if (_childOpened) {
diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.h b/src/mongo/db/exec/sbe/stages/hash_agg.h
index 557f95c7f32..afb8bea06f3 100644
--- a/src/mongo/db/exec/sbe/stages/hash_agg.h
+++ b/src/mongo/db/exec/sbe/stages/hash_agg.h
@@ -74,6 +74,7 @@ public:
value::SlotVector seekKeysSlots,
bool optimizedClose,
boost::optional<value::SlotId> collatorSlot,
+ bool allowDiskUse,
PlanNodeId planNodeId);
std::unique_ptr<PlanStage> clone() const final;
@@ -90,12 +91,16 @@ public:
size_t estimateCompileTimeSize() const final;
protected:
+ void doSaveState(bool relinquishCursor) override;
+ void doRestoreState(bool relinquishCursor) override;
+ void doDetachFromOperationContext() override;
+ void doAttachToOperationContext(OperationContext* opCtx) override;
void doDetachFromTrialRunTracker() override;
TrialRunTrackerAttachResultMask doAttachToTrialRunTracker(
TrialRunTracker* tracker, TrialRunTrackerAttachResultMask childrenAttachResult) override;
private:
- void makeTemporaryRecordStore();
+ boost::optional<value::MaterializedRow> getFromRecordStore(const RecordId& rid);
using TableType = stdx::unordered_map<value::MaterializedRow,
value::MaterializedRow,
@@ -105,9 +110,32 @@ private:
using HashKeyAccessor = value::MaterializedRowKeyAccessor<TableType::iterator>;
using HashAggAccessor = value::MaterializedRowValueAccessor<TableType::iterator>;
+ void makeTemporaryRecordStore();
+
+ /**
+ * Spills a key and value pair to the '_recordStore' where the semantics are insert or update
+ * depending on the 'update' flag. When the 'update' flag is true this method already expects
+ * the 'key' to be inserted into the '_recordStore', otherwise the 'key' and 'val' pair are
+ * fresh.
+ *
+ * This method expects the key to be seralized into a KeyString::Value so that the key is
+ * memcmp-able and lookups can be done to update the 'val' in the '_recordStore'. Note that the
+ * 'typeBits' are needed to reconstruct the spilled 'key' when calling 'getNext' to deserialize
+ * the 'key' to a MaterializedRow. Since the '_recordStore' only stores the memcmp-able part of
+ * the KeyString we need to carry the 'typeBits' separately, and we do this by appending the
+ * 'typeBits' to the end of the serialized 'val' buffer and store them at the leaves of the
+ * backing B-tree of the '_recordStore'. used as the RecordId.
+ */
+ void spillValueToDisk(const RecordId& key,
+ const value::MaterializedRow& val,
+ const KeyString::TypeBits& typeBits,
+ bool update);
+
+
const value::SlotVector _gbs;
const value::SlotMap<std::unique_ptr<EExpression>> _aggs;
const boost::optional<value::SlotId> _collatorSlot;
+ const bool _allowDiskUse;
const value::SlotVector _seekKeysSlots;
// When this operator does not expect to be reopened (almost always) then it can close the child
// early.
@@ -122,12 +150,25 @@ private:
value::SlotAccessorMap _outAccessors;
std::vector<value::SlotAccessor*> _inKeyAccessors;
- std::vector<std::unique_ptr<HashKeyAccessor>> _outKeyAccessors;
+
+ // Accesors for the key stored in '_ht', a SwitchAccessor is used so we can produce the key from
+ // either the '_ht' or the '_recordStore'.
+ std::vector<std::unique_ptr<HashKeyAccessor>> _outHashKeyAccessors;
+ std::vector<std::unique_ptr<value::SwitchAccessor>> _outKeyAccessors;
+
+ // Accessor for the agg state value stored in the '_recordStore' when data is spilled to disk.
+ value::MaterializedRow _aggKeyRecordStore{0};
+ value::MaterializedRow _aggValueRecordStore{0};
+ std::vector<std::unique_ptr<value::MaterializedSingleRowAccessor>> _outRecordStoreKeyAccessors;
+ std::vector<std::unique_ptr<value::MaterializedSingleRowAccessor>> _outRecordStoreAggAccessors;
std::vector<value::SlotAccessor*> _seekKeysAccessors;
value::MaterializedRow _seekKeys;
- std::vector<std::unique_ptr<HashAggAccessor>> _outAggAccessors;
+ // Accesors for the agg state in '_ht', a SwitchAccessor is used so we can produce the agg state
+ // from either the '_ht' or the '_recordStore' when draining the HashAgg stage.
+ std::vector<std::unique_ptr<value::SwitchAccessor>> _outAggAccessors;
+ std::vector<std::unique_ptr<HashAggAccessor>> _outHashAggAccessors;
std::vector<std::unique_ptr<vm::CodeFragment>> _aggCodes;
// Only set if collator slot provided on construction.
@@ -143,6 +184,8 @@ private:
// Used when spilling to disk.
std::unique_ptr<TemporaryRecordStore> _recordStore;
+ bool _drainingRecordStore{false};
+ std::unique_ptr<SeekableRecordCursor> _rsCursor;
// If provided, used during a trial run to accumulate certain execution stats. Once the trial
// run is complete, this pointer is reset to nullptr.
diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp
index c45ec2aab16..bb9fe54d46b 100644
--- a/src/mongo/db/pipeline/pipeline_d.cpp
+++ b/src/mongo/db/pipeline/pipeline_d.cpp
@@ -132,7 +132,7 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleGr
if (!feature_flags::gFeatureFlagSBEGroupPushdown.isEnabled(
serverGlobalParams.featureCompatibility) ||
- cq->getForceClassicEngine() || expCtx->allowDiskUse || queryNeedsSubplanning) {
+ cq->getForceClassicEngine() || queryNeedsSubplanning) {
return {};
}
diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp
index 64296a23db0..575f054d62c 100644
--- a/src/mongo/db/query/sbe_stage_builder.cpp
+++ b/src/mongo/db/query/sbe_stage_builder.cpp
@@ -657,7 +657,8 @@ SlotBasedStageBuilder::SlotBasedStageBuilder(OperationContext* opCtx,
&_slotIdGenerator,
&_frameIdGenerator,
&_spoolIdGenerator,
- _cq.getExpCtx()->needsMerge) {
+ _cq.getExpCtx()->needsMerge,
+ _cq.getExpCtx()->allowDiskUse) {
// SERVER-52803: In the future if we need to gather more information from the QuerySolutionNode
// tree, rather than doing one-off scans for each piece of information, we should add a formal
// analysis pass here.
@@ -2432,6 +2433,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
{groupBySlot},
std::move(accSlotToExprMap),
_state.env->getSlotIfExists("collator"_sd),
+ _cq.getExpCtx()->allowDiskUse,
nodeId);
tassert(
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index a9560298914..2675c2a25a1 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -1193,6 +1193,7 @@ public:
sbe::makeSV(),
sbe::makeEM(finalGroupSlot, std::move(finalAddToArrayExpr)),
collatorSlot,
+ _context->state.allowDiskUse,
_context->planNodeId);
// Returns true if any of our input expressions return null.
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
index 700454de8de..531e0d66bd2 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
@@ -455,6 +455,7 @@ EvalStage makeHashAgg(EvalStage stage,
sbe::value::SlotVector gbs,
sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> aggs,
boost::optional<sbe::value::SlotId> collatorSlot,
+ bool allowDiskUse,
PlanNodeId planNodeId) {
stage.outSlots = gbs;
for (auto& [slot, _] : aggs) {
@@ -466,6 +467,7 @@ EvalStage makeHashAgg(EvalStage stage,
sbe::makeSV(),
true /* optimized close */,
collatorSlot,
+ allowDiskUse,
planNodeId);
return stage;
}
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h
index 68441264e0f..78ab7724eb8 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.h
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.h
@@ -417,6 +417,7 @@ EvalStage makeHashAgg(EvalStage stage,
sbe::value::SlotVector gbs,
sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> aggs,
boost::optional<sbe::value::SlotId> collatorSlot,
+ bool allowDiskUse,
PlanNodeId planNodeId);
EvalStage makeMkBsonObj(EvalStage stage,
@@ -875,14 +876,16 @@ struct StageBuilderState {
sbe::value::SlotIdGenerator* slotIdGenerator,
sbe::value::FrameIdGenerator* frameIdGenerator,
sbe::value::SpoolIdGenerator* spoolIdGenerator,
- bool needsMerge)
+ bool needsMerge,
+ bool allowDiskUse)
: slotIdGenerator{slotIdGenerator},
frameIdGenerator{frameIdGenerator},
spoolIdGenerator{spoolIdGenerator},
opCtx{opCtx},
env{env},
variables{variables},
- needsMerge{needsMerge} {}
+ needsMerge{needsMerge},
+ allowDiskUse{allowDiskUse} {}
StageBuilderState(const StageBuilderState& other) = delete;
@@ -912,6 +915,10 @@ struct StageBuilderState {
// When the mongos splits $group stage and sends it to shards, it adds 'needsMerge'/'fromMongs'
// flags to true so that shards can sends special partial aggregation results to the mongos.
bool needsMerge;
+
+ // A flag to indicate the user allows disk use for spilling.
+ bool allowDiskUse;
+
// This map is used to plumb through pre-generated field expressions ('sbe::EExpression')
// corresponding to field paths to 'generateExpression' to avoid repeated expression generation.
// Key is expected to represent field paths in form CURRENT.<field_name>[.<field_name>]*.