summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Storch <david.storch@mongodb.com>2022-11-28 21:39:33 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-11-28 22:33:20 +0000
commite2d6db6b489a9f4514c7d66462d6f3b0d1835bd4 (patch)
tree9398ce4f6c2ad44411caa8c5b329c67e5b1625fb
parent24d547dfe6adf659165ced87b135d7245d085c03 (diff)
downloadmongo-e2d6db6b489a9f4514c7d66462d6f3b0d1835bd4.tar.gz
SERVER-70395 Make stage builders generate partial agg combining exprs (part 1)
This patch handles $min, $max, $first, and $last. The remaining accumulators will be implemented as follow-up work.
-rw-r--r--src/mongo/db/exec/sbe/expression_test_base.h18
-rw-r--r--src/mongo/db/query/sbe_stage_builder.cpp4
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator.cpp162
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator.h19
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp227
5 files changed, 390 insertions, 40 deletions
diff --git a/src/mongo/db/exec/sbe/expression_test_base.h b/src/mongo/db/exec/sbe/expression_test_base.h
index 2cc0f8ded6d..e3565016243 100644
--- a/src/mongo/db/exec/sbe/expression_test_base.h
+++ b/src/mongo/db/exec/sbe/expression_test_base.h
@@ -73,6 +73,24 @@ protected:
}
/**
+ * Compiles 'expr' to bytecode when 'expr' is computing an aggregate. The current aggregate
+ * value can be read out of the provided 'aggAccessor'.
+ *
+ * Note that when actually executing the resulting bytecode, the caller is responsible for
+ * setting the value of 'aggAccessor' to the new resulting aggregate value.
+ */
+ std::unique_ptr<vm::CodeFragment> compileAggExpression(const EExpression& expr,
+ value::SlotAccessor* aggAccessor) {
+ ON_BLOCK_EXIT([this] {
+ _ctx.aggExpression = false;
+ _ctx.accumulator = nullptr;
+ });
+ _ctx.aggExpression = true;
+ _ctx.accumulator = aggAccessor;
+ return expr.compile(_ctx);
+ }
+
+ /**
* The caller takes ownership of the Value returned by this function and must call
* 'releaseValue()' on it. The preferred way to ensure the Value is properly released is to
* immediately store it in a ValueGuard.
diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp
index a3d3bf854e8..c0da82f8fe9 100644
--- a/src/mongo/db/query/sbe_stage_builder.cpp
+++ b/src/mongo/db/query/sbe_stage_builder.cpp
@@ -2407,7 +2407,9 @@ std::tuple<sbe::value::SlotVector, EvalStage> generateAccumulator(
// One accumulator may be translated to multiple accumulator expressions. For example, The
// $avg will have two accumulators expressions, a sum(..) and a count which is implemented
// as sum(1).
- auto accExprs = stage_builder::buildAccumulator(state, accStmt, std::move(argExpr));
+ auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
+ auto accExprs = stage_builder::buildAccumulator(
+ accStmt, std::move(argExpr), collatorSlot, *state.frameIdGenerator);
sbe::value::SlotVector aggSlots;
for (auto& accExpr : accExprs) {
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
index 08f8cfee02c..07a863a2988 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
@@ -42,9 +42,9 @@
namespace mongo::stage_builder {
namespace {
-std::unique_ptr<sbe::EExpression> wrapMinMaxArg(StageBuilderState& state,
- std::unique_ptr<sbe::EExpression> arg) {
- return makeLocalBind(state.frameIdGenerator,
+std::unique_ptr<sbe::EExpression> wrapMinMaxArg(std::unique_ptr<sbe::EExpression> arg,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ return makeLocalBind(&frameIdGenerator,
[](sbe::EVariable input) {
return sbe::makeE<sbe::EIf>(
generateNullOrMissing(input),
@@ -55,21 +55,33 @@ std::unique_ptr<sbe::EExpression> wrapMinMaxArg(StageBuilderState& state,
}
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorMin(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
- auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
if (collatorSlot) {
aggs.push_back(makeFunction("collMin"_sd,
sbe::makeE<sbe::EVariable>(*collatorSlot),
- wrapMinMaxArg(state, std::move(arg))));
+ wrapMinMaxArg(std::move(arg), frameIdGenerator)));
} else {
- aggs.push_back(makeFunction("min"_sd, wrapMinMaxArg(state, std::move(arg))));
+ aggs.push_back(makeFunction("min"_sd, wrapMinMaxArg(std::move(arg), frameIdGenerator)));
}
return aggs;
}
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsMin(
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ tassert(7039501,
+ "partial agg combiner for $min should have exactly one input slot",
+ inputSlots.size() == 1);
+ auto arg = makeVariable(inputSlots[0]);
+ return buildAccumulatorMin(expr, std::move(arg), collatorSlot, frameIdGenerator);
+}
+
std::unique_ptr<sbe::EExpression> buildFinalizeMin(StageBuilderState& state,
const AccumulationExpression& expr,
const sbe::value::SlotVector& minSlots) {
@@ -84,21 +96,33 @@ std::unique_ptr<sbe::EExpression> buildFinalizeMin(StageBuilderState& state,
}
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorMax(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
- auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
if (collatorSlot) {
aggs.push_back(makeFunction("collMax"_sd,
sbe::makeE<sbe::EVariable>(*collatorSlot),
- wrapMinMaxArg(state, std::move(arg))));
+ wrapMinMaxArg(std::move(arg), frameIdGenerator)));
} else {
- aggs.push_back(makeFunction("max"_sd, wrapMinMaxArg(state, std::move(arg))));
+ aggs.push_back(makeFunction("max"_sd, wrapMinMaxArg(std::move(arg), frameIdGenerator)));
}
return aggs;
}
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsMax(
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ tassert(7039502,
+ "partial agg combiner for $max should have exactly one input slot",
+ inputSlots.size() == 1);
+ auto arg = makeVariable(inputSlots[0]);
+ return buildAccumulatorMax(expr, std::move(arg), collatorSlot, frameIdGenerator);
+}
+
std::unique_ptr<sbe::EExpression> buildFinalizeMax(StageBuilderState& state,
const AccumulationExpression& expr,
const sbe::value::SlotVector& maxSlots) {
@@ -111,34 +135,61 @@ std::unique_ptr<sbe::EExpression> buildFinalizeMax(StageBuilderState& state,
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorFirst(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
aggs.push_back(makeFunction("first", makeFillEmptyNull(std::move(arg))));
return aggs;
}
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsFirst(
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ tassert(7039503,
+ "partial agg combiner for $first should have exactly one input slot",
+ inputSlots.size() == 1);
+ auto arg = makeVariable(inputSlots[0]);
+ return buildAccumulatorFirst(expr, std::move(arg), collatorSlot, frameIdGenerator);
+}
+
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorLast(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
aggs.push_back(makeFunction("last", makeFillEmptyNull(std::move(arg))));
return aggs;
}
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsLast(
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ tassert(7039504,
+ "partial agg combiner for $last should have exactly one input slot",
+ inputSlots.size() == 1);
+ auto arg = makeVariable(inputSlots[0]);
+ return buildAccumulatorLast(expr, std::move(arg), collatorSlot, frameIdGenerator);
+}
+
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorAvg(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
// 'aggDoubleDoubleSum' will ignore non-numeric values automatically.
aggs.push_back(makeFunction("aggDoubleDoubleSum", arg->clone()));
// For the counter we need to skip non-numeric values ourselves.
- auto addend = makeLocalBind(state.frameIdGenerator,
+ auto addend = makeLocalBind(&frameIdGenerator,
[](sbe::EVariable input) {
return sbe::makeE<sbe::EIf>(
makeBinaryOp(sbe::EPrimBinary::logicOr,
@@ -225,9 +276,10 @@ getCountAddend(const AccumulationExpression& expr) {
} // namespace
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorSum(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
// Optimize for a count-like accumulator like {$sum: 1}.
@@ -274,12 +326,12 @@ std::unique_ptr<sbe::EExpression> buildFinalizeSum(StageBuilderState& state,
}
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorAddToSet(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
const int cap = internalQueryMaxAddToSetBytes.load();
- auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
if (collatorSlot) {
aggs.push_back(makeFunction(
"collAddToSetCapped"_sd,
@@ -317,9 +369,10 @@ std::unique_ptr<sbe::EExpression> buildFinalizeCappedAccumulator(
}
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorPush(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
const int cap = internalQueryMaxPushBytes.load();
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
aggs.push_back(makeFunction(
@@ -330,9 +383,10 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorPush(
}
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorStdDev(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
aggs.push_back(makeFunction("aggStdDev", std::move(arg)));
return aggs;
@@ -400,13 +454,14 @@ std::unique_ptr<sbe::EExpression> buildFinalizeStdDevSamp(
}
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorMergeObjects(
- StageBuilderState& state,
const AccumulationExpression& expr,
- std::unique_ptr<sbe::EExpression> arg) {
+ std::unique_ptr<sbe::EExpression> arg,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
auto filterExpr =
- makeLocalBind(state.frameIdGenerator,
+ makeLocalBind(&frameIdGenerator,
[](sbe::EVariable input) {
auto typeCheckExpr =
makeBinaryOp(sbe::EPrimBinary::logicOr,
@@ -438,11 +493,15 @@ std::pair<std::unique_ptr<sbe::EExpression>, EvalStage> buildArgument(
}
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
- StageBuilderState& state,
const AccumulationStatement& acc,
- std::unique_ptr<sbe::EExpression> inputExpr) {
+ std::unique_ptr<sbe::EExpression> argExpr,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
using BuildAccumulatorFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>(
- StageBuilderState&, const AccumulationExpression&, std::unique_ptr<sbe::EExpression>)>;
+ const AccumulationExpression&,
+ std::unique_ptr<sbe::EExpression>,
+ boost::optional<sbe::value::SlotId>,
+ sbe::value::FrameIdGenerator&)>;
static const StringDataMap<BuildAccumulatorFn> kAccumulatorBuilders = {
{AccumulatorMin::kName, &buildAccumulatorMin},
@@ -463,7 +522,38 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName,
kAccumulatorBuilders.find(accExprName) != kAccumulatorBuilders.end());
- return std::invoke(kAccumulatorBuilders.at(accExprName), state, acc.expr, std::move(inputExpr));
+ return std::invoke(kAccumulatorBuilders.at(accExprName),
+ acc.expr,
+ std::move(argExpr),
+ collatorSlot,
+ frameIdGenerator);
+}
+
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates(
+ const AccumulationStatement& acc,
+ const sbe::value::SlotVector& inputSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ using BuildAggCombinerFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>(
+ const AccumulationExpression&,
+ const sbe::value::SlotVector&,
+ boost::optional<sbe::value::SlotId>,
+ sbe::value::FrameIdGenerator&)>;
+
+ static const StringDataMap<BuildAggCombinerFn> kAggCombinerBuilders = {
+ {AccumulatorFirst::kName, &buildCombinePartialAggsFirst},
+ {AccumulatorLast::kName, &buildCombinePartialAggsLast},
+ {AccumulatorMax::kName, &buildCombinePartialAggsMax},
+ {AccumulatorMin::kName, &buildCombinePartialAggsMin},
+ };
+
+ auto accExprName = acc.expr.name;
+ uassert(7039500,
+ str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName,
+ kAggCombinerBuilders.find(accExprName) != kAggCombinerBuilders.end());
+
+ return std::invoke(
+ kAggCombinerBuilders.at(accExprName), acc.expr, inputSlots, collatorSlot, frameIdGenerator);
}
std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state,
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.h b/src/mongo/db/query/sbe_stage_builder_accumulator.h
index 6d8b1a5b112..cd201508823 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator.h
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator.h
@@ -54,12 +54,25 @@ std::pair<std::unique_ptr<sbe::EExpression>, EvalStage> buildArgument(
/**
* Translates an input AccumulationStatement into an SBE EExpression for accumulation expressions.
- * The 'stage' parameter provides the input subtree to build on top of.
*/
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
- StageBuilderState& state,
const AccumulationStatement& acc,
- std::unique_ptr<sbe::EExpression> argExpr);
+ std::unique_ptr<sbe::EExpression> argExpr,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator&);
+
+/**
+ * When SBE hash aggregation spills to disk, it spills partial aggregates which need to be combined
+ * later. This function returns the expressions that can be used to combine partial aggregates for
+ * the given accumulator 'acc'. The aggregate-of-aggregates will be stored in a slots owned by the
+ * hash agg stage, while the new partial aggregates to combine can be read from the given
+ * 'inputSlots'.
+ */
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates(
+ const AccumulationStatement& acc,
+ const sbe::value::SlotVector& inputSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator&);
/**
* Translates an input AccumulationStatement into an SBE EExpression that represents an
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
index 395840ccbbd..ff4e958984c 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
@@ -31,10 +31,12 @@
#include <fmt/printf.h>
+#include "mongo/db/exec/sbe/expression_test_base.h"
#include "mongo/db/pipeline/document_source_group.h"
#include "mongo/db/pipeline/expression_context_for_test.h"
#include "mongo/db/query/collation/collator_interface_mock.h"
#include "mongo/db/query/query_solution.h"
+#include "mongo/db/query/sbe_stage_builder_accumulator.h"
#include "mongo/db/query/sbe_stage_builder_test_fixture.h"
#include "mongo/unittest/unittest.h"
@@ -1677,4 +1679,229 @@ TEST_F(SbeStageBuilderGroupTest, SbeIncompatibleExpressionInGroup) {
}
}
+/**
+ * A test fixture designed to test that the expressions generated to combine partial aggregates
+ * that have been spilled to disk work correctly. We use 'EExpressionTestFixture' rather than
+ * something like 'SbeStageBuilderTestFixture' so that the expressions can be tested in isolation,
+ * without actually requiring a hash agg stage or without actually spilling any data to disk.
+ */
+class SbeStageBuilderGroupAggCombinerTest : public sbe::EExpressionTestFixture {
+public:
+ explicit SbeStageBuilderGroupAggCombinerTest()
+ : _expCtx{make_intrusive<ExpressionContextForTest>()},
+ _inputSlotId{bindAccessor(&_inputAccessor)},
+ _collatorSlotId{bindAccessor(&_collatorAccessor)} {}
+
+ AccumulationStatement makeAccumulationStatement(StringData accumName) {
+ _accumulationStmtBson = BSON("unused" << BSON(accumName << "unused"));
+ VariablesParseState vps = _expCtx->variablesParseState;
+ return AccumulationStatement::parseAccumulationStatement(
+ _expCtx.get(), _accumulationStmtBson.firstElement(), vps);
+ }
+
+ /**
+ * Verifies that executing the bytecode ('code') for combining partial aggregates for $group
+ * spilling produces the 'expected' outputs given 'inputs'.
+ *
+ * The inputs and expected outputs are expressed as BSON arrays as a convenience to the caller,
+ * and should have the same length. The bytecode is executed over each element of 'inputs'
+ * one-by-one, with the result stored into a slot holding the aggregate value. At each step,
+ * this function asserts that the current aggregate value is equal to the matching element in
+ * 'expected'.
+ *
+ * The string "MISSING" can be used as a sentinel in either 'inputs' or 'outputs' in order to
+ * represent the Nothing value (since nothingness cannot literally be stored in a BSON array).
+ */
+ void aggregateAndAssertResults(BSONArray inputs,
+ BSONArray expected,
+ const sbe::vm::CodeFragment* code) {
+ // Make sure we are starting from a clean state.
+ _inputAccessor.reset();
+ _aggAccessor.reset();
+
+ auto [inputTag, inputVal] = makeArray(inputs);
+ sbe::value::ValueGuard inputGuard{inputTag, inputVal};
+ auto [expectedTag, expectedVal] = makeArray(expected);
+ sbe::value::ValueGuard expectedGuard{expectedTag, expectedVal};
+
+ sbe::value::ArrayEnumerator inputEnumerator{inputTag, inputVal};
+ sbe::value::ArrayEnumerator expectedEnumerator{expectedTag, expectedVal};
+
+ // Aggregate the inputs one-by-one, and at each step validate that the resulting accumulator
+ // state is as expected.
+ while (!inputEnumerator.atEnd()) {
+ ASSERT_FALSE(expectedEnumerator.atEnd());
+ auto [nextInputTag, nextInputVal] = inputEnumerator.getViewOfValue();
+
+ // Feed in the input value, treating "MISSING" as a special sentinel to indicate the
+ // Nothing value.
+ if (sbe::value::isString(nextInputTag) &&
+ sbe::value::getStringView(nextInputTag, nextInputVal) == "MISSING"_sd) {
+ _inputAccessor.reset();
+ } else {
+ auto [copyTag, copyVal] = sbe::value::copyValue(nextInputTag, nextInputVal);
+ _inputAccessor.reset(true, copyTag, copyVal);
+ }
+
+ auto [outputTag, outputVal] = runCompiledExpression(code);
+
+ // Validate that the output value equals the expected value, and then put the output
+ // value into the slot that holds the accumulation state.
+ auto [expectedOutputTag, expectedOutputValue] = expectedEnumerator.getViewOfValue();
+ if (sbe::value::isString(expectedOutputTag) &&
+ sbe::value::getStringView(expectedOutputTag, expectedOutputValue) == "MISSING"_sd) {
+ expectedOutputTag = sbe::value::TypeTags::Nothing;
+ expectedOutputValue = 0;
+ }
+ auto [compareTag, compareValue] = sbe::value::compareValue(
+ outputTag, outputVal, expectedOutputTag, expectedOutputValue);
+ ASSERT_EQ(compareTag, sbe::value::TypeTags::NumberInt32);
+ ASSERT_EQ(compareValue, 0);
+ _aggAccessor.reset(true, outputTag, outputVal);
+
+ inputEnumerator.advance();
+ expectedEnumerator.advance();
+ }
+ }
+
+ /**
+ * Convenience method for producing bytecode which combines partial aggregates for the given
+ * 'AccumulationStatement'.
+ *
+ * Requires that accumulation statement results in a single aggregate with one input and one
+ * output. Furthermore, cannot be used when the test case involves a non-simple collation.
+ */
+ std::unique_ptr<sbe::vm::CodeFragment> compileSingleInputNoCollator(
+ const AccumulationStatement& accStatement) {
+ auto exprs = stage_builder::buildCombinePartialAggregates(
+ accStatement, {_inputSlotId}, boost::none, _frameIdGenerator);
+ ASSERT_EQ(exprs.size(), 1u);
+ _expr = std::move(exprs[0]);
+
+ return compileAggExpression(*_expr, &_aggAccessor);
+ }
+
+protected:
+ sbe::value::FrameIdGenerator _frameIdGenerator;
+ boost::intrusive_ptr<ExpressionContextForTest> _expCtx;
+
+ // Accessor and corresponding slot id that holds the input to the agg expression. Each time we
+ // "turn the crank" this will hold the next partial aggregate to be aggregated into
+ // '_aggAccessor'.
+ sbe::value::OwnedValueAccessor _inputAccessor;
+ sbe::value::SlotId _inputSlotId;
+
+ // The accessor which holds the final output resulting from combining all partial outputs. We
+ // check that the intermediate value is as expected after every turn of the crank.
+ sbe::value::OwnedValueAccessor _aggAccessor;
+
+ sbe::value::OwnedValueAccessor _collatorAccessor;
+ sbe::value::SlotId _collatorSlotId;
+
+private:
+ BSONObj _accumulationStmtBson;
+ std::unique_ptr<sbe::EExpression> _expr;
+};
+
+TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMin) {
+ auto accStatement = makeAccumulationStatement("$min"_sd);
+ auto compiledExpr = compileSingleInputNoCollator(accStatement);
+
+ auto inputValues = BSON_ARRAY(8 << 7 << 9 << BSONNULL << 6);
+ auto expectedAggStates = BSON_ARRAY(8 << 7 << 7 << 7 << 6);
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+
+ // Test that Nothing values are treated as expected.
+ inputValues = BSON_ARRAY("MISSING" << 9 << 7 << "MISSING" << 6);
+ expectedAggStates = BSON_ARRAY("MISSING" << 9 << 7 << 7 << 6);
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+}
+
+TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMinWithCollation) {
+ auto accStatement = makeAccumulationStatement("$min"_sd);
+
+ auto exprs = stage_builder::buildCombinePartialAggregates(
+ accStatement, {_inputSlotId}, {_collatorSlotId}, _frameIdGenerator);
+ ASSERT_EQ(exprs.size(), 1u);
+ auto expr = std::move(exprs[0]);
+
+ CollatorInterfaceMock collator{CollatorInterfaceMock::MockType::kReverseString};
+ _collatorAccessor.reset(false,
+ sbe::value::TypeTags::collator,
+ sbe::value::bitcastFrom<const CollatorInterface*>(&collator));
+
+ auto compiledExpr = compileAggExpression(*expr, &_aggAccessor);
+
+ // The strings in reverse have the opposite ordering as compared to forwards.
+ auto inputValues = BSON_ARRAY("az"
+ << "by"
+ << "cx");
+ auto expectedAggStates = BSON_ARRAY("az"
+ << "by"
+ << "cx");
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+}
+
+TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMax) {
+ auto accStatement = makeAccumulationStatement("$max"_sd);
+ auto compiledExpr = compileSingleInputNoCollator(accStatement);
+
+ auto inputValues = BSON_ARRAY(3 << 1 << 4 << BSONNULL << 8);
+ auto expectedAggStates = BSON_ARRAY(3 << 3 << 4 << 4 << 8);
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+
+ // Test that Nothing values are treated as expected.
+ inputValues = BSON_ARRAY("MISSING" << 7 << 9 << "MISSING" << 10);
+ expectedAggStates = BSON_ARRAY("MISSING" << 7 << 9 << 9 << 10);
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+}
+
+TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMaxWithCollation) {
+ auto accStatement = makeAccumulationStatement("$max"_sd);
+
+ auto exprs = stage_builder::buildCombinePartialAggregates(
+ accStatement, {_inputSlotId}, {_collatorSlotId}, _frameIdGenerator);
+ ASSERT_EQ(exprs.size(), 1u);
+ auto expr = std::move(exprs[0]);
+
+ CollatorInterfaceMock collator{CollatorInterfaceMock::MockType::kReverseString};
+ _collatorAccessor.reset(false,
+ sbe::value::TypeTags::collator,
+ sbe::value::bitcastFrom<const CollatorInterface*>(&collator));
+
+ auto compiledExpr = compileAggExpression(*expr, &_aggAccessor);
+
+ // The strings in reverse have the opposite ordering as compared to forwards.
+ auto inputValues = BSON_ARRAY("cx"
+ << "by"
+ << "az");
+ auto expectedAggStates = BSON_ARRAY("cx"
+ << "by"
+ << "az");
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+}
+
+TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsFirst) {
+ auto accStatement = makeAccumulationStatement("$first"_sd);
+ auto compiledExpr = compileSingleInputNoCollator(accStatement);
+
+ auto inputValues = BSON_ARRAY(3 << 1 << BSONNULL << "MISSING" << 8);
+ auto expectedAggStates = BSON_ARRAY(3 << 3 << 3 << 3 << 3);
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+
+ // When the first value is missing, the resulting value is a literal null.
+ inputValues = BSON_ARRAY("MISSING" << 1 << BSONNULL << "MISSING" << 8);
+ expectedAggStates = BSON_ARRAY(BSONNULL << BSONNULL << BSONNULL << BSONNULL << BSONNULL);
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+}
+
+TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsLast) {
+ auto accStatement = makeAccumulationStatement("$last"_sd);
+ auto compiledExpr = compileSingleInputNoCollator(accStatement);
+
+ auto inputValues = BSON_ARRAY(3 << 1 << BSONNULL << "MISSING" << 8);
+ auto expectedAggStates = BSON_ARRAY(3 << 1 << BSONNULL << BSONNULL << 8);
+ aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get());
+}
+
} // namespace mongo