summaryrefslogtreecommitdiff
path: root/src/mongo/db/query/optimizer/interval_simplify_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/query/optimizer/interval_simplify_test.cpp')
-rw-r--r--src/mongo/db/query/optimizer/interval_simplify_test.cpp177
1 files changed, 123 insertions, 54 deletions
diff --git a/src/mongo/db/query/optimizer/interval_simplify_test.cpp b/src/mongo/db/query/optimizer/interval_simplify_test.cpp
index 9454a896922..288ccec7c75 100644
--- a/src/mongo/db/query/optimizer/interval_simplify_test.cpp
+++ b/src/mongo/db/query/optimizer/interval_simplify_test.cpp
@@ -355,8 +355,12 @@ TEST_F(IntervalIntersection, VariableIntervals1) {
ASSERT_INTERVAL(
"{\n"
" {\n"
- " {[If [] BinaryOp [Gte] Variable [v2] Variable [v1] Const [maxKey] Variable [v1],"
- " If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2]]}\n"
+ " {[If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] BinaryOp [And] "
+ "BinaryOp [Lt] Variable [v2] Variable [v1] Const [true] BinaryOp [And] BinaryOp [Lt] "
+ "Variable [v2] Const [maxKey] Const [true] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] "
+ "Variable [v1] Variable [v2] BinaryOp [Lt] Variable [v2] Const [maxKey] Const [true] "
+ "BinaryOp [Lt] Variable [v2] Const [maxKey] BinaryOp [Gt] Variable [v1] Variable [v2] "
+ "Variable [v1] Const [maxKey], Variable [v1]]}\n"
" }\n"
" U \n"
" {\n"
@@ -404,16 +408,19 @@ TEST_F(IntervalIntersection, VariableIntervals3) {
ASSERT_INTERVAL(
"{\n"
" {\n"
- " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable "
- "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable "
- "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] "
- "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
- "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
+ " {[If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] BinaryOp [And] "
+ "BinaryOp [Lt] Variable [v2] Variable [v1] BinaryOp [Lt] Variable [v1] Variable [v4] "
+ "BinaryOp [And] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [Lte] Variable [v3] "
+ "Variable [v4] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] Variable [v1] Variable [v2] "
+ "BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [And] BinaryOp [Lt] Variable [v1] "
+ "Variable [v4] BinaryOp [Lte] Variable [v4] Variable [v3] BinaryOp [And] BinaryOp [Lt] "
+ "Variable [v1] Variable [v3] BinaryOp [Lte] Variable [v2] Variable [v4] BinaryOp [Gt] "
+ "Variable [v2] Variable [v1] Variable [v2] Const [maxKey], Variable [v2]]}\n"
" }\n"
" U \n"
" {\n"
- " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
- "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
+ " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2], "
+ "If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
" }\n"
"}\n",
*result);
@@ -434,25 +441,30 @@ TEST_F(IntervalIntersection, VariableIntervals4) {
ASSERT_INTERVAL(
"{\n"
" {\n"
- " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable "
- "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable "
- "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] "
- "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
- "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n"
+ " {[If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] BinaryOp [And] "
+ "BinaryOp [Lt] Variable [v2] Variable [v1] BinaryOp [Lt] Variable [v1] Variable [v4] "
+ "BinaryOp [And] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [Lt] Variable [v3] "
+ "Variable [v4] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] Variable [v1] Variable [v2] "
+ "BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [And] BinaryOp [Lt] Variable [v1] "
+ "Variable [v4] BinaryOp [Lt] Variable [v4] Variable [v3] BinaryOp [And] BinaryOp [Lt] "
+ "Variable [v1] Variable [v3] BinaryOp [Lt] Variable [v2] Variable [v4] BinaryOp [Gt] "
+ "Variable [v2] Variable [v1] Variable [v2] Const [maxKey], Variable [v2]]}\n"
" }\n"
" U \n"
" {\n"
- " {[If [] BinaryOp [Gte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] "
- "Variable [v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable "
- "[v3] Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] "
- "Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable "
- "[v4], If [] BinaryOp [Lte] Variable [v4] Variable [v3] Const [minKey] Variable "
- "[v3]]}\n"
+ " {[Variable [v3], If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] "
+ "BinaryOp [And] BinaryOp [Lt] Variable [v2] Variable [v1] BinaryOp [Lt] Variable [v1] "
+ "Variable [v4] BinaryOp [And] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [Lt] "
+ "Variable [v3] Variable [v4] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] Variable [v1] "
+ "Variable [v2] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [And] BinaryOp [Lt] "
+ "Variable [v1] Variable [v4] BinaryOp [Lt] Variable [v4] Variable [v3] BinaryOp [And] "
+ "BinaryOp [Lt] Variable [v1] Variable [v3] BinaryOp [Lt] Variable [v2] Variable [v4] "
+ "BinaryOp [Lt] Variable [v3] Variable [v4] Variable [v3] Const [minKey]]}\n"
" }\n"
" U \n"
" {\n"
- " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable "
- "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4])}\n"
+ " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2], "
+ "If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4])}\n"
" }\n"
"}\n",
*result);
@@ -630,13 +642,27 @@ bool compareIntervals(const IntervalReqExpr::Node& original,
return transport.computeInclusion(original) == transport.computeInclusion(simplified);
}
+void constFoldBounds(IntervalReqExpr::Node& node) {
+ for (auto& disjunct : node.cast<IntervalReqExpr::Disjunction>()->nodes()) {
+ for (auto& conjunct : disjunct.cast<IntervalReqExpr::Conjunction>()->nodes()) {
+ auto& interval = conjunct.cast<IntervalReqExpr::Atom>()->getExpr();
+ ABT lowABT = interval.getLowBound().getBound();
+ ABT highABT = interval.getHighBound().getBound();
+ ConstEval::constFold(lowABT);
+ ConstEval::constFold(highABT);
+ interval = IntervalRequirement{
+ BoundRequirement{interval.getLowBound().isInclusive(), std::move(lowABT)},
+ BoundRequirement{interval.getHighBound().isInclusive(), std::move(highABT)},
+ };
+ }
+ }
+}
+
/*
* Create two random intervals composed of constants and test intersection/union on them.
*/
template <int N>
void testIntervalPermutation(int permutation) {
- auto prefixId = PrefixId::createForTests();
-
const bool low1Inc = decode<2>(permutation);
const int low1 = decode<N>(permutation);
const bool high1Inc = decode<2>(permutation);
@@ -645,17 +671,33 @@ void testIntervalPermutation(int permutation) {
const int low2 = decode<N>(permutation);
const bool high2Inc = decode<2>(permutation);
const int high2 = decode<N>(permutation);
+ const bool useRealConstFold = decode<2>(permutation);
+
+ // This function can be passed as a substitute for the real constant folding function, to test
+ // that our simplification methods work when we cannot constant fold anything.
+ const auto noOpConstFold = [](ABT& n) {
+ // No-op.
+ };
// Test intersection.
{
auto original = _disj(
_conj(_interval({low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}),
_interval({low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)})));
- auto simplified = intersectDNFIntervals(original, ConstEval::constFold);
+ auto simplified = intersectDNFIntervals(
+ original, useRealConstFold ? ConstEval::constFold : noOpConstFold);
+
if (simplified) {
- // Since we are testing with constants, we should have at most one interval.
- ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*simplified));
- compareIntervals<N>(original, *simplified);
+ if (useRealConstFold) {
+ // Since we are testing with constants, we should have at most one interval as long
+ // as we use real constant folding.
+ ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*simplified));
+ } else {
+ // If we didn't use the real constant folding function, we have to constant fold
+ // now, because our bounds will have If's.
+ constFoldBounds(*simplified);
+ }
+ ASSERT(compareIntervals<N>(original, *simplified));
} else {
IntervalInclusionTransport<N> transport;
ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>());
@@ -715,36 +757,38 @@ void testIntervalFuzz(const uint64_t seed, PseudoRandom& threadLocalRNG) {
}
EvalVariables varEval(std::move(varMap));
- // Create between one and five intervals.
+ // Create three intervals.
+ constexpr size_t numIntervals = 3;
- // TODO SERVER-71656 we should be able to uncomment this after SERVER-71656 is complete.
// Intersect with multiple intervals.
- // const size_t numIntervals = 2 + threadLocalRNG.nextInt32(5);
- // {
- // IntervalReqExpr::NodeVector intervalVec;
- // for (size_t i = 0; i < numIntervals; i++) {
- // intervalVec.push_back(IntervalReqExpr::make<IntervalReqExpr::Atom>(
- // IntervalRequirement{makeRandomBound(), makeRandomBound()}));
- // }
-
- // auto original =
- // IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
- // IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(intervalVec))});
- // auto simplified = intersectDNFIntervals(original, ConstEval::constFold);
-
- // varEval.replaceVarsInInterval(original);
- // if (simplified) {
- // varEval.replaceVarsInInterval(*simplified);
- // compareIntervals<N>(original, *simplified);
- // } else {
- // IntervalInclusionTransport<N> transport;
- // ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>());
- // }
- // }
+ {
+ IntervalReqExpr::NodeVector intervalVec;
+ for (size_t i = 0; i < numIntervals; i++) {
+ intervalVec.push_back(IntervalReqExpr::make<IntervalReqExpr::Atom>(
+ IntervalRequirement{makeRandomBound<N, true>(threadLocalRNG, vars),
+ makeRandomBound<N, false>(threadLocalRNG, vars)}));
+ }
+
+ auto original =
+ IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(intervalVec))});
+ auto simplified = intersectDNFIntervals(original, ConstEval::constFold);
+
+ varEval.replaceVarsInInterval(original);
+ if (simplified) {
+ varEval.replaceVarsInInterval(*simplified);
+ ASSERT(compareIntervals<N>(original, *simplified));
+ } else {
+ // 'simplified' false means the simplified interval is empty (always-false) and can't be
+ // represented as a BoolExpr, so check that 'original' returns false on every example.
+ IntervalInclusionTransport<N> transport;
+ ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>());
+ }
+ }
// TODO SERVER-71175 should include a union test here
- // TODO SERVER-71656 should extend this to have DNF intervals that are not purely disjunctions
+ // TODO SERVER-71175 should extend this to have DNF intervals that are not purely disjunctions
// or purely conjunctions. A mix of ORs and ANDs seems necessary, to test that the two
// simplification methods (intersecting and unioning) don't interfere with each other.
}
@@ -753,8 +797,12 @@ static constexpr int bitsetSize = 10;
static const size_t numThreads = ProcessInfo::getNumCores();
TEST_F(IntervalIntersection, IntervalPermutations) {
+ // Number of permutations is bitsetSize^4 * 2^4 * 2
+ // The first term is needed because we generate four bounds to intersect two intervals. The
+ // second term is for the inclusivity of the four bounds. The third term is to determine if we
+ // test with real constant folding or a no-op constant folding function.
static constexpr int numPermutations =
- bitsetSize * bitsetSize * bitsetSize * bitsetSize * 2 * 2 * 2 * 2;
+ (bitsetSize * bitsetSize * bitsetSize * bitsetSize) * (2 * 2 * 2 * 2) * 2;
/**
* Test for interval intersection. Generate intervals with constants in the
* range of [0, N), with random inclusion/exclusion of the endpoints. Intersect the intervals
@@ -816,5 +864,26 @@ TEST_F(IntervalIntersection, IntervalFuzz) {
std::cout << "...done. Took: " << elapsedFuzz << " s.\n";
}
+TEST(IntervalIntersection, IntersectionSpecialCase) {
+ auto original = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ {true, make<Variable>("var1")}, {true, make<Variable>("var1")}}),
+ IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
+ {false, make<Variable>("var2")}, {false, make<Variable>("var3")}})})});
+
+ auto simplified = intersectDNFIntervals(original, ConstEval::constFold);
+ ASSERT(simplified);
+
+ EvalVariables varEval({
+ {"var1", Constant::int32(3)},
+ {"var2", Constant::int32(0)},
+ {"var3", Constant::int32(3)},
+ });
+ varEval.replaceVarsInInterval(original);
+ varEval.replaceVarsInInterval(*simplified);
+ ASSERT(compareIntervals<bitsetSize>(original, *simplified));
+}
+
} // namespace
} // namespace mongo::optimizer