diff options
-rw-r--r-- | src/mongo/db/exec/sbe/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/parser/parser.cpp | 34 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp | 100 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/sbe_hash_join_test.cpp | 137 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/stages/hash_agg.cpp | 42 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/stages/hash_agg.h | 8 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/stages/hash_join.cpp | 34 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/stages/hash_join.h | 13 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/values/slot.h | 19 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder.cpp | 4 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 3 |
11 files changed, 363 insertions, 32 deletions
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript index 461e9410f4c..4df8209c6f2 100644 --- a/src/mongo/db/exec/sbe/SConscript +++ b/src/mongo/db/exec/sbe/SConscript @@ -144,6 +144,7 @@ env.CppUnitTest( 'parser/sbe_parser_test.cpp', 'sbe_filter_test.cpp', 'sbe_hash_agg_test.cpp', + 'sbe_hash_join_test.cpp', 'sbe_key_string_test.cpp', 'sbe_limit_skip_test.cpp', 'sbe_math_builtins_test.cpp', diff --git a/src/mongo/db/exec/sbe/parser/parser.cpp b/src/mongo/db/exec/sbe/parser/parser.cpp index 19e6385b8ec..19fe41378ef 100644 --- a/src/mongo/db/exec/sbe/parser/parser.cpp +++ b/src/mongo/db/exec/sbe/parser/parser.cpp @@ -138,8 +138,9 @@ static constexpr auto kSyntax = R"( MKOBJ_FLAG # Return old object OPERATOR # child - GROUP <- 'group' IDENT_LIST PROJECT_LIST OPERATOR - HJOIN <- 'hj' LEFT RIGHT + GROUP <- 'group' IDENT_LIST PROJECT_LIST OPERATOR ( EXPR )? # optional collator + HJOIN <- 'hj' ( EXPR )? # optional collator + LEFT RIGHT LEFT <- 'left' IDENT_LIST IDENT_LIST OPERATOR RIGHT <- 'right' IDENT_LIST IDENT_LIST OPERATOR @@ -978,21 +979,38 @@ void Parser::walkMkObj(AstQuery& ast) { void Parser::walkGroup(AstQuery& ast) { walkChildren(ast); + boost::optional<value::SlotId> collatorSlot; + if (ast.nodes.size() == 3) { + collatorSlot = lookupSlot(std::move(ast.nodes[2]->identifier)); + } + ast.stage = makeS<HashAggStage>(std::move(ast.nodes[2]->stage), lookupSlots(std::move(ast.nodes[0]->identifiers)), lookupSlots(std::move(ast.nodes[1]->projects)), + collatorSlot, getCurrentPlanNodeId()); } void Parser::walkHashJoin(AstQuery& ast) { walkChildren(ast); + + boost::optional<value::SlotId> collatorSlot; + auto outerNode = ast.nodes[0]; + auto innerNode = ast.nodes[1]; + if (ast.nodes.size() == 3) { + outerNode = ast.nodes[1]; + innerNode = ast.nodes[2]; + collatorSlot = lookupSlot(ast.nodes[0]->identifier); + } + ast.stage = - makeS<HashJoinStage>(std::move(ast.nodes[0]->nodes[2]->stage), // outer - std::move(ast.nodes[1]->nodes[2]->stage), // inner - lookupSlots(ast.nodes[0]->nodes[0]->identifiers), // outer conditions - lookupSlots(ast.nodes[0]->nodes[1]->identifiers), // outer projections - lookupSlots(ast.nodes[1]->nodes[0]->identifiers), // inner conditions - lookupSlots(ast.nodes[1]->nodes[1]->identifiers), // inner projections + makeS<HashJoinStage>(std::move(outerNode->nodes[2]->stage), // outer + std::move(innerNode->nodes[2]->stage), // inner + lookupSlots(outerNode->nodes[0]->identifiers), // outer conditions + lookupSlots(outerNode->nodes[1]->identifiers), // outer projections + lookupSlots(innerNode->nodes[0]->identifiers), // inner conditions + lookupSlots(innerNode->nodes[1]->identifiers), // inner projections + collatorSlot, // collator getCurrentPlanNodeId()); } diff --git a/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp b/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp index 8ca2c4c5398..281a7cf0684 100644 --- a/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp +++ b/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp @@ -28,7 +28,7 @@ */ /** - * This file contains tests for sbe::FilterStage. + * This file contains tests for sbe::HashAggStage. */ #include "mongo/platform/basic.h" @@ -82,6 +82,7 @@ TEST_F(HashAggStageTest, HashAggMinMaxTest) { collMaxSlot, stage_builder::makeFunction( "collMax", collExpr->clone(), makeE<EVariable>(scanSlot))), + boost::none, kEmptyPlanNodeId); auto outSlot = generateSlotId(); @@ -135,6 +136,7 @@ TEST_F(HashAggStageTest, HashAggAddToSetTest) { makeEM(hashAggSlot, stage_builder::makeFunction( "collAddToSet", std::move(collExpr), makeE<EVariable>(scanSlot))), + boost::none, kEmptyPlanNodeId); return std::make_pair(hashAggSlot, std::move(hashAggStage)); @@ -183,4 +185,100 @@ TEST_F(HashAggStageTest, HashAggAddToSetTest) { ASSERT_TRUE(resultsEnumerator.atEnd()); } +TEST_F(HashAggStageTest, HashAggCollationTest) { + using namespace std::literals; + for (auto useCollator : {false, true}) { + + BSONArrayBuilder bab1; + bab1.append("A").append("a").append("b").append("c").append("B").append("a"); + auto [inputTag, inputVal] = stage_builder::makeValue(bab1.arr()); + value::ValueGuard inputGuard{inputTag, inputVal}; + + BSONArrayBuilder bab2; + if (useCollator) { + // Collator groups the values as: ["A", "a", "a"], ["B", "b"], ["c"]. + bab2.append(3).append(2).append(1); + } else { + // No Collator groups the values as: ["a", "a"], ["A"], ["B"], ["b"], ["c"]. + bab2.append(2).append(1).append(1).append(1).append(1); + } + auto [expectedTag, expectedVal] = stage_builder::makeValue(bab2.arr()); + value::ValueGuard expectedGuard{expectedTag, expectedVal}; + + auto collatorSlot = generateSlotId(); + + auto makeStageFn = [this, collatorSlot, useCollator](value::SlotId scanSlot, + std::unique_ptr<PlanStage> scanStage) { + // Build a HashAggStage to make sure HashAgg groups use collator correctly. + auto countsSlot = generateSlotId(); + + auto hashAggStage = + makeS<HashAggStage>(std::move(scanStage), + makeSV(scanSlot), + makeEM(countsSlot, + stage_builder::makeFunction( + "sum", + makeE<EConstant>(value::TypeTags::NumberInt64, + value::bitcastFrom<int64_t>(1)))), + boost::optional<value::SlotId>{useCollator, collatorSlot}, + kEmptyPlanNodeId); + + return std::make_pair(countsSlot, std::move(hashAggStage)); + }; + + auto ctx = makeCompileCtx(); + + // Setup collator and insert it into the ctx. + auto collator = std::make_unique<CollatorInterfaceMock>( + CollatorInterfaceMock::MockType::kToLowerString); + value::OwnedValueAccessor collatorAccessor; + ctx->pushCorrelated(collatorSlot, &collatorAccessor); + collatorAccessor.reset(value::TypeTags::collator, + value::bitcastFrom<CollatorInterface*>(collator.get())); + + // Generate a mock scan from `input` with a single output slot. + inputGuard.reset(); + auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal); + + // Call the `makeStage` function to create the HashAggStage, passing in the mock scan + // subtree and the subtree's output slot. + auto [outputSlot, stage] = makeStageFn(scanSlot, std::move(scanStage)); + + // Prepare the tree and get the `SlotAccessor` for the output slot. + auto resultAccessor = prepareTree(ctx.get(), stage.get(), outputSlot); + + // Get all the results produced. + auto [resultsTag, resultsVal] = getAllResults(stage.get(), resultAccessor); + value::ValueGuard resultsGuard{resultsTag, resultsVal}; + + // Sort results for stable compare, since the counts could come out in any order + using valuePair = std::pair<value::TypeTags, value::Value>; + std::vector<valuePair> resultsContents; + auto resultsView = value::getArrayView(resultsVal); + for (size_t i = 0; i < resultsView->size(); i++) { + resultsContents.push_back(resultsView->getAt(i)); + } + std::sort(resultsContents.begin(), + resultsContents.end(), + [](const valuePair& lhs, const valuePair& rhs) -> bool { + auto [lhsTag, lhsVal] = lhs; + auto [rhsTag, rhsVal] = rhs; + auto [compareTag, compareVal] = + value::compareValue(lhsTag, lhsVal, rhsTag, rhsVal); + ASSERT_EQ(compareTag, value::TypeTags::NumberInt32); + return compareVal == 1; + }); + + auto [sortedResultsTag, sortedResultsVal] = value::makeNewArray(); + value::ValueGuard sortedResultsGuard{sortedResultsTag, sortedResultsVal}; + auto sortedResultsView = value::getArrayView(sortedResultsVal); + for (auto [tag, val] : resultsContents) { + auto [tagCopy, valCopy] = copyValue(tag, val); + sortedResultsView->push_back(tagCopy, valCopy); + } + + assertValuesEqual(sortedResultsTag, sortedResultsVal, expectedTag, expectedVal); + } +} + } // namespace mongo::sbe diff --git a/src/mongo/db/exec/sbe/sbe_hash_join_test.cpp b/src/mongo/db/exec/sbe/sbe_hash_join_test.cpp new file mode 100644 index 00000000000..2c74ad4296f --- /dev/null +++ b/src/mongo/db/exec/sbe/sbe_hash_join_test.cpp @@ -0,0 +1,137 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +/** + * This file contains tests for sbe::HashJoinStage. + */ + +#include "mongo/platform/basic.h" + +#include <string_view> + +#include "mongo/db/exec/sbe/sbe_plan_stage_test.h" +#include "mongo/db/exec/sbe/stages/hash_join.h" +#include "mongo/db/query/collation/collator_interface_mock.h" + +namespace mongo::sbe { + +using HashJoinStageTest = PlanStageTestFixture; + +TEST_F(HashJoinStageTest, HashJoinCollationTest) { + using namespace std::literals; + for (auto useCollator : {false, true}) { + auto [innerTag, innerVal] = stage_builder::makeValue(BSON_ARRAY("a" + << "b" + << "c")); + value::ValueGuard innerGuard{innerTag, innerVal}; + + auto [outerTag, outerVal] = stage_builder::makeValue(BSON_ARRAY("a" + << "b" + << "A")); + value::ValueGuard outerGuard{outerTag, outerVal}; + + // After running the join we expect to get back pairs of the keys that were + // matched up. + std::vector<std::pair<std::string, std::string>> expectedVec; + if (useCollator) { + expectedVec = {{"a", "A"}, {"a", "a"}, {"b", "b"}}; + } else { + expectedVec = {{"a", "a"}, {"b", "b"}}; + } + + auto collatorSlot = generateSlotId(); + + auto makeStageFn = [this, collatorSlot, useCollator]( + value::SlotId outerCondSlot, + value::SlotId innerCondSlot, + std::unique_ptr<PlanStage> outerStage, + std::unique_ptr<PlanStage> innerStage) { + auto hashJoinStage = + makeS<HashJoinStage>(std::move(outerStage), + std::move(innerStage), + makeSV(outerCondSlot), + makeSV(), + makeSV(innerCondSlot), + makeSV(), + boost::optional<value::SlotId>{useCollator, collatorSlot}, + kEmptyPlanNodeId); + + return std::make_pair(makeSV(innerCondSlot, outerCondSlot), std::move(hashJoinStage)); + }; + + auto ctx = makeCompileCtx(); + + // Setup collator and insert it into the ctx. + auto collator = std::make_unique<CollatorInterfaceMock>( + CollatorInterfaceMock::MockType::kToLowerString); + value::OwnedValueAccessor collatorAccessor; + ctx->pushCorrelated(collatorSlot, &collatorAccessor); + collatorAccessor.reset(value::TypeTags::collator, + value::bitcastFrom<CollatorInterface*>(collator.get())); + + // Two separate virtual scans are needed since HashJoinStage needs two child stages. + outerGuard.reset(); + auto [outerCondSlot, outerStage] = generateVirtualScan(outerTag, outerVal); + + innerGuard.reset(); + auto [innerCondSlot, innerStage] = generateVirtualScan(innerTag, innerVal); + + // Call the `makeStage` callback to create the HashJoinStage, passing in the mock scan + // subtrees and the subtree's output slots. + auto [outputSlots, stage] = + makeStageFn(outerCondSlot, innerCondSlot, std::move(outerStage), std::move(innerStage)); + + // Prepare the tree and get the SlotAccessor for the output slots. + auto resultAccessors = prepareTree(ctx.get(), stage.get(), outputSlots); + + // Get all the results produced by HashJoin. + auto [resultsTag, resultsVal] = getAllResultsMulti(stage.get(), resultAccessors); + value::ValueGuard resultsGuard{resultsTag, resultsVal}; + ASSERT_EQ(resultsTag, value::TypeTags::Array); + auto resultsView = value::getArrayView(resultsVal); + + // make sure all the expected pairs occur in the result + ASSERT_EQ(resultsView->size(), expectedVec.size()); + for (auto [outer, inner] : expectedVec) { + auto [expectedTag, expectedVal] = stage_builder::makeValue(BSON_ARRAY(outer << inner)); + bool found = false; + for (size_t i = 0; i < resultsView->size(); i++) { + auto [tag, val] = resultsView->getAt(i); + auto [cmpTag, cmpVal] = compareValue(expectedTag, expectedVal, tag, val); + if (cmpVal == 0) { + found = true; + break; + } + } + ASSERT_TRUE(found); + } + } +} + +} // namespace mongo::sbe diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.cpp b/src/mongo/db/exec/sbe/stages/hash_agg.cpp index 2424ba019de..b4987417aa6 100644 --- a/src/mongo/db/exec/sbe/stages/hash_agg.cpp +++ b/src/mongo/db/exec/sbe/stages/hash_agg.cpp @@ -38,8 +38,12 @@ namespace sbe { HashAggStage::HashAggStage(std::unique_ptr<PlanStage> input, value::SlotVector gbs, value::SlotMap<std::unique_ptr<EExpression>> aggs, + boost::optional<value::SlotId> collatorSlot, PlanNodeId planNodeId) - : PlanStage("group"_sd, planNodeId), _gbs(std::move(gbs)), _aggs(std::move(aggs)) { + : PlanStage("group"_sd, planNodeId), + _gbs(std::move(gbs)), + _aggs(std::move(aggs)), + _collatorSlot(collatorSlot) { _children.emplace_back(std::move(input)); } @@ -49,12 +53,19 @@ std::unique_ptr<PlanStage> HashAggStage::clone() const { aggs.emplace(k, v->clone()); } return std::make_unique<HashAggStage>( - _children[0]->clone(), _gbs, std::move(aggs), _commonStats.nodeId); + _children[0]->clone(), _gbs, std::move(aggs), _collatorSlot, _commonStats.nodeId); } void HashAggStage::prepare(CompileCtx& ctx) { _children[0]->prepare(ctx); + if (_collatorSlot) { + _collatorAccessor = getAccessor(ctx, *_collatorSlot); + tassert(5402501, + "collator accessor should exist if collator slot provided to HashAggStage", + _collatorAccessor != nullptr); + } + value::SlotSet dupCheck; size_t counter = 0; // Process group by columns. @@ -107,8 +118,15 @@ void HashAggStage::open(bool reOpen) { _commonStats.opens++; _children[0]->open(reOpen); - if (reOpen) { - _ht.clear(); + if (_collatorAccessor) { + auto [tag, collatorVal] = _collatorAccessor->getViewOfValue(); + uassert(5402503, "collatorSlot must be of collator type", tag == value::TypeTags::collator); + auto collatorView = value::getCollatorView(collatorVal); + const value::MaterializedRowHasher hasher(collatorView); + const value::MaterializedRowEq equator(collatorView); + _ht.emplace(0, hasher, equator); + } else { + _ht.emplace(); } while (_children[0]->getNext() == PlanState::ADVANCED) { @@ -120,7 +138,7 @@ void HashAggStage::open(bool reOpen) { key.reset(idx++, false, tag, val); } - auto [it, inserted] = _ht.try_emplace(std::move(key), value::MaterializedRow{0}); + auto [it, inserted] = _ht->try_emplace(std::move(key), value::MaterializedRow{0}); if (inserted) { // Copy keys. const_cast<value::MaterializedRow&>(it->first).makeOwned(); @@ -138,19 +156,19 @@ void HashAggStage::open(bool reOpen) { _children[0]->close(); - _htIt = _ht.end(); + _htIt = _ht->end(); } PlanState HashAggStage::getNext() { auto optTimer(getOptTimer(_opCtx)); - if (_htIt == _ht.end()) { - _htIt = _ht.begin(); + if (_htIt == _ht->end()) { + _htIt = _ht->begin(); } else { ++_htIt; } - if (_htIt == _ht.end()) { + if (_htIt == _ht->end()) { return trackPlanState(PlanState::IS_EOF); } @@ -185,7 +203,7 @@ void HashAggStage::close() { auto optTimer(getOptTimer(_opCtx)); _commonStats.closes++; - _ht.clear(); + _ht = boost::none; } std::vector<DebugPrinter::Block> HashAggStage::debugPrint() const { @@ -215,6 +233,10 @@ std::vector<DebugPrinter::Block> HashAggStage::debugPrint() const { }); ret.emplace_back("`]"); + if (_collatorSlot) { + DebugPrinter::addIdentifier(ret, *_collatorSlot); + } + DebugPrinter::addNewLine(ret); DebugPrinter::addBlocks(ret, _children[0]->debugPrint()); diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.h b/src/mongo/db/exec/sbe/stages/hash_agg.h index 2c2ec5ab6aa..06c0e211b31 100644 --- a/src/mongo/db/exec/sbe/stages/hash_agg.h +++ b/src/mongo/db/exec/sbe/stages/hash_agg.h @@ -43,6 +43,7 @@ public: HashAggStage(std::unique_ptr<PlanStage> input, value::SlotVector gbs, value::SlotMap<std::unique_ptr<EExpression>> aggs, + boost::optional<value::SlotId> collatorSlot, PlanNodeId planNodeId); std::unique_ptr<PlanStage> clone() const final; @@ -68,6 +69,7 @@ private: const value::SlotVector _gbs; const value::SlotMap<std::unique_ptr<EExpression>> _aggs; + const boost::optional<value::SlotId> _collatorSlot; value::SlotAccessorMap _outAccessors; std::vector<value::SlotAccessor*> _inKeyAccessors; @@ -76,8 +78,10 @@ private: std::vector<std::unique_ptr<HashAggAccessor>> _outAggAccessors; std::vector<std::unique_ptr<vm::CodeFragment>> _aggCodes; - // TODO SERVER-54025: Update HashAggStage so that group-bys are collation-aware. - TableType _ht; + // Only set if collator slot provided on construction. + value::SlotAccessor* _collatorAccessor = nullptr; + + boost::optional<TableType> _ht; TableType::iterator _htIt; vm::ByteCode _bytecode; diff --git a/src/mongo/db/exec/sbe/stages/hash_join.cpp b/src/mongo/db/exec/sbe/stages/hash_join.cpp index e92bbebdab4..190b8d7dd05 100644 --- a/src/mongo/db/exec/sbe/stages/hash_join.cpp +++ b/src/mongo/db/exec/sbe/stages/hash_join.cpp @@ -42,12 +42,14 @@ HashJoinStage::HashJoinStage(std::unique_ptr<PlanStage> outer, value::SlotVector outerProjects, value::SlotVector innerCond, value::SlotVector innerProjects, + boost::optional<value::SlotId> collatorSlot, PlanNodeId planNodeId) : PlanStage("hj"_sd, planNodeId), _outerCond(std::move(outerCond)), _outerProjects(std::move(outerProjects)), _innerCond(std::move(innerCond)), _innerProjects(std::move(innerProjects)), + _collatorSlot(collatorSlot), _probeKey(0) { if (_outerCond.size() != _innerCond.size()) { uasserted(4822823, "left and right size do not match"); @@ -64,6 +66,7 @@ std::unique_ptr<PlanStage> HashJoinStage::clone() const { _outerProjects, _innerCond, _innerProjects, + _collatorSlot, _commonStats.nodeId); } @@ -71,6 +74,13 @@ void HashJoinStage::prepare(CompileCtx& ctx) { _children[0]->prepare(ctx); _children[1]->prepare(ctx); + if (_collatorSlot) { + _collatorAccessor = getAccessor(ctx, *_collatorSlot); + tassert(5402502, + "collator accessor should exist if collator slot provided to HashJoinStage", + _collatorAccessor != nullptr); + } + size_t counter = 0; value::SlotSet dupCheck; for (auto& slot : _outerCond) { @@ -121,6 +131,17 @@ value::SlotAccessor* HashJoinStage::getAccessor(CompileCtx& ctx, value::SlotId s void HashJoinStage::open(bool reOpen) { auto optTimer(getOptTimer(_opCtx)); + if (_collatorAccessor) { + auto [tag, collatorVal] = _collatorAccessor->getViewOfValue(); + uassert(5402504, "collatorSlot must be of collator type", tag == value::TypeTags::collator); + auto collatorView = value::getCollatorView(collatorVal); + const value::MaterializedRowHasher hasher(collatorView); + const value::MaterializedRowEq equator(collatorView); + _ht.emplace(0, hasher, equator); + } else { + _ht.emplace(); + } + _commonStats.opens++; _children[0]->open(reOpen); // Insert the outer side into the hash table. @@ -142,15 +163,15 @@ void HashJoinStage::open(bool reOpen) { project.reset(idx++, true, tag, val); } - _ht.emplace(std::move(key), std::move(project)); + _ht->emplace(std::move(key), std::move(project)); } _children[0]->close(); _children[1]->open(reOpen); - _htIt = _ht.end(); - _htItEnd = _ht.end(); + _htIt = _ht->end(); + _htItEnd = _ht->end(); } PlanState HashJoinStage::getNext() { @@ -175,7 +196,7 @@ PlanState HashJoinStage::getNext() { _probeKey.reset(idx++, false, tag, val); } - auto [low, hi] = _ht.equal_range(_probeKey); + auto [low, hi] = _ht->equal_range(_probeKey); _htIt = low; _htItEnd = hi; // If _htIt == _htItEnd (i.e. no match) then RIGHT and OUTER joins @@ -191,6 +212,7 @@ void HashJoinStage::close() { _commonStats.closes++; _children[1]->close(); + _ht = boost::none; } std::unique_ptr<PlanStageStats> HashJoinStage::getStats(bool includeDebugInfo) const { @@ -207,6 +229,10 @@ const SpecificStats* HashJoinStage::getSpecificStats() const { std::vector<DebugPrinter::Block> HashJoinStage::debugPrint() const { auto ret = PlanStage::debugPrint(); + if (_collatorSlot) { + DebugPrinter::addIdentifier(ret, *_collatorSlot); + } + ret.emplace_back(DebugPrinter::Block::cmdIncIndent); DebugPrinter::addKeyword(ret, "left"); diff --git a/src/mongo/db/exec/sbe/stages/hash_join.h b/src/mongo/db/exec/sbe/stages/hash_join.h index f40af886a02..05ec1ec294a 100644 --- a/src/mongo/db/exec/sbe/stages/hash_join.h +++ b/src/mongo/db/exec/sbe/stages/hash_join.h @@ -43,6 +43,7 @@ public: value::SlotVector outerProjects, value::SlotVector innerCond, value::SlotVector innerProjects, + boost::optional<value::SlotId> collatorSlot, PlanNodeId planNodeId); std::unique_ptr<PlanStage> clone() const final; @@ -70,11 +71,12 @@ private: const value::SlotVector _outerProjects; const value::SlotVector _innerCond; const value::SlotVector _innerProjects; + const boost::optional<value::SlotId> _collatorSlot; // All defined values from the outer side (i.e. they come from the hash table). value::SlotAccessorMap _outOuterAccessors; - // Accessors of input codition values (keys) that are being inserted into the hash table. + // Accessors of input condition values (keys) that are being inserted into the hash table. std::vector<value::SlotAccessor*> _inOuterKeyAccessors; // Accessors of output keys. @@ -86,15 +88,16 @@ private: // Accessors of output projections. std::vector<std::unique_ptr<HashProjectAccessor>> _outOuterProjectAccessors; - // Accessors of input codition values (keys) that are being inserted into the hash table. + // Accessors of input condition values (keys) that are being inserted into the hash table. std::vector<value::SlotAccessor*> _inInnerKeyAccessors; + // Accessor for collator. Only set if collatorSlot provided during construction. + value::SlotAccessor* _collatorAccessor = nullptr; + // Key used to probe inside the hash table. value::MaterializedRow _probeKey; - // TODO SERVER-54025: Update HashJoinStage so that it's mechanism for matching outer keys and - // inner keys is collation-aware. - TableType _ht; + boost::optional<TableType> _ht; TableType::iterator _htIt; TableType::iterator _htItEnd; diff --git a/src/mongo/db/exec/sbe/values/slot.h b/src/mongo/db/exec/sbe/values/slot.h index b5abc318a9c..f7cc78601cd 100644 --- a/src/mongo/db/exec/sbe/values/slot.h +++ b/src/mongo/db/exec/sbe/values/slot.h @@ -508,11 +508,16 @@ private: }; struct MaterializedRowEq { + using ComparatorType = StringData::ComparatorInterface*; + + explicit MaterializedRowEq(const ComparatorType comparator = nullptr) + : _comparator(comparator) {} + bool operator()(const MaterializedRow& lhs, const MaterializedRow& rhs) const { for (size_t idx = 0; idx < lhs.size(); ++idx) { auto [lhsTag, lhsVal] = lhs.getViewOfValue(idx); auto [rhsTag, rhsVal] = rhs.getViewOfValue(idx); - auto [tag, val] = compareValue(lhsTag, lhsVal, rhsTag, rhsVal); + auto [tag, val] = compareValue(lhsTag, lhsVal, rhsTag, rhsVal, _comparator); if (tag != value::TypeTags::NumberInt32 || value::bitcastTo<int32_t>(val) != 0) { return false; @@ -521,17 +526,27 @@ struct MaterializedRowEq { return true; } + +private: + const ComparatorType _comparator = nullptr; }; struct MaterializedRowHasher { + using CollatorType = CollatorInterface*; + + explicit MaterializedRowHasher(const CollatorType collator = nullptr) : _collator(collator) {} + std::size_t operator()(const MaterializedRow& k) const { size_t res = hashInit(); for (size_t idx = 0; idx < k.size(); ++idx) { auto [tag, val] = k.getViewOfValue(idx); - res = hashCombine(res, hashValue(tag, val)); + res = hashCombine(res, hashValue(tag, val, _collator)); } return res; } + +private: + const CollatorType _collator = nullptr; }; /** diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp index 3ff4deb6abd..709fab540c7 100644 --- a/src/mongo/db/query/sbe_stage_builder.cpp +++ b/src/mongo/db/query/sbe_stage_builder.cpp @@ -1237,6 +1237,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder auto innerCondSlots = sbe::makeSV(innerIdSlot); auto innerProjectSlots = sbe::makeSV(innerResultSlot); + auto collatorSlot = _data.env->getSlotIfExists("collator"_sd); + // Designate outputs. PlanStageSlots outputs(reqs, &_slotIdGenerator); if (reqs.has(kRecordId)) { @@ -1252,6 +1254,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder outerProjectSlots, innerCondSlots, innerProjectSlots, + collatorSlot, root->nodeId()); // If there are more than 2 children, iterate all remaining children and hash @@ -1271,6 +1274,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder projectSlots, innerCondSlots, innerProjectSlots, + collatorSlot, root->nodeId()); } diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 397fe5bdb0f..19e00bfb6be 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -1125,6 +1125,7 @@ public: makeLimitTree(std::move(unionWithNullStage.stage), _context->planNodeId, numChildren); // Create a group stage to aggregate elements into a single array. + auto collatorSlot = _context->runtimeEnvironment->getSlotIfExists("collator"_sd); auto addToArrayExpr = makeFunction("addToArray", sbe::makeE<sbe::EVariable>(unionWithNullSlot)); auto groupSlot = _context->slotIdGenerator->generate(); @@ -1132,6 +1133,7 @@ public: sbe::makeS<sbe::HashAggStage>(std::move(limitNumChildren), sbe::makeSV(), sbe::makeEM(groupSlot, std::move(addToArrayExpr)), + collatorSlot, _context->planNodeId); EvalStage groupEvalStage = {std::move(groupStage), sbe::makeSV(groupSlot)}; @@ -1160,6 +1162,7 @@ public: std::move(unwindEvalStage.stage), sbe::makeSV(), sbe::makeEM(finalGroupSlot, std::move(finalAddToArrayExpr)), + collatorSlot, _context->planNodeId); // Create a branch stage to select between the branch that produces one null if any eleemnts |