diff options
author | Svilen Mihaylov <svilen.mihaylov@mongodb.com> | 2020-09-11 12:20:10 -0400 |
---|---|---|
committer | Svilen Mihaylov <svilen.mihaylov@mongodb.com> | 2020-09-11 12:20:35 -0400 |
commit | 435225797704e51cae5fd7db4d02c2e86e784eea (patch) | |
tree | 39c4b33adea9b5781a2f52173b686e8fe57d3358 | |
parent | 6dee6f0a7b62faaad5a43a4564a76e4ab99b0012 (diff) | |
download | mongo-435225797704e51cae5fd7db4d02c2e86e784eea.tar.gz |
use PolyValue
-rw-r--r-- | src/mongo/db/query/optimizer/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/algebra/operator.h | 305 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/algebra/polyvalue.h | 381 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/memo.cpp | 43 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/memo.h (renamed from src/mongo/db/query/optimizer/visitor.h) | 24 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/node.cpp | 216 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/node.h | 151 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/optimizer_test.cpp | 20 |
8 files changed, 841 insertions, 300 deletions
diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript index 175b109625d..0863192a593 100644 --- a/src/mongo/db/query/optimizer/SConscript +++ b/src/mongo/db/query/optimizer/SConscript @@ -8,6 +8,7 @@ env.Library( target="optimizer", source=[ "defs.cpp", + "memo.cpp", "node.cpp", ], LIBDEPS=[ diff --git a/src/mongo/db/query/optimizer/algebra/operator.h b/src/mongo/db/query/optimizer/algebra/operator.h new file mode 100644 index 00000000000..524b7246413 --- /dev/null +++ b/src/mongo/db/query/optimizer/algebra/operator.h @@ -0,0 +1,305 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <vector> + +#include "mongo/db/query/optimizer/algebra/polyvalue.h" + +namespace mongo::optimizer { +namespace algebra { + +template <typename T, int S> +struct OpNodeStorage { + T _nodes[S]; + + template <typename... Ts> + OpNodeStorage(Ts&&... vals) : _nodes{std::forward<Ts>(vals)...} {} +}; + +template <typename T> +struct OpNodeStorage<T, 0> {}; + +/*=====----- + * + * Arity of operator can be: + * 1. statically known - A, A, A, ... + * 2. dynamic prefix with optional statically know - vector<A>, A, A, A, ... + * + * Denotations map A to some B. + * So static arity <A,A,A> is mapped to <B,B,B>. + * Similarly, arity <vector<A>,A> is mapped to <vector<B>,B> + * + * There is a wrinkle when B is a reference (if allowed at all) + * Arity <vector<A>, A, A> is mapped to <vector<B>&, B&, B&> - note that the reference is lifted + * outside of the vector. + * + */ +template <typename Slot, typename Derived, int Arity> +class OpSpecificArity : public OpNodeStorage<Slot, Arity> { + using Base = OpNodeStorage<Slot, Arity>; + +public: + template <typename... Ts> + OpSpecificArity(Ts&&... vals) : Base({std::forward<Ts>(vals)...}) { + static_assert(sizeof...(Ts) == Arity, "constructor paramaters do not match"); + } + + template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0> + auto& get() noexcept { + return this->_nodes[I]; + } + + template <int I, std::enable_if_t<(I >= 0 && I < Arity), int> = 0> + const auto& get() const noexcept { + return this->_nodes[I]; + } +}; +/*=====----- + * + * Operator with dynamic arity + * + */ +template <typename Slot, typename Derived, int Arity> +class OpSpecificDynamicArity : public OpSpecificArity<Slot, Derived, Arity> { + using Base = OpSpecificArity<Slot, Derived, Arity>; + + std::vector<Slot> _dyNodes; + +public: + template <typename... Ts> + OpSpecificDynamicArity(std::vector<Slot> nodes, Ts&&... vals) + : Base({std::forward<Ts>(vals)...}), _dyNodes(std::move(nodes)) {} + + auto& nodes() { + return _dyNodes; + } + const auto& nodes() const { + return _dyNodes; + } +}; + +/*=====----- + * + * Semantic transport interface + * + */ +namespace detail { +template <typename D, typename T, typename = std::void_t<>> +struct has_prepare : std::false_type {}; +template <typename D, typename T> +struct has_prepare<D, T, std::void_t<decltype(std::declval<D>().prepare(std::declval<T&>()))>> + : std::true_type {}; + +template <typename D, typename T> +inline constexpr auto has_prepare_v = has_prepare<D, T>::value; + +template <typename Slot, typename Derived, int Arity> +inline constexpr int get_arity(const OpSpecificArity<Slot, Derived, Arity>*) { + return Arity; +} + +template <typename Slot, typename Derived, int Arity> +inline constexpr bool is_dynamic(const OpSpecificArity<Slot, Derived, Arity>*) { + return false; +} + +template <typename Slot, typename Derived, int Arity> +inline constexpr bool is_dynamic(const OpSpecificDynamicArity<Slot, Derived, Arity>*) { + return true; +} + +template <typename T> +using OpConcreteType = typename std::remove_reference_t<T>::template get_t<0>; +} // namespace detail + +template <typename D, bool withSlot> +class OpTransporter { + D& _domain; + + template <typename T, bool B> + struct Deducer {}; + template <typename T> + struct Deducer<T, true> { + using type = decltype(std::declval<D>().transport( + std::declval<T>(), std::declval<detail::OpConcreteType<T>&>())); + }; + template <typename T> + struct Deducer<T, false> { + using type = + decltype(std::declval<D>().transport(std::declval<detail::OpConcreteType<T>&>())); + }; + template <typename T> + using deduced_t = typename Deducer<T, withSlot>::type; + + template <typename N, typename T, typename... Ts> + auto transformStep(N&& slot, T&& op, Ts&&... args) { + if constexpr (withSlot) { + return _domain.transport( + std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...); + } else { + return _domain.transport(std::forward<T>(op), std::forward<Ts>(args)...); + } + } + + template <typename N, typename T, size_t... I> + auto transportUnpack(N&& slot, T&& op, std::index_sequence<I...>) { + return transformStep( + std::forward<N>(slot), std::forward<T>(op), op.template get<I>().visit(*this)...); + } + template <typename N, typename T, size_t... I> + auto transportDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>) { + std::vector<decltype(slot.visit(*this))> v; + for (auto& node : op.nodes()) { + v.emplace_back(node.visit(*this)); + } + return transformStep(std::forward<N>(slot), + std::forward<T>(op), + std::move(v), + op.template get<I>().visit(*this)...); + } + template <typename N, typename T, size_t... I> + void transportUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>) { + (op.template get<I>().visit(*this), ...); + return transformStep(std::forward<N>(slot), std::forward<T>(op), op.template get<I>()...); + } + template <typename N, typename T, size_t... I> + void transportDynamicUnpackVoid(N&& slot, T&& op, std::index_sequence<I...>) { + for (auto& node : op.nodes()) { + node.visit(*this); + } + (op.template get<I>().visit(*this), ...); + return transformStep( + std::forward<N>(slot), std::forward<T>(op), op.nodes(), op.template get<I>()...); + } + +public: + OpTransporter(D& domain) : _domain(domain) {} + + template <typename N, typename T, typename R = deduced_t<N>> + R operator()(N&& slot, T&& op) { + // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference + // T is either `A&` or `const A&` where A is one of Ts + using type = std::remove_reference_t<T>; + + constexpr int arity = detail::get_arity(static_cast<type*>(nullptr)); + constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr)); + + if constexpr (detail::has_prepare_v<D, type>) { + _domain.prepare(std::forward<T>(op)); + } + if constexpr (is_dynamic) { + if constexpr (std::is_same_v<R, void>) { + return transportDynamicUnpackVoid( + std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{}); + } else { + return transportDynamicUnpack( + std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{}); + } + } else { + if constexpr (std::is_same_v<R, void>) { + return transportUnpackVoid( + std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{}); + } else { + return transportUnpack( + std::forward<N>(slot), std::forward<T>(op), std::make_index_sequence<arity>{}); + } + } + } +}; + +template <typename D, bool withSlot> +class OpWalker { + D& _domain; + + template <typename N, typename T, typename... Ts> + auto walkStep(N&& slot, T&& op, Ts&&... args) { + if constexpr (withSlot) { + return _domain.walk( + std::forward<N>(slot), std::forward<T>(op), std::forward<Ts>(args)...); + } else { + return _domain.walk(std::forward<T>(op), std::forward<Ts>(args)...); + } + } + + template <typename N, typename T, typename... Args, size_t... I> + auto walkUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + return walkStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + op.template get<I>()...); + } + template <typename N, typename T, typename... Args, size_t... I> + auto walkDynamicUnpack(N&& slot, T&& op, std::index_sequence<I...>, Args&&... args) { + return walkStep(std::forward<N>(slot), + std::forward<T>(op), + std::forward<Args>(args)..., + op.nodes(), + op.template get<I>()...); + } + +public: + OpWalker(D& domain) : _domain(domain) {} + + template <typename N, typename T, typename... Args> + auto operator()(N&& slot, T&& op, Args&&... args) { + // N is either `PolyValue<Ts...>&` or `const PolyValue<Ts...>&` i.e. reference + // T is either `A&` or `const A&` where A is one of Ts + using type = std::remove_reference_t<T>; + + constexpr int arity = detail::get_arity(static_cast<type*>(nullptr)); + constexpr bool is_dynamic = detail::is_dynamic(static_cast<type*>(nullptr)); + + if constexpr (is_dynamic) { + return walkDynamicUnpack(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } else { + return walkUnpack(std::forward<N>(slot), + std::forward<T>(op), + std::make_index_sequence<arity>{}, + std::forward<Args>(args)...); + } + } +}; + +template <bool withSlot = false, typename D, typename N> +auto transport(N&& node, D& domain) { + return node.visit(OpTransporter<D, withSlot>{domain}); +} + +template <bool withSlot = false, typename D, typename N, typename... Args> +auto walk(N&& node, D& domain, Args&&... args) { + return node.visit(OpWalker<D, withSlot>{domain}, std::forward<Args>(args)...); +} + +} // namespace algebra +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/algebra/polyvalue.h b/src/mongo/db/query/optimizer/algebra/polyvalue.h new file mode 100644 index 00000000000..374041c5704 --- /dev/null +++ b/src/mongo/db/query/optimizer/algebra/polyvalue.h @@ -0,0 +1,381 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <array> +#include <stdexcept> +#include <type_traits> + +namespace mongo::optimizer { +namespace algebra { +namespace detail { + +template <typename T, typename... Args> +inline constexpr bool is_one_of_v = std::disjunction_v<std::is_same<T, Args>...>; + +template <typename T, typename... Args> +inline constexpr bool is_one_of_f() { + return is_one_of_v<T, Args...>; +} + +template <typename... Args> +struct is_unique_t : std::true_type {}; + +template <typename H, typename... T> +struct is_unique_t<H, T...> + : std::bool_constant<!is_one_of_f<H, T...>() && is_unique_t<T...>::value> {}; + +template <typename... Args> +inline constexpr bool is_unique_v = is_unique_t<Args...>::value; + +// Given the type T find its index in Ts +template <typename T, typename... Ts> +static inline constexpr int find_index() { + static_assert(detail::is_unique_v<Ts...>, "Types must be unique"); + constexpr bool matchVector[] = {std::is_same<T, Ts>::value...}; + + for (int index = 0; index < static_cast<int>(sizeof...(Ts)); ++index) { + if (matchVector[index]) { + return index; + } + } + + return -1; +} + +template <int N, typename T, typename... Ts> +struct get_type_by_index_impl { + using type = typename get_type_by_index_impl<N - 1, Ts...>::type; +}; +template <typename T, typename... Ts> +struct get_type_by_index_impl<0, T, Ts...> { + using type = T; +}; + +// Given the index I return the type from Ts +template <int I, typename... Ts> +using get_type_by_index = typename get_type_by_index_impl<I, Ts...>::type; + +} // namespace detail + +/*=====----- + * + * The overload trick to construct visitors from lambdas. + * + */ +template <class... Ts> +struct overload : Ts... { + using Ts::operator()...; +}; +template <class... Ts> +overload(Ts...)->overload<Ts...>; + +/*=====----- + * + * Forward declarations + * + */ +template <typename... Ts> +class PolyValue; + +template <typename T, typename... Ts> +class ControlBlockVTable; + +/*=====----- + * + * The base control block that PolyValue holds. + * + * It does not contain anything else by the runtime tag. + * + */ +template <typename... Ts> +class ControlBlock { + const int _tag; + +protected: + ControlBlock(int tag) noexcept : _tag(tag) {} + +public: + auto getRuntimeTag() const noexcept { + return _tag; + } +}; + +/*=====----- + * + * The concrete control block VTable generator. + * + * It must be empty ad PolyValue derives from the generators + * and we want EBO to kick in. + * + */ +template <typename T, typename... Ts> +class ControlBlockVTable { + static constexpr int _staticTag = detail::find_index<T, Ts...>(); + static_assert(_staticTag != -1, "Type must be on the list"); + + using AbstractType = ControlBlock<Ts...>; + using PolyValueType = PolyValue<Ts...>; + + /*=====----- + * + * The concrete control block for every type T of Ts. + * + * It derives from the ControlBlock. All methods are private and only + * the friend class ControlBlockVTable can call them. + * + */ + class ConcreteType : public AbstractType { + T _t; + + public: + template <typename... Args> + ConcreteType(Args&&... args) : AbstractType(_staticTag), _t(std::forward<Args>(args)...) {} + + const T* getPtr() const { + return &_t; + } + + T* getPtr() { + return &_t; + } + }; + + static constexpr auto concrete(AbstractType* block) noexcept { + return static_cast<ConcreteType*>(block); + } + + static constexpr auto concrete(const AbstractType* block) noexcept { + return static_cast<const ConcreteType*>(block); + } + +public: + template <typename... Args> + static AbstractType* make(Args&&... args) { + return new ConcreteType(std::forward<Args>(args)...); + } + + static AbstractType* clone(const AbstractType* block) { + return new ConcreteType(*concrete(block)); + } + + static void destroy(AbstractType* block) noexcept { + delete concrete(block); + } + + static bool compareEq(AbstractType* blockLhs, AbstractType* blockRhs) noexcept { + if (blockLhs->getRuntimeTag() == blockRhs->getRuntimeTag()) { + return *castConst<T>(blockLhs) == *castConst<T>(blockRhs); + } + return false; + } + + template <typename U> + static constexpr bool is_v = std::is_base_of_v<U, T>; + + template <typename U> + static U* cast(AbstractType* block) { + if constexpr (is_v<U>) { + return static_cast<U*>(concrete(block)->getPtr()); + } else { + // gcc bug 81676 + (void)block; + return nullptr; + } + } + + template <typename U> + static const U* castConst(const AbstractType* block) { + if constexpr (is_v<U>) { + return static_cast<const U*>(concrete(block)->getPtr()); + } else { + // gcc bug 81676 + (void)block; + return nullptr; + } + } + + template <typename V, typename... Args> + static auto visit(V&& v, PolyValueType& holder, AbstractType* block, Args&&... args) { + return v(holder, *cast<T>(block), std::forward<Args>(args)...); + } + + template <typename V, typename... Args> + static auto visitConst(V&& v, + const PolyValueType& holder, + const AbstractType* block, + Args&&... args) { + return v(holder, *castConst<T>(block), std::forward<Args>(args)...); + } +}; + +/*=====----- + * + * This is a variation on variant and polymorphic value theme. + * + * A tag based dispatch + * + * Supported operations: + * - construction + * - destruction + * - clone a = b; + * - cast a.cast<T>() + * - multi-method cast to common base a.cast<B>() + * - multi-method visit + */ +template <typename... Ts> +class PolyValue : private ControlBlockVTable<Ts, Ts...>... { + static_assert(detail::is_unique_v<Ts...>, "Types must be unique"); + static_assert(std::conjunction_v<std::is_empty<ControlBlockVTable<Ts, Ts...>>...>, + "VTable base classes must be empty"); + + ControlBlock<Ts...>* _object{nullptr}; + + PolyValue(ControlBlock<Ts...>* object) noexcept : _object(object) {} + + auto tag() const noexcept { + return _object->getRuntimeTag(); + } + + void check() const { + if (!_object) { + throw std::logic_error("PolyValue is empty"); + } + } + + static void destroy(ControlBlock<Ts...>* object) { + static constexpr std::array destroyTbl = {&ControlBlockVTable<Ts, Ts...>::destroy...}; + + destroyTbl[object->getRuntimeTag()](object); + } + +public: + PolyValue() = delete; + + PolyValue(const PolyValue& other) { + static constexpr std::array cloneTbl = {&ControlBlockVTable<Ts, Ts...>::clone...}; + if (other._object) { + _object = cloneTbl[other.tag()](other._object); + } + } + + PolyValue(PolyValue&& other) noexcept { + swap(other); + } + + ~PolyValue() noexcept { + if (_object) { + destroy(_object); + } + } + + PolyValue& operator=(PolyValue other) noexcept { + swap(other); + return *this; + } + + template <typename T, typename... Args> + static PolyValue make(Args&&... args) { + return PolyValue{ControlBlockVTable<T, Ts...>::make(std::forward<Args>(args)...)}; + } + + template <int I> + using get_t = detail::get_type_by_index<I, Ts...>; + + template <typename V, typename... Args> + auto visit(V&& v, Args&&... args) { + // unfortunately gcc rejects much nicer code, clang and msvc accept + // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template + // visit<V>... }; + + using FunPtrType = + decltype(&ControlBlockVTable<get_t<0>, Ts...>::template visit<V, Args...>); + static constexpr FunPtrType visitTbl[] = { + &ControlBlockVTable<Ts, Ts...>::template visit<V, Args...>...}; + + check(); + return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...); + } + + template <typename V, typename... Args> + auto visit(V&& v, Args&&... args) const { + // unfortunately gcc rejects much nicer code, clang and msvc accept + // static constexpr std::array visitTbl = { &ControlBlockVTable<Ts, Ts...>::template + // visitConst<V>... }; + + using FunPtrType = + decltype(&ControlBlockVTable<get_t<0>, Ts...>::template visitConst<V, Args...>); + static constexpr FunPtrType visitTbl[] = { + &ControlBlockVTable<Ts, Ts...>::template visitConst<V, Args...>...}; + + check(); + return visitTbl[tag()](std::forward<V>(v), *this, _object, std::forward<Args>(args)...); + } + + template <typename T> + T* cast() { + check(); + static constexpr std::array castTbl = {&ControlBlockVTable<Ts, Ts...>::template cast<T>...}; + return castTbl[tag()](_object); + } + + template <typename T> + const T* cast() const { + static constexpr std::array castTbl = { + &ControlBlockVTable<Ts, Ts...>::template castConst<T>...}; + + check(); + return castTbl[tag()](_object); + } + + template <typename T> + bool is() const { + static constexpr std::array isTbl = {ControlBlockVTable<Ts, Ts...>::template is_v<T>...}; + + check(); + return isTbl[tag()]; + } + + bool empty() const { + return !_object; + } + + void swap(PolyValue& other) noexcept { + std::swap(other._object, _object); + } + + bool operator==(const PolyValue& rhs) const noexcept { + static constexpr std::array cmp = {ControlBlockVTable<Ts, Ts...>::compareEq...}; + return cmp[tag()](_object, rhs._object); + } +}; + +} // namespace algebra +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/memo.cpp b/src/mongo/db/query/optimizer/memo.cpp new file mode 100644 index 00000000000..c4dadbb3d5a --- /dev/null +++ b/src/mongo/db/query/optimizer/memo.cpp @@ -0,0 +1,43 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/query/optimizer/algebra/operator.h" +#include "mongo/db/query/optimizer/memo.h" +#include "mongo/db/query/optimizer/node.h" + +namespace mongo::optimizer { + +std::string MemoGenerator::generateMemo(const PolymorphicNode& e) { + _os.str(""); + _os.clear(); + algebra::transport<false>(e, *this); + return _os.str(); +} + +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/visitor.h b/src/mongo/db/query/optimizer/memo.h index 1aa0a886fab..ad3703f8fd8 100644 --- a/src/mongo/db/query/optimizer/visitor.h +++ b/src/mongo/db/query/optimizer/memo.h @@ -31,16 +31,24 @@ #include <string> +#include "mongo/db/query/optimizer/node.h" + namespace mongo::optimizer { -class AbstractVisitor { +class MemoGenerator { public: - virtual void visit(const ScanNode& node) = 0; - virtual void visit(const MultiJoinNode& node) = 0; - virtual void visit(const UnionNode& node) = 0; - virtual void visit(const GroupByNode& node) = 0; - virtual void visit(const UnwindNode& node) = 0; - virtual void visit(const WindNode& node) = 0; + template <typename T, typename... Ts> + void transport(const T&, Ts&&...) {} + + template <typename T> + void prepare(const T& n) { + n.generateMemo(_os); + } + + std::string generateMemo(const PolymorphicNode& e); + +private: + std::ostringstream _os; }; -} // namespace mongo::optimizer +} // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp index 4836dcce39e..a1455efd60f 100644 --- a/src/mongo/db/query/optimizer/node.cpp +++ b/src/mongo/db/query/optimizer/node.cpp @@ -30,130 +30,19 @@ #include <functional> #include <stack> +#include "mongo/db/query/optimizer/memo.h" #include "mongo/db/query/optimizer/node.h" -#include "mongo/db/query/optimizer/visitor.h" -#include "mongo/util/assert_util.h" namespace mongo::optimizer { -Node::Node(Context& ctx) : _nodeId(ctx.getNextNodeId()), _children() {} - -Node::Node(Context& ctx, NodePtr child) : _nodeId(ctx.getNextNodeId()) { - _children.push_back(std::move(child)); -} - -Node::Node(Context& ctx, ChildVector children) - : _nodeId(ctx.getNextNodeId()), _children(std::move(children)) {} +Node::Node(Context& ctx) : _nodeId(ctx.getNextNodeId()) {} void Node::generateMemoBase(std::ostringstream& os) const { os << "NodeId: " << _nodeId << "\n"; } -void Node::visitPreOrder(AbstractVisitor& visitor) const { - visit(visitor); - for (const NodePtr& ptr : _children) { - ptr->visitPreOrder(visitor); - } -} - -void Node::visitPostOrder(AbstractVisitor& visitor) const { - for (const NodePtr& ptr : _children) { - ptr->visitPostOrder(visitor); - } - visit(visitor); -} - -std::string Node::generateMemo() const { - class MemoVisitor : public AbstractVisitor { - protected: - void visit(const ScanNode& node) override { - node.generateMemo(_os); - } - void visit(const MultiJoinNode& node) override { - node.generateMemo(_os); - } - void visit(const UnionNode& node) override { - node.generateMemo(_os); - } - void visit(const GroupByNode& node) override { - node.generateMemo(_os); - } - void visit(const UnwindNode& node) override { - node.generateMemo(_os); - } - void visit(const WindNode& node) override { - node.generateMemo(_os); - } - - public: - std::ostringstream _os; - }; - - MemoVisitor visitor; - visitPreOrder(visitor); - return visitor._os.str(); -} - -NodePtr Node::clone(Context& ctx) const { - class CloneVisitor : public AbstractVisitor { - public: - explicit CloneVisitor(Context& ctx) : _ctx(ctx), _childStack() {} - - protected: - void visit(const ScanNode& node) override { - doClone(node, [&](ChildVector v){ return ScanNode::clone(_ctx, node); }); - } - void visit(const MultiJoinNode& node) override { - doClone(node, [&](ChildVector v){ return MultiJoinNode::clone(_ctx, node, std::move(v)); }); - } - void visit(const UnionNode& node) override { - doClone(node, [&](ChildVector v){ return UnionNode::clone(_ctx, node, std::move(v)); }); - } - void visit(const GroupByNode& node) override { - doClone(node, [&](ChildVector v){ return GroupByNode::clone(_ctx, node, std::move(v.at(0))); }); - } - void visit(const UnwindNode& node) override { - doClone(node, [&](ChildVector v){ return UnwindNode::clone(_ctx, node, std::move(v.at(0))); }); - } - void visit(const WindNode& node) override { - doClone(node, [&](ChildVector v){ return WindNode::clone(_ctx, node, std::move(v.at(0))); }); - } - - private: - void doClone(const Node& node, const std::function<NodePtr(ChildVector newChildren)>& cloneFn) { - ChildVector newChildren; - for (int i = 0; i < node.getChildCount(); i++) { - newChildren.push_back(std::move(_childStack.top())); - _childStack.pop(); - } - _childStack.push(cloneFn(std::move(newChildren))); - } - - public: - Context& _ctx; - std::stack<NodePtr> _childStack; - }; - - CloneVisitor visitor(ctx); - visitPostOrder(visitor); - invariant(visitor._childStack.size() == 1); - return std::move(visitor._childStack.top()); -} - -int Node::getChildCount() const { - return _children.size(); -} - -NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) { - return NodePtr(new ScanNode(ctx, std::move(collectionName))); -} - -NodePtr ScanNode::clone(Context& ctx, const ScanNode& other) { - return create(ctx, other._collectionName); -} - ScanNode::ScanNode(Context& ctx, CollectionNameType collectionName) - : Node(ctx), _collectionName(std::move(collectionName)) {} + : Base(), Node(ctx), _collectionName(std::move(collectionName)) {} void ScanNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); @@ -161,27 +50,12 @@ void ScanNode::generateMemo(std::ostringstream& os) const { << "\n"; } -void ScanNode::visit(AbstractVisitor& visitor) const { - visitor.visit(*this); -} - -NodePtr MultiJoinNode::create(Context& ctx, - FilterSet filterSet, - ProjectionMap projectionMap, - ChildVector children) { - return NodePtr(new MultiJoinNode( - ctx, std::move(filterSet), std::move(projectionMap), std::move(children))); -} - -NodePtr MultiJoinNode::clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren) { - return create(ctx, other._filterSet, other._projectionMap, std::move(newChildren)); -} - MultiJoinNode::MultiJoinNode(Context& ctx, FilterSet filterSet, ProjectionMap projectionMap, - ChildVector children) - : Node(ctx, std::move(children)), + PolymorphicNodeVector children) + : Base(std::move(children)), + Node(ctx), _filterSet(std::move(filterSet)), _projectionMap(std::move(projectionMap)) {} @@ -191,20 +65,8 @@ void MultiJoinNode::generateMemo(std::ostringstream& os) const { << "\n"; } -void MultiJoinNode::visit(AbstractVisitor& visitor) const { - visitor.visit(*this); -} - -NodePtr UnionNode::create(Context& ctx, ChildVector children) { - return NodePtr(new UnionNode(ctx, std::move(children))); -} - -NodePtr UnionNode::clone(Context& ctx, const UnionNode& other, ChildVector newChildren) { - return create(ctx, std::move(newChildren)); -} - -UnionNode::UnionNode(Context& ctx, ChildVector children) - : Node(ctx, std::move(children)) {} +UnionNode::UnionNode(Context& ctx, PolymorphicNodeVector children) + : Base(std::move(children)), Node(ctx) {} void UnionNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); @@ -212,27 +74,12 @@ void UnionNode::generateMemo(std::ostringstream& os) const { << "\n"; } -void UnionNode::visit(AbstractVisitor& visitor) const { - visitor.visit(*this); -} - -NodePtr GroupByNode::create(Context& ctx, - GroupByNode::GroupByVector groupByVector, - GroupByNode::ProjectionMap projectionMap, - NodePtr child) { - return NodePtr( - new GroupByNode(ctx, std::move(groupByVector), std::move(projectionMap), std::move(child))); -} - -NodePtr GroupByNode::clone(Context& ctx, const GroupByNode& other, NodePtr newChild) { - return create(ctx, other._groupByVector, other._projectionMap, std::move(newChild)); -} - GroupByNode::GroupByNode(Context& ctx, GroupByNode::GroupByVector groupByVector, GroupByNode::ProjectionMap projectionMap, - NodePtr child) - : Node(ctx, std::move(child)), + PolymorphicNode child) + : Base(std::move(child)), + Node(ctx), _groupByVector(std::move(groupByVector)), _projectionMap(std::move(projectionMap)) {} @@ -242,27 +89,12 @@ void GroupByNode::generateMemo(std::ostringstream& os) const { << "\n"; } -void GroupByNode::visit(AbstractVisitor& visitor) const { - visitor.visit(*this); -} - -NodePtr UnwindNode::create(Context& ctx, - ProjectionName projectionName, - const bool retainNonArrays, - NodePtr child) { - return NodePtr( - new UnwindNode(ctx, std::move(projectionName), retainNonArrays, std::move(child))); -} - -NodePtr UnwindNode::clone(Context& ctx, const UnwindNode& other, NodePtr newChild) { - return create(ctx, other._projectionName, other._retainNonArrays, std::move(newChild)); -} - UnwindNode::UnwindNode(Context& ctx, ProjectionName projectionName, const bool retainNonArrays, - NodePtr child) - : Node(ctx, std::move(child)), + PolymorphicNode child) + : Base(std::move(child)), + Node(ctx), _projectionName(std::move(projectionName)), _retainNonArrays(retainNonArrays) {} @@ -272,20 +104,8 @@ void UnwindNode::generateMemo(std::ostringstream& os) const { << "\n"; } -void UnwindNode::visit(AbstractVisitor& visitor) const { - visitor.visit(*this); -} - -NodePtr WindNode::create(Context& ctx, ProjectionName projectionName, NodePtr child) { - return NodePtr(new WindNode(ctx, std::move(projectionName), std::move(child))); -} - -NodePtr WindNode::clone(Context& ctx, const WindNode& other, NodePtr newChild) { - return create(ctx, other._projectionName, std::move(newChild)); -} - -WindNode::WindNode(Context& ctx, ProjectionName projectionName, NodePtr child) - : Node(ctx, std::move(child)), _projectionName(std::move(projectionName)) {} +WindNode::WindNode(Context& ctx, ProjectionName projectionName, PolymorphicNode child) + : Base(std::move(child)), Node(ctx), _projectionName(std::move(projectionName)) {} void WindNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); @@ -293,8 +113,4 @@ void WindNode::generateMemo(std::ostringstream& os) const { << "\n"; } -void WindNode::visit(AbstractVisitor& visitor) const { - visitor.visit(*this); -} - } // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h index 78010d7d333..de56ade7b56 100644 --- a/src/mongo/db/query/optimizer/node.h +++ b/src/mongo/db/query/optimizer/node.h @@ -37,6 +37,7 @@ #include <utility> #include <vector> +#include "mongo/db/query/optimizer/algebra/operator.h" #include "mongo/db/query/optimizer/defs.h" #include "mongo/db/query/optimizer/filter.h" #include "mongo/db/query/optimizer/projection.h" @@ -45,157 +46,137 @@ namespace mongo::optimizer { -class Node; -using NodePtr = std::unique_ptr<Node>; -class AbstractVisitor; +class ScanNode; +class MultiJoinNode; +class UnionNode; +class GroupByNode; +class UnwindNode; +class WindNode; -class Node { -public: - using ChildVector = std::vector<NodePtr>; +using PolymorphicNode = + algebra::PolyValue<ScanNode, MultiJoinNode, UnionNode, GroupByNode, UnwindNode, WindNode>; + +template <typename Derived, size_t Arity> +using Operator = algebra::OpSpecificArity<PolymorphicNode, Derived, Arity>; + +template <typename Derived, size_t Arity> +using OperatorDynamic = algebra::OpSpecificDynamicArity<PolymorphicNode, Derived, Arity>; + +template <typename Derived> +using OperatorDynamicHomogenous = OperatorDynamic<Derived, 0>; + +using PolymorphicNodeVector = std::vector<PolymorphicNode>; + +template <typename T, typename... Args> +inline auto make(Args&&... args) { + return PolymorphicNode::make<T>(std::forward<Args>(args)...); +} +template <typename... Args> +inline auto makeSeq(Args&&... args) { + PolymorphicNodeVector seq; + (seq.emplace_back(std::forward<Args>(args)), ...); + return seq; +} + +class Node { protected: explicit Node(Context& ctx); - explicit Node(Context& ctx, NodePtr child); - explicit Node(Context& ctx, ChildVector children); void generateMemoBase(std::ostringstream& os) const; - virtual void visit(AbstractVisitor& visitor) const = 0; - void visitPreOrder(AbstractVisitor& visitor) const; - void visitPostOrder(AbstractVisitor& visitor) const; - - // clone public: Node() = delete; - std::string generateMemo() const; - - NodePtr clone(Context& ctx) const; - - int getChildCount() const; - private: const NodeIdType _nodeId; - ChildVector _children; }; -class ScanNode : public Node { +class ScanNode final : public Operator<ScanNode, 0>, public Node { + using Base = Operator<ScanNode, 0>; + public: - static NodePtr create(Context& ctx, CollectionNameType collectionName); - static NodePtr clone(Context& ctx, const ScanNode& other); + explicit ScanNode(Context& ctx, CollectionNameType collectionName); void generateMemo(std::ostringstream& os) const; -protected: - void visit(AbstractVisitor& visitor) const override; - private: - explicit ScanNode(Context& ctx, CollectionNameType collectionName); - const CollectionNameType _collectionName; }; -class MultiJoinNode : public Node { +class MultiJoinNode final : public OperatorDynamicHomogenous<MultiJoinNode>, public Node { + using Base = OperatorDynamicHomogenous<MultiJoinNode>; + public: using FilterSet = std::unordered_set<FilterType>; using ProjectionMap = std::unordered_map<ProjectionName, ProjectionType>; - static NodePtr create(Context& ctx, - FilterSet filterSet, - ProjectionMap projectionMap, - ChildVector children); - static NodePtr clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren); - - void generateMemo(std::ostringstream& os) const; - -protected: - void visit(AbstractVisitor& visitor) const override; - -private: explicit MultiJoinNode(Context& ctx, FilterSet filterSet, ProjectionMap projectionMap, - ChildVector children); + PolymorphicNodeVector children); + void generateMemo(std::ostringstream& os) const; + +private: FilterSet _filterSet; ProjectionMap _projectionMap; }; -class UnionNode : public Node { +class UnionNode final : public OperatorDynamicHomogenous<UnionNode>, public Node { + using Base = OperatorDynamicHomogenous<UnionNode>; + public: - static NodePtr create(Context& ctx, ChildVector children); - static NodePtr clone(Context& ctx, const UnionNode& other, ChildVector newChildren); + explicit UnionNode(Context& ctx, PolymorphicNodeVector children); void generateMemo(std::ostringstream& os) const; - -protected: - void visit(AbstractVisitor& visitor) const override; - -private: - explicit UnionNode(Context& ctx, ChildVector children); }; -class GroupByNode : public Node { +class GroupByNode : public Operator<GroupByNode, 1>, public Node { + using Base = Operator<GroupByNode, 1>; + public: using GroupByVector = std::vector<ProjectionName>; using ProjectionMap = std::unordered_map<ProjectionName, ProjectionType>; - static NodePtr create(Context& ctx, - GroupByVector groupByVector, - ProjectionMap projectionMap, - NodePtr child); - static NodePtr clone(Context& ctx, const GroupByNode& other, NodePtr newChild); - - void generateMemo(std::ostringstream& os) const; - -protected: - void visit(AbstractVisitor& visitor) const override; - -private: explicit GroupByNode(Context& ctx, GroupByVector groupByVector, ProjectionMap projectionMap, - NodePtr child); + PolymorphicNode child); + + void generateMemo(std::ostringstream& os) const; +private: GroupByVector _groupByVector; ProjectionMap _projectionMap; }; -class UnwindNode : public Node { +class UnwindNode final : public Operator<UnwindNode, 1>, public Node { + using Base = Operator<UnwindNode, 1>; + public: - static NodePtr create(Context& ctx, - ProjectionName projectionName, - bool retainNonArrays, - NodePtr child); - static NodePtr clone(Context& ctx, const UnwindNode& other, NodePtr newChild); + explicit UnwindNode(Context& ctx, + ProjectionName projectionName, + bool retainNonArrays, + PolymorphicNode child); void generateMemo(std::ostringstream& os) const; -protected: - void visit(AbstractVisitor& visitor) const override; - private: - UnwindNode(Context& ctx, ProjectionName projectionName, bool retainNonArrays, NodePtr child); - const ProjectionName _projectionName; const bool _retainNonArrays; }; -class WindNode : public Node { +class WindNode final : public Operator<WindNode, 1>, public Node { + using Base = Operator<WindNode, 1>; + public: - static NodePtr create(Context& ctx, ProjectionName projectionName, NodePtr child); - static NodePtr clone(Context& ctx, const WindNode& other, NodePtr newChild); + explicit WindNode(Context& ctx, ProjectionName projectionName, PolymorphicNode child); void generateMemo(std::ostringstream& os) const; -protected: - void visit(AbstractVisitor& visitor) const override; - private: - WindNode(Context& ctx, ProjectionName projectionName, NodePtr child); - const ProjectionName _projectionName; }; - } // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp index 86966e05a7e..f1cffe77303 100644 --- a/src/mongo/db/query/optimizer/optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/optimizer_test.cpp @@ -27,6 +27,7 @@ * it in the license file. */ +#include "mongo/db/query/optimizer/memo.h" #include "mongo/db/query/optimizer/node.h" #include "mongo/unittest/unittest.h" @@ -35,15 +36,20 @@ namespace { TEST(Optimizer, Basic) { Context ctx; + MemoGenerator gen; - NodePtr ptrScan = ScanNode::create(ctx, "test"); - Node::ChildVector v; - v.push_back(std::move(ptrScan)); - NodePtr ptrJoin = MultiJoinNode::create(ctx, {}, {}, std::move(v)); - ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", ptrJoin->generateMemo()); + PolymorphicNode scanNode = make<ScanNode>(ctx, "test"); + ASSERT_EQ("NodeId: 0\nScan\n", gen.generateMemo(scanNode)); - NodePtr cloned = ptrJoin->clone(ctx); - ASSERT_EQ("NodeId: 3\nMultiJoin\nNodeId: 2\nScan\n", cloned->generateMemo()); + PolymorphicNode joinNode = make<MultiJoinNode>(ctx, + MultiJoinNode::FilterSet{}, + MultiJoinNode::ProjectionMap{}, + makeSeq(std::move(scanNode))); + ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", gen.generateMemo(joinNode)); + + + PolymorphicNode cloned = joinNode; + ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", gen.generateMemo(cloned)); } } // namespace |