summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSvilen Mihaylov <svilen.mihaylov@mongodb.com>2020-09-11 12:20:10 -0400
committerSvilen Mihaylov <svilen.mihaylov@mongodb.com>2020-09-11 12:20:35 -0400
commit435225797704e51cae5fd7db4d02c2e86e784eea (patch)
tree39c4b33adea9b5781a2f52173b686e8fe57d3358
parent6dee6f0a7b62faaad5a43a4564a76e4ab99b0012 (diff)
downloadmongo-435225797704e51cae5fd7db4d02c2e86e784eea.tar.gz
use PolyValue
-rw-r--r--src/mongo/db/query/optimizer/SConscript1
-rw-r--r--src/mongo/db/query/optimizer/algebra/operator.h305
-rw-r--r--src/mongo/db/query/optimizer/algebra/polyvalue.h381
-rw-r--r--src/mongo/db/query/optimizer/memo.cpp43
-rw-r--r--src/mongo/db/query/optimizer/memo.h (renamed from src/mongo/db/query/optimizer/visitor.h)24
-rw-r--r--src/mongo/db/query/optimizer/node.cpp216
-rw-r--r--src/mongo/db/query/optimizer/node.h151
-rw-r--r--src/mongo/db/query/optimizer/optimizer_test.cpp20
8 files changed, 841 insertions, 300 deletions
diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript
index 175b109625d..0863192a593 100644
--- a/src/mongo/db/query/optimizer/SConscript
+++ b/src/mongo/db/query/optimizer/SConscript
@@ -8,6 +8,7 @@ env.Library(
target="optimizer",
source=[
"defs.cpp",
+ "memo.cpp",
"node.cpp",
],
LIBDEPS=[
diff --git a/src/mongo/db/query/optimizer/algebra/operator.h b/src/mongo/db/query/optimizer/algebra/operator.h
new file mode 100644
index 00000000000..524b7246413
--- /dev/null
+++ b/src/mongo/db/query/optimizer/algebra/operator.h
@@ -0,0 +1,305 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <vector>
+
+#include "mongo/db/query/optimizer/algebra/polyvalue.h"
+
+namespace mongo::optimizer {
+namespace algebra {
+
+template <typename T, int S>
+struct OpNodeStorage {
+ T _nodes[S];
+
+ template <typename... Ts>
+ OpNodeStorage(Ts&&... vals) : _nodes{std::forward<Ts>(vals)...} {}
+};
+
+template <typename T>
+struct OpNodeStorage<T, 0> {};
+
+/*=====-----
+ *
+ * Arity of operator can be:
+ * 1. statically known - A, A, A, ...
+ * 2. dynamic prefix with optional statically know - vector<A>, A, A, A, ...
+ *
+ * Denotations map A to some B.
+ * So static arity <A,A,A> is mapped to <B,B,B>.
+ * Similarly, arity <vector<A>,A> is mapped to <vector<B>,B>
+ *
+ * There is a wrinkle when B is a reference (if allowed at all)
+ * Arity <vector<A>, A, A> is mapped to <vector<B>&, B&, B&> - note that the reference is lifted
+ * outside of the vector.
+ *
+ */
+template <typename Slot, typename Derived, int Arity>
+class OpSpecificArity : public OpNodeStorage<Slot, Arity> {
+ using Base = OpNodeStorage<Slot, Arity>;
+
+public:
+ template <typename... Ts>
+ OpSpecificArity(Ts&&... vals) : Base({std::forward<Ts>(vals)...}) {
+ static_assert(sizeof...(Ts) == Arity, "constructor paramaters do not match");
+ }
+
+ template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0>
+ auto& get() noexcept {
+ return this->_nodes[I];
+ }
+
+ template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0>
+ const auto& get() const noexcept {
+ return this->_nodes[I];
+ }
+};
+/*=====-----
+ *
+ * Operator with dynamic arity
+ *
+ */
+template <typename Slot, typename Derived, int Arity>
+class OpSpecificDynamicArity : public OpSpecificArity<Slot, Derived, Arity> {
+ using Base = OpSpecificArity<Slot, Derived, Arity>;
+
+ std::vector<Slot> _dyNodes;
+
+public:
+ template <typename... Ts>
+ OpSpecificDynamicArity(std::vector<Slot> nodes, Ts&&... vals)
+ : Base({std::forward<Ts>(vals)...}), _dyNodes(std::move(nodes)) {}
+
+ auto& nodes() {
+ return _dyNodes;
+ }
+ const auto& nodes() const {
+ return _dyNodes;
+ }
+};
+
+/*=====-----
+ *
+ * Semantic transport interface
+ *
+ */
+namespace detail {
+template <typename D, typename T, typename = std::void_t<>>
+struct has_prepare : std::false_type {};
+template <typename D, typename T>
+struct has_prepare<D, T, std::void_t<decltype(std::declval<D>().prepare(std::declval<T&>()))>>
+ : std::true_type {};
+
+template <typename D, typename T>
+inline constexpr auto has_prepare_v = has_prepare<D, T>::value;
+
+template <typename Slot, typename Derived, int Arity>
+inline constexpr int get_arity(const OpSpecificArity<Slot, Derived, Arity>*) {
+ return Arity;
+}
+
+template <typename Slot, typename Derived, int Arity>
+inline constexpr bool is_dynamic(const OpSpecificArity<Slot, Derived, Arity>*) {
+ return false;
+}
+
+template <typename Slot, typename Derived, int Arity>
+inline constexpr bool is_dynamic(const OpSpecificDynamicArity<Slot, Derived, Arity>*) {
+ return true;
+}
+
+template <typename T>
+using OpConcreteType = typename std::remove_reference_t<T>::template get_t<0>;
+} // namespace detail
+
+template <typename D, bool withSlot>
+class OpTransporter {
+ D& _domain;
+
+ template <typename T, bool B>
+ struct Deducer {};
+ template <typename T>
+ struct Deducer<T, true> {
+ using type = decltype(std::declval<D>().transport(
+ std::declval<T>(), std::declval<detail::OpConcreteType<T>&>()));
+ };
+ template <typename T>
+ struct Deducer<T, false> {
+ using type =
+ decltype(std::declval<D>().transport(std::declval<detail::OpConcreteType<T>&>()));
+ };
+ template <typename T>
+ using deduced_t = typename Deducer<T, withSlot>::type;
+
+ template <typename N, typename T, typename... Ts>
+ auto transformStep(N&& slot, T&& op, Ts&&... args) {
+ if constexpr (withSlot) {
+ return _domain.transport(
+ std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...);
+ } else {
+ return _domain.transport(std::forward<T>(op), std::forward<Ts>(args)...);
+ }
+ }
+
+ template <typename N, typename T, size_t... I>
+ auto transportUnpack(N&& slot, T&& op, std::index_sequence<I...>) {
+ return transformStep(
+ std::forward<N>(slot), std::forward<T>(op), op.template get<I>().visit(*this)...);
+ }
+ template <typename N, typename T, size_t... I>
+ auto transportDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>) {
+ std::vector<decltype(slot.visit(*this))> v;
+ for (auto& node : op.nodes()) {
+ v.emplace_back(node.visit(*this));
+ }
+ return transformStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::move(v),
+ op.template get<I>().visit(*this)...);
+ }
+ template <typename N, typename T, size_t... I>
+ void transportUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>) {
+ (op.template get<I>().visit(*this), ...);
+ return transformStep(std::forward<N>(slot), std::forward<T>(op), op.template get<I>()...);
+ }
+ template <typename N, typename T, size_t... I>
+ void transportDynamicUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>) {
+ for (auto& node : op.nodes()) {
+ node.visit(*this);
+ }
+ (op.template get<I>().visit(*this), ...);
+ return transformStep(
+ std::forward<N>(slot), std::forward<T>(op), op.nodes(), op.template get<I>()...);
+ }
+
+public:
+ OpTransporter(D& domain) : _domain(domain) {}
+
+ template <typename N, typename T, typename R = deduced_t<N>>
+ R operator()(N&& slot, T&& op) {
+ // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference
+ // T is either `A&` or `const A&` where A is one of Ts
+ using type = std::remove_reference_t<T>;
+
+ constexpr int arity = detail::get_arity(static_cast<type*>(nullptr));
+ constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr));
+
+ if constexpr (detail::has_prepare_v<D, type>) {
+ _domain.prepare(std::forward<T>(op));
+ }
+ if constexpr (is_dynamic) {
+ if constexpr (std::is_same_v<R, void>) {
+ return transportDynamicUnpackVoid(
+ std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{});
+ } else {
+ return transportDynamicUnpack(
+ std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{});
+ }
+ } else {
+ if constexpr (std::is_same_v<R, void>) {
+ return transportUnpackVoid(
+ std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{});
+ } else {
+ return transportUnpack(
+ std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{});
+ }
+ }
+ }
+};
+
+template <typename D, bool withSlot>
+class OpWalker {
+ D& _domain;
+
+ template <typename N, typename T, typename... Ts>
+ auto walkStep(N&& slot, T&& op, Ts&&... args) {
+ if constexpr (withSlot) {
+ return _domain.walk(
+ std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...);
+ } else {
+ return _domain.walk(std::forward<T>(op), std::forward<Ts>(args)...);
+ }
+ }
+
+ template <typename N, typename T, typename... Args, size_t... I>
+ auto walkUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ return walkStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ op.template get<I>()...);
+ }
+ template <typename N, typename T, typename... Args, size_t... I>
+ auto walkDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) {
+ return walkStep(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::forward<Args>(args)...,
+ op.nodes(),
+ op.template get<I>()...);
+ }
+
+public:
+ OpWalker(D& domain) : _domain(domain) {}
+
+ template <typename N, typename T, typename... Args>
+ auto operator()(N&& slot, T&& op, Args&&... args) {
+ // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference
+ // T is either `A&` or `const A&` where A is one of Ts
+ using type = std::remove_reference_t<T>;
+
+ constexpr int arity = detail::get_arity(static_cast<type*>(nullptr));
+ constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr));
+
+ if constexpr (is_dynamic) {
+ return walkDynamicUnpack(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ } else {
+ return walkUnpack(std::forward<N>(slot),
+ std::forward<T>(op),
+ std::make_index_sequence<arity>{},
+ std::forward<Args>(args)...);
+ }
+ }
+};
+
+template <bool withSlot = false, typename D, typename N>
+auto transport(N&& node, D& domain) {
+ return node.visit(OpTransporter<D, withSlot>{domain});
+}
+
+template <bool withSlot = false, typename D, typename N, typename... Args>
+auto walk(N&& node, D& domain, Args&&... args) {
+ return node.visit(OpWalker<D, withSlot>{domain}, std::forward<Args>(args)...);
+}
+
+} // namespace algebra
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/algebra/polyvalue.h b/src/mongo/db/query/optimizer/algebra/polyvalue.h
new file mode 100644
index 00000000000..374041c5704
--- /dev/null
+++ b/src/mongo/db/query/optimizer/algebra/polyvalue.h
@@ -0,0 +1,381 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <array>
+#include <stdexcept>
+#include <type_traits>
+
+namespace mongo::optimizer {
+namespace algebra {
+namespace detail {
+
+template <typename T, typename... Args>
+inline constexpr bool is_one_of_v = std::disjunction_v<std::is_same<T, Args>...>;
+
+template <typename T, typename... Args>
+inline constexpr bool is_one_of_f() {
+ return is_one_of_v<T, Args...>;
+}
+
+template <typename... Args>
+struct is_unique_t : std::true_type {};
+
+template <typename H, typename... T>
+struct is_unique_t<H, T...>
+ : std::bool_constant<!is_one_of_f<H, T...>() && is_unique_t<T...>::value> {};
+
+template <typename... Args>
+inline constexpr bool is_unique_v = is_unique_t<Args...>::value;
+
+// Given the type T find its index in Ts
+template <typename T, typename... Ts>
+static inline constexpr int find_index() {
+ static_assert(detail::is_unique_v<Ts...>, "Types must be unique");
+ constexpr bool matchVector[] = {std::is_same<T, Ts>::value...};
+
+ for (int index = 0; index < static_cast<int>(sizeof...(Ts)); ++index) {
+ if (matchVector[index]) {
+ return index;
+ }
+ }
+
+ return -1;
+}
+
+template <int N, typename T, typename... Ts>
+struct get_type_by_index_impl {
+ using type = typename get_type_by_index_impl<N - 1, Ts...>::type;
+};
+template <typename T, typename... Ts>
+struct get_type_by_index_impl<0, T, Ts...> {
+ using type = T;
+};
+
+// Given the index I return the type from Ts
+template <int I, typename... Ts>
+using get_type_by_index = typename get_type_by_index_impl<I, Ts...>::type;
+
+} // namespace detail
+
+/*=====-----
+ *
+ * The overload trick to construct visitors from lambdas.
+ *
+ */
+template <class... Ts>
+struct overload : Ts... {
+ using Ts::operator()...;
+};
+template <class... Ts>
+overload(Ts...)->overload<Ts...>;
+
+/*=====-----
+ *
+ * Forward declarations
+ *
+ */
+template <typename... Ts>
+class PolyValue;
+
+template <typename T, typename... Ts>
+class ControlBlockVTable;
+
+/*=====-----
+ *
+ * The base control block that PolyValue holds.
+ *
+ * It does not contain anything else by the runtime tag.
+ *
+ */
+template <typename... Ts>
+class ControlBlock {
+ const int _tag;
+
+protected:
+ ControlBlock(int tag) noexcept : _tag(tag) {}
+
+public:
+ auto getRuntimeTag() const noexcept {
+ return _tag;
+ }
+};
+
+/*=====-----
+ *
+ * The concrete control block VTable generator.
+ *
+ * It must be empty ad PolyValue derives from the generators
+ * and we want EBO to kick in.
+ *
+ */
+template <typename T, typename... Ts>
+class ControlBlockVTable {
+ static constexpr int _staticTag = detail::find_index<T, Ts...>();
+ static_assert(_staticTag != -1, "Type must be on the list");
+
+ using AbstractType = ControlBlock<Ts...>;
+ using PolyValueType = PolyValue<Ts...>;
+
+ /*=====-----
+ *
+ * The concrete control block for every type T of Ts.
+ *
+ * It derives from the ControlBlock. All methods are private and only
+ * the friend class ControlBlockVTable can call them.
+ *
+ */
+ class ConcreteType : public AbstractType {
+ T _t;
+
+ public:
+ template <typename... Args>
+ ConcreteType(Args&&... args) : AbstractType(_staticTag), _t(std::forward<Args>(args)...) {}
+
+ const T* getPtr() const {
+ return &_t;
+ }
+
+ T* getPtr() {
+ return &_t;
+ }
+ };
+
+ static constexpr auto concrete(AbstractType* block) noexcept {
+ return static_cast<ConcreteType*>(block);
+ }
+
+ static constexpr auto concrete(const AbstractType* block) noexcept {
+ return static_cast<const ConcreteType*>(block);
+ }
+
+public:
+ template <typename... Args>
+ static AbstractType* make(Args&&... args) {
+ return new ConcreteType(std::forward<Args>(args)...);
+ }
+
+ static AbstractType* clone(const AbstractType* block) {
+ return new ConcreteType(*concrete(block));
+ }
+
+ static void destroy(AbstractType* block) noexcept {
+ delete concrete(block);
+ }
+
+ static bool compareEq(AbstractType* blockLhs, AbstractType* blockRhs) noexcept {
+ if (blockLhs->getRuntimeTag() == blockRhs->getRuntimeTag()) {
+ return *castConst<T>(blockLhs) == *castConst<T>(blockRhs);
+ }
+ return false;
+ }
+
+ template <typename U>
+ static constexpr bool is_v = std::is_base_of_v<U, T>;
+
+ template <typename U>
+ static U* cast(AbstractType* block) {
+ if constexpr (is_v<U>) {
+ return static_cast<U*>(concrete(block)->getPtr());
+ } else {
+ // gcc bug 81676
+ (void)block;
+ return nullptr;
+ }
+ }
+
+ template <typename U>
+ static const U* castConst(const AbstractType* block) {
+ if constexpr (is_v<U>) {
+ return static_cast<const U*>(concrete(block)->getPtr());
+ } else {
+ // gcc bug 81676
+ (void)block;
+ return nullptr;
+ }
+ }
+
+ template <typename V, typename... Args>
+ static auto visit(V&& v, PolyValueType& holder, AbstractType* block, Args&&... args) {
+ return v(holder, *cast<T>(block), std::forward<Args>(args)...);
+ }
+
+ template <typename V, typename... Args>
+ static auto visitConst(V&& v,
+ const PolyValueType& holder,
+ const AbstractType* block,
+ Args&&... args) {
+ return v(holder, *castConst<T>(block), std::forward<Args>(args)...);
+ }
+};
+
+/*=====-----
+ *
+ * This is a variation on variant and polymorphic value theme.
+ *
+ * A tag based dispatch
+ *
+ * Supported operations:
+ * - construction
+ * - destruction
+ * - clone a = b;
+ * - cast a.cast<T>()
+ * - multi-method cast to common base a.cast<B>()
+ * - multi-method visit
+ */
+template <typename... Ts>
+class PolyValue : private ControlBlockVTable<Ts, Ts...>... {
+ static_assert(detail::is_unique_v<Ts...>, "Types must be unique");
+ static_assert(std::conjunction_v<std::is_empty<ControlBlockVTable<Ts, Ts...>>...>,
+ "VTable base classes must be empty");
+
+ ControlBlock<Ts...>* _object{nullptr};
+
+ PolyValue(ControlBlock<Ts...>* object) noexcept : _object(object) {}
+
+ auto tag() const noexcept {
+ return _object->getRuntimeTag();
+ }
+
+ void check() const {
+ if (!_object) {
+ throw std::logic_error("PolyValue is empty");
+ }
+ }
+
+ static void destroy(ControlBlock<Ts...>* object) {
+ static constexpr std::array destroyTbl = {&ControlBlockVTable<Ts, Ts...>::destroy...};
+
+ destroyTbl[object->getRuntimeTag()](object);
+ }
+
+public:
+ PolyValue() = delete;
+
+ PolyValue(const PolyValue& other) {
+ static constexpr std::array cloneTbl = {&ControlBlockVTable<Ts, Ts...>::clone...};
+ if (other._object) {
+ _object = cloneTbl[other.tag()](other._object);
+ }
+ }
+
+ PolyValue(PolyValue&& other) noexcept {
+ swap(other);
+ }
+
+ ~PolyValue() noexcept {
+ if (_object) {
+ destroy(_object);
+ }
+ }
+
+ PolyValue& operator=(PolyValue other) noexcept {
+ swap(other);
+ return *this;
+ }
+
+ template <typename T, typename... Args>
+ static PolyValue make(Args&&... args) {
+ return PolyValue{ControlBlockVTable<T, Ts...>::make(std::forward<Args>(args)...)};
+ }
+
+ template <int I>
+ using get_t = detail::get_type_by_index<I, Ts...>;
+
+ template <typename V, typename... Args>
+ auto visit(V&& v, Args&&... args) {
+ // unfortunately gcc rejects much nicer code, clang and msvc accept
+ // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template
+ // visit<V>... };
+
+ using FunPtrType =
+ decltype(&ControlBlockVTable<get_t<0>, Ts...>::template visit<V, Args...>);
+ static constexpr FunPtrType visitTbl[] = {
+ &ControlBlockVTable<Ts, Ts...>::template visit<V, Args...>...};
+
+ check();
+ return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...);
+ }
+
+ template <typename V, typename... Args>
+ auto visit(V&& v, Args&&... args) const {
+ // unfortunately gcc rejects much nicer code, clang and msvc accept
+ // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template
+ // visitConst<V>... };
+
+ using FunPtrType =
+ decltype(&ControlBlockVTable<get_t<0>, Ts...>::template visitConst<V, Args...>);
+ static constexpr FunPtrType visitTbl[] = {
+ &ControlBlockVTable<Ts, Ts...>::template visitConst<V, Args...>...};
+
+ check();
+ return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...);
+ }
+
+ template <typename T>
+ T* cast() {
+ check();
+ static constexpr std::array castTbl = {&ControlBlockVTable<Ts, Ts...>::template cast<T>...};
+ return castTbl[tag()](_object);
+ }
+
+ template <typename T>
+ const T* cast() const {
+ static constexpr std::array castTbl = {
+ &ControlBlockVTable<Ts, Ts...>::template castConst<T>...};
+
+ check();
+ return castTbl[tag()](_object);
+ }
+
+ template <typename T>
+ bool is() const {
+ static constexpr std::array isTbl = {ControlBlockVTable<Ts, Ts...>::template is_v<T>...};
+
+ check();
+ return isTbl[tag()];
+ }
+
+ bool empty() const {
+ return !_object;
+ }
+
+ void swap(PolyValue& other) noexcept {
+ std::swap(other._object, _object);
+ }
+
+ bool operator==(const PolyValue& rhs) const noexcept {
+ static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...};
+ return cmp[tag()](_object, rhs._object);
+ }
+};
+
+} // namespace algebra
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/memo.cpp b/src/mongo/db/query/optimizer/memo.cpp
new file mode 100644
index 00000000000..c4dadbb3d5a
--- /dev/null
+++ b/src/mongo/db/query/optimizer/memo.cpp
@@ -0,0 +1,43 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/query/optimizer/algebra/operator.h"
+#include "mongo/db/query/optimizer/memo.h"
+#include "mongo/db/query/optimizer/node.h"
+
+namespace mongo::optimizer {
+
+std::string MemoGenerator::generateMemo(const PolymorphicNode& e) {
+ _os.str("");
+ _os.clear();
+ algebra::transport<false>(e, *this);
+ return _os.str();
+}
+
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/visitor.h b/src/mongo/db/query/optimizer/memo.h
index 1aa0a886fab..ad3703f8fd8 100644
--- a/src/mongo/db/query/optimizer/visitor.h
+++ b/src/mongo/db/query/optimizer/memo.h
@@ -31,16 +31,24 @@
#include <string>
+#include "mongo/db/query/optimizer/node.h"
+
namespace mongo::optimizer {
-class AbstractVisitor {
+class MemoGenerator {
public:
- virtual void visit(const ScanNode& node) = 0;
- virtual void visit(const MultiJoinNode& node) = 0;
- virtual void visit(const UnionNode& node) = 0;
- virtual void visit(const GroupByNode& node) = 0;
- virtual void visit(const UnwindNode& node) = 0;
- virtual void visit(const WindNode& node) = 0;
+ template <typename T, typename... Ts>
+ void transport(const T&, Ts&&...) {}
+
+ template <typename T>
+ void prepare(const T& n) {
+ n.generateMemo(_os);
+ }
+
+ std::string generateMemo(const PolymorphicNode& e);
+
+private:
+ std::ostringstream _os;
};
-} // namespace mongo::optimizer
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp
index 4836dcce39e..a1455efd60f 100644
--- a/src/mongo/db/query/optimizer/node.cpp
+++ b/src/mongo/db/query/optimizer/node.cpp
@@ -30,130 +30,19 @@
#include <functional>
#include <stack>
+#include "mongo/db/query/optimizer/memo.h"
#include "mongo/db/query/optimizer/node.h"
-#include "mongo/db/query/optimizer/visitor.h"
-#include "mongo/util/assert_util.h"
namespace mongo::optimizer {
-Node::Node(Context& ctx) : _nodeId(ctx.getNextNodeId()), _children() {}
-
-Node::Node(Context& ctx, NodePtr child) : _nodeId(ctx.getNextNodeId()) {
- _children.push_back(std::move(child));
-}
-
-Node::Node(Context& ctx, ChildVector children)
- : _nodeId(ctx.getNextNodeId()), _children(std::move(children)) {}
+Node::Node(Context& ctx) : _nodeId(ctx.getNextNodeId()) {}
void Node::generateMemoBase(std::ostringstream& os) const {
os << "NodeId: " << _nodeId << "\n";
}
-void Node::visitPreOrder(AbstractVisitor& visitor) const {
- visit(visitor);
- for (const NodePtr& ptr : _children) {
- ptr->visitPreOrder(visitor);
- }
-}
-
-void Node::visitPostOrder(AbstractVisitor& visitor) const {
- for (const NodePtr& ptr : _children) {
- ptr->visitPostOrder(visitor);
- }
- visit(visitor);
-}
-
-std::string Node::generateMemo() const {
- class MemoVisitor : public AbstractVisitor {
- protected:
- void visit(const ScanNode& node) override {
- node.generateMemo(_os);
- }
- void visit(const MultiJoinNode& node) override {
- node.generateMemo(_os);
- }
- void visit(const UnionNode& node) override {
- node.generateMemo(_os);
- }
- void visit(const GroupByNode& node) override {
- node.generateMemo(_os);
- }
- void visit(const UnwindNode& node) override {
- node.generateMemo(_os);
- }
- void visit(const WindNode& node) override {
- node.generateMemo(_os);
- }
-
- public:
- std::ostringstream _os;
- };
-
- MemoVisitor visitor;
- visitPreOrder(visitor);
- return visitor._os.str();
-}
-
-NodePtr Node::clone(Context& ctx) const {
- class CloneVisitor : public AbstractVisitor {
- public:
- explicit CloneVisitor(Context& ctx) : _ctx(ctx), _childStack() {}
-
- protected:
- void visit(const ScanNode& node) override {
- doClone(node, [&](ChildVector v){ return ScanNode::clone(_ctx, node); });
- }
- void visit(const MultiJoinNode& node) override {
- doClone(node, [&](ChildVector v){ return MultiJoinNode::clone(_ctx, node, std::move(v)); });
- }
- void visit(const UnionNode& node) override {
- doClone(node, [&](ChildVector v){ return UnionNode::clone(_ctx, node, std::move(v)); });
- }
- void visit(const GroupByNode& node) override {
- doClone(node, [&](ChildVector v){ return GroupByNode::clone(_ctx, node, std::move(v.at(0))); });
- }
- void visit(const UnwindNode& node) override {
- doClone(node, [&](ChildVector v){ return UnwindNode::clone(_ctx, node, std::move(v.at(0))); });
- }
- void visit(const WindNode& node) override {
- doClone(node, [&](ChildVector v){ return WindNode::clone(_ctx, node, std::move(v.at(0))); });
- }
-
- private:
- void doClone(const Node& node, const std::function<NodePtr(ChildVector newChildren)>& cloneFn) {
- ChildVector newChildren;
- for (int i = 0; i < node.getChildCount(); i++) {
- newChildren.push_back(std::move(_childStack.top()));
- _childStack.pop();
- }
- _childStack.push(cloneFn(std::move(newChildren)));
- }
-
- public:
- Context& _ctx;
- std::stack<NodePtr> _childStack;
- };
-
- CloneVisitor visitor(ctx);
- visitPostOrder(visitor);
- invariant(visitor._childStack.size() == 1);
- return std::move(visitor._childStack.top());
-}
-
-int Node::getChildCount() const {
- return _children.size();
-}
-
-NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) {
- return NodePtr(new ScanNode(ctx, std::move(collectionName)));
-}
-
-NodePtr ScanNode::clone(Context& ctx, const ScanNode& other) {
- return create(ctx, other._collectionName);
-}
-
ScanNode::ScanNode(Context& ctx, CollectionNameType collectionName)
- : Node(ctx), _collectionName(std::move(collectionName)) {}
+ : Base(), Node(ctx), _collectionName(std::move(collectionName)) {}
void ScanNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
@@ -161,27 +50,12 @@ void ScanNode::generateMemo(std::ostringstream& os) const {
<< "\n";
}
-void ScanNode::visit(AbstractVisitor& visitor) const {
- visitor.visit(*this);
-}
-
-NodePtr MultiJoinNode::create(Context& ctx,
- FilterSet filterSet,
- ProjectionMap projectionMap,
- ChildVector children) {
- return NodePtr(new MultiJoinNode(
- ctx, std::move(filterSet), std::move(projectionMap), std::move(children)));
-}
-
-NodePtr MultiJoinNode::clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren) {
- return create(ctx, other._filterSet, other._projectionMap, std::move(newChildren));
-}
-
MultiJoinNode::MultiJoinNode(Context& ctx,
FilterSet filterSet,
ProjectionMap projectionMap,
- ChildVector children)
- : Node(ctx, std::move(children)),
+ PolymorphicNodeVector children)
+ : Base(std::move(children)),
+ Node(ctx),
_filterSet(std::move(filterSet)),
_projectionMap(std::move(projectionMap)) {}
@@ -191,20 +65,8 @@ void MultiJoinNode::generateMemo(std::ostringstream& os) const {
<< "\n";
}
-void MultiJoinNode::visit(AbstractVisitor& visitor) const {
- visitor.visit(*this);
-}
-
-NodePtr UnionNode::create(Context& ctx, ChildVector children) {
- return NodePtr(new UnionNode(ctx, std::move(children)));
-}
-
-NodePtr UnionNode::clone(Context& ctx, const UnionNode& other, ChildVector newChildren) {
- return create(ctx, std::move(newChildren));
-}
-
-UnionNode::UnionNode(Context& ctx, ChildVector children)
- : Node(ctx, std::move(children)) {}
+UnionNode::UnionNode(Context& ctx, PolymorphicNodeVector children)
+ : Base(std::move(children)), Node(ctx) {}
void UnionNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
@@ -212,27 +74,12 @@ void UnionNode::generateMemo(std::ostringstream& os) const {
<< "\n";
}
-void UnionNode::visit(AbstractVisitor& visitor) const {
- visitor.visit(*this);
-}
-
-NodePtr GroupByNode::create(Context& ctx,
- GroupByNode::GroupByVector groupByVector,
- GroupByNode::ProjectionMap projectionMap,
- NodePtr child) {
- return NodePtr(
- new GroupByNode(ctx, std::move(groupByVector), std::move(projectionMap), std::move(child)));
-}
-
-NodePtr GroupByNode::clone(Context& ctx, const GroupByNode& other, NodePtr newChild) {
- return create(ctx, other._groupByVector, other._projectionMap, std::move(newChild));
-}
-
GroupByNode::GroupByNode(Context& ctx,
GroupByNode::GroupByVector groupByVector,
GroupByNode::ProjectionMap projectionMap,
- NodePtr child)
- : Node(ctx, std::move(child)),
+ PolymorphicNode child)
+ : Base(std::move(child)),
+ Node(ctx),
_groupByVector(std::move(groupByVector)),
_projectionMap(std::move(projectionMap)) {}
@@ -242,27 +89,12 @@ void GroupByNode::generateMemo(std::ostringstream& os) const {
<< "\n";
}
-void GroupByNode::visit(AbstractVisitor& visitor) const {
- visitor.visit(*this);
-}
-
-NodePtr UnwindNode::create(Context& ctx,
- ProjectionName projectionName,
- const bool retainNonArrays,
- NodePtr child) {
- return NodePtr(
- new UnwindNode(ctx, std::move(projectionName), retainNonArrays, std::move(child)));
-}
-
-NodePtr UnwindNode::clone(Context& ctx, const UnwindNode& other, NodePtr newChild) {
- return create(ctx, other._projectionName, other._retainNonArrays, std::move(newChild));
-}
-
UnwindNode::UnwindNode(Context& ctx,
ProjectionName projectionName,
const bool retainNonArrays,
- NodePtr child)
- : Node(ctx, std::move(child)),
+ PolymorphicNode child)
+ : Base(std::move(child)),
+ Node(ctx),
_projectionName(std::move(projectionName)),
_retainNonArrays(retainNonArrays) {}
@@ -272,20 +104,8 @@ void UnwindNode::generateMemo(std::ostringstream& os) const {
<< "\n";
}
-void UnwindNode::visit(AbstractVisitor& visitor) const {
- visitor.visit(*this);
-}
-
-NodePtr WindNode::create(Context& ctx, ProjectionName projectionName, NodePtr child) {
- return NodePtr(new WindNode(ctx, std::move(projectionName), std::move(child)));
-}
-
-NodePtr WindNode::clone(Context& ctx, const WindNode& other, NodePtr newChild) {
- return create(ctx, other._projectionName, std::move(newChild));
-}
-
-WindNode::WindNode(Context& ctx, ProjectionName projectionName, NodePtr child)
- : Node(ctx, std::move(child)), _projectionName(std::move(projectionName)) {}
+WindNode::WindNode(Context& ctx, ProjectionName projectionName, PolymorphicNode child)
+ : Base(std::move(child)), Node(ctx), _projectionName(std::move(projectionName)) {}
void WindNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
@@ -293,8 +113,4 @@ void WindNode::generateMemo(std::ostringstream& os) const {
<< "\n";
}
-void WindNode::visit(AbstractVisitor& visitor) const {
- visitor.visit(*this);
-}
-
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h
index 78010d7d333..de56ade7b56 100644
--- a/src/mongo/db/query/optimizer/node.h
+++ b/src/mongo/db/query/optimizer/node.h
@@ -37,6 +37,7 @@
#include <utility>
#include <vector>
+#include "mongo/db/query/optimizer/algebra/operator.h"
#include "mongo/db/query/optimizer/defs.h"
#include "mongo/db/query/optimizer/filter.h"
#include "mongo/db/query/optimizer/projection.h"
@@ -45,157 +46,137 @@
namespace mongo::optimizer {
-class Node;
-using NodePtr = std::unique_ptr<Node>;
-class AbstractVisitor;
+class ScanNode;
+class MultiJoinNode;
+class UnionNode;
+class GroupByNode;
+class UnwindNode;
+class WindNode;
-class Node {
-public:
- using ChildVector = std::vector<NodePtr>;
+using PolymorphicNode =
+ algebra::PolyValue<ScanNode, MultiJoinNode, UnionNode, GroupByNode, UnwindNode, WindNode>;
+
+template <typename Derived, size_t Arity>
+using Operator = algebra::OpSpecificArity<PolymorphicNode, Derived, Arity>;
+
+template <typename Derived, size_t Arity>
+using OperatorDynamic = algebra::OpSpecificDynamicArity<PolymorphicNode, Derived, Arity>;
+
+template <typename Derived>
+using OperatorDynamicHomogenous = OperatorDynamic<Derived, 0>;
+
+using PolymorphicNodeVector = std::vector<PolymorphicNode>;
+
+template <typename T, typename... Args>
+inline auto make(Args&&... args) {
+ return PolymorphicNode::make<T>(std::forward<Args>(args)...);
+}
+template <typename... Args>
+inline auto makeSeq(Args&&... args) {
+ PolymorphicNodeVector seq;
+ (seq.emplace_back(std::forward<Args>(args)), ...);
+ return seq;
+}
+
+class Node {
protected:
explicit Node(Context& ctx);
- explicit Node(Context& ctx, NodePtr child);
- explicit Node(Context& ctx, ChildVector children);
void generateMemoBase(std::ostringstream& os) const;
- virtual void visit(AbstractVisitor& visitor) const = 0;
- void visitPreOrder(AbstractVisitor& visitor) const;
- void visitPostOrder(AbstractVisitor& visitor) const;
-
- // clone
public:
Node() = delete;
- std::string generateMemo() const;
-
- NodePtr clone(Context& ctx) const;
-
- int getChildCount() const;
-
private:
const NodeIdType _nodeId;
- ChildVector _children;
};
-class ScanNode : public Node {
+class ScanNode final : public Operator<ScanNode, 0>, public Node {
+ using Base = Operator<ScanNode, 0>;
+
public:
- static NodePtr create(Context& ctx, CollectionNameType collectionName);
- static NodePtr clone(Context& ctx, const ScanNode& other);
+ explicit ScanNode(Context& ctx, CollectionNameType collectionName);
void generateMemo(std::ostringstream& os) const;
-protected:
- void visit(AbstractVisitor& visitor) const override;
-
private:
- explicit ScanNode(Context& ctx, CollectionNameType collectionName);
-
const CollectionNameType _collectionName;
};
-class MultiJoinNode : public Node {
+class MultiJoinNode final : public OperatorDynamicHomogenous<MultiJoinNode>, public Node {
+ using Base = OperatorDynamicHomogenous<MultiJoinNode>;
+
public:
using FilterSet = std::unordered_set<FilterType>;
using ProjectionMap = std::unordered_map<ProjectionName, ProjectionType>;
- static NodePtr create(Context& ctx,
- FilterSet filterSet,
- ProjectionMap projectionMap,
- ChildVector children);
- static NodePtr clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren);
-
- void generateMemo(std::ostringstream& os) const;
-
-protected:
- void visit(AbstractVisitor& visitor) const override;
-
-private:
explicit MultiJoinNode(Context& ctx,
FilterSet filterSet,
ProjectionMap projectionMap,
- ChildVector children);
+ PolymorphicNodeVector children);
+ void generateMemo(std::ostringstream& os) const;
+
+private:
FilterSet _filterSet;
ProjectionMap _projectionMap;
};
-class UnionNode : public Node {
+class UnionNode final : public OperatorDynamicHomogenous<UnionNode>, public Node {
+ using Base = OperatorDynamicHomogenous<UnionNode>;
+
public:
- static NodePtr create(Context& ctx, ChildVector children);
- static NodePtr clone(Context& ctx, const UnionNode& other, ChildVector newChildren);
+ explicit UnionNode(Context& ctx, PolymorphicNodeVector children);
void generateMemo(std::ostringstream& os) const;
-
-protected:
- void visit(AbstractVisitor& visitor) const override;
-
-private:
- explicit UnionNode(Context& ctx, ChildVector children);
};
-class GroupByNode : public Node {
+class GroupByNode : public Operator<GroupByNode, 1>, public Node {
+ using Base = Operator<GroupByNode, 1>;
+
public:
using GroupByVector = std::vector<ProjectionName>;
using ProjectionMap = std::unordered_map<ProjectionName, ProjectionType>;
- static NodePtr create(Context& ctx,
- GroupByVector groupByVector,
- ProjectionMap projectionMap,
- NodePtr child);
- static NodePtr clone(Context& ctx, const GroupByNode& other, NodePtr newChild);
-
- void generateMemo(std::ostringstream& os) const;
-
-protected:
- void visit(AbstractVisitor& visitor) const override;
-
-private:
explicit GroupByNode(Context& ctx,
GroupByVector groupByVector,
ProjectionMap projectionMap,
- NodePtr child);
+ PolymorphicNode child);
+
+ void generateMemo(std::ostringstream& os) const;
+private:
GroupByVector _groupByVector;
ProjectionMap _projectionMap;
};
-class UnwindNode : public Node {
+class UnwindNode final : public Operator<UnwindNode, 1>, public Node {
+ using Base = Operator<UnwindNode, 1>;
+
public:
- static NodePtr create(Context& ctx,
- ProjectionName projectionName,
- bool retainNonArrays,
- NodePtr child);
- static NodePtr clone(Context& ctx, const UnwindNode& other, NodePtr newChild);
+ explicit UnwindNode(Context& ctx,
+ ProjectionName projectionName,
+ bool retainNonArrays,
+ PolymorphicNode child);
void generateMemo(std::ostringstream& os) const;
-protected:
- void visit(AbstractVisitor& visitor) const override;
-
private:
- UnwindNode(Context& ctx, ProjectionName projectionName, bool retainNonArrays, NodePtr child);
-
const ProjectionName _projectionName;
const bool _retainNonArrays;
};
-class WindNode : public Node {
+class WindNode final : public Operator<WindNode, 1>, public Node {
+ using Base = Operator<WindNode, 1>;
+
public:
- static NodePtr create(Context& ctx, ProjectionName projectionName, NodePtr child);
- static NodePtr clone(Context& ctx, const WindNode& other, NodePtr newChild);
+ explicit WindNode(Context& ctx, ProjectionName projectionName, PolymorphicNode child);
void generateMemo(std::ostringstream& os) const;
-protected:
- void visit(AbstractVisitor& visitor) const override;
-
private:
- WindNode(Context& ctx, ProjectionName projectionName, NodePtr child);
-
const ProjectionName _projectionName;
};
-
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp
index 86966e05a7e..f1cffe77303 100644
--- a/src/mongo/db/query/optimizer/optimizer_test.cpp
+++ b/src/mongo/db/query/optimizer/optimizer_test.cpp
@@ -27,6 +27,7 @@
* it in the license file.
*/
+#include "mongo/db/query/optimizer/memo.h"
#include "mongo/db/query/optimizer/node.h"
#include "mongo/unittest/unittest.h"
@@ -35,15 +36,20 @@ namespace {
TEST(Optimizer, Basic) {
Context ctx;
+ MemoGenerator gen;
- NodePtr ptrScan = ScanNode::create(ctx, "test");
- Node::ChildVector v;
- v.push_back(std::move(ptrScan));
- NodePtr ptrJoin = MultiJoinNode::create(ctx, {}, {}, std::move(v));
- ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", ptrJoin->generateMemo());
+ PolymorphicNode scanNode = make<ScanNode>(ctx, "test");
+ ASSERT_EQ("NodeId: 0\nScan\n", gen.generateMemo(scanNode));
- NodePtr cloned = ptrJoin->clone(ctx);
- ASSERT_EQ("NodeId: 3\nMultiJoin\nNodeId: 2\nScan\n", cloned->generateMemo());
+ PolymorphicNode joinNode = make<MultiJoinNode>(ctx,
+ MultiJoinNode::FilterSet{},
+ MultiJoinNode::ProjectionMap{},
+ makeSeq(std::move(scanNode)));
+ ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", gen.generateMemo(joinNode));
+
+
+ PolymorphicNode cloned = joinNode;
+ ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", gen.generateMemo(cloned));
}
} // namespace