summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/db/query/optimizer/SConscript4
-rw-r--r--src/mongo/db/query/optimizer/interval_simplify_test.cpp (renamed from src/mongo/db/query/optimizer/interval_intersection_test.cpp)317
-rw-r--r--src/mongo/db/query/optimizer/utils/interval_utils.h2
-rw-r--r--src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h12
4 files changed, 277 insertions, 58 deletions
diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript
index bc551c21af4..fd8661d9ec3 100644
--- a/src/mongo/db/query/optimizer/SConscript
+++ b/src/mongo/db/query/optimizer/SConscript
@@ -140,9 +140,9 @@ env.CppUnitTest(
)
env.CppUnitTest(
- target='interval_intersection_test',
+ target='interval_simplify_test',
source=[
- "interval_intersection_test.cpp",
+ "interval_simplify_test.cpp",
],
LIBDEPS=[
"optimizer",
diff --git a/src/mongo/db/query/optimizer/interval_intersection_test.cpp b/src/mongo/db/query/optimizer/interval_simplify_test.cpp
index 34c7fd0a105..31475035ffe 100644
--- a/src/mongo/db/query/optimizer/interval_intersection_test.cpp
+++ b/src/mongo/db/query/optimizer/interval_simplify_test.cpp
@@ -43,6 +43,7 @@
namespace mongo::optimizer {
namespace {
+using namespace unit_test_abt_literals;
ABT optimizedQueryPlan(const std::string& query,
const opt::unordered_map<std::string, IndexDefinition>& indexes) {
@@ -341,8 +342,6 @@ TEST(IntervalIntersection, MultiFieldIntersection) {
}
TEST(IntervalIntersection, VariableIntervals1) {
- using namespace unit_test_abt_literals;
-
auto interval = _disj(
_conj(_interval(_incl("v1"_var), _plusInf()), _interval(_excl("v2"_var), _plusInf())));
@@ -370,8 +369,6 @@ TEST(IntervalIntersection, VariableIntervals1) {
}
TEST(IntervalIntersection, VariableIntervals2) {
- using namespace unit_test_abt_literals;
-
auto interval = _disj(_conj(_interval(_incl("v1"_var), _incl("v3"_var)),
_interval(_incl("v2"_var), _incl("v4"_var))));
@@ -395,8 +392,6 @@ TEST(IntervalIntersection, VariableIntervals2) {
}
TEST(IntervalIntersection, VariableIntervals3) {
- using namespace unit_test_abt_literals;
-
auto interval = _disj(_conj(_interval(_excl("v1"_var), _incl("v3"_var)),
_interval(_incl("v2"_var), _incl("v4"_var))));
@@ -427,8 +422,6 @@ TEST(IntervalIntersection, VariableIntervals3) {
}
TEST(IntervalIntersection, VariableIntervals4) {
- using namespace unit_test_abt_literals;
-
auto interval = _disj(_conj(_interval(_excl("v1"_var), _incl("v3"_var)),
_interval(_incl("v2"_var), _excl("v4"_var))));
@@ -467,37 +460,90 @@ TEST(IntervalIntersection, VariableIntervals4) {
ASSERT_TRUE(*result == *result1);
}
+/*
+ * Bitset with extra flags to indicate whether MinKey and MaxKey are included.
+ * The first two bits are MinKey and MaxKey, the rest represent integers [0, N).
+ */
template <int N>
-void updateResults(const bool lowInc,
- const int low,
- const bool highInc,
- const int high,
- std::bitset<N>& inclusion) {
- for (int v = 0; v < low + (lowInc ? 0 : 1); v++) {
- inclusion.set(v, false);
+class ExtendedBitset {
+public:
+ ExtendedBitset() {}
+
+ void set(const int i, const bool b) {
+ invariant(i >= 0 && i < N);
+ _b.set(i + 2, b);
}
- for (int v = high + (highInc ? 1 : 0); v < N; v++) {
- inclusion.set(v, false);
+
+ static ExtendedBitset<N> minKey() {
+ ExtendedBitset<N> b;
+ b._b.set(0);
+ return b;
}
-}
+ static ExtendedBitset<N> maxKey() {
+ ExtendedBitset<N> b;
+ b._b.set(1);
+ return b;
+ }
+
+ ExtendedBitset& operator&=(const ExtendedBitset& rhs) {
+ _b &= rhs._b;
+ return *this;
+ }
+
+ ExtendedBitset& operator|=(const ExtendedBitset& rhs) {
+ _b |= rhs._b;
+ return *this;
+ }
+
+ bool operator==(const ExtendedBitset& rhs) const {
+ return _b == rhs._b;
+ }
+
+ bool operator!=(const ExtendedBitset& rhs) const {
+ return !(*this == rhs);
+ }
+
+private:
+ std::bitset<N + 2> _b;
+};
+
+/*
+ * Calculates the extended bitset of a given interval in any form (not just DNF).
+ */
template <int N>
class IntervalInclusionTransport {
public:
- using ResultType = std::bitset<N>;
+ using ResultType = ExtendedBitset<N>;
ResultType transport(const IntervalReqExpr::Atom& node) {
const auto& expr = node.getExpr();
const auto& lb = expr.getLowBound();
const auto& hb = expr.getHighBound();
- std::bitset<N> result;
- result.flip();
- updateResults<N>(lb.isInclusive(),
- lb.getBound().cast<Constant>()->getValueInt32(),
- hb.isInclusive(),
- hb.getBound().cast<Constant>()->getValueInt32(),
- result);
+ ExtendedBitset<N> result;
+ if (lb.getBound() == Constant::maxKey() || hb.getBound() == Constant::minKey()) {
+ return result;
+ }
+
+ int lbInt = 0;
+ if (lb.getBound() == Constant::minKey()) {
+ result |= ExtendedBitset<N>::minKey();
+ } else {
+ lbInt = lb.getBound().cast<Constant>()->getValueInt32() + (lb.isInclusive() ? 0 : 1);
+ }
+
+ int hbInt = N;
+ if (hb.getBound() == Constant::maxKey()) {
+ result |= ExtendedBitset<N>::maxKey();
+ } else {
+ hbInt = hb.getBound().cast<Constant>()->getValueInt32() + (hb.isInclusive() ? 1 : 0);
+ }
+
+ for (int v = lbInt; v < hbInt; v++) {
+ result.set(v, true);
+ }
+
return result;
}
@@ -522,6 +568,51 @@ public:
}
};
+/*
+ * Replaces variables with their value in the given varMap.
+ */
+class EvalVariables {
+public:
+ EvalVariables(ProjectionNameMap<ABT> varMap) : _varMap(std::move(varMap)) {}
+
+ void transport(ABT& n, const Variable& node) {
+ const auto it = _varMap.find(ProjectionName(node.name().value()));
+ if (it != _varMap.end()) {
+ n = it->second;
+ }
+ }
+
+ template <typename T, typename... Ts>
+ void transport(ABT& /*n*/, const T& /*node*/, Ts&&...) {
+ invariant((std::is_base_of_v<If, T> || std::is_base_of_v<BinaryOp, T> ||
+ std::is_base_of_v<UnaryOp, T> || std::is_base_of_v<Constant, T> ||
+ std::is_base_of_v<Variable, T>));
+ }
+
+ void evalVars(ABT& n) {
+ algebra::transport<true>(n, *this);
+ ConstEval::constFold(n);
+ invariant(n.is<Constant>());
+ }
+
+ void replaceVarsInInterval(IntervalReqExpr::Node& interval) {
+ for (auto& disjunct : interval.cast<IntervalReqExpr::Disjunction>()->nodes()) {
+ for (auto& conjunct : disjunct.cast<IntervalReqExpr::Conjunction>()->nodes()) {
+ auto& interval = conjunct.cast<IntervalReqExpr::Atom>()->getExpr();
+ ABT lowBound = interval.getLowBound().getBound();
+ ABT highBound = interval.getHighBound().getBound();
+ evalVars(lowBound);
+ evalVars(highBound);
+ interval = {{interval.getLowBound().isInclusive(), std::move(lowBound)},
+ {interval.getHighBound().isInclusive(), std::move(highBound)}};
+ }
+ }
+ }
+
+private:
+ ProjectionNameMap<ABT> _varMap;
+};
+
template <int V>
int decode(int& permutation) {
const int result = permutation % V;
@@ -530,7 +621,19 @@ int decode(int& permutation) {
}
template <int N>
-void testInterval(int permutation) {
+bool compareIntervals(const IntervalReqExpr::Node& original,
+ const IntervalReqExpr::Node& simplified) {
+ IntervalInclusionTransport<N> transport;
+ return transport.computeInclusion(original) == transport.computeInclusion(simplified);
+}
+
+/*
+ * Create two random intervals composed of constants and test intersection/union on them.
+ */
+template <int N>
+void testIntervalPermutation(int permutation) {
+ auto prefixId = PrefixId::createForTests();
+
const bool low1Inc = decode<2>(permutation);
const int low1 = decode<N>(permutation);
const bool high1Inc = decode<2>(permutation);
@@ -540,44 +643,120 @@ void testInterval(int permutation) {
const bool high2Inc = decode<2>(permutation);
const int high2 = decode<N>(permutation);
- auto interval = IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
- IntervalReqExpr::make<IntervalReqExpr::Conjunction>(IntervalReqExpr::NodeVector{
- IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
- {low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}}),
- IntervalReqExpr::make<IntervalReqExpr::Atom>(IntervalRequirement{
- {low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)}})})});
-
- auto result = intersectDNFIntervals(interval, ConstEval::constFold);
- std::bitset<N> inclusionActual;
- if (result) {
- // Since we are testing with constants, we should have at most one interval.
- ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*result));
-
- IntervalInclusionTransport<N> transport;
- // Compute actual inclusion bitset based on the interval intersection code.
- inclusionActual = transport.computeInclusion(*result);
+ // Test intersection.
+ {
+ auto original = _disj(
+ _conj(_interval({low1Inc, Constant::int32(low1)}, {high1Inc, Constant::int32(high1)}),
+ _interval({low2Inc, Constant::int32(low2)}, {high2Inc, Constant::int32(high2)})));
+ auto simplified = intersectDNFIntervals(original, ConstEval::constFold);
+ if (simplified) {
+ // Since we are testing with constants, we should have at most one interval.
+ ASSERT_TRUE(IntervalReqExpr::getSingularDNF(*simplified));
+ compareIntervals<N>(original, *simplified);
+ } else {
+ IntervalInclusionTransport<N> transport;
+ ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>());
+ }
}
+}
- std::bitset<N> inclusionExpected;
- inclusionExpected.flip();
+// Generates a random integer bound. If isLow is true, lower values are more likely. If isLow is
+// false, higher values are more likely.
+template <int N, bool isLow>
+ABT makeRandomIntBound(PseudoRandom& threadLocalRNG) {
+ // This is a trick to create a skewed distribution on [0, N). Say N=3,
+ // potential values of r = 0 1 2 3 4 5 6 7 8
+ // (int) sqrt(r) = 0 1 1 1 2 2 2 2 2
+ // The higher the number is (as long as its <N), the more likely it is to occur.
+ const int r = threadLocalRNG.nextInt32(N * N);
+ const int bound = (int)std::sqrt(r);
+ invariant(0 <= bound && bound < N);
+ return Constant::int32(isLow ? N - 1 - bound : bound);
+}
- // Compute ground truth.
- updateResults<N>(low1Inc, low1, high1Inc, high1, inclusionExpected);
- updateResults<N>(low2Inc, low2, high2Inc, high2, inclusionExpected);
+template <int N, bool isLow>
+BoundRequirement makeRandomBound(PseudoRandom& threadLocalRNG,
+ const std::vector<ProjectionName>& vars) {
+ const bool isInclusive = threadLocalRNG.nextInt32(2);
+ // We can return one of: N+2 constants (+2 because of minkey and maxkey), or 8 variables.
+ const int r = threadLocalRNG.nextInt32(N + 10);
+ if (r == 0) {
+ return {isInclusive, Constant::minKey()};
+ } else if (r == 1) {
+ return {isInclusive, Constant::maxKey()};
+ } else if (r < N + 2) {
+ return {isInclusive, makeRandomIntBound<N, isLow>(threadLocalRNG)};
+ } else {
+ return {isInclusive, make<Variable>(vars.at(r - (N + 2)))};
+ }
+};
- ASSERT_EQ(inclusionExpected, inclusionActual);
+template <int N>
+void testIntervalFuzz(const uint64_t seed, PseudoRandom& threadLocalRNG) {
+ // Generate values for the eight variables we have.
+ auto prefixId = PrefixId::createForTests();
+ ProjectionNameMap<ABT> varMap;
+ std::vector<ProjectionName> vars;
+ for (size_t i = 0; i < 8; i++) {
+ // minkey=0, maxkey=1, anything else is a constant
+ const int type = threadLocalRNG.nextInt32(N + 2);
+ ABT val = Constant::int32(type - 2);
+ if (type == 0) {
+ val = Constant::minKey();
+ } else if (type == 1) {
+ val = Constant::maxKey();
+ }
+ ProjectionName var = prefixId.getNextId("var");
+ varMap.emplace(var.value(), val);
+ vars.push_back(var);
+ }
+ EvalVariables varEval(std::move(varMap));
+
+ // Create between one and five intervals.
+
+ // TODO SERVER-71656 we should be able to uncomment this after SERVER-71656 is complete.
+ // Intersect with multiple intervals.
+ // const size_t numIntervals = 2 + threadLocalRNG.nextInt32(5);
+ // {
+ // IntervalReqExpr::NodeVector intervalVec;
+ // for (size_t i = 0; i < numIntervals; i++) {
+ // intervalVec.push_back(IntervalReqExpr::make<IntervalReqExpr::Atom>(
+ // IntervalRequirement{makeRandomBound(), makeRandomBound()}));
+ // }
+
+ // auto original =
+ // IntervalReqExpr::make<IntervalReqExpr::Disjunction>(IntervalReqExpr::NodeVector{
+ // IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(intervalVec))});
+ // auto simplified = intersectDNFIntervals(original, ConstEval::constFold);
+
+ // varEval.replaceVarsInInterval(original);
+ // if (simplified) {
+ // varEval.replaceVarsInInterval(*simplified);
+ // compareIntervals<N>(original, *simplified);
+ // } else {
+ // IntervalInclusionTransport<N> transport;
+ // ASSERT(transport.computeInclusion(original) == ExtendedBitset<N>());
+ // }
+ // }
+
+ // TODO SERVER-71175 should include a union test here
+
+ // TODO SERVER-71656 should extend this to have DNF intervals that are not purely disjunctions
+ // or purely conjunctions. A mix of ORs and ANDs seems necessary, to test that the two
+ // simplification methods (intersecting and unioning) don't interfere with each other.
}
-TEST(IntervalIntersection, IntervalPermutations) {
- static constexpr int N = 10;
- static constexpr int numPermutations = N * N * N * N * 2 * 2 * 2 * 2;
+static constexpr int bitsetSize = 10;
+static const size_t numThreads = ProcessInfo::getNumCores();
+TEST(IntervalIntersection, IntervalPermutations) {
+ static constexpr int numPermutations =
+ bitsetSize * bitsetSize * bitsetSize * bitsetSize * 2 * 2 * 2 * 2;
/**
* Test for interval intersection. Generate intervals with constants in the
* range of [0, N), with random inclusion/exclusion of the endpoints. Intersect the intervals
* and verify against ground truth.
*/
- const size_t numThreads = ProcessInfo::getNumCores();
std::cout << "Testing " << numPermutations << " interval permutations using " << numThreads
<< " cores...\n";
auto timeBegin = Date_t::now();
@@ -591,7 +770,7 @@ TEST(IntervalIntersection, IntervalPermutations) {
if (nextP >= numPermutations) {
break;
}
- testInterval<N>(nextP);
+ testIntervalPermutation<bitsetSize>(nextP);
}
});
}
@@ -604,5 +783,35 @@ TEST(IntervalIntersection, IntervalPermutations) {
std::cout << "...done. Took: " << elapsed << " s.\n";
}
+TEST(IntervalIntersection, IntervalFuzz) {
+ static constexpr int numFuzzTests = 5000;
+ /**
+ * Generate random intervals with a mix of variables and constants, and test that they intersect
+ * and union correctly.
+ */
+ std::cout << "Testing " << numFuzzTests << " fuzzed intervals using " << numThreads
+ << " cores...\n";
+ const auto timeBeginFuzz = Date_t::now();
+
+ std::vector<stdx::thread> threads;
+ for (size_t i = 0; i < numThreads; i++) {
+ threads.emplace_back([]() {
+ const auto seed = SecureRandom().nextInt64();
+ std::cout << "Using random seed: " << seed << "\n";
+ PseudoRandom threadLocalRNG(seed);
+ for (size_t i = 0; i < numFuzzTests / numThreads; i++) {
+ testIntervalFuzz<bitsetSize>(seed, threadLocalRNG);
+ }
+ });
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ const auto elapsedFuzz =
+ (Date_t::now().toMillisSinceEpoch() - timeBeginFuzz.toMillisSinceEpoch()) / 1000.0;
+ std::cout << "...done. Took: " << elapsedFuzz << " s.\n";
+}
+
} // namespace
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/interval_utils.h b/src/mongo/db/query/optimizer/utils/interval_utils.h
index 90156e0aeb7..7b629174506 100644
--- a/src/mongo/db/query/optimizer/utils/interval_utils.h
+++ b/src/mongo/db/query/optimizer/utils/interval_utils.h
@@ -31,6 +31,8 @@
#include "mongo/db/query/optimizer/index_bounds.h"
+#include "mongo/db/query/optimizer/utils/utils.h"
+
namespace mongo::optimizer {
/**
diff --git a/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h b/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h
index 607765d46ce..a6b3ddafba4 100644
--- a/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h
+++ b/src/mongo/db/query/optimizer/utils/unit_test_abt_literals.h
@@ -419,18 +419,26 @@ private:
/**
* Interval expressions.
*/
+inline auto _disj(IntervalReqExpr::NodeVector v) {
+ return IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(v));
+}
+
template <typename... Ts>
inline auto _disj(Ts&&... pack) {
IntervalReqExpr::NodeVector v;
(v.push_back(std::forward<Ts>(pack)), ...);
- return IntervalReqExpr::make<IntervalReqExpr::Disjunction>(std::move(v));
+ return _disj(std::move(v));
+}
+
+inline auto _conj(IntervalReqExpr::NodeVector v) {
+ return IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(v));
}
template <typename... Ts>
inline auto _conj(Ts&&... pack) {
IntervalReqExpr::NodeVector v;
(v.push_back(std::forward<Ts>(pack)), ...);
- return IntervalReqExpr::make<IntervalReqExpr::Conjunction>(std::move(v));
+ return _conj(std::move(v));
}
inline auto _interval(IntervalRequirement req) {