diff options
author | David Storch <david.storch@mongodb.com> | 2022-11-28 21:39:33 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-11-28 22:33:20 +0000 |
commit | e2d6db6b489a9f4514c7d66462d6f3b0d1835bd4 (patch) | |
tree | 9398ce4f6c2ad44411caa8c5b329c67e5b1625fb | |
parent | 24d547dfe6adf659165ced87b135d7245d085c03 (diff) | |
download | mongo-e2d6db6b489a9f4514c7d66462d6f3b0d1835bd4.tar.gz |
SERVER-70395 Make stage builders generate partial agg combining exprs (part 1)
This patch handles $min, $max, $first, and $last. The
remaining accumulators will be implemented as follow-up
work.
-rw-r--r-- | src/mongo/db/exec/sbe/expression_test_base.h | 18 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder.cpp | 4 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_accumulator.cpp | 162 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_accumulator.h | 19 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp | 227 |
5 files changed, 390 insertions, 40 deletions
diff --git a/src/mongo/db/exec/sbe/expression_test_base.h b/src/mongo/db/exec/sbe/expression_test_base.h index 2cc0f8ded6d..e3565016243 100644 --- a/src/mongo/db/exec/sbe/expression_test_base.h +++ b/src/mongo/db/exec/sbe/expression_test_base.h @@ -73,6 +73,24 @@ protected: } /** + * Compiles 'expr' to bytecode when 'expr' is computing an aggregate. The current aggregate + * value can be read out of the provided 'aggAccessor'. + * + * Note that when actually executing the resulting bytecode, the caller is responsible for + * setting the value of 'aggAccessor' to the new resulting aggregate value. + */ + std::unique_ptr<vm::CodeFragment> compileAggExpression(const EExpression& expr, + value::SlotAccessor* aggAccessor) { + ON_BLOCK_EXIT([this] { + _ctx.aggExpression = false; + _ctx.accumulator = nullptr; + }); + _ctx.aggExpression = true; + _ctx.accumulator = aggAccessor; + return expr.compile(_ctx); + } + + /** * The caller takes ownership of the Value returned by this function and must call * 'releaseValue()' on it. The preferred way to ensure the Value is properly released is to * immediately store it in a ValueGuard. diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp index a3d3bf854e8..c0da82f8fe9 100644 --- a/src/mongo/db/query/sbe_stage_builder.cpp +++ b/src/mongo/db/query/sbe_stage_builder.cpp @@ -2407,7 +2407,9 @@ std::tuple<sbe::value::SlotVector, EvalStage> generateAccumulator( // One accumulator may be translated to multiple accumulator expressions. For example, The // $avg will have two accumulators expressions, a sum(..) and a count which is implemented // as sum(1). - auto accExprs = stage_builder::buildAccumulator(state, accStmt, std::move(argExpr)); + auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); + auto accExprs = stage_builder::buildAccumulator( + accStmt, std::move(argExpr), collatorSlot, *state.frameIdGenerator); sbe::value::SlotVector aggSlots; for (auto& accExpr : accExprs) { diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp index 08f8cfee02c..07a863a2988 100644 --- a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp +++ b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp @@ -42,9 +42,9 @@ namespace mongo::stage_builder { namespace { -std::unique_ptr<sbe::EExpression> wrapMinMaxArg(StageBuilderState& state, - std::unique_ptr<sbe::EExpression> arg) { - return makeLocalBind(state.frameIdGenerator, +std::unique_ptr<sbe::EExpression> wrapMinMaxArg(std::unique_ptr<sbe::EExpression> arg, + sbe::value::FrameIdGenerator& frameIdGenerator) { + return makeLocalBind(&frameIdGenerator, [](sbe::EVariable input) { return sbe::makeE<sbe::EIf>( generateNullOrMissing(input), @@ -55,21 +55,33 @@ std::unique_ptr<sbe::EExpression> wrapMinMaxArg(StageBuilderState& state, } std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorMin( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; - auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); if (collatorSlot) { aggs.push_back(makeFunction("collMin"_sd, sbe::makeE<sbe::EVariable>(*collatorSlot), - wrapMinMaxArg(state, std::move(arg)))); + wrapMinMaxArg(std::move(arg), frameIdGenerator))); } else { - aggs.push_back(makeFunction("min"_sd, wrapMinMaxArg(state, std::move(arg)))); + aggs.push_back(makeFunction("min"_sd, wrapMinMaxArg(std::move(arg), frameIdGenerator))); } return aggs; } +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsMin( + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + tassert(7039501, + "partial agg combiner for $min should have exactly one input slot", + inputSlots.size() == 1); + auto arg = makeVariable(inputSlots[0]); + return buildAccumulatorMin(expr, std::move(arg), collatorSlot, frameIdGenerator); +} + std::unique_ptr<sbe::EExpression> buildFinalizeMin(StageBuilderState& state, const AccumulationExpression& expr, const sbe::value::SlotVector& minSlots) { @@ -84,21 +96,33 @@ std::unique_ptr<sbe::EExpression> buildFinalizeMin(StageBuilderState& state, } std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorMax( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; - auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); if (collatorSlot) { aggs.push_back(makeFunction("collMax"_sd, sbe::makeE<sbe::EVariable>(*collatorSlot), - wrapMinMaxArg(state, std::move(arg)))); + wrapMinMaxArg(std::move(arg), frameIdGenerator))); } else { - aggs.push_back(makeFunction("max"_sd, wrapMinMaxArg(state, std::move(arg)))); + aggs.push_back(makeFunction("max"_sd, wrapMinMaxArg(std::move(arg), frameIdGenerator))); } return aggs; } +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsMax( + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + tassert(7039502, + "partial agg combiner for $max should have exactly one input slot", + inputSlots.size() == 1); + auto arg = makeVariable(inputSlots[0]); + return buildAccumulatorMax(expr, std::move(arg), collatorSlot, frameIdGenerator); +} + std::unique_ptr<sbe::EExpression> buildFinalizeMax(StageBuilderState& state, const AccumulationExpression& expr, const sbe::value::SlotVector& maxSlots) { @@ -111,34 +135,61 @@ std::unique_ptr<sbe::EExpression> buildFinalizeMax(StageBuilderState& state, std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorFirst( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; aggs.push_back(makeFunction("first", makeFillEmptyNull(std::move(arg)))); return aggs; } +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsFirst( + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + tassert(7039503, + "partial agg combiner for $first should have exactly one input slot", + inputSlots.size() == 1); + auto arg = makeVariable(inputSlots[0]); + return buildAccumulatorFirst(expr, std::move(arg), collatorSlot, frameIdGenerator); +} + std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorLast( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; aggs.push_back(makeFunction("last", makeFillEmptyNull(std::move(arg)))); return aggs; } +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsLast( + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + tassert(7039504, + "partial agg combiner for $last should have exactly one input slot", + inputSlots.size() == 1); + auto arg = makeVariable(inputSlots[0]); + return buildAccumulatorLast(expr, std::move(arg), collatorSlot, frameIdGenerator); +} + std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorAvg( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; // 'aggDoubleDoubleSum' will ignore non-numeric values automatically. aggs.push_back(makeFunction("aggDoubleDoubleSum", arg->clone())); // For the counter we need to skip non-numeric values ourselves. - auto addend = makeLocalBind(state.frameIdGenerator, + auto addend = makeLocalBind(&frameIdGenerator, [](sbe::EVariable input) { return sbe::makeE<sbe::EIf>( makeBinaryOp(sbe::EPrimBinary::logicOr, @@ -225,9 +276,10 @@ getCountAddend(const AccumulationExpression& expr) { } // namespace std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorSum( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; // Optimize for a count-like accumulator like {$sum: 1}. @@ -274,12 +326,12 @@ std::unique_ptr<sbe::EExpression> buildFinalizeSum(StageBuilderState& state, } std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorAddToSet( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; const int cap = internalQueryMaxAddToSetBytes.load(); - auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); if (collatorSlot) { aggs.push_back(makeFunction( "collAddToSetCapped"_sd, @@ -317,9 +369,10 @@ std::unique_ptr<sbe::EExpression> buildFinalizeCappedAccumulator( } std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorPush( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { const int cap = internalQueryMaxPushBytes.load(); std::vector<std::unique_ptr<sbe::EExpression>> aggs; aggs.push_back(makeFunction( @@ -330,9 +383,10 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorPush( } std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorStdDev( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; aggs.push_back(makeFunction("aggStdDev", std::move(arg))); return aggs; @@ -400,13 +454,14 @@ std::unique_ptr<sbe::EExpression> buildFinalizeStdDevSamp( } std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorMergeObjects( - StageBuilderState& state, const AccumulationExpression& expr, - std::unique_ptr<sbe::EExpression> arg) { + std::unique_ptr<sbe::EExpression> arg, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { std::vector<std::unique_ptr<sbe::EExpression>> aggs; auto filterExpr = - makeLocalBind(state.frameIdGenerator, + makeLocalBind(&frameIdGenerator, [](sbe::EVariable input) { auto typeCheckExpr = makeBinaryOp(sbe::EPrimBinary::logicOr, @@ -438,11 +493,15 @@ std::pair<std::unique_ptr<sbe::EExpression>, EvalStage> buildArgument( } std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( - StageBuilderState& state, const AccumulationStatement& acc, - std::unique_ptr<sbe::EExpression> inputExpr) { + std::unique_ptr<sbe::EExpression> argExpr, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { using BuildAccumulatorFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>( - StageBuilderState&, const AccumulationExpression&, std::unique_ptr<sbe::EExpression>)>; + const AccumulationExpression&, + std::unique_ptr<sbe::EExpression>, + boost::optional<sbe::value::SlotId>, + sbe::value::FrameIdGenerator&)>; static const StringDataMap<BuildAccumulatorFn> kAccumulatorBuilders = { {AccumulatorMin::kName, &buildAccumulatorMin}, @@ -463,7 +522,38 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName, kAccumulatorBuilders.find(accExprName) != kAccumulatorBuilders.end()); - return std::invoke(kAccumulatorBuilders.at(accExprName), state, acc.expr, std::move(inputExpr)); + return std::invoke(kAccumulatorBuilders.at(accExprName), + acc.expr, + std::move(argExpr), + collatorSlot, + frameIdGenerator); +} + +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates( + const AccumulationStatement& acc, + const sbe::value::SlotVector& inputSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + using BuildAggCombinerFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>( + const AccumulationExpression&, + const sbe::value::SlotVector&, + boost::optional<sbe::value::SlotId>, + sbe::value::FrameIdGenerator&)>; + + static const StringDataMap<BuildAggCombinerFn> kAggCombinerBuilders = { + {AccumulatorFirst::kName, &buildCombinePartialAggsFirst}, + {AccumulatorLast::kName, &buildCombinePartialAggsLast}, + {AccumulatorMax::kName, &buildCombinePartialAggsMax}, + {AccumulatorMin::kName, &buildCombinePartialAggsMin}, + }; + + auto accExprName = acc.expr.name; + uassert(7039500, + str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName, + kAggCombinerBuilders.find(accExprName) != kAggCombinerBuilders.end()); + + return std::invoke( + kAggCombinerBuilders.at(accExprName), acc.expr, inputSlots, collatorSlot, frameIdGenerator); } std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state, diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.h b/src/mongo/db/query/sbe_stage_builder_accumulator.h index 6d8b1a5b112..cd201508823 100644 --- a/src/mongo/db/query/sbe_stage_builder_accumulator.h +++ b/src/mongo/db/query/sbe_stage_builder_accumulator.h @@ -54,12 +54,25 @@ std::pair<std::unique_ptr<sbe::EExpression>, EvalStage> buildArgument( /** * Translates an input AccumulationStatement into an SBE EExpression for accumulation expressions. - * The 'stage' parameter provides the input subtree to build on top of. */ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( - StageBuilderState& state, const AccumulationStatement& acc, - std::unique_ptr<sbe::EExpression> argExpr); + std::unique_ptr<sbe::EExpression> argExpr, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator&); + +/** + * When SBE hash aggregation spills to disk, it spills partial aggregates which need to be combined + * later. This function returns the expressions that can be used to combine partial aggregates for + * the given accumulator 'acc'. The aggregate-of-aggregates will be stored in a slots owned by the + * hash agg stage, while the new partial aggregates to combine can be read from the given + * 'inputSlots'. + */ +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates( + const AccumulationStatement& acc, + const sbe::value::SlotVector& inputSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator&); /** * Translates an input AccumulationStatement into an SBE EExpression that represents an diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp index 395840ccbbd..ff4e958984c 100644 --- a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp +++ b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp @@ -31,10 +31,12 @@ #include <fmt/printf.h> +#include "mongo/db/exec/sbe/expression_test_base.h" #include "mongo/db/pipeline/document_source_group.h" #include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/collation/collator_interface_mock.h" #include "mongo/db/query/query_solution.h" +#include "mongo/db/query/sbe_stage_builder_accumulator.h" #include "mongo/db/query/sbe_stage_builder_test_fixture.h" #include "mongo/unittest/unittest.h" @@ -1677,4 +1679,229 @@ TEST_F(SbeStageBuilderGroupTest, SbeIncompatibleExpressionInGroup) { } } +/** + * A test fixture designed to test that the expressions generated to combine partial aggregates + * that have been spilled to disk work correctly. We use 'EExpressionTestFixture' rather than + * something like 'SbeStageBuilderTestFixture' so that the expressions can be tested in isolation, + * without actually requiring a hash agg stage or without actually spilling any data to disk. + */ +class SbeStageBuilderGroupAggCombinerTest : public sbe::EExpressionTestFixture { +public: + explicit SbeStageBuilderGroupAggCombinerTest() + : _expCtx{make_intrusive<ExpressionContextForTest>()}, + _inputSlotId{bindAccessor(&_inputAccessor)}, + _collatorSlotId{bindAccessor(&_collatorAccessor)} {} + + AccumulationStatement makeAccumulationStatement(StringData accumName) { + _accumulationStmtBson = BSON("unused" << BSON(accumName << "unused")); + VariablesParseState vps = _expCtx->variablesParseState; + return AccumulationStatement::parseAccumulationStatement( + _expCtx.get(), _accumulationStmtBson.firstElement(), vps); + } + + /** + * Verifies that executing the bytecode ('code') for combining partial aggregates for $group + * spilling produces the 'expected' outputs given 'inputs'. + * + * The inputs and expected outputs are expressed as BSON arrays as a convenience to the caller, + * and should have the same length. The bytecode is executed over each element of 'inputs' + * one-by-one, with the result stored into a slot holding the aggregate value. At each step, + * this function asserts that the current aggregate value is equal to the matching element in + * 'expected'. + * + * The string "MISSING" can be used as a sentinel in either 'inputs' or 'outputs' in order to + * represent the Nothing value (since nothingness cannot literally be stored in a BSON array). + */ + void aggregateAndAssertResults(BSONArray inputs, + BSONArray expected, + const sbe::vm::CodeFragment* code) { + // Make sure we are starting from a clean state. + _inputAccessor.reset(); + _aggAccessor.reset(); + + auto [inputTag, inputVal] = makeArray(inputs); + sbe::value::ValueGuard inputGuard{inputTag, inputVal}; + auto [expectedTag, expectedVal] = makeArray(expected); + sbe::value::ValueGuard expectedGuard{expectedTag, expectedVal}; + + sbe::value::ArrayEnumerator inputEnumerator{inputTag, inputVal}; + sbe::value::ArrayEnumerator expectedEnumerator{expectedTag, expectedVal}; + + // Aggregate the inputs one-by-one, and at each step validate that the resulting accumulator + // state is as expected. + while (!inputEnumerator.atEnd()) { + ASSERT_FALSE(expectedEnumerator.atEnd()); + auto [nextInputTag, nextInputVal] = inputEnumerator.getViewOfValue(); + + // Feed in the input value, treating "MISSING" as a special sentinel to indicate the + // Nothing value. + if (sbe::value::isString(nextInputTag) && + sbe::value::getStringView(nextInputTag, nextInputVal) == "MISSING"_sd) { + _inputAccessor.reset(); + } else { + auto [copyTag, copyVal] = sbe::value::copyValue(nextInputTag, nextInputVal); + _inputAccessor.reset(true, copyTag, copyVal); + } + + auto [outputTag, outputVal] = runCompiledExpression(code); + + // Validate that the output value equals the expected value, and then put the output + // value into the slot that holds the accumulation state. + auto [expectedOutputTag, expectedOutputValue] = expectedEnumerator.getViewOfValue(); + if (sbe::value::isString(expectedOutputTag) && + sbe::value::getStringView(expectedOutputTag, expectedOutputValue) == "MISSING"_sd) { + expectedOutputTag = sbe::value::TypeTags::Nothing; + expectedOutputValue = 0; + } + auto [compareTag, compareValue] = sbe::value::compareValue( + outputTag, outputVal, expectedOutputTag, expectedOutputValue); + ASSERT_EQ(compareTag, sbe::value::TypeTags::NumberInt32); + ASSERT_EQ(compareValue, 0); + _aggAccessor.reset(true, outputTag, outputVal); + + inputEnumerator.advance(); + expectedEnumerator.advance(); + } + } + + /** + * Convenience method for producing bytecode which combines partial aggregates for the given + * 'AccumulationStatement'. + * + * Requires that accumulation statement results in a single aggregate with one input and one + * output. Furthermore, cannot be used when the test case involves a non-simple collation. + */ + std::unique_ptr<sbe::vm::CodeFragment> compileSingleInputNoCollator( + const AccumulationStatement& accStatement) { + auto exprs = stage_builder::buildCombinePartialAggregates( + accStatement, {_inputSlotId}, boost::none, _frameIdGenerator); + ASSERT_EQ(exprs.size(), 1u); + _expr = std::move(exprs[0]); + + return compileAggExpression(*_expr, &_aggAccessor); + } + +protected: + sbe::value::FrameIdGenerator _frameIdGenerator; + boost::intrusive_ptr<ExpressionContextForTest> _expCtx; + + // Accessor and corresponding slot id that holds the input to the agg expression. Each time we + // "turn the crank" this will hold the next partial aggregate to be aggregated into + // '_aggAccessor'. + sbe::value::OwnedValueAccessor _inputAccessor; + sbe::value::SlotId _inputSlotId; + + // The accessor which holds the final output resulting from combining all partial outputs. We + // check that the intermediate value is as expected after every turn of the crank. + sbe::value::OwnedValueAccessor _aggAccessor; + + sbe::value::OwnedValueAccessor _collatorAccessor; + sbe::value::SlotId _collatorSlotId; + +private: + BSONObj _accumulationStmtBson; + std::unique_ptr<sbe::EExpression> _expr; +}; + +TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMin) { + auto accStatement = makeAccumulationStatement("$min"_sd); + auto compiledExpr = compileSingleInputNoCollator(accStatement); + + auto inputValues = BSON_ARRAY(8 << 7 << 9 << BSONNULL << 6); + auto expectedAggStates = BSON_ARRAY(8 << 7 << 7 << 7 << 6); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); + + // Test that Nothing values are treated as expected. + inputValues = BSON_ARRAY("MISSING" << 9 << 7 << "MISSING" << 6); + expectedAggStates = BSON_ARRAY("MISSING" << 9 << 7 << 7 << 6); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); +} + +TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMinWithCollation) { + auto accStatement = makeAccumulationStatement("$min"_sd); + + auto exprs = stage_builder::buildCombinePartialAggregates( + accStatement, {_inputSlotId}, {_collatorSlotId}, _frameIdGenerator); + ASSERT_EQ(exprs.size(), 1u); + auto expr = std::move(exprs[0]); + + CollatorInterfaceMock collator{CollatorInterfaceMock::MockType::kReverseString}; + _collatorAccessor.reset(false, + sbe::value::TypeTags::collator, + sbe::value::bitcastFrom<const CollatorInterface*>(&collator)); + + auto compiledExpr = compileAggExpression(*expr, &_aggAccessor); + + // The strings in reverse have the opposite ordering as compared to forwards. + auto inputValues = BSON_ARRAY("az" + << "by" + << "cx"); + auto expectedAggStates = BSON_ARRAY("az" + << "by" + << "cx"); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); +} + +TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMax) { + auto accStatement = makeAccumulationStatement("$max"_sd); + auto compiledExpr = compileSingleInputNoCollator(accStatement); + + auto inputValues = BSON_ARRAY(3 << 1 << 4 << BSONNULL << 8); + auto expectedAggStates = BSON_ARRAY(3 << 3 << 4 << 4 << 8); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); + + // Test that Nothing values are treated as expected. + inputValues = BSON_ARRAY("MISSING" << 7 << 9 << "MISSING" << 10); + expectedAggStates = BSON_ARRAY("MISSING" << 7 << 9 << 9 << 10); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); +} + +TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsMaxWithCollation) { + auto accStatement = makeAccumulationStatement("$max"_sd); + + auto exprs = stage_builder::buildCombinePartialAggregates( + accStatement, {_inputSlotId}, {_collatorSlotId}, _frameIdGenerator); + ASSERT_EQ(exprs.size(), 1u); + auto expr = std::move(exprs[0]); + + CollatorInterfaceMock collator{CollatorInterfaceMock::MockType::kReverseString}; + _collatorAccessor.reset(false, + sbe::value::TypeTags::collator, + sbe::value::bitcastFrom<const CollatorInterface*>(&collator)); + + auto compiledExpr = compileAggExpression(*expr, &_aggAccessor); + + // The strings in reverse have the opposite ordering as compared to forwards. + auto inputValues = BSON_ARRAY("cx" + << "by" + << "az"); + auto expectedAggStates = BSON_ARRAY("cx" + << "by" + << "az"); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); +} + +TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsFirst) { + auto accStatement = makeAccumulationStatement("$first"_sd); + auto compiledExpr = compileSingleInputNoCollator(accStatement); + + auto inputValues = BSON_ARRAY(3 << 1 << BSONNULL << "MISSING" << 8); + auto expectedAggStates = BSON_ARRAY(3 << 3 << 3 << 3 << 3); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); + + // When the first value is missing, the resulting value is a literal null. + inputValues = BSON_ARRAY("MISSING" << 1 << BSONNULL << "MISSING" << 8); + expectedAggStates = BSON_ARRAY(BSONNULL << BSONNULL << BSONNULL << BSONNULL << BSONNULL); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); +} + +TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsLast) { + auto accStatement = makeAccumulationStatement("$last"_sd); + auto compiledExpr = compileSingleInputNoCollator(accStatement); + + auto inputValues = BSON_ARRAY(3 << 1 << BSONNULL << "MISSING" << 8); + auto expectedAggStates = BSON_ARRAY(3 << 1 << BSONNULL << BSONNULL << 8); + aggregateAndAssertResults(inputValues, expectedAggStates, compiledExpr.get()); +} + } // namespace mongo |