summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/db/exec/sbe/SConscript1
-rw-r--r--src/mongo/db/exec/sbe/parser/parser.cpp34
-rw-r--r--src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp100
-rw-r--r--src/mongo/db/exec/sbe/sbe_hash_join_test.cpp137
-rw-r--r--src/mongo/db/exec/sbe/stages/hash_agg.cpp42
-rw-r--r--src/mongo/db/exec/sbe/stages/hash_agg.h8
-rw-r--r--src/mongo/db/exec/sbe/stages/hash_join.cpp34
-rw-r--r--src/mongo/db/exec/sbe/stages/hash_join.h13
-rw-r--r--src/mongo/db/exec/sbe/values/slot.h19
-rw-r--r--src/mongo/db/query/sbe_stage_builder.cpp4
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp3
11 files changed, 363 insertions, 32 deletions
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript
index 461e9410f4c..4df8209c6f2 100644
--- a/src/mongo/db/exec/sbe/SConscript
+++ b/src/mongo/db/exec/sbe/SConscript
@@ -144,6 +144,7 @@ env.CppUnitTest(
'parser/sbe_parser_test.cpp',
'sbe_filter_test.cpp',
'sbe_hash_agg_test.cpp',
+ 'sbe_hash_join_test.cpp',
'sbe_key_string_test.cpp',
'sbe_limit_skip_test.cpp',
'sbe_math_builtins_test.cpp',
diff --git a/src/mongo/db/exec/sbe/parser/parser.cpp b/src/mongo/db/exec/sbe/parser/parser.cpp
index 19e6385b8ec..19fe41378ef 100644
--- a/src/mongo/db/exec/sbe/parser/parser.cpp
+++ b/src/mongo/db/exec/sbe/parser/parser.cpp
@@ -138,8 +138,9 @@ static constexpr auto kSyntax = R"(
MKOBJ_FLAG # Return old object
OPERATOR # child
- GROUP <- 'group' IDENT_LIST PROJECT_LIST OPERATOR
- HJOIN <- 'hj' LEFT RIGHT
+ GROUP <- 'group' IDENT_LIST PROJECT_LIST OPERATOR ( EXPR )? # optional collator
+ HJOIN <- 'hj' ( EXPR )? # optional collator
+ LEFT RIGHT
LEFT <- 'left' IDENT_LIST IDENT_LIST OPERATOR
RIGHT <- 'right' IDENT_LIST IDENT_LIST OPERATOR
@@ -978,21 +979,38 @@ void Parser::walkMkObj(AstQuery& ast) {
void Parser::walkGroup(AstQuery& ast) {
walkChildren(ast);
+ boost::optional<value::SlotId> collatorSlot;
+ if (ast.nodes.size() == 3) {
+ collatorSlot = lookupSlot(std::move(ast.nodes[2]->identifier));
+ }
+
ast.stage = makeS<HashAggStage>(std::move(ast.nodes[2]->stage),
lookupSlots(std::move(ast.nodes[0]->identifiers)),
lookupSlots(std::move(ast.nodes[1]->projects)),
+ collatorSlot,
getCurrentPlanNodeId());
}
void Parser::walkHashJoin(AstQuery& ast) {
walkChildren(ast);
+
+ boost::optional<value::SlotId> collatorSlot;
+ auto outerNode = ast.nodes[0];
+ auto innerNode = ast.nodes[1];
+ if (ast.nodes.size() == 3) {
+ outerNode = ast.nodes[1];
+ innerNode = ast.nodes[2];
+ collatorSlot = lookupSlot(ast.nodes[0]->identifier);
+ }
+
ast.stage =
- makeS<HashJoinStage>(std::move(ast.nodes[0]->nodes[2]->stage), // outer
- std::move(ast.nodes[1]->nodes[2]->stage), // inner
- lookupSlots(ast.nodes[0]->nodes[0]->identifiers), // outer conditions
- lookupSlots(ast.nodes[0]->nodes[1]->identifiers), // outer projections
- lookupSlots(ast.nodes[1]->nodes[0]->identifiers), // inner conditions
- lookupSlots(ast.nodes[1]->nodes[1]->identifiers), // inner projections
+ makeS<HashJoinStage>(std::move(outerNode->nodes[2]->stage), // outer
+ std::move(innerNode->nodes[2]->stage), // inner
+ lookupSlots(outerNode->nodes[0]->identifiers), // outer conditions
+ lookupSlots(outerNode->nodes[1]->identifiers), // outer projections
+ lookupSlots(innerNode->nodes[0]->identifiers), // inner conditions
+ lookupSlots(innerNode->nodes[1]->identifiers), // inner projections
+ collatorSlot, // collator
getCurrentPlanNodeId());
}
diff --git a/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp b/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp
index 8ca2c4c5398..281a7cf0684 100644
--- a/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp
+++ b/src/mongo/db/exec/sbe/sbe_hash_agg_test.cpp
@@ -28,7 +28,7 @@
*/
/**
- * This file contains tests for sbe::FilterStage.
+ * This file contains tests for sbe::HashAggStage.
*/
#include "mongo/platform/basic.h"
@@ -82,6 +82,7 @@ TEST_F(HashAggStageTest, HashAggMinMaxTest) {
collMaxSlot,
stage_builder::makeFunction(
"collMax", collExpr->clone(), makeE<EVariable>(scanSlot))),
+ boost::none,
kEmptyPlanNodeId);
auto outSlot = generateSlotId();
@@ -135,6 +136,7 @@ TEST_F(HashAggStageTest, HashAggAddToSetTest) {
makeEM(hashAggSlot,
stage_builder::makeFunction(
"collAddToSet", std::move(collExpr), makeE<EVariable>(scanSlot))),
+ boost::none,
kEmptyPlanNodeId);
return std::make_pair(hashAggSlot, std::move(hashAggStage));
@@ -183,4 +185,100 @@ TEST_F(HashAggStageTest, HashAggAddToSetTest) {
ASSERT_TRUE(resultsEnumerator.atEnd());
}
+TEST_F(HashAggStageTest, HashAggCollationTest) {
+ using namespace std::literals;
+ for (auto useCollator : {false, true}) {
+
+ BSONArrayBuilder bab1;
+ bab1.append("A").append("a").append("b").append("c").append("B").append("a");
+ auto [inputTag, inputVal] = stage_builder::makeValue(bab1.arr());
+ value::ValueGuard inputGuard{inputTag, inputVal};
+
+ BSONArrayBuilder bab2;
+ if (useCollator) {
+ // Collator groups the values as: ["A", "a", "a"], ["B", "b"], ["c"].
+ bab2.append(3).append(2).append(1);
+ } else {
+ // No Collator groups the values as: ["a", "a"], ["A"], ["B"], ["b"], ["c"].
+ bab2.append(2).append(1).append(1).append(1).append(1);
+ }
+ auto [expectedTag, expectedVal] = stage_builder::makeValue(bab2.arr());
+ value::ValueGuard expectedGuard{expectedTag, expectedVal};
+
+ auto collatorSlot = generateSlotId();
+
+ auto makeStageFn = [this, collatorSlot, useCollator](value::SlotId scanSlot,
+ std::unique_ptr<PlanStage> scanStage) {
+ // Build a HashAggStage to make sure HashAgg groups use collator correctly.
+ auto countsSlot = generateSlotId();
+
+ auto hashAggStage =
+ makeS<HashAggStage>(std::move(scanStage),
+ makeSV(scanSlot),
+ makeEM(countsSlot,
+ stage_builder::makeFunction(
+ "sum",
+ makeE<EConstant>(value::TypeTags::NumberInt64,
+ value::bitcastFrom<int64_t>(1)))),
+ boost::optional<value::SlotId>{useCollator, collatorSlot},
+ kEmptyPlanNodeId);
+
+ return std::make_pair(countsSlot, std::move(hashAggStage));
+ };
+
+ auto ctx = makeCompileCtx();
+
+ // Setup collator and insert it into the ctx.
+ auto collator = std::make_unique<CollatorInterfaceMock>(
+ CollatorInterfaceMock::MockType::kToLowerString);
+ value::OwnedValueAccessor collatorAccessor;
+ ctx->pushCorrelated(collatorSlot, &collatorAccessor);
+ collatorAccessor.reset(value::TypeTags::collator,
+ value::bitcastFrom<CollatorInterface*>(collator.get()));
+
+ // Generate a mock scan from `input` with a single output slot.
+ inputGuard.reset();
+ auto [scanSlot, scanStage] = generateVirtualScan(inputTag, inputVal);
+
+ // Call the `makeStage` function to create the HashAggStage, passing in the mock scan
+ // subtree and the subtree's output slot.
+ auto [outputSlot, stage] = makeStageFn(scanSlot, std::move(scanStage));
+
+ // Prepare the tree and get the `SlotAccessor` for the output slot.
+ auto resultAccessor = prepareTree(ctx.get(), stage.get(), outputSlot);
+
+ // Get all the results produced.
+ auto [resultsTag, resultsVal] = getAllResults(stage.get(), resultAccessor);
+ value::ValueGuard resultsGuard{resultsTag, resultsVal};
+
+ // Sort results for stable compare, since the counts could come out in any order
+ using valuePair = std::pair<value::TypeTags, value::Value>;
+ std::vector<valuePair> resultsContents;
+ auto resultsView = value::getArrayView(resultsVal);
+ for (size_t i = 0; i < resultsView->size(); i++) {
+ resultsContents.push_back(resultsView->getAt(i));
+ }
+ std::sort(resultsContents.begin(),
+ resultsContents.end(),
+ [](const valuePair& lhs, const valuePair& rhs) -> bool {
+ auto [lhsTag, lhsVal] = lhs;
+ auto [rhsTag, rhsVal] = rhs;
+ auto [compareTag, compareVal] =
+ value::compareValue(lhsTag, lhsVal, rhsTag, rhsVal);
+ ASSERT_EQ(compareTag, value::TypeTags::NumberInt32);
+ return compareVal == 1;
+ });
+
+ auto [sortedResultsTag, sortedResultsVal] = value::makeNewArray();
+ value::ValueGuard sortedResultsGuard{sortedResultsTag, sortedResultsVal};
+ auto sortedResultsView = value::getArrayView(sortedResultsVal);
+ for (auto [tag, val] : resultsContents) {
+ auto [tagCopy, valCopy] = copyValue(tag, val);
+ sortedResultsView->push_back(tagCopy, valCopy);
+ }
+
+ assertValuesEqual(sortedResultsTag, sortedResultsVal, expectedTag, expectedVal);
+ }
+}
+
} // namespace mongo::sbe
diff --git a/src/mongo/db/exec/sbe/sbe_hash_join_test.cpp b/src/mongo/db/exec/sbe/sbe_hash_join_test.cpp
new file mode 100644
index 00000000000..2c74ad4296f
--- /dev/null
+++ b/src/mongo/db/exec/sbe/sbe_hash_join_test.cpp
@@ -0,0 +1,137 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+/**
+ * This file contains tests for sbe::HashJoinStage.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include <string_view>
+
+#include "mongo/db/exec/sbe/sbe_plan_stage_test.h"
+#include "mongo/db/exec/sbe/stages/hash_join.h"
+#include "mongo/db/query/collation/collator_interface_mock.h"
+
+namespace mongo::sbe {
+
+using HashJoinStageTest = PlanStageTestFixture;
+
+TEST_F(HashJoinStageTest, HashJoinCollationTest) {
+ using namespace std::literals;
+ for (auto useCollator : {false, true}) {
+ auto [innerTag, innerVal] = stage_builder::makeValue(BSON_ARRAY("a"
+ << "b"
+ << "c"));
+ value::ValueGuard innerGuard{innerTag, innerVal};
+
+ auto [outerTag, outerVal] = stage_builder::makeValue(BSON_ARRAY("a"
+ << "b"
+ << "A"));
+ value::ValueGuard outerGuard{outerTag, outerVal};
+
+ // After running the join we expect to get back pairs of the keys that were
+ // matched up.
+ std::vector<std::pair<std::string, std::string>> expectedVec;
+ if (useCollator) {
+ expectedVec = {{"a", "A"}, {"a", "a"}, {"b", "b"}};
+ } else {
+ expectedVec = {{"a", "a"}, {"b", "b"}};
+ }
+
+ auto collatorSlot = generateSlotId();
+
+ auto makeStageFn = [this, collatorSlot, useCollator](
+ value::SlotId outerCondSlot,
+ value::SlotId innerCondSlot,
+ std::unique_ptr<PlanStage> outerStage,
+ std::unique_ptr<PlanStage> innerStage) {
+ auto hashJoinStage =
+ makeS<HashJoinStage>(std::move(outerStage),
+ std::move(innerStage),
+ makeSV(outerCondSlot),
+ makeSV(),
+ makeSV(innerCondSlot),
+ makeSV(),
+ boost::optional<value::SlotId>{useCollator, collatorSlot},
+ kEmptyPlanNodeId);
+
+ return std::make_pair(makeSV(innerCondSlot, outerCondSlot), std::move(hashJoinStage));
+ };
+
+ auto ctx = makeCompileCtx();
+
+ // Setup collator and insert it into the ctx.
+ auto collator = std::make_unique<CollatorInterfaceMock>(
+ CollatorInterfaceMock::MockType::kToLowerString);
+ value::OwnedValueAccessor collatorAccessor;
+ ctx->pushCorrelated(collatorSlot, &collatorAccessor);
+ collatorAccessor.reset(value::TypeTags::collator,
+ value::bitcastFrom<CollatorInterface*>(collator.get()));
+
+ // Two separate virtual scans are needed since HashJoinStage needs two child stages.
+ outerGuard.reset();
+ auto [outerCondSlot, outerStage] = generateVirtualScan(outerTag, outerVal);
+
+ innerGuard.reset();
+ auto [innerCondSlot, innerStage] = generateVirtualScan(innerTag, innerVal);
+
+ // Call the `makeStage` callback to create the HashJoinStage, passing in the mock scan
+ // subtrees and the subtree's output slots.
+ auto [outputSlots, stage] =
+ makeStageFn(outerCondSlot, innerCondSlot, std::move(outerStage), std::move(innerStage));
+
+ // Prepare the tree and get the SlotAccessor for the output slots.
+ auto resultAccessors = prepareTree(ctx.get(), stage.get(), outputSlots);
+
+ // Get all the results produced by HashJoin.
+ auto [resultsTag, resultsVal] = getAllResultsMulti(stage.get(), resultAccessors);
+ value::ValueGuard resultsGuard{resultsTag, resultsVal};
+ ASSERT_EQ(resultsTag, value::TypeTags::Array);
+ auto resultsView = value::getArrayView(resultsVal);
+
+ // make sure all the expected pairs occur in the result
+ ASSERT_EQ(resultsView->size(), expectedVec.size());
+ for (auto [outer, inner] : expectedVec) {
+ auto [expectedTag, expectedVal] = stage_builder::makeValue(BSON_ARRAY(outer << inner));
+ bool found = false;
+ for (size_t i = 0; i < resultsView->size(); i++) {
+ auto [tag, val] = resultsView->getAt(i);
+ auto [cmpTag, cmpVal] = compareValue(expectedTag, expectedVal, tag, val);
+ if (cmpVal == 0) {
+ found = true;
+ break;
+ }
+ }
+ ASSERT_TRUE(found);
+ }
+ }
+}
+
+} // namespace mongo::sbe
diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.cpp b/src/mongo/db/exec/sbe/stages/hash_agg.cpp
index 2424ba019de..b4987417aa6 100644
--- a/src/mongo/db/exec/sbe/stages/hash_agg.cpp
+++ b/src/mongo/db/exec/sbe/stages/hash_agg.cpp
@@ -38,8 +38,12 @@ namespace sbe {
HashAggStage::HashAggStage(std::unique_ptr<PlanStage> input,
value::SlotVector gbs,
value::SlotMap<std::unique_ptr<EExpression>> aggs,
+ boost::optional<value::SlotId> collatorSlot,
PlanNodeId planNodeId)
- : PlanStage("group"_sd, planNodeId), _gbs(std::move(gbs)), _aggs(std::move(aggs)) {
+ : PlanStage("group"_sd, planNodeId),
+ _gbs(std::move(gbs)),
+ _aggs(std::move(aggs)),
+ _collatorSlot(collatorSlot) {
_children.emplace_back(std::move(input));
}
@@ -49,12 +53,19 @@ std::unique_ptr<PlanStage> HashAggStage::clone() const {
aggs.emplace(k, v->clone());
}
return std::make_unique<HashAggStage>(
- _children[0]->clone(), _gbs, std::move(aggs), _commonStats.nodeId);
+ _children[0]->clone(), _gbs, std::move(aggs), _collatorSlot, _commonStats.nodeId);
}
void HashAggStage::prepare(CompileCtx& ctx) {
_children[0]->prepare(ctx);
+ if (_collatorSlot) {
+ _collatorAccessor = getAccessor(ctx, *_collatorSlot);
+ tassert(5402501,
+ "collator accessor should exist if collator slot provided to HashAggStage",
+ _collatorAccessor != nullptr);
+ }
+
value::SlotSet dupCheck;
size_t counter = 0;
// Process group by columns.
@@ -107,8 +118,15 @@ void HashAggStage::open(bool reOpen) {
_commonStats.opens++;
_children[0]->open(reOpen);
- if (reOpen) {
- _ht.clear();
+ if (_collatorAccessor) {
+ auto [tag, collatorVal] = _collatorAccessor->getViewOfValue();
+ uassert(5402503, "collatorSlot must be of collator type", tag == value::TypeTags::collator);
+ auto collatorView = value::getCollatorView(collatorVal);
+ const value::MaterializedRowHasher hasher(collatorView);
+ const value::MaterializedRowEq equator(collatorView);
+ _ht.emplace(0, hasher, equator);
+ } else {
+ _ht.emplace();
}
while (_children[0]->getNext() == PlanState::ADVANCED) {
@@ -120,7 +138,7 @@ void HashAggStage::open(bool reOpen) {
key.reset(idx++, false, tag, val);
}
- auto [it, inserted] = _ht.try_emplace(std::move(key), value::MaterializedRow{0});
+ auto [it, inserted] = _ht->try_emplace(std::move(key), value::MaterializedRow{0});
if (inserted) {
// Copy keys.
const_cast<value::MaterializedRow&>(it->first).makeOwned();
@@ -138,19 +156,19 @@ void HashAggStage::open(bool reOpen) {
_children[0]->close();
- _htIt = _ht.end();
+ _htIt = _ht->end();
}
PlanState HashAggStage::getNext() {
auto optTimer(getOptTimer(_opCtx));
- if (_htIt == _ht.end()) {
- _htIt = _ht.begin();
+ if (_htIt == _ht->end()) {
+ _htIt = _ht->begin();
} else {
++_htIt;
}
- if (_htIt == _ht.end()) {
+ if (_htIt == _ht->end()) {
return trackPlanState(PlanState::IS_EOF);
}
@@ -185,7 +203,7 @@ void HashAggStage::close() {
auto optTimer(getOptTimer(_opCtx));
_commonStats.closes++;
- _ht.clear();
+ _ht = boost::none;
}
std::vector<DebugPrinter::Block> HashAggStage::debugPrint() const {
@@ -215,6 +233,10 @@ std::vector<DebugPrinter::Block> HashAggStage::debugPrint() const {
});
ret.emplace_back("`]");
+ if (_collatorSlot) {
+ DebugPrinter::addIdentifier(ret, *_collatorSlot);
+ }
+
DebugPrinter::addNewLine(ret);
DebugPrinter::addBlocks(ret, _children[0]->debugPrint());
diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.h b/src/mongo/db/exec/sbe/stages/hash_agg.h
index 2c2ec5ab6aa..06c0e211b31 100644
--- a/src/mongo/db/exec/sbe/stages/hash_agg.h
+++ b/src/mongo/db/exec/sbe/stages/hash_agg.h
@@ -43,6 +43,7 @@ public:
HashAggStage(std::unique_ptr<PlanStage> input,
value::SlotVector gbs,
value::SlotMap<std::unique_ptr<EExpression>> aggs,
+ boost::optional<value::SlotId> collatorSlot,
PlanNodeId planNodeId);
std::unique_ptr<PlanStage> clone() const final;
@@ -68,6 +69,7 @@ private:
const value::SlotVector _gbs;
const value::SlotMap<std::unique_ptr<EExpression>> _aggs;
+ const boost::optional<value::SlotId> _collatorSlot;
value::SlotAccessorMap _outAccessors;
std::vector<value::SlotAccessor*> _inKeyAccessors;
@@ -76,8 +78,10 @@ private:
std::vector<std::unique_ptr<HashAggAccessor>> _outAggAccessors;
std::vector<std::unique_ptr<vm::CodeFragment>> _aggCodes;
- // TODO SERVER-54025: Update HashAggStage so that group-bys are collation-aware.
- TableType _ht;
+ // Only set if collator slot provided on construction.
+ value::SlotAccessor* _collatorAccessor = nullptr;
+
+ boost::optional<TableType> _ht;
TableType::iterator _htIt;
vm::ByteCode _bytecode;
diff --git a/src/mongo/db/exec/sbe/stages/hash_join.cpp b/src/mongo/db/exec/sbe/stages/hash_join.cpp
index e92bbebdab4..190b8d7dd05 100644
--- a/src/mongo/db/exec/sbe/stages/hash_join.cpp
+++ b/src/mongo/db/exec/sbe/stages/hash_join.cpp
@@ -42,12 +42,14 @@ HashJoinStage::HashJoinStage(std::unique_ptr<PlanStage> outer,
value::SlotVector outerProjects,
value::SlotVector innerCond,
value::SlotVector innerProjects,
+ boost::optional<value::SlotId> collatorSlot,
PlanNodeId planNodeId)
: PlanStage("hj"_sd, planNodeId),
_outerCond(std::move(outerCond)),
_outerProjects(std::move(outerProjects)),
_innerCond(std::move(innerCond)),
_innerProjects(std::move(innerProjects)),
+ _collatorSlot(collatorSlot),
_probeKey(0) {
if (_outerCond.size() != _innerCond.size()) {
uasserted(4822823, "left and right size do not match");
@@ -64,6 +66,7 @@ std::unique_ptr<PlanStage> HashJoinStage::clone() const {
_outerProjects,
_innerCond,
_innerProjects,
+ _collatorSlot,
_commonStats.nodeId);
}
@@ -71,6 +74,13 @@ void HashJoinStage::prepare(CompileCtx& ctx) {
_children[0]->prepare(ctx);
_children[1]->prepare(ctx);
+ if (_collatorSlot) {
+ _collatorAccessor = getAccessor(ctx, *_collatorSlot);
+ tassert(5402502,
+ "collator accessor should exist if collator slot provided to HashJoinStage",
+ _collatorAccessor != nullptr);
+ }
+
size_t counter = 0;
value::SlotSet dupCheck;
for (auto& slot : _outerCond) {
@@ -121,6 +131,17 @@ value::SlotAccessor* HashJoinStage::getAccessor(CompileCtx& ctx, value::SlotId s
void HashJoinStage::open(bool reOpen) {
auto optTimer(getOptTimer(_opCtx));
+ if (_collatorAccessor) {
+ auto [tag, collatorVal] = _collatorAccessor->getViewOfValue();
+ uassert(5402504, "collatorSlot must be of collator type", tag == value::TypeTags::collator);
+ auto collatorView = value::getCollatorView(collatorVal);
+ const value::MaterializedRowHasher hasher(collatorView);
+ const value::MaterializedRowEq equator(collatorView);
+ _ht.emplace(0, hasher, equator);
+ } else {
+ _ht.emplace();
+ }
+
_commonStats.opens++;
_children[0]->open(reOpen);
// Insert the outer side into the hash table.
@@ -142,15 +163,15 @@ void HashJoinStage::open(bool reOpen) {
project.reset(idx++, true, tag, val);
}
- _ht.emplace(std::move(key), std::move(project));
+ _ht->emplace(std::move(key), std::move(project));
}
_children[0]->close();
_children[1]->open(reOpen);
- _htIt = _ht.end();
- _htItEnd = _ht.end();
+ _htIt = _ht->end();
+ _htItEnd = _ht->end();
}
PlanState HashJoinStage::getNext() {
@@ -175,7 +196,7 @@ PlanState HashJoinStage::getNext() {
_probeKey.reset(idx++, false, tag, val);
}
- auto [low, hi] = _ht.equal_range(_probeKey);
+ auto [low, hi] = _ht->equal_range(_probeKey);
_htIt = low;
_htItEnd = hi;
// If _htIt == _htItEnd (i.e. no match) then RIGHT and OUTER joins
@@ -191,6 +212,7 @@ void HashJoinStage::close() {
_commonStats.closes++;
_children[1]->close();
+ _ht = boost::none;
}
std::unique_ptr<PlanStageStats> HashJoinStage::getStats(bool includeDebugInfo) const {
@@ -207,6 +229,10 @@ const SpecificStats* HashJoinStage::getSpecificStats() const {
std::vector<DebugPrinter::Block> HashJoinStage::debugPrint() const {
auto ret = PlanStage::debugPrint();
+ if (_collatorSlot) {
+ DebugPrinter::addIdentifier(ret, *_collatorSlot);
+ }
+
ret.emplace_back(DebugPrinter::Block::cmdIncIndent);
DebugPrinter::addKeyword(ret, "left");
diff --git a/src/mongo/db/exec/sbe/stages/hash_join.h b/src/mongo/db/exec/sbe/stages/hash_join.h
index f40af886a02..05ec1ec294a 100644
--- a/src/mongo/db/exec/sbe/stages/hash_join.h
+++ b/src/mongo/db/exec/sbe/stages/hash_join.h
@@ -43,6 +43,7 @@ public:
value::SlotVector outerProjects,
value::SlotVector innerCond,
value::SlotVector innerProjects,
+ boost::optional<value::SlotId> collatorSlot,
PlanNodeId planNodeId);
std::unique_ptr<PlanStage> clone() const final;
@@ -70,11 +71,12 @@ private:
const value::SlotVector _outerProjects;
const value::SlotVector _innerCond;
const value::SlotVector _innerProjects;
+ const boost::optional<value::SlotId> _collatorSlot;
// All defined values from the outer side (i.e. they come from the hash table).
value::SlotAccessorMap _outOuterAccessors;
- // Accessors of input codition values (keys) that are being inserted into the hash table.
+ // Accessors of input condition values (keys) that are being inserted into the hash table.
std::vector<value::SlotAccessor*> _inOuterKeyAccessors;
// Accessors of output keys.
@@ -86,15 +88,16 @@ private:
// Accessors of output projections.
std::vector<std::unique_ptr<HashProjectAccessor>> _outOuterProjectAccessors;
- // Accessors of input codition values (keys) that are being inserted into the hash table.
+ // Accessors of input condition values (keys) that are being inserted into the hash table.
std::vector<value::SlotAccessor*> _inInnerKeyAccessors;
+ // Accessor for collator. Only set if collatorSlot provided during construction.
+ value::SlotAccessor* _collatorAccessor = nullptr;
+
// Key used to probe inside the hash table.
value::MaterializedRow _probeKey;
- // TODO SERVER-54025: Update HashJoinStage so that it's mechanism for matching outer keys and
- // inner keys is collation-aware.
- TableType _ht;
+ boost::optional<TableType> _ht;
TableType::iterator _htIt;
TableType::iterator _htItEnd;
diff --git a/src/mongo/db/exec/sbe/values/slot.h b/src/mongo/db/exec/sbe/values/slot.h
index b5abc318a9c..f7cc78601cd 100644
--- a/src/mongo/db/exec/sbe/values/slot.h
+++ b/src/mongo/db/exec/sbe/values/slot.h
@@ -508,11 +508,16 @@ private:
};
struct MaterializedRowEq {
+ using ComparatorType = StringData::ComparatorInterface*;
+
+ explicit MaterializedRowEq(const ComparatorType comparator = nullptr)
+ : _comparator(comparator) {}
+
bool operator()(const MaterializedRow& lhs, const MaterializedRow& rhs) const {
for (size_t idx = 0; idx < lhs.size(); ++idx) {
auto [lhsTag, lhsVal] = lhs.getViewOfValue(idx);
auto [rhsTag, rhsVal] = rhs.getViewOfValue(idx);
- auto [tag, val] = compareValue(lhsTag, lhsVal, rhsTag, rhsVal);
+ auto [tag, val] = compareValue(lhsTag, lhsVal, rhsTag, rhsVal, _comparator);
if (tag != value::TypeTags::NumberInt32 || value::bitcastTo<int32_t>(val) != 0) {
return false;
@@ -521,17 +526,27 @@ struct MaterializedRowEq {
return true;
}
+
+private:
+ const ComparatorType _comparator = nullptr;
};
struct MaterializedRowHasher {
+ using CollatorType = CollatorInterface*;
+
+ explicit MaterializedRowHasher(const CollatorType collator = nullptr) : _collator(collator) {}
+
std::size_t operator()(const MaterializedRow& k) const {
size_t res = hashInit();
for (size_t idx = 0; idx < k.size(); ++idx) {
auto [tag, val] = k.getViewOfValue(idx);
- res = hashCombine(res, hashValue(tag, val));
+ res = hashCombine(res, hashValue(tag, val, _collator));
}
return res;
}
+
+private:
+ const CollatorType _collator = nullptr;
};
/**
diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp
index 3ff4deb6abd..709fab540c7 100644
--- a/src/mongo/db/query/sbe_stage_builder.cpp
+++ b/src/mongo/db/query/sbe_stage_builder.cpp
@@ -1237,6 +1237,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
auto innerCondSlots = sbe::makeSV(innerIdSlot);
auto innerProjectSlots = sbe::makeSV(innerResultSlot);
+ auto collatorSlot = _data.env->getSlotIfExists("collator"_sd);
+
// Designate outputs.
PlanStageSlots outputs(reqs, &_slotIdGenerator);
if (reqs.has(kRecordId)) {
@@ -1252,6 +1254,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
outerProjectSlots,
innerCondSlots,
innerProjectSlots,
+ collatorSlot,
root->nodeId());
// If there are more than 2 children, iterate all remaining children and hash
@@ -1271,6 +1274,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
projectSlots,
innerCondSlots,
innerProjectSlots,
+ collatorSlot,
root->nodeId());
}
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 397fe5bdb0f..19e00bfb6be 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -1125,6 +1125,7 @@ public:
makeLimitTree(std::move(unionWithNullStage.stage), _context->planNodeId, numChildren);
// Create a group stage to aggregate elements into a single array.
+ auto collatorSlot = _context->runtimeEnvironment->getSlotIfExists("collator"_sd);
auto addToArrayExpr =
makeFunction("addToArray", sbe::makeE<sbe::EVariable>(unionWithNullSlot));
auto groupSlot = _context->slotIdGenerator->generate();
@@ -1132,6 +1133,7 @@ public:
sbe::makeS<sbe::HashAggStage>(std::move(limitNumChildren),
sbe::makeSV(),
sbe::makeEM(groupSlot, std::move(addToArrayExpr)),
+ collatorSlot,
_context->planNodeId);
EvalStage groupEvalStage = {std::move(groupStage), sbe::makeSV(groupSlot)};
@@ -1160,6 +1162,7 @@ public:
std::move(unwindEvalStage.stage),
sbe::makeSV(),
sbe::makeEM(finalGroupSlot, std::move(finalAddToArrayExpr)),
+ collatorSlot,
_context->planNodeId);
// Create a branch stage to select between the branch that produces one null if any eleemnts