diff options
Diffstat (limited to 'src/mongo/db/query/optimizer/interval_simplify_test.cpp')
-rw-r--r-- | src/mongo/db/query/optimizer/interval_simplify_test.cpp | 177 |
1 files changed, 123 insertions, 54 deletions
diff --git a/src/mongo/db/query/optimizer/interval_simplify_test.cpp b/src/mongo/db/query/optimizer/interval_simplify_test.cpp index 9454a896922..288ccec7c75 100644 --- a/src/mongo/db/query/optimizer/interval_simplify_test.cpp +++ b/src/mongo/db/query/optimizer/interval_simplify_test.cpp @@ -355,8 +355,12 @@ TEST_F(IntervalIntersection, VariableIntervals1) { ASSERT_INTERVAL( "{\n" " {\n" - " {[If [] BinaryOp [Gte] Variable [v2] Variable [v1] Const [maxKey] Variable [v1]," - " If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2]]}\n" + " {[If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] BinaryOp [And] " + "BinaryOp [Lt] Variable [v2] Variable [v1] Const [true] BinaryOp [And] BinaryOp [Lt] " + "Variable [v2] Const [maxKey] Const [true] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] " + "Variable [v1] Variable [v2] BinaryOp [Lt] Variable [v2] Const [maxKey] Const [true] " + "BinaryOp [Lt] Variable [v2] Const [maxKey] BinaryOp [Gt] Variable [v1] Variable [v2] " + "Variable [v1] Const [maxKey], Variable [v1]]}\n" " }\n" " U \n" " {\n" @@ -404,16 +408,19 @@ TEST_F(IntervalIntersection, VariableIntervals3) { ASSERT_INTERVAL( "{\n" " {\n" - " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable " - "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable " - "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] " - "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " - "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" + " {[If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] BinaryOp [And] " + "BinaryOp [Lt] Variable [v2] Variable [v1] BinaryOp [Lt] Variable [v1] Variable [v4] " + "BinaryOp [And] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [Lte] Variable [v3] " + "Variable [v4] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] Variable [v1] Variable [v2] " + "BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [And] BinaryOp [Lt] Variable [v1] " + "Variable [v4] BinaryOp [Lte] Variable [v4] Variable [v3] BinaryOp [And] BinaryOp [Lt] " + "Variable [v1] Variable [v3] BinaryOp [Lte] Variable [v2] Variable [v4] BinaryOp [Gt] " + "Variable [v2] Variable [v1] Variable [v2] Const [maxKey], Variable [v2]]}\n" " }\n" " U \n" " {\n" - " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " - "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" + " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2], " + "If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" " }\n" "}\n", *result); @@ -434,25 +441,30 @@ TEST_F(IntervalIntersection, VariableIntervals4) { ASSERT_INTERVAL( "{\n" " {\n" - " {[If [] BinaryOp [Gte] Variable [v1] Variable [v2] Const [maxKey] Variable " - "[v2], If [] BinaryOp [Lte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable " - "[v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] " - "Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " - "[v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4]]}\n" + " {[If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] BinaryOp [And] " + "BinaryOp [Lt] Variable [v2] Variable [v1] BinaryOp [Lt] Variable [v1] Variable [v4] " + "BinaryOp [And] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [Lt] Variable [v3] " + "Variable [v4] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] Variable [v1] Variable [v2] " + "BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [And] BinaryOp [Lt] Variable [v1] " + "Variable [v4] BinaryOp [Lt] Variable [v4] Variable [v3] BinaryOp [And] BinaryOp [Lt] " + "Variable [v1] Variable [v3] BinaryOp [Lt] Variable [v2] Variable [v4] BinaryOp [Gt] " + "Variable [v2] Variable [v1] Variable [v2] Const [maxKey], Variable [v2]]}\n" " }\n" " U \n" " {\n" - " {[If [] BinaryOp [Gte] If [] BinaryOp [Gte] Variable [v1] Variable [v2] " - "Variable [v1] Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable " - "[v3] Variable [v4] If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] " - "Variable [v2] If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable " - "[v4], If [] BinaryOp [Lte] Variable [v4] Variable [v3] Const [minKey] Variable " - "[v3]]}\n" + " {[Variable [v3], If [] BinaryOp [And] BinaryOp [And] BinaryOp [Or] BinaryOp [Or] " + "BinaryOp [And] BinaryOp [Lt] Variable [v2] Variable [v1] BinaryOp [Lt] Variable [v1] " + "Variable [v4] BinaryOp [And] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [Lt] " + "Variable [v3] Variable [v4] BinaryOp [Or] BinaryOp [And] BinaryOp [Lt] Variable [v1] " + "Variable [v2] BinaryOp [Lte] Variable [v2] Variable [v3] BinaryOp [And] BinaryOp [Lt] " + "Variable [v1] Variable [v4] BinaryOp [Lt] Variable [v4] Variable [v3] BinaryOp [And] " + "BinaryOp [Lt] Variable [v1] Variable [v3] BinaryOp [Lt] Variable [v2] Variable [v4] " + "BinaryOp [Lt] Variable [v3] Variable [v4] Variable [v3] Const [minKey]]}\n" " }\n" " U \n" " {\n" - " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable " - "[v2], If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4])}\n" + " {(If [] BinaryOp [Gte] Variable [v1] Variable [v2] Variable [v1] Variable [v2], " + "If [] BinaryOp [Lte] Variable [v3] Variable [v4] Variable [v3] Variable [v4])}\n" " }\n" "}\n", *result); @@ -630,13 +642,27 @@ bool compareIntervals(const IntervalReqExpr::Node& original, return transport.computeInclusion(original) == transport.computeInclusion(simplified); } +void constFoldBounds(IntervalReqExpr::Node& node) { + for (auto& disjunct : node.cast<IntervalReqExpr::Disjunction>()->nodes()) { + for (auto& conjunct : disjunct.cast<IntervalReqExpr::Conjunction>()->nodes()) { + auto& interval = conjunct.cast<IntervalReqExpr::Atom>()->getExpr(); + ABT lowABT = interval.getLowBound().getBound(); + ABT highABT = interval.getHighBound().getBound(); + ConstEval::constFold(lowABT); + ConstEval::constFold(highABT); + interval = IntervalRequirement{ + BoundRequirement{interval.getLowBound().isInclusive(), std::move(lowABT)}, + BoundRequirement{interval.getHighBound().isInclusive(), std::move(highABT)}, + }; + } + } +} + /* * Create two random intervals composed of constants and test intersection/union on them. */ template <int N> void testIntervalPermutation(int permutation) { - auto prefixId = PrefixId::createForTests(); - const bool low1Inc = decode<2>(permutation); const int low1 = decode<N>(permutation); const bool high1Inc = decode<2>(permutation); @@ -645,17 +671,33 @@ void testIntervalPermutation(int permutation) { const int low2 = decode<N>(permutation); const bool high2Inc = decode<2>(permutation); const int high2 = decode<N>(permutation); + const bool useRealConstFold = decode<2>(permutation); + + // This function can be passed as a substitute for the real constant folding function, to test + // that our simplification methods work when we cannot constant fold anything. + const auto noOpConstFold = [](ABT& n) { + // No-op. + }; // Test intersection. { auto original = _disj( _conj(_interval({low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}), _interval({low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)}))); - auto simplified = intersectDNFIntervals(original, ConstEval::constFold); + auto simplified = intersectDNFIntervals( + original, useRealConstFold ? ConstEval::constFold : noOpConstFold); + if (simplified) { - // Since we are testing with constants, we should have at most one interval. - ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*simplified)); - compareIntervals<N>(original, *simplified); + if (useRealConstFold) { + // Since we are testing with constants, we should have at most one interval as long + // as we use real constant folding. + ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*simplified)); + } else { + // If we didn't use the real constant folding function, we have to constant fold + // now, because our bounds will have If's. + constFoldBounds(*simplified); + } + ASSERT(compareIntervals<N>(original, *simplified)); } else { IntervalInclusionTransport<N> transport; ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>()); @@ -715,36 +757,38 @@ void testIntervalFuzz(const uint64_t seed, PseudoRandom& threadLocalRNG) { } EvalVariables varEval(std::move(varMap)); - // Create between one and five intervals. + // Create three intervals. + constexpr size_t numIntervals = 3; - // TODO SERVER-71656 we should be able to uncomment this after SERVER-71656 is complete. // Intersect with multiple intervals. - // const size_t numIntervals = 2 + threadLocalRNG.nextInt32(5); - // { - // IntervalReqExpr::NodeVector intervalVec; - // for (size_t i = 0; i < numIntervals; i++) { - // intervalVec.push_back(IntervalReqExpr::make<IntervalReqExpr::Atom>( - // IntervalRequirement{makeRandomBound(), makeRandomBound()})); - // } - - // auto original = - // IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ - // IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(intervalVec))}); - // auto simplified = intersectDNFIntervals(original, ConstEval::constFold); - - // varEval.replaceVarsInInterval(original); - // if (simplified) { - // varEval.replaceVarsInInterval(*simplified); - // compareIntervals<N>(original, *simplified); - // } else { - // IntervalInclusionTransport<N> transport; - // ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>()); - // } - // } + { + IntervalReqExpr::NodeVector intervalVec; + for (size_t i = 0; i < numIntervals; i++) { + intervalVec.push_back(IntervalReqExpr::make<IntervalReqExpr::Atom>( + IntervalRequirement{makeRandomBound<N, true>(threadLocalRNG, vars), + makeRandomBound<N, false>(threadLocalRNG, vars)})); + } + + auto original = + IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(intervalVec))}); + auto simplified = intersectDNFIntervals(original, ConstEval::constFold); + + varEval.replaceVarsInInterval(original); + if (simplified) { + varEval.replaceVarsInInterval(*simplified); + ASSERT(compareIntervals<N>(original, *simplified)); + } else { + // 'simplified' false means the simplified interval is empty (always-false) and can't be + // represented as a BoolExpr, so check that 'original' returns false on every example. + IntervalInclusionTransport<N> transport; + ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>()); + } + } // TODO SERVER-71175 should include a union test here - // TODO SERVER-71656 should extend this to have DNF intervals that are not purely disjunctions + // TODO SERVER-71175 should extend this to have DNF intervals that are not purely disjunctions // or purely conjunctions. A mix of ORs and ANDs seems necessary, to test that the two // simplification methods (intersecting and unioning) don't interfere with each other. } @@ -753,8 +797,12 @@ static constexpr int bitsetSize = 10; static const size_t numThreads = ProcessInfo::getNumCores(); TEST_F(IntervalIntersection, IntervalPermutations) { + // Number of permutations is bitsetSize^4 * 2^4 * 2 + // The first term is needed because we generate four bounds to intersect two intervals. The + // second term is for the inclusivity of the four bounds. The third term is to determine if we + // test with real constant folding or a no-op constant folding function. static constexpr int numPermutations = - bitsetSize * bitsetSize * bitsetSize * bitsetSize * 2 * 2 * 2 * 2; + (bitsetSize * bitsetSize * bitsetSize * bitsetSize) * (2 * 2 * 2 * 2) * 2; /** * Test for interval intersection. Generate intervals with constants in the * range of [0, N), with random inclusion/exclusion of the endpoints. Intersect the intervals @@ -816,5 +864,26 @@ TEST_F(IntervalIntersection, IntervalFuzz) { std::cout << "...done. Took: " << elapsedFuzz << " s.\n"; } +TEST(IntervalIntersection, IntersectionSpecialCase) { + auto original = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{ + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + {true, make<Variable>("var1")}, {true, make<Variable>("var1")}}), + IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ + {false, make<Variable>("var2")}, {false, make<Variable>("var3")}})})}); + + auto simplified = intersectDNFIntervals(original, ConstEval::constFold); + ASSERT(simplified); + + EvalVariables varEval({ + {"var1", Constant::int32(3)}, + {"var2", Constant::int32(0)}, + {"var3", Constant::int32(3)}, + }); + varEval.replaceVarsInInterval(original); + varEval.replaceVarsInInterval(*simplified); + ASSERT(compareIntervals<bitsetSize>(original, *simplified)); +} + } // namespace } // namespace mongo::optimizer |