diff options
-rw-r--r-- | src/mongo/db/query/optimizer/SConscript | 4 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/interval_simplify_test.cpp (renamed from src/mongo/db/query/optimizer/interval_intersection_test.cpp) | 317 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/utils/interval_utils.h | 2 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h | 12 |
4 files changed, 277 insertions, 58 deletions
diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript index bc551c21af4..fd8661d9ec3 100644 --- a/src/mongo/db/query/optimizer/SConscript +++ b/src/mongo/db/query/optimizer/SConscript @@ -140,9 +140,9 @@ env.CppUnitTest( ) env.CppUnitTest( - target='interval_intersection_test', + target='interval_simplify_test', source=[ - "interval_intersection_test.cpp", + "interval_simplify_test.cpp", ], LIBDEPS=[ "optimizer", diff --git a/src/mongo/db/query/optimizer/interval_intersection_test.cpp b/src/mongo/db/query/optimizer/interval_simplify_test.cpp index 34c7fd0a105..31475035ffe 100644 --- a/src/mongo/db/query/optimizer/interval_intersection_test.cpp +++ b/src/mongo/db/query/optimizer/interval_simplify_test.cpp @@ -43,6 +43,7 @@ namespace mongo::optimizer { namespace { +using namespace unit_test_abt_literals; ABT optimizedQueryPlan(const std::string& query, const opt::unordered_map<std::string, IndexDefinition>& indexes) { @@ -341,8 +342,6 @@ TEST(IntervalIntersection, MultiFieldIntersection) { } TEST(IntervalIntersection, VariableIntervals1) { - using namespace unit_test_abt_literals; - auto interval = _disj( _conj(_interval(_incl("v1"_var), _plusInf()), _interval(_excl("v2"_var), _plusInf()))); @@ -370,8 +369,6 @@ TEST(IntervalIntersection, VariableIntervals1) { } TEST(IntervalIntersection, VariableIntervals2) { - using namespace unit_test_abt_literals; - auto interval = _disj(_conj(_interval(_incl("v1"_var), _incl("v3"_var)), _interval(_incl("v2"_var), _incl("v4"_var)))); @@ -395,8 +392,6 @@ TEST(IntervalIntersection, VariableIntervals2) { } TEST(IntervalIntersection, VariableIntervals3) { - using namespace unit_test_abt_literals; - auto interval = _disj(_conj(_interval(_excl("v1"_var), _incl("v3"_var)), _interval(_incl("v2"_var), _incl("v4"_var)))); @@ -427,8 +422,6 @@ TEST(IntervalIntersection, VariableIntervals3) { } TEST(IntervalIntersection, VariableIntervals4) { - using namespace unit_test_abt_literals; - auto interval = _disj(_conj(_interval(_excl("v1"_var), _incl("v3"_var)), _interval(_incl("v2"_var), _excl("v4"_var)))); @@ -467,37 +460,90 @@ TEST(IntervalIntersection, VariableIntervals4) { ASSERT_TRUE(*result == *result1); } +/* + * Bitset with extra flags to indicate whether MinKey and MaxKey are included. + * The first two bits are MinKey and MaxKey, the rest represent integers [0, N). + */ template <int N> -void updateResults(const bool lowInc, - const int low, - const bool highInc, - const int high, - std::bitset<N>& inclusion) { - for (int v = 0; v < low + (lowInc ? 0 : 1); v++) { - inclusion.set(v, false); +class ExtendedBitset { +public: + ExtendedBitset() {} + + void set(const int i, const bool b) { + invariant(i >= 0 && i < N); + _b.set(i + 2, b); } - for (int v = high + (highInc ? 1 : 0); v < N; v++) { - inclusion.set(v, false); + + static ExtendedBitset<N> minKey() { + ExtendedBitset<N> b; + b._b.set(0); + return b; } -} + static ExtendedBitset<N> maxKey() { + ExtendedBitset<N> b; + b._b.set(1); + return b; + } + + ExtendedBitset& operator&=(const ExtendedBitset& rhs) { + _b &= rhs._b; + return *this; + } + + ExtendedBitset& operator|=(const ExtendedBitset& rhs) { + _b |= rhs._b; + return *this; + } + + bool operator==(const ExtendedBitset& rhs) const { + return _b == rhs._b; + } + + bool operator!=(const ExtendedBitset& rhs) const { + return !(*this == rhs); + } + +private: + std::bitset<N + 2> _b; +}; + +/* + * Calculates the extended bitset of a given interval in any form (not just DNF). + */ template <int N> class IntervalInclusionTransport { public: - using ResultType = std::bitset<N>; + using ResultType = ExtendedBitset<N>; ResultType transport(const IntervalReqExpr::Atom& node) { const auto& expr = node.getExpr(); const auto& lb = expr.getLowBound(); const auto& hb = expr.getHighBound(); - std::bitset<N> result; - result.flip(); - updateResults<N>(lb.isInclusive(), - lb.getBound().cast<Constant>()->getValueInt32(), - hb.isInclusive(), - hb.getBound().cast<Constant>()->getValueInt32(), - result); + ExtendedBitset<N> result; + if (lb.getBound() == Constant::maxKey() || hb.getBound() == Constant::minKey()) { + return result; + } + + int lbInt = 0; + if (lb.getBound() == Constant::minKey()) { + result |= ExtendedBitset<N>::minKey(); + } else { + lbInt = lb.getBound().cast<Constant>()->getValueInt32() + (lb.isInclusive() ? 0 : 1); + } + + int hbInt = N; + if (hb.getBound() == Constant::maxKey()) { + result |= ExtendedBitset<N>::maxKey(); + } else { + hbInt = hb.getBound().cast<Constant>()->getValueInt32() + (hb.isInclusive() ? 1 : 0); + } + + for (int v = lbInt; v < hbInt; v++) { + result.set(v, true); + } + return result; } @@ -522,6 +568,51 @@ public: } }; +/* + * Replaces variables with their value in the given varMap. + */ +class EvalVariables { +public: + EvalVariables(ProjectionNameMap<ABT> varMap) : _varMap(std::move(varMap)) {} + + void transport(ABT& n, const Variable& node) { + const auto it = _varMap.find(ProjectionName(node.name().value())); + if (it != _varMap.end()) { + n = it->second; + } + } + + template <typename T, typename... Ts> + void transport(ABT& /*n*/, const T& /*node*/, Ts&&...) { + invariant((std::is_base_of_v<If, T> || std::is_base_of_v<BinaryOp, T> || + std::is_base_of_v<UnaryOp, T> || std::is_base_of_v<Constant, T> || + std::is_base_of_v<Variable, T>)); + } + + void evalVars(ABT& n) { + algebra::transport<true>(n, *this); + ConstEval::constFold(n); + invariant(n.is<Constant>()); + } + + void replaceVarsInInterval(IntervalReqExpr::Node& interval) { + for (auto& disjunct : interval.cast<IntervalReqExpr::Disjunction>()->nodes()) { + for (auto& conjunct : disjunct.cast<IntervalReqExpr::Conjunction>()->nodes()) { + auto& interval = conjunct.cast<IntervalReqExpr::Atom>()->getExpr(); + ABT lowBound = interval.getLowBound().getBound(); + ABT highBound = interval.getHighBound().getBound(); + evalVars(lowBound); + evalVars(highBound); + interval = {{interval.getLowBound().isInclusive(), std::move(lowBound)}, + {interval.getHighBound().isInclusive(), std::move(highBound)}}; + } + } + } + +private: + ProjectionNameMap<ABT> _varMap; +}; + template <int V> int decode(int& permutation) { const int result = permutation % V; @@ -530,7 +621,19 @@ int decode(int& permutation) { } template <int N> -void testInterval(int permutation) { +bool compareIntervals(const IntervalReqExpr::Node& original, + const IntervalReqExpr::Node& simplified) { + IntervalInclusionTransport<N> transport; + return transport.computeInclusion(original) == transport.computeInclusion(simplified); +} + +/* + * Create two random intervals composed of constants and test intersection/union on them. + */ +template <int N> +void testIntervalPermutation(int permutation) { + auto prefixId = PrefixId::createForTests(); + const bool low1Inc = decode<2>(permutation); const int low1 = decode<N>(permutation); const bool high1Inc = decode<2>(permutation); @@ -540,44 +643,120 @@ void testInterval(int permutation) { const bool high2Inc = decode<2>(permutation); const int high2 = decode<N>(permutation); - auto interval = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ - IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{ - IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ - {low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}}), - IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{ - {low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)}})})}); - - auto result = intersectDNFIntervals(interval, ConstEval::constFold); - std::bitset<N> inclusionActual; - if (result) { - // Since we are testing with constants, we should have at most one interval. - ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*result)); - - IntervalInclusionTransport<N> transport; - // Compute actual inclusion bitset based on the interval intersection code. - inclusionActual = transport.computeInclusion(*result); + // Test intersection. + { + auto original = _disj( + _conj(_interval({low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}), + _interval({low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)}))); + auto simplified = intersectDNFIntervals(original, ConstEval::constFold); + if (simplified) { + // Since we are testing with constants, we should have at most one interval. + ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*simplified)); + compareIntervals<N>(original, *simplified); + } else { + IntervalInclusionTransport<N> transport; + ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>()); + } } +} - std::bitset<N> inclusionExpected; - inclusionExpected.flip(); +// Generates a random integer bound. If isLow is true, lower values are more likely. If isLow is +// false, higher values are more likely. +template <int N, bool isLow> +ABT makeRandomIntBound(PseudoRandom& threadLocalRNG) { + // This is a trick to create a skewed distribution on [0, N). Say N=3, + // potential values of r = 0 1 2 3 4 5 6 7 8 + // (int) sqrt(r) = 0 1 1 1 2 2 2 2 2 + // The higher the number is (as long as its <N), the more likely it is to occur. + const int r = threadLocalRNG.nextInt32(N * N); + const int bound = (int)std::sqrt(r); + invariant(0 <= bound && bound < N); + return Constant::int32(isLow ? N - 1 - bound : bound); +} - // Compute ground truth. - updateResults<N>(low1Inc, low1, high1Inc, high1, inclusionExpected); - updateResults<N>(low2Inc, low2, high2Inc, high2, inclusionExpected); +template <int N, bool isLow> +BoundRequirement makeRandomBound(PseudoRandom& threadLocalRNG, + const std::vector<ProjectionName>& vars) { + const bool isInclusive = threadLocalRNG.nextInt32(2); + // We can return one of: N+2 constants (+2 because of minkey and maxkey), or 8 variables. + const int r = threadLocalRNG.nextInt32(N + 10); + if (r == 0) { + return {isInclusive, Constant::minKey()}; + } else if (r == 1) { + return {isInclusive, Constant::maxKey()}; + } else if (r < N + 2) { + return {isInclusive, makeRandomIntBound<N, isLow>(threadLocalRNG)}; + } else { + return {isInclusive, make<Variable>(vars.at(r - (N + 2)))}; + } +}; - ASSERT_EQ(inclusionExpected, inclusionActual); +template <int N> +void testIntervalFuzz(const uint64_t seed, PseudoRandom& threadLocalRNG) { + // Generate values for the eight variables we have. + auto prefixId = PrefixId::createForTests(); + ProjectionNameMap<ABT> varMap; + std::vector<ProjectionName> vars; + for (size_t i = 0; i < 8; i++) { + // minkey=0, maxkey=1, anything else is a constant + const int type = threadLocalRNG.nextInt32(N + 2); + ABT val = Constant::int32(type - 2); + if (type == 0) { + val = Constant::minKey(); + } else if (type == 1) { + val = Constant::maxKey(); + } + ProjectionName var = prefixId.getNextId("var"); + varMap.emplace(var.value(), val); + vars.push_back(var); + } + EvalVariables varEval(std::move(varMap)); + + // Create between one and five intervals. + + // TODO SERVER-71656 we should be able to uncomment this after SERVER-71656 is complete. + // Intersect with multiple intervals. + // const size_t numIntervals = 2 + threadLocalRNG.nextInt32(5); + // { + // IntervalReqExpr::NodeVector intervalVec; + // for (size_t i = 0; i < numIntervals; i++) { + // intervalVec.push_back(IntervalReqExpr::make<IntervalReqExpr::Atom>( + // IntervalRequirement{makeRandomBound(), makeRandomBound()})); + // } + + // auto original = + // IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{ + // IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(intervalVec))}); + // auto simplified = intersectDNFIntervals(original, ConstEval::constFold); + + // varEval.replaceVarsInInterval(original); + // if (simplified) { + // varEval.replaceVarsInInterval(*simplified); + // compareIntervals<N>(original, *simplified); + // } else { + // IntervalInclusionTransport<N> transport; + // ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>()); + // } + // } + + // TODO SERVER-71175 should include a union test here + + // TODO SERVER-71656 should extend this to have DNF intervals that are not purely disjunctions + // or purely conjunctions. A mix of ORs and ANDs seems necessary, to test that the two + // simplification methods (intersecting and unioning) don't interfere with each other. } -TEST(IntervalIntersection, IntervalPermutations) { - static constexpr int N = 10; - static constexpr int numPermutations = N * N * N * N * 2 * 2 * 2 * 2; +static constexpr int bitsetSize = 10; +static const size_t numThreads = ProcessInfo::getNumCores(); +TEST(IntervalIntersection, IntervalPermutations) { + static constexpr int numPermutations = + bitsetSize * bitsetSize * bitsetSize * bitsetSize * 2 * 2 * 2 * 2; /** * Test for interval intersection. Generate intervals with constants in the * range of [0, N), with random inclusion/exclusion of the endpoints. Intersect the intervals * and verify against ground truth. */ - const size_t numThreads = ProcessInfo::getNumCores(); std::cout << "Testing " << numPermutations << " interval permutations using " << numThreads << " cores...\n"; auto timeBegin = Date_t::now(); @@ -591,7 +770,7 @@ TEST(IntervalIntersection, IntervalPermutations) { if (nextP >= numPermutations) { break; } - testInterval<N>(nextP); + testIntervalPermutation<bitsetSize>(nextP); } }); } @@ -604,5 +783,35 @@ TEST(IntervalIntersection, IntervalPermutations) { std::cout << "...done. Took: " << elapsed << " s.\n"; } +TEST(IntervalIntersection, IntervalFuzz) { + static constexpr int numFuzzTests = 5000; + /** + * Generate random intervals with a mix of variables and constants, and test that they intersect + * and union correctly. + */ + std::cout << "Testing " << numFuzzTests << " fuzzed intervals using " << numThreads + << " cores...\n"; + const auto timeBeginFuzz = Date_t::now(); + + std::vector<stdx::thread> threads; + for (size_t i = 0; i < numThreads; i++) { + threads.emplace_back([]() { + const auto seed = SecureRandom().nextInt64(); + std::cout << "Using random seed: " << seed << "\n"; + PseudoRandom threadLocalRNG(seed); + for (size_t i = 0; i < numFuzzTests / numThreads; i++) { + testIntervalFuzz<bitsetSize>(seed, threadLocalRNG); + } + }); + } + for (auto& thread : threads) { + thread.join(); + } + + const auto elapsedFuzz = + (Date_t::now().toMillisSinceEpoch() - timeBeginFuzz.toMillisSinceEpoch()) / 1000.0; + std::cout << "...done. Took: " << elapsedFuzz << " s.\n"; +} + } // namespace } // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/interval_utils.h b/src/mongo/db/query/optimizer/utils/interval_utils.h index 90156e0aeb7..7b629174506 100644 --- a/src/mongo/db/query/optimizer/utils/interval_utils.h +++ b/src/mongo/db/query/optimizer/utils/interval_utils.h @@ -31,6 +31,8 @@ #include "mongo/db/query/optimizer/index_bounds.h" +#include "mongo/db/query/optimizer/utils/utils.h" + namespace mongo::optimizer { /** diff --git a/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h b/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h index 607765d46ce..a6b3ddafba4 100644 --- a/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h +++ b/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h @@ -419,18 +419,26 @@ private: /** * Interval expressions. */ +inline auto _disj(IntervalReqExpr::NodeVector v) { + return IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(v)); +} + template <typename... Ts> inline auto _disj(Ts&&... pack) { IntervalReqExpr::NodeVector v; (v.push_back(std::forward<Ts>(pack)), ...); - return IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(v)); + return _disj(std::move(v)); +} + +inline auto _conj(IntervalReqExpr::NodeVector v) { + return IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(v)); } template <typename... Ts> inline auto _conj(Ts&&... pack) { IntervalReqExpr::NodeVector v; (v.push_back(std::forward<Ts>(pack)), ...); - return IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(v)); + return _conj(std::move(v)); } inline auto _interval(IntervalRequirement req) { |